Files
gogo2/core/trade_data_manager.py
Dobromir Popov 4afa147bd1 test cases
2025-06-25 14:45:37 +03:00

546 lines
24 KiB
Python

#!/usr/bin/env python3
"""
Trade Data Manager - Centralized trade data capture and training case management
Handles:
- Comprehensive model input capture during trade execution
- Storage in testcases structure (positive/negative)
- Case indexing and management
- Integration with existing negative case trainer
- Cold start training data preparation
"""
import os
import json
import pickle
import logging
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple
import numpy as np
logger = logging.getLogger(__name__)
class TradeDataManager:
"""Centralized manager for trade data capture and training case storage"""
def __init__(self, base_dir: str = "testcases"):
self.base_dir = base_dir
self.cases_cache = {} # In-memory cache of recent cases
self.max_cache_size = 100
# Initialize directory structure
self._setup_directory_structure()
logger.info(f"TradeDataManager initialized with base directory: {base_dir}")
def _setup_directory_structure(self):
"""Setup the testcases directory structure"""
try:
for case_type in ['positive', 'negative']:
for subdir in ['cases', 'sessions', 'models']:
dir_path = os.path.join(self.base_dir, case_type, subdir)
os.makedirs(dir_path, exist_ok=True)
logger.debug("Directory structure setup complete")
except Exception as e:
logger.error(f"Error setting up directory structure: {e}")
def capture_comprehensive_model_inputs(self, symbol: str, action: str, current_price: float,
orchestrator=None, data_provider=None) -> Dict[str, Any]:
"""Capture comprehensive model inputs for cold start training"""
try:
logger.info(f"Capturing model inputs for {action} trade on {symbol} at ${current_price:.2f}")
model_inputs = {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'action': action,
'price': current_price,
'capture_type': 'trade_execution'
}
# 1. Market State Features
try:
market_state = self._get_comprehensive_market_state(symbol, current_price, data_provider)
model_inputs['market_state'] = market_state
logger.debug(f"Captured market state: {len(market_state)} features")
except Exception as e:
logger.warning(f"Error capturing market state: {e}")
model_inputs['market_state'] = {}
# 2. CNN Features and Predictions
try:
cnn_data = self._get_cnn_features_and_predictions(symbol, orchestrator)
model_inputs['cnn_features'] = cnn_data.get('features', {})
model_inputs['cnn_predictions'] = cnn_data.get('predictions', {})
logger.debug(f"Captured CNN data: {len(cnn_data)} items")
except Exception as e:
logger.warning(f"Error capturing CNN data: {e}")
model_inputs['cnn_features'] = {}
model_inputs['cnn_predictions'] = {}
# 3. DQN/RL State Features
try:
dqn_state = self._get_dqn_state_features(symbol, current_price, orchestrator)
model_inputs['dqn_state'] = dqn_state
logger.debug(f"Captured DQN state: {len(dqn_state) if dqn_state else 0} features")
except Exception as e:
logger.warning(f"Error capturing DQN state: {e}")
model_inputs['dqn_state'] = {}
# 4. COB (Order Book) Features
try:
cob_data = self._get_cob_features_for_training(symbol, orchestrator)
model_inputs['cob_features'] = cob_data
logger.debug(f"Captured COB features: {len(cob_data) if cob_data else 0} features")
except Exception as e:
logger.warning(f"Error capturing COB features: {e}")
model_inputs['cob_features'] = {}
# 5. Technical Indicators
try:
technical_indicators = self._get_technical_indicators(symbol, data_provider)
model_inputs['technical_indicators'] = technical_indicators
logger.debug(f"Captured technical indicators: {len(technical_indicators)} indicators")
except Exception as e:
logger.warning(f"Error capturing technical indicators: {e}")
model_inputs['technical_indicators'] = {}
# 6. Recent Price History (for context)
try:
price_history = self._get_recent_price_history(symbol, data_provider, periods=50)
model_inputs['price_history'] = price_history
logger.debug(f"Captured price history: {len(price_history)} periods")
except Exception as e:
logger.warning(f"Error capturing price history: {e}")
model_inputs['price_history'] = []
total_features = sum(len(v) if isinstance(v, (dict, list)) else 1 for v in model_inputs.values())
logger.info(f"✅ Captured {total_features} total features for cold start training")
return model_inputs
except Exception as e:
logger.error(f"Error capturing model inputs: {e}")
return {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'action': action,
'price': current_price,
'error': str(e)
}
def store_trade_for_training(self, trade_record: Dict[str, Any]) -> Optional[str]:
"""Store trade for future cold start training in testcases structure"""
try:
# Determine if this will be a positive or negative case based on eventual P&L
pnl = trade_record.get('pnl', 0)
case_type = "positive" if pnl >= 0 else "negative"
# Create testcases directory structure
case_dir = os.path.join(self.base_dir, case_type)
cases_dir = os.path.join(case_dir, "cases")
# Create unique case ID
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
symbol_clean = trade_record['symbol'].replace('/', '')
case_id = f"{case_type}_{timestamp}_{symbol_clean}_pnl_{pnl:.4f}".replace('.', 'p').replace('-', 'neg')
# Store comprehensive case data as pickle (for complex model inputs)
case_filepath = os.path.join(cases_dir, f"{case_id}.pkl")
with open(case_filepath, 'wb') as f:
pickle.dump(trade_record, f)
# Store JSON summary for easy viewing
json_filepath = os.path.join(cases_dir, f"{case_id}.json")
json_summary = {
'case_id': case_id,
'timestamp': trade_record.get('entry_time', datetime.now()).isoformat() if hasattr(trade_record.get('entry_time'), 'isoformat') else str(trade_record.get('entry_time')),
'symbol': trade_record['symbol'],
'side': trade_record['side'],
'entry_price': trade_record['entry_price'],
'pnl': pnl,
'confidence': trade_record.get('confidence', 0),
'trade_type': trade_record.get('trade_type', 'unknown'),
'model_inputs_captured': bool(trade_record.get('model_inputs_at_entry')),
'training_ready': trade_record.get('training_ready', False),
'feature_counts': {
'market_state': len(trade_record.get('entry_market_state', {})),
'cnn_features': len(trade_record.get('model_inputs_at_entry', {}).get('cnn_features', {})),
'dqn_state': len(trade_record.get('model_inputs_at_entry', {}).get('dqn_state', {})),
'cob_features': len(trade_record.get('model_inputs_at_entry', {}).get('cob_features', {})),
'technical_indicators': len(trade_record.get('model_inputs_at_entry', {}).get('technical_indicators', {})),
'price_history': len(trade_record.get('model_inputs_at_entry', {}).get('price_history', []))
}
}
with open(json_filepath, 'w') as f:
json.dump(json_summary, f, indent=2, default=str)
# Update case index
self._update_case_index(case_dir, case_id, json_summary, case_type)
# Add to cache
self.cases_cache[case_id] = json_summary
if len(self.cases_cache) > self.max_cache_size:
# Remove oldest entry
oldest_key = next(iter(self.cases_cache))
del self.cases_cache[oldest_key]
logger.info(f"✅ Stored {case_type} case for training: {case_id}")
logger.info(f" PKL: {case_filepath}")
logger.info(f" JSON: {json_filepath}")
return case_id
except Exception as e:
logger.error(f"Error storing trade for training: {e}")
import traceback
logger.error(traceback.format_exc())
return None
def _update_case_index(self, case_dir: str, case_id: str, case_summary: Dict[str, Any], case_type: str):
"""Update the case index file"""
try:
index_file = os.path.join(case_dir, "case_index.json")
# Load existing index or create new one
if os.path.exists(index_file):
with open(index_file, 'r') as f:
index_data = json.load(f)
else:
index_data = {"cases": [], "last_updated": None}
# Add new case
index_entry = {
"case_id": case_id,
"timestamp": case_summary['timestamp'],
"symbol": case_summary['symbol'],
"pnl": case_summary['pnl'],
"training_priority": self._calculate_training_priority(case_summary, case_type),
"retraining_count": 0,
"feature_counts": case_summary['feature_counts']
}
index_data["cases"].append(index_entry)
index_data["last_updated"] = datetime.now().isoformat()
# Save updated index
with open(index_file, 'w') as f:
json.dump(index_data, f, indent=2)
logger.debug(f"Updated case index: {case_id}")
except Exception as e:
logger.error(f"Error updating case index: {e}")
def _calculate_training_priority(self, case_summary: Dict[str, Any], case_type: str) -> int:
"""Calculate training priority based on case characteristics"""
try:
pnl = abs(case_summary.get('pnl', 0))
confidence = case_summary.get('confidence', 0)
# Higher priority for larger losses/gains and high confidence wrong predictions
if case_type == "negative":
# Larger losses get higher priority, especially with high confidence
priority = min(5, int(pnl * 10) + int(confidence * 2))
else:
# Profits get medium priority unless very large
priority = min(3, int(pnl * 5) + 1)
return max(1, priority) # Minimum priority of 1
except Exception:
return 1 # Default priority
def get_training_cases(self, case_type: str = "negative", limit: int = 50) -> List[Dict[str, Any]]:
"""Get training cases for model training"""
try:
case_dir = os.path.join(self.base_dir, case_type)
index_file = os.path.join(case_dir, "case_index.json")
if not os.path.exists(index_file):
return []
with open(index_file, 'r') as f:
index_data = json.load(f)
# Sort by training priority (highest first) and limit
cases = sorted(index_data["cases"],
key=lambda x: x.get("training_priority", 1),
reverse=True)[:limit]
return cases
except Exception as e:
logger.error(f"Error getting training cases: {e}")
return []
def load_case_data(self, case_id: str, case_type: str = None) -> Optional[Dict[str, Any]]:
"""Load full case data from pickle file"""
try:
# Determine case type if not provided
if case_type is None:
case_type = "positive" if "positive" in case_id else "negative"
case_filepath = os.path.join(self.base_dir, case_type, "cases", f"{case_id}.pkl")
if not os.path.exists(case_filepath):
logger.warning(f"Case file not found: {case_filepath}")
return None
with open(case_filepath, 'rb') as f:
case_data = pickle.load(f)
return case_data
except Exception as e:
logger.error(f"Error loading case data for {case_id}: {e}")
return None
def cleanup_old_cases(self, days_to_keep: int = 30):
"""Clean up old test cases to manage storage"""
try:
from datetime import timedelta
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
for case_type in ['positive', 'negative']:
case_dir = os.path.join(self.base_dir, case_type)
cases_dir = os.path.join(case_dir, "cases")
if not os.path.exists(cases_dir):
continue
# Get case index
index_file = os.path.join(case_dir, "case_index.json")
if os.path.exists(index_file):
with open(index_file, 'r') as f:
index_data = json.load(f)
# Filter cases to keep
cases_to_keep = []
cases_removed = 0
for case in index_data["cases"]:
case_date = datetime.fromisoformat(case["timestamp"])
if case_date > cutoff_date:
cases_to_keep.append(case)
else:
# Remove case files
case_id = case["case_id"]
pkl_file = os.path.join(cases_dir, f"{case_id}.pkl")
json_file = os.path.join(cases_dir, f"{case_id}.json")
for file_path in [pkl_file, json_file]:
if os.path.exists(file_path):
os.remove(file_path)
cases_removed += 1
# Update index
index_data["cases"] = cases_to_keep
index_data["last_updated"] = datetime.now().isoformat()
with open(index_file, 'w') as f:
json.dump(index_data, f, indent=2)
if cases_removed > 0:
logger.info(f"Cleaned up {cases_removed} old {case_type} cases")
except Exception as e:
logger.error(f"Error cleaning up old cases: {e}")
# Helper methods for feature extraction
def _get_comprehensive_market_state(self, symbol: str, current_price: float, data_provider) -> Dict[str, float]:
"""Get comprehensive market state features"""
try:
if not data_provider:
return {'current_price': current_price}
market_state = {'current_price': current_price}
# Get historical data for features
df = data_provider.get_historical_data(symbol, '1m', limit=100)
if df is not None and not df.empty:
prices = df['close'].values
volumes = df['volume'].values
# Price features
market_state['price_sma_5'] = float(prices[-5:].mean())
market_state['price_sma_20'] = float(prices[-20:].mean())
market_state['price_std_20'] = float(prices[-20:].std())
market_state['price_rsi'] = self._calculate_rsi(prices, 14)
# Volume features
market_state['volume_current'] = float(volumes[-1])
market_state['volume_sma_20'] = float(volumes[-20:].mean())
market_state['volume_ratio'] = float(volumes[-1] / volumes[-20:].mean())
# Trend features
market_state['price_momentum_5'] = float((prices[-1] - prices[-5]) / prices[-5])
market_state['price_momentum_20'] = float((prices[-1] - prices[-20]) / prices[-20])
# Add timestamp features
now = datetime.now()
market_state['hour_of_day'] = now.hour
market_state['minute_of_hour'] = now.minute
market_state['day_of_week'] = now.weekday()
return market_state
except Exception as e:
logger.warning(f"Error getting market state: {e}")
return {'current_price': current_price}
def _calculate_rsi(self, prices, period=14):
"""Calculate RSI indicator"""
try:
deltas = np.diff(prices)
gains = np.where(deltas > 0, deltas, 0)
losses = np.where(deltas < 0, -deltas, 0)
avg_gain = np.mean(gains[-period:])
avg_loss = np.mean(losses[-period:])
if avg_loss == 0:
return 100.0
rs = avg_gain / avg_loss
rsi = 100 - (100 / (1 + rs))
return float(rsi)
except:
return 50.0 # Neutral RSI
def _get_cnn_features_and_predictions(self, symbol: str, orchestrator) -> Dict[str, Any]:
"""Get CNN features and predictions from orchestrator"""
try:
if not orchestrator:
return {}
cnn_data = {}
# Get CNN features if available
if hasattr(orchestrator, 'latest_cnn_features'):
cnn_features = getattr(orchestrator, 'latest_cnn_features', {}).get(symbol)
if cnn_features is not None:
cnn_data['features'] = cnn_features.tolist() if hasattr(cnn_features, 'tolist') else cnn_features
# Get CNN predictions if available
if hasattr(orchestrator, 'latest_cnn_predictions'):
cnn_predictions = getattr(orchestrator, 'latest_cnn_predictions', {}).get(symbol)
if cnn_predictions is not None:
cnn_data['predictions'] = cnn_predictions.tolist() if hasattr(cnn_predictions, 'tolist') else cnn_predictions
return cnn_data
except Exception as e:
logger.debug(f"Error getting CNN data: {e}")
return {}
def _get_dqn_state_features(self, symbol: str, current_price: float, orchestrator) -> Dict[str, Any]:
"""Get DQN state features from orchestrator"""
try:
if not orchestrator:
return {}
# Get DQN state from orchestrator if available
if hasattr(orchestrator, 'build_comprehensive_rl_state'):
rl_state = orchestrator.build_comprehensive_rl_state(symbol)
if rl_state is not None:
return {
'state_vector': rl_state.tolist() if hasattr(rl_state, 'tolist') else rl_state,
'state_size': len(rl_state) if hasattr(rl_state, '__len__') else 0
}
return {}
except Exception as e:
logger.debug(f"Error getting DQN state: {e}")
return {}
def _get_cob_features_for_training(self, symbol: str, orchestrator) -> Dict[str, Any]:
"""Get COB features for training"""
try:
if not orchestrator:
return {}
cob_data = {}
# Get COB features from orchestrator
if hasattr(orchestrator, 'latest_cob_features'):
cob_features = getattr(orchestrator, 'latest_cob_features', {}).get(symbol)
if cob_features is not None:
cob_data['features'] = cob_features.tolist() if hasattr(cob_features, 'tolist') else cob_features
# Get COB snapshot
if hasattr(orchestrator, 'cob_integration') and orchestrator.cob_integration:
if hasattr(orchestrator.cob_integration, 'get_cob_snapshot'):
cob_snapshot = orchestrator.cob_integration.get_cob_snapshot(symbol)
if cob_snapshot:
cob_data['snapshot_available'] = True
cob_data['bid_levels'] = len(getattr(cob_snapshot, 'consolidated_bids', []))
cob_data['ask_levels'] = len(getattr(cob_snapshot, 'consolidated_asks', []))
else:
cob_data['snapshot_available'] = False
return cob_data
except Exception as e:
logger.debug(f"Error getting COB features: {e}")
return {}
def _get_technical_indicators(self, symbol: str, data_provider) -> Dict[str, float]:
"""Get technical indicators"""
try:
if not data_provider:
return {}
indicators = {}
# Get recent price data
df = data_provider.get_historical_data(symbol, '1m', limit=50)
if df is not None and not df.empty:
closes = df['close'].values
highs = df['high'].values
lows = df['low'].values
volumes = df['volume'].values
# Moving averages
indicators['sma_10'] = float(closes[-10:].mean())
indicators['sma_20'] = float(closes[-20:].mean())
# Bollinger Bands
sma_20 = closes[-20:].mean()
std_20 = closes[-20:].std()
indicators['bb_upper'] = float(sma_20 + 2 * std_20)
indicators['bb_lower'] = float(sma_20 - 2 * std_20)
indicators['bb_position'] = float((closes[-1] - indicators['bb_lower']) / (indicators['bb_upper'] - indicators['bb_lower']))
# MACD
ema_12 = closes[-12:].mean() # Simplified
ema_26 = closes[-26:].mean() # Simplified
indicators['macd'] = float(ema_12 - ema_26)
# Volatility
indicators['volatility'] = float(std_20 / sma_20)
return indicators
except Exception as e:
logger.debug(f"Error calculating technical indicators: {e}")
return {}
def _get_recent_price_history(self, symbol: str, data_provider, periods: int = 50) -> List[float]:
"""Get recent price history"""
try:
if not data_provider:
return []
df = data_provider.get_historical_data(symbol, '1m', limit=periods)
if df is not None and not df.empty:
return df['close'].tolist()
return []
except Exception as e:
logger.debug(f"Error getting price history: {e}")
return []