dqn model data fix

This commit is contained in:
Dobromir Popov
2025-07-29 00:09:13 +03:00
parent 548c0d5e0f
commit e1e453c204
2 changed files with 88 additions and 74 deletions

View File

@ -1167,7 +1167,7 @@ class DQNAgent:
# Handle empty dictionary case
if not state:
logger.error("No numerical values found in state dict, using default state")
logger.error("Empty state dictionary received, using default state")
expected_size = getattr(self, 'state_size', 403)
if isinstance(expected_size, tuple):
expected_size = np.prod(expected_size)

View File

@ -6348,9 +6348,10 @@ class TradingOrchestrator:
action = decision.action
confidence = decision.confidence
# Get current market data for training context
market_data = self._get_current_market_data(symbol)
if not market_data:
# Get current market data for training context - use same data source as CNN model
base_data = self.build_base_data_input(symbol)
if not base_data:
logger.warning(f"No base data available for training {symbol}, skipping model training")
return
# Track if any model was trained for checkpoint saving
@ -6359,8 +6360,8 @@ class TradingOrchestrator:
# Train DQN agent if available and enabled
if self.rl_agent and hasattr(self.rl_agent, "add_experience") and self.is_model_training_enabled("dqn"):
try:
# Create state representation
state = self._create_state_for_training(symbol, market_data)
# Create state representation from base_data (same as CNN model)
state = self._create_state_from_base_data(symbol, base_data)
# Map action to DQN action space - CONSISTENT ACTION MAPPING
action_mapping = {"BUY": 0, "SELL": 1, "HOLD": 2}
@ -6389,9 +6390,9 @@ class TradingOrchestrator:
# Train CNN model if available and enabled
if self.cnn_model and hasattr(self.cnn_model, "add_training_sample") and self.is_model_training_enabled("cnn"):
try:
# Create CNN input features
cnn_features = self._create_cnn_features_for_training(
symbol, market_data
# Create CNN input features from base_data (same as inference)
cnn_features = self._create_cnn_features_from_base_data(
symbol, base_data
)
# Create target based on action
@ -6573,88 +6574,101 @@ class TradingOrchestrator:
def _get_current_market_data(self, symbol: str) -> Optional[Dict]:
"""Get current market data for training context"""
try:
if self.data_provider:
# Get recent data for training
df = self.data_provider.get_historical_data(symbol, "1m", limit=100)
if df is not None and not df.empty:
return {
"ohlcv": df.tail(50).to_dict("records"), # Last 50 candles
"current_price": float(df["close"].iloc[-1]),
"volume": float(df["volume"].iloc[-1]),
"timestamp": df.index[-1],
}
return None
if not self.data_provider:
logger.warning(f"No data provider available for {symbol}")
return None
# Get recent data for training
df = self.data_provider.get_historical_data(symbol, "1m", limit=100)
if df is not None and not df.empty:
return {
"ohlcv": df.tail(50).to_dict("records"), # Last 50 candles
"current_price": float(df["close"].iloc[-1]),
"volume": float(df["volume"].iloc[-1]),
"timestamp": df.index[-1],
}
else:
logger.warning(f"No historical data available for {symbol}")
return None
except Exception as e:
logger.debug(f"Error getting market data for training: {e}")
logger.error(f"Error getting market data for training {symbol}: {e}")
return None
def _create_state_for_training(self, symbol: str, market_data: Dict) -> np.ndarray:
"""Create state representation for DQN training"""
def _create_state_from_base_data(self, symbol: str, base_data: Any) -> np.ndarray:
"""Create state representation for DQN training from base_data (same as CNN model)"""
try:
# Create a basic state representation
ohlcv_data = market_data.get("ohlcv", [])
if not ohlcv_data:
return np.zeros(100) # Default state size
# Validate base_data
if not base_data or not hasattr(base_data, 'get_feature_vector'):
logger.warning(f"Invalid base_data for {symbol}: {type(base_data)}")
return np.zeros(403) # Default state size for DQN
# Extract features from recent candles
features = []
for candle in ohlcv_data[-20:]: # Last 20 candles
features.extend(
[
candle.get("open", 0),
candle.get("high", 0),
candle.get("low", 0),
candle.get("close", 0),
candle.get("volume", 0),
]
)
# Get feature vector from base_data (same as CNN model)
features = base_data.get_feature_vector()
# Pad or truncate to expected size
state = np.array(features[:100])
if len(state) < 100:
state = np.pad(state, (0, 100 - len(state)), "constant")
if not features or len(features) == 0:
logger.warning(f"No features available from base_data for {symbol}, using default state")
return np.zeros(403) # Default state size for DQN
# Convert to numpy array
state = np.array(features, dtype=np.float32)
# Ensure correct dimensions for DQN (403 features)
if len(state) != 403:
if len(state) < 403:
# Pad with zeros
padded_state = np.zeros(403, dtype=np.float32)
padded_state[:len(state)] = state
state = padded_state
else:
# Truncate
state = state[:403]
return state
except Exception as e:
logger.debug(f"Error creating state for training: {e}")
return np.zeros(100)
logger.error(f"Error creating state from base_data for {symbol}: {e}")
return np.zeros(403) # Default state size for DQN
def _create_cnn_features_for_training(
self, symbol: str, market_data: Dict
def _create_cnn_features_from_base_data(
self, symbol: str, base_data: Any
) -> np.ndarray:
"""Create CNN features for training"""
"""Create CNN features for training from base_data (same as inference)"""
try:
# Similar to state creation but formatted for CNN
ohlcv_data = market_data.get("ohlcv", [])
if not ohlcv_data:
return np.zeros((1, 100))
# Validate base_data
if not base_data or not hasattr(base_data, 'get_feature_vector'):
logger.warning(f"Invalid base_data for CNN training {symbol}: {type(base_data)}")
return np.zeros((1, 403)) # Default CNN input size
# Create feature matrix
features = []
for candle in ohlcv_data[-20:]:
features.extend(
[
candle.get("open", 0),
candle.get("high", 0),
candle.get("low", 0),
candle.get("close", 0),
candle.get("volume", 0),
]
)
# Get feature vector from base_data (same as CNN inference)
features = base_data.get_feature_vector()
# Reshape for CNN input
cnn_features = np.array(features[:100]).reshape(1, -1)
if cnn_features.shape[1] < 100:
cnn_features = np.pad(
cnn_features, ((0, 0), (0, 100 - cnn_features.shape[1])), "constant"
)
if not features or len(features) == 0:
logger.warning(f"No features available from base_data for CNN training {symbol}, using default")
return np.zeros((1, 403)) # Default CNN input size
# Convert to numpy array and reshape for CNN
cnn_features = np.array(features, dtype=np.float32).reshape(1, -1)
# Ensure correct dimensions for CNN (403 features)
if cnn_features.shape[1] != 403:
if cnn_features.shape[1] < 403:
# Pad with zeros
padded_features = np.zeros((1, 403), dtype=np.float32)
padded_features[0, :cnn_features.shape[1]] = cnn_features[0]
cnn_features = padded_features
else:
# Truncate
cnn_features = cnn_features[:, :403]
return cnn_features
except Exception as e:
logger.debug(f"Error creating CNN features for training: {e}")
return np.zeros((1, 100))
logger.error(f"Error creating CNN features from base_data for {symbol}: {e}")
return np.zeros((1, 403)) # Default CNN input size
def _create_cob_state_for_training(self, symbol: str, cob_data: Dict) -> np.ndarray:
"""Create COB state representation for training"""