cleanup_1

This commit is contained in:
Dobromir Popov
2025-05-24 02:01:07 +03:00
parent f01047f260
commit 509ad0ae17
11 changed files with 1926 additions and 19 deletions

0
core/__init__.py Normal file
View File

239
core/config.py Normal file
View File

@ -0,0 +1,239 @@
"""
Central Configuration Management
This module handles all configuration for the trading system.
It loads settings from config.yaml and provides easy access to all components.
"""
import os
import yaml
import logging
from pathlib import Path
from typing import Dict, List, Any, Optional
logger = logging.getLogger(__name__)
class Config:
"""Central configuration management for the trading system"""
def __init__(self, config_path: str = "config.yaml"):
"""Initialize configuration from YAML file"""
self.config_path = Path(config_path)
self._config = self._load_config()
self._setup_directories()
def _load_config(self) -> Dict[str, Any]:
"""Load configuration from YAML file"""
try:
if not self.config_path.exists():
logger.warning(f"Config file {self.config_path} not found, using defaults")
return self._get_default_config()
with open(self.config_path, 'r') as f:
config = yaml.safe_load(f)
logger.info(f"Loaded configuration from {self.config_path}")
return config
except Exception as e:
logger.error(f"Error loading config: {e}")
logger.info("Using default configuration")
return self._get_default_config()
def _get_default_config(self) -> Dict[str, Any]:
"""Get default configuration if file is missing"""
return {
'symbols': ['ETH/USDT', 'BTC/USDT'],
'timeframes': ['1m', '5m', '15m', '1h', '4h', '1d'],
'data': {
'provider': 'binance',
'cache_enabled': True,
'cache_dir': 'cache',
'historical_limit': 1000,
'real_time_enabled': True,
'websocket_reconnect': True
},
'cnn': {
'window_size': 20,
'features': ['open', 'high', 'low', 'close', 'volume'],
'hidden_layers': [64, 32, 16],
'dropout': 0.2,
'learning_rate': 0.001,
'batch_size': 32,
'epochs': 100,
'confidence_threshold': 0.6
},
'rl': {
'state_size': 100,
'action_space': 3,
'epsilon': 1.0,
'epsilon_decay': 0.995,
'epsilon_min': 0.01,
'learning_rate': 0.0001,
'gamma': 0.99,
'memory_size': 10000,
'batch_size': 64,
'target_update_freq': 1000
},
'orchestrator': {
'cnn_weight': 0.7,
'rl_weight': 0.3,
'confidence_threshold': 0.5,
'decision_frequency': 60
},
'trading': {
'max_position_size': 0.1,
'stop_loss': 0.02,
'take_profit': 0.05,
'trading_fee': 0.0002,
'min_trade_interval': 60
},
'web': {
'host': '127.0.0.1',
'port': 8050,
'debug': False,
'update_interval': 1000,
'chart_history': 100
},
'logging': {
'level': 'INFO',
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s',
'file': 'logs/trading.log',
'max_size': 10485760,
'backup_count': 5
},
'performance': {
'use_gpu': True,
'mixed_precision': True,
'num_workers': 4,
'batch_size_multiplier': 1.0
},
'paths': {
'models': 'models',
'data': 'data',
'logs': 'logs',
'cache': 'cache',
'plots': 'plots'
}
}
def _setup_directories(self):
"""Create necessary directories"""
try:
paths = self._config.get('paths', {})
for path_name, path_value in paths.items():
Path(path_value).mkdir(parents=True, exist_ok=True)
# Also create specific model subdirectories
models_dir = Path(paths.get('models', 'models'))
(models_dir / 'cnn' / 'saved').mkdir(parents=True, exist_ok=True)
(models_dir / 'rl' / 'saved').mkdir(parents=True, exist_ok=True)
except Exception as e:
logger.error(f"Error creating directories: {e}")
# Property accessors for easy access
@property
def symbols(self) -> List[str]:
"""Get list of trading symbols"""
return self._config.get('symbols', ['ETH/USDT'])
@property
def timeframes(self) -> List[str]:
"""Get list of timeframes"""
return self._config.get('timeframes', ['1m', '5m', '1h'])
@property
def data(self) -> Dict[str, Any]:
"""Get data provider settings"""
return self._config.get('data', {})
@property
def cnn(self) -> Dict[str, Any]:
"""Get CNN model settings"""
return self._config.get('cnn', {})
@property
def rl(self) -> Dict[str, Any]:
"""Get RL agent settings"""
return self._config.get('rl', {})
@property
def orchestrator(self) -> Dict[str, Any]:
"""Get orchestrator settings"""
return self._config.get('orchestrator', {})
@property
def trading(self) -> Dict[str, Any]:
"""Get trading execution settings"""
return self._config.get('trading', {})
@property
def web(self) -> Dict[str, Any]:
"""Get web dashboard settings"""
return self._config.get('web', {})
@property
def logging(self) -> Dict[str, Any]:
"""Get logging settings"""
return self._config.get('logging', {})
@property
def performance(self) -> Dict[str, Any]:
"""Get performance settings"""
return self._config.get('performance', {})
@property
def paths(self) -> Dict[str, str]:
"""Get file paths"""
return self._config.get('paths', {})
def get(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key with optional default"""
return self._config.get(key, default)
def update(self, key: str, value: Any):
"""Update configuration value"""
self._config[key] = value
def save(self):
"""Save current configuration back to file"""
try:
with open(self.config_path, 'w') as f:
yaml.dump(self._config, f, default_flow_style=False, indent=2)
logger.info(f"Configuration saved to {self.config_path}")
except Exception as e:
logger.error(f"Error saving configuration: {e}")
# Global configuration instance
_config_instance = None
def get_config(config_path: str = "config.yaml") -> Config:
"""Get global configuration instance (singleton pattern)"""
global _config_instance
if _config_instance is None:
_config_instance = Config(config_path)
return _config_instance
def setup_logging(config: Optional[Config] = None):
"""Setup logging based on configuration"""
if config is None:
config = get_config()
log_config = config.logging
# Create logs directory
log_file = Path(log_config.get('file', 'logs/trading.log'))
log_file.parent.mkdir(parents=True, exist_ok=True)
# Setup logging
logging.basicConfig(
level=getattr(logging, log_config.get('level', 'INFO')),
format=log_config.get('format', '%(asctime)s - %(name)s - %(levelname)s - %(message)s'),
handlers=[
logging.FileHandler(log_file),
logging.StreamHandler()
]
)
logger.info("Logging configured successfully")

437
core/data_provider.py Normal file
View File

@ -0,0 +1,437 @@
"""
Multi-Timeframe, Multi-Symbol Data Provider
This module consolidates all data functionality including:
- Historical data fetching from Binance API
- Real-time data streaming via WebSocket
- Multi-timeframe candle generation
- Caching and data management
- Technical indicators calculation
"""
import asyncio
import json
import logging
import os
import time
import websockets
import requests
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
import ta
from threading import Thread, Lock
from collections import deque
from .config import get_config
logger = logging.getLogger(__name__)
class DataProvider:
"""Unified data provider for historical and real-time market data"""
def __init__(self, symbols: List[str] = None, timeframes: List[str] = None):
"""Initialize the data provider"""
self.config = get_config()
self.symbols = symbols or self.config.symbols
self.timeframes = timeframes or self.config.timeframes
# Data storage
self.historical_data = {} # {symbol: {timeframe: DataFrame}}
self.real_time_data = {} # {symbol: {timeframe: deque}}
self.current_prices = {} # {symbol: float}
# Real-time processing
self.websocket_tasks = {}
self.is_streaming = False
self.data_lock = Lock()
# Cache settings
self.cache_enabled = self.config.data.get('cache_enabled', True)
self.cache_dir = Path(self.config.data.get('cache_dir', 'cache'))
self.cache_dir.mkdir(parents=True, exist_ok=True)
# Timeframe conversion
self.timeframe_seconds = {
'1m': 60, '5m': 300, '15m': 900, '30m': 1800,
'1h': 3600, '4h': 14400, '1d': 86400
}
logger.info(f"DataProvider initialized for symbols: {self.symbols}")
logger.info(f"Timeframes: {self.timeframes}")
def get_historical_data(self, symbol: str, timeframe: str, limit: int = 1000,
refresh: bool = False) -> Optional[pd.DataFrame]:
"""Get historical OHLCV data for a symbol and timeframe"""
try:
# Check cache first
if not refresh and self.cache_enabled:
cached_data = self._load_from_cache(symbol, timeframe)
if cached_data is not None and len(cached_data) >= limit * 0.8:
logger.info(f"Using cached data for {symbol} {timeframe}")
return cached_data.tail(limit)
# Fetch from API
logger.info(f"Fetching historical data for {symbol} {timeframe}")
df = self._fetch_from_binance(symbol, timeframe, limit)
if df is not None and not df.empty:
# Add technical indicators
df = self._add_technical_indicators(df)
# Cache the data
if self.cache_enabled:
self._save_to_cache(df, symbol, timeframe)
# Store in memory
if symbol not in self.historical_data:
self.historical_data[symbol] = {}
self.historical_data[symbol][timeframe] = df
return df
logger.warning(f"No data received for {symbol} {timeframe}")
return None
except Exception as e:
logger.error(f"Error fetching historical data for {symbol} {timeframe}: {e}")
return None
def _fetch_from_binance(self, symbol: str, timeframe: str, limit: int) -> Optional[pd.DataFrame]:
"""Fetch data from Binance API"""
try:
# Convert symbol format
binance_symbol = symbol.replace('/', '').upper()
# Convert timeframe
timeframe_map = {
'1m': '1m', '5m': '5m', '15m': '15m', '30m': '30m',
'1h': '1h', '4h': '4h', '1d': '1d'
}
binance_timeframe = timeframe_map.get(timeframe, '1h')
# API request
url = "https://api.binance.com/api/v3/klines"
params = {
'symbol': binance_symbol,
'interval': binance_timeframe,
'limit': limit
}
response = requests.get(url, params=params)
response.raise_for_status()
data = response.json()
# Convert to DataFrame
df = pd.DataFrame(data, columns=[
'timestamp', 'open', 'high', 'low', 'close', 'volume',
'close_time', 'quote_volume', 'trades', 'taker_buy_base',
'taker_buy_quote', 'ignore'
])
# Process columns
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
for col in ['open', 'high', 'low', 'close', 'volume']:
df[col] = df[col].astype(float)
# Keep only OHLCV columns
df = df[['timestamp', 'open', 'high', 'low', 'close', 'volume']]
df = df.sort_values('timestamp').reset_index(drop=True)
logger.info(f"Fetched {len(df)} candles for {symbol} {timeframe}")
return df
except Exception as e:
logger.error(f"Error fetching from Binance API: {e}")
return None
def _add_technical_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
"""Add technical indicators to the DataFrame"""
try:
df = df.copy()
# Moving averages
df['sma_20'] = ta.trend.sma_indicator(df['close'], window=20)
df['sma_50'] = ta.trend.sma_indicator(df['close'], window=50)
df['ema_12'] = ta.trend.ema_indicator(df['close'], window=12)
df['ema_26'] = ta.trend.ema_indicator(df['close'], window=26)
# MACD
macd = ta.trend.MACD(df['close'])
df['macd'] = macd.macd()
df['macd_signal'] = macd.macd_signal()
df['macd_histogram'] = macd.macd_diff()
# RSI
df['rsi'] = ta.momentum.rsi(df['close'], window=14)
# Bollinger Bands
bollinger = ta.volatility.BollingerBands(df['close'])
df['bb_upper'] = bollinger.bollinger_hband()
df['bb_lower'] = bollinger.bollinger_lband()
df['bb_middle'] = bollinger.bollinger_mavg()
# Volume moving average (simple rolling mean since ta.volume.volume_sma doesn't exist)
df['volume_sma'] = df['volume'].rolling(window=20).mean()
# Fill NaN values
df = df.bfill().fillna(0)
return df
except Exception as e:
logger.error(f"Error adding technical indicators: {e}")
return df
def _load_from_cache(self, symbol: str, timeframe: str) -> Optional[pd.DataFrame]:
"""Load data from cache"""
try:
cache_file = self.cache_dir / f"{symbol.replace('/', '')}_{timeframe}.parquet"
if cache_file.exists():
# Check if cache is recent (less than 1 hour old)
cache_age = time.time() - cache_file.stat().st_mtime
if cache_age < 3600: # 1 hour
df = pd.read_parquet(cache_file)
logger.debug(f"Loaded {len(df)} rows from cache for {symbol} {timeframe}")
return df
else:
logger.debug(f"Cache for {symbol} {timeframe} is too old ({cache_age/3600:.1f}h)")
return None
except Exception as e:
logger.warning(f"Error loading cache for {symbol} {timeframe}: {e}")
return None
def _save_to_cache(self, df: pd.DataFrame, symbol: str, timeframe: str):
"""Save data to cache"""
try:
cache_file = self.cache_dir / f"{symbol.replace('/', '')}_{timeframe}.parquet"
df.to_parquet(cache_file, index=False)
logger.debug(f"Saved {len(df)} rows to cache for {symbol} {timeframe}")
except Exception as e:
logger.warning(f"Error saving cache for {symbol} {timeframe}: {e}")
async def start_real_time_streaming(self):
"""Start real-time data streaming for all symbols"""
if self.is_streaming:
logger.warning("Real-time streaming already active")
return
self.is_streaming = True
logger.info("Starting real-time data streaming")
# Start WebSocket for each symbol
for symbol in self.symbols:
task = asyncio.create_task(self._websocket_stream(symbol))
self.websocket_tasks[symbol] = task
async def stop_real_time_streaming(self):
"""Stop real-time data streaming"""
if not self.is_streaming:
return
logger.info("Stopping real-time data streaming")
self.is_streaming = False
# Cancel all WebSocket tasks
for symbol, task in self.websocket_tasks.items():
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
self.websocket_tasks.clear()
async def _websocket_stream(self, symbol: str):
"""WebSocket stream for a single symbol"""
binance_symbol = symbol.replace('/', '').lower()
url = f"wss://stream.binance.com:9443/ws/{binance_symbol}@ticker"
while self.is_streaming:
try:
async with websockets.connect(url) as websocket:
logger.info(f"WebSocket connected for {symbol}")
async for message in websocket:
if not self.is_streaming:
break
try:
data = json.loads(message)
await self._process_tick(symbol, data)
except Exception as e:
logger.warning(f"Error processing tick for {symbol}: {e}")
except Exception as e:
logger.error(f"WebSocket error for {symbol}: {e}")
if self.is_streaming:
logger.info(f"Reconnecting WebSocket for {symbol} in 5 seconds...")
await asyncio.sleep(5)
async def _process_tick(self, symbol: str, tick_data: Dict):
"""Process a single tick and update candles"""
try:
price = float(tick_data.get('c', 0)) # Current price
volume = float(tick_data.get('v', 0)) # 24h Volume
timestamp = pd.Timestamp.now()
# Update current price
with self.data_lock:
self.current_prices[symbol] = price
# Initialize real-time data structure if needed
if symbol not in self.real_time_data:
self.real_time_data[symbol] = {}
for tf in self.timeframes:
self.real_time_data[symbol][tf] = deque(maxlen=1000)
# Create tick record
tick = {
'timestamp': timestamp,
'price': price,
'volume': volume
}
# Update all timeframes
for timeframe in self.timeframes:
self._update_candle(symbol, timeframe, tick)
except Exception as e:
logger.error(f"Error processing tick for {symbol}: {e}")
def _update_candle(self, symbol: str, timeframe: str, tick: Dict):
"""Update candle for specific timeframe"""
try:
timeframe_secs = self.timeframe_seconds.get(timeframe, 3600)
current_time = tick['timestamp']
# Calculate candle start time
candle_start = current_time.floor(f'{timeframe_secs}s')
# Get current candle queue
candle_queue = self.real_time_data[symbol][timeframe]
# Check if we need a new candle
if not candle_queue or candle_queue[-1]['timestamp'] != candle_start:
# Create new candle
new_candle = {
'timestamp': candle_start,
'open': tick['price'],
'high': tick['price'],
'low': tick['price'],
'close': tick['price'],
'volume': tick['volume']
}
candle_queue.append(new_candle)
else:
# Update existing candle
current_candle = candle_queue[-1]
current_candle['high'] = max(current_candle['high'], tick['price'])
current_candle['low'] = min(current_candle['low'], tick['price'])
current_candle['close'] = tick['price']
current_candle['volume'] += tick['volume']
except Exception as e:
logger.error(f"Error updating candle for {symbol} {timeframe}: {e}")
def get_latest_candles(self, symbol: str, timeframe: str, limit: int = 100) -> pd.DataFrame:
"""Get the latest candles combining historical and real-time data"""
try:
# Get historical data
historical_df = self.get_historical_data(symbol, timeframe, limit=limit)
# Get real-time data
with self.data_lock:
if symbol in self.real_time_data and timeframe in self.real_time_data[symbol]:
real_time_candles = list(self.real_time_data[symbol][timeframe])
if real_time_candles:
# Convert to DataFrame
rt_df = pd.DataFrame(real_time_candles)
if historical_df is not None:
# Combine historical and real-time
# Remove overlapping candles from historical data
if not rt_df.empty:
cutoff_time = rt_df['timestamp'].min()
historical_df = historical_df[historical_df['timestamp'] < cutoff_time]
# Concatenate
combined_df = pd.concat([historical_df, rt_df], ignore_index=True)
else:
combined_df = rt_df
return combined_df.tail(limit)
# Return just historical data if no real-time data
return historical_df.tail(limit) if historical_df is not None else pd.DataFrame()
except Exception as e:
logger.error(f"Error getting latest candles for {symbol} {timeframe}: {e}")
return pd.DataFrame()
def get_current_price(self, symbol: str) -> Optional[float]:
"""Get current price for a symbol"""
with self.data_lock:
return self.current_prices.get(symbol)
def get_feature_matrix(self, symbol: str, timeframes: List[str] = None,
window_size: int = 20) -> Optional[np.ndarray]:
"""Get feature matrix for multiple timeframes"""
try:
if timeframes is None:
timeframes = self.timeframes
features = []
for tf in timeframes:
df = self.get_latest_candles(symbol, tf, limit=window_size + 50)
if df is not None and len(df) >= window_size:
# Select feature columns
feature_cols = ['open', 'high', 'low', 'close', 'volume']
if 'sma_20' in df.columns:
feature_cols.extend(['sma_20', 'rsi', 'macd'])
# Get the latest window
tf_features = df[feature_cols].tail(window_size).values
features.append(tf_features)
else:
logger.warning(f"Insufficient data for {symbol} {tf}")
return None
if features:
# Stack features from all timeframes
return np.stack(features, axis=0) # Shape: (n_timeframes, window_size, n_features)
return None
except Exception as e:
logger.error(f"Error creating feature matrix for {symbol}: {e}")
return None
def health_check(self) -> Dict[str, Any]:
"""Get health status of the data provider"""
status = {
'streaming': self.is_streaming,
'symbols': len(self.symbols),
'timeframes': len(self.timeframes),
'current_prices': len(self.current_prices),
'websocket_tasks': len(self.websocket_tasks),
'historical_data_loaded': {}
}
# Check historical data availability
for symbol in self.symbols:
status['historical_data_loaded'][symbol] = {}
for tf in self.timeframes:
has_data = (symbol in self.historical_data and
tf in self.historical_data[symbol] and
not self.historical_data[symbol][tf].empty)
status['historical_data_loaded'][symbol][tf] = has_data
return status

516
core/orchestrator.py Normal file
View File

@ -0,0 +1,516 @@
"""
Trading Orchestrator - Main Decision Making Module
This is the core orchestrator that:
1. Coordinates CNN and RL modules via model registry
2. Combines their outputs with confidence weighting
3. Makes final trading decisions (BUY/SELL/HOLD)
4. Manages the learning loop between components
5. Ensures memory efficiency (8GB constraint)
"""
import asyncio
import logging
import time
import numpy as np
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass
from .config import get_config
from .data_provider import DataProvider
from models import get_model_registry, ModelInterface, CNNModelInterface, RLAgentInterface
logger = logging.getLogger(__name__)
@dataclass
class Prediction:
"""Represents a prediction from a model"""
action: str # 'BUY', 'SELL', 'HOLD'
confidence: float # 0.0 to 1.0
probabilities: Dict[str, float] # Probabilities for each action
timeframe: str # Timeframe this prediction is for
timestamp: datetime
model_name: str # Name of the model that made this prediction
metadata: Dict[str, Any] = None # Additional model-specific data
@dataclass
class TradingDecision:
"""Final trading decision from the orchestrator"""
action: str # 'BUY', 'SELL', 'HOLD'
confidence: float # Combined confidence
symbol: str
price: float
timestamp: datetime
reasoning: Dict[str, Any] # Why this decision was made
memory_usage: Dict[str, int] # Memory usage of models
class TradingOrchestrator:
"""
Main orchestrator that coordinates multiple AI models for trading decisions
"""
def __init__(self, data_provider: DataProvider = None):
"""Initialize the orchestrator"""
self.config = get_config()
self.data_provider = data_provider or DataProvider()
self.model_registry = get_model_registry()
# Configuration
self.confidence_threshold = self.config.orchestrator.get('confidence_threshold', 0.5)
self.decision_frequency = self.config.orchestrator.get('decision_frequency', 60)
# Dynamic weights (will be adapted based on performance)
self.model_weights = {} # {model_name: weight}
self._initialize_default_weights()
# State tracking
self.last_decision_time = {} # {symbol: datetime}
self.recent_decisions = {} # {symbol: List[TradingDecision]}
self.model_performance = {} # {model_name: {'correct': int, 'total': int, 'accuracy': float}}
# Decision callbacks
self.decision_callbacks = []
logger.info("TradingOrchestrator initialized with modular model system")
logger.info(f"Confidence threshold: {self.confidence_threshold}")
logger.info(f"Decision frequency: {self.decision_frequency}s")
def _initialize_default_weights(self):
"""Initialize default model weights from config"""
self.model_weights = {
'CNN': self.config.orchestrator.get('cnn_weight', 0.7),
'RL': self.config.orchestrator.get('rl_weight', 0.3)
}
def register_model(self, model: ModelInterface, weight: float = None) -> bool:
"""Register a new model with the orchestrator"""
try:
# Register with model registry
if not self.model_registry.register_model(model):
return False
# Set weight
if weight is not None:
self.model_weights[model.name] = weight
elif model.name not in self.model_weights:
self.model_weights[model.name] = 0.1 # Default low weight for new models
# Initialize performance tracking
if model.name not in self.model_performance:
self.model_performance[model.name] = {'correct': 0, 'total': 0, 'accuracy': 0.0}
logger.info(f"Registered {model.name} model with weight {self.model_weights[model.name]}")
self._normalize_weights()
return True
except Exception as e:
logger.error(f"Error registering model {model.name}: {e}")
return False
def unregister_model(self, model_name: str) -> bool:
"""Unregister a model"""
try:
if self.model_registry.unregister_model(model_name):
if model_name in self.model_weights:
del self.model_weights[model_name]
if model_name in self.model_performance:
del self.model_performance[model_name]
self._normalize_weights()
logger.info(f"Unregistered {model_name} model")
return True
return False
except Exception as e:
logger.error(f"Error unregistering model {model_name}: {e}")
return False
def _normalize_weights(self):
"""Normalize model weights to sum to 1.0"""
total_weight = sum(self.model_weights.values())
if total_weight > 0:
for model_name in self.model_weights:
self.model_weights[model_name] /= total_weight
def add_decision_callback(self, callback):
"""Add a callback function to be called when decisions are made"""
self.decision_callbacks.append(callback)
async def make_trading_decision(self, symbol: str) -> Optional[TradingDecision]:
"""
Make a trading decision for a symbol by combining all registered model outputs
"""
try:
current_time = datetime.now()
# Check if enough time has passed since last decision
if symbol in self.last_decision_time:
time_since_last = (current_time - self.last_decision_time[symbol]).total_seconds()
if time_since_last < self.decision_frequency:
return None
# Get current market data
current_price = self.data_provider.get_current_price(symbol)
if current_price is None:
logger.warning(f"No current price available for {symbol}")
return None
# Get predictions from all registered models
predictions = await self._get_all_predictions(symbol)
if not predictions:
logger.warning(f"No predictions available for {symbol}")
return None
# Combine predictions
decision = self._combine_predictions(
symbol=symbol,
price=current_price,
predictions=predictions,
timestamp=current_time
)
# Update state
self.last_decision_time[symbol] = current_time
if symbol not in self.recent_decisions:
self.recent_decisions[symbol] = []
self.recent_decisions[symbol].append(decision)
# Keep only recent decisions (last 100)
if len(self.recent_decisions[symbol]) > 100:
self.recent_decisions[symbol] = self.recent_decisions[symbol][-100:]
# Call decision callbacks
for callback in self.decision_callbacks:
try:
await callback(decision)
except Exception as e:
logger.error(f"Error in decision callback: {e}")
# Clean up memory periodically
if len(self.recent_decisions[symbol]) % 50 == 0:
self.model_registry.cleanup_all_models()
return decision
except Exception as e:
logger.error(f"Error making trading decision for {symbol}: {e}")
return None
async def _get_all_predictions(self, symbol: str) -> List[Prediction]:
"""Get predictions from all registered models"""
predictions = []
for model_name, model in self.model_registry.models.items():
try:
if isinstance(model, CNNModelInterface):
# Get CNN predictions for each timeframe
cnn_predictions = await self._get_cnn_predictions(model, symbol)
predictions.extend(cnn_predictions)
elif isinstance(model, RLAgentInterface):
# Get RL prediction
rl_prediction = await self._get_rl_prediction(model, symbol)
if rl_prediction:
predictions.append(rl_prediction)
else:
# Generic model interface
generic_prediction = await self._get_generic_prediction(model, symbol)
if generic_prediction:
predictions.append(generic_prediction)
except Exception as e:
logger.error(f"Error getting prediction from {model_name}: {e}")
continue
return predictions
async def _get_cnn_predictions(self, model: CNNModelInterface, symbol: str) -> List[Prediction]:
"""Get predictions from CNN model for all timeframes"""
predictions = []
try:
for timeframe in self.config.timeframes:
# Get feature matrix for this timeframe
feature_matrix = self.data_provider.get_feature_matrix(
symbol=symbol,
timeframes=[timeframe],
window_size=model.window_size
)
if feature_matrix is not None:
# Get CNN prediction
try:
action_probs, confidence = model.predict_timeframe(feature_matrix, timeframe)
except AttributeError:
# Fallback to generic predict method
action_probs, confidence = model.predict(feature_matrix)
if action_probs is not None:
# Convert to prediction object
action_names = ['SELL', 'HOLD', 'BUY']
best_action_idx = np.argmax(action_probs)
best_action = action_names[best_action_idx]
prediction = Prediction(
action=best_action,
confidence=float(confidence) if confidence is not None else float(action_probs[best_action_idx]),
probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)},
timeframe=timeframe,
timestamp=datetime.now(),
model_name=model.name,
metadata={'timeframe_specific': True}
)
predictions.append(prediction)
except Exception as e:
logger.error(f"Error getting CNN predictions: {e}")
return predictions
async def _get_rl_prediction(self, model: RLAgentInterface, symbol: str) -> Optional[Prediction]:
"""Get prediction from RL agent"""
try:
# Get current state for RL agent
state = self._get_rl_state(symbol)
if state is None:
return None
# Get RL agent's action and confidence
action_idx, confidence = model.act_with_confidence(state)
action_names = ['SELL', 'HOLD', 'BUY']
action = action_names[action_idx]
# Create prediction object
prediction = Prediction(
action=action,
confidence=float(confidence),
probabilities={action: float(confidence), 'HOLD': 1.0 - float(confidence)},
timeframe='mixed', # RL uses mixed timeframes
timestamp=datetime.now(),
model_name=model.name,
metadata={'state_size': len(state)}
)
return prediction
except Exception as e:
logger.error(f"Error getting RL prediction: {e}")
return None
async def _get_generic_prediction(self, model: ModelInterface, symbol: str) -> Optional[Prediction]:
"""Get prediction from generic model"""
try:
# Get feature matrix for the model
feature_matrix = self.data_provider.get_feature_matrix(
symbol=symbol,
timeframes=self.config.timeframes[:3], # Use first 3 timeframes
window_size=20
)
if feature_matrix is not None:
action_probs, confidence = model.predict(feature_matrix)
if action_probs is not None:
action_names = ['SELL', 'HOLD', 'BUY']
best_action_idx = np.argmax(action_probs)
best_action = action_names[best_action_idx]
prediction = Prediction(
action=best_action,
confidence=float(confidence),
probabilities={name: float(prob) for name, prob in zip(action_names, action_probs)},
timeframe='mixed',
timestamp=datetime.now(),
model_name=model.name,
metadata={'generic_model': True}
)
return prediction
return None
except Exception as e:
logger.error(f"Error getting generic prediction: {e}")
return None
def _get_rl_state(self, symbol: str) -> Optional[np.ndarray]:
"""Get current state for RL agent"""
try:
# Get feature matrix for all timeframes
feature_matrix = self.data_provider.get_feature_matrix(
symbol=symbol,
timeframes=self.config.timeframes,
window_size=self.config.rl.get('window_size', 20)
)
if feature_matrix is not None:
# Flatten the feature matrix for RL agent
# Shape: (n_timeframes, window_size, n_features) -> (n_timeframes * window_size * n_features,)
state = feature_matrix.flatten()
# Add additional state information (position, balance, etc.)
# This would come from a portfolio manager in a real implementation
additional_state = np.array([0.0, 1.0, 0.0]) # [position, balance, unrealized_pnl]
return np.concatenate([state, additional_state])
return None
except Exception as e:
logger.error(f"Error creating RL state for {symbol}: {e}")
return None
def _combine_predictions(self, symbol: str, price: float,
predictions: List[Prediction],
timestamp: datetime) -> TradingDecision:
"""Combine all predictions into a final decision"""
try:
reasoning = {
'predictions': len(predictions),
'weights': self.model_weights.copy(),
'models_used': [pred.model_name for pred in predictions]
}
# Initialize action scores
action_scores = {'BUY': 0.0, 'SELL': 0.0, 'HOLD': 0.0}
total_weight = 0.0
# Process all predictions
for pred in predictions:
# Get model weight
model_weight = self.model_weights.get(pred.model_name, 0.1)
# Weight by confidence and timeframe importance
timeframe_weight = self._get_timeframe_weight(pred.timeframe)
weighted_confidence = pred.confidence * timeframe_weight * model_weight
action_scores[pred.action] += weighted_confidence
total_weight += weighted_confidence
# Normalize scores
if total_weight > 0:
for action in action_scores:
action_scores[action] /= total_weight
# Choose best action
best_action = max(action_scores, key=action_scores.get)
best_confidence = action_scores[best_action]
# Apply confidence threshold
if best_confidence < self.confidence_threshold:
best_action = 'HOLD'
reasoning['threshold_applied'] = True
# Get memory usage stats
memory_usage = self.model_registry.get_memory_stats()
# Create final decision
decision = TradingDecision(
action=best_action,
confidence=best_confidence,
symbol=symbol,
price=price,
timestamp=timestamp,
reasoning=reasoning,
memory_usage=memory_usage['models']
)
logger.info(f"Decision for {symbol}: {best_action} (confidence: {best_confidence:.3f})")
logger.debug(f"Memory usage: {memory_usage['total_used_mb']:.1f}MB / {memory_usage['total_limit_mb']:.1f}MB")
return decision
except Exception as e:
logger.error(f"Error combining predictions for {symbol}: {e}")
# Return safe default
return TradingDecision(
action='HOLD',
confidence=0.0,
symbol=symbol,
price=price,
timestamp=timestamp,
reasoning={'error': str(e)},
memory_usage={}
)
def _get_timeframe_weight(self, timeframe: str) -> float:
"""Get importance weight for a timeframe"""
# Higher timeframes get more weight in decision making
weights = {
'1m': 0.1, '5m': 0.2, '15m': 0.3, '30m': 0.4,
'1h': 0.6, '4h': 0.8, '1d': 1.0
}
return weights.get(timeframe, 0.5)
def update_model_performance(self, model_name: str, was_correct: bool):
"""Update performance tracking for a model"""
if model_name in self.model_performance:
self.model_performance[model_name]['total'] += 1
if was_correct:
self.model_performance[model_name]['correct'] += 1
# Update accuracy
total = self.model_performance[model_name]['total']
correct = self.model_performance[model_name]['correct']
self.model_performance[model_name]['accuracy'] = correct / total if total > 0 else 0.0
def adapt_weights(self):
"""Dynamically adapt model weights based on performance"""
try:
for model_name, performance in self.model_performance.items():
if performance['total'] > 0:
# Adjust weight based on relative performance
accuracy = performance['correct'] / performance['total']
self.model_weights[model_name] = accuracy
logger.info(f"Adapted {model_name} weight: {self.model_weights[model_name]}")
except Exception as e:
logger.error(f"Error adapting weights: {e}")
def get_recent_decisions(self, symbol: str, limit: int = 10) -> List[TradingDecision]:
"""Get recent decisions for a symbol"""
if symbol in self.recent_decisions:
return self.recent_decisions[symbol][-limit:]
return []
def get_performance_metrics(self) -> Dict[str, Any]:
"""Get performance metrics for the orchestrator"""
return {
'model_performance': self.model_performance.copy(),
'weights': self.model_weights.copy(),
'configuration': {
'confidence_threshold': self.confidence_threshold,
'decision_frequency': self.decision_frequency
},
'recent_activity': {
symbol: len(decisions) for symbol, decisions in self.recent_decisions.items()
}
}
async def start_continuous_trading(self, symbols: List[str] = None):
"""Start continuous trading decisions for specified symbols"""
if symbols is None:
symbols = self.config.symbols
logger.info(f"Starting continuous trading for symbols: {symbols}")
while True:
try:
# Make decisions for all symbols
for symbol in symbols:
decision = await self.make_trading_decision(symbol)
if decision and decision.action != 'HOLD':
logger.info(f"Trading decision: {decision.action} {symbol} at {decision.price}")
# Wait before next decision cycle
await asyncio.sleep(self.decision_frequency)
except Exception as e:
logger.error(f"Error in continuous trading loop: {e}")
await asyncio.sleep(10) # Wait before retrying