diff --git a/NN/models/dqn_agent.py b/NN/models/dqn_agent.py
index a077359..6477d0f 100644
--- a/NN/models/dqn_agent.py
+++ b/NN/models/dqn_agent.py
@@ -28,14 +28,14 @@ class DQNAgent:
window_size: int,
num_features: int,
timeframes: List[str],
- learning_rate: float = 0.001,
- gamma: float = 0.99,
+ learning_rate: float = 0.0005, # Reduced learning rate for more stability
+ gamma: float = 0.97, # Slightly reduced discount factor
epsilon: float = 1.0,
- epsilon_min: float = 0.01,
- epsilon_decay: float = 0.995,
- memory_size: int = 10000,
- batch_size: int = 64,
- target_update: int = 10):
+ epsilon_min: float = 0.05, # Increased minimum epsilon for more exploration
+ epsilon_decay: float = 0.9975, # Slower decay rate
+ memory_size: int = 20000, # Increased memory size
+ batch_size: int = 128, # Larger batch size
+ target_update: int = 5): # More frequent target updates
self.state_size = state_size
self.action_size = action_size
@@ -70,23 +70,25 @@ class DQNAgent:
).to(self.device)
self.target_net.load_state_dict(self.policy_net.state_dict())
- # Initialize optimizer
- self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate)
+ # Initialize optimizer with gradient clipping
+ self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate, weight_decay=1e-5)
- # Initialize memory
+ # Initialize memories with different priorities
self.memory = deque(maxlen=memory_size)
-
- # Special memory for extrema samples to use for targeted learning
- self.extrema_memory = deque(maxlen=memory_size // 5) # Smaller size for extrema examples
+ self.extrema_memory = deque(maxlen=memory_size // 4) # For extrema points
+ self.positive_memory = deque(maxlen=memory_size // 4) # For positive rewards
# Training metrics
self.update_count = 0
self.losses = []
+ self.avg_reward = 0
+ self.no_improvement_count = 0
+ self.best_reward = float('-inf')
def remember(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool, is_extrema: bool = False):
"""
- Store experience in memory
+ Store experience in memory with prioritization
Args:
state: Current state
@@ -97,28 +99,124 @@ class DQNAgent:
is_extrema: Whether this is a local extrema sample (for specialized learning)
"""
experience = (state, action, reward, next_state, done)
+
+ # Always add to main memory
self.memory.append(experience)
- # If this is an extrema sample, also add to specialized memory
+ # Add to specialized memories if applicable
if is_extrema:
self.extrema_memory.append(experience)
+
+ # Store positive experiences separately for prioritized replay
+ if reward > 0:
+ self.positive_memory.append(experience)
- def act(self, state: np.ndarray) -> int:
- """Choose action using epsilon-greedy policy"""
- if random.random() < self.epsilon:
+ def act(self, state: np.ndarray, explore=True) -> int:
+ """Choose action using epsilon-greedy policy with explore flag"""
+ if explore and random.random() < self.epsilon:
return random.randrange(self.action_size)
with torch.no_grad():
- state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
- action_probs, extrema_pred = self.policy_net(state)
+ # Ensure state is normalized before inference
+ state_tensor = self._normalize_state(state)
+ state_tensor = torch.FloatTensor(state_tensor).unsqueeze(0).to(self.device)
+ action_probs, extrema_pred = self.policy_net(state_tensor)
return action_probs.argmax().item()
- def replay(self, use_extrema=False) -> float:
+ def _normalize_state(self, state: np.ndarray) -> np.ndarray:
+ """Normalize the state data to prevent numerical issues"""
+ # Handle NaN and infinite values
+ state = np.nan_to_num(state, nan=0.0, posinf=1.0, neginf=-1.0)
+
+ # Check if state is 1D array (happens in some environments)
+ if len(state.shape) == 1:
+ # If 1D, we need to normalize the whole array
+ normalized_state = state.copy()
+
+ # Convert any timestamp or non-numeric data to float
+ for i in range(len(normalized_state)):
+ # Check for timestamp-like objects
+ if hasattr(normalized_state[i], 'timestamp') and callable(getattr(normalized_state[i], 'timestamp')):
+ # Convert timestamp to float (seconds since epoch)
+ normalized_state[i] = float(normalized_state[i].timestamp())
+ elif not isinstance(normalized_state[i], (int, float, np.number)):
+ # Set non-numeric data to 0
+ normalized_state[i] = 0.0
+
+ # Ensure all values are float
+ normalized_state = normalized_state.astype(np.float32)
+
+ # Simple min-max normalization for 1D state
+ state_min = np.min(normalized_state)
+ state_max = np.max(normalized_state)
+ if state_max > state_min:
+ normalized_state = (normalized_state - state_min) / (state_max - state_min)
+ return normalized_state
+
+ # Handle 2D arrays
+ normalized_state = np.zeros_like(state, dtype=np.float32)
+
+ # Convert any timestamp or non-numeric data to float
+ for i in range(state.shape[0]):
+ for j in range(state.shape[1]):
+ if hasattr(state[i, j], 'timestamp') and callable(getattr(state[i, j], 'timestamp')):
+ # Convert timestamp to float (seconds since epoch)
+ normalized_state[i, j] = float(state[i, j].timestamp())
+ elif isinstance(state[i, j], (int, float, np.number)):
+ normalized_state[i, j] = state[i, j]
+ else:
+ # Set non-numeric data to 0
+ normalized_state[i, j] = 0.0
+
+ # Loop through each timeframe's features in the combined state
+ feature_count = state.shape[1] // len(self.timeframes)
+
+ for tf_idx in range(len(self.timeframes)):
+ start_idx = tf_idx * feature_count
+ end_idx = start_idx + feature_count
+
+ # Extract this timeframe's features
+ tf_features = normalized_state[:, start_idx:end_idx]
+
+ # Normalize OHLCV data by the first close price in the window
+ # This makes price movements relative rather than absolute
+ price_idx = 3 # Assuming close price is at index 3
+ if price_idx < tf_features.shape[1]:
+ reference_price = np.mean(tf_features[:, price_idx])
+ if reference_price != 0:
+ # Normalize price-related columns (OHLC)
+ for i in range(4): # First 4 columns are OHLC
+ if i < tf_features.shape[1]:
+ normalized_state[:, start_idx + i] = tf_features[:, i] / reference_price
+
+ # Normalize volume using mean and std
+ vol_idx = 4 # Assuming volume is at index 4
+ if vol_idx < tf_features.shape[1]:
+ vol_mean = np.mean(tf_features[:, vol_idx])
+ vol_std = np.std(tf_features[:, vol_idx])
+ if vol_std > 0:
+ normalized_state[:, start_idx + vol_idx] = (tf_features[:, vol_idx] - vol_mean) / vol_std
+ else:
+ normalized_state[:, start_idx + vol_idx] = 0
+
+ # Other features (technical indicators) - normalize with min-max scaling
+ for i in range(5, feature_count):
+ if i < tf_features.shape[1]:
+ feature_min = np.min(tf_features[:, i])
+ feature_max = np.max(tf_features[:, i])
+ if feature_max > feature_min:
+ normalized_state[:, start_idx + i] = (tf_features[:, i] - feature_min) / (feature_max - feature_min)
+ else:
+ normalized_state[:, start_idx + i] = 0
+
+ return normalized_state
+
+ def replay(self, use_prioritized=True) -> float:
"""
- Train on a batch of experiences
+ Train on a batch of experiences with prioritized sampling
Args:
- use_extrema: Whether to include extrema samples in training
+ use_prioritized: Whether to use prioritized replay
Returns:
float: Loss value
@@ -126,55 +224,67 @@ class DQNAgent:
if len(self.memory) < self.batch_size:
return 0.0
- # Sample batch - mix regular and extrema samples
+ # Sample batch with prioritization
batch = []
- if use_extrema and len(self.extrema_memory) > self.batch_size // 4:
- # Get some extrema samples
- extrema_count = min(self.batch_size // 3, len(self.extrema_memory))
- extrema_samples = random.sample(list(self.extrema_memory), extrema_count)
+
+ if use_prioritized and len(self.positive_memory) > 0 and len(self.extrema_memory) > 0:
+ # Prioritized sampling from different memory types
+ positive_count = min(self.batch_size // 4, len(self.positive_memory))
+ extrema_count = min(self.batch_size // 4, len(self.extrema_memory))
+ regular_count = self.batch_size - positive_count - extrema_count
- # Get regular samples for the rest
- regular_count = self.batch_size - extrema_count
+ positive_samples = random.sample(list(self.positive_memory), positive_count)
+ extrema_samples = random.sample(list(self.extrema_memory), extrema_count)
regular_samples = random.sample(list(self.memory), regular_count)
- # Combine samples
- batch = extrema_samples + regular_samples
+ batch = positive_samples + extrema_samples + regular_samples
else:
# Standard sampling
batch = random.sample(self.memory, self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
+ # Normalize states before training
+ normalized_states = np.array([self._normalize_state(state) for state in states])
+ normalized_next_states = np.array([self._normalize_state(state) for state in next_states])
+
# Convert to tensors and move to device
- states = torch.FloatTensor(np.array(states)).to(self.device)
- actions = torch.LongTensor(actions).to(self.device)
- rewards = torch.FloatTensor(rewards).to(self.device)
- next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
- dones = torch.FloatTensor(dones).to(self.device)
+ states_tensor = torch.FloatTensor(normalized_states).to(self.device)
+ actions_tensor = torch.LongTensor(actions).to(self.device)
+ rewards_tensor = torch.FloatTensor(rewards).to(self.device)
+ next_states_tensor = torch.FloatTensor(normalized_next_states).to(self.device)
+ dones_tensor = torch.FloatTensor(dones).to(self.device)
# Get current Q values
- current_q_values, extrema_pred = self.policy_net(states)
- current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
+ current_q_values, extrema_pred = self.policy_net(states_tensor)
+ current_q_values = current_q_values.gather(1, actions_tensor.unsqueeze(1))
- # Get next Q values from target network
+ # Get next Q values from target network (Double DQN approach)
with torch.no_grad():
- next_q_values, _ = self.target_net(next_states)
- next_q_values = next_q_values.max(1)[0]
- target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
+ # Get actions from policy network
+ next_actions, _ = self.policy_net(next_states_tensor)
+ next_actions = next_actions.max(1)[1].unsqueeze(1)
+
+ # Get Q values from target network for those actions
+ next_q_values, _ = self.target_net(next_states_tensor)
+ next_q_values = next_q_values.gather(1, next_actions).squeeze(1)
+
+ # Compute target Q values
+ target_q_values = rewards_tensor + (1 - dones_tensor) * self.gamma * next_q_values
- # Compute Q-learning loss
- q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
+ # Clamp target values to prevent extreme values
+ target_q_values = torch.clamp(target_q_values, -100, 100)
- # If we have extrema labels (not in this implementation yet),
- # we could add an additional loss for extrema prediction
- # This would require labels for whether each state is near an extrema
-
- # Total loss is just Q-learning loss for now
- loss = q_loss
+ # Compute Huber loss (more robust to outliers than MSE)
+ loss = nn.SmoothL1Loss()(current_q_values.squeeze(), target_q_values)
# Optimize
self.optimizer.zero_grad()
loss.backward()
+
+ # Apply gradient clipping to prevent exploding gradients
+ nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
+
self.optimizer.step()
# Update target network if needed
@@ -200,37 +310,77 @@ class DQNAgent:
"""
if len(states) == 0:
return 0.0
+
+ # Normalize states
+ normalized_states = np.array([self._normalize_state(state) for state in states])
+ normalized_next_states = np.array([self._normalize_state(state) for state in next_states])
# Convert to tensors
- states = torch.FloatTensor(np.array(states)).to(self.device)
- actions = torch.LongTensor(actions).to(self.device)
- rewards = torch.FloatTensor(rewards).to(self.device)
- next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
- dones = torch.FloatTensor(dones).to(self.device)
+ states_tensor = torch.FloatTensor(normalized_states).to(self.device)
+ actions_tensor = torch.LongTensor(actions).to(self.device)
+ rewards_tensor = torch.FloatTensor(rewards).to(self.device)
+ next_states_tensor = torch.FloatTensor(normalized_next_states).to(self.device)
+ dones_tensor = torch.FloatTensor(dones).to(self.device)
# Forward pass
- current_q_values, extrema_pred = self.policy_net(states)
- current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
+ current_q_values, extrema_pred = self.policy_net(states_tensor)
+ current_q_values = current_q_values.gather(1, actions_tensor.unsqueeze(1))
- # Get next Q values
+ # Get next Q values (Double DQN approach)
with torch.no_grad():
- next_q_values, _ = self.target_net(next_states)
- next_q_values = next_q_values.max(1)[0]
- target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
+ next_actions, _ = self.policy_net(next_states_tensor)
+ next_actions = next_actions.max(1)[1].unsqueeze(1)
- # Higher weight for extrema training
- q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
+ next_q_values, _ = self.target_net(next_states_tensor)
+ next_q_values = next_q_values.gather(1, next_actions).squeeze(1)
+
+ target_q_values = rewards_tensor + (1 - dones_tensor) * self.gamma * next_q_values
+
+ # Clamp target values
+ target_q_values = torch.clamp(target_q_values, -100, 100)
- # Full loss is just Q-learning loss
+ # Use Huber loss for extrema training
+ q_loss = nn.SmoothL1Loss()(current_q_values.squeeze(), target_q_values)
+
+ # Full loss
loss = q_loss
# Optimize
self.optimizer.zero_grad()
loss.backward()
+ nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0)
self.optimizer.step()
return loss.item()
+ def update_learning_metrics(self, episode_reward, best_reward_threshold=0.01):
+ """Update learning metrics and perform learning rate adjustments if needed"""
+ # Update average reward with exponential moving average
+ if self.avg_reward == 0:
+ self.avg_reward = episode_reward
+ else:
+ self.avg_reward = 0.95 * self.avg_reward + 0.05 * episode_reward
+
+ # Check if we're making sufficient progress
+ if episode_reward > (1 + best_reward_threshold) * self.best_reward:
+ self.best_reward = episode_reward
+ self.no_improvement_count = 0
+ return True # Improved
+ else:
+ self.no_improvement_count += 1
+
+ # If no improvement for a while, adjust learning rate
+ if self.no_improvement_count >= 10:
+ current_lr = self.optimizer.param_groups[0]['lr']
+ new_lr = current_lr * 0.5
+ if new_lr >= 1e-6: # Don't reduce below minimum threshold
+ for param_group in self.optimizer.param_groups:
+ param_group['lr'] = new_lr
+ logger.info(f"Reducing learning rate from {current_lr} to {new_lr}")
+ self.no_improvement_count = 0
+
+ return False # No improvement
+
def save(self, path: str):
"""Save model and agent state"""
os.makedirs(os.path.dirname(path), exist_ok=True)
@@ -246,9 +396,13 @@ class DQNAgent:
'epsilon': self.epsilon,
'update_count': self.update_count,
'losses': self.losses,
- 'optimizer_state': self.optimizer.state_dict()
+ 'optimizer_state': self.optimizer.state_dict(),
+ 'best_reward': self.best_reward,
+ 'avg_reward': self.avg_reward
}
+
torch.save(state, f"{path}_agent_state.pt")
+ logger.info(f"Agent state saved to {path}_agent_state.pt")
def load(self, path: str):
"""Load model and agent state"""
@@ -259,8 +413,19 @@ class DQNAgent:
self.target_net.load(f"{path}_target")
# Load agent state
- state = torch.load(f"{path}_agent_state.pt")
- self.epsilon = state['epsilon']
- self.update_count = state['update_count']
- self.losses = state['losses']
- self.optimizer.load_state_dict(state['optimizer_state'])
\ No newline at end of file
+ try:
+ agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device)
+ self.epsilon = agent_state['epsilon']
+ self.update_count = agent_state['update_count']
+ self.losses = agent_state['losses']
+ self.optimizer.load_state_dict(agent_state['optimizer_state'])
+
+ # Load additional metrics if they exist
+ if 'best_reward' in agent_state:
+ self.best_reward = agent_state['best_reward']
+ if 'avg_reward' in agent_state:
+ self.avg_reward = agent_state['avg_reward']
+
+ logger.info(f"Agent state loaded from {path}_agent_state.pt")
+ except FileNotFoundError:
+ logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values")
\ No newline at end of file
diff --git a/NN/models/simple_cnn.py b/NN/models/simple_cnn.py
index f6293c0..b4b637b 100644
--- a/NN/models/simple_cnn.py
+++ b/NN/models/simple_cnn.py
@@ -44,13 +44,46 @@ class PricePatternAttention(nn.Module):
return output, attn_weights
+class AdaptiveNorm(nn.Module):
+ """
+ Adaptive normalization layer that chooses between different normalization
+ methods based on input dimensions
+ """
+ def __init__(self, num_features):
+ super(AdaptiveNorm, self).__init__()
+ self.batch_norm = nn.BatchNorm1d(num_features, affine=True)
+ self.group_norm = nn.GroupNorm(min(32, num_features), num_features)
+ self.layer_norm = nn.LayerNorm([num_features, 1])
+
+ def forward(self, x):
+ # Check input dimensions
+ batch_size, channels, seq_len = x.size()
+
+ # Choose normalization method:
+ # - Batch size > 1 and seq_len > 1: BatchNorm
+ # - Batch size == 1 or seq_len == 1: GroupNorm
+ # - Fallback for extreme cases: LayerNorm
+ if batch_size > 1 and seq_len > 1:
+ return self.batch_norm(x)
+ elif seq_len > 1:
+ return self.group_norm(x)
+ else:
+ # For 1D inputs (seq_len=1), we need to adjust the layer norm
+ # to the actual input size
+ if not hasattr(self, 'layer_norm_1d') or self.layer_norm_1d.normalized_shape[0] != channels:
+ self.layer_norm_1d = nn.LayerNorm([channels, seq_len]).to(x.device)
+ return self.layer_norm_1d(x)
+
class CNNModelPyTorch(nn.Module):
"""
CNN model for trading with multiple timeframes
"""
- def __init__(self, window_size, num_features, output_size, timeframes):
+ def __init__(self, window_size=20, num_features=5, output_size=3, timeframes=None):
super(CNNModelPyTorch, self).__init__()
+ if timeframes is None:
+ timeframes = [1]
+
self.window_size = window_size
self.num_features = num_features
self.output_size = output_size
@@ -73,27 +106,28 @@ class CNNModelPyTorch(nn.Module):
"""Create all model layers with current feature dimensions"""
# Convolutional layers - use total_features as input channels
self.conv1 = nn.Conv1d(self.total_features, 64, kernel_size=3, padding=1)
- self.bn1 = nn.BatchNorm1d(64)
+ self.norm1 = AdaptiveNorm(64)
+ self.dropout1 = nn.Dropout(0.2)
self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
- self.bn2 = nn.BatchNorm1d(128)
+ self.norm2 = AdaptiveNorm(128)
+ self.dropout2 = nn.Dropout(0.3)
self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
- self.bn3 = nn.BatchNorm1d(256)
+ self.norm3 = AdaptiveNorm(256)
+ self.dropout3 = nn.Dropout(0.4)
# Add price pattern attention layer
self.attention = PricePatternAttention(256)
# Extrema detection specialized convolutional layer
- self.extrema_conv = nn.Conv1d(256, 128, kernel_size=5, padding=2)
- self.extrema_bn = nn.BatchNorm1d(128)
+ self.extrema_conv = nn.Conv1d(256, 128, kernel_size=3, padding=1) # Smaller kernel for small inputs
+ self.extrema_norm = AdaptiveNorm(128)
- # Calculate size after convolutions - adjusted for attention output
- conv_output_size = self.window_size * 256
-
- # Fully connected layers
- self.fc1 = nn.Linear(conv_output_size, 512)
+ # Fully connected layers - input size will be determined dynamically
+ self.fc1 = None # Will be initialized in forward pass
self.fc2 = nn.Linear(512, 256)
+ self.dropout_fc = nn.Dropout(0.5)
# Advantage and Value streams (Dueling DQN architecture)
self.fc3 = nn.Linear(256, self.output_size) # Advantage stream
@@ -131,46 +165,96 @@ class CNNModelPyTorch(nn.Module):
# Ensure input is on the correct device
x = x.to(self.device)
- # Check and handle if input dimensions don't match model expectations
- batch_size, window_len, feature_dim = x.size()
- if feature_dim != self.total_features:
- logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features}), rebuilding layers")
- self.rebuild_conv_layers(feature_dim)
+ # Check input dimensions and reshape as needed
+ if len(x.size()) == 2:
+ # If input is [batch_size, features], reshape to [batch_size, features, 1]
+ batch_size, feature_dim = x.size()
+
+ # Check and handle if input features don't match model expectations
+ if feature_dim != self.total_features:
+ logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features}), rebuilding layers")
+ self.rebuild_conv_layers(feature_dim)
+
+ # For 1D input, use a sequence length of 1
+ seq_len = 1
+ x = x.unsqueeze(2) # Reshape to [batch, features, 1]
+ elif len(x.size()) == 3:
+ # Standard case: [batch_size, window_size, features]
+ batch_size, seq_len, feature_dim = x.size()
+
+ # Check and handle if input dimensions don't match model expectations
+ if feature_dim != self.total_features:
+ logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features}), rebuilding layers")
+ self.rebuild_conv_layers(feature_dim)
+
+ # Reshape input: [batch, window_size, features] -> [batch, features, window_size]
+ x = x.permute(0, 2, 1)
+ else:
+ raise ValueError(f"Unexpected input shape: {x.size()}, expected 2D or 3D tensor")
- # Reshape input: [batch, window_size, features] -> [batch, channels, window_size]
- x = x.permute(0, 2, 1)
-
- # Convolutional layers
- x = F.relu(self.bn1(self.conv1(x)))
- x = F.relu(self.bn2(self.conv2(x)))
- x = F.relu(self.bn3(self.conv3(x)))
+ # Convolutional layers with dropout - safely handle small spatial dimensions
+ try:
+ x = self.dropout1(F.relu(self.norm1(self.conv1(x))))
+ x = self.dropout2(F.relu(self.norm2(self.conv2(x))))
+ x = self.dropout3(F.relu(self.norm3(self.conv3(x))))
+ except Exception as e:
+ logger.warning(f"Error in convolutional layers: {str(e)}")
+ # Fallback for very small inputs: skip some convolutions
+ if seq_len < 3:
+ # Apply a simpler convolution for very small inputs
+ x = F.relu(self.conv1(x))
+ x = F.relu(self.conv2(x))
+ # Skip last conv if we get dimension errors
+ try:
+ x = F.relu(self.conv3(x))
+ except:
+ pass
# Store conv features for extrema detection
conv_features = x
- # Reshape for attention: [batch, channels, window_size] -> [batch, window_size, channels]
- x_attention = x.permute(0, 2, 1)
+ # Get the current shape after convolutions
+ _, channels, conv_seq_len = x.size()
- # Apply attention
- attention_output, attention_weights = self.attention(x_attention)
+ # Initialize fc1 if not created yet or if the shape has changed
+ if self.fc1 is None:
+ flattened_size = channels * conv_seq_len
+ logger.info(f"Initializing fc1 with input size {flattened_size}")
+ self.fc1 = nn.Linear(flattened_size, 512).to(self.device)
- # We'll use attention directly without the residual connection
- # to avoid dimension mismatch issues
- attention_reshaped = attention_output.permute(0, 2, 1) # [batch, channels, window_size]
+ # Apply extrema detection safely
+ try:
+ extrema_features = F.relu(self.extrema_norm(self.extrema_conv(conv_features)))
+ except Exception as e:
+ logger.warning(f"Error in extrema detection: {str(e)}")
+ extrema_features = conv_features # Fallback
- # Apply extrema detection specialized layer
- extrema_features = F.relu(self.extrema_bn(self.extrema_conv(conv_features)))
+ # Handle attention for small sequence lengths
+ if conv_seq_len > 1:
+ # Reshape for attention: [batch, channels, seq_len] -> [batch, seq_len, channels]
+ x_attention = x.permute(0, 2, 1)
+
+ # Apply attention
+ try:
+ attention_output, attention_weights = self.attention(x_attention)
+ except Exception as e:
+ logger.warning(f"Error in attention layer: {str(e)}")
+ # Fallback: don't use attention
- # Use attention features directly instead of residual connection
- # to avoid dimension mismatches
- x = conv_features # Just use the convolutional features
+ # Flatten - get the actual shape for this batch
+ flattened_size = channels * conv_seq_len
+ x = x.view(batch_size, flattened_size)
- # Flatten
- x = x.view(batch_size, -1)
+ # Check if we need to recreate fc1 with the correct size
+ if self.fc1.in_features != flattened_size:
+ logger.info(f"Recreating fc1 layer to match input size {flattened_size}")
+ self.fc1 = nn.Linear(flattened_size, 512).to(self.device)
+ # Reinitialize optimizer after changing the model
+ self.optimizer = optim.Adam(self.parameters(), lr=0.001)
- # Fully connected layers
+ # Fully connected layers with dropout
x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
+ x = self.dropout_fc(F.relu(self.fc2(x)))
# Split into advantage and value streams
advantage = self.fc3(x)
diff --git a/NN/utils/trading_env.py b/NN/utils/trading_env.py
index c33fc09..d02602c 100644
--- a/NN/utils/trading_env.py
+++ b/NN/utils/trading_env.py
@@ -3,6 +3,10 @@ import gym
from gym import spaces
from typing import Dict, Tuple, List
import pandas as pd
+import logging
+
+# Configure logger
+logger = logging.getLogger(__name__)
class TradingEnvironment(gym.Env):
"""
@@ -12,97 +16,284 @@ class TradingEnvironment(gym.Env):
data: pd.DataFrame,
initial_balance: float = 100.0,
fee_rate: float = 0.0002,
- max_steps: int = 1000):
+ max_steps: int = 1000,
+ window_size: int = 20,
+ risk_aversion: float = 0.2, # Controls how much to penalize volatility
+ price_scaling: str = 'zscore', # 'zscore', 'minmax', or 'raw'
+ reward_scaling: float = 10.0, # Scale factor for rewards
+ episode_penalty: float = 0.1): # Penalty for active positions at end of episode
super(TradingEnvironment, self).__init__()
self.data = data
self.initial_balance = initial_balance
self.fee_rate = fee_rate
self.max_steps = max_steps
+ self.window_size = window_size
+ self.risk_aversion = risk_aversion
+ self.price_scaling = price_scaling
+ self.reward_scaling = reward_scaling
+ self.episode_penalty = episode_penalty
- # Action space: 0 (SELL), 1 (HOLD), 2 (BUY)
+ # Preprocess data if needed
+ self._preprocess_data()
+
+ # Action space: 0 (BUY), 1 (SELL), 2 (HOLD)
self.action_space = spaces.Discrete(3)
# Observation space: price data, technical indicators, and account state
+ feature_dim = self.data.shape[1] + 3 # Adding position, equity, unrealized_pnl
self.observation_space = spaces.Box(
low=-np.inf,
high=np.inf,
- shape=(data.shape[1],), # Number of features
+ shape=(feature_dim,),
dtype=np.float32
)
# Initialize state
self.reset()
+ def _preprocess_data(self):
+ """Preprocess data - normalize or standardize features"""
+ # Store the original data for reference
+ self.original_data = self.data.copy()
+
+ # Normalize price data based on the selected method
+ if self.price_scaling == 'zscore':
+ # For each feature, apply z-score normalization
+ for col in self.data.columns:
+ if col in ['open', 'high', 'low', 'close']:
+ mean = self.data[col].mean()
+ std = self.data[col].std()
+ if std > 0:
+ self.data[col] = (self.data[col] - mean) / std
+ # Normalize volume separately
+ elif col == 'volume':
+ mean = self.data[col].mean()
+ std = self.data[col].std()
+ if std > 0:
+ self.data[col] = (self.data[col] - mean) / std
+
+ elif self.price_scaling == 'minmax':
+ # For each feature, apply min-max scaling
+ for col in self.data.columns:
+ min_val = self.data[col].min()
+ max_val = self.data[col].max()
+ if max_val > min_val:
+ self.data[col] = (self.data[col] - min_val) / (max_val - min_val)
+
def reset(self) -> np.ndarray:
"""Reset the environment to initial state"""
- self.current_step = 0
+ self.current_step = self.window_size
self.balance = self.initial_balance
- self.position = 0 # 0: no position, 1: long position
+ self.position = 0 # 0: no position, 1: long position, -1: short position
self.entry_price = 0
+ self.entry_time = 0
self.total_trades = 0
self.winning_trades = 0
+ self.losing_trades = 0
self.total_pnl = 0
self.balance_history = [self.initial_balance]
+ self.equity_history = [self.initial_balance]
self.max_balance = self.initial_balance
+ self.max_drawdown = 0
+
+ # Trading performance metrics
+ self.trade_durations = [] # Track how long trades are held
+ self.returns = [] # Track returns of each trade
+
+ # For analyzing trade clustering
+ self.last_action_time = 0
+ self.actions_taken = []
return self._get_observation()
def _get_observation(self) -> np.ndarray:
- """Get current observation state"""
- return self.data.iloc[self.current_step].values
+ """Get current observation state with account information"""
+ # Get market data for the current step
+ market_data = self.data.iloc[self.current_step].values
+
+ # Get current price
+ current_price = self.original_data.iloc[self.current_step]['close']
+
+ # Calculate unrealized PnL
+ unrealized_pnl = 0
+ if self.position != 0:
+ price_diff = current_price - self.entry_price
+ unrealized_pnl = self.position * price_diff
+
+ # Calculate total equity (balance + unrealized PnL)
+ equity = self.balance + unrealized_pnl
+
+ # Normalize account state
+ normalized_position = self.position # -1, 0, or 1
+ normalized_equity = equity / self.initial_balance - 1.0 # Percent change from initial
+ normalized_unrealized_pnl = unrealized_pnl / self.initial_balance if self.initial_balance > 0 else 0
+
+ # Combine market data with account state
+ account_state = np.array([normalized_position, normalized_equity, normalized_unrealized_pnl])
+ observation = np.concatenate([market_data, account_state])
+
+ # Handle any NaN values
+ observation = np.nan_to_num(observation, nan=0.0)
+
+ return observation
def _calculate_reward(self, action: int) -> float:
- """Calculate reward based on action and outcome"""
- current_price = self.data.iloc[self.current_step]['close']
+ """
+ Calculate reward based on action and outcome with improved risk-adjusted metrics
- # If we have an open position
- if self.position != 0:
- # Calculate PnL
- pnl = self.position * (current_price - self.entry_price) / self.entry_price
- fees = self.fee_rate * 2 # Entry and exit fees
+ Args:
+ action: The action taken (0=BUY, 1=SELL, 2=HOLD)
- # Close position
- if (action == 0 and self.position > 0) or (action == 2 and self.position < 0):
- net_pnl = pnl - fees
- self.total_pnl += net_pnl
- self.balance *= (1 + net_pnl)
+ Returns:
+ float: Calculated reward value
+ """
+ # Get current price and next price
+ current_price = self.original_data.iloc[self.current_step]['close']
+
+ # Default reward is slightly negative to discourage excessive trading
+ reward = -0.0001
+ pnl = 0.0
+
+ # Handle different actions based on current position
+ if self.position == 0: # No position
+ if action == 0: # BUY
+ self.position = 1
+ self.entry_price = current_price
+ self.entry_time = self.current_step
+ reward = -self.fee_rate # Small penalty for trading cost
+
+ elif action == 1: # SELL (start short position)
+ self.position = -1
+ self.entry_price = current_price
+ self.entry_time = self.current_step
+ reward = -self.fee_rate # Small penalty for trading cost
+
+ # else action == 2 (HOLD) - keep the small negative reward
+
+ elif self.position > 0: # Long position
+ if action == 1: # SELL (close long)
+ # Calculate profit/loss
+ price_diff = current_price - self.entry_price
+ pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees
+
+ # Adjust reward based on PnL and risk
+ reward = pnl * self.reward_scaling
+
+ # Track trade performance
+ self.total_trades += 1
+ if pnl > 0:
+ self.winning_trades += 1
+ else:
+ self.losing_trades += 1
+
+ # Calculate trade duration
+ trade_duration = self.current_step - self.entry_time
+ self.trade_durations.append(trade_duration)
+
+ # Update returns list
+ self.returns.append(pnl)
+
+ # Update balance and reset position
+ self.balance *= (1 + pnl)
self.balance_history.append(self.balance)
self.max_balance = max(self.max_balance, self.balance)
+ self.total_pnl += pnl
- self.total_trades += 1
- if net_pnl > 0:
- self.winning_trades += 1
-
- # Reward based on PnL
- reward = net_pnl * 100 # Scale up for better learning
-
- # Additional reward for win rate
- win_rate = self.winning_trades / max(1, self.total_trades)
- reward += win_rate * 0.1
-
+ # Reset position
self.position = 0
- return reward
+
+ elif action == 0: # BUY (while already long)
+ # Penalize trying to increase an already active position
+ reward = -0.001
+
+ # else action == 2 (HOLD) - calculate unrealized P&L for reward
+ else:
+ price_diff = current_price - self.entry_price
+ unrealized_pnl = price_diff / self.entry_price
+
+ # Small reward/penalty based on unrealized P&L
+ reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions
+
+ elif self.position < 0: # Short position
+ if action == 0: # BUY (close short)
+ # Calculate profit/loss
+ price_diff = self.entry_price - current_price
+ pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees
+
+ # Adjust reward based on PnL and risk
+ reward = pnl * self.reward_scaling
+
+ # Track trade performance
+ self.total_trades += 1
+ if pnl > 0:
+ self.winning_trades += 1
+ else:
+ self.losing_trades += 1
+
+ # Calculate trade duration
+ trade_duration = self.current_step - self.entry_time
+ self.trade_durations.append(trade_duration)
+
+ # Update returns list
+ self.returns.append(pnl)
+
+ # Update balance and reset position
+ self.balance *= (1 + pnl)
+ self.balance_history.append(self.balance)
+ self.max_balance = max(self.max_balance, self.balance)
+ self.total_pnl += pnl
+
+ # Reset position
+ self.position = 0
+
+ elif action == 1: # SELL (while already short)
+ # Penalize trying to increase an already active position
+ reward = -0.001
+
+ # else action == 2 (HOLD) - calculate unrealized P&L for reward
+ else:
+ price_diff = self.entry_price - current_price
+ unrealized_pnl = price_diff / self.entry_price
+
+ # Small reward/penalty based on unrealized P&L
+ reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions
+
+ # Record the action
+ self.actions_taken.append(action)
+ self.last_action_time = self.current_step
+
+ # Update equity history (balance + unrealized P&L)
+ current_equity = self.balance
+ if self.position != 0:
+ # Calculate unrealized P&L
+ if self.position > 0: # Long
+ price_diff = current_price - self.entry_price
+ unrealized_pnl = price_diff / self.entry_price * self.balance
+ else: # Short
+ price_diff = self.entry_price - current_price
+ unrealized_pnl = price_diff / self.entry_price * self.balance
+
+ current_equity = self.balance + unrealized_pnl
- # Hold position
- return pnl * 0.1 # Small reward for holding profitable positions
+ self.equity_history.append(current_equity)
- # No position
- if action == 1: # HOLD
- return 0
+ # Calculate current drawdown
+ peak_equity = max(self.equity_history)
+ current_drawdown = (peak_equity - current_equity) / peak_equity if peak_equity > 0 else 0
+ self.max_drawdown = max(self.max_drawdown, current_drawdown)
- # Open new position
- if action in [0, 2]: # SELL or BUY
- self.position = -1 if action == 0 else 1
- self.entry_price = current_price
- return -self.fee_rate # Small penalty for trading
+ # Apply risk aversion factor - penalize volatility
+ if len(self.returns) > 1:
+ returns_std = np.std(self.returns)
+ reward -= returns_std * self.risk_aversion
- return 0
+ return reward, pnl
def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]:
"""Execute one step in the environment"""
- # Calculate reward
- reward = self._calculate_reward(action)
+ # Calculate reward and update state
+ reward, pnl = self._calculate_reward(action)
# Move to next step
self.current_step += 1
@@ -110,26 +301,70 @@ class TradingEnvironment(gym.Env):
# Check if episode is done
done = self.current_step >= min(self.max_steps - 1, len(self.data) - 1)
+ # Apply penalty if episode ends with open position
+ if done and self.position != 0:
+ reward -= self.episode_penalty
+
+ # Force close the position at the end if still open
+ current_price = self.original_data.iloc[self.current_step]['close']
+ if self.position > 0: # Long position
+ price_diff = current_price - self.entry_price
+ pnl = price_diff / self.entry_price - 2 * self.fee_rate
+ else: # Short position
+ price_diff = self.entry_price - current_price
+ pnl = price_diff / self.entry_price - 2 * self.fee_rate
+
+ # Update balance
+ self.balance *= (1 + pnl)
+ self.total_pnl += pnl
+
+ # Track trade
+ self.total_trades += 1
+ if pnl > 0:
+ self.winning_trades += 1
+ else:
+ self.losing_trades += 1
+
+ # Reset position
+ self.position = 0
+
# Get next observation
observation = self._get_observation()
- # Calculate max drawdown
- max_drawdown = 0
- if len(self.balance_history) > 1:
- peak = self.balance_history[0]
- for balance in self.balance_history:
- peak = max(peak, balance)
- drawdown = (peak - balance) / peak
- max_drawdown = max(max_drawdown, drawdown)
+ # Calculate sharpe ratio and sortino ratio if possible
+ sharpe_ratio = 0
+ sortino_ratio = 0
+ win_rate = self.winning_trades / max(1, self.total_trades)
+
+ if len(self.returns) > 1:
+ mean_return = np.mean(self.returns)
+ std_return = np.std(self.returns)
+ if std_return > 0:
+ sharpe_ratio = mean_return / std_return
+
+ # For sortino, we only consider downside deviation
+ downside_returns = [r for r in self.returns if r < 0]
+ if downside_returns:
+ downside_deviation = np.std(downside_returns)
+ if downside_deviation > 0:
+ sortino_ratio = mean_return / downside_deviation
+
+ # Calculate average trade duration
+ avg_trade_duration = np.mean(self.trade_durations) if self.trade_durations else 0
# Additional info
info = {
'balance': self.balance,
'position': self.position,
'total_trades': self.total_trades,
- 'win_rate': self.winning_trades / max(1, self.total_trades),
+ 'win_rate': win_rate,
'total_pnl': self.total_pnl,
- 'max_drawdown': max_drawdown
+ 'max_drawdown': self.max_drawdown,
+ 'sharpe_ratio': sharpe_ratio,
+ 'sortino_ratio': sortino_ratio,
+ 'avg_trade_duration': avg_trade_duration,
+ 'pnl': pnl,
+ 'gain': (self.balance - self.initial_balance) / self.initial_balance
}
return observation, reward, done, info
@@ -143,20 +378,19 @@ class TradingEnvironment(gym.Env):
print(f"Total Trades: {self.total_trades}")
print(f"Win Rate: {self.winning_trades/max(1, self.total_trades):.2%}")
print(f"Total PnL: ${self.total_pnl:.2f}")
- print(f"Max Drawdown: {self._calculate_max_drawdown():.2%}")
+ print(f"Max Drawdown: {self.max_drawdown:.2%}")
+ print(f"Sharpe Ratio: {self._calculate_sharpe_ratio():.4f}")
print("-" * 50)
-
- def _calculate_max_drawdown(self):
- """Calculate maximum drawdown from balance history"""
- if len(self.balance_history) <= 1:
+
+ def _calculate_sharpe_ratio(self):
+ """Calculate Sharpe ratio from returns"""
+ if len(self.returns) < 2:
return 0.0
- peak = self.balance_history[0]
- max_drawdown = 0.0
+ mean_return = np.mean(self.returns)
+ std_return = np.std(self.returns)
- for balance in self.balance_history:
- peak = max(peak, balance)
- drawdown = (peak - balance) / peak
- max_drawdown = max(max_drawdown, drawdown)
+ if std_return == 0:
+ return 0.0
- return max_drawdown
\ No newline at end of file
+ return mean_return / std_return
\ No newline at end of file
diff --git a/realtime.py b/realtime.py
index 0bfa47f..b5ea1e5 100644
--- a/realtime.py
+++ b/realtime.py
@@ -22,6 +22,112 @@ import tzlocal
import threading
import random
import dash_bootstrap_components as dbc
+import uuid
+
+class BinanceHistoricalData:
+ """
+ Class for fetching historical price data from Binance.
+ """
+ def __init__(self):
+ self.base_url = "https://api.binance.com/api/v3"
+
+ def get_historical_candles(self, symbol, interval_seconds=3600, limit=1000):
+ """
+ Fetch historical candles from Binance API.
+
+ Args:
+ symbol (str): Trading pair symbol (e.g., "BTC/USDT")
+ interval_seconds (int): Timeframe in seconds (e.g., 3600 for 1h)
+ limit (int): Number of candles to fetch
+
+ Returns:
+ pd.DataFrame: DataFrame with OHLCV data
+ """
+ # Convert interval_seconds to Binance interval format
+ interval_map = {
+ 1: "1s",
+ 60: "1m",
+ 300: "5m",
+ 900: "15m",
+ 1800: "30m",
+ 3600: "1h",
+ 14400: "4h",
+ 86400: "1d"
+ }
+
+ interval = interval_map.get(interval_seconds, "1h")
+
+ # Format symbol for Binance API (remove slash)
+ formatted_symbol = symbol.replace("/", "")
+
+ try:
+ # Build URL for klines endpoint
+ url = f"{self.base_url}/klines"
+ params = {
+ "symbol": formatted_symbol,
+ "interval": interval,
+ "limit": limit
+ }
+
+ # Make the request
+ response = requests.get(url, params=params)
+ response.raise_for_status()
+
+ # Parse the response
+ data = response.json()
+
+ # Create dataframe
+ df = pd.DataFrame(data, columns=[
+ "timestamp", "open", "high", "low", "close", "volume",
+ "close_time", "quote_asset_volume", "number_of_trades",
+ "taker_buy_base_asset_volume", "taker_buy_quote_asset_volume", "ignore"
+ ])
+
+ # Convert timestamp to datetime
+ df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms")
+
+ # Convert price columns to float
+ for col in ["open", "high", "low", "close", "volume"]:
+ df[col] = df[col].astype(float)
+
+ # Sort by timestamp
+ df = df.sort_values("timestamp")
+
+ logger.info(f"Fetched {len(df)} candles for {symbol} ({interval})")
+ return df
+
+ except Exception as e:
+ logger.error(f"Error fetching historical data from Binance: {str(e)}")
+ # Return empty dataframe on error
+ return pd.DataFrame()
+
+ def get_recent_trades(self, symbol, limit=1000):
+ """Get recent trades for a symbol"""
+ formatted_symbol = symbol.replace("/", "")
+
+ try:
+ url = f"{self.base_url}/trades"
+ params = {
+ "symbol": formatted_symbol,
+ "limit": limit
+ }
+
+ response = requests.get(url, params=params)
+ response.raise_for_status()
+
+ data = response.json()
+
+ # Create dataframe
+ df = pd.DataFrame(data)
+ df["time"] = pd.to_datetime(df["time"], unit="ms")
+ df["price"] = df["price"].astype(float)
+ df["qty"] = df["qty"].astype(float)
+
+ return df
+
+ except Exception as e:
+ logger.error(f"Error fetching recent trades: {str(e)}")
+ return pd.DataFrame()
# Configure logging with more detailed format
logging.basicConfig(
@@ -63,7 +169,7 @@ def setup_neural_network():
try:
# Get configuration from environment variables or use defaults
- symbol = os.environ.get('NN_SYMBOL', 'BTC/USDT')
+ symbol = os.environ.get('NN_SYMBOL', 'ETH/USDT')
timeframes = os.environ.get('NN_TIMEFRAMES', '1m,5m,1h,4h,1d').split(',')
output_size = int(os.environ.get('NN_OUTPUT_SIZE', '3')) # 3 for BUY/HOLD/SELL
@@ -216,1217 +322,192 @@ def convert_to_local_time(timestamp):
logger.error(f"Error converting timestamp to local time: {str(e)}")
return timestamp
-class TradeTickStorage:
- """Storage for trade ticks with a maximum age limit"""
-
- def __init__(self, symbol: str = None, max_age_seconds: int = 3600, use_sample_data: bool = True, log_no_ticks_warning: bool = False): # 1 hour by default, up from 30 min
- """Initialize the tick storage
-
- Args:
- symbol: Trading symbol
- max_age_seconds: Maximum age for ticks to be stored
- use_sample_data: If True, generate sample ticks when no real ticks available
- log_no_ticks_warning: If True, log a warning when no ticks are available
- """
- self.symbol = symbol
+class TickStorage:
+ """Simple storage for ticks and candles"""
+ def __init__(self):
self.ticks = []
- self.max_age_seconds = max_age_seconds
- self.last_cleanup_time = time.time()
- self.cleanup_interval = 60 # run cleanup every 60 seconds
- self.cache_dir = "cache"
- self.use_sample_data = use_sample_data
- self.log_no_ticks_warning = log_no_ticks_warning
- self.last_sample_price = 83000.0 # Starting price for sample data (for BTC)
- self.last_sample_time = time.time() * 1000 # Starting time for sample data
- self.last_tick_time = 0 # Initialize last_tick_time attribute
- self.tick_count = 0 # Initialize tick_count attribute
+ self.candles = {}
+ self.latest_price = None
- # Create cache directory if it doesn't exist
- if not os.path.exists(self.cache_dir):
- os.makedirs(self.cache_dir)
-
- # Try to load cached ticks
- self._load_cached_ticks()
+ def add_tick(self, price, volume=0, timestamp=None):
+ """Add a tick to the storage"""
+ if timestamp is None:
+ timestamp = datetime.now()
- logger.info(f"Initialized TradeTickStorage for {symbol} with max age: {max_age_seconds} seconds, cleanup interval: {self.cleanup_interval} seconds")
-
- def add_tick(self, tick: Dict):
- """Add a tick to storage
-
- Args:
- tick: Tick data dict with fields:
- price: The price
- volume: The volume
- timestamp: Timestamp in milliseconds
- """
- if not tick:
- return
-
- # Check if we need to generate a timestamp
- if 'timestamp' not in tick:
- tick['timestamp'] = int(time.time() * 1000) # Current time in ms
-
- # Ensure timestamp is an integer (milliseconds since epoch)
- if not isinstance(tick['timestamp'], int):
- try:
- # Try to convert from float or string
- tick['timestamp'] = int(float(tick['timestamp']))
- except (ValueError, TypeError):
- # If conversion fails, use current time
- tick['timestamp'] = int(time.time() * 1000)
-
- # Set default volume if not present
- if 'volume' not in tick:
- tick['volume'] = 0.01 # Default small volume
-
- # Add tick to storage with a copy to avoid mutation
- self.ticks.append(tick.copy())
-
- # Keep track of latest tick for stats
- self.last_tick_time = max(self.last_tick_time, tick['timestamp'])
-
- # Cache every 100 ticks to avoid data loss
- self.tick_count += 1
- if self.tick_count % 100 == 0:
- self._cache_ticks()
-
- # Periodically clean up old ticks
- if self.tick_count % 1000 == 0:
- self._cleanup()
-
- def _cleanup(self):
- """Remove ticks older than max_age_seconds"""
- # Get current time in milliseconds
- now = int(time.time() * 1000)
-
- # Remove old ticks
- cutoff = now - (self.max_age_seconds * 1000)
- original_count = len(self.ticks)
- self.ticks = [tick for tick in self.ticks if tick['timestamp'] >= cutoff]
- removed = original_count - len(self.ticks)
-
- if removed > 0:
- logger.debug(f"Cleaned up {removed} old ticks, remaining: {len(self.ticks)}")
-
- def _load_cached_ticks(self):
- """Load cached ticks from disk on startup"""
- # Create symbol-specific filename
- symbol_safe = self.symbol.replace("/", "_").replace("-", "_").lower()
- cache_file = os.path.join(self.cache_dir, f"{symbol_safe}_recent_ticks.csv")
-
- if not os.path.exists(cache_file):
- logger.info(f"No cached ticks found for {self.symbol}")
- return
-
- try:
- # Check if cache is fresh (less than 10 minutes old)
- file_age = time.time() - os.path.getmtime(cache_file)
- if file_age > 600: # 10 minutes
- logger.info(f"Cached ticks for {self.symbol} are too old ({file_age:.1f}s), skipping")
- return
-
- # Load cached ticks
- tick_df = pd.read_csv(cache_file)
- if tick_df.empty:
- logger.info(f"Cached ticks file for {self.symbol} is empty")
- return
-
- # Convert to list of dicts and add to storage
- cached_ticks = tick_df.to_dict('records')
- self.ticks.extend(cached_ticks)
- logger.info(f"Loaded {len(cached_ticks)} cached ticks for {self.symbol} from {cache_file}")
- except Exception as e:
- logger.error(f"Error loading cached ticks for {self.symbol}: {str(e)}")
- import traceback
- logger.error(traceback.format_exc())
-
- def _cache_ticks(self):
- """Cache recent ticks to disk"""
- if not self.ticks:
- return
-
- # Get ticks from last 10 minutes
- now = int(time.time() * 1000) # Current time in ms
- cutoff = now - (600 * 1000) # 10 minutes in ms
- recent_ticks = [tick for tick in self.ticks if tick['timestamp'] >= cutoff]
-
- if not recent_ticks:
- logger.debug("No recent ticks to cache")
- return
-
- # Create symbol-specific filename
- symbol_safe = self.symbol.replace("/", "_").replace("-", "_").lower()
- cache_file = os.path.join(self.cache_dir, f"{symbol_safe}_recent_ticks.csv")
-
- # Save to disk
- try:
- tick_df = pd.DataFrame(recent_ticks)
- tick_df.to_csv(cache_file, index=False)
- logger.info(f"Cached {len(recent_ticks)} recent ticks for {self.symbol} to {cache_file}")
- except Exception as e:
- logger.error(f"Error caching ticks: {str(e)}")
-
- def get_latest_price(self) -> Optional[float]:
- """Get the latest price from the most recent tick"""
- if self.ticks:
- return self.ticks[-1].get('price')
- return None
-
- def get_price_stats(self) -> Dict:
- """Get stats about the prices in storage"""
- if not self.ticks:
- return {
- 'min': None,
- 'max': None,
- 'latest': None,
- 'count': 0,
- 'age_seconds': 0
- }
-
- prices = [tick['price'] for tick in self.ticks]
- latest_timestamp = self.ticks[-1]['timestamp']
- oldest_timestamp = self.ticks[0]['timestamp']
-
- return {
- 'min': min(prices),
- 'max': max(prices),
- 'latest': prices[-1],
- 'count': len(prices),
- 'age_seconds': (latest_timestamp - oldest_timestamp) / 1000
- }
-
- def get_ticks_as_df(self) -> pd.DataFrame:
- """Return ticks as a DataFrame"""
- if not self.ticks:
- logger.warning("No ticks available for DataFrame conversion")
- return pd.DataFrame()
-
- # Ensure we have fresh data
- self._cleanup()
-
- # Create a new list from ticks to avoid modifying the original data
- ticks_data = self.ticks.copy()
-
- # Ensure we have the latest ticks at the end of the DataFrame
- ticks_data.sort(key=lambda x: x['timestamp'])
-
- df = pd.DataFrame(ticks_data)
- if not df.empty:
- logger.debug(f"Converting timestamps for {len(df)} ticks")
- # Ensure timestamp column exists
- if 'timestamp' not in df.columns:
- logger.error("Tick data missing timestamp column")
- return pd.DataFrame()
-
- # Check timestamp datatype before conversion
- sample_ts = df['timestamp'].iloc[0] if len(df) > 0 else None
- logger.debug(f"Sample timestamp before conversion: {sample_ts}, type: {type(sample_ts)}")
-
- # Convert timestamps to datetime
- try:
- df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
- logger.debug(f"Timestamps converted to datetime successfully")
- if len(df) > 0:
- logger.debug(f"Sample converted timestamp: {df['timestamp'].iloc[0]}")
- except Exception as e:
- logger.error(f"Error converting timestamps: {str(e)}")
- import traceback
- logger.error(traceback.format_exc())
- return pd.DataFrame()
- return df
-
- def get_candles(self, interval_seconds: int = 1, start_time_ms: int = None, end_time_ms: int = None) -> pd.DataFrame:
- """Generate candlestick data from ticks
-
- Args:
- interval_seconds: Interval in seconds for each candle
- start_time_ms: Start time in milliseconds
- end_time_ms: End time in milliseconds
-
- Returns:
- DataFrame with candlestick data
- """
- # Get filtered ticks
- ticks = self.get_ticks_from_time(start_time_ms, end_time_ms)
-
- if not ticks:
- if self.use_sample_data:
- # Generate multiple sample ticks to create several candles
- current_time = int(time.time() * 1000)
- sample_ticks = []
-
- # Generate ticks for the past 10 intervals
- for i in range(20):
- # Base price with some trend
- base_price = self.last_sample_price * (1 + 0.0001 * (10 - i))
-
- # Add some randomness to the price
- random_factor = random.uniform(-0.002, 0.002) # Small random change
- tick_price = base_price * (1 + random_factor)
-
- # Create timestamp with appropriate offset
- tick_time = current_time - (i * interval_seconds * 1000 // 2)
-
- sample_tick = {
- 'price': tick_price,
- 'volume': random.uniform(0.01, 0.5),
- 'timestamp': tick_time,
- 'is_sample': True
- }
-
- sample_ticks.append(sample_tick)
-
- # Update the last sample values
- self.last_sample_price = sample_ticks[0]['price']
- self.last_sample_time = sample_ticks[0]['timestamp']
-
- # Add the sample ticks in chronological order
- for tick in sorted(sample_ticks, key=lambda x: x['timestamp']):
- self.add_tick(tick)
-
- # Try again with the new ticks
- ticks = self.get_ticks_from_time(start_time_ms, end_time_ms)
-
- if not ticks and self.log_no_ticks_warning:
- logger.warning("Still no ticks available after adding sample data")
- elif self.log_no_ticks_warning:
- logger.warning("No ticks available for candle formation")
- return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
- else:
- return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
-
- # Ensure ticks are up to date
- try:
- self._cleanup()
- except Exception as cleanup_error:
- logger.error(f"Error cleaning up ticks: {str(cleanup_error)}")
-
- df = pd.DataFrame(ticks)
- if df.empty:
- logger.warning("Tick DataFrame is empty after filtering/conversion")
- return pd.DataFrame()
-
- logger.info(f"Preparing to create candles from {len(df)} ticks with {interval_seconds}s interval")
-
- # First, ensure all required columns exist
- required_columns = ['timestamp', 'price', 'volume']
- for col in required_columns:
- if col not in df.columns:
- logger.error(f"Required column '{col}' missing from tick data")
- return pd.DataFrame()
-
- # Make sure DataFrame has no duplicated timestamps before setting index
- try:
- if 'timestamp' in df.columns:
- # Check for duplicate timestamps
- duplicate_count = df['timestamp'].duplicated().sum()
- if duplicate_count > 0:
- logger.warning(f"Found {duplicate_count} duplicate timestamps, keeping the last occurrence")
- # Keep the last occurrence of each timestamp
- df = df.drop_duplicates(subset='timestamp', keep='last')
-
- # Convert timestamp to datetime if it's not already
- if not pd.api.types.is_datetime64_any_dtype(df['timestamp']):
- logger.debug("Converting timestamp to datetime")
- # Try multiple approaches to convert timestamps
- try:
- # First, try to convert from milliseconds (integer timestamps)
- if pd.api.types.is_integer_dtype(df['timestamp']):
- df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
- else:
- # Otherwise try standard conversion
- df['timestamp'] = pd.to_datetime(df['timestamp'])
- except Exception as conv_error:
- logger.error(f"Error converting timestamps to datetime: {str(conv_error)}")
- # Try a fallback approach
- try:
- # Fallback for integer timestamps
- if df['timestamp'].iloc[0] > 1000000000000: # Check if milliseconds timestamp
- df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
- else: # Otherwise assume seconds
- df['timestamp'] = pd.to_datetime(df['timestamp'], unit='s')
- except Exception as fallback_error:
- logger.error(f"Fallback timestamp conversion failed: {str(fallback_error)}")
- return pd.DataFrame()
-
- # Use timestamp column for resampling
- df = df.set_index('timestamp')
- except Exception as prep_error:
- logger.error(f"Error preprocessing DataFrame for resampling: {str(prep_error)}")
- import traceback
- logger.error(traceback.format_exc())
- return pd.DataFrame()
-
- # Create interval string for resampling - use 's' instead of deprecated 'S'
- interval_str = f'{interval_seconds}s'
-
- # Resample to create OHLCV candles with multiple fallback options
- logger.debug(f"Resampling with interval: {interval_str}")
-
- candles = None
-
- # First attempt - individual column resampling
- try:
- # Check that price column exists and has enough data
- if 'price' not in df.columns:
- raise ValueError("Price column missing from DataFrame")
-
- if len(df) < 2:
- logger.warning("Not enough data points for resampling, using direct data")
- # For single data point, create a single candle
- if len(df) == 1:
- price_val = df['price'].iloc[0]
- volume_val = df['volume'].iloc[0] if 'volume' in df.columns else 0
- timestamp_val = df.index[0]
-
- candles = pd.DataFrame({
- 'timestamp': [timestamp_val],
- 'open': [price_val],
- 'high': [price_val],
- 'low': [price_val],
- 'close': [price_val],
- 'volume': [volume_val]
- })
- return candles
- else:
- # No data
- return pd.DataFrame()
-
- # Resample and aggregate each column separately
- open_df = df['price'].resample(interval_str).first()
- high_df = df['price'].resample(interval_str).max()
- low_df = df['price'].resample(interval_str).min()
- close_df = df['price'].resample(interval_str).last()
- volume_df = df['volume'].resample(interval_str).sum()
-
- # Check for length mismatches before combining
- expected_length = len(open_df)
- if (len(high_df) != expected_length or
- len(low_df) != expected_length or
- len(close_df) != expected_length or
- len(volume_df) != expected_length):
- logger.warning("Length mismatch in resampled columns, falling back to alternative method")
- raise ValueError("Length mismatch")
-
- # Combine into a single DataFrame
- candles = pd.DataFrame({
- 'open': open_df,
- 'high': high_df,
- 'low': low_df,
- 'close': close_df,
- 'volume': volume_df
- })
- logger.debug(f"Successfully created {len(candles)} candles with individual column resampling")
- except Exception as resample_error:
- logger.error(f"Error in individual column resampling: {str(resample_error)}")
-
- # Second attempt - built-in agg method
- try:
- logger.debug("Trying fallback resampling method with agg()")
- candles = df.resample(interval_str).agg({
- 'price': ['first', 'max', 'min', 'last'],
- 'volume': 'sum'
- })
- # Flatten MultiIndex columns
- candles.columns = ['open', 'high', 'low', 'close', 'volume']
- logger.debug(f"Successfully created {len(candles)} candles with agg() method")
- except Exception as agg_error:
- logger.error(f"Error in agg() resampling: {str(agg_error)}")
-
- # Third attempt - manual candle construction
- try:
- logger.debug("Trying manual candle construction method")
- resampler = df.resample(interval_str)
- candle_data = []
-
- for name, group in resampler:
- if not group.empty:
- candle = {
- 'timestamp': name,
- 'open': group['price'].iloc[0],
- 'high': group['price'].max(),
- 'low': group['price'].min(),
- 'close': group['price'].iloc[-1],
- 'volume': group['volume'].sum() if 'volume' in group.columns else 0
- }
- candle_data.append(candle)
-
- if candle_data:
- candles = pd.DataFrame(candle_data)
- logger.debug(f"Successfully created {len(candles)} candles with manual method")
- else:
- logger.warning("No candles created with manual method")
- return pd.DataFrame()
- except Exception as manual_error:
- logger.error(f"Error in manual candle construction: {str(manual_error)}")
- import traceback
- logger.error(traceback.format_exc())
- return pd.DataFrame()
-
- # Ensure the result isn't empty
- if candles is None or candles.empty:
- logger.warning("No candles were created after all resampling attempts")
- return pd.DataFrame()
-
- # Reset index to get timestamp as column
- try:
- candles = candles.reset_index()
- except Exception as reset_error:
- logger.error(f"Error resetting index: {str(reset_error)}")
- # Try to create a new DataFrame with the timestamp index as a column
- try:
- timestamp_col = candles.index.to_list()
- candles_dict = candles.to_dict('list')
- candles_dict['timestamp'] = timestamp_col
- candles = pd.DataFrame(candles_dict)
- except Exception as fallback_error:
- logger.error(f"Error in fallback index reset: {str(fallback_error)}")
- return pd.DataFrame()
-
- # Ensure no NaN values
- try:
- nan_count_before = candles.isna().sum().sum()
- if nan_count_before > 0:
- logger.warning(f"Found {nan_count_before} NaN values in candles, dropping them")
-
- candles = candles.dropna()
- except Exception as nan_error:
- logger.error(f"Error handling NaN values: {str(nan_error)}")
- # Try to fill NaN values instead of dropping
- try:
- candles = candles.fillna(method='ffill').fillna(method='bfill')
- except:
- pass
-
- logger.debug(f"Generated {len(candles)} candles from {len(df)} ticks")
- return candles
-
- def get_candle_stats(self) -> Dict:
- """Get statistics about cached candles for different intervals"""
- stats = {}
-
- # Define intervals to check
- intervals = [1, 5, 15, 60, 300, 900, 3600]
-
- for interval in intervals:
- candles = self.get_candles(interval_seconds=interval)
- count = len(candles) if not candles.empty else 0
-
- # Get time range if we have candles
- time_range = None
- if count > 0:
- try:
- start_time = candles['timestamp'].min()
- end_time = candles['timestamp'].max()
- if isinstance(start_time, pd.Timestamp):
- start_time = start_time.strftime('%Y-%m-%d %H:%M:%S')
- if isinstance(end_time, pd.Timestamp):
- end_time = end_time.strftime('%Y-%m-%d %H:%M:%S')
- time_range = f"{start_time} to {end_time}"
- except:
- time_range = "Unknown"
-
- stats[f"{interval}s"] = {
- 'count': count,
- 'time_range': time_range
- }
-
- return stats
-
- def get_ticks_from_time(self, start_time_ms: int = None, end_time_ms: int = None) -> List[Dict]:
- """Get ticks within a specific time range
-
- Args:
- start_time_ms: Start time in milliseconds (None for no lower bound)
- end_time_ms: End time in milliseconds (None for no upper bound)
-
- Returns:
- List of ticks within the time range
- """
- if not self.ticks:
- return []
-
- # Ensure ticks are updated
- self._cleanup()
-
- # Apply time filters if specified
- filtered_ticks = self.ticks
- if start_time_ms is not None:
- filtered_ticks = [tick for tick in filtered_ticks if tick['timestamp'] >= start_time_ms]
- if end_time_ms is not None:
- filtered_ticks = [tick for tick in filtered_ticks if tick['timestamp'] <= end_time_ms]
-
- logger.debug(f"Retrieved {len(filtered_ticks)} ticks from time range {start_time_ms} to {end_time_ms}")
- return filtered_ticks
-
- def get_time_based_stats(self) -> Dict:
- """Get statistics about the ticks organized by time periods
-
- Returns:
- Dictionary with statistics for different time periods
- """
- if not self.ticks:
- return {
- 'total_ticks': 0,
- 'periods': {}
- }
-
- # Ensure ticks are updated
- self._cleanup()
-
- now = int(time.time() * 1000) # Current time in ms
-
- # Define time periods to analyze
- periods = {
- '1min': now - (60 * 1000),
- '5min': now - (5 * 60 * 1000),
- '15min': now - (15 * 60 * 1000),
- '30min': now - (30 * 60 * 1000)
+ tick = {
+ 'price': price,
+ 'volume': volume,
+ 'timestamp': timestamp
}
- stats = {
- 'total_ticks': len(self.ticks),
- 'oldest_tick': self.ticks[0]['timestamp'] if self.ticks else None,
- 'newest_tick': self.ticks[-1]['timestamp'] if self.ticks else None,
- 'time_span_seconds': (self.ticks[-1]['timestamp'] - self.ticks[0]['timestamp']) / 1000 if self.ticks else 0,
- 'periods': {}
+ self.ticks.append(tick)
+ self.latest_price = price
+
+ # Keep only last 10000 ticks
+ if len(self.ticks) > 10000:
+ self.ticks = self.ticks[-10000:]
+
+ # Update candles
+ self._update_candles(tick)
+
+ def get_latest_price(self):
+ """Get the latest price"""
+ return self.latest_price
+
+ def _update_candles(self, tick):
+ """Update candles with the new tick"""
+ intervals = {
+ '1m': 60,
+ '5m': 300,
+ '15m': 900,
+ '1h': 3600,
+ '4h': 14400,
+ '1d': 86400
}
- # Calculate stats for each period
- for period_name, cutoff_time in periods.items():
- period_ticks = [tick for tick in self.ticks if tick['timestamp'] >= cutoff_time]
+ for interval_key, seconds in intervals.items():
+ if interval_key not in self.candles:
+ self.candles[interval_key] = []
+
+ # Get or create the current candle
+ current_candle = self._get_current_candle(interval_key, tick['timestamp'], seconds)
- if period_ticks:
- prices = [tick['price'] for tick in period_ticks]
- volumes = [tick.get('volume', 0) for tick in period_ticks]
-
- period_stats = {
- 'tick_count': len(period_ticks),
- 'min_price': min(prices) if prices else None,
- 'max_price': max(prices) if prices else None,
- 'avg_price': sum(prices) / len(prices) if prices else None,
- 'last_price': period_ticks[-1]['price'] if period_ticks else None,
- 'total_volume': sum(volumes),
- 'ticks_per_second': len(period_ticks) / (int(period_name[:-3]) * 60) if period_ticks else 0
- }
-
- stats['periods'][period_name] = period_stats
+ # Update the candle with the new tick
+ if current_candle['high'] < tick['price']:
+ current_candle['high'] = tick['price']
+ if current_candle['low'] > tick['price']:
+ current_candle['low'] = tick['price']
+ current_candle['close'] = tick['price']
+ current_candle['volume'] += tick['volume']
+
+ def _get_current_candle(self, interval_key, timestamp, interval_seconds):
+ """Get the current candle for the given interval, or create a new one"""
+ # Calculate the candle start time
+ candle_start = timestamp.replace(
+ microsecond=0,
+ second=0,
+ minute=(timestamp.minute // (interval_seconds // 60)) * (interval_seconds // 60)
+ )
- logger.debug(f"Generated time-based stats: {len(stats['periods'])} periods")
- return stats
-
-class CandlestickData:
- def __init__(self, max_length: int = 300):
- self.timestamps = deque(maxlen=max_length)
- self.opens = deque(maxlen=max_length)
- self.highs = deque(maxlen=max_length)
- self.lows = deque(maxlen=max_length)
- self.closes = deque(maxlen=max_length)
- self.volumes = deque(maxlen=max_length)
- self.current_candle = {
- 'timestamp': None,
- 'open': None,
- 'high': None,
- 'low': None,
- 'close': None,
+ if interval_seconds >= 3600: # For hourly or higher
+ hours = (timestamp.hour // (interval_seconds // 3600)) * (interval_seconds // 3600)
+ candle_start = candle_start.replace(hour=hours)
+
+ if interval_seconds >= 86400: # For daily
+ candle_start = candle_start.replace(hour=0)
+
+ # Check if we already have a candle for this time
+ for candle in self.candles[interval_key]:
+ if candle['timestamp'] == candle_start:
+ return candle
+
+ # Create a new candle
+ candle = {
+ 'timestamp': candle_start,
+ 'open': self.latest_price if self.latest_price is not None else tick['price'],
+ 'high': tick['price'],
+ 'low': tick['price'],
+ 'close': tick['price'],
'volume': 0
}
- self.candle_interval = 1 # 1 second by default
-
- def update_from_trade(self, trade: Dict):
- timestamp = trade['timestamp']
- price = trade['price']
- volume = trade.get('volume', 0)
-
- # Round timestamp to nearest candle interval
- candle_timestamp = int(timestamp / (self.candle_interval * 1000)) * (self.candle_interval * 1000)
-
- if self.current_candle['timestamp'] != candle_timestamp:
- # Save current candle if it exists
- if self.current_candle['timestamp'] is not None:
- self.timestamps.append(self.current_candle['timestamp'])
- self.opens.append(self.current_candle['open'])
- self.highs.append(self.current_candle['high'])
- self.lows.append(self.current_candle['low'])
- self.closes.append(self.current_candle['close'])
- self.volumes.append(self.current_candle['volume'])
- logger.debug(f"New candle saved: {self.current_candle}")
-
- # Start new candle
- self.current_candle = {
- 'timestamp': candle_timestamp,
- 'open': price,
- 'high': price,
- 'low': price,
- 'close': price,
- 'volume': volume
- }
- logger.debug(f"New candle started: {self.current_candle}")
- else:
- # Update current candle
- if self.current_candle['high'] is None or price > self.current_candle['high']:
- self.current_candle['high'] = price
- if self.current_candle['low'] is None or price < self.current_candle['low']:
- self.current_candle['low'] = price
- self.current_candle['close'] = price
- self.current_candle['volume'] += volume
- logger.debug(f"Updated current candle: {self.current_candle}")
-
- def get_dataframe(self) -> pd.DataFrame:
- # Include current candle in the dataframe if it exists
- timestamps = list(self.timestamps)
- opens = list(self.opens)
- highs = list(self.highs)
- lows = list(self.lows)
- closes = list(self.closes)
- volumes = list(self.volumes)
-
- if self.current_candle['timestamp'] is not None:
- timestamps.append(self.current_candle['timestamp'])
- opens.append(self.current_candle['open'])
- highs.append(self.current_candle['high'])
- lows.append(self.current_candle['low'])
- closes.append(self.current_candle['close'])
- volumes.append(self.current_candle['volume'])
-
- df = pd.DataFrame({
- 'timestamp': timestamps,
- 'open': opens,
- 'high': highs,
- 'low': lows,
- 'close': closes,
- 'volume': volumes
- })
- if not df.empty:
- df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
+
+ self.candles[interval_key].append(candle)
+ return candle
+
+ def get_candles(self, interval='1m'):
+ """Get candles for the given interval"""
+ if interval not in self.candles or not self.candles[interval]:
+ return None
+
+ # Convert to DataFrame
+ df = pd.DataFrame(self.candles[interval])
+ df.set_index('timestamp', inplace=True)
return df
-
-class BinanceWebSocket:
- """Binance WebSocket implementation for real-time tick data"""
- def __init__(self, symbol: str):
- self.symbol = symbol.replace('/', '').lower()
- self.ws = None
- self.running = False
- self.reconnect_delay = 1
- self.max_reconnect_delay = 60
- self.message_count = 0
-
- # Binance WebSocket configuration
- self.ws_url = f"wss://stream.binance.com:9443/ws/{self.symbol}@trade"
- logger.info(f"Initialized Binance WebSocket for symbol: {self.symbol}")
-
- async def connect(self):
- while True:
- try:
- logger.info(f"Attempting to connect to {self.ws_url}")
- self.ws = await websockets.connect(self.ws_url)
- logger.info("WebSocket connection established")
-
- self.running = True
- self.reconnect_delay = 1
- logger.info(f"Successfully connected to Binance WebSocket for {self.symbol}")
- return True
- except Exception as e:
- logger.error(f"WebSocket connection error: {str(e)}")
- await asyncio.sleep(self.reconnect_delay)
- self.reconnect_delay = min(self.reconnect_delay * 2, self.max_reconnect_delay)
- continue
-
- async def receive(self) -> Optional[Dict]:
- if not self.ws:
- return None
+ def load_from_file(self, file_path):
+ """Load ticks from a file"""
try:
- message = await self.ws.recv()
- self.message_count += 1
-
- if self.message_count % 100 == 0: # Log every 100th message to avoid spam
- logger.info(f"Received message #{self.message_count}")
- logger.debug(f"Raw message: {message[:200]}...")
-
- data = json.loads(message)
-
- # Process trade data
- if 'e' in data and data['e'] == 'trade':
- trade_data = {
- 'timestamp': data['T'], # Trade time
- 'price': float(data['p']), # Price
- 'volume': float(data['q']), # Quantity
- 'type': 'trade'
- }
- logger.debug(f"Processed trade data: {trade_data}")
- return trade_data
-
- return None
- except websockets.exceptions.ConnectionClosed:
- logger.warning("WebSocket connection closed")
- self.running = False
- return None
- except json.JSONDecodeError as e:
- logger.error(f"JSON decode error: {str(e)}, message: {message[:200]}...")
- return None
- except Exception as e:
- logger.error(f"Error receiving message: {str(e)}")
- return None
-
- async def close(self):
- """Close the WebSocket connection"""
- if self.ws:
- await self.ws.close()
- self.running = False
- logger.info("WebSocket connection closed")
-
-class BinanceHistoricalData:
- """Fetch historical candle data from Binance"""
-
- def __init__(self):
- self.base_url = "https://api.binance.com/api/v3/klines"
- # Create a cache directory if it doesn't exist
- self.cache_dir = os.path.join(os.getcwd(), "cache")
- os.makedirs(self.cache_dir, exist_ok=True)
- logger.info(f"Initialized BinanceHistoricalData with cache directory: {self.cache_dir}")
-
- def _get_interval_string(self, interval_seconds: int) -> str:
- """Convert interval seconds to Binance interval string"""
- if interval_seconds == 60: # 1m
- return "1m"
- elif interval_seconds == 300: # 5m
- return "5m"
- elif interval_seconds == 900: # 15m
- return "15m"
- elif interval_seconds == 1800: # 30m
- return "30m"
- elif interval_seconds == 3600: # 1h
- return "1h"
- elif interval_seconds == 14400: # 4h
- return "4h"
- elif interval_seconds == 86400: # 1d
- return "1d"
- else:
- # Default to 1m if not recognized
- logger.warning(f"Unrecognized interval {interval_seconds}s, defaulting to 1m")
- return "1m"
-
- def _get_cache_filename(self, symbol: str, interval: str) -> str:
- """Generate cache filename for the symbol and interval"""
- # Replace any slashes in symbol with underscore
- safe_symbol = symbol.replace("/", "_")
- return os.path.join(self.cache_dir, f"{safe_symbol}_{interval}_candles.csv")
-
- def _load_from_cache(self, symbol: str, interval: str) -> Optional[pd.DataFrame]:
- """Load candle data from cache if available and not expired"""
- filename = self._get_cache_filename(symbol, interval)
-
- if not os.path.exists(filename):
- logger.debug(f"No cache file found for {symbol} {interval}")
- return None
-
- # Check if cache is fresh (less than 1 hour old for anything but 1d, 1 day for 1d)
- file_age = time.time() - os.path.getmtime(filename)
- max_age = 86400 if interval == "1d" else 3600 # 1 day for 1d, 1 hour for others
-
- if file_age > max_age:
- logger.debug(f"Cache for {symbol} {interval} is expired ({file_age:.1f}s old)")
- return None
-
- try:
- df = pd.read_csv(filename)
- # Convert timestamp string back to datetime
- df['timestamp'] = pd.to_datetime(df['timestamp'])
- logger.info(f"Loaded {len(df)} candles from cache for {symbol} {interval}")
- return df
- except Exception as e:
- logger.error(f"Error loading from cache: {str(e)}")
- return None
-
- def _save_to_cache(self, df: pd.DataFrame, symbol: str, interval: str) -> bool:
- """Save candle data to cache"""
- if df.empty:
- logger.warning(f"No data to cache for {symbol} {interval}")
- return False
-
- filename = self._get_cache_filename(symbol, interval)
- try:
- df.to_csv(filename, index=False)
- logger.info(f"Cached {len(df)} candles for {symbol} {interval} to {filename}")
- return True
- except Exception as e:
- logger.error(f"Error saving to cache: {str(e)}")
- return False
-
- def get_historical_candles(self, symbol: str, interval_seconds: int, limit: int = 500) -> pd.DataFrame:
- """Get historical candle data for the specified symbol and interval"""
- # Convert to Binance format
- clean_symbol = symbol.replace("/", "")
- interval = self._get_interval_string(interval_seconds)
-
- # Try to load from cache first
- cached_data = self._load_from_cache(symbol, interval)
- if cached_data is not None and len(cached_data) >= limit:
- return cached_data.tail(limit)
-
- # Fetch from API if not cached or insufficient
- try:
- logger.info(f"Fetching {limit} historical candles for {symbol} ({interval}) from Binance API")
-
- params = {
- "symbol": clean_symbol,
- "interval": interval,
- "limit": limit
- }
-
- response = requests.get(self.base_url, params=params)
- response.raise_for_status() # Raise exception for HTTP errors
-
- # Process the data
- candles = response.json()
-
- if not candles:
- logger.warning(f"No candles returned from Binance for {symbol} {interval}")
- return pd.DataFrame()
-
- # Convert to DataFrame - Binance returns data in this format:
- # [
- # [
- # 1499040000000, // Open time
- # "0.01634790", // Open
- # "0.80000000", // High
- # "0.01575800", // Low
- # "0.01577100", // Close
- # "148976.11427815", // Volume
- # ... // Ignore the rest
- # ],
- # ...
- # ]
-
- df = pd.DataFrame(candles, columns=[
- "timestamp", "open", "high", "low", "close", "volume",
- "close_time", "quote_asset_volume", "number_of_trades",
- "taker_buy_base_asset_volume", "taker_buy_quote_asset_volume", "ignore"
- ])
-
- # Convert types
- df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
- for col in ["open", "high", "low", "close", "volume"]:
- df[col] = df[col].astype(float)
-
- # Keep only needed columns
- df = df[["timestamp", "open", "high", "low", "close", "volume"]]
-
- # Cache the results
- self._save_to_cache(df, symbol, interval)
-
- logger.info(f"Successfully fetched {len(df)} candles for {symbol} {interval}")
- return df
-
- except Exception as e:
- logger.error(f"Error fetching historical data for {symbol} {interval}: {str(e)}")
- import traceback
- logger.error(traceback.format_exc())
- return pd.DataFrame()
-
-
-class ExchangeWebSocket:
- """Generic WebSocket interface for cryptocurrency exchanges"""
- def __init__(self, symbol: str, exchange: str = "binance"):
- self.symbol = symbol
- self.exchange = exchange.lower()
- self.ws = None
-
- # Initialize the appropriate WebSocket implementation
- if self.exchange == "binance":
- self.ws = BinanceWebSocket(symbol)
- elif self.exchange == "mexc":
- self.ws = MEXCWebSocket(symbol)
- else:
- raise ValueError(f"Unsupported exchange: {exchange}")
-
- async def connect(self):
- """Connect to the exchange WebSocket"""
- return await self.ws.connect()
-
- async def receive(self) -> Optional[Dict]:
- """Receive data from the WebSocket"""
- return await self.ws.receive()
-
- async def close(self):
- """Close the WebSocket connection"""
- await self.ws.close()
-
- @property
- def running(self):
- """Check if the WebSocket is running"""
- return self.ws.running if self.ws else False
-
-class CandleCache:
- def __init__(self, max_candles: int = 5000):
- self.candles = {
- '1s': deque(maxlen=max_candles),
- '1m': deque(maxlen=max_candles),
- '1h': deque(maxlen=max_candles),
- '1d': deque(maxlen=max_candles)
- }
- logger.info(f"Initialized CandleCache with max candles: {max_candles}")
-
- def add_candles(self, interval: str, new_candles: pd.DataFrame):
- if interval in self.candles and not new_candles.empty:
- # Convert DataFrame to list of dicts to avoid pandas issues
- for _, row in new_candles.iterrows():
- candle_dict = row.to_dict()
- self.candles[interval].append(candle_dict)
- logger.debug(f"Added {len(new_candles)} candles to {interval} cache")
-
- def get_recent_candles(self, interval: str, count: int = 500) -> pd.DataFrame:
- if interval in self.candles and self.candles[interval]:
- # Convert deque to list of dicts first
- all_candles = list(self.candles[interval])
- # Check if we're requesting more candles than we have
- if count > len(all_candles):
- logger.debug(f"Requested {count} candles, but only have {len(all_candles)} for {interval}")
- count = len(all_candles)
-
- recent_candles = all_candles[-count:]
- logger.debug(f"Returning {len(recent_candles)} recent candles for {interval} (requested {count})")
-
- # Create DataFrame and ensure timestamp is datetime type
- df = pd.DataFrame(recent_candles)
- if not df.empty and 'timestamp' in df.columns:
- try:
- if not pd.api.types.is_datetime64_any_dtype(df['timestamp']):
- df['timestamp'] = pd.to_datetime(df['timestamp'])
- except Exception as e:
- logger.warning(f"Error converting timestamps in get_recent_candles: {str(e)}")
-
- return df
-
- logger.debug(f"No candles available for {interval}")
- return pd.DataFrame()
-
- def update_cache(self, interval: str, new_candles: pd.DataFrame):
- """
- Update the candle cache for a specific timeframe with new candles
-
- Args:
- interval: The timeframe interval ('1s', '1m', '1h', '1d')
- new_candles: DataFrame with new candles to add to the cache
- """
- if interval not in self.candles:
- logger.warning(f"Invalid interval {interval} for cache update")
- return
-
- if new_candles is None or new_candles.empty:
- logger.debug(f"No new candles to update {interval} cache")
- return
-
- # Check if timestamp column exists
- if 'timestamp' not in new_candles.columns:
- logger.warning(f"No timestamp column in new candles for {interval}")
- return
-
- try:
- # Ensure timestamp is datetime for proper comparison
- try:
- if not pd.api.types.is_datetime64_any_dtype(new_candles['timestamp']):
- logger.debug(f"Converting timestamps to datetime for {interval}")
- new_candles['timestamp'] = pd.to_datetime(new_candles['timestamp'])
- except Exception as e:
- logger.error(f"Error converting timestamps: {str(e)}")
- # Try a different approach
- try:
- new_candles['timestamp'] = pd.to_datetime(new_candles['timestamp'], errors='coerce')
- # Drop any rows where conversion failed
- new_candles = new_candles.dropna(subset=['timestamp'])
- if new_candles.empty:
- logger.warning(f"All timestamps conversion failed for {interval}")
- return
- except Exception as e2:
- logger.error(f"Second attempt to convert timestamps failed: {str(e2)}")
- return
-
- # Create a copy to avoid modifying the original
- new_candles_copy = new_candles.copy()
-
- # If we have no candles in cache, add all new candles
- if not self.candles[interval]:
- logger.debug(f"No existing candles for {interval}, adding all {len(new_candles_copy)} candles")
- self.add_candles(interval, new_candles_copy)
- return
-
- # Get the timestamp from the last cached candle
- last_cached_candle = self.candles[interval][-1]
- if not isinstance(last_cached_candle, dict):
- logger.warning(f"Last cached candle is not a dictionary for {interval}")
- last_cached_candle = {'timestamp': None}
-
- if 'timestamp' not in last_cached_candle:
- logger.warning(f"No timestamp in last cached candle for {interval}")
- last_cached_candle['timestamp'] = None
-
- last_cached_time = last_cached_candle['timestamp']
- logger.debug(f"Last cached timestamp for {interval}: {last_cached_time}")
-
- # If last_cached_time is None, add all candles
- if last_cached_time is None:
- logger.debug(f"No valid last cached timestamp, adding all {len(new_candles_copy)} candles for {interval}")
- self.add_candles(interval, new_candles_copy)
- return
-
- # Convert last_cached_time to datetime if needed
- if not isinstance(last_cached_time, (pd.Timestamp, datetime)):
- try:
- last_cached_time = pd.to_datetime(last_cached_time)
- except Exception as e:
- logger.error(f"Cannot convert last cached time to datetime: {str(e)}")
- # Add all candles as fallback
- self.add_candles(interval, new_candles_copy)
- return
-
- # Make a backup of current cache before filtering
- cache_backup = list(self.candles[interval])
-
- # Filter new candles that are after the last cached candle
- try:
- filtered_candles = new_candles_copy[new_candles_copy['timestamp'] > last_cached_time]
-
- if not filtered_candles.empty:
- logger.debug(f"Adding {len(filtered_candles)} new candles for {interval}")
- self.add_candles(interval, filtered_candles)
+ df = pd.read_csv(file_path)
+ for _, row in df.iterrows():
+ if 'timestamp' in row:
+ timestamp = pd.to_datetime(row['timestamp'])
else:
- # No new candles after last cached time, check for missing candles
- try:
- # Get unique timestamps in cache
- cached_timestamps = set()
- for candle in self.candles[interval]:
- if isinstance(candle, dict) and 'timestamp' in candle:
- ts = candle['timestamp']
- if isinstance(ts, (pd.Timestamp, datetime)):
- cached_timestamps.add(ts)
- else:
- try:
- cached_timestamps.add(pd.to_datetime(ts))
- except:
- pass
-
- # Find candles in new_candles that aren't in the cache
- missing_candles = new_candles_copy[~new_candles_copy['timestamp'].isin(cached_timestamps)]
-
- if not missing_candles.empty:
- logger.info(f"Found {len(missing_candles)} missing candles for {interval}")
- self.add_candles(interval, missing_candles)
- else:
- logger.debug(f"No new or missing candles to add for {interval}")
- except Exception as missing_error:
- logger.error(f"Error checking for missing candles: {str(missing_error)}")
- except Exception as filter_error:
- logger.error(f"Error filtering candles by timestamp: {str(filter_error)}")
- # Restore from backup
- self.candles[interval] = deque(cache_backup, maxlen=self.candles[interval].maxlen)
- # Try adding all candles as fallback
- self.add_candles(interval, new_candles_copy)
+ timestamp = None
+
+ self.add_tick(
+ price=row.get('price', row.get('close', 0)),
+ volume=row.get('volume', 0),
+ timestamp=timestamp
+ )
+ logger.info(f"Loaded {len(df)} ticks from {file_path}")
except Exception as e:
- logger.error(f"Unhandled error updating cache for {interval}: {str(e)}")
+ logger.error(f"Error loading ticks from file: {str(e)}")
+
+ def load_historical_data(self, historical_data, symbol):
+ """Load historical data"""
+ try:
+ df = historical_data.get_historical_candles(symbol)
+ if df is not None and not df.empty:
+ for _, row in df.iterrows():
+ self.add_tick(
+ price=row['close'],
+ volume=row['volume'],
+ timestamp=row['timestamp']
+ )
+ logger.info(f"Loaded {len(df)} historical candles for {symbol}")
+ except Exception as e:
+ logger.error(f"Error loading historical data: {str(e)}")
import traceback
logger.error(traceback.format_exc())
- def get_candles(self, timeframe: str, count: int = 500) -> pd.DataFrame:
- """
- Get candles for a specific timeframe. This is an alias for get_recent_candles
- to maintain compatibility with code that expects this method name.
+class Position:
+ """Class representing a trading position"""
+ def __init__(self, action, entry_price, amount, timestamp=None, trade_id=None):
+ self.action = action # BUY or SELL
+ self.entry_price = entry_price
+ self.amount = amount
+ self.timestamp = timestamp or datetime.now()
+ self.status = "OPEN" # OPEN or CLOSED
+ self.exit_price = None
+ self.exit_timestamp = None
+ self.pnl = 0.0
+ self.trade_id = trade_id or str(uuid.uuid4())
- Args:
- timeframe: The timeframe interval ('1s', '1m', '1h', '1d')
- count: Maximum number of candles to return
+ def close(self, exit_price, exit_timestamp=None):
+ """Close the position with an exit price"""
+ self.exit_price = exit_price
+ self.exit_timestamp = exit_timestamp or datetime.now()
+ self.status = "CLOSED"
+
+ # Calculate PnL
+ if self.action == "BUY":
+ self.pnl = (self.exit_price - self.entry_price) * self.amount
+ else: # SELL
+ self.pnl = (self.entry_price - self.exit_price) * self.amount
- Returns:
- DataFrame containing the candles
- """
- try:
- logger.debug(f"Getting {count} candles for {timeframe} via get_candles()")
- return self.get_recent_candles(timeframe, count)
- except Exception as e:
- logger.error(f"Error in get_candles for {timeframe}: {str(e)}")
- import traceback
- logger.error(traceback.format_exc())
- return pd.DataFrame()
+ return self.pnl
class RealTimeChart:
"""Real-time chart using Dash and Plotly"""
- def __init__(self, symbol="BTC/USDT", use_sample_data=False, log_no_ticks_warning=True):
- """Initialize a new RealTimeChart
-
- Args:
- symbol: Trading pair symbol (e.g., BTC/USDT)
- use_sample_data: Whether to use sample data when no real data is available
- log_no_ticks_warning: Whether to log warnings when no ticks are available
- """
+ def __init__(self, symbol, data_path=None, historical_data=None, exchange=None, timeframe='1m'):
+ """Initialize the RealTimeChart class"""
self.symbol = symbol
- self.use_sample_data = use_sample_data
+ self.exchange = exchange
+ self.app = dash.Dash(__name__, external_stylesheets=[dbc.themes.DARKLY])
+ self.tick_storage = TickStorage()
+ self.historical_data = historical_data
+ self.data_path = data_path
+ self.current_interval = '1m' # Default interval
+ self.fig = None # Will hold the main chart figure
+ self.positions = [] # List to hold position objects
+ self.balance = 1000.0 # Starting balance
+ self.last_action = None # Last trading action
- # Initialize variables for trading info display
- self.current_signal = 'HOLD'
- self.signal_time = datetime.now()
- self.current_position = 0.0
- self.session_balance = 100.0 # Start with $100 balance
- self.session_pnl = 0.0
-
- # Initialize NN signals and trades lists
- self.nn_signals = []
- self.trades = []
-
- # Use existing timezone variable instead of trying to detect again
- logger.info(f"Using timezone: {tz_name}")
-
- # Initialize tick storage
- logger.info(f"Initializing RealTimeChart for {symbol}")
- self.tick_storage = TradeTickStorage(
- symbol=symbol,
- max_age_seconds=3600, # Keep ticks for 1 hour
- use_sample_data=use_sample_data,
- log_no_ticks_warning=log_no_ticks_warning
- )
-
- # Initialize candlestick data for backward compatibility
- self.candlestick_data = CandlestickData(max_length=5000)
-
- # Initialize candle cache
- self.candle_cache = CandleCache(max_candles=5000)
-
- # Initialize OHLCV cache dictionaries for different timeframes
- self.ohlcv_cache = {
- '1s': None,
- '5s': None,
- '15s': None,
- '60s': None,
- '300s': None,
- '900s': None,
- '3600s': None,
- '1m': None,
- '5m': None,
- '15m': None,
- '1h': None,
- '1d': None
- }
-
- # Historical data handler
- self.historical_data = BinanceHistoricalData()
-
- # Flag for first render to force data loading
- self.first_render = True
-
- # Last time candles were saved to disk
- self.last_cache_save_time = time.time()
-
- # Initialize Dash app
- self.app = dash.Dash(
- __name__,
- external_stylesheets=[dbc.themes.DARKLY],
- suppress_callback_exceptions=True,
- meta_tags=[{"name": "viewport", "content": "width=device-width, initial-scale=1"}]
- )
-
- # Set up layout and callbacks
self._setup_app_layout()
+
+ # Run the app in a separate thread
+ threading.Thread(target=self._run_app, daemon=True).start()
def _setup_app_layout(self):
"""Set up the app layout and callbacks"""
@@ -1458,78 +539,100 @@ class RealTimeChart:
self._setup_interval_callback(button_style, active_button_style)
self._setup_chart_callback()
self._setup_position_list_callback()
+ self._setup_trading_status_callback()
# We've removed the ticks callback, so don't call it
# self._setup_ticks_callback()
def _get_chart_layout(self, button_style, active_button_style):
- # Chart page layout
+ """Get the chart layout"""
return html.Div([
- # Chart title and interval buttons
+ # Trading stats header at the top
html.Div([
- html.H2(f"{self.symbol} Real-Time Chart", style={
- 'textAlign': 'center',
- 'color': '#FFFFFF',
- 'marginBottom': '10px'
- }),
-
- # Store interval data
- dcc.Store(id='interval-store', data={'interval': 1}),
-
- # Interval selection buttons
html.Div([
- html.Button('1s', id='btn-1s', n_clicks=0, style=active_button_style),
- html.Button('5s', id='btn-5s', n_clicks=0, style=button_style),
- html.Button('15s', id='btn-15s', n_clicks=0, style=button_style),
- html.Button('30s', id='btn-30s', n_clicks=0, style=button_style),
- html.Button('1m', id='btn-1m', n_clicks=0, style=button_style),
+ html.Div([
+ html.Span("Signal: ", style={'fontWeight': 'bold', 'marginRight': '5px'}),
+ html.Span("NONE", id='current-signal-value', style={'color': 'white'})
+ ], style={'marginRight': '20px', 'display': 'inline-block'}),
+ html.Div([
+ html.Span("Position: ", style={'fontWeight': 'bold', 'marginRight': '5px'}),
+ html.Span("NONE", id='current-position-value', style={'color': 'white'})
+ ], style={'marginRight': '20px', 'display': 'inline-block'}),
+ html.Div([
+ html.Span("Balance: ", style={'fontWeight': 'bold', 'marginRight': '5px'}),
+ html.Span("$0.00", id='current-balance-value', style={'color': 'white'})
+ ], style={'marginRight': '20px', 'display': 'inline-block'}),
+ html.Div([
+ html.Span("Session PnL: ", style={'fontWeight': 'bold', 'marginRight': '5px'}),
+ html.Span("$0.00", id='current-pnl-value', style={'color': 'white'})
+ ], style={'display': 'inline-block'})
], style={
- 'display': 'flex',
- 'justifyContent': 'center',
- 'marginBottom': '15px'
- }),
-
- # Interval component for updates - set to refresh every 500ms
- dcc.Interval(
- id='interval-component',
- interval=300, # Refresh every 300ms for better real-time updates
- n_intervals=0
- ),
-
- # Main chart
- dcc.Graph(id='live-chart', style={'height': '75vh'}),
-
- # Last 5 Positions component
- html.Div([
- html.H4("Last 5 Positions", style={'textAlign': 'center', 'color': '#FFFFFF'}),
- html.Table([
- html.Thead(
- html.Tr([
- html.Th("Action", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}),
- html.Th("Size", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}),
- html.Th("Entry Price", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}),
- html.Th("Exit Price", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}),
- html.Th("PnL", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}),
- html.Th("Time", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'})
- ])
- ),
- html.Tbody(id='position-table-body')
- ], style={
- 'width': '100%',
- 'borderCollapse': 'collapse',
- 'color': '#FFFFFF',
- 'backgroundColor': '#222'
- })
- ], style={'marginTop': '20px', 'marginBottom': '20px', 'overflowX': 'auto'}),
-
- # Chart acknowledgment
- html.Div("Real-time trading chart with ML signals", style={
- 'textAlign': 'center',
- 'color': '#AAAAAA',
- 'fontSize': '12px',
- 'marginTop': '5px'
+ 'padding': '10px',
+ 'backgroundColor': '#222222',
+ 'borderRadius': '5px',
+ 'marginBottom': '10px',
+ 'border': '1px solid #444444'
})
- ])
- ])
+ ]),
+
+ # Recent Trades Table (compact at top)
+ html.Div([
+ html.H4("Recent Trades", style={'color': 'white', 'margin': '5px 0', 'fontSize': '14px'}),
+ html.Table([
+ html.Thead(html.Tr([
+ html.Th("Status", style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px', 'fontWeight': 'bold', 'backgroundColor': '#333333'}),
+ html.Th("Amount", style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px', 'fontWeight': 'bold', 'backgroundColor': '#333333'}),
+ html.Th("Entry Price", style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px', 'fontWeight': 'bold', 'backgroundColor': '#333333'}),
+ html.Th("Exit Price", style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px', 'fontWeight': 'bold', 'backgroundColor': '#333333'}),
+ html.Th("PnL", style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px', 'fontWeight': 'bold', 'backgroundColor': '#333333'}),
+ html.Th("Time", style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px', 'fontWeight': 'bold', 'backgroundColor': '#333333'})
+ ])),
+ html.Tbody(id='position-list', children=[
+ html.Tr([html.Td("No positions yet", colSpan=6, style={'textAlign': 'center', 'padding': '4px', 'fontSize': '12px'})])
+ ])
+ ], style={
+ 'width': '100%',
+ 'borderCollapse': 'collapse',
+ 'fontSize': '12px',
+ 'backgroundColor': '#222222',
+ 'color': 'white',
+ 'marginBottom': '10px'
+ })
+ ], style={'marginBottom': '10px'}),
+
+ # Chart area
+ dcc.Graph(
+ id='real-time-chart',
+ style={'height': 'calc(100vh - 250px)'}, # Adjusted to account for header and table
+ config={
+ 'displayModeBar': True,
+ 'scrollZoom': True,
+ 'modeBarButtonsToRemove': ['lasso2d', 'select2d']
+ }
+ ),
+
+ # Interval selector
+ html.Div([
+ html.Button('1m', id='1m-interval', n_clicks=0, style=active_button_style if self.current_interval == '1m' else button_style),
+ html.Button('5m', id='5m-interval', n_clicks=0, style=active_button_style if self.current_interval == '5m' else button_style),
+ html.Button('15m', id='15m-interval', n_clicks=0, style=active_button_style if self.current_interval == '15m' else button_style),
+ html.Button('1h', id='1h-interval', n_clicks=0, style=active_button_style if self.current_interval == '1h' else button_style),
+ html.Button('4h', id='4h-interval', n_clicks=0, style=active_button_style if self.current_interval == '4h' else button_style),
+ html.Button('1d', id='1d-interval', n_clicks=0, style=active_button_style if self.current_interval == '1d' else button_style),
+ ], style={'textAlign': 'center', 'marginTop': '10px'}),
+
+ # Interval component for automatic updates
+ dcc.Interval(
+ id='chart-interval',
+ interval=300, # Refresh every 300ms for better real-time updates
+ n_intervals=0
+ )
+ ], style={
+ 'backgroundColor': '#121212',
+ 'padding': '20px',
+ 'color': 'white',
+ 'height': '100vh',
+ 'boxSizing': 'border-box'
+ })
def _get_ticks_layout(self):
# Ticks data page layout
@@ -1615,906 +718,176 @@ class RealTimeChart:
])
def _setup_interval_callback(self, button_style, active_button_style):
- # Callback to update interval based on button clicks and update button styles
+ """Set up the callback for interval selection buttons"""
@self.app.callback(
- [Output('interval-store', 'data'),
- Output('btn-1s', 'style'),
- Output('btn-5s', 'style'),
- Output('btn-15s', 'style'),
- Output('btn-30s', 'style'),
- Output('btn-1m', 'style')],
- [Input('btn-1s', 'n_clicks'),
- Input('btn-5s', 'n_clicks'),
- Input('btn-15s', 'n_clicks'),
- Input('btn-30s', 'n_clicks'),
- Input('btn-1m', 'n_clicks')],
- [dash.dependencies.State('interval-store', 'data')]
+ [
+ Output('1m-interval', 'style'),
+ Output('5m-interval', 'style'),
+ Output('15m-interval', 'style'),
+ Output('1h-interval', 'style'),
+ Output('4h-interval', 'style'),
+ Output('1d-interval', 'style')
+ ],
+ [
+ Input('1m-interval', 'n_clicks'),
+ Input('5m-interval', 'n_clicks'),
+ Input('15m-interval', 'n_clicks'),
+ Input('1h-interval', 'n_clicks'),
+ Input('4h-interval', 'n_clicks'),
+ Input('1d-interval', 'n_clicks')
+ ]
)
- def update_interval(n1, n5, n15, n30, n60, data):
- ctx = dash.callback_context
- if not ctx.triggered:
- # Default state (1s selected)
- return ({'interval': 1},
- active_button_style, button_style, button_style, button_style, button_style)
+ def update_interval_buttons(n1, n5, n15, n1h, n4h, n1d):
+ ctx = callback_context
+ # Default styles (all inactive)
+ styles = {
+ '1m': button_style.copy(),
+ '5m': button_style.copy(),
+ '15m': button_style.copy(),
+ '1h': button_style.copy(),
+ '4h': button_style.copy(),
+ '1d': button_style.copy()
+ }
+
+ # If no button clicked yet, use default interval
+ if not ctx.triggered:
+ styles[self.current_interval] = active_button_style.copy()
+ return [styles['1m'], styles['5m'], styles['15m'], styles['1h'], styles['4h'], styles['1d']]
+
+ # Get the button ID that was clicked
button_id = ctx.triggered[0]['prop_id'].split('.')[0]
- if button_id == 'btn-1s':
- return ({'interval': 1},
- active_button_style, button_style, button_style, button_style, button_style)
- elif button_id == 'btn-5s':
- return ({'interval': 5},
- button_style, active_button_style, button_style, button_style, button_style)
- elif button_id == 'btn-15s':
- return ({'interval': 15},
- button_style, button_style, active_button_style, button_style, button_style)
- elif button_id == 'btn-30s':
- return ({'interval': 30},
- button_style, button_style, button_style, active_button_style, button_style)
- elif button_id == 'btn-1m':
- return ({'interval': 60},
- button_style, button_style, button_style, button_style, active_button_style)
+ # Map button ID to interval
+ interval_map = {
+ '1m-interval': '1m',
+ '5m-interval': '5m',
+ '15m-interval': '15m',
+ '1h-interval': '1h',
+ '4h-interval': '4h',
+ '1d-interval': '1d'
+ }
- # Default case - keep current interval and highlight appropriate button
- current_interval = data.get('interval', 1)
- styles = [button_style] * 5 # All inactive by default
+ # Update the current interval based on clicked button
+ self.current_interval = interval_map.get(button_id, self.current_interval)
- # Set active style based on current interval
- if current_interval == 1:
- styles[0] = active_button_style
- elif current_interval == 5:
- styles[1] = active_button_style
- elif current_interval == 15:
- styles[2] = active_button_style
- elif current_interval == 30:
- styles[3] = active_button_style
- elif current_interval == 60:
- styles[4] = active_button_style
+ # Set active style for selected interval
+ styles[self.current_interval] = active_button_style.copy()
- return (data, *styles)
+ # Update the chart with the new interval
+ self._update_chart()
+
+ return [styles['1m'], styles['5m'], styles['15m'], styles['1h'], styles['4h'], styles['1d']]
def _setup_chart_callback(self):
- # Callback to update the chart
+ """Set up the callback for the chart updates"""
@self.app.callback(
- Output('live-chart', 'figure'),
- [Input('interval-component', 'n_intervals'),
- Input('interval-store', 'data')]
+ Output('real-time-chart', 'figure'),
+ [Input('chart-interval', 'n_intervals')]
)
- def update_chart(n, interval_data):
+ def update_chart(n_intervals):
try:
- interval = interval_data.get('interval', 1)
- logger.debug(f"Updating chart with interval {interval}")
-
- # Get candlesticks data for the selected interval
- try:
- df = self.tick_storage.get_candles(interval_seconds=interval)
- if df.empty and self.ohlcv_cache[f'{interval}s' if interval < 60 else '1m'] is not None:
- df = self.ohlcv_cache[f'{interval}s' if interval < 60 else '1m']
- except Exception as e:
- logger.error(f"Error getting candles: {str(e)}")
- df = pd.DataFrame()
-
- # Get data for other timeframes (1m, 1h, 1d)
- df_1m = None
- df_1h = None
- df_1d = None
-
- try:
- # Get 1m candles
- if self.ohlcv_cache['1m'] is not None and not self.ohlcv_cache['1m'].empty:
- df_1m = self.ohlcv_cache['1m'].copy()
- else:
- df_1m = self.tick_storage.get_candles(interval_seconds=60)
-
- # Get 1h candles
- if self.ohlcv_cache['1h'] is not None and not self.ohlcv_cache['1h'].empty:
- df_1h = self.ohlcv_cache['1h'].copy()
- else:
- df_1h = self.tick_storage.get_candles(interval_seconds=3600)
-
- # Get 1d candles
- if self.ohlcv_cache['1d'] is not None and not self.ohlcv_cache['1d'].empty:
- df_1d = self.ohlcv_cache['1d'].copy()
- else:
- df_1d = self.tick_storage.get_candles(interval_seconds=86400)
-
- # Limit the number of candles to display but show more for context
- if df_1m is not None and not df_1m.empty:
- df_1m = df_1m.tail(600) # Show 600 1m candles for better context - 10 hours
- if df_1h is not None and not df_1h.empty:
- df_1h = df_1h.tail(480) # Show 480 hours of hourly data for better context - 20 days
- if df_1d is not None and not df_1d.empty:
- df_1d = df_1d.tail(365*2) # Show 2 years of daily data for better context
-
- except Exception as e:
- logger.error(f"Error getting additional timeframes: {str(e)}")
-
- # Create layout with 5 rows (main chart, volume, 1m, 1h, 1d)
- fig = make_subplots(
- rows=5, cols=1,
- vertical_spacing=0.03,
- subplot_titles=(
- f'{self.symbol} Price ({interval}s)',
- 'Volume',
- '1-Minute Chart',
- '1-Hour Chart',
- '1-Day Chart'
- ),
- row_heights=[0.4, 0.15, 0.15, 0.15, 0.15]
- )
-
- # Add candlestick chart for main timeframe
- if not df.empty and 'open' in df.columns:
- # Candlestick chart
- fig.add_trace(
- go.Candlestick(
- x=df.index,
- open=df['open'],
- high=df['high'],
- low=df['low'],
- close=df['close'],
- name='OHLC'
- ),
- row=1, col=1
- )
-
- # Calculate y-axis range with padding
- low_min = df['low'].min()
- high_max = df['high'].max()
- price_range = high_max - low_min
- y_min = low_min - (price_range * 0.05) # 5% padding below
- y_max = high_max + (price_range * 0.05) # 5% padding above
-
- # Set y-axis range to ensure candles are visible
- fig.update_yaxes(range=[y_min, y_max], row=1, col=1)
-
- # Volume bars
- colors = ['rgba(0,255,0,0.7)' if close >= open else 'rgba(255,0,0,0.7)'
- for open, close in zip(df['open'], df['close'])]
-
- fig.add_trace(
- go.Bar(
- x=df.index,
- y=df['volume'],
- name='Volume',
- marker_color=colors
- ),
- row=2, col=1
- )
-
- # Add buy/sell markers and PnL annotations to the main chart
- if hasattr(self, 'trades') and self.trades:
- buy_times = []
- buy_prices = []
- buy_markers = []
- sell_times = []
- sell_prices = []
- sell_markers = []
-
- # Filter trades to only include recent ones (last 100)
- recent_trades = self.trades[-100:]
-
- for trade in recent_trades:
- # Convert timestamp to datetime if it's not already
- trade_time = trade.get('timestamp')
- if isinstance(trade_time, (int, float)):
- trade_time = pd.to_datetime(trade_time, unit='ms')
-
- price = trade.get('price', 0)
- pnl = trade.get('pnl', None)
- action = trade.get('action', 'SELL') # Default to SELL
-
- if action == 'BUY':
- buy_times.append(trade_time)
- buy_prices.append(price)
- buy_markers.append("")
- elif action == 'SELL':
- sell_times.append(trade_time)
- sell_prices.append(price)
- # Add PnL as marker text if available
- if pnl is not None:
- pnl_text = f"{pnl:.4f}" if abs(pnl) < 0.01 else f"{pnl:.2f}"
- sell_markers.append(pnl_text)
- else:
- sell_markers.append("")
-
- # Add buy markers
- if buy_times:
- fig.add_trace(
- go.Scatter(
- x=buy_times,
- y=buy_prices,
- mode='markers',
- name='Buy',
- marker=dict(
- symbol='triangle-up',
- size=12,
- color='rgba(0,255,0,0.8)',
- line=dict(width=1, color='darkgreen')
- ),
- text=buy_markers,
- hoverinfo='x+y+text',
- showlegend=True # Ensure Buy appears in legend
- ),
- row=1, col=1
- )
-
- # Add vertical and horizontal connecting lines for buys
- for i, (btime, bprice) in enumerate(zip(buy_times, buy_prices)):
- # Add vertical dashed line to time axis
- fig.add_shape(
- type="line",
- x0=btime, x1=btime,
- y0=y_min, y1=bprice,
- line=dict(color="rgba(0,255,0,0.5)", width=1, dash="dash"),
- row=1, col=1
- )
- # Add horizontal dashed line showing the price level
- fig.add_shape(
- type="line",
- x0=df.index.min(), x1=btime,
- y0=bprice, y1=bprice,
- line=dict(color="rgba(0,255,0,0.5)", width=1, dash="dash"),
- row=1, col=1
- )
-
- # Add sell markers with PnL annotations
- if sell_times:
- fig.add_trace(
- go.Scatter(
- x=sell_times,
- y=sell_prices,
- mode='markers+text',
- name='Sell',
- marker=dict(
- symbol='triangle-down',
- size=12,
- color='rgba(255,0,0,0.8)',
- line=dict(width=1, color='darkred')
- ),
- text=sell_markers,
- textposition='top center',
- textfont=dict(size=10),
- hoverinfo='x+y+text',
- showlegend=True # Ensure Sell appears in legend
- ),
- row=1, col=1
- )
-
- # Add vertical and horizontal connecting lines for sells
- for i, (stime, sprice) in enumerate(zip(sell_times, sell_prices)):
- # Add vertical dashed line to time axis
- fig.add_shape(
- type="line",
- x0=stime, x1=stime,
- y0=y_min, y1=sprice,
- line=dict(color="rgba(255,0,0,0.5)", width=1, dash="dash"),
- row=1, col=1
- )
- # Add horizontal dashed line showing the price level
- fig.add_shape(
- type="line",
- x0=df.index.min(), x1=stime,
- y0=sprice, y1=sprice,
- line=dict(color="rgba(255,0,0,0.5)", width=1, dash="dash"),
- row=1, col=1
- )
-
- # Add connecting lines between consecutive buy-sell pairs
- if len(buy_times) > 0 and len(sell_times) > 0:
- # Create pairs of buy-sell trades based on timestamps
- pairs = []
- buys_copy = list(zip(buy_times, buy_prices))
-
- for i, (stime, sprice) in enumerate(zip(sell_times, sell_prices)):
- # Find the most recent buy before this sell
- matching_buy = None
- for j, (btime, bprice) in enumerate(buys_copy):
- if btime < stime:
- matching_buy = (btime, bprice)
- buys_copy.pop(j) # Remove this buy to prevent reuse
- break
-
- if matching_buy:
- pairs.append((matching_buy, (stime, sprice)))
-
- # Add connecting lines for each pair
- for (btime, bprice), (stime, sprice) in pairs:
- # Draw line connecting the buy and sell points
- fig.add_shape(
- type="line",
- x0=btime, x1=stime,
- y0=bprice, y1=sprice,
- line=dict(
- color="rgba(255,255,255,0.5)",
- width=1,
- dash="dot"
- ),
- row=1, col=1
- )
-
- # Add 1m chart
- if df_1m is not None and not df_1m.empty and 'open' in df_1m.columns:
- fig.add_trace(
- go.Candlestick(
- x=df_1m.index,
- open=df_1m['open'],
- high=df_1m['high'],
- low=df_1m['low'],
- close=df_1m['close'],
- name='1m',
- showlegend=False
- ),
- row=3, col=1
- )
-
- # Set appropriate date format for 1m chart
- fig.update_xaxes(
- title_text="",
- row=3,
- col=1,
- tickformat="%H:%M",
- tickmode="auto",
- nticks=12
- )
-
- # Add buy/sell markers to 1m chart if they fall within the visible timeframe
- if hasattr(self, 'trades') and self.trades:
- # Filter trades visible in 1m timeframe
- min_time = df_1m.index.min()
- max_time = df_1m.index.max()
-
- # Ensure min_time and max_time are pandas.Timestamp objects
- if isinstance(min_time, (int, float)):
- min_time = pd.to_datetime(min_time, unit='ms')
- if isinstance(max_time, (int, float)):
- max_time = pd.to_datetime(max_time, unit='ms')
-
- # Collect only trades within this timeframe
- minute_buy_times = []
- minute_buy_prices = []
- minute_sell_times = []
- minute_sell_prices = []
-
- for trade in self.trades[-100:]:
- trade_time = trade.get('timestamp')
- if isinstance(trade_time, (int, float)):
- # Convert numeric timestamp to datetime
- trade_time = pd.to_datetime(trade_time, unit='ms')
- elif not isinstance(trade_time, pd.Timestamp) and not isinstance(trade_time, datetime):
- # Skip trades with invalid timestamp format
- continue
-
- # Check if trade falls within 1m chart timeframe
- try:
- if min_time <= trade_time <= max_time:
- price = trade.get('price', 0)
- action = trade.get('action', 'SELL')
-
- if action == 'BUY':
- minute_buy_times.append(trade_time)
- minute_buy_prices.append(price)
- elif action == 'SELL':
- minute_sell_times.append(trade_time)
- minute_sell_prices.append(price)
- except TypeError:
- # If comparison fails due to type mismatch, log the error and skip this trade
- logger.warning(f"Type mismatch in timestamp comparison: min_time={type(min_time)}, trade_time={type(trade_time)}")
- continue
-
- # Add buy markers to 1m chart
- if minute_buy_times:
- fig.add_trace(
- go.Scatter(
- x=minute_buy_times,
- y=minute_buy_prices,
- mode='markers',
- name='Buy (1m)',
- marker=dict(
- symbol='triangle-up',
- size=8,
- color='rgba(0,255,0,0.8)',
- line=dict(width=1, color='darkgreen')
- ),
- showlegend=False,
- hoverinfo='x+y'
- ),
- row=3, col=1
- )
-
- # Add sell markers to 1m chart
- if minute_sell_times:
- fig.add_trace(
- go.Scatter(
- x=minute_sell_times,
- y=minute_sell_prices,
- mode='markers',
- name='Sell (1m)',
- marker=dict(
- symbol='triangle-down',
- size=8,
- color='rgba(255,0,0,0.8)',
- line=dict(width=1, color='darkred')
- ),
- showlegend=False,
- hoverinfo='x+y'
- ),
- row=3, col=1
- )
-
- # Add 1h chart
- if df_1h is not None and not df_1h.empty and 'open' in df_1h.columns:
- fig.add_trace(
- go.Candlestick(
- x=df_1h.index,
- open=df_1h['open'],
- high=df_1h['high'],
- low=df_1h['low'],
- close=df_1h['close'],
- name='1h',
- showlegend=False
- ),
- row=4, col=1
- )
-
- # Set appropriate date format for 1h chart
- fig.update_xaxes(
- title_text="",
- row=4,
- col=1,
- tickformat="%m-%d %H:%M",
- tickmode="auto",
- nticks=8
- )
-
- # Add buy/sell markers to 1h chart if they fall within the visible timeframe
- if hasattr(self, 'trades') and self.trades:
- # Filter trades visible in 1h timeframe
- min_time = df_1h.index.min()
- max_time = df_1h.index.max()
-
- # Ensure min_time and max_time are pandas.Timestamp objects
- if isinstance(min_time, (int, float)):
- min_time = pd.to_datetime(min_time, unit='ms')
- if isinstance(max_time, (int, float)):
- max_time = pd.to_datetime(max_time, unit='ms')
-
- # Collect only trades within this timeframe
- hour_buy_times = []
- hour_buy_prices = []
- hour_sell_times = []
- hour_sell_prices = []
-
- for trade in self.trades[-200:]: # Check more trades for longer timeframe
- trade_time = trade.get('timestamp')
- if isinstance(trade_time, (int, float)):
- # Convert numeric timestamp to datetime
- trade_time = pd.to_datetime(trade_time, unit='ms')
- elif not isinstance(trade_time, pd.Timestamp) and not isinstance(trade_time, datetime):
- # Skip trades with invalid timestamp format
- continue
-
- # Check if trade falls within 1h chart timeframe
- try:
- if min_time <= trade_time <= max_time:
- price = trade.get('price', 0)
- action = trade.get('action', 'SELL')
-
- if action == 'BUY':
- hour_buy_times.append(trade_time)
- hour_buy_prices.append(price)
- elif action == 'SELL':
- hour_sell_times.append(trade_time)
- hour_sell_prices.append(price)
- except TypeError:
- # If comparison fails due to type mismatch, log the error and skip this trade
- logger.warning(f"Type mismatch in timestamp comparison: min_time={type(min_time)}, trade_time={type(trade_time)}")
- continue
-
- # Add buy markers to 1h chart
- if hour_buy_times:
- fig.add_trace(
- go.Scatter(
- x=hour_buy_times,
- y=hour_buy_prices,
- mode='markers',
- name='Buy (1h)',
- marker=dict(
- symbol='triangle-up',
- size=6,
- color='rgba(0,255,0,0.8)',
- line=dict(width=1, color='darkgreen')
- ),
- showlegend=False,
- hoverinfo='x+y'
- ),
- row=4, col=1
- )
-
- # Add sell markers to 1h chart
- if hour_sell_times:
- fig.add_trace(
- go.Scatter(
- x=hour_sell_times,
- y=hour_sell_prices,
- mode='markers',
- name='Sell (1h)',
- marker=dict(
- symbol='triangle-down',
- size=6,
- color='rgba(255,0,0,0.8)',
- line=dict(width=1, color='darkred')
- ),
- showlegend=False,
- hoverinfo='x+y'
- ),
- row=4, col=1
- )
- # Add 1d chart
- if df_1d is not None and not df_1d.empty and 'open' in df_1d.columns:
- fig.add_trace(
- go.Candlestick(
- x=df_1d.index,
- open=df_1d['open'],
- high=df_1d['high'],
- low=df_1d['low'],
- close=df_1d['close'],
- name='1d',
- showlegend=False
- ),
- row=5, col=1
- )
-
- # Set appropriate date format for 1d chart
- fig.update_xaxes(
- title_text="",
- row=5,
- col=1,
- tickformat="%Y-%m-%d",
- tickmode="auto",
- nticks=10
- )
-
- # Add buy/sell markers to 1d chart if they fall within the visible timeframe
- if hasattr(self, 'trades') and self.trades:
- # Filter trades visible in 1d timeframe
- min_time = df_1d.index.min()
- max_time = df_1d.index.max()
-
- # Ensure min_time and max_time are pandas.Timestamp objects
- if isinstance(min_time, (int, float)):
- min_time = pd.to_datetime(min_time, unit='ms')
- if isinstance(max_time, (int, float)):
- max_time = pd.to_datetime(max_time, unit='ms')
-
- # Collect only trades within this timeframe
- day_buy_times = []
- day_buy_prices = []
- day_sell_times = []
- day_sell_prices = []
-
- for trade in self.trades[-300:]: # Check more trades for daily timeframe
- trade_time = trade.get('timestamp')
- if isinstance(trade_time, (int, float)):
- # Convert numeric timestamp to datetime
- trade_time = pd.to_datetime(trade_time, unit='ms')
- elif not isinstance(trade_time, pd.Timestamp) and not isinstance(trade_time, datetime):
- # Skip trades with invalid timestamp format
- continue
-
- # Check if trade falls within 1d chart timeframe
- try:
- if min_time <= trade_time <= max_time:
- price = trade.get('price', 0)
- action = trade.get('action', 'SELL')
-
- if action == 'BUY':
- day_buy_times.append(trade_time)
- day_buy_prices.append(price)
- elif action == 'SELL':
- day_sell_times.append(trade_time)
- day_sell_prices.append(price)
- except TypeError:
- # If comparison fails due to type mismatch, log the error and skip this trade
- logger.warning(f"Type mismatch in timestamp comparison: min_time={type(min_time)}, trade_time={type(trade_time)}")
- continue
-
- # Add buy markers to 1d chart
- if day_buy_times:
- fig.add_trace(
- go.Scatter(
- x=day_buy_times,
- y=day_buy_prices,
- mode='markers',
- name='Buy (1d)',
- marker=dict(
- symbol='triangle-up',
- size=5,
- color='rgba(0,255,0,0.8)',
- line=dict(width=1, color='darkgreen')
- ),
- showlegend=False,
- hoverinfo='x+y'
- ),
- row=5, col=1
- )
-
- # Add sell markers to 1d chart
- if day_sell_times:
- fig.add_trace(
- go.Scatter(
- x=day_sell_times,
- y=day_sell_prices,
- mode='markers',
- name='Sell (1d)',
- marker=dict(
- symbol='triangle-down',
- size=5,
- color='rgba(255,0,0,0.8)',
- line=dict(width=1, color='darkred')
- ),
- showlegend=False,
- hoverinfo='x+y'
- ),
- row=5, col=1
- )
-
- # Add trading info annotation if available
- if hasattr(self, 'current_signal') and self.current_signal:
- signal_color = "#33DD33" if self.current_signal == "BUY" else "#FF4444" if self.current_signal == "SELL" else "#BBBBBB"
-
- # Format position value
- position_text = f"{self.current_position:.4f}" if self.current_position < 0.01 else f"{self.current_position:.2f}"
-
- # Format PnL with color based on value
- pnl_color = "#33DD33" if self.session_pnl >= 0 else "#FF4444"
- pnl_text = f"{self.session_pnl:.4f}"
-
- # Create trading info text
- info_text = (
- f"Signal: {self.current_signal} | "
- f"Position: {position_text} | "
- f"Balance: ${self.session_balance:.2f} | "
- f"PnL: {pnl_text}"
- )
-
- # Add annotation
- fig.add_annotation(
- x=0.5, y=1.05,
- xref="paper", yref="paper",
- text=info_text,
- showarrow=False,
- font=dict(size=14, color="white"),
- bgcolor="rgba(50,50,50,0.6)",
- borderwidth=1,
- borderpad=6,
- align="center"
- )
-
- # Update layout
- fig.update_layout(
- title_text=f"{self.symbol} Real-Time Data",
- title_x=0.5,
- xaxis_rangeslider_visible=False,
- height=1000, # Increased height to accommodate all charts
- template='plotly_dark',
- paper_bgcolor='rgba(0,0,0,0)',
- plot_bgcolor='rgba(25,25,50,1)'
- )
-
- # Update axes styling for all subplots
- fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
- fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
-
- # Hide rangesliders for all candlestick charts
- fig.update_layout(
- xaxis_rangeslider_visible=False,
- xaxis2_rangeslider_visible=False,
- xaxis3_rangeslider_visible=False,
- xaxis4_rangeslider_visible=False,
- xaxis5_rangeslider_visible=False
- )
-
- # Improve date formatting for the main chart
- fig.update_xaxes(
- title_text="",
- row=1,
- col=1,
- tickformat="%H:%M:%S",
- tickmode="auto",
- nticks=15
- )
-
- return fig
+ # Create the main figure if it doesn't exist yet
+ if self.fig is None:
+ self._initialize_chart()
+
+ # Update the chart data
+ self._update_chart()
+ return self.fig
except Exception as e:
logger.error(f"Error updating chart: {str(e)}")
import traceback
logger.error(traceback.format_exc())
- fig = go.Figure()
- fig.add_annotation(
- x=0.5, y=0.5,
- text=f"Error updating chart: {str(e)}",
- showarrow=False,
- font=dict(size=14, color="red"),
- xref="paper", yref="paper"
- )
- return fig
+
+ # Return empty figure on error
+ return {
+ 'data': [],
+ 'layout': {
+ 'title': 'Error updating chart',
+ 'annotations': [{
+ 'text': str(e),
+ 'showarrow': False,
+ 'font': {'color': 'red'}
+ }]
+ }
+ }
def _setup_position_list_callback(self):
- """Callback to update the position list with the latest 5 trades"""
+ """Set up the callback for the position list"""
@self.app.callback(
- Output('position-table-body', 'children'),
- [Input('interval-component', 'n_intervals')]
+ Output('position-list', 'children'),
+ [Input('chart-interval', 'n_intervals')]
)
def update_position_list(n):
- try:
- # Get the last 5 positions
- if not hasattr(self, 'trades') or not self.trades:
- return [html.Tr([html.Td("No trades yet", colSpan=6, style={'textAlign': 'center', 'color': '#AAAAAA', 'padding': '10px'})])]
-
- # Collect BUY and SELL trades to pair them
- buy_trades = {}
- sell_trades = {}
- position_rows = []
-
- # Process all trades to match buy and sell pairs
- for trade in self.trades:
- action = trade.get('action', 'UNKNOWN')
- if action == 'BUY':
- # Store buy trades by timestamp to match with sells later
- buy_trades[trade.get('timestamp')] = trade
- elif action == 'SELL':
- # Store sell trades
- sell_trades[trade.get('timestamp')] = trade
-
- # Create position entries (matched pairs or single trades)
- position_entries = []
-
- # First add matched pairs (BUY trades with matching close_price)
- for timestamp, buy_trade in buy_trades.items():
- if 'close_price' in buy_trade:
- # This is a closed position
- entry_price = buy_trade.get('price', 'N/A')
- exit_price = buy_trade.get('close_price', 'N/A')
- entry_time = buy_trade.get('timestamp')
- exit_time = buy_trade.get('close_timestamp')
- amount = buy_trade.get('amount', 0.1)
-
- # Calculate PnL if not provided
- pnl = buy_trade.get('pnl')
- if pnl is None and isinstance(entry_price, (int, float)) and isinstance(exit_price, (int, float)):
- pnl = (exit_price - entry_price) * amount
-
- position_entries.append({
- 'action': 'CLOSED',
- 'entry_price': entry_price,
- 'exit_price': exit_price,
- 'amount': amount,
- 'pnl': pnl,
- 'time': entry_time,
- 'exit_time': exit_time,
- 'status': 'CLOSED'
- })
- else:
- # This is an open position (BUY without a matching SELL)
- current_price = self.tick_storage.get_latest_price() or buy_trade.get('price', 0)
- amount = buy_trade.get('amount', 0.1)
- entry_price = buy_trade.get('price', 0)
-
- # Calculate unrealized PnL
- unrealized_pnl = (current_price - entry_price) * amount if isinstance(current_price, (int, float)) and isinstance(entry_price, (int, float)) else None
-
- position_entries.append({
- 'action': 'BUY',
- 'entry_price': entry_price,
- 'exit_price': current_price, # Use current price as potential exit
- 'amount': amount,
- 'pnl': unrealized_pnl,
- 'time': buy_trade.get('timestamp'),
- 'status': 'OPEN'
- })
-
- # Add standalone SELL trades that don't have a matching BUY
- for timestamp, sell_trade in sell_trades.items():
- # Check if this SELL is already accounted for in a closed position
- already_matched = False
- for entry in position_entries:
- if entry.get('status') == 'CLOSED' and entry.get('exit_time') == timestamp:
- already_matched = True
- break
-
- if not already_matched:
- position_entries.append({
- 'action': 'SELL',
- 'entry_price': 'N/A',
- 'exit_price': sell_trade.get('price', 'N/A'),
- 'amount': sell_trade.get('amount', 0.1),
- 'pnl': sell_trade.get('pnl'),
- 'time': sell_trade.get('timestamp'),
- 'status': 'STANDALONE'
- })
-
- # Sort by time (most recent first) and take last 5
- position_entries.sort(key=lambda x: x['time'] if isinstance(x['time'], datetime) else datetime.now(), reverse=True)
- position_entries = position_entries[:5]
-
- # Convert to table rows
- for entry in position_entries:
- action = entry['action']
- entry_price = entry['entry_price']
- exit_price = entry['exit_price']
- amount = entry['amount']
- pnl = entry['pnl']
- time_obj = entry['time']
- status = entry['status']
-
- # Format time
- if isinstance(time_obj, datetime):
- # If trade is from a different day, include the date
- today = datetime.now().date()
- if time_obj.date() == today:
- time_str = time_obj.strftime('%H:%M:%S')
- else:
- time_str = time_obj.strftime('%m-%d %H:%M:%S')
- else:
- time_str = str(time_obj)
-
- # Format prices with proper decimal places
- if isinstance(entry_price, (int, float)):
- entry_price_str = f"${entry_price:.2f}"
- else:
- entry_price_str = str(entry_price)
-
- if isinstance(exit_price, (int, float)):
- exit_price_str = f"${exit_price:.2f}"
- else:
- exit_price_str = str(exit_price)
-
- # Format PnL
- if pnl is not None and isinstance(pnl, (int, float)):
- pnl_str = f"${pnl:.2f}"
- pnl_color = '#00FF00' if pnl >= 0 else '#FF0000'
- else:
- pnl_str = 'N/A'
- pnl_color = '#FFFFFF'
-
- # Set action/status color and text
- if status == 'OPEN':
- status_color = '#00AAFF' # Blue for open positions
- status_text = "OPEN (BUY)"
- elif status == 'CLOSED':
- if pnl is not None and isinstance(pnl, (int, float)):
- status_color = '#00FF00' if pnl >= 0 else '#FF0000' # Green/Red based on profit
- else:
- status_color = '#FFCC00' # Yellow if PnL unknown
- status_text = "CLOSED"
- elif action == 'BUY':
- status_color = '#00FF00'
- status_text = "BUY"
- elif action == 'SELL':
- status_color = '#FF0000'
- status_text = "SELL"
- else:
- status_color = '#FFFFFF'
- status_text = action
-
- # Create table row
- position_rows.append(html.Tr([
- html.Td(status_text, style={'color': status_color, 'padding': '8px', 'border': '1px solid #444'}),
- html.Td(f"{amount} BTC", style={'padding': '8px', 'border': '1px solid #444'}),
- html.Td(entry_price_str, style={'padding': '8px', 'border': '1px solid #444'}),
- html.Td(exit_price_str, style={'padding': '8px', 'border': '1px solid #444'}),
- html.Td(pnl_str, style={'color': pnl_color, 'padding': '8px', 'border': '1px solid #444'}),
- html.Td(time_str, style={'padding': '8px', 'border': '1px solid #444'})
- ]))
-
- return position_rows
- except Exception as e:
- logger.error(f"Error updating position table: {str(e)}")
- import traceback
- logger.error(traceback.format_exc())
- return [html.Tr([html.Td(f"Error: {str(e)}", colSpan=6, style={'color': '#FF0000', 'padding': '10px'})])]
+ if not self.positions:
+ return [html.Tr([html.Td("No positions yet", colSpan=6)])]
+ return self._get_position_list_rows()
+
+ def _setup_trading_status_callback(self):
+ """Set up the callback for the trading status fields"""
+ @self.app.callback(
+ [
+ Output('current-signal-value', 'children'),
+ Output('current-position-value', 'children'),
+ Output('current-balance-value', 'children'),
+ Output('current-pnl-value', 'children'),
+ Output('current-signal-value', 'style'),
+ Output('current-position-value', 'style')
+ ],
+ [Input('chart-interval', 'n_intervals')]
+ )
+ def update_trading_status(n):
+ # Get the current signal
+ current_signal = "NONE"
+ signal_style = {'color': 'white'}
+
+ if hasattr(self, 'last_action') and self.last_action:
+ current_signal = self.last_action
+ if current_signal == "BUY":
+ signal_style = {'color': 'green', 'fontWeight': 'bold'}
+ elif current_signal == "SELL":
+ signal_style = {'color': 'red', 'fontWeight': 'bold'}
+
+ # Get the current position
+ current_position = "NONE"
+ position_style = {'color': 'white'}
+
+ # Check if we have any open positions
+ open_positions = [p for p in self.positions if p.status == "OPEN"]
+ if open_positions:
+ current_position = f"{open_positions[0].action} {open_positions[0].amount:.4f}"
+ if open_positions[0].action == "BUY":
+ position_style = {'color': 'green', 'fontWeight': 'bold'}
+ else:
+ position_style = {'color': 'red', 'fontWeight': 'bold'}
+
+ # Get the current balance and session PnL
+ current_balance = f"${self.balance:.2f}" if hasattr(self, 'balance') else "$0.00"
+
+ # Calculate session PnL
+ session_pnl = 0
+ for position in self.positions:
+ if position.status == "CLOSED":
+ session_pnl += position.pnl
+
+ # Format PnL with color
+ pnl_text = f"${session_pnl:.2f}"
+
+ return current_signal, current_position, current_balance, pnl_text, signal_style, position_style
+
+ def _add_manual_trade_inputs(self):
+ # Add manual trade inputs
+ self.app.layout.children.append(
+ html.Div([
+ html.H3("Add Manual Trade"),
+ dcc.Input(id='manual-price', type='number', placeholder='Price'),
+ dcc.Input(id='manual-volume', type='number', placeholder='Volume'),
+ dcc.Input(id='manual-pnl', type='number', placeholder='PnL'),
+ dcc.Input(id='manual-action', type='text', placeholder='Action'),
+ html.Button('Add Trade', id='add-manual-trade')
+ ])
+ )
def _interval_to_seconds(self, interval_key: str) -> int:
"""Convert interval key to seconds"""
@@ -2649,356 +1022,104 @@ class RealTimeChart:
logger.info(f"Waiting 5 seconds before reconnecting {self.symbol} WebSocket...")
await asyncio.sleep(5)
- def run(self, host='localhost', port=8050):
- """Run the Dash app
-
- Args:
- host: Hostname to run on
- port: Port to run on
- """
- logger.info(f"Starting Dash app on {host}:{port}")
-
- # Ensure interval component is created
- if not hasattr(self, 'app') or not self.app.layout:
- logger.error("App layout not initialized properly")
- return
-
- # If interval-component is not in the layout, add it
- if 'interval-component' not in str(self.app.layout):
- logger.warning("Interval component not found in layout, adding it")
- self.app.layout.children.append(
- dcc.Interval(
- id='interval-component',
- interval=500, # 500ms for real-time updates
- n_intervals=0
- )
- )
-
- # Start websocket connection in a separate thread
- loop = asyncio.new_event_loop()
- asyncio.set_event_loop(loop)
- self.websocket_thread = threading.Thread(target=lambda: asyncio.run(self.start_websocket()))
- self.websocket_thread.daemon = True
- self.websocket_thread.start()
-
- # Ensure historical data is loaded before starting
- self._load_historical_data()
-
+ def _run_app(self):
+ """Run the Dash app"""
try:
- self.app.run(host=host, port=port, debug=False)
+ logger.info(f"Starting Dash app for {self.symbol}")
+ # Updated to use app.run instead of app.run_server (which is deprecated)
+ self.app.run(debug=False, use_reloader=False, port=8050)
except Exception as e:
logger.error(f"Error running Dash app: {str(e)}")
- finally:
- # Ensure resources are cleaned up
- self._save_candles_to_disk(force=True)
- logger.info("Dash app stopped")
+ logger.error(traceback.format_exc())
+
+ return
- def _load_historical_data(self):
- """Load historical data for all timeframes from Binance API and local cache"""
+ def add_trade(self, price, timestamp=None, pnl=None, amount=0.1, action="BUY", trade_type="MARKET"):
+ """Add a trade to the chart
+
+ Args:
+ price: Trade price
+ timestamp: Trade timestamp (datetime or milliseconds)
+ pnl: Profit and Loss (for SELL trades)
+ amount: Trade amount
+ action: Trade action (BUY or SELL)
+ trade_type: Trade type (MARKET, LIMIT, etc.)
+ """
try:
- logger.info(f"Loading historical data for {self.symbol}...")
-
- # Define intervals to fetch
- intervals = {
- '1s': 1,
- '1m': 60,
- '1h': 3600,
- '1d': 86400
- }
-
- # Track load status
- load_status = {interval: False for interval in intervals.keys()}
-
- # First try to load from local cache files
- logger.info("Step 1: Loading from local cache files...")
- for interval_key, interval_seconds in intervals.items():
- try:
- cache_file = os.path.join(self.historical_data.cache_dir,
- f"{self.symbol.replace('/', '_')}_{interval_key}_candles.csv")
-
- logger.info(f"Checking for cached {interval_key} data at {cache_file}")
- if os.path.exists(cache_file):
- # Check if cache is fresh (less than 1 day old for anything but 1d, 3 days for 1d)
- file_age = time.time() - os.path.getmtime(cache_file)
- max_age = 259200 if interval_key == '1d' else 86400 # 3 days for 1d, 1 day for others
- logger.info(f"Cache file age: {file_age:.1f}s, max allowed: {max_age}s")
+ # Convert timestamp to datetime if it's a number
+ if timestamp is None:
+ timestamp = datetime.now()
+ elif isinstance(timestamp, (int, float)):
+ timestamp = datetime.fromtimestamp(timestamp / 1000)
+
+ # Process the trade based on action
+ if action == "BUY":
+ # Create a new position
+ position = Position(
+ action="BUY",
+ entry_price=price,
+ amount=amount,
+ timestamp=timestamp
+ )
+ self.positions.append(position)
+
+ # Update last action
+ self.last_action = "BUY"
+
+ elif action == "SELL":
+ # Find an open BUY position to close, or create a new SELL position
+ open_buy_position = None
+ for pos in self.positions:
+ if pos.status == "OPEN" and pos.action == "BUY":
+ open_buy_position = pos
+ break
- if file_age <= max_age:
- logger.info(f"Loading {interval_key} candles from cache")
- cached_df = pd.read_csv(cache_file)
- if not cached_df.empty:
- # Diagnostic info about the loaded data
- logger.info(f"Loaded {len(cached_df)} candles from {cache_file}")
- logger.info(f"Columns: {cached_df.columns.tolist()}")
- logger.info(f"First few rows: {cached_df.head(2).to_dict('records')}")
-
- # Convert timestamp string back to datetime
- if 'timestamp' in cached_df.columns:
- try:
- if not pd.api.types.is_datetime64_any_dtype(cached_df['timestamp']):
- cached_df['timestamp'] = pd.to_datetime(cached_df['timestamp'])
- logger.info("Successfully converted timestamps to datetime")
- except Exception as e:
- logger.warning(f"Could not convert timestamp column for {interval_key}: {str(e)}")
-
- # Only keep the last 2000 candles for memory efficiency
- if len(cached_df) > 2000:
- cached_df = cached_df.tail(2000)
- logger.info(f"Truncated to last 2000 candles")
-
- # Add to cache
- for _, row in cached_df.iterrows():
- candle_dict = row.to_dict()
- self.candle_cache.candles[interval_key].append(candle_dict)
-
- # Update ohlcv_cache
- self.ohlcv_cache[interval_key] = self.candle_cache.get_recent_candles(interval_key, count=2000)
- logger.info(f"Successfully loaded {len(self.ohlcv_cache[interval_key])} cached {interval_key} candles")
-
- if len(self.ohlcv_cache[interval_key]) >= 500:
- load_status[interval_key] = True
- # Skip fetching from API if we loaded from cache (except for 1d timeframe which we always refresh)
- if interval_key != '1d':
- continue
- else:
- logger.info(f"Cache file for {interval_key} is too old ({file_age:.1f}s)")
- else:
- logger.info(f"No cache file found for {interval_key}")
- except Exception as e:
- logger.error(f"Error loading cached {interval_key} candles: {str(e)}")
- import traceback
- logger.error(traceback.format_exc())
-
- # For timeframes other than 1s, fetch from API as backup or for fresh data
- logger.info("Step 2: Fetching data from API for missing timeframes...")
- for interval_key, interval_seconds in intervals.items():
- # Skip 1s for API requests
- if interval_key == '1s' or load_status[interval_key]:
- logger.info(f"Skipping API fetch for {interval_key}: already loaded or 1s timeframe")
- continue
+ if open_buy_position:
+ # Close the position
+ pnl_value = open_buy_position.close(price, timestamp)
- # Fetch historical data from API
- try:
- logger.info(f"Fetching {interval_key} candles from API for {self.symbol}")
- historical_df = self.historical_data.get_historical_candles(
- symbol=self.symbol,
- interval_seconds=interval_seconds,
- limit=500 # Get 500 candles
+ # Update balance
+ self.balance += pnl_value
+
+ # If pnl was provided, use it instead
+ if pnl is not None:
+ open_buy_position.pnl = pnl
+ self.balance = self.balance - pnl_value + pnl
+
+ else:
+ # Create a standalone SELL position
+ position = Position(
+ action="SELL",
+ entry_price=price,
+ amount=amount,
+ timestamp=timestamp
)
- if not historical_df.empty:
- logger.info(f"Loaded {len(historical_df)} historical candles for {self.symbol} {interval_key} from API")
+ # Set it as closed with the same price
+ position.close(price, timestamp)
+
+ # Set PnL if provided
+ if pnl is not None:
+ position.pnl = pnl
+ self.balance += pnl
- # If we already have data in cache, merge with new data to avoid duplicates
- if self.ohlcv_cache[interval_key] is not None and not self.ohlcv_cache[interval_key].empty:
- existing_df = self.ohlcv_cache[interval_key]
- # Get the latest timestamp from existing data
- latest_time = existing_df['timestamp'].max()
- # Only keep newer records from API
- new_candles = historical_df[historical_df['timestamp'] > latest_time]
- if not new_candles.empty:
- logger.info(f"Adding {len(new_candles)} new candles to existing {interval_key} cache")
- # Add to cache
- for _, row in new_candles.iterrows():
- candle_dict = row.to_dict()
- self.candle_cache.candles[interval_key].append(candle_dict)
- else:
- # No existing data, add all from API
- for _, row in historical_df.iterrows():
- candle_dict = row.to_dict()
- self.candle_cache.candles[interval_key].append(candle_dict)
-
- # Update ohlcv_cache with combined data
- self.ohlcv_cache[interval_key] = self.candle_cache.get_recent_candles(interval_key, count=2000)
- logger.info(f"Total {interval_key} candles in cache: {len(self.ohlcv_cache[interval_key])}")
-
- if len(self.ohlcv_cache[interval_key]) >= 500:
- load_status[interval_key] = True
- else:
- logger.warning(f"No historical data available from API for {self.symbol} {interval_key}")
- except Exception as e:
- logger.error(f"Error fetching {interval_key} data from API: {str(e)}")
- import traceback
- logger.error(traceback.format_exc())
+ self.positions.append(position)
+
+ # Update last action
+ self.last_action = "SELL"
- # Log summary of loaded data
- logger.info("Historical data load summary:")
- for interval_key in intervals.keys():
- count = len(self.ohlcv_cache[interval_key]) if self.ohlcv_cache[interval_key] is not None else 0
- status = "Success" if load_status[interval_key] else "Failed"
- if count > 0 and count < 500:
- status = "Partial"
- logger.info(f"{interval_key}: {count} candles - {status}")
+ # Log the trade
+ logger.info(f"Added {action} trade: price={price}, amount={amount}, time={timestamp}, PnL={pnl}")
+
+ # Trigger more frequent chart updates for immediate visibility
+ if hasattr(self, 'fig') and self.fig is not None:
+ self._update_chart()
except Exception as e:
- logger.error(f"Error in _load_historical_data: {str(e)}")
+ logger.error(f"Error adding trade: {str(e)}")
import traceback
logger.error(traceback.format_exc())
- def _save_candles_to_disk(self, force=False):
- """Save current candle cache to disk for persistence between runs"""
- try:
- # Only save if we have data and sufficient time has passed (every 5 minutes)
- current_time = time.time()
- if not force and current_time - self.last_cache_save_time < 300: # 5 minutes
- return
-
- # Save each timeframe's candles to disk
- for interval_key, candles in self.candle_cache.candles.items():
- if candles:
- # Convert to DataFrame
- df = pd.DataFrame(list(candles))
- if not df.empty:
- # Ensure timestamp is properly formatted
- if 'timestamp' in df.columns:
- try:
- if not pd.api.types.is_datetime64_any_dtype(df['timestamp']):
- df['timestamp'] = pd.to_datetime(df['timestamp'])
- except:
- logger.warning(f"Could not convert timestamp column for {interval_key}")
-
- # Save to disk in the cache directory
- cache_file = os.path.join(self.historical_data.cache_dir,
- f"{self.symbol.replace('/', '_')}_{interval_key}_candles.csv")
- df.to_csv(cache_file, index=False)
- logger.info(f"Saved {len(df)} {interval_key} candles to {cache_file}")
-
- self.last_cache_save_time = current_time
- logger.info(f"Saved all candle caches to disk at {datetime.now()}")
- except Exception as e:
- logger.error(f"Error saving candles to disk: {str(e)}")
- import traceback
- logger.error(traceback.format_exc())
-
- def add_nn_signal(self, signal_type, timestamp, probability=None):
- """Add a neural network signal to be displayed on the chart
-
- Args:
- signal_type: The type of signal (BUY, SELL, HOLD)
- timestamp: The timestamp for the signal
- probability: Optional probability/confidence value
- """
- if signal_type not in ['BUY', 'SELL', 'HOLD']:
- logger.warning(f"Invalid NN signal type: {signal_type}")
- return
-
- # Convert timestamp to datetime if it's not already
- if not isinstance(timestamp, datetime):
- try:
- if isinstance(timestamp, str):
- timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
- elif isinstance(timestamp, (int, float)):
- timestamp = datetime.fromtimestamp(timestamp / 1000.0)
- except Exception as e:
- logger.error(f"Error converting timestamp for NN signal: {str(e)}")
- timestamp = datetime.now()
-
- # Add the signal to our list
- self.nn_signals.append({
- 'type': signal_type,
- 'timestamp': timestamp,
- 'probability': probability,
- 'added': datetime.now()
- })
-
- # Only keep the most recent 50 signals
- if len(self.nn_signals) > 50:
- self.nn_signals = self.nn_signals[-50:]
-
- logger.info(f"Added NN signal: {signal_type} at {timestamp}")
-
- def add_trade(self, price, timestamp, pnl=None, amount=0.1, action=None, type=None):
- """Add a trade to be displayed on the chart
-
- Args:
- price: The price at which the trade was executed
- timestamp: The timestamp for the trade
- pnl: Optional profit and loss value for the trade
- amount: Amount traded
- action: The type of trade (BUY or SELL) - alternative to type parameter
- type: The type of trade (BUY or SELL) - alternative to action parameter
- """
- # Handle both action and type parameters for backward compatibility
- trade_type = type or action
-
- # Default to BUY if trade_type is None or not specified
- if trade_type is None:
- logger.warning(f"Trade type not specified in add_trade call, defaulting to BUY. Price: {price}, Timestamp: {timestamp}")
- trade_type = "BUY"
-
- if isinstance(trade_type, int):
- trade_type = "BUY" if trade_type == 0 else "SELL"
-
- # Ensure trade_type is uppercase if it's a string
- if isinstance(trade_type, str):
- trade_type = trade_type.upper()
-
- if trade_type not in ['BUY', 'SELL']:
- logger.warning(f"Invalid trade type: {trade_type} (value type: {type(trade_type).__name__}), defaulting to BUY. Price: {price}, Timestamp: {timestamp}")
- trade_type = "BUY"
-
- # Convert timestamp to datetime if it's not already
- if not isinstance(timestamp, datetime):
- try:
- if isinstance(timestamp, str):
- timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
- elif isinstance(timestamp, (int, float)):
- timestamp = datetime.fromtimestamp(timestamp / 1000.0)
- except Exception as e:
- logger.error(f"Error converting timestamp for trade: {str(e)}")
- timestamp = datetime.now()
-
- # Create the trade object
- trade = {
- 'price': price,
- 'timestamp': timestamp,
- 'pnl': pnl,
- 'amount': amount,
- 'action': trade_type
- }
-
- # Add to our trades list
- if not hasattr(self, 'trades'):
- self.trades = []
-
- # If this is a SELL trade, try to find the corresponding BUY trade and update it with close_price
- if trade_type == 'SELL' and len(self.trades) > 0:
- for i in range(len(self.trades) - 1, -1, -1):
- prev_trade = self.trades[i]
- if prev_trade.get('action') == 'BUY' and 'close_price' not in prev_trade:
- # Found a BUY trade without a close_price, consider it the matching trade
- prev_trade['close_price'] = price
- prev_trade['close_timestamp'] = timestamp
- logger.info(f"Updated BUY trade at {prev_trade['timestamp']} with close price {price}")
- break
-
- self.trades.append(trade)
-
- # Log the trade for debugging
- pnl_str = f" with PnL: {pnl}" if pnl is not None else ""
- logger.info(f"Added trade: {trade_type} {amount} at price {price} at time {timestamp}{pnl_str}")
-
- # Trigger a more frequent update of the chart by scheduling a callback
- # This helps ensure the trade appears immediately on the chart
- if hasattr(self, 'app') and self.app is not None:
- try:
- # Only update if we have a dash app running
- # This is a workaround to make trades appear immediately
- callback_context = dash.callback_context
- # Force an update by triggering the callback
- for callback_id, callback_info in self.app.callback_map.items():
- if 'live-chart' in callback_id:
- # Found the chart callback, try to trigger it
- logger.debug(f"Triggering chart update callback after trade")
- callback_info['callback']()
- break
- except Exception as e:
- # If callback triggering fails, it's not critical
- logger.debug(f"Failed to trigger chart update: {str(e)}")
- pass
-
- return trade
-
def update_trading_info(self, signal=None, position=None, balance=None, pnl=None):
"""Update the current trading information to be displayed on the chart
@@ -3026,9 +1147,291 @@ class RealTimeChart:
logger.debug(f"Updated trading info: Signal={self.current_signal}, Position={self.current_position}, Balance=${self.session_balance:.2f}, PnL={self.session_pnl:.4f}")
+ def _get_position_list_rows(self):
+ """Generate rows for the position table"""
+ if not self.positions:
+ return [html.Tr([html.Td("No positions yet", colSpan=6)])]
+
+ position_rows = []
+
+ # Sort positions by time (most recent first)
+ sorted_positions = sorted(self.positions,
+ key=lambda x: x.timestamp if hasattr(x, 'timestamp') else datetime.now(),
+ reverse=True)
+
+ # Take only the most recent 5 positions
+ for position in sorted_positions[:5]:
+ # Format time
+ time_obj = position.timestamp if hasattr(position, 'timestamp') else datetime.now()
+ if isinstance(time_obj, datetime):
+ # If trade is from a different day, include the date
+ today = datetime.now().date()
+ if time_obj.date() == today:
+ time_str = time_obj.strftime('%H:%M:%S')
+ else:
+ time_str = time_obj.strftime('%m-%d %H:%M:%S')
+ else:
+ time_str = str(time_obj)
+
+ # Format prices with proper decimal places
+ entry_price = position.entry_price if hasattr(position, 'entry_price') else 'N/A'
+ if isinstance(entry_price, (int, float)):
+ entry_price_str = f"${entry_price:.6f}"
+ else:
+ entry_price_str = str(entry_price)
+
+ # For exit price, use close_price for closed positions or current market price for open ones
+ if position.status == "CLOSED" and hasattr(position, 'exit_price'):
+ exit_price = position.exit_price
+ else:
+ exit_price = self.tick_storage.get_latest_price() if position.status == "OPEN" else 'N/A'
+
+ if isinstance(exit_price, (int, float)):
+ exit_price_str = f"${exit_price:.6f}"
+ else:
+ exit_price_str = str(exit_price)
+
+ # Format amount
+ amount = position.amount if hasattr(position, 'amount') else 0.1
+ amount_str = f"{amount:.4f} BTC"
+
+ # Format PnL
+ if position.status == "CLOSED":
+ pnl = position.pnl if hasattr(position, 'pnl') else 0
+ pnl_str = f"${pnl:.2f}"
+ pnl_color = '#00FF00' if pnl >= 0 else '#FF0000'
+ elif position.status == "OPEN" and position.action == "BUY":
+ # Calculate unrealized PnL for open positions
+ if isinstance(exit_price, (int, float)) and isinstance(entry_price, (int, float)):
+ unrealized_pnl = (exit_price - entry_price) * amount
+ pnl_str = f"${unrealized_pnl:.2f} (unrealized)"
+ pnl_color = '#00FF00' if unrealized_pnl >= 0 else '#FF0000'
+ else:
+ pnl_str = 'N/A'
+ pnl_color = '#FFFFFF'
+ else:
+ pnl_str = 'N/A'
+ pnl_color = '#FFFFFF'
+
+ # Set action/status color and text
+ if position.status == 'OPEN':
+ status_color = '#00AAFF' # Blue for open positions
+ status_text = f"OPEN ({position.action})"
+ elif position.status == 'CLOSED':
+ if hasattr(position, 'pnl') and isinstance(position.pnl, (int, float)):
+ status_color = '#00FF00' if position.pnl >= 0 else '#FF0000' # Green/Red based on profit
+ else:
+ status_color = '#FFCC00' # Yellow if PnL unknown
+ status_text = "CLOSED"
+ else:
+ status_color = '#00FF00' if position.action == 'BUY' else '#FF0000'
+ status_text = position.action
+
+ # Create table row with more compact styling
+ position_rows.append(html.Tr([
+ html.Td(status_text, style={'color': status_color, 'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px'}),
+ html.Td(amount_str, style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px'}),
+ html.Td(entry_price_str, style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px'}),
+ html.Td(exit_price_str, style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px'}),
+ html.Td(pnl_str, style={'color': pnl_color, 'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px'}),
+ html.Td(time_str, style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px'})
+ ]))
+
+ return position_rows
+
+ def _initialize_chart(self):
+ """Initialize the chart figure"""
+ # Create a figure with subplots for price and volume
+ self.fig = make_subplots(
+ rows=2,
+ cols=1,
+ shared_xaxes=True,
+ vertical_spacing=0.03,
+ row_heights=[0.8, 0.2],
+ subplot_titles=(f"{self.symbol} Price Chart", "Volume")
+ )
+
+ # Set up initial empty traces
+ self.fig.add_trace(
+ go.Candlestick(
+ x=[], open=[], high=[], low=[], close=[],
+ name='Price',
+ increasing={'line': {'color': '#26A69A', 'width': 1}, 'fillcolor': '#26A69A'},
+ decreasing={'line': {'color': '#EF5350', 'width': 1}, 'fillcolor': '#EF5350'}
+ ),
+ row=1, col=1
+ )
+
+ # Add volume trace
+ self.fig.add_trace(
+ go.Bar(
+ x=[], y=[],
+ name='Volume',
+ marker={'color': '#888888'}
+ ),
+ row=2, col=1
+ )
+
+ # Add empty traces for buy/sell markers
+ self.fig.add_trace(
+ go.Scatter(
+ x=[], y=[],
+ mode='markers',
+ name='BUY',
+ marker=dict(
+ symbol='triangle-up',
+ size=12,
+ color='rgba(0,255,0,0.8)',
+ line=dict(width=1, color='darkgreen')
+ ),
+ showlegend=True
+ ),
+ row=1, col=1
+ )
+
+ self.fig.add_trace(
+ go.Scatter(
+ x=[], y=[],
+ mode='markers',
+ name='SELL',
+ marker=dict(
+ symbol='triangle-down',
+ size=12,
+ color='rgba(255,0,0,0.8)',
+ line=dict(width=1, color='darkred')
+ ),
+ showlegend=True
+ ),
+ row=1, col=1
+ )
+
+ # Update layout
+ self.fig.update_layout(
+ title=f"{self.symbol} Real-Time Trading Chart",
+ title_x=0.5,
+ template='plotly_dark',
+ paper_bgcolor='rgba(0,0,0,0)',
+ plot_bgcolor='rgba(25,25,50,1)',
+ height=800,
+ xaxis_rangeslider_visible=False,
+ legend=dict(
+ orientation="h",
+ yanchor="bottom",
+ y=1.02,
+ xanchor="center",
+ x=0.5
+ )
+ )
+
+ # Update axes styling
+ self.fig.update_xaxes(
+ showgrid=True,
+ gridwidth=1,
+ gridcolor='rgba(128,128,128,0.2)',
+ zeroline=False
+ )
+
+ self.fig.update_yaxes(
+ showgrid=True,
+ gridwidth=1,
+ gridcolor='rgba(128,128,128,0.2)',
+ zeroline=False
+ )
+
+ # Do an initial update to populate the chart
+ self._update_chart()
+
+ def _update_chart(self):
+ """Update the chart with the latest data"""
+ try:
+ # Get candlesticks data for the current interval
+ df = self.tick_storage.get_candles(interval=self.current_interval)
+
+ if df is None or df.empty:
+ logger.warning(f"No candle data available for {self.current_interval}")
+ return
+
+ # Limit the number of candles to display (show 500 for context)
+ df = df.tail(500)
+
+ # Update candlestick data
+ self.fig.update_traces(
+ x=df.index,
+ open=df['open'],
+ high=df['high'],
+ low=df['low'],
+ close=df['close'],
+ selector=dict(type='candlestick')
+ )
+
+ # Update volume bars with colors based on price movement
+ colors = ['rgba(0,255,0,0.5)' if close >= open else 'rgba(255,0,0,0.5)'
+ for open, close in zip(df['open'], df['close'])]
+
+ self.fig.update_traces(
+ x=df.index,
+ y=df['volume'],
+ marker_color=colors,
+ selector=dict(type='bar')
+ )
+
+ # Calculate y-axis range with padding for better visibility
+ if len(df) > 0:
+ low_min = df['low'].min()
+ high_max = df['high'].max()
+ price_range = high_max - low_min
+ y_min = low_min - (price_range * 0.05) # 5% padding below
+ y_max = high_max + (price_range * 0.05) # 5% padding above
+
+ # Update y-axis range
+ self.fig.update_yaxes(range=[y_min, y_max], row=1, col=1)
+
+ # Update Buy/Sell markers
+ if hasattr(self, 'positions') and self.positions:
+ # Collect buy and sell points
+ buy_times = []
+ buy_prices = []
+ sell_times = []
+ sell_prices = []
+
+ for position in self.positions:
+ # Handle buy trades
+ if position.action == "BUY":
+ buy_times.append(position.timestamp)
+ buy_prices.append(position.entry_price)
+
+ # Handle sell trades or closed positions
+ if position.status == "CLOSED" and hasattr(position, 'exit_timestamp') and hasattr(position, 'exit_price'):
+ sell_times.append(position.exit_timestamp)
+ sell_prices.append(position.exit_price)
+
+ # Update buy markers trace
+ self.fig.update_traces(
+ x=buy_times,
+ y=buy_prices,
+ selector=dict(name='BUY')
+ )
+
+ # Update sell markers trace
+ self.fig.update_traces(
+ x=sell_times,
+ y=sell_prices,
+ selector=dict(name='SELL')
+ )
+
+ # Update chart title with the current interval
+ self.fig.update_layout(
+ title=f"{self.symbol} Real-Time Chart ({self.current_interval})"
+ )
+
+ except Exception as e:
+ logger.error(f"Error in _update_chart: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+
async def main():
global charts # Make charts globally accessible for NN integration
- symbols = ["ETH/USDT", "BTC/USDT"]
+ symbols = ["ETH/USDT", "ETH/USDT"]
logger.info(f"Starting application for symbols: {symbols}")
# Initialize neural network if enabled
diff --git a/realtime_old.py b/realtime_old.py
new file mode 100644
index 0000000..0bfa47f
--- /dev/null
+++ b/realtime_old.py
@@ -0,0 +1,3083 @@
+import asyncio
+import json
+import logging
+
+from typing import Dict, List, Optional
+import websockets
+import plotly.graph_objects as go
+from plotly.subplots import make_subplots
+import dash
+from dash import html, dcc
+from dash.dependencies import Input, Output
+import pandas as pd
+import numpy as np
+from collections import deque
+import time
+from threading import Thread
+import requests
+import os
+from datetime import datetime, timedelta
+import pytz
+import tzlocal
+import threading
+import random
+import dash_bootstrap_components as dbc
+
+# Configure logging with more detailed format
+logging.basicConfig(
+ level=logging.INFO, # Changed to DEBUG for more detailed logs
+ format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s',
+ handlers=[
+ logging.StreamHandler(),
+ logging.FileHandler('realtime_chart.log')
+ ]
+)
+logger = logging.getLogger(__name__)
+
+# Neural Network integration (conditional import)
+NN_ENABLED = os.environ.get('ENABLE_NN_MODELS', '0') == '1'
+nn_orchestrator = None
+nn_inference_thread = None
+
+if NN_ENABLED:
+ try:
+ import sys
+ # Add project root to sys.path if needed
+ project_root = os.path.dirname(os.path.abspath(__file__))
+ if project_root not in sys.path:
+ sys.path.append(project_root)
+
+ from NN.main import NeuralNetworkOrchestrator
+ logger.info("Neural Network module enabled")
+ except ImportError as e:
+ logger.warning(f"Failed to import Neural Network module, disabling NN features: {str(e)}")
+ NN_ENABLED = False
+
+# NN utility functions
+def setup_neural_network():
+ """Initialize the neural network components if enabled"""
+ global nn_orchestrator, NN_ENABLED
+
+ if not NN_ENABLED:
+ return False
+
+ try:
+ # Get configuration from environment variables or use defaults
+ symbol = os.environ.get('NN_SYMBOL', 'BTC/USDT')
+ timeframes = os.environ.get('NN_TIMEFRAMES', '1m,5m,1h,4h,1d').split(',')
+ output_size = int(os.environ.get('NN_OUTPUT_SIZE', '3')) # 3 for BUY/HOLD/SELL
+
+ # Configure the orchestrator
+ config = {
+ 'symbol': symbol,
+ 'timeframes': timeframes,
+ 'window_size': int(os.environ.get('NN_WINDOW_SIZE', '20')),
+ 'n_features': 5, # OHLCV
+ 'output_size': output_size,
+ 'model_dir': 'NN/models/saved',
+ 'data_dir': 'NN/data'
+ }
+
+ # Initialize the orchestrator
+ logger.info(f"Initializing Neural Network Orchestrator with config: {config}")
+ nn_orchestrator = NeuralNetworkOrchestrator(config)
+
+ # Start inference thread if enabled
+ inference_interval = int(os.environ.get('NN_INFERENCE_INTERVAL', '60'))
+ if inference_interval > 0:
+ start_nn_inference_thread(inference_interval)
+
+ return True
+ except Exception as e:
+ logger.error(f"Error setting up neural network: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+ NN_ENABLED = False
+ return False
+
+def start_nn_inference_thread(interval_seconds):
+ """Start a background thread to periodically run inference with the neural network"""
+ global nn_inference_thread
+
+ if not NN_ENABLED or nn_orchestrator is None:
+ logger.warning("Cannot start inference thread - Neural Network not enabled or initialized")
+ return False
+
+ def inference_worker():
+ """Worker function for the inference thread"""
+ model_type = os.environ.get('NN_MODEL_TYPE', 'cnn')
+ timeframe = os.environ.get('NN_TIMEFRAME', '1h')
+
+ logger.info(f"Starting neural network inference thread with {interval_seconds}s interval")
+ logger.info(f"Using model type: {model_type}, timeframe: {timeframe}")
+
+ # Wait a bit for charts to initialize
+ time.sleep(5)
+
+ # Track active charts
+ active_charts = []
+
+ while True:
+ try:
+ # Find active charts if we don't have them yet
+ if not active_charts and 'charts' in globals():
+ active_charts = globals()['charts']
+ logger.info(f"Found {len(active_charts)} active charts for NN signals")
+
+ # Run inference
+ result = nn_orchestrator.run_inference_pipeline(
+ model_type=model_type,
+ timeframe=timeframe
+ )
+
+ if result:
+ # Log the result
+ logger.info(f"Neural network inference result: {result}")
+
+ # Add signal to charts
+ if active_charts:
+ try:
+ if 'action' in result:
+ action = result['action']
+ timestamp = datetime.fromisoformat(result['timestamp'].replace('Z', '+00:00'))
+
+ # Get probability if available
+ probability = None
+ if 'probability' in result:
+ probability = result['probability']
+ elif 'probabilities' in result:
+ probability = result['probabilities'].get(action, None)
+
+ # Add signal to each chart
+ for chart in active_charts:
+ if hasattr(chart, 'add_nn_signal'):
+ chart.add_nn_signal(action, timestamp, probability)
+ except Exception as e:
+ logger.error(f"Error adding NN signal to chart: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+
+ # Sleep for the interval
+ time.sleep(interval_seconds)
+
+ except Exception as e:
+ logger.error(f"Error in inference thread: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+ time.sleep(5) # Wait a bit before retrying
+
+ # Create and start the thread
+ nn_inference_thread = threading.Thread(target=inference_worker, daemon=True)
+ nn_inference_thread.start()
+
+ return True
+
+# Try to get local timezone, default to Sofia/EET if not available
+try:
+ local_timezone = tzlocal.get_localzone()
+ # Get timezone name safely
+ try:
+ tz_name = str(local_timezone)
+ # Handle case where it might be zoneinfo.ZoneInfo object instead of pytz timezone
+ if hasattr(local_timezone, 'zone'):
+ tz_name = local_timezone.zone
+ elif hasattr(local_timezone, 'key'):
+ tz_name = local_timezone.key
+ else:
+ tz_name = str(local_timezone)
+ except:
+ tz_name = "Local"
+ logger.info(f"Detected local timezone: {local_timezone} ({tz_name})")
+except Exception as e:
+ logger.warning(f"Could not detect local timezone: {str(e)}. Defaulting to Sofia/EET")
+ local_timezone = pytz.timezone('Europe/Sofia')
+ tz_name = "Europe/Sofia"
+
+def convert_to_local_time(timestamp):
+ """Convert timestamp to local timezone"""
+ try:
+ if isinstance(timestamp, pd.Timestamp):
+ dt = timestamp.to_pydatetime()
+ elif isinstance(timestamp, np.datetime64):
+ dt = pd.Timestamp(timestamp).to_pydatetime()
+ elif isinstance(timestamp, str):
+ dt = pd.to_datetime(timestamp).to_pydatetime()
+ else:
+ dt = timestamp
+
+ # If datetime is naive (no timezone), assume it's UTC
+ if dt.tzinfo is None:
+ dt = dt.replace(tzinfo=pytz.UTC)
+
+ # Convert to local timezone
+ local_dt = dt.astimezone(local_timezone)
+ return local_dt
+ except Exception as e:
+ logger.error(f"Error converting timestamp to local time: {str(e)}")
+ return timestamp
+
+class TradeTickStorage:
+ """Storage for trade ticks with a maximum age limit"""
+
+ def __init__(self, symbol: str = None, max_age_seconds: int = 3600, use_sample_data: bool = True, log_no_ticks_warning: bool = False): # 1 hour by default, up from 30 min
+ """Initialize the tick storage
+
+ Args:
+ symbol: Trading symbol
+ max_age_seconds: Maximum age for ticks to be stored
+ use_sample_data: If True, generate sample ticks when no real ticks available
+ log_no_ticks_warning: If True, log a warning when no ticks are available
+ """
+ self.symbol = symbol
+ self.ticks = []
+ self.max_age_seconds = max_age_seconds
+ self.last_cleanup_time = time.time()
+ self.cleanup_interval = 60 # run cleanup every 60 seconds
+ self.cache_dir = "cache"
+ self.use_sample_data = use_sample_data
+ self.log_no_ticks_warning = log_no_ticks_warning
+ self.last_sample_price = 83000.0 # Starting price for sample data (for BTC)
+ self.last_sample_time = time.time() * 1000 # Starting time for sample data
+ self.last_tick_time = 0 # Initialize last_tick_time attribute
+ self.tick_count = 0 # Initialize tick_count attribute
+
+ # Create cache directory if it doesn't exist
+ if not os.path.exists(self.cache_dir):
+ os.makedirs(self.cache_dir)
+
+ # Try to load cached ticks
+ self._load_cached_ticks()
+
+ logger.info(f"Initialized TradeTickStorage for {symbol} with max age: {max_age_seconds} seconds, cleanup interval: {self.cleanup_interval} seconds")
+
+ def add_tick(self, tick: Dict):
+ """Add a tick to storage
+
+ Args:
+ tick: Tick data dict with fields:
+ price: The price
+ volume: The volume
+ timestamp: Timestamp in milliseconds
+ """
+ if not tick:
+ return
+
+ # Check if we need to generate a timestamp
+ if 'timestamp' not in tick:
+ tick['timestamp'] = int(time.time() * 1000) # Current time in ms
+
+ # Ensure timestamp is an integer (milliseconds since epoch)
+ if not isinstance(tick['timestamp'], int):
+ try:
+ # Try to convert from float or string
+ tick['timestamp'] = int(float(tick['timestamp']))
+ except (ValueError, TypeError):
+ # If conversion fails, use current time
+ tick['timestamp'] = int(time.time() * 1000)
+
+ # Set default volume if not present
+ if 'volume' not in tick:
+ tick['volume'] = 0.01 # Default small volume
+
+ # Add tick to storage with a copy to avoid mutation
+ self.ticks.append(tick.copy())
+
+ # Keep track of latest tick for stats
+ self.last_tick_time = max(self.last_tick_time, tick['timestamp'])
+
+ # Cache every 100 ticks to avoid data loss
+ self.tick_count += 1
+ if self.tick_count % 100 == 0:
+ self._cache_ticks()
+
+ # Periodically clean up old ticks
+ if self.tick_count % 1000 == 0:
+ self._cleanup()
+
+ def _cleanup(self):
+ """Remove ticks older than max_age_seconds"""
+ # Get current time in milliseconds
+ now = int(time.time() * 1000)
+
+ # Remove old ticks
+ cutoff = now - (self.max_age_seconds * 1000)
+ original_count = len(self.ticks)
+ self.ticks = [tick for tick in self.ticks if tick['timestamp'] >= cutoff]
+ removed = original_count - len(self.ticks)
+
+ if removed > 0:
+ logger.debug(f"Cleaned up {removed} old ticks, remaining: {len(self.ticks)}")
+
+ def _load_cached_ticks(self):
+ """Load cached ticks from disk on startup"""
+ # Create symbol-specific filename
+ symbol_safe = self.symbol.replace("/", "_").replace("-", "_").lower()
+ cache_file = os.path.join(self.cache_dir, f"{symbol_safe}_recent_ticks.csv")
+
+ if not os.path.exists(cache_file):
+ logger.info(f"No cached ticks found for {self.symbol}")
+ return
+
+ try:
+ # Check if cache is fresh (less than 10 minutes old)
+ file_age = time.time() - os.path.getmtime(cache_file)
+ if file_age > 600: # 10 minutes
+ logger.info(f"Cached ticks for {self.symbol} are too old ({file_age:.1f}s), skipping")
+ return
+
+ # Load cached ticks
+ tick_df = pd.read_csv(cache_file)
+ if tick_df.empty:
+ logger.info(f"Cached ticks file for {self.symbol} is empty")
+ return
+
+ # Convert to list of dicts and add to storage
+ cached_ticks = tick_df.to_dict('records')
+ self.ticks.extend(cached_ticks)
+ logger.info(f"Loaded {len(cached_ticks)} cached ticks for {self.symbol} from {cache_file}")
+ except Exception as e:
+ logger.error(f"Error loading cached ticks for {self.symbol}: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+
+ def _cache_ticks(self):
+ """Cache recent ticks to disk"""
+ if not self.ticks:
+ return
+
+ # Get ticks from last 10 minutes
+ now = int(time.time() * 1000) # Current time in ms
+ cutoff = now - (600 * 1000) # 10 minutes in ms
+ recent_ticks = [tick for tick in self.ticks if tick['timestamp'] >= cutoff]
+
+ if not recent_ticks:
+ logger.debug("No recent ticks to cache")
+ return
+
+ # Create symbol-specific filename
+ symbol_safe = self.symbol.replace("/", "_").replace("-", "_").lower()
+ cache_file = os.path.join(self.cache_dir, f"{symbol_safe}_recent_ticks.csv")
+
+ # Save to disk
+ try:
+ tick_df = pd.DataFrame(recent_ticks)
+ tick_df.to_csv(cache_file, index=False)
+ logger.info(f"Cached {len(recent_ticks)} recent ticks for {self.symbol} to {cache_file}")
+ except Exception as e:
+ logger.error(f"Error caching ticks: {str(e)}")
+
+ def get_latest_price(self) -> Optional[float]:
+ """Get the latest price from the most recent tick"""
+ if self.ticks:
+ return self.ticks[-1].get('price')
+ return None
+
+ def get_price_stats(self) -> Dict:
+ """Get stats about the prices in storage"""
+ if not self.ticks:
+ return {
+ 'min': None,
+ 'max': None,
+ 'latest': None,
+ 'count': 0,
+ 'age_seconds': 0
+ }
+
+ prices = [tick['price'] for tick in self.ticks]
+ latest_timestamp = self.ticks[-1]['timestamp']
+ oldest_timestamp = self.ticks[0]['timestamp']
+
+ return {
+ 'min': min(prices),
+ 'max': max(prices),
+ 'latest': prices[-1],
+ 'count': len(prices),
+ 'age_seconds': (latest_timestamp - oldest_timestamp) / 1000
+ }
+
+ def get_ticks_as_df(self) -> pd.DataFrame:
+ """Return ticks as a DataFrame"""
+ if not self.ticks:
+ logger.warning("No ticks available for DataFrame conversion")
+ return pd.DataFrame()
+
+ # Ensure we have fresh data
+ self._cleanup()
+
+ # Create a new list from ticks to avoid modifying the original data
+ ticks_data = self.ticks.copy()
+
+ # Ensure we have the latest ticks at the end of the DataFrame
+ ticks_data.sort(key=lambda x: x['timestamp'])
+
+ df = pd.DataFrame(ticks_data)
+ if not df.empty:
+ logger.debug(f"Converting timestamps for {len(df)} ticks")
+ # Ensure timestamp column exists
+ if 'timestamp' not in df.columns:
+ logger.error("Tick data missing timestamp column")
+ return pd.DataFrame()
+
+ # Check timestamp datatype before conversion
+ sample_ts = df['timestamp'].iloc[0] if len(df) > 0 else None
+ logger.debug(f"Sample timestamp before conversion: {sample_ts}, type: {type(sample_ts)}")
+
+ # Convert timestamps to datetime
+ try:
+ df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
+ logger.debug(f"Timestamps converted to datetime successfully")
+ if len(df) > 0:
+ logger.debug(f"Sample converted timestamp: {df['timestamp'].iloc[0]}")
+ except Exception as e:
+ logger.error(f"Error converting timestamps: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+ return pd.DataFrame()
+ return df
+
+ def get_candles(self, interval_seconds: int = 1, start_time_ms: int = None, end_time_ms: int = None) -> pd.DataFrame:
+ """Generate candlestick data from ticks
+
+ Args:
+ interval_seconds: Interval in seconds for each candle
+ start_time_ms: Start time in milliseconds
+ end_time_ms: End time in milliseconds
+
+ Returns:
+ DataFrame with candlestick data
+ """
+ # Get filtered ticks
+ ticks = self.get_ticks_from_time(start_time_ms, end_time_ms)
+
+ if not ticks:
+ if self.use_sample_data:
+ # Generate multiple sample ticks to create several candles
+ current_time = int(time.time() * 1000)
+ sample_ticks = []
+
+ # Generate ticks for the past 10 intervals
+ for i in range(20):
+ # Base price with some trend
+ base_price = self.last_sample_price * (1 + 0.0001 * (10 - i))
+
+ # Add some randomness to the price
+ random_factor = random.uniform(-0.002, 0.002) # Small random change
+ tick_price = base_price * (1 + random_factor)
+
+ # Create timestamp with appropriate offset
+ tick_time = current_time - (i * interval_seconds * 1000 // 2)
+
+ sample_tick = {
+ 'price': tick_price,
+ 'volume': random.uniform(0.01, 0.5),
+ 'timestamp': tick_time,
+ 'is_sample': True
+ }
+
+ sample_ticks.append(sample_tick)
+
+ # Update the last sample values
+ self.last_sample_price = sample_ticks[0]['price']
+ self.last_sample_time = sample_ticks[0]['timestamp']
+
+ # Add the sample ticks in chronological order
+ for tick in sorted(sample_ticks, key=lambda x: x['timestamp']):
+ self.add_tick(tick)
+
+ # Try again with the new ticks
+ ticks = self.get_ticks_from_time(start_time_ms, end_time_ms)
+
+ if not ticks and self.log_no_ticks_warning:
+ logger.warning("Still no ticks available after adding sample data")
+ elif self.log_no_ticks_warning:
+ logger.warning("No ticks available for candle formation")
+ return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
+ else:
+ return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
+
+ # Ensure ticks are up to date
+ try:
+ self._cleanup()
+ except Exception as cleanup_error:
+ logger.error(f"Error cleaning up ticks: {str(cleanup_error)}")
+
+ df = pd.DataFrame(ticks)
+ if df.empty:
+ logger.warning("Tick DataFrame is empty after filtering/conversion")
+ return pd.DataFrame()
+
+ logger.info(f"Preparing to create candles from {len(df)} ticks with {interval_seconds}s interval")
+
+ # First, ensure all required columns exist
+ required_columns = ['timestamp', 'price', 'volume']
+ for col in required_columns:
+ if col not in df.columns:
+ logger.error(f"Required column '{col}' missing from tick data")
+ return pd.DataFrame()
+
+ # Make sure DataFrame has no duplicated timestamps before setting index
+ try:
+ if 'timestamp' in df.columns:
+ # Check for duplicate timestamps
+ duplicate_count = df['timestamp'].duplicated().sum()
+ if duplicate_count > 0:
+ logger.warning(f"Found {duplicate_count} duplicate timestamps, keeping the last occurrence")
+ # Keep the last occurrence of each timestamp
+ df = df.drop_duplicates(subset='timestamp', keep='last')
+
+ # Convert timestamp to datetime if it's not already
+ if not pd.api.types.is_datetime64_any_dtype(df['timestamp']):
+ logger.debug("Converting timestamp to datetime")
+ # Try multiple approaches to convert timestamps
+ try:
+ # First, try to convert from milliseconds (integer timestamps)
+ if pd.api.types.is_integer_dtype(df['timestamp']):
+ df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
+ else:
+ # Otherwise try standard conversion
+ df['timestamp'] = pd.to_datetime(df['timestamp'])
+ except Exception as conv_error:
+ logger.error(f"Error converting timestamps to datetime: {str(conv_error)}")
+ # Try a fallback approach
+ try:
+ # Fallback for integer timestamps
+ if df['timestamp'].iloc[0] > 1000000000000: # Check if milliseconds timestamp
+ df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
+ else: # Otherwise assume seconds
+ df['timestamp'] = pd.to_datetime(df['timestamp'], unit='s')
+ except Exception as fallback_error:
+ logger.error(f"Fallback timestamp conversion failed: {str(fallback_error)}")
+ return pd.DataFrame()
+
+ # Use timestamp column for resampling
+ df = df.set_index('timestamp')
+ except Exception as prep_error:
+ logger.error(f"Error preprocessing DataFrame for resampling: {str(prep_error)}")
+ import traceback
+ logger.error(traceback.format_exc())
+ return pd.DataFrame()
+
+ # Create interval string for resampling - use 's' instead of deprecated 'S'
+ interval_str = f'{interval_seconds}s'
+
+ # Resample to create OHLCV candles with multiple fallback options
+ logger.debug(f"Resampling with interval: {interval_str}")
+
+ candles = None
+
+ # First attempt - individual column resampling
+ try:
+ # Check that price column exists and has enough data
+ if 'price' not in df.columns:
+ raise ValueError("Price column missing from DataFrame")
+
+ if len(df) < 2:
+ logger.warning("Not enough data points for resampling, using direct data")
+ # For single data point, create a single candle
+ if len(df) == 1:
+ price_val = df['price'].iloc[0]
+ volume_val = df['volume'].iloc[0] if 'volume' in df.columns else 0
+ timestamp_val = df.index[0]
+
+ candles = pd.DataFrame({
+ 'timestamp': [timestamp_val],
+ 'open': [price_val],
+ 'high': [price_val],
+ 'low': [price_val],
+ 'close': [price_val],
+ 'volume': [volume_val]
+ })
+ return candles
+ else:
+ # No data
+ return pd.DataFrame()
+
+ # Resample and aggregate each column separately
+ open_df = df['price'].resample(interval_str).first()
+ high_df = df['price'].resample(interval_str).max()
+ low_df = df['price'].resample(interval_str).min()
+ close_df = df['price'].resample(interval_str).last()
+ volume_df = df['volume'].resample(interval_str).sum()
+
+ # Check for length mismatches before combining
+ expected_length = len(open_df)
+ if (len(high_df) != expected_length or
+ len(low_df) != expected_length or
+ len(close_df) != expected_length or
+ len(volume_df) != expected_length):
+ logger.warning("Length mismatch in resampled columns, falling back to alternative method")
+ raise ValueError("Length mismatch")
+
+ # Combine into a single DataFrame
+ candles = pd.DataFrame({
+ 'open': open_df,
+ 'high': high_df,
+ 'low': low_df,
+ 'close': close_df,
+ 'volume': volume_df
+ })
+ logger.debug(f"Successfully created {len(candles)} candles with individual column resampling")
+ except Exception as resample_error:
+ logger.error(f"Error in individual column resampling: {str(resample_error)}")
+
+ # Second attempt - built-in agg method
+ try:
+ logger.debug("Trying fallback resampling method with agg()")
+ candles = df.resample(interval_str).agg({
+ 'price': ['first', 'max', 'min', 'last'],
+ 'volume': 'sum'
+ })
+ # Flatten MultiIndex columns
+ candles.columns = ['open', 'high', 'low', 'close', 'volume']
+ logger.debug(f"Successfully created {len(candles)} candles with agg() method")
+ except Exception as agg_error:
+ logger.error(f"Error in agg() resampling: {str(agg_error)}")
+
+ # Third attempt - manual candle construction
+ try:
+ logger.debug("Trying manual candle construction method")
+ resampler = df.resample(interval_str)
+ candle_data = []
+
+ for name, group in resampler:
+ if not group.empty:
+ candle = {
+ 'timestamp': name,
+ 'open': group['price'].iloc[0],
+ 'high': group['price'].max(),
+ 'low': group['price'].min(),
+ 'close': group['price'].iloc[-1],
+ 'volume': group['volume'].sum() if 'volume' in group.columns else 0
+ }
+ candle_data.append(candle)
+
+ if candle_data:
+ candles = pd.DataFrame(candle_data)
+ logger.debug(f"Successfully created {len(candles)} candles with manual method")
+ else:
+ logger.warning("No candles created with manual method")
+ return pd.DataFrame()
+ except Exception as manual_error:
+ logger.error(f"Error in manual candle construction: {str(manual_error)}")
+ import traceback
+ logger.error(traceback.format_exc())
+ return pd.DataFrame()
+
+ # Ensure the result isn't empty
+ if candles is None or candles.empty:
+ logger.warning("No candles were created after all resampling attempts")
+ return pd.DataFrame()
+
+ # Reset index to get timestamp as column
+ try:
+ candles = candles.reset_index()
+ except Exception as reset_error:
+ logger.error(f"Error resetting index: {str(reset_error)}")
+ # Try to create a new DataFrame with the timestamp index as a column
+ try:
+ timestamp_col = candles.index.to_list()
+ candles_dict = candles.to_dict('list')
+ candles_dict['timestamp'] = timestamp_col
+ candles = pd.DataFrame(candles_dict)
+ except Exception as fallback_error:
+ logger.error(f"Error in fallback index reset: {str(fallback_error)}")
+ return pd.DataFrame()
+
+ # Ensure no NaN values
+ try:
+ nan_count_before = candles.isna().sum().sum()
+ if nan_count_before > 0:
+ logger.warning(f"Found {nan_count_before} NaN values in candles, dropping them")
+
+ candles = candles.dropna()
+ except Exception as nan_error:
+ logger.error(f"Error handling NaN values: {str(nan_error)}")
+ # Try to fill NaN values instead of dropping
+ try:
+ candles = candles.fillna(method='ffill').fillna(method='bfill')
+ except:
+ pass
+
+ logger.debug(f"Generated {len(candles)} candles from {len(df)} ticks")
+ return candles
+
+ def get_candle_stats(self) -> Dict:
+ """Get statistics about cached candles for different intervals"""
+ stats = {}
+
+ # Define intervals to check
+ intervals = [1, 5, 15, 60, 300, 900, 3600]
+
+ for interval in intervals:
+ candles = self.get_candles(interval_seconds=interval)
+ count = len(candles) if not candles.empty else 0
+
+ # Get time range if we have candles
+ time_range = None
+ if count > 0:
+ try:
+ start_time = candles['timestamp'].min()
+ end_time = candles['timestamp'].max()
+ if isinstance(start_time, pd.Timestamp):
+ start_time = start_time.strftime('%Y-%m-%d %H:%M:%S')
+ if isinstance(end_time, pd.Timestamp):
+ end_time = end_time.strftime('%Y-%m-%d %H:%M:%S')
+ time_range = f"{start_time} to {end_time}"
+ except:
+ time_range = "Unknown"
+
+ stats[f"{interval}s"] = {
+ 'count': count,
+ 'time_range': time_range
+ }
+
+ return stats
+
+ def get_ticks_from_time(self, start_time_ms: int = None, end_time_ms: int = None) -> List[Dict]:
+ """Get ticks within a specific time range
+
+ Args:
+ start_time_ms: Start time in milliseconds (None for no lower bound)
+ end_time_ms: End time in milliseconds (None for no upper bound)
+
+ Returns:
+ List of ticks within the time range
+ """
+ if not self.ticks:
+ return []
+
+ # Ensure ticks are updated
+ self._cleanup()
+
+ # Apply time filters if specified
+ filtered_ticks = self.ticks
+ if start_time_ms is not None:
+ filtered_ticks = [tick for tick in filtered_ticks if tick['timestamp'] >= start_time_ms]
+ if end_time_ms is not None:
+ filtered_ticks = [tick for tick in filtered_ticks if tick['timestamp'] <= end_time_ms]
+
+ logger.debug(f"Retrieved {len(filtered_ticks)} ticks from time range {start_time_ms} to {end_time_ms}")
+ return filtered_ticks
+
+ def get_time_based_stats(self) -> Dict:
+ """Get statistics about the ticks organized by time periods
+
+ Returns:
+ Dictionary with statistics for different time periods
+ """
+ if not self.ticks:
+ return {
+ 'total_ticks': 0,
+ 'periods': {}
+ }
+
+ # Ensure ticks are updated
+ self._cleanup()
+
+ now = int(time.time() * 1000) # Current time in ms
+
+ # Define time periods to analyze
+ periods = {
+ '1min': now - (60 * 1000),
+ '5min': now - (5 * 60 * 1000),
+ '15min': now - (15 * 60 * 1000),
+ '30min': now - (30 * 60 * 1000)
+ }
+
+ stats = {
+ 'total_ticks': len(self.ticks),
+ 'oldest_tick': self.ticks[0]['timestamp'] if self.ticks else None,
+ 'newest_tick': self.ticks[-1]['timestamp'] if self.ticks else None,
+ 'time_span_seconds': (self.ticks[-1]['timestamp'] - self.ticks[0]['timestamp']) / 1000 if self.ticks else 0,
+ 'periods': {}
+ }
+
+ # Calculate stats for each period
+ for period_name, cutoff_time in periods.items():
+ period_ticks = [tick for tick in self.ticks if tick['timestamp'] >= cutoff_time]
+
+ if period_ticks:
+ prices = [tick['price'] for tick in period_ticks]
+ volumes = [tick.get('volume', 0) for tick in period_ticks]
+
+ period_stats = {
+ 'tick_count': len(period_ticks),
+ 'min_price': min(prices) if prices else None,
+ 'max_price': max(prices) if prices else None,
+ 'avg_price': sum(prices) / len(prices) if prices else None,
+ 'last_price': period_ticks[-1]['price'] if period_ticks else None,
+ 'total_volume': sum(volumes),
+ 'ticks_per_second': len(period_ticks) / (int(period_name[:-3]) * 60) if period_ticks else 0
+ }
+
+ stats['periods'][period_name] = period_stats
+
+ logger.debug(f"Generated time-based stats: {len(stats['periods'])} periods")
+ return stats
+
+class CandlestickData:
+ def __init__(self, max_length: int = 300):
+ self.timestamps = deque(maxlen=max_length)
+ self.opens = deque(maxlen=max_length)
+ self.highs = deque(maxlen=max_length)
+ self.lows = deque(maxlen=max_length)
+ self.closes = deque(maxlen=max_length)
+ self.volumes = deque(maxlen=max_length)
+ self.current_candle = {
+ 'timestamp': None,
+ 'open': None,
+ 'high': None,
+ 'low': None,
+ 'close': None,
+ 'volume': 0
+ }
+ self.candle_interval = 1 # 1 second by default
+
+ def update_from_trade(self, trade: Dict):
+ timestamp = trade['timestamp']
+ price = trade['price']
+ volume = trade.get('volume', 0)
+
+ # Round timestamp to nearest candle interval
+ candle_timestamp = int(timestamp / (self.candle_interval * 1000)) * (self.candle_interval * 1000)
+
+ if self.current_candle['timestamp'] != candle_timestamp:
+ # Save current candle if it exists
+ if self.current_candle['timestamp'] is not None:
+ self.timestamps.append(self.current_candle['timestamp'])
+ self.opens.append(self.current_candle['open'])
+ self.highs.append(self.current_candle['high'])
+ self.lows.append(self.current_candle['low'])
+ self.closes.append(self.current_candle['close'])
+ self.volumes.append(self.current_candle['volume'])
+ logger.debug(f"New candle saved: {self.current_candle}")
+
+ # Start new candle
+ self.current_candle = {
+ 'timestamp': candle_timestamp,
+ 'open': price,
+ 'high': price,
+ 'low': price,
+ 'close': price,
+ 'volume': volume
+ }
+ logger.debug(f"New candle started: {self.current_candle}")
+ else:
+ # Update current candle
+ if self.current_candle['high'] is None or price > self.current_candle['high']:
+ self.current_candle['high'] = price
+ if self.current_candle['low'] is None or price < self.current_candle['low']:
+ self.current_candle['low'] = price
+ self.current_candle['close'] = price
+ self.current_candle['volume'] += volume
+ logger.debug(f"Updated current candle: {self.current_candle}")
+
+ def get_dataframe(self) -> pd.DataFrame:
+ # Include current candle in the dataframe if it exists
+ timestamps = list(self.timestamps)
+ opens = list(self.opens)
+ highs = list(self.highs)
+ lows = list(self.lows)
+ closes = list(self.closes)
+ volumes = list(self.volumes)
+
+ if self.current_candle['timestamp'] is not None:
+ timestamps.append(self.current_candle['timestamp'])
+ opens.append(self.current_candle['open'])
+ highs.append(self.current_candle['high'])
+ lows.append(self.current_candle['low'])
+ closes.append(self.current_candle['close'])
+ volumes.append(self.current_candle['volume'])
+
+ df = pd.DataFrame({
+ 'timestamp': timestamps,
+ 'open': opens,
+ 'high': highs,
+ 'low': lows,
+ 'close': closes,
+ 'volume': volumes
+ })
+ if not df.empty:
+ df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
+ return df
+
+class BinanceWebSocket:
+ """Binance WebSocket implementation for real-time tick data"""
+ def __init__(self, symbol: str):
+ self.symbol = symbol.replace('/', '').lower()
+ self.ws = None
+ self.running = False
+ self.reconnect_delay = 1
+ self.max_reconnect_delay = 60
+ self.message_count = 0
+
+ # Binance WebSocket configuration
+ self.ws_url = f"wss://stream.binance.com:9443/ws/{self.symbol}@trade"
+ logger.info(f"Initialized Binance WebSocket for symbol: {self.symbol}")
+
+ async def connect(self):
+ while True:
+ try:
+ logger.info(f"Attempting to connect to {self.ws_url}")
+ self.ws = await websockets.connect(self.ws_url)
+ logger.info("WebSocket connection established")
+
+ self.running = True
+ self.reconnect_delay = 1
+ logger.info(f"Successfully connected to Binance WebSocket for {self.symbol}")
+ return True
+ except Exception as e:
+ logger.error(f"WebSocket connection error: {str(e)}")
+ await asyncio.sleep(self.reconnect_delay)
+ self.reconnect_delay = min(self.reconnect_delay * 2, self.max_reconnect_delay)
+ continue
+
+ async def receive(self) -> Optional[Dict]:
+ if not self.ws:
+ return None
+
+ try:
+ message = await self.ws.recv()
+ self.message_count += 1
+
+ if self.message_count % 100 == 0: # Log every 100th message to avoid spam
+ logger.info(f"Received message #{self.message_count}")
+ logger.debug(f"Raw message: {message[:200]}...")
+
+ data = json.loads(message)
+
+ # Process trade data
+ if 'e' in data and data['e'] == 'trade':
+ trade_data = {
+ 'timestamp': data['T'], # Trade time
+ 'price': float(data['p']), # Price
+ 'volume': float(data['q']), # Quantity
+ 'type': 'trade'
+ }
+ logger.debug(f"Processed trade data: {trade_data}")
+ return trade_data
+
+ return None
+ except websockets.exceptions.ConnectionClosed:
+ logger.warning("WebSocket connection closed")
+ self.running = False
+ return None
+ except json.JSONDecodeError as e:
+ logger.error(f"JSON decode error: {str(e)}, message: {message[:200]}...")
+ return None
+ except Exception as e:
+ logger.error(f"Error receiving message: {str(e)}")
+ return None
+
+ async def close(self):
+ """Close the WebSocket connection"""
+ if self.ws:
+ await self.ws.close()
+ self.running = False
+ logger.info("WebSocket connection closed")
+
+class BinanceHistoricalData:
+ """Fetch historical candle data from Binance"""
+
+ def __init__(self):
+ self.base_url = "https://api.binance.com/api/v3/klines"
+ # Create a cache directory if it doesn't exist
+ self.cache_dir = os.path.join(os.getcwd(), "cache")
+ os.makedirs(self.cache_dir, exist_ok=True)
+ logger.info(f"Initialized BinanceHistoricalData with cache directory: {self.cache_dir}")
+
+ def _get_interval_string(self, interval_seconds: int) -> str:
+ """Convert interval seconds to Binance interval string"""
+ if interval_seconds == 60: # 1m
+ return "1m"
+ elif interval_seconds == 300: # 5m
+ return "5m"
+ elif interval_seconds == 900: # 15m
+ return "15m"
+ elif interval_seconds == 1800: # 30m
+ return "30m"
+ elif interval_seconds == 3600: # 1h
+ return "1h"
+ elif interval_seconds == 14400: # 4h
+ return "4h"
+ elif interval_seconds == 86400: # 1d
+ return "1d"
+ else:
+ # Default to 1m if not recognized
+ logger.warning(f"Unrecognized interval {interval_seconds}s, defaulting to 1m")
+ return "1m"
+
+ def _get_cache_filename(self, symbol: str, interval: str) -> str:
+ """Generate cache filename for the symbol and interval"""
+ # Replace any slashes in symbol with underscore
+ safe_symbol = symbol.replace("/", "_")
+ return os.path.join(self.cache_dir, f"{safe_symbol}_{interval}_candles.csv")
+
+ def _load_from_cache(self, symbol: str, interval: str) -> Optional[pd.DataFrame]:
+ """Load candle data from cache if available and not expired"""
+ filename = self._get_cache_filename(symbol, interval)
+
+ if not os.path.exists(filename):
+ logger.debug(f"No cache file found for {symbol} {interval}")
+ return None
+
+ # Check if cache is fresh (less than 1 hour old for anything but 1d, 1 day for 1d)
+ file_age = time.time() - os.path.getmtime(filename)
+ max_age = 86400 if interval == "1d" else 3600 # 1 day for 1d, 1 hour for others
+
+ if file_age > max_age:
+ logger.debug(f"Cache for {symbol} {interval} is expired ({file_age:.1f}s old)")
+ return None
+
+ try:
+ df = pd.read_csv(filename)
+ # Convert timestamp string back to datetime
+ df['timestamp'] = pd.to_datetime(df['timestamp'])
+ logger.info(f"Loaded {len(df)} candles from cache for {symbol} {interval}")
+ return df
+ except Exception as e:
+ logger.error(f"Error loading from cache: {str(e)}")
+ return None
+
+ def _save_to_cache(self, df: pd.DataFrame, symbol: str, interval: str) -> bool:
+ """Save candle data to cache"""
+ if df.empty:
+ logger.warning(f"No data to cache for {symbol} {interval}")
+ return False
+
+ filename = self._get_cache_filename(symbol, interval)
+ try:
+ df.to_csv(filename, index=False)
+ logger.info(f"Cached {len(df)} candles for {symbol} {interval} to {filename}")
+ return True
+ except Exception as e:
+ logger.error(f"Error saving to cache: {str(e)}")
+ return False
+
+ def get_historical_candles(self, symbol: str, interval_seconds: int, limit: int = 500) -> pd.DataFrame:
+ """Get historical candle data for the specified symbol and interval"""
+ # Convert to Binance format
+ clean_symbol = symbol.replace("/", "")
+ interval = self._get_interval_string(interval_seconds)
+
+ # Try to load from cache first
+ cached_data = self._load_from_cache(symbol, interval)
+ if cached_data is not None and len(cached_data) >= limit:
+ return cached_data.tail(limit)
+
+ # Fetch from API if not cached or insufficient
+ try:
+ logger.info(f"Fetching {limit} historical candles for {symbol} ({interval}) from Binance API")
+
+ params = {
+ "symbol": clean_symbol,
+ "interval": interval,
+ "limit": limit
+ }
+
+ response = requests.get(self.base_url, params=params)
+ response.raise_for_status() # Raise exception for HTTP errors
+
+ # Process the data
+ candles = response.json()
+
+ if not candles:
+ logger.warning(f"No candles returned from Binance for {symbol} {interval}")
+ return pd.DataFrame()
+
+ # Convert to DataFrame - Binance returns data in this format:
+ # [
+ # [
+ # 1499040000000, // Open time
+ # "0.01634790", // Open
+ # "0.80000000", // High
+ # "0.01575800", // Low
+ # "0.01577100", // Close
+ # "148976.11427815", // Volume
+ # ... // Ignore the rest
+ # ],
+ # ...
+ # ]
+
+ df = pd.DataFrame(candles, columns=[
+ "timestamp", "open", "high", "low", "close", "volume",
+ "close_time", "quote_asset_volume", "number_of_trades",
+ "taker_buy_base_asset_volume", "taker_buy_quote_asset_volume", "ignore"
+ ])
+
+ # Convert types
+ df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
+ for col in ["open", "high", "low", "close", "volume"]:
+ df[col] = df[col].astype(float)
+
+ # Keep only needed columns
+ df = df[["timestamp", "open", "high", "low", "close", "volume"]]
+
+ # Cache the results
+ self._save_to_cache(df, symbol, interval)
+
+ logger.info(f"Successfully fetched {len(df)} candles for {symbol} {interval}")
+ return df
+
+ except Exception as e:
+ logger.error(f"Error fetching historical data for {symbol} {interval}: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+ return pd.DataFrame()
+
+
+class ExchangeWebSocket:
+ """Generic WebSocket interface for cryptocurrency exchanges"""
+ def __init__(self, symbol: str, exchange: str = "binance"):
+ self.symbol = symbol
+ self.exchange = exchange.lower()
+ self.ws = None
+
+ # Initialize the appropriate WebSocket implementation
+ if self.exchange == "binance":
+ self.ws = BinanceWebSocket(symbol)
+ elif self.exchange == "mexc":
+ self.ws = MEXCWebSocket(symbol)
+ else:
+ raise ValueError(f"Unsupported exchange: {exchange}")
+
+ async def connect(self):
+ """Connect to the exchange WebSocket"""
+ return await self.ws.connect()
+
+ async def receive(self) -> Optional[Dict]:
+ """Receive data from the WebSocket"""
+ return await self.ws.receive()
+
+ async def close(self):
+ """Close the WebSocket connection"""
+ await self.ws.close()
+
+ @property
+ def running(self):
+ """Check if the WebSocket is running"""
+ return self.ws.running if self.ws else False
+
+class CandleCache:
+ def __init__(self, max_candles: int = 5000):
+ self.candles = {
+ '1s': deque(maxlen=max_candles),
+ '1m': deque(maxlen=max_candles),
+ '1h': deque(maxlen=max_candles),
+ '1d': deque(maxlen=max_candles)
+ }
+ logger.info(f"Initialized CandleCache with max candles: {max_candles}")
+
+ def add_candles(self, interval: str, new_candles: pd.DataFrame):
+ if interval in self.candles and not new_candles.empty:
+ # Convert DataFrame to list of dicts to avoid pandas issues
+ for _, row in new_candles.iterrows():
+ candle_dict = row.to_dict()
+ self.candles[interval].append(candle_dict)
+ logger.debug(f"Added {len(new_candles)} candles to {interval} cache")
+
+ def get_recent_candles(self, interval: str, count: int = 500) -> pd.DataFrame:
+ if interval in self.candles and self.candles[interval]:
+ # Convert deque to list of dicts first
+ all_candles = list(self.candles[interval])
+ # Check if we're requesting more candles than we have
+ if count > len(all_candles):
+ logger.debug(f"Requested {count} candles, but only have {len(all_candles)} for {interval}")
+ count = len(all_candles)
+
+ recent_candles = all_candles[-count:]
+ logger.debug(f"Returning {len(recent_candles)} recent candles for {interval} (requested {count})")
+
+ # Create DataFrame and ensure timestamp is datetime type
+ df = pd.DataFrame(recent_candles)
+ if not df.empty and 'timestamp' in df.columns:
+ try:
+ if not pd.api.types.is_datetime64_any_dtype(df['timestamp']):
+ df['timestamp'] = pd.to_datetime(df['timestamp'])
+ except Exception as e:
+ logger.warning(f"Error converting timestamps in get_recent_candles: {str(e)}")
+
+ return df
+
+ logger.debug(f"No candles available for {interval}")
+ return pd.DataFrame()
+
+ def update_cache(self, interval: str, new_candles: pd.DataFrame):
+ """
+ Update the candle cache for a specific timeframe with new candles
+
+ Args:
+ interval: The timeframe interval ('1s', '1m', '1h', '1d')
+ new_candles: DataFrame with new candles to add to the cache
+ """
+ if interval not in self.candles:
+ logger.warning(f"Invalid interval {interval} for cache update")
+ return
+
+ if new_candles is None or new_candles.empty:
+ logger.debug(f"No new candles to update {interval} cache")
+ return
+
+ # Check if timestamp column exists
+ if 'timestamp' not in new_candles.columns:
+ logger.warning(f"No timestamp column in new candles for {interval}")
+ return
+
+ try:
+ # Ensure timestamp is datetime for proper comparison
+ try:
+ if not pd.api.types.is_datetime64_any_dtype(new_candles['timestamp']):
+ logger.debug(f"Converting timestamps to datetime for {interval}")
+ new_candles['timestamp'] = pd.to_datetime(new_candles['timestamp'])
+ except Exception as e:
+ logger.error(f"Error converting timestamps: {str(e)}")
+ # Try a different approach
+ try:
+ new_candles['timestamp'] = pd.to_datetime(new_candles['timestamp'], errors='coerce')
+ # Drop any rows where conversion failed
+ new_candles = new_candles.dropna(subset=['timestamp'])
+ if new_candles.empty:
+ logger.warning(f"All timestamps conversion failed for {interval}")
+ return
+ except Exception as e2:
+ logger.error(f"Second attempt to convert timestamps failed: {str(e2)}")
+ return
+
+ # Create a copy to avoid modifying the original
+ new_candles_copy = new_candles.copy()
+
+ # If we have no candles in cache, add all new candles
+ if not self.candles[interval]:
+ logger.debug(f"No existing candles for {interval}, adding all {len(new_candles_copy)} candles")
+ self.add_candles(interval, new_candles_copy)
+ return
+
+ # Get the timestamp from the last cached candle
+ last_cached_candle = self.candles[interval][-1]
+ if not isinstance(last_cached_candle, dict):
+ logger.warning(f"Last cached candle is not a dictionary for {interval}")
+ last_cached_candle = {'timestamp': None}
+
+ if 'timestamp' not in last_cached_candle:
+ logger.warning(f"No timestamp in last cached candle for {interval}")
+ last_cached_candle['timestamp'] = None
+
+ last_cached_time = last_cached_candle['timestamp']
+ logger.debug(f"Last cached timestamp for {interval}: {last_cached_time}")
+
+ # If last_cached_time is None, add all candles
+ if last_cached_time is None:
+ logger.debug(f"No valid last cached timestamp, adding all {len(new_candles_copy)} candles for {interval}")
+ self.add_candles(interval, new_candles_copy)
+ return
+
+ # Convert last_cached_time to datetime if needed
+ if not isinstance(last_cached_time, (pd.Timestamp, datetime)):
+ try:
+ last_cached_time = pd.to_datetime(last_cached_time)
+ except Exception as e:
+ logger.error(f"Cannot convert last cached time to datetime: {str(e)}")
+ # Add all candles as fallback
+ self.add_candles(interval, new_candles_copy)
+ return
+
+ # Make a backup of current cache before filtering
+ cache_backup = list(self.candles[interval])
+
+ # Filter new candles that are after the last cached candle
+ try:
+ filtered_candles = new_candles_copy[new_candles_copy['timestamp'] > last_cached_time]
+
+ if not filtered_candles.empty:
+ logger.debug(f"Adding {len(filtered_candles)} new candles for {interval}")
+ self.add_candles(interval, filtered_candles)
+ else:
+ # No new candles after last cached time, check for missing candles
+ try:
+ # Get unique timestamps in cache
+ cached_timestamps = set()
+ for candle in self.candles[interval]:
+ if isinstance(candle, dict) and 'timestamp' in candle:
+ ts = candle['timestamp']
+ if isinstance(ts, (pd.Timestamp, datetime)):
+ cached_timestamps.add(ts)
+ else:
+ try:
+ cached_timestamps.add(pd.to_datetime(ts))
+ except:
+ pass
+
+ # Find candles in new_candles that aren't in the cache
+ missing_candles = new_candles_copy[~new_candles_copy['timestamp'].isin(cached_timestamps)]
+
+ if not missing_candles.empty:
+ logger.info(f"Found {len(missing_candles)} missing candles for {interval}")
+ self.add_candles(interval, missing_candles)
+ else:
+ logger.debug(f"No new or missing candles to add for {interval}")
+ except Exception as missing_error:
+ logger.error(f"Error checking for missing candles: {str(missing_error)}")
+ except Exception as filter_error:
+ logger.error(f"Error filtering candles by timestamp: {str(filter_error)}")
+ # Restore from backup
+ self.candles[interval] = deque(cache_backup, maxlen=self.candles[interval].maxlen)
+ # Try adding all candles as fallback
+ self.add_candles(interval, new_candles_copy)
+ except Exception as e:
+ logger.error(f"Unhandled error updating cache for {interval}: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+
+ def get_candles(self, timeframe: str, count: int = 500) -> pd.DataFrame:
+ """
+ Get candles for a specific timeframe. This is an alias for get_recent_candles
+ to maintain compatibility with code that expects this method name.
+
+ Args:
+ timeframe: The timeframe interval ('1s', '1m', '1h', '1d')
+ count: Maximum number of candles to return
+
+ Returns:
+ DataFrame containing the candles
+ """
+ try:
+ logger.debug(f"Getting {count} candles for {timeframe} via get_candles()")
+ return self.get_recent_candles(timeframe, count)
+ except Exception as e:
+ logger.error(f"Error in get_candles for {timeframe}: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+ return pd.DataFrame()
+
+class RealTimeChart:
+ """Real-time chart using Dash and Plotly"""
+
+ def __init__(self, symbol="BTC/USDT", use_sample_data=False, log_no_ticks_warning=True):
+ """Initialize a new RealTimeChart
+
+ Args:
+ symbol: Trading pair symbol (e.g., BTC/USDT)
+ use_sample_data: Whether to use sample data when no real data is available
+ log_no_ticks_warning: Whether to log warnings when no ticks are available
+ """
+ self.symbol = symbol
+ self.use_sample_data = use_sample_data
+
+ # Initialize variables for trading info display
+ self.current_signal = 'HOLD'
+ self.signal_time = datetime.now()
+ self.current_position = 0.0
+ self.session_balance = 100.0 # Start with $100 balance
+ self.session_pnl = 0.0
+
+ # Initialize NN signals and trades lists
+ self.nn_signals = []
+ self.trades = []
+
+ # Use existing timezone variable instead of trying to detect again
+ logger.info(f"Using timezone: {tz_name}")
+
+ # Initialize tick storage
+ logger.info(f"Initializing RealTimeChart for {symbol}")
+ self.tick_storage = TradeTickStorage(
+ symbol=symbol,
+ max_age_seconds=3600, # Keep ticks for 1 hour
+ use_sample_data=use_sample_data,
+ log_no_ticks_warning=log_no_ticks_warning
+ )
+
+ # Initialize candlestick data for backward compatibility
+ self.candlestick_data = CandlestickData(max_length=5000)
+
+ # Initialize candle cache
+ self.candle_cache = CandleCache(max_candles=5000)
+
+ # Initialize OHLCV cache dictionaries for different timeframes
+ self.ohlcv_cache = {
+ '1s': None,
+ '5s': None,
+ '15s': None,
+ '60s': None,
+ '300s': None,
+ '900s': None,
+ '3600s': None,
+ '1m': None,
+ '5m': None,
+ '15m': None,
+ '1h': None,
+ '1d': None
+ }
+
+ # Historical data handler
+ self.historical_data = BinanceHistoricalData()
+
+ # Flag for first render to force data loading
+ self.first_render = True
+
+ # Last time candles were saved to disk
+ self.last_cache_save_time = time.time()
+
+ # Initialize Dash app
+ self.app = dash.Dash(
+ __name__,
+ external_stylesheets=[dbc.themes.DARKLY],
+ suppress_callback_exceptions=True,
+ meta_tags=[{"name": "viewport", "content": "width=device-width, initial-scale=1"}]
+ )
+
+ # Set up layout and callbacks
+ self._setup_app_layout()
+
+ def _setup_app_layout(self):
+ """Set up the app layout and callbacks"""
+ # Define styling for interval buttons
+ button_style = {
+ 'backgroundColor': '#2C2C2C',
+ 'color': 'white',
+ 'border': 'none',
+ 'padding': '10px 15px',
+ 'margin': '5px',
+ 'borderRadius': '5px',
+ 'cursor': 'pointer',
+ 'fontWeight': 'bold'
+ }
+
+ active_button_style = {
+ **button_style,
+ 'backgroundColor': '#4CAF50',
+ 'boxShadow': '0 2px 4px rgba(0,0,0,0.5)'
+ }
+
+ # Create tab layout
+ self.app.layout = dbc.Tabs([
+ dbc.Tab(self._get_chart_layout(button_style, active_button_style), label="Chart", tab_id="chart-tab"),
+ # No longer need ticks tab as it's causing errors
+ ], id="tabs")
+
+ # Set up callbacks
+ self._setup_interval_callback(button_style, active_button_style)
+ self._setup_chart_callback()
+ self._setup_position_list_callback()
+ # We've removed the ticks callback, so don't call it
+ # self._setup_ticks_callback()
+
+ def _get_chart_layout(self, button_style, active_button_style):
+ # Chart page layout
+ return html.Div([
+ # Chart title and interval buttons
+ html.Div([
+ html.H2(f"{self.symbol} Real-Time Chart", style={
+ 'textAlign': 'center',
+ 'color': '#FFFFFF',
+ 'marginBottom': '10px'
+ }),
+
+ # Store interval data
+ dcc.Store(id='interval-store', data={'interval': 1}),
+
+ # Interval selection buttons
+ html.Div([
+ html.Button('1s', id='btn-1s', n_clicks=0, style=active_button_style),
+ html.Button('5s', id='btn-5s', n_clicks=0, style=button_style),
+ html.Button('15s', id='btn-15s', n_clicks=0, style=button_style),
+ html.Button('30s', id='btn-30s', n_clicks=0, style=button_style),
+ html.Button('1m', id='btn-1m', n_clicks=0, style=button_style),
+ ], style={
+ 'display': 'flex',
+ 'justifyContent': 'center',
+ 'marginBottom': '15px'
+ }),
+
+ # Interval component for updates - set to refresh every 500ms
+ dcc.Interval(
+ id='interval-component',
+ interval=300, # Refresh every 300ms for better real-time updates
+ n_intervals=0
+ ),
+
+ # Main chart
+ dcc.Graph(id='live-chart', style={'height': '75vh'}),
+
+ # Last 5 Positions component
+ html.Div([
+ html.H4("Last 5 Positions", style={'textAlign': 'center', 'color': '#FFFFFF'}),
+ html.Table([
+ html.Thead(
+ html.Tr([
+ html.Th("Action", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}),
+ html.Th("Size", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}),
+ html.Th("Entry Price", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}),
+ html.Th("Exit Price", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}),
+ html.Th("PnL", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}),
+ html.Th("Time", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'})
+ ])
+ ),
+ html.Tbody(id='position-table-body')
+ ], style={
+ 'width': '100%',
+ 'borderCollapse': 'collapse',
+ 'color': '#FFFFFF',
+ 'backgroundColor': '#222'
+ })
+ ], style={'marginTop': '20px', 'marginBottom': '20px', 'overflowX': 'auto'}),
+
+ # Chart acknowledgment
+ html.Div("Real-time trading chart with ML signals", style={
+ 'textAlign': 'center',
+ 'color': '#AAAAAA',
+ 'fontSize': '12px',
+ 'marginTop': '5px'
+ })
+ ])
+ ])
+
+ def _get_ticks_layout(self):
+ # Ticks data page layout
+ return html.Div([
+ # Header and controls
+ html.Div([
+ html.H2(f"{self.symbol} Raw Tick Data (Last 5 Minutes)", style={
+ 'textAlign': 'center',
+ 'color': '#FFFFFF',
+ 'margin': '10px 0'
+ }),
+
+ # Refresh button
+ html.Button('Refresh Data', id='refresh-ticks-btn', n_clicks=0, style={
+ 'backgroundColor': '#4CAF50',
+ 'color': 'white',
+ 'padding': '10px 20px',
+ 'margin': '10px auto',
+ 'border': 'none',
+ 'borderRadius': '5px',
+ 'fontSize': '14px',
+ 'cursor': 'pointer',
+ 'display': 'block'
+ }),
+
+ # Time window selector
+ html.Div([
+ html.Label("Time Window:", style={'color': 'white', 'marginRight': '10px'}),
+ dcc.Dropdown(
+ id='time-window-dropdown',
+ options=[
+ {'label': 'Last 1 minute', 'value': 60},
+ {'label': 'Last 5 minutes', 'value': 300},
+ {'label': 'Last 15 minutes', 'value': 900},
+ {'label': 'Last 30 minutes', 'value': 1800},
+ ],
+ value=300, # Default to 5 minutes
+ style={'width': '200px', 'backgroundColor': '#2C2C2C', 'color': 'black'}
+ )
+ ], style={
+ 'display': 'flex',
+ 'alignItems': 'center',
+ 'justifyContent': 'center',
+ 'margin': '10px'
+ }),
+ ], style={
+ 'backgroundColor': '#2C2C2C',
+ 'padding': '10px',
+ 'borderRadius': '5px',
+ 'marginBottom': '15px'
+ }),
+
+ # Stats cards
+ html.Div(id='tick-stats-cards', style={
+ 'display': 'flex',
+ 'flexWrap': 'wrap',
+ 'justifyContent': 'space-around',
+ 'marginBottom': '15px'
+ }),
+
+ # Ticks data table
+ html.Div(id='ticks-table-container', style={
+ 'backgroundColor': '#232323',
+ 'padding': '10px',
+ 'borderRadius': '5px',
+ 'overflowX': 'auto'
+ }),
+
+ # Price movement chart
+ html.Div([
+ html.H3("Price Movement", style={
+ 'textAlign': 'center',
+ 'color': '#FFFFFF',
+ 'margin': '10px 0'
+ }),
+ dcc.Graph(id='tick-price-chart')
+ ], style={
+ 'backgroundColor': '#232323',
+ 'padding': '10px',
+ 'borderRadius': '5px',
+ 'marginTop': '15px'
+ })
+ ])
+
+ def _setup_interval_callback(self, button_style, active_button_style):
+ # Callback to update interval based on button clicks and update button styles
+ @self.app.callback(
+ [Output('interval-store', 'data'),
+ Output('btn-1s', 'style'),
+ Output('btn-5s', 'style'),
+ Output('btn-15s', 'style'),
+ Output('btn-30s', 'style'),
+ Output('btn-1m', 'style')],
+ [Input('btn-1s', 'n_clicks'),
+ Input('btn-5s', 'n_clicks'),
+ Input('btn-15s', 'n_clicks'),
+ Input('btn-30s', 'n_clicks'),
+ Input('btn-1m', 'n_clicks')],
+ [dash.dependencies.State('interval-store', 'data')]
+ )
+ def update_interval(n1, n5, n15, n30, n60, data):
+ ctx = dash.callback_context
+ if not ctx.triggered:
+ # Default state (1s selected)
+ return ({'interval': 1},
+ active_button_style, button_style, button_style, button_style, button_style)
+
+ button_id = ctx.triggered[0]['prop_id'].split('.')[0]
+
+ if button_id == 'btn-1s':
+ return ({'interval': 1},
+ active_button_style, button_style, button_style, button_style, button_style)
+ elif button_id == 'btn-5s':
+ return ({'interval': 5},
+ button_style, active_button_style, button_style, button_style, button_style)
+ elif button_id == 'btn-15s':
+ return ({'interval': 15},
+ button_style, button_style, active_button_style, button_style, button_style)
+ elif button_id == 'btn-30s':
+ return ({'interval': 30},
+ button_style, button_style, button_style, active_button_style, button_style)
+ elif button_id == 'btn-1m':
+ return ({'interval': 60},
+ button_style, button_style, button_style, button_style, active_button_style)
+
+ # Default case - keep current interval and highlight appropriate button
+ current_interval = data.get('interval', 1)
+ styles = [button_style] * 5 # All inactive by default
+
+ # Set active style based on current interval
+ if current_interval == 1:
+ styles[0] = active_button_style
+ elif current_interval == 5:
+ styles[1] = active_button_style
+ elif current_interval == 15:
+ styles[2] = active_button_style
+ elif current_interval == 30:
+ styles[3] = active_button_style
+ elif current_interval == 60:
+ styles[4] = active_button_style
+
+ return (data, *styles)
+
+ def _setup_chart_callback(self):
+ # Callback to update the chart
+ @self.app.callback(
+ Output('live-chart', 'figure'),
+ [Input('interval-component', 'n_intervals'),
+ Input('interval-store', 'data')]
+ )
+ def update_chart(n, interval_data):
+ try:
+ interval = interval_data.get('interval', 1)
+ logger.debug(f"Updating chart with interval {interval}")
+
+ # Get candlesticks data for the selected interval
+ try:
+ df = self.tick_storage.get_candles(interval_seconds=interval)
+ if df.empty and self.ohlcv_cache[f'{interval}s' if interval < 60 else '1m'] is not None:
+ df = self.ohlcv_cache[f'{interval}s' if interval < 60 else '1m']
+ except Exception as e:
+ logger.error(f"Error getting candles: {str(e)}")
+ df = pd.DataFrame()
+
+ # Get data for other timeframes (1m, 1h, 1d)
+ df_1m = None
+ df_1h = None
+ df_1d = None
+
+ try:
+ # Get 1m candles
+ if self.ohlcv_cache['1m'] is not None and not self.ohlcv_cache['1m'].empty:
+ df_1m = self.ohlcv_cache['1m'].copy()
+ else:
+ df_1m = self.tick_storage.get_candles(interval_seconds=60)
+
+ # Get 1h candles
+ if self.ohlcv_cache['1h'] is not None and not self.ohlcv_cache['1h'].empty:
+ df_1h = self.ohlcv_cache['1h'].copy()
+ else:
+ df_1h = self.tick_storage.get_candles(interval_seconds=3600)
+
+ # Get 1d candles
+ if self.ohlcv_cache['1d'] is not None and not self.ohlcv_cache['1d'].empty:
+ df_1d = self.ohlcv_cache['1d'].copy()
+ else:
+ df_1d = self.tick_storage.get_candles(interval_seconds=86400)
+
+ # Limit the number of candles to display but show more for context
+ if df_1m is not None and not df_1m.empty:
+ df_1m = df_1m.tail(600) # Show 600 1m candles for better context - 10 hours
+ if df_1h is not None and not df_1h.empty:
+ df_1h = df_1h.tail(480) # Show 480 hours of hourly data for better context - 20 days
+ if df_1d is not None and not df_1d.empty:
+ df_1d = df_1d.tail(365*2) # Show 2 years of daily data for better context
+
+ except Exception as e:
+ logger.error(f"Error getting additional timeframes: {str(e)}")
+
+ # Create layout with 5 rows (main chart, volume, 1m, 1h, 1d)
+ fig = make_subplots(
+ rows=5, cols=1,
+ vertical_spacing=0.03,
+ subplot_titles=(
+ f'{self.symbol} Price ({interval}s)',
+ 'Volume',
+ '1-Minute Chart',
+ '1-Hour Chart',
+ '1-Day Chart'
+ ),
+ row_heights=[0.4, 0.15, 0.15, 0.15, 0.15]
+ )
+
+ # Add candlestick chart for main timeframe
+ if not df.empty and 'open' in df.columns:
+ # Candlestick chart
+ fig.add_trace(
+ go.Candlestick(
+ x=df.index,
+ open=df['open'],
+ high=df['high'],
+ low=df['low'],
+ close=df['close'],
+ name='OHLC'
+ ),
+ row=1, col=1
+ )
+
+ # Calculate y-axis range with padding
+ low_min = df['low'].min()
+ high_max = df['high'].max()
+ price_range = high_max - low_min
+ y_min = low_min - (price_range * 0.05) # 5% padding below
+ y_max = high_max + (price_range * 0.05) # 5% padding above
+
+ # Set y-axis range to ensure candles are visible
+ fig.update_yaxes(range=[y_min, y_max], row=1, col=1)
+
+ # Volume bars
+ colors = ['rgba(0,255,0,0.7)' if close >= open else 'rgba(255,0,0,0.7)'
+ for open, close in zip(df['open'], df['close'])]
+
+ fig.add_trace(
+ go.Bar(
+ x=df.index,
+ y=df['volume'],
+ name='Volume',
+ marker_color=colors
+ ),
+ row=2, col=1
+ )
+
+ # Add buy/sell markers and PnL annotations to the main chart
+ if hasattr(self, 'trades') and self.trades:
+ buy_times = []
+ buy_prices = []
+ buy_markers = []
+ sell_times = []
+ sell_prices = []
+ sell_markers = []
+
+ # Filter trades to only include recent ones (last 100)
+ recent_trades = self.trades[-100:]
+
+ for trade in recent_trades:
+ # Convert timestamp to datetime if it's not already
+ trade_time = trade.get('timestamp')
+ if isinstance(trade_time, (int, float)):
+ trade_time = pd.to_datetime(trade_time, unit='ms')
+
+ price = trade.get('price', 0)
+ pnl = trade.get('pnl', None)
+ action = trade.get('action', 'SELL') # Default to SELL
+
+ if action == 'BUY':
+ buy_times.append(trade_time)
+ buy_prices.append(price)
+ buy_markers.append("")
+ elif action == 'SELL':
+ sell_times.append(trade_time)
+ sell_prices.append(price)
+ # Add PnL as marker text if available
+ if pnl is not None:
+ pnl_text = f"{pnl:.4f}" if abs(pnl) < 0.01 else f"{pnl:.2f}"
+ sell_markers.append(pnl_text)
+ else:
+ sell_markers.append("")
+
+ # Add buy markers
+ if buy_times:
+ fig.add_trace(
+ go.Scatter(
+ x=buy_times,
+ y=buy_prices,
+ mode='markers',
+ name='Buy',
+ marker=dict(
+ symbol='triangle-up',
+ size=12,
+ color='rgba(0,255,0,0.8)',
+ line=dict(width=1, color='darkgreen')
+ ),
+ text=buy_markers,
+ hoverinfo='x+y+text',
+ showlegend=True # Ensure Buy appears in legend
+ ),
+ row=1, col=1
+ )
+
+ # Add vertical and horizontal connecting lines for buys
+ for i, (btime, bprice) in enumerate(zip(buy_times, buy_prices)):
+ # Add vertical dashed line to time axis
+ fig.add_shape(
+ type="line",
+ x0=btime, x1=btime,
+ y0=y_min, y1=bprice,
+ line=dict(color="rgba(0,255,0,0.5)", width=1, dash="dash"),
+ row=1, col=1
+ )
+ # Add horizontal dashed line showing the price level
+ fig.add_shape(
+ type="line",
+ x0=df.index.min(), x1=btime,
+ y0=bprice, y1=bprice,
+ line=dict(color="rgba(0,255,0,0.5)", width=1, dash="dash"),
+ row=1, col=1
+ )
+
+ # Add sell markers with PnL annotations
+ if sell_times:
+ fig.add_trace(
+ go.Scatter(
+ x=sell_times,
+ y=sell_prices,
+ mode='markers+text',
+ name='Sell',
+ marker=dict(
+ symbol='triangle-down',
+ size=12,
+ color='rgba(255,0,0,0.8)',
+ line=dict(width=1, color='darkred')
+ ),
+ text=sell_markers,
+ textposition='top center',
+ textfont=dict(size=10),
+ hoverinfo='x+y+text',
+ showlegend=True # Ensure Sell appears in legend
+ ),
+ row=1, col=1
+ )
+
+ # Add vertical and horizontal connecting lines for sells
+ for i, (stime, sprice) in enumerate(zip(sell_times, sell_prices)):
+ # Add vertical dashed line to time axis
+ fig.add_shape(
+ type="line",
+ x0=stime, x1=stime,
+ y0=y_min, y1=sprice,
+ line=dict(color="rgba(255,0,0,0.5)", width=1, dash="dash"),
+ row=1, col=1
+ )
+ # Add horizontal dashed line showing the price level
+ fig.add_shape(
+ type="line",
+ x0=df.index.min(), x1=stime,
+ y0=sprice, y1=sprice,
+ line=dict(color="rgba(255,0,0,0.5)", width=1, dash="dash"),
+ row=1, col=1
+ )
+
+ # Add connecting lines between consecutive buy-sell pairs
+ if len(buy_times) > 0 and len(sell_times) > 0:
+ # Create pairs of buy-sell trades based on timestamps
+ pairs = []
+ buys_copy = list(zip(buy_times, buy_prices))
+
+ for i, (stime, sprice) in enumerate(zip(sell_times, sell_prices)):
+ # Find the most recent buy before this sell
+ matching_buy = None
+ for j, (btime, bprice) in enumerate(buys_copy):
+ if btime < stime:
+ matching_buy = (btime, bprice)
+ buys_copy.pop(j) # Remove this buy to prevent reuse
+ break
+
+ if matching_buy:
+ pairs.append((matching_buy, (stime, sprice)))
+
+ # Add connecting lines for each pair
+ for (btime, bprice), (stime, sprice) in pairs:
+ # Draw line connecting the buy and sell points
+ fig.add_shape(
+ type="line",
+ x0=btime, x1=stime,
+ y0=bprice, y1=sprice,
+ line=dict(
+ color="rgba(255,255,255,0.5)",
+ width=1,
+ dash="dot"
+ ),
+ row=1, col=1
+ )
+
+ # Add 1m chart
+ if df_1m is not None and not df_1m.empty and 'open' in df_1m.columns:
+ fig.add_trace(
+ go.Candlestick(
+ x=df_1m.index,
+ open=df_1m['open'],
+ high=df_1m['high'],
+ low=df_1m['low'],
+ close=df_1m['close'],
+ name='1m',
+ showlegend=False
+ ),
+ row=3, col=1
+ )
+
+ # Set appropriate date format for 1m chart
+ fig.update_xaxes(
+ title_text="",
+ row=3,
+ col=1,
+ tickformat="%H:%M",
+ tickmode="auto",
+ nticks=12
+ )
+
+ # Add buy/sell markers to 1m chart if they fall within the visible timeframe
+ if hasattr(self, 'trades') and self.trades:
+ # Filter trades visible in 1m timeframe
+ min_time = df_1m.index.min()
+ max_time = df_1m.index.max()
+
+ # Ensure min_time and max_time are pandas.Timestamp objects
+ if isinstance(min_time, (int, float)):
+ min_time = pd.to_datetime(min_time, unit='ms')
+ if isinstance(max_time, (int, float)):
+ max_time = pd.to_datetime(max_time, unit='ms')
+
+ # Collect only trades within this timeframe
+ minute_buy_times = []
+ minute_buy_prices = []
+ minute_sell_times = []
+ minute_sell_prices = []
+
+ for trade in self.trades[-100:]:
+ trade_time = trade.get('timestamp')
+ if isinstance(trade_time, (int, float)):
+ # Convert numeric timestamp to datetime
+ trade_time = pd.to_datetime(trade_time, unit='ms')
+ elif not isinstance(trade_time, pd.Timestamp) and not isinstance(trade_time, datetime):
+ # Skip trades with invalid timestamp format
+ continue
+
+ # Check if trade falls within 1m chart timeframe
+ try:
+ if min_time <= trade_time <= max_time:
+ price = trade.get('price', 0)
+ action = trade.get('action', 'SELL')
+
+ if action == 'BUY':
+ minute_buy_times.append(trade_time)
+ minute_buy_prices.append(price)
+ elif action == 'SELL':
+ minute_sell_times.append(trade_time)
+ minute_sell_prices.append(price)
+ except TypeError:
+ # If comparison fails due to type mismatch, log the error and skip this trade
+ logger.warning(f"Type mismatch in timestamp comparison: min_time={type(min_time)}, trade_time={type(trade_time)}")
+ continue
+
+ # Add buy markers to 1m chart
+ if minute_buy_times:
+ fig.add_trace(
+ go.Scatter(
+ x=minute_buy_times,
+ y=minute_buy_prices,
+ mode='markers',
+ name='Buy (1m)',
+ marker=dict(
+ symbol='triangle-up',
+ size=8,
+ color='rgba(0,255,0,0.8)',
+ line=dict(width=1, color='darkgreen')
+ ),
+ showlegend=False,
+ hoverinfo='x+y'
+ ),
+ row=3, col=1
+ )
+
+ # Add sell markers to 1m chart
+ if minute_sell_times:
+ fig.add_trace(
+ go.Scatter(
+ x=minute_sell_times,
+ y=minute_sell_prices,
+ mode='markers',
+ name='Sell (1m)',
+ marker=dict(
+ symbol='triangle-down',
+ size=8,
+ color='rgba(255,0,0,0.8)',
+ line=dict(width=1, color='darkred')
+ ),
+ showlegend=False,
+ hoverinfo='x+y'
+ ),
+ row=3, col=1
+ )
+
+ # Add 1h chart
+ if df_1h is not None and not df_1h.empty and 'open' in df_1h.columns:
+ fig.add_trace(
+ go.Candlestick(
+ x=df_1h.index,
+ open=df_1h['open'],
+ high=df_1h['high'],
+ low=df_1h['low'],
+ close=df_1h['close'],
+ name='1h',
+ showlegend=False
+ ),
+ row=4, col=1
+ )
+
+ # Set appropriate date format for 1h chart
+ fig.update_xaxes(
+ title_text="",
+ row=4,
+ col=1,
+ tickformat="%m-%d %H:%M",
+ tickmode="auto",
+ nticks=8
+ )
+
+ # Add buy/sell markers to 1h chart if they fall within the visible timeframe
+ if hasattr(self, 'trades') and self.trades:
+ # Filter trades visible in 1h timeframe
+ min_time = df_1h.index.min()
+ max_time = df_1h.index.max()
+
+ # Ensure min_time and max_time are pandas.Timestamp objects
+ if isinstance(min_time, (int, float)):
+ min_time = pd.to_datetime(min_time, unit='ms')
+ if isinstance(max_time, (int, float)):
+ max_time = pd.to_datetime(max_time, unit='ms')
+
+ # Collect only trades within this timeframe
+ hour_buy_times = []
+ hour_buy_prices = []
+ hour_sell_times = []
+ hour_sell_prices = []
+
+ for trade in self.trades[-200:]: # Check more trades for longer timeframe
+ trade_time = trade.get('timestamp')
+ if isinstance(trade_time, (int, float)):
+ # Convert numeric timestamp to datetime
+ trade_time = pd.to_datetime(trade_time, unit='ms')
+ elif not isinstance(trade_time, pd.Timestamp) and not isinstance(trade_time, datetime):
+ # Skip trades with invalid timestamp format
+ continue
+
+ # Check if trade falls within 1h chart timeframe
+ try:
+ if min_time <= trade_time <= max_time:
+ price = trade.get('price', 0)
+ action = trade.get('action', 'SELL')
+
+ if action == 'BUY':
+ hour_buy_times.append(trade_time)
+ hour_buy_prices.append(price)
+ elif action == 'SELL':
+ hour_sell_times.append(trade_time)
+ hour_sell_prices.append(price)
+ except TypeError:
+ # If comparison fails due to type mismatch, log the error and skip this trade
+ logger.warning(f"Type mismatch in timestamp comparison: min_time={type(min_time)}, trade_time={type(trade_time)}")
+ continue
+
+ # Add buy markers to 1h chart
+ if hour_buy_times:
+ fig.add_trace(
+ go.Scatter(
+ x=hour_buy_times,
+ y=hour_buy_prices,
+ mode='markers',
+ name='Buy (1h)',
+ marker=dict(
+ symbol='triangle-up',
+ size=6,
+ color='rgba(0,255,0,0.8)',
+ line=dict(width=1, color='darkgreen')
+ ),
+ showlegend=False,
+ hoverinfo='x+y'
+ ),
+ row=4, col=1
+ )
+
+ # Add sell markers to 1h chart
+ if hour_sell_times:
+ fig.add_trace(
+ go.Scatter(
+ x=hour_sell_times,
+ y=hour_sell_prices,
+ mode='markers',
+ name='Sell (1h)',
+ marker=dict(
+ symbol='triangle-down',
+ size=6,
+ color='rgba(255,0,0,0.8)',
+ line=dict(width=1, color='darkred')
+ ),
+ showlegend=False,
+ hoverinfo='x+y'
+ ),
+ row=4, col=1
+ )
+ # Add 1d chart
+ if df_1d is not None and not df_1d.empty and 'open' in df_1d.columns:
+ fig.add_trace(
+ go.Candlestick(
+ x=df_1d.index,
+ open=df_1d['open'],
+ high=df_1d['high'],
+ low=df_1d['low'],
+ close=df_1d['close'],
+ name='1d',
+ showlegend=False
+ ),
+ row=5, col=1
+ )
+
+ # Set appropriate date format for 1d chart
+ fig.update_xaxes(
+ title_text="",
+ row=5,
+ col=1,
+ tickformat="%Y-%m-%d",
+ tickmode="auto",
+ nticks=10
+ )
+
+ # Add buy/sell markers to 1d chart if they fall within the visible timeframe
+ if hasattr(self, 'trades') and self.trades:
+ # Filter trades visible in 1d timeframe
+ min_time = df_1d.index.min()
+ max_time = df_1d.index.max()
+
+ # Ensure min_time and max_time are pandas.Timestamp objects
+ if isinstance(min_time, (int, float)):
+ min_time = pd.to_datetime(min_time, unit='ms')
+ if isinstance(max_time, (int, float)):
+ max_time = pd.to_datetime(max_time, unit='ms')
+
+ # Collect only trades within this timeframe
+ day_buy_times = []
+ day_buy_prices = []
+ day_sell_times = []
+ day_sell_prices = []
+
+ for trade in self.trades[-300:]: # Check more trades for daily timeframe
+ trade_time = trade.get('timestamp')
+ if isinstance(trade_time, (int, float)):
+ # Convert numeric timestamp to datetime
+ trade_time = pd.to_datetime(trade_time, unit='ms')
+ elif not isinstance(trade_time, pd.Timestamp) and not isinstance(trade_time, datetime):
+ # Skip trades with invalid timestamp format
+ continue
+
+ # Check if trade falls within 1d chart timeframe
+ try:
+ if min_time <= trade_time <= max_time:
+ price = trade.get('price', 0)
+ action = trade.get('action', 'SELL')
+
+ if action == 'BUY':
+ day_buy_times.append(trade_time)
+ day_buy_prices.append(price)
+ elif action == 'SELL':
+ day_sell_times.append(trade_time)
+ day_sell_prices.append(price)
+ except TypeError:
+ # If comparison fails due to type mismatch, log the error and skip this trade
+ logger.warning(f"Type mismatch in timestamp comparison: min_time={type(min_time)}, trade_time={type(trade_time)}")
+ continue
+
+ # Add buy markers to 1d chart
+ if day_buy_times:
+ fig.add_trace(
+ go.Scatter(
+ x=day_buy_times,
+ y=day_buy_prices,
+ mode='markers',
+ name='Buy (1d)',
+ marker=dict(
+ symbol='triangle-up',
+ size=5,
+ color='rgba(0,255,0,0.8)',
+ line=dict(width=1, color='darkgreen')
+ ),
+ showlegend=False,
+ hoverinfo='x+y'
+ ),
+ row=5, col=1
+ )
+
+ # Add sell markers to 1d chart
+ if day_sell_times:
+ fig.add_trace(
+ go.Scatter(
+ x=day_sell_times,
+ y=day_sell_prices,
+ mode='markers',
+ name='Sell (1d)',
+ marker=dict(
+ symbol='triangle-down',
+ size=5,
+ color='rgba(255,0,0,0.8)',
+ line=dict(width=1, color='darkred')
+ ),
+ showlegend=False,
+ hoverinfo='x+y'
+ ),
+ row=5, col=1
+ )
+
+ # Add trading info annotation if available
+ if hasattr(self, 'current_signal') and self.current_signal:
+ signal_color = "#33DD33" if self.current_signal == "BUY" else "#FF4444" if self.current_signal == "SELL" else "#BBBBBB"
+
+ # Format position value
+ position_text = f"{self.current_position:.4f}" if self.current_position < 0.01 else f"{self.current_position:.2f}"
+
+ # Format PnL with color based on value
+ pnl_color = "#33DD33" if self.session_pnl >= 0 else "#FF4444"
+ pnl_text = f"{self.session_pnl:.4f}"
+
+ # Create trading info text
+ info_text = (
+ f"Signal: {self.current_signal} | "
+ f"Position: {position_text} | "
+ f"Balance: ${self.session_balance:.2f} | "
+ f"PnL: {pnl_text}"
+ )
+
+ # Add annotation
+ fig.add_annotation(
+ x=0.5, y=1.05,
+ xref="paper", yref="paper",
+ text=info_text,
+ showarrow=False,
+ font=dict(size=14, color="white"),
+ bgcolor="rgba(50,50,50,0.6)",
+ borderwidth=1,
+ borderpad=6,
+ align="center"
+ )
+
+ # Update layout
+ fig.update_layout(
+ title_text=f"{self.symbol} Real-Time Data",
+ title_x=0.5,
+ xaxis_rangeslider_visible=False,
+ height=1000, # Increased height to accommodate all charts
+ template='plotly_dark',
+ paper_bgcolor='rgba(0,0,0,0)',
+ plot_bgcolor='rgba(25,25,50,1)'
+ )
+
+ # Update axes styling for all subplots
+ fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
+ fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)')
+
+ # Hide rangesliders for all candlestick charts
+ fig.update_layout(
+ xaxis_rangeslider_visible=False,
+ xaxis2_rangeslider_visible=False,
+ xaxis3_rangeslider_visible=False,
+ xaxis4_rangeslider_visible=False,
+ xaxis5_rangeslider_visible=False
+ )
+
+ # Improve date formatting for the main chart
+ fig.update_xaxes(
+ title_text="",
+ row=1,
+ col=1,
+ tickformat="%H:%M:%S",
+ tickmode="auto",
+ nticks=15
+ )
+
+ return fig
+
+ except Exception as e:
+ logger.error(f"Error updating chart: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+ fig = go.Figure()
+ fig.add_annotation(
+ x=0.5, y=0.5,
+ text=f"Error updating chart: {str(e)}",
+ showarrow=False,
+ font=dict(size=14, color="red"),
+ xref="paper", yref="paper"
+ )
+ return fig
+
+ def _setup_position_list_callback(self):
+ """Callback to update the position list with the latest 5 trades"""
+ @self.app.callback(
+ Output('position-table-body', 'children'),
+ [Input('interval-component', 'n_intervals')]
+ )
+ def update_position_list(n):
+ try:
+ # Get the last 5 positions
+ if not hasattr(self, 'trades') or not self.trades:
+ return [html.Tr([html.Td("No trades yet", colSpan=6, style={'textAlign': 'center', 'color': '#AAAAAA', 'padding': '10px'})])]
+
+ # Collect BUY and SELL trades to pair them
+ buy_trades = {}
+ sell_trades = {}
+ position_rows = []
+
+ # Process all trades to match buy and sell pairs
+ for trade in self.trades:
+ action = trade.get('action', 'UNKNOWN')
+ if action == 'BUY':
+ # Store buy trades by timestamp to match with sells later
+ buy_trades[trade.get('timestamp')] = trade
+ elif action == 'SELL':
+ # Store sell trades
+ sell_trades[trade.get('timestamp')] = trade
+
+ # Create position entries (matched pairs or single trades)
+ position_entries = []
+
+ # First add matched pairs (BUY trades with matching close_price)
+ for timestamp, buy_trade in buy_trades.items():
+ if 'close_price' in buy_trade:
+ # This is a closed position
+ entry_price = buy_trade.get('price', 'N/A')
+ exit_price = buy_trade.get('close_price', 'N/A')
+ entry_time = buy_trade.get('timestamp')
+ exit_time = buy_trade.get('close_timestamp')
+ amount = buy_trade.get('amount', 0.1)
+
+ # Calculate PnL if not provided
+ pnl = buy_trade.get('pnl')
+ if pnl is None and isinstance(entry_price, (int, float)) and isinstance(exit_price, (int, float)):
+ pnl = (exit_price - entry_price) * amount
+
+ position_entries.append({
+ 'action': 'CLOSED',
+ 'entry_price': entry_price,
+ 'exit_price': exit_price,
+ 'amount': amount,
+ 'pnl': pnl,
+ 'time': entry_time,
+ 'exit_time': exit_time,
+ 'status': 'CLOSED'
+ })
+ else:
+ # This is an open position (BUY without a matching SELL)
+ current_price = self.tick_storage.get_latest_price() or buy_trade.get('price', 0)
+ amount = buy_trade.get('amount', 0.1)
+ entry_price = buy_trade.get('price', 0)
+
+ # Calculate unrealized PnL
+ unrealized_pnl = (current_price - entry_price) * amount if isinstance(current_price, (int, float)) and isinstance(entry_price, (int, float)) else None
+
+ position_entries.append({
+ 'action': 'BUY',
+ 'entry_price': entry_price,
+ 'exit_price': current_price, # Use current price as potential exit
+ 'amount': amount,
+ 'pnl': unrealized_pnl,
+ 'time': buy_trade.get('timestamp'),
+ 'status': 'OPEN'
+ })
+
+ # Add standalone SELL trades that don't have a matching BUY
+ for timestamp, sell_trade in sell_trades.items():
+ # Check if this SELL is already accounted for in a closed position
+ already_matched = False
+ for entry in position_entries:
+ if entry.get('status') == 'CLOSED' and entry.get('exit_time') == timestamp:
+ already_matched = True
+ break
+
+ if not already_matched:
+ position_entries.append({
+ 'action': 'SELL',
+ 'entry_price': 'N/A',
+ 'exit_price': sell_trade.get('price', 'N/A'),
+ 'amount': sell_trade.get('amount', 0.1),
+ 'pnl': sell_trade.get('pnl'),
+ 'time': sell_trade.get('timestamp'),
+ 'status': 'STANDALONE'
+ })
+
+ # Sort by time (most recent first) and take last 5
+ position_entries.sort(key=lambda x: x['time'] if isinstance(x['time'], datetime) else datetime.now(), reverse=True)
+ position_entries = position_entries[:5]
+
+ # Convert to table rows
+ for entry in position_entries:
+ action = entry['action']
+ entry_price = entry['entry_price']
+ exit_price = entry['exit_price']
+ amount = entry['amount']
+ pnl = entry['pnl']
+ time_obj = entry['time']
+ status = entry['status']
+
+ # Format time
+ if isinstance(time_obj, datetime):
+ # If trade is from a different day, include the date
+ today = datetime.now().date()
+ if time_obj.date() == today:
+ time_str = time_obj.strftime('%H:%M:%S')
+ else:
+ time_str = time_obj.strftime('%m-%d %H:%M:%S')
+ else:
+ time_str = str(time_obj)
+
+ # Format prices with proper decimal places
+ if isinstance(entry_price, (int, float)):
+ entry_price_str = f"${entry_price:.2f}"
+ else:
+ entry_price_str = str(entry_price)
+
+ if isinstance(exit_price, (int, float)):
+ exit_price_str = f"${exit_price:.2f}"
+ else:
+ exit_price_str = str(exit_price)
+
+ # Format PnL
+ if pnl is not None and isinstance(pnl, (int, float)):
+ pnl_str = f"${pnl:.2f}"
+ pnl_color = '#00FF00' if pnl >= 0 else '#FF0000'
+ else:
+ pnl_str = 'N/A'
+ pnl_color = '#FFFFFF'
+
+ # Set action/status color and text
+ if status == 'OPEN':
+ status_color = '#00AAFF' # Blue for open positions
+ status_text = "OPEN (BUY)"
+ elif status == 'CLOSED':
+ if pnl is not None and isinstance(pnl, (int, float)):
+ status_color = '#00FF00' if pnl >= 0 else '#FF0000' # Green/Red based on profit
+ else:
+ status_color = '#FFCC00' # Yellow if PnL unknown
+ status_text = "CLOSED"
+ elif action == 'BUY':
+ status_color = '#00FF00'
+ status_text = "BUY"
+ elif action == 'SELL':
+ status_color = '#FF0000'
+ status_text = "SELL"
+ else:
+ status_color = '#FFFFFF'
+ status_text = action
+
+ # Create table row
+ position_rows.append(html.Tr([
+ html.Td(status_text, style={'color': status_color, 'padding': '8px', 'border': '1px solid #444'}),
+ html.Td(f"{amount} BTC", style={'padding': '8px', 'border': '1px solid #444'}),
+ html.Td(entry_price_str, style={'padding': '8px', 'border': '1px solid #444'}),
+ html.Td(exit_price_str, style={'padding': '8px', 'border': '1px solid #444'}),
+ html.Td(pnl_str, style={'color': pnl_color, 'padding': '8px', 'border': '1px solid #444'}),
+ html.Td(time_str, style={'padding': '8px', 'border': '1px solid #444'})
+ ]))
+
+ return position_rows
+ except Exception as e:
+ logger.error(f"Error updating position table: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+ return [html.Tr([html.Td(f"Error: {str(e)}", colSpan=6, style={'color': '#FF0000', 'padding': '10px'})])]
+
+ def _interval_to_seconds(self, interval_key: str) -> int:
+ """Convert interval key to seconds"""
+ mapping = {
+ '1s': 1,
+ '1m': 60,
+ '1h': 3600,
+ '1d': 86400
+ }
+ return mapping.get(interval_key, 1)
+
+ async def start_websocket(self):
+ ws = ExchangeWebSocket(self.symbol)
+ connection_attempts = 0
+ max_attempts = 10 # Maximum connection attempts before longer waiting period
+
+ while True: # Keep trying to maintain connection
+ connection_attempts += 1
+ if not await ws.connect():
+ logger.error(f"Failed to connect to exchange for {self.symbol}")
+ # Gradually increase wait time based on number of connection failures
+ wait_time = min(5 * connection_attempts, 60) # Cap at 60 seconds
+ logger.warning(f"Waiting {wait_time} seconds before retry (attempt {connection_attempts})")
+
+ if connection_attempts >= max_attempts:
+ logger.warning(f"Reached {max_attempts} connection attempts, taking a longer break")
+ await asyncio.sleep(120) # 2 minutes wait after max attempts
+ connection_attempts = 0 # Reset counter
+ else:
+ await asyncio.sleep(wait_time)
+ continue
+
+ # Successfully connected
+ connection_attempts = 0
+
+ try:
+ logger.info(f"WebSocket connected for {self.symbol}, beginning data collection")
+ tick_count = 0
+ last_tick_count_log = time.time()
+ last_status_report = time.time()
+
+ # Track stats for reporting
+ price_min = float('inf')
+ price_max = float('-inf')
+ price_last = None
+ volume_total = 0
+ start_collection_time = time.time()
+
+ while True:
+ if not ws.running:
+ logger.warning(f"WebSocket connection lost for {self.symbol}, breaking loop")
+ break
+
+ data = await ws.receive()
+ if data:
+ if data.get('type') == 'kline':
+ # Use kline data directly for candlestick
+ trade_data = {
+ 'timestamp': data['timestamp'],
+ 'price': data['price'],
+ 'volume': data['volume'],
+ 'open': data['open'],
+ 'high': data['high'],
+ 'low': data['low']
+ }
+ logger.debug(f"Received kline data: {data}")
+ else:
+ # Use trade data
+ trade_data = {
+ 'timestamp': data['timestamp'],
+ 'price': data['price'],
+ 'volume': data['volume']
+ }
+
+ # Update stats
+ price = trade_data['price']
+ volume = trade_data['volume']
+ price_min = min(price_min, price)
+ price_max = max(price_max, price)
+ price_last = price
+ volume_total += volume
+
+ # Store raw tick in the tick storage
+ self.tick_storage.add_tick(trade_data)
+ tick_count += 1
+
+ # Also update the old candlestick data for backward compatibility
+ # Add check to ensure the candlestick_data attribute exists before using it
+ if hasattr(self, 'candlestick_data'):
+ self.candlestick_data.update_from_trade(trade_data)
+
+ # Log tick counts periodically
+ current_time = time.time()
+ if current_time - last_tick_count_log >= 10: # Log every 10 seconds
+ elapsed = current_time - last_tick_count_log
+ tps = tick_count / elapsed if elapsed > 0 else 0
+ logger.info(f"{self.symbol}: Collected {tick_count} ticks in last {elapsed:.1f}s ({tps:.2f} ticks/sec), total: {len(self.tick_storage.ticks)}")
+ last_tick_count_log = current_time
+ tick_count = 0
+
+ # Check if ticks are being converted to candles
+ if len(self.tick_storage.ticks) > 0:
+ sample_df = self.tick_storage.get_candles(interval_seconds=1)
+ logger.info(f"{self.symbol}: Sample candle count: {len(sample_df)}")
+
+ # Periodic status report (every 60 seconds)
+ if current_time - last_status_report >= 60:
+ elapsed_total = current_time - start_collection_time
+ logger.info(f"{self.symbol} Status Report:")
+ logger.info(f" Collection time: {elapsed_total:.1f} seconds")
+ logger.info(f" Price range: {price_min:.2f} - {price_max:.2f} (last: {price_last:.2f})")
+ logger.info(f" Total volume: {volume_total:.8f}")
+ logger.info(f" Active ticks in storage: {len(self.tick_storage.ticks)}")
+
+ # Reset stats for next period
+ last_status_report = current_time
+ price_min = float('inf') if price_last is None else price_last
+ price_max = float('-inf') if price_last is None else price_last
+ volume_total = 0
+
+ await asyncio.sleep(0.01)
+ except websockets.exceptions.ConnectionClosed as e:
+ logger.error(f"WebSocket connection closed for {self.symbol}: {str(e)}")
+ except Exception as e:
+ logger.error(f"Error in WebSocket loop for {self.symbol}: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+ finally:
+ logger.info(f"Closing WebSocket connection for {self.symbol}")
+ await ws.close()
+
+ logger.info(f"Waiting 5 seconds before reconnecting {self.symbol} WebSocket...")
+ await asyncio.sleep(5)
+
+ def run(self, host='localhost', port=8050):
+ """Run the Dash app
+
+ Args:
+ host: Hostname to run on
+ port: Port to run on
+ """
+ logger.info(f"Starting Dash app on {host}:{port}")
+
+ # Ensure interval component is created
+ if not hasattr(self, 'app') or not self.app.layout:
+ logger.error("App layout not initialized properly")
+ return
+
+ # If interval-component is not in the layout, add it
+ if 'interval-component' not in str(self.app.layout):
+ logger.warning("Interval component not found in layout, adding it")
+ self.app.layout.children.append(
+ dcc.Interval(
+ id='interval-component',
+ interval=500, # 500ms for real-time updates
+ n_intervals=0
+ )
+ )
+
+ # Start websocket connection in a separate thread
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ self.websocket_thread = threading.Thread(target=lambda: asyncio.run(self.start_websocket()))
+ self.websocket_thread.daemon = True
+ self.websocket_thread.start()
+
+ # Ensure historical data is loaded before starting
+ self._load_historical_data()
+
+ try:
+ self.app.run(host=host, port=port, debug=False)
+ except Exception as e:
+ logger.error(f"Error running Dash app: {str(e)}")
+ finally:
+ # Ensure resources are cleaned up
+ self._save_candles_to_disk(force=True)
+ logger.info("Dash app stopped")
+
+ def _load_historical_data(self):
+ """Load historical data for all timeframes from Binance API and local cache"""
+ try:
+ logger.info(f"Loading historical data for {self.symbol}...")
+
+ # Define intervals to fetch
+ intervals = {
+ '1s': 1,
+ '1m': 60,
+ '1h': 3600,
+ '1d': 86400
+ }
+
+ # Track load status
+ load_status = {interval: False for interval in intervals.keys()}
+
+ # First try to load from local cache files
+ logger.info("Step 1: Loading from local cache files...")
+ for interval_key, interval_seconds in intervals.items():
+ try:
+ cache_file = os.path.join(self.historical_data.cache_dir,
+ f"{self.symbol.replace('/', '_')}_{interval_key}_candles.csv")
+
+ logger.info(f"Checking for cached {interval_key} data at {cache_file}")
+ if os.path.exists(cache_file):
+ # Check if cache is fresh (less than 1 day old for anything but 1d, 3 days for 1d)
+ file_age = time.time() - os.path.getmtime(cache_file)
+ max_age = 259200 if interval_key == '1d' else 86400 # 3 days for 1d, 1 day for others
+ logger.info(f"Cache file age: {file_age:.1f}s, max allowed: {max_age}s")
+
+ if file_age <= max_age:
+ logger.info(f"Loading {interval_key} candles from cache")
+ cached_df = pd.read_csv(cache_file)
+ if not cached_df.empty:
+ # Diagnostic info about the loaded data
+ logger.info(f"Loaded {len(cached_df)} candles from {cache_file}")
+ logger.info(f"Columns: {cached_df.columns.tolist()}")
+ logger.info(f"First few rows: {cached_df.head(2).to_dict('records')}")
+
+ # Convert timestamp string back to datetime
+ if 'timestamp' in cached_df.columns:
+ try:
+ if not pd.api.types.is_datetime64_any_dtype(cached_df['timestamp']):
+ cached_df['timestamp'] = pd.to_datetime(cached_df['timestamp'])
+ logger.info("Successfully converted timestamps to datetime")
+ except Exception as e:
+ logger.warning(f"Could not convert timestamp column for {interval_key}: {str(e)}")
+
+ # Only keep the last 2000 candles for memory efficiency
+ if len(cached_df) > 2000:
+ cached_df = cached_df.tail(2000)
+ logger.info(f"Truncated to last 2000 candles")
+
+ # Add to cache
+ for _, row in cached_df.iterrows():
+ candle_dict = row.to_dict()
+ self.candle_cache.candles[interval_key].append(candle_dict)
+
+ # Update ohlcv_cache
+ self.ohlcv_cache[interval_key] = self.candle_cache.get_recent_candles(interval_key, count=2000)
+ logger.info(f"Successfully loaded {len(self.ohlcv_cache[interval_key])} cached {interval_key} candles")
+
+ if len(self.ohlcv_cache[interval_key]) >= 500:
+ load_status[interval_key] = True
+ # Skip fetching from API if we loaded from cache (except for 1d timeframe which we always refresh)
+ if interval_key != '1d':
+ continue
+ else:
+ logger.info(f"Cache file for {interval_key} is too old ({file_age:.1f}s)")
+ else:
+ logger.info(f"No cache file found for {interval_key}")
+ except Exception as e:
+ logger.error(f"Error loading cached {interval_key} candles: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+
+ # For timeframes other than 1s, fetch from API as backup or for fresh data
+ logger.info("Step 2: Fetching data from API for missing timeframes...")
+ for interval_key, interval_seconds in intervals.items():
+ # Skip 1s for API requests
+ if interval_key == '1s' or load_status[interval_key]:
+ logger.info(f"Skipping API fetch for {interval_key}: already loaded or 1s timeframe")
+ continue
+
+ # Fetch historical data from API
+ try:
+ logger.info(f"Fetching {interval_key} candles from API for {self.symbol}")
+ historical_df = self.historical_data.get_historical_candles(
+ symbol=self.symbol,
+ interval_seconds=interval_seconds,
+ limit=500 # Get 500 candles
+ )
+
+ if not historical_df.empty:
+ logger.info(f"Loaded {len(historical_df)} historical candles for {self.symbol} {interval_key} from API")
+
+ # If we already have data in cache, merge with new data to avoid duplicates
+ if self.ohlcv_cache[interval_key] is not None and not self.ohlcv_cache[interval_key].empty:
+ existing_df = self.ohlcv_cache[interval_key]
+ # Get the latest timestamp from existing data
+ latest_time = existing_df['timestamp'].max()
+ # Only keep newer records from API
+ new_candles = historical_df[historical_df['timestamp'] > latest_time]
+ if not new_candles.empty:
+ logger.info(f"Adding {len(new_candles)} new candles to existing {interval_key} cache")
+ # Add to cache
+ for _, row in new_candles.iterrows():
+ candle_dict = row.to_dict()
+ self.candle_cache.candles[interval_key].append(candle_dict)
+ else:
+ # No existing data, add all from API
+ for _, row in historical_df.iterrows():
+ candle_dict = row.to_dict()
+ self.candle_cache.candles[interval_key].append(candle_dict)
+
+ # Update ohlcv_cache with combined data
+ self.ohlcv_cache[interval_key] = self.candle_cache.get_recent_candles(interval_key, count=2000)
+ logger.info(f"Total {interval_key} candles in cache: {len(self.ohlcv_cache[interval_key])}")
+
+ if len(self.ohlcv_cache[interval_key]) >= 500:
+ load_status[interval_key] = True
+ else:
+ logger.warning(f"No historical data available from API for {self.symbol} {interval_key}")
+ except Exception as e:
+ logger.error(f"Error fetching {interval_key} data from API: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+
+ # Log summary of loaded data
+ logger.info("Historical data load summary:")
+ for interval_key in intervals.keys():
+ count = len(self.ohlcv_cache[interval_key]) if self.ohlcv_cache[interval_key] is not None else 0
+ status = "Success" if load_status[interval_key] else "Failed"
+ if count > 0 and count < 500:
+ status = "Partial"
+ logger.info(f"{interval_key}: {count} candles - {status}")
+
+ except Exception as e:
+ logger.error(f"Error in _load_historical_data: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+
+ def _save_candles_to_disk(self, force=False):
+ """Save current candle cache to disk for persistence between runs"""
+ try:
+ # Only save if we have data and sufficient time has passed (every 5 minutes)
+ current_time = time.time()
+ if not force and current_time - self.last_cache_save_time < 300: # 5 minutes
+ return
+
+ # Save each timeframe's candles to disk
+ for interval_key, candles in self.candle_cache.candles.items():
+ if candles:
+ # Convert to DataFrame
+ df = pd.DataFrame(list(candles))
+ if not df.empty:
+ # Ensure timestamp is properly formatted
+ if 'timestamp' in df.columns:
+ try:
+ if not pd.api.types.is_datetime64_any_dtype(df['timestamp']):
+ df['timestamp'] = pd.to_datetime(df['timestamp'])
+ except:
+ logger.warning(f"Could not convert timestamp column for {interval_key}")
+
+ # Save to disk in the cache directory
+ cache_file = os.path.join(self.historical_data.cache_dir,
+ f"{self.symbol.replace('/', '_')}_{interval_key}_candles.csv")
+ df.to_csv(cache_file, index=False)
+ logger.info(f"Saved {len(df)} {interval_key} candles to {cache_file}")
+
+ self.last_cache_save_time = current_time
+ logger.info(f"Saved all candle caches to disk at {datetime.now()}")
+ except Exception as e:
+ logger.error(f"Error saving candles to disk: {str(e)}")
+ import traceback
+ logger.error(traceback.format_exc())
+
+ def add_nn_signal(self, signal_type, timestamp, probability=None):
+ """Add a neural network signal to be displayed on the chart
+
+ Args:
+ signal_type: The type of signal (BUY, SELL, HOLD)
+ timestamp: The timestamp for the signal
+ probability: Optional probability/confidence value
+ """
+ if signal_type not in ['BUY', 'SELL', 'HOLD']:
+ logger.warning(f"Invalid NN signal type: {signal_type}")
+ return
+
+ # Convert timestamp to datetime if it's not already
+ if not isinstance(timestamp, datetime):
+ try:
+ if isinstance(timestamp, str):
+ timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
+ elif isinstance(timestamp, (int, float)):
+ timestamp = datetime.fromtimestamp(timestamp / 1000.0)
+ except Exception as e:
+ logger.error(f"Error converting timestamp for NN signal: {str(e)}")
+ timestamp = datetime.now()
+
+ # Add the signal to our list
+ self.nn_signals.append({
+ 'type': signal_type,
+ 'timestamp': timestamp,
+ 'probability': probability,
+ 'added': datetime.now()
+ })
+
+ # Only keep the most recent 50 signals
+ if len(self.nn_signals) > 50:
+ self.nn_signals = self.nn_signals[-50:]
+
+ logger.info(f"Added NN signal: {signal_type} at {timestamp}")
+
+ def add_trade(self, price, timestamp, pnl=None, amount=0.1, action=None, type=None):
+ """Add a trade to be displayed on the chart
+
+ Args:
+ price: The price at which the trade was executed
+ timestamp: The timestamp for the trade
+ pnl: Optional profit and loss value for the trade
+ amount: Amount traded
+ action: The type of trade (BUY or SELL) - alternative to type parameter
+ type: The type of trade (BUY or SELL) - alternative to action parameter
+ """
+ # Handle both action and type parameters for backward compatibility
+ trade_type = type or action
+
+ # Default to BUY if trade_type is None or not specified
+ if trade_type is None:
+ logger.warning(f"Trade type not specified in add_trade call, defaulting to BUY. Price: {price}, Timestamp: {timestamp}")
+ trade_type = "BUY"
+
+ if isinstance(trade_type, int):
+ trade_type = "BUY" if trade_type == 0 else "SELL"
+
+ # Ensure trade_type is uppercase if it's a string
+ if isinstance(trade_type, str):
+ trade_type = trade_type.upper()
+
+ if trade_type not in ['BUY', 'SELL']:
+ logger.warning(f"Invalid trade type: {trade_type} (value type: {type(trade_type).__name__}), defaulting to BUY. Price: {price}, Timestamp: {timestamp}")
+ trade_type = "BUY"
+
+ # Convert timestamp to datetime if it's not already
+ if not isinstance(timestamp, datetime):
+ try:
+ if isinstance(timestamp, str):
+ timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00'))
+ elif isinstance(timestamp, (int, float)):
+ timestamp = datetime.fromtimestamp(timestamp / 1000.0)
+ except Exception as e:
+ logger.error(f"Error converting timestamp for trade: {str(e)}")
+ timestamp = datetime.now()
+
+ # Create the trade object
+ trade = {
+ 'price': price,
+ 'timestamp': timestamp,
+ 'pnl': pnl,
+ 'amount': amount,
+ 'action': trade_type
+ }
+
+ # Add to our trades list
+ if not hasattr(self, 'trades'):
+ self.trades = []
+
+ # If this is a SELL trade, try to find the corresponding BUY trade and update it with close_price
+ if trade_type == 'SELL' and len(self.trades) > 0:
+ for i in range(len(self.trades) - 1, -1, -1):
+ prev_trade = self.trades[i]
+ if prev_trade.get('action') == 'BUY' and 'close_price' not in prev_trade:
+ # Found a BUY trade without a close_price, consider it the matching trade
+ prev_trade['close_price'] = price
+ prev_trade['close_timestamp'] = timestamp
+ logger.info(f"Updated BUY trade at {prev_trade['timestamp']} with close price {price}")
+ break
+
+ self.trades.append(trade)
+
+ # Log the trade for debugging
+ pnl_str = f" with PnL: {pnl}" if pnl is not None else ""
+ logger.info(f"Added trade: {trade_type} {amount} at price {price} at time {timestamp}{pnl_str}")
+
+ # Trigger a more frequent update of the chart by scheduling a callback
+ # This helps ensure the trade appears immediately on the chart
+ if hasattr(self, 'app') and self.app is not None:
+ try:
+ # Only update if we have a dash app running
+ # This is a workaround to make trades appear immediately
+ callback_context = dash.callback_context
+ # Force an update by triggering the callback
+ for callback_id, callback_info in self.app.callback_map.items():
+ if 'live-chart' in callback_id:
+ # Found the chart callback, try to trigger it
+ logger.debug(f"Triggering chart update callback after trade")
+ callback_info['callback']()
+ break
+ except Exception as e:
+ # If callback triggering fails, it's not critical
+ logger.debug(f"Failed to trigger chart update: {str(e)}")
+ pass
+
+ return trade
+
+ def update_trading_info(self, signal=None, position=None, balance=None, pnl=None):
+ """Update the current trading information to be displayed on the chart
+
+ Args:
+ signal: Current signal (BUY, SELL, HOLD)
+ position: Current position size
+ balance: Current session balance
+ pnl: Current session PnL
+ """
+ if signal is not None:
+ if signal in ['BUY', 'SELL', 'HOLD']:
+ self.current_signal = signal
+ self.signal_time = datetime.now()
+ else:
+ logger.warning(f"Invalid signal type: {signal}")
+
+ if position is not None:
+ self.current_position = position
+
+ if balance is not None:
+ self.session_balance = balance
+
+ if pnl is not None:
+ self.session_pnl = pnl
+
+ logger.debug(f"Updated trading info: Signal={self.current_signal}, Position={self.current_position}, Balance=${self.session_balance:.2f}, PnL={self.session_pnl:.4f}")
+
+async def main():
+ global charts # Make charts globally accessible for NN integration
+ symbols = ["ETH/USDT", "BTC/USDT"]
+ logger.info(f"Starting application for symbols: {symbols}")
+
+ # Initialize neural network if enabled
+ if NN_ENABLED:
+ logger.info("Initializing Neural Network integration...")
+ if setup_neural_network():
+ logger.info("Neural Network integration initialized successfully")
+ else:
+ logger.warning("Neural Network integration failed to initialize")
+
+ charts = []
+ websocket_tasks = []
+
+ # Create a chart and websocket task for each symbol
+ for symbol in symbols:
+ chart = RealTimeChart(symbol)
+ charts.append(chart)
+ websocket_tasks.append(asyncio.create_task(chart.start_websocket()))
+
+ # Run Dash in a separate thread to not block the event loop
+ server_threads = []
+ for i, chart in enumerate(charts):
+ port = 8050 + i # Use different ports for each chart
+ logger.info(f"Starting chart for {chart.symbol} on port {port}")
+ thread = Thread(target=lambda c=chart, p=port: c.run(port=p)) # Ensure correct port is passed
+ thread.daemon = True
+ thread.start()
+ server_threads.append(thread)
+ logger.info(f"Thread started for {chart.symbol} on port {port}")
+
+ try:
+ # Keep the main task running
+ while True:
+ await asyncio.sleep(1)
+ except KeyboardInterrupt:
+ logger.info("Shutting down...")
+ except Exception as e:
+ logger.error(f"Unexpected error: {str(e)}")
+ finally:
+ for task in websocket_tasks:
+ task.cancel()
+ try:
+ await task
+ except asyncio.CancelledError:
+ pass
+
+if __name__ == "__main__":
+ try:
+ asyncio.run(main())
+ except KeyboardInterrupt:
+ logger.info("Application terminated by user")
+
diff --git a/train_improved_rl.py b/train_improved_rl.py
new file mode 100644
index 0000000..56792da
--- /dev/null
+++ b/train_improved_rl.py
@@ -0,0 +1,547 @@
+#!/usr/bin/env python
+"""
+Improved RL Trading with Enhanced Training and Monitoring
+
+This script provides an improved version of the RL training process,
+implementing better normalization, reward structure, and model training.
+"""
+
+import os
+import sys
+import logging
+import argparse
+import time
+from datetime import datetime
+import numpy as np
+import torch
+import pandas as pd
+import matplotlib.pyplot as plt
+from pathlib import Path
+
+# Add project directory to path if needed
+project_root = os.path.dirname(os.path.abspath(__file__))
+if project_root not in sys.path:
+ sys.path.append(project_root)
+
+# Import our custom modules
+from NN.models.dqn_agent import DQNAgent
+from NN.utils.trading_env import TradingEnvironment
+from NN.utils.data_interface import DataInterface
+from realtime import BinanceHistoricalData, RealTimeChart
+
+# Configure logging
+log_filename = f'improved_rl_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'
+logging.basicConfig(
+ level=logging.INFO,
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
+ handlers=[
+ logging.FileHandler(log_filename),
+ logging.StreamHandler()
+ ]
+)
+logger = logging.getLogger('improved_rl')
+
+# Parse command line arguments
+parser = argparse.ArgumentParser(description='Improved RL Trading with Enhanced Training')
+parser.add_argument('--episodes', type=int, default=20, help='Number of episodes to train')
+parser.add_argument('--visualize', action='store_true', help='Visualize trades during training')
+parser.add_argument('--save-path', type=str, default='NN/models/saved/improved_dqn_agent', help='Path to save trained model')
+parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
+args = parser.parse_args()
+
+def create_training_environment(symbol, window_size=20):
+ """Create and prepare the training environment with data"""
+ logger.info(f"Setting up training environment for {symbol}")
+
+ # Fetch historical data from multiple timeframes
+ data_interface = DataInterface(symbol)
+
+ # Use Binance data provider for fetching data
+ historical_data = BinanceHistoricalData()
+
+ # Fetch data for each timeframe
+ df_1m = historical_data.get_historical_candles(symbol, interval_seconds=60, limit=1000)
+ df_5m = historical_data.get_historical_candles(symbol, interval_seconds=300, limit=1000)
+ df_15m = historical_data.get_historical_candles(symbol, interval_seconds=900, limit=500)
+
+ # Ensure all dataframes have index as timestamp type
+ if df_1m is not None and not df_1m.empty:
+ if 'timestamp' in df_1m.columns:
+ df_1m = df_1m.set_index('timestamp')
+
+ if df_5m is not None and not df_5m.empty:
+ if 'timestamp' in df_5m.columns:
+ df_5m = df_5m.set_index('timestamp')
+
+ if df_15m is not None and not df_15m.empty:
+ if 'timestamp' in df_15m.columns:
+ df_15m = df_15m.set_index('timestamp')
+
+ # Preprocess data (add technical indicators)
+ df_1m = preprocess_dataframe(df_1m)
+ df_5m = preprocess_dataframe(df_5m)
+ df_15m = preprocess_dataframe(df_15m)
+
+ # Create environment with all timeframes
+ env = create_multi_timeframe_env(df_1m, df_5m, df_15m, window_size)
+
+ return env, (df_1m, df_5m, df_15m)
+
+def preprocess_dataframe(df):
+ """Add technical indicators and preprocess dataframe"""
+ if df is None or df.empty:
+ return None
+
+ # Drop any missing values
+ df = df.dropna()
+
+ # Ensure it has OHLCV columns
+ required_columns = ['open', 'high', 'low', 'close', 'volume']
+ missing_columns = [col for col in required_columns if col not in df.columns]
+
+ if missing_columns:
+ logger.warning(f"Missing required columns: {missing_columns}")
+ for col in missing_columns:
+ # Fill with close price for OHLC if missing
+ if col in ['open', 'high', 'low'] and 'close' in df.columns:
+ df[col] = df['close']
+ # Fill with zeros for volume if missing
+ elif col == 'volume':
+ df[col] = 0
+
+ # Add simple technical indicators
+ # 1. Simple Moving Averages
+ df['sma_5'] = df['close'].rolling(window=5).mean()
+ df['sma_10'] = df['close'].rolling(window=10).mean()
+
+ # 2. Relative Strength Index (RSI)
+ delta = df['close'].diff()
+ gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
+ loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
+ rs = gain / loss
+ df['rsi'] = 100 - (100 / (1 + rs))
+
+ # 3. Bollinger Bands
+ df['bb_middle'] = df['close'].rolling(window=20).mean()
+ df['bb_std'] = df['close'].rolling(window=20).std()
+ df['bb_upper'] = df['bb_middle'] + 2 * df['bb_std']
+ df['bb_lower'] = df['bb_middle'] - 2 * df['bb_std']
+
+ # 4. MACD
+ df['ema_12'] = df['close'].ewm(span=12, adjust=False).mean()
+ df['ema_26'] = df['close'].ewm(span=26, adjust=False).mean()
+ df['macd'] = df['ema_12'] - df['ema_26']
+ df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean()
+
+ # 5. Price rate of change
+ df['roc'] = df['close'].pct_change(periods=10) * 100
+
+ # Fill any remaining NaN values with 0
+ df = df.fillna(0)
+
+ return df
+
+def create_multi_timeframe_env(df_1m, df_5m, df_15m, window_size=20):
+ """Create a custom environment that handles multiple timeframes"""
+
+ # Ensure we have complete data for all timeframes
+ min_required_length = window_size + 100 # Add buffer for training
+
+ if (df_1m is None or len(df_1m) < min_required_length or
+ df_5m is None or len(df_5m) < min_required_length or
+ df_15m is None or len(df_15m) < min_required_length):
+ raise ValueError(f"Insufficient data for training. Need at least {min_required_length} candles per timeframe.")
+
+ # Ensure we only use the last N valid data points
+ df_1m = df_1m.iloc[-900:].copy() if len(df_1m) > 900 else df_1m.copy()
+ df_5m = df_5m.iloc[-180:].copy() if len(df_5m) > 180 else df_5m.copy()
+ df_15m = df_15m.iloc[-60:].copy() if len(df_15m) > 60 else df_15m.copy()
+
+ # Reset index to make sure we have continuous integers
+ df_1m = df_1m.reset_index(drop=True)
+ df_5m = df_5m.reset_index(drop=True)
+ df_15m = df_15m.reset_index(drop=True)
+
+ # For simplicity, we'll use the 1m data as the base environment
+ # The other timeframes will be incorporated through observation
+
+ env = TradingEnvironment(
+ data=df_1m,
+ initial_balance=100.0,
+ fee_rate=0.0005, # 0.05% fee (typical for crypto exchanges)
+ max_steps=len(df_1m) - window_size - 50, # Leave some room at the end
+ window_size=window_size,
+ risk_aversion=0.2, # Moderately risk-averse
+ price_scaling='zscore', # Use z-score normalization
+ reward_scaling=10.0, # Scale rewards for better learning
+ episode_penalty=0.2 # Penalty for holding positions at end of episode
+ )
+
+ return env
+
+def initialize_agent(env, window_size=20, num_features=0, timeframes=None):
+ """Initialize the DQN agent with appropriate parameters"""
+ if timeframes is None:
+ timeframes = ['1m', '5m', '15m']
+
+ # Calculate input dimensions
+ state_dim = env.observation_space.shape[0]
+ action_dim = env.action_space.n
+
+ # If num_features wasn't provided, infer from environment
+ if num_features == 0:
+ # Calculate features per timeframe from state dimension and number of timeframes
+ # Accounting for the 3 additional features (position, equity, unrealized_pnl)
+ num_features = (state_dim - 3) // len(timeframes)
+
+ logger.info(f"Initializing DQN agent: state_dim={state_dim}, action_dim={action_dim}, features={num_features}")
+
+ agent = DQNAgent(
+ state_size=state_dim,
+ action_size=action_dim,
+ window_size=window_size,
+ num_features=num_features,
+ timeframes=timeframes,
+ learning_rate=0.0005, # Start with a moderate learning rate
+ gamma=0.97, # Slightly reduced discount factor for stable learning
+ epsilon=1.0, # Start with full exploration
+ epsilon_min=0.05, # Maintain some exploration even at the end
+ epsilon_decay=0.9975, # Slower decay for more exploration
+ memory_size=20000, # Larger replay buffer
+ batch_size=128, # Larger batch size for more stable gradients
+ target_update=5 # More frequent target network updates
+ )
+
+ return agent
+
+def train_agent(env, agent, num_episodes=20, visualize=False, chart=None, save_path=None, save_freq=5):
+ """
+ Train the DQN agent with improved training loop
+
+ Args:
+ env: The trading environment
+ agent: The DQN agent
+ num_episodes: Number of episodes to train
+ visualize: Whether to visualize trades during training
+ chart: The visualization chart (if visualize=True)
+ save_path: Path to save the model
+ save_freq: How often to save checkpoints (in episodes)
+
+ Returns:
+ tuple: (rewards, wins, losses, best_reward)
+ """
+ logger.info(f"Starting training for {num_episodes} episodes")
+
+ # Initialize metrics tracking
+ rewards = []
+ win_rates = []
+ total_train_time = 0
+ best_reward = float('-inf')
+ best_model_path = None
+
+ # Create directory for checkpoints if needed
+ if save_path:
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ checkpoint_dir = os.path.join(os.path.dirname(save_path), 'checkpoints')
+ os.makedirs(checkpoint_dir, exist_ok=True)
+
+ # For tracking improvement
+ last_improved_episode = 0
+ patience = 10 # Episodes to wait for improvement before early stopping
+
+ for episode in range(num_episodes):
+ start_time = time.time()
+
+ # Reset environment and get initial state
+ state = env.reset()
+ done = False
+ episode_reward = 0
+ step = 0
+
+ # Action metrics for this episode
+ actions_taken = {0: 0, 1: 0, 2: 0} # Track BUY, SELL, HOLD actions
+
+ while not done:
+ # Select action
+ action = agent.act(state)
+
+ # Execute action
+ next_state, reward, done, info = env.step(action)
+
+ # Store experience in replay buffer
+ is_extrema = False # In a real implementation, detect extrema points
+ agent.remember(state, action, reward, next_state, done, is_extrema)
+
+ # Learn from experience
+ if len(agent.memory) >= agent.batch_size:
+ use_prioritized = episode > 1 # Start using prioritized replay after first episode
+ loss = agent.replay(use_prioritized=use_prioritized)
+
+ # Update state and metrics
+ state = next_state
+ episode_reward += reward
+ actions_taken[action] += 1
+
+ # Every 100 steps, log progress
+ if step % 100 == 0 or step < 10:
+ action_str = "BUY" if action == 0 else "SELL" if action == 1 else "HOLD"
+ current_price = info.get('current_price', 0)
+ pnl = info.get('pnl', 0)
+ balance = info.get('balance', 0)
+
+ logger.info(f"Episode {episode}, Step {step}: Action={action_str}, "
+ f"Reward={reward:.4f}, Balance=${balance:.2f}, PnL={pnl:.4f}")
+
+ # Add trade to visualization if enabled
+ if visualize and chart and action in [0, 1]: # BUY or SELL
+ chart.add_trade(
+ price=current_price,
+ timestamp=datetime.now(),
+ amount=0.1,
+ pnl=pnl,
+ action=action_str
+ )
+
+ step += 1
+
+ # Episode finished - calculate metrics
+ episode_time = time.time() - start_time
+ total_train_time += episode_time
+
+ # Get environment info
+ win_rate = env.winning_trades / max(1, env.total_trades)
+ trades = env.total_trades
+ balance = env.balance
+ gain = (balance - env.initial_balance) / env.initial_balance
+ max_drawdown = env.max_drawdown
+
+ # Record metrics
+ rewards.append(episode_reward)
+ win_rates.append(win_rate)
+
+ # Update agent's learning metrics
+ improved = agent.update_learning_metrics(episode_reward)
+
+ # If this is best performance, save the model
+ if episode_reward > best_reward:
+ best_reward = episode_reward
+ if save_path:
+ best_model_path = f"{save_path}_best"
+ agent.save(best_model_path)
+ logger.info(f"New best model saved to {best_model_path} (reward: {best_reward:.2f})")
+ last_improved_episode = episode
+
+ # Regular checkpoint saving
+ if save_path and episode % save_freq == 0:
+ checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_episode_{episode}")
+ agent.save(checkpoint_path)
+
+ # Print episode summary
+ actions_summary = ", ".join([f"{k}:{v}" for k, v in actions_taken.items()])
+ logger.info(f"Episode {episode} completed in {episode_time:.2f}s")
+ logger.info(f" Total reward: {episode_reward:.4f}")
+ logger.info(f" Actions taken: {actions_summary}")
+ logger.info(f" Trades: {trades}, Win rate: {win_rate:.2%}")
+ logger.info(f" Balance: ${balance:.2f}, Gain: {gain:.2%}")
+ logger.info(f" Max Drawdown: {max_drawdown:.2%}")
+
+ # Early stopping check
+ if episode - last_improved_episode >= patience:
+ logger.info(f"No improvement for {patience} episodes. Early stopping.")
+ break
+
+ # Training complete
+ avg_time_per_episode = total_train_time / max(1, len(rewards))
+ logger.info(f"Training completed in {total_train_time:.2f}s ({avg_time_per_episode:.2f}s per episode)")
+
+ # Save final model
+ if save_path:
+ agent.save(f"{save_path}_final")
+ logger.info(f"Final model saved to {save_path}_final")
+
+ # Return training metrics
+ return rewards, win_rates, best_reward, best_model_path
+
+def plot_training_results(rewards, win_rates, save_dir=None):
+ """Plot training metrics and save the figure"""
+ plt.figure(figsize=(12, 8))
+
+ # Plot rewards
+ plt.subplot(2, 1, 1)
+ plt.plot(rewards, 'b-')
+ plt.title('Training Rewards per Episode')
+ plt.xlabel('Episode')
+ plt.ylabel('Total Reward')
+ plt.grid(True)
+
+ # Plot win rates
+ plt.subplot(2, 1, 2)
+ plt.plot(win_rates, 'g-')
+ plt.title('Win Rate per Episode')
+ plt.xlabel('Episode')
+ plt.ylabel('Win Rate')
+ plt.grid(True)
+
+ plt.tight_layout()
+
+ # Save figure if directory provided
+ if save_dir:
+ os.makedirs(save_dir, exist_ok=True)
+ plt.savefig(os.path.join(save_dir, f'training_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png'))
+
+ plt.close()
+
+def evaluate_agent(env, agent, num_episodes=5, visualize=False, chart=None):
+ """
+ Evaluate a trained agent on the environment
+
+ Args:
+ env: The trading environment
+ agent: The trained DQN agent
+ num_episodes: Number of evaluation episodes
+ visualize: Whether to visualize trades
+ chart: The visualization chart (if visualize=True)
+
+ Returns:
+ dict: Evaluation metrics
+ """
+ logger.info(f"Evaluating agent over {num_episodes} episodes")
+
+ # Metrics to track
+ total_rewards = []
+ total_trades = []
+ win_rates = []
+ sharpe_ratios = []
+ sortino_ratios = []
+ max_drawdowns = []
+ final_balances = []
+
+ for episode in range(num_episodes):
+ # Reset environment
+ state = env.reset()
+ done = False
+ episode_reward = 0
+
+ # Run episode without exploration
+ while not done:
+ action = agent.act(state, explore=False) # No exploration during evaluation
+ next_state, reward, done, info = env.step(action)
+
+ episode_reward += reward
+ state = next_state
+
+ # Add trade to visualization if enabled
+ if visualize and chart and action in [0, 1]: # BUY or SELL
+ action_str = "BUY" if action == 0 else "SELL"
+ current_price = info.get('current_price', 0)
+ pnl = info.get('pnl', 0)
+
+ chart.add_trade(
+ price=current_price,
+ timestamp=datetime.now(),
+ amount=0.1,
+ pnl=pnl,
+ action=action_str
+ )
+
+ # Record metrics
+ total_rewards.append(episode_reward)
+ total_trades.append(env.total_trades)
+ win_rates.append(env.winning_trades / max(1, env.total_trades))
+ sharpe_ratios.append(info.get('sharpe_ratio', 0))
+ sortino_ratios.append(info.get('sortino_ratio', 0))
+ max_drawdowns.append(env.max_drawdown)
+ final_balances.append(env.balance)
+
+ logger.info(f"Evaluation episode {episode} - Reward: {episode_reward:.4f}, "
+ f"Balance: ${env.balance:.2f}, Win rate: {win_rates[-1]:.2%}")
+
+ # Calculate average metrics
+ avg_reward = np.mean(total_rewards)
+ avg_trades = np.mean(total_trades)
+ avg_win_rate = np.mean(win_rates)
+ avg_sharpe = np.mean(sharpe_ratios)
+ avg_sortino = np.mean(sortino_ratios)
+ avg_max_drawdown = np.mean(max_drawdowns)
+ avg_final_balance = np.mean(final_balances)
+
+ # Log evaluation summary
+ logger.info("Evaluation completed:")
+ logger.info(f" Average reward: {avg_reward:.4f}")
+ logger.info(f" Average trades per episode: {avg_trades:.2f}")
+ logger.info(f" Average win rate: {avg_win_rate:.2%}")
+ logger.info(f" Average Sharpe ratio: {avg_sharpe:.4f}")
+ logger.info(f" Average Sortino ratio: {avg_sortino:.4f}")
+ logger.info(f" Average max drawdown: {avg_max_drawdown:.2%}")
+ logger.info(f" Average final balance: ${avg_final_balance:.2f}")
+
+ # Return evaluation metrics
+ return {
+ 'avg_reward': avg_reward,
+ 'avg_trades': avg_trades,
+ 'avg_win_rate': avg_win_rate,
+ 'avg_sharpe': avg_sharpe,
+ 'avg_sortino': avg_sortino,
+ 'avg_max_drawdown': avg_max_drawdown,
+ 'avg_final_balance': avg_final_balance
+ }
+
+def main():
+ """Main function to run the improved RL training"""
+ start_time = time.time()
+ logger.info(f"Starting improved RL training for {args.symbol}")
+
+ # Create environment
+ env, data_frames = create_training_environment(args.symbol)
+
+ # Initialize visualization if enabled
+ chart = None
+ if args.visualize:
+ logger.info("Initializing visualization chart")
+ chart = RealTimeChart(args.symbol)
+ time.sleep(2) # Give time for chart to initialize
+
+ # Initialize agent
+ agent = initialize_agent(env)
+
+ # Train agent
+ rewards, win_rates, best_reward, best_model_path = train_agent(
+ env=env,
+ agent=agent,
+ num_episodes=args.episodes,
+ visualize=args.visualize,
+ chart=chart,
+ save_path=args.save_path
+ )
+
+ # Plot training results
+ plot_dir = os.path.join(os.path.dirname(args.save_path), 'plots')
+ plot_training_results(rewards, win_rates, save_dir=plot_dir)
+
+ # Evaluate best model
+ logger.info("Evaluating best model")
+
+ # Load best model for evaluation
+ if best_model_path:
+ best_agent = initialize_agent(env)
+ best_agent.load(best_model_path)
+
+ # Evaluate the best model
+ eval_metrics = evaluate_agent(
+ env=env,
+ agent=best_agent,
+ visualize=args.visualize,
+ chart=chart
+ )
+
+ # Log evaluation results
+ logger.info("Best model evaluation complete:")
+ for metric, value in eval_metrics.items():
+ logger.info(f" {metric}: {value}")
+
+ # Total run time
+ total_time = time.time() - start_time
+ logger.info(f"Total run time: {total_time:.2f} seconds")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file