Compare commits
4 Commits
d17af5ca4b
...
26eeb9b35b
Author | SHA1 | Date | |
---|---|---|---|
26eeb9b35b | |||
1f60c80d67 | |||
78b4bb0f06 | |||
045780758a |
1
.gitignore
vendored
1
.gitignore
vendored
@ -48,3 +48,4 @@ chrome_user_data/*
|
||||
|
||||
.env
|
||||
.env
|
||||
training_data/*
|
||||
|
@ -74,6 +74,16 @@ Based on the existing implementation in `core/data_provider.py`, we'll enhance i
|
||||
- 1,5,15 and 60s MA of the COB imbalance counting +- 5 COB buckets
|
||||
- ***OUTPUTS***: suggested trade action (BUY/SELL)
|
||||
|
||||
# Standardized input for all models:
|
||||
{
|
||||
'primary_symbol': 'ETH/USDT',
|
||||
'reference_symbol': 'BTC/USDT',
|
||||
'eth_data': {'ETH_1s': df, 'ETH_1m': df, 'ETH_1h': df, 'ETH_1d': df},
|
||||
'btc_data': {'BTC_1s': df},
|
||||
'current_prices': {'ETH': price, 'BTC': price},
|
||||
'data_completeness': {...}
|
||||
}
|
||||
|
||||
### 2. CNN Model
|
||||
|
||||
The CNN Model is responsible for analyzing patterns in market data and predicting pivot points across multiple timeframes.
|
||||
|
@ -197,7 +197,9 @@
|
||||
- Ensure validation occurs before any model inference
|
||||
- _Requirements: 9.1, 9.4_
|
||||
|
||||
- [ ] 5.2. Implement persistent inference history storage
|
||||
- [x] 5.2. Implement persistent inference history storage
|
||||
|
||||
|
||||
- Create InferenceHistoryStore class for persistent storage
|
||||
- Store complete input data packages with each prediction
|
||||
- Include timestamp, symbol, input features, prediction outputs, confidence scores
|
||||
|
@ -70,6 +70,9 @@ class EnhancedCNNAdapter:
|
||||
else:
|
||||
self._load_best_checkpoint()
|
||||
|
||||
# Final device check and move
|
||||
self._ensure_model_on_device()
|
||||
|
||||
logger.info(f"EnhancedCNNAdapter initialized on {self.device}")
|
||||
|
||||
def _initialize_model(self):
|
||||
@ -88,9 +91,10 @@ class EnhancedCNNAdapter:
|
||||
|
||||
# Create model
|
||||
self.model = EnhancedCNN(input_shape=input_shape, n_actions=n_actions)
|
||||
# Ensure model is moved to the correct device
|
||||
self.model.to(self.device)
|
||||
|
||||
logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions}")
|
||||
logger.info(f"EnhancedCNN model initialized with input_shape={input_shape}, n_actions={n_actions} on device {self.device}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error initializing EnhancedCNN model: {e}")
|
||||
@ -102,7 +106,9 @@ class EnhancedCNNAdapter:
|
||||
if self.model and os.path.exists(checkpoint_path):
|
||||
success = self.model.load(checkpoint_path)
|
||||
if success:
|
||||
logger.info(f"Loaded model from {checkpoint_path}")
|
||||
# Ensure model is moved to the correct device after loading
|
||||
self.model.to(self.device)
|
||||
logger.info(f"Loaded model from {checkpoint_path} and moved to {self.device}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"Failed to load model from {checkpoint_path}")
|
||||
@ -146,7 +152,9 @@ class EnhancedCNNAdapter:
|
||||
success = self.model.load(best_checkpoint_path)
|
||||
|
||||
if success:
|
||||
logger.info(f"Loaded best checkpoint from {best_checkpoint_path}")
|
||||
# Ensure model is moved to the correct device after loading
|
||||
self.model.to(self.device)
|
||||
logger.info(f"Loaded best checkpoint from {best_checkpoint_path} and moved to {self.device}")
|
||||
|
||||
# Log metrics
|
||||
metrics = best_checkpoint_metadata.get('metrics', {})
|
||||
@ -161,7 +169,17 @@ class EnhancedCNNAdapter:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _ensure_model_on_device(self):
|
||||
"""Ensure model and all its components are on the correct device"""
|
||||
try:
|
||||
if self.model:
|
||||
self.model.to(self.device)
|
||||
# Also ensure the model's internal device is set correctly
|
||||
if hasattr(self.model, 'device'):
|
||||
self.model.device = self.device
|
||||
logger.debug(f"Model ensured on device {self.device}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring model on device: {e}")
|
||||
|
||||
def _create_default_output(self, symbol: str) -> ModelOutput:
|
||||
"""Create default output when prediction fails"""
|
||||
@ -235,6 +253,9 @@ class EnhancedCNNAdapter:
|
||||
if features.dim() == 1:
|
||||
features = features.unsqueeze(0)
|
||||
|
||||
# Ensure model is on correct device before prediction
|
||||
self._ensure_model_on_device()
|
||||
|
||||
# Set model to evaluation mode
|
||||
self.model.eval()
|
||||
|
||||
@ -399,6 +420,9 @@ class EnhancedCNNAdapter:
|
||||
logger.info(f"Not enough training data: {len(self.training_data)} samples, need at least {self.batch_size}")
|
||||
return {'loss': 0.0, 'accuracy': 0.0, 'samples': len(self.training_data)}
|
||||
|
||||
# Ensure model is on correct device before training
|
||||
self._ensure_model_on_device()
|
||||
|
||||
# Set model to training mode
|
||||
self.model.train()
|
||||
|
||||
@ -423,8 +447,8 @@ class EnhancedCNNAdapter:
|
||||
if len(batch) < 2:
|
||||
continue
|
||||
|
||||
# Prepare batch
|
||||
features = torch.stack([sample[0] for sample in batch])
|
||||
# Prepare batch - ensure all tensors are on the correct device
|
||||
features = torch.stack([sample[0].to(self.device) for sample in batch])
|
||||
actions = torch.tensor([sample[1] for sample in batch], dtype=torch.long, device=self.device)
|
||||
rewards = torch.tensor([sample[2] for sample in batch], dtype=torch.float32, device=self.device)
|
||||
|
||||
|
@ -110,7 +110,9 @@ class TradingOrchestrator:
|
||||
self.confidence_threshold_close = self.config.orchestrator.get('confidence_threshold_close', 0.08) # Lowered from 0.10
|
||||
# Decision frequency limit to prevent excessive trading
|
||||
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 30)
|
||||
self.symbols = self.config.get('symbols', ['ETH/USDT', 'BTC/USDT']) # Enhanced to support multiple symbols
|
||||
|
||||
self.symbol = self.config.get('symbol', "ETH/USDT") # main symbol we wre trading and making predictions on. only one!
|
||||
self.ref_symbols = self.config.get('ref_symbols', [ 'BTC/USDT']) # Enhanced to support multiple reference symbols. ToDo: we can add 'SOL/USDT' later
|
||||
|
||||
# NEW: Aggressiveness parameters
|
||||
self.entry_aggressiveness = self.config.orchestrator.get('entry_aggressiveness', 0.5) # 0.0 = conservative, 1.0 = very aggressive
|
||||
@ -153,12 +155,11 @@ class TradingOrchestrator:
|
||||
self.recent_cnn_predictions: Dict[str, deque] = {} # {symbol: List[Dict]} - Recent CNN predictions
|
||||
self.prediction_accuracy_history: Dict[str, deque] = {} # {symbol: List[Dict]} - Prediction accuracy tracking
|
||||
|
||||
# Initialize prediction tracking for each symbol
|
||||
for symbol in self.symbols:
|
||||
self.recent_dqn_predictions[symbol] = deque(maxlen=100)
|
||||
self.recent_cnn_predictions[symbol] = deque(maxlen=50)
|
||||
self.prediction_accuracy_history[symbol] = deque(maxlen=200)
|
||||
self.signal_accumulator[symbol] = []
|
||||
# Initialize prediction tracking for the primary trading symbol only
|
||||
self.recent_dqn_predictions[self.symbol] = deque(maxlen=100)
|
||||
self.recent_cnn_predictions[self.symbol] = deque(maxlen=50)
|
||||
self.prediction_accuracy_history[self.symbol] = deque(maxlen=200)
|
||||
self.signal_accumulator[self.symbol] = []
|
||||
|
||||
# Decision callbacks
|
||||
self.decision_callbacks: List[Any] = []
|
||||
@ -177,7 +178,7 @@ class TradingOrchestrator:
|
||||
self.latest_cob_data: Dict[str, Any] = {} # {symbol: COBSnapshot}
|
||||
self.latest_cob_features: Dict[str, Any] = {} # {symbol: np.ndarray} - CNN features
|
||||
self.latest_cob_state: Dict[str, Any] = {} # {symbol: np.ndarray} - DQN state features
|
||||
self.cob_feature_history: Dict[str, List[Any]] = {symbol: [] for symbol in self.symbols} # Rolling history for models
|
||||
self.cob_feature_history: Dict[str, List[Any]] = {self.symbol: []} # Rolling history for primary trading symbol
|
||||
|
||||
# Enhanced ML Models
|
||||
self.rl_agent: Any = None # DQN Agent
|
||||
@ -204,13 +205,13 @@ class TradingOrchestrator:
|
||||
# Training tracking
|
||||
self.last_trained_symbols: Dict[str, datetime] = {}
|
||||
|
||||
# INFERENCE DATA STORAGE - Store model inputs and outputs for training
|
||||
self.inference_history: Dict[str, deque] = {} # {symbol: deque of inference records}
|
||||
self.max_inference_history = 1000 # Keep last 1000 inference records per symbol
|
||||
# INFERENCE DATA STORAGE - Per-model storage with memory optimization
|
||||
self.inference_history: Dict[str, deque] = {} # {model_name: deque of last 5 inference records}
|
||||
self.max_memory_inferences = 5 # Keep only last 5 inferences in memory per model
|
||||
self.max_disk_files_per_model = 200 # Cap disk files per model
|
||||
|
||||
# Initialize inference history for each symbol
|
||||
for symbol in self.symbols:
|
||||
self.inference_history[symbol] = deque(maxlen=self.max_inference_history)
|
||||
# Initialize inference history for each model (will be populated as models make predictions)
|
||||
# We'll create entries dynamically as models are used
|
||||
|
||||
# ENHANCED: Real-time Training System Integration
|
||||
self.enhanced_training_system = None # Will be set to EnhancedRealtimeTrainingSystem if available
|
||||
@ -223,7 +224,7 @@ class TradingOrchestrator:
|
||||
logger.info(f"Training enabled: {self.training_enabled}")
|
||||
logger.info(f"Confidence threshold: {self.confidence_threshold}")
|
||||
# logger.info(f"Decision frequency: {self.decision_frequency}s")
|
||||
logger.info(f"Symbols: {self.symbols}")
|
||||
logger.info(f"Primary symbol: {self.symbol}, Reference symbols: {self.ref_symbols}")
|
||||
logger.info("Universal Data Adapter integrated for centralized data flow")
|
||||
|
||||
# Start centralized data collection for all models and dashboard
|
||||
@ -298,12 +299,12 @@ class TradingOrchestrator:
|
||||
logger.warning("DQN Agent not available")
|
||||
self.rl_agent = None
|
||||
|
||||
# Initialize CNN Model
|
||||
# Initialize CNN Model with Adapter
|
||||
try:
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
|
||||
self.cnn_model = StandardizedCNN()
|
||||
self.cnn_model.to(self.device) # Move CNN model to the determined device
|
||||
self.cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
self.cnn_model = self.cnn_adapter.model # Keep reference for compatibility
|
||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
|
||||
|
||||
# Load best checkpoint and capture initial state
|
||||
@ -331,11 +332,12 @@ class TradingOrchestrator:
|
||||
self.model_states['cnn']['best_loss'] = None
|
||||
logger.info("CNN starting fresh - no checkpoint found")
|
||||
|
||||
logger.info("Enhanced CNN model initialized")
|
||||
logger.info("Enhanced CNN adapter initialized")
|
||||
except ImportError:
|
||||
try:
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
self.cnn_model = StandardizedCNN()
|
||||
self.cnn_adapter = None # No adapter available
|
||||
self.cnn_model.to(self.device) # Move basic CNN model to the determined device
|
||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN
|
||||
|
||||
@ -358,6 +360,7 @@ class TradingOrchestrator:
|
||||
except ImportError:
|
||||
logger.warning("CNN model not available")
|
||||
self.cnn_model = None
|
||||
self.cnn_adapter = None
|
||||
self.cnn_optimizer = None # Ensure optimizer is also None if model is not available
|
||||
|
||||
# Initialize Extrema Trainer
|
||||
@ -365,7 +368,7 @@ class TradingOrchestrator:
|
||||
from core.extrema_trainer import ExtremaTrainer
|
||||
self.extrema_trainer = ExtremaTrainer(
|
||||
data_provider=self.data_provider,
|
||||
symbols=self.symbols
|
||||
symbols=[self.symbol] # Only primary trading symbol
|
||||
)
|
||||
|
||||
# Load checkpoint and capture initial state
|
||||
@ -617,7 +620,7 @@ class TradingOrchestrator:
|
||||
async def start_continuous_trading(self, symbols: Optional[List[str]] = None):
|
||||
"""Start the continuous trading loop, using a decision model and trading executor"""
|
||||
if symbols is None:
|
||||
symbols = self.symbols
|
||||
symbols = [self.symbol] # Only trade the primary symbol
|
||||
|
||||
if not self.realtime_processing_task:
|
||||
self.realtime_processing_task = asyncio.create_task(self._trading_decision_loop())
|
||||
@ -638,9 +641,9 @@ class TradingOrchestrator:
|
||||
logger.info("Trading decision loop started")
|
||||
while self.running:
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
await self.make_trading_decision(symbol)
|
||||
await asyncio.sleep(1) # Small delay between symbols
|
||||
# Only make decisions for the primary trading symbol
|
||||
await self.make_trading_decision(self.symbol)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await asyncio.sleep(self.decision_frequency)
|
||||
except Exception as e:
|
||||
@ -767,7 +770,7 @@ class TradingOrchestrator:
|
||||
if COB_INTEGRATION_AVAILABLE and COBIntegration is not None:
|
||||
try:
|
||||
self.cob_integration = COBIntegration(
|
||||
symbols=self.symbols,
|
||||
symbols=[self.symbol] + self.ref_symbols, # Primary + reference symbols
|
||||
data_provider=self.data_provider
|
||||
)
|
||||
logger.info("COB Integration initialized")
|
||||
@ -929,6 +932,11 @@ class TradingOrchestrator:
|
||||
if model.name not in self.model_performance:
|
||||
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
|
||||
|
||||
# Initialize inference history for this model
|
||||
if model.name not in self.inference_history:
|
||||
self.inference_history[model.name] = deque(maxlen=self.max_memory_inferences)
|
||||
logger.debug(f"Initialized inference history for {model.name}")
|
||||
|
||||
logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}")
|
||||
self._normalize_weights()
|
||||
return True
|
||||
@ -1023,6 +1031,9 @@ class TradingOrchestrator:
|
||||
except Exception as e:
|
||||
logger.error(f"Error in decision callback: {e}")
|
||||
|
||||
# Add training samples based on current market conditions
|
||||
await self._add_training_samples_from_predictions(symbol, predictions, current_price)
|
||||
|
||||
# Clean up memory periodically
|
||||
if len(self.recent_decisions[symbol]) % 200 == 0: # Reduced from 50 to 200
|
||||
self.model_registry.cleanup_all_models()
|
||||
@ -1033,6 +1044,47 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error making trading decision for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
async def _add_training_samples_from_predictions(self, symbol: str, predictions: List[Prediction], current_price: float):
|
||||
"""Add training samples to models based on current predictions and market conditions"""
|
||||
try:
|
||||
if not hasattr(self, 'cnn_adapter') or not self.cnn_adapter:
|
||||
return
|
||||
|
||||
# Get recent price data to evaluate if predictions would be correct
|
||||
recent_prices = self.data_provider.get_recent_prices(symbol, limit=10)
|
||||
if not recent_prices or len(recent_prices) < 2:
|
||||
return
|
||||
|
||||
# Calculate recent price change
|
||||
price_change_pct = (current_price - recent_prices[-2]) / recent_prices[-2] * 100
|
||||
|
||||
# Add training samples for CNN predictions
|
||||
for prediction in predictions:
|
||||
if 'cnn' in prediction.model_name.lower():
|
||||
# Determine reward based on prediction accuracy
|
||||
reward = 0.0
|
||||
|
||||
if prediction.action == 'BUY' and price_change_pct > 0.1:
|
||||
reward = min(price_change_pct * 0.1, 1.0) # Positive reward for correct BUY
|
||||
elif prediction.action == 'SELL' and price_change_pct < -0.1:
|
||||
reward = min(abs(price_change_pct) * 0.1, 1.0) # Positive reward for correct SELL
|
||||
elif prediction.action == 'HOLD' and abs(price_change_pct) < 0.1:
|
||||
reward = 0.1 # Small positive reward for correct HOLD
|
||||
else:
|
||||
reward = -0.05 # Small negative reward for incorrect prediction
|
||||
|
||||
# Add training sample
|
||||
self.cnn_adapter.add_training_sample(symbol, prediction.action, reward)
|
||||
logger.debug(f"Added CNN training sample: {prediction.action}, reward={reward:.3f}, price_change={price_change_pct:.2f}%")
|
||||
|
||||
# Trigger training if we have enough samples
|
||||
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
|
||||
training_results = self.cnn_adapter.train(epochs=1)
|
||||
logger.info(f"CNN training completed: loss={training_results.get('loss', 0):.4f}, accuracy={training_results.get('accuracy', 0):.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding training samples from predictions: {e}")
|
||||
|
||||
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
|
||||
"""Get predictions from all registered models with input data storage"""
|
||||
predictions = []
|
||||
@ -1050,8 +1102,12 @@ class TradingOrchestrator:
|
||||
# Get CNN predictions for each timeframe
|
||||
cnn_predictions = await self._get_cnn_predictions(model, symbol)
|
||||
predictions.extend(cnn_predictions)
|
||||
# Store input data for CNN
|
||||
# Store input data for CNN - store for each prediction
|
||||
model_input = input_data.get('cnn_input')
|
||||
if model_input is not None and cnn_predictions:
|
||||
# Store inference data for each CNN prediction
|
||||
for cnn_pred in cnn_predictions:
|
||||
await self._store_inference_data_async(model_name, model_input, cnn_pred, current_time, symbol)
|
||||
|
||||
elif isinstance(model, RLAgentInterface):
|
||||
# Get RL prediction
|
||||
@ -1061,6 +1117,8 @@ class TradingOrchestrator:
|
||||
prediction = rl_prediction
|
||||
# Store input data for RL
|
||||
model_input = input_data.get('rl_input')
|
||||
if model_input is not None:
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
|
||||
|
||||
else:
|
||||
# Generic model interface
|
||||
@ -1070,78 +1128,173 @@ class TradingOrchestrator:
|
||||
prediction = generic_prediction
|
||||
# Store input data for generic model
|
||||
model_input = input_data.get('generic_input')
|
||||
|
||||
# Store inference data for training
|
||||
if prediction and model_input is not None:
|
||||
self._store_inference_data(symbol, model_name, model_input, prediction, current_time)
|
||||
if model_input is not None:
|
||||
await self._store_inference_data_async(model_name, model_input, prediction, current_time, symbol)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting prediction from {model_name}: {e}")
|
||||
continue
|
||||
|
||||
# Debug: Log inference history status (only if low record count)
|
||||
total_records = sum(len(history) for history in self.inference_history.values())
|
||||
if total_records < 10: # Only log when we have few records
|
||||
logger.debug(f"Total inference records across all models: {total_records}")
|
||||
for model_name, history in self.inference_history.items():
|
||||
logger.debug(f" {model_name}: {len(history)} records")
|
||||
|
||||
# Trigger training based on previous inference data
|
||||
await self._trigger_model_training(symbol)
|
||||
|
||||
return predictions
|
||||
|
||||
async def _collect_model_input_data(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Collect comprehensive input data for all models"""
|
||||
"""Collect standardized input data for all models - ETH primary + BTC reference"""
|
||||
try:
|
||||
input_data = {}
|
||||
# Only collect data for ETH (primary symbol) - we inference only for ETH
|
||||
if symbol != 'ETH/USDT':
|
||||
return {}
|
||||
|
||||
# Get current market data from data provider
|
||||
current_price = self.data_provider.get_current_price(symbol)
|
||||
# Standardized input: 4 ETH timeframes + 1s BTC reference
|
||||
eth_data = {}
|
||||
eth_timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
# Collect OHLCV data for multiple timeframes
|
||||
ohlcv_data = {}
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
for tf in timeframes:
|
||||
df = self.data_provider.get_historical_data(symbol, tf, limit=300)
|
||||
# Collect ETH data for all timeframes
|
||||
for tf in eth_timeframes:
|
||||
df = self.data_provider.get_historical_data('ETH/USDT', tf, limit=300)
|
||||
if df is not None and not df.empty:
|
||||
ohlcv_data[tf] = df
|
||||
eth_data[f'ETH_{tf}'] = df
|
||||
|
||||
# Collect COB data if available
|
||||
cob_data = self.get_cob_snapshot(symbol)
|
||||
# Collect BTC 1s reference data
|
||||
btc_1s = self.data_provider.get_historical_data('BTC/USDT', '1s', limit=300)
|
||||
btc_data = {}
|
||||
if btc_1s is not None and not btc_1s.empty:
|
||||
btc_data['BTC_1s'] = btc_1s
|
||||
|
||||
# Collect technical indicators
|
||||
technical_indicators = {}
|
||||
if '1h' in ohlcv_data:
|
||||
df = ohlcv_data['1h']
|
||||
if len(df) > 20:
|
||||
technical_indicators['sma_20'] = df['close'].rolling(20).mean().iloc[-1]
|
||||
technical_indicators['rsi'] = self._calculate_rsi(df['close'])
|
||||
# Get current prices
|
||||
eth_price = self.data_provider.get_current_price('ETH/USDT')
|
||||
btc_price = self.data_provider.get_current_price('BTC/USDT')
|
||||
|
||||
# Prepare CNN input
|
||||
cnn_input = self._prepare_cnn_input_data(ohlcv_data, cob_data, technical_indicators)
|
||||
|
||||
# Prepare RL input
|
||||
rl_input = self._prepare_rl_input_data(ohlcv_data, cob_data, technical_indicators)
|
||||
|
||||
# Prepare generic input
|
||||
generic_input = {
|
||||
'symbol': symbol,
|
||||
'current_price': current_price,
|
||||
'ohlcv_data': ohlcv_data,
|
||||
'cob_data': cob_data,
|
||||
'technical_indicators': technical_indicators
|
||||
}
|
||||
|
||||
input_data = {
|
||||
'cnn_input': cnn_input,
|
||||
'rl_input': rl_input,
|
||||
'generic_input': generic_input,
|
||||
# Create standardized input package
|
||||
standardized_input = {
|
||||
'timestamp': datetime.now(),
|
||||
'symbol': symbol
|
||||
'primary_symbol': 'ETH/USDT',
|
||||
'reference_symbol': 'BTC/USDT',
|
||||
'eth_data': eth_data,
|
||||
'btc_data': btc_data,
|
||||
'current_prices': {
|
||||
'ETH': eth_price,
|
||||
'BTC': btc_price
|
||||
},
|
||||
'data_completeness': {
|
||||
'eth_timeframes': len(eth_data),
|
||||
'btc_reference': len(btc_data),
|
||||
'total_expected': 5 # 4 ETH + 1 BTC
|
||||
}
|
||||
}
|
||||
|
||||
return input_data
|
||||
# Create model-specific input data
|
||||
model_inputs = {
|
||||
'cnn_input': standardized_input,
|
||||
'rl_input': standardized_input,
|
||||
'generic_input': standardized_input,
|
||||
'standardized_input': standardized_input
|
||||
}
|
||||
|
||||
return model_inputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error collecting model input data for {symbol}: {e}")
|
||||
logger.error(f"Error collecting standardized model input data: {e}")
|
||||
return {}
|
||||
|
||||
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
|
||||
"""Prepare standardized input data for CNN models"""
|
||||
async def _store_inference_data_async(self, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime, symbol: str = None):
|
||||
"""Store inference data per-model with async file operations and memory optimization"""
|
||||
try:
|
||||
# Only log first few inference records to avoid spam
|
||||
if len(self.inference_history.get(model_name, [])) < 3:
|
||||
logger.debug(f"Storing inference data for {model_name}: {prediction.action} (confidence: {prediction.confidence:.3f})")
|
||||
|
||||
# Extract symbol from prediction if not provided
|
||||
if symbol is None:
|
||||
symbol = getattr(prediction, 'symbol', 'ETH/USDT') # Default to ETH/USDT if not available
|
||||
|
||||
# Create comprehensive inference record
|
||||
inference_record = {
|
||||
'timestamp': timestamp.isoformat(),
|
||||
'symbol': symbol,
|
||||
'model_name': model_name,
|
||||
'model_input': model_input,
|
||||
'prediction': {
|
||||
'action': prediction.action,
|
||||
'confidence': prediction.confidence,
|
||||
'probabilities': prediction.probabilities,
|
||||
'timeframe': prediction.timeframe
|
||||
},
|
||||
'metadata': prediction.metadata or {}
|
||||
}
|
||||
|
||||
# Store in memory (only last 5 per model)
|
||||
if model_name not in self.inference_history:
|
||||
self.inference_history[model_name] = deque(maxlen=self.max_memory_inferences)
|
||||
|
||||
self.inference_history[model_name].append(inference_record)
|
||||
|
||||
# Async file storage (don't wait for completion)
|
||||
asyncio.create_task(self._save_inference_to_disk_async(model_name, inference_record))
|
||||
|
||||
logger.debug(f"Stored inference data for {model_name} (memory: {len(self.inference_history[model_name])}/5)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing inference data for {model_name}: {e}")
|
||||
|
||||
async def _save_inference_to_disk_async(self, model_name: str, inference_record: Dict):
|
||||
"""Async save inference record to disk with file capping"""
|
||||
try:
|
||||
# Create model-specific directory
|
||||
model_dir = Path(f"training_data/inference_history/{model_name}")
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create filename with timestamp
|
||||
timestamp_str = datetime.fromisoformat(inference_record['timestamp']).strftime('%Y%m%d_%H%M%S_%f')[:-3]
|
||||
filename = f"inference_{timestamp_str}.json"
|
||||
filepath = model_dir / filename
|
||||
|
||||
# Convert to JSON-serializable format
|
||||
serializable_record = self._make_json_serializable(inference_record)
|
||||
|
||||
# Save to file
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(serializable_record, f, indent=2)
|
||||
|
||||
# Cap files per model (keep only latest 200)
|
||||
await self._cap_model_files(model_dir)
|
||||
|
||||
logger.debug(f"Saved inference record to disk: {filepath}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving inference to disk for {model_name}: {e}")
|
||||
|
||||
async def _cap_model_files(self, model_dir: Path):
|
||||
"""Cap the number of files per model to max_disk_files_per_model"""
|
||||
try:
|
||||
# Get all inference files
|
||||
files = list(model_dir.glob("inference_*.json"))
|
||||
|
||||
if len(files) > self.max_disk_files_per_model:
|
||||
# Sort by modification time (oldest first)
|
||||
files.sort(key=lambda x: x.stat().st_mtime)
|
||||
|
||||
# Remove oldest files
|
||||
files_to_remove = files[:-self.max_disk_files_per_model]
|
||||
for file_path in files_to_remove:
|
||||
file_path.unlink()
|
||||
|
||||
logger.debug(f"Removed {len(files_to_remove)} old inference files from {model_dir.name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error capping model files in {model_dir}: {e}")
|
||||
|
||||
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
|
||||
"""Prepare standardized input data for CNN models with proper GPU device placement"""
|
||||
try:
|
||||
# Create feature matrix from OHLCV data
|
||||
features = []
|
||||
@ -1168,16 +1321,18 @@ class TradingOrchestrator:
|
||||
feature_array = np.pad(feature_array, (0, 300 - len(feature_array)), 'constant')
|
||||
else:
|
||||
feature_array = feature_array[:300]
|
||||
return feature_array.reshape(1, -1)
|
||||
# Convert to tensor and move to GPU
|
||||
return torch.tensor(feature_array.reshape(1, -1), dtype=torch.float32, device=self.device)
|
||||
else:
|
||||
return np.zeros((1, 300))
|
||||
# Return zero tensor on GPU
|
||||
return torch.zeros((1, 300), dtype=torch.float32, device=self.device)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing CNN input data: {e}")
|
||||
return np.zeros((1, 300))
|
||||
return torch.zeros((1, 300), dtype=torch.float32, device=self.device)
|
||||
|
||||
def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> np.ndarray:
|
||||
"""Prepare standardized input data for RL models"""
|
||||
def _prepare_rl_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
|
||||
"""Prepare standardized input data for RL models with proper GPU device placement"""
|
||||
try:
|
||||
# Create state representation
|
||||
state_features = []
|
||||
@ -1205,13 +1360,15 @@ class TradingOrchestrator:
|
||||
state_array = np.pad(state_array, (0, expected_size - len(state_array)), 'constant')
|
||||
else:
|
||||
state_array = state_array[:expected_size]
|
||||
return state_array
|
||||
# Convert to tensor and move to GPU
|
||||
return torch.tensor(state_array, dtype=torch.float32, device=self.device)
|
||||
else:
|
||||
return np.zeros(100)
|
||||
# Return zero tensor on GPU
|
||||
return torch.zeros(100, dtype=torch.float32, device=self.device)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error preparing RL input data: {e}")
|
||||
return np.zeros(100)
|
||||
return torch.zeros(100, dtype=torch.float32, device=self.device)
|
||||
|
||||
def _store_inference_data(self, symbol: str, model_name: str, model_input: Any, prediction: Prediction, timestamp: datetime):
|
||||
"""Store comprehensive inference data for future training with persistent storage"""
|
||||
@ -1262,10 +1419,12 @@ class TradingOrchestrator:
|
||||
'outcome_evaluated': False
|
||||
}
|
||||
|
||||
# Store in memory (inference history)
|
||||
if symbol in self.inference_history:
|
||||
self.inference_history[symbol].append(inference_record)
|
||||
logger.debug(f"Stored inference data for {model_name} on {symbol}")
|
||||
# Store in memory (inference history) - keyed by model_name
|
||||
if model_name not in self.inference_history:
|
||||
self.inference_history[model_name] = deque(maxlen=self.max_memory_inferences)
|
||||
|
||||
self.inference_history[model_name].append(inference_record)
|
||||
logger.debug(f"Stored inference data for {model_name} on {symbol}")
|
||||
|
||||
# Persistent storage to disk (for long-term training data)
|
||||
self._save_inference_to_disk(inference_record)
|
||||
@ -1350,6 +1509,35 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error loading inference history from disk: {e}")
|
||||
return []
|
||||
|
||||
async def load_model_inference_history(self, model_name: str, limit: int = 50) -> List[Dict]:
|
||||
"""Load inference history for a specific model from disk"""
|
||||
try:
|
||||
model_dir = Path(f"training_data/inference_history/{model_name}")
|
||||
if not model_dir.exists():
|
||||
return []
|
||||
|
||||
# Get all inference files
|
||||
files = list(model_dir.glob("inference_*.json"))
|
||||
files.sort(key=lambda x: x.stat().st_mtime, reverse=True) # Newest first
|
||||
|
||||
# Load up to 'limit' files
|
||||
inference_records = []
|
||||
for filepath in files[:limit]:
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
record = json.load(f)
|
||||
inference_records.append(record)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading inference file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Loaded {len(inference_records)} inference records for {model_name}")
|
||||
return inference_records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model inference history for {model_name}: {e}")
|
||||
return []
|
||||
|
||||
def get_model_training_data(self, model_name: str, symbol: str = None) -> List[Dict]:
|
||||
"""Get training data for a specific model"""
|
||||
try:
|
||||
@ -1395,12 +1583,28 @@ class TradingOrchestrator:
|
||||
async def _trigger_model_training(self, symbol: str):
|
||||
"""Trigger training for models based on previous inference data"""
|
||||
try:
|
||||
if not self.training_enabled or symbol not in self.inference_history:
|
||||
if not self.training_enabled:
|
||||
logger.debug("Training disabled, skipping model training")
|
||||
return
|
||||
|
||||
# Get recent inference records
|
||||
recent_records = list(self.inference_history[symbol])
|
||||
if len(recent_records) < 2:
|
||||
# Check if we have any inference history for any model
|
||||
if not self.inference_history:
|
||||
logger.debug("No inference history available for training")
|
||||
return
|
||||
|
||||
# Get recent inference records from all models (not symbol-based)
|
||||
all_recent_records = []
|
||||
for model_name, model_records in self.inference_history.items():
|
||||
all_recent_records.extend(list(model_records))
|
||||
|
||||
# Only log if we have few records (for debugging)
|
||||
if len(all_recent_records) < 5:
|
||||
logger.debug(f"Total inference records for training: {len(all_recent_records)}")
|
||||
for model_name, model_records in self.inference_history.items():
|
||||
logger.debug(f" Model {model_name} has {len(model_records)} inference records")
|
||||
|
||||
if len(all_recent_records) < 2:
|
||||
logger.debug("Not enough inference records for training")
|
||||
return # Need at least 2 records to compare
|
||||
|
||||
# Get current price for outcome evaluation
|
||||
@ -1408,12 +1612,11 @@ class TradingOrchestrator:
|
||||
if current_price is None:
|
||||
return
|
||||
|
||||
# Process records that are old enough to evaluate outcomes
|
||||
cutoff_time = datetime.now() - timedelta(minutes=5) # 5 minutes ago
|
||||
|
||||
for record in recent_records:
|
||||
if record['timestamp'] < cutoff_time:
|
||||
await self._evaluate_and_train_on_record(record, current_price)
|
||||
# Train on the most recent inference record (last prediction made)
|
||||
if all_recent_records:
|
||||
# Get the most recent record for training
|
||||
most_recent_record = max(all_recent_records, key=lambda x: datetime.fromisoformat(x['timestamp']) if isinstance(x['timestamp'], str) else x['timestamp'])
|
||||
await self._evaluate_and_train_on_record(most_recent_record, current_price)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering model training for {symbol}: {e}")
|
||||
@ -1425,6 +1628,10 @@ class TradingOrchestrator:
|
||||
prediction = record['prediction']
|
||||
timestamp = record['timestamp']
|
||||
|
||||
# Convert timestamp string back to datetime if needed
|
||||
if isinstance(timestamp, str):
|
||||
timestamp = datetime.fromisoformat(timestamp)
|
||||
|
||||
# Calculate price change since prediction
|
||||
# This is a simplified outcome evaluation - you might want to make it more sophisticated
|
||||
time_diff = (datetime.now() - timestamp).total_seconds() / 60 # minutes
|
||||
@ -1495,12 +1702,23 @@ class TradingOrchestrator:
|
||||
)
|
||||
logger.debug(f"Added RL training experience: reward={reward}")
|
||||
|
||||
# Train CNN models
|
||||
elif 'cnn' in model_name.lower() and self.cnn_model:
|
||||
if hasattr(self.cnn_model, 'train_on_outcome'):
|
||||
target = 1 if was_correct else 0
|
||||
self.cnn_model.train_on_outcome(model_input, target)
|
||||
logger.debug(f"Trained CNN on outcome: target={target}")
|
||||
# Train CNN models using adapter
|
||||
elif 'cnn' in model_name.lower() and hasattr(self, 'cnn_adapter') and self.cnn_adapter:
|
||||
# Use the adapter's add_training_sample method
|
||||
actual_action = prediction['action']
|
||||
self.cnn_adapter.add_training_sample(record['symbol'], actual_action, reward)
|
||||
logger.debug(f"Added CNN training sample: action={actual_action}, reward={reward}")
|
||||
|
||||
# Trigger training if we have enough samples
|
||||
if len(self.cnn_adapter.training_data) >= self.cnn_adapter.batch_size:
|
||||
training_results = self.cnn_adapter.train(epochs=1)
|
||||
logger.debug(f"CNN training results: {training_results}")
|
||||
|
||||
# Fallback for raw CNN model
|
||||
elif 'cnn' in model_name.lower() and self.cnn_model and hasattr(self.cnn_model, 'train_on_outcome'):
|
||||
target = 1 if was_correct else 0
|
||||
self.cnn_model.train_on_outcome(model_input, target)
|
||||
logger.debug(f"Trained CNN on outcome: target={target}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training model on outcome: {e}")
|
||||
@ -2147,8 +2365,8 @@ class TradingOrchestrator:
|
||||
return
|
||||
|
||||
if not ENHANCED_TRAINING_AVAILABLE:
|
||||
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled")
|
||||
self.training_enabled = False
|
||||
logger.info("EnhancedRealtimeTrainingSystem not available - using built-in training")
|
||||
# Keep training enabled - we have built-in training capabilities
|
||||
return
|
||||
|
||||
# Initialize the enhanced training system
|
||||
|
@ -451,3 +451,35 @@ class StandardizedDataProvider(DataProvider):
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping real-time processing: {e}")
|
||||
|
||||
def get_recent_prices(self, symbol: str, limit: int = 10) -> List[float]:
|
||||
"""
|
||||
Get recent prices for a symbol
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
limit: Number of recent prices to return
|
||||
|
||||
Returns:
|
||||
List[float]: List of recent prices
|
||||
"""
|
||||
try:
|
||||
# Get recent OHLCV data using parent class method
|
||||
df = self.get_historical_data(symbol, '1m', limit)
|
||||
if df is None or df.empty:
|
||||
return []
|
||||
|
||||
# Extract close prices from DataFrame
|
||||
if 'close' in df.columns:
|
||||
prices = df['close'].tolist()
|
||||
return prices[-limit:] # Return most recent prices
|
||||
else:
|
||||
logger.warning(f"No 'close' column found in OHLCV data for {symbol}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting recent prices for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping real-time processing: {e}")
|
@ -16,6 +16,18 @@ ETH:
|
||||
300s of 1s OHLCV data (5 min)
|
||||
300 OHLCV + indicatros bars of each 1m 1h 1d and 1s BTC
|
||||
|
||||
so:
|
||||
|
||||
# Standardized input for all models:
|
||||
{
|
||||
'primary_symbol': 'ETH/USDT',
|
||||
'reference_symbol': 'BTC/USDT',
|
||||
'eth_data': {'ETH_1s': df, 'ETH_1m': df, 'ETH_1h': df, 'ETH_1d': df},
|
||||
'btc_data': {'BTC_1s': df},
|
||||
'current_prices': {'ETH': price, 'BTC': price},
|
||||
'data_completeness': {...}
|
||||
}
|
||||
|
||||
RL model should have also access of the last hidden layers of the CNN model where patterns are learned. it can be empty if CNN model is not active or missing. as well as the output (predictions) of the CNN model for each timeframe (1s 1m 1h 1d) and next expected pivot point
|
||||
|
||||
## CNN model
|
||||
|
141
test_device_fix.py
Normal file
141
test_device_fix.py
Normal file
@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify device mismatch fixes for GPU training
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add the project root to the path
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
from core.data_models import BaseDataInput, OHLCVBar
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_device_consistency():
|
||||
"""Test that all tensors are on the same device"""
|
||||
|
||||
logger.info("Testing device consistency for EnhancedCNN...")
|
||||
|
||||
# Check if CUDA is available
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
try:
|
||||
# Initialize the adapter
|
||||
adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
|
||||
# Verify adapter device
|
||||
logger.info(f"Adapter device: {adapter.device}")
|
||||
logger.info(f"Model device: {next(adapter.model.parameters()).device}")
|
||||
|
||||
# Create sample data
|
||||
sample_ohlcv = [
|
||||
OHLCVBar(
|
||||
symbol="ETH/USDT",
|
||||
timeframe="1s",
|
||||
timestamp=1640995200.0, # 2022-01-01
|
||||
open=50000.0,
|
||||
high=51000.0,
|
||||
low=49000.0,
|
||||
close=50500.0,
|
||||
volume=1000.0
|
||||
)
|
||||
] * 300 # 300 frames
|
||||
|
||||
base_data = BaseDataInput(
|
||||
symbol="ETH/USDT",
|
||||
timestamp=1640995200.0,
|
||||
ohlcv_1s=sample_ohlcv,
|
||||
ohlcv_1m=sample_ohlcv,
|
||||
ohlcv_5m=sample_ohlcv,
|
||||
ohlcv_15m=sample_ohlcv,
|
||||
btc_ohlcv=sample_ohlcv,
|
||||
cob_data={},
|
||||
ma_data={},
|
||||
technical_indicators={},
|
||||
last_predictions={}
|
||||
)
|
||||
|
||||
# Test prediction
|
||||
logger.info("Testing prediction...")
|
||||
prediction = adapter.predict(base_data)
|
||||
logger.info(f"Prediction successful: {prediction.predictions['action']} (confidence: {prediction.confidence:.3f})")
|
||||
|
||||
# Test training sample addition
|
||||
logger.info("Testing training sample addition...")
|
||||
adapter.add_training_sample(base_data, "BUY", 0.1)
|
||||
adapter.add_training_sample(base_data, "SELL", -0.05)
|
||||
adapter.add_training_sample(base_data, "HOLD", 0.02)
|
||||
|
||||
# Test training
|
||||
logger.info("Testing training...")
|
||||
training_results = adapter.train(epochs=1)
|
||||
logger.info(f"Training results: {training_results}")
|
||||
|
||||
logger.info("✅ All device consistency tests passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Device consistency test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def test_orchestrator_inference_history():
|
||||
"""Test that orchestrator properly initializes inference history"""
|
||||
|
||||
logger.info("Testing orchestrator inference history initialization...")
|
||||
|
||||
try:
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
# Initialize orchestrator
|
||||
data_provider = DataProvider()
|
||||
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
||||
|
||||
# Check if inference history is initialized
|
||||
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
|
||||
|
||||
# Check if models are registered
|
||||
logger.info(f"Registered models: {list(orchestrator.model_registry.models.keys())}")
|
||||
|
||||
# Verify each registered model has inference history
|
||||
for model_name in orchestrator.model_registry.models.keys():
|
||||
if model_name in orchestrator.inference_history:
|
||||
logger.info(f"✅ {model_name} has inference history initialized")
|
||||
else:
|
||||
logger.warning(f"❌ {model_name} missing inference history")
|
||||
|
||||
logger.info("✅ Orchestrator inference history test completed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Orchestrator test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Starting device fix verification tests...")
|
||||
|
||||
# Test 1: Device consistency
|
||||
test1_passed = test_device_consistency()
|
||||
|
||||
# Test 2: Orchestrator inference history
|
||||
test2_passed = test_orchestrator_inference_history()
|
||||
|
||||
# Summary
|
||||
if test1_passed and test2_passed:
|
||||
logger.info("🎉 All tests passed! Device issues should be fixed.")
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("❌ Some tests failed. Please check the logs above.")
|
||||
sys.exit(1)
|
153
test_device_training_fix.py
Normal file
153
test_device_training_fix.py
Normal file
@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script to verify device handling and training sample population fixes
|
||||
"""
|
||||
|
||||
import logging
|
||||
import asyncio
|
||||
import torch
|
||||
from datetime import datetime
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def test_device_handling():
|
||||
"""Test that device handling is working correctly"""
|
||||
try:
|
||||
logger.info("Testing device handling...")
|
||||
|
||||
# Test 1: Check CUDA availability
|
||||
cuda_available = torch.cuda.is_available()
|
||||
device = torch.device("cuda" if cuda_available else "cpu")
|
||||
logger.info(f"CUDA available: {cuda_available}")
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Test 2: Initialize CNN adapter
|
||||
from core.enhanced_cnn_adapter import EnhancedCNNAdapter
|
||||
|
||||
logger.info("Initializing CNN adapter...")
|
||||
cnn_adapter = EnhancedCNNAdapter(checkpoint_dir="models/enhanced_cnn")
|
||||
|
||||
logger.info(f"CNN adapter device: {cnn_adapter.device}")
|
||||
logger.info(f"CNN model device: {cnn_adapter.model.device}")
|
||||
|
||||
# Test 3: Create test data
|
||||
from core.data_models import BaseDataInput
|
||||
|
||||
logger.info("Creating test BaseDataInput...")
|
||||
base_data = BaseDataInput(
|
||||
symbol="ETH/USDT",
|
||||
timestamp=datetime.now(),
|
||||
ohlcv_1s=[],
|
||||
ohlcv_1m=[],
|
||||
ohlcv_1h=[],
|
||||
ohlcv_1d=[],
|
||||
btc_ohlcv_1s=[],
|
||||
cob_data=None,
|
||||
technical_indicators={},
|
||||
last_predictions={}
|
||||
)
|
||||
|
||||
# Test 4: Make prediction (this should not cause device mismatch)
|
||||
logger.info("Making prediction...")
|
||||
prediction = cnn_adapter.predict(base_data)
|
||||
|
||||
logger.info(f"Prediction successful: {prediction.predictions['action']}")
|
||||
logger.info(f"Confidence: {prediction.confidence:.4f}")
|
||||
|
||||
# Test 5: Add training samples
|
||||
logger.info("Adding training samples...")
|
||||
cnn_adapter.add_training_sample(base_data, "BUY", 0.1)
|
||||
cnn_adapter.add_training_sample(base_data, "SELL", -0.05)
|
||||
cnn_adapter.add_training_sample(base_data, "HOLD", 0.02)
|
||||
|
||||
logger.info(f"Training samples added: {len(cnn_adapter.training_data)}")
|
||||
|
||||
# Test 6: Try training if we have enough samples
|
||||
if len(cnn_adapter.training_data) >= 2:
|
||||
logger.info("Attempting training...")
|
||||
training_results = cnn_adapter.train(epochs=1)
|
||||
logger.info(f"Training results: {training_results}")
|
||||
else:
|
||||
logger.info("Not enough samples for training")
|
||||
|
||||
logger.info("✅ Device handling test passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Device handling test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def test_orchestrator_training():
|
||||
"""Test that orchestrator properly adds training samples"""
|
||||
try:
|
||||
logger.info("Testing orchestrator training integration...")
|
||||
|
||||
# Test 1: Initialize orchestrator
|
||||
from core.orchestrator import TradingOrchestrator
|
||||
from core.standardized_data_provider import StandardizedDataProvider
|
||||
|
||||
logger.info("Initializing data provider...")
|
||||
data_provider = StandardizedDataProvider()
|
||||
|
||||
logger.info("Initializing orchestrator...")
|
||||
orchestrator = TradingOrchestrator(data_provider=data_provider)
|
||||
|
||||
# Test 2: Check if CNN adapter is available
|
||||
if hasattr(orchestrator, 'cnn_adapter') and orchestrator.cnn_adapter:
|
||||
logger.info(f"✅ CNN adapter available in orchestrator")
|
||||
logger.info(f"Initial training samples: {len(orchestrator.cnn_adapter.training_data)}")
|
||||
else:
|
||||
logger.warning("⚠️ CNN adapter not available in orchestrator")
|
||||
return False
|
||||
|
||||
# Test 3: Make a trading decision (this should add training samples)
|
||||
logger.info("Making trading decision...")
|
||||
decision = await orchestrator.make_trading_decision("ETH/USDT")
|
||||
|
||||
if decision:
|
||||
logger.info(f"Decision: {decision.action} (confidence: {decision.confidence:.4f})")
|
||||
logger.info(f"Training samples after decision: {len(orchestrator.cnn_adapter.training_data)}")
|
||||
else:
|
||||
logger.warning("No decision made")
|
||||
|
||||
# Test 4: Check inference history
|
||||
logger.info(f"Inference history keys: {list(orchestrator.inference_history.keys())}")
|
||||
for model_name, history in orchestrator.inference_history.items():
|
||||
logger.info(f" {model_name}: {len(history)} records")
|
||||
|
||||
logger.info("✅ Orchestrator training test passed!")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Orchestrator training test failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
async def main():
|
||||
"""Run all tests"""
|
||||
logger.info("Starting device and training fix tests...")
|
||||
|
||||
# Test 1: Device handling
|
||||
test1_passed = test_device_handling()
|
||||
|
||||
# Test 2: Orchestrator training
|
||||
test2_passed = await test_orchestrator_training()
|
||||
|
||||
# Summary
|
||||
logger.info("\n" + "="*50)
|
||||
logger.info("TEST SUMMARY:")
|
||||
logger.info(f"Device handling: {'✅ PASSED' if test1_passed else '❌ FAILED'}")
|
||||
logger.info(f"Orchestrator training: {'✅ PASSED' if test2_passed else '❌ FAILED'}")
|
||||
|
||||
if test1_passed and test2_passed:
|
||||
logger.info("🎉 All tests passed! Device and training issues should be fixed.")
|
||||
else:
|
||||
logger.error("❌ Some tests failed. Please check the logs above.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
Reference in New Issue
Block a user