feed COB to the models
This commit is contained in:
@ -3630,12 +3630,16 @@ class CleanTradingDashboard:
|
||||
def _feed_cob_data_to_models(self, symbol: str, cob_snapshot: dict):
|
||||
"""Feed COB data to models for training and inference"""
|
||||
try:
|
||||
# Calculate cumulative imbalance for model feeding
|
||||
cumulative_imbalance = self._calculate_cumulative_imbalance(symbol)
|
||||
|
||||
# Create 15-second history for model feeding
|
||||
history_data = {
|
||||
'symbol': symbol,
|
||||
'current_snapshot': cob_snapshot,
|
||||
'history': self.cob_data_history[symbol][-15:], # Last 15 seconds
|
||||
'bucketed_data': self.cob_bucketed_data[symbol],
|
||||
'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance
|
||||
'timestamp': cob_snapshot['timestamp']
|
||||
}
|
||||
|
||||
@ -3643,7 +3647,7 @@ class CleanTradingDashboard:
|
||||
if hasattr(self.orchestrator, '_on_cob_dashboard_data'):
|
||||
try:
|
||||
self.orchestrator._on_cob_dashboard_data(symbol, history_data)
|
||||
logger.debug(f"COB data fed to orchestrator for {symbol}")
|
||||
logger.debug(f"COB data fed to orchestrator for {symbol} with cumulative imbalance: {cumulative_imbalance}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error feeding COB data to orchestrator: {e}")
|
||||
|
||||
@ -4026,6 +4030,10 @@ class CleanTradingDashboard:
|
||||
current_price = self._get_current_price('ETH/USDT')
|
||||
if not current_price:
|
||||
return training_data
|
||||
|
||||
# Get cumulative imbalance for training
|
||||
cumulative_imbalance = self._calculate_cumulative_imbalance('ETH/USDT')
|
||||
|
||||
df = self.data_provider.get_historical_data('ETH/USDT', '1m', limit=50)
|
||||
if df is not None and not df.empty:
|
||||
for i in range(1, min(len(df), 20)):
|
||||
@ -4035,6 +4043,7 @@ class CleanTradingDashboard:
|
||||
sample = {
|
||||
'timestamp': df.index[i], 'price': curr_price, 'prev_price': prev_price,
|
||||
'price_change': price_change, 'volume': float(df['volume'].iloc[i]),
|
||||
'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance
|
||||
'action': 'BUY' if price_change > 0.001 else 'SELL' if price_change < -0.001 else 'HOLD'
|
||||
}
|
||||
training_data.append(sample)
|
||||
@ -4043,7 +4052,8 @@ class CleanTradingDashboard:
|
||||
for tick in recent_ticks:
|
||||
sample = {
|
||||
'timestamp': tick.get('datetime', datetime.now()), 'price': tick.get('price', current_price),
|
||||
'volume': tick.get('volume', 0), 'tick_data': True
|
||||
'volume': tick.get('volume', 0), 'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance
|
||||
'tick_data': True
|
||||
}
|
||||
training_data.append(sample)
|
||||
return training_data
|
||||
@ -4058,13 +4068,34 @@ class CleanTradingDashboard:
|
||||
return
|
||||
agent = self.orchestrator.rl_agent
|
||||
training_samples = 0
|
||||
total_loss = 0
|
||||
loss_count = 0
|
||||
|
||||
for data in market_data[-10:]:
|
||||
try:
|
||||
price = data.get('price', 0)
|
||||
prev_price = data.get('prev_price', price)
|
||||
price_change = data.get('price_change', 0)
|
||||
volume = data.get('volume', 0)
|
||||
state = np.array([price / 10000, price_change, volume / 1000000, 1.0 if price > prev_price else 0.0, abs(price_change) * 100])
|
||||
cumulative_imbalance = data.get('cumulative_imbalance', {})
|
||||
|
||||
# Extract imbalance values for state
|
||||
imbalance_1s = cumulative_imbalance.get('1s', 0.0)
|
||||
imbalance_5s = cumulative_imbalance.get('5s', 0.0)
|
||||
imbalance_15s = cumulative_imbalance.get('15s', 0.0)
|
||||
imbalance_60s = cumulative_imbalance.get('60s', 0.0)
|
||||
|
||||
state = np.array([
|
||||
price / 10000,
|
||||
price_change,
|
||||
volume / 1000000,
|
||||
1.0 if price > prev_price else 0.0,
|
||||
abs(price_change) * 100,
|
||||
imbalance_1s,
|
||||
imbalance_5s,
|
||||
imbalance_15s,
|
||||
imbalance_60s
|
||||
])
|
||||
if hasattr(agent, 'state_dim') and len(state) < agent.state_dim:
|
||||
padded_state = np.zeros(agent.state_dim)
|
||||
padded_state[:len(state)] = state
|
||||
@ -4079,17 +4110,58 @@ class CleanTradingDashboard:
|
||||
training_samples += 1
|
||||
except Exception as e:
|
||||
logger.debug(f"Error adding market experience to DQN memory: {e}")
|
||||
|
||||
if hasattr(agent, 'memory') and len(agent.memory) >= 32:
|
||||
for _ in range(3):
|
||||
try:
|
||||
loss = agent.replay()
|
||||
if loss is not None:
|
||||
total_loss += loss
|
||||
loss_count += 1
|
||||
self.orchestrator.update_model_loss('dqn', loss)
|
||||
if not hasattr(agent, 'losses'): agent.losses = []
|
||||
agent.losses.append(loss)
|
||||
if len(agent.losses) > 1000: agent.losses = agent.losses[-1000:]
|
||||
except Exception as e:
|
||||
logger.debug(f"DQN training step failed: {e}")
|
||||
|
||||
# Save checkpoint after training
|
||||
if loss_count > 0:
|
||||
try:
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
avg_loss = total_loss / loss_count
|
||||
|
||||
# Prepare checkpoint data
|
||||
checkpoint_data = {
|
||||
'model_state_dict': agent.model.state_dict() if hasattr(agent, 'model') else None,
|
||||
'target_model_state_dict': agent.target_model.state_dict() if hasattr(agent, 'target_model') else None,
|
||||
'optimizer_state_dict': agent.optimizer.state_dict() if hasattr(agent, 'optimizer') else None,
|
||||
'memory_size': len(agent.memory),
|
||||
'training_samples': training_samples,
|
||||
'losses': agent.losses[-100:] if hasattr(agent, 'losses') else []
|
||||
}
|
||||
|
||||
performance_metrics = {
|
||||
'loss': avg_loss,
|
||||
'memory_size': len(agent.memory),
|
||||
'training_samples': training_samples,
|
||||
'model_parameters': sum(p.numel() for p in agent.model.parameters()) if hasattr(agent, 'model') else 0
|
||||
}
|
||||
|
||||
metadata = save_checkpoint(
|
||||
model=checkpoint_data,
|
||||
model_name="dqn_agent",
|
||||
model_type="dqn",
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata={'training_iterations': loss_count}
|
||||
)
|
||||
|
||||
if metadata:
|
||||
logger.info(f"DQN checkpoint saved: {metadata.checkpoint_id} (loss={avg_loss:.4f})")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving DQN checkpoint: {e}")
|
||||
|
||||
logger.info(f"DQN TRAINING: Added {training_samples} experiences, memory size: {len(agent.memory)}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in real DQN training: {e}")
|
||||
@ -4109,10 +4181,17 @@ class CleanTradingDashboard:
|
||||
current_price = current_data.get('price', 0)
|
||||
next_price = next_data.get('price', current_price)
|
||||
price_change = (next_price - current_price) / current_price if current_price > 0 else 0
|
||||
cumulative_imbalance = current_data.get('cumulative_imbalance', {})
|
||||
|
||||
features = np.random.randn(100)
|
||||
features[0] = current_price / 10000
|
||||
features[1] = price_change
|
||||
features[2] = current_data.get('volume', 0) / 1000000
|
||||
# Add cumulative imbalance features
|
||||
features[3] = cumulative_imbalance.get('1s', 0.0)
|
||||
features[4] = cumulative_imbalance.get('5s', 0.0)
|
||||
features[5] = cumulative_imbalance.get('15s', 0.0)
|
||||
features[6] = cumulative_imbalance.get('60s', 0.0)
|
||||
if price_change > 0.001: target = 2
|
||||
elif price_change < -0.001: target = 0
|
||||
else: target = 1
|
||||
|
Reference in New Issue
Block a user