dqn model data fix

This commit is contained in:
Dobromir Popov
2025-07-29 00:09:13 +03:00
parent 548c0d5e0f
commit e1e453c204
2 changed files with 88 additions and 74 deletions

View File

@ -1167,7 +1167,7 @@ class DQNAgent:
# Handle empty dictionary case # Handle empty dictionary case
if not state: if not state:
logger.error("No numerical values found in state dict, using default state") logger.error("Empty state dictionary received, using default state")
expected_size = getattr(self, 'state_size', 403) expected_size = getattr(self, 'state_size', 403)
if isinstance(expected_size, tuple): if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size) expected_size = np.prod(expected_size)

View File

@ -6348,9 +6348,10 @@ class TradingOrchestrator:
action = decision.action action = decision.action
confidence = decision.confidence confidence = decision.confidence
# Get current market data for training context # Get current market data for training context - use same data source as CNN model
market_data = self._get_current_market_data(symbol) base_data = self.build_base_data_input(symbol)
if not market_data: if not base_data:
logger.warning(f"No base data available for training {symbol}, skipping model training")
return return
# Track if any model was trained for checkpoint saving # Track if any model was trained for checkpoint saving
@ -6359,8 +6360,8 @@ class TradingOrchestrator:
# Train DQN agent if available and enabled # Train DQN agent if available and enabled
if self.rl_agent and hasattr(self.rl_agent, "add_experience") and self.is_model_training_enabled("dqn"): if self.rl_agent and hasattr(self.rl_agent, "add_experience") and self.is_model_training_enabled("dqn"):
try: try:
# Create state representation # Create state representation from base_data (same as CNN model)
state = self._create_state_for_training(symbol, market_data) state = self._create_state_from_base_data(symbol, base_data)
# Map action to DQN action space - CONSISTENT ACTION MAPPING # Map action to DQN action space - CONSISTENT ACTION MAPPING
action_mapping = {"BUY": 0, "SELL": 1, "HOLD": 2} action_mapping = {"BUY": 0, "SELL": 1, "HOLD": 2}
@ -6389,9 +6390,9 @@ class TradingOrchestrator:
# Train CNN model if available and enabled # Train CNN model if available and enabled
if self.cnn_model and hasattr(self.cnn_model, "add_training_sample") and self.is_model_training_enabled("cnn"): if self.cnn_model and hasattr(self.cnn_model, "add_training_sample") and self.is_model_training_enabled("cnn"):
try: try:
# Create CNN input features # Create CNN input features from base_data (same as inference)
cnn_features = self._create_cnn_features_for_training( cnn_features = self._create_cnn_features_from_base_data(
symbol, market_data symbol, base_data
) )
# Create target based on action # Create target based on action
@ -6573,88 +6574,101 @@ class TradingOrchestrator:
def _get_current_market_data(self, symbol: str) -> Optional[Dict]: def _get_current_market_data(self, symbol: str) -> Optional[Dict]:
"""Get current market data for training context""" """Get current market data for training context"""
try: try:
if self.data_provider: if not self.data_provider:
# Get recent data for training logger.warning(f"No data provider available for {symbol}")
df = self.data_provider.get_historical_data(symbol, "1m", limit=100) return None
if df is not None and not df.empty:
return { # Get recent data for training
"ohlcv": df.tail(50).to_dict("records"), # Last 50 candles df = self.data_provider.get_historical_data(symbol, "1m", limit=100)
"current_price": float(df["close"].iloc[-1]), if df is not None and not df.empty:
"volume": float(df["volume"].iloc[-1]), return {
"timestamp": df.index[-1], "ohlcv": df.tail(50).to_dict("records"), # Last 50 candles
} "current_price": float(df["close"].iloc[-1]),
return None "volume": float(df["volume"].iloc[-1]),
"timestamp": df.index[-1],
}
else:
logger.warning(f"No historical data available for {symbol}")
return None
except Exception as e: except Exception as e:
logger.debug(f"Error getting market data for training: {e}") logger.error(f"Error getting market data for training {symbol}: {e}")
return None return None
def _create_state_for_training(self, symbol: str, market_data: Dict) -> np.ndarray: def _create_state_from_base_data(self, symbol: str, base_data: Any) -> np.ndarray:
"""Create state representation for DQN training""" """Create state representation for DQN training from base_data (same as CNN model)"""
try: try:
# Create a basic state representation # Validate base_data
ohlcv_data = market_data.get("ohlcv", []) if not base_data or not hasattr(base_data, 'get_feature_vector'):
if not ohlcv_data: logger.warning(f"Invalid base_data for {symbol}: {type(base_data)}")
return np.zeros(100) # Default state size return np.zeros(403) # Default state size for DQN
# Get feature vector from base_data (same as CNN model)
features = base_data.get_feature_vector()
if not features or len(features) == 0:
logger.warning(f"No features available from base_data for {symbol}, using default state")
return np.zeros(403) # Default state size for DQN
# Extract features from recent candles # Convert to numpy array
features = [] state = np.array(features, dtype=np.float32)
for candle in ohlcv_data[-20:]: # Last 20 candles
features.extend( # Ensure correct dimensions for DQN (403 features)
[ if len(state) != 403:
candle.get("open", 0), if len(state) < 403:
candle.get("high", 0), # Pad with zeros
candle.get("low", 0), padded_state = np.zeros(403, dtype=np.float32)
candle.get("close", 0), padded_state[:len(state)] = state
candle.get("volume", 0), state = padded_state
] else:
) # Truncate
state = state[:403]
# Pad or truncate to expected size
state = np.array(features[:100])
if len(state) < 100:
state = np.pad(state, (0, 100 - len(state)), "constant")
return state return state
except Exception as e: except Exception as e:
logger.debug(f"Error creating state for training: {e}") logger.error(f"Error creating state from base_data for {symbol}: {e}")
return np.zeros(100) return np.zeros(403) # Default state size for DQN
def _create_cnn_features_for_training(
self, symbol: str, market_data: Dict
def _create_cnn_features_from_base_data(
self, symbol: str, base_data: Any
) -> np.ndarray: ) -> np.ndarray:
"""Create CNN features for training""" """Create CNN features for training from base_data (same as inference)"""
try: try:
# Similar to state creation but formatted for CNN # Validate base_data
ohlcv_data = market_data.get("ohlcv", []) if not base_data or not hasattr(base_data, 'get_feature_vector'):
if not ohlcv_data: logger.warning(f"Invalid base_data for CNN training {symbol}: {type(base_data)}")
return np.zeros((1, 100)) return np.zeros((1, 403)) # Default CNN input size
# Get feature vector from base_data (same as CNN inference)
features = base_data.get_feature_vector()
if not features or len(features) == 0:
logger.warning(f"No features available from base_data for CNN training {symbol}, using default")
return np.zeros((1, 403)) # Default CNN input size
# Create feature matrix # Convert to numpy array and reshape for CNN
features = [] cnn_features = np.array(features, dtype=np.float32).reshape(1, -1)
for candle in ohlcv_data[-20:]:
features.extend( # Ensure correct dimensions for CNN (403 features)
[ if cnn_features.shape[1] != 403:
candle.get("open", 0), if cnn_features.shape[1] < 403:
candle.get("high", 0), # Pad with zeros
candle.get("low", 0), padded_features = np.zeros((1, 403), dtype=np.float32)
candle.get("close", 0), padded_features[0, :cnn_features.shape[1]] = cnn_features[0]
candle.get("volume", 0), cnn_features = padded_features
] else:
) # Truncate
cnn_features = cnn_features[:, :403]
# Reshape for CNN input
cnn_features = np.array(features[:100]).reshape(1, -1)
if cnn_features.shape[1] < 100:
cnn_features = np.pad(
cnn_features, ((0, 0), (0, 100 - cnn_features.shape[1])), "constant"
)
return cnn_features return cnn_features
except Exception as e: except Exception as e:
logger.debug(f"Error creating CNN features for training: {e}") logger.error(f"Error creating CNN features from base_data for {symbol}: {e}")
return np.zeros((1, 100)) return np.zeros((1, 403)) # Default CNN input size
def _create_cob_state_for_training(self, symbol: str, cob_data: Dict) -> np.ndarray: def _create_cob_state_for_training(self, symbol: str, cob_data: Dict) -> np.ndarray:
"""Create COB state representation for training""" """Create COB state representation for training"""