# training/train.py import torch import torch.nn as nn import torch.optim as optim from data.data_utils import preprocess_data, create_mask from model.transformer import Transformer from data.live_data import LiveDataManager from visualization.plotting import plot_live_data import asyncio import time import os from datetime import datetime from collections import deque # --- Directories for saving models --- LAST_DIR = os.path.join("models", "last") BEST_DIR = os.path.join("models", "best") os.makedirs(LAST_DIR, exist_ok=True) os.makedirs(BEST_DIR, exist_ok=True) # ------------------------------------- # Checkpoint Functions (same as before) # ------------------------------------- def maintain_checkpoint_directory(directory, max_files=10): files = os.listdir(directory) if len(files) > max_files: full_paths = [os.path.join(directory, f) for f in files] full_paths.sort(key=lambda x: os.path.getmtime(x)) for f in full_paths[: len(files) - max_files]: os.remove(f) def get_best_models(directory): best_files = [] for file in os.listdir(directory): parts = file.split("_") try: r = float(parts[1]) best_files.append((r, file)) except Exception: continue return best_files def save_checkpoint(model, epoch, total_loss, last_dir=LAST_DIR, best_dir=BEST_DIR): timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt" last_path = os.path.join(last_dir, last_filename) torch.save({ "epoch": epoch, "total_loss": total_loss, "model_state_dict": model.state_dict() }, last_path) maintain_checkpoint_directory(last_dir, max_files=10) best_models = get_best_models(best_dir) add_to_best = False if len(best_models) < 10: add_to_best = True else: min_loss, min_file = min(best_models, key=lambda x: x[0]) if total_loss < min_loss: add_to_best = True os.remove(os.path.join(best_dir, min_file)) if add_to_best: best_filename = f"best_{total_loss:.4f}_epoch_{epoch}_{timestamp}.pt" best_path = os.path.join(best_dir, best_filename) torch.save({ "epoch": epoch, "total_loss": total_loss, "model_state_dict": model.state_dict() }, best_path) maintain_checkpoint_directory(best_dir, max_files=10) print(f"Saved checkpoint for epoch {epoch} with loss {total_loss:.4f}") def load_best_checkpoint(model, best_dir=BEST_DIR): best_models = get_best_models(best_dir) if not best_models: return None best_loss, best_file = min(best_models, key=lambda x: x[0]) #changed to min to represent the loss path = os.path.join(best_dir, best_file) print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}") checkpoint = torch.load(path) model.load_state_dict(checkpoint["model_state_dict"]) return checkpoint async def train(model, data_manager, optimizer, criterion_candle, criterion_volume, criterion_ticks, num_epochs=10, device='cpu'): model.to(device) model.train() trade_history = deque(maxlen=100) # Load best checkpoint if available. load_best_checkpoint(model, BEST_DIR) await data_manager._fetch_initial_candles() for epoch in range(1, num_epochs + 1): start_time = time.time() total_loss = 0 while True: # Continuously train on live data await data_manager.fetch_and_process_ticks() candles, ticks = await data_manager.get_data() if len(candles) < data_manager.window_size: # print("Waiting for enough data...") # avoid to print too many lines await asyncio.sleep(1) #wait and try again continue candle_features, tick_features, future_candle, future_volume, future_ticks = preprocess_data(candles, ticks) # Skip if preprocessing fails (e.g., not enough data) if candle_features is None: await asyncio.sleep(1) continue # Convert to PyTorch tensors and move to the correct device candle_features = torch.tensor(candle_features).unsqueeze(0).to(device) # Add batch dimension tick_features = torch.tensor(tick_features).unsqueeze(0).to(device) future_candle = torch.tensor(future_candle).unsqueeze(0).to(device) future_volume = torch.tensor(future_volume).unsqueeze(0).to(device) future_ticks = torch.tensor(future_ticks).unsqueeze(0).to(device) future_candle_mask = create_mask(candle_features.size(1)).to(device) future_ticks_mask = create_mask(tick_features.size(1)).to(device) optimizer.zero_grad() future_candle_pred, future_volume_pred, future_ticks_pred = model(candle_features, tick_features, future_candle_mask, future_ticks_mask) # Calculate Loss loss_candle = criterion_candle(future_candle_pred.squeeze(1), future_candle) loss_volume = criterion_volume(future_volume_pred.squeeze(1), future_volume) # Add .squeeze() here loss_ticks = criterion_ticks(future_ticks_pred.squeeze(1), future_ticks) # Combine losses (you can add weights to each loss component) total_loss = loss_candle + loss_volume + loss_ticks total_loss.backward() optimizer.step() print(f"Epoch: {epoch}, Candle Loss: {loss_candle.item():.4f}, Volume Loss: {loss_volume.item():.4f}, Tick Loss: {loss_ticks.item():.4f}, Total: {total_loss.item():.4f}") # Save checkpoint if epoch % 1 == 0: # every epoch save_checkpoint(model, epoch, total_loss.item(), LAST_DIR, BEST_DIR) # --- Basic Trading Logic (Illustrative) --- # This is a very simplified example. In a real system, you would have # much more sophisticated entry/exit logic, risk management, etc. predicted_close = future_candle_pred[0, 0, 3].item() # Predicted close current_close = candles[-1]['close'] if predicted_close > current_close * 1.005: # Example: Buy if predicted close is 0.5% higher trade_history.append({"type": "buy", "price": current_close, "time": time.time()}) print(f"BUY signal at {current_close}") elif predicted_close < current_close * 0.995: # Example: Sell if predicted close is 0.5% lower trade_history.append({"type": "sell", "price": current_close, "time": time.time()}) print(f"SELL signal at {current_close}") # Plot data if len(trade_history)>0: # only after the first trade plot_live_data(candles, list(trade_history)) await asyncio.sleep(1) # Adjust sleep time as needed