test cases
This commit is contained in:
546
core/trade_data_manager.py
Normal file
546
core/trade_data_manager.py
Normal file
@ -0,0 +1,546 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Trade Data Manager - Centralized trade data capture and training case management
|
||||
|
||||
Handles:
|
||||
- Comprehensive model input capture during trade execution
|
||||
- Storage in testcases structure (positive/negative)
|
||||
- Case indexing and management
|
||||
- Integration with existing negative case trainer
|
||||
- Cold start training data preparation
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import pickle
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradeDataManager:
|
||||
"""Centralized manager for trade data capture and training case storage"""
|
||||
|
||||
def __init__(self, base_dir: str = "testcases"):
|
||||
self.base_dir = base_dir
|
||||
self.cases_cache = {} # In-memory cache of recent cases
|
||||
self.max_cache_size = 100
|
||||
|
||||
# Initialize directory structure
|
||||
self._setup_directory_structure()
|
||||
|
||||
logger.info(f"TradeDataManager initialized with base directory: {base_dir}")
|
||||
|
||||
def _setup_directory_structure(self):
|
||||
"""Setup the testcases directory structure"""
|
||||
try:
|
||||
for case_type in ['positive', 'negative']:
|
||||
for subdir in ['cases', 'sessions', 'models']:
|
||||
dir_path = os.path.join(self.base_dir, case_type, subdir)
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
|
||||
logger.debug("Directory structure setup complete")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error setting up directory structure: {e}")
|
||||
|
||||
def capture_comprehensive_model_inputs(self, symbol: str, action: str, current_price: float,
|
||||
orchestrator=None, data_provider=None) -> Dict[str, Any]:
|
||||
"""Capture comprehensive model inputs for cold start training"""
|
||||
try:
|
||||
logger.info(f"Capturing model inputs for {action} trade on {symbol} at ${current_price:.2f}")
|
||||
|
||||
model_inputs = {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'symbol': symbol,
|
||||
'action': action,
|
||||
'price': current_price,
|
||||
'capture_type': 'trade_execution'
|
||||
}
|
||||
|
||||
# 1. Market State Features
|
||||
try:
|
||||
market_state = self._get_comprehensive_market_state(symbol, current_price, data_provider)
|
||||
model_inputs['market_state'] = market_state
|
||||
logger.debug(f"Captured market state: {len(market_state)} features")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error capturing market state: {e}")
|
||||
model_inputs['market_state'] = {}
|
||||
|
||||
# 2. CNN Features and Predictions
|
||||
try:
|
||||
cnn_data = self._get_cnn_features_and_predictions(symbol, orchestrator)
|
||||
model_inputs['cnn_features'] = cnn_data.get('features', {})
|
||||
model_inputs['cnn_predictions'] = cnn_data.get('predictions', {})
|
||||
logger.debug(f"Captured CNN data: {len(cnn_data)} items")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error capturing CNN data: {e}")
|
||||
model_inputs['cnn_features'] = {}
|
||||
model_inputs['cnn_predictions'] = {}
|
||||
|
||||
# 3. DQN/RL State Features
|
||||
try:
|
||||
dqn_state = self._get_dqn_state_features(symbol, current_price, orchestrator)
|
||||
model_inputs['dqn_state'] = dqn_state
|
||||
logger.debug(f"Captured DQN state: {len(dqn_state) if dqn_state else 0} features")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error capturing DQN state: {e}")
|
||||
model_inputs['dqn_state'] = {}
|
||||
|
||||
# 4. COB (Order Book) Features
|
||||
try:
|
||||
cob_data = self._get_cob_features_for_training(symbol, orchestrator)
|
||||
model_inputs['cob_features'] = cob_data
|
||||
logger.debug(f"Captured COB features: {len(cob_data) if cob_data else 0} features")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error capturing COB features: {e}")
|
||||
model_inputs['cob_features'] = {}
|
||||
|
||||
# 5. Technical Indicators
|
||||
try:
|
||||
technical_indicators = self._get_technical_indicators(symbol, data_provider)
|
||||
model_inputs['technical_indicators'] = technical_indicators
|
||||
logger.debug(f"Captured technical indicators: {len(technical_indicators)} indicators")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error capturing technical indicators: {e}")
|
||||
model_inputs['technical_indicators'] = {}
|
||||
|
||||
# 6. Recent Price History (for context)
|
||||
try:
|
||||
price_history = self._get_recent_price_history(symbol, data_provider, periods=50)
|
||||
model_inputs['price_history'] = price_history
|
||||
logger.debug(f"Captured price history: {len(price_history)} periods")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error capturing price history: {e}")
|
||||
model_inputs['price_history'] = []
|
||||
|
||||
total_features = sum(len(v) if isinstance(v, (dict, list)) else 1 for v in model_inputs.values())
|
||||
logger.info(f"✅ Captured {total_features} total features for cold start training")
|
||||
|
||||
return model_inputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing model inputs: {e}")
|
||||
return {
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'symbol': symbol,
|
||||
'action': action,
|
||||
'price': current_price,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
def store_trade_for_training(self, trade_record: Dict[str, Any]) -> Optional[str]:
|
||||
"""Store trade for future cold start training in testcases structure"""
|
||||
try:
|
||||
# Determine if this will be a positive or negative case based on eventual P&L
|
||||
pnl = trade_record.get('pnl', 0)
|
||||
case_type = "positive" if pnl >= 0 else "negative"
|
||||
|
||||
# Create testcases directory structure
|
||||
case_dir = os.path.join(self.base_dir, case_type)
|
||||
cases_dir = os.path.join(case_dir, "cases")
|
||||
|
||||
# Create unique case ID
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
symbol_clean = trade_record['symbol'].replace('/', '')
|
||||
case_id = f"{case_type}_{timestamp}_{symbol_clean}_pnl_{pnl:.4f}".replace('.', 'p').replace('-', 'neg')
|
||||
|
||||
# Store comprehensive case data as pickle (for complex model inputs)
|
||||
case_filepath = os.path.join(cases_dir, f"{case_id}.pkl")
|
||||
with open(case_filepath, 'wb') as f:
|
||||
pickle.dump(trade_record, f)
|
||||
|
||||
# Store JSON summary for easy viewing
|
||||
json_filepath = os.path.join(cases_dir, f"{case_id}.json")
|
||||
json_summary = {
|
||||
'case_id': case_id,
|
||||
'timestamp': trade_record.get('entry_time', datetime.now()).isoformat() if hasattr(trade_record.get('entry_time'), 'isoformat') else str(trade_record.get('entry_time')),
|
||||
'symbol': trade_record['symbol'],
|
||||
'side': trade_record['side'],
|
||||
'entry_price': trade_record['entry_price'],
|
||||
'pnl': pnl,
|
||||
'confidence': trade_record.get('confidence', 0),
|
||||
'trade_type': trade_record.get('trade_type', 'unknown'),
|
||||
'model_inputs_captured': bool(trade_record.get('model_inputs_at_entry')),
|
||||
'training_ready': trade_record.get('training_ready', False),
|
||||
'feature_counts': {
|
||||
'market_state': len(trade_record.get('entry_market_state', {})),
|
||||
'cnn_features': len(trade_record.get('model_inputs_at_entry', {}).get('cnn_features', {})),
|
||||
'dqn_state': len(trade_record.get('model_inputs_at_entry', {}).get('dqn_state', {})),
|
||||
'cob_features': len(trade_record.get('model_inputs_at_entry', {}).get('cob_features', {})),
|
||||
'technical_indicators': len(trade_record.get('model_inputs_at_entry', {}).get('technical_indicators', {})),
|
||||
'price_history': len(trade_record.get('model_inputs_at_entry', {}).get('price_history', []))
|
||||
}
|
||||
}
|
||||
|
||||
with open(json_filepath, 'w') as f:
|
||||
json.dump(json_summary, f, indent=2, default=str)
|
||||
|
||||
# Update case index
|
||||
self._update_case_index(case_dir, case_id, json_summary, case_type)
|
||||
|
||||
# Add to cache
|
||||
self.cases_cache[case_id] = json_summary
|
||||
if len(self.cases_cache) > self.max_cache_size:
|
||||
# Remove oldest entry
|
||||
oldest_key = next(iter(self.cases_cache))
|
||||
del self.cases_cache[oldest_key]
|
||||
|
||||
logger.info(f"✅ Stored {case_type} case for training: {case_id}")
|
||||
logger.info(f" PKL: {case_filepath}")
|
||||
logger.info(f" JSON: {json_filepath}")
|
||||
|
||||
return case_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing trade for training: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return None
|
||||
|
||||
def _update_case_index(self, case_dir: str, case_id: str, case_summary: Dict[str, Any], case_type: str):
|
||||
"""Update the case index file"""
|
||||
try:
|
||||
index_file = os.path.join(case_dir, "case_index.json")
|
||||
|
||||
# Load existing index or create new one
|
||||
if os.path.exists(index_file):
|
||||
with open(index_file, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
else:
|
||||
index_data = {"cases": [], "last_updated": None}
|
||||
|
||||
# Add new case
|
||||
index_entry = {
|
||||
"case_id": case_id,
|
||||
"timestamp": case_summary['timestamp'],
|
||||
"symbol": case_summary['symbol'],
|
||||
"pnl": case_summary['pnl'],
|
||||
"training_priority": self._calculate_training_priority(case_summary, case_type),
|
||||
"retraining_count": 0,
|
||||
"feature_counts": case_summary['feature_counts']
|
||||
}
|
||||
|
||||
index_data["cases"].append(index_entry)
|
||||
index_data["last_updated"] = datetime.now().isoformat()
|
||||
|
||||
# Save updated index
|
||||
with open(index_file, 'w') as f:
|
||||
json.dump(index_data, f, indent=2)
|
||||
|
||||
logger.debug(f"Updated case index: {case_id}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating case index: {e}")
|
||||
|
||||
def _calculate_training_priority(self, case_summary: Dict[str, Any], case_type: str) -> int:
|
||||
"""Calculate training priority based on case characteristics"""
|
||||
try:
|
||||
pnl = abs(case_summary.get('pnl', 0))
|
||||
confidence = case_summary.get('confidence', 0)
|
||||
|
||||
# Higher priority for larger losses/gains and high confidence wrong predictions
|
||||
if case_type == "negative":
|
||||
# Larger losses get higher priority, especially with high confidence
|
||||
priority = min(5, int(pnl * 10) + int(confidence * 2))
|
||||
else:
|
||||
# Profits get medium priority unless very large
|
||||
priority = min(3, int(pnl * 5) + 1)
|
||||
|
||||
return max(1, priority) # Minimum priority of 1
|
||||
|
||||
except Exception:
|
||||
return 1 # Default priority
|
||||
|
||||
def get_training_cases(self, case_type: str = "negative", limit: int = 50) -> List[Dict[str, Any]]:
|
||||
"""Get training cases for model training"""
|
||||
try:
|
||||
case_dir = os.path.join(self.base_dir, case_type)
|
||||
index_file = os.path.join(case_dir, "case_index.json")
|
||||
|
||||
if not os.path.exists(index_file):
|
||||
return []
|
||||
|
||||
with open(index_file, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
|
||||
# Sort by training priority (highest first) and limit
|
||||
cases = sorted(index_data["cases"],
|
||||
key=lambda x: x.get("training_priority", 1),
|
||||
reverse=True)[:limit]
|
||||
|
||||
return cases
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting training cases: {e}")
|
||||
return []
|
||||
|
||||
def load_case_data(self, case_id: str, case_type: str = None) -> Optional[Dict[str, Any]]:
|
||||
"""Load full case data from pickle file"""
|
||||
try:
|
||||
# Determine case type if not provided
|
||||
if case_type is None:
|
||||
case_type = "positive" if "positive" in case_id else "negative"
|
||||
|
||||
case_filepath = os.path.join(self.base_dir, case_type, "cases", f"{case_id}.pkl")
|
||||
|
||||
if not os.path.exists(case_filepath):
|
||||
logger.warning(f"Case file not found: {case_filepath}")
|
||||
return None
|
||||
|
||||
with open(case_filepath, 'rb') as f:
|
||||
case_data = pickle.load(f)
|
||||
|
||||
return case_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading case data for {case_id}: {e}")
|
||||
return None
|
||||
|
||||
def cleanup_old_cases(self, days_to_keep: int = 30):
|
||||
"""Clean up old test cases to manage storage"""
|
||||
try:
|
||||
from datetime import timedelta
|
||||
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
for case_type in ['positive', 'negative']:
|
||||
case_dir = os.path.join(self.base_dir, case_type)
|
||||
cases_dir = os.path.join(case_dir, "cases")
|
||||
|
||||
if not os.path.exists(cases_dir):
|
||||
continue
|
||||
|
||||
# Get case index
|
||||
index_file = os.path.join(case_dir, "case_index.json")
|
||||
if os.path.exists(index_file):
|
||||
with open(index_file, 'r') as f:
|
||||
index_data = json.load(f)
|
||||
|
||||
# Filter cases to keep
|
||||
cases_to_keep = []
|
||||
cases_removed = 0
|
||||
|
||||
for case in index_data["cases"]:
|
||||
case_date = datetime.fromisoformat(case["timestamp"])
|
||||
if case_date > cutoff_date:
|
||||
cases_to_keep.append(case)
|
||||
else:
|
||||
# Remove case files
|
||||
case_id = case["case_id"]
|
||||
pkl_file = os.path.join(cases_dir, f"{case_id}.pkl")
|
||||
json_file = os.path.join(cases_dir, f"{case_id}.json")
|
||||
|
||||
for file_path in [pkl_file, json_file]:
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
cases_removed += 1
|
||||
|
||||
# Update index
|
||||
index_data["cases"] = cases_to_keep
|
||||
index_data["last_updated"] = datetime.now().isoformat()
|
||||
|
||||
with open(index_file, 'w') as f:
|
||||
json.dump(index_data, f, indent=2)
|
||||
|
||||
if cases_removed > 0:
|
||||
logger.info(f"Cleaned up {cases_removed} old {case_type} cases")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up old cases: {e}")
|
||||
|
||||
# Helper methods for feature extraction
|
||||
def _get_comprehensive_market_state(self, symbol: str, current_price: float, data_provider) -> Dict[str, float]:
|
||||
"""Get comprehensive market state features"""
|
||||
try:
|
||||
if not data_provider:
|
||||
return {'current_price': current_price}
|
||||
|
||||
market_state = {'current_price': current_price}
|
||||
|
||||
# Get historical data for features
|
||||
df = data_provider.get_historical_data(symbol, '1m', limit=100)
|
||||
if df is not None and not df.empty:
|
||||
prices = df['close'].values
|
||||
volumes = df['volume'].values
|
||||
|
||||
# Price features
|
||||
market_state['price_sma_5'] = float(prices[-5:].mean())
|
||||
market_state['price_sma_20'] = float(prices[-20:].mean())
|
||||
market_state['price_std_20'] = float(prices[-20:].std())
|
||||
market_state['price_rsi'] = self._calculate_rsi(prices, 14)
|
||||
|
||||
# Volume features
|
||||
market_state['volume_current'] = float(volumes[-1])
|
||||
market_state['volume_sma_20'] = float(volumes[-20:].mean())
|
||||
market_state['volume_ratio'] = float(volumes[-1] / volumes[-20:].mean())
|
||||
|
||||
# Trend features
|
||||
market_state['price_momentum_5'] = float((prices[-1] - prices[-5]) / prices[-5])
|
||||
market_state['price_momentum_20'] = float((prices[-1] - prices[-20]) / prices[-20])
|
||||
|
||||
# Add timestamp features
|
||||
now = datetime.now()
|
||||
market_state['hour_of_day'] = now.hour
|
||||
market_state['minute_of_hour'] = now.minute
|
||||
market_state['day_of_week'] = now.weekday()
|
||||
|
||||
return market_state
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting market state: {e}")
|
||||
return {'current_price': current_price}
|
||||
|
||||
def _calculate_rsi(self, prices, period=14):
|
||||
"""Calculate RSI indicator"""
|
||||
try:
|
||||
deltas = np.diff(prices)
|
||||
gains = np.where(deltas > 0, deltas, 0)
|
||||
losses = np.where(deltas < 0, -deltas, 0)
|
||||
|
||||
avg_gain = np.mean(gains[-period:])
|
||||
avg_loss = np.mean(losses[-period:])
|
||||
|
||||
if avg_loss == 0:
|
||||
return 100.0
|
||||
|
||||
rs = avg_gain / avg_loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return float(rsi)
|
||||
except:
|
||||
return 50.0 # Neutral RSI
|
||||
|
||||
def _get_cnn_features_and_predictions(self, symbol: str, orchestrator) -> Dict[str, Any]:
|
||||
"""Get CNN features and predictions from orchestrator"""
|
||||
try:
|
||||
if not orchestrator:
|
||||
return {}
|
||||
|
||||
cnn_data = {}
|
||||
|
||||
# Get CNN features if available
|
||||
if hasattr(orchestrator, 'latest_cnn_features'):
|
||||
cnn_features = getattr(orchestrator, 'latest_cnn_features', {}).get(symbol)
|
||||
if cnn_features is not None:
|
||||
cnn_data['features'] = cnn_features.tolist() if hasattr(cnn_features, 'tolist') else cnn_features
|
||||
|
||||
# Get CNN predictions if available
|
||||
if hasattr(orchestrator, 'latest_cnn_predictions'):
|
||||
cnn_predictions = getattr(orchestrator, 'latest_cnn_predictions', {}).get(symbol)
|
||||
if cnn_predictions is not None:
|
||||
cnn_data['predictions'] = cnn_predictions.tolist() if hasattr(cnn_predictions, 'tolist') else cnn_predictions
|
||||
|
||||
return cnn_data
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting CNN data: {e}")
|
||||
return {}
|
||||
|
||||
def _get_dqn_state_features(self, symbol: str, current_price: float, orchestrator) -> Dict[str, Any]:
|
||||
"""Get DQN state features from orchestrator"""
|
||||
try:
|
||||
if not orchestrator:
|
||||
return {}
|
||||
|
||||
# Get DQN state from orchestrator if available
|
||||
if hasattr(orchestrator, 'build_comprehensive_rl_state'):
|
||||
rl_state = orchestrator.build_comprehensive_rl_state(symbol)
|
||||
if rl_state is not None:
|
||||
return {
|
||||
'state_vector': rl_state.tolist() if hasattr(rl_state, 'tolist') else rl_state,
|
||||
'state_size': len(rl_state) if hasattr(rl_state, '__len__') else 0
|
||||
}
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting DQN state: {e}")
|
||||
return {}
|
||||
|
||||
def _get_cob_features_for_training(self, symbol: str, orchestrator) -> Dict[str, Any]:
|
||||
"""Get COB features for training"""
|
||||
try:
|
||||
if not orchestrator:
|
||||
return {}
|
||||
|
||||
cob_data = {}
|
||||
|
||||
# Get COB features from orchestrator
|
||||
if hasattr(orchestrator, 'latest_cob_features'):
|
||||
cob_features = getattr(orchestrator, 'latest_cob_features', {}).get(symbol)
|
||||
if cob_features is not None:
|
||||
cob_data['features'] = cob_features.tolist() if hasattr(cob_features, 'tolist') else cob_features
|
||||
|
||||
# Get COB snapshot
|
||||
if hasattr(orchestrator, 'cob_integration') and orchestrator.cob_integration:
|
||||
if hasattr(orchestrator.cob_integration, 'get_cob_snapshot'):
|
||||
cob_snapshot = orchestrator.cob_integration.get_cob_snapshot(symbol)
|
||||
if cob_snapshot:
|
||||
cob_data['snapshot_available'] = True
|
||||
cob_data['bid_levels'] = len(getattr(cob_snapshot, 'consolidated_bids', []))
|
||||
cob_data['ask_levels'] = len(getattr(cob_snapshot, 'consolidated_asks', []))
|
||||
else:
|
||||
cob_data['snapshot_available'] = False
|
||||
|
||||
return cob_data
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting COB features: {e}")
|
||||
return {}
|
||||
|
||||
def _get_technical_indicators(self, symbol: str, data_provider) -> Dict[str, float]:
|
||||
"""Get technical indicators"""
|
||||
try:
|
||||
if not data_provider:
|
||||
return {}
|
||||
|
||||
indicators = {}
|
||||
|
||||
# Get recent price data
|
||||
df = data_provider.get_historical_data(symbol, '1m', limit=50)
|
||||
if df is not None and not df.empty:
|
||||
closes = df['close'].values
|
||||
highs = df['high'].values
|
||||
lows = df['low'].values
|
||||
volumes = df['volume'].values
|
||||
|
||||
# Moving averages
|
||||
indicators['sma_10'] = float(closes[-10:].mean())
|
||||
indicators['sma_20'] = float(closes[-20:].mean())
|
||||
|
||||
# Bollinger Bands
|
||||
sma_20 = closes[-20:].mean()
|
||||
std_20 = closes[-20:].std()
|
||||
indicators['bb_upper'] = float(sma_20 + 2 * std_20)
|
||||
indicators['bb_lower'] = float(sma_20 - 2 * std_20)
|
||||
indicators['bb_position'] = float((closes[-1] - indicators['bb_lower']) / (indicators['bb_upper'] - indicators['bb_lower']))
|
||||
|
||||
# MACD
|
||||
ema_12 = closes[-12:].mean() # Simplified
|
||||
ema_26 = closes[-26:].mean() # Simplified
|
||||
indicators['macd'] = float(ema_12 - ema_26)
|
||||
|
||||
# Volatility
|
||||
indicators['volatility'] = float(std_20 / sma_20)
|
||||
|
||||
return indicators
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error calculating technical indicators: {e}")
|
||||
return {}
|
||||
|
||||
def _get_recent_price_history(self, symbol: str, data_provider, periods: int = 50) -> List[float]:
|
||||
"""Get recent price history"""
|
||||
try:
|
||||
if not data_provider:
|
||||
return []
|
||||
|
||||
df = data_provider.get_historical_data(symbol, '1m', limit=periods)
|
||||
if df is not None and not df.empty:
|
||||
return df['close'].tolist()
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting price history: {e}")
|
||||
return []
|
304
core/training_integration.py
Normal file
304
core/training_integration.py
Normal file
@ -0,0 +1,304 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Training Integration - Handles cold start training and model learning integration
|
||||
|
||||
Manages:
|
||||
- Cold start training triggers from trade outcomes
|
||||
- Reward calculation based on P&L
|
||||
- Integration with DQN, CNN, and COB RL models
|
||||
- Training session management
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TrainingIntegration:
|
||||
"""Manages training integration for cold start learning"""
|
||||
|
||||
def __init__(self, orchestrator=None):
|
||||
self.orchestrator = orchestrator
|
||||
self.training_sessions = {}
|
||||
self.min_confidence_threshold = 0.3
|
||||
|
||||
logger.info("TrainingIntegration initialized")
|
||||
|
||||
def trigger_cold_start_training(self, trade_record: Dict[str, Any], case_id: str = None) -> bool:
|
||||
"""Trigger cold start training when trades close with known outcomes"""
|
||||
try:
|
||||
if not trade_record.get('model_inputs_at_entry'):
|
||||
logger.warning("No model inputs captured for training - skipping")
|
||||
return False
|
||||
|
||||
pnl = trade_record.get('pnl', 0)
|
||||
confidence = trade_record.get('confidence', 0)
|
||||
|
||||
logger.info(f"Triggering cold start training for trade with P&L: ${pnl:.4f}")
|
||||
|
||||
# Calculate training reward based on P&L and confidence
|
||||
reward = self._calculate_training_reward(pnl, confidence)
|
||||
|
||||
# Train DQN on trade outcome
|
||||
dqn_success = self._train_dqn_on_trade_outcome(trade_record, reward)
|
||||
|
||||
# Train CNN if available (placeholder for now)
|
||||
cnn_success = self._train_cnn_on_trade_outcome(trade_record, reward)
|
||||
|
||||
# Train COB RL if available (placeholder for now)
|
||||
cob_success = self._train_cob_rl_on_trade_outcome(trade_record, reward)
|
||||
|
||||
# Log training results
|
||||
training_success = any([dqn_success, cnn_success, cob_success])
|
||||
if training_success:
|
||||
logger.info(f"Cold start training completed - DQN: {dqn_success}, CNN: {cnn_success}, COB: {cob_success}")
|
||||
else:
|
||||
logger.warning("Cold start training failed for all models")
|
||||
|
||||
return training_success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cold start training: {e}")
|
||||
return False
|
||||
|
||||
def _calculate_training_reward(self, pnl: float, confidence: float) -> float:
|
||||
"""Calculate training reward based on P&L and confidence"""
|
||||
try:
|
||||
# Base reward is proportional to P&L
|
||||
base_reward = pnl
|
||||
|
||||
# Adjust for confidence - penalize high confidence wrong predictions more
|
||||
if pnl < 0 and confidence > 0.7:
|
||||
# High confidence loss - significant negative reward
|
||||
confidence_adjustment = -confidence * 2
|
||||
elif pnl > 0 and confidence > 0.7:
|
||||
# High confidence gain - boost reward
|
||||
confidence_adjustment = confidence * 1.5
|
||||
else:
|
||||
# Low confidence - minimal adjustment
|
||||
confidence_adjustment = 0
|
||||
|
||||
final_reward = base_reward + confidence_adjustment
|
||||
|
||||
# Normalize to [-1, 1] range for training stability
|
||||
normalized_reward = np.tanh(final_reward / 10.0)
|
||||
|
||||
logger.debug(f"Training reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}")
|
||||
|
||||
return float(normalized_reward)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating training reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def _train_dqn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
|
||||
"""Train DQN agent on trade outcome"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
logger.warning("No orchestrator available for DQN training")
|
||||
return False
|
||||
|
||||
# Get DQN agent
|
||||
if not hasattr(self.orchestrator, 'dqn_agent') or not self.orchestrator.dqn_agent:
|
||||
logger.warning("DQN agent not available for training")
|
||||
return False
|
||||
|
||||
# Extract DQN state from model inputs
|
||||
model_inputs = trade_record.get('model_inputs_at_entry', {})
|
||||
dqn_state = model_inputs.get('dqn_state', {}).get('state_vector')
|
||||
|
||||
if not dqn_state:
|
||||
logger.warning("No DQN state available for training")
|
||||
return False
|
||||
|
||||
# Convert action to DQN action index
|
||||
action = trade_record.get('side', 'HOLD').upper()
|
||||
action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
|
||||
action_idx = action_map.get(action, 2)
|
||||
|
||||
# Create next state (simplified - could be current market state)
|
||||
next_state = dqn_state # Placeholder - should be state after trade
|
||||
|
||||
# Store experience in DQN memory
|
||||
dqn_agent = self.orchestrator.dqn_agent
|
||||
if hasattr(dqn_agent, 'store_experience'):
|
||||
dqn_agent.store_experience(
|
||||
state=np.array(dqn_state),
|
||||
action=action_idx,
|
||||
reward=reward,
|
||||
next_state=np.array(next_state),
|
||||
done=True # Trade is complete
|
||||
)
|
||||
|
||||
# Trigger training if enough experiences
|
||||
if hasattr(dqn_agent, 'replay') and len(getattr(dqn_agent, 'memory', [])) > 32:
|
||||
dqn_agent.replay(batch_size=32)
|
||||
logger.info("DQN training step completed")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.warning("DQN agent doesn't support experience storage")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training DQN on trade outcome: {e}")
|
||||
return False
|
||||
|
||||
def _train_cnn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
|
||||
"""Train CNN on trade outcome (placeholder)"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
return False
|
||||
|
||||
# Check if CNN is available
|
||||
if not hasattr(self.orchestrator, 'williams_cnn') or not self.orchestrator.williams_cnn:
|
||||
logger.debug("CNN not available for training")
|
||||
return False
|
||||
|
||||
# Get CNN features from model inputs
|
||||
model_inputs = trade_record.get('model_inputs_at_entry', {})
|
||||
cnn_features = model_inputs.get('cnn_features')
|
||||
cnn_predictions = model_inputs.get('cnn_predictions')
|
||||
|
||||
if not cnn_features or not cnn_predictions:
|
||||
logger.debug("No CNN features available for training")
|
||||
return False
|
||||
|
||||
# CNN training would go here - requires more specific implementation
|
||||
# For now, just log that we could train CNN
|
||||
logger.debug(f"CNN training opportunity: features={len(cnn_features)}, predictions={len(cnn_predictions)}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in CNN training: {e}")
|
||||
return False
|
||||
|
||||
def _train_cob_rl_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
|
||||
"""Train COB RL on trade outcome (placeholder)"""
|
||||
try:
|
||||
if not self.orchestrator:
|
||||
return False
|
||||
|
||||
# Check if COB integration is available
|
||||
if not hasattr(self.orchestrator, 'cob_integration') or not self.orchestrator.cob_integration:
|
||||
logger.debug("COB integration not available for training")
|
||||
return False
|
||||
|
||||
# Get COB features from model inputs
|
||||
model_inputs = trade_record.get('model_inputs_at_entry', {})
|
||||
cob_features = model_inputs.get('cob_features')
|
||||
|
||||
if not cob_features:
|
||||
logger.debug("No COB features available for training")
|
||||
return False
|
||||
|
||||
# COB RL training would go here - requires more specific implementation
|
||||
# For now, just log that we could train COB RL
|
||||
logger.debug(f"COB RL training opportunity: features={len(cob_features)}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error in COB RL training: {e}")
|
||||
return False
|
||||
|
||||
def get_training_status(self) -> Dict[str, Any]:
|
||||
"""Get current training integration status"""
|
||||
try:
|
||||
status = {
|
||||
'orchestrator_available': self.orchestrator is not None,
|
||||
'training_sessions': len(self.training_sessions),
|
||||
'last_update': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if self.orchestrator:
|
||||
status['dqn_available'] = hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent is not None
|
||||
status['cnn_available'] = hasattr(self.orchestrator, 'williams_cnn') and self.orchestrator.williams_cnn is not None
|
||||
status['cob_available'] = hasattr(self.orchestrator, 'cob_integration') and self.orchestrator.cob_integration is not None
|
||||
|
||||
return status
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting training status: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
def start_training_session(self, session_name: str, config: Dict[str, Any] = None) -> str:
|
||||
"""Start a new training session"""
|
||||
try:
|
||||
session_id = f"{session_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
|
||||
session_data = {
|
||||
'session_id': session_id,
|
||||
'session_name': session_name,
|
||||
'start_time': datetime.now().isoformat(),
|
||||
'config': config or {},
|
||||
'trades_processed': 0,
|
||||
'successful_trainings': 0,
|
||||
'failed_trainings': 0
|
||||
}
|
||||
|
||||
self.training_sessions[session_id] = session_data
|
||||
|
||||
logger.info(f"Started training session: {session_id}")
|
||||
|
||||
return session_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error starting training session: {e}")
|
||||
return ""
|
||||
|
||||
def end_training_session(self, session_id: str) -> Dict[str, Any]:
|
||||
"""End a training session and return summary"""
|
||||
try:
|
||||
if session_id not in self.training_sessions:
|
||||
logger.warning(f"Training session not found: {session_id}")
|
||||
return {}
|
||||
|
||||
session_data = self.training_sessions[session_id]
|
||||
session_data['end_time'] = datetime.now().isoformat()
|
||||
|
||||
# Calculate session duration
|
||||
start_time = datetime.fromisoformat(session_data['start_time'])
|
||||
end_time = datetime.fromisoformat(session_data['end_time'])
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
session_data['duration_seconds'] = duration
|
||||
|
||||
# Calculate success rate
|
||||
total_attempts = session_data['successful_trainings'] + session_data['failed_trainings']
|
||||
session_data['success_rate'] = session_data['successful_trainings'] / total_attempts if total_attempts > 0 else 0
|
||||
|
||||
logger.info(f"Ended training session: {session_id}")
|
||||
logger.info(f" Duration: {duration:.1f}s")
|
||||
logger.info(f" Trades processed: {session_data['trades_processed']}")
|
||||
logger.info(f" Success rate: {session_data['success_rate']:.2%}")
|
||||
|
||||
# Remove from active sessions
|
||||
completed_session = self.training_sessions.pop(session_id)
|
||||
|
||||
return completed_session
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error ending training session: {e}")
|
||||
return {}
|
||||
|
||||
def update_session_stats(self, session_id: str, trade_processed: bool = True, training_success: bool = False):
|
||||
"""Update training session statistics"""
|
||||
try:
|
||||
if session_id not in self.training_sessions:
|
||||
return
|
||||
|
||||
session = self.training_sessions[session_id]
|
||||
|
||||
if trade_processed:
|
||||
session['trades_processed'] += 1
|
||||
|
||||
if training_success:
|
||||
session['successful_trainings'] += 1
|
||||
else:
|
||||
session['failed_trainings'] += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating session stats: {e}")
|
Reference in New Issue
Block a user