feed COB to the models
This commit is contained in:
@ -3630,12 +3630,16 @@ class CleanTradingDashboard:
|
|||||||
def _feed_cob_data_to_models(self, symbol: str, cob_snapshot: dict):
|
def _feed_cob_data_to_models(self, symbol: str, cob_snapshot: dict):
|
||||||
"""Feed COB data to models for training and inference"""
|
"""Feed COB data to models for training and inference"""
|
||||||
try:
|
try:
|
||||||
|
# Calculate cumulative imbalance for model feeding
|
||||||
|
cumulative_imbalance = self._calculate_cumulative_imbalance(symbol)
|
||||||
|
|
||||||
# Create 15-second history for model feeding
|
# Create 15-second history for model feeding
|
||||||
history_data = {
|
history_data = {
|
||||||
'symbol': symbol,
|
'symbol': symbol,
|
||||||
'current_snapshot': cob_snapshot,
|
'current_snapshot': cob_snapshot,
|
||||||
'history': self.cob_data_history[symbol][-15:], # Last 15 seconds
|
'history': self.cob_data_history[symbol][-15:], # Last 15 seconds
|
||||||
'bucketed_data': self.cob_bucketed_data[symbol],
|
'bucketed_data': self.cob_bucketed_data[symbol],
|
||||||
|
'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance
|
||||||
'timestamp': cob_snapshot['timestamp']
|
'timestamp': cob_snapshot['timestamp']
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3643,7 +3647,7 @@ class CleanTradingDashboard:
|
|||||||
if hasattr(self.orchestrator, '_on_cob_dashboard_data'):
|
if hasattr(self.orchestrator, '_on_cob_dashboard_data'):
|
||||||
try:
|
try:
|
||||||
self.orchestrator._on_cob_dashboard_data(symbol, history_data)
|
self.orchestrator._on_cob_dashboard_data(symbol, history_data)
|
||||||
logger.debug(f"COB data fed to orchestrator for {symbol}")
|
logger.debug(f"COB data fed to orchestrator for {symbol} with cumulative imbalance: {cumulative_imbalance}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error feeding COB data to orchestrator: {e}")
|
logger.debug(f"Error feeding COB data to orchestrator: {e}")
|
||||||
|
|
||||||
@ -4026,6 +4030,10 @@ class CleanTradingDashboard:
|
|||||||
current_price = self._get_current_price('ETH/USDT')
|
current_price = self._get_current_price('ETH/USDT')
|
||||||
if not current_price:
|
if not current_price:
|
||||||
return training_data
|
return training_data
|
||||||
|
|
||||||
|
# Get cumulative imbalance for training
|
||||||
|
cumulative_imbalance = self._calculate_cumulative_imbalance('ETH/USDT')
|
||||||
|
|
||||||
df = self.data_provider.get_historical_data('ETH/USDT', '1m', limit=50)
|
df = self.data_provider.get_historical_data('ETH/USDT', '1m', limit=50)
|
||||||
if df is not None and not df.empty:
|
if df is not None and not df.empty:
|
||||||
for i in range(1, min(len(df), 20)):
|
for i in range(1, min(len(df), 20)):
|
||||||
@ -4035,6 +4043,7 @@ class CleanTradingDashboard:
|
|||||||
sample = {
|
sample = {
|
||||||
'timestamp': df.index[i], 'price': curr_price, 'prev_price': prev_price,
|
'timestamp': df.index[i], 'price': curr_price, 'prev_price': prev_price,
|
||||||
'price_change': price_change, 'volume': float(df['volume'].iloc[i]),
|
'price_change': price_change, 'volume': float(df['volume'].iloc[i]),
|
||||||
|
'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance
|
||||||
'action': 'BUY' if price_change > 0.001 else 'SELL' if price_change < -0.001 else 'HOLD'
|
'action': 'BUY' if price_change > 0.001 else 'SELL' if price_change < -0.001 else 'HOLD'
|
||||||
}
|
}
|
||||||
training_data.append(sample)
|
training_data.append(sample)
|
||||||
@ -4043,7 +4052,8 @@ class CleanTradingDashboard:
|
|||||||
for tick in recent_ticks:
|
for tick in recent_ticks:
|
||||||
sample = {
|
sample = {
|
||||||
'timestamp': tick.get('datetime', datetime.now()), 'price': tick.get('price', current_price),
|
'timestamp': tick.get('datetime', datetime.now()), 'price': tick.get('price', current_price),
|
||||||
'volume': tick.get('volume', 0), 'tick_data': True
|
'volume': tick.get('volume', 0), 'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance
|
||||||
|
'tick_data': True
|
||||||
}
|
}
|
||||||
training_data.append(sample)
|
training_data.append(sample)
|
||||||
return training_data
|
return training_data
|
||||||
@ -4058,13 +4068,34 @@ class CleanTradingDashboard:
|
|||||||
return
|
return
|
||||||
agent = self.orchestrator.rl_agent
|
agent = self.orchestrator.rl_agent
|
||||||
training_samples = 0
|
training_samples = 0
|
||||||
|
total_loss = 0
|
||||||
|
loss_count = 0
|
||||||
|
|
||||||
for data in market_data[-10:]:
|
for data in market_data[-10:]:
|
||||||
try:
|
try:
|
||||||
price = data.get('price', 0)
|
price = data.get('price', 0)
|
||||||
prev_price = data.get('prev_price', price)
|
prev_price = data.get('prev_price', price)
|
||||||
price_change = data.get('price_change', 0)
|
price_change = data.get('price_change', 0)
|
||||||
volume = data.get('volume', 0)
|
volume = data.get('volume', 0)
|
||||||
state = np.array([price / 10000, price_change, volume / 1000000, 1.0 if price > prev_price else 0.0, abs(price_change) * 100])
|
cumulative_imbalance = data.get('cumulative_imbalance', {})
|
||||||
|
|
||||||
|
# Extract imbalance values for state
|
||||||
|
imbalance_1s = cumulative_imbalance.get('1s', 0.0)
|
||||||
|
imbalance_5s = cumulative_imbalance.get('5s', 0.0)
|
||||||
|
imbalance_15s = cumulative_imbalance.get('15s', 0.0)
|
||||||
|
imbalance_60s = cumulative_imbalance.get('60s', 0.0)
|
||||||
|
|
||||||
|
state = np.array([
|
||||||
|
price / 10000,
|
||||||
|
price_change,
|
||||||
|
volume / 1000000,
|
||||||
|
1.0 if price > prev_price else 0.0,
|
||||||
|
abs(price_change) * 100,
|
||||||
|
imbalance_1s,
|
||||||
|
imbalance_5s,
|
||||||
|
imbalance_15s,
|
||||||
|
imbalance_60s
|
||||||
|
])
|
||||||
if hasattr(agent, 'state_dim') and len(state) < agent.state_dim:
|
if hasattr(agent, 'state_dim') and len(state) < agent.state_dim:
|
||||||
padded_state = np.zeros(agent.state_dim)
|
padded_state = np.zeros(agent.state_dim)
|
||||||
padded_state[:len(state)] = state
|
padded_state[:len(state)] = state
|
||||||
@ -4079,17 +4110,58 @@ class CleanTradingDashboard:
|
|||||||
training_samples += 1
|
training_samples += 1
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Error adding market experience to DQN memory: {e}")
|
logger.debug(f"Error adding market experience to DQN memory: {e}")
|
||||||
|
|
||||||
if hasattr(agent, 'memory') and len(agent.memory) >= 32:
|
if hasattr(agent, 'memory') and len(agent.memory) >= 32:
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
try:
|
try:
|
||||||
loss = agent.replay()
|
loss = agent.replay()
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
|
total_loss += loss
|
||||||
|
loss_count += 1
|
||||||
self.orchestrator.update_model_loss('dqn', loss)
|
self.orchestrator.update_model_loss('dqn', loss)
|
||||||
if not hasattr(agent, 'losses'): agent.losses = []
|
if not hasattr(agent, 'losses'): agent.losses = []
|
||||||
agent.losses.append(loss)
|
agent.losses.append(loss)
|
||||||
if len(agent.losses) > 1000: agent.losses = agent.losses[-1000:]
|
if len(agent.losses) > 1000: agent.losses = agent.losses[-1000:]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"DQN training step failed: {e}")
|
logger.debug(f"DQN training step failed: {e}")
|
||||||
|
|
||||||
|
# Save checkpoint after training
|
||||||
|
if loss_count > 0:
|
||||||
|
try:
|
||||||
|
from utils.checkpoint_manager import save_checkpoint
|
||||||
|
avg_loss = total_loss / loss_count
|
||||||
|
|
||||||
|
# Prepare checkpoint data
|
||||||
|
checkpoint_data = {
|
||||||
|
'model_state_dict': agent.model.state_dict() if hasattr(agent, 'model') else None,
|
||||||
|
'target_model_state_dict': agent.target_model.state_dict() if hasattr(agent, 'target_model') else None,
|
||||||
|
'optimizer_state_dict': agent.optimizer.state_dict() if hasattr(agent, 'optimizer') else None,
|
||||||
|
'memory_size': len(agent.memory),
|
||||||
|
'training_samples': training_samples,
|
||||||
|
'losses': agent.losses[-100:] if hasattr(agent, 'losses') else []
|
||||||
|
}
|
||||||
|
|
||||||
|
performance_metrics = {
|
||||||
|
'loss': avg_loss,
|
||||||
|
'memory_size': len(agent.memory),
|
||||||
|
'training_samples': training_samples,
|
||||||
|
'model_parameters': sum(p.numel() for p in agent.model.parameters()) if hasattr(agent, 'model') else 0
|
||||||
|
}
|
||||||
|
|
||||||
|
metadata = save_checkpoint(
|
||||||
|
model=checkpoint_data,
|
||||||
|
model_name="dqn_agent",
|
||||||
|
model_type="dqn",
|
||||||
|
performance_metrics=performance_metrics,
|
||||||
|
training_metadata={'training_iterations': loss_count}
|
||||||
|
)
|
||||||
|
|
||||||
|
if metadata:
|
||||||
|
logger.info(f"DQN checkpoint saved: {metadata.checkpoint_id} (loss={avg_loss:.4f})")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving DQN checkpoint: {e}")
|
||||||
|
|
||||||
logger.info(f"DQN TRAINING: Added {training_samples} experiences, memory size: {len(agent.memory)}")
|
logger.info(f"DQN TRAINING: Added {training_samples} experiences, memory size: {len(agent.memory)}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in real DQN training: {e}")
|
logger.error(f"Error in real DQN training: {e}")
|
||||||
@ -4109,10 +4181,17 @@ class CleanTradingDashboard:
|
|||||||
current_price = current_data.get('price', 0)
|
current_price = current_data.get('price', 0)
|
||||||
next_price = next_data.get('price', current_price)
|
next_price = next_data.get('price', current_price)
|
||||||
price_change = (next_price - current_price) / current_price if current_price > 0 else 0
|
price_change = (next_price - current_price) / current_price if current_price > 0 else 0
|
||||||
|
cumulative_imbalance = current_data.get('cumulative_imbalance', {})
|
||||||
|
|
||||||
features = np.random.randn(100)
|
features = np.random.randn(100)
|
||||||
features[0] = current_price / 10000
|
features[0] = current_price / 10000
|
||||||
features[1] = price_change
|
features[1] = price_change
|
||||||
features[2] = current_data.get('volume', 0) / 1000000
|
features[2] = current_data.get('volume', 0) / 1000000
|
||||||
|
# Add cumulative imbalance features
|
||||||
|
features[3] = cumulative_imbalance.get('1s', 0.0)
|
||||||
|
features[4] = cumulative_imbalance.get('5s', 0.0)
|
||||||
|
features[5] = cumulative_imbalance.get('15s', 0.0)
|
||||||
|
features[6] = cumulative_imbalance.get('60s', 0.0)
|
||||||
if price_change > 0.001: target = 2
|
if price_change > 0.001: target = 2
|
||||||
elif price_change < -0.001: target = 0
|
elif price_change < -0.001: target = 0
|
||||||
else: target = 1
|
else: target = 1
|
||||||
|
Reference in New Issue
Block a user