enhanced
This commit is contained in:
@ -1,15 +1,25 @@
|
||||
"""
|
||||
Enhanced RL Trainer with Market Environment Adaptation
|
||||
Enhanced RL Trainer with Continuous Learning
|
||||
|
||||
This trainer implements:
|
||||
1. Continuous learning from orchestrator action evaluations
|
||||
2. Environment adaptation based on market regime changes
|
||||
3. Multi-symbol coordinated RL training
|
||||
4. Experience replay with prioritized sampling
|
||||
5. Dynamic reward shaping based on market conditions
|
||||
This module implements sophisticated RL training with:
|
||||
- Prioritized experience replay
|
||||
- Market regime adaptation
|
||||
- Continuous learning from trading outcomes
|
||||
- Performance tracking and visualization
|
||||
"""
|
||||
|
||||
import asyncioimport asyncioimport loggingimport numpy as npimport torchimport torch.nn as nnimport torch.optim as optimfrom collections import deque, namedtupleimport randomfrom datetime import datetime, timedeltafrom typing import Dict, List, Optional, Tuple, Anyimport matplotlib.pyplot as pltfrom pathlib import Path
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from collections import deque, namedtuple
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
@ -290,7 +300,22 @@ class EnhancedDQNAgent(nn.Module, RLAgentInterface):
|
||||
self.target_value_head.load_state_dict(self.value_head.state_dict())
|
||||
self.target_advantage_head.load_state_dict(self.advantage_head.state_dict())
|
||||
|
||||
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]: """Predict action probabilities and confidence (required by ModelInterface)""" action, confidence = self.act_with_confidence(features) # Convert action to probabilities action_probs = np.zeros(self.action_space) action_probs[action] = 1.0 return action_probs, confidence def get_memory_usage(self) -> int: """Get memory usage in MB""" if torch.cuda.is_available(): return torch.cuda.memory_allocated(self.device) // (1024 * 1024) else: param_count = sum(p.numel() for p in self.parameters()) buffer_size = len(self.replay_buffer) * self.state_size * 4 # Rough estimate return (param_count * 4 + buffer_size) // (1024 * 1024)
|
||||
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||
"""Predict action probabilities and confidence (required by ModelInterface)"""
|
||||
action, confidence = self.act_with_confidence(features)
|
||||
# Convert action to probabilities
|
||||
action_probs = np.zeros(self.action_space)
|
||||
action_probs[action] = 1.0
|
||||
return action_probs, confidence
|
||||
|
||||
def get_memory_usage(self) -> int:
|
||||
"""Get memory usage in MB"""
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.memory_allocated(self.device) // (1024 * 1024)
|
||||
else:
|
||||
param_count = sum(p.numel() for p in self.parameters())
|
||||
buffer_size = len(self.replay_buffer) * self.state_size * 4 # Rough estimate
|
||||
return (param_count * 4 + buffer_size) // (1024 * 1024)
|
||||
|
||||
class EnhancedRLTrainer:
|
||||
"""Enhanced RL trainer with continuous learning from market feedback"""
|
||||
@ -322,7 +347,10 @@ class EnhancedRLTrainer:
|
||||
'epsilon_values': {symbol: [] for symbol in self.config.symbols}
|
||||
}
|
||||
|
||||
# Create save directory models_path = self.config.rl.get('model_dir', "models/enhanced_rl") self.save_dir = Path(models_path) self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Create save directory
|
||||
models_path = self.config.rl.get('model_dir', "models/enhanced_rl")
|
||||
self.save_dir = Path(models_path)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info(f"Enhanced RL trainer initialized for symbols: {self.config.symbols}")
|
||||
|
||||
|
Reference in New Issue
Block a user