diff --git a/NN/models/dqn_agent.py b/NN/models/dqn_agent.py index a077359..6477d0f 100644 --- a/NN/models/dqn_agent.py +++ b/NN/models/dqn_agent.py @@ -28,14 +28,14 @@ class DQNAgent: window_size: int, num_features: int, timeframes: List[str], - learning_rate: float = 0.001, - gamma: float = 0.99, + learning_rate: float = 0.0005, # Reduced learning rate for more stability + gamma: float = 0.97, # Slightly reduced discount factor epsilon: float = 1.0, - epsilon_min: float = 0.01, - epsilon_decay: float = 0.995, - memory_size: int = 10000, - batch_size: int = 64, - target_update: int = 10): + epsilon_min: float = 0.05, # Increased minimum epsilon for more exploration + epsilon_decay: float = 0.9975, # Slower decay rate + memory_size: int = 20000, # Increased memory size + batch_size: int = 128, # Larger batch size + target_update: int = 5): # More frequent target updates self.state_size = state_size self.action_size = action_size @@ -70,23 +70,25 @@ class DQNAgent: ).to(self.device) self.target_net.load_state_dict(self.policy_net.state_dict()) - # Initialize optimizer - self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate) + # Initialize optimizer with gradient clipping + self.optimizer = optim.Adam(self.policy_net.parameters(), lr=learning_rate, weight_decay=1e-5) - # Initialize memory + # Initialize memories with different priorities self.memory = deque(maxlen=memory_size) - - # Special memory for extrema samples to use for targeted learning - self.extrema_memory = deque(maxlen=memory_size // 5) # Smaller size for extrema examples + self.extrema_memory = deque(maxlen=memory_size // 4) # For extrema points + self.positive_memory = deque(maxlen=memory_size // 4) # For positive rewards # Training metrics self.update_count = 0 self.losses = [] + self.avg_reward = 0 + self.no_improvement_count = 0 + self.best_reward = float('-inf') def remember(self, state: np.ndarray, action: int, reward: float, next_state: np.ndarray, done: bool, is_extrema: bool = False): """ - Store experience in memory + Store experience in memory with prioritization Args: state: Current state @@ -97,28 +99,124 @@ class DQNAgent: is_extrema: Whether this is a local extrema sample (for specialized learning) """ experience = (state, action, reward, next_state, done) + + # Always add to main memory self.memory.append(experience) - # If this is an extrema sample, also add to specialized memory + # Add to specialized memories if applicable if is_extrema: self.extrema_memory.append(experience) + + # Store positive experiences separately for prioritized replay + if reward > 0: + self.positive_memory.append(experience) - def act(self, state: np.ndarray) -> int: - """Choose action using epsilon-greedy policy""" - if random.random() < self.epsilon: + def act(self, state: np.ndarray, explore=True) -> int: + """Choose action using epsilon-greedy policy with explore flag""" + if explore and random.random() < self.epsilon: return random.randrange(self.action_size) with torch.no_grad(): - state = torch.FloatTensor(state).unsqueeze(0).to(self.device) - action_probs, extrema_pred = self.policy_net(state) + # Ensure state is normalized before inference + state_tensor = self._normalize_state(state) + state_tensor = torch.FloatTensor(state_tensor).unsqueeze(0).to(self.device) + action_probs, extrema_pred = self.policy_net(state_tensor) return action_probs.argmax().item() - def replay(self, use_extrema=False) -> float: + def _normalize_state(self, state: np.ndarray) -> np.ndarray: + """Normalize the state data to prevent numerical issues""" + # Handle NaN and infinite values + state = np.nan_to_num(state, nan=0.0, posinf=1.0, neginf=-1.0) + + # Check if state is 1D array (happens in some environments) + if len(state.shape) == 1: + # If 1D, we need to normalize the whole array + normalized_state = state.copy() + + # Convert any timestamp or non-numeric data to float + for i in range(len(normalized_state)): + # Check for timestamp-like objects + if hasattr(normalized_state[i], 'timestamp') and callable(getattr(normalized_state[i], 'timestamp')): + # Convert timestamp to float (seconds since epoch) + normalized_state[i] = float(normalized_state[i].timestamp()) + elif not isinstance(normalized_state[i], (int, float, np.number)): + # Set non-numeric data to 0 + normalized_state[i] = 0.0 + + # Ensure all values are float + normalized_state = normalized_state.astype(np.float32) + + # Simple min-max normalization for 1D state + state_min = np.min(normalized_state) + state_max = np.max(normalized_state) + if state_max > state_min: + normalized_state = (normalized_state - state_min) / (state_max - state_min) + return normalized_state + + # Handle 2D arrays + normalized_state = np.zeros_like(state, dtype=np.float32) + + # Convert any timestamp or non-numeric data to float + for i in range(state.shape[0]): + for j in range(state.shape[1]): + if hasattr(state[i, j], 'timestamp') and callable(getattr(state[i, j], 'timestamp')): + # Convert timestamp to float (seconds since epoch) + normalized_state[i, j] = float(state[i, j].timestamp()) + elif isinstance(state[i, j], (int, float, np.number)): + normalized_state[i, j] = state[i, j] + else: + # Set non-numeric data to 0 + normalized_state[i, j] = 0.0 + + # Loop through each timeframe's features in the combined state + feature_count = state.shape[1] // len(self.timeframes) + + for tf_idx in range(len(self.timeframes)): + start_idx = tf_idx * feature_count + end_idx = start_idx + feature_count + + # Extract this timeframe's features + tf_features = normalized_state[:, start_idx:end_idx] + + # Normalize OHLCV data by the first close price in the window + # This makes price movements relative rather than absolute + price_idx = 3 # Assuming close price is at index 3 + if price_idx < tf_features.shape[1]: + reference_price = np.mean(tf_features[:, price_idx]) + if reference_price != 0: + # Normalize price-related columns (OHLC) + for i in range(4): # First 4 columns are OHLC + if i < tf_features.shape[1]: + normalized_state[:, start_idx + i] = tf_features[:, i] / reference_price + + # Normalize volume using mean and std + vol_idx = 4 # Assuming volume is at index 4 + if vol_idx < tf_features.shape[1]: + vol_mean = np.mean(tf_features[:, vol_idx]) + vol_std = np.std(tf_features[:, vol_idx]) + if vol_std > 0: + normalized_state[:, start_idx + vol_idx] = (tf_features[:, vol_idx] - vol_mean) / vol_std + else: + normalized_state[:, start_idx + vol_idx] = 0 + + # Other features (technical indicators) - normalize with min-max scaling + for i in range(5, feature_count): + if i < tf_features.shape[1]: + feature_min = np.min(tf_features[:, i]) + feature_max = np.max(tf_features[:, i]) + if feature_max > feature_min: + normalized_state[:, start_idx + i] = (tf_features[:, i] - feature_min) / (feature_max - feature_min) + else: + normalized_state[:, start_idx + i] = 0 + + return normalized_state + + def replay(self, use_prioritized=True) -> float: """ - Train on a batch of experiences + Train on a batch of experiences with prioritized sampling Args: - use_extrema: Whether to include extrema samples in training + use_prioritized: Whether to use prioritized replay Returns: float: Loss value @@ -126,55 +224,67 @@ class DQNAgent: if len(self.memory) < self.batch_size: return 0.0 - # Sample batch - mix regular and extrema samples + # Sample batch with prioritization batch = [] - if use_extrema and len(self.extrema_memory) > self.batch_size // 4: - # Get some extrema samples - extrema_count = min(self.batch_size // 3, len(self.extrema_memory)) - extrema_samples = random.sample(list(self.extrema_memory), extrema_count) + + if use_prioritized and len(self.positive_memory) > 0 and len(self.extrema_memory) > 0: + # Prioritized sampling from different memory types + positive_count = min(self.batch_size // 4, len(self.positive_memory)) + extrema_count = min(self.batch_size // 4, len(self.extrema_memory)) + regular_count = self.batch_size - positive_count - extrema_count - # Get regular samples for the rest - regular_count = self.batch_size - extrema_count + positive_samples = random.sample(list(self.positive_memory), positive_count) + extrema_samples = random.sample(list(self.extrema_memory), extrema_count) regular_samples = random.sample(list(self.memory), regular_count) - # Combine samples - batch = extrema_samples + regular_samples + batch = positive_samples + extrema_samples + regular_samples else: # Standard sampling batch = random.sample(self.memory, self.batch_size) states, actions, rewards, next_states, dones = zip(*batch) + # Normalize states before training + normalized_states = np.array([self._normalize_state(state) for state in states]) + normalized_next_states = np.array([self._normalize_state(state) for state in next_states]) + # Convert to tensors and move to device - states = torch.FloatTensor(np.array(states)).to(self.device) - actions = torch.LongTensor(actions).to(self.device) - rewards = torch.FloatTensor(rewards).to(self.device) - next_states = torch.FloatTensor(np.array(next_states)).to(self.device) - dones = torch.FloatTensor(dones).to(self.device) + states_tensor = torch.FloatTensor(normalized_states).to(self.device) + actions_tensor = torch.LongTensor(actions).to(self.device) + rewards_tensor = torch.FloatTensor(rewards).to(self.device) + next_states_tensor = torch.FloatTensor(normalized_next_states).to(self.device) + dones_tensor = torch.FloatTensor(dones).to(self.device) # Get current Q values - current_q_values, extrema_pred = self.policy_net(states) - current_q_values = current_q_values.gather(1, actions.unsqueeze(1)) + current_q_values, extrema_pred = self.policy_net(states_tensor) + current_q_values = current_q_values.gather(1, actions_tensor.unsqueeze(1)) - # Get next Q values from target network + # Get next Q values from target network (Double DQN approach) with torch.no_grad(): - next_q_values, _ = self.target_net(next_states) - next_q_values = next_q_values.max(1)[0] - target_q_values = rewards + (1 - dones) * self.gamma * next_q_values + # Get actions from policy network + next_actions, _ = self.policy_net(next_states_tensor) + next_actions = next_actions.max(1)[1].unsqueeze(1) + + # Get Q values from target network for those actions + next_q_values, _ = self.target_net(next_states_tensor) + next_q_values = next_q_values.gather(1, next_actions).squeeze(1) + + # Compute target Q values + target_q_values = rewards_tensor + (1 - dones_tensor) * self.gamma * next_q_values - # Compute Q-learning loss - q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values) + # Clamp target values to prevent extreme values + target_q_values = torch.clamp(target_q_values, -100, 100) - # If we have extrema labels (not in this implementation yet), - # we could add an additional loss for extrema prediction - # This would require labels for whether each state is near an extrema - - # Total loss is just Q-learning loss for now - loss = q_loss + # Compute Huber loss (more robust to outliers than MSE) + loss = nn.SmoothL1Loss()(current_q_values.squeeze(), target_q_values) # Optimize self.optimizer.zero_grad() loss.backward() + + # Apply gradient clipping to prevent exploding gradients + nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0) + self.optimizer.step() # Update target network if needed @@ -200,37 +310,77 @@ class DQNAgent: """ if len(states) == 0: return 0.0 + + # Normalize states + normalized_states = np.array([self._normalize_state(state) for state in states]) + normalized_next_states = np.array([self._normalize_state(state) for state in next_states]) # Convert to tensors - states = torch.FloatTensor(np.array(states)).to(self.device) - actions = torch.LongTensor(actions).to(self.device) - rewards = torch.FloatTensor(rewards).to(self.device) - next_states = torch.FloatTensor(np.array(next_states)).to(self.device) - dones = torch.FloatTensor(dones).to(self.device) + states_tensor = torch.FloatTensor(normalized_states).to(self.device) + actions_tensor = torch.LongTensor(actions).to(self.device) + rewards_tensor = torch.FloatTensor(rewards).to(self.device) + next_states_tensor = torch.FloatTensor(normalized_next_states).to(self.device) + dones_tensor = torch.FloatTensor(dones).to(self.device) # Forward pass - current_q_values, extrema_pred = self.policy_net(states) - current_q_values = current_q_values.gather(1, actions.unsqueeze(1)) + current_q_values, extrema_pred = self.policy_net(states_tensor) + current_q_values = current_q_values.gather(1, actions_tensor.unsqueeze(1)) - # Get next Q values + # Get next Q values (Double DQN approach) with torch.no_grad(): - next_q_values, _ = self.target_net(next_states) - next_q_values = next_q_values.max(1)[0] - target_q_values = rewards + (1 - dones) * self.gamma * next_q_values + next_actions, _ = self.policy_net(next_states_tensor) + next_actions = next_actions.max(1)[1].unsqueeze(1) - # Higher weight for extrema training - q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values) + next_q_values, _ = self.target_net(next_states_tensor) + next_q_values = next_q_values.gather(1, next_actions).squeeze(1) + + target_q_values = rewards_tensor + (1 - dones_tensor) * self.gamma * next_q_values + + # Clamp target values + target_q_values = torch.clamp(target_q_values, -100, 100) - # Full loss is just Q-learning loss + # Use Huber loss for extrema training + q_loss = nn.SmoothL1Loss()(current_q_values.squeeze(), target_q_values) + + # Full loss loss = q_loss # Optimize self.optimizer.zero_grad() loss.backward() + nn.utils.clip_grad_norm_(self.policy_net.parameters(), max_norm=1.0) self.optimizer.step() return loss.item() + def update_learning_metrics(self, episode_reward, best_reward_threshold=0.01): + """Update learning metrics and perform learning rate adjustments if needed""" + # Update average reward with exponential moving average + if self.avg_reward == 0: + self.avg_reward = episode_reward + else: + self.avg_reward = 0.95 * self.avg_reward + 0.05 * episode_reward + + # Check if we're making sufficient progress + if episode_reward > (1 + best_reward_threshold) * self.best_reward: + self.best_reward = episode_reward + self.no_improvement_count = 0 + return True # Improved + else: + self.no_improvement_count += 1 + + # If no improvement for a while, adjust learning rate + if self.no_improvement_count >= 10: + current_lr = self.optimizer.param_groups[0]['lr'] + new_lr = current_lr * 0.5 + if new_lr >= 1e-6: # Don't reduce below minimum threshold + for param_group in self.optimizer.param_groups: + param_group['lr'] = new_lr + logger.info(f"Reducing learning rate from {current_lr} to {new_lr}") + self.no_improvement_count = 0 + + return False # No improvement + def save(self, path: str): """Save model and agent state""" os.makedirs(os.path.dirname(path), exist_ok=True) @@ -246,9 +396,13 @@ class DQNAgent: 'epsilon': self.epsilon, 'update_count': self.update_count, 'losses': self.losses, - 'optimizer_state': self.optimizer.state_dict() + 'optimizer_state': self.optimizer.state_dict(), + 'best_reward': self.best_reward, + 'avg_reward': self.avg_reward } + torch.save(state, f"{path}_agent_state.pt") + logger.info(f"Agent state saved to {path}_agent_state.pt") def load(self, path: str): """Load model and agent state""" @@ -259,8 +413,19 @@ class DQNAgent: self.target_net.load(f"{path}_target") # Load agent state - state = torch.load(f"{path}_agent_state.pt") - self.epsilon = state['epsilon'] - self.update_count = state['update_count'] - self.losses = state['losses'] - self.optimizer.load_state_dict(state['optimizer_state']) \ No newline at end of file + try: + agent_state = torch.load(f"{path}_agent_state.pt", map_location=self.device) + self.epsilon = agent_state['epsilon'] + self.update_count = agent_state['update_count'] + self.losses = agent_state['losses'] + self.optimizer.load_state_dict(agent_state['optimizer_state']) + + # Load additional metrics if they exist + if 'best_reward' in agent_state: + self.best_reward = agent_state['best_reward'] + if 'avg_reward' in agent_state: + self.avg_reward = agent_state['avg_reward'] + + logger.info(f"Agent state loaded from {path}_agent_state.pt") + except FileNotFoundError: + logger.warning(f"Agent state file not found at {path}_agent_state.pt, using default values") \ No newline at end of file diff --git a/NN/models/simple_cnn.py b/NN/models/simple_cnn.py index f6293c0..b4b637b 100644 --- a/NN/models/simple_cnn.py +++ b/NN/models/simple_cnn.py @@ -44,13 +44,46 @@ class PricePatternAttention(nn.Module): return output, attn_weights +class AdaptiveNorm(nn.Module): + """ + Adaptive normalization layer that chooses between different normalization + methods based on input dimensions + """ + def __init__(self, num_features): + super(AdaptiveNorm, self).__init__() + self.batch_norm = nn.BatchNorm1d(num_features, affine=True) + self.group_norm = nn.GroupNorm(min(32, num_features), num_features) + self.layer_norm = nn.LayerNorm([num_features, 1]) + + def forward(self, x): + # Check input dimensions + batch_size, channels, seq_len = x.size() + + # Choose normalization method: + # - Batch size > 1 and seq_len > 1: BatchNorm + # - Batch size == 1 or seq_len == 1: GroupNorm + # - Fallback for extreme cases: LayerNorm + if batch_size > 1 and seq_len > 1: + return self.batch_norm(x) + elif seq_len > 1: + return self.group_norm(x) + else: + # For 1D inputs (seq_len=1), we need to adjust the layer norm + # to the actual input size + if not hasattr(self, 'layer_norm_1d') or self.layer_norm_1d.normalized_shape[0] != channels: + self.layer_norm_1d = nn.LayerNorm([channels, seq_len]).to(x.device) + return self.layer_norm_1d(x) + class CNNModelPyTorch(nn.Module): """ CNN model for trading with multiple timeframes """ - def __init__(self, window_size, num_features, output_size, timeframes): + def __init__(self, window_size=20, num_features=5, output_size=3, timeframes=None): super(CNNModelPyTorch, self).__init__() + if timeframes is None: + timeframes = [1] + self.window_size = window_size self.num_features = num_features self.output_size = output_size @@ -73,27 +106,28 @@ class CNNModelPyTorch(nn.Module): """Create all model layers with current feature dimensions""" # Convolutional layers - use total_features as input channels self.conv1 = nn.Conv1d(self.total_features, 64, kernel_size=3, padding=1) - self.bn1 = nn.BatchNorm1d(64) + self.norm1 = AdaptiveNorm(64) + self.dropout1 = nn.Dropout(0.2) self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1) - self.bn2 = nn.BatchNorm1d(128) + self.norm2 = AdaptiveNorm(128) + self.dropout2 = nn.Dropout(0.3) self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) - self.bn3 = nn.BatchNorm1d(256) + self.norm3 = AdaptiveNorm(256) + self.dropout3 = nn.Dropout(0.4) # Add price pattern attention layer self.attention = PricePatternAttention(256) # Extrema detection specialized convolutional layer - self.extrema_conv = nn.Conv1d(256, 128, kernel_size=5, padding=2) - self.extrema_bn = nn.BatchNorm1d(128) + self.extrema_conv = nn.Conv1d(256, 128, kernel_size=3, padding=1) # Smaller kernel for small inputs + self.extrema_norm = AdaptiveNorm(128) - # Calculate size after convolutions - adjusted for attention output - conv_output_size = self.window_size * 256 - - # Fully connected layers - self.fc1 = nn.Linear(conv_output_size, 512) + # Fully connected layers - input size will be determined dynamically + self.fc1 = None # Will be initialized in forward pass self.fc2 = nn.Linear(512, 256) + self.dropout_fc = nn.Dropout(0.5) # Advantage and Value streams (Dueling DQN architecture) self.fc3 = nn.Linear(256, self.output_size) # Advantage stream @@ -131,46 +165,96 @@ class CNNModelPyTorch(nn.Module): # Ensure input is on the correct device x = x.to(self.device) - # Check and handle if input dimensions don't match model expectations - batch_size, window_len, feature_dim = x.size() - if feature_dim != self.total_features: - logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features}), rebuilding layers") - self.rebuild_conv_layers(feature_dim) + # Check input dimensions and reshape as needed + if len(x.size()) == 2: + # If input is [batch_size, features], reshape to [batch_size, features, 1] + batch_size, feature_dim = x.size() + + # Check and handle if input features don't match model expectations + if feature_dim != self.total_features: + logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features}), rebuilding layers") + self.rebuild_conv_layers(feature_dim) + + # For 1D input, use a sequence length of 1 + seq_len = 1 + x = x.unsqueeze(2) # Reshape to [batch, features, 1] + elif len(x.size()) == 3: + # Standard case: [batch_size, window_size, features] + batch_size, seq_len, feature_dim = x.size() + + # Check and handle if input dimensions don't match model expectations + if feature_dim != self.total_features: + logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features}), rebuilding layers") + self.rebuild_conv_layers(feature_dim) + + # Reshape input: [batch, window_size, features] -> [batch, features, window_size] + x = x.permute(0, 2, 1) + else: + raise ValueError(f"Unexpected input shape: {x.size()}, expected 2D or 3D tensor") - # Reshape input: [batch, window_size, features] -> [batch, channels, window_size] - x = x.permute(0, 2, 1) - - # Convolutional layers - x = F.relu(self.bn1(self.conv1(x))) - x = F.relu(self.bn2(self.conv2(x))) - x = F.relu(self.bn3(self.conv3(x))) + # Convolutional layers with dropout - safely handle small spatial dimensions + try: + x = self.dropout1(F.relu(self.norm1(self.conv1(x)))) + x = self.dropout2(F.relu(self.norm2(self.conv2(x)))) + x = self.dropout3(F.relu(self.norm3(self.conv3(x)))) + except Exception as e: + logger.warning(f"Error in convolutional layers: {str(e)}") + # Fallback for very small inputs: skip some convolutions + if seq_len < 3: + # Apply a simpler convolution for very small inputs + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + # Skip last conv if we get dimension errors + try: + x = F.relu(self.conv3(x)) + except: + pass # Store conv features for extrema detection conv_features = x - # Reshape for attention: [batch, channels, window_size] -> [batch, window_size, channels] - x_attention = x.permute(0, 2, 1) + # Get the current shape after convolutions + _, channels, conv_seq_len = x.size() - # Apply attention - attention_output, attention_weights = self.attention(x_attention) + # Initialize fc1 if not created yet or if the shape has changed + if self.fc1 is None: + flattened_size = channels * conv_seq_len + logger.info(f"Initializing fc1 with input size {flattened_size}") + self.fc1 = nn.Linear(flattened_size, 512).to(self.device) - # We'll use attention directly without the residual connection - # to avoid dimension mismatch issues - attention_reshaped = attention_output.permute(0, 2, 1) # [batch, channels, window_size] + # Apply extrema detection safely + try: + extrema_features = F.relu(self.extrema_norm(self.extrema_conv(conv_features))) + except Exception as e: + logger.warning(f"Error in extrema detection: {str(e)}") + extrema_features = conv_features # Fallback - # Apply extrema detection specialized layer - extrema_features = F.relu(self.extrema_bn(self.extrema_conv(conv_features))) + # Handle attention for small sequence lengths + if conv_seq_len > 1: + # Reshape for attention: [batch, channels, seq_len] -> [batch, seq_len, channels] + x_attention = x.permute(0, 2, 1) + + # Apply attention + try: + attention_output, attention_weights = self.attention(x_attention) + except Exception as e: + logger.warning(f"Error in attention layer: {str(e)}") + # Fallback: don't use attention - # Use attention features directly instead of residual connection - # to avoid dimension mismatches - x = conv_features # Just use the convolutional features + # Flatten - get the actual shape for this batch + flattened_size = channels * conv_seq_len + x = x.view(batch_size, flattened_size) - # Flatten - x = x.view(batch_size, -1) + # Check if we need to recreate fc1 with the correct size + if self.fc1.in_features != flattened_size: + logger.info(f"Recreating fc1 layer to match input size {flattened_size}") + self.fc1 = nn.Linear(flattened_size, 512).to(self.device) + # Reinitialize optimizer after changing the model + self.optimizer = optim.Adam(self.parameters(), lr=0.001) - # Fully connected layers + # Fully connected layers with dropout x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) + x = self.dropout_fc(F.relu(self.fc2(x))) # Split into advantage and value streams advantage = self.fc3(x) diff --git a/NN/utils/trading_env.py b/NN/utils/trading_env.py index c33fc09..d02602c 100644 --- a/NN/utils/trading_env.py +++ b/NN/utils/trading_env.py @@ -3,6 +3,10 @@ import gym from gym import spaces from typing import Dict, Tuple, List import pandas as pd +import logging + +# Configure logger +logger = logging.getLogger(__name__) class TradingEnvironment(gym.Env): """ @@ -12,97 +16,284 @@ class TradingEnvironment(gym.Env): data: pd.DataFrame, initial_balance: float = 100.0, fee_rate: float = 0.0002, - max_steps: int = 1000): + max_steps: int = 1000, + window_size: int = 20, + risk_aversion: float = 0.2, # Controls how much to penalize volatility + price_scaling: str = 'zscore', # 'zscore', 'minmax', or 'raw' + reward_scaling: float = 10.0, # Scale factor for rewards + episode_penalty: float = 0.1): # Penalty for active positions at end of episode super(TradingEnvironment, self).__init__() self.data = data self.initial_balance = initial_balance self.fee_rate = fee_rate self.max_steps = max_steps + self.window_size = window_size + self.risk_aversion = risk_aversion + self.price_scaling = price_scaling + self.reward_scaling = reward_scaling + self.episode_penalty = episode_penalty - # Action space: 0 (SELL), 1 (HOLD), 2 (BUY) + # Preprocess data if needed + self._preprocess_data() + + # Action space: 0 (BUY), 1 (SELL), 2 (HOLD) self.action_space = spaces.Discrete(3) # Observation space: price data, technical indicators, and account state + feature_dim = self.data.shape[1] + 3 # Adding position, equity, unrealized_pnl self.observation_space = spaces.Box( low=-np.inf, high=np.inf, - shape=(data.shape[1],), # Number of features + shape=(feature_dim,), dtype=np.float32 ) # Initialize state self.reset() + def _preprocess_data(self): + """Preprocess data - normalize or standardize features""" + # Store the original data for reference + self.original_data = self.data.copy() + + # Normalize price data based on the selected method + if self.price_scaling == 'zscore': + # For each feature, apply z-score normalization + for col in self.data.columns: + if col in ['open', 'high', 'low', 'close']: + mean = self.data[col].mean() + std = self.data[col].std() + if std > 0: + self.data[col] = (self.data[col] - mean) / std + # Normalize volume separately + elif col == 'volume': + mean = self.data[col].mean() + std = self.data[col].std() + if std > 0: + self.data[col] = (self.data[col] - mean) / std + + elif self.price_scaling == 'minmax': + # For each feature, apply min-max scaling + for col in self.data.columns: + min_val = self.data[col].min() + max_val = self.data[col].max() + if max_val > min_val: + self.data[col] = (self.data[col] - min_val) / (max_val - min_val) + def reset(self) -> np.ndarray: """Reset the environment to initial state""" - self.current_step = 0 + self.current_step = self.window_size self.balance = self.initial_balance - self.position = 0 # 0: no position, 1: long position + self.position = 0 # 0: no position, 1: long position, -1: short position self.entry_price = 0 + self.entry_time = 0 self.total_trades = 0 self.winning_trades = 0 + self.losing_trades = 0 self.total_pnl = 0 self.balance_history = [self.initial_balance] + self.equity_history = [self.initial_balance] self.max_balance = self.initial_balance + self.max_drawdown = 0 + + # Trading performance metrics + self.trade_durations = [] # Track how long trades are held + self.returns = [] # Track returns of each trade + + # For analyzing trade clustering + self.last_action_time = 0 + self.actions_taken = [] return self._get_observation() def _get_observation(self) -> np.ndarray: - """Get current observation state""" - return self.data.iloc[self.current_step].values + """Get current observation state with account information""" + # Get market data for the current step + market_data = self.data.iloc[self.current_step].values + + # Get current price + current_price = self.original_data.iloc[self.current_step]['close'] + + # Calculate unrealized PnL + unrealized_pnl = 0 + if self.position != 0: + price_diff = current_price - self.entry_price + unrealized_pnl = self.position * price_diff + + # Calculate total equity (balance + unrealized PnL) + equity = self.balance + unrealized_pnl + + # Normalize account state + normalized_position = self.position # -1, 0, or 1 + normalized_equity = equity / self.initial_balance - 1.0 # Percent change from initial + normalized_unrealized_pnl = unrealized_pnl / self.initial_balance if self.initial_balance > 0 else 0 + + # Combine market data with account state + account_state = np.array([normalized_position, normalized_equity, normalized_unrealized_pnl]) + observation = np.concatenate([market_data, account_state]) + + # Handle any NaN values + observation = np.nan_to_num(observation, nan=0.0) + + return observation def _calculate_reward(self, action: int) -> float: - """Calculate reward based on action and outcome""" - current_price = self.data.iloc[self.current_step]['close'] + """ + Calculate reward based on action and outcome with improved risk-adjusted metrics - # If we have an open position - if self.position != 0: - # Calculate PnL - pnl = self.position * (current_price - self.entry_price) / self.entry_price - fees = self.fee_rate * 2 # Entry and exit fees + Args: + action: The action taken (0=BUY, 1=SELL, 2=HOLD) - # Close position - if (action == 0 and self.position > 0) or (action == 2 and self.position < 0): - net_pnl = pnl - fees - self.total_pnl += net_pnl - self.balance *= (1 + net_pnl) + Returns: + float: Calculated reward value + """ + # Get current price and next price + current_price = self.original_data.iloc[self.current_step]['close'] + + # Default reward is slightly negative to discourage excessive trading + reward = -0.0001 + pnl = 0.0 + + # Handle different actions based on current position + if self.position == 0: # No position + if action == 0: # BUY + self.position = 1 + self.entry_price = current_price + self.entry_time = self.current_step + reward = -self.fee_rate # Small penalty for trading cost + + elif action == 1: # SELL (start short position) + self.position = -1 + self.entry_price = current_price + self.entry_time = self.current_step + reward = -self.fee_rate # Small penalty for trading cost + + # else action == 2 (HOLD) - keep the small negative reward + + elif self.position > 0: # Long position + if action == 1: # SELL (close long) + # Calculate profit/loss + price_diff = current_price - self.entry_price + pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees + + # Adjust reward based on PnL and risk + reward = pnl * self.reward_scaling + + # Track trade performance + self.total_trades += 1 + if pnl > 0: + self.winning_trades += 1 + else: + self.losing_trades += 1 + + # Calculate trade duration + trade_duration = self.current_step - self.entry_time + self.trade_durations.append(trade_duration) + + # Update returns list + self.returns.append(pnl) + + # Update balance and reset position + self.balance *= (1 + pnl) self.balance_history.append(self.balance) self.max_balance = max(self.max_balance, self.balance) + self.total_pnl += pnl - self.total_trades += 1 - if net_pnl > 0: - self.winning_trades += 1 - - # Reward based on PnL - reward = net_pnl * 100 # Scale up for better learning - - # Additional reward for win rate - win_rate = self.winning_trades / max(1, self.total_trades) - reward += win_rate * 0.1 - + # Reset position self.position = 0 - return reward + + elif action == 0: # BUY (while already long) + # Penalize trying to increase an already active position + reward = -0.001 + + # else action == 2 (HOLD) - calculate unrealized P&L for reward + else: + price_diff = current_price - self.entry_price + unrealized_pnl = price_diff / self.entry_price + + # Small reward/penalty based on unrealized P&L + reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions + + elif self.position < 0: # Short position + if action == 0: # BUY (close short) + # Calculate profit/loss + price_diff = self.entry_price - current_price + pnl = price_diff / self.entry_price - 2 * self.fee_rate # Account for entry and exit fees + + # Adjust reward based on PnL and risk + reward = pnl * self.reward_scaling + + # Track trade performance + self.total_trades += 1 + if pnl > 0: + self.winning_trades += 1 + else: + self.losing_trades += 1 + + # Calculate trade duration + trade_duration = self.current_step - self.entry_time + self.trade_durations.append(trade_duration) + + # Update returns list + self.returns.append(pnl) + + # Update balance and reset position + self.balance *= (1 + pnl) + self.balance_history.append(self.balance) + self.max_balance = max(self.max_balance, self.balance) + self.total_pnl += pnl + + # Reset position + self.position = 0 + + elif action == 1: # SELL (while already short) + # Penalize trying to increase an already active position + reward = -0.001 + + # else action == 2 (HOLD) - calculate unrealized P&L for reward + else: + price_diff = self.entry_price - current_price + unrealized_pnl = price_diff / self.entry_price + + # Small reward/penalty based on unrealized P&L + reward = unrealized_pnl * 0.05 # Scale down to encourage holding good positions + + # Record the action + self.actions_taken.append(action) + self.last_action_time = self.current_step + + # Update equity history (balance + unrealized P&L) + current_equity = self.balance + if self.position != 0: + # Calculate unrealized P&L + if self.position > 0: # Long + price_diff = current_price - self.entry_price + unrealized_pnl = price_diff / self.entry_price * self.balance + else: # Short + price_diff = self.entry_price - current_price + unrealized_pnl = price_diff / self.entry_price * self.balance + + current_equity = self.balance + unrealized_pnl - # Hold position - return pnl * 0.1 # Small reward for holding profitable positions + self.equity_history.append(current_equity) - # No position - if action == 1: # HOLD - return 0 + # Calculate current drawdown + peak_equity = max(self.equity_history) + current_drawdown = (peak_equity - current_equity) / peak_equity if peak_equity > 0 else 0 + self.max_drawdown = max(self.max_drawdown, current_drawdown) - # Open new position - if action in [0, 2]: # SELL or BUY - self.position = -1 if action == 0 else 1 - self.entry_price = current_price - return -self.fee_rate # Small penalty for trading + # Apply risk aversion factor - penalize volatility + if len(self.returns) > 1: + returns_std = np.std(self.returns) + reward -= returns_std * self.risk_aversion - return 0 + return reward, pnl def step(self, action: int) -> Tuple[np.ndarray, float, bool, Dict]: """Execute one step in the environment""" - # Calculate reward - reward = self._calculate_reward(action) + # Calculate reward and update state + reward, pnl = self._calculate_reward(action) # Move to next step self.current_step += 1 @@ -110,26 +301,70 @@ class TradingEnvironment(gym.Env): # Check if episode is done done = self.current_step >= min(self.max_steps - 1, len(self.data) - 1) + # Apply penalty if episode ends with open position + if done and self.position != 0: + reward -= self.episode_penalty + + # Force close the position at the end if still open + current_price = self.original_data.iloc[self.current_step]['close'] + if self.position > 0: # Long position + price_diff = current_price - self.entry_price + pnl = price_diff / self.entry_price - 2 * self.fee_rate + else: # Short position + price_diff = self.entry_price - current_price + pnl = price_diff / self.entry_price - 2 * self.fee_rate + + # Update balance + self.balance *= (1 + pnl) + self.total_pnl += pnl + + # Track trade + self.total_trades += 1 + if pnl > 0: + self.winning_trades += 1 + else: + self.losing_trades += 1 + + # Reset position + self.position = 0 + # Get next observation observation = self._get_observation() - # Calculate max drawdown - max_drawdown = 0 - if len(self.balance_history) > 1: - peak = self.balance_history[0] - for balance in self.balance_history: - peak = max(peak, balance) - drawdown = (peak - balance) / peak - max_drawdown = max(max_drawdown, drawdown) + # Calculate sharpe ratio and sortino ratio if possible + sharpe_ratio = 0 + sortino_ratio = 0 + win_rate = self.winning_trades / max(1, self.total_trades) + + if len(self.returns) > 1: + mean_return = np.mean(self.returns) + std_return = np.std(self.returns) + if std_return > 0: + sharpe_ratio = mean_return / std_return + + # For sortino, we only consider downside deviation + downside_returns = [r for r in self.returns if r < 0] + if downside_returns: + downside_deviation = np.std(downside_returns) + if downside_deviation > 0: + sortino_ratio = mean_return / downside_deviation + + # Calculate average trade duration + avg_trade_duration = np.mean(self.trade_durations) if self.trade_durations else 0 # Additional info info = { 'balance': self.balance, 'position': self.position, 'total_trades': self.total_trades, - 'win_rate': self.winning_trades / max(1, self.total_trades), + 'win_rate': win_rate, 'total_pnl': self.total_pnl, - 'max_drawdown': max_drawdown + 'max_drawdown': self.max_drawdown, + 'sharpe_ratio': sharpe_ratio, + 'sortino_ratio': sortino_ratio, + 'avg_trade_duration': avg_trade_duration, + 'pnl': pnl, + 'gain': (self.balance - self.initial_balance) / self.initial_balance } return observation, reward, done, info @@ -143,20 +378,19 @@ class TradingEnvironment(gym.Env): print(f"Total Trades: {self.total_trades}") print(f"Win Rate: {self.winning_trades/max(1, self.total_trades):.2%}") print(f"Total PnL: ${self.total_pnl:.2f}") - print(f"Max Drawdown: {self._calculate_max_drawdown():.2%}") + print(f"Max Drawdown: {self.max_drawdown:.2%}") + print(f"Sharpe Ratio: {self._calculate_sharpe_ratio():.4f}") print("-" * 50) - - def _calculate_max_drawdown(self): - """Calculate maximum drawdown from balance history""" - if len(self.balance_history) <= 1: + + def _calculate_sharpe_ratio(self): + """Calculate Sharpe ratio from returns""" + if len(self.returns) < 2: return 0.0 - peak = self.balance_history[0] - max_drawdown = 0.0 + mean_return = np.mean(self.returns) + std_return = np.std(self.returns) - for balance in self.balance_history: - peak = max(peak, balance) - drawdown = (peak - balance) / peak - max_drawdown = max(max_drawdown, drawdown) + if std_return == 0: + return 0.0 - return max_drawdown \ No newline at end of file + return mean_return / std_return \ No newline at end of file diff --git a/realtime.py b/realtime.py index 0bfa47f..b5ea1e5 100644 --- a/realtime.py +++ b/realtime.py @@ -22,6 +22,112 @@ import tzlocal import threading import random import dash_bootstrap_components as dbc +import uuid + +class BinanceHistoricalData: + """ + Class for fetching historical price data from Binance. + """ + def __init__(self): + self.base_url = "https://api.binance.com/api/v3" + + def get_historical_candles(self, symbol, interval_seconds=3600, limit=1000): + """ + Fetch historical candles from Binance API. + + Args: + symbol (str): Trading pair symbol (e.g., "BTC/USDT") + interval_seconds (int): Timeframe in seconds (e.g., 3600 for 1h) + limit (int): Number of candles to fetch + + Returns: + pd.DataFrame: DataFrame with OHLCV data + """ + # Convert interval_seconds to Binance interval format + interval_map = { + 1: "1s", + 60: "1m", + 300: "5m", + 900: "15m", + 1800: "30m", + 3600: "1h", + 14400: "4h", + 86400: "1d" + } + + interval = interval_map.get(interval_seconds, "1h") + + # Format symbol for Binance API (remove slash) + formatted_symbol = symbol.replace("/", "") + + try: + # Build URL for klines endpoint + url = f"{self.base_url}/klines" + params = { + "symbol": formatted_symbol, + "interval": interval, + "limit": limit + } + + # Make the request + response = requests.get(url, params=params) + response.raise_for_status() + + # Parse the response + data = response.json() + + # Create dataframe + df = pd.DataFrame(data, columns=[ + "timestamp", "open", "high", "low", "close", "volume", + "close_time", "quote_asset_volume", "number_of_trades", + "taker_buy_base_asset_volume", "taker_buy_quote_asset_volume", "ignore" + ]) + + # Convert timestamp to datetime + df["timestamp"] = pd.to_datetime(df["timestamp"], unit="ms") + + # Convert price columns to float + for col in ["open", "high", "low", "close", "volume"]: + df[col] = df[col].astype(float) + + # Sort by timestamp + df = df.sort_values("timestamp") + + logger.info(f"Fetched {len(df)} candles for {symbol} ({interval})") + return df + + except Exception as e: + logger.error(f"Error fetching historical data from Binance: {str(e)}") + # Return empty dataframe on error + return pd.DataFrame() + + def get_recent_trades(self, symbol, limit=1000): + """Get recent trades for a symbol""" + formatted_symbol = symbol.replace("/", "") + + try: + url = f"{self.base_url}/trades" + params = { + "symbol": formatted_symbol, + "limit": limit + } + + response = requests.get(url, params=params) + response.raise_for_status() + + data = response.json() + + # Create dataframe + df = pd.DataFrame(data) + df["time"] = pd.to_datetime(df["time"], unit="ms") + df["price"] = df["price"].astype(float) + df["qty"] = df["qty"].astype(float) + + return df + + except Exception as e: + logger.error(f"Error fetching recent trades: {str(e)}") + return pd.DataFrame() # Configure logging with more detailed format logging.basicConfig( @@ -63,7 +169,7 @@ def setup_neural_network(): try: # Get configuration from environment variables or use defaults - symbol = os.environ.get('NN_SYMBOL', 'BTC/USDT') + symbol = os.environ.get('NN_SYMBOL', 'ETH/USDT') timeframes = os.environ.get('NN_TIMEFRAMES', '1m,5m,1h,4h,1d').split(',') output_size = int(os.environ.get('NN_OUTPUT_SIZE', '3')) # 3 for BUY/HOLD/SELL @@ -216,1217 +322,192 @@ def convert_to_local_time(timestamp): logger.error(f"Error converting timestamp to local time: {str(e)}") return timestamp -class TradeTickStorage: - """Storage for trade ticks with a maximum age limit""" - - def __init__(self, symbol: str = None, max_age_seconds: int = 3600, use_sample_data: bool = True, log_no_ticks_warning: bool = False): # 1 hour by default, up from 30 min - """Initialize the tick storage - - Args: - symbol: Trading symbol - max_age_seconds: Maximum age for ticks to be stored - use_sample_data: If True, generate sample ticks when no real ticks available - log_no_ticks_warning: If True, log a warning when no ticks are available - """ - self.symbol = symbol +class TickStorage: + """Simple storage for ticks and candles""" + def __init__(self): self.ticks = [] - self.max_age_seconds = max_age_seconds - self.last_cleanup_time = time.time() - self.cleanup_interval = 60 # run cleanup every 60 seconds - self.cache_dir = "cache" - self.use_sample_data = use_sample_data - self.log_no_ticks_warning = log_no_ticks_warning - self.last_sample_price = 83000.0 # Starting price for sample data (for BTC) - self.last_sample_time = time.time() * 1000 # Starting time for sample data - self.last_tick_time = 0 # Initialize last_tick_time attribute - self.tick_count = 0 # Initialize tick_count attribute + self.candles = {} + self.latest_price = None - # Create cache directory if it doesn't exist - if not os.path.exists(self.cache_dir): - os.makedirs(self.cache_dir) - - # Try to load cached ticks - self._load_cached_ticks() + def add_tick(self, price, volume=0, timestamp=None): + """Add a tick to the storage""" + if timestamp is None: + timestamp = datetime.now() - logger.info(f"Initialized TradeTickStorage for {symbol} with max age: {max_age_seconds} seconds, cleanup interval: {self.cleanup_interval} seconds") - - def add_tick(self, tick: Dict): - """Add a tick to storage - - Args: - tick: Tick data dict with fields: - price: The price - volume: The volume - timestamp: Timestamp in milliseconds - """ - if not tick: - return - - # Check if we need to generate a timestamp - if 'timestamp' not in tick: - tick['timestamp'] = int(time.time() * 1000) # Current time in ms - - # Ensure timestamp is an integer (milliseconds since epoch) - if not isinstance(tick['timestamp'], int): - try: - # Try to convert from float or string - tick['timestamp'] = int(float(tick['timestamp'])) - except (ValueError, TypeError): - # If conversion fails, use current time - tick['timestamp'] = int(time.time() * 1000) - - # Set default volume if not present - if 'volume' not in tick: - tick['volume'] = 0.01 # Default small volume - - # Add tick to storage with a copy to avoid mutation - self.ticks.append(tick.copy()) - - # Keep track of latest tick for stats - self.last_tick_time = max(self.last_tick_time, tick['timestamp']) - - # Cache every 100 ticks to avoid data loss - self.tick_count += 1 - if self.tick_count % 100 == 0: - self._cache_ticks() - - # Periodically clean up old ticks - if self.tick_count % 1000 == 0: - self._cleanup() - - def _cleanup(self): - """Remove ticks older than max_age_seconds""" - # Get current time in milliseconds - now = int(time.time() * 1000) - - # Remove old ticks - cutoff = now - (self.max_age_seconds * 1000) - original_count = len(self.ticks) - self.ticks = [tick for tick in self.ticks if tick['timestamp'] >= cutoff] - removed = original_count - len(self.ticks) - - if removed > 0: - logger.debug(f"Cleaned up {removed} old ticks, remaining: {len(self.ticks)}") - - def _load_cached_ticks(self): - """Load cached ticks from disk on startup""" - # Create symbol-specific filename - symbol_safe = self.symbol.replace("/", "_").replace("-", "_").lower() - cache_file = os.path.join(self.cache_dir, f"{symbol_safe}_recent_ticks.csv") - - if not os.path.exists(cache_file): - logger.info(f"No cached ticks found for {self.symbol}") - return - - try: - # Check if cache is fresh (less than 10 minutes old) - file_age = time.time() - os.path.getmtime(cache_file) - if file_age > 600: # 10 minutes - logger.info(f"Cached ticks for {self.symbol} are too old ({file_age:.1f}s), skipping") - return - - # Load cached ticks - tick_df = pd.read_csv(cache_file) - if tick_df.empty: - logger.info(f"Cached ticks file for {self.symbol} is empty") - return - - # Convert to list of dicts and add to storage - cached_ticks = tick_df.to_dict('records') - self.ticks.extend(cached_ticks) - logger.info(f"Loaded {len(cached_ticks)} cached ticks for {self.symbol} from {cache_file}") - except Exception as e: - logger.error(f"Error loading cached ticks for {self.symbol}: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - - def _cache_ticks(self): - """Cache recent ticks to disk""" - if not self.ticks: - return - - # Get ticks from last 10 minutes - now = int(time.time() * 1000) # Current time in ms - cutoff = now - (600 * 1000) # 10 minutes in ms - recent_ticks = [tick for tick in self.ticks if tick['timestamp'] >= cutoff] - - if not recent_ticks: - logger.debug("No recent ticks to cache") - return - - # Create symbol-specific filename - symbol_safe = self.symbol.replace("/", "_").replace("-", "_").lower() - cache_file = os.path.join(self.cache_dir, f"{symbol_safe}_recent_ticks.csv") - - # Save to disk - try: - tick_df = pd.DataFrame(recent_ticks) - tick_df.to_csv(cache_file, index=False) - logger.info(f"Cached {len(recent_ticks)} recent ticks for {self.symbol} to {cache_file}") - except Exception as e: - logger.error(f"Error caching ticks: {str(e)}") - - def get_latest_price(self) -> Optional[float]: - """Get the latest price from the most recent tick""" - if self.ticks: - return self.ticks[-1].get('price') - return None - - def get_price_stats(self) -> Dict: - """Get stats about the prices in storage""" - if not self.ticks: - return { - 'min': None, - 'max': None, - 'latest': None, - 'count': 0, - 'age_seconds': 0 - } - - prices = [tick['price'] for tick in self.ticks] - latest_timestamp = self.ticks[-1]['timestamp'] - oldest_timestamp = self.ticks[0]['timestamp'] - - return { - 'min': min(prices), - 'max': max(prices), - 'latest': prices[-1], - 'count': len(prices), - 'age_seconds': (latest_timestamp - oldest_timestamp) / 1000 - } - - def get_ticks_as_df(self) -> pd.DataFrame: - """Return ticks as a DataFrame""" - if not self.ticks: - logger.warning("No ticks available for DataFrame conversion") - return pd.DataFrame() - - # Ensure we have fresh data - self._cleanup() - - # Create a new list from ticks to avoid modifying the original data - ticks_data = self.ticks.copy() - - # Ensure we have the latest ticks at the end of the DataFrame - ticks_data.sort(key=lambda x: x['timestamp']) - - df = pd.DataFrame(ticks_data) - if not df.empty: - logger.debug(f"Converting timestamps for {len(df)} ticks") - # Ensure timestamp column exists - if 'timestamp' not in df.columns: - logger.error("Tick data missing timestamp column") - return pd.DataFrame() - - # Check timestamp datatype before conversion - sample_ts = df['timestamp'].iloc[0] if len(df) > 0 else None - logger.debug(f"Sample timestamp before conversion: {sample_ts}, type: {type(sample_ts)}") - - # Convert timestamps to datetime - try: - df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') - logger.debug(f"Timestamps converted to datetime successfully") - if len(df) > 0: - logger.debug(f"Sample converted timestamp: {df['timestamp'].iloc[0]}") - except Exception as e: - logger.error(f"Error converting timestamps: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - return pd.DataFrame() - return df - - def get_candles(self, interval_seconds: int = 1, start_time_ms: int = None, end_time_ms: int = None) -> pd.DataFrame: - """Generate candlestick data from ticks - - Args: - interval_seconds: Interval in seconds for each candle - start_time_ms: Start time in milliseconds - end_time_ms: End time in milliseconds - - Returns: - DataFrame with candlestick data - """ - # Get filtered ticks - ticks = self.get_ticks_from_time(start_time_ms, end_time_ms) - - if not ticks: - if self.use_sample_data: - # Generate multiple sample ticks to create several candles - current_time = int(time.time() * 1000) - sample_ticks = [] - - # Generate ticks for the past 10 intervals - for i in range(20): - # Base price with some trend - base_price = self.last_sample_price * (1 + 0.0001 * (10 - i)) - - # Add some randomness to the price - random_factor = random.uniform(-0.002, 0.002) # Small random change - tick_price = base_price * (1 + random_factor) - - # Create timestamp with appropriate offset - tick_time = current_time - (i * interval_seconds * 1000 // 2) - - sample_tick = { - 'price': tick_price, - 'volume': random.uniform(0.01, 0.5), - 'timestamp': tick_time, - 'is_sample': True - } - - sample_ticks.append(sample_tick) - - # Update the last sample values - self.last_sample_price = sample_ticks[0]['price'] - self.last_sample_time = sample_ticks[0]['timestamp'] - - # Add the sample ticks in chronological order - for tick in sorted(sample_ticks, key=lambda x: x['timestamp']): - self.add_tick(tick) - - # Try again with the new ticks - ticks = self.get_ticks_from_time(start_time_ms, end_time_ms) - - if not ticks and self.log_no_ticks_warning: - logger.warning("Still no ticks available after adding sample data") - elif self.log_no_ticks_warning: - logger.warning("No ticks available for candle formation") - return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']) - else: - return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']) - - # Ensure ticks are up to date - try: - self._cleanup() - except Exception as cleanup_error: - logger.error(f"Error cleaning up ticks: {str(cleanup_error)}") - - df = pd.DataFrame(ticks) - if df.empty: - logger.warning("Tick DataFrame is empty after filtering/conversion") - return pd.DataFrame() - - logger.info(f"Preparing to create candles from {len(df)} ticks with {interval_seconds}s interval") - - # First, ensure all required columns exist - required_columns = ['timestamp', 'price', 'volume'] - for col in required_columns: - if col not in df.columns: - logger.error(f"Required column '{col}' missing from tick data") - return pd.DataFrame() - - # Make sure DataFrame has no duplicated timestamps before setting index - try: - if 'timestamp' in df.columns: - # Check for duplicate timestamps - duplicate_count = df['timestamp'].duplicated().sum() - if duplicate_count > 0: - logger.warning(f"Found {duplicate_count} duplicate timestamps, keeping the last occurrence") - # Keep the last occurrence of each timestamp - df = df.drop_duplicates(subset='timestamp', keep='last') - - # Convert timestamp to datetime if it's not already - if not pd.api.types.is_datetime64_any_dtype(df['timestamp']): - logger.debug("Converting timestamp to datetime") - # Try multiple approaches to convert timestamps - try: - # First, try to convert from milliseconds (integer timestamps) - if pd.api.types.is_integer_dtype(df['timestamp']): - df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') - else: - # Otherwise try standard conversion - df['timestamp'] = pd.to_datetime(df['timestamp']) - except Exception as conv_error: - logger.error(f"Error converting timestamps to datetime: {str(conv_error)}") - # Try a fallback approach - try: - # Fallback for integer timestamps - if df['timestamp'].iloc[0] > 1000000000000: # Check if milliseconds timestamp - df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') - else: # Otherwise assume seconds - df['timestamp'] = pd.to_datetime(df['timestamp'], unit='s') - except Exception as fallback_error: - logger.error(f"Fallback timestamp conversion failed: {str(fallback_error)}") - return pd.DataFrame() - - # Use timestamp column for resampling - df = df.set_index('timestamp') - except Exception as prep_error: - logger.error(f"Error preprocessing DataFrame for resampling: {str(prep_error)}") - import traceback - logger.error(traceback.format_exc()) - return pd.DataFrame() - - # Create interval string for resampling - use 's' instead of deprecated 'S' - interval_str = f'{interval_seconds}s' - - # Resample to create OHLCV candles with multiple fallback options - logger.debug(f"Resampling with interval: {interval_str}") - - candles = None - - # First attempt - individual column resampling - try: - # Check that price column exists and has enough data - if 'price' not in df.columns: - raise ValueError("Price column missing from DataFrame") - - if len(df) < 2: - logger.warning("Not enough data points for resampling, using direct data") - # For single data point, create a single candle - if len(df) == 1: - price_val = df['price'].iloc[0] - volume_val = df['volume'].iloc[0] if 'volume' in df.columns else 0 - timestamp_val = df.index[0] - - candles = pd.DataFrame({ - 'timestamp': [timestamp_val], - 'open': [price_val], - 'high': [price_val], - 'low': [price_val], - 'close': [price_val], - 'volume': [volume_val] - }) - return candles - else: - # No data - return pd.DataFrame() - - # Resample and aggregate each column separately - open_df = df['price'].resample(interval_str).first() - high_df = df['price'].resample(interval_str).max() - low_df = df['price'].resample(interval_str).min() - close_df = df['price'].resample(interval_str).last() - volume_df = df['volume'].resample(interval_str).sum() - - # Check for length mismatches before combining - expected_length = len(open_df) - if (len(high_df) != expected_length or - len(low_df) != expected_length or - len(close_df) != expected_length or - len(volume_df) != expected_length): - logger.warning("Length mismatch in resampled columns, falling back to alternative method") - raise ValueError("Length mismatch") - - # Combine into a single DataFrame - candles = pd.DataFrame({ - 'open': open_df, - 'high': high_df, - 'low': low_df, - 'close': close_df, - 'volume': volume_df - }) - logger.debug(f"Successfully created {len(candles)} candles with individual column resampling") - except Exception as resample_error: - logger.error(f"Error in individual column resampling: {str(resample_error)}") - - # Second attempt - built-in agg method - try: - logger.debug("Trying fallback resampling method with agg()") - candles = df.resample(interval_str).agg({ - 'price': ['first', 'max', 'min', 'last'], - 'volume': 'sum' - }) - # Flatten MultiIndex columns - candles.columns = ['open', 'high', 'low', 'close', 'volume'] - logger.debug(f"Successfully created {len(candles)} candles with agg() method") - except Exception as agg_error: - logger.error(f"Error in agg() resampling: {str(agg_error)}") - - # Third attempt - manual candle construction - try: - logger.debug("Trying manual candle construction method") - resampler = df.resample(interval_str) - candle_data = [] - - for name, group in resampler: - if not group.empty: - candle = { - 'timestamp': name, - 'open': group['price'].iloc[0], - 'high': group['price'].max(), - 'low': group['price'].min(), - 'close': group['price'].iloc[-1], - 'volume': group['volume'].sum() if 'volume' in group.columns else 0 - } - candle_data.append(candle) - - if candle_data: - candles = pd.DataFrame(candle_data) - logger.debug(f"Successfully created {len(candles)} candles with manual method") - else: - logger.warning("No candles created with manual method") - return pd.DataFrame() - except Exception as manual_error: - logger.error(f"Error in manual candle construction: {str(manual_error)}") - import traceback - logger.error(traceback.format_exc()) - return pd.DataFrame() - - # Ensure the result isn't empty - if candles is None or candles.empty: - logger.warning("No candles were created after all resampling attempts") - return pd.DataFrame() - - # Reset index to get timestamp as column - try: - candles = candles.reset_index() - except Exception as reset_error: - logger.error(f"Error resetting index: {str(reset_error)}") - # Try to create a new DataFrame with the timestamp index as a column - try: - timestamp_col = candles.index.to_list() - candles_dict = candles.to_dict('list') - candles_dict['timestamp'] = timestamp_col - candles = pd.DataFrame(candles_dict) - except Exception as fallback_error: - logger.error(f"Error in fallback index reset: {str(fallback_error)}") - return pd.DataFrame() - - # Ensure no NaN values - try: - nan_count_before = candles.isna().sum().sum() - if nan_count_before > 0: - logger.warning(f"Found {nan_count_before} NaN values in candles, dropping them") - - candles = candles.dropna() - except Exception as nan_error: - logger.error(f"Error handling NaN values: {str(nan_error)}") - # Try to fill NaN values instead of dropping - try: - candles = candles.fillna(method='ffill').fillna(method='bfill') - except: - pass - - logger.debug(f"Generated {len(candles)} candles from {len(df)} ticks") - return candles - - def get_candle_stats(self) -> Dict: - """Get statistics about cached candles for different intervals""" - stats = {} - - # Define intervals to check - intervals = [1, 5, 15, 60, 300, 900, 3600] - - for interval in intervals: - candles = self.get_candles(interval_seconds=interval) - count = len(candles) if not candles.empty else 0 - - # Get time range if we have candles - time_range = None - if count > 0: - try: - start_time = candles['timestamp'].min() - end_time = candles['timestamp'].max() - if isinstance(start_time, pd.Timestamp): - start_time = start_time.strftime('%Y-%m-%d %H:%M:%S') - if isinstance(end_time, pd.Timestamp): - end_time = end_time.strftime('%Y-%m-%d %H:%M:%S') - time_range = f"{start_time} to {end_time}" - except: - time_range = "Unknown" - - stats[f"{interval}s"] = { - 'count': count, - 'time_range': time_range - } - - return stats - - def get_ticks_from_time(self, start_time_ms: int = None, end_time_ms: int = None) -> List[Dict]: - """Get ticks within a specific time range - - Args: - start_time_ms: Start time in milliseconds (None for no lower bound) - end_time_ms: End time in milliseconds (None for no upper bound) - - Returns: - List of ticks within the time range - """ - if not self.ticks: - return [] - - # Ensure ticks are updated - self._cleanup() - - # Apply time filters if specified - filtered_ticks = self.ticks - if start_time_ms is not None: - filtered_ticks = [tick for tick in filtered_ticks if tick['timestamp'] >= start_time_ms] - if end_time_ms is not None: - filtered_ticks = [tick for tick in filtered_ticks if tick['timestamp'] <= end_time_ms] - - logger.debug(f"Retrieved {len(filtered_ticks)} ticks from time range {start_time_ms} to {end_time_ms}") - return filtered_ticks - - def get_time_based_stats(self) -> Dict: - """Get statistics about the ticks organized by time periods - - Returns: - Dictionary with statistics for different time periods - """ - if not self.ticks: - return { - 'total_ticks': 0, - 'periods': {} - } - - # Ensure ticks are updated - self._cleanup() - - now = int(time.time() * 1000) # Current time in ms - - # Define time periods to analyze - periods = { - '1min': now - (60 * 1000), - '5min': now - (5 * 60 * 1000), - '15min': now - (15 * 60 * 1000), - '30min': now - (30 * 60 * 1000) + tick = { + 'price': price, + 'volume': volume, + 'timestamp': timestamp } - stats = { - 'total_ticks': len(self.ticks), - 'oldest_tick': self.ticks[0]['timestamp'] if self.ticks else None, - 'newest_tick': self.ticks[-1]['timestamp'] if self.ticks else None, - 'time_span_seconds': (self.ticks[-1]['timestamp'] - self.ticks[0]['timestamp']) / 1000 if self.ticks else 0, - 'periods': {} + self.ticks.append(tick) + self.latest_price = price + + # Keep only last 10000 ticks + if len(self.ticks) > 10000: + self.ticks = self.ticks[-10000:] + + # Update candles + self._update_candles(tick) + + def get_latest_price(self): + """Get the latest price""" + return self.latest_price + + def _update_candles(self, tick): + """Update candles with the new tick""" + intervals = { + '1m': 60, + '5m': 300, + '15m': 900, + '1h': 3600, + '4h': 14400, + '1d': 86400 } - # Calculate stats for each period - for period_name, cutoff_time in periods.items(): - period_ticks = [tick for tick in self.ticks if tick['timestamp'] >= cutoff_time] + for interval_key, seconds in intervals.items(): + if interval_key not in self.candles: + self.candles[interval_key] = [] + + # Get or create the current candle + current_candle = self._get_current_candle(interval_key, tick['timestamp'], seconds) - if period_ticks: - prices = [tick['price'] for tick in period_ticks] - volumes = [tick.get('volume', 0) for tick in period_ticks] - - period_stats = { - 'tick_count': len(period_ticks), - 'min_price': min(prices) if prices else None, - 'max_price': max(prices) if prices else None, - 'avg_price': sum(prices) / len(prices) if prices else None, - 'last_price': period_ticks[-1]['price'] if period_ticks else None, - 'total_volume': sum(volumes), - 'ticks_per_second': len(period_ticks) / (int(period_name[:-3]) * 60) if period_ticks else 0 - } - - stats['periods'][period_name] = period_stats + # Update the candle with the new tick + if current_candle['high'] < tick['price']: + current_candle['high'] = tick['price'] + if current_candle['low'] > tick['price']: + current_candle['low'] = tick['price'] + current_candle['close'] = tick['price'] + current_candle['volume'] += tick['volume'] + + def _get_current_candle(self, interval_key, timestamp, interval_seconds): + """Get the current candle for the given interval, or create a new one""" + # Calculate the candle start time + candle_start = timestamp.replace( + microsecond=0, + second=0, + minute=(timestamp.minute // (interval_seconds // 60)) * (interval_seconds // 60) + ) - logger.debug(f"Generated time-based stats: {len(stats['periods'])} periods") - return stats - -class CandlestickData: - def __init__(self, max_length: int = 300): - self.timestamps = deque(maxlen=max_length) - self.opens = deque(maxlen=max_length) - self.highs = deque(maxlen=max_length) - self.lows = deque(maxlen=max_length) - self.closes = deque(maxlen=max_length) - self.volumes = deque(maxlen=max_length) - self.current_candle = { - 'timestamp': None, - 'open': None, - 'high': None, - 'low': None, - 'close': None, + if interval_seconds >= 3600: # For hourly or higher + hours = (timestamp.hour // (interval_seconds // 3600)) * (interval_seconds // 3600) + candle_start = candle_start.replace(hour=hours) + + if interval_seconds >= 86400: # For daily + candle_start = candle_start.replace(hour=0) + + # Check if we already have a candle for this time + for candle in self.candles[interval_key]: + if candle['timestamp'] == candle_start: + return candle + + # Create a new candle + candle = { + 'timestamp': candle_start, + 'open': self.latest_price if self.latest_price is not None else tick['price'], + 'high': tick['price'], + 'low': tick['price'], + 'close': tick['price'], 'volume': 0 } - self.candle_interval = 1 # 1 second by default - - def update_from_trade(self, trade: Dict): - timestamp = trade['timestamp'] - price = trade['price'] - volume = trade.get('volume', 0) - - # Round timestamp to nearest candle interval - candle_timestamp = int(timestamp / (self.candle_interval * 1000)) * (self.candle_interval * 1000) - - if self.current_candle['timestamp'] != candle_timestamp: - # Save current candle if it exists - if self.current_candle['timestamp'] is not None: - self.timestamps.append(self.current_candle['timestamp']) - self.opens.append(self.current_candle['open']) - self.highs.append(self.current_candle['high']) - self.lows.append(self.current_candle['low']) - self.closes.append(self.current_candle['close']) - self.volumes.append(self.current_candle['volume']) - logger.debug(f"New candle saved: {self.current_candle}") - - # Start new candle - self.current_candle = { - 'timestamp': candle_timestamp, - 'open': price, - 'high': price, - 'low': price, - 'close': price, - 'volume': volume - } - logger.debug(f"New candle started: {self.current_candle}") - else: - # Update current candle - if self.current_candle['high'] is None or price > self.current_candle['high']: - self.current_candle['high'] = price - if self.current_candle['low'] is None or price < self.current_candle['low']: - self.current_candle['low'] = price - self.current_candle['close'] = price - self.current_candle['volume'] += volume - logger.debug(f"Updated current candle: {self.current_candle}") - - def get_dataframe(self) -> pd.DataFrame: - # Include current candle in the dataframe if it exists - timestamps = list(self.timestamps) - opens = list(self.opens) - highs = list(self.highs) - lows = list(self.lows) - closes = list(self.closes) - volumes = list(self.volumes) - - if self.current_candle['timestamp'] is not None: - timestamps.append(self.current_candle['timestamp']) - opens.append(self.current_candle['open']) - highs.append(self.current_candle['high']) - lows.append(self.current_candle['low']) - closes.append(self.current_candle['close']) - volumes.append(self.current_candle['volume']) - - df = pd.DataFrame({ - 'timestamp': timestamps, - 'open': opens, - 'high': highs, - 'low': lows, - 'close': closes, - 'volume': volumes - }) - if not df.empty: - df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') + + self.candles[interval_key].append(candle) + return candle + + def get_candles(self, interval='1m'): + """Get candles for the given interval""" + if interval not in self.candles or not self.candles[interval]: + return None + + # Convert to DataFrame + df = pd.DataFrame(self.candles[interval]) + df.set_index('timestamp', inplace=True) return df - -class BinanceWebSocket: - """Binance WebSocket implementation for real-time tick data""" - def __init__(self, symbol: str): - self.symbol = symbol.replace('/', '').lower() - self.ws = None - self.running = False - self.reconnect_delay = 1 - self.max_reconnect_delay = 60 - self.message_count = 0 - - # Binance WebSocket configuration - self.ws_url = f"wss://stream.binance.com:9443/ws/{self.symbol}@trade" - logger.info(f"Initialized Binance WebSocket for symbol: {self.symbol}") - - async def connect(self): - while True: - try: - logger.info(f"Attempting to connect to {self.ws_url}") - self.ws = await websockets.connect(self.ws_url) - logger.info("WebSocket connection established") - - self.running = True - self.reconnect_delay = 1 - logger.info(f"Successfully connected to Binance WebSocket for {self.symbol}") - return True - except Exception as e: - logger.error(f"WebSocket connection error: {str(e)}") - await asyncio.sleep(self.reconnect_delay) - self.reconnect_delay = min(self.reconnect_delay * 2, self.max_reconnect_delay) - continue - - async def receive(self) -> Optional[Dict]: - if not self.ws: - return None + def load_from_file(self, file_path): + """Load ticks from a file""" try: - message = await self.ws.recv() - self.message_count += 1 - - if self.message_count % 100 == 0: # Log every 100th message to avoid spam - logger.info(f"Received message #{self.message_count}") - logger.debug(f"Raw message: {message[:200]}...") - - data = json.loads(message) - - # Process trade data - if 'e' in data and data['e'] == 'trade': - trade_data = { - 'timestamp': data['T'], # Trade time - 'price': float(data['p']), # Price - 'volume': float(data['q']), # Quantity - 'type': 'trade' - } - logger.debug(f"Processed trade data: {trade_data}") - return trade_data - - return None - except websockets.exceptions.ConnectionClosed: - logger.warning("WebSocket connection closed") - self.running = False - return None - except json.JSONDecodeError as e: - logger.error(f"JSON decode error: {str(e)}, message: {message[:200]}...") - return None - except Exception as e: - logger.error(f"Error receiving message: {str(e)}") - return None - - async def close(self): - """Close the WebSocket connection""" - if self.ws: - await self.ws.close() - self.running = False - logger.info("WebSocket connection closed") - -class BinanceHistoricalData: - """Fetch historical candle data from Binance""" - - def __init__(self): - self.base_url = "https://api.binance.com/api/v3/klines" - # Create a cache directory if it doesn't exist - self.cache_dir = os.path.join(os.getcwd(), "cache") - os.makedirs(self.cache_dir, exist_ok=True) - logger.info(f"Initialized BinanceHistoricalData with cache directory: {self.cache_dir}") - - def _get_interval_string(self, interval_seconds: int) -> str: - """Convert interval seconds to Binance interval string""" - if interval_seconds == 60: # 1m - return "1m" - elif interval_seconds == 300: # 5m - return "5m" - elif interval_seconds == 900: # 15m - return "15m" - elif interval_seconds == 1800: # 30m - return "30m" - elif interval_seconds == 3600: # 1h - return "1h" - elif interval_seconds == 14400: # 4h - return "4h" - elif interval_seconds == 86400: # 1d - return "1d" - else: - # Default to 1m if not recognized - logger.warning(f"Unrecognized interval {interval_seconds}s, defaulting to 1m") - return "1m" - - def _get_cache_filename(self, symbol: str, interval: str) -> str: - """Generate cache filename for the symbol and interval""" - # Replace any slashes in symbol with underscore - safe_symbol = symbol.replace("/", "_") - return os.path.join(self.cache_dir, f"{safe_symbol}_{interval}_candles.csv") - - def _load_from_cache(self, symbol: str, interval: str) -> Optional[pd.DataFrame]: - """Load candle data from cache if available and not expired""" - filename = self._get_cache_filename(symbol, interval) - - if not os.path.exists(filename): - logger.debug(f"No cache file found for {symbol} {interval}") - return None - - # Check if cache is fresh (less than 1 hour old for anything but 1d, 1 day for 1d) - file_age = time.time() - os.path.getmtime(filename) - max_age = 86400 if interval == "1d" else 3600 # 1 day for 1d, 1 hour for others - - if file_age > max_age: - logger.debug(f"Cache for {symbol} {interval} is expired ({file_age:.1f}s old)") - return None - - try: - df = pd.read_csv(filename) - # Convert timestamp string back to datetime - df['timestamp'] = pd.to_datetime(df['timestamp']) - logger.info(f"Loaded {len(df)} candles from cache for {symbol} {interval}") - return df - except Exception as e: - logger.error(f"Error loading from cache: {str(e)}") - return None - - def _save_to_cache(self, df: pd.DataFrame, symbol: str, interval: str) -> bool: - """Save candle data to cache""" - if df.empty: - logger.warning(f"No data to cache for {symbol} {interval}") - return False - - filename = self._get_cache_filename(symbol, interval) - try: - df.to_csv(filename, index=False) - logger.info(f"Cached {len(df)} candles for {symbol} {interval} to {filename}") - return True - except Exception as e: - logger.error(f"Error saving to cache: {str(e)}") - return False - - def get_historical_candles(self, symbol: str, interval_seconds: int, limit: int = 500) -> pd.DataFrame: - """Get historical candle data for the specified symbol and interval""" - # Convert to Binance format - clean_symbol = symbol.replace("/", "") - interval = self._get_interval_string(interval_seconds) - - # Try to load from cache first - cached_data = self._load_from_cache(symbol, interval) - if cached_data is not None and len(cached_data) >= limit: - return cached_data.tail(limit) - - # Fetch from API if not cached or insufficient - try: - logger.info(f"Fetching {limit} historical candles for {symbol} ({interval}) from Binance API") - - params = { - "symbol": clean_symbol, - "interval": interval, - "limit": limit - } - - response = requests.get(self.base_url, params=params) - response.raise_for_status() # Raise exception for HTTP errors - - # Process the data - candles = response.json() - - if not candles: - logger.warning(f"No candles returned from Binance for {symbol} {interval}") - return pd.DataFrame() - - # Convert to DataFrame - Binance returns data in this format: - # [ - # [ - # 1499040000000, // Open time - # "0.01634790", // Open - # "0.80000000", // High - # "0.01575800", // Low - # "0.01577100", // Close - # "148976.11427815", // Volume - # ... // Ignore the rest - # ], - # ... - # ] - - df = pd.DataFrame(candles, columns=[ - "timestamp", "open", "high", "low", "close", "volume", - "close_time", "quote_asset_volume", "number_of_trades", - "taker_buy_base_asset_volume", "taker_buy_quote_asset_volume", "ignore" - ]) - - # Convert types - df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') - for col in ["open", "high", "low", "close", "volume"]: - df[col] = df[col].astype(float) - - # Keep only needed columns - df = df[["timestamp", "open", "high", "low", "close", "volume"]] - - # Cache the results - self._save_to_cache(df, symbol, interval) - - logger.info(f"Successfully fetched {len(df)} candles for {symbol} {interval}") - return df - - except Exception as e: - logger.error(f"Error fetching historical data for {symbol} {interval}: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - return pd.DataFrame() - - -class ExchangeWebSocket: - """Generic WebSocket interface for cryptocurrency exchanges""" - def __init__(self, symbol: str, exchange: str = "binance"): - self.symbol = symbol - self.exchange = exchange.lower() - self.ws = None - - # Initialize the appropriate WebSocket implementation - if self.exchange == "binance": - self.ws = BinanceWebSocket(symbol) - elif self.exchange == "mexc": - self.ws = MEXCWebSocket(symbol) - else: - raise ValueError(f"Unsupported exchange: {exchange}") - - async def connect(self): - """Connect to the exchange WebSocket""" - return await self.ws.connect() - - async def receive(self) -> Optional[Dict]: - """Receive data from the WebSocket""" - return await self.ws.receive() - - async def close(self): - """Close the WebSocket connection""" - await self.ws.close() - - @property - def running(self): - """Check if the WebSocket is running""" - return self.ws.running if self.ws else False - -class CandleCache: - def __init__(self, max_candles: int = 5000): - self.candles = { - '1s': deque(maxlen=max_candles), - '1m': deque(maxlen=max_candles), - '1h': deque(maxlen=max_candles), - '1d': deque(maxlen=max_candles) - } - logger.info(f"Initialized CandleCache with max candles: {max_candles}") - - def add_candles(self, interval: str, new_candles: pd.DataFrame): - if interval in self.candles and not new_candles.empty: - # Convert DataFrame to list of dicts to avoid pandas issues - for _, row in new_candles.iterrows(): - candle_dict = row.to_dict() - self.candles[interval].append(candle_dict) - logger.debug(f"Added {len(new_candles)} candles to {interval} cache") - - def get_recent_candles(self, interval: str, count: int = 500) -> pd.DataFrame: - if interval in self.candles and self.candles[interval]: - # Convert deque to list of dicts first - all_candles = list(self.candles[interval]) - # Check if we're requesting more candles than we have - if count > len(all_candles): - logger.debug(f"Requested {count} candles, but only have {len(all_candles)} for {interval}") - count = len(all_candles) - - recent_candles = all_candles[-count:] - logger.debug(f"Returning {len(recent_candles)} recent candles for {interval} (requested {count})") - - # Create DataFrame and ensure timestamp is datetime type - df = pd.DataFrame(recent_candles) - if not df.empty and 'timestamp' in df.columns: - try: - if not pd.api.types.is_datetime64_any_dtype(df['timestamp']): - df['timestamp'] = pd.to_datetime(df['timestamp']) - except Exception as e: - logger.warning(f"Error converting timestamps in get_recent_candles: {str(e)}") - - return df - - logger.debug(f"No candles available for {interval}") - return pd.DataFrame() - - def update_cache(self, interval: str, new_candles: pd.DataFrame): - """ - Update the candle cache for a specific timeframe with new candles - - Args: - interval: The timeframe interval ('1s', '1m', '1h', '1d') - new_candles: DataFrame with new candles to add to the cache - """ - if interval not in self.candles: - logger.warning(f"Invalid interval {interval} for cache update") - return - - if new_candles is None or new_candles.empty: - logger.debug(f"No new candles to update {interval} cache") - return - - # Check if timestamp column exists - if 'timestamp' not in new_candles.columns: - logger.warning(f"No timestamp column in new candles for {interval}") - return - - try: - # Ensure timestamp is datetime for proper comparison - try: - if not pd.api.types.is_datetime64_any_dtype(new_candles['timestamp']): - logger.debug(f"Converting timestamps to datetime for {interval}") - new_candles['timestamp'] = pd.to_datetime(new_candles['timestamp']) - except Exception as e: - logger.error(f"Error converting timestamps: {str(e)}") - # Try a different approach - try: - new_candles['timestamp'] = pd.to_datetime(new_candles['timestamp'], errors='coerce') - # Drop any rows where conversion failed - new_candles = new_candles.dropna(subset=['timestamp']) - if new_candles.empty: - logger.warning(f"All timestamps conversion failed for {interval}") - return - except Exception as e2: - logger.error(f"Second attempt to convert timestamps failed: {str(e2)}") - return - - # Create a copy to avoid modifying the original - new_candles_copy = new_candles.copy() - - # If we have no candles in cache, add all new candles - if not self.candles[interval]: - logger.debug(f"No existing candles for {interval}, adding all {len(new_candles_copy)} candles") - self.add_candles(interval, new_candles_copy) - return - - # Get the timestamp from the last cached candle - last_cached_candle = self.candles[interval][-1] - if not isinstance(last_cached_candle, dict): - logger.warning(f"Last cached candle is not a dictionary for {interval}") - last_cached_candle = {'timestamp': None} - - if 'timestamp' not in last_cached_candle: - logger.warning(f"No timestamp in last cached candle for {interval}") - last_cached_candle['timestamp'] = None - - last_cached_time = last_cached_candle['timestamp'] - logger.debug(f"Last cached timestamp for {interval}: {last_cached_time}") - - # If last_cached_time is None, add all candles - if last_cached_time is None: - logger.debug(f"No valid last cached timestamp, adding all {len(new_candles_copy)} candles for {interval}") - self.add_candles(interval, new_candles_copy) - return - - # Convert last_cached_time to datetime if needed - if not isinstance(last_cached_time, (pd.Timestamp, datetime)): - try: - last_cached_time = pd.to_datetime(last_cached_time) - except Exception as e: - logger.error(f"Cannot convert last cached time to datetime: {str(e)}") - # Add all candles as fallback - self.add_candles(interval, new_candles_copy) - return - - # Make a backup of current cache before filtering - cache_backup = list(self.candles[interval]) - - # Filter new candles that are after the last cached candle - try: - filtered_candles = new_candles_copy[new_candles_copy['timestamp'] > last_cached_time] - - if not filtered_candles.empty: - logger.debug(f"Adding {len(filtered_candles)} new candles for {interval}") - self.add_candles(interval, filtered_candles) + df = pd.read_csv(file_path) + for _, row in df.iterrows(): + if 'timestamp' in row: + timestamp = pd.to_datetime(row['timestamp']) else: - # No new candles after last cached time, check for missing candles - try: - # Get unique timestamps in cache - cached_timestamps = set() - for candle in self.candles[interval]: - if isinstance(candle, dict) and 'timestamp' in candle: - ts = candle['timestamp'] - if isinstance(ts, (pd.Timestamp, datetime)): - cached_timestamps.add(ts) - else: - try: - cached_timestamps.add(pd.to_datetime(ts)) - except: - pass - - # Find candles in new_candles that aren't in the cache - missing_candles = new_candles_copy[~new_candles_copy['timestamp'].isin(cached_timestamps)] - - if not missing_candles.empty: - logger.info(f"Found {len(missing_candles)} missing candles for {interval}") - self.add_candles(interval, missing_candles) - else: - logger.debug(f"No new or missing candles to add for {interval}") - except Exception as missing_error: - logger.error(f"Error checking for missing candles: {str(missing_error)}") - except Exception as filter_error: - logger.error(f"Error filtering candles by timestamp: {str(filter_error)}") - # Restore from backup - self.candles[interval] = deque(cache_backup, maxlen=self.candles[interval].maxlen) - # Try adding all candles as fallback - self.add_candles(interval, new_candles_copy) + timestamp = None + + self.add_tick( + price=row.get('price', row.get('close', 0)), + volume=row.get('volume', 0), + timestamp=timestamp + ) + logger.info(f"Loaded {len(df)} ticks from {file_path}") except Exception as e: - logger.error(f"Unhandled error updating cache for {interval}: {str(e)}") + logger.error(f"Error loading ticks from file: {str(e)}") + + def load_historical_data(self, historical_data, symbol): + """Load historical data""" + try: + df = historical_data.get_historical_candles(symbol) + if df is not None and not df.empty: + for _, row in df.iterrows(): + self.add_tick( + price=row['close'], + volume=row['volume'], + timestamp=row['timestamp'] + ) + logger.info(f"Loaded {len(df)} historical candles for {symbol}") + except Exception as e: + logger.error(f"Error loading historical data: {str(e)}") import traceback logger.error(traceback.format_exc()) - def get_candles(self, timeframe: str, count: int = 500) -> pd.DataFrame: - """ - Get candles for a specific timeframe. This is an alias for get_recent_candles - to maintain compatibility with code that expects this method name. +class Position: + """Class representing a trading position""" + def __init__(self, action, entry_price, amount, timestamp=None, trade_id=None): + self.action = action # BUY or SELL + self.entry_price = entry_price + self.amount = amount + self.timestamp = timestamp or datetime.now() + self.status = "OPEN" # OPEN or CLOSED + self.exit_price = None + self.exit_timestamp = None + self.pnl = 0.0 + self.trade_id = trade_id or str(uuid.uuid4()) - Args: - timeframe: The timeframe interval ('1s', '1m', '1h', '1d') - count: Maximum number of candles to return + def close(self, exit_price, exit_timestamp=None): + """Close the position with an exit price""" + self.exit_price = exit_price + self.exit_timestamp = exit_timestamp or datetime.now() + self.status = "CLOSED" + + # Calculate PnL + if self.action == "BUY": + self.pnl = (self.exit_price - self.entry_price) * self.amount + else: # SELL + self.pnl = (self.entry_price - self.exit_price) * self.amount - Returns: - DataFrame containing the candles - """ - try: - logger.debug(f"Getting {count} candles for {timeframe} via get_candles()") - return self.get_recent_candles(timeframe, count) - except Exception as e: - logger.error(f"Error in get_candles for {timeframe}: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - return pd.DataFrame() + return self.pnl class RealTimeChart: """Real-time chart using Dash and Plotly""" - def __init__(self, symbol="BTC/USDT", use_sample_data=False, log_no_ticks_warning=True): - """Initialize a new RealTimeChart - - Args: - symbol: Trading pair symbol (e.g., BTC/USDT) - use_sample_data: Whether to use sample data when no real data is available - log_no_ticks_warning: Whether to log warnings when no ticks are available - """ + def __init__(self, symbol, data_path=None, historical_data=None, exchange=None, timeframe='1m'): + """Initialize the RealTimeChart class""" self.symbol = symbol - self.use_sample_data = use_sample_data + self.exchange = exchange + self.app = dash.Dash(__name__, external_stylesheets=[dbc.themes.DARKLY]) + self.tick_storage = TickStorage() + self.historical_data = historical_data + self.data_path = data_path + self.current_interval = '1m' # Default interval + self.fig = None # Will hold the main chart figure + self.positions = [] # List to hold position objects + self.balance = 1000.0 # Starting balance + self.last_action = None # Last trading action - # Initialize variables for trading info display - self.current_signal = 'HOLD' - self.signal_time = datetime.now() - self.current_position = 0.0 - self.session_balance = 100.0 # Start with $100 balance - self.session_pnl = 0.0 - - # Initialize NN signals and trades lists - self.nn_signals = [] - self.trades = [] - - # Use existing timezone variable instead of trying to detect again - logger.info(f"Using timezone: {tz_name}") - - # Initialize tick storage - logger.info(f"Initializing RealTimeChart for {symbol}") - self.tick_storage = TradeTickStorage( - symbol=symbol, - max_age_seconds=3600, # Keep ticks for 1 hour - use_sample_data=use_sample_data, - log_no_ticks_warning=log_no_ticks_warning - ) - - # Initialize candlestick data for backward compatibility - self.candlestick_data = CandlestickData(max_length=5000) - - # Initialize candle cache - self.candle_cache = CandleCache(max_candles=5000) - - # Initialize OHLCV cache dictionaries for different timeframes - self.ohlcv_cache = { - '1s': None, - '5s': None, - '15s': None, - '60s': None, - '300s': None, - '900s': None, - '3600s': None, - '1m': None, - '5m': None, - '15m': None, - '1h': None, - '1d': None - } - - # Historical data handler - self.historical_data = BinanceHistoricalData() - - # Flag for first render to force data loading - self.first_render = True - - # Last time candles were saved to disk - self.last_cache_save_time = time.time() - - # Initialize Dash app - self.app = dash.Dash( - __name__, - external_stylesheets=[dbc.themes.DARKLY], - suppress_callback_exceptions=True, - meta_tags=[{"name": "viewport", "content": "width=device-width, initial-scale=1"}] - ) - - # Set up layout and callbacks self._setup_app_layout() + + # Run the app in a separate thread + threading.Thread(target=self._run_app, daemon=True).start() def _setup_app_layout(self): """Set up the app layout and callbacks""" @@ -1458,78 +539,100 @@ class RealTimeChart: self._setup_interval_callback(button_style, active_button_style) self._setup_chart_callback() self._setup_position_list_callback() + self._setup_trading_status_callback() # We've removed the ticks callback, so don't call it # self._setup_ticks_callback() def _get_chart_layout(self, button_style, active_button_style): - # Chart page layout + """Get the chart layout""" return html.Div([ - # Chart title and interval buttons + # Trading stats header at the top html.Div([ - html.H2(f"{self.symbol} Real-Time Chart", style={ - 'textAlign': 'center', - 'color': '#FFFFFF', - 'marginBottom': '10px' - }), - - # Store interval data - dcc.Store(id='interval-store', data={'interval': 1}), - - # Interval selection buttons html.Div([ - html.Button('1s', id='btn-1s', n_clicks=0, style=active_button_style), - html.Button('5s', id='btn-5s', n_clicks=0, style=button_style), - html.Button('15s', id='btn-15s', n_clicks=0, style=button_style), - html.Button('30s', id='btn-30s', n_clicks=0, style=button_style), - html.Button('1m', id='btn-1m', n_clicks=0, style=button_style), + html.Div([ + html.Span("Signal: ", style={'fontWeight': 'bold', 'marginRight': '5px'}), + html.Span("NONE", id='current-signal-value', style={'color': 'white'}) + ], style={'marginRight': '20px', 'display': 'inline-block'}), + html.Div([ + html.Span("Position: ", style={'fontWeight': 'bold', 'marginRight': '5px'}), + html.Span("NONE", id='current-position-value', style={'color': 'white'}) + ], style={'marginRight': '20px', 'display': 'inline-block'}), + html.Div([ + html.Span("Balance: ", style={'fontWeight': 'bold', 'marginRight': '5px'}), + html.Span("$0.00", id='current-balance-value', style={'color': 'white'}) + ], style={'marginRight': '20px', 'display': 'inline-block'}), + html.Div([ + html.Span("Session PnL: ", style={'fontWeight': 'bold', 'marginRight': '5px'}), + html.Span("$0.00", id='current-pnl-value', style={'color': 'white'}) + ], style={'display': 'inline-block'}) ], style={ - 'display': 'flex', - 'justifyContent': 'center', - 'marginBottom': '15px' - }), - - # Interval component for updates - set to refresh every 500ms - dcc.Interval( - id='interval-component', - interval=300, # Refresh every 300ms for better real-time updates - n_intervals=0 - ), - - # Main chart - dcc.Graph(id='live-chart', style={'height': '75vh'}), - - # Last 5 Positions component - html.Div([ - html.H4("Last 5 Positions", style={'textAlign': 'center', 'color': '#FFFFFF'}), - html.Table([ - html.Thead( - html.Tr([ - html.Th("Action", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}), - html.Th("Size", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}), - html.Th("Entry Price", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}), - html.Th("Exit Price", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}), - html.Th("PnL", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}), - html.Th("Time", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}) - ]) - ), - html.Tbody(id='position-table-body') - ], style={ - 'width': '100%', - 'borderCollapse': 'collapse', - 'color': '#FFFFFF', - 'backgroundColor': '#222' - }) - ], style={'marginTop': '20px', 'marginBottom': '20px', 'overflowX': 'auto'}), - - # Chart acknowledgment - html.Div("Real-time trading chart with ML signals", style={ - 'textAlign': 'center', - 'color': '#AAAAAA', - 'fontSize': '12px', - 'marginTop': '5px' + 'padding': '10px', + 'backgroundColor': '#222222', + 'borderRadius': '5px', + 'marginBottom': '10px', + 'border': '1px solid #444444' }) - ]) - ]) + ]), + + # Recent Trades Table (compact at top) + html.Div([ + html.H4("Recent Trades", style={'color': 'white', 'margin': '5px 0', 'fontSize': '14px'}), + html.Table([ + html.Thead(html.Tr([ + html.Th("Status", style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px', 'fontWeight': 'bold', 'backgroundColor': '#333333'}), + html.Th("Amount", style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px', 'fontWeight': 'bold', 'backgroundColor': '#333333'}), + html.Th("Entry Price", style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px', 'fontWeight': 'bold', 'backgroundColor': '#333333'}), + html.Th("Exit Price", style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px', 'fontWeight': 'bold', 'backgroundColor': '#333333'}), + html.Th("PnL", style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px', 'fontWeight': 'bold', 'backgroundColor': '#333333'}), + html.Th("Time", style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px', 'fontWeight': 'bold', 'backgroundColor': '#333333'}) + ])), + html.Tbody(id='position-list', children=[ + html.Tr([html.Td("No positions yet", colSpan=6, style={'textAlign': 'center', 'padding': '4px', 'fontSize': '12px'})]) + ]) + ], style={ + 'width': '100%', + 'borderCollapse': 'collapse', + 'fontSize': '12px', + 'backgroundColor': '#222222', + 'color': 'white', + 'marginBottom': '10px' + }) + ], style={'marginBottom': '10px'}), + + # Chart area + dcc.Graph( + id='real-time-chart', + style={'height': 'calc(100vh - 250px)'}, # Adjusted to account for header and table + config={ + 'displayModeBar': True, + 'scrollZoom': True, + 'modeBarButtonsToRemove': ['lasso2d', 'select2d'] + } + ), + + # Interval selector + html.Div([ + html.Button('1m', id='1m-interval', n_clicks=0, style=active_button_style if self.current_interval == '1m' else button_style), + html.Button('5m', id='5m-interval', n_clicks=0, style=active_button_style if self.current_interval == '5m' else button_style), + html.Button('15m', id='15m-interval', n_clicks=0, style=active_button_style if self.current_interval == '15m' else button_style), + html.Button('1h', id='1h-interval', n_clicks=0, style=active_button_style if self.current_interval == '1h' else button_style), + html.Button('4h', id='4h-interval', n_clicks=0, style=active_button_style if self.current_interval == '4h' else button_style), + html.Button('1d', id='1d-interval', n_clicks=0, style=active_button_style if self.current_interval == '1d' else button_style), + ], style={'textAlign': 'center', 'marginTop': '10px'}), + + # Interval component for automatic updates + dcc.Interval( + id='chart-interval', + interval=300, # Refresh every 300ms for better real-time updates + n_intervals=0 + ) + ], style={ + 'backgroundColor': '#121212', + 'padding': '20px', + 'color': 'white', + 'height': '100vh', + 'boxSizing': 'border-box' + }) def _get_ticks_layout(self): # Ticks data page layout @@ -1615,906 +718,176 @@ class RealTimeChart: ]) def _setup_interval_callback(self, button_style, active_button_style): - # Callback to update interval based on button clicks and update button styles + """Set up the callback for interval selection buttons""" @self.app.callback( - [Output('interval-store', 'data'), - Output('btn-1s', 'style'), - Output('btn-5s', 'style'), - Output('btn-15s', 'style'), - Output('btn-30s', 'style'), - Output('btn-1m', 'style')], - [Input('btn-1s', 'n_clicks'), - Input('btn-5s', 'n_clicks'), - Input('btn-15s', 'n_clicks'), - Input('btn-30s', 'n_clicks'), - Input('btn-1m', 'n_clicks')], - [dash.dependencies.State('interval-store', 'data')] + [ + Output('1m-interval', 'style'), + Output('5m-interval', 'style'), + Output('15m-interval', 'style'), + Output('1h-interval', 'style'), + Output('4h-interval', 'style'), + Output('1d-interval', 'style') + ], + [ + Input('1m-interval', 'n_clicks'), + Input('5m-interval', 'n_clicks'), + Input('15m-interval', 'n_clicks'), + Input('1h-interval', 'n_clicks'), + Input('4h-interval', 'n_clicks'), + Input('1d-interval', 'n_clicks') + ] ) - def update_interval(n1, n5, n15, n30, n60, data): - ctx = dash.callback_context - if not ctx.triggered: - # Default state (1s selected) - return ({'interval': 1}, - active_button_style, button_style, button_style, button_style, button_style) + def update_interval_buttons(n1, n5, n15, n1h, n4h, n1d): + ctx = callback_context + # Default styles (all inactive) + styles = { + '1m': button_style.copy(), + '5m': button_style.copy(), + '15m': button_style.copy(), + '1h': button_style.copy(), + '4h': button_style.copy(), + '1d': button_style.copy() + } + + # If no button clicked yet, use default interval + if not ctx.triggered: + styles[self.current_interval] = active_button_style.copy() + return [styles['1m'], styles['5m'], styles['15m'], styles['1h'], styles['4h'], styles['1d']] + + # Get the button ID that was clicked button_id = ctx.triggered[0]['prop_id'].split('.')[0] - if button_id == 'btn-1s': - return ({'interval': 1}, - active_button_style, button_style, button_style, button_style, button_style) - elif button_id == 'btn-5s': - return ({'interval': 5}, - button_style, active_button_style, button_style, button_style, button_style) - elif button_id == 'btn-15s': - return ({'interval': 15}, - button_style, button_style, active_button_style, button_style, button_style) - elif button_id == 'btn-30s': - return ({'interval': 30}, - button_style, button_style, button_style, active_button_style, button_style) - elif button_id == 'btn-1m': - return ({'interval': 60}, - button_style, button_style, button_style, button_style, active_button_style) + # Map button ID to interval + interval_map = { + '1m-interval': '1m', + '5m-interval': '5m', + '15m-interval': '15m', + '1h-interval': '1h', + '4h-interval': '4h', + '1d-interval': '1d' + } - # Default case - keep current interval and highlight appropriate button - current_interval = data.get('interval', 1) - styles = [button_style] * 5 # All inactive by default + # Update the current interval based on clicked button + self.current_interval = interval_map.get(button_id, self.current_interval) - # Set active style based on current interval - if current_interval == 1: - styles[0] = active_button_style - elif current_interval == 5: - styles[1] = active_button_style - elif current_interval == 15: - styles[2] = active_button_style - elif current_interval == 30: - styles[3] = active_button_style - elif current_interval == 60: - styles[4] = active_button_style + # Set active style for selected interval + styles[self.current_interval] = active_button_style.copy() - return (data, *styles) + # Update the chart with the new interval + self._update_chart() + + return [styles['1m'], styles['5m'], styles['15m'], styles['1h'], styles['4h'], styles['1d']] def _setup_chart_callback(self): - # Callback to update the chart + """Set up the callback for the chart updates""" @self.app.callback( - Output('live-chart', 'figure'), - [Input('interval-component', 'n_intervals'), - Input('interval-store', 'data')] + Output('real-time-chart', 'figure'), + [Input('chart-interval', 'n_intervals')] ) - def update_chart(n, interval_data): + def update_chart(n_intervals): try: - interval = interval_data.get('interval', 1) - logger.debug(f"Updating chart with interval {interval}") - - # Get candlesticks data for the selected interval - try: - df = self.tick_storage.get_candles(interval_seconds=interval) - if df.empty and self.ohlcv_cache[f'{interval}s' if interval < 60 else '1m'] is not None: - df = self.ohlcv_cache[f'{interval}s' if interval < 60 else '1m'] - except Exception as e: - logger.error(f"Error getting candles: {str(e)}") - df = pd.DataFrame() - - # Get data for other timeframes (1m, 1h, 1d) - df_1m = None - df_1h = None - df_1d = None - - try: - # Get 1m candles - if self.ohlcv_cache['1m'] is not None and not self.ohlcv_cache['1m'].empty: - df_1m = self.ohlcv_cache['1m'].copy() - else: - df_1m = self.tick_storage.get_candles(interval_seconds=60) - - # Get 1h candles - if self.ohlcv_cache['1h'] is not None and not self.ohlcv_cache['1h'].empty: - df_1h = self.ohlcv_cache['1h'].copy() - else: - df_1h = self.tick_storage.get_candles(interval_seconds=3600) - - # Get 1d candles - if self.ohlcv_cache['1d'] is not None and not self.ohlcv_cache['1d'].empty: - df_1d = self.ohlcv_cache['1d'].copy() - else: - df_1d = self.tick_storage.get_candles(interval_seconds=86400) - - # Limit the number of candles to display but show more for context - if df_1m is not None and not df_1m.empty: - df_1m = df_1m.tail(600) # Show 600 1m candles for better context - 10 hours - if df_1h is not None and not df_1h.empty: - df_1h = df_1h.tail(480) # Show 480 hours of hourly data for better context - 20 days - if df_1d is not None and not df_1d.empty: - df_1d = df_1d.tail(365*2) # Show 2 years of daily data for better context - - except Exception as e: - logger.error(f"Error getting additional timeframes: {str(e)}") - - # Create layout with 5 rows (main chart, volume, 1m, 1h, 1d) - fig = make_subplots( - rows=5, cols=1, - vertical_spacing=0.03, - subplot_titles=( - f'{self.symbol} Price ({interval}s)', - 'Volume', - '1-Minute Chart', - '1-Hour Chart', - '1-Day Chart' - ), - row_heights=[0.4, 0.15, 0.15, 0.15, 0.15] - ) - - # Add candlestick chart for main timeframe - if not df.empty and 'open' in df.columns: - # Candlestick chart - fig.add_trace( - go.Candlestick( - x=df.index, - open=df['open'], - high=df['high'], - low=df['low'], - close=df['close'], - name='OHLC' - ), - row=1, col=1 - ) - - # Calculate y-axis range with padding - low_min = df['low'].min() - high_max = df['high'].max() - price_range = high_max - low_min - y_min = low_min - (price_range * 0.05) # 5% padding below - y_max = high_max + (price_range * 0.05) # 5% padding above - - # Set y-axis range to ensure candles are visible - fig.update_yaxes(range=[y_min, y_max], row=1, col=1) - - # Volume bars - colors = ['rgba(0,255,0,0.7)' if close >= open else 'rgba(255,0,0,0.7)' - for open, close in zip(df['open'], df['close'])] - - fig.add_trace( - go.Bar( - x=df.index, - y=df['volume'], - name='Volume', - marker_color=colors - ), - row=2, col=1 - ) - - # Add buy/sell markers and PnL annotations to the main chart - if hasattr(self, 'trades') and self.trades: - buy_times = [] - buy_prices = [] - buy_markers = [] - sell_times = [] - sell_prices = [] - sell_markers = [] - - # Filter trades to only include recent ones (last 100) - recent_trades = self.trades[-100:] - - for trade in recent_trades: - # Convert timestamp to datetime if it's not already - trade_time = trade.get('timestamp') - if isinstance(trade_time, (int, float)): - trade_time = pd.to_datetime(trade_time, unit='ms') - - price = trade.get('price', 0) - pnl = trade.get('pnl', None) - action = trade.get('action', 'SELL') # Default to SELL - - if action == 'BUY': - buy_times.append(trade_time) - buy_prices.append(price) - buy_markers.append("") - elif action == 'SELL': - sell_times.append(trade_time) - sell_prices.append(price) - # Add PnL as marker text if available - if pnl is not None: - pnl_text = f"{pnl:.4f}" if abs(pnl) < 0.01 else f"{pnl:.2f}" - sell_markers.append(pnl_text) - else: - sell_markers.append("") - - # Add buy markers - if buy_times: - fig.add_trace( - go.Scatter( - x=buy_times, - y=buy_prices, - mode='markers', - name='Buy', - marker=dict( - symbol='triangle-up', - size=12, - color='rgba(0,255,0,0.8)', - line=dict(width=1, color='darkgreen') - ), - text=buy_markers, - hoverinfo='x+y+text', - showlegend=True # Ensure Buy appears in legend - ), - row=1, col=1 - ) - - # Add vertical and horizontal connecting lines for buys - for i, (btime, bprice) in enumerate(zip(buy_times, buy_prices)): - # Add vertical dashed line to time axis - fig.add_shape( - type="line", - x0=btime, x1=btime, - y0=y_min, y1=bprice, - line=dict(color="rgba(0,255,0,0.5)", width=1, dash="dash"), - row=1, col=1 - ) - # Add horizontal dashed line showing the price level - fig.add_shape( - type="line", - x0=df.index.min(), x1=btime, - y0=bprice, y1=bprice, - line=dict(color="rgba(0,255,0,0.5)", width=1, dash="dash"), - row=1, col=1 - ) - - # Add sell markers with PnL annotations - if sell_times: - fig.add_trace( - go.Scatter( - x=sell_times, - y=sell_prices, - mode='markers+text', - name='Sell', - marker=dict( - symbol='triangle-down', - size=12, - color='rgba(255,0,0,0.8)', - line=dict(width=1, color='darkred') - ), - text=sell_markers, - textposition='top center', - textfont=dict(size=10), - hoverinfo='x+y+text', - showlegend=True # Ensure Sell appears in legend - ), - row=1, col=1 - ) - - # Add vertical and horizontal connecting lines for sells - for i, (stime, sprice) in enumerate(zip(sell_times, sell_prices)): - # Add vertical dashed line to time axis - fig.add_shape( - type="line", - x0=stime, x1=stime, - y0=y_min, y1=sprice, - line=dict(color="rgba(255,0,0,0.5)", width=1, dash="dash"), - row=1, col=1 - ) - # Add horizontal dashed line showing the price level - fig.add_shape( - type="line", - x0=df.index.min(), x1=stime, - y0=sprice, y1=sprice, - line=dict(color="rgba(255,0,0,0.5)", width=1, dash="dash"), - row=1, col=1 - ) - - # Add connecting lines between consecutive buy-sell pairs - if len(buy_times) > 0 and len(sell_times) > 0: - # Create pairs of buy-sell trades based on timestamps - pairs = [] - buys_copy = list(zip(buy_times, buy_prices)) - - for i, (stime, sprice) in enumerate(zip(sell_times, sell_prices)): - # Find the most recent buy before this sell - matching_buy = None - for j, (btime, bprice) in enumerate(buys_copy): - if btime < stime: - matching_buy = (btime, bprice) - buys_copy.pop(j) # Remove this buy to prevent reuse - break - - if matching_buy: - pairs.append((matching_buy, (stime, sprice))) - - # Add connecting lines for each pair - for (btime, bprice), (stime, sprice) in pairs: - # Draw line connecting the buy and sell points - fig.add_shape( - type="line", - x0=btime, x1=stime, - y0=bprice, y1=sprice, - line=dict( - color="rgba(255,255,255,0.5)", - width=1, - dash="dot" - ), - row=1, col=1 - ) - - # Add 1m chart - if df_1m is not None and not df_1m.empty and 'open' in df_1m.columns: - fig.add_trace( - go.Candlestick( - x=df_1m.index, - open=df_1m['open'], - high=df_1m['high'], - low=df_1m['low'], - close=df_1m['close'], - name='1m', - showlegend=False - ), - row=3, col=1 - ) - - # Set appropriate date format for 1m chart - fig.update_xaxes( - title_text="", - row=3, - col=1, - tickformat="%H:%M", - tickmode="auto", - nticks=12 - ) - - # Add buy/sell markers to 1m chart if they fall within the visible timeframe - if hasattr(self, 'trades') and self.trades: - # Filter trades visible in 1m timeframe - min_time = df_1m.index.min() - max_time = df_1m.index.max() - - # Ensure min_time and max_time are pandas.Timestamp objects - if isinstance(min_time, (int, float)): - min_time = pd.to_datetime(min_time, unit='ms') - if isinstance(max_time, (int, float)): - max_time = pd.to_datetime(max_time, unit='ms') - - # Collect only trades within this timeframe - minute_buy_times = [] - minute_buy_prices = [] - minute_sell_times = [] - minute_sell_prices = [] - - for trade in self.trades[-100:]: - trade_time = trade.get('timestamp') - if isinstance(trade_time, (int, float)): - # Convert numeric timestamp to datetime - trade_time = pd.to_datetime(trade_time, unit='ms') - elif not isinstance(trade_time, pd.Timestamp) and not isinstance(trade_time, datetime): - # Skip trades with invalid timestamp format - continue - - # Check if trade falls within 1m chart timeframe - try: - if min_time <= trade_time <= max_time: - price = trade.get('price', 0) - action = trade.get('action', 'SELL') - - if action == 'BUY': - minute_buy_times.append(trade_time) - minute_buy_prices.append(price) - elif action == 'SELL': - minute_sell_times.append(trade_time) - minute_sell_prices.append(price) - except TypeError: - # If comparison fails due to type mismatch, log the error and skip this trade - logger.warning(f"Type mismatch in timestamp comparison: min_time={type(min_time)}, trade_time={type(trade_time)}") - continue - - # Add buy markers to 1m chart - if minute_buy_times: - fig.add_trace( - go.Scatter( - x=minute_buy_times, - y=minute_buy_prices, - mode='markers', - name='Buy (1m)', - marker=dict( - symbol='triangle-up', - size=8, - color='rgba(0,255,0,0.8)', - line=dict(width=1, color='darkgreen') - ), - showlegend=False, - hoverinfo='x+y' - ), - row=3, col=1 - ) - - # Add sell markers to 1m chart - if minute_sell_times: - fig.add_trace( - go.Scatter( - x=minute_sell_times, - y=minute_sell_prices, - mode='markers', - name='Sell (1m)', - marker=dict( - symbol='triangle-down', - size=8, - color='rgba(255,0,0,0.8)', - line=dict(width=1, color='darkred') - ), - showlegend=False, - hoverinfo='x+y' - ), - row=3, col=1 - ) - - # Add 1h chart - if df_1h is not None and not df_1h.empty and 'open' in df_1h.columns: - fig.add_trace( - go.Candlestick( - x=df_1h.index, - open=df_1h['open'], - high=df_1h['high'], - low=df_1h['low'], - close=df_1h['close'], - name='1h', - showlegend=False - ), - row=4, col=1 - ) - - # Set appropriate date format for 1h chart - fig.update_xaxes( - title_text="", - row=4, - col=1, - tickformat="%m-%d %H:%M", - tickmode="auto", - nticks=8 - ) - - # Add buy/sell markers to 1h chart if they fall within the visible timeframe - if hasattr(self, 'trades') and self.trades: - # Filter trades visible in 1h timeframe - min_time = df_1h.index.min() - max_time = df_1h.index.max() - - # Ensure min_time and max_time are pandas.Timestamp objects - if isinstance(min_time, (int, float)): - min_time = pd.to_datetime(min_time, unit='ms') - if isinstance(max_time, (int, float)): - max_time = pd.to_datetime(max_time, unit='ms') - - # Collect only trades within this timeframe - hour_buy_times = [] - hour_buy_prices = [] - hour_sell_times = [] - hour_sell_prices = [] - - for trade in self.trades[-200:]: # Check more trades for longer timeframe - trade_time = trade.get('timestamp') - if isinstance(trade_time, (int, float)): - # Convert numeric timestamp to datetime - trade_time = pd.to_datetime(trade_time, unit='ms') - elif not isinstance(trade_time, pd.Timestamp) and not isinstance(trade_time, datetime): - # Skip trades with invalid timestamp format - continue - - # Check if trade falls within 1h chart timeframe - try: - if min_time <= trade_time <= max_time: - price = trade.get('price', 0) - action = trade.get('action', 'SELL') - - if action == 'BUY': - hour_buy_times.append(trade_time) - hour_buy_prices.append(price) - elif action == 'SELL': - hour_sell_times.append(trade_time) - hour_sell_prices.append(price) - except TypeError: - # If comparison fails due to type mismatch, log the error and skip this trade - logger.warning(f"Type mismatch in timestamp comparison: min_time={type(min_time)}, trade_time={type(trade_time)}") - continue - - # Add buy markers to 1h chart - if hour_buy_times: - fig.add_trace( - go.Scatter( - x=hour_buy_times, - y=hour_buy_prices, - mode='markers', - name='Buy (1h)', - marker=dict( - symbol='triangle-up', - size=6, - color='rgba(0,255,0,0.8)', - line=dict(width=1, color='darkgreen') - ), - showlegend=False, - hoverinfo='x+y' - ), - row=4, col=1 - ) - - # Add sell markers to 1h chart - if hour_sell_times: - fig.add_trace( - go.Scatter( - x=hour_sell_times, - y=hour_sell_prices, - mode='markers', - name='Sell (1h)', - marker=dict( - symbol='triangle-down', - size=6, - color='rgba(255,0,0,0.8)', - line=dict(width=1, color='darkred') - ), - showlegend=False, - hoverinfo='x+y' - ), - row=4, col=1 - ) - # Add 1d chart - if df_1d is not None and not df_1d.empty and 'open' in df_1d.columns: - fig.add_trace( - go.Candlestick( - x=df_1d.index, - open=df_1d['open'], - high=df_1d['high'], - low=df_1d['low'], - close=df_1d['close'], - name='1d', - showlegend=False - ), - row=5, col=1 - ) - - # Set appropriate date format for 1d chart - fig.update_xaxes( - title_text="", - row=5, - col=1, - tickformat="%Y-%m-%d", - tickmode="auto", - nticks=10 - ) - - # Add buy/sell markers to 1d chart if they fall within the visible timeframe - if hasattr(self, 'trades') and self.trades: - # Filter trades visible in 1d timeframe - min_time = df_1d.index.min() - max_time = df_1d.index.max() - - # Ensure min_time and max_time are pandas.Timestamp objects - if isinstance(min_time, (int, float)): - min_time = pd.to_datetime(min_time, unit='ms') - if isinstance(max_time, (int, float)): - max_time = pd.to_datetime(max_time, unit='ms') - - # Collect only trades within this timeframe - day_buy_times = [] - day_buy_prices = [] - day_sell_times = [] - day_sell_prices = [] - - for trade in self.trades[-300:]: # Check more trades for daily timeframe - trade_time = trade.get('timestamp') - if isinstance(trade_time, (int, float)): - # Convert numeric timestamp to datetime - trade_time = pd.to_datetime(trade_time, unit='ms') - elif not isinstance(trade_time, pd.Timestamp) and not isinstance(trade_time, datetime): - # Skip trades with invalid timestamp format - continue - - # Check if trade falls within 1d chart timeframe - try: - if min_time <= trade_time <= max_time: - price = trade.get('price', 0) - action = trade.get('action', 'SELL') - - if action == 'BUY': - day_buy_times.append(trade_time) - day_buy_prices.append(price) - elif action == 'SELL': - day_sell_times.append(trade_time) - day_sell_prices.append(price) - except TypeError: - # If comparison fails due to type mismatch, log the error and skip this trade - logger.warning(f"Type mismatch in timestamp comparison: min_time={type(min_time)}, trade_time={type(trade_time)}") - continue - - # Add buy markers to 1d chart - if day_buy_times: - fig.add_trace( - go.Scatter( - x=day_buy_times, - y=day_buy_prices, - mode='markers', - name='Buy (1d)', - marker=dict( - symbol='triangle-up', - size=5, - color='rgba(0,255,0,0.8)', - line=dict(width=1, color='darkgreen') - ), - showlegend=False, - hoverinfo='x+y' - ), - row=5, col=1 - ) - - # Add sell markers to 1d chart - if day_sell_times: - fig.add_trace( - go.Scatter( - x=day_sell_times, - y=day_sell_prices, - mode='markers', - name='Sell (1d)', - marker=dict( - symbol='triangle-down', - size=5, - color='rgba(255,0,0,0.8)', - line=dict(width=1, color='darkred') - ), - showlegend=False, - hoverinfo='x+y' - ), - row=5, col=1 - ) - - # Add trading info annotation if available - if hasattr(self, 'current_signal') and self.current_signal: - signal_color = "#33DD33" if self.current_signal == "BUY" else "#FF4444" if self.current_signal == "SELL" else "#BBBBBB" - - # Format position value - position_text = f"{self.current_position:.4f}" if self.current_position < 0.01 else f"{self.current_position:.2f}" - - # Format PnL with color based on value - pnl_color = "#33DD33" if self.session_pnl >= 0 else "#FF4444" - pnl_text = f"{self.session_pnl:.4f}" - - # Create trading info text - info_text = ( - f"Signal: {self.current_signal} | " - f"Position: {position_text} | " - f"Balance: ${self.session_balance:.2f} | " - f"PnL: {pnl_text}" - ) - - # Add annotation - fig.add_annotation( - x=0.5, y=1.05, - xref="paper", yref="paper", - text=info_text, - showarrow=False, - font=dict(size=14, color="white"), - bgcolor="rgba(50,50,50,0.6)", - borderwidth=1, - borderpad=6, - align="center" - ) - - # Update layout - fig.update_layout( - title_text=f"{self.symbol} Real-Time Data", - title_x=0.5, - xaxis_rangeslider_visible=False, - height=1000, # Increased height to accommodate all charts - template='plotly_dark', - paper_bgcolor='rgba(0,0,0,0)', - plot_bgcolor='rgba(25,25,50,1)' - ) - - # Update axes styling for all subplots - fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)') - fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)') - - # Hide rangesliders for all candlestick charts - fig.update_layout( - xaxis_rangeslider_visible=False, - xaxis2_rangeslider_visible=False, - xaxis3_rangeslider_visible=False, - xaxis4_rangeslider_visible=False, - xaxis5_rangeslider_visible=False - ) - - # Improve date formatting for the main chart - fig.update_xaxes( - title_text="", - row=1, - col=1, - tickformat="%H:%M:%S", - tickmode="auto", - nticks=15 - ) - - return fig + # Create the main figure if it doesn't exist yet + if self.fig is None: + self._initialize_chart() + + # Update the chart data + self._update_chart() + return self.fig except Exception as e: logger.error(f"Error updating chart: {str(e)}") import traceback logger.error(traceback.format_exc()) - fig = go.Figure() - fig.add_annotation( - x=0.5, y=0.5, - text=f"Error updating chart: {str(e)}", - showarrow=False, - font=dict(size=14, color="red"), - xref="paper", yref="paper" - ) - return fig + + # Return empty figure on error + return { + 'data': [], + 'layout': { + 'title': 'Error updating chart', + 'annotations': [{ + 'text': str(e), + 'showarrow': False, + 'font': {'color': 'red'} + }] + } + } def _setup_position_list_callback(self): - """Callback to update the position list with the latest 5 trades""" + """Set up the callback for the position list""" @self.app.callback( - Output('position-table-body', 'children'), - [Input('interval-component', 'n_intervals')] + Output('position-list', 'children'), + [Input('chart-interval', 'n_intervals')] ) def update_position_list(n): - try: - # Get the last 5 positions - if not hasattr(self, 'trades') or not self.trades: - return [html.Tr([html.Td("No trades yet", colSpan=6, style={'textAlign': 'center', 'color': '#AAAAAA', 'padding': '10px'})])] - - # Collect BUY and SELL trades to pair them - buy_trades = {} - sell_trades = {} - position_rows = [] - - # Process all trades to match buy and sell pairs - for trade in self.trades: - action = trade.get('action', 'UNKNOWN') - if action == 'BUY': - # Store buy trades by timestamp to match with sells later - buy_trades[trade.get('timestamp')] = trade - elif action == 'SELL': - # Store sell trades - sell_trades[trade.get('timestamp')] = trade - - # Create position entries (matched pairs or single trades) - position_entries = [] - - # First add matched pairs (BUY trades with matching close_price) - for timestamp, buy_trade in buy_trades.items(): - if 'close_price' in buy_trade: - # This is a closed position - entry_price = buy_trade.get('price', 'N/A') - exit_price = buy_trade.get('close_price', 'N/A') - entry_time = buy_trade.get('timestamp') - exit_time = buy_trade.get('close_timestamp') - amount = buy_trade.get('amount', 0.1) - - # Calculate PnL if not provided - pnl = buy_trade.get('pnl') - if pnl is None and isinstance(entry_price, (int, float)) and isinstance(exit_price, (int, float)): - pnl = (exit_price - entry_price) * amount - - position_entries.append({ - 'action': 'CLOSED', - 'entry_price': entry_price, - 'exit_price': exit_price, - 'amount': amount, - 'pnl': pnl, - 'time': entry_time, - 'exit_time': exit_time, - 'status': 'CLOSED' - }) - else: - # This is an open position (BUY without a matching SELL) - current_price = self.tick_storage.get_latest_price() or buy_trade.get('price', 0) - amount = buy_trade.get('amount', 0.1) - entry_price = buy_trade.get('price', 0) - - # Calculate unrealized PnL - unrealized_pnl = (current_price - entry_price) * amount if isinstance(current_price, (int, float)) and isinstance(entry_price, (int, float)) else None - - position_entries.append({ - 'action': 'BUY', - 'entry_price': entry_price, - 'exit_price': current_price, # Use current price as potential exit - 'amount': amount, - 'pnl': unrealized_pnl, - 'time': buy_trade.get('timestamp'), - 'status': 'OPEN' - }) - - # Add standalone SELL trades that don't have a matching BUY - for timestamp, sell_trade in sell_trades.items(): - # Check if this SELL is already accounted for in a closed position - already_matched = False - for entry in position_entries: - if entry.get('status') == 'CLOSED' and entry.get('exit_time') == timestamp: - already_matched = True - break - - if not already_matched: - position_entries.append({ - 'action': 'SELL', - 'entry_price': 'N/A', - 'exit_price': sell_trade.get('price', 'N/A'), - 'amount': sell_trade.get('amount', 0.1), - 'pnl': sell_trade.get('pnl'), - 'time': sell_trade.get('timestamp'), - 'status': 'STANDALONE' - }) - - # Sort by time (most recent first) and take last 5 - position_entries.sort(key=lambda x: x['time'] if isinstance(x['time'], datetime) else datetime.now(), reverse=True) - position_entries = position_entries[:5] - - # Convert to table rows - for entry in position_entries: - action = entry['action'] - entry_price = entry['entry_price'] - exit_price = entry['exit_price'] - amount = entry['amount'] - pnl = entry['pnl'] - time_obj = entry['time'] - status = entry['status'] - - # Format time - if isinstance(time_obj, datetime): - # If trade is from a different day, include the date - today = datetime.now().date() - if time_obj.date() == today: - time_str = time_obj.strftime('%H:%M:%S') - else: - time_str = time_obj.strftime('%m-%d %H:%M:%S') - else: - time_str = str(time_obj) - - # Format prices with proper decimal places - if isinstance(entry_price, (int, float)): - entry_price_str = f"${entry_price:.2f}" - else: - entry_price_str = str(entry_price) - - if isinstance(exit_price, (int, float)): - exit_price_str = f"${exit_price:.2f}" - else: - exit_price_str = str(exit_price) - - # Format PnL - if pnl is not None and isinstance(pnl, (int, float)): - pnl_str = f"${pnl:.2f}" - pnl_color = '#00FF00' if pnl >= 0 else '#FF0000' - else: - pnl_str = 'N/A' - pnl_color = '#FFFFFF' - - # Set action/status color and text - if status == 'OPEN': - status_color = '#00AAFF' # Blue for open positions - status_text = "OPEN (BUY)" - elif status == 'CLOSED': - if pnl is not None and isinstance(pnl, (int, float)): - status_color = '#00FF00' if pnl >= 0 else '#FF0000' # Green/Red based on profit - else: - status_color = '#FFCC00' # Yellow if PnL unknown - status_text = "CLOSED" - elif action == 'BUY': - status_color = '#00FF00' - status_text = "BUY" - elif action == 'SELL': - status_color = '#FF0000' - status_text = "SELL" - else: - status_color = '#FFFFFF' - status_text = action - - # Create table row - position_rows.append(html.Tr([ - html.Td(status_text, style={'color': status_color, 'padding': '8px', 'border': '1px solid #444'}), - html.Td(f"{amount} BTC", style={'padding': '8px', 'border': '1px solid #444'}), - html.Td(entry_price_str, style={'padding': '8px', 'border': '1px solid #444'}), - html.Td(exit_price_str, style={'padding': '8px', 'border': '1px solid #444'}), - html.Td(pnl_str, style={'color': pnl_color, 'padding': '8px', 'border': '1px solid #444'}), - html.Td(time_str, style={'padding': '8px', 'border': '1px solid #444'}) - ])) - - return position_rows - except Exception as e: - logger.error(f"Error updating position table: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - return [html.Tr([html.Td(f"Error: {str(e)}", colSpan=6, style={'color': '#FF0000', 'padding': '10px'})])] + if not self.positions: + return [html.Tr([html.Td("No positions yet", colSpan=6)])] + return self._get_position_list_rows() + + def _setup_trading_status_callback(self): + """Set up the callback for the trading status fields""" + @self.app.callback( + [ + Output('current-signal-value', 'children'), + Output('current-position-value', 'children'), + Output('current-balance-value', 'children'), + Output('current-pnl-value', 'children'), + Output('current-signal-value', 'style'), + Output('current-position-value', 'style') + ], + [Input('chart-interval', 'n_intervals')] + ) + def update_trading_status(n): + # Get the current signal + current_signal = "NONE" + signal_style = {'color': 'white'} + + if hasattr(self, 'last_action') and self.last_action: + current_signal = self.last_action + if current_signal == "BUY": + signal_style = {'color': 'green', 'fontWeight': 'bold'} + elif current_signal == "SELL": + signal_style = {'color': 'red', 'fontWeight': 'bold'} + + # Get the current position + current_position = "NONE" + position_style = {'color': 'white'} + + # Check if we have any open positions + open_positions = [p for p in self.positions if p.status == "OPEN"] + if open_positions: + current_position = f"{open_positions[0].action} {open_positions[0].amount:.4f}" + if open_positions[0].action == "BUY": + position_style = {'color': 'green', 'fontWeight': 'bold'} + else: + position_style = {'color': 'red', 'fontWeight': 'bold'} + + # Get the current balance and session PnL + current_balance = f"${self.balance:.2f}" if hasattr(self, 'balance') else "$0.00" + + # Calculate session PnL + session_pnl = 0 + for position in self.positions: + if position.status == "CLOSED": + session_pnl += position.pnl + + # Format PnL with color + pnl_text = f"${session_pnl:.2f}" + + return current_signal, current_position, current_balance, pnl_text, signal_style, position_style + + def _add_manual_trade_inputs(self): + # Add manual trade inputs + self.app.layout.children.append( + html.Div([ + html.H3("Add Manual Trade"), + dcc.Input(id='manual-price', type='number', placeholder='Price'), + dcc.Input(id='manual-volume', type='number', placeholder='Volume'), + dcc.Input(id='manual-pnl', type='number', placeholder='PnL'), + dcc.Input(id='manual-action', type='text', placeholder='Action'), + html.Button('Add Trade', id='add-manual-trade') + ]) + ) def _interval_to_seconds(self, interval_key: str) -> int: """Convert interval key to seconds""" @@ -2649,356 +1022,104 @@ class RealTimeChart: logger.info(f"Waiting 5 seconds before reconnecting {self.symbol} WebSocket...") await asyncio.sleep(5) - def run(self, host='localhost', port=8050): - """Run the Dash app - - Args: - host: Hostname to run on - port: Port to run on - """ - logger.info(f"Starting Dash app on {host}:{port}") - - # Ensure interval component is created - if not hasattr(self, 'app') or not self.app.layout: - logger.error("App layout not initialized properly") - return - - # If interval-component is not in the layout, add it - if 'interval-component' not in str(self.app.layout): - logger.warning("Interval component not found in layout, adding it") - self.app.layout.children.append( - dcc.Interval( - id='interval-component', - interval=500, # 500ms for real-time updates - n_intervals=0 - ) - ) - - # Start websocket connection in a separate thread - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - self.websocket_thread = threading.Thread(target=lambda: asyncio.run(self.start_websocket())) - self.websocket_thread.daemon = True - self.websocket_thread.start() - - # Ensure historical data is loaded before starting - self._load_historical_data() - + def _run_app(self): + """Run the Dash app""" try: - self.app.run(host=host, port=port, debug=False) + logger.info(f"Starting Dash app for {self.symbol}") + # Updated to use app.run instead of app.run_server (which is deprecated) + self.app.run(debug=False, use_reloader=False, port=8050) except Exception as e: logger.error(f"Error running Dash app: {str(e)}") - finally: - # Ensure resources are cleaned up - self._save_candles_to_disk(force=True) - logger.info("Dash app stopped") + logger.error(traceback.format_exc()) + + return - def _load_historical_data(self): - """Load historical data for all timeframes from Binance API and local cache""" + def add_trade(self, price, timestamp=None, pnl=None, amount=0.1, action="BUY", trade_type="MARKET"): + """Add a trade to the chart + + Args: + price: Trade price + timestamp: Trade timestamp (datetime or milliseconds) + pnl: Profit and Loss (for SELL trades) + amount: Trade amount + action: Trade action (BUY or SELL) + trade_type: Trade type (MARKET, LIMIT, etc.) + """ try: - logger.info(f"Loading historical data for {self.symbol}...") - - # Define intervals to fetch - intervals = { - '1s': 1, - '1m': 60, - '1h': 3600, - '1d': 86400 - } - - # Track load status - load_status = {interval: False for interval in intervals.keys()} - - # First try to load from local cache files - logger.info("Step 1: Loading from local cache files...") - for interval_key, interval_seconds in intervals.items(): - try: - cache_file = os.path.join(self.historical_data.cache_dir, - f"{self.symbol.replace('/', '_')}_{interval_key}_candles.csv") - - logger.info(f"Checking for cached {interval_key} data at {cache_file}") - if os.path.exists(cache_file): - # Check if cache is fresh (less than 1 day old for anything but 1d, 3 days for 1d) - file_age = time.time() - os.path.getmtime(cache_file) - max_age = 259200 if interval_key == '1d' else 86400 # 3 days for 1d, 1 day for others - logger.info(f"Cache file age: {file_age:.1f}s, max allowed: {max_age}s") + # Convert timestamp to datetime if it's a number + if timestamp is None: + timestamp = datetime.now() + elif isinstance(timestamp, (int, float)): + timestamp = datetime.fromtimestamp(timestamp / 1000) + + # Process the trade based on action + if action == "BUY": + # Create a new position + position = Position( + action="BUY", + entry_price=price, + amount=amount, + timestamp=timestamp + ) + self.positions.append(position) + + # Update last action + self.last_action = "BUY" + + elif action == "SELL": + # Find an open BUY position to close, or create a new SELL position + open_buy_position = None + for pos in self.positions: + if pos.status == "OPEN" and pos.action == "BUY": + open_buy_position = pos + break - if file_age <= max_age: - logger.info(f"Loading {interval_key} candles from cache") - cached_df = pd.read_csv(cache_file) - if not cached_df.empty: - # Diagnostic info about the loaded data - logger.info(f"Loaded {len(cached_df)} candles from {cache_file}") - logger.info(f"Columns: {cached_df.columns.tolist()}") - logger.info(f"First few rows: {cached_df.head(2).to_dict('records')}") - - # Convert timestamp string back to datetime - if 'timestamp' in cached_df.columns: - try: - if not pd.api.types.is_datetime64_any_dtype(cached_df['timestamp']): - cached_df['timestamp'] = pd.to_datetime(cached_df['timestamp']) - logger.info("Successfully converted timestamps to datetime") - except Exception as e: - logger.warning(f"Could not convert timestamp column for {interval_key}: {str(e)}") - - # Only keep the last 2000 candles for memory efficiency - if len(cached_df) > 2000: - cached_df = cached_df.tail(2000) - logger.info(f"Truncated to last 2000 candles") - - # Add to cache - for _, row in cached_df.iterrows(): - candle_dict = row.to_dict() - self.candle_cache.candles[interval_key].append(candle_dict) - - # Update ohlcv_cache - self.ohlcv_cache[interval_key] = self.candle_cache.get_recent_candles(interval_key, count=2000) - logger.info(f"Successfully loaded {len(self.ohlcv_cache[interval_key])} cached {interval_key} candles") - - if len(self.ohlcv_cache[interval_key]) >= 500: - load_status[interval_key] = True - # Skip fetching from API if we loaded from cache (except for 1d timeframe which we always refresh) - if interval_key != '1d': - continue - else: - logger.info(f"Cache file for {interval_key} is too old ({file_age:.1f}s)") - else: - logger.info(f"No cache file found for {interval_key}") - except Exception as e: - logger.error(f"Error loading cached {interval_key} candles: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - - # For timeframes other than 1s, fetch from API as backup or for fresh data - logger.info("Step 2: Fetching data from API for missing timeframes...") - for interval_key, interval_seconds in intervals.items(): - # Skip 1s for API requests - if interval_key == '1s' or load_status[interval_key]: - logger.info(f"Skipping API fetch for {interval_key}: already loaded or 1s timeframe") - continue + if open_buy_position: + # Close the position + pnl_value = open_buy_position.close(price, timestamp) - # Fetch historical data from API - try: - logger.info(f"Fetching {interval_key} candles from API for {self.symbol}") - historical_df = self.historical_data.get_historical_candles( - symbol=self.symbol, - interval_seconds=interval_seconds, - limit=500 # Get 500 candles + # Update balance + self.balance += pnl_value + + # If pnl was provided, use it instead + if pnl is not None: + open_buy_position.pnl = pnl + self.balance = self.balance - pnl_value + pnl + + else: + # Create a standalone SELL position + position = Position( + action="SELL", + entry_price=price, + amount=amount, + timestamp=timestamp ) - if not historical_df.empty: - logger.info(f"Loaded {len(historical_df)} historical candles for {self.symbol} {interval_key} from API") + # Set it as closed with the same price + position.close(price, timestamp) + + # Set PnL if provided + if pnl is not None: + position.pnl = pnl + self.balance += pnl - # If we already have data in cache, merge with new data to avoid duplicates - if self.ohlcv_cache[interval_key] is not None and not self.ohlcv_cache[interval_key].empty: - existing_df = self.ohlcv_cache[interval_key] - # Get the latest timestamp from existing data - latest_time = existing_df['timestamp'].max() - # Only keep newer records from API - new_candles = historical_df[historical_df['timestamp'] > latest_time] - if not new_candles.empty: - logger.info(f"Adding {len(new_candles)} new candles to existing {interval_key} cache") - # Add to cache - for _, row in new_candles.iterrows(): - candle_dict = row.to_dict() - self.candle_cache.candles[interval_key].append(candle_dict) - else: - # No existing data, add all from API - for _, row in historical_df.iterrows(): - candle_dict = row.to_dict() - self.candle_cache.candles[interval_key].append(candle_dict) - - # Update ohlcv_cache with combined data - self.ohlcv_cache[interval_key] = self.candle_cache.get_recent_candles(interval_key, count=2000) - logger.info(f"Total {interval_key} candles in cache: {len(self.ohlcv_cache[interval_key])}") - - if len(self.ohlcv_cache[interval_key]) >= 500: - load_status[interval_key] = True - else: - logger.warning(f"No historical data available from API for {self.symbol} {interval_key}") - except Exception as e: - logger.error(f"Error fetching {interval_key} data from API: {str(e)}") - import traceback - logger.error(traceback.format_exc()) + self.positions.append(position) + + # Update last action + self.last_action = "SELL" - # Log summary of loaded data - logger.info("Historical data load summary:") - for interval_key in intervals.keys(): - count = len(self.ohlcv_cache[interval_key]) if self.ohlcv_cache[interval_key] is not None else 0 - status = "Success" if load_status[interval_key] else "Failed" - if count > 0 and count < 500: - status = "Partial" - logger.info(f"{interval_key}: {count} candles - {status}") + # Log the trade + logger.info(f"Added {action} trade: price={price}, amount={amount}, time={timestamp}, PnL={pnl}") + + # Trigger more frequent chart updates for immediate visibility + if hasattr(self, 'fig') and self.fig is not None: + self._update_chart() except Exception as e: - logger.error(f"Error in _load_historical_data: {str(e)}") + logger.error(f"Error adding trade: {str(e)}") import traceback logger.error(traceback.format_exc()) - def _save_candles_to_disk(self, force=False): - """Save current candle cache to disk for persistence between runs""" - try: - # Only save if we have data and sufficient time has passed (every 5 minutes) - current_time = time.time() - if not force and current_time - self.last_cache_save_time < 300: # 5 minutes - return - - # Save each timeframe's candles to disk - for interval_key, candles in self.candle_cache.candles.items(): - if candles: - # Convert to DataFrame - df = pd.DataFrame(list(candles)) - if not df.empty: - # Ensure timestamp is properly formatted - if 'timestamp' in df.columns: - try: - if not pd.api.types.is_datetime64_any_dtype(df['timestamp']): - df['timestamp'] = pd.to_datetime(df['timestamp']) - except: - logger.warning(f"Could not convert timestamp column for {interval_key}") - - # Save to disk in the cache directory - cache_file = os.path.join(self.historical_data.cache_dir, - f"{self.symbol.replace('/', '_')}_{interval_key}_candles.csv") - df.to_csv(cache_file, index=False) - logger.info(f"Saved {len(df)} {interval_key} candles to {cache_file}") - - self.last_cache_save_time = current_time - logger.info(f"Saved all candle caches to disk at {datetime.now()}") - except Exception as e: - logger.error(f"Error saving candles to disk: {str(e)}") - import traceback - logger.error(traceback.format_exc()) - - def add_nn_signal(self, signal_type, timestamp, probability=None): - """Add a neural network signal to be displayed on the chart - - Args: - signal_type: The type of signal (BUY, SELL, HOLD) - timestamp: The timestamp for the signal - probability: Optional probability/confidence value - """ - if signal_type not in ['BUY', 'SELL', 'HOLD']: - logger.warning(f"Invalid NN signal type: {signal_type}") - return - - # Convert timestamp to datetime if it's not already - if not isinstance(timestamp, datetime): - try: - if isinstance(timestamp, str): - timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00')) - elif isinstance(timestamp, (int, float)): - timestamp = datetime.fromtimestamp(timestamp / 1000.0) - except Exception as e: - logger.error(f"Error converting timestamp for NN signal: {str(e)}") - timestamp = datetime.now() - - # Add the signal to our list - self.nn_signals.append({ - 'type': signal_type, - 'timestamp': timestamp, - 'probability': probability, - 'added': datetime.now() - }) - - # Only keep the most recent 50 signals - if len(self.nn_signals) > 50: - self.nn_signals = self.nn_signals[-50:] - - logger.info(f"Added NN signal: {signal_type} at {timestamp}") - - def add_trade(self, price, timestamp, pnl=None, amount=0.1, action=None, type=None): - """Add a trade to be displayed on the chart - - Args: - price: The price at which the trade was executed - timestamp: The timestamp for the trade - pnl: Optional profit and loss value for the trade - amount: Amount traded - action: The type of trade (BUY or SELL) - alternative to type parameter - type: The type of trade (BUY or SELL) - alternative to action parameter - """ - # Handle both action and type parameters for backward compatibility - trade_type = type or action - - # Default to BUY if trade_type is None or not specified - if trade_type is None: - logger.warning(f"Trade type not specified in add_trade call, defaulting to BUY. Price: {price}, Timestamp: {timestamp}") - trade_type = "BUY" - - if isinstance(trade_type, int): - trade_type = "BUY" if trade_type == 0 else "SELL" - - # Ensure trade_type is uppercase if it's a string - if isinstance(trade_type, str): - trade_type = trade_type.upper() - - if trade_type not in ['BUY', 'SELL']: - logger.warning(f"Invalid trade type: {trade_type} (value type: {type(trade_type).__name__}), defaulting to BUY. Price: {price}, Timestamp: {timestamp}") - trade_type = "BUY" - - # Convert timestamp to datetime if it's not already - if not isinstance(timestamp, datetime): - try: - if isinstance(timestamp, str): - timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00')) - elif isinstance(timestamp, (int, float)): - timestamp = datetime.fromtimestamp(timestamp / 1000.0) - except Exception as e: - logger.error(f"Error converting timestamp for trade: {str(e)}") - timestamp = datetime.now() - - # Create the trade object - trade = { - 'price': price, - 'timestamp': timestamp, - 'pnl': pnl, - 'amount': amount, - 'action': trade_type - } - - # Add to our trades list - if not hasattr(self, 'trades'): - self.trades = [] - - # If this is a SELL trade, try to find the corresponding BUY trade and update it with close_price - if trade_type == 'SELL' and len(self.trades) > 0: - for i in range(len(self.trades) - 1, -1, -1): - prev_trade = self.trades[i] - if prev_trade.get('action') == 'BUY' and 'close_price' not in prev_trade: - # Found a BUY trade without a close_price, consider it the matching trade - prev_trade['close_price'] = price - prev_trade['close_timestamp'] = timestamp - logger.info(f"Updated BUY trade at {prev_trade['timestamp']} with close price {price}") - break - - self.trades.append(trade) - - # Log the trade for debugging - pnl_str = f" with PnL: {pnl}" if pnl is not None else "" - logger.info(f"Added trade: {trade_type} {amount} at price {price} at time {timestamp}{pnl_str}") - - # Trigger a more frequent update of the chart by scheduling a callback - # This helps ensure the trade appears immediately on the chart - if hasattr(self, 'app') and self.app is not None: - try: - # Only update if we have a dash app running - # This is a workaround to make trades appear immediately - callback_context = dash.callback_context - # Force an update by triggering the callback - for callback_id, callback_info in self.app.callback_map.items(): - if 'live-chart' in callback_id: - # Found the chart callback, try to trigger it - logger.debug(f"Triggering chart update callback after trade") - callback_info['callback']() - break - except Exception as e: - # If callback triggering fails, it's not critical - logger.debug(f"Failed to trigger chart update: {str(e)}") - pass - - return trade - def update_trading_info(self, signal=None, position=None, balance=None, pnl=None): """Update the current trading information to be displayed on the chart @@ -3026,9 +1147,291 @@ class RealTimeChart: logger.debug(f"Updated trading info: Signal={self.current_signal}, Position={self.current_position}, Balance=${self.session_balance:.2f}, PnL={self.session_pnl:.4f}") + def _get_position_list_rows(self): + """Generate rows for the position table""" + if not self.positions: + return [html.Tr([html.Td("No positions yet", colSpan=6)])] + + position_rows = [] + + # Sort positions by time (most recent first) + sorted_positions = sorted(self.positions, + key=lambda x: x.timestamp if hasattr(x, 'timestamp') else datetime.now(), + reverse=True) + + # Take only the most recent 5 positions + for position in sorted_positions[:5]: + # Format time + time_obj = position.timestamp if hasattr(position, 'timestamp') else datetime.now() + if isinstance(time_obj, datetime): + # If trade is from a different day, include the date + today = datetime.now().date() + if time_obj.date() == today: + time_str = time_obj.strftime('%H:%M:%S') + else: + time_str = time_obj.strftime('%m-%d %H:%M:%S') + else: + time_str = str(time_obj) + + # Format prices with proper decimal places + entry_price = position.entry_price if hasattr(position, 'entry_price') else 'N/A' + if isinstance(entry_price, (int, float)): + entry_price_str = f"${entry_price:.6f}" + else: + entry_price_str = str(entry_price) + + # For exit price, use close_price for closed positions or current market price for open ones + if position.status == "CLOSED" and hasattr(position, 'exit_price'): + exit_price = position.exit_price + else: + exit_price = self.tick_storage.get_latest_price() if position.status == "OPEN" else 'N/A' + + if isinstance(exit_price, (int, float)): + exit_price_str = f"${exit_price:.6f}" + else: + exit_price_str = str(exit_price) + + # Format amount + amount = position.amount if hasattr(position, 'amount') else 0.1 + amount_str = f"{amount:.4f} BTC" + + # Format PnL + if position.status == "CLOSED": + pnl = position.pnl if hasattr(position, 'pnl') else 0 + pnl_str = f"${pnl:.2f}" + pnl_color = '#00FF00' if pnl >= 0 else '#FF0000' + elif position.status == "OPEN" and position.action == "BUY": + # Calculate unrealized PnL for open positions + if isinstance(exit_price, (int, float)) and isinstance(entry_price, (int, float)): + unrealized_pnl = (exit_price - entry_price) * amount + pnl_str = f"${unrealized_pnl:.2f} (unrealized)" + pnl_color = '#00FF00' if unrealized_pnl >= 0 else '#FF0000' + else: + pnl_str = 'N/A' + pnl_color = '#FFFFFF' + else: + pnl_str = 'N/A' + pnl_color = '#FFFFFF' + + # Set action/status color and text + if position.status == 'OPEN': + status_color = '#00AAFF' # Blue for open positions + status_text = f"OPEN ({position.action})" + elif position.status == 'CLOSED': + if hasattr(position, 'pnl') and isinstance(position.pnl, (int, float)): + status_color = '#00FF00' if position.pnl >= 0 else '#FF0000' # Green/Red based on profit + else: + status_color = '#FFCC00' # Yellow if PnL unknown + status_text = "CLOSED" + else: + status_color = '#00FF00' if position.action == 'BUY' else '#FF0000' + status_text = position.action + + # Create table row with more compact styling + position_rows.append(html.Tr([ + html.Td(status_text, style={'color': status_color, 'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px'}), + html.Td(amount_str, style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px'}), + html.Td(entry_price_str, style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px'}), + html.Td(exit_price_str, style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px'}), + html.Td(pnl_str, style={'color': pnl_color, 'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px'}), + html.Td(time_str, style={'padding': '4px 8px', 'border': '1px solid #444', 'fontSize': '12px'}) + ])) + + return position_rows + + def _initialize_chart(self): + """Initialize the chart figure""" + # Create a figure with subplots for price and volume + self.fig = make_subplots( + rows=2, + cols=1, + shared_xaxes=True, + vertical_spacing=0.03, + row_heights=[0.8, 0.2], + subplot_titles=(f"{self.symbol} Price Chart", "Volume") + ) + + # Set up initial empty traces + self.fig.add_trace( + go.Candlestick( + x=[], open=[], high=[], low=[], close=[], + name='Price', + increasing={'line': {'color': '#26A69A', 'width': 1}, 'fillcolor': '#26A69A'}, + decreasing={'line': {'color': '#EF5350', 'width': 1}, 'fillcolor': '#EF5350'} + ), + row=1, col=1 + ) + + # Add volume trace + self.fig.add_trace( + go.Bar( + x=[], y=[], + name='Volume', + marker={'color': '#888888'} + ), + row=2, col=1 + ) + + # Add empty traces for buy/sell markers + self.fig.add_trace( + go.Scatter( + x=[], y=[], + mode='markers', + name='BUY', + marker=dict( + symbol='triangle-up', + size=12, + color='rgba(0,255,0,0.8)', + line=dict(width=1, color='darkgreen') + ), + showlegend=True + ), + row=1, col=1 + ) + + self.fig.add_trace( + go.Scatter( + x=[], y=[], + mode='markers', + name='SELL', + marker=dict( + symbol='triangle-down', + size=12, + color='rgba(255,0,0,0.8)', + line=dict(width=1, color='darkred') + ), + showlegend=True + ), + row=1, col=1 + ) + + # Update layout + self.fig.update_layout( + title=f"{self.symbol} Real-Time Trading Chart", + title_x=0.5, + template='plotly_dark', + paper_bgcolor='rgba(0,0,0,0)', + plot_bgcolor='rgba(25,25,50,1)', + height=800, + xaxis_rangeslider_visible=False, + legend=dict( + orientation="h", + yanchor="bottom", + y=1.02, + xanchor="center", + x=0.5 + ) + ) + + # Update axes styling + self.fig.update_xaxes( + showgrid=True, + gridwidth=1, + gridcolor='rgba(128,128,128,0.2)', + zeroline=False + ) + + self.fig.update_yaxes( + showgrid=True, + gridwidth=1, + gridcolor='rgba(128,128,128,0.2)', + zeroline=False + ) + + # Do an initial update to populate the chart + self._update_chart() + + def _update_chart(self): + """Update the chart with the latest data""" + try: + # Get candlesticks data for the current interval + df = self.tick_storage.get_candles(interval=self.current_interval) + + if df is None or df.empty: + logger.warning(f"No candle data available for {self.current_interval}") + return + + # Limit the number of candles to display (show 500 for context) + df = df.tail(500) + + # Update candlestick data + self.fig.update_traces( + x=df.index, + open=df['open'], + high=df['high'], + low=df['low'], + close=df['close'], + selector=dict(type='candlestick') + ) + + # Update volume bars with colors based on price movement + colors = ['rgba(0,255,0,0.5)' if close >= open else 'rgba(255,0,0,0.5)' + for open, close in zip(df['open'], df['close'])] + + self.fig.update_traces( + x=df.index, + y=df['volume'], + marker_color=colors, + selector=dict(type='bar') + ) + + # Calculate y-axis range with padding for better visibility + if len(df) > 0: + low_min = df['low'].min() + high_max = df['high'].max() + price_range = high_max - low_min + y_min = low_min - (price_range * 0.05) # 5% padding below + y_max = high_max + (price_range * 0.05) # 5% padding above + + # Update y-axis range + self.fig.update_yaxes(range=[y_min, y_max], row=1, col=1) + + # Update Buy/Sell markers + if hasattr(self, 'positions') and self.positions: + # Collect buy and sell points + buy_times = [] + buy_prices = [] + sell_times = [] + sell_prices = [] + + for position in self.positions: + # Handle buy trades + if position.action == "BUY": + buy_times.append(position.timestamp) + buy_prices.append(position.entry_price) + + # Handle sell trades or closed positions + if position.status == "CLOSED" and hasattr(position, 'exit_timestamp') and hasattr(position, 'exit_price'): + sell_times.append(position.exit_timestamp) + sell_prices.append(position.exit_price) + + # Update buy markers trace + self.fig.update_traces( + x=buy_times, + y=buy_prices, + selector=dict(name='BUY') + ) + + # Update sell markers trace + self.fig.update_traces( + x=sell_times, + y=sell_prices, + selector=dict(name='SELL') + ) + + # Update chart title with the current interval + self.fig.update_layout( + title=f"{self.symbol} Real-Time Chart ({self.current_interval})" + ) + + except Exception as e: + logger.error(f"Error in _update_chart: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + async def main(): global charts # Make charts globally accessible for NN integration - symbols = ["ETH/USDT", "BTC/USDT"] + symbols = ["ETH/USDT", "ETH/USDT"] logger.info(f"Starting application for symbols: {symbols}") # Initialize neural network if enabled diff --git a/realtime_old.py b/realtime_old.py new file mode 100644 index 0000000..0bfa47f --- /dev/null +++ b/realtime_old.py @@ -0,0 +1,3083 @@ +import asyncio +import json +import logging + +from typing import Dict, List, Optional +import websockets +import plotly.graph_objects as go +from plotly.subplots import make_subplots +import dash +from dash import html, dcc +from dash.dependencies import Input, Output +import pandas as pd +import numpy as np +from collections import deque +import time +from threading import Thread +import requests +import os +from datetime import datetime, timedelta +import pytz +import tzlocal +import threading +import random +import dash_bootstrap_components as dbc + +# Configure logging with more detailed format +logging.basicConfig( + level=logging.INFO, # Changed to DEBUG for more detailed logs + format='%(asctime)s - %(levelname)s - [%(filename)s:%(lineno)d] - %(message)s', + handlers=[ + logging.StreamHandler(), + logging.FileHandler('realtime_chart.log') + ] +) +logger = logging.getLogger(__name__) + +# Neural Network integration (conditional import) +NN_ENABLED = os.environ.get('ENABLE_NN_MODELS', '0') == '1' +nn_orchestrator = None +nn_inference_thread = None + +if NN_ENABLED: + try: + import sys + # Add project root to sys.path if needed + project_root = os.path.dirname(os.path.abspath(__file__)) + if project_root not in sys.path: + sys.path.append(project_root) + + from NN.main import NeuralNetworkOrchestrator + logger.info("Neural Network module enabled") + except ImportError as e: + logger.warning(f"Failed to import Neural Network module, disabling NN features: {str(e)}") + NN_ENABLED = False + +# NN utility functions +def setup_neural_network(): + """Initialize the neural network components if enabled""" + global nn_orchestrator, NN_ENABLED + + if not NN_ENABLED: + return False + + try: + # Get configuration from environment variables or use defaults + symbol = os.environ.get('NN_SYMBOL', 'BTC/USDT') + timeframes = os.environ.get('NN_TIMEFRAMES', '1m,5m,1h,4h,1d').split(',') + output_size = int(os.environ.get('NN_OUTPUT_SIZE', '3')) # 3 for BUY/HOLD/SELL + + # Configure the orchestrator + config = { + 'symbol': symbol, + 'timeframes': timeframes, + 'window_size': int(os.environ.get('NN_WINDOW_SIZE', '20')), + 'n_features': 5, # OHLCV + 'output_size': output_size, + 'model_dir': 'NN/models/saved', + 'data_dir': 'NN/data' + } + + # Initialize the orchestrator + logger.info(f"Initializing Neural Network Orchestrator with config: {config}") + nn_orchestrator = NeuralNetworkOrchestrator(config) + + # Start inference thread if enabled + inference_interval = int(os.environ.get('NN_INFERENCE_INTERVAL', '60')) + if inference_interval > 0: + start_nn_inference_thread(inference_interval) + + return True + except Exception as e: + logger.error(f"Error setting up neural network: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + NN_ENABLED = False + return False + +def start_nn_inference_thread(interval_seconds): + """Start a background thread to periodically run inference with the neural network""" + global nn_inference_thread + + if not NN_ENABLED or nn_orchestrator is None: + logger.warning("Cannot start inference thread - Neural Network not enabled or initialized") + return False + + def inference_worker(): + """Worker function for the inference thread""" + model_type = os.environ.get('NN_MODEL_TYPE', 'cnn') + timeframe = os.environ.get('NN_TIMEFRAME', '1h') + + logger.info(f"Starting neural network inference thread with {interval_seconds}s interval") + logger.info(f"Using model type: {model_type}, timeframe: {timeframe}") + + # Wait a bit for charts to initialize + time.sleep(5) + + # Track active charts + active_charts = [] + + while True: + try: + # Find active charts if we don't have them yet + if not active_charts and 'charts' in globals(): + active_charts = globals()['charts'] + logger.info(f"Found {len(active_charts)} active charts for NN signals") + + # Run inference + result = nn_orchestrator.run_inference_pipeline( + model_type=model_type, + timeframe=timeframe + ) + + if result: + # Log the result + logger.info(f"Neural network inference result: {result}") + + # Add signal to charts + if active_charts: + try: + if 'action' in result: + action = result['action'] + timestamp = datetime.fromisoformat(result['timestamp'].replace('Z', '+00:00')) + + # Get probability if available + probability = None + if 'probability' in result: + probability = result['probability'] + elif 'probabilities' in result: + probability = result['probabilities'].get(action, None) + + # Add signal to each chart + for chart in active_charts: + if hasattr(chart, 'add_nn_signal'): + chart.add_nn_signal(action, timestamp, probability) + except Exception as e: + logger.error(f"Error adding NN signal to chart: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + + # Sleep for the interval + time.sleep(interval_seconds) + + except Exception as e: + logger.error(f"Error in inference thread: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + time.sleep(5) # Wait a bit before retrying + + # Create and start the thread + nn_inference_thread = threading.Thread(target=inference_worker, daemon=True) + nn_inference_thread.start() + + return True + +# Try to get local timezone, default to Sofia/EET if not available +try: + local_timezone = tzlocal.get_localzone() + # Get timezone name safely + try: + tz_name = str(local_timezone) + # Handle case where it might be zoneinfo.ZoneInfo object instead of pytz timezone + if hasattr(local_timezone, 'zone'): + tz_name = local_timezone.zone + elif hasattr(local_timezone, 'key'): + tz_name = local_timezone.key + else: + tz_name = str(local_timezone) + except: + tz_name = "Local" + logger.info(f"Detected local timezone: {local_timezone} ({tz_name})") +except Exception as e: + logger.warning(f"Could not detect local timezone: {str(e)}. Defaulting to Sofia/EET") + local_timezone = pytz.timezone('Europe/Sofia') + tz_name = "Europe/Sofia" + +def convert_to_local_time(timestamp): + """Convert timestamp to local timezone""" + try: + if isinstance(timestamp, pd.Timestamp): + dt = timestamp.to_pydatetime() + elif isinstance(timestamp, np.datetime64): + dt = pd.Timestamp(timestamp).to_pydatetime() + elif isinstance(timestamp, str): + dt = pd.to_datetime(timestamp).to_pydatetime() + else: + dt = timestamp + + # If datetime is naive (no timezone), assume it's UTC + if dt.tzinfo is None: + dt = dt.replace(tzinfo=pytz.UTC) + + # Convert to local timezone + local_dt = dt.astimezone(local_timezone) + return local_dt + except Exception as e: + logger.error(f"Error converting timestamp to local time: {str(e)}") + return timestamp + +class TradeTickStorage: + """Storage for trade ticks with a maximum age limit""" + + def __init__(self, symbol: str = None, max_age_seconds: int = 3600, use_sample_data: bool = True, log_no_ticks_warning: bool = False): # 1 hour by default, up from 30 min + """Initialize the tick storage + + Args: + symbol: Trading symbol + max_age_seconds: Maximum age for ticks to be stored + use_sample_data: If True, generate sample ticks when no real ticks available + log_no_ticks_warning: If True, log a warning when no ticks are available + """ + self.symbol = symbol + self.ticks = [] + self.max_age_seconds = max_age_seconds + self.last_cleanup_time = time.time() + self.cleanup_interval = 60 # run cleanup every 60 seconds + self.cache_dir = "cache" + self.use_sample_data = use_sample_data + self.log_no_ticks_warning = log_no_ticks_warning + self.last_sample_price = 83000.0 # Starting price for sample data (for BTC) + self.last_sample_time = time.time() * 1000 # Starting time for sample data + self.last_tick_time = 0 # Initialize last_tick_time attribute + self.tick_count = 0 # Initialize tick_count attribute + + # Create cache directory if it doesn't exist + if not os.path.exists(self.cache_dir): + os.makedirs(self.cache_dir) + + # Try to load cached ticks + self._load_cached_ticks() + + logger.info(f"Initialized TradeTickStorage for {symbol} with max age: {max_age_seconds} seconds, cleanup interval: {self.cleanup_interval} seconds") + + def add_tick(self, tick: Dict): + """Add a tick to storage + + Args: + tick: Tick data dict with fields: + price: The price + volume: The volume + timestamp: Timestamp in milliseconds + """ + if not tick: + return + + # Check if we need to generate a timestamp + if 'timestamp' not in tick: + tick['timestamp'] = int(time.time() * 1000) # Current time in ms + + # Ensure timestamp is an integer (milliseconds since epoch) + if not isinstance(tick['timestamp'], int): + try: + # Try to convert from float or string + tick['timestamp'] = int(float(tick['timestamp'])) + except (ValueError, TypeError): + # If conversion fails, use current time + tick['timestamp'] = int(time.time() * 1000) + + # Set default volume if not present + if 'volume' not in tick: + tick['volume'] = 0.01 # Default small volume + + # Add tick to storage with a copy to avoid mutation + self.ticks.append(tick.copy()) + + # Keep track of latest tick for stats + self.last_tick_time = max(self.last_tick_time, tick['timestamp']) + + # Cache every 100 ticks to avoid data loss + self.tick_count += 1 + if self.tick_count % 100 == 0: + self._cache_ticks() + + # Periodically clean up old ticks + if self.tick_count % 1000 == 0: + self._cleanup() + + def _cleanup(self): + """Remove ticks older than max_age_seconds""" + # Get current time in milliseconds + now = int(time.time() * 1000) + + # Remove old ticks + cutoff = now - (self.max_age_seconds * 1000) + original_count = len(self.ticks) + self.ticks = [tick for tick in self.ticks if tick['timestamp'] >= cutoff] + removed = original_count - len(self.ticks) + + if removed > 0: + logger.debug(f"Cleaned up {removed} old ticks, remaining: {len(self.ticks)}") + + def _load_cached_ticks(self): + """Load cached ticks from disk on startup""" + # Create symbol-specific filename + symbol_safe = self.symbol.replace("/", "_").replace("-", "_").lower() + cache_file = os.path.join(self.cache_dir, f"{symbol_safe}_recent_ticks.csv") + + if not os.path.exists(cache_file): + logger.info(f"No cached ticks found for {self.symbol}") + return + + try: + # Check if cache is fresh (less than 10 minutes old) + file_age = time.time() - os.path.getmtime(cache_file) + if file_age > 600: # 10 minutes + logger.info(f"Cached ticks for {self.symbol} are too old ({file_age:.1f}s), skipping") + return + + # Load cached ticks + tick_df = pd.read_csv(cache_file) + if tick_df.empty: + logger.info(f"Cached ticks file for {self.symbol} is empty") + return + + # Convert to list of dicts and add to storage + cached_ticks = tick_df.to_dict('records') + self.ticks.extend(cached_ticks) + logger.info(f"Loaded {len(cached_ticks)} cached ticks for {self.symbol} from {cache_file}") + except Exception as e: + logger.error(f"Error loading cached ticks for {self.symbol}: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + + def _cache_ticks(self): + """Cache recent ticks to disk""" + if not self.ticks: + return + + # Get ticks from last 10 minutes + now = int(time.time() * 1000) # Current time in ms + cutoff = now - (600 * 1000) # 10 minutes in ms + recent_ticks = [tick for tick in self.ticks if tick['timestamp'] >= cutoff] + + if not recent_ticks: + logger.debug("No recent ticks to cache") + return + + # Create symbol-specific filename + symbol_safe = self.symbol.replace("/", "_").replace("-", "_").lower() + cache_file = os.path.join(self.cache_dir, f"{symbol_safe}_recent_ticks.csv") + + # Save to disk + try: + tick_df = pd.DataFrame(recent_ticks) + tick_df.to_csv(cache_file, index=False) + logger.info(f"Cached {len(recent_ticks)} recent ticks for {self.symbol} to {cache_file}") + except Exception as e: + logger.error(f"Error caching ticks: {str(e)}") + + def get_latest_price(self) -> Optional[float]: + """Get the latest price from the most recent tick""" + if self.ticks: + return self.ticks[-1].get('price') + return None + + def get_price_stats(self) -> Dict: + """Get stats about the prices in storage""" + if not self.ticks: + return { + 'min': None, + 'max': None, + 'latest': None, + 'count': 0, + 'age_seconds': 0 + } + + prices = [tick['price'] for tick in self.ticks] + latest_timestamp = self.ticks[-1]['timestamp'] + oldest_timestamp = self.ticks[0]['timestamp'] + + return { + 'min': min(prices), + 'max': max(prices), + 'latest': prices[-1], + 'count': len(prices), + 'age_seconds': (latest_timestamp - oldest_timestamp) / 1000 + } + + def get_ticks_as_df(self) -> pd.DataFrame: + """Return ticks as a DataFrame""" + if not self.ticks: + logger.warning("No ticks available for DataFrame conversion") + return pd.DataFrame() + + # Ensure we have fresh data + self._cleanup() + + # Create a new list from ticks to avoid modifying the original data + ticks_data = self.ticks.copy() + + # Ensure we have the latest ticks at the end of the DataFrame + ticks_data.sort(key=lambda x: x['timestamp']) + + df = pd.DataFrame(ticks_data) + if not df.empty: + logger.debug(f"Converting timestamps for {len(df)} ticks") + # Ensure timestamp column exists + if 'timestamp' not in df.columns: + logger.error("Tick data missing timestamp column") + return pd.DataFrame() + + # Check timestamp datatype before conversion + sample_ts = df['timestamp'].iloc[0] if len(df) > 0 else None + logger.debug(f"Sample timestamp before conversion: {sample_ts}, type: {type(sample_ts)}") + + # Convert timestamps to datetime + try: + df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') + logger.debug(f"Timestamps converted to datetime successfully") + if len(df) > 0: + logger.debug(f"Sample converted timestamp: {df['timestamp'].iloc[0]}") + except Exception as e: + logger.error(f"Error converting timestamps: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return pd.DataFrame() + return df + + def get_candles(self, interval_seconds: int = 1, start_time_ms: int = None, end_time_ms: int = None) -> pd.DataFrame: + """Generate candlestick data from ticks + + Args: + interval_seconds: Interval in seconds for each candle + start_time_ms: Start time in milliseconds + end_time_ms: End time in milliseconds + + Returns: + DataFrame with candlestick data + """ + # Get filtered ticks + ticks = self.get_ticks_from_time(start_time_ms, end_time_ms) + + if not ticks: + if self.use_sample_data: + # Generate multiple sample ticks to create several candles + current_time = int(time.time() * 1000) + sample_ticks = [] + + # Generate ticks for the past 10 intervals + for i in range(20): + # Base price with some trend + base_price = self.last_sample_price * (1 + 0.0001 * (10 - i)) + + # Add some randomness to the price + random_factor = random.uniform(-0.002, 0.002) # Small random change + tick_price = base_price * (1 + random_factor) + + # Create timestamp with appropriate offset + tick_time = current_time - (i * interval_seconds * 1000 // 2) + + sample_tick = { + 'price': tick_price, + 'volume': random.uniform(0.01, 0.5), + 'timestamp': tick_time, + 'is_sample': True + } + + sample_ticks.append(sample_tick) + + # Update the last sample values + self.last_sample_price = sample_ticks[0]['price'] + self.last_sample_time = sample_ticks[0]['timestamp'] + + # Add the sample ticks in chronological order + for tick in sorted(sample_ticks, key=lambda x: x['timestamp']): + self.add_tick(tick) + + # Try again with the new ticks + ticks = self.get_ticks_from_time(start_time_ms, end_time_ms) + + if not ticks and self.log_no_ticks_warning: + logger.warning("Still no ticks available after adding sample data") + elif self.log_no_ticks_warning: + logger.warning("No ticks available for candle formation") + return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']) + else: + return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume']) + + # Ensure ticks are up to date + try: + self._cleanup() + except Exception as cleanup_error: + logger.error(f"Error cleaning up ticks: {str(cleanup_error)}") + + df = pd.DataFrame(ticks) + if df.empty: + logger.warning("Tick DataFrame is empty after filtering/conversion") + return pd.DataFrame() + + logger.info(f"Preparing to create candles from {len(df)} ticks with {interval_seconds}s interval") + + # First, ensure all required columns exist + required_columns = ['timestamp', 'price', 'volume'] + for col in required_columns: + if col not in df.columns: + logger.error(f"Required column '{col}' missing from tick data") + return pd.DataFrame() + + # Make sure DataFrame has no duplicated timestamps before setting index + try: + if 'timestamp' in df.columns: + # Check for duplicate timestamps + duplicate_count = df['timestamp'].duplicated().sum() + if duplicate_count > 0: + logger.warning(f"Found {duplicate_count} duplicate timestamps, keeping the last occurrence") + # Keep the last occurrence of each timestamp + df = df.drop_duplicates(subset='timestamp', keep='last') + + # Convert timestamp to datetime if it's not already + if not pd.api.types.is_datetime64_any_dtype(df['timestamp']): + logger.debug("Converting timestamp to datetime") + # Try multiple approaches to convert timestamps + try: + # First, try to convert from milliseconds (integer timestamps) + if pd.api.types.is_integer_dtype(df['timestamp']): + df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') + else: + # Otherwise try standard conversion + df['timestamp'] = pd.to_datetime(df['timestamp']) + except Exception as conv_error: + logger.error(f"Error converting timestamps to datetime: {str(conv_error)}") + # Try a fallback approach + try: + # Fallback for integer timestamps + if df['timestamp'].iloc[0] > 1000000000000: # Check if milliseconds timestamp + df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') + else: # Otherwise assume seconds + df['timestamp'] = pd.to_datetime(df['timestamp'], unit='s') + except Exception as fallback_error: + logger.error(f"Fallback timestamp conversion failed: {str(fallback_error)}") + return pd.DataFrame() + + # Use timestamp column for resampling + df = df.set_index('timestamp') + except Exception as prep_error: + logger.error(f"Error preprocessing DataFrame for resampling: {str(prep_error)}") + import traceback + logger.error(traceback.format_exc()) + return pd.DataFrame() + + # Create interval string for resampling - use 's' instead of deprecated 'S' + interval_str = f'{interval_seconds}s' + + # Resample to create OHLCV candles with multiple fallback options + logger.debug(f"Resampling with interval: {interval_str}") + + candles = None + + # First attempt - individual column resampling + try: + # Check that price column exists and has enough data + if 'price' not in df.columns: + raise ValueError("Price column missing from DataFrame") + + if len(df) < 2: + logger.warning("Not enough data points for resampling, using direct data") + # For single data point, create a single candle + if len(df) == 1: + price_val = df['price'].iloc[0] + volume_val = df['volume'].iloc[0] if 'volume' in df.columns else 0 + timestamp_val = df.index[0] + + candles = pd.DataFrame({ + 'timestamp': [timestamp_val], + 'open': [price_val], + 'high': [price_val], + 'low': [price_val], + 'close': [price_val], + 'volume': [volume_val] + }) + return candles + else: + # No data + return pd.DataFrame() + + # Resample and aggregate each column separately + open_df = df['price'].resample(interval_str).first() + high_df = df['price'].resample(interval_str).max() + low_df = df['price'].resample(interval_str).min() + close_df = df['price'].resample(interval_str).last() + volume_df = df['volume'].resample(interval_str).sum() + + # Check for length mismatches before combining + expected_length = len(open_df) + if (len(high_df) != expected_length or + len(low_df) != expected_length or + len(close_df) != expected_length or + len(volume_df) != expected_length): + logger.warning("Length mismatch in resampled columns, falling back to alternative method") + raise ValueError("Length mismatch") + + # Combine into a single DataFrame + candles = pd.DataFrame({ + 'open': open_df, + 'high': high_df, + 'low': low_df, + 'close': close_df, + 'volume': volume_df + }) + logger.debug(f"Successfully created {len(candles)} candles with individual column resampling") + except Exception as resample_error: + logger.error(f"Error in individual column resampling: {str(resample_error)}") + + # Second attempt - built-in agg method + try: + logger.debug("Trying fallback resampling method with agg()") + candles = df.resample(interval_str).agg({ + 'price': ['first', 'max', 'min', 'last'], + 'volume': 'sum' + }) + # Flatten MultiIndex columns + candles.columns = ['open', 'high', 'low', 'close', 'volume'] + logger.debug(f"Successfully created {len(candles)} candles with agg() method") + except Exception as agg_error: + logger.error(f"Error in agg() resampling: {str(agg_error)}") + + # Third attempt - manual candle construction + try: + logger.debug("Trying manual candle construction method") + resampler = df.resample(interval_str) + candle_data = [] + + for name, group in resampler: + if not group.empty: + candle = { + 'timestamp': name, + 'open': group['price'].iloc[0], + 'high': group['price'].max(), + 'low': group['price'].min(), + 'close': group['price'].iloc[-1], + 'volume': group['volume'].sum() if 'volume' in group.columns else 0 + } + candle_data.append(candle) + + if candle_data: + candles = pd.DataFrame(candle_data) + logger.debug(f"Successfully created {len(candles)} candles with manual method") + else: + logger.warning("No candles created with manual method") + return pd.DataFrame() + except Exception as manual_error: + logger.error(f"Error in manual candle construction: {str(manual_error)}") + import traceback + logger.error(traceback.format_exc()) + return pd.DataFrame() + + # Ensure the result isn't empty + if candles is None or candles.empty: + logger.warning("No candles were created after all resampling attempts") + return pd.DataFrame() + + # Reset index to get timestamp as column + try: + candles = candles.reset_index() + except Exception as reset_error: + logger.error(f"Error resetting index: {str(reset_error)}") + # Try to create a new DataFrame with the timestamp index as a column + try: + timestamp_col = candles.index.to_list() + candles_dict = candles.to_dict('list') + candles_dict['timestamp'] = timestamp_col + candles = pd.DataFrame(candles_dict) + except Exception as fallback_error: + logger.error(f"Error in fallback index reset: {str(fallback_error)}") + return pd.DataFrame() + + # Ensure no NaN values + try: + nan_count_before = candles.isna().sum().sum() + if nan_count_before > 0: + logger.warning(f"Found {nan_count_before} NaN values in candles, dropping them") + + candles = candles.dropna() + except Exception as nan_error: + logger.error(f"Error handling NaN values: {str(nan_error)}") + # Try to fill NaN values instead of dropping + try: + candles = candles.fillna(method='ffill').fillna(method='bfill') + except: + pass + + logger.debug(f"Generated {len(candles)} candles from {len(df)} ticks") + return candles + + def get_candle_stats(self) -> Dict: + """Get statistics about cached candles for different intervals""" + stats = {} + + # Define intervals to check + intervals = [1, 5, 15, 60, 300, 900, 3600] + + for interval in intervals: + candles = self.get_candles(interval_seconds=interval) + count = len(candles) if not candles.empty else 0 + + # Get time range if we have candles + time_range = None + if count > 0: + try: + start_time = candles['timestamp'].min() + end_time = candles['timestamp'].max() + if isinstance(start_time, pd.Timestamp): + start_time = start_time.strftime('%Y-%m-%d %H:%M:%S') + if isinstance(end_time, pd.Timestamp): + end_time = end_time.strftime('%Y-%m-%d %H:%M:%S') + time_range = f"{start_time} to {end_time}" + except: + time_range = "Unknown" + + stats[f"{interval}s"] = { + 'count': count, + 'time_range': time_range + } + + return stats + + def get_ticks_from_time(self, start_time_ms: int = None, end_time_ms: int = None) -> List[Dict]: + """Get ticks within a specific time range + + Args: + start_time_ms: Start time in milliseconds (None for no lower bound) + end_time_ms: End time in milliseconds (None for no upper bound) + + Returns: + List of ticks within the time range + """ + if not self.ticks: + return [] + + # Ensure ticks are updated + self._cleanup() + + # Apply time filters if specified + filtered_ticks = self.ticks + if start_time_ms is not None: + filtered_ticks = [tick for tick in filtered_ticks if tick['timestamp'] >= start_time_ms] + if end_time_ms is not None: + filtered_ticks = [tick for tick in filtered_ticks if tick['timestamp'] <= end_time_ms] + + logger.debug(f"Retrieved {len(filtered_ticks)} ticks from time range {start_time_ms} to {end_time_ms}") + return filtered_ticks + + def get_time_based_stats(self) -> Dict: + """Get statistics about the ticks organized by time periods + + Returns: + Dictionary with statistics for different time periods + """ + if not self.ticks: + return { + 'total_ticks': 0, + 'periods': {} + } + + # Ensure ticks are updated + self._cleanup() + + now = int(time.time() * 1000) # Current time in ms + + # Define time periods to analyze + periods = { + '1min': now - (60 * 1000), + '5min': now - (5 * 60 * 1000), + '15min': now - (15 * 60 * 1000), + '30min': now - (30 * 60 * 1000) + } + + stats = { + 'total_ticks': len(self.ticks), + 'oldest_tick': self.ticks[0]['timestamp'] if self.ticks else None, + 'newest_tick': self.ticks[-1]['timestamp'] if self.ticks else None, + 'time_span_seconds': (self.ticks[-1]['timestamp'] - self.ticks[0]['timestamp']) / 1000 if self.ticks else 0, + 'periods': {} + } + + # Calculate stats for each period + for period_name, cutoff_time in periods.items(): + period_ticks = [tick for tick in self.ticks if tick['timestamp'] >= cutoff_time] + + if period_ticks: + prices = [tick['price'] for tick in period_ticks] + volumes = [tick.get('volume', 0) for tick in period_ticks] + + period_stats = { + 'tick_count': len(period_ticks), + 'min_price': min(prices) if prices else None, + 'max_price': max(prices) if prices else None, + 'avg_price': sum(prices) / len(prices) if prices else None, + 'last_price': period_ticks[-1]['price'] if period_ticks else None, + 'total_volume': sum(volumes), + 'ticks_per_second': len(period_ticks) / (int(period_name[:-3]) * 60) if period_ticks else 0 + } + + stats['periods'][period_name] = period_stats + + logger.debug(f"Generated time-based stats: {len(stats['periods'])} periods") + return stats + +class CandlestickData: + def __init__(self, max_length: int = 300): + self.timestamps = deque(maxlen=max_length) + self.opens = deque(maxlen=max_length) + self.highs = deque(maxlen=max_length) + self.lows = deque(maxlen=max_length) + self.closes = deque(maxlen=max_length) + self.volumes = deque(maxlen=max_length) + self.current_candle = { + 'timestamp': None, + 'open': None, + 'high': None, + 'low': None, + 'close': None, + 'volume': 0 + } + self.candle_interval = 1 # 1 second by default + + def update_from_trade(self, trade: Dict): + timestamp = trade['timestamp'] + price = trade['price'] + volume = trade.get('volume', 0) + + # Round timestamp to nearest candle interval + candle_timestamp = int(timestamp / (self.candle_interval * 1000)) * (self.candle_interval * 1000) + + if self.current_candle['timestamp'] != candle_timestamp: + # Save current candle if it exists + if self.current_candle['timestamp'] is not None: + self.timestamps.append(self.current_candle['timestamp']) + self.opens.append(self.current_candle['open']) + self.highs.append(self.current_candle['high']) + self.lows.append(self.current_candle['low']) + self.closes.append(self.current_candle['close']) + self.volumes.append(self.current_candle['volume']) + logger.debug(f"New candle saved: {self.current_candle}") + + # Start new candle + self.current_candle = { + 'timestamp': candle_timestamp, + 'open': price, + 'high': price, + 'low': price, + 'close': price, + 'volume': volume + } + logger.debug(f"New candle started: {self.current_candle}") + else: + # Update current candle + if self.current_candle['high'] is None or price > self.current_candle['high']: + self.current_candle['high'] = price + if self.current_candle['low'] is None or price < self.current_candle['low']: + self.current_candle['low'] = price + self.current_candle['close'] = price + self.current_candle['volume'] += volume + logger.debug(f"Updated current candle: {self.current_candle}") + + def get_dataframe(self) -> pd.DataFrame: + # Include current candle in the dataframe if it exists + timestamps = list(self.timestamps) + opens = list(self.opens) + highs = list(self.highs) + lows = list(self.lows) + closes = list(self.closes) + volumes = list(self.volumes) + + if self.current_candle['timestamp'] is not None: + timestamps.append(self.current_candle['timestamp']) + opens.append(self.current_candle['open']) + highs.append(self.current_candle['high']) + lows.append(self.current_candle['low']) + closes.append(self.current_candle['close']) + volumes.append(self.current_candle['volume']) + + df = pd.DataFrame({ + 'timestamp': timestamps, + 'open': opens, + 'high': highs, + 'low': lows, + 'close': closes, + 'volume': volumes + }) + if not df.empty: + df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') + return df + +class BinanceWebSocket: + """Binance WebSocket implementation for real-time tick data""" + def __init__(self, symbol: str): + self.symbol = symbol.replace('/', '').lower() + self.ws = None + self.running = False + self.reconnect_delay = 1 + self.max_reconnect_delay = 60 + self.message_count = 0 + + # Binance WebSocket configuration + self.ws_url = f"wss://stream.binance.com:9443/ws/{self.symbol}@trade" + logger.info(f"Initialized Binance WebSocket for symbol: {self.symbol}") + + async def connect(self): + while True: + try: + logger.info(f"Attempting to connect to {self.ws_url}") + self.ws = await websockets.connect(self.ws_url) + logger.info("WebSocket connection established") + + self.running = True + self.reconnect_delay = 1 + logger.info(f"Successfully connected to Binance WebSocket for {self.symbol}") + return True + except Exception as e: + logger.error(f"WebSocket connection error: {str(e)}") + await asyncio.sleep(self.reconnect_delay) + self.reconnect_delay = min(self.reconnect_delay * 2, self.max_reconnect_delay) + continue + + async def receive(self) -> Optional[Dict]: + if not self.ws: + return None + + try: + message = await self.ws.recv() + self.message_count += 1 + + if self.message_count % 100 == 0: # Log every 100th message to avoid spam + logger.info(f"Received message #{self.message_count}") + logger.debug(f"Raw message: {message[:200]}...") + + data = json.loads(message) + + # Process trade data + if 'e' in data and data['e'] == 'trade': + trade_data = { + 'timestamp': data['T'], # Trade time + 'price': float(data['p']), # Price + 'volume': float(data['q']), # Quantity + 'type': 'trade' + } + logger.debug(f"Processed trade data: {trade_data}") + return trade_data + + return None + except websockets.exceptions.ConnectionClosed: + logger.warning("WebSocket connection closed") + self.running = False + return None + except json.JSONDecodeError as e: + logger.error(f"JSON decode error: {str(e)}, message: {message[:200]}...") + return None + except Exception as e: + logger.error(f"Error receiving message: {str(e)}") + return None + + async def close(self): + """Close the WebSocket connection""" + if self.ws: + await self.ws.close() + self.running = False + logger.info("WebSocket connection closed") + +class BinanceHistoricalData: + """Fetch historical candle data from Binance""" + + def __init__(self): + self.base_url = "https://api.binance.com/api/v3/klines" + # Create a cache directory if it doesn't exist + self.cache_dir = os.path.join(os.getcwd(), "cache") + os.makedirs(self.cache_dir, exist_ok=True) + logger.info(f"Initialized BinanceHistoricalData with cache directory: {self.cache_dir}") + + def _get_interval_string(self, interval_seconds: int) -> str: + """Convert interval seconds to Binance interval string""" + if interval_seconds == 60: # 1m + return "1m" + elif interval_seconds == 300: # 5m + return "5m" + elif interval_seconds == 900: # 15m + return "15m" + elif interval_seconds == 1800: # 30m + return "30m" + elif interval_seconds == 3600: # 1h + return "1h" + elif interval_seconds == 14400: # 4h + return "4h" + elif interval_seconds == 86400: # 1d + return "1d" + else: + # Default to 1m if not recognized + logger.warning(f"Unrecognized interval {interval_seconds}s, defaulting to 1m") + return "1m" + + def _get_cache_filename(self, symbol: str, interval: str) -> str: + """Generate cache filename for the symbol and interval""" + # Replace any slashes in symbol with underscore + safe_symbol = symbol.replace("/", "_") + return os.path.join(self.cache_dir, f"{safe_symbol}_{interval}_candles.csv") + + def _load_from_cache(self, symbol: str, interval: str) -> Optional[pd.DataFrame]: + """Load candle data from cache if available and not expired""" + filename = self._get_cache_filename(symbol, interval) + + if not os.path.exists(filename): + logger.debug(f"No cache file found for {symbol} {interval}") + return None + + # Check if cache is fresh (less than 1 hour old for anything but 1d, 1 day for 1d) + file_age = time.time() - os.path.getmtime(filename) + max_age = 86400 if interval == "1d" else 3600 # 1 day for 1d, 1 hour for others + + if file_age > max_age: + logger.debug(f"Cache for {symbol} {interval} is expired ({file_age:.1f}s old)") + return None + + try: + df = pd.read_csv(filename) + # Convert timestamp string back to datetime + df['timestamp'] = pd.to_datetime(df['timestamp']) + logger.info(f"Loaded {len(df)} candles from cache for {symbol} {interval}") + return df + except Exception as e: + logger.error(f"Error loading from cache: {str(e)}") + return None + + def _save_to_cache(self, df: pd.DataFrame, symbol: str, interval: str) -> bool: + """Save candle data to cache""" + if df.empty: + logger.warning(f"No data to cache for {symbol} {interval}") + return False + + filename = self._get_cache_filename(symbol, interval) + try: + df.to_csv(filename, index=False) + logger.info(f"Cached {len(df)} candles for {symbol} {interval} to {filename}") + return True + except Exception as e: + logger.error(f"Error saving to cache: {str(e)}") + return False + + def get_historical_candles(self, symbol: str, interval_seconds: int, limit: int = 500) -> pd.DataFrame: + """Get historical candle data for the specified symbol and interval""" + # Convert to Binance format + clean_symbol = symbol.replace("/", "") + interval = self._get_interval_string(interval_seconds) + + # Try to load from cache first + cached_data = self._load_from_cache(symbol, interval) + if cached_data is not None and len(cached_data) >= limit: + return cached_data.tail(limit) + + # Fetch from API if not cached or insufficient + try: + logger.info(f"Fetching {limit} historical candles for {symbol} ({interval}) from Binance API") + + params = { + "symbol": clean_symbol, + "interval": interval, + "limit": limit + } + + response = requests.get(self.base_url, params=params) + response.raise_for_status() # Raise exception for HTTP errors + + # Process the data + candles = response.json() + + if not candles: + logger.warning(f"No candles returned from Binance for {symbol} {interval}") + return pd.DataFrame() + + # Convert to DataFrame - Binance returns data in this format: + # [ + # [ + # 1499040000000, // Open time + # "0.01634790", // Open + # "0.80000000", // High + # "0.01575800", // Low + # "0.01577100", // Close + # "148976.11427815", // Volume + # ... // Ignore the rest + # ], + # ... + # ] + + df = pd.DataFrame(candles, columns=[ + "timestamp", "open", "high", "low", "close", "volume", + "close_time", "quote_asset_volume", "number_of_trades", + "taker_buy_base_asset_volume", "taker_buy_quote_asset_volume", "ignore" + ]) + + # Convert types + df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms') + for col in ["open", "high", "low", "close", "volume"]: + df[col] = df[col].astype(float) + + # Keep only needed columns + df = df[["timestamp", "open", "high", "low", "close", "volume"]] + + # Cache the results + self._save_to_cache(df, symbol, interval) + + logger.info(f"Successfully fetched {len(df)} candles for {symbol} {interval}") + return df + + except Exception as e: + logger.error(f"Error fetching historical data for {symbol} {interval}: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return pd.DataFrame() + + +class ExchangeWebSocket: + """Generic WebSocket interface for cryptocurrency exchanges""" + def __init__(self, symbol: str, exchange: str = "binance"): + self.symbol = symbol + self.exchange = exchange.lower() + self.ws = None + + # Initialize the appropriate WebSocket implementation + if self.exchange == "binance": + self.ws = BinanceWebSocket(symbol) + elif self.exchange == "mexc": + self.ws = MEXCWebSocket(symbol) + else: + raise ValueError(f"Unsupported exchange: {exchange}") + + async def connect(self): + """Connect to the exchange WebSocket""" + return await self.ws.connect() + + async def receive(self) -> Optional[Dict]: + """Receive data from the WebSocket""" + return await self.ws.receive() + + async def close(self): + """Close the WebSocket connection""" + await self.ws.close() + + @property + def running(self): + """Check if the WebSocket is running""" + return self.ws.running if self.ws else False + +class CandleCache: + def __init__(self, max_candles: int = 5000): + self.candles = { + '1s': deque(maxlen=max_candles), + '1m': deque(maxlen=max_candles), + '1h': deque(maxlen=max_candles), + '1d': deque(maxlen=max_candles) + } + logger.info(f"Initialized CandleCache with max candles: {max_candles}") + + def add_candles(self, interval: str, new_candles: pd.DataFrame): + if interval in self.candles and not new_candles.empty: + # Convert DataFrame to list of dicts to avoid pandas issues + for _, row in new_candles.iterrows(): + candle_dict = row.to_dict() + self.candles[interval].append(candle_dict) + logger.debug(f"Added {len(new_candles)} candles to {interval} cache") + + def get_recent_candles(self, interval: str, count: int = 500) -> pd.DataFrame: + if interval in self.candles and self.candles[interval]: + # Convert deque to list of dicts first + all_candles = list(self.candles[interval]) + # Check if we're requesting more candles than we have + if count > len(all_candles): + logger.debug(f"Requested {count} candles, but only have {len(all_candles)} for {interval}") + count = len(all_candles) + + recent_candles = all_candles[-count:] + logger.debug(f"Returning {len(recent_candles)} recent candles for {interval} (requested {count})") + + # Create DataFrame and ensure timestamp is datetime type + df = pd.DataFrame(recent_candles) + if not df.empty and 'timestamp' in df.columns: + try: + if not pd.api.types.is_datetime64_any_dtype(df['timestamp']): + df['timestamp'] = pd.to_datetime(df['timestamp']) + except Exception as e: + logger.warning(f"Error converting timestamps in get_recent_candles: {str(e)}") + + return df + + logger.debug(f"No candles available for {interval}") + return pd.DataFrame() + + def update_cache(self, interval: str, new_candles: pd.DataFrame): + """ + Update the candle cache for a specific timeframe with new candles + + Args: + interval: The timeframe interval ('1s', '1m', '1h', '1d') + new_candles: DataFrame with new candles to add to the cache + """ + if interval not in self.candles: + logger.warning(f"Invalid interval {interval} for cache update") + return + + if new_candles is None or new_candles.empty: + logger.debug(f"No new candles to update {interval} cache") + return + + # Check if timestamp column exists + if 'timestamp' not in new_candles.columns: + logger.warning(f"No timestamp column in new candles for {interval}") + return + + try: + # Ensure timestamp is datetime for proper comparison + try: + if not pd.api.types.is_datetime64_any_dtype(new_candles['timestamp']): + logger.debug(f"Converting timestamps to datetime for {interval}") + new_candles['timestamp'] = pd.to_datetime(new_candles['timestamp']) + except Exception as e: + logger.error(f"Error converting timestamps: {str(e)}") + # Try a different approach + try: + new_candles['timestamp'] = pd.to_datetime(new_candles['timestamp'], errors='coerce') + # Drop any rows where conversion failed + new_candles = new_candles.dropna(subset=['timestamp']) + if new_candles.empty: + logger.warning(f"All timestamps conversion failed for {interval}") + return + except Exception as e2: + logger.error(f"Second attempt to convert timestamps failed: {str(e2)}") + return + + # Create a copy to avoid modifying the original + new_candles_copy = new_candles.copy() + + # If we have no candles in cache, add all new candles + if not self.candles[interval]: + logger.debug(f"No existing candles for {interval}, adding all {len(new_candles_copy)} candles") + self.add_candles(interval, new_candles_copy) + return + + # Get the timestamp from the last cached candle + last_cached_candle = self.candles[interval][-1] + if not isinstance(last_cached_candle, dict): + logger.warning(f"Last cached candle is not a dictionary for {interval}") + last_cached_candle = {'timestamp': None} + + if 'timestamp' not in last_cached_candle: + logger.warning(f"No timestamp in last cached candle for {interval}") + last_cached_candle['timestamp'] = None + + last_cached_time = last_cached_candle['timestamp'] + logger.debug(f"Last cached timestamp for {interval}: {last_cached_time}") + + # If last_cached_time is None, add all candles + if last_cached_time is None: + logger.debug(f"No valid last cached timestamp, adding all {len(new_candles_copy)} candles for {interval}") + self.add_candles(interval, new_candles_copy) + return + + # Convert last_cached_time to datetime if needed + if not isinstance(last_cached_time, (pd.Timestamp, datetime)): + try: + last_cached_time = pd.to_datetime(last_cached_time) + except Exception as e: + logger.error(f"Cannot convert last cached time to datetime: {str(e)}") + # Add all candles as fallback + self.add_candles(interval, new_candles_copy) + return + + # Make a backup of current cache before filtering + cache_backup = list(self.candles[interval]) + + # Filter new candles that are after the last cached candle + try: + filtered_candles = new_candles_copy[new_candles_copy['timestamp'] > last_cached_time] + + if not filtered_candles.empty: + logger.debug(f"Adding {len(filtered_candles)} new candles for {interval}") + self.add_candles(interval, filtered_candles) + else: + # No new candles after last cached time, check for missing candles + try: + # Get unique timestamps in cache + cached_timestamps = set() + for candle in self.candles[interval]: + if isinstance(candle, dict) and 'timestamp' in candle: + ts = candle['timestamp'] + if isinstance(ts, (pd.Timestamp, datetime)): + cached_timestamps.add(ts) + else: + try: + cached_timestamps.add(pd.to_datetime(ts)) + except: + pass + + # Find candles in new_candles that aren't in the cache + missing_candles = new_candles_copy[~new_candles_copy['timestamp'].isin(cached_timestamps)] + + if not missing_candles.empty: + logger.info(f"Found {len(missing_candles)} missing candles for {interval}") + self.add_candles(interval, missing_candles) + else: + logger.debug(f"No new or missing candles to add for {interval}") + except Exception as missing_error: + logger.error(f"Error checking for missing candles: {str(missing_error)}") + except Exception as filter_error: + logger.error(f"Error filtering candles by timestamp: {str(filter_error)}") + # Restore from backup + self.candles[interval] = deque(cache_backup, maxlen=self.candles[interval].maxlen) + # Try adding all candles as fallback + self.add_candles(interval, new_candles_copy) + except Exception as e: + logger.error(f"Unhandled error updating cache for {interval}: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + + def get_candles(self, timeframe: str, count: int = 500) -> pd.DataFrame: + """ + Get candles for a specific timeframe. This is an alias for get_recent_candles + to maintain compatibility with code that expects this method name. + + Args: + timeframe: The timeframe interval ('1s', '1m', '1h', '1d') + count: Maximum number of candles to return + + Returns: + DataFrame containing the candles + """ + try: + logger.debug(f"Getting {count} candles for {timeframe} via get_candles()") + return self.get_recent_candles(timeframe, count) + except Exception as e: + logger.error(f"Error in get_candles for {timeframe}: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return pd.DataFrame() + +class RealTimeChart: + """Real-time chart using Dash and Plotly""" + + def __init__(self, symbol="BTC/USDT", use_sample_data=False, log_no_ticks_warning=True): + """Initialize a new RealTimeChart + + Args: + symbol: Trading pair symbol (e.g., BTC/USDT) + use_sample_data: Whether to use sample data when no real data is available + log_no_ticks_warning: Whether to log warnings when no ticks are available + """ + self.symbol = symbol + self.use_sample_data = use_sample_data + + # Initialize variables for trading info display + self.current_signal = 'HOLD' + self.signal_time = datetime.now() + self.current_position = 0.0 + self.session_balance = 100.0 # Start with $100 balance + self.session_pnl = 0.0 + + # Initialize NN signals and trades lists + self.nn_signals = [] + self.trades = [] + + # Use existing timezone variable instead of trying to detect again + logger.info(f"Using timezone: {tz_name}") + + # Initialize tick storage + logger.info(f"Initializing RealTimeChart for {symbol}") + self.tick_storage = TradeTickStorage( + symbol=symbol, + max_age_seconds=3600, # Keep ticks for 1 hour + use_sample_data=use_sample_data, + log_no_ticks_warning=log_no_ticks_warning + ) + + # Initialize candlestick data for backward compatibility + self.candlestick_data = CandlestickData(max_length=5000) + + # Initialize candle cache + self.candle_cache = CandleCache(max_candles=5000) + + # Initialize OHLCV cache dictionaries for different timeframes + self.ohlcv_cache = { + '1s': None, + '5s': None, + '15s': None, + '60s': None, + '300s': None, + '900s': None, + '3600s': None, + '1m': None, + '5m': None, + '15m': None, + '1h': None, + '1d': None + } + + # Historical data handler + self.historical_data = BinanceHistoricalData() + + # Flag for first render to force data loading + self.first_render = True + + # Last time candles were saved to disk + self.last_cache_save_time = time.time() + + # Initialize Dash app + self.app = dash.Dash( + __name__, + external_stylesheets=[dbc.themes.DARKLY], + suppress_callback_exceptions=True, + meta_tags=[{"name": "viewport", "content": "width=device-width, initial-scale=1"}] + ) + + # Set up layout and callbacks + self._setup_app_layout() + + def _setup_app_layout(self): + """Set up the app layout and callbacks""" + # Define styling for interval buttons + button_style = { + 'backgroundColor': '#2C2C2C', + 'color': 'white', + 'border': 'none', + 'padding': '10px 15px', + 'margin': '5px', + 'borderRadius': '5px', + 'cursor': 'pointer', + 'fontWeight': 'bold' + } + + active_button_style = { + **button_style, + 'backgroundColor': '#4CAF50', + 'boxShadow': '0 2px 4px rgba(0,0,0,0.5)' + } + + # Create tab layout + self.app.layout = dbc.Tabs([ + dbc.Tab(self._get_chart_layout(button_style, active_button_style), label="Chart", tab_id="chart-tab"), + # No longer need ticks tab as it's causing errors + ], id="tabs") + + # Set up callbacks + self._setup_interval_callback(button_style, active_button_style) + self._setup_chart_callback() + self._setup_position_list_callback() + # We've removed the ticks callback, so don't call it + # self._setup_ticks_callback() + + def _get_chart_layout(self, button_style, active_button_style): + # Chart page layout + return html.Div([ + # Chart title and interval buttons + html.Div([ + html.H2(f"{self.symbol} Real-Time Chart", style={ + 'textAlign': 'center', + 'color': '#FFFFFF', + 'marginBottom': '10px' + }), + + # Store interval data + dcc.Store(id='interval-store', data={'interval': 1}), + + # Interval selection buttons + html.Div([ + html.Button('1s', id='btn-1s', n_clicks=0, style=active_button_style), + html.Button('5s', id='btn-5s', n_clicks=0, style=button_style), + html.Button('15s', id='btn-15s', n_clicks=0, style=button_style), + html.Button('30s', id='btn-30s', n_clicks=0, style=button_style), + html.Button('1m', id='btn-1m', n_clicks=0, style=button_style), + ], style={ + 'display': 'flex', + 'justifyContent': 'center', + 'marginBottom': '15px' + }), + + # Interval component for updates - set to refresh every 500ms + dcc.Interval( + id='interval-component', + interval=300, # Refresh every 300ms for better real-time updates + n_intervals=0 + ), + + # Main chart + dcc.Graph(id='live-chart', style={'height': '75vh'}), + + # Last 5 Positions component + html.Div([ + html.H4("Last 5 Positions", style={'textAlign': 'center', 'color': '#FFFFFF'}), + html.Table([ + html.Thead( + html.Tr([ + html.Th("Action", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}), + html.Th("Size", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}), + html.Th("Entry Price", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}), + html.Th("Exit Price", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}), + html.Th("PnL", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}), + html.Th("Time", style={'padding': '8px', 'border': '1px solid #444', 'backgroundColor': '#333'}) + ]) + ), + html.Tbody(id='position-table-body') + ], style={ + 'width': '100%', + 'borderCollapse': 'collapse', + 'color': '#FFFFFF', + 'backgroundColor': '#222' + }) + ], style={'marginTop': '20px', 'marginBottom': '20px', 'overflowX': 'auto'}), + + # Chart acknowledgment + html.Div("Real-time trading chart with ML signals", style={ + 'textAlign': 'center', + 'color': '#AAAAAA', + 'fontSize': '12px', + 'marginTop': '5px' + }) + ]) + ]) + + def _get_ticks_layout(self): + # Ticks data page layout + return html.Div([ + # Header and controls + html.Div([ + html.H2(f"{self.symbol} Raw Tick Data (Last 5 Minutes)", style={ + 'textAlign': 'center', + 'color': '#FFFFFF', + 'margin': '10px 0' + }), + + # Refresh button + html.Button('Refresh Data', id='refresh-ticks-btn', n_clicks=0, style={ + 'backgroundColor': '#4CAF50', + 'color': 'white', + 'padding': '10px 20px', + 'margin': '10px auto', + 'border': 'none', + 'borderRadius': '5px', + 'fontSize': '14px', + 'cursor': 'pointer', + 'display': 'block' + }), + + # Time window selector + html.Div([ + html.Label("Time Window:", style={'color': 'white', 'marginRight': '10px'}), + dcc.Dropdown( + id='time-window-dropdown', + options=[ + {'label': 'Last 1 minute', 'value': 60}, + {'label': 'Last 5 minutes', 'value': 300}, + {'label': 'Last 15 minutes', 'value': 900}, + {'label': 'Last 30 minutes', 'value': 1800}, + ], + value=300, # Default to 5 minutes + style={'width': '200px', 'backgroundColor': '#2C2C2C', 'color': 'black'} + ) + ], style={ + 'display': 'flex', + 'alignItems': 'center', + 'justifyContent': 'center', + 'margin': '10px' + }), + ], style={ + 'backgroundColor': '#2C2C2C', + 'padding': '10px', + 'borderRadius': '5px', + 'marginBottom': '15px' + }), + + # Stats cards + html.Div(id='tick-stats-cards', style={ + 'display': 'flex', + 'flexWrap': 'wrap', + 'justifyContent': 'space-around', + 'marginBottom': '15px' + }), + + # Ticks data table + html.Div(id='ticks-table-container', style={ + 'backgroundColor': '#232323', + 'padding': '10px', + 'borderRadius': '5px', + 'overflowX': 'auto' + }), + + # Price movement chart + html.Div([ + html.H3("Price Movement", style={ + 'textAlign': 'center', + 'color': '#FFFFFF', + 'margin': '10px 0' + }), + dcc.Graph(id='tick-price-chart') + ], style={ + 'backgroundColor': '#232323', + 'padding': '10px', + 'borderRadius': '5px', + 'marginTop': '15px' + }) + ]) + + def _setup_interval_callback(self, button_style, active_button_style): + # Callback to update interval based on button clicks and update button styles + @self.app.callback( + [Output('interval-store', 'data'), + Output('btn-1s', 'style'), + Output('btn-5s', 'style'), + Output('btn-15s', 'style'), + Output('btn-30s', 'style'), + Output('btn-1m', 'style')], + [Input('btn-1s', 'n_clicks'), + Input('btn-5s', 'n_clicks'), + Input('btn-15s', 'n_clicks'), + Input('btn-30s', 'n_clicks'), + Input('btn-1m', 'n_clicks')], + [dash.dependencies.State('interval-store', 'data')] + ) + def update_interval(n1, n5, n15, n30, n60, data): + ctx = dash.callback_context + if not ctx.triggered: + # Default state (1s selected) + return ({'interval': 1}, + active_button_style, button_style, button_style, button_style, button_style) + + button_id = ctx.triggered[0]['prop_id'].split('.')[0] + + if button_id == 'btn-1s': + return ({'interval': 1}, + active_button_style, button_style, button_style, button_style, button_style) + elif button_id == 'btn-5s': + return ({'interval': 5}, + button_style, active_button_style, button_style, button_style, button_style) + elif button_id == 'btn-15s': + return ({'interval': 15}, + button_style, button_style, active_button_style, button_style, button_style) + elif button_id == 'btn-30s': + return ({'interval': 30}, + button_style, button_style, button_style, active_button_style, button_style) + elif button_id == 'btn-1m': + return ({'interval': 60}, + button_style, button_style, button_style, button_style, active_button_style) + + # Default case - keep current interval and highlight appropriate button + current_interval = data.get('interval', 1) + styles = [button_style] * 5 # All inactive by default + + # Set active style based on current interval + if current_interval == 1: + styles[0] = active_button_style + elif current_interval == 5: + styles[1] = active_button_style + elif current_interval == 15: + styles[2] = active_button_style + elif current_interval == 30: + styles[3] = active_button_style + elif current_interval == 60: + styles[4] = active_button_style + + return (data, *styles) + + def _setup_chart_callback(self): + # Callback to update the chart + @self.app.callback( + Output('live-chart', 'figure'), + [Input('interval-component', 'n_intervals'), + Input('interval-store', 'data')] + ) + def update_chart(n, interval_data): + try: + interval = interval_data.get('interval', 1) + logger.debug(f"Updating chart with interval {interval}") + + # Get candlesticks data for the selected interval + try: + df = self.tick_storage.get_candles(interval_seconds=interval) + if df.empty and self.ohlcv_cache[f'{interval}s' if interval < 60 else '1m'] is not None: + df = self.ohlcv_cache[f'{interval}s' if interval < 60 else '1m'] + except Exception as e: + logger.error(f"Error getting candles: {str(e)}") + df = pd.DataFrame() + + # Get data for other timeframes (1m, 1h, 1d) + df_1m = None + df_1h = None + df_1d = None + + try: + # Get 1m candles + if self.ohlcv_cache['1m'] is not None and not self.ohlcv_cache['1m'].empty: + df_1m = self.ohlcv_cache['1m'].copy() + else: + df_1m = self.tick_storage.get_candles(interval_seconds=60) + + # Get 1h candles + if self.ohlcv_cache['1h'] is not None and not self.ohlcv_cache['1h'].empty: + df_1h = self.ohlcv_cache['1h'].copy() + else: + df_1h = self.tick_storage.get_candles(interval_seconds=3600) + + # Get 1d candles + if self.ohlcv_cache['1d'] is not None and not self.ohlcv_cache['1d'].empty: + df_1d = self.ohlcv_cache['1d'].copy() + else: + df_1d = self.tick_storage.get_candles(interval_seconds=86400) + + # Limit the number of candles to display but show more for context + if df_1m is not None and not df_1m.empty: + df_1m = df_1m.tail(600) # Show 600 1m candles for better context - 10 hours + if df_1h is not None and not df_1h.empty: + df_1h = df_1h.tail(480) # Show 480 hours of hourly data for better context - 20 days + if df_1d is not None and not df_1d.empty: + df_1d = df_1d.tail(365*2) # Show 2 years of daily data for better context + + except Exception as e: + logger.error(f"Error getting additional timeframes: {str(e)}") + + # Create layout with 5 rows (main chart, volume, 1m, 1h, 1d) + fig = make_subplots( + rows=5, cols=1, + vertical_spacing=0.03, + subplot_titles=( + f'{self.symbol} Price ({interval}s)', + 'Volume', + '1-Minute Chart', + '1-Hour Chart', + '1-Day Chart' + ), + row_heights=[0.4, 0.15, 0.15, 0.15, 0.15] + ) + + # Add candlestick chart for main timeframe + if not df.empty and 'open' in df.columns: + # Candlestick chart + fig.add_trace( + go.Candlestick( + x=df.index, + open=df['open'], + high=df['high'], + low=df['low'], + close=df['close'], + name='OHLC' + ), + row=1, col=1 + ) + + # Calculate y-axis range with padding + low_min = df['low'].min() + high_max = df['high'].max() + price_range = high_max - low_min + y_min = low_min - (price_range * 0.05) # 5% padding below + y_max = high_max + (price_range * 0.05) # 5% padding above + + # Set y-axis range to ensure candles are visible + fig.update_yaxes(range=[y_min, y_max], row=1, col=1) + + # Volume bars + colors = ['rgba(0,255,0,0.7)' if close >= open else 'rgba(255,0,0,0.7)' + for open, close in zip(df['open'], df['close'])] + + fig.add_trace( + go.Bar( + x=df.index, + y=df['volume'], + name='Volume', + marker_color=colors + ), + row=2, col=1 + ) + + # Add buy/sell markers and PnL annotations to the main chart + if hasattr(self, 'trades') and self.trades: + buy_times = [] + buy_prices = [] + buy_markers = [] + sell_times = [] + sell_prices = [] + sell_markers = [] + + # Filter trades to only include recent ones (last 100) + recent_trades = self.trades[-100:] + + for trade in recent_trades: + # Convert timestamp to datetime if it's not already + trade_time = trade.get('timestamp') + if isinstance(trade_time, (int, float)): + trade_time = pd.to_datetime(trade_time, unit='ms') + + price = trade.get('price', 0) + pnl = trade.get('pnl', None) + action = trade.get('action', 'SELL') # Default to SELL + + if action == 'BUY': + buy_times.append(trade_time) + buy_prices.append(price) + buy_markers.append("") + elif action == 'SELL': + sell_times.append(trade_time) + sell_prices.append(price) + # Add PnL as marker text if available + if pnl is not None: + pnl_text = f"{pnl:.4f}" if abs(pnl) < 0.01 else f"{pnl:.2f}" + sell_markers.append(pnl_text) + else: + sell_markers.append("") + + # Add buy markers + if buy_times: + fig.add_trace( + go.Scatter( + x=buy_times, + y=buy_prices, + mode='markers', + name='Buy', + marker=dict( + symbol='triangle-up', + size=12, + color='rgba(0,255,0,0.8)', + line=dict(width=1, color='darkgreen') + ), + text=buy_markers, + hoverinfo='x+y+text', + showlegend=True # Ensure Buy appears in legend + ), + row=1, col=1 + ) + + # Add vertical and horizontal connecting lines for buys + for i, (btime, bprice) in enumerate(zip(buy_times, buy_prices)): + # Add vertical dashed line to time axis + fig.add_shape( + type="line", + x0=btime, x1=btime, + y0=y_min, y1=bprice, + line=dict(color="rgba(0,255,0,0.5)", width=1, dash="dash"), + row=1, col=1 + ) + # Add horizontal dashed line showing the price level + fig.add_shape( + type="line", + x0=df.index.min(), x1=btime, + y0=bprice, y1=bprice, + line=dict(color="rgba(0,255,0,0.5)", width=1, dash="dash"), + row=1, col=1 + ) + + # Add sell markers with PnL annotations + if sell_times: + fig.add_trace( + go.Scatter( + x=sell_times, + y=sell_prices, + mode='markers+text', + name='Sell', + marker=dict( + symbol='triangle-down', + size=12, + color='rgba(255,0,0,0.8)', + line=dict(width=1, color='darkred') + ), + text=sell_markers, + textposition='top center', + textfont=dict(size=10), + hoverinfo='x+y+text', + showlegend=True # Ensure Sell appears in legend + ), + row=1, col=1 + ) + + # Add vertical and horizontal connecting lines for sells + for i, (stime, sprice) in enumerate(zip(sell_times, sell_prices)): + # Add vertical dashed line to time axis + fig.add_shape( + type="line", + x0=stime, x1=stime, + y0=y_min, y1=sprice, + line=dict(color="rgba(255,0,0,0.5)", width=1, dash="dash"), + row=1, col=1 + ) + # Add horizontal dashed line showing the price level + fig.add_shape( + type="line", + x0=df.index.min(), x1=stime, + y0=sprice, y1=sprice, + line=dict(color="rgba(255,0,0,0.5)", width=1, dash="dash"), + row=1, col=1 + ) + + # Add connecting lines between consecutive buy-sell pairs + if len(buy_times) > 0 and len(sell_times) > 0: + # Create pairs of buy-sell trades based on timestamps + pairs = [] + buys_copy = list(zip(buy_times, buy_prices)) + + for i, (stime, sprice) in enumerate(zip(sell_times, sell_prices)): + # Find the most recent buy before this sell + matching_buy = None + for j, (btime, bprice) in enumerate(buys_copy): + if btime < stime: + matching_buy = (btime, bprice) + buys_copy.pop(j) # Remove this buy to prevent reuse + break + + if matching_buy: + pairs.append((matching_buy, (stime, sprice))) + + # Add connecting lines for each pair + for (btime, bprice), (stime, sprice) in pairs: + # Draw line connecting the buy and sell points + fig.add_shape( + type="line", + x0=btime, x1=stime, + y0=bprice, y1=sprice, + line=dict( + color="rgba(255,255,255,0.5)", + width=1, + dash="dot" + ), + row=1, col=1 + ) + + # Add 1m chart + if df_1m is not None and not df_1m.empty and 'open' in df_1m.columns: + fig.add_trace( + go.Candlestick( + x=df_1m.index, + open=df_1m['open'], + high=df_1m['high'], + low=df_1m['low'], + close=df_1m['close'], + name='1m', + showlegend=False + ), + row=3, col=1 + ) + + # Set appropriate date format for 1m chart + fig.update_xaxes( + title_text="", + row=3, + col=1, + tickformat="%H:%M", + tickmode="auto", + nticks=12 + ) + + # Add buy/sell markers to 1m chart if they fall within the visible timeframe + if hasattr(self, 'trades') and self.trades: + # Filter trades visible in 1m timeframe + min_time = df_1m.index.min() + max_time = df_1m.index.max() + + # Ensure min_time and max_time are pandas.Timestamp objects + if isinstance(min_time, (int, float)): + min_time = pd.to_datetime(min_time, unit='ms') + if isinstance(max_time, (int, float)): + max_time = pd.to_datetime(max_time, unit='ms') + + # Collect only trades within this timeframe + minute_buy_times = [] + minute_buy_prices = [] + minute_sell_times = [] + minute_sell_prices = [] + + for trade in self.trades[-100:]: + trade_time = trade.get('timestamp') + if isinstance(trade_time, (int, float)): + # Convert numeric timestamp to datetime + trade_time = pd.to_datetime(trade_time, unit='ms') + elif not isinstance(trade_time, pd.Timestamp) and not isinstance(trade_time, datetime): + # Skip trades with invalid timestamp format + continue + + # Check if trade falls within 1m chart timeframe + try: + if min_time <= trade_time <= max_time: + price = trade.get('price', 0) + action = trade.get('action', 'SELL') + + if action == 'BUY': + minute_buy_times.append(trade_time) + minute_buy_prices.append(price) + elif action == 'SELL': + minute_sell_times.append(trade_time) + minute_sell_prices.append(price) + except TypeError: + # If comparison fails due to type mismatch, log the error and skip this trade + logger.warning(f"Type mismatch in timestamp comparison: min_time={type(min_time)}, trade_time={type(trade_time)}") + continue + + # Add buy markers to 1m chart + if minute_buy_times: + fig.add_trace( + go.Scatter( + x=minute_buy_times, + y=minute_buy_prices, + mode='markers', + name='Buy (1m)', + marker=dict( + symbol='triangle-up', + size=8, + color='rgba(0,255,0,0.8)', + line=dict(width=1, color='darkgreen') + ), + showlegend=False, + hoverinfo='x+y' + ), + row=3, col=1 + ) + + # Add sell markers to 1m chart + if minute_sell_times: + fig.add_trace( + go.Scatter( + x=minute_sell_times, + y=minute_sell_prices, + mode='markers', + name='Sell (1m)', + marker=dict( + symbol='triangle-down', + size=8, + color='rgba(255,0,0,0.8)', + line=dict(width=1, color='darkred') + ), + showlegend=False, + hoverinfo='x+y' + ), + row=3, col=1 + ) + + # Add 1h chart + if df_1h is not None and not df_1h.empty and 'open' in df_1h.columns: + fig.add_trace( + go.Candlestick( + x=df_1h.index, + open=df_1h['open'], + high=df_1h['high'], + low=df_1h['low'], + close=df_1h['close'], + name='1h', + showlegend=False + ), + row=4, col=1 + ) + + # Set appropriate date format for 1h chart + fig.update_xaxes( + title_text="", + row=4, + col=1, + tickformat="%m-%d %H:%M", + tickmode="auto", + nticks=8 + ) + + # Add buy/sell markers to 1h chart if they fall within the visible timeframe + if hasattr(self, 'trades') and self.trades: + # Filter trades visible in 1h timeframe + min_time = df_1h.index.min() + max_time = df_1h.index.max() + + # Ensure min_time and max_time are pandas.Timestamp objects + if isinstance(min_time, (int, float)): + min_time = pd.to_datetime(min_time, unit='ms') + if isinstance(max_time, (int, float)): + max_time = pd.to_datetime(max_time, unit='ms') + + # Collect only trades within this timeframe + hour_buy_times = [] + hour_buy_prices = [] + hour_sell_times = [] + hour_sell_prices = [] + + for trade in self.trades[-200:]: # Check more trades for longer timeframe + trade_time = trade.get('timestamp') + if isinstance(trade_time, (int, float)): + # Convert numeric timestamp to datetime + trade_time = pd.to_datetime(trade_time, unit='ms') + elif not isinstance(trade_time, pd.Timestamp) and not isinstance(trade_time, datetime): + # Skip trades with invalid timestamp format + continue + + # Check if trade falls within 1h chart timeframe + try: + if min_time <= trade_time <= max_time: + price = trade.get('price', 0) + action = trade.get('action', 'SELL') + + if action == 'BUY': + hour_buy_times.append(trade_time) + hour_buy_prices.append(price) + elif action == 'SELL': + hour_sell_times.append(trade_time) + hour_sell_prices.append(price) + except TypeError: + # If comparison fails due to type mismatch, log the error and skip this trade + logger.warning(f"Type mismatch in timestamp comparison: min_time={type(min_time)}, trade_time={type(trade_time)}") + continue + + # Add buy markers to 1h chart + if hour_buy_times: + fig.add_trace( + go.Scatter( + x=hour_buy_times, + y=hour_buy_prices, + mode='markers', + name='Buy (1h)', + marker=dict( + symbol='triangle-up', + size=6, + color='rgba(0,255,0,0.8)', + line=dict(width=1, color='darkgreen') + ), + showlegend=False, + hoverinfo='x+y' + ), + row=4, col=1 + ) + + # Add sell markers to 1h chart + if hour_sell_times: + fig.add_trace( + go.Scatter( + x=hour_sell_times, + y=hour_sell_prices, + mode='markers', + name='Sell (1h)', + marker=dict( + symbol='triangle-down', + size=6, + color='rgba(255,0,0,0.8)', + line=dict(width=1, color='darkred') + ), + showlegend=False, + hoverinfo='x+y' + ), + row=4, col=1 + ) + # Add 1d chart + if df_1d is not None and not df_1d.empty and 'open' in df_1d.columns: + fig.add_trace( + go.Candlestick( + x=df_1d.index, + open=df_1d['open'], + high=df_1d['high'], + low=df_1d['low'], + close=df_1d['close'], + name='1d', + showlegend=False + ), + row=5, col=1 + ) + + # Set appropriate date format for 1d chart + fig.update_xaxes( + title_text="", + row=5, + col=1, + tickformat="%Y-%m-%d", + tickmode="auto", + nticks=10 + ) + + # Add buy/sell markers to 1d chart if they fall within the visible timeframe + if hasattr(self, 'trades') and self.trades: + # Filter trades visible in 1d timeframe + min_time = df_1d.index.min() + max_time = df_1d.index.max() + + # Ensure min_time and max_time are pandas.Timestamp objects + if isinstance(min_time, (int, float)): + min_time = pd.to_datetime(min_time, unit='ms') + if isinstance(max_time, (int, float)): + max_time = pd.to_datetime(max_time, unit='ms') + + # Collect only trades within this timeframe + day_buy_times = [] + day_buy_prices = [] + day_sell_times = [] + day_sell_prices = [] + + for trade in self.trades[-300:]: # Check more trades for daily timeframe + trade_time = trade.get('timestamp') + if isinstance(trade_time, (int, float)): + # Convert numeric timestamp to datetime + trade_time = pd.to_datetime(trade_time, unit='ms') + elif not isinstance(trade_time, pd.Timestamp) and not isinstance(trade_time, datetime): + # Skip trades with invalid timestamp format + continue + + # Check if trade falls within 1d chart timeframe + try: + if min_time <= trade_time <= max_time: + price = trade.get('price', 0) + action = trade.get('action', 'SELL') + + if action == 'BUY': + day_buy_times.append(trade_time) + day_buy_prices.append(price) + elif action == 'SELL': + day_sell_times.append(trade_time) + day_sell_prices.append(price) + except TypeError: + # If comparison fails due to type mismatch, log the error and skip this trade + logger.warning(f"Type mismatch in timestamp comparison: min_time={type(min_time)}, trade_time={type(trade_time)}") + continue + + # Add buy markers to 1d chart + if day_buy_times: + fig.add_trace( + go.Scatter( + x=day_buy_times, + y=day_buy_prices, + mode='markers', + name='Buy (1d)', + marker=dict( + symbol='triangle-up', + size=5, + color='rgba(0,255,0,0.8)', + line=dict(width=1, color='darkgreen') + ), + showlegend=False, + hoverinfo='x+y' + ), + row=5, col=1 + ) + + # Add sell markers to 1d chart + if day_sell_times: + fig.add_trace( + go.Scatter( + x=day_sell_times, + y=day_sell_prices, + mode='markers', + name='Sell (1d)', + marker=dict( + symbol='triangle-down', + size=5, + color='rgba(255,0,0,0.8)', + line=dict(width=1, color='darkred') + ), + showlegend=False, + hoverinfo='x+y' + ), + row=5, col=1 + ) + + # Add trading info annotation if available + if hasattr(self, 'current_signal') and self.current_signal: + signal_color = "#33DD33" if self.current_signal == "BUY" else "#FF4444" if self.current_signal == "SELL" else "#BBBBBB" + + # Format position value + position_text = f"{self.current_position:.4f}" if self.current_position < 0.01 else f"{self.current_position:.2f}" + + # Format PnL with color based on value + pnl_color = "#33DD33" if self.session_pnl >= 0 else "#FF4444" + pnl_text = f"{self.session_pnl:.4f}" + + # Create trading info text + info_text = ( + f"Signal: {self.current_signal} | " + f"Position: {position_text} | " + f"Balance: ${self.session_balance:.2f} | " + f"PnL: {pnl_text}" + ) + + # Add annotation + fig.add_annotation( + x=0.5, y=1.05, + xref="paper", yref="paper", + text=info_text, + showarrow=False, + font=dict(size=14, color="white"), + bgcolor="rgba(50,50,50,0.6)", + borderwidth=1, + borderpad=6, + align="center" + ) + + # Update layout + fig.update_layout( + title_text=f"{self.symbol} Real-Time Data", + title_x=0.5, + xaxis_rangeslider_visible=False, + height=1000, # Increased height to accommodate all charts + template='plotly_dark', + paper_bgcolor='rgba(0,0,0,0)', + plot_bgcolor='rgba(25,25,50,1)' + ) + + # Update axes styling for all subplots + fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)') + fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)') + + # Hide rangesliders for all candlestick charts + fig.update_layout( + xaxis_rangeslider_visible=False, + xaxis2_rangeslider_visible=False, + xaxis3_rangeslider_visible=False, + xaxis4_rangeslider_visible=False, + xaxis5_rangeslider_visible=False + ) + + # Improve date formatting for the main chart + fig.update_xaxes( + title_text="", + row=1, + col=1, + tickformat="%H:%M:%S", + tickmode="auto", + nticks=15 + ) + + return fig + + except Exception as e: + logger.error(f"Error updating chart: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + fig = go.Figure() + fig.add_annotation( + x=0.5, y=0.5, + text=f"Error updating chart: {str(e)}", + showarrow=False, + font=dict(size=14, color="red"), + xref="paper", yref="paper" + ) + return fig + + def _setup_position_list_callback(self): + """Callback to update the position list with the latest 5 trades""" + @self.app.callback( + Output('position-table-body', 'children'), + [Input('interval-component', 'n_intervals')] + ) + def update_position_list(n): + try: + # Get the last 5 positions + if not hasattr(self, 'trades') or not self.trades: + return [html.Tr([html.Td("No trades yet", colSpan=6, style={'textAlign': 'center', 'color': '#AAAAAA', 'padding': '10px'})])] + + # Collect BUY and SELL trades to pair them + buy_trades = {} + sell_trades = {} + position_rows = [] + + # Process all trades to match buy and sell pairs + for trade in self.trades: + action = trade.get('action', 'UNKNOWN') + if action == 'BUY': + # Store buy trades by timestamp to match with sells later + buy_trades[trade.get('timestamp')] = trade + elif action == 'SELL': + # Store sell trades + sell_trades[trade.get('timestamp')] = trade + + # Create position entries (matched pairs or single trades) + position_entries = [] + + # First add matched pairs (BUY trades with matching close_price) + for timestamp, buy_trade in buy_trades.items(): + if 'close_price' in buy_trade: + # This is a closed position + entry_price = buy_trade.get('price', 'N/A') + exit_price = buy_trade.get('close_price', 'N/A') + entry_time = buy_trade.get('timestamp') + exit_time = buy_trade.get('close_timestamp') + amount = buy_trade.get('amount', 0.1) + + # Calculate PnL if not provided + pnl = buy_trade.get('pnl') + if pnl is None and isinstance(entry_price, (int, float)) and isinstance(exit_price, (int, float)): + pnl = (exit_price - entry_price) * amount + + position_entries.append({ + 'action': 'CLOSED', + 'entry_price': entry_price, + 'exit_price': exit_price, + 'amount': amount, + 'pnl': pnl, + 'time': entry_time, + 'exit_time': exit_time, + 'status': 'CLOSED' + }) + else: + # This is an open position (BUY without a matching SELL) + current_price = self.tick_storage.get_latest_price() or buy_trade.get('price', 0) + amount = buy_trade.get('amount', 0.1) + entry_price = buy_trade.get('price', 0) + + # Calculate unrealized PnL + unrealized_pnl = (current_price - entry_price) * amount if isinstance(current_price, (int, float)) and isinstance(entry_price, (int, float)) else None + + position_entries.append({ + 'action': 'BUY', + 'entry_price': entry_price, + 'exit_price': current_price, # Use current price as potential exit + 'amount': amount, + 'pnl': unrealized_pnl, + 'time': buy_trade.get('timestamp'), + 'status': 'OPEN' + }) + + # Add standalone SELL trades that don't have a matching BUY + for timestamp, sell_trade in sell_trades.items(): + # Check if this SELL is already accounted for in a closed position + already_matched = False + for entry in position_entries: + if entry.get('status') == 'CLOSED' and entry.get('exit_time') == timestamp: + already_matched = True + break + + if not already_matched: + position_entries.append({ + 'action': 'SELL', + 'entry_price': 'N/A', + 'exit_price': sell_trade.get('price', 'N/A'), + 'amount': sell_trade.get('amount', 0.1), + 'pnl': sell_trade.get('pnl'), + 'time': sell_trade.get('timestamp'), + 'status': 'STANDALONE' + }) + + # Sort by time (most recent first) and take last 5 + position_entries.sort(key=lambda x: x['time'] if isinstance(x['time'], datetime) else datetime.now(), reverse=True) + position_entries = position_entries[:5] + + # Convert to table rows + for entry in position_entries: + action = entry['action'] + entry_price = entry['entry_price'] + exit_price = entry['exit_price'] + amount = entry['amount'] + pnl = entry['pnl'] + time_obj = entry['time'] + status = entry['status'] + + # Format time + if isinstance(time_obj, datetime): + # If trade is from a different day, include the date + today = datetime.now().date() + if time_obj.date() == today: + time_str = time_obj.strftime('%H:%M:%S') + else: + time_str = time_obj.strftime('%m-%d %H:%M:%S') + else: + time_str = str(time_obj) + + # Format prices with proper decimal places + if isinstance(entry_price, (int, float)): + entry_price_str = f"${entry_price:.2f}" + else: + entry_price_str = str(entry_price) + + if isinstance(exit_price, (int, float)): + exit_price_str = f"${exit_price:.2f}" + else: + exit_price_str = str(exit_price) + + # Format PnL + if pnl is not None and isinstance(pnl, (int, float)): + pnl_str = f"${pnl:.2f}" + pnl_color = '#00FF00' if pnl >= 0 else '#FF0000' + else: + pnl_str = 'N/A' + pnl_color = '#FFFFFF' + + # Set action/status color and text + if status == 'OPEN': + status_color = '#00AAFF' # Blue for open positions + status_text = "OPEN (BUY)" + elif status == 'CLOSED': + if pnl is not None and isinstance(pnl, (int, float)): + status_color = '#00FF00' if pnl >= 0 else '#FF0000' # Green/Red based on profit + else: + status_color = '#FFCC00' # Yellow if PnL unknown + status_text = "CLOSED" + elif action == 'BUY': + status_color = '#00FF00' + status_text = "BUY" + elif action == 'SELL': + status_color = '#FF0000' + status_text = "SELL" + else: + status_color = '#FFFFFF' + status_text = action + + # Create table row + position_rows.append(html.Tr([ + html.Td(status_text, style={'color': status_color, 'padding': '8px', 'border': '1px solid #444'}), + html.Td(f"{amount} BTC", style={'padding': '8px', 'border': '1px solid #444'}), + html.Td(entry_price_str, style={'padding': '8px', 'border': '1px solid #444'}), + html.Td(exit_price_str, style={'padding': '8px', 'border': '1px solid #444'}), + html.Td(pnl_str, style={'color': pnl_color, 'padding': '8px', 'border': '1px solid #444'}), + html.Td(time_str, style={'padding': '8px', 'border': '1px solid #444'}) + ])) + + return position_rows + except Exception as e: + logger.error(f"Error updating position table: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + return [html.Tr([html.Td(f"Error: {str(e)}", colSpan=6, style={'color': '#FF0000', 'padding': '10px'})])] + + def _interval_to_seconds(self, interval_key: str) -> int: + """Convert interval key to seconds""" + mapping = { + '1s': 1, + '1m': 60, + '1h': 3600, + '1d': 86400 + } + return mapping.get(interval_key, 1) + + async def start_websocket(self): + ws = ExchangeWebSocket(self.symbol) + connection_attempts = 0 + max_attempts = 10 # Maximum connection attempts before longer waiting period + + while True: # Keep trying to maintain connection + connection_attempts += 1 + if not await ws.connect(): + logger.error(f"Failed to connect to exchange for {self.symbol}") + # Gradually increase wait time based on number of connection failures + wait_time = min(5 * connection_attempts, 60) # Cap at 60 seconds + logger.warning(f"Waiting {wait_time} seconds before retry (attempt {connection_attempts})") + + if connection_attempts >= max_attempts: + logger.warning(f"Reached {max_attempts} connection attempts, taking a longer break") + await asyncio.sleep(120) # 2 minutes wait after max attempts + connection_attempts = 0 # Reset counter + else: + await asyncio.sleep(wait_time) + continue + + # Successfully connected + connection_attempts = 0 + + try: + logger.info(f"WebSocket connected for {self.symbol}, beginning data collection") + tick_count = 0 + last_tick_count_log = time.time() + last_status_report = time.time() + + # Track stats for reporting + price_min = float('inf') + price_max = float('-inf') + price_last = None + volume_total = 0 + start_collection_time = time.time() + + while True: + if not ws.running: + logger.warning(f"WebSocket connection lost for {self.symbol}, breaking loop") + break + + data = await ws.receive() + if data: + if data.get('type') == 'kline': + # Use kline data directly for candlestick + trade_data = { + 'timestamp': data['timestamp'], + 'price': data['price'], + 'volume': data['volume'], + 'open': data['open'], + 'high': data['high'], + 'low': data['low'] + } + logger.debug(f"Received kline data: {data}") + else: + # Use trade data + trade_data = { + 'timestamp': data['timestamp'], + 'price': data['price'], + 'volume': data['volume'] + } + + # Update stats + price = trade_data['price'] + volume = trade_data['volume'] + price_min = min(price_min, price) + price_max = max(price_max, price) + price_last = price + volume_total += volume + + # Store raw tick in the tick storage + self.tick_storage.add_tick(trade_data) + tick_count += 1 + + # Also update the old candlestick data for backward compatibility + # Add check to ensure the candlestick_data attribute exists before using it + if hasattr(self, 'candlestick_data'): + self.candlestick_data.update_from_trade(trade_data) + + # Log tick counts periodically + current_time = time.time() + if current_time - last_tick_count_log >= 10: # Log every 10 seconds + elapsed = current_time - last_tick_count_log + tps = tick_count / elapsed if elapsed > 0 else 0 + logger.info(f"{self.symbol}: Collected {tick_count} ticks in last {elapsed:.1f}s ({tps:.2f} ticks/sec), total: {len(self.tick_storage.ticks)}") + last_tick_count_log = current_time + tick_count = 0 + + # Check if ticks are being converted to candles + if len(self.tick_storage.ticks) > 0: + sample_df = self.tick_storage.get_candles(interval_seconds=1) + logger.info(f"{self.symbol}: Sample candle count: {len(sample_df)}") + + # Periodic status report (every 60 seconds) + if current_time - last_status_report >= 60: + elapsed_total = current_time - start_collection_time + logger.info(f"{self.symbol} Status Report:") + logger.info(f" Collection time: {elapsed_total:.1f} seconds") + logger.info(f" Price range: {price_min:.2f} - {price_max:.2f} (last: {price_last:.2f})") + logger.info(f" Total volume: {volume_total:.8f}") + logger.info(f" Active ticks in storage: {len(self.tick_storage.ticks)}") + + # Reset stats for next period + last_status_report = current_time + price_min = float('inf') if price_last is None else price_last + price_max = float('-inf') if price_last is None else price_last + volume_total = 0 + + await asyncio.sleep(0.01) + except websockets.exceptions.ConnectionClosed as e: + logger.error(f"WebSocket connection closed for {self.symbol}: {str(e)}") + except Exception as e: + logger.error(f"Error in WebSocket loop for {self.symbol}: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + finally: + logger.info(f"Closing WebSocket connection for {self.symbol}") + await ws.close() + + logger.info(f"Waiting 5 seconds before reconnecting {self.symbol} WebSocket...") + await asyncio.sleep(5) + + def run(self, host='localhost', port=8050): + """Run the Dash app + + Args: + host: Hostname to run on + port: Port to run on + """ + logger.info(f"Starting Dash app on {host}:{port}") + + # Ensure interval component is created + if not hasattr(self, 'app') or not self.app.layout: + logger.error("App layout not initialized properly") + return + + # If interval-component is not in the layout, add it + if 'interval-component' not in str(self.app.layout): + logger.warning("Interval component not found in layout, adding it") + self.app.layout.children.append( + dcc.Interval( + id='interval-component', + interval=500, # 500ms for real-time updates + n_intervals=0 + ) + ) + + # Start websocket connection in a separate thread + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self.websocket_thread = threading.Thread(target=lambda: asyncio.run(self.start_websocket())) + self.websocket_thread.daemon = True + self.websocket_thread.start() + + # Ensure historical data is loaded before starting + self._load_historical_data() + + try: + self.app.run(host=host, port=port, debug=False) + except Exception as e: + logger.error(f"Error running Dash app: {str(e)}") + finally: + # Ensure resources are cleaned up + self._save_candles_to_disk(force=True) + logger.info("Dash app stopped") + + def _load_historical_data(self): + """Load historical data for all timeframes from Binance API and local cache""" + try: + logger.info(f"Loading historical data for {self.symbol}...") + + # Define intervals to fetch + intervals = { + '1s': 1, + '1m': 60, + '1h': 3600, + '1d': 86400 + } + + # Track load status + load_status = {interval: False for interval in intervals.keys()} + + # First try to load from local cache files + logger.info("Step 1: Loading from local cache files...") + for interval_key, interval_seconds in intervals.items(): + try: + cache_file = os.path.join(self.historical_data.cache_dir, + f"{self.symbol.replace('/', '_')}_{interval_key}_candles.csv") + + logger.info(f"Checking for cached {interval_key} data at {cache_file}") + if os.path.exists(cache_file): + # Check if cache is fresh (less than 1 day old for anything but 1d, 3 days for 1d) + file_age = time.time() - os.path.getmtime(cache_file) + max_age = 259200 if interval_key == '1d' else 86400 # 3 days for 1d, 1 day for others + logger.info(f"Cache file age: {file_age:.1f}s, max allowed: {max_age}s") + + if file_age <= max_age: + logger.info(f"Loading {interval_key} candles from cache") + cached_df = pd.read_csv(cache_file) + if not cached_df.empty: + # Diagnostic info about the loaded data + logger.info(f"Loaded {len(cached_df)} candles from {cache_file}") + logger.info(f"Columns: {cached_df.columns.tolist()}") + logger.info(f"First few rows: {cached_df.head(2).to_dict('records')}") + + # Convert timestamp string back to datetime + if 'timestamp' in cached_df.columns: + try: + if not pd.api.types.is_datetime64_any_dtype(cached_df['timestamp']): + cached_df['timestamp'] = pd.to_datetime(cached_df['timestamp']) + logger.info("Successfully converted timestamps to datetime") + except Exception as e: + logger.warning(f"Could not convert timestamp column for {interval_key}: {str(e)}") + + # Only keep the last 2000 candles for memory efficiency + if len(cached_df) > 2000: + cached_df = cached_df.tail(2000) + logger.info(f"Truncated to last 2000 candles") + + # Add to cache + for _, row in cached_df.iterrows(): + candle_dict = row.to_dict() + self.candle_cache.candles[interval_key].append(candle_dict) + + # Update ohlcv_cache + self.ohlcv_cache[interval_key] = self.candle_cache.get_recent_candles(interval_key, count=2000) + logger.info(f"Successfully loaded {len(self.ohlcv_cache[interval_key])} cached {interval_key} candles") + + if len(self.ohlcv_cache[interval_key]) >= 500: + load_status[interval_key] = True + # Skip fetching from API if we loaded from cache (except for 1d timeframe which we always refresh) + if interval_key != '1d': + continue + else: + logger.info(f"Cache file for {interval_key} is too old ({file_age:.1f}s)") + else: + logger.info(f"No cache file found for {interval_key}") + except Exception as e: + logger.error(f"Error loading cached {interval_key} candles: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + + # For timeframes other than 1s, fetch from API as backup or for fresh data + logger.info("Step 2: Fetching data from API for missing timeframes...") + for interval_key, interval_seconds in intervals.items(): + # Skip 1s for API requests + if interval_key == '1s' or load_status[interval_key]: + logger.info(f"Skipping API fetch for {interval_key}: already loaded or 1s timeframe") + continue + + # Fetch historical data from API + try: + logger.info(f"Fetching {interval_key} candles from API for {self.symbol}") + historical_df = self.historical_data.get_historical_candles( + symbol=self.symbol, + interval_seconds=interval_seconds, + limit=500 # Get 500 candles + ) + + if not historical_df.empty: + logger.info(f"Loaded {len(historical_df)} historical candles for {self.symbol} {interval_key} from API") + + # If we already have data in cache, merge with new data to avoid duplicates + if self.ohlcv_cache[interval_key] is not None and not self.ohlcv_cache[interval_key].empty: + existing_df = self.ohlcv_cache[interval_key] + # Get the latest timestamp from existing data + latest_time = existing_df['timestamp'].max() + # Only keep newer records from API + new_candles = historical_df[historical_df['timestamp'] > latest_time] + if not new_candles.empty: + logger.info(f"Adding {len(new_candles)} new candles to existing {interval_key} cache") + # Add to cache + for _, row in new_candles.iterrows(): + candle_dict = row.to_dict() + self.candle_cache.candles[interval_key].append(candle_dict) + else: + # No existing data, add all from API + for _, row in historical_df.iterrows(): + candle_dict = row.to_dict() + self.candle_cache.candles[interval_key].append(candle_dict) + + # Update ohlcv_cache with combined data + self.ohlcv_cache[interval_key] = self.candle_cache.get_recent_candles(interval_key, count=2000) + logger.info(f"Total {interval_key} candles in cache: {len(self.ohlcv_cache[interval_key])}") + + if len(self.ohlcv_cache[interval_key]) >= 500: + load_status[interval_key] = True + else: + logger.warning(f"No historical data available from API for {self.symbol} {interval_key}") + except Exception as e: + logger.error(f"Error fetching {interval_key} data from API: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + + # Log summary of loaded data + logger.info("Historical data load summary:") + for interval_key in intervals.keys(): + count = len(self.ohlcv_cache[interval_key]) if self.ohlcv_cache[interval_key] is not None else 0 + status = "Success" if load_status[interval_key] else "Failed" + if count > 0 and count < 500: + status = "Partial" + logger.info(f"{interval_key}: {count} candles - {status}") + + except Exception as e: + logger.error(f"Error in _load_historical_data: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + + def _save_candles_to_disk(self, force=False): + """Save current candle cache to disk for persistence between runs""" + try: + # Only save if we have data and sufficient time has passed (every 5 minutes) + current_time = time.time() + if not force and current_time - self.last_cache_save_time < 300: # 5 minutes + return + + # Save each timeframe's candles to disk + for interval_key, candles in self.candle_cache.candles.items(): + if candles: + # Convert to DataFrame + df = pd.DataFrame(list(candles)) + if not df.empty: + # Ensure timestamp is properly formatted + if 'timestamp' in df.columns: + try: + if not pd.api.types.is_datetime64_any_dtype(df['timestamp']): + df['timestamp'] = pd.to_datetime(df['timestamp']) + except: + logger.warning(f"Could not convert timestamp column for {interval_key}") + + # Save to disk in the cache directory + cache_file = os.path.join(self.historical_data.cache_dir, + f"{self.symbol.replace('/', '_')}_{interval_key}_candles.csv") + df.to_csv(cache_file, index=False) + logger.info(f"Saved {len(df)} {interval_key} candles to {cache_file}") + + self.last_cache_save_time = current_time + logger.info(f"Saved all candle caches to disk at {datetime.now()}") + except Exception as e: + logger.error(f"Error saving candles to disk: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + + def add_nn_signal(self, signal_type, timestamp, probability=None): + """Add a neural network signal to be displayed on the chart + + Args: + signal_type: The type of signal (BUY, SELL, HOLD) + timestamp: The timestamp for the signal + probability: Optional probability/confidence value + """ + if signal_type not in ['BUY', 'SELL', 'HOLD']: + logger.warning(f"Invalid NN signal type: {signal_type}") + return + + # Convert timestamp to datetime if it's not already + if not isinstance(timestamp, datetime): + try: + if isinstance(timestamp, str): + timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00')) + elif isinstance(timestamp, (int, float)): + timestamp = datetime.fromtimestamp(timestamp / 1000.0) + except Exception as e: + logger.error(f"Error converting timestamp for NN signal: {str(e)}") + timestamp = datetime.now() + + # Add the signal to our list + self.nn_signals.append({ + 'type': signal_type, + 'timestamp': timestamp, + 'probability': probability, + 'added': datetime.now() + }) + + # Only keep the most recent 50 signals + if len(self.nn_signals) > 50: + self.nn_signals = self.nn_signals[-50:] + + logger.info(f"Added NN signal: {signal_type} at {timestamp}") + + def add_trade(self, price, timestamp, pnl=None, amount=0.1, action=None, type=None): + """Add a trade to be displayed on the chart + + Args: + price: The price at which the trade was executed + timestamp: The timestamp for the trade + pnl: Optional profit and loss value for the trade + amount: Amount traded + action: The type of trade (BUY or SELL) - alternative to type parameter + type: The type of trade (BUY or SELL) - alternative to action parameter + """ + # Handle both action and type parameters for backward compatibility + trade_type = type or action + + # Default to BUY if trade_type is None or not specified + if trade_type is None: + logger.warning(f"Trade type not specified in add_trade call, defaulting to BUY. Price: {price}, Timestamp: {timestamp}") + trade_type = "BUY" + + if isinstance(trade_type, int): + trade_type = "BUY" if trade_type == 0 else "SELL" + + # Ensure trade_type is uppercase if it's a string + if isinstance(trade_type, str): + trade_type = trade_type.upper() + + if trade_type not in ['BUY', 'SELL']: + logger.warning(f"Invalid trade type: {trade_type} (value type: {type(trade_type).__name__}), defaulting to BUY. Price: {price}, Timestamp: {timestamp}") + trade_type = "BUY" + + # Convert timestamp to datetime if it's not already + if not isinstance(timestamp, datetime): + try: + if isinstance(timestamp, str): + timestamp = datetime.fromisoformat(timestamp.replace('Z', '+00:00')) + elif isinstance(timestamp, (int, float)): + timestamp = datetime.fromtimestamp(timestamp / 1000.0) + except Exception as e: + logger.error(f"Error converting timestamp for trade: {str(e)}") + timestamp = datetime.now() + + # Create the trade object + trade = { + 'price': price, + 'timestamp': timestamp, + 'pnl': pnl, + 'amount': amount, + 'action': trade_type + } + + # Add to our trades list + if not hasattr(self, 'trades'): + self.trades = [] + + # If this is a SELL trade, try to find the corresponding BUY trade and update it with close_price + if trade_type == 'SELL' and len(self.trades) > 0: + for i in range(len(self.trades) - 1, -1, -1): + prev_trade = self.trades[i] + if prev_trade.get('action') == 'BUY' and 'close_price' not in prev_trade: + # Found a BUY trade without a close_price, consider it the matching trade + prev_trade['close_price'] = price + prev_trade['close_timestamp'] = timestamp + logger.info(f"Updated BUY trade at {prev_trade['timestamp']} with close price {price}") + break + + self.trades.append(trade) + + # Log the trade for debugging + pnl_str = f" with PnL: {pnl}" if pnl is not None else "" + logger.info(f"Added trade: {trade_type} {amount} at price {price} at time {timestamp}{pnl_str}") + + # Trigger a more frequent update of the chart by scheduling a callback + # This helps ensure the trade appears immediately on the chart + if hasattr(self, 'app') and self.app is not None: + try: + # Only update if we have a dash app running + # This is a workaround to make trades appear immediately + callback_context = dash.callback_context + # Force an update by triggering the callback + for callback_id, callback_info in self.app.callback_map.items(): + if 'live-chart' in callback_id: + # Found the chart callback, try to trigger it + logger.debug(f"Triggering chart update callback after trade") + callback_info['callback']() + break + except Exception as e: + # If callback triggering fails, it's not critical + logger.debug(f"Failed to trigger chart update: {str(e)}") + pass + + return trade + + def update_trading_info(self, signal=None, position=None, balance=None, pnl=None): + """Update the current trading information to be displayed on the chart + + Args: + signal: Current signal (BUY, SELL, HOLD) + position: Current position size + balance: Current session balance + pnl: Current session PnL + """ + if signal is not None: + if signal in ['BUY', 'SELL', 'HOLD']: + self.current_signal = signal + self.signal_time = datetime.now() + else: + logger.warning(f"Invalid signal type: {signal}") + + if position is not None: + self.current_position = position + + if balance is not None: + self.session_balance = balance + + if pnl is not None: + self.session_pnl = pnl + + logger.debug(f"Updated trading info: Signal={self.current_signal}, Position={self.current_position}, Balance=${self.session_balance:.2f}, PnL={self.session_pnl:.4f}") + +async def main(): + global charts # Make charts globally accessible for NN integration + symbols = ["ETH/USDT", "BTC/USDT"] + logger.info(f"Starting application for symbols: {symbols}") + + # Initialize neural network if enabled + if NN_ENABLED: + logger.info("Initializing Neural Network integration...") + if setup_neural_network(): + logger.info("Neural Network integration initialized successfully") + else: + logger.warning("Neural Network integration failed to initialize") + + charts = [] + websocket_tasks = [] + + # Create a chart and websocket task for each symbol + for symbol in symbols: + chart = RealTimeChart(symbol) + charts.append(chart) + websocket_tasks.append(asyncio.create_task(chart.start_websocket())) + + # Run Dash in a separate thread to not block the event loop + server_threads = [] + for i, chart in enumerate(charts): + port = 8050 + i # Use different ports for each chart + logger.info(f"Starting chart for {chart.symbol} on port {port}") + thread = Thread(target=lambda c=chart, p=port: c.run(port=p)) # Ensure correct port is passed + thread.daemon = True + thread.start() + server_threads.append(thread) + logger.info(f"Thread started for {chart.symbol} on port {port}") + + try: + # Keep the main task running + while True: + await asyncio.sleep(1) + except KeyboardInterrupt: + logger.info("Shutting down...") + except Exception as e: + logger.error(f"Unexpected error: {str(e)}") + finally: + for task in websocket_tasks: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Application terminated by user") + diff --git a/train_improved_rl.py b/train_improved_rl.py new file mode 100644 index 0000000..56792da --- /dev/null +++ b/train_improved_rl.py @@ -0,0 +1,547 @@ +#!/usr/bin/env python +""" +Improved RL Trading with Enhanced Training and Monitoring + +This script provides an improved version of the RL training process, +implementing better normalization, reward structure, and model training. +""" + +import os +import sys +import logging +import argparse +import time +from datetime import datetime +import numpy as np +import torch +import pandas as pd +import matplotlib.pyplot as plt +from pathlib import Path + +# Add project directory to path if needed +project_root = os.path.dirname(os.path.abspath(__file__)) +if project_root not in sys.path: + sys.path.append(project_root) + +# Import our custom modules +from NN.models.dqn_agent import DQNAgent +from NN.utils.trading_env import TradingEnvironment +from NN.utils.data_interface import DataInterface +from realtime import BinanceHistoricalData, RealTimeChart + +# Configure logging +log_filename = f'improved_rl_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log' +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', + handlers=[ + logging.FileHandler(log_filename), + logging.StreamHandler() + ] +) +logger = logging.getLogger('improved_rl') + +# Parse command line arguments +parser = argparse.ArgumentParser(description='Improved RL Trading with Enhanced Training') +parser.add_argument('--episodes', type=int, default=20, help='Number of episodes to train') +parser.add_argument('--visualize', action='store_true', help='Visualize trades during training') +parser.add_argument('--save-path', type=str, default='NN/models/saved/improved_dqn_agent', help='Path to save trained model') +parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol') +args = parser.parse_args() + +def create_training_environment(symbol, window_size=20): + """Create and prepare the training environment with data""" + logger.info(f"Setting up training environment for {symbol}") + + # Fetch historical data from multiple timeframes + data_interface = DataInterface(symbol) + + # Use Binance data provider for fetching data + historical_data = BinanceHistoricalData() + + # Fetch data for each timeframe + df_1m = historical_data.get_historical_candles(symbol, interval_seconds=60, limit=1000) + df_5m = historical_data.get_historical_candles(symbol, interval_seconds=300, limit=1000) + df_15m = historical_data.get_historical_candles(symbol, interval_seconds=900, limit=500) + + # Ensure all dataframes have index as timestamp type + if df_1m is not None and not df_1m.empty: + if 'timestamp' in df_1m.columns: + df_1m = df_1m.set_index('timestamp') + + if df_5m is not None and not df_5m.empty: + if 'timestamp' in df_5m.columns: + df_5m = df_5m.set_index('timestamp') + + if df_15m is not None and not df_15m.empty: + if 'timestamp' in df_15m.columns: + df_15m = df_15m.set_index('timestamp') + + # Preprocess data (add technical indicators) + df_1m = preprocess_dataframe(df_1m) + df_5m = preprocess_dataframe(df_5m) + df_15m = preprocess_dataframe(df_15m) + + # Create environment with all timeframes + env = create_multi_timeframe_env(df_1m, df_5m, df_15m, window_size) + + return env, (df_1m, df_5m, df_15m) + +def preprocess_dataframe(df): + """Add technical indicators and preprocess dataframe""" + if df is None or df.empty: + return None + + # Drop any missing values + df = df.dropna() + + # Ensure it has OHLCV columns + required_columns = ['open', 'high', 'low', 'close', 'volume'] + missing_columns = [col for col in required_columns if col not in df.columns] + + if missing_columns: + logger.warning(f"Missing required columns: {missing_columns}") + for col in missing_columns: + # Fill with close price for OHLC if missing + if col in ['open', 'high', 'low'] and 'close' in df.columns: + df[col] = df['close'] + # Fill with zeros for volume if missing + elif col == 'volume': + df[col] = 0 + + # Add simple technical indicators + # 1. Simple Moving Averages + df['sma_5'] = df['close'].rolling(window=5).mean() + df['sma_10'] = df['close'].rolling(window=10).mean() + + # 2. Relative Strength Index (RSI) + delta = df['close'].diff() + gain = (delta.where(delta > 0, 0)).rolling(window=14).mean() + loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean() + rs = gain / loss + df['rsi'] = 100 - (100 / (1 + rs)) + + # 3. Bollinger Bands + df['bb_middle'] = df['close'].rolling(window=20).mean() + df['bb_std'] = df['close'].rolling(window=20).std() + df['bb_upper'] = df['bb_middle'] + 2 * df['bb_std'] + df['bb_lower'] = df['bb_middle'] - 2 * df['bb_std'] + + # 4. MACD + df['ema_12'] = df['close'].ewm(span=12, adjust=False).mean() + df['ema_26'] = df['close'].ewm(span=26, adjust=False).mean() + df['macd'] = df['ema_12'] - df['ema_26'] + df['macd_signal'] = df['macd'].ewm(span=9, adjust=False).mean() + + # 5. Price rate of change + df['roc'] = df['close'].pct_change(periods=10) * 100 + + # Fill any remaining NaN values with 0 + df = df.fillna(0) + + return df + +def create_multi_timeframe_env(df_1m, df_5m, df_15m, window_size=20): + """Create a custom environment that handles multiple timeframes""" + + # Ensure we have complete data for all timeframes + min_required_length = window_size + 100 # Add buffer for training + + if (df_1m is None or len(df_1m) < min_required_length or + df_5m is None or len(df_5m) < min_required_length or + df_15m is None or len(df_15m) < min_required_length): + raise ValueError(f"Insufficient data for training. Need at least {min_required_length} candles per timeframe.") + + # Ensure we only use the last N valid data points + df_1m = df_1m.iloc[-900:].copy() if len(df_1m) > 900 else df_1m.copy() + df_5m = df_5m.iloc[-180:].copy() if len(df_5m) > 180 else df_5m.copy() + df_15m = df_15m.iloc[-60:].copy() if len(df_15m) > 60 else df_15m.copy() + + # Reset index to make sure we have continuous integers + df_1m = df_1m.reset_index(drop=True) + df_5m = df_5m.reset_index(drop=True) + df_15m = df_15m.reset_index(drop=True) + + # For simplicity, we'll use the 1m data as the base environment + # The other timeframes will be incorporated through observation + + env = TradingEnvironment( + data=df_1m, + initial_balance=100.0, + fee_rate=0.0005, # 0.05% fee (typical for crypto exchanges) + max_steps=len(df_1m) - window_size - 50, # Leave some room at the end + window_size=window_size, + risk_aversion=0.2, # Moderately risk-averse + price_scaling='zscore', # Use z-score normalization + reward_scaling=10.0, # Scale rewards for better learning + episode_penalty=0.2 # Penalty for holding positions at end of episode + ) + + return env + +def initialize_agent(env, window_size=20, num_features=0, timeframes=None): + """Initialize the DQN agent with appropriate parameters""" + if timeframes is None: + timeframes = ['1m', '5m', '15m'] + + # Calculate input dimensions + state_dim = env.observation_space.shape[0] + action_dim = env.action_space.n + + # If num_features wasn't provided, infer from environment + if num_features == 0: + # Calculate features per timeframe from state dimension and number of timeframes + # Accounting for the 3 additional features (position, equity, unrealized_pnl) + num_features = (state_dim - 3) // len(timeframes) + + logger.info(f"Initializing DQN agent: state_dim={state_dim}, action_dim={action_dim}, features={num_features}") + + agent = DQNAgent( + state_size=state_dim, + action_size=action_dim, + window_size=window_size, + num_features=num_features, + timeframes=timeframes, + learning_rate=0.0005, # Start with a moderate learning rate + gamma=0.97, # Slightly reduced discount factor for stable learning + epsilon=1.0, # Start with full exploration + epsilon_min=0.05, # Maintain some exploration even at the end + epsilon_decay=0.9975, # Slower decay for more exploration + memory_size=20000, # Larger replay buffer + batch_size=128, # Larger batch size for more stable gradients + target_update=5 # More frequent target network updates + ) + + return agent + +def train_agent(env, agent, num_episodes=20, visualize=False, chart=None, save_path=None, save_freq=5): + """ + Train the DQN agent with improved training loop + + Args: + env: The trading environment + agent: The DQN agent + num_episodes: Number of episodes to train + visualize: Whether to visualize trades during training + chart: The visualization chart (if visualize=True) + save_path: Path to save the model + save_freq: How often to save checkpoints (in episodes) + + Returns: + tuple: (rewards, wins, losses, best_reward) + """ + logger.info(f"Starting training for {num_episodes} episodes") + + # Initialize metrics tracking + rewards = [] + win_rates = [] + total_train_time = 0 + best_reward = float('-inf') + best_model_path = None + + # Create directory for checkpoints if needed + if save_path: + os.makedirs(os.path.dirname(save_path), exist_ok=True) + checkpoint_dir = os.path.join(os.path.dirname(save_path), 'checkpoints') + os.makedirs(checkpoint_dir, exist_ok=True) + + # For tracking improvement + last_improved_episode = 0 + patience = 10 # Episodes to wait for improvement before early stopping + + for episode in range(num_episodes): + start_time = time.time() + + # Reset environment and get initial state + state = env.reset() + done = False + episode_reward = 0 + step = 0 + + # Action metrics for this episode + actions_taken = {0: 0, 1: 0, 2: 0} # Track BUY, SELL, HOLD actions + + while not done: + # Select action + action = agent.act(state) + + # Execute action + next_state, reward, done, info = env.step(action) + + # Store experience in replay buffer + is_extrema = False # In a real implementation, detect extrema points + agent.remember(state, action, reward, next_state, done, is_extrema) + + # Learn from experience + if len(agent.memory) >= agent.batch_size: + use_prioritized = episode > 1 # Start using prioritized replay after first episode + loss = agent.replay(use_prioritized=use_prioritized) + + # Update state and metrics + state = next_state + episode_reward += reward + actions_taken[action] += 1 + + # Every 100 steps, log progress + if step % 100 == 0 or step < 10: + action_str = "BUY" if action == 0 else "SELL" if action == 1 else "HOLD" + current_price = info.get('current_price', 0) + pnl = info.get('pnl', 0) + balance = info.get('balance', 0) + + logger.info(f"Episode {episode}, Step {step}: Action={action_str}, " + f"Reward={reward:.4f}, Balance=${balance:.2f}, PnL={pnl:.4f}") + + # Add trade to visualization if enabled + if visualize and chart and action in [0, 1]: # BUY or SELL + chart.add_trade( + price=current_price, + timestamp=datetime.now(), + amount=0.1, + pnl=pnl, + action=action_str + ) + + step += 1 + + # Episode finished - calculate metrics + episode_time = time.time() - start_time + total_train_time += episode_time + + # Get environment info + win_rate = env.winning_trades / max(1, env.total_trades) + trades = env.total_trades + balance = env.balance + gain = (balance - env.initial_balance) / env.initial_balance + max_drawdown = env.max_drawdown + + # Record metrics + rewards.append(episode_reward) + win_rates.append(win_rate) + + # Update agent's learning metrics + improved = agent.update_learning_metrics(episode_reward) + + # If this is best performance, save the model + if episode_reward > best_reward: + best_reward = episode_reward + if save_path: + best_model_path = f"{save_path}_best" + agent.save(best_model_path) + logger.info(f"New best model saved to {best_model_path} (reward: {best_reward:.2f})") + last_improved_episode = episode + + # Regular checkpoint saving + if save_path and episode % save_freq == 0: + checkpoint_path = os.path.join(checkpoint_dir, f"checkpoint_episode_{episode}") + agent.save(checkpoint_path) + + # Print episode summary + actions_summary = ", ".join([f"{k}:{v}" for k, v in actions_taken.items()]) + logger.info(f"Episode {episode} completed in {episode_time:.2f}s") + logger.info(f" Total reward: {episode_reward:.4f}") + logger.info(f" Actions taken: {actions_summary}") + logger.info(f" Trades: {trades}, Win rate: {win_rate:.2%}") + logger.info(f" Balance: ${balance:.2f}, Gain: {gain:.2%}") + logger.info(f" Max Drawdown: {max_drawdown:.2%}") + + # Early stopping check + if episode - last_improved_episode >= patience: + logger.info(f"No improvement for {patience} episodes. Early stopping.") + break + + # Training complete + avg_time_per_episode = total_train_time / max(1, len(rewards)) + logger.info(f"Training completed in {total_train_time:.2f}s ({avg_time_per_episode:.2f}s per episode)") + + # Save final model + if save_path: + agent.save(f"{save_path}_final") + logger.info(f"Final model saved to {save_path}_final") + + # Return training metrics + return rewards, win_rates, best_reward, best_model_path + +def plot_training_results(rewards, win_rates, save_dir=None): + """Plot training metrics and save the figure""" + plt.figure(figsize=(12, 8)) + + # Plot rewards + plt.subplot(2, 1, 1) + plt.plot(rewards, 'b-') + plt.title('Training Rewards per Episode') + plt.xlabel('Episode') + plt.ylabel('Total Reward') + plt.grid(True) + + # Plot win rates + plt.subplot(2, 1, 2) + plt.plot(win_rates, 'g-') + plt.title('Win Rate per Episode') + plt.xlabel('Episode') + plt.ylabel('Win Rate') + plt.grid(True) + + plt.tight_layout() + + # Save figure if directory provided + if save_dir: + os.makedirs(save_dir, exist_ok=True) + plt.savefig(os.path.join(save_dir, f'training_results_{datetime.now().strftime("%Y%m%d_%H%M%S")}.png')) + + plt.close() + +def evaluate_agent(env, agent, num_episodes=5, visualize=False, chart=None): + """ + Evaluate a trained agent on the environment + + Args: + env: The trading environment + agent: The trained DQN agent + num_episodes: Number of evaluation episodes + visualize: Whether to visualize trades + chart: The visualization chart (if visualize=True) + + Returns: + dict: Evaluation metrics + """ + logger.info(f"Evaluating agent over {num_episodes} episodes") + + # Metrics to track + total_rewards = [] + total_trades = [] + win_rates = [] + sharpe_ratios = [] + sortino_ratios = [] + max_drawdowns = [] + final_balances = [] + + for episode in range(num_episodes): + # Reset environment + state = env.reset() + done = False + episode_reward = 0 + + # Run episode without exploration + while not done: + action = agent.act(state, explore=False) # No exploration during evaluation + next_state, reward, done, info = env.step(action) + + episode_reward += reward + state = next_state + + # Add trade to visualization if enabled + if visualize and chart and action in [0, 1]: # BUY or SELL + action_str = "BUY" if action == 0 else "SELL" + current_price = info.get('current_price', 0) + pnl = info.get('pnl', 0) + + chart.add_trade( + price=current_price, + timestamp=datetime.now(), + amount=0.1, + pnl=pnl, + action=action_str + ) + + # Record metrics + total_rewards.append(episode_reward) + total_trades.append(env.total_trades) + win_rates.append(env.winning_trades / max(1, env.total_trades)) + sharpe_ratios.append(info.get('sharpe_ratio', 0)) + sortino_ratios.append(info.get('sortino_ratio', 0)) + max_drawdowns.append(env.max_drawdown) + final_balances.append(env.balance) + + logger.info(f"Evaluation episode {episode} - Reward: {episode_reward:.4f}, " + f"Balance: ${env.balance:.2f}, Win rate: {win_rates[-1]:.2%}") + + # Calculate average metrics + avg_reward = np.mean(total_rewards) + avg_trades = np.mean(total_trades) + avg_win_rate = np.mean(win_rates) + avg_sharpe = np.mean(sharpe_ratios) + avg_sortino = np.mean(sortino_ratios) + avg_max_drawdown = np.mean(max_drawdowns) + avg_final_balance = np.mean(final_balances) + + # Log evaluation summary + logger.info("Evaluation completed:") + logger.info(f" Average reward: {avg_reward:.4f}") + logger.info(f" Average trades per episode: {avg_trades:.2f}") + logger.info(f" Average win rate: {avg_win_rate:.2%}") + logger.info(f" Average Sharpe ratio: {avg_sharpe:.4f}") + logger.info(f" Average Sortino ratio: {avg_sortino:.4f}") + logger.info(f" Average max drawdown: {avg_max_drawdown:.2%}") + logger.info(f" Average final balance: ${avg_final_balance:.2f}") + + # Return evaluation metrics + return { + 'avg_reward': avg_reward, + 'avg_trades': avg_trades, + 'avg_win_rate': avg_win_rate, + 'avg_sharpe': avg_sharpe, + 'avg_sortino': avg_sortino, + 'avg_max_drawdown': avg_max_drawdown, + 'avg_final_balance': avg_final_balance + } + +def main(): + """Main function to run the improved RL training""" + start_time = time.time() + logger.info(f"Starting improved RL training for {args.symbol}") + + # Create environment + env, data_frames = create_training_environment(args.symbol) + + # Initialize visualization if enabled + chart = None + if args.visualize: + logger.info("Initializing visualization chart") + chart = RealTimeChart(args.symbol) + time.sleep(2) # Give time for chart to initialize + + # Initialize agent + agent = initialize_agent(env) + + # Train agent + rewards, win_rates, best_reward, best_model_path = train_agent( + env=env, + agent=agent, + num_episodes=args.episodes, + visualize=args.visualize, + chart=chart, + save_path=args.save_path + ) + + # Plot training results + plot_dir = os.path.join(os.path.dirname(args.save_path), 'plots') + plot_training_results(rewards, win_rates, save_dir=plot_dir) + + # Evaluate best model + logger.info("Evaluating best model") + + # Load best model for evaluation + if best_model_path: + best_agent = initialize_agent(env) + best_agent.load(best_model_path) + + # Evaluate the best model + eval_metrics = evaluate_agent( + env=env, + agent=best_agent, + visualize=args.visualize, + chart=chart + ) + + # Log evaluation results + logger.info("Best model evaluation complete:") + for metric, value in eval_metrics.items(): + logger.info(f" {metric}: {value}") + + # Total run time + total_time = time.time() - start_time + logger.info(f"Total run time: {total_time:.2f} seconds") + +if __name__ == "__main__": + main() \ No newline at end of file