dedulicae model storage
This commit is contained in:
@@ -346,11 +346,58 @@ class TradingOrchestrator:
|
||||
logger.warning("Extrema trainer not available")
|
||||
self.extrema_trainer = None
|
||||
|
||||
# COB RL Model REMOVED - See COB_MODEL_ARCHITECTURE_DOCUMENTATION.md
|
||||
# Reason: Need quality COB data first before evaluating massive parameter benefit
|
||||
# Will recreate improved version when COB data pipeline is fixed
|
||||
logger.info("COB RL model removed - focusing on COB data quality first")
|
||||
self.cob_rl_agent = None
|
||||
# Initialize COB RL Model - UNIFIED with ModelManager
|
||||
try:
|
||||
from NN.models.cob_rl_model import COBRLModelInterface
|
||||
|
||||
# Initialize COB RL model using unified approach
|
||||
self.cob_rl_agent = COBRLModelInterface(
|
||||
model_checkpoint_dir="@checkpoints/cob_rl",
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu'
|
||||
)
|
||||
|
||||
# Add COB RL to model states tracking
|
||||
self.model_states['cob_rl'] = {
|
||||
'initial_loss': None,
|
||||
'current_loss': None,
|
||||
'best_loss': None,
|
||||
'checkpoint_loaded': False
|
||||
}
|
||||
|
||||
# Load best checkpoint using unified ModelManager
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
from NN.training.model_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("cob_rl_agent")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.model_states['cob_rl']['initial_loss'] = metadata.loss
|
||||
self.model_states['cob_rl']['current_loss'] = metadata.loss
|
||||
self.model_states['cob_rl']['best_loss'] = metadata.loss
|
||||
self.model_states['cob_rl']['checkpoint_loaded'] = True
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "N/A"
|
||||
logger.info(f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading COB RL checkpoint: {e}")
|
||||
|
||||
if not checkpoint_loaded:
|
||||
# New model - no synthetic data, start fresh
|
||||
self.model_states['cob_rl']['initial_loss'] = None
|
||||
self.model_states['cob_rl']['current_loss'] = None
|
||||
self.model_states['cob_rl']['best_loss'] = None
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = 'none (fresh start)'
|
||||
logger.info("COB RL starting fresh - no checkpoint found")
|
||||
|
||||
logger.info("COB RL Agent initialized and integrated with unified ModelManager")
|
||||
logger.info(" - Uses @checkpoints/ directory structure")
|
||||
logger.info(" - Follows same load/save/checkpoint flow as other models")
|
||||
logger.info(" - Integrated with enhanced real-time training system")
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"COB RL Model not available: {e}")
|
||||
self.cob_rl_agent = None
|
||||
|
||||
# Initialize TRANSFORMER Model
|
||||
try:
|
||||
|
@@ -34,7 +34,8 @@ import os
|
||||
# Local imports
|
||||
from .cob_integration import COBIntegration
|
||||
from .trading_executor import TradingExecutor
|
||||
from NN.models.cob_rl_model import MassiveRLNetwork, COBRLModelInterface
|
||||
# UNIFIED: Import only the interface, models come from orchestrator
|
||||
from NN.models.cob_rl_model import COBRLModelInterface
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -98,51 +99,44 @@ class RealtimeRLCOBTrader:
|
||||
Real-time RL trader using COB data with comprehensive subscriber system
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
def __init__(self,
|
||||
symbols: Optional[List[str]] = None,
|
||||
trading_executor: Optional[TradingExecutor] = None,
|
||||
model_checkpoint_dir: str = "models/realtime_rl_cob",
|
||||
orchestrator: Any = None, # UNIFIED: Use orchestrator's models
|
||||
inference_interval_ms: int = 200,
|
||||
min_confidence_threshold: float = 0.35, # Lowered from 0.7 for more aggressive trading
|
||||
required_confident_predictions: int = 3,
|
||||
checkpoint_manager: Any = None):
|
||||
required_confident_predictions: int = 3):
|
||||
|
||||
self.symbols = symbols or ['BTC/USDT', 'ETH/USDT']
|
||||
self.trading_executor = trading_executor
|
||||
self.model_checkpoint_dir = model_checkpoint_dir
|
||||
self.orchestrator = orchestrator # UNIFIED: Use orchestrator's models
|
||||
self.inference_interval_ms = inference_interval_ms
|
||||
self.min_confidence_threshold = min_confidence_threshold
|
||||
self.required_confident_predictions = required_confident_predictions
|
||||
|
||||
# Initialize ModelManager (either provided or get global instance)
|
||||
if checkpoint_manager is None:
|
||||
from NN.training.model_manager import create_model_manager
|
||||
self.checkpoint_manager = create_model_manager()
|
||||
|
||||
# UNIFIED: Use orchestrator's ModelManager instead of creating our own
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'model_manager'):
|
||||
self.model_manager = self.orchestrator.model_manager
|
||||
else:
|
||||
self.checkpoint_manager = checkpoint_manager
|
||||
|
||||
from NN.training.model_manager import create_model_manager
|
||||
self.model_manager = create_model_manager()
|
||||
|
||||
# Track start time for training duration calculation
|
||||
self.start_time = datetime.now() # Initialize start_time
|
||||
|
||||
# Setup device
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
logger.info(f"Using device: {self.device}")
|
||||
|
||||
# Initialize models for each symbol
|
||||
self.models: Dict[str, MassiveRLNetwork] = {}
|
||||
self.optimizers: Dict[str, optim.AdamW] = {}
|
||||
self.scalers: Dict[str, torch.cuda.amp.GradScaler] = {}
|
||||
|
||||
for symbol in self.symbols:
|
||||
model = MassiveRLNetwork().to(self.device)
|
||||
self.models[symbol] = model
|
||||
self.optimizers[symbol] = optim.AdamW(
|
||||
model.parameters(),
|
||||
lr=1e-5, # Low learning rate for stability
|
||||
weight_decay=1e-6,
|
||||
betas=(0.9, 0.999)
|
||||
)
|
||||
self.scalers[symbol] = torch.cuda.amp.GradScaler()
|
||||
self.start_time = datetime.now()
|
||||
|
||||
# UNIFIED: Use orchestrator's COB RL model
|
||||
if not self.orchestrator or not hasattr(self.orchestrator, 'cob_rl_agent') or not self.orchestrator.cob_rl_agent:
|
||||
raise ValueError("RealtimeRLCOBTrader requires orchestrator with COB RL model. Please initialize TradingOrchestrator first.")
|
||||
|
||||
# Use orchestrator's unified COB RL model
|
||||
self.cob_rl_model = self.orchestrator.cob_rl_agent
|
||||
self.device = self.orchestrator.cob_rl_agent.device if hasattr(self.orchestrator.cob_rl_agent, 'device') else torch.device('cpu')
|
||||
logger.info(f"Using orchestrator's unified COB RL model on device: {self.device}")
|
||||
|
||||
# Create unified model references for all symbols
|
||||
self.models = {symbol: self.cob_rl_model.model for symbol in self.symbols}
|
||||
self.optimizers = {symbol: self.cob_rl_model.optimizer for symbol in self.symbols}
|
||||
self.scalers = {symbol: self.cob_rl_model.scaler for symbol in self.symbols}
|
||||
|
||||
# Subscriber system for real-time events
|
||||
self.prediction_subscribers: List[Callable[[PredictionResult], None]] = []
|
||||
@@ -906,56 +900,67 @@ class RealtimeRLCOBTrader:
|
||||
return reward
|
||||
|
||||
async def _train_batch(self, symbol: str, predictions: List[PredictionResult]) -> float:
|
||||
"""Train model on a batch of predictions"""
|
||||
"""Train model on a batch of predictions using unified approach"""
|
||||
try:
|
||||
model = self.models[symbol]
|
||||
optimizer = self.optimizers[symbol]
|
||||
scaler = self.scalers[symbol]
|
||||
|
||||
# UNIFIED: Always use orchestrator's COB RL model
|
||||
return self._train_batch_unified(predictions)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training batch for {symbol}: {e}")
|
||||
return 0.0
|
||||
|
||||
def _train_batch_unified(self, predictions: List[PredictionResult]) -> float:
|
||||
"""Train using unified COB RL model from orchestrator"""
|
||||
try:
|
||||
model = self.cob_rl_model.model
|
||||
optimizer = self.cob_rl_model.optimizer
|
||||
scaler = self.cob_rl_model.scaler
|
||||
|
||||
model.train()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
# Prepare batch data
|
||||
features = torch.stack([
|
||||
torch.from_numpy(p.features) for p in predictions
|
||||
]).to(self.device)
|
||||
|
||||
|
||||
# Targets
|
||||
direction_targets = torch.tensor([
|
||||
p.actual_direction for p in predictions
|
||||
], dtype=torch.long).to(self.device)
|
||||
|
||||
|
||||
value_targets = torch.tensor([
|
||||
p.reward for p in predictions
|
||||
], dtype=torch.float32).to(self.device)
|
||||
|
||||
|
||||
# Forward pass with mixed precision
|
||||
with torch.cuda.amp.autocast():
|
||||
outputs = model(features)
|
||||
|
||||
|
||||
# Calculate losses
|
||||
direction_loss = nn.CrossEntropyLoss()(outputs['price_logits'], direction_targets)
|
||||
value_loss = nn.MSELoss()(outputs['value'].squeeze(), value_targets)
|
||||
|
||||
|
||||
# Confidence loss (encourage high confidence for correct predictions)
|
||||
correct_predictions = (torch.argmax(outputs['price_logits'], dim=1) == direction_targets).float()
|
||||
confidence_loss = nn.BCELoss()(outputs['confidence'].squeeze(), correct_predictions)
|
||||
|
||||
|
||||
# Combined loss
|
||||
total_loss = direction_loss + 0.5 * value_loss + 0.3 * confidence_loss
|
||||
|
||||
|
||||
# Backward pass with gradient scaling
|
||||
scaler.scale(total_loss).backward()
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
|
||||
|
||||
return total_loss.item()
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training batch for {symbol}: {e}")
|
||||
logger.error(f"Error in unified training batch: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
async def _train_on_trade_execution(self, symbol: str, signals: List[PredictionResult],
|
||||
action: str, price: float):
|
||||
@@ -1015,68 +1020,99 @@ class RealtimeRLCOBTrader:
|
||||
await asyncio.sleep(60)
|
||||
|
||||
def _save_models(self):
|
||||
"""Save all models to disk using CheckpointManager"""
|
||||
"""Save models using unified ModelManager approach"""
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
model_name = f"cob_rl_{symbol.replace('/', '_').lower()}" # Standardize model name for CheckpointManager
|
||||
|
||||
# Prepare performance metrics for CheckpointManager
|
||||
if self.cob_rl_model:
|
||||
# UNIFIED: Use orchestrator's COB RL model with ModelManager
|
||||
performance_metrics = {
|
||||
'loss': self.training_stats[symbol].get('average_loss', 0.0),
|
||||
'reward': self.training_stats[symbol].get('average_reward', 0.0), # Assuming average_reward is tracked
|
||||
'accuracy': self.training_stats[symbol].get('average_accuracy', 0.0), # Assuming average_accuracy is tracked
|
||||
'loss': self._get_average_loss(),
|
||||
'reward': self._get_average_reward(),
|
||||
'accuracy': self._get_average_accuracy(),
|
||||
}
|
||||
if self.trading_executor: # Add check for trading_executor
|
||||
daily_stats = self.trading_executor.get_daily_stats()
|
||||
performance_metrics['pnl'] = daily_stats.get('total_pnl', 0.0) # Example, get actual pnl
|
||||
performance_metrics['training_samples'] = self.training_stats[symbol].get('total_training_steps', 0)
|
||||
|
||||
# Prepare training metadata for CheckpointManager
|
||||
# Add P&L if trading executor is available
|
||||
if self.trading_executor and hasattr(self.trading_executor, 'get_daily_stats'):
|
||||
try:
|
||||
daily_stats = self.trading_executor.get_daily_stats()
|
||||
performance_metrics['pnl'] = daily_stats.get('total_pnl', 0.0)
|
||||
except Exception:
|
||||
performance_metrics['pnl'] = 0.0
|
||||
|
||||
performance_metrics['training_samples'] = sum(
|
||||
stats.get('total_training_steps', 0) for stats in self.training_stats.values()
|
||||
)
|
||||
|
||||
# Prepare training metadata
|
||||
training_metadata = {
|
||||
'total_parameters': sum(p.numel() for p in self.models[symbol].parameters()),
|
||||
'epoch': self.training_stats[symbol].get('total_training_steps', 0), # Using total_training_steps as pseudo-epoch
|
||||
'total_parameters': sum(p.numel() for p in self.cob_rl_model.model.parameters()),
|
||||
'epoch': max(stats.get('total_training_steps', 0) for stats in self.training_stats.values()),
|
||||
'training_time_hours': (datetime.now() - self.start_time).total_seconds() / 3600
|
||||
}
|
||||
|
||||
self.checkpoint_manager.save_checkpoint(
|
||||
model=self.models[symbol],
|
||||
model_name=model_name,
|
||||
model_type='COB_RL', # Specify model type
|
||||
# Save using unified ModelManager
|
||||
self.model_manager.save_checkpoint(
|
||||
model=self.cob_rl_model.model,
|
||||
model_name="cob_rl_agent",
|
||||
model_type='COB_RL',
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata
|
||||
)
|
||||
|
||||
logger.debug(f"Saved model for {symbol}")
|
||||
|
||||
|
||||
logger.info("COB RL model saved using unified ModelManager")
|
||||
else:
|
||||
# This should not happen with proper initialization
|
||||
logger.error("Unified COB RL model not available - check orchestrator initialization")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving models: {e}")
|
||||
|
||||
|
||||
def _load_models(self):
|
||||
"""Load existing models from disk using CheckpointManager"""
|
||||
"""Load models using unified ModelManager approach"""
|
||||
try:
|
||||
for symbol in self.symbols:
|
||||
model_name = f"cob_rl_{symbol.replace('/', '_').lower()}" # Standardize model name for CheckpointManager
|
||||
|
||||
loaded_checkpoint = self.checkpoint_manager.load_best_checkpoint(model_name)
|
||||
|
||||
if self.cob_rl_model:
|
||||
# UNIFIED: Load using ModelManager
|
||||
loaded_checkpoint = self.model_manager.load_best_checkpoint("cob_rl_agent")
|
||||
|
||||
if loaded_checkpoint:
|
||||
model_path, metadata = loaded_checkpoint
|
||||
checkpoint = torch.load(model_path, map_location=self.device)
|
||||
|
||||
self.models[symbol].load_state_dict(checkpoint['model_state_dict'])
|
||||
self.optimizers[symbol].load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
if 'training_stats' in checkpoint:
|
||||
self.training_stats[symbol].update(checkpoint['training_stats'])
|
||||
if 'inference_stats' in checkpoint:
|
||||
self.inference_stats[symbol].update(checkpoint['inference_stats'])
|
||||
|
||||
logger.info(f"Loaded existing model for {symbol} from checkpoint: {metadata.checkpoint_id}")
|
||||
|
||||
self.cob_rl_model.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.cob_rl_model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
|
||||
# Update training stats for all symbols with loaded data
|
||||
for symbol in self.symbols:
|
||||
if 'training_stats' in checkpoint:
|
||||
self.training_stats[symbol].update(checkpoint['training_stats'])
|
||||
if 'inference_stats' in checkpoint:
|
||||
self.inference_stats[symbol].update(checkpoint['inference_stats'])
|
||||
|
||||
logger.info(f"Loaded unified COB RL model from checkpoint: {metadata.checkpoint_id}")
|
||||
else:
|
||||
logger.info(f"No existing model found for {symbol} via CheckpointManager, starting fresh.")
|
||||
|
||||
logger.info("No existing COB RL model found via ModelManager, starting fresh.")
|
||||
else:
|
||||
# This should not happen with proper initialization
|
||||
logger.error("Unified COB RL model not available - check orchestrator initialization")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading models: {e}")
|
||||
|
||||
|
||||
def _get_average_loss(self) -> float:
|
||||
"""Get average loss across all symbols"""
|
||||
losses = [stats.get('average_loss', 0.0) for stats in self.training_stats.values() if stats.get('average_loss') is not None]
|
||||
return sum(losses) / len(losses) if losses else 0.0
|
||||
|
||||
def _get_average_reward(self) -> float:
|
||||
"""Get average reward across all symbols"""
|
||||
rewards = [stats.get('average_reward', 0.0) for stats in self.training_stats.values() if stats.get('average_reward') is not None]
|
||||
return sum(rewards) / len(rewards) if rewards else 0.0
|
||||
|
||||
def _get_average_accuracy(self) -> float:
|
||||
"""Get average accuracy across all symbols"""
|
||||
accuracies = [stats.get('average_accuracy', 0.0) for stats in self.training_stats.values() if stats.get('average_accuracy') is not None]
|
||||
return sum(accuracies) / len(accuracies) if accuracies else 0.0
|
||||
|
||||
def get_performance_stats(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive performance statistics"""
|
||||
@@ -1119,36 +1155,49 @@ class RealtimeRLCOBTrader:
|
||||
|
||||
# Example usage
|
||||
async def main():
|
||||
"""Example usage of RealtimeRLCOBTrader"""
|
||||
"""Example usage of unified RealtimeRLCOBTrader"""
|
||||
from ..core.orchestrator import TradingOrchestrator
|
||||
from ..core.trading_executor import TradingExecutor
|
||||
|
||||
|
||||
# Initialize orchestrator (which now includes unified COB RL model)
|
||||
orchestrator = TradingOrchestrator()
|
||||
|
||||
# Initialize trading executor (simulation mode)
|
||||
trading_executor = TradingExecutor()
|
||||
|
||||
# Initialize real-time RL trader
|
||||
|
||||
# Initialize real-time RL trader with unified orchestrator
|
||||
trader = RealtimeRLCOBTrader(
|
||||
symbols=['BTC/USDT', 'ETH/USDT'],
|
||||
trading_executor=trading_executor,
|
||||
orchestrator=orchestrator, # UNIFIED: Use orchestrator's models
|
||||
inference_interval_ms=200,
|
||||
min_confidence_threshold=0.7,
|
||||
required_confident_predictions=3
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
# Start the trader
|
||||
# Start the orchestrator first (initializes all models)
|
||||
await orchestrator.start()
|
||||
|
||||
# Start the trader (uses orchestrator's unified COB RL model)
|
||||
await trader.start()
|
||||
|
||||
|
||||
# Run for demonstration
|
||||
logger.info("Real-time RL COB Trader running...")
|
||||
logger.info("Real-time RL COB Trader running with unified orchestrator...")
|
||||
await asyncio.sleep(300) # Run for 5 minutes
|
||||
|
||||
# Print performance stats
|
||||
stats = trader.get_performance_stats()
|
||||
logger.info(f"Performance stats: {json.dumps(stats, indent=2, default=str)}")
|
||||
|
||||
|
||||
# Print performance stats from both systems
|
||||
orchestrator_stats = orchestrator.get_model_stats()
|
||||
trader_stats = trader.get_performance_stats()
|
||||
logger.info("=== ORCHESTRATOR STATS ===")
|
||||
logger.info(f"Model stats: {json.dumps(orchestrator_stats, indent=2, default=str)}")
|
||||
logger.info("=== TRADER STATS ===")
|
||||
logger.info(f"Performance stats: {json.dumps(trader_stats, indent=2, default=str)}")
|
||||
|
||||
finally:
|
||||
# Stop the trader
|
||||
# Stop both systems
|
||||
await trader.stop()
|
||||
await orchestrator.stop()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
Reference in New Issue
Block a user