test cases

This commit is contained in:
Dobromir Popov
2025-06-25 14:45:37 +03:00
parent 4a1170d593
commit 4afa147bd1
5 changed files with 1039 additions and 247 deletions

1
.gitignore vendored
View File

@ -39,3 +39,4 @@ NN/models/saved/hybrid_stats_20250409_022901.json
*.png
closed_trades_history.json
data/cnn_training/cnn_training_data*
testcases/*

546
core/trade_data_manager.py Normal file
View File

@ -0,0 +1,546 @@
#!/usr/bin/env python3
"""
Trade Data Manager - Centralized trade data capture and training case management
Handles:
- Comprehensive model input capture during trade execution
- Storage in testcases structure (positive/negative)
- Case indexing and management
- Integration with existing negative case trainer
- Cold start training data preparation
"""
import os
import json
import pickle
import logging
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple
import numpy as np
logger = logging.getLogger(__name__)
class TradeDataManager:
"""Centralized manager for trade data capture and training case storage"""
def __init__(self, base_dir: str = "testcases"):
self.base_dir = base_dir
self.cases_cache = {} # In-memory cache of recent cases
self.max_cache_size = 100
# Initialize directory structure
self._setup_directory_structure()
logger.info(f"TradeDataManager initialized with base directory: {base_dir}")
def _setup_directory_structure(self):
"""Setup the testcases directory structure"""
try:
for case_type in ['positive', 'negative']:
for subdir in ['cases', 'sessions', 'models']:
dir_path = os.path.join(self.base_dir, case_type, subdir)
os.makedirs(dir_path, exist_ok=True)
logger.debug("Directory structure setup complete")
except Exception as e:
logger.error(f"Error setting up directory structure: {e}")
def capture_comprehensive_model_inputs(self, symbol: str, action: str, current_price: float,
orchestrator=None, data_provider=None) -> Dict[str, Any]:
"""Capture comprehensive model inputs for cold start training"""
try:
logger.info(f"Capturing model inputs for {action} trade on {symbol} at ${current_price:.2f}")
model_inputs = {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'action': action,
'price': current_price,
'capture_type': 'trade_execution'
}
# 1. Market State Features
try:
market_state = self._get_comprehensive_market_state(symbol, current_price, data_provider)
model_inputs['market_state'] = market_state
logger.debug(f"Captured market state: {len(market_state)} features")
except Exception as e:
logger.warning(f"Error capturing market state: {e}")
model_inputs['market_state'] = {}
# 2. CNN Features and Predictions
try:
cnn_data = self._get_cnn_features_and_predictions(symbol, orchestrator)
model_inputs['cnn_features'] = cnn_data.get('features', {})
model_inputs['cnn_predictions'] = cnn_data.get('predictions', {})
logger.debug(f"Captured CNN data: {len(cnn_data)} items")
except Exception as e:
logger.warning(f"Error capturing CNN data: {e}")
model_inputs['cnn_features'] = {}
model_inputs['cnn_predictions'] = {}
# 3. DQN/RL State Features
try:
dqn_state = self._get_dqn_state_features(symbol, current_price, orchestrator)
model_inputs['dqn_state'] = dqn_state
logger.debug(f"Captured DQN state: {len(dqn_state) if dqn_state else 0} features")
except Exception as e:
logger.warning(f"Error capturing DQN state: {e}")
model_inputs['dqn_state'] = {}
# 4. COB (Order Book) Features
try:
cob_data = self._get_cob_features_for_training(symbol, orchestrator)
model_inputs['cob_features'] = cob_data
logger.debug(f"Captured COB features: {len(cob_data) if cob_data else 0} features")
except Exception as e:
logger.warning(f"Error capturing COB features: {e}")
model_inputs['cob_features'] = {}
# 5. Technical Indicators
try:
technical_indicators = self._get_technical_indicators(symbol, data_provider)
model_inputs['technical_indicators'] = technical_indicators
logger.debug(f"Captured technical indicators: {len(technical_indicators)} indicators")
except Exception as e:
logger.warning(f"Error capturing technical indicators: {e}")
model_inputs['technical_indicators'] = {}
# 6. Recent Price History (for context)
try:
price_history = self._get_recent_price_history(symbol, data_provider, periods=50)
model_inputs['price_history'] = price_history
logger.debug(f"Captured price history: {len(price_history)} periods")
except Exception as e:
logger.warning(f"Error capturing price history: {e}")
model_inputs['price_history'] = []
total_features = sum(len(v) if isinstance(v, (dict, list)) else 1 for v in model_inputs.values())
logger.info(f"✅ Captured {total_features} total features for cold start training")
return model_inputs
except Exception as e:
logger.error(f"Error capturing model inputs: {e}")
return {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'action': action,
'price': current_price,
'error': str(e)
}
def store_trade_for_training(self, trade_record: Dict[str, Any]) -> Optional[str]:
"""Store trade for future cold start training in testcases structure"""
try:
# Determine if this will be a positive or negative case based on eventual P&L
pnl = trade_record.get('pnl', 0)
case_type = "positive" if pnl >= 0 else "negative"
# Create testcases directory structure
case_dir = os.path.join(self.base_dir, case_type)
cases_dir = os.path.join(case_dir, "cases")
# Create unique case ID
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
symbol_clean = trade_record['symbol'].replace('/', '')
case_id = f"{case_type}_{timestamp}_{symbol_clean}_pnl_{pnl:.4f}".replace('.', 'p').replace('-', 'neg')
# Store comprehensive case data as pickle (for complex model inputs)
case_filepath = os.path.join(cases_dir, f"{case_id}.pkl")
with open(case_filepath, 'wb') as f:
pickle.dump(trade_record, f)
# Store JSON summary for easy viewing
json_filepath = os.path.join(cases_dir, f"{case_id}.json")
json_summary = {
'case_id': case_id,
'timestamp': trade_record.get('entry_time', datetime.now()).isoformat() if hasattr(trade_record.get('entry_time'), 'isoformat') else str(trade_record.get('entry_time')),
'symbol': trade_record['symbol'],
'side': trade_record['side'],
'entry_price': trade_record['entry_price'],
'pnl': pnl,
'confidence': trade_record.get('confidence', 0),
'trade_type': trade_record.get('trade_type', 'unknown'),
'model_inputs_captured': bool(trade_record.get('model_inputs_at_entry')),
'training_ready': trade_record.get('training_ready', False),
'feature_counts': {
'market_state': len(trade_record.get('entry_market_state', {})),
'cnn_features': len(trade_record.get('model_inputs_at_entry', {}).get('cnn_features', {})),
'dqn_state': len(trade_record.get('model_inputs_at_entry', {}).get('dqn_state', {})),
'cob_features': len(trade_record.get('model_inputs_at_entry', {}).get('cob_features', {})),
'technical_indicators': len(trade_record.get('model_inputs_at_entry', {}).get('technical_indicators', {})),
'price_history': len(trade_record.get('model_inputs_at_entry', {}).get('price_history', []))
}
}
with open(json_filepath, 'w') as f:
json.dump(json_summary, f, indent=2, default=str)
# Update case index
self._update_case_index(case_dir, case_id, json_summary, case_type)
# Add to cache
self.cases_cache[case_id] = json_summary
if len(self.cases_cache) > self.max_cache_size:
# Remove oldest entry
oldest_key = next(iter(self.cases_cache))
del self.cases_cache[oldest_key]
logger.info(f"✅ Stored {case_type} case for training: {case_id}")
logger.info(f" PKL: {case_filepath}")
logger.info(f" JSON: {json_filepath}")
return case_id
except Exception as e:
logger.error(f"Error storing trade for training: {e}")
import traceback
logger.error(traceback.format_exc())
return None
def _update_case_index(self, case_dir: str, case_id: str, case_summary: Dict[str, Any], case_type: str):
"""Update the case index file"""
try:
index_file = os.path.join(case_dir, "case_index.json")
# Load existing index or create new one
if os.path.exists(index_file):
with open(index_file, 'r') as f:
index_data = json.load(f)
else:
index_data = {"cases": [], "last_updated": None}
# Add new case
index_entry = {
"case_id": case_id,
"timestamp": case_summary['timestamp'],
"symbol": case_summary['symbol'],
"pnl": case_summary['pnl'],
"training_priority": self._calculate_training_priority(case_summary, case_type),
"retraining_count": 0,
"feature_counts": case_summary['feature_counts']
}
index_data["cases"].append(index_entry)
index_data["last_updated"] = datetime.now().isoformat()
# Save updated index
with open(index_file, 'w') as f:
json.dump(index_data, f, indent=2)
logger.debug(f"Updated case index: {case_id}")
except Exception as e:
logger.error(f"Error updating case index: {e}")
def _calculate_training_priority(self, case_summary: Dict[str, Any], case_type: str) -> int:
"""Calculate training priority based on case characteristics"""
try:
pnl = abs(case_summary.get('pnl', 0))
confidence = case_summary.get('confidence', 0)
# Higher priority for larger losses/gains and high confidence wrong predictions
if case_type == "negative":
# Larger losses get higher priority, especially with high confidence
priority = min(5, int(pnl * 10) + int(confidence * 2))
else:
# Profits get medium priority unless very large
priority = min(3, int(pnl * 5) + 1)
return max(1, priority) # Minimum priority of 1
except Exception:
return 1 # Default priority
def get_training_cases(self, case_type: str = "negative", limit: int = 50) -> List[Dict[str, Any]]:
"""Get training cases for model training"""
try:
case_dir = os.path.join(self.base_dir, case_type)
index_file = os.path.join(case_dir, "case_index.json")
if not os.path.exists(index_file):
return []
with open(index_file, 'r') as f:
index_data = json.load(f)
# Sort by training priority (highest first) and limit
cases = sorted(index_data["cases"],
key=lambda x: x.get("training_priority", 1),
reverse=True)[:limit]
return cases
except Exception as e:
logger.error(f"Error getting training cases: {e}")
return []
def load_case_data(self, case_id: str, case_type: str = None) -> Optional[Dict[str, Any]]:
"""Load full case data from pickle file"""
try:
# Determine case type if not provided
if case_type is None:
case_type = "positive" if "positive" in case_id else "negative"
case_filepath = os.path.join(self.base_dir, case_type, "cases", f"{case_id}.pkl")
if not os.path.exists(case_filepath):
logger.warning(f"Case file not found: {case_filepath}")
return None
with open(case_filepath, 'rb') as f:
case_data = pickle.load(f)
return case_data
except Exception as e:
logger.error(f"Error loading case data for {case_id}: {e}")
return None
def cleanup_old_cases(self, days_to_keep: int = 30):
"""Clean up old test cases to manage storage"""
try:
from datetime import timedelta
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
for case_type in ['positive', 'negative']:
case_dir = os.path.join(self.base_dir, case_type)
cases_dir = os.path.join(case_dir, "cases")
if not os.path.exists(cases_dir):
continue
# Get case index
index_file = os.path.join(case_dir, "case_index.json")
if os.path.exists(index_file):
with open(index_file, 'r') as f:
index_data = json.load(f)
# Filter cases to keep
cases_to_keep = []
cases_removed = 0
for case in index_data["cases"]:
case_date = datetime.fromisoformat(case["timestamp"])
if case_date > cutoff_date:
cases_to_keep.append(case)
else:
# Remove case files
case_id = case["case_id"]
pkl_file = os.path.join(cases_dir, f"{case_id}.pkl")
json_file = os.path.join(cases_dir, f"{case_id}.json")
for file_path in [pkl_file, json_file]:
if os.path.exists(file_path):
os.remove(file_path)
cases_removed += 1
# Update index
index_data["cases"] = cases_to_keep
index_data["last_updated"] = datetime.now().isoformat()
with open(index_file, 'w') as f:
json.dump(index_data, f, indent=2)
if cases_removed > 0:
logger.info(f"Cleaned up {cases_removed} old {case_type} cases")
except Exception as e:
logger.error(f"Error cleaning up old cases: {e}")
# Helper methods for feature extraction
def _get_comprehensive_market_state(self, symbol: str, current_price: float, data_provider) -> Dict[str, float]:
"""Get comprehensive market state features"""
try:
if not data_provider:
return {'current_price': current_price}
market_state = {'current_price': current_price}
# Get historical data for features
df = data_provider.get_historical_data(symbol, '1m', limit=100)
if df is not None and not df.empty:
prices = df['close'].values
volumes = df['volume'].values
# Price features
market_state['price_sma_5'] = float(prices[-5:].mean())
market_state['price_sma_20'] = float(prices[-20:].mean())
market_state['price_std_20'] = float(prices[-20:].std())
market_state['price_rsi'] = self._calculate_rsi(prices, 14)
# Volume features
market_state['volume_current'] = float(volumes[-1])
market_state['volume_sma_20'] = float(volumes[-20:].mean())
market_state['volume_ratio'] = float(volumes[-1] / volumes[-20:].mean())
# Trend features
market_state['price_momentum_5'] = float((prices[-1] - prices[-5]) / prices[-5])
market_state['price_momentum_20'] = float((prices[-1] - prices[-20]) / prices[-20])
# Add timestamp features
now = datetime.now()
market_state['hour_of_day'] = now.hour
market_state['minute_of_hour'] = now.minute
market_state['day_of_week'] = now.weekday()
return market_state
except Exception as e:
logger.warning(f"Error getting market state: {e}")
return {'current_price': current_price}
def _calculate_rsi(self, prices, period=14):
"""Calculate RSI indicator"""
try:
deltas = np.diff(prices)
gains = np.where(deltas > 0, deltas, 0)
losses = np.where(deltas < 0, -deltas, 0)
avg_gain = np.mean(gains[-period:])
avg_loss = np.mean(losses[-period:])
if avg_loss == 0:
return 100.0
rs = avg_gain / avg_loss
rsi = 100 - (100 / (1 + rs))
return float(rsi)
except:
return 50.0 # Neutral RSI
def _get_cnn_features_and_predictions(self, symbol: str, orchestrator) -> Dict[str, Any]:
"""Get CNN features and predictions from orchestrator"""
try:
if not orchestrator:
return {}
cnn_data = {}
# Get CNN features if available
if hasattr(orchestrator, 'latest_cnn_features'):
cnn_features = getattr(orchestrator, 'latest_cnn_features', {}).get(symbol)
if cnn_features is not None:
cnn_data['features'] = cnn_features.tolist() if hasattr(cnn_features, 'tolist') else cnn_features
# Get CNN predictions if available
if hasattr(orchestrator, 'latest_cnn_predictions'):
cnn_predictions = getattr(orchestrator, 'latest_cnn_predictions', {}).get(symbol)
if cnn_predictions is not None:
cnn_data['predictions'] = cnn_predictions.tolist() if hasattr(cnn_predictions, 'tolist') else cnn_predictions
return cnn_data
except Exception as e:
logger.debug(f"Error getting CNN data: {e}")
return {}
def _get_dqn_state_features(self, symbol: str, current_price: float, orchestrator) -> Dict[str, Any]:
"""Get DQN state features from orchestrator"""
try:
if not orchestrator:
return {}
# Get DQN state from orchestrator if available
if hasattr(orchestrator, 'build_comprehensive_rl_state'):
rl_state = orchestrator.build_comprehensive_rl_state(symbol)
if rl_state is not None:
return {
'state_vector': rl_state.tolist() if hasattr(rl_state, 'tolist') else rl_state,
'state_size': len(rl_state) if hasattr(rl_state, '__len__') else 0
}
return {}
except Exception as e:
logger.debug(f"Error getting DQN state: {e}")
return {}
def _get_cob_features_for_training(self, symbol: str, orchestrator) -> Dict[str, Any]:
"""Get COB features for training"""
try:
if not orchestrator:
return {}
cob_data = {}
# Get COB features from orchestrator
if hasattr(orchestrator, 'latest_cob_features'):
cob_features = getattr(orchestrator, 'latest_cob_features', {}).get(symbol)
if cob_features is not None:
cob_data['features'] = cob_features.tolist() if hasattr(cob_features, 'tolist') else cob_features
# Get COB snapshot
if hasattr(orchestrator, 'cob_integration') and orchestrator.cob_integration:
if hasattr(orchestrator.cob_integration, 'get_cob_snapshot'):
cob_snapshot = orchestrator.cob_integration.get_cob_snapshot(symbol)
if cob_snapshot:
cob_data['snapshot_available'] = True
cob_data['bid_levels'] = len(getattr(cob_snapshot, 'consolidated_bids', []))
cob_data['ask_levels'] = len(getattr(cob_snapshot, 'consolidated_asks', []))
else:
cob_data['snapshot_available'] = False
return cob_data
except Exception as e:
logger.debug(f"Error getting COB features: {e}")
return {}
def _get_technical_indicators(self, symbol: str, data_provider) -> Dict[str, float]:
"""Get technical indicators"""
try:
if not data_provider:
return {}
indicators = {}
# Get recent price data
df = data_provider.get_historical_data(symbol, '1m', limit=50)
if df is not None and not df.empty:
closes = df['close'].values
highs = df['high'].values
lows = df['low'].values
volumes = df['volume'].values
# Moving averages
indicators['sma_10'] = float(closes[-10:].mean())
indicators['sma_20'] = float(closes[-20:].mean())
# Bollinger Bands
sma_20 = closes[-20:].mean()
std_20 = closes[-20:].std()
indicators['bb_upper'] = float(sma_20 + 2 * std_20)
indicators['bb_lower'] = float(sma_20 - 2 * std_20)
indicators['bb_position'] = float((closes[-1] - indicators['bb_lower']) / (indicators['bb_upper'] - indicators['bb_lower']))
# MACD
ema_12 = closes[-12:].mean() # Simplified
ema_26 = closes[-26:].mean() # Simplified
indicators['macd'] = float(ema_12 - ema_26)
# Volatility
indicators['volatility'] = float(std_20 / sma_20)
return indicators
except Exception as e:
logger.debug(f"Error calculating technical indicators: {e}")
return {}
def _get_recent_price_history(self, symbol: str, data_provider, periods: int = 50) -> List[float]:
"""Get recent price history"""
try:
if not data_provider:
return []
df = data_provider.get_historical_data(symbol, '1m', limit=periods)
if df is not None and not df.empty:
return df['close'].tolist()
return []
except Exception as e:
logger.debug(f"Error getting price history: {e}")
return []

View File

@ -0,0 +1,304 @@
#!/usr/bin/env python3
"""
Training Integration - Handles cold start training and model learning integration
Manages:
- Cold start training triggers from trade outcomes
- Reward calculation based on P&L
- Integration with DQN, CNN, and COB RL models
- Training session management
"""
import logging
from datetime import datetime
from typing import Dict, List, Any, Optional
import numpy as np
logger = logging.getLogger(__name__)
class TrainingIntegration:
"""Manages training integration for cold start learning"""
def __init__(self, orchestrator=None):
self.orchestrator = orchestrator
self.training_sessions = {}
self.min_confidence_threshold = 0.3
logger.info("TrainingIntegration initialized")
def trigger_cold_start_training(self, trade_record: Dict[str, Any], case_id: str = None) -> bool:
"""Trigger cold start training when trades close with known outcomes"""
try:
if not trade_record.get('model_inputs_at_entry'):
logger.warning("No model inputs captured for training - skipping")
return False
pnl = trade_record.get('pnl', 0)
confidence = trade_record.get('confidence', 0)
logger.info(f"Triggering cold start training for trade with P&L: ${pnl:.4f}")
# Calculate training reward based on P&L and confidence
reward = self._calculate_training_reward(pnl, confidence)
# Train DQN on trade outcome
dqn_success = self._train_dqn_on_trade_outcome(trade_record, reward)
# Train CNN if available (placeholder for now)
cnn_success = self._train_cnn_on_trade_outcome(trade_record, reward)
# Train COB RL if available (placeholder for now)
cob_success = self._train_cob_rl_on_trade_outcome(trade_record, reward)
# Log training results
training_success = any([dqn_success, cnn_success, cob_success])
if training_success:
logger.info(f"Cold start training completed - DQN: {dqn_success}, CNN: {cnn_success}, COB: {cob_success}")
else:
logger.warning("Cold start training failed for all models")
return training_success
except Exception as e:
logger.error(f"Error in cold start training: {e}")
return False
def _calculate_training_reward(self, pnl: float, confidence: float) -> float:
"""Calculate training reward based on P&L and confidence"""
try:
# Base reward is proportional to P&L
base_reward = pnl
# Adjust for confidence - penalize high confidence wrong predictions more
if pnl < 0 and confidence > 0.7:
# High confidence loss - significant negative reward
confidence_adjustment = -confidence * 2
elif pnl > 0 and confidence > 0.7:
# High confidence gain - boost reward
confidence_adjustment = confidence * 1.5
else:
# Low confidence - minimal adjustment
confidence_adjustment = 0
final_reward = base_reward + confidence_adjustment
# Normalize to [-1, 1] range for training stability
normalized_reward = np.tanh(final_reward / 10.0)
logger.debug(f"Training reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}")
return float(normalized_reward)
except Exception as e:
logger.error(f"Error calculating training reward: {e}")
return 0.0
def _train_dqn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
"""Train DQN agent on trade outcome"""
try:
if not self.orchestrator:
logger.warning("No orchestrator available for DQN training")
return False
# Get DQN agent
if not hasattr(self.orchestrator, 'dqn_agent') or not self.orchestrator.dqn_agent:
logger.warning("DQN agent not available for training")
return False
# Extract DQN state from model inputs
model_inputs = trade_record.get('model_inputs_at_entry', {})
dqn_state = model_inputs.get('dqn_state', {}).get('state_vector')
if not dqn_state:
logger.warning("No DQN state available for training")
return False
# Convert action to DQN action index
action = trade_record.get('side', 'HOLD').upper()
action_map = {'BUY': 0, 'SELL': 1, 'HOLD': 2}
action_idx = action_map.get(action, 2)
# Create next state (simplified - could be current market state)
next_state = dqn_state # Placeholder - should be state after trade
# Store experience in DQN memory
dqn_agent = self.orchestrator.dqn_agent
if hasattr(dqn_agent, 'store_experience'):
dqn_agent.store_experience(
state=np.array(dqn_state),
action=action_idx,
reward=reward,
next_state=np.array(next_state),
done=True # Trade is complete
)
# Trigger training if enough experiences
if hasattr(dqn_agent, 'replay') and len(getattr(dqn_agent, 'memory', [])) > 32:
dqn_agent.replay(batch_size=32)
logger.info("DQN training step completed")
return True
else:
logger.warning("DQN agent doesn't support experience storage")
return False
except Exception as e:
logger.error(f"Error training DQN on trade outcome: {e}")
return False
def _train_cnn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
"""Train CNN on trade outcome (placeholder)"""
try:
if not self.orchestrator:
return False
# Check if CNN is available
if not hasattr(self.orchestrator, 'williams_cnn') or not self.orchestrator.williams_cnn:
logger.debug("CNN not available for training")
return False
# Get CNN features from model inputs
model_inputs = trade_record.get('model_inputs_at_entry', {})
cnn_features = model_inputs.get('cnn_features')
cnn_predictions = model_inputs.get('cnn_predictions')
if not cnn_features or not cnn_predictions:
logger.debug("No CNN features available for training")
return False
# CNN training would go here - requires more specific implementation
# For now, just log that we could train CNN
logger.debug(f"CNN training opportunity: features={len(cnn_features)}, predictions={len(cnn_predictions)}")
return True
except Exception as e:
logger.debug(f"Error in CNN training: {e}")
return False
def _train_cob_rl_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float) -> bool:
"""Train COB RL on trade outcome (placeholder)"""
try:
if not self.orchestrator:
return False
# Check if COB integration is available
if not hasattr(self.orchestrator, 'cob_integration') or not self.orchestrator.cob_integration:
logger.debug("COB integration not available for training")
return False
# Get COB features from model inputs
model_inputs = trade_record.get('model_inputs_at_entry', {})
cob_features = model_inputs.get('cob_features')
if not cob_features:
logger.debug("No COB features available for training")
return False
# COB RL training would go here - requires more specific implementation
# For now, just log that we could train COB RL
logger.debug(f"COB RL training opportunity: features={len(cob_features)}")
return True
except Exception as e:
logger.debug(f"Error in COB RL training: {e}")
return False
def get_training_status(self) -> Dict[str, Any]:
"""Get current training integration status"""
try:
status = {
'orchestrator_available': self.orchestrator is not None,
'training_sessions': len(self.training_sessions),
'last_update': datetime.now().isoformat()
}
if self.orchestrator:
status['dqn_available'] = hasattr(self.orchestrator, 'dqn_agent') and self.orchestrator.dqn_agent is not None
status['cnn_available'] = hasattr(self.orchestrator, 'williams_cnn') and self.orchestrator.williams_cnn is not None
status['cob_available'] = hasattr(self.orchestrator, 'cob_integration') and self.orchestrator.cob_integration is not None
return status
except Exception as e:
logger.error(f"Error getting training status: {e}")
return {'error': str(e)}
def start_training_session(self, session_name: str, config: Dict[str, Any] = None) -> str:
"""Start a new training session"""
try:
session_id = f"{session_name}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
session_data = {
'session_id': session_id,
'session_name': session_name,
'start_time': datetime.now().isoformat(),
'config': config or {},
'trades_processed': 0,
'successful_trainings': 0,
'failed_trainings': 0
}
self.training_sessions[session_id] = session_data
logger.info(f"Started training session: {session_id}")
return session_id
except Exception as e:
logger.error(f"Error starting training session: {e}")
return ""
def end_training_session(self, session_id: str) -> Dict[str, Any]:
"""End a training session and return summary"""
try:
if session_id not in self.training_sessions:
logger.warning(f"Training session not found: {session_id}")
return {}
session_data = self.training_sessions[session_id]
session_data['end_time'] = datetime.now().isoformat()
# Calculate session duration
start_time = datetime.fromisoformat(session_data['start_time'])
end_time = datetime.fromisoformat(session_data['end_time'])
duration = (end_time - start_time).total_seconds()
session_data['duration_seconds'] = duration
# Calculate success rate
total_attempts = session_data['successful_trainings'] + session_data['failed_trainings']
session_data['success_rate'] = session_data['successful_trainings'] / total_attempts if total_attempts > 0 else 0
logger.info(f"Ended training session: {session_id}")
logger.info(f" Duration: {duration:.1f}s")
logger.info(f" Trades processed: {session_data['trades_processed']}")
logger.info(f" Success rate: {session_data['success_rate']:.2%}")
# Remove from active sessions
completed_session = self.training_sessions.pop(session_id)
return completed_session
except Exception as e:
logger.error(f"Error ending training session: {e}")
return {}
def update_session_stats(self, session_id: str, trade_processed: bool = True, training_success: bool = False):
"""Update training session statistics"""
try:
if session_id not in self.training_sessions:
return
session = self.training_sessions[session_id]
if trade_processed:
session['trades_processed'] += 1
if training_success:
session['successful_trainings'] += 1
else:
session['failed_trainings'] += 1
except Exception as e:
logger.error(f"Error updating session stats: {e}")

View File

@ -152,7 +152,7 @@ def start_clean_dashboard_with_training():
enhanced_rl_training=True, # Enable RL training
model_registry={}
)
logger.info("Enhanced Trading Orchestrator created with training enabled")
logger.info("Enhanced Trading Orchestrator created with training enabled")
# Create trading executor
trading_executor = TradingExecutor()
@ -166,7 +166,7 @@ def start_clean_dashboard_with_training():
orchestrator=orchestrator,
trading_executor=trading_executor
)
logger.info("Clean Trading Dashboard created")
logger.info("Clean Trading Dashboard created")
# Start training pipeline in background thread
def training_worker():
@ -178,7 +178,7 @@ def start_clean_dashboard_with_training():
training_thread = threading.Thread(target=training_worker, daemon=True)
training_thread.start()
logger.info("Training pipeline started in background")
logger.info("Training pipeline started in background")
# Wait a moment for training to initialize
time.sleep(3)

View File

@ -663,9 +663,9 @@ class CleanTradingDashboard:
color='rgba(0, 255, 100, 0.9)',
line=dict(width=3, color='green')
),
name='EXECUTED BUY',
name='EXECUTED BUY',
showlegend=True,
hovertemplate="<b>EXECUTED BUY TRADE</b><br>" +
hovertemplate="<b>EXECUTED BUY TRADE</b><br>" +
"Price: $%{y:.2f}<br>" +
"Time: %{x}<br>" +
"Confidence: %{customdata:.1%}<extra></extra>",
@ -687,9 +687,9 @@ class CleanTradingDashboard:
color='rgba(255, 100, 100, 0.9)',
line=dict(width=3, color='red')
),
name='EXECUTED SELL',
name='EXECUTED SELL',
showlegend=True,
hovertemplate="<b>EXECUTED SELL TRADE</b><br>" +
hovertemplate="<b>EXECUTED SELL TRADE</b><br>" +
"Price: $%{y:.2f}<br>" +
"Time: %{x}<br>" +
"Confidence: %{customdata:.1%}<extra></extra>",
@ -768,9 +768,9 @@ class CleanTradingDashboard:
color='rgba(0, 255, 100, 1.0)',
line=dict(width=2, color='green')
),
name='BUY (Executed)',
name='BUY (Executed)',
showlegend=False,
hovertemplate="<b>BUY EXECUTED</b><br>" +
hovertemplate="<b>BUY EXECUTED</b><br>" +
"Price: $%{y:.2f}<br>" +
"Time: %{x}<br>" +
"Confidence: %{customdata:.1%}<extra></extra>",
@ -822,9 +822,9 @@ class CleanTradingDashboard:
color='rgba(255, 100, 100, 1.0)',
line=dict(width=2, color='red')
),
name='SELL (Executed)',
name='SELL (Executed)',
showlegend=False,
hovertemplate="<b>SELL EXECUTED</b><br>" +
hovertemplate="<b>SELL EXECUTED</b><br>" +
"Price: $%{y:.2f}<br>" +
"Time: %{x}<br>" +
"Confidence: %{customdata:.1%}<extra></extra>",
@ -1540,8 +1540,16 @@ class CleanTradingDashboard:
logger.warning("No current price available for manual trade")
return
# CAPTURE ALL MODEL INPUTS FOR COLD START TRAINING
model_inputs = self._capture_comprehensive_model_inputs(symbol, action, current_price)
# CAPTURE ALL MODEL INPUTS FOR COLD START TRAINING using core TradeDataManager
try:
from core.trade_data_manager import TradeDataManager
trade_data_manager = TradeDataManager()
model_inputs = trade_data_manager.capture_comprehensive_model_inputs(
symbol, action, current_price, self.orchestrator, self.data_provider
)
except Exception as e:
logger.warning(f"Failed to capture model inputs via TradeDataManager: {e}")
model_inputs = {}
# Create manual trading decision
decision = {
@ -1588,8 +1596,13 @@ class CleanTradingDashboard:
# Add to closed trades for display
self.closed_trades.append(trade_record)
# Store for cold start training when trade closes
self._store_trade_for_training(trade_record)
# Store for cold start training when trade closes using core TradeDataManager
try:
case_id = trade_data_manager.store_trade_for_training(trade_record)
if case_id:
logger.info(f"Trade stored for training with case ID: {case_id}")
except Exception as e:
logger.warning(f"Failed to store trade for training: {e}")
# Update session metrics
if action == 'BUY':
@ -1600,8 +1613,17 @@ class CleanTradingDashboard:
self.session_pnl += demo_pnl
trade_record['pnl'] = demo_pnl
# TRIGGER COLD START TRAINING on profitable demo trade
self._trigger_cold_start_training(trade_record, demo_pnl)
# TRIGGER COLD START TRAINING on profitable demo trade using core TrainingIntegration
try:
from core.training_integration import TrainingIntegration
training_integration = TrainingIntegration(self.orchestrator)
training_success = training_integration.trigger_cold_start_training(trade_record, case_id)
if training_success:
logger.info("Cold start training completed successfully")
else:
logger.warning("Cold start training failed")
except Exception as e:
logger.warning(f"Failed to trigger cold start training: {e}")
else:
decision['executed'] = False
@ -1625,89 +1647,7 @@ class CleanTradingDashboard:
except Exception as e:
logger.error(f"Error executing manual {action}: {e}")
def _capture_comprehensive_model_inputs(self, symbol: str, action: str, current_price: float) -> Dict[str, Any]:
"""Capture comprehensive model inputs for cold start training"""
try:
logger.info(f"Capturing model inputs for {action} trade on {symbol} at ${current_price:.2f}")
model_inputs = {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'action': action,
'price': current_price,
'capture_type': 'trade_execution'
}
# 1. Market State Features
try:
market_state = self._get_comprehensive_market_state(symbol, current_price)
model_inputs['market_state'] = market_state
logger.debug(f"Captured market state: {len(market_state)} features")
except Exception as e:
logger.warning(f"Error capturing market state: {e}")
model_inputs['market_state'] = {}
# 2. CNN Features and Predictions
try:
cnn_data = self._get_cnn_features_and_predictions(symbol)
model_inputs['cnn_features'] = cnn_data.get('features', {})
model_inputs['cnn_predictions'] = cnn_data.get('predictions', {})
logger.debug(f"Captured CNN data: {len(cnn_data)} items")
except Exception as e:
logger.warning(f"Error capturing CNN data: {e}")
model_inputs['cnn_features'] = {}
model_inputs['cnn_predictions'] = {}
# 3. DQN/RL State Features
try:
dqn_state = self._get_dqn_state_features(symbol, current_price)
model_inputs['dqn_state'] = dqn_state
logger.debug(f"Captured DQN state: {len(dqn_state) if dqn_state else 0} features")
except Exception as e:
logger.warning(f"Error capturing DQN state: {e}")
model_inputs['dqn_state'] = {}
# 4. COB (Order Book) Features
try:
cob_data = self._get_cob_features_for_training(symbol)
model_inputs['cob_features'] = cob_data
logger.debug(f"Captured COB features: {len(cob_data) if cob_data else 0} features")
except Exception as e:
logger.warning(f"Error capturing COB features: {e}")
model_inputs['cob_features'] = {}
# 5. Technical Indicators
try:
technical_indicators = self._get_technical_indicators(symbol)
model_inputs['technical_indicators'] = technical_indicators
logger.debug(f"Captured technical indicators: {len(technical_indicators)} indicators")
except Exception as e:
logger.warning(f"Error capturing technical indicators: {e}")
model_inputs['technical_indicators'] = {}
# 6. Recent Price History (for context)
try:
price_history = self._get_recent_price_history(symbol, periods=50)
model_inputs['price_history'] = price_history
logger.debug(f"Captured price history: {len(price_history)} periods")
except Exception as e:
logger.warning(f"Error capturing price history: {e}")
model_inputs['price_history'] = []
total_features = sum(len(v) if isinstance(v, (dict, list)) else 1 for v in model_inputs.values())
logger.info(f"✅ Captured {total_features} total features for cold start training")
return model_inputs
except Exception as e:
logger.error(f"Error capturing model inputs: {e}")
return {
'timestamp': datetime.now().isoformat(),
'symbol': symbol,
'action': action,
'price': current_price,
'error': str(e)
}
# Model input capture moved to core.trade_data_manager.TradeDataManager
def _get_comprehensive_market_state(self, symbol: str, current_price: float) -> Dict[str, float]:
"""Get comprehensive market state features"""
@ -1885,150 +1825,9 @@ class CleanTradingDashboard:
logger.debug(f"Error getting price history: {e}")
return []
def _store_trade_for_training(self, trade_record: Dict[str, Any]):
"""Store trade for future cold start training"""
try:
# Create training data storage directory
import os
training_dir = "training_data"
os.makedirs(training_dir, exist_ok=True)
# Trade storage moved to core.trade_data_manager.TradeDataManager
# Store trade data with timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"trade_{trade_record['symbol'].replace('/', '')}_{timestamp}.json"
filepath = os.path.join(training_dir, filename)
import json
with open(filepath, 'w') as f:
json.dump(trade_record, f, indent=2, default=str)
logger.info(f"✅ Stored trade data for training: {filepath}")
# Also store in memory for immediate access
if not hasattr(self, 'stored_trades'):
self.stored_trades = []
self.stored_trades.append(trade_record)
# Keep only last 100 trades in memory
if len(self.stored_trades) > 100:
self.stored_trades = self.stored_trades[-100:]
except Exception as e:
logger.error(f"Error storing trade for training: {e}")
def _trigger_cold_start_training(self, trade_record: Dict[str, Any], pnl: float):
"""Trigger cold start training when we have trade outcome"""
try:
logger.info(f"🔥 TRIGGERING COLD START TRAINING")
logger.info(f"Trade: {trade_record['side']} {trade_record['symbol']} @ ${trade_record['entry_price']:.2f}")
logger.info(f"P&L: ${pnl:.4f} ({'PROFIT' if pnl > 0 else 'LOSS'})")
# Calculate reward based on P&L
reward = self._calculate_training_reward(pnl, trade_record)
# Send to DQN agent if available
if hasattr(self.orchestrator, 'sensitivity_dqn_agent') and self.orchestrator.sensitivity_dqn_agent:
self._train_dqn_on_trade_outcome(trade_record, reward)
# Send to CNN if available
if hasattr(self.orchestrator, 'williams_structure') and self.orchestrator.williams_structure:
self._train_cnn_on_trade_outcome(trade_record, reward)
# Send to COB RL if available
if hasattr(self.orchestrator, 'cob_integration') and self.orchestrator.cob_integration:
self._train_cob_rl_on_trade_outcome(trade_record, reward)
logger.info(f"✅ Cold start training triggered with reward: {reward:.4f}")
except Exception as e:
logger.error(f"Error triggering cold start training: {e}")
def _calculate_training_reward(self, pnl: float, trade_record: Dict[str, Any]) -> float:
"""Calculate training reward based on trade outcome"""
try:
# Base reward from P&L
base_reward = pnl * 100 # Scale up for training
# Confidence adjustment (higher confidence wrong predictions get bigger penalties)
confidence = trade_record.get('confidence', 0.5)
if pnl < 0: # Loss
confidence_penalty = confidence * 2 # Higher confidence losses hurt more
base_reward *= (1 + confidence_penalty)
else: # Profit
confidence_bonus = confidence * 0.5 # Higher confidence wins get small bonus
base_reward *= (1 + confidence_bonus)
# Time-based adjustment (faster profits are better)
# For demo trades, just use a small bonus
time_bonus = 0.1 if pnl > 0 else 0
final_reward = base_reward + time_bonus
logger.debug(f"Reward calculation: P&L={pnl:.4f}, Confidence={confidence:.2f}, Final={final_reward:.4f}")
return final_reward
except Exception as e:
logger.warning(f"Error calculating reward: {e}")
return pnl # Fallback to simple P&L
def _train_dqn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float):
"""Train DQN agent on trade outcome"""
try:
dqn_agent = self.orchestrator.sensitivity_dqn_agent
# Get the state that was used for the decision
model_inputs = trade_record.get('model_inputs_at_entry', {})
dqn_state = model_inputs.get('dqn_state', {}).get('state_vector', [])
if not dqn_state:
logger.debug("No DQN state available for training")
return
# Convert to numpy array
state = np.array(dqn_state, dtype=np.float32)
# Map action to DQN action space
action = 1 if trade_record['side'] == 'BUY' else 0
# Create next state (current market state after trade)
current_state = self._get_dqn_state_features(trade_record['symbol'], trade_record['entry_price'])
next_state = np.array(current_state.get('state_vector', state), dtype=np.float32)
# Add experience to DQN memory
dqn_agent.remember(state, action, reward, next_state, True) # done=True for completed trade
# Trigger training if enough experiences
if len(dqn_agent.memory) >= dqn_agent.batch_size:
loss = dqn_agent.replay()
if loss:
logger.info(f"🧠 DQN trained on trade outcome - Loss: {loss:.6f}, Reward: {reward:.4f}")
except Exception as e:
logger.debug(f"Error training DQN on trade outcome: {e}")
def _train_cnn_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float):
"""Train CNN on trade outcome (simplified for now)"""
try:
# CNN training requires more complex setup - log for now
logger.info(f"📊 CNN training opportunity: {trade_record['side']} with reward {reward:.4f}")
# In future: extract CNN features from model_inputs and create training sample
except Exception as e:
logger.debug(f"Error training CNN on trade outcome: {e}")
def _train_cob_rl_on_trade_outcome(self, trade_record: Dict[str, Any], reward: float):
"""Train COB RL on trade outcome (simplified for now)"""
try:
# COB RL training requires accessing the 400M parameter model - log for now
logger.info(f"📈 COB RL training opportunity: {trade_record['side']} with reward {reward:.4f}")
# In future: access COB RL model and create training sample
except Exception as e:
logger.debug(f"Error training COB RL on trade outcome: {e}")
# Cold start training moved to core.training_integration.TrainingIntegration
def _clear_session(self):
"""Clear session data"""
@ -2487,9 +2286,151 @@ class CleanTradingDashboard:
except Exception as e:
logger.error(f"Error handling universal stream data: {e}")
# Factory function for easy creation
def _update_case_index(self, case_dir: str, case_id: str, case_summary: Dict[str, Any], case_type: str):
"""Update the case index file with new case information"""
try:
import json
import os
index_filepath = os.path.join(case_dir, "case_index.json")
# Load existing index or create new one
if os.path.exists(index_filepath):
with open(index_filepath, 'r') as f:
index_data = json.load(f)
else:
index_data = {
"cases": [],
"last_updated": datetime.now().isoformat(),
"case_type": case_type,
"total_cases": 0
}
# Add new case to index
pnl = case_summary.get('pnl', 0)
training_priority = 1 # Default priority
# Calculate training priority based on P&L and confidence
if case_type == "negative":
# Higher priority for bigger losses
if abs(pnl) > 10:
training_priority = 5 # Very high priority
elif abs(pnl) > 5:
training_priority = 4
elif abs(pnl) > 1:
training_priority = 3
else:
training_priority = 2
else: # positive
# Higher priority for high-confidence profitable trades
confidence = case_summary.get('confidence', 0)
if pnl > 5 and confidence > 0.8:
training_priority = 5
elif pnl > 1 and confidence > 0.6:
training_priority = 4
elif pnl > 0.5:
training_priority = 3
else:
training_priority = 2
case_entry = {
"case_id": case_id,
"timestamp": case_summary['timestamp'],
"symbol": case_summary['symbol'],
"side": case_summary['side'],
"entry_price": case_summary['entry_price'],
"pnl": pnl,
"confidence": case_summary.get('confidence', 0),
"trade_type": case_summary.get('trade_type', 'unknown'),
"training_priority": training_priority,
"retraining_count": 0,
"model_inputs_captured": case_summary.get('model_inputs_captured', False),
"feature_counts": case_summary.get('feature_counts', {}),
"created_at": datetime.now().isoformat()
}
# Add to cases list
index_data["cases"].append(case_entry)
index_data["last_updated"] = datetime.now().isoformat()
index_data["total_cases"] = len(index_data["cases"])
# Sort by training priority (highest first) and timestamp (newest first)
index_data["cases"].sort(key=lambda x: (-x['training_priority'], -time.mktime(datetime.fromisoformat(x['timestamp']).timetuple())))
# Keep only last 1000 cases to prevent index from getting too large
if len(index_data["cases"]) > 1000:
index_data["cases"] = index_data["cases"][:1000]
index_data["total_cases"] = 1000
# Save updated index
with open(index_filepath, 'w') as f:
json.dump(index_data, f, indent=2, default=str)
logger.debug(f"Updated {case_type} case index: {len(index_data['cases'])} total cases")
except Exception as e:
logger.error(f"Error updating case index: {e}")
def get_testcase_summary(self) -> Dict[str, Any]:
"""Get summary of stored testcases for display"""
try:
import os
import json
summary = {
'positive_cases': 0,
'negative_cases': 0,
'total_cases': 0,
'latest_cases': [],
'high_priority_cases': 0
}
base_dir = "testcases"
for case_type in ['positive', 'negative']:
case_dir = os.path.join(base_dir, case_type)
index_filepath = os.path.join(case_dir, "case_index.json")
if os.path.exists(index_filepath):
with open(index_filepath, 'r') as f:
index_data = json.load(f)
case_count = len(index_data.get('cases', []))
summary[f'{case_type}_cases'] = case_count
summary['total_cases'] += case_count
# Get high priority cases
high_priority = len([c for c in index_data.get('cases', []) if c.get('training_priority', 1) >= 4])
summary['high_priority_cases'] += high_priority
# Get latest cases
latest = index_data.get('cases', [])[:5] # Top 5 latest
for case in latest:
case['case_type'] = case_type
summary['latest_cases'].extend(latest)
# Sort latest cases by timestamp
summary['latest_cases'].sort(key=lambda x: x.get('timestamp', ''), reverse=True)
# Keep only top 10 latest cases
summary['latest_cases'] = summary['latest_cases'][:10]
return summary
except Exception as e:
logger.error(f"Error getting testcase summary: {e}")
return {
'positive_cases': 0,
'negative_cases': 0,
'total_cases': 0,
'latest_cases': [],
'high_priority_cases': 0,
'error': str(e)
}
def create_clean_dashboard(data_provider=None, orchestrator=None, trading_executor=None):
"""Create a clean trading dashboard instance"""
"""Factory function to create a CleanTradingDashboard instance"""
return CleanTradingDashboard(
data_provider=data_provider,
orchestrator=orchestrator,