init gogo
This commit is contained in:
157
crypto/gogo/training/train.py
Normal file
157
crypto/gogo/training/train.py
Normal file
@ -0,0 +1,157 @@
|
||||
# training/train.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from data.data_utils import preprocess_data, create_mask
|
||||
from model.transformer import Transformer
|
||||
from data.live_data import LiveDataManager
|
||||
from visualization.plotting import plot_live_data
|
||||
import asyncio
|
||||
import time
|
||||
import os
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
|
||||
# --- Directories for saving models ---
|
||||
LAST_DIR = os.path.join("models", "last")
|
||||
BEST_DIR = os.path.join("models", "best")
|
||||
os.makedirs(LAST_DIR, exist_ok=True)
|
||||
os.makedirs(BEST_DIR, exist_ok=True)
|
||||
|
||||
# -------------------------------------
|
||||
# Checkpoint Functions (same as before)
|
||||
# -------------------------------------
|
||||
def maintain_checkpoint_directory(directory, max_files=10):
|
||||
files = os.listdir(directory)
|
||||
if len(files) > max_files:
|
||||
full_paths = [os.path.join(directory, f) for f in files]
|
||||
full_paths.sort(key=lambda x: os.path.getmtime(x))
|
||||
for f in full_paths[: len(files) - max_files]:
|
||||
os.remove(f)
|
||||
|
||||
def get_best_models(directory):
|
||||
best_files = []
|
||||
for file in os.listdir(directory):
|
||||
parts = file.split("_")
|
||||
try:
|
||||
r = float(parts[1])
|
||||
best_files.append((r, file))
|
||||
except Exception:
|
||||
continue
|
||||
return best_files
|
||||
|
||||
def save_checkpoint(model, epoch, total_loss, last_dir=LAST_DIR, best_dir=BEST_DIR):
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
last_filename = f"model_last_epoch_{epoch}_{timestamp}.pt"
|
||||
last_path = os.path.join(last_dir, last_filename)
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"total_loss": total_loss,
|
||||
"model_state_dict": model.state_dict()
|
||||
}, last_path)
|
||||
maintain_checkpoint_directory(last_dir, max_files=10)
|
||||
|
||||
best_models = get_best_models(best_dir)
|
||||
add_to_best = False
|
||||
if len(best_models) < 10:
|
||||
add_to_best = True
|
||||
else:
|
||||
min_loss, min_file = min(best_models, key=lambda x: x[0])
|
||||
if total_loss < min_loss:
|
||||
add_to_best = True
|
||||
os.remove(os.path.join(best_dir, min_file))
|
||||
if add_to_best:
|
||||
best_filename = f"best_{total_loss:.4f}_epoch_{epoch}_{timestamp}.pt"
|
||||
best_path = os.path.join(best_dir, best_filename)
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"total_loss": total_loss,
|
||||
"model_state_dict": model.state_dict()
|
||||
}, best_path)
|
||||
maintain_checkpoint_directory(best_dir, max_files=10)
|
||||
print(f"Saved checkpoint for epoch {epoch} with loss {total_loss:.4f}")
|
||||
|
||||
def load_best_checkpoint(model, best_dir=BEST_DIR):
|
||||
best_models = get_best_models(best_dir)
|
||||
if not best_models:
|
||||
return None
|
||||
best_loss, best_file = min(best_models, key=lambda x: x[0]) #changed to min to represent the loss
|
||||
path = os.path.join(best_dir, best_file)
|
||||
print(f"Loading best model from checkpoint: {best_file} with loss {best_loss:.4f}")
|
||||
checkpoint = torch.load(path)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
return checkpoint
|
||||
|
||||
async def train(model, data_manager, optimizer, criterion_candle, criterion_volume, criterion_ticks, num_epochs=10, device='cpu'):
|
||||
model.to(device)
|
||||
model.train()
|
||||
trade_history = deque(maxlen=100)
|
||||
|
||||
# Load best checkpoint if available.
|
||||
load_best_checkpoint(model, BEST_DIR)
|
||||
await data_manager._fetch_initial_candles()
|
||||
|
||||
for epoch in range(1, num_epochs + 1):
|
||||
start_time = time.time()
|
||||
total_loss = 0
|
||||
|
||||
while True: # Continuously train on live data
|
||||
await data_manager.fetch_and_process_ticks()
|
||||
candles, ticks = await data_manager.get_data()
|
||||
if len(candles) < data_manager.window_size:
|
||||
# print("Waiting for enough data...") # avoid to print too many lines
|
||||
await asyncio.sleep(1) #wait and try again
|
||||
continue
|
||||
|
||||
candle_features, tick_features, future_candle, future_volume, future_ticks = preprocess_data(candles, ticks)
|
||||
|
||||
# Skip if preprocessing fails (e.g., not enough data)
|
||||
if candle_features is None:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
# Convert to PyTorch tensors and move to the correct device
|
||||
candle_features = torch.tensor(candle_features).unsqueeze(0).to(device) # Add batch dimension
|
||||
tick_features = torch.tensor(tick_features).unsqueeze(0).to(device)
|
||||
future_candle = torch.tensor(future_candle).unsqueeze(0).to(device)
|
||||
future_volume = torch.tensor(future_volume).unsqueeze(0).to(device)
|
||||
future_ticks = torch.tensor(future_ticks).unsqueeze(0).to(device)
|
||||
|
||||
future_candle_mask = create_mask(candle_features.size(1)).to(device)
|
||||
future_ticks_mask = create_mask(tick_features.size(1)).to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
future_candle_pred, future_volume_pred, future_ticks_pred = model(candle_features, tick_features, future_candle_mask, future_ticks_mask)
|
||||
|
||||
# Calculate Loss
|
||||
loss_candle = criterion_candle(future_candle_pred.squeeze(1), future_candle)
|
||||
loss_volume = criterion_volume(future_volume_pred.squeeze(1), future_volume) # Add .squeeze() here
|
||||
loss_ticks = criterion_ticks(future_ticks_pred.squeeze(1), future_ticks)
|
||||
|
||||
# Combine losses (you can add weights to each loss component)
|
||||
total_loss = loss_candle + loss_volume + loss_ticks
|
||||
total_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
print(f"Epoch: {epoch}, Candle Loss: {loss_candle.item():.4f}, Volume Loss: {loss_volume.item():.4f}, Tick Loss: {loss_ticks.item():.4f}, Total: {total_loss.item():.4f}")
|
||||
|
||||
# Save checkpoint
|
||||
if epoch % 1 == 0: # every epoch
|
||||
save_checkpoint(model, epoch, total_loss.item(), LAST_DIR, BEST_DIR)
|
||||
|
||||
# --- Basic Trading Logic (Illustrative) ---
|
||||
# This is a very simplified example. In a real system, you would have
|
||||
# much more sophisticated entry/exit logic, risk management, etc.
|
||||
predicted_close = future_candle_pred[0, 0, 3].item() # Predicted close
|
||||
current_close = candles[-1]['close']
|
||||
|
||||
if predicted_close > current_close * 1.005: # Example: Buy if predicted close is 0.5% higher
|
||||
trade_history.append({"type": "buy", "price": current_close, "time": time.time()})
|
||||
print(f"BUY signal at {current_close}")
|
||||
elif predicted_close < current_close * 0.995: # Example: Sell if predicted close is 0.5% lower
|
||||
trade_history.append({"type": "sell", "price": current_close, "time": time.time()})
|
||||
print(f"SELL signal at {current_close}")
|
||||
|
||||
# Plot data
|
||||
if len(trade_history)>0: # only after the first trade
|
||||
plot_live_data(candles, list(trade_history))
|
||||
await asyncio.sleep(1) # Adjust sleep time as needed
|
Reference in New Issue
Block a user