dqn model data fix
This commit is contained in:
@ -1167,7 +1167,7 @@ class DQNAgent:
|
||||
|
||||
# Handle empty dictionary case
|
||||
if not state:
|
||||
logger.error("No numerical values found in state dict, using default state")
|
||||
logger.error("Empty state dictionary received, using default state")
|
||||
expected_size = getattr(self, 'state_size', 403)
|
||||
if isinstance(expected_size, tuple):
|
||||
expected_size = np.prod(expected_size)
|
||||
|
@ -6348,9 +6348,10 @@ class TradingOrchestrator:
|
||||
action = decision.action
|
||||
confidence = decision.confidence
|
||||
|
||||
# Get current market data for training context
|
||||
market_data = self._get_current_market_data(symbol)
|
||||
if not market_data:
|
||||
# Get current market data for training context - use same data source as CNN model
|
||||
base_data = self.build_base_data_input(symbol)
|
||||
if not base_data:
|
||||
logger.warning(f"No base data available for training {symbol}, skipping model training")
|
||||
return
|
||||
|
||||
# Track if any model was trained for checkpoint saving
|
||||
@ -6359,8 +6360,8 @@ class TradingOrchestrator:
|
||||
# Train DQN agent if available and enabled
|
||||
if self.rl_agent and hasattr(self.rl_agent, "add_experience") and self.is_model_training_enabled("dqn"):
|
||||
try:
|
||||
# Create state representation
|
||||
state = self._create_state_for_training(symbol, market_data)
|
||||
# Create state representation from base_data (same as CNN model)
|
||||
state = self._create_state_from_base_data(symbol, base_data)
|
||||
|
||||
# Map action to DQN action space - CONSISTENT ACTION MAPPING
|
||||
action_mapping = {"BUY": 0, "SELL": 1, "HOLD": 2}
|
||||
@ -6389,9 +6390,9 @@ class TradingOrchestrator:
|
||||
# Train CNN model if available and enabled
|
||||
if self.cnn_model and hasattr(self.cnn_model, "add_training_sample") and self.is_model_training_enabled("cnn"):
|
||||
try:
|
||||
# Create CNN input features
|
||||
cnn_features = self._create_cnn_features_for_training(
|
||||
symbol, market_data
|
||||
# Create CNN input features from base_data (same as inference)
|
||||
cnn_features = self._create_cnn_features_from_base_data(
|
||||
symbol, base_data
|
||||
)
|
||||
|
||||
# Create target based on action
|
||||
@ -6573,88 +6574,101 @@ class TradingOrchestrator:
|
||||
def _get_current_market_data(self, symbol: str) -> Optional[Dict]:
|
||||
"""Get current market data for training context"""
|
||||
try:
|
||||
if self.data_provider:
|
||||
# Get recent data for training
|
||||
df = self.data_provider.get_historical_data(symbol, "1m", limit=100)
|
||||
if df is not None and not df.empty:
|
||||
return {
|
||||
"ohlcv": df.tail(50).to_dict("records"), # Last 50 candles
|
||||
"current_price": float(df["close"].iloc[-1]),
|
||||
"volume": float(df["volume"].iloc[-1]),
|
||||
"timestamp": df.index[-1],
|
||||
}
|
||||
return None
|
||||
if not self.data_provider:
|
||||
logger.warning(f"No data provider available for {symbol}")
|
||||
return None
|
||||
|
||||
# Get recent data for training
|
||||
df = self.data_provider.get_historical_data(symbol, "1m", limit=100)
|
||||
if df is not None and not df.empty:
|
||||
return {
|
||||
"ohlcv": df.tail(50).to_dict("records"), # Last 50 candles
|
||||
"current_price": float(df["close"].iloc[-1]),
|
||||
"volume": float(df["volume"].iloc[-1]),
|
||||
"timestamp": df.index[-1],
|
||||
}
|
||||
else:
|
||||
logger.warning(f"No historical data available for {symbol}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting market data for training: {e}")
|
||||
logger.error(f"Error getting market data for training {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _create_state_for_training(self, symbol: str, market_data: Dict) -> np.ndarray:
|
||||
"""Create state representation for DQN training"""
|
||||
def _create_state_from_base_data(self, symbol: str, base_data: Any) -> np.ndarray:
|
||||
"""Create state representation for DQN training from base_data (same as CNN model)"""
|
||||
try:
|
||||
# Create a basic state representation
|
||||
ohlcv_data = market_data.get("ohlcv", [])
|
||||
if not ohlcv_data:
|
||||
return np.zeros(100) # Default state size
|
||||
# Validate base_data
|
||||
if not base_data or not hasattr(base_data, 'get_feature_vector'):
|
||||
logger.warning(f"Invalid base_data for {symbol}: {type(base_data)}")
|
||||
return np.zeros(403) # Default state size for DQN
|
||||
|
||||
# Get feature vector from base_data (same as CNN model)
|
||||
features = base_data.get_feature_vector()
|
||||
|
||||
if not features or len(features) == 0:
|
||||
logger.warning(f"No features available from base_data for {symbol}, using default state")
|
||||
return np.zeros(403) # Default state size for DQN
|
||||
|
||||
# Extract features from recent candles
|
||||
features = []
|
||||
for candle in ohlcv_data[-20:]: # Last 20 candles
|
||||
features.extend(
|
||||
[
|
||||
candle.get("open", 0),
|
||||
candle.get("high", 0),
|
||||
candle.get("low", 0),
|
||||
candle.get("close", 0),
|
||||
candle.get("volume", 0),
|
||||
]
|
||||
)
|
||||
|
||||
# Pad or truncate to expected size
|
||||
state = np.array(features[:100])
|
||||
if len(state) < 100:
|
||||
state = np.pad(state, (0, 100 - len(state)), "constant")
|
||||
# Convert to numpy array
|
||||
state = np.array(features, dtype=np.float32)
|
||||
|
||||
# Ensure correct dimensions for DQN (403 features)
|
||||
if len(state) != 403:
|
||||
if len(state) < 403:
|
||||
# Pad with zeros
|
||||
padded_state = np.zeros(403, dtype=np.float32)
|
||||
padded_state[:len(state)] = state
|
||||
state = padded_state
|
||||
else:
|
||||
# Truncate
|
||||
state = state[:403]
|
||||
|
||||
return state
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error creating state for training: {e}")
|
||||
return np.zeros(100)
|
||||
logger.error(f"Error creating state from base_data for {symbol}: {e}")
|
||||
return np.zeros(403) # Default state size for DQN
|
||||
|
||||
def _create_cnn_features_for_training(
|
||||
self, symbol: str, market_data: Dict
|
||||
|
||||
|
||||
def _create_cnn_features_from_base_data(
|
||||
self, symbol: str, base_data: Any
|
||||
) -> np.ndarray:
|
||||
"""Create CNN features for training"""
|
||||
"""Create CNN features for training from base_data (same as inference)"""
|
||||
try:
|
||||
# Similar to state creation but formatted for CNN
|
||||
ohlcv_data = market_data.get("ohlcv", [])
|
||||
if not ohlcv_data:
|
||||
return np.zeros((1, 100))
|
||||
# Validate base_data
|
||||
if not base_data or not hasattr(base_data, 'get_feature_vector'):
|
||||
logger.warning(f"Invalid base_data for CNN training {symbol}: {type(base_data)}")
|
||||
return np.zeros((1, 403)) # Default CNN input size
|
||||
|
||||
# Get feature vector from base_data (same as CNN inference)
|
||||
features = base_data.get_feature_vector()
|
||||
|
||||
if not features or len(features) == 0:
|
||||
logger.warning(f"No features available from base_data for CNN training {symbol}, using default")
|
||||
return np.zeros((1, 403)) # Default CNN input size
|
||||
|
||||
# Create feature matrix
|
||||
features = []
|
||||
for candle in ohlcv_data[-20:]:
|
||||
features.extend(
|
||||
[
|
||||
candle.get("open", 0),
|
||||
candle.get("high", 0),
|
||||
candle.get("low", 0),
|
||||
candle.get("close", 0),
|
||||
candle.get("volume", 0),
|
||||
]
|
||||
)
|
||||
|
||||
# Reshape for CNN input
|
||||
cnn_features = np.array(features[:100]).reshape(1, -1)
|
||||
if cnn_features.shape[1] < 100:
|
||||
cnn_features = np.pad(
|
||||
cnn_features, ((0, 0), (0, 100 - cnn_features.shape[1])), "constant"
|
||||
)
|
||||
# Convert to numpy array and reshape for CNN
|
||||
cnn_features = np.array(features, dtype=np.float32).reshape(1, -1)
|
||||
|
||||
# Ensure correct dimensions for CNN (403 features)
|
||||
if cnn_features.shape[1] != 403:
|
||||
if cnn_features.shape[1] < 403:
|
||||
# Pad with zeros
|
||||
padded_features = np.zeros((1, 403), dtype=np.float32)
|
||||
padded_features[0, :cnn_features.shape[1]] = cnn_features[0]
|
||||
cnn_features = padded_features
|
||||
else:
|
||||
# Truncate
|
||||
cnn_features = cnn_features[:, :403]
|
||||
|
||||
return cnn_features
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error creating CNN features for training: {e}")
|
||||
return np.zeros((1, 100))
|
||||
logger.error(f"Error creating CNN features from base_data for {symbol}: {e}")
|
||||
return np.zeros((1, 403)) # Default CNN input size
|
||||
|
||||
|
||||
|
||||
def _create_cob_state_for_training(self, symbol: str, cob_data: Dict) -> np.ndarray:
|
||||
"""Create COB state representation for training"""
|
||||
|
Reference in New Issue
Block a user