#!/usr/bin/env python3 """ Trade Data Manager - Centralized trade data capture and training case management Handles: - Comprehensive model input capture during trade execution - Storage in testcases structure (positive/negative) - Case indexing and management - Integration with existing negative case trainer - Cold start training data preparation """ import os import json import pickle import logging from datetime import datetime from typing import Dict, List, Any, Optional, Tuple import numpy as np logger = logging.getLogger(__name__) class TradeDataManager: """Centralized manager for trade data capture and training case storage""" def __init__(self, base_dir: str = "testcases"): self.base_dir = base_dir self.cases_cache = {} # In-memory cache of recent cases self.max_cache_size = 100 # Initialize directory structure self._setup_directory_structure() logger.info(f"TradeDataManager initialized with base directory: {base_dir}") def _setup_directory_structure(self): """Setup the testcases directory structure""" try: # Create base directories including new 'base' directory for temporary trades for case_type in ['positive', 'negative', 'base']: for subdir in ['cases', 'sessions', 'models']: dir_path = os.path.join(self.base_dir, case_type, subdir) os.makedirs(dir_path, exist_ok=True) logger.debug("Directory structure setup complete") except Exception as e: logger.error(f"Error setting up directory structure: {e}") def capture_comprehensive_model_inputs(self, symbol: str, action: str, current_price: float, orchestrator=None, data_provider=None) -> Dict[str, Any]: """Capture comprehensive model inputs for cold start training""" try: logger.info(f"Capturing model inputs for {action} trade on {symbol} at ${current_price:.2f}") model_inputs = { 'timestamp': datetime.now().isoformat(), 'symbol': symbol, 'action': action, 'price': current_price, 'capture_type': 'trade_execution' } # 1. Market State Features try: market_state = self._get_comprehensive_market_state(symbol, current_price, data_provider) model_inputs['market_state'] = market_state logger.debug(f"Captured market state: {len(market_state)} features") except Exception as e: logger.warning(f"Error capturing market state: {e}") model_inputs['market_state'] = {} # 2. CNN Features and Predictions try: cnn_data = self._get_cnn_features_and_predictions(symbol, orchestrator) model_inputs['cnn_features'] = cnn_data.get('features', {}) model_inputs['cnn_predictions'] = cnn_data.get('predictions', {}) logger.debug(f"Captured CNN data: {len(cnn_data)} items") except Exception as e: logger.warning(f"Error capturing CNN data: {e}") model_inputs['cnn_features'] = {} model_inputs['cnn_predictions'] = {} # 3. DQN/RL State Features try: dqn_state = self._get_dqn_state_features(symbol, current_price, orchestrator) model_inputs['dqn_state'] = dqn_state logger.debug(f"Captured DQN state: {len(dqn_state) if dqn_state else 0} features") except Exception as e: logger.warning(f"Error capturing DQN state: {e}") model_inputs['dqn_state'] = {} # 4. COB (Order Book) Features try: cob_data = self._get_cob_features_for_training(symbol, orchestrator) model_inputs['cob_features'] = cob_data logger.debug(f"Captured COB features: {len(cob_data) if cob_data else 0} features") except Exception as e: logger.warning(f"Error capturing COB features: {e}") model_inputs['cob_features'] = {} # 5. Technical Indicators try: technical_indicators = self._get_technical_indicators(symbol, data_provider) model_inputs['technical_indicators'] = technical_indicators logger.debug(f"Captured technical indicators: {len(technical_indicators)} indicators") except Exception as e: logger.warning(f"Error capturing technical indicators: {e}") model_inputs['technical_indicators'] = {} # 6. Recent Price History (for context) try: price_history = self._get_recent_price_history(symbol, data_provider, periods=50) model_inputs['price_history'] = price_history logger.debug(f"Captured price history: {len(price_history)} periods") except Exception as e: logger.warning(f"Error capturing price history: {e}") model_inputs['price_history'] = [] total_features = sum(len(v) if isinstance(v, (dict, list)) else 1 for v in model_inputs.values()) logger.info(f" Captured {total_features} total features for cold start training") return model_inputs except Exception as e: logger.error(f"Error capturing model inputs: {e}") return { 'timestamp': datetime.now().isoformat(), 'symbol': symbol, 'action': action, 'price': current_price, 'error': str(e) } def store_trade_for_training(self, trade_record: Dict[str, Any]) -> Optional[str]: """Store trade for future cold start training in testcases structure""" try: # Determine if this will be a positive or negative case based on eventual P&L pnl = trade_record.get('pnl', 0) case_type = "positive" if pnl >= 0 else "negative" # Create testcases directory structure case_dir = os.path.join(self.base_dir, case_type) cases_dir = os.path.join(case_dir, "cases") # Create unique case ID timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") symbol_clean = trade_record['symbol'].replace('/', '') case_id = f"{case_type}_{timestamp}_{symbol_clean}_pnl_{pnl:.4f}".replace('.', 'p').replace('-', 'neg') # Store comprehensive case data as pickle (for complex model inputs) case_filepath = os.path.join(cases_dir, f"{case_id}.pkl") with open(case_filepath, 'wb') as f: pickle.dump(trade_record, f) # Store JSON summary for easy viewing json_filepath = os.path.join(cases_dir, f"{case_id}.json") json_summary = { 'case_id': case_id, 'timestamp': trade_record.get('entry_time', datetime.now()).isoformat() if hasattr(trade_record.get('entry_time'), 'isoformat') else str(trade_record.get('entry_time')), 'symbol': trade_record['symbol'], 'side': trade_record['side'], 'entry_price': trade_record['entry_price'], 'pnl': pnl, 'confidence': trade_record.get('confidence', 0), 'trade_type': trade_record.get('trade_type', 'unknown'), 'model_inputs_captured': bool(trade_record.get('model_inputs_at_entry')), 'training_ready': trade_record.get('training_ready', False), 'feature_counts': { 'market_state': len(trade_record.get('entry_market_state', {})), 'cnn_features': len(trade_record.get('model_inputs_at_entry', {}).get('cnn_features', {})), 'dqn_state': len(trade_record.get('model_inputs_at_entry', {}).get('dqn_state', {})), 'cob_features': len(trade_record.get('model_inputs_at_entry', {}).get('cob_features', {})), 'technical_indicators': len(trade_record.get('model_inputs_at_entry', {}).get('technical_indicators', {})), 'price_history': len(trade_record.get('model_inputs_at_entry', {}).get('price_history', [])) } } with open(json_filepath, 'w') as f: json.dump(json_summary, f, indent=2, default=str) # Update case index self._update_case_index(case_dir, case_id, json_summary, case_type) # Add to cache self.cases_cache[case_id] = json_summary if len(self.cases_cache) > self.max_cache_size: # Remove oldest entry oldest_key = next(iter(self.cases_cache)) del self.cases_cache[oldest_key] logger.info(f" Stored {case_type} case for training: {case_id}") logger.info(f" PKL: {case_filepath}") logger.info(f" JSON: {json_filepath}") return case_id except Exception as e: logger.error(f"Error storing trade for training: {e}") import traceback logger.error(traceback.format_exc()) return None def _update_case_index(self, case_dir: str, case_id: str, case_summary: Dict[str, Any], case_type: str): """Update the case index file""" try: index_file = os.path.join(case_dir, "case_index.json") # Load existing index or create new one if os.path.exists(index_file): with open(index_file, 'r') as f: index_data = json.load(f) else: index_data = {"cases": [], "last_updated": None} # Add new case index_entry = { "case_id": case_id, "timestamp": case_summary['timestamp'], "symbol": case_summary['symbol'], "pnl": case_summary['pnl'], "training_priority": self._calculate_training_priority(case_summary, case_type), "retraining_count": 0, "feature_counts": case_summary['feature_counts'] } index_data["cases"].append(index_entry) index_data["last_updated"] = datetime.now().isoformat() # Save updated index with open(index_file, 'w') as f: json.dump(index_data, f, indent=2) logger.debug(f"Updated case index: {case_id}") except Exception as e: logger.error(f"Error updating case index: {e}") def _calculate_training_priority(self, case_summary: Dict[str, Any], case_type: str) -> int: """Calculate training priority based on case characteristics""" try: pnl = abs(case_summary.get('pnl', 0)) confidence = case_summary.get('confidence', 0) # Higher priority for larger losses/gains and high confidence wrong predictions if case_type == "negative": # Larger losses get higher priority, especially with high confidence priority = min(5, int(pnl * 10) + int(confidence * 2)) else: # Profits get medium priority unless very large priority = min(3, int(pnl * 5) + 1) return max(1, priority) # Minimum priority of 1 except Exception: return 1 # Default priority def get_training_cases(self, case_type: str = "negative", limit: int = 50) -> List[Dict[str, Any]]: """Get training cases for model training""" try: case_dir = os.path.join(self.base_dir, case_type) index_file = os.path.join(case_dir, "case_index.json") if not os.path.exists(index_file): return [] with open(index_file, 'r') as f: index_data = json.load(f) # Sort by training priority (highest first) and limit cases = sorted(index_data["cases"], key=lambda x: x.get("training_priority", 1), reverse=True)[:limit] return cases except Exception as e: logger.error(f"Error getting training cases: {e}") return [] def load_case_data(self, case_id: str, case_type: str = None) -> Optional[Dict[str, Any]]: """Load full case data from pickle file""" try: # Determine case type if not provided if case_type is None: case_type = "positive" if "positive" in case_id else "negative" case_filepath = os.path.join(self.base_dir, case_type, "cases", f"{case_id}.pkl") if not os.path.exists(case_filepath): logger.warning(f"Case file not found: {case_filepath}") return None with open(case_filepath, 'rb') as f: case_data = pickle.load(f) return case_data except Exception as e: logger.error(f"Error loading case data for {case_id}: {e}") return None def cleanup_old_cases(self, days_to_keep: int = 30): """Clean up old test cases to manage storage""" try: from datetime import timedelta cutoff_date = datetime.now() - timedelta(days=days_to_keep) for case_type in ['positive', 'negative']: case_dir = os.path.join(self.base_dir, case_type) cases_dir = os.path.join(case_dir, "cases") if not os.path.exists(cases_dir): continue # Get case index index_file = os.path.join(case_dir, "case_index.json") if os.path.exists(index_file): with open(index_file, 'r') as f: index_data = json.load(f) # Filter cases to keep cases_to_keep = [] cases_removed = 0 for case in index_data["cases"]: case_date = datetime.fromisoformat(case["timestamp"]) if case_date > cutoff_date: cases_to_keep.append(case) else: # Remove case files case_id = case["case_id"] pkl_file = os.path.join(cases_dir, f"{case_id}.pkl") json_file = os.path.join(cases_dir, f"{case_id}.json") for file_path in [pkl_file, json_file]: if os.path.exists(file_path): os.remove(file_path) cases_removed += 1 # Update index index_data["cases"] = cases_to_keep index_data["last_updated"] = datetime.now().isoformat() with open(index_file, 'w') as f: json.dump(index_data, f, indent=2) if cases_removed > 0: logger.info(f"Cleaned up {cases_removed} old {case_type} cases") except Exception as e: logger.error(f"Error cleaning up old cases: {e}") # Helper methods for feature extraction def _get_comprehensive_market_state(self, symbol: str, current_price: float, data_provider) -> Dict[str, float]: """Get comprehensive market state features""" try: if not data_provider: return {'current_price': current_price} market_state = {'current_price': current_price} # Get historical data for features df = data_provider.get_historical_data(symbol, '1m', limit=100) if df is not None and not df.empty: prices = df['close'].values volumes = df['volume'].values # Price features market_state['price_sma_5'] = float(prices[-5:].mean()) market_state['price_sma_20'] = float(prices[-20:].mean()) market_state['price_std_20'] = float(prices[-20:].std()) market_state['price_rsi'] = self._calculate_rsi(prices, 14) # Volume features market_state['volume_current'] = float(volumes[-1]) market_state['volume_sma_20'] = float(volumes[-20:].mean()) market_state['volume_ratio'] = float(volumes[-1] / volumes[-20:].mean()) # Trend features market_state['price_momentum_5'] = float((prices[-1] - prices[-5]) / prices[-5]) market_state['price_momentum_20'] = float((prices[-1] - prices[-20]) / prices[-20]) # Add timestamp features now = datetime.now() market_state['hour_of_day'] = now.hour market_state['minute_of_hour'] = now.minute market_state['day_of_week'] = now.weekday() return market_state except Exception as e: logger.warning(f"Error getting market state: {e}") return {'current_price': current_price} def _calculate_rsi(self, prices, period=14): """Calculate RSI indicator""" try: deltas = np.diff(prices) gains = np.where(deltas > 0, deltas, 0) losses = np.where(deltas < 0, -deltas, 0) avg_gain = np.mean(gains[-period:]) avg_loss = np.mean(losses[-period:]) if avg_loss == 0: return 100.0 rs = avg_gain / avg_loss rsi = 100 - (100 / (1 + rs)) return float(rsi) except: return 50.0 # Neutral RSI def _get_cnn_features_and_predictions(self, symbol: str, orchestrator) -> Dict[str, Any]: """Get CNN features and predictions from orchestrator""" try: if not orchestrator: return {} cnn_data = {} # Get CNN features if available if hasattr(orchestrator, 'latest_cnn_features'): cnn_features = getattr(orchestrator, 'latest_cnn_features', {}).get(symbol) if cnn_features is not None: cnn_data['features'] = cnn_features.tolist() if hasattr(cnn_features, 'tolist') else cnn_features # Get CNN predictions if available if hasattr(orchestrator, 'latest_cnn_predictions'): cnn_predictions = getattr(orchestrator, 'latest_cnn_predictions', {}).get(symbol) if cnn_predictions is not None: cnn_data['predictions'] = cnn_predictions.tolist() if hasattr(cnn_predictions, 'tolist') else cnn_predictions return cnn_data except Exception as e: logger.debug(f"Error getting CNN data: {e}") return {} def _get_dqn_state_features(self, symbol: str, current_price: float, orchestrator) -> Dict[str, Any]: """Get DQN state features from orchestrator""" try: if not orchestrator: return {} # Get DQN state from orchestrator if available if hasattr(orchestrator, 'build_comprehensive_rl_state'): rl_state = orchestrator.build_comprehensive_rl_state(symbol) if rl_state is not None: return { 'state_vector': rl_state.tolist() if hasattr(rl_state, 'tolist') else rl_state, 'state_size': len(rl_state) if hasattr(rl_state, '__len__') else 0 } return {} except Exception as e: logger.debug(f"Error getting DQN state: {e}") return {} def _get_cob_features_for_training(self, symbol: str, orchestrator) -> Dict[str, Any]: """Get COB features for training""" try: if not orchestrator: return {} cob_data = {} # Get COB features from orchestrator if hasattr(orchestrator, 'latest_cob_features'): cob_features = getattr(orchestrator, 'latest_cob_features', {}).get(symbol) if cob_features is not None: cob_data['features'] = cob_features.tolist() if hasattr(cob_features, 'tolist') else cob_features # Get COB snapshot if hasattr(orchestrator, 'cob_integration') and orchestrator.cob_integration: if hasattr(orchestrator.cob_integration, 'get_cob_snapshot'): cob_snapshot = orchestrator.cob_integration.get_cob_snapshot(symbol) if cob_snapshot: cob_data['snapshot_available'] = True cob_data['bid_levels'] = len(getattr(cob_snapshot, 'consolidated_bids', [])) cob_data['ask_levels'] = len(getattr(cob_snapshot, 'consolidated_asks', [])) else: cob_data['snapshot_available'] = False return cob_data except Exception as e: logger.debug(f"Error getting COB features: {e}") return {} def _get_technical_indicators(self, symbol: str, data_provider) -> Dict[str, float]: """Get technical indicators""" try: if not data_provider: return {} indicators = {} # Get recent price data df = data_provider.get_historical_data(symbol, '1m', limit=50) if df is not None and not df.empty: closes = df['close'].values highs = df['high'].values lows = df['low'].values volumes = df['volume'].values # Moving averages indicators['sma_10'] = float(closes[-10:].mean()) indicators['sma_20'] = float(closes[-20:].mean()) # Bollinger Bands sma_20 = closes[-20:].mean() std_20 = closes[-20:].std() indicators['bb_upper'] = float(sma_20 + 2 * std_20) indicators['bb_lower'] = float(sma_20 - 2 * std_20) indicators['bb_position'] = float((closes[-1] - indicators['bb_lower']) / (indicators['bb_upper'] - indicators['bb_lower'])) # MACD ema_12 = closes[-12:].mean() # Simplified ema_26 = closes[-26:].mean() # Simplified indicators['macd'] = float(ema_12 - ema_26) # Volatility indicators['volatility'] = float(std_20 / sma_20) return indicators except Exception as e: logger.debug(f"Error calculating technical indicators: {e}") return {} def _get_recent_price_history(self, symbol: str, data_provider, periods: int = 50) -> List[float]: """Get recent price history""" try: if not data_provider: return [] df = data_provider.get_historical_data(symbol, '1m', limit=periods) if df is not None and not df.empty: return df['close'].tolist() return [] except Exception as e: logger.debug(f"Error getting price history: {e}") return [] def store_base_trade_for_later_classification(self, trade_record: Dict[str, Any]) -> Optional[str]: """Store opening trade as BASE case until position is closed and P&L is known""" try: # Store in base directory (temporary) case_dir = os.path.join(self.base_dir, "base") cases_dir = os.path.join(case_dir, "cases") # Create unique case ID for base case timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") symbol_clean = trade_record['symbol'].replace('/', '') base_case_id = f"base_{timestamp}_{symbol_clean}_{trade_record['side']}" # Store comprehensive case data as pickle case_filepath = os.path.join(cases_dir, f"{base_case_id}.pkl") with open(case_filepath, 'wb') as f: pickle.dump(trade_record, f) # Store JSON summary json_filepath = os.path.join(cases_dir, f"{base_case_id}.json") json_summary = { 'case_id': base_case_id, 'timestamp': trade_record.get('timestamp_entry', datetime.now()).isoformat() if hasattr(trade_record.get('timestamp_entry'), 'isoformat') else str(trade_record.get('timestamp_entry')), 'symbol': trade_record['symbol'], 'side': trade_record['side'], 'entry_price': trade_record['entry_price'], 'leverage': trade_record.get('leverage', 1), 'quantity': trade_record.get('quantity', 0), 'trade_status': 'OPENING', 'confidence': trade_record.get('confidence', 0), 'trade_type': trade_record.get('trade_type', 'manual'), 'training_ready': False, # Not ready until closed 'feature_counts': { 'market_state': len(trade_record.get('model_inputs_at_entry', {})), 'cob_features': len(trade_record.get('cob_snapshot_at_entry', {})) } } with open(json_filepath, 'w') as f: json.dump(json_summary, f, indent=2, default=str) logger.info(f"Stored base case for later classification: {base_case_id}") return base_case_id except Exception as e: logger.error(f"Error storing base trade: {e}") return None def move_base_trade_to_outcome(self, base_case_id: str, closing_trade_record: Dict[str, Any], is_positive: bool) -> Optional[str]: """Move base case to positive/negative based on trade outcome""" try: # Load the original base case base_case_path = os.path.join(self.base_dir, "base", "cases", f"{base_case_id}.pkl") base_json_path = os.path.join(self.base_dir, "base", "cases", f"{base_case_id}.json") if not os.path.exists(base_case_path): logger.warning(f"Base case not found: {base_case_id}") return None # Load opening trade data with open(base_case_path, 'rb') as f: opening_trade_data = pickle.load(f) # Combine opening and closing data combined_trade_record = { **opening_trade_data, # Opening snapshot **closing_trade_record, # Closing snapshot 'opening_data': opening_trade_data, 'closing_data': closing_trade_record, 'trade_complete': True } # Determine target directory case_type = "positive" if is_positive else "negative" case_dir = os.path.join(self.base_dir, case_type) cases_dir = os.path.join(case_dir, "cases") # Create new case ID for final outcome timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") symbol_clean = closing_trade_record['symbol'].replace('/', '') pnl_leveraged = closing_trade_record.get('pnl_leveraged', 0) final_case_id = f"{case_type}_{timestamp}_{symbol_clean}_pnl_{pnl_leveraged:.4f}".replace('.', 'p').replace('-', 'neg') # Store final case data final_case_filepath = os.path.join(cases_dir, f"{final_case_id}.pkl") with open(final_case_filepath, 'wb') as f: pickle.dump(combined_trade_record, f) # Store JSON summary final_json_filepath = os.path.join(cases_dir, f"{final_case_id}.json") json_summary = { 'case_id': final_case_id, 'original_base_case_id': base_case_id, 'timestamp_opened': str(opening_trade_data.get('timestamp_entry', '')), 'timestamp_closed': str(closing_trade_record.get('timestamp_exit', '')), 'symbol': closing_trade_record['symbol'], 'side_opened': opening_trade_data['side'], 'side_closed': closing_trade_record['side'], 'entry_price': opening_trade_data['entry_price'], 'exit_price': closing_trade_record['exit_price'], 'leverage': closing_trade_record.get('leverage', 1), 'quantity': closing_trade_record.get('quantity', 0), 'pnl_raw': closing_trade_record.get('pnl_raw', 0), 'pnl_leveraged': pnl_leveraged, 'trade_type': closing_trade_record.get('trade_type', 'manual'), 'training_ready': True, 'complete_trade_pair': True, 'feature_counts': { 'opening_market_state': len(opening_trade_data.get('model_inputs_at_entry', {})), 'opening_cob_features': len(opening_trade_data.get('cob_snapshot_at_entry', {})), 'closing_market_state': len(closing_trade_record.get('model_inputs_at_exit', {})), 'closing_cob_features': len(closing_trade_record.get('cob_snapshot_at_exit', {})) } } with open(final_json_filepath, 'w') as f: json.dump(json_summary, f, indent=2, default=str) # Update case index self._update_case_index(case_dir, final_case_id, json_summary, case_type) # Clean up base case files try: os.remove(base_case_path) os.remove(base_json_path) logger.debug(f"Cleaned up base case files: {base_case_id}") except Exception as e: logger.warning(f"Error cleaning up base case files: {e}") logger.info(f"Moved base case to {case_type}: {final_case_id}") return final_case_id except Exception as e: logger.error(f"Error moving base trade to outcome: {e}") return None