feed COB to the models

This commit is contained in:
Dobromir Popov
2025-07-02 00:38:29 +03:00
parent 1442e28101
commit 56f1110df3

View File

@ -3630,12 +3630,16 @@ class CleanTradingDashboard:
def _feed_cob_data_to_models(self, symbol: str, cob_snapshot: dict): def _feed_cob_data_to_models(self, symbol: str, cob_snapshot: dict):
"""Feed COB data to models for training and inference""" """Feed COB data to models for training and inference"""
try: try:
# Calculate cumulative imbalance for model feeding
cumulative_imbalance = self._calculate_cumulative_imbalance(symbol)
# Create 15-second history for model feeding # Create 15-second history for model feeding
history_data = { history_data = {
'symbol': symbol, 'symbol': symbol,
'current_snapshot': cob_snapshot, 'current_snapshot': cob_snapshot,
'history': self.cob_data_history[symbol][-15:], # Last 15 seconds 'history': self.cob_data_history[symbol][-15:], # Last 15 seconds
'bucketed_data': self.cob_bucketed_data[symbol], 'bucketed_data': self.cob_bucketed_data[symbol],
'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance
'timestamp': cob_snapshot['timestamp'] 'timestamp': cob_snapshot['timestamp']
} }
@ -3643,7 +3647,7 @@ class CleanTradingDashboard:
if hasattr(self.orchestrator, '_on_cob_dashboard_data'): if hasattr(self.orchestrator, '_on_cob_dashboard_data'):
try: try:
self.orchestrator._on_cob_dashboard_data(symbol, history_data) self.orchestrator._on_cob_dashboard_data(symbol, history_data)
logger.debug(f"COB data fed to orchestrator for {symbol}") logger.debug(f"COB data fed to orchestrator for {symbol} with cumulative imbalance: {cumulative_imbalance}")
except Exception as e: except Exception as e:
logger.debug(f"Error feeding COB data to orchestrator: {e}") logger.debug(f"Error feeding COB data to orchestrator: {e}")
@ -4026,6 +4030,10 @@ class CleanTradingDashboard:
current_price = self._get_current_price('ETH/USDT') current_price = self._get_current_price('ETH/USDT')
if not current_price: if not current_price:
return training_data return training_data
# Get cumulative imbalance for training
cumulative_imbalance = self._calculate_cumulative_imbalance('ETH/USDT')
df = self.data_provider.get_historical_data('ETH/USDT', '1m', limit=50) df = self.data_provider.get_historical_data('ETH/USDT', '1m', limit=50)
if df is not None and not df.empty: if df is not None and not df.empty:
for i in range(1, min(len(df), 20)): for i in range(1, min(len(df), 20)):
@ -4035,6 +4043,7 @@ class CleanTradingDashboard:
sample = { sample = {
'timestamp': df.index[i], 'price': curr_price, 'prev_price': prev_price, 'timestamp': df.index[i], 'price': curr_price, 'prev_price': prev_price,
'price_change': price_change, 'volume': float(df['volume'].iloc[i]), 'price_change': price_change, 'volume': float(df['volume'].iloc[i]),
'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance
'action': 'BUY' if price_change > 0.001 else 'SELL' if price_change < -0.001 else 'HOLD' 'action': 'BUY' if price_change > 0.001 else 'SELL' if price_change < -0.001 else 'HOLD'
} }
training_data.append(sample) training_data.append(sample)
@ -4043,7 +4052,8 @@ class CleanTradingDashboard:
for tick in recent_ticks: for tick in recent_ticks:
sample = { sample = {
'timestamp': tick.get('datetime', datetime.now()), 'price': tick.get('price', current_price), 'timestamp': tick.get('datetime', datetime.now()), 'price': tick.get('price', current_price),
'volume': tick.get('volume', 0), 'tick_data': True 'volume': tick.get('volume', 0), 'cumulative_imbalance': cumulative_imbalance, # Add cumulative imbalance
'tick_data': True
} }
training_data.append(sample) training_data.append(sample)
return training_data return training_data
@ -4058,13 +4068,34 @@ class CleanTradingDashboard:
return return
agent = self.orchestrator.rl_agent agent = self.orchestrator.rl_agent
training_samples = 0 training_samples = 0
total_loss = 0
loss_count = 0
for data in market_data[-10:]: for data in market_data[-10:]:
try: try:
price = data.get('price', 0) price = data.get('price', 0)
prev_price = data.get('prev_price', price) prev_price = data.get('prev_price', price)
price_change = data.get('price_change', 0) price_change = data.get('price_change', 0)
volume = data.get('volume', 0) volume = data.get('volume', 0)
state = np.array([price / 10000, price_change, volume / 1000000, 1.0 if price > prev_price else 0.0, abs(price_change) * 100]) cumulative_imbalance = data.get('cumulative_imbalance', {})
# Extract imbalance values for state
imbalance_1s = cumulative_imbalance.get('1s', 0.0)
imbalance_5s = cumulative_imbalance.get('5s', 0.0)
imbalance_15s = cumulative_imbalance.get('15s', 0.0)
imbalance_60s = cumulative_imbalance.get('60s', 0.0)
state = np.array([
price / 10000,
price_change,
volume / 1000000,
1.0 if price > prev_price else 0.0,
abs(price_change) * 100,
imbalance_1s,
imbalance_5s,
imbalance_15s,
imbalance_60s
])
if hasattr(agent, 'state_dim') and len(state) < agent.state_dim: if hasattr(agent, 'state_dim') and len(state) < agent.state_dim:
padded_state = np.zeros(agent.state_dim) padded_state = np.zeros(agent.state_dim)
padded_state[:len(state)] = state padded_state[:len(state)] = state
@ -4079,17 +4110,58 @@ class CleanTradingDashboard:
training_samples += 1 training_samples += 1
except Exception as e: except Exception as e:
logger.debug(f"Error adding market experience to DQN memory: {e}") logger.debug(f"Error adding market experience to DQN memory: {e}")
if hasattr(agent, 'memory') and len(agent.memory) >= 32: if hasattr(agent, 'memory') and len(agent.memory) >= 32:
for _ in range(3): for _ in range(3):
try: try:
loss = agent.replay() loss = agent.replay()
if loss is not None: if loss is not None:
total_loss += loss
loss_count += 1
self.orchestrator.update_model_loss('dqn', loss) self.orchestrator.update_model_loss('dqn', loss)
if not hasattr(agent, 'losses'): agent.losses = [] if not hasattr(agent, 'losses'): agent.losses = []
agent.losses.append(loss) agent.losses.append(loss)
if len(agent.losses) > 1000: agent.losses = agent.losses[-1000:] if len(agent.losses) > 1000: agent.losses = agent.losses[-1000:]
except Exception as e: except Exception as e:
logger.debug(f"DQN training step failed: {e}") logger.debug(f"DQN training step failed: {e}")
# Save checkpoint after training
if loss_count > 0:
try:
from utils.checkpoint_manager import save_checkpoint
avg_loss = total_loss / loss_count
# Prepare checkpoint data
checkpoint_data = {
'model_state_dict': agent.model.state_dict() if hasattr(agent, 'model') else None,
'target_model_state_dict': agent.target_model.state_dict() if hasattr(agent, 'target_model') else None,
'optimizer_state_dict': agent.optimizer.state_dict() if hasattr(agent, 'optimizer') else None,
'memory_size': len(agent.memory),
'training_samples': training_samples,
'losses': agent.losses[-100:] if hasattr(agent, 'losses') else []
}
performance_metrics = {
'loss': avg_loss,
'memory_size': len(agent.memory),
'training_samples': training_samples,
'model_parameters': sum(p.numel() for p in agent.model.parameters()) if hasattr(agent, 'model') else 0
}
metadata = save_checkpoint(
model=checkpoint_data,
model_name="dqn_agent",
model_type="dqn",
performance_metrics=performance_metrics,
training_metadata={'training_iterations': loss_count}
)
if metadata:
logger.info(f"DQN checkpoint saved: {metadata.checkpoint_id} (loss={avg_loss:.4f})")
except Exception as e:
logger.error(f"Error saving DQN checkpoint: {e}")
logger.info(f"DQN TRAINING: Added {training_samples} experiences, memory size: {len(agent.memory)}") logger.info(f"DQN TRAINING: Added {training_samples} experiences, memory size: {len(agent.memory)}")
except Exception as e: except Exception as e:
logger.error(f"Error in real DQN training: {e}") logger.error(f"Error in real DQN training: {e}")
@ -4109,10 +4181,17 @@ class CleanTradingDashboard:
current_price = current_data.get('price', 0) current_price = current_data.get('price', 0)
next_price = next_data.get('price', current_price) next_price = next_data.get('price', current_price)
price_change = (next_price - current_price) / current_price if current_price > 0 else 0 price_change = (next_price - current_price) / current_price if current_price > 0 else 0
cumulative_imbalance = current_data.get('cumulative_imbalance', {})
features = np.random.randn(100) features = np.random.randn(100)
features[0] = current_price / 10000 features[0] = current_price / 10000
features[1] = price_change features[1] = price_change
features[2] = current_data.get('volume', 0) / 1000000 features[2] = current_data.get('volume', 0) / 1000000
# Add cumulative imbalance features
features[3] = cumulative_imbalance.get('1s', 0.0)
features[4] = cumulative_imbalance.get('5s', 0.0)
features[5] = cumulative_imbalance.get('15s', 0.0)
features[6] = cumulative_imbalance.get('60s', 0.0)
if price_change > 0.001: target = 2 if price_change > 0.001: target = 2
elif price_change < -0.001: target = 0 elif price_change < -0.001: target = 0
else: target = 1 else: target = 1