Compare commits
7 Commits
62fa2f41ae
...
b404191ffa
Author | SHA1 | Date | |
---|---|---|---|
![]() |
b404191ffa | ||
![]() |
9a76624904 | ||
![]() |
c39b70f6fa | ||
![]() |
f86457fc38 | ||
![]() |
81749ee18e | ||
![]() |
9992b226ea | ||
![]() |
10199e4171 |
4
.kiro/steering/focus.md
Normal file
4
.kiro/steering/focus.md
Normal file
@@ -0,0 +1,4 @@
|
||||
---
|
||||
inclusion: manual
|
||||
---
|
||||
focus only on web\dashboard.py and it's dependencies besides the usual support files (.env, launch.json, etc..) we're developing this dash as our project main entry and interaction
|
3
.kiro/steering/specs.md
Normal file
3
.kiro/steering/specs.md
Normal file
@@ -0,0 +1,3 @@
|
||||
---
|
||||
inclusion: manual
|
||||
---
|
41
.vscode/launch.json
vendored
41
.vscode/launch.json
vendored
@@ -138,46 +138,7 @@
|
||||
"order": 2
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "🧠 CNN Development Pipeline (Training + Analysis)",
|
||||
"configurations": [
|
||||
"🧠 Enhanced CNN Training with Backtesting",
|
||||
"🧪 CNN Live Training with Analysis",
|
||||
"📈 TensorBoard Monitor (All Runs)"
|
||||
],
|
||||
"stopAll": true,
|
||||
"presentation": {
|
||||
"hidden": false,
|
||||
"group": "Development",
|
||||
"order": 3
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "🎯 Enhanced Trading System (1s Bars + Cache + Monitor)",
|
||||
"configurations": [
|
||||
"🎯 Enhanced Scalping Dashboard (1s Bars + 15min Cache)",
|
||||
"🌙 Overnight Training Monitor (504M Model)"
|
||||
],
|
||||
"stopAll": true,
|
||||
"presentation": {
|
||||
"hidden": false,
|
||||
"group": "Enhanced Trading",
|
||||
"order": 4
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "🔥 COB Dashboard + 400M RL Trading System",
|
||||
"configurations": [
|
||||
"📈 COB Data Provider Dashboard",
|
||||
"🔥 Real-time RL COB Trader (400M Parameters)"
|
||||
],
|
||||
"stopAll": true,
|
||||
"presentation": {
|
||||
"hidden": false,
|
||||
"group": "COB Trading",
|
||||
"order": 5
|
||||
}
|
||||
},
|
||||
|
||||
{
|
||||
"name": "🌐 COBY Multi-Exchange System (Full Stack)",
|
||||
"configurations": [
|
||||
|
@@ -1071,8 +1071,9 @@ class DQNAgent:
|
||||
|
||||
# If no experiences provided, sample from memory
|
||||
if experiences is None:
|
||||
# Skip if memory is too small
|
||||
if len(self.memory) < self.batch_size:
|
||||
# Skip if memory is too small (allow early training for GPU warmup)
|
||||
min_required = min(getattr(self, 'batch_size', 32), 16)
|
||||
if len(self.memory) < min_required:
|
||||
return 0.0
|
||||
|
||||
# Sample random mini-batch from memory
|
||||
|
@@ -66,9 +66,23 @@ class StandardizedCNN(nn.Module):
|
||||
# Output processing layers
|
||||
self.output_processor = self._build_output_processor()
|
||||
|
||||
# Optional numeric return head (predicts percent change for 1s,1m,1h,1d)
|
||||
# Uses cnn_features (1024) to regress predicted returns per timeframe
|
||||
self.return_head = nn.Sequential(
|
||||
nn.Linear(1024, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
nn.Linear(256, 4) # [return_1s, return_1m, return_1h, return_1d]
|
||||
)
|
||||
|
||||
# Device management
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.to(self.device)
|
||||
try:
|
||||
import torch.backends.cudnn as cudnn
|
||||
cudnn.benchmark = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.info(f"StandardizedCNN '{model_name}' initialized")
|
||||
logger.info(f"Expected feature dimension: {self.expected_feature_dim}")
|
||||
@@ -175,6 +189,9 @@ class StandardizedCNN(nn.Module):
|
||||
# Process outputs for standardized format
|
||||
action_probs = self.output_processor(cnn_features) # [batch, 3]
|
||||
|
||||
# Predict numeric returns per timeframe from cnn_features
|
||||
predicted_returns = self.return_head(cnn_features) # [batch, 4]
|
||||
|
||||
# Prepare hidden states for cross-model feeding
|
||||
hidden_states = {
|
||||
'processed_features': processed_features.detach(),
|
||||
@@ -186,7 +203,7 @@ class StandardizedCNN(nn.Module):
|
||||
'attention_weights': torch.ones(batch_size, 1, device=x.device) # Placeholder
|
||||
}
|
||||
|
||||
return action_probs, hidden_states
|
||||
return action_probs, hidden_states, predicted_returns.detach()
|
||||
|
||||
def predict_from_base_input(self, base_input: BaseDataInput) -> ModelOutput:
|
||||
"""
|
||||
@@ -210,7 +227,7 @@ class StandardizedCNN(nn.Module):
|
||||
|
||||
with torch.no_grad():
|
||||
# Forward pass
|
||||
action_probs, hidden_states = self.forward(input_tensor)
|
||||
action_probs, hidden_states, predicted_returns = self.forward(input_tensor)
|
||||
|
||||
# Get action and confidence
|
||||
action_probs_np = action_probs.squeeze(0).cpu().numpy()
|
||||
@@ -233,6 +250,19 @@ class StandardizedCNN(nn.Module):
|
||||
'market_conditions': self._interpret_advanced_predictions(hidden_states.get('advanced_predictions'))
|
||||
}
|
||||
|
||||
# Add numeric predicted returns per timeframe if available
|
||||
try:
|
||||
pr = predicted_returns.squeeze(0).cpu().numpy().tolist()
|
||||
# Ensure length 4; if not, safely handle
|
||||
if isinstance(pr, list) and len(pr) >= 4:
|
||||
predictions['predicted_returns'] = pr[:4]
|
||||
predictions['predicted_return_1s'] = float(pr[0])
|
||||
predictions['predicted_return_1m'] = float(pr[1])
|
||||
predictions['predicted_return_1h'] = float(pr[2])
|
||||
predictions['predicted_return_1d'] = float(pr[3])
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Prepare hidden states for cross-model feeding (convert tensors to numpy)
|
||||
cross_model_states = {}
|
||||
for key, tensor in hidden_states.items():
|
||||
|
11
NN/training/samples/readme.md
Normal file
11
NN/training/samples/readme.md
Normal file
@@ -0,0 +1,11 @@
|
||||
<!-- example text dump: -->
|
||||
|
||||
symbol MAIN SYMBOL (ETH) REF1 (BTC) REF2 (SPX) REF3 (SOL)
|
||||
timeframe 1s 1m 1h 1d 1s 1s 1s
|
||||
datapoint O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp O H L C V Timestamp
|
||||
2025-01-15T10:00:00Z 3421.5 3421.75 3421.25 3421.6 125.4 2025-01-15T10:00:00Z 3422.1 3424.8 3420.5 3423.25 1245.7 2025-01-15T10:00:00Z 3420 3428.5 3418.75 3425.1 12847.2 2025-01-15T10:00:00Z 3415.25 3435.6 3410.8 3430.4 145238.6 2025-01-15T10:00:00Z 97850.2 97852.4 97848.1 97851.3 8.7 2025-01-15T10:00:00Z 5925.4 5926.1 5924.8 5925.7 0 2025-01-15T10:00:00Z 191.22 191.45 191.08 191.35 1247.3
|
||||
2025-01-15T10:00:01Z 3421.6 3421.85 3421.45 3421.75 98.2 2025-01-15T10:01:00Z 3423.25 3425.9 3421.8 3424.6 1189.3 2025-01-15T11:00:00Z 3425.1 3432.2 3422.4 3429.8 11960.5 2025-01-16T10:00:00Z 3430.4 3445.2 3425.15 3440.85 138947.1 2025-01-15T10:00:01Z 97851.3 97853.8 97849.5 97852.9 9.1 2025-01-15T10:00:01Z 5925.7 5926.3 5925.2 5925.9 0 2025-01-15T10:00:01Z 191.35 191.58 191.15 191.48 1156.7
|
||||
2025-01-15T10:00:02Z 3421.75 3421.95 3421.55 3421.8 110.6 2025-01-15T10:02:00Z 3424.6 3427.15 3423.4 3425.9 1356.8 2025-01-15T12:00:00Z 3429.8 3436.7 3427.2 3434.5 13205.9 2025-01-17T10:00:00Z 3440.85 3455.3 3438.9 3450.75 142568.3 2025-01-15T10:00:02Z 97852.9 97855.2 97850.7 97854.6 7.9 2025-01-15T10:00:02Z 5925.9 5926.5 5925.4 5926.1 0 2025-01-15T10:00:02Z 191.48 191.72 191.28 191.61 1298.4
|
||||
2025-01-15T10:00:03Z 3421.8 3422.05 3421.65 3421.9 87.3 2025-01-15T10:03:00Z 3425.9 3428.4 3424.2 3427.1 1423.5 2025-01-15T13:00:00Z 3434.5 3441.8 3432.1 3438.2 14087.6 2025-01-18T10:00:00Z 3450.75 3465.4 3448.6 3460.2 149825.7 2025-01-15T10:00:03Z 97854.6 97857.1 97852.3 97856.8 8.4 2025-01-15T10:00:03Z 5926.1 5926.7 5925.6 5926.3 0 2025-01-15T10:00:03Z 191.61 191.85 191.42 191.74 1187.9
|
||||
2025-01-15T10:00:04Z 3421.9 3422.15 3421.75 3422.0 134.7 2025-01-15T10:04:00Z 3427.1 3429.6 3425.8 3428.3 1298.2 2025-01-15T14:00:00Z 3438.2 3445.6 3436.4 3442.1 12734.8 2025-01-19T10:00:00Z 3460.2 3475.8 3457.4 3470.6 156742.4 2025-01-15T10:00:04Z 97856.8 97859.4 97854.9 97858.2 9.2 2025-01-15T10:00:04Z 5926.3 5926.9 5925.8 5926.5 0 2025-01-15T10:00:04Z 191.74 191.98 191.55 191.87 1342.6
|
||||
2025-01-15T10:00:05Z 3422.0 3422.25 3421.85 3422.1 156.8 2025-01-15T10:05:00Z 3428.3 3430.8 3426.9 3429.5 1467.9 2025-01-15T15:00:00Z 3442.1 3449.3 3440.7 3446.8 11823.4 2025-01-20T10:00:00Z 3470.6 3485.2 3467.9 3480.1 163456.8 2025-01-15T10:00:05Z 97858.2 97860.7 97856.4 97859.8 8.8 2025-01-15T10:00:05Z 5926.5 5927.1 5926.0 5926.7 0 2025-01-15T10:00:05Z 191.87 192.11 191.68 192.0 1278.5
|
@@ -110,4 +110,10 @@ I want it more to be a part of a proper rewardfunction bias rather than a algori
|
||||
THINK REALY HARD
|
||||
|
||||
|
||||
do we evaluate and reward/punish each model at each reference?
|
||||
do we evaluate and reward/punish each model at each reference?
|
||||
|
||||
|
||||
|
||||
|
||||
in our realtime Reinforcement learning training how do we calculate the score (reward/penalty?)
|
||||
Let's use the mean squared difference between the prediction and the empirical outcome. We should do a training run at each inference which will use the last inference's prediction and the current price as outcome. do that up to 6 last predictions and calculating accuracity separately to have a better picture of the ability to predict couple of timeframes in the future. additionally to the frequent inference every 1 or 5s (i forgot the curent CNN rate) do an inference at each new timeframe interval. model should get the full data (multi timeframe - ETH (main) 1s 1m 1h 1d and 1m for BTC, SPX and one more) but should also know on what timeframe it is predicting. we predict only on the main symbol - so in 4 timeframes. bur on every hour we will do 4 inferences - one for each timeframe
|
@@ -6,6 +6,15 @@ system:
|
||||
log_level: "INFO" # DEBUG, INFO, WARNING, ERROR
|
||||
session_timeout: 3600 # Session timeout in seconds
|
||||
|
||||
# LLM Proxy Configuration
|
||||
llm_proxy:
|
||||
base_url: "http://localhost:1234" # LLM server base URL
|
||||
model: "openai/gpt-oss-20b" # Model name
|
||||
temperature: 0.7 # Response creativity (0.0-1.0)
|
||||
max_tokens: -1 # Max response tokens (-1 for unlimited)
|
||||
timeout: 30 # Request timeout in seconds
|
||||
api_key: null # API key if required
|
||||
|
||||
# Cold Start Mode Configuration
|
||||
cold_start:
|
||||
enabled: true # Enable cold start mode logic
|
||||
|
@@ -1268,15 +1268,12 @@ class DataProvider:
|
||||
logger.debug(f"No valid candles generated for {symbol}")
|
||||
return None
|
||||
|
||||
# Convert to DataFrame (timestamps remain UTC tz-aware)
|
||||
# Convert to DataFrame and normalize timestamps to UTC tz-aware
|
||||
df = pd.DataFrame(candles)
|
||||
# Ensure timestamps are timezone-aware (UTC to match COB WebSocket data)
|
||||
if not df.empty and 'timestamp' in df.columns:
|
||||
# Normalize to UTC tz-aware using pandas idioms
|
||||
if df['timestamp'].dt.tz is None:
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True)
|
||||
else:
|
||||
df['timestamp'] = df['timestamp'].dt.tz_convert('UTC')
|
||||
# Coerce to datetime with UTC; avoid .dt on non-datetimelike
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], utc=True, errors='coerce')
|
||||
df = df.dropna(subset=['timestamp'])
|
||||
|
||||
df = df.sort_values('timestamp').reset_index(drop=True)
|
||||
|
||||
@@ -1315,13 +1312,19 @@ class DataProvider:
|
||||
|
||||
# For 1s timeframe, try to generate from WebSocket ticks first
|
||||
if timeframe == '1s':
|
||||
# logger.info(f"Attempting to generate 1s candles from WebSocket ticks for {symbol}")
|
||||
generated_df = self._generate_1s_candles_from_ticks(symbol, limit)
|
||||
# Attempt to generate from WebSocket ticks, but throttle attempts to avoid spam
|
||||
if not hasattr(self, '_last_1s_generation_attempt'):
|
||||
self._last_1s_generation_attempt = {}
|
||||
now_ts = time.time()
|
||||
last_attempt = self._last_1s_generation_attempt.get(symbol, 0)
|
||||
generated_df = None
|
||||
if now_ts - last_attempt >= 1.5:
|
||||
self._last_1s_generation_attempt[symbol] = now_ts
|
||||
generated_df = self._generate_1s_candles_from_ticks(symbol, limit)
|
||||
if generated_df is not None and not generated_df.empty:
|
||||
# logger.info(f"Successfully generated 1s candles from WebSocket ticks for {symbol}")
|
||||
return generated_df
|
||||
else:
|
||||
logger.info(f"Could not generate 1s candles from ticks for {symbol}; trying Binance API")
|
||||
logger.debug(f"Could not generate 1s candles from ticks for {symbol}; trying Binance API")
|
||||
|
||||
# Convert symbol format
|
||||
binance_symbol = symbol.replace('/', '').upper()
|
||||
|
@@ -446,8 +446,8 @@ class EnhancedCOBWebSocket:
|
||||
# Add ping/pong handling and proper connection management
|
||||
async with websockets_connect(
|
||||
ws_url,
|
||||
ping_interval=20, # Binance sends ping every 20 seconds
|
||||
ping_timeout=60, # Binance disconnects after 1 minute without pong
|
||||
ping_interval=25, # Slightly longer than default
|
||||
ping_timeout=90, # Allow longer time before timeout
|
||||
close_timeout=10
|
||||
) as websocket:
|
||||
# Connection successful
|
||||
@@ -539,8 +539,11 @@ class EnhancedCOBWebSocket:
|
||||
|
||||
# Wait before reconnecting
|
||||
status.increase_reconnect_delay()
|
||||
logger.info(f"Waiting {status.reconnect_delay:.1f}s before reconnecting {symbol}")
|
||||
await asyncio.sleep(status.reconnect_delay)
|
||||
# Add jitter to avoid synchronized reconnects
|
||||
jitter = 0.5 + (random.random() * 1.5)
|
||||
delay = status.reconnect_delay * jitter
|
||||
logger.info(f"Waiting {delay:.1f}s before reconnecting {symbol}")
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
async def _process_websocket_message(self, symbol: str, data: Dict):
|
||||
"""Process WebSocket message and convert to COB format
|
||||
|
494
core/enhanced_reward_calculator.py
Normal file
494
core/enhanced_reward_calculator.py
Normal file
@@ -0,0 +1,494 @@
|
||||
"""
|
||||
Enhanced Reward Calculator for Reinforcement Learning Training
|
||||
|
||||
This module implements a comprehensive reward calculation system based on mean squared error
|
||||
between predictions and empirical outcomes. It tracks multiple timeframes separately and
|
||||
maintains prediction history for accurate reward computation.
|
||||
|
||||
Key Features:
|
||||
- MSE-based reward calculation for prediction accuracy
|
||||
- Multi-timeframe prediction tracking (1s, 1m, 1h, 1d)
|
||||
- Separate accuracy tracking for each timeframe
|
||||
- Prediction history tracking (last 6 predictions per timeframe)
|
||||
- Real-time training at each inference
|
||||
- Timeframe-aware inference scheduling
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from collections import deque
|
||||
from datetime import datetime, timedelta
|
||||
import numpy as np
|
||||
import threading
|
||||
from enum import Enum
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimeFrame(Enum):
|
||||
"""Supported timeframes for prediction"""
|
||||
SECONDS_1 = "1s"
|
||||
MINUTES_1 = "1m"
|
||||
HOURS_1 = "1h"
|
||||
DAYS_1 = "1d"
|
||||
|
||||
|
||||
@dataclass
|
||||
class PredictionRecord:
|
||||
"""Individual prediction record with outcome tracking"""
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
timeframe: TimeFrame
|
||||
predicted_price: float
|
||||
predicted_direction: int # -1: down, 0: neutral, 1: up
|
||||
confidence: float
|
||||
current_price: float
|
||||
model_name: str
|
||||
# Optional state vector used for prediction/training (standardized feature/state)
|
||||
state_vector: Optional[list] = None
|
||||
|
||||
# Outcome fields (set when outcome is determined)
|
||||
actual_price: Optional[float] = None
|
||||
actual_direction: Optional[int] = None
|
||||
outcome_timestamp: Optional[datetime] = None
|
||||
mse_reward: Optional[float] = None
|
||||
direction_correct: Optional[bool] = None
|
||||
is_evaluated: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimeframeAccuracy:
|
||||
"""Accuracy tracking for a specific timeframe"""
|
||||
timeframe: TimeFrame
|
||||
total_predictions: int = 0
|
||||
correct_directions: int = 0
|
||||
total_mse: float = 0.0
|
||||
prediction_history: deque = field(default_factory=lambda: deque(maxlen=6))
|
||||
|
||||
@property
|
||||
def direction_accuracy(self) -> float:
|
||||
"""Calculate directional accuracy percentage"""
|
||||
if self.total_predictions == 0:
|
||||
return 0.0
|
||||
return (self.correct_directions / self.total_predictions) * 100.0
|
||||
|
||||
@property
|
||||
def average_mse(self) -> float:
|
||||
"""Calculate average MSE"""
|
||||
if self.total_predictions == 0:
|
||||
return 0.0
|
||||
return self.total_mse / self.total_predictions
|
||||
|
||||
|
||||
class EnhancedRewardCalculator:
|
||||
"""
|
||||
Enhanced reward calculator using MSE and multi-timeframe tracking
|
||||
|
||||
This calculator:
|
||||
1. Tracks predictions for multiple timeframes separately
|
||||
2. Calculates MSE-based rewards when outcomes are available
|
||||
3. Maintains prediction history for last 6 predictions per timeframe
|
||||
4. Provides separate accuracy metrics for each timeframe
|
||||
5. Enables real-time training at each inference
|
||||
"""
|
||||
|
||||
def __init__(self, symbols: List[str] = None):
|
||||
"""Initialize the enhanced reward calculator"""
|
||||
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||
self.timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1, TimeFrame.HOURS_1, TimeFrame.DAYS_1]
|
||||
|
||||
# Prediction storage: symbol -> timeframe -> deque of PredictionRecord
|
||||
self.predictions: Dict[str, Dict[TimeFrame, deque]] = {}
|
||||
|
||||
# Accuracy tracking: symbol -> timeframe -> TimeframeAccuracy
|
||||
self.accuracy_tracker: Dict[str, Dict[TimeFrame, TimeframeAccuracy]] = {}
|
||||
|
||||
# Evaluation timeouts for each timeframe (in seconds)
|
||||
self.evaluation_timeouts = {
|
||||
TimeFrame.SECONDS_1: 5, # Evaluate 1s predictions after 5 seconds
|
||||
TimeFrame.MINUTES_1: 60, # Evaluate 1m predictions after 1 minute
|
||||
TimeFrame.HOURS_1: 300, # Evaluate 1h predictions after 5 minutes
|
||||
TimeFrame.DAYS_1: 900 # Evaluate 1d predictions after 15 minutes
|
||||
}
|
||||
|
||||
# Price data cache for outcome evaluation
|
||||
self.price_cache: Dict[str, List[Tuple[datetime, float]]] = {}
|
||||
self.price_cache_max_size = 1000
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Initialize data structures
|
||||
self._initialize_data_structures()
|
||||
|
||||
logger.info(f"EnhancedRewardCalculator initialized for symbols: {self.symbols}")
|
||||
logger.info(f"Timeframes: {[tf.value for tf in self.timeframes]}")
|
||||
logger.info(f"Evaluation timeouts: {[(tf.value, timeout) for tf, timeout in self.evaluation_timeouts.items()]}")
|
||||
|
||||
def _initialize_data_structures(self):
|
||||
"""Initialize nested data structures"""
|
||||
for symbol in self.symbols:
|
||||
self.predictions[symbol] = {}
|
||||
self.accuracy_tracker[symbol] = {}
|
||||
self.price_cache[symbol] = []
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
self.predictions[symbol][timeframe] = deque(maxlen=100) # Keep last 100 predictions
|
||||
self.accuracy_tracker[symbol][timeframe] = TimeframeAccuracy(timeframe)
|
||||
|
||||
def add_prediction(self,
|
||||
symbol: str,
|
||||
timeframe: TimeFrame,
|
||||
predicted_price: float,
|
||||
predicted_return: Optional[float] = None,
|
||||
predicted_direction: int,
|
||||
confidence: float,
|
||||
current_price: float,
|
||||
model_name: str,
|
||||
state_vector: Optional[list] = None) -> str:
|
||||
"""
|
||||
Add a new prediction to track
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe for this prediction
|
||||
predicted_price: Model's predicted price
|
||||
predicted_direction: Predicted direction (-1, 0, 1)
|
||||
confidence: Model's confidence (0.0 to 1.0)
|
||||
current_price: Current market price
|
||||
model_name: Name of the model making prediction
|
||||
|
||||
Returns:
|
||||
Unique prediction ID for later reference
|
||||
"""
|
||||
with self.lock:
|
||||
prediction = PredictionRecord(
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
predicted_price=predicted_price,
|
||||
predicted_direction=predicted_direction,
|
||||
confidence=confidence,
|
||||
current_price=current_price,
|
||||
model_name=model_name,
|
||||
state_vector=state_vector
|
||||
)
|
||||
|
||||
# If predicted_return provided, prefer computing implied predicted_price
|
||||
# to avoid synthetic price fabrication
|
||||
try:
|
||||
if predicted_return is not None and current_price > 0:
|
||||
prediction.predicted_price = current_price * (1.0 + float(predicted_return))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Store prediction
|
||||
if symbol not in self.predictions:
|
||||
self._initialize_data_structures()
|
||||
|
||||
self.predictions[symbol][timeframe].append(prediction)
|
||||
|
||||
# Add to accuracy tracker history
|
||||
self.accuracy_tracker[symbol][timeframe].prediction_history.append(prediction)
|
||||
|
||||
prediction_id = f"{symbol}_{timeframe.value}_{prediction.timestamp.isoformat()}_{model_name}"
|
||||
|
||||
logger.debug(f"Added prediction: {prediction_id}, predicted_price={predicted_price:.4f}, "
|
||||
f"direction={predicted_direction}, confidence={confidence:.3f}")
|
||||
|
||||
return prediction_id
|
||||
|
||||
def update_price(self, symbol: str, price: float, timestamp: datetime = None):
|
||||
"""
|
||||
Update current price for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
price: Current price
|
||||
timestamp: Price timestamp (defaults to now)
|
||||
"""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
|
||||
with self.lock:
|
||||
if symbol not in self.price_cache:
|
||||
self.price_cache[symbol] = []
|
||||
|
||||
self.price_cache[symbol].append((timestamp, price))
|
||||
|
||||
# Maintain cache size
|
||||
if len(self.price_cache[symbol]) > self.price_cache_max_size:
|
||||
self.price_cache[symbol] = self.price_cache[symbol][-self.price_cache_max_size:]
|
||||
|
||||
def evaluate_predictions(self, symbol: str = None) -> Dict[str, List[Tuple[PredictionRecord, float]]]:
|
||||
"""
|
||||
Evaluate pending predictions and calculate rewards
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol to evaluate (None for all symbols)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping symbol to list of (prediction, reward) tuples
|
||||
"""
|
||||
results = {}
|
||||
symbols_to_evaluate = [symbol] if symbol else self.symbols
|
||||
|
||||
with self.lock:
|
||||
for sym in symbols_to_evaluate:
|
||||
if sym not in self.predictions:
|
||||
continue
|
||||
|
||||
results[sym] = []
|
||||
current_time = datetime.now()
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
predictions_to_evaluate = []
|
||||
|
||||
# Find predictions ready for evaluation
|
||||
for prediction in self.predictions[sym][timeframe]:
|
||||
if prediction.is_evaluated:
|
||||
continue
|
||||
|
||||
time_elapsed = (current_time - prediction.timestamp).total_seconds()
|
||||
timeout = self.evaluation_timeouts[timeframe]
|
||||
|
||||
if time_elapsed >= timeout:
|
||||
predictions_to_evaluate.append(prediction)
|
||||
|
||||
# Evaluate predictions
|
||||
for prediction in predictions_to_evaluate:
|
||||
reward = self._calculate_prediction_reward(prediction)
|
||||
if reward is not None:
|
||||
results[sym].append((prediction, reward))
|
||||
|
||||
# Update accuracy tracking
|
||||
self._update_accuracy_tracking(sym, timeframe, prediction)
|
||||
|
||||
return results
|
||||
|
||||
def _calculate_prediction_reward(self, prediction: PredictionRecord) -> Optional[float]:
|
||||
"""
|
||||
Calculate MSE-based reward for a prediction
|
||||
|
||||
Args:
|
||||
prediction: Prediction record to evaluate
|
||||
|
||||
Returns:
|
||||
Calculated reward or None if outcome cannot be determined
|
||||
"""
|
||||
# Get actual price at evaluation time
|
||||
actual_price = self._get_price_at_time(
|
||||
prediction.symbol,
|
||||
prediction.timestamp + timedelta(seconds=self.evaluation_timeouts[prediction.timeframe])
|
||||
)
|
||||
|
||||
if actual_price is None:
|
||||
logger.debug(f"Cannot evaluate prediction - no price data available for {prediction.symbol}")
|
||||
return None
|
||||
|
||||
# Calculate price change and direction
|
||||
price_change = actual_price - prediction.current_price
|
||||
actual_direction = 1 if price_change > 0 else (-1 if price_change < 0 else 0)
|
||||
|
||||
# Calculate MSE reward
|
||||
price_error = actual_price - prediction.predicted_price
|
||||
mse = price_error ** 2
|
||||
|
||||
# Normalize MSE to a reasonable reward scale (lower MSE = higher reward)
|
||||
# Use exponential decay to heavily penalize large errors
|
||||
max_mse = (prediction.current_price * 0.1) ** 2 # 10% price change as max expected error
|
||||
normalized_mse = min(mse / max_mse, 1.0)
|
||||
mse_reward = np.exp(-5 * normalized_mse) # Exponential decay, range [exp(-5), 1]
|
||||
|
||||
# Direction accuracy bonus/penalty
|
||||
direction_correct = (prediction.predicted_direction == actual_direction)
|
||||
direction_bonus = 0.5 if direction_correct else -0.5
|
||||
|
||||
# Confidence scaling
|
||||
confidence_weight = prediction.confidence
|
||||
|
||||
# Final reward calculation
|
||||
base_reward = mse_reward + direction_bonus
|
||||
final_reward = base_reward * confidence_weight
|
||||
|
||||
# Update prediction record
|
||||
prediction.actual_price = actual_price
|
||||
prediction.actual_direction = actual_direction
|
||||
prediction.outcome_timestamp = datetime.now()
|
||||
prediction.mse_reward = final_reward
|
||||
prediction.direction_correct = direction_correct
|
||||
prediction.is_evaluated = True
|
||||
|
||||
logger.debug(f"Evaluated prediction: {prediction.symbol} {prediction.timeframe.value}, "
|
||||
f"MSE={mse:.6f}, direction_correct={direction_correct}, "
|
||||
f"confidence={confidence_weight:.3f}, reward={final_reward:.4f}")
|
||||
|
||||
return final_reward
|
||||
|
||||
def _get_price_at_time(self, symbol: str, target_time: datetime) -> Optional[float]:
|
||||
"""
|
||||
Get price for symbol at a specific time
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
target_time: Target timestamp
|
||||
|
||||
Returns:
|
||||
Price at target time or None if not available
|
||||
"""
|
||||
if symbol not in self.price_cache or not self.price_cache[symbol]:
|
||||
return None
|
||||
|
||||
# Find closest price to target time
|
||||
closest_price = None
|
||||
min_time_diff = float('inf')
|
||||
|
||||
for timestamp, price in self.price_cache[symbol]:
|
||||
time_diff = abs((timestamp - target_time).total_seconds())
|
||||
if time_diff < min_time_diff:
|
||||
min_time_diff = time_diff
|
||||
closest_price = price
|
||||
|
||||
# Only return price if it's within reasonable time window (30 seconds)
|
||||
if min_time_diff <= 30:
|
||||
return closest_price
|
||||
|
||||
return None
|
||||
|
||||
def _update_accuracy_tracking(self, symbol: str, timeframe: TimeFrame, prediction: PredictionRecord):
|
||||
"""Update accuracy tracking for a timeframe"""
|
||||
tracker = self.accuracy_tracker[symbol][timeframe]
|
||||
|
||||
tracker.total_predictions += 1
|
||||
if prediction.direction_correct:
|
||||
tracker.correct_directions += 1
|
||||
|
||||
if prediction.mse_reward is not None:
|
||||
# Convert reward back to MSE for tracking
|
||||
# Since reward = exp(-5 * normalized_mse), we can reverse it
|
||||
normalized_mse = -np.log(max(prediction.mse_reward, 0.001)) / 5
|
||||
max_mse = (prediction.current_price * 0.1) ** 2
|
||||
mse = normalized_mse * max_mse
|
||||
tracker.total_mse += mse
|
||||
|
||||
def get_accuracy_summary(self, symbol: str = None) -> Dict[str, Dict[str, Dict[str, float]]]:
|
||||
"""
|
||||
Get accuracy summary for all or specific symbol
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol (None for all)
|
||||
|
||||
Returns:
|
||||
Nested dictionary with accuracy metrics
|
||||
"""
|
||||
summary = {}
|
||||
symbols_to_summarize = [symbol] if symbol else self.symbols
|
||||
|
||||
with self.lock:
|
||||
for sym in symbols_to_summarize:
|
||||
if sym not in self.accuracy_tracker:
|
||||
continue
|
||||
|
||||
summary[sym] = {}
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
tracker = self.accuracy_tracker[sym][timeframe]
|
||||
summary[sym][timeframe.value] = {
|
||||
'total_predictions': tracker.total_predictions,
|
||||
'direction_accuracy': tracker.direction_accuracy,
|
||||
'average_mse': tracker.average_mse,
|
||||
'recent_predictions': len(tracker.prediction_history)
|
||||
}
|
||||
|
||||
return summary
|
||||
|
||||
def get_training_data(self, symbol: str, timeframe: TimeFrame,
|
||||
max_samples: int = 50) -> List[Tuple[PredictionRecord, float]]:
|
||||
"""
|
||||
Get recent evaluated predictions for training
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Specific timeframe
|
||||
max_samples: Maximum number of samples to return
|
||||
|
||||
Returns:
|
||||
List of (prediction, reward) tuples ready for training
|
||||
"""
|
||||
training_data = []
|
||||
|
||||
with self.lock:
|
||||
if symbol not in self.predictions or timeframe not in self.predictions[symbol]:
|
||||
return training_data
|
||||
|
||||
evaluated_predictions = [
|
||||
p for p in self.predictions[symbol][timeframe]
|
||||
if p.is_evaluated and p.mse_reward is not None
|
||||
]
|
||||
|
||||
# Get most recent evaluated predictions
|
||||
recent_predictions = list(evaluated_predictions)[-max_samples:]
|
||||
|
||||
for prediction in recent_predictions:
|
||||
training_data.append((prediction, prediction.mse_reward))
|
||||
|
||||
return training_data
|
||||
|
||||
def cleanup_old_predictions(self, days_to_keep: int = 7):
|
||||
"""
|
||||
Clean up old predictions to manage memory
|
||||
|
||||
Args:
|
||||
days_to_keep: Number of days of predictions to keep
|
||||
"""
|
||||
cutoff_time = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
with self.lock:
|
||||
for symbol in self.predictions:
|
||||
for timeframe in self.timeframes:
|
||||
# Filter out old predictions
|
||||
old_count = len(self.predictions[symbol][timeframe])
|
||||
|
||||
self.predictions[symbol][timeframe] = deque(
|
||||
[p for p in self.predictions[symbol][timeframe]
|
||||
if p.timestamp > cutoff_time],
|
||||
maxlen=100
|
||||
)
|
||||
|
||||
new_count = len(self.predictions[symbol][timeframe])
|
||||
removed_count = old_count - new_count
|
||||
|
||||
if removed_count > 0:
|
||||
logger.info(f"Cleaned up {removed_count} old predictions for "
|
||||
f"{symbol} {timeframe.value}")
|
||||
|
||||
def force_evaluate_timeframe_predictions(self, symbol: str, timeframe: TimeFrame) -> List[Tuple[PredictionRecord, float]]:
|
||||
"""
|
||||
Force evaluation of all pending predictions for a specific timeframe
|
||||
Useful for immediate training needs
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Specific timeframe to evaluate
|
||||
|
||||
Returns:
|
||||
List of (prediction, reward) tuples
|
||||
"""
|
||||
results = []
|
||||
|
||||
with self.lock:
|
||||
if symbol not in self.predictions or timeframe not in self.predictions[symbol]:
|
||||
return results
|
||||
|
||||
# Evaluate all non-evaluated predictions
|
||||
for prediction in self.predictions[symbol][timeframe]:
|
||||
if not prediction.is_evaluated:
|
||||
reward = self._calculate_prediction_reward(prediction)
|
||||
if reward is not None:
|
||||
results.append((prediction, reward))
|
||||
self._update_accuracy_tracking(symbol, timeframe, prediction)
|
||||
|
||||
return results
|
||||
|
346
core/enhanced_reward_system_integration.py
Normal file
346
core/enhanced_reward_system_integration.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
Enhanced Reward System Integration
|
||||
|
||||
This module provides a simple integration point for the new MSE-based reward system
|
||||
with the existing trading orchestrator and training infrastructure.
|
||||
|
||||
Key Features:
|
||||
- Easy integration with existing TradingOrchestrator
|
||||
- Minimal changes required to existing code
|
||||
- Backward compatibility maintained
|
||||
- Enhanced performance monitoring
|
||||
- Real-time training with MSE rewards
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
|
||||
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame
|
||||
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator
|
||||
from core.enhanced_rl_training_adapter import EnhancedRLTrainingAdapter
|
||||
from core.unified_training_manager import UnifiedTrainingManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EnhancedRewardSystemIntegration:
|
||||
"""
|
||||
Complete integration of the enhanced reward system
|
||||
|
||||
This class provides a single integration point that can be easily added
|
||||
to the existing TradingOrchestrator to enable MSE-based rewards and
|
||||
multi-timeframe training.
|
||||
"""
|
||||
|
||||
def __init__(self, orchestrator: Any, symbols: list = None):
|
||||
"""
|
||||
Initialize the enhanced reward system integration
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance
|
||||
symbols: List of symbols to track (defaults to ETH/USDT, BTC/USDT)
|
||||
"""
|
||||
self.orchestrator = orchestrator
|
||||
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||
|
||||
# Initialize core components
|
||||
self.reward_calculator = EnhancedRewardCalculator(symbols=self.symbols)
|
||||
|
||||
self.inference_coordinator = TimeframeInferenceCoordinator(
|
||||
reward_calculator=self.reward_calculator,
|
||||
data_provider=getattr(orchestrator, 'data_provider', None),
|
||||
symbols=self.symbols
|
||||
)
|
||||
|
||||
self.training_adapter = EnhancedRLTrainingAdapter(
|
||||
reward_calculator=self.reward_calculator,
|
||||
inference_coordinator=self.inference_coordinator,
|
||||
orchestrator=orchestrator,
|
||||
training_system=getattr(orchestrator, 'enhanced_training_system', None)
|
||||
)
|
||||
|
||||
# Unified Training Manager (always available)
|
||||
self.unified_training = UnifiedTrainingManager(
|
||||
orchestrator=orchestrator,
|
||||
reward_system=self,
|
||||
)
|
||||
|
||||
# Integration state
|
||||
self.is_running = False
|
||||
self.integration_stats = {
|
||||
'start_time': None,
|
||||
'total_predictions_tracked': 0,
|
||||
'total_rewards_calculated': 0,
|
||||
'total_training_batches': 0
|
||||
}
|
||||
|
||||
logger.info(f"EnhancedRewardSystemIntegration initialized for symbols: {self.symbols}")
|
||||
|
||||
async def start_integration(self):
|
||||
"""Start the enhanced reward system integration"""
|
||||
if self.is_running:
|
||||
logger.warning("Enhanced reward system already running")
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Starting Enhanced Reward System Integration")
|
||||
|
||||
# Start core components
|
||||
await self.inference_coordinator.start_coordination()
|
||||
await self.training_adapter.start_training_loop()
|
||||
await self.unified_training.start()
|
||||
|
||||
# Start price monitoring
|
||||
asyncio.create_task(self._price_monitoring_loop())
|
||||
|
||||
self.is_running = True
|
||||
self.integration_stats['start_time'] = datetime.now().isoformat()
|
||||
|
||||
logger.info("Enhanced Reward System Integration started successfully")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting enhanced reward system integration: {e}")
|
||||
await self.stop_integration()
|
||||
|
||||
async def stop_integration(self):
|
||||
"""Stop the enhanced reward system integration"""
|
||||
if not self.is_running:
|
||||
return
|
||||
|
||||
try:
|
||||
logger.info("Stopping Enhanced Reward System Integration")
|
||||
|
||||
# Stop components
|
||||
await self.inference_coordinator.stop_coordination()
|
||||
await self.training_adapter.stop_training_loop()
|
||||
await self.unified_training.stop()
|
||||
|
||||
self.is_running = False
|
||||
|
||||
logger.info("Enhanced Reward System Integration stopped")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping enhanced reward system integration: {e}")
|
||||
|
||||
async def _price_monitoring_loop(self):
|
||||
"""Monitor prices and update the reward calculator"""
|
||||
while self.is_running:
|
||||
try:
|
||||
# Update current prices for all symbols
|
||||
for symbol in self.symbols:
|
||||
current_price = await self._get_current_price(symbol)
|
||||
if current_price > 0:
|
||||
self.reward_calculator.update_price(symbol, current_price)
|
||||
|
||||
# Sleep for 1 second between updates
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in price monitoring loop: {e}")
|
||||
await asyncio.sleep(5.0) # Wait longer on error
|
||||
|
||||
async def _get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider'):
|
||||
price = self.orchestrator.data_provider.get_current_price(symbol)
|
||||
return float(price) if price is not None else 0.0
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
|
||||
return 0.0
|
||||
|
||||
def add_prediction_manually(self, symbol: str, timeframe_str: str,
|
||||
predicted_price: float, predicted_direction: int,
|
||||
confidence: float, current_price: float,
|
||||
model_name: str) -> str:
|
||||
"""
|
||||
Manually add a prediction to the reward calculator
|
||||
|
||||
This method allows existing code to easily integrate with the new reward system
|
||||
without major changes.
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
timeframe_str: Timeframe string ('1s', '1m', '1h', '1d')
|
||||
predicted_price: Model's predicted price
|
||||
predicted_direction: Predicted direction (-1, 0, 1)
|
||||
confidence: Model's confidence (0.0 to 1.0)
|
||||
current_price: Current market price
|
||||
model_name: Name of the model making prediction
|
||||
|
||||
Returns:
|
||||
Unique prediction ID
|
||||
"""
|
||||
try:
|
||||
# Convert timeframe string to enum
|
||||
timeframe = TimeFrame(timeframe_str)
|
||||
|
||||
prediction_id = self.reward_calculator.add_prediction(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
predicted_price=predicted_price,
|
||||
predicted_direction=predicted_direction,
|
||||
confidence=confidence,
|
||||
current_price=current_price,
|
||||
model_name=model_name
|
||||
)
|
||||
|
||||
self.integration_stats['total_predictions_tracked'] += 1
|
||||
|
||||
return prediction_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding prediction manually: {e}")
|
||||
return ""
|
||||
|
||||
def get_model_accuracy(self, model_name: str = None, symbol: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get accuracy statistics for models
|
||||
|
||||
Args:
|
||||
model_name: Specific model name (None for all)
|
||||
symbol: Specific symbol (None for all)
|
||||
|
||||
Returns:
|
||||
Dictionary with accuracy statistics
|
||||
"""
|
||||
try:
|
||||
accuracy_summary = self.reward_calculator.get_accuracy_summary(symbol)
|
||||
|
||||
if model_name:
|
||||
# Filter by model name in prediction history
|
||||
# This would require enhancing the reward calculator to track by model
|
||||
pass
|
||||
|
||||
return accuracy_summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model accuracy: {e}")
|
||||
return {}
|
||||
|
||||
def force_evaluation_and_training(self, symbol: str = None, timeframe_str: str = None):
|
||||
"""
|
||||
Force immediate evaluation and training for debugging/testing
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol (None for all)
|
||||
timeframe_str: Specific timeframe (None for all)
|
||||
"""
|
||||
try:
|
||||
if timeframe_str:
|
||||
timeframe = TimeFrame(timeframe_str)
|
||||
symbols_to_process = [symbol] if symbol else self.symbols
|
||||
|
||||
for sym in symbols_to_process:
|
||||
# Force evaluation of predictions
|
||||
results = self.reward_calculator.force_evaluate_timeframe_predictions(sym, timeframe)
|
||||
logger.info(f"Force evaluated {len(results)} predictions for {sym} {timeframe.value}")
|
||||
else:
|
||||
# Evaluate all pending predictions
|
||||
for sym in (self.symbols if not symbol else [symbol]):
|
||||
results = self.reward_calculator.evaluate_predictions(sym)
|
||||
if sym in results:
|
||||
logger.info(f"Force evaluated {len(results[sym])} predictions for {sym}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in force evaluation and training: {e}")
|
||||
|
||||
def get_integration_statistics(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive integration statistics"""
|
||||
try:
|
||||
stats = self.integration_stats.copy()
|
||||
|
||||
# Add component statistics
|
||||
stats['inference_coordinator'] = self.inference_coordinator.get_inference_statistics()
|
||||
stats['training_adapter'] = self.training_adapter.get_training_statistics()
|
||||
stats['reward_calculator'] = self.reward_calculator.get_accuracy_summary()
|
||||
|
||||
# Add system status
|
||||
stats['is_running'] = self.is_running
|
||||
stats['components_running'] = {
|
||||
'inference_coordinator': self.inference_coordinator.running,
|
||||
'training_adapter': self.training_adapter.running
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting integration statistics: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def cleanup_old_data(self, days_to_keep: int = 7):
|
||||
"""Clean up old prediction data to manage memory"""
|
||||
try:
|
||||
self.reward_calculator.cleanup_old_predictions(days_to_keep)
|
||||
logger.info(f"Cleaned up prediction data older than {days_to_keep} days")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up old data: {e}")
|
||||
|
||||
|
||||
# Utility functions for easy integration
|
||||
|
||||
def integrate_enhanced_rewards(orchestrator: Any, symbols: list = None) -> EnhancedRewardSystemIntegration:
|
||||
"""
|
||||
Utility function to easily integrate enhanced rewards with an existing orchestrator
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance
|
||||
symbols: List of symbols to track
|
||||
|
||||
Returns:
|
||||
EnhancedRewardSystemIntegration instance
|
||||
"""
|
||||
integration = EnhancedRewardSystemIntegration(orchestrator, symbols)
|
||||
|
||||
# Add integration as an attribute to the orchestrator for easy access
|
||||
setattr(orchestrator, 'enhanced_reward_system', integration)
|
||||
|
||||
logger.info("Enhanced reward system integrated with orchestrator")
|
||||
return integration
|
||||
|
||||
|
||||
async def start_enhanced_rewards_for_orchestrator(orchestrator: Any, symbols: list = None):
|
||||
"""
|
||||
Start enhanced rewards for an existing orchestrator
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance
|
||||
symbols: List of symbols to track
|
||||
"""
|
||||
if not hasattr(orchestrator, 'enhanced_reward_system'):
|
||||
integrate_enhanced_rewards(orchestrator, symbols)
|
||||
|
||||
await orchestrator.enhanced_reward_system.start_integration()
|
||||
|
||||
|
||||
def add_prediction_to_enhanced_rewards(orchestrator: Any, symbol: str, timeframe: str,
|
||||
predicted_price: float, direction: int, confidence: float,
|
||||
current_price: float, model_name: str) -> str:
|
||||
"""
|
||||
Helper function to add predictions to enhanced rewards from existing code
|
||||
|
||||
Args:
|
||||
orchestrator: TradingOrchestrator instance with enhanced_reward_system
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe string
|
||||
predicted_price: Predicted price
|
||||
direction: Predicted direction (-1, 0, 1)
|
||||
confidence: Model confidence
|
||||
current_price: Current market price
|
||||
model_name: Model name
|
||||
|
||||
Returns:
|
||||
Prediction ID
|
||||
"""
|
||||
if hasattr(orchestrator, 'enhanced_reward_system'):
|
||||
return orchestrator.enhanced_reward_system.add_prediction_manually(
|
||||
symbol, timeframe, predicted_price, direction, confidence, current_price, model_name
|
||||
)
|
||||
|
||||
logger.warning("Enhanced reward system not integrated with orchestrator")
|
||||
return ""
|
||||
|
595
core/enhanced_rl_training_adapter.py
Normal file
595
core/enhanced_rl_training_adapter.py
Normal file
@@ -0,0 +1,595 @@
|
||||
"""
|
||||
Enhanced RL Training Adapter
|
||||
|
||||
This module integrates the new MSE-based reward system with existing RL training pipelines.
|
||||
It provides a bridge between the timeframe-aware inference coordinator and the existing
|
||||
model training infrastructure.
|
||||
|
||||
Key Features:
|
||||
- Integration with EnhancedRewardCalculator
|
||||
- Adaptation of existing RL models to new reward system
|
||||
- Real-time training triggers based on prediction outcomes
|
||||
- Multi-timeframe training coordination
|
||||
- Backward compatibility with existing training infrastructure
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame, PredictionRecord
|
||||
from core.timeframe_inference_coordinator import TimeframeInferenceCoordinator, InferenceContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingBatch:
|
||||
"""Training batch for RL models with enhanced reward data"""
|
||||
model_name: str
|
||||
symbol: str
|
||||
timeframe: TimeFrame
|
||||
states: List[np.ndarray]
|
||||
actions: List[int]
|
||||
rewards: List[float]
|
||||
next_states: List[np.ndarray]
|
||||
dones: List[bool]
|
||||
confidences: List[float]
|
||||
prediction_records: List[PredictionRecord]
|
||||
batch_timestamp: datetime
|
||||
|
||||
|
||||
class EnhancedRLTrainingAdapter:
|
||||
"""
|
||||
Adapter that integrates new reward system with existing RL training infrastructure
|
||||
|
||||
This adapter:
|
||||
1. Bridges new reward calculator with existing RL models
|
||||
2. Converts prediction records to RL training format
|
||||
3. Triggers real-time training based on reward evaluation
|
||||
4. Maintains compatibility with existing training systems
|
||||
5. Coordinates multi-timeframe training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reward_calculator: EnhancedRewardCalculator,
|
||||
inference_coordinator: TimeframeInferenceCoordinator,
|
||||
orchestrator: Any = None,
|
||||
training_system: Any = None):
|
||||
"""
|
||||
Initialize the enhanced RL training adapter
|
||||
|
||||
Args:
|
||||
reward_calculator: Enhanced reward calculator instance
|
||||
inference_coordinator: Timeframe inference coordinator
|
||||
orchestrator: Trading orchestrator (optional)
|
||||
training_system: Enhanced realtime training system (optional)
|
||||
"""
|
||||
self.reward_calculator = reward_calculator
|
||||
self.inference_coordinator = inference_coordinator
|
||||
self.orchestrator = orchestrator
|
||||
self.training_system = training_system
|
||||
|
||||
# Model registry for training functions
|
||||
self.model_trainers: Dict[str, Any] = {}
|
||||
|
||||
# Training configuration
|
||||
self.min_batch_size = 8 # Minimum samples for training
|
||||
self.max_batch_size = 64 # Maximum samples per training batch
|
||||
self.training_interval_seconds = 5.0 # How often to check for training opportunities
|
||||
|
||||
# Training statistics
|
||||
self.training_stats = {
|
||||
'total_training_batches': 0,
|
||||
'successful_training_calls': 0,
|
||||
'failed_training_calls': 0,
|
||||
'last_training_time': None,
|
||||
'training_times_per_model': {},
|
||||
'average_batch_sizes': {}
|
||||
}
|
||||
|
||||
# State conversion helpers
|
||||
self.state_builders: Dict[str, Any] = {}
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Running state
|
||||
self.running = False
|
||||
self.training_task: Optional[asyncio.Task] = None
|
||||
|
||||
logger.info("EnhancedRLTrainingAdapter initialized")
|
||||
self._register_default_model_handlers()
|
||||
|
||||
def _register_default_model_handlers(self):
|
||||
"""Register default model handlers for existing models"""
|
||||
# Register inference functions with the coordinator
|
||||
if self.inference_coordinator:
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'dqn_agent', self._dqn_inference_wrapper
|
||||
)
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'cob_rl', self._cob_rl_inference_wrapper
|
||||
)
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'enhanced_cnn', self._cnn_inference_wrapper
|
||||
)
|
||||
|
||||
async def _dqn_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for DQN model inference"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
|
||||
# Get base data for the symbol
|
||||
base_data = await self._get_base_data(context.symbol)
|
||||
if base_data is None:
|
||||
return None
|
||||
|
||||
# Convert to DQN state format
|
||||
state = self._convert_to_dqn_state(base_data, context)
|
||||
|
||||
# Run DQN prediction
|
||||
if hasattr(self.orchestrator.rl_agent, 'act'):
|
||||
action_idx = self.orchestrator.rl_agent.act(state)
|
||||
# Try to extract confidence from agent if available
|
||||
confidence = getattr(self.orchestrator.rl_agent, 'last_confidence', None)
|
||||
if confidence is None:
|
||||
confidence = 0.5
|
||||
|
||||
# Convert action to prediction format
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
direction = action_idx - 1 # Convert 0,1,2 to -1,0,1
|
||||
|
||||
# Use real current price
|
||||
current_price = self._safe_get_current_price(context.symbol)
|
||||
|
||||
# Do not fabricate price; set predicted_price only if model provides numeric target later
|
||||
return {
|
||||
'predicted_price': current_price, # same as current when no numeric target available
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': float(confidence),
|
||||
'action': action_names[action_idx],
|
||||
'model_state': (state.tolist() if hasattr(state, 'tolist') else state),
|
||||
'context': context
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in DQN inference wrapper: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _cob_rl_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for COB RL model inference"""
|
||||
try:
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
||||
self.orchestrator.realtime_rl_trader):
|
||||
|
||||
# Get COB features
|
||||
features = await self._get_cob_features(context.symbol)
|
||||
if features is None:
|
||||
return None
|
||||
|
||||
# Run COB RL prediction
|
||||
prediction = self.orchestrator.realtime_rl_trader._predict(context.symbol, features)
|
||||
|
||||
if prediction:
|
||||
current_price = self._safe_get_current_price(context.symbol)
|
||||
# If 'change' is available assume it is a fractional return
|
||||
change = prediction.get('change', None)
|
||||
predicted_price = current_price * (1 + change) if (change is not None and current_price) else current_price
|
||||
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'current_price': current_price,
|
||||
'direction': prediction.get('direction', 0),
|
||||
'confidence': prediction.get('confidence', 0.0),
|
||||
'change': prediction.get('change', 0.0),
|
||||
'model_features': features,
|
||||
'context': context
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in COB RL inference wrapper: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _cnn_inference_wrapper(self, context: InferenceContext) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for CNN model inference"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'model_registry'):
|
||||
# Find CNN models in registry
|
||||
for model_name, model in self.orchestrator.model_registry.models.items():
|
||||
if 'cnn' in model_name.lower():
|
||||
# Get base data
|
||||
base_data = await self._get_base_data(context.symbol)
|
||||
if base_data is None:
|
||||
continue
|
||||
|
||||
# Run CNN prediction
|
||||
if hasattr(model, 'predict_from_base_input'):
|
||||
model_output = model.predict_from_base_input(base_data)
|
||||
|
||||
# Extract current price from data provider
|
||||
current_price = self._safe_get_current_price(context.symbol)
|
||||
|
||||
# Extract prediction data
|
||||
predictions = model_output.predictions
|
||||
action = predictions.get('action', 'HOLD')
|
||||
confidence = predictions.get('confidence', 0.0)
|
||||
|
||||
# Convert action to direction only for classification signal
|
||||
direction = {'BUY': 1, 'SELL': -1, 'HOLD': 0}.get(action, 0)
|
||||
|
||||
# Use numeric predicted return if provided (no synthetic fabrication)
|
||||
pr_map = {
|
||||
TimeFrame.SECONDS_1: 'predicted_return_1s',
|
||||
TimeFrame.MINUTES_1: 'predicted_return_1m',
|
||||
TimeFrame.HOURS_1: 'predicted_return_1h',
|
||||
TimeFrame.DAYS_1: 'predicted_return_1d',
|
||||
}
|
||||
ret_key = pr_map.get(context.target_timeframe)
|
||||
predicted_return = None
|
||||
if ret_key and ret_key in predictions:
|
||||
predicted_return = float(predictions.get(ret_key))
|
||||
|
||||
predicted_price = current_price * (1 + predicted_return) if (predicted_return is not None and current_price) else current_price
|
||||
|
||||
# Also attach DQN-formatted state if available for training consumption
|
||||
dqn_state = self._convert_to_dqn_state(base_data, context)
|
||||
return {
|
||||
'predicted_price': predicted_price,
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': confidence,
|
||||
'predicted_return': predicted_return,
|
||||
'action': action,
|
||||
'model_output': model_output,
|
||||
'model_state': (dqn_state.tolist() if hasattr(dqn_state, 'tolist') else dqn_state),
|
||||
'context': context
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN inference wrapper: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_base_data(self, symbol: str) -> Optional[Any]:
|
||||
"""Get base data for a symbol"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
||||
# Use orchestrator's data provider
|
||||
return await self.orchestrator._build_base_data(symbol)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting base data for {symbol}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_cob_features(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Get COB features for a symbol"""
|
||||
try:
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
||||
self.orchestrator.realtime_rl_trader):
|
||||
|
||||
# Get latest features from COB trader
|
||||
feature_buffers = self.orchestrator.realtime_rl_trader.feature_buffers
|
||||
if symbol in feature_buffers and feature_buffers[symbol]:
|
||||
latest_data = feature_buffers[symbol][-1]
|
||||
return latest_data.get('features')
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting COB features for {symbol}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _safe_get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price for a symbol via DataProvider API"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
||||
price = self.orchestrator.data_provider.get_current_price(symbol)
|
||||
return float(price) if price is not None else 0.0
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price for {symbol}: {e}")
|
||||
return 0.0
|
||||
|
||||
def _convert_to_dqn_state(self, base_data: Any, context: InferenceContext) -> np.ndarray:
|
||||
"""Convert base data to DQN state format"""
|
||||
try:
|
||||
# Use existing state building logic if available
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'enhanced_training_system') and
|
||||
hasattr(self.orchestrator.enhanced_training_system, '_build_dqn_state')):
|
||||
|
||||
return self.orchestrator.enhanced_training_system._build_dqn_state(
|
||||
base_data, context.symbol
|
||||
)
|
||||
|
||||
# Fallback: create simple state representation
|
||||
feature_vector = base_data.get_feature_vector() if hasattr(base_data, 'get_feature_vector') else []
|
||||
if feature_vector:
|
||||
return np.array(feature_vector, dtype=np.float32)
|
||||
|
||||
# Last resort: create minimal state
|
||||
return np.zeros(100, dtype=np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting to DQN state: {e}")
|
||||
return np.zeros(100, dtype=np.float32)
|
||||
|
||||
async def start_training_loop(self):
|
||||
"""Start the enhanced training loop"""
|
||||
if self.running:
|
||||
logger.warning("Training loop already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
self.training_task = asyncio.create_task(self._training_loop())
|
||||
logger.info("Enhanced RL training loop started")
|
||||
|
||||
async def stop_training_loop(self):
|
||||
"""Stop the enhanced training loop"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
if self.training_task:
|
||||
self.training_task.cancel()
|
||||
try:
|
||||
await self.training_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("Enhanced RL training loop stopped")
|
||||
|
||||
async def _training_loop(self):
|
||||
"""Main training loop that processes evaluated predictions"""
|
||||
logger.info("Starting enhanced RL training loop")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Process training for each symbol and timeframe
|
||||
for symbol in self.reward_calculator.symbols:
|
||||
for timeframe in [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
|
||||
TimeFrame.HOURS_1, TimeFrame.DAYS_1]:
|
||||
|
||||
# Get training data for this symbol/timeframe
|
||||
training_data = self.reward_calculator.get_training_data(
|
||||
symbol, timeframe, self.max_batch_size
|
||||
)
|
||||
|
||||
if len(training_data) >= self.min_batch_size:
|
||||
await self._process_training_batch(symbol, timeframe, training_data)
|
||||
|
||||
# Sleep between training checks
|
||||
await asyncio.sleep(self.training_interval_seconds)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in training loop: {e}")
|
||||
await asyncio.sleep(10) # Wait longer on error
|
||||
|
||||
async def _process_training_batch(self, symbol: str, timeframe: TimeFrame,
|
||||
training_data: List[Tuple[PredictionRecord, float]]):
|
||||
"""
|
||||
Process a training batch for a specific symbol/timeframe
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe for training
|
||||
training_data: List of (prediction_record, reward) tuples
|
||||
"""
|
||||
try:
|
||||
# Group training data by model
|
||||
model_batches = {}
|
||||
|
||||
for prediction_record, reward in training_data:
|
||||
model_name = prediction_record.model_name
|
||||
if model_name not in model_batches:
|
||||
model_batches[model_name] = []
|
||||
model_batches[model_name].append((prediction_record, reward))
|
||||
|
||||
# Process each model's batch
|
||||
for model_name, model_data in model_batches.items():
|
||||
if len(model_data) >= self.min_batch_size:
|
||||
await self._train_model_batch(model_name, symbol, timeframe, model_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing training batch for {symbol} {timeframe.value}: {e}")
|
||||
|
||||
async def _train_model_batch(self, model_name: str, symbol: str, timeframe: TimeFrame,
|
||||
training_data: List[Tuple[PredictionRecord, float]]):
|
||||
"""
|
||||
Train a specific model with a batch of data
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to train
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe for training
|
||||
training_data: List of (prediction_record, reward) tuples
|
||||
"""
|
||||
try:
|
||||
training_start = time.time()
|
||||
|
||||
# Convert to training batch format
|
||||
batch = self._create_training_batch(model_name, symbol, timeframe, training_data)
|
||||
|
||||
if batch is None:
|
||||
return
|
||||
|
||||
# Call appropriate training function based on model type
|
||||
success = False
|
||||
|
||||
if 'dqn' in model_name.lower():
|
||||
success = await self._train_dqn_model(batch)
|
||||
elif 'cob' in model_name.lower():
|
||||
success = await self._train_cob_rl_model(batch)
|
||||
elif 'cnn' in model_name.lower():
|
||||
success = await self._train_cnn_model(batch)
|
||||
else:
|
||||
logger.warning(f"Unknown model type for training: {model_name}")
|
||||
|
||||
# Update statistics
|
||||
training_time = time.time() - training_start
|
||||
self._update_training_stats(model_name, batch, success, training_time)
|
||||
|
||||
if success:
|
||||
logger.info(f"Successfully trained {model_name} on {symbol} {timeframe.value} "
|
||||
f"with {len(training_data)} samples in {training_time:.3f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training model {model_name}: {e}")
|
||||
self._update_training_stats(model_name, None, False, 0)
|
||||
|
||||
def _create_training_batch(self, model_name: str, symbol: str, timeframe: TimeFrame,
|
||||
training_data: List[Tuple[PredictionRecord, float]]) -> Optional[TrainingBatch]:
|
||||
"""Create a training batch from prediction records and rewards"""
|
||||
try:
|
||||
states = []
|
||||
actions = []
|
||||
rewards = []
|
||||
next_states = []
|
||||
dones = []
|
||||
confidences = []
|
||||
prediction_records = []
|
||||
|
||||
for prediction_record, reward in training_data:
|
||||
# Extract state information
|
||||
# This would need to be adapted based on how states are stored
|
||||
state = np.zeros(100)
|
||||
next_state = state.copy() # Simplified next state
|
||||
|
||||
# Convert direction to action
|
||||
direction = prediction_record.predicted_direction
|
||||
action = direction + 1 # Convert -1,0,1 to 0,1,2
|
||||
|
||||
states.append(state)
|
||||
actions.append(action)
|
||||
rewards.append(reward)
|
||||
next_states.append(next_state)
|
||||
dones.append(True) # Each prediction is treated as terminal
|
||||
confidences.append(prediction_record.confidence)
|
||||
prediction_records.append(prediction_record)
|
||||
|
||||
return TrainingBatch(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
states=states,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_states=next_states,
|
||||
dones=dones,
|
||||
confidences=confidences,
|
||||
prediction_records=prediction_records,
|
||||
batch_timestamp=datetime.now()
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating training batch: {e}")
|
||||
return None
|
||||
|
||||
async def _train_dqn_model(self, batch: TrainingBatch) -> bool:
|
||||
"""Train DQN model with batch data"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
|
||||
rl_agent = self.orchestrator.rl_agent
|
||||
|
||||
# Add experiences to memory
|
||||
for i in range(len(batch.states)):
|
||||
if hasattr(rl_agent, 'remember'):
|
||||
rl_agent.remember(
|
||||
state=batch.states[i],
|
||||
action=batch.actions[i],
|
||||
reward=batch.rewards[i],
|
||||
next_state=batch.next_states[i],
|
||||
done=batch.dones[i]
|
||||
)
|
||||
|
||||
# Trigger training if enough experiences
|
||||
if hasattr(rl_agent, 'replay') and hasattr(rl_agent, 'memory'):
|
||||
if len(rl_agent.memory) >= getattr(rl_agent, 'batch_size', 32):
|
||||
loss = rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"DQN training loss: {loss:.6f}")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training DQN model: {e}")
|
||||
return False
|
||||
|
||||
async def _train_cob_rl_model(self, batch: TrainingBatch) -> bool:
|
||||
"""Train COB RL model with batch data"""
|
||||
try:
|
||||
if (self.orchestrator and
|
||||
hasattr(self.orchestrator, 'realtime_rl_trader') and
|
||||
self.orchestrator.realtime_rl_trader):
|
||||
|
||||
# Use COB RL trainer if available
|
||||
# This is a placeholder - implement based on actual COB RL training interface
|
||||
logger.debug(f"COB RL training batch: {len(batch.states)} samples")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training COB RL model: {e}")
|
||||
return False
|
||||
|
||||
async def _train_cnn_model(self, batch: TrainingBatch) -> bool:
|
||||
"""Train CNN model with batch data"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'enhanced_training_system'):
|
||||
# Use enhanced training system for CNN training
|
||||
# This is a placeholder - implement based on actual CNN training interface
|
||||
logger.debug(f"CNN training batch: {len(batch.states)} samples")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training CNN model: {e}")
|
||||
return False
|
||||
|
||||
def _update_training_stats(self, model_name: str, batch: Optional[TrainingBatch],
|
||||
success: bool, training_time: float):
|
||||
"""Update training statistics"""
|
||||
with self.lock:
|
||||
self.training_stats['total_training_batches'] += 1
|
||||
|
||||
if success:
|
||||
self.training_stats['successful_training_calls'] += 1
|
||||
else:
|
||||
self.training_stats['failed_training_calls'] += 1
|
||||
|
||||
self.training_stats['last_training_time'] = datetime.now().isoformat()
|
||||
|
||||
# Model-specific stats
|
||||
if model_name not in self.training_stats['training_times_per_model']:
|
||||
self.training_stats['training_times_per_model'][model_name] = []
|
||||
self.training_stats['average_batch_sizes'][model_name] = []
|
||||
|
||||
self.training_stats['training_times_per_model'][model_name].append(training_time)
|
||||
|
||||
if batch:
|
||||
self.training_stats['average_batch_sizes'][model_name].append(len(batch.states))
|
||||
|
||||
def get_training_statistics(self) -> Dict[str, Any]:
|
||||
"""Get training statistics"""
|
||||
with self.lock:
|
||||
stats = self.training_stats.copy()
|
||||
|
||||
# Calculate averages
|
||||
for model_name in stats['training_times_per_model']:
|
||||
times = stats['training_times_per_model'][model_name]
|
||||
if times:
|
||||
stats[f'{model_name}_avg_training_time'] = sum(times) / len(times)
|
||||
|
||||
sizes = stats['average_batch_sizes'][model_name]
|
||||
if sizes:
|
||||
stats[f'{model_name}_avg_batch_size'] = sum(sizes) / len(sizes)
|
||||
|
||||
return stats
|
||||
|
383
core/llm_proxy.py
Normal file
383
core/llm_proxy.py
Normal file
@@ -0,0 +1,383 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
LLM Proxy Model - Interface for LLM-based trading signals
|
||||
Sends market data to LLM endpoint and parses responses for trade signals
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import requests
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class LLMTradeSignal:
|
||||
"""Trade signal from LLM"""
|
||||
symbol: str
|
||||
action: str # 'BUY', 'SELL', 'HOLD'
|
||||
confidence: float # 0.0 to 1.0
|
||||
reasoning: str
|
||||
price_target: Optional[float] = None
|
||||
stop_loss: Optional[float] = None
|
||||
timestamp: Optional[datetime] = None
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""LLM configuration"""
|
||||
base_url: str = "http://localhost:1234"
|
||||
model: str = "openai/gpt-oss-20b"
|
||||
temperature: float = 0.7
|
||||
max_tokens: int = -1
|
||||
timeout: int = 30
|
||||
api_key: Optional[str] = None
|
||||
|
||||
class LLMProxy:
|
||||
"""
|
||||
LLM Proxy for trading signal generation
|
||||
|
||||
Features:
|
||||
- Configurable LLM endpoint and model
|
||||
- Processes market data from TextDataExporter files
|
||||
- Generates structured trading signals
|
||||
- Thread-safe operations
|
||||
- Error handling and retry logic
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
config: Optional[LLMConfig] = None,
|
||||
data_dir: str = "NN/training/samples/txt"):
|
||||
"""
|
||||
Initialize LLM proxy
|
||||
|
||||
Args:
|
||||
config: LLM configuration
|
||||
data_dir: Directory to watch for market data files
|
||||
"""
|
||||
self.config = config or LLMConfig()
|
||||
self.data_dir = data_dir
|
||||
|
||||
# Processing state
|
||||
self.is_running = False
|
||||
self.processing_thread = None
|
||||
self.processed_files = set()
|
||||
|
||||
# Signal storage
|
||||
self.latest_signals: Dict[str, LLMTradeSignal] = {}
|
||||
self.signal_history: List[LLMTradeSignal] = []
|
||||
self.lock = threading.Lock()
|
||||
|
||||
# System prompt for trading
|
||||
self.system_prompt = """You are an expert cryptocurrency trading analyst.
|
||||
You will receive market data for ETH (main symbol) with reference data for BTC and SPX.
|
||||
Analyze the multi-timeframe data (1s, 1m, 1h, 1d) and provide trading recommendations.
|
||||
|
||||
Respond ONLY with valid JSON in this format:
|
||||
{
|
||||
"action": "BUY|SELL|HOLD",
|
||||
"confidence": 0.0-1.0,
|
||||
"reasoning": "brief analysis",
|
||||
"price_target": number_or_null,
|
||||
"stop_loss": number_or_null
|
||||
}
|
||||
|
||||
Consider market correlations, timeframe divergences, and risk management.
|
||||
"""
|
||||
|
||||
logger.info(f"LLM Proxy initialized - Model: {self.config.model}")
|
||||
logger.info(f"Watching directory: {self.data_dir}")
|
||||
|
||||
def start(self):
|
||||
"""Start LLM processing"""
|
||||
if self.is_running:
|
||||
logger.warning("LLM proxy already running")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.processing_thread = threading.Thread(target=self._processing_loop, daemon=True)
|
||||
self.processing_thread.start()
|
||||
logger.info("LLM proxy started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop LLM processing"""
|
||||
self.is_running = False
|
||||
if self.processing_thread:
|
||||
self.processing_thread.join(timeout=5)
|
||||
logger.info("LLM proxy stopped")
|
||||
|
||||
def _processing_loop(self):
|
||||
"""Main processing loop - checks for new files"""
|
||||
while self.is_running:
|
||||
try:
|
||||
self._check_for_new_files()
|
||||
time.sleep(5) # Check every 5 seconds
|
||||
except Exception as e:
|
||||
logger.error(f"Error in LLM processing loop: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
def _check_for_new_files(self):
|
||||
"""Check for new market data files"""
|
||||
try:
|
||||
if not os.path.exists(self.data_dir):
|
||||
return
|
||||
|
||||
txt_files = [f for f in os.listdir(self.data_dir)
|
||||
if f.endswith('.txt') and f.startswith('market_data_')]
|
||||
|
||||
for filename in txt_files:
|
||||
if filename not in self.processed_files:
|
||||
filepath = os.path.join(self.data_dir, filename)
|
||||
self._process_file(filepath, filename)
|
||||
self.processed_files.add(filename)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking for new files: {e}")
|
||||
|
||||
def _process_file(self, filepath: str, filename: str):
|
||||
"""Process a market data file"""
|
||||
try:
|
||||
logger.info(f"Processing market data file: {filename}")
|
||||
|
||||
# Read and parse market data
|
||||
market_data = self._parse_market_data(filepath)
|
||||
if not market_data:
|
||||
logger.warning(f"No valid market data in {filename}")
|
||||
return
|
||||
|
||||
# Generate LLM prompt
|
||||
prompt = self._create_trading_prompt(market_data)
|
||||
|
||||
# Send to LLM
|
||||
response = self._query_llm(prompt)
|
||||
if not response:
|
||||
logger.warning(f"No response from LLM for {filename}")
|
||||
return
|
||||
|
||||
# Parse response
|
||||
signal = self._parse_llm_response(response, market_data)
|
||||
if signal:
|
||||
with self.lock:
|
||||
self.latest_signals['ETH'] = signal
|
||||
self.signal_history.append(signal)
|
||||
# Keep only last 100 signals
|
||||
if len(self.signal_history) > 100:
|
||||
self.signal_history = self.signal_history[-100:]
|
||||
|
||||
logger.info(f"Generated signal: {signal.action} ({signal.confidence:.2f}) - {signal.reasoning}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file {filename}: {e}")
|
||||
|
||||
def _parse_market_data(self, filepath: str) -> Optional[Dict[str, Any]]:
|
||||
"""Parse market data from text file"""
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
if len(lines) < 4: # Need header + data
|
||||
return None
|
||||
|
||||
# Find data line (skip headers)
|
||||
data_line = None
|
||||
for line in lines[3:]: # Skip the 3 header lines
|
||||
if line.strip() and not line.startswith('symbol'):
|
||||
data_line = line.strip()
|
||||
break
|
||||
|
||||
if not data_line:
|
||||
return None
|
||||
|
||||
# Parse tab-separated data
|
||||
parts = data_line.split('\t')
|
||||
if len(parts) < 25: # Need minimum data
|
||||
return None
|
||||
|
||||
# Extract structured data
|
||||
parsed_data = {
|
||||
'timestamp': parts[0],
|
||||
'eth_1s': self._extract_ohlcv(parts[1:7]),
|
||||
'eth_1m': self._extract_ohlcv(parts[7:13]),
|
||||
'eth_1h': self._extract_ohlcv(parts[13:19]),
|
||||
'eth_1d': self._extract_ohlcv(parts[19:25]),
|
||||
'btc_1s': self._extract_ohlcv(parts[25:31]) if len(parts) > 25 else None,
|
||||
'spx_1s': self._extract_ohlcv(parts[31:37]) if len(parts) > 31 else None
|
||||
}
|
||||
|
||||
return parsed_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing market data: {e}")
|
||||
return None
|
||||
|
||||
def _extract_ohlcv(self, data_parts: List[str]) -> Dict[str, float]:
|
||||
"""Extract OHLCV data from parts"""
|
||||
try:
|
||||
return {
|
||||
'open': float(data_parts[0]) if data_parts[0] != '0' else 0.0,
|
||||
'high': float(data_parts[1]) if data_parts[1] != '0' else 0.0,
|
||||
'low': float(data_parts[2]) if data_parts[2] != '0' else 0.0,
|
||||
'close': float(data_parts[3]) if data_parts[3] != '0' else 0.0,
|
||||
'volume': float(data_parts[4]) if data_parts[4] != '0' else 0.0,
|
||||
'timestamp': data_parts[5]
|
||||
}
|
||||
except (ValueError, IndexError):
|
||||
return {'open': 0.0, 'high': 0.0, 'low': 0.0, 'close': 0.0, 'volume': 0.0, 'timestamp': ''}
|
||||
|
||||
def _create_trading_prompt(self, market_data: Dict[str, Any]) -> str:
|
||||
"""Create trading prompt from market data"""
|
||||
prompt = f"""Market Data Analysis for ETH/USDT:
|
||||
|
||||
Timestamp: {market_data['timestamp']}
|
||||
|
||||
ETH Multi-timeframe Data:
|
||||
1s: O:{market_data['eth_1s']['open']:.2f} H:{market_data['eth_1s']['high']:.2f} L:{market_data['eth_1s']['low']:.2f} C:{market_data['eth_1s']['close']:.2f} V:{market_data['eth_1s']['volume']:.1f}
|
||||
1m: O:{market_data['eth_1m']['open']:.2f} H:{market_data['eth_1m']['high']:.2f} L:{market_data['eth_1m']['low']:.2f} C:{market_data['eth_1m']['close']:.2f} V:{market_data['eth_1m']['volume']:.1f}
|
||||
1h: O:{market_data['eth_1h']['open']:.2f} H:{market_data['eth_1h']['high']:.2f} L:{market_data['eth_1h']['low']:.2f} C:{market_data['eth_1h']['close']:.2f} V:{market_data['eth_1h']['volume']:.1f}
|
||||
1d: O:{market_data['eth_1d']['open']:.2f} H:{market_data['eth_1d']['high']:.2f} L:{market_data['eth_1d']['low']:.2f} C:{market_data['eth_1d']['close']:.2f} V:{market_data['eth_1d']['volume']:.1f}
|
||||
"""
|
||||
|
||||
if market_data.get('btc_1s'):
|
||||
prompt += f"\nBTC Reference (1s): O:{market_data['btc_1s']['open']:.2f} H:{market_data['btc_1s']['high']:.2f} L:{market_data['btc_1s']['low']:.2f} C:{market_data['btc_1s']['close']:.2f} V:{market_data['btc_1s']['volume']:.1f}"
|
||||
|
||||
if market_data.get('spx_1s'):
|
||||
prompt += f"\nSPX Reference (1s): O:{market_data['spx_1s']['open']:.2f} H:{market_data['spx_1s']['high']:.2f} L:{market_data['spx_1s']['low']:.2f} C:{market_data['spx_1s']['close']:.2f}"
|
||||
|
||||
prompt += "\n\nProvide trading recommendation based on this multi-timeframe analysis."
|
||||
return prompt
|
||||
|
||||
def _query_llm(self, prompt: str) -> Optional[str]:
|
||||
"""Send query to LLM endpoint"""
|
||||
try:
|
||||
url = f"{self.config.base_url}/v1/chat/completions"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
if self.config.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.config.api_key}"
|
||||
|
||||
payload = {
|
||||
"model": self.config.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
"temperature": self.config.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
data=json.dumps(payload),
|
||||
timeout=self.config.timeout
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if 'choices' in result and len(result['choices']) > 0:
|
||||
return result['choices'][0]['message']['content']
|
||||
else:
|
||||
logger.error(f"LLM API error: {response.status_code} - {response.text}")
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying LLM: {e}")
|
||||
return None
|
||||
|
||||
def _parse_llm_response(self, response: str, market_data: Dict[str, Any]) -> Optional[LLMTradeSignal]:
|
||||
"""Parse LLM response into trade signal"""
|
||||
try:
|
||||
# Try to extract JSON from response
|
||||
response = response.strip()
|
||||
if response.startswith('```json'):
|
||||
response = response[7:]
|
||||
if response.endswith('```'):
|
||||
response = response[:-3]
|
||||
|
||||
# Parse JSON
|
||||
data = json.loads(response)
|
||||
|
||||
# Validate required fields
|
||||
if 'action' not in data or 'confidence' not in data:
|
||||
logger.warning("LLM response missing required fields")
|
||||
return None
|
||||
|
||||
# Create signal
|
||||
signal = LLMTradeSignal(
|
||||
symbol='ETH/USDT',
|
||||
action=data['action'].upper(),
|
||||
confidence=float(data['confidence']),
|
||||
reasoning=data.get('reasoning', ''),
|
||||
price_target=data.get('price_target'),
|
||||
stop_loss=data.get('stop_loss'),
|
||||
timestamp=datetime.now()
|
||||
)
|
||||
|
||||
# Validate action
|
||||
if signal.action not in ['BUY', 'SELL', 'HOLD']:
|
||||
logger.warning(f"Invalid action: {signal.action}")
|
||||
return None
|
||||
|
||||
# Validate confidence
|
||||
signal.confidence = max(0.0, min(1.0, signal.confidence))
|
||||
|
||||
return signal
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing LLM response: {e}")
|
||||
logger.debug(f"Response was: {response}")
|
||||
return None
|
||||
|
||||
def get_latest_signal(self, symbol: str = 'ETH') -> Optional[LLMTradeSignal]:
|
||||
"""Get latest trading signal for symbol"""
|
||||
with self.lock:
|
||||
return self.latest_signals.get(symbol)
|
||||
|
||||
def get_signal_history(self, limit: int = 10) -> List[LLMTradeSignal]:
|
||||
"""Get recent signal history"""
|
||||
with self.lock:
|
||||
return self.signal_history[-limit:] if self.signal_history else []
|
||||
|
||||
def update_config(self, config: LLMConfig):
|
||||
"""Update LLM configuration"""
|
||||
self.config = config
|
||||
logger.info(f"LLM config updated - Model: {self.config.model}, Base URL: {self.config.base_url}")
|
||||
|
||||
def get_status(self) -> Dict[str, Any]:
|
||||
"""Get LLM proxy status"""
|
||||
with self.lock:
|
||||
return {
|
||||
'is_running': self.is_running,
|
||||
'config': {
|
||||
'base_url': self.config.base_url,
|
||||
'model': self.config.model,
|
||||
'temperature': self.config.temperature
|
||||
},
|
||||
'processed_files': len(self.processed_files),
|
||||
'total_signals': len(self.signal_history),
|
||||
'latest_signals': {k: {
|
||||
'action': v.action,
|
||||
'confidence': v.confidence,
|
||||
'timestamp': v.timestamp.isoformat() if v.timestamp else None
|
||||
} for k, v in self.latest_signals.items()}
|
||||
}
|
||||
|
||||
# Convenience functions
|
||||
def create_llm_proxy(config: Optional[LLMConfig] = None, **kwargs) -> LLMProxy:
|
||||
"""Create LLM proxy instance"""
|
||||
return LLMProxy(config=config, **kwargs)
|
||||
|
||||
def create_llm_config(base_url: str = "http://localhost:1234",
|
||||
model: str = "openai/gpt-oss-20b",
|
||||
**kwargs) -> LLMConfig:
|
||||
"""Create LLM configuration"""
|
||||
return LLMConfig(base_url=base_url, model=model, **kwargs)
|
@@ -28,6 +28,10 @@ import shutil
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
# Text export integration
|
||||
from .text_export_integration import TextExportManager
|
||||
from .llm_proxy import LLMProxy, LLMConfig
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
|
||||
@@ -568,6 +572,8 @@ class TradingOrchestrator:
|
||||
self._initialize_decision_fusion() # Initialize fusion system
|
||||
self._initialize_transformer_model() # Initialize transformer model
|
||||
self._initialize_enhanced_training_system() # Initialize real-time training
|
||||
self._initialize_text_export_manager() # Initialize text data export
|
||||
self._initialize_llm_proxy() # Initialize LLM proxy for trading signals
|
||||
|
||||
def _normalize_model_name(self, name: str) -> str:
|
||||
"""Map various registry/UI names to canonical toggle keys."""
|
||||
@@ -1518,9 +1524,89 @@ class TradingOrchestrator:
|
||||
with open(self.ui_state_file, "w") as f:
|
||||
json.dump(ui_state, f, indent=4)
|
||||
logger.debug(f"UI state saved to {self.ui_state_file}")
|
||||
# Also append a session snapshot for persistence across restarts
|
||||
self._append_session_snapshot()
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving UI state: {e}")
|
||||
|
||||
def _append_session_snapshot(self):
|
||||
"""Append current session metrics to persistent JSON until cleared manually."""
|
||||
try:
|
||||
session_file = os.path.join("data", "session_state.json")
|
||||
os.makedirs(os.path.dirname(session_file), exist_ok=True)
|
||||
|
||||
# Load existing
|
||||
existing = {}
|
||||
if os.path.exists(session_file):
|
||||
try:
|
||||
with open(session_file, "r", encoding="utf-8") as f:
|
||||
existing = json.load(f) or {}
|
||||
except Exception:
|
||||
existing = {}
|
||||
|
||||
# Collect metrics
|
||||
balance = 0.0
|
||||
pnl_total = 0.0
|
||||
closed_trades = []
|
||||
try:
|
||||
if hasattr(self, "trading_executor") and self.trading_executor:
|
||||
balance = float(getattr(self.trading_executor, "account_balance", 0.0) or 0.0)
|
||||
if hasattr(self.trading_executor, "trade_history"):
|
||||
for t in self.trading_executor.trade_history:
|
||||
try:
|
||||
closed_trades.append({
|
||||
"symbol": t.symbol,
|
||||
"side": t.side,
|
||||
"qty": t.quantity,
|
||||
"entry": t.entry_price,
|
||||
"exit": t.exit_price,
|
||||
"pnl": t.pnl,
|
||||
"timestamp": getattr(t, "timestamp", None)
|
||||
})
|
||||
pnl_total += float(t.pnl or 0.0)
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Models and performance (best-effort)
|
||||
models = {}
|
||||
try:
|
||||
models = {
|
||||
"dqn": {
|
||||
"available": bool(getattr(self, "rl_agent", None)),
|
||||
"last_losses": getattr(getattr(self, "rl_agent", None), "losses", [])[-10:] if getattr(getattr(self, "rl_agent", None), "losses", None) else []
|
||||
},
|
||||
"cnn": {
|
||||
"available": bool(getattr(self, "cnn_model", None))
|
||||
},
|
||||
"cob_rl": {
|
||||
"available": bool(getattr(self, "cob_rl_agent", None))
|
||||
},
|
||||
"decision_fusion": {
|
||||
"available": bool(getattr(self, "decision_model", None))
|
||||
}
|
||||
}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
snapshot = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"balance": balance,
|
||||
"session_pnl": pnl_total,
|
||||
"closed_trades": closed_trades,
|
||||
"models": models
|
||||
}
|
||||
|
||||
if "snapshots" not in existing:
|
||||
existing["snapshots"] = []
|
||||
existing["snapshots"].append(snapshot)
|
||||
|
||||
with open(session_file, "w", encoding="utf-8") as f:
|
||||
json.dump(existing, f, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"Error appending session snapshot: {e}")
|
||||
|
||||
def get_model_toggle_state(self, model_name: str) -> Dict[str, bool]:
|
||||
"""Get toggle state for a model"""
|
||||
key = self._normalize_model_name(model_name)
|
||||
@@ -6894,11 +6980,29 @@ class TradingOrchestrator:
|
||||
try:
|
||||
if not self.training_enabled or not self.enhanced_training_system:
|
||||
logger.warning("Enhanced training system not available")
|
||||
# Still start enhanced reward system + timeframe coordinator unconditionally
|
||||
try:
|
||||
from core.enhanced_reward_system_integration import start_enhanced_rewards_for_orchestrator
|
||||
import asyncio as _asyncio
|
||||
_asyncio.create_task(start_enhanced_rewards_for_orchestrator(self, symbols=[self.symbol] + self.ref_symbols))
|
||||
logger.info("Enhanced reward system started (without enhanced training)")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting enhanced reward system: {e}")
|
||||
return False
|
||||
|
||||
if hasattr(self.enhanced_training_system, "start_training"):
|
||||
self.enhanced_training_system.start_training()
|
||||
logger.info("Enhanced real-time training started")
|
||||
|
||||
# Start Enhanced Reward System integration
|
||||
try:
|
||||
from core.enhanced_reward_system_integration import start_enhanced_rewards_for_orchestrator
|
||||
# Fire and forget task to start integration
|
||||
import asyncio as _asyncio
|
||||
_asyncio.create_task(start_enhanced_rewards_for_orchestrator(self, symbols=[self.symbol] + self.ref_symbols))
|
||||
logger.info("Enhanced reward system started")
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting enhanced reward system: {e}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -6925,6 +7029,180 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error stopping enhanced training: {e}")
|
||||
return False
|
||||
|
||||
def _initialize_text_export_manager(self):
|
||||
"""Initialize the text data export manager"""
|
||||
try:
|
||||
self.text_export_manager = TextExportManager(
|
||||
data_provider=self.data_provider,
|
||||
orchestrator=self
|
||||
)
|
||||
|
||||
# Configure with current symbols
|
||||
export_config = {
|
||||
'main_symbol': self.symbol,
|
||||
'ref1_symbol': self.ref_symbols[0] if self.ref_symbols else 'BTC/USDT',
|
||||
'ref2_symbol': 'SPX', # Default to SPX for now
|
||||
'export_dir': 'NN/training/samples/txt'
|
||||
}
|
||||
|
||||
self.text_export_manager.export_config.update(export_config)
|
||||
logger.info("Text export manager initialized")
|
||||
logger.info(f" - Main symbol: {export_config['main_symbol']}")
|
||||
logger.info(f" - Reference symbols: {export_config['ref1_symbol']}, {export_config['ref2_symbol']}")
|
||||
logger.info(f" - Export directory: {export_config['export_dir']}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing text export manager: {e}")
|
||||
self.text_export_manager = None
|
||||
|
||||
def _initialize_llm_proxy(self):
|
||||
"""Initialize LLM proxy for trading signals"""
|
||||
try:
|
||||
# Get LLM configuration from config file or use defaults
|
||||
llm_config = self.config.get('llm_proxy', {})
|
||||
|
||||
llm_proxy_config = LLMConfig(
|
||||
base_url=llm_config.get('base_url', 'http://localhost:1234'),
|
||||
model=llm_config.get('model', 'openai/gpt-oss-20b'),
|
||||
temperature=llm_config.get('temperature', 0.7),
|
||||
max_tokens=llm_config.get('max_tokens', -1),
|
||||
timeout=llm_config.get('timeout', 30),
|
||||
api_key=llm_config.get('api_key')
|
||||
)
|
||||
|
||||
self.llm_proxy = LLMProxy(
|
||||
config=llm_proxy_config,
|
||||
data_dir='NN/training/samples/txt'
|
||||
)
|
||||
|
||||
logger.info("LLM proxy initialized")
|
||||
logger.info(f" - Model: {llm_proxy_config.model}")
|
||||
logger.info(f" - Base URL: {llm_proxy_config.base_url}")
|
||||
logger.info(f" - Temperature: {llm_proxy_config.temperature}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing LLM proxy: {e}")
|
||||
self.llm_proxy = None
|
||||
|
||||
def start_text_export(self) -> bool:
|
||||
"""Start text data export"""
|
||||
try:
|
||||
if not hasattr(self, 'text_export_manager') or not self.text_export_manager:
|
||||
logger.warning("Text export manager not initialized")
|
||||
return False
|
||||
|
||||
return self.text_export_manager.start_export()
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting text export: {e}")
|
||||
return False
|
||||
|
||||
def stop_text_export(self) -> bool:
|
||||
"""Stop text data export"""
|
||||
try:
|
||||
if not hasattr(self, 'text_export_manager') or not self.text_export_manager:
|
||||
return True
|
||||
|
||||
return self.text_export_manager.stop_export()
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping text export: {e}")
|
||||
return False
|
||||
|
||||
def get_text_export_status(self) -> Dict[str, Any]:
|
||||
"""Get text export status"""
|
||||
try:
|
||||
if not hasattr(self, 'text_export_manager') or not self.text_export_manager:
|
||||
return {'enabled': False, 'initialized': False, 'error': 'Not initialized'}
|
||||
|
||||
return self.text_export_manager.get_export_status()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting text export status: {e}")
|
||||
return {'enabled': False, 'initialized': False, 'error': str(e)}
|
||||
|
||||
def start_llm_proxy(self) -> bool:
|
||||
"""Start LLM proxy for trading signals"""
|
||||
try:
|
||||
if not hasattr(self, 'llm_proxy') or not self.llm_proxy:
|
||||
logger.warning("LLM proxy not initialized")
|
||||
return False
|
||||
|
||||
self.llm_proxy.start()
|
||||
logger.info("LLM proxy started")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting LLM proxy: {e}")
|
||||
return False
|
||||
|
||||
def stop_llm_proxy(self) -> bool:
|
||||
"""Stop LLM proxy"""
|
||||
try:
|
||||
if not hasattr(self, 'llm_proxy') or not self.llm_proxy:
|
||||
return True
|
||||
|
||||
self.llm_proxy.stop()
|
||||
logger.info("LLM proxy stopped")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping LLM proxy: {e}")
|
||||
return False
|
||||
|
||||
def get_llm_proxy_status(self) -> Dict[str, Any]:
|
||||
"""Get LLM proxy status"""
|
||||
try:
|
||||
if not hasattr(self, 'llm_proxy') or not self.llm_proxy:
|
||||
return {'enabled': False, 'initialized': False, 'error': 'Not initialized'}
|
||||
|
||||
return self.llm_proxy.get_status()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting LLM proxy status: {e}")
|
||||
return {'enabled': False, 'initialized': False, 'error': str(e)}
|
||||
|
||||
def get_latest_llm_signal(self, symbol: str = 'ETH'):
|
||||
"""Get latest LLM trading signal"""
|
||||
try:
|
||||
if not hasattr(self, 'llm_proxy') or not self.llm_proxy:
|
||||
return None
|
||||
|
||||
return self.llm_proxy.get_latest_signal(symbol)
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting LLM signal: {e}")
|
||||
return None
|
||||
|
||||
def update_llm_config(self, new_config: Dict[str, Any]) -> bool:
|
||||
"""Update LLM proxy configuration"""
|
||||
try:
|
||||
if not hasattr(self, 'llm_proxy') or not self.llm_proxy:
|
||||
logger.warning("LLM proxy not initialized")
|
||||
return False
|
||||
|
||||
# Create new config
|
||||
llm_proxy_config = LLMConfig(
|
||||
base_url=new_config.get('base_url', 'http://localhost:1234'),
|
||||
model=new_config.get('model', 'openai/gpt-oss-20b'),
|
||||
temperature=new_config.get('temperature', 0.7),
|
||||
max_tokens=new_config.get('max_tokens', -1),
|
||||
timeout=new_config.get('timeout', 30),
|
||||
api_key=new_config.get('api_key')
|
||||
)
|
||||
|
||||
# Stop current proxy
|
||||
was_running = self.llm_proxy.is_running
|
||||
if was_running:
|
||||
self.llm_proxy.stop()
|
||||
|
||||
# Update config
|
||||
self.llm_proxy.update_config(llm_proxy_config)
|
||||
|
||||
# Restart if it was running
|
||||
if was_running:
|
||||
self.llm_proxy.start()
|
||||
|
||||
logger.info("LLM proxy configuration updated")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating LLM config: {e}")
|
||||
return False
|
||||
|
||||
def get_enhanced_training_stats(self) -> Dict[str, Any]:
|
||||
"""Get enhanced training system statistics with orchestrator integration"""
|
||||
try:
|
||||
|
364
core/text_data_exporter.py
Normal file
364
core/text_data_exporter.py
Normal file
@@ -0,0 +1,364 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Text Data Exporter - CSV Interface for External Systems
|
||||
Exports market data in CSV format for integration with text-based systems
|
||||
"""
|
||||
|
||||
import os
|
||||
import csv
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class MarketDataPoint:
|
||||
"""Single market data point"""
|
||||
symbol: str
|
||||
timeframe: str
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: float
|
||||
timestamp: datetime
|
||||
|
||||
class TextDataExporter:
|
||||
"""
|
||||
Exports market data to CSV files for external text-based systems
|
||||
|
||||
Features:
|
||||
- Multi-symbol support (MAIN + REF1 + REF2)
|
||||
- Multi-timeframe (1s, 1m, 1h, 1d)
|
||||
- File rotation every minute
|
||||
- Overwrites within the same minute
|
||||
- Thread-safe operations
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
data_provider=None,
|
||||
export_dir: str = "NN/training/samples/txt",
|
||||
main_symbol: str = "ETH/USDT",
|
||||
ref1_symbol: str = "BTC/USDT",
|
||||
ref2_symbol: str = "SPX"):
|
||||
"""
|
||||
Initialize text data exporter
|
||||
|
||||
Args:
|
||||
data_provider: Data provider instance
|
||||
export_dir: Directory for CSV exports
|
||||
main_symbol: Main trading symbol (ETH)
|
||||
ref1_symbol: Reference symbol 1 (BTC)
|
||||
ref2_symbol: Reference symbol 2 (SPX)
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
self.export_dir = export_dir
|
||||
self.main_symbol = main_symbol
|
||||
self.ref1_symbol = ref1_symbol
|
||||
self.ref2_symbol = ref2_symbol
|
||||
|
||||
# Timeframes to export
|
||||
self.timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
# File management
|
||||
self.current_minute = None
|
||||
self.current_filename = None
|
||||
self.export_lock = threading.Lock()
|
||||
|
||||
# Running state
|
||||
self.is_running = False
|
||||
self.export_thread = None
|
||||
|
||||
# Create export directory
|
||||
os.makedirs(self.export_dir, exist_ok=True)
|
||||
|
||||
logger.info(f"Text Data Exporter initialized - Export dir: {self.export_dir}")
|
||||
logger.info(f"Symbols: MAIN={main_symbol}, REF1={ref1_symbol}, REF2={ref2_symbol}")
|
||||
|
||||
def start(self):
|
||||
"""Start the data export process"""
|
||||
if self.is_running:
|
||||
logger.warning("Text data exporter already running")
|
||||
return
|
||||
|
||||
self.is_running = True
|
||||
self.export_thread = threading.Thread(target=self._export_loop, daemon=True)
|
||||
self.export_thread.start()
|
||||
logger.info("Text data exporter started")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the data export process"""
|
||||
self.is_running = False
|
||||
if self.export_thread:
|
||||
self.export_thread.join(timeout=5)
|
||||
logger.info("Text data exporter stopped")
|
||||
|
||||
def _export_loop(self):
|
||||
"""Main export loop - runs every second"""
|
||||
while self.is_running:
|
||||
try:
|
||||
self._export_current_data()
|
||||
time.sleep(1) # Export every second
|
||||
except Exception as e:
|
||||
logger.error(f"Error in export loop: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
def _export_current_data(self):
|
||||
"""Export current market data to CSV"""
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
current_minute_key = current_time.strftime("%Y%m%d_%H%M")
|
||||
|
||||
# Check if we need a new file (new minute)
|
||||
if self.current_minute != current_minute_key:
|
||||
self.current_minute = current_minute_key
|
||||
self.current_filename = f"market_data_{current_minute_key}.txt"
|
||||
logger.info(f"Starting new export file: {self.current_filename}")
|
||||
|
||||
# Gather data for all symbols and timeframes
|
||||
export_data = self._gather_export_data()
|
||||
|
||||
if export_data:
|
||||
self._write_csv_file(export_data)
|
||||
else:
|
||||
logger.debug("No data available for export")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error exporting data: {e}")
|
||||
|
||||
def _gather_export_data(self) -> List[Dict[str, Any]]:
|
||||
"""Gather market data for all symbols and timeframes"""
|
||||
export_rows = []
|
||||
|
||||
if not self.data_provider:
|
||||
return export_rows
|
||||
|
||||
symbols = [
|
||||
("MAIN", self.main_symbol),
|
||||
("REF1", self.ref1_symbol),
|
||||
("REF2", self.ref2_symbol)
|
||||
]
|
||||
|
||||
for symbol_type, symbol in symbols:
|
||||
for timeframe in self.timeframes:
|
||||
try:
|
||||
# Get latest data for this symbol/timeframe
|
||||
data_point = self._get_latest_data(symbol, timeframe)
|
||||
if data_point:
|
||||
export_rows.append({
|
||||
'symbol_type': symbol_type,
|
||||
'symbol': symbol,
|
||||
'timeframe': timeframe,
|
||||
'open': data_point.open,
|
||||
'high': data_point.high,
|
||||
'low': data_point.low,
|
||||
'close': data_point.close,
|
||||
'volume': data_point.volume,
|
||||
'timestamp': data_point.timestamp
|
||||
})
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting data for {symbol} {timeframe}: {e}")
|
||||
|
||||
return export_rows
|
||||
|
||||
def _get_latest_data(self, symbol: str, timeframe: str) -> Optional[MarketDataPoint]:
|
||||
"""Get latest market data for symbol/timeframe"""
|
||||
try:
|
||||
if not hasattr(self.data_provider, 'get_latest_candle'):
|
||||
return None
|
||||
|
||||
# Try to get latest candle data
|
||||
candle = self.data_provider.get_latest_candle(symbol, timeframe)
|
||||
if not candle:
|
||||
return None
|
||||
|
||||
# Convert to MarketDataPoint
|
||||
return MarketDataPoint(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
open=float(candle.get('open', 0)),
|
||||
high=float(candle.get('high', 0)),
|
||||
low=float(candle.get('low', 0)),
|
||||
close=float(candle.get('close', 0)),
|
||||
volume=float(candle.get('volume', 0)),
|
||||
timestamp=candle.get('timestamp', datetime.now())
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting latest data for {symbol} {timeframe}: {e}")
|
||||
return None
|
||||
|
||||
def _write_csv_file(self, export_data: List[Dict[str, Any]]):
|
||||
"""Write data to TXT file in tab-separated format"""
|
||||
if not export_data:
|
||||
return
|
||||
|
||||
filepath = os.path.join(self.export_dir, self.current_filename)
|
||||
|
||||
with self.export_lock:
|
||||
try:
|
||||
# Group data by symbol type for organized output
|
||||
grouped_data = self._group_data_by_symbol(export_data)
|
||||
|
||||
with open(filepath, 'w', encoding='utf-8') as txtfile:
|
||||
# Write in the format specified in readme.md sample
|
||||
self._write_tab_format(txtfile, grouped_data)
|
||||
|
||||
logger.debug(f"Exported {len(export_data)} data points to {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error writing TXT file {filepath}: {e}")
|
||||
|
||||
def _create_csv_header(self) -> List[str]:
|
||||
"""Create CSV header based on specification"""
|
||||
header = ['symbol']
|
||||
|
||||
# Add columns for each symbol type and timeframe
|
||||
for symbol_type in ['MAIN', 'REF1', 'REF2']:
|
||||
for timeframe in self.timeframes:
|
||||
prefix = f"{symbol_type}_{timeframe}"
|
||||
header.extend([
|
||||
f"{prefix}_O", # Open
|
||||
f"{prefix}_H", # High
|
||||
f"{prefix}_L", # Low
|
||||
f"{prefix}_C", # Close
|
||||
f"{prefix}_V", # Volume
|
||||
f"{prefix}_T" # Timestamp
|
||||
])
|
||||
|
||||
return header
|
||||
|
||||
def _group_data_by_symbol(self, export_data: List[Dict[str, Any]]) -> Dict[str, Dict[str, Dict[str, Any]]]:
|
||||
"""Group data by symbol type and timeframe"""
|
||||
grouped = {}
|
||||
|
||||
for data_point in export_data:
|
||||
symbol_type = data_point['symbol_type']
|
||||
timeframe = data_point['timeframe']
|
||||
|
||||
if symbol_type not in grouped:
|
||||
grouped[symbol_type] = {}
|
||||
|
||||
grouped[symbol_type][timeframe] = data_point
|
||||
|
||||
return grouped
|
||||
|
||||
def _format_csv_rows(self, grouped_data: Dict[str, Dict[str, Dict[str, Any]]]) -> List[Dict[str, Any]]:
|
||||
"""Format data into CSV rows"""
|
||||
rows = []
|
||||
|
||||
# Create a single row with all data
|
||||
row = {'symbol': f"{self.main_symbol.split('/')[0]} ({self.ref1_symbol.split('/')[0]}, {self.ref2_symbol})"}
|
||||
|
||||
for symbol_type in ['MAIN', 'REF1', 'REF2']:
|
||||
symbol_data = grouped_data.get(symbol_type, {})
|
||||
|
||||
for timeframe in self.timeframes:
|
||||
prefix = f"{symbol_type}_{timeframe}"
|
||||
data_point = symbol_data.get(timeframe)
|
||||
|
||||
if data_point:
|
||||
row[f"{prefix}_O"] = f"{data_point['open']:.6f}"
|
||||
row[f"{prefix}_H"] = f"{data_point['high']:.6f}"
|
||||
row[f"{prefix}_L"] = f"{data_point['low']:.6f}"
|
||||
row[f"{prefix}_C"] = f"{data_point['close']:.6f}"
|
||||
row[f"{prefix}_V"] = f"{data_point['volume']:.2f}"
|
||||
row[f"{prefix}_T"] = data_point['timestamp'].strftime("%Y-%m-%d %H:%M:%S")
|
||||
else:
|
||||
# Empty values if no data
|
||||
row[f"{prefix}_O"] = ""
|
||||
row[f"{prefix}_H"] = ""
|
||||
row[f"{prefix}_L"] = ""
|
||||
row[f"{prefix}_C"] = ""
|
||||
row[f"{prefix}_V"] = ""
|
||||
row[f"{prefix}_T"] = ""
|
||||
|
||||
rows.append(row)
|
||||
return rows
|
||||
|
||||
def _write_tab_format(self, txtfile, grouped_data: Dict[str, Dict[str, Dict[str, Any]]]):
|
||||
"""Write data in tab-separated format like readme.md sample"""
|
||||
# Write header structure
|
||||
txtfile.write("symbol\tMAIN SYMBOL (ETH)\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\t\tREF1 (BTC)\t\t\t\t\t\tREF2 (SPX)\t\t\t\t\t\tREF3 (SOL)\n")
|
||||
txtfile.write("timeframe\t1s\t\t\t\t\t\t1m\t\t\t\t\t\t1h\t\t\t\t\t\t1d\t\t\t\t\t\t1s\t\t\t\t\t\t1s\t\t\t\t\t\t1s\n")
|
||||
txtfile.write("datapoint\tO\tH\tL\tC\tV\tTimestamp\tO\tH\tL\tC\tV\tTimestamp\tO\tH\tL\tC\tV\tTimestamp\tO\tH\tL\tC\tV\tTimestamp\tO\tH\tL\tC\tV\tTimestamp\tO\tH\tL\tC\tV\tTimestamp\tO\tH\tL\tC\tV\tTimestamp\n")
|
||||
|
||||
# Write data row
|
||||
row_parts = []
|
||||
current_time = datetime.now()
|
||||
|
||||
# Timestamp first
|
||||
row_parts.append(current_time.strftime("%Y-%m-%dT%H:%M:%SZ"))
|
||||
|
||||
# ETH data for all timeframes (1s, 1m, 1h, 1d)
|
||||
main_data = grouped_data.get('MAIN', {})
|
||||
for timeframe in ['1s', '1m', '1h', '1d']:
|
||||
data_point = main_data.get(timeframe)
|
||||
if data_point:
|
||||
row_parts.extend([
|
||||
f"{data_point['open']:.2f}",
|
||||
f"{data_point['high']:.2f}",
|
||||
f"{data_point['low']:.2f}",
|
||||
f"{data_point['close']:.2f}",
|
||||
f"{data_point['volume']:.1f}",
|
||||
data_point['timestamp'].strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
])
|
||||
else:
|
||||
row_parts.extend(["0", "0", "0", "0", "0", current_time.strftime("%Y-%m-%dT%H:%M:%SZ")])
|
||||
|
||||
# REF1 (BTC), REF2 (SPX), REF3 (SOL) - 1s timeframe only
|
||||
for ref_type in ['REF1', 'REF2']: # REF3 will be added by LLM proxy
|
||||
ref_data = grouped_data.get(ref_type, {})
|
||||
data_point = ref_data.get('1s')
|
||||
if data_point:
|
||||
row_parts.extend([
|
||||
f"{data_point['open']:.2f}",
|
||||
f"{data_point['high']:.2f}",
|
||||
f"{data_point['low']:.2f}",
|
||||
f"{data_point['close']:.2f}",
|
||||
f"{data_point['volume']:.1f}",
|
||||
data_point['timestamp'].strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
])
|
||||
else:
|
||||
row_parts.extend(["0", "0", "0", "0", "0", current_time.strftime("%Y-%m-%dT%H:%M:%SZ")])
|
||||
|
||||
# Add placeholder for REF3 (SOL) - will be filled by LLM proxy
|
||||
row_parts.extend(["0", "0", "0", "0", "0", current_time.strftime("%Y-%m-%dT%H:%M:%SZ")])
|
||||
|
||||
txtfile.write("\t".join(row_parts) + "\n")
|
||||
|
||||
def get_current_filename(self) -> Optional[str]:
|
||||
"""Get current export filename"""
|
||||
return self.current_filename
|
||||
|
||||
def get_export_stats(self) -> Dict[str, Any]:
|
||||
"""Get export statistics"""
|
||||
stats = {
|
||||
'is_running': self.is_running,
|
||||
'export_dir': self.export_dir,
|
||||
'current_filename': self.current_filename,
|
||||
'symbols': {
|
||||
'main': self.main_symbol,
|
||||
'ref1': self.ref1_symbol,
|
||||
'ref2': self.ref2_symbol
|
||||
},
|
||||
'timeframes': self.timeframes
|
||||
}
|
||||
|
||||
# Add file count
|
||||
try:
|
||||
files = [f for f in os.listdir(self.export_dir) if f.endswith('.txt')]
|
||||
stats['total_files'] = len(files)
|
||||
except:
|
||||
stats['total_files'] = 0
|
||||
|
||||
return stats
|
||||
|
||||
# Convenience function for integration
|
||||
def create_text_exporter(data_provider=None, **kwargs) -> TextDataExporter:
|
||||
"""Create and return a TextDataExporter instance"""
|
||||
return TextDataExporter(data_provider=data_provider, **kwargs)
|
233
core/text_export_integration.py
Normal file
233
core/text_export_integration.py
Normal file
@@ -0,0 +1,233 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Text Export Integration - Connects TextDataExporter with existing data systems
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from .text_data_exporter import TextDataExporter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TextExportManager:
|
||||
"""
|
||||
Manages text data export integration with the trading system
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider=None, orchestrator=None):
|
||||
"""
|
||||
Initialize text export manager
|
||||
|
||||
Args:
|
||||
data_provider: Main data provider instance
|
||||
orchestrator: Trading orchestrator instance
|
||||
"""
|
||||
self.data_provider = data_provider
|
||||
self.orchestrator = orchestrator
|
||||
self.text_exporter: Optional[TextDataExporter] = None
|
||||
|
||||
# Configuration
|
||||
self.export_enabled = False
|
||||
self.export_config = {
|
||||
'main_symbol': 'ETH/USDT',
|
||||
'ref1_symbol': 'BTC/USDT',
|
||||
'ref2_symbol': 'SPX', # Will need to be mapped to available data
|
||||
'export_dir': 'NN/training/samples/txt'
|
||||
}
|
||||
|
||||
def initialize_exporter(self, config: Optional[Dict[str, Any]] = None):
|
||||
"""Initialize the text data exporter"""
|
||||
try:
|
||||
if config:
|
||||
self.export_config.update(config)
|
||||
|
||||
# Create enhanced data provider wrapper
|
||||
enhanced_provider = EnhancedDataProviderWrapper(
|
||||
self.data_provider,
|
||||
self.orchestrator
|
||||
)
|
||||
|
||||
# Create text exporter
|
||||
self.text_exporter = TextDataExporter(
|
||||
data_provider=enhanced_provider,
|
||||
export_dir=self.export_config['export_dir'],
|
||||
main_symbol=self.export_config['main_symbol'],
|
||||
ref1_symbol=self.export_config['ref1_symbol'],
|
||||
ref2_symbol=self.export_config['ref2_symbol']
|
||||
)
|
||||
|
||||
logger.info("Text data exporter initialized successfully")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing text exporter: {e}")
|
||||
return False
|
||||
|
||||
def start_export(self):
|
||||
"""Start text data export"""
|
||||
if not self.text_exporter:
|
||||
if not self.initialize_exporter():
|
||||
logger.error("Cannot start export - initialization failed")
|
||||
return False
|
||||
|
||||
try:
|
||||
self.text_exporter.start()
|
||||
self.export_enabled = True
|
||||
logger.info("Text data export started")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting text export: {e}")
|
||||
return False
|
||||
|
||||
def stop_export(self):
|
||||
"""Stop text data export"""
|
||||
if self.text_exporter:
|
||||
try:
|
||||
self.text_exporter.stop()
|
||||
self.export_enabled = False
|
||||
logger.info("Text data export stopped")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping text export: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_export_status(self) -> Dict[str, Any]:
|
||||
"""Get current export status"""
|
||||
status = {
|
||||
'enabled': self.export_enabled,
|
||||
'initialized': self.text_exporter is not None,
|
||||
'config': self.export_config.copy()
|
||||
}
|
||||
|
||||
if self.text_exporter:
|
||||
status.update(self.text_exporter.get_export_stats())
|
||||
|
||||
return status
|
||||
|
||||
def update_config(self, new_config: Dict[str, Any]):
|
||||
"""Update export configuration"""
|
||||
old_enabled = self.export_enabled
|
||||
|
||||
# Stop if running
|
||||
if old_enabled:
|
||||
self.stop_export()
|
||||
|
||||
# Update config
|
||||
self.export_config.update(new_config)
|
||||
|
||||
# Reinitialize
|
||||
self.text_exporter = None
|
||||
|
||||
# Restart if was enabled
|
||||
if old_enabled:
|
||||
self.start_export()
|
||||
|
||||
logger.info(f"Text export config updated: {new_config}")
|
||||
|
||||
class EnhancedDataProviderWrapper:
|
||||
"""
|
||||
Wrapper around the existing data provider to provide the interface
|
||||
expected by TextDataExporter
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider, orchestrator=None):
|
||||
self.data_provider = data_provider
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
# Timeframe mapping
|
||||
self.timeframe_map = {
|
||||
'1s': '1s',
|
||||
'1m': '1m',
|
||||
'1h': '1h',
|
||||
'1d': '1d'
|
||||
}
|
||||
|
||||
def get_latest_candle(self, symbol: str, timeframe: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get latest candle data for symbol/timeframe"""
|
||||
try:
|
||||
# Handle special symbols
|
||||
if symbol == 'SPX':
|
||||
return self._get_spx_data()
|
||||
|
||||
# Map timeframe
|
||||
mapped_timeframe = self.timeframe_map.get(timeframe, timeframe)
|
||||
|
||||
# Try different methods to get data
|
||||
candle_data = None
|
||||
|
||||
# Method 1: Direct candle data
|
||||
if hasattr(self.data_provider, 'get_latest_candle'):
|
||||
candle_data = self.data_provider.get_latest_candle(symbol, mapped_timeframe)
|
||||
|
||||
# Method 2: From candle buffer
|
||||
elif hasattr(self.data_provider, 'candle_buffer'):
|
||||
buffer_key = f"{symbol}_{mapped_timeframe}"
|
||||
if buffer_key in self.data_provider.candle_buffer:
|
||||
candles = self.data_provider.candle_buffer[buffer_key]
|
||||
if candles:
|
||||
latest = candles[-1]
|
||||
candle_data = {
|
||||
'open': latest.get('open', 0),
|
||||
'high': latest.get('high', 0),
|
||||
'low': latest.get('low', 0),
|
||||
'close': latest.get('close', 0),
|
||||
'volume': latest.get('volume', 0),
|
||||
'timestamp': latest.get('timestamp', datetime.now())
|
||||
}
|
||||
|
||||
# Method 3: From tick data (for 1s timeframe)
|
||||
elif mapped_timeframe == '1s' and hasattr(self.data_provider, 'latest_prices'):
|
||||
if symbol in self.data_provider.latest_prices:
|
||||
price = self.data_provider.latest_prices[symbol]
|
||||
candle_data = {
|
||||
'open': price,
|
||||
'high': price,
|
||||
'low': price,
|
||||
'close': price,
|
||||
'volume': 0,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
return candle_data
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting candle data for {symbol} {timeframe}: {e}")
|
||||
return None
|
||||
|
||||
def _get_spx_data(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get SPX data - placeholder for now"""
|
||||
# For now, return mock SPX data
|
||||
# In production, this would connect to a stock data provider
|
||||
return {
|
||||
'open': 5500.0,
|
||||
'high': 5520.0,
|
||||
'low': 5495.0,
|
||||
'close': 5510.0,
|
||||
'volume': 1000000,
|
||||
'timestamp': datetime.now()
|
||||
}
|
||||
|
||||
# Integration helper functions
|
||||
def setup_text_export(data_provider=None, orchestrator=None, config: Optional[Dict[str, Any]] = None) -> TextExportManager:
|
||||
"""Setup text export with default configuration"""
|
||||
manager = TextExportManager(data_provider, orchestrator)
|
||||
|
||||
if config:
|
||||
manager.export_config.update(config)
|
||||
|
||||
return manager
|
||||
|
||||
def start_text_export_service(data_provider=None, orchestrator=None, auto_start: bool = True) -> TextExportManager:
|
||||
"""Start text export service with auto-initialization"""
|
||||
manager = setup_text_export(data_provider, orchestrator)
|
||||
|
||||
if auto_start:
|
||||
if manager.initialize_exporter():
|
||||
manager.start_export()
|
||||
logger.info("Text export service started successfully")
|
||||
else:
|
||||
logger.error("Failed to start text export service")
|
||||
|
||||
return manager
|
496
core/timeframe_inference_coordinator.py
Normal file
496
core/timeframe_inference_coordinator.py
Normal file
@@ -0,0 +1,496 @@
|
||||
"""
|
||||
Timeframe-Aware Inference Coordinator
|
||||
|
||||
This module coordinates model inference across multiple timeframes with proper scheduling.
|
||||
It ensures that models know which timeframe they are predicting on and handles the
|
||||
complex scheduling requirements for multi-timeframe predictions.
|
||||
|
||||
Key Features:
|
||||
- Timeframe-aware model inference
|
||||
- Hourly multi-timeframe inference (4 predictions per hour)
|
||||
- Frequent inference at 1-5 second intervals
|
||||
- Prediction context management
|
||||
- Integration with enhanced reward calculator
|
||||
"""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
import threading
|
||||
from enum import Enum
|
||||
|
||||
from core.enhanced_reward_calculator import EnhancedRewardCalculator, TimeFrame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceContext:
|
||||
"""Context information for a model inference"""
|
||||
symbol: str
|
||||
timeframe: TimeFrame
|
||||
timestamp: datetime
|
||||
target_timeframe: TimeFrame # Which timeframe we're predicting for
|
||||
is_hourly_inference: bool = False
|
||||
inference_type: str = "regular" # "regular", "hourly", "continuous"
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceSchedule:
|
||||
"""Schedule configuration for different inference types"""
|
||||
continuous_interval_seconds: float = 5.0 # Continuous inference every 5 seconds
|
||||
hourly_timeframes: List[TimeFrame] = None # Timeframes for hourly inference
|
||||
|
||||
def __post_init__(self):
|
||||
if self.hourly_timeframes is None:
|
||||
self.hourly_timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
|
||||
TimeFrame.HOURS_1, TimeFrame.DAYS_1]
|
||||
|
||||
|
||||
class TimeframeInferenceCoordinator:
|
||||
"""
|
||||
Coordinates timeframe-aware model inference with proper scheduling
|
||||
|
||||
This coordinator:
|
||||
1. Manages continuous inference every 1-5 seconds on main timeframe
|
||||
2. Schedules hourly multi-timeframe inference (4 predictions per hour)
|
||||
3. Ensures models know which timeframe they're predicting on
|
||||
4. Integrates with enhanced reward calculator for training
|
||||
5. Handles prediction context and metadata
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reward_calculator: EnhancedRewardCalculator,
|
||||
data_provider: Any = None,
|
||||
symbols: List[str] = None):
|
||||
"""
|
||||
Initialize the timeframe inference coordinator
|
||||
|
||||
Args:
|
||||
reward_calculator: Enhanced reward calculator instance
|
||||
data_provider: Data provider for market data
|
||||
symbols: List of symbols to coordinate inference for
|
||||
"""
|
||||
self.reward_calculator = reward_calculator
|
||||
self.data_provider = data_provider
|
||||
self.symbols = symbols or ['ETH/USDT', 'BTC/USDT']
|
||||
|
||||
# Inference schedule configuration
|
||||
self.schedule = InferenceSchedule()
|
||||
|
||||
# Model registry - stores inference functions for different models
|
||||
self.model_inference_functions: Dict[str, Callable] = {}
|
||||
|
||||
# Tracking inference state
|
||||
self.last_continuous_inference: Dict[str, datetime] = {}
|
||||
self.last_hourly_inference: Dict[str, datetime] = {}
|
||||
self.next_hourly_inference: Dict[str, datetime] = {}
|
||||
|
||||
# Active inference tasks
|
||||
self.inference_tasks: List[asyncio.Task] = []
|
||||
self.running = False
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Performance metrics
|
||||
self.inference_stats = {
|
||||
'continuous_inferences': 0,
|
||||
'hourly_inferences': 0,
|
||||
'failed_inferences': 0,
|
||||
'average_inference_time_ms': 0.0
|
||||
}
|
||||
|
||||
self._initialize_schedules()
|
||||
|
||||
logger.info(f"TimeframeInferenceCoordinator initialized for symbols: {self.symbols}")
|
||||
logger.info(f"Continuous inference interval: {self.schedule.continuous_interval_seconds}s")
|
||||
logger.info(f"Hourly timeframes: {[tf.value for tf in self.schedule.hourly_timeframes]}")
|
||||
|
||||
def _initialize_schedules(self):
|
||||
"""Initialize inference schedules for all symbols"""
|
||||
current_time = datetime.now()
|
||||
|
||||
for symbol in self.symbols:
|
||||
self.last_continuous_inference[symbol] = current_time
|
||||
self.last_hourly_inference[symbol] = current_time
|
||||
|
||||
# Schedule next hourly inference at the top of the next hour
|
||||
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
||||
self.next_hourly_inference[symbol] = next_hour
|
||||
|
||||
def register_model_inference_function(self, model_name: str, inference_func: Callable):
|
||||
"""
|
||||
Register a model's inference function
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
inference_func: Async function that takes InferenceContext and returns prediction
|
||||
"""
|
||||
self.model_inference_functions[model_name] = inference_func
|
||||
logger.info(f"Registered inference function for model: {model_name}")
|
||||
|
||||
async def start_coordination(self):
|
||||
"""Start the inference coordination system"""
|
||||
if self.running:
|
||||
logger.warning("Inference coordination already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
logger.info("Starting timeframe inference coordination")
|
||||
|
||||
# Start continuous inference tasks for each symbol
|
||||
for symbol in self.symbols:
|
||||
task = asyncio.create_task(self._continuous_inference_loop(symbol))
|
||||
self.inference_tasks.append(task)
|
||||
|
||||
# Start hourly inference scheduler
|
||||
task = asyncio.create_task(self._hourly_inference_scheduler())
|
||||
self.inference_tasks.append(task)
|
||||
|
||||
# Start reward evaluation loop
|
||||
task = asyncio.create_task(self._reward_evaluation_loop())
|
||||
self.inference_tasks.append(task)
|
||||
|
||||
logger.info(f"Started {len(self.inference_tasks)} inference coordination tasks")
|
||||
|
||||
async def stop_coordination(self):
|
||||
"""Stop the inference coordination system"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
logger.info("Stopping timeframe inference coordination")
|
||||
|
||||
# Cancel all tasks
|
||||
for task in self.inference_tasks:
|
||||
task.cancel()
|
||||
|
||||
# Wait for tasks to complete
|
||||
await asyncio.gather(*self.inference_tasks, return_exceptions=True)
|
||||
self.inference_tasks.clear()
|
||||
|
||||
logger.info("Inference coordination stopped")
|
||||
|
||||
async def _continuous_inference_loop(self, symbol: str):
|
||||
"""
|
||||
Continuous inference loop for a specific symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol to run inference for
|
||||
"""
|
||||
logger.info(f"Starting continuous inference loop for {symbol}")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
|
||||
# Check if it's time for continuous inference
|
||||
last_inference = self.last_continuous_inference[symbol]
|
||||
time_since_last = (current_time - last_inference).total_seconds()
|
||||
|
||||
if time_since_last >= self.schedule.continuous_interval_seconds:
|
||||
# Run continuous inference on primary timeframe (1s)
|
||||
context = InferenceContext(
|
||||
symbol=symbol,
|
||||
timeframe=TimeFrame.SECONDS_1,
|
||||
timestamp=current_time,
|
||||
target_timeframe=TimeFrame.SECONDS_1,
|
||||
is_hourly_inference=False,
|
||||
inference_type="continuous"
|
||||
)
|
||||
|
||||
await self._execute_inference(context)
|
||||
self.last_continuous_inference[symbol] = current_time
|
||||
self.inference_stats['continuous_inferences'] += 1
|
||||
|
||||
# Sleep for a short interval to avoid busy waiting
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in continuous inference loop for {symbol}: {e}")
|
||||
await asyncio.sleep(1.0) # Wait longer on error
|
||||
|
||||
async def _hourly_inference_scheduler(self):
|
||||
"""Scheduler for hourly multi-timeframe inference and timeframe-boundary triggers"""
|
||||
logger.info("Starting hourly inference scheduler")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
|
||||
# Check if any symbol needs hourly inference
|
||||
for symbol in self.symbols:
|
||||
if current_time >= self.next_hourly_inference[symbol]:
|
||||
await self._execute_hourly_inference(symbol, current_time)
|
||||
|
||||
# Schedule next hourly inference
|
||||
next_hour = current_time.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1)
|
||||
self.next_hourly_inference[symbol] = next_hour
|
||||
self.last_hourly_inference[symbol] = current_time
|
||||
|
||||
# Trigger at each new timeframe boundary: 1m, 1h, 1d
|
||||
if current_time.second == 0:
|
||||
# New minute
|
||||
await self._execute_boundary_inference(symbol, current_time, TimeFrame.MINUTES_1)
|
||||
if current_time.minute == 0 and current_time.second == 0:
|
||||
# New hour
|
||||
await self._execute_boundary_inference(symbol, current_time, TimeFrame.HOURS_1)
|
||||
if current_time.hour == 0 and current_time.minute == 0 and current_time.second == 0:
|
||||
# New day
|
||||
await self._execute_boundary_inference(symbol, current_time, TimeFrame.DAYS_1)
|
||||
|
||||
# Sleep for 30 seconds between checks
|
||||
await asyncio.sleep(30)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in hourly inference scheduler: {e}")
|
||||
await asyncio.sleep(60) # Wait longer on error
|
||||
|
||||
async def _execute_boundary_inference(self, symbol: str, timestamp: datetime, timeframe: TimeFrame):
|
||||
"""Execute an inference exactly at timeframe boundary"""
|
||||
try:
|
||||
context = InferenceContext(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp,
|
||||
target_timeframe=timeframe,
|
||||
is_hourly_inference=False,
|
||||
inference_type="boundary"
|
||||
)
|
||||
await self._execute_inference(context)
|
||||
except Exception as e:
|
||||
logger.debug(f"Boundary inference error for {symbol} {timeframe.value}: {e}")
|
||||
|
||||
async def _execute_hourly_inference(self, symbol: str, timestamp: datetime):
|
||||
"""
|
||||
Execute hourly multi-timeframe inference for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timestamp: Current timestamp
|
||||
"""
|
||||
logger.info(f"Executing hourly multi-timeframe inference for {symbol}")
|
||||
|
||||
# Run inference for each timeframe
|
||||
for timeframe in self.schedule.hourly_timeframes:
|
||||
context = InferenceContext(
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
timestamp=timestamp,
|
||||
target_timeframe=timeframe,
|
||||
is_hourly_inference=True,
|
||||
inference_type="hourly"
|
||||
)
|
||||
|
||||
await self._execute_inference(context)
|
||||
self.inference_stats['hourly_inferences'] += 1
|
||||
|
||||
# Small delay between timeframe inferences
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def _execute_inference(self, context: InferenceContext):
|
||||
"""
|
||||
Execute inference for a specific context
|
||||
|
||||
Args:
|
||||
context: Inference context containing all necessary information
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Run inference for all registered models
|
||||
for model_name, inference_func in self.model_inference_functions.items():
|
||||
try:
|
||||
# Execute model inference
|
||||
prediction = await inference_func(context)
|
||||
|
||||
if prediction is not None:
|
||||
# Add prediction to reward calculator
|
||||
prediction_id = self.reward_calculator.add_prediction(
|
||||
symbol=context.symbol,
|
||||
timeframe=context.target_timeframe,
|
||||
predicted_price=prediction.get('predicted_price', 0.0),
|
||||
predicted_direction=prediction.get('direction', 0),
|
||||
confidence=prediction.get('confidence', 0.0),
|
||||
current_price=prediction.get('current_price', 0.0),
|
||||
model_name=model_name,
|
||||
predicted_return=prediction.get('predicted_return'),
|
||||
state_vector=prediction.get('model_state') or prediction.get('model_features')
|
||||
)
|
||||
|
||||
logger.debug(f"Added prediction {prediction_id} from {model_name} "
|
||||
f"for {context.symbol} {context.target_timeframe.value}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error running inference for model {model_name}: {e}")
|
||||
self.inference_stats['failed_inferences'] += 1
|
||||
|
||||
# Update inference timing stats
|
||||
inference_time_ms = (time.time() - start_time) * 1000
|
||||
self._update_inference_timing(inference_time_ms)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing inference for context {context}: {e}")
|
||||
self.inference_stats['failed_inferences'] += 1
|
||||
|
||||
def _update_inference_timing(self, inference_time_ms: float):
|
||||
"""Update inference timing statistics"""
|
||||
total_inferences = (self.inference_stats['continuous_inferences'] +
|
||||
self.inference_stats['hourly_inferences'])
|
||||
|
||||
if total_inferences > 0:
|
||||
current_avg = self.inference_stats['average_inference_time_ms']
|
||||
new_avg = ((current_avg * (total_inferences - 1)) + inference_time_ms) / total_inferences
|
||||
self.inference_stats['average_inference_time_ms'] = new_avg
|
||||
|
||||
async def _reward_evaluation_loop(self):
|
||||
"""Continuous loop for evaluating prediction rewards"""
|
||||
logger.info("Starting reward evaluation loop")
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Update price cache if data provider available
|
||||
if self.data_provider:
|
||||
# DataProvider.get_current_price is synchronous; do not await
|
||||
await self._update_price_cache()
|
||||
|
||||
# Evaluate predictions and get training data
|
||||
for symbol in self.symbols:
|
||||
evaluation_results = self.reward_calculator.evaluate_predictions(symbol)
|
||||
|
||||
if symbol in evaluation_results and evaluation_results[symbol]:
|
||||
logger.debug(f"Evaluated {len(evaluation_results[symbol])} predictions for {symbol}")
|
||||
|
||||
# Here you could trigger training for models that have new evaluated predictions
|
||||
await self._trigger_model_training(symbol, evaluation_results[symbol])
|
||||
|
||||
# Sleep for evaluation interval
|
||||
await asyncio.sleep(10) # Evaluate every 10 seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in reward evaluation loop: {e}")
|
||||
await asyncio.sleep(30) # Wait longer on error
|
||||
|
||||
async def _update_price_cache(self):
|
||||
"""Update price cache with current market prices"""
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
# Get current price from data provider
|
||||
if hasattr(self.data_provider, 'get_current_price'):
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
if current_price:
|
||||
self.reward_calculator.update_price(symbol, current_price)
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error updating price cache: {e}")
|
||||
|
||||
async def _trigger_model_training(self, symbol: str, evaluation_results: List[Any]):
|
||||
"""
|
||||
Trigger model training based on evaluation results
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
evaluation_results: List of (prediction, reward) tuples
|
||||
"""
|
||||
try:
|
||||
# Group by model and timeframe for targeted training
|
||||
training_groups = {}
|
||||
|
||||
for prediction_record, reward in evaluation_results:
|
||||
model_name = prediction_record.model_name
|
||||
timeframe = prediction_record.timeframe
|
||||
|
||||
key = f"{model_name}_{timeframe.value}"
|
||||
if key not in training_groups:
|
||||
training_groups[key] = []
|
||||
|
||||
training_groups[key].append((prediction_record, reward))
|
||||
|
||||
# Trigger training for each group
|
||||
for group_key, training_data in training_groups.items():
|
||||
model_name, timeframe_str = group_key.split('_', 1)
|
||||
timeframe = TimeFrame(timeframe_str)
|
||||
|
||||
logger.info(f"Triggering training for {model_name} on {symbol} {timeframe.value} "
|
||||
f"with {len(training_data)} samples")
|
||||
|
||||
# Here you would call the specific model's training function
|
||||
# This is a placeholder - you'll need to implement the actual training calls
|
||||
await self._call_model_training(model_name, symbol, timeframe, training_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering model training: {e}")
|
||||
|
||||
async def _call_model_training(self, model_name: str, symbol: str,
|
||||
timeframe: TimeFrame, training_data: List[Any]):
|
||||
"""
|
||||
Call model-specific training function
|
||||
|
||||
Args:
|
||||
model_name: Name of the model to train
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe for training
|
||||
training_data: List of (prediction, reward) tuples
|
||||
"""
|
||||
# This is a placeholder for model-specific training calls
|
||||
# You'll need to implement this based on your specific model interfaces
|
||||
logger.debug(f"Training call for {model_name}: {len(training_data)} samples")
|
||||
|
||||
def get_inference_statistics(self) -> Dict[str, Any]:
|
||||
"""Get inference coordination statistics"""
|
||||
with self.lock:
|
||||
stats = self.inference_stats.copy()
|
||||
|
||||
# Add scheduling information
|
||||
stats['symbols'] = self.symbols
|
||||
stats['continuous_interval_seconds'] = self.schedule.continuous_interval_seconds
|
||||
stats['hourly_timeframes'] = [tf.value for tf in self.schedule.hourly_timeframes]
|
||||
stats['next_hourly_inferences'] = {
|
||||
symbol: timestamp.isoformat()
|
||||
for symbol, timestamp in self.next_hourly_inference.items()
|
||||
}
|
||||
|
||||
# Add accuracy summary from reward calculator
|
||||
stats['accuracy_summary'] = self.reward_calculator.get_accuracy_summary()
|
||||
|
||||
return stats
|
||||
|
||||
def force_hourly_inference(self, symbol: str = None):
|
||||
"""
|
||||
Force immediate hourly inference for symbol(s)
|
||||
|
||||
Args:
|
||||
symbol: Specific symbol (None for all symbols)
|
||||
"""
|
||||
symbols_to_process = [symbol] if symbol else self.symbols
|
||||
|
||||
async def _force_inference():
|
||||
current_time = datetime.now()
|
||||
for sym in symbols_to_process:
|
||||
await self._execute_hourly_inference(sym, current_time)
|
||||
|
||||
# Schedule the inference
|
||||
if self.running:
|
||||
asyncio.create_task(_force_inference())
|
||||
else:
|
||||
logger.warning("Cannot force inference - coordinator not running")
|
||||
|
||||
def get_prediction_history(self, symbol: str, timeframe: TimeFrame,
|
||||
max_samples: int = 50) -> List[Any]:
|
||||
"""
|
||||
Get prediction history for training
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Specific timeframe
|
||||
max_samples: Maximum samples to return
|
||||
|
||||
Returns:
|
||||
List of training samples
|
||||
"""
|
||||
return self.reward_calculator.get_training_data(symbol, timeframe, max_samples)
|
||||
|
130
core/unified_training_manager.py
Normal file
130
core/unified_training_manager.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""
|
||||
Unified Training Manager
|
||||
|
||||
Combines the previous built-in (normal) training and the EnhancedRealtimeTrainingSystem ideas
|
||||
into a single orchestrator-agnostic manager. Keeps orchestrator lean by moving training logic here.
|
||||
|
||||
Key responsibilities:
|
||||
- Subscribe to model predictions/outcomes and perform online updates (DQN/COB RL/CNN)
|
||||
- Schedule periodic training (intervals) and replay-based training
|
||||
- Integrate with Enhanced Reward System for evaluated rewards
|
||||
- Work regardless of enhanced system availability
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from core.enhanced_reward_calculator import TimeFrame
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnifiedTrainingManager:
|
||||
"""Unified training controller decoupled from the orchestrator."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orchestrator: Any,
|
||||
reward_system: Any = None,
|
||||
dqn_interval_s: int = 5,
|
||||
cob_rl_interval_s: int = 1,
|
||||
cnn_interval_s: int = 10,
|
||||
min_dqn_experiences: int = 16,
|
||||
):
|
||||
self.orchestrator = orchestrator
|
||||
self.reward_system = reward_system
|
||||
self.dqn_interval_s = dqn_interval_s
|
||||
self.cob_rl_interval_s = cob_rl_interval_s
|
||||
self.cnn_interval_s = cnn_interval_s
|
||||
self.min_dqn_experiences = min_dqn_experiences
|
||||
|
||||
self.running = False
|
||||
self._tasks: List[asyncio.Task] = []
|
||||
|
||||
async def start(self):
|
||||
if self.running:
|
||||
logger.warning("UnifiedTrainingManager already running")
|
||||
return
|
||||
self.running = True
|
||||
logger.info("UnifiedTrainingManager started")
|
||||
# Periodic trainers
|
||||
self._tasks.append(asyncio.create_task(self._dqn_trainer_loop()))
|
||||
self._tasks.append(asyncio.create_task(self._cob_rl_trainer_loop()))
|
||||
self._tasks.append(asyncio.create_task(self._cnn_trainer_loop()))
|
||||
# Reward-driven trainer
|
||||
if self.reward_system is not None:
|
||||
self._tasks.append(asyncio.create_task(self._reward_driven_training_loop()))
|
||||
|
||||
async def stop(self):
|
||||
self.running = False
|
||||
for t in self._tasks:
|
||||
t.cancel()
|
||||
await asyncio.gather(*self._tasks, return_exceptions=True)
|
||||
self._tasks.clear()
|
||||
logger.info("UnifiedTrainingManager stopped")
|
||||
|
||||
async def _dqn_trainer_loop(self):
|
||||
while self.running:
|
||||
try:
|
||||
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
|
||||
if rl_agent and hasattr(rl_agent, 'memory'):
|
||||
if len(rl_agent.memory) >= self.min_dqn_experiences and hasattr(rl_agent, 'replay'):
|
||||
loss = rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"DQN replay loss: {loss}")
|
||||
await asyncio.sleep(self.dqn_interval_s)
|
||||
except Exception as e:
|
||||
logger.error(f"DQN trainer loop error: {e}")
|
||||
await asyncio.sleep(self.dqn_interval_s)
|
||||
|
||||
async def _cob_rl_trainer_loop(self):
|
||||
while self.running:
|
||||
try:
|
||||
cob_agent = getattr(self.orchestrator, 'cob_rl_agent', None)
|
||||
if cob_agent and hasattr(cob_agent, 'replay') and hasattr(cob_agent, 'memory'):
|
||||
if len(getattr(cob_agent, 'memory', [])) >= 8:
|
||||
loss = cob_agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"COB RL replay loss: {loss}")
|
||||
await asyncio.sleep(self.cob_rl_interval_s)
|
||||
except Exception as e:
|
||||
logger.error(f"COB RL trainer loop error: {e}")
|
||||
await asyncio.sleep(self.cob_rl_interval_s)
|
||||
|
||||
async def _cnn_trainer_loop(self):
|
||||
while self.running:
|
||||
try:
|
||||
# Placeholder: hook to your CNN trainer if available
|
||||
await asyncio.sleep(self.cnn_interval_s)
|
||||
except Exception as e:
|
||||
logger.error(f"CNN trainer loop error: {e}")
|
||||
await asyncio.sleep(self.cnn_interval_s)
|
||||
|
||||
async def _reward_driven_training_loop(self):
|
||||
while self.running:
|
||||
try:
|
||||
# Pull evaluated samples and feed to respective models
|
||||
symbols = getattr(self.reward_system.reward_calculator, 'symbols', []) if hasattr(self.reward_system, 'reward_calculator') else []
|
||||
for sym in symbols:
|
||||
# Use short horizon for fast feedback
|
||||
samples = self.reward_system.reward_calculator.get_training_data(sym, TimeFrame.SECONDS_1, max_samples=64)
|
||||
if not samples:
|
||||
continue
|
||||
# Currently DQN batch: add to memory and let replay loop train
|
||||
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
|
||||
if rl_agent and hasattr(rl_agent, 'remember'):
|
||||
for rec, reward in samples:
|
||||
# Use state vector captured at inference time when available
|
||||
state = rec.state_vector if getattr(rec, 'state_vector', None) else []
|
||||
if not state:
|
||||
continue
|
||||
action = rec.predicted_direction + 1
|
||||
rl_agent.remember(state, action, reward, state, True)
|
||||
await asyncio.sleep(2)
|
||||
except Exception as e:
|
||||
logger.error(f"Reward-driven training loop error: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
349
docs/ENHANCED_REWARD_SYSTEM.md
Normal file
349
docs/ENHANCED_REWARD_SYSTEM.md
Normal file
@@ -0,0 +1,349 @@
|
||||
# Enhanced Reward System for Reinforcement Learning Training
|
||||
|
||||
## Overview
|
||||
|
||||
This document describes the implementation of an enhanced reward system for your reinforcement learning trading models. The system uses **mean squared error (MSE) between predictions and empirical outcomes** as the primary reward mechanism, with support for multiple timeframes and comprehensive accuracy tracking.
|
||||
|
||||
## Key Features
|
||||
|
||||
### ✅ MSE-Based Reward Calculation
|
||||
- Uses mean squared difference between predicted and actual prices
|
||||
- Exponential decay function heavily penalizes large prediction errors
|
||||
- Direction accuracy bonus/penalty system
|
||||
- Confidence-weighted final rewards
|
||||
|
||||
### ✅ Multi-Timeframe Support
|
||||
- Separate tracking for **1s, 1m, 1h, 1d** timeframes
|
||||
- Independent accuracy metrics for each timeframe
|
||||
- Timeframe-specific evaluation timeouts
|
||||
- Models know which timeframe they're predicting on
|
||||
|
||||
### ✅ Prediction History Tracking
|
||||
- Maintains last **6 predictions per timeframe** per symbol
|
||||
- Comprehensive prediction records with outcomes
|
||||
- Historical accuracy analysis
|
||||
- Memory-efficient with automatic cleanup
|
||||
|
||||
### ✅ Real-Time Training
|
||||
- Training triggered at each inference when outcomes are available
|
||||
- Separate training batches for each model and timeframe
|
||||
- Automatic evaluation of predictions after appropriate timeouts
|
||||
- Integration with existing RL training infrastructure
|
||||
|
||||
### ✅ Enhanced Inference Scheduling
|
||||
- **Continuous inference** every 1-5 seconds on primary timeframe
|
||||
- **Hourly multi-timeframe inference** (4 predictions per hour - one for each timeframe)
|
||||
- Timeframe-aware inference context
|
||||
- Proper scheduling and coordination
|
||||
|
||||
## Architecture
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[Market Data] --> B[Timeframe Inference Coordinator]
|
||||
B --> C[Model Inference]
|
||||
C --> D[Enhanced Reward Calculator]
|
||||
D --> E[Prediction Tracking]
|
||||
E --> F[Outcome Evaluation]
|
||||
F --> G[MSE Reward Calculation]
|
||||
G --> H[Enhanced RL Training Adapter]
|
||||
H --> I[Model Training]
|
||||
I --> J[Performance Monitoring]
|
||||
```
|
||||
|
||||
## Core Components
|
||||
|
||||
### 1. EnhancedRewardCalculator (`core/enhanced_reward_calculator.py`)
|
||||
|
||||
**Purpose**: Central reward calculation engine using MSE methodology
|
||||
|
||||
**Key Methods**:
|
||||
- `add_prediction()` - Track new predictions
|
||||
- `evaluate_predictions()` - Calculate rewards when outcomes available
|
||||
- `get_accuracy_summary()` - Comprehensive accuracy metrics
|
||||
- `get_training_data()` - Extract training samples for models
|
||||
|
||||
**Reward Formula**:
|
||||
```python
|
||||
# MSE calculation
|
||||
price_error = actual_price - predicted_price
|
||||
mse = price_error ** 2
|
||||
|
||||
# Normalize to reasonable scale
|
||||
max_mse = (current_price * 0.1) ** 2 # 10% as max expected error
|
||||
normalized_mse = min(mse / max_mse, 1.0)
|
||||
|
||||
# Exponential decay (heavily penalize large errors)
|
||||
mse_reward = exp(-5 * normalized_mse) # Range: [exp(-5), 1]
|
||||
|
||||
# Direction bonus/penalty
|
||||
direction_bonus = 0.5 if direction_correct else -0.5
|
||||
|
||||
# Final reward (confidence weighted)
|
||||
final_reward = (mse_reward + direction_bonus) * confidence
|
||||
```
|
||||
|
||||
### 2. TimeframeInferenceCoordinator (`core/timeframe_inference_coordinator.py`)
|
||||
|
||||
**Purpose**: Coordinates timeframe-aware model inference with proper scheduling
|
||||
|
||||
**Key Features**:
|
||||
- **Continuous inference loop** for each symbol (every 5 seconds)
|
||||
- **Hourly multi-timeframe scheduler** (4 predictions per hour)
|
||||
- **Inference context management** (models know target timeframe)
|
||||
- **Automatic reward evaluation** and training triggers
|
||||
|
||||
**Scheduling**:
|
||||
- **Every 5 seconds**: Inference on primary timeframe (1s)
|
||||
- **Every hour**: One inference for each timeframe (1s, 1m, 1h, 1d)
|
||||
- **Evaluation timeouts**: 5s for 1s predictions, 60s for 1m, 300s for 1h, 900s for 1d
|
||||
|
||||
### 3. EnhancedRLTrainingAdapter (`core/enhanced_rl_training_adapter.py`)
|
||||
|
||||
**Purpose**: Bridge between new reward system and existing RL training infrastructure
|
||||
|
||||
**Key Features**:
|
||||
- **Model inference wrappers** for DQN, COB RL, and CNN models
|
||||
- **Training batch creation** from prediction records and rewards
|
||||
- **Real-time training triggers** based on evaluation results
|
||||
- **Backward compatibility** with existing training systems
|
||||
|
||||
### 4. EnhancedRewardSystemIntegration (`core/enhanced_reward_system_integration.py`)
|
||||
|
||||
**Purpose**: Simple integration point for existing systems
|
||||
|
||||
**Key Features**:
|
||||
- **One-line integration** with existing TradingOrchestrator
|
||||
- **Helper functions** for easy prediction tracking
|
||||
- **Comprehensive monitoring** and statistics
|
||||
- **Minimal code changes** required
|
||||
|
||||
## Integration Guide
|
||||
|
||||
### Step 1: Import Required Components
|
||||
|
||||
Add to your `orchestrator.py`:
|
||||
|
||||
```python
|
||||
from core.enhanced_reward_system_integration import (
|
||||
integrate_enhanced_rewards,
|
||||
add_prediction_to_enhanced_rewards
|
||||
)
|
||||
```
|
||||
|
||||
### Step 2: Initialize in TradingOrchestrator
|
||||
|
||||
In your `TradingOrchestrator.__init__()`:
|
||||
|
||||
```python
|
||||
# Add this line after existing initialization
|
||||
integrate_enhanced_rewards(self, symbols=['ETH/USDT', 'BTC/USDT'])
|
||||
```
|
||||
|
||||
### Step 3: Start the System
|
||||
|
||||
In your `TradingOrchestrator.run()` method:
|
||||
|
||||
```python
|
||||
# Add this line after initialization
|
||||
await self.enhanced_reward_system.start_integration()
|
||||
```
|
||||
|
||||
### Step 4: Track Predictions
|
||||
|
||||
In your model inference methods (CNN, DQN, COB RL):
|
||||
|
||||
```python
|
||||
# Example in CNN inference
|
||||
prediction_id = add_prediction_to_enhanced_rewards(
|
||||
self, # orchestrator instance
|
||||
symbol, # 'ETH/USDT'
|
||||
timeframe, # '1s', '1m', '1h', '1d'
|
||||
predicted_price, # model's price prediction
|
||||
direction, # -1 (down), 0 (neutral), 1 (up)
|
||||
confidence, # 0.0 to 1.0
|
||||
current_price, # current market price
|
||||
'enhanced_cnn' # model name
|
||||
)
|
||||
```
|
||||
|
||||
### Step 5: Monitor Performance
|
||||
|
||||
```python
|
||||
# Get comprehensive statistics
|
||||
stats = self.enhanced_reward_system.get_integration_statistics()
|
||||
accuracy = self.enhanced_reward_system.get_model_accuracy()
|
||||
|
||||
# Force evaluation for testing
|
||||
self.enhanced_reward_system.force_evaluation_and_training('ETH/USDT', '1s')
|
||||
```
|
||||
|
||||
## Usage Example
|
||||
|
||||
See `examples/enhanced_reward_system_example.py` for a complete demonstration.
|
||||
|
||||
```bash
|
||||
python examples/enhanced_reward_system_example.py
|
||||
```
|
||||
|
||||
## Performance Benefits
|
||||
|
||||
### 🎯 Better Accuracy Measurement
|
||||
- **MSE rewards** provide nuanced feedback vs. simple directional accuracy
|
||||
- **Price prediction accuracy** measured alongside direction accuracy
|
||||
- **Confidence-weighted rewards** encourage well-calibrated predictions
|
||||
|
||||
### 📊 Multi-Timeframe Intelligence
|
||||
- **Separate tracking** prevents timeframe confusion
|
||||
- **Timeframe-specific evaluation** accounts for different market dynamics
|
||||
- **Comprehensive accuracy picture** across all prediction horizons
|
||||
|
||||
### ⚡ Real-Time Learning
|
||||
- **Immediate training** when prediction outcomes available
|
||||
- **No batch delays** - models learn from every prediction
|
||||
- **Adaptive training frequency** based on prediction evaluation
|
||||
|
||||
### 🔄 Enhanced Inference Scheduling
|
||||
- **Optimal prediction frequency** balances real-time response with computational efficiency
|
||||
- **Hourly multi-timeframe predictions** provide comprehensive market coverage
|
||||
- **Context-aware models** make better predictions knowing their target timeframe
|
||||
|
||||
## Configuration
|
||||
|
||||
### Evaluation Timeouts (Configurable in EnhancedRewardCalculator)
|
||||
|
||||
```python
|
||||
evaluation_timeouts = {
|
||||
TimeFrame.SECONDS_1: 5, # Evaluate 1s predictions after 5 seconds
|
||||
TimeFrame.MINUTES_1: 60, # Evaluate 1m predictions after 1 minute
|
||||
TimeFrame.HOURS_1: 300, # Evaluate 1h predictions after 5 minutes
|
||||
TimeFrame.DAYS_1: 900 # Evaluate 1d predictions after 15 minutes
|
||||
}
|
||||
```
|
||||
|
||||
### Inference Scheduling (Configurable in TimeframeInferenceCoordinator)
|
||||
|
||||
```python
|
||||
schedule = InferenceSchedule(
|
||||
continuous_interval_seconds=5.0, # Continuous inference every 5 seconds
|
||||
hourly_timeframes=[TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
|
||||
TimeFrame.HOURS_1, TimeFrame.DAYS_1]
|
||||
)
|
||||
```
|
||||
|
||||
### Training Configuration (Configurable in EnhancedRLTrainingAdapter)
|
||||
|
||||
```python
|
||||
min_batch_size = 8 # Minimum samples for training
|
||||
max_batch_size = 64 # Maximum samples per training batch
|
||||
training_interval_seconds = 5.0 # Training check frequency
|
||||
```
|
||||
|
||||
## Monitoring and Statistics
|
||||
|
||||
### Integration Statistics
|
||||
|
||||
```python
|
||||
stats = enhanced_reward_system.get_integration_statistics()
|
||||
```
|
||||
|
||||
Returns:
|
||||
- System running status
|
||||
- Total predictions tracked
|
||||
- Component status
|
||||
- Inference and training statistics
|
||||
- Performance metrics
|
||||
|
||||
### Model Accuracy
|
||||
|
||||
```python
|
||||
accuracy = enhanced_reward_system.get_model_accuracy()
|
||||
```
|
||||
|
||||
Returns for each symbol and timeframe:
|
||||
- Total predictions made
|
||||
- Direction accuracy percentage
|
||||
- Average MSE
|
||||
- Recent prediction count
|
||||
|
||||
### Real-Time Monitoring
|
||||
|
||||
The system provides comprehensive logging at different levels:
|
||||
- `INFO`: Major system events, training results
|
||||
- `DEBUG`: Detailed prediction tracking, reward calculations
|
||||
- `ERROR`: System errors and recovery actions
|
||||
|
||||
## Backward Compatibility
|
||||
|
||||
The enhanced reward system is designed to be **fully backward compatible**:
|
||||
|
||||
✅ **Existing models continue to work** without modification
|
||||
✅ **Existing training systems** remain functional
|
||||
✅ **Existing reward calculations** can run in parallel
|
||||
✅ **Gradual migration** - enable for specific models incrementally
|
||||
|
||||
## Testing and Validation
|
||||
|
||||
### Force Evaluation for Testing
|
||||
|
||||
```python
|
||||
# Force immediate evaluation of all predictions
|
||||
enhanced_reward_system.force_evaluation_and_training()
|
||||
|
||||
# Force evaluation for specific symbol/timeframe
|
||||
enhanced_reward_system.force_evaluation_and_training('ETH/USDT', '1s')
|
||||
```
|
||||
|
||||
### Manual Prediction Addition
|
||||
|
||||
```python
|
||||
# Add predictions manually for testing
|
||||
prediction_id = enhanced_reward_system.add_prediction_manually(
|
||||
symbol='ETH/USDT',
|
||||
timeframe_str='1s',
|
||||
predicted_price=3150.50,
|
||||
predicted_direction=1,
|
||||
confidence=0.85,
|
||||
current_price=3150.00,
|
||||
model_name='test_model'
|
||||
)
|
||||
```
|
||||
|
||||
## Memory Management
|
||||
|
||||
The system includes automatic memory management:
|
||||
|
||||
- **Automatic prediction cleanup** (configurable retention period)
|
||||
- **Circular buffers** for prediction history (max 100 per timeframe)
|
||||
- **Price cache management** (max 1000 price points per symbol)
|
||||
- **Efficient storage** using deques and compressed data structures
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
The architecture supports easy extension for:
|
||||
|
||||
1. **Additional timeframes** (30s, 5m, 15m, etc.)
|
||||
2. **Custom reward functions** (Sharpe ratio, maximum drawdown, etc.)
|
||||
3. **Multi-symbol correlation** rewards
|
||||
4. **Advanced statistical metrics** (Sortino ratio, Calmar ratio)
|
||||
5. **Model ensemble** reward aggregation
|
||||
6. **A/B testing** framework for reward functions
|
||||
|
||||
## Conclusion
|
||||
|
||||
The Enhanced Reward System provides a comprehensive foundation for improving RL model training through:
|
||||
|
||||
- **Precise MSE-based rewards** that accurately measure prediction quality
|
||||
- **Multi-timeframe intelligence** that prevents confusion between different prediction horizons
|
||||
- **Real-time learning** that maximizes training opportunities
|
||||
- **Easy integration** that requires minimal changes to existing code
|
||||
- **Comprehensive monitoring** that provides insights into model performance
|
||||
|
||||
This system addresses the specific requirements you outlined:
|
||||
✅ MSE-based accuracy calculation
|
||||
✅ Training at each inference using last prediction vs. current outcome
|
||||
✅ Separate accuracy tracking for up to 6 last predictions per timeframe
|
||||
✅ Models know which timeframe they're predicting on
|
||||
✅ Hourly multi-timeframe inference (4 predictions per hour)
|
||||
✅ Integration with existing 1-5 second inference frequency
|
||||
|
265
examples/enhanced_reward_system_example.py
Normal file
265
examples/enhanced_reward_system_example.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""
|
||||
Enhanced Reward System Integration Example
|
||||
|
||||
This example demonstrates how to integrate the new MSE-based reward system
|
||||
with the existing trading orchestrator and models.
|
||||
|
||||
Usage:
|
||||
python examples/enhanced_reward_system_example.py
|
||||
|
||||
This example shows:
|
||||
1. How to integrate the enhanced reward system with TradingOrchestrator
|
||||
2. How to add predictions from existing models
|
||||
3. How to monitor accuracy and training statistics
|
||||
4. How the system handles multi-timeframe predictions and training
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# Import the integration components
|
||||
from core.enhanced_reward_system_integration import (
|
||||
integrate_enhanced_rewards,
|
||||
start_enhanced_rewards_for_orchestrator,
|
||||
add_prediction_to_enhanced_rewards
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def demonstrate_enhanced_reward_integration():
|
||||
"""Demonstrate the enhanced reward system integration"""
|
||||
|
||||
print("=" * 80)
|
||||
print("ENHANCED REWARD SYSTEM INTEGRATION DEMONSTRATION")
|
||||
print("=" * 80)
|
||||
|
||||
# Note: This is a demonstration - in real usage, you would use your actual orchestrator
|
||||
# For this example, we'll create a mock orchestrator
|
||||
|
||||
print("\n1. Setting up mock orchestrator...")
|
||||
mock_orchestrator = create_mock_orchestrator()
|
||||
|
||||
print("\n2. Integrating enhanced reward system...")
|
||||
# This is the main integration step - just one line!
|
||||
enhanced_rewards = integrate_enhanced_rewards(mock_orchestrator, ['ETH/USDT', 'BTC/USDT'])
|
||||
|
||||
print("\n3. Starting enhanced reward system...")
|
||||
await start_enhanced_rewards_for_orchestrator(mock_orchestrator)
|
||||
|
||||
print("\n4. System is now running with enhanced rewards!")
|
||||
print(" - CNN predictions every 10 seconds (current rate)")
|
||||
print(" - Continuous inference every 5 seconds")
|
||||
print(" - Hourly multi-timeframe inference (4 predictions per hour)")
|
||||
print(" - Real-time MSE-based reward calculation")
|
||||
print(" - Automatic training when predictions are evaluated")
|
||||
|
||||
# Demonstrate adding predictions from existing models
|
||||
await demonstrate_prediction_tracking(mock_orchestrator)
|
||||
|
||||
# Demonstrate monitoring and statistics
|
||||
await demonstrate_monitoring(mock_orchestrator)
|
||||
|
||||
# Demonstrate force evaluation for testing
|
||||
await demonstrate_force_evaluation(mock_orchestrator)
|
||||
|
||||
print("\n8. Stopping enhanced reward system...")
|
||||
await mock_orchestrator.enhanced_reward_system.stop_integration()
|
||||
|
||||
print("\n✅ Enhanced Reward System demonstration completed successfully!")
|
||||
print("\nTo integrate with your actual system:")
|
||||
print("1. Add these imports to your orchestrator file")
|
||||
print("2. Call integrate_enhanced_rewards(your_orchestrator) in __init__")
|
||||
print("3. Call await start_enhanced_rewards_for_orchestrator(your_orchestrator) in run()")
|
||||
print("4. Use add_prediction_to_enhanced_rewards() in your model inference code")
|
||||
|
||||
|
||||
async def demonstrate_prediction_tracking(orchestrator):
|
||||
"""Demonstrate how to track predictions from existing models"""
|
||||
|
||||
print("\n5. Demonstrating prediction tracking...")
|
||||
|
||||
# Simulate predictions from different models and timeframes
|
||||
predictions = [
|
||||
# CNN predictions for multiple timeframes
|
||||
('ETH/USDT', '1s', 3150.50, 1, 0.85, 3150.00, 'enhanced_cnn'),
|
||||
('ETH/USDT', '1m', 3155.00, 1, 0.78, 3150.00, 'enhanced_cnn'),
|
||||
('ETH/USDT', '1h', 3200.00, 1, 0.72, 3150.00, 'enhanced_cnn'),
|
||||
('ETH/USDT', '1d', 3300.00, 1, 0.65, 3150.00, 'enhanced_cnn'),
|
||||
|
||||
# DQN predictions
|
||||
('ETH/USDT', '1s', 3149.00, -1, 0.70, 3150.00, 'dqn_agent'),
|
||||
('BTC/USDT', '1s', 51200.00, 1, 0.75, 51150.00, 'dqn_agent'),
|
||||
|
||||
# COB RL predictions
|
||||
('ETH/USDT', '1s', 3151.20, 1, 0.88, 3150.00, 'cob_rl'),
|
||||
('BTC/USDT', '1s', 51180.00, 1, 0.82, 51150.00, 'cob_rl'),
|
||||
]
|
||||
|
||||
prediction_ids = []
|
||||
for symbol, timeframe, pred_price, direction, confidence, curr_price, model in predictions:
|
||||
prediction_id = add_prediction_to_enhanced_rewards(
|
||||
orchestrator, symbol, timeframe, pred_price, direction, confidence, curr_price, model
|
||||
)
|
||||
prediction_ids.append(prediction_id)
|
||||
print(f" ✓ Added prediction: {model} predicts {symbol} {timeframe} "
|
||||
f"direction={direction} confidence={confidence:.2f}")
|
||||
|
||||
print(f" 📊 Total predictions added: {len(prediction_ids)}")
|
||||
|
||||
|
||||
async def demonstrate_monitoring(orchestrator):
|
||||
"""Demonstrate monitoring and statistics"""
|
||||
|
||||
print("\n6. Demonstrating monitoring and statistics...")
|
||||
|
||||
# Wait a bit for some processing
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Get integration statistics
|
||||
stats = orchestrator.enhanced_reward_system.get_integration_statistics()
|
||||
|
||||
print(" 📈 Integration Statistics:")
|
||||
print(f" - System running: {stats.get('is_running', False)}")
|
||||
print(f" - Start time: {stats.get('start_time', 'N/A')}")
|
||||
print(f" - Predictions tracked: {stats.get('total_predictions_tracked', 0)}")
|
||||
|
||||
# Get accuracy summary
|
||||
accuracy = orchestrator.enhanced_reward_system.get_model_accuracy()
|
||||
print("\n 🎯 Accuracy Summary by Symbol and Timeframe:")
|
||||
for symbol, timeframes in accuracy.items():
|
||||
print(f" - {symbol}:")
|
||||
for timeframe, metrics in timeframes.items():
|
||||
print(f" - {timeframe}: {metrics['total_predictions']} predictions, "
|
||||
f"{metrics['direction_accuracy']:.1f}% accuracy")
|
||||
|
||||
|
||||
async def demonstrate_force_evaluation(orchestrator):
|
||||
"""Demonstrate force evaluation for testing"""
|
||||
|
||||
print("\n7. Demonstrating force evaluation for testing...")
|
||||
|
||||
# Simulate some price changes by updating prices
|
||||
print(" 💰 Simulating price changes...")
|
||||
orchestrator.enhanced_reward_system.reward_calculator.update_price('ETH/USDT', 3152.50)
|
||||
orchestrator.enhanced_reward_system.reward_calculator.update_price('BTC/USDT', 51175.00)
|
||||
|
||||
# Force evaluation of all predictions
|
||||
print(" ⚡ Force evaluating all predictions...")
|
||||
orchestrator.enhanced_reward_system.force_evaluation_and_training()
|
||||
|
||||
# Get updated statistics
|
||||
await asyncio.sleep(1)
|
||||
stats = orchestrator.enhanced_reward_system.get_integration_statistics()
|
||||
|
||||
print(" 📊 Updated statistics after evaluation:")
|
||||
accuracy = orchestrator.enhanced_reward_system.get_model_accuracy()
|
||||
total_evaluated = sum(
|
||||
sum(tf_data['total_predictions'] for tf_data in symbol_data.values())
|
||||
for symbol_data in accuracy.values()
|
||||
)
|
||||
print(f" - Total predictions evaluated: {total_evaluated}")
|
||||
|
||||
|
||||
def create_mock_orchestrator():
|
||||
"""Create a mock orchestrator for demonstration purposes"""
|
||||
|
||||
class MockDataProvider:
|
||||
def __init__(self):
|
||||
self.current_prices = {
|
||||
'ETH/USDT': 3150.00,
|
||||
'BTC/USDT': 51150.00
|
||||
}
|
||||
|
||||
class MockOrchestrator:
|
||||
def __init__(self):
|
||||
self.data_provider = MockDataProvider()
|
||||
# Add other mock attributes as needed
|
||||
|
||||
return MockOrchestrator()
|
||||
|
||||
|
||||
def show_integration_instructions():
|
||||
"""Show step-by-step integration instructions"""
|
||||
|
||||
print("\n" + "=" * 80)
|
||||
print("INTEGRATION INSTRUCTIONS FOR YOUR ACTUAL SYSTEM")
|
||||
print("=" * 80)
|
||||
|
||||
print("""
|
||||
To integrate the enhanced reward system with your actual TradingOrchestrator:
|
||||
|
||||
1. ADD IMPORTS to your orchestrator.py:
|
||||
```python
|
||||
from core.enhanced_reward_system_integration import (
|
||||
integrate_enhanced_rewards,
|
||||
add_prediction_to_enhanced_rewards
|
||||
)
|
||||
```
|
||||
|
||||
2. INTEGRATE in TradingOrchestrator.__init__():
|
||||
```python
|
||||
# Add this line in your __init__ method
|
||||
integrate_enhanced_rewards(self, symbols=['ETH/USDT', 'BTC/USDT'])
|
||||
```
|
||||
|
||||
3. START in TradingOrchestrator.run():
|
||||
```python
|
||||
# Add this line in your run() method, after initialization
|
||||
await self.enhanced_reward_system.start_integration()
|
||||
```
|
||||
|
||||
4. ADD PREDICTIONS in your model inference code:
|
||||
```python
|
||||
# In your CNN/DQN/COB model inference methods, add:
|
||||
prediction_id = add_prediction_to_enhanced_rewards(
|
||||
self, # orchestrator instance
|
||||
symbol, # e.g., 'ETH/USDT'
|
||||
timeframe, # e.g., '1s', '1m', '1h', '1d'
|
||||
predicted_price, # model's price prediction
|
||||
direction, # -1 (down), 0 (neutral), 1 (up)
|
||||
confidence, # 0.0 to 1.0
|
||||
current_price, # current market price
|
||||
model_name # e.g., 'enhanced_cnn', 'dqn_agent'
|
||||
)
|
||||
```
|
||||
|
||||
5. MONITOR with:
|
||||
```python
|
||||
# Get statistics anytime
|
||||
stats = self.enhanced_reward_system.get_integration_statistics()
|
||||
accuracy = self.enhanced_reward_system.get_model_accuracy()
|
||||
```
|
||||
|
||||
The system will automatically:
|
||||
- Track predictions for multiple timeframes separately
|
||||
- Calculate MSE-based rewards when outcomes are available
|
||||
- Trigger real-time training with enhanced rewards
|
||||
- Maintain accuracy statistics for each model and timeframe
|
||||
- Handle hourly multi-timeframe inference scheduling
|
||||
|
||||
Key Benefits:
|
||||
✅ MSE-based accuracy measurement (better than simple directional accuracy)
|
||||
✅ Separate tracking for up to 6 last predictions per timeframe
|
||||
✅ Real-time training at each inference when outcomes available
|
||||
✅ Multi-timeframe prediction support (1s, 1m, 1h, 1d)
|
||||
✅ Hourly inference on all timeframes (4 predictions per hour)
|
||||
✅ Models know which timeframe they're predicting on
|
||||
✅ Backward compatible with existing code
|
||||
""")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the demonstration
|
||||
asyncio.run(demonstrate_enhanced_reward_integration())
|
||||
|
||||
# Show integration instructions
|
||||
show_integration_instructions()
|
||||
|
@@ -1,74 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Force refresh dashboard model states to show correct DQN status
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import requests
|
||||
import time
|
||||
sys.path.append('.')
|
||||
|
||||
def force_refresh_dashboard():
|
||||
"""Force refresh the dashboard to show correct model states"""
|
||||
print("=== Forcing Dashboard Refresh ===")
|
||||
|
||||
# Try to hit the dashboard endpoint to force a refresh
|
||||
dashboard_url = "http://localhost:8050"
|
||||
|
||||
try:
|
||||
print(f"Attempting to refresh dashboard at {dashboard_url}...")
|
||||
|
||||
# Try to access the main dashboard page
|
||||
response = requests.get(dashboard_url, timeout=5)
|
||||
if response.status_code == 200:
|
||||
print("✅ Dashboard is accessible and responding")
|
||||
else:
|
||||
print(f"⚠️ Dashboard responded with status code: {response.status_code}")
|
||||
|
||||
# Try to access the model states API endpoint if it exists
|
||||
try:
|
||||
api_response = requests.get(f"{dashboard_url}/api/model-states", timeout=5)
|
||||
if api_response.status_code == 200:
|
||||
print("✅ Model states API is accessible")
|
||||
else:
|
||||
print(f"⚠️ Model states API responded with: {api_response.status_code}")
|
||||
except:
|
||||
print("ℹ️ No model states API endpoint found (this is normal)")
|
||||
|
||||
except requests.exceptions.ConnectionError:
|
||||
print("❌ Dashboard is not running or not accessible")
|
||||
print("Please start the dashboard with: python run_clean_dashboard.py")
|
||||
except Exception as e:
|
||||
print(f"❌ Error accessing dashboard: {e}")
|
||||
|
||||
# Also verify the model states are correct
|
||||
print("\n=== Verifying Model States ===")
|
||||
try:
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
dp = DataProvider()
|
||||
orch = TradingOrchestrator(data_provider=dp)
|
||||
|
||||
states = orch.get_model_states()
|
||||
if states and 'dqn' in states:
|
||||
dqn_state = states['dqn']
|
||||
checkpoint_loaded = dqn_state.get('checkpoint_loaded', False)
|
||||
checkpoint_filename = dqn_state.get('checkpoint_filename', 'None')
|
||||
|
||||
print(f"✅ DQN checkpoint_loaded: {checkpoint_loaded}")
|
||||
print(f"✅ DQN checkpoint_filename: {checkpoint_filename}")
|
||||
|
||||
if checkpoint_loaded:
|
||||
print("🎯 DQN should show as ACTIVE with [CKPT], not FRESH")
|
||||
else:
|
||||
print("⚠️ DQN checkpoint not loaded - will show as FRESH")
|
||||
else:
|
||||
print("❌ No DQN state found")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error verifying model states: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
force_refresh_dashboard()
|
@@ -162,8 +162,8 @@ training:
|
||||
|
||||
# RL specific training
|
||||
rl_training_interval: 300 # Train RL every 5 minutes (was 1 hour)
|
||||
min_experiences: 50 # Reduced from 100 for faster learning
|
||||
training_steps_per_cycle: 20 # Increased from 10 for more learning
|
||||
min_experiences: 16 # Lowered to trigger replay sooner in cold-start
|
||||
training_steps_per_cycle: 32 # More steps per cycle to use GPU effectively
|
||||
|
||||
model_type: "optimized_short_term"
|
||||
use_realtime: true
|
||||
|
@@ -1756,32 +1756,25 @@ class CleanTradingDashboard:
|
||||
# Original training metrics callback - temporarily disabled for testing
|
||||
# @self.app.callback(
|
||||
# Output('training-metrics', 'children'),
|
||||
# Lightweight, cached training metrics panel
|
||||
self._training_panel_cache = {"content": None, "ts": 0.0}
|
||||
self._training_panel_ttl = 3.0 # seconds
|
||||
|
||||
@self.app.callback(
|
||||
Output('training-metrics', 'children'),
|
||||
[Input('slow-interval-component', 'n_intervals'),
|
||||
Input('fast-interval-component', 'n_intervals'), # Add fast interval for testing
|
||||
Input('refresh-training-metrics-btn', 'n_clicks')] # Add manual refresh button
|
||||
Input('fast-interval-component', 'n_intervals'),
|
||||
Input('refresh-training-metrics-btn', 'n_clicks')]
|
||||
)
|
||||
def update_training_metrics(slow_intervals, fast_intervals, n_clicks):
|
||||
"""Update training metrics using new clean panel implementation"""
|
||||
logger.info(f"update_training_metrics callback triggered with slow_intervals={slow_intervals}, fast_intervals={fast_intervals}, n_clicks={n_clicks}")
|
||||
try:
|
||||
# Import compact training panel
|
||||
now_ts = time.time()
|
||||
# Serve cached panel if fresh
|
||||
if self._training_panel_cache["content"] is not None and (now_ts - self._training_panel_cache["ts"]) < self._training_panel_ttl:
|
||||
raise PreventUpdate
|
||||
|
||||
from web.models_training_panel import ModelsTrainingPanel
|
||||
|
||||
# Create panel instance with orchestrator
|
||||
panel = ModelsTrainingPanel(orchestrator=self.orchestrator)
|
||||
|
||||
# Ensure enhanced training system is initialized and running
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'initialize_enhanced_training_system'):
|
||||
self.orchestrator.initialize_enhanced_training_system()
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'start_enhanced_training'):
|
||||
self.orchestrator.start_enhanced_training()
|
||||
except Exception as _ets_ex:
|
||||
logger.warning(f"TRAINING: Failed to start orchestrator enhanced training system: {_ets_ex}")
|
||||
|
||||
# Prefer create_panel if available; fallback to render
|
||||
if hasattr(panel, 'create_panel'):
|
||||
panel_content = panel.create_panel()
|
||||
elif hasattr(panel, 'render'):
|
||||
@@ -1789,20 +1782,14 @@ class CleanTradingDashboard:
|
||||
else:
|
||||
panel_content = [html.Div("Training panel not available", className="text-muted small")]
|
||||
|
||||
logger.info("Successfully created training metrics panel")
|
||||
self._training_panel_cache["content"] = panel_content
|
||||
self._training_panel_cache["ts"] = now_ts
|
||||
return panel_content
|
||||
|
||||
except PreventUpdate:
|
||||
logger.info("PreventUpdate raised in training metrics callback")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating training metrics with new panel: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return [
|
||||
html.P("Error loading training panel", className="text-danger small"),
|
||||
html.P(f"Details: {str(e)}", className="text-muted small")
|
||||
]
|
||||
logger.debug(f"Error updating training metrics: {e}")
|
||||
return self._training_panel_cache.get("content") or [html.Div("Training panel unavailable", className="text-muted small")]
|
||||
|
||||
# Universal model toggle callback using pattern matching
|
||||
@self.app.callback(
|
||||
@@ -8149,12 +8136,12 @@ class CleanTradingDashboard:
|
||||
'price_at_prediction': self._get_current_price(symbol)
|
||||
}
|
||||
|
||||
# Sleep for 10 seconds (0.1Hz prediction rate for cold start)
|
||||
time.sleep(10.0)
|
||||
# Sleep for 2 seconds to improve GPU utilization and responsiveness
|
||||
time.sleep(2.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN prediction worker: {e}")
|
||||
time.sleep(10.0) # Wait same interval on error
|
||||
time.sleep(2.0) # Wait same interval on error
|
||||
|
||||
# Start the worker thread
|
||||
import threading
|
||||
|
Reference in New Issue
Block a user