gogo2/crypto/gogo/main.py
Dobromir Popov 5606ed3cab init gogo
2025-02-12 01:15:44 +02:00

47 lines
1.4 KiB
Python

# main.py
import asyncio
import torch
import torch.nn as nn
import torch.optim as optim
from data.live_data import LiveDataManager
from model.transformer import Transformer
from training.train import train
from data.data_utils import preprocess_data # Import preprocess_data
async def main():
symbol = 'BTC/USDT'
data_manager = LiveDataManager(symbol)
# Model parameters (adjust for ~1B parameters)
input_dim = 6 + len([5, 10, 20, 60, 120, 200]) # OHLCV + EMAs
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
dropout = 0.1
model = Transformer(input_dim, d_model, num_heads, num_layers, d_ff, dropout)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
# Define loss functions
criterion_candle = nn.MSELoss()
criterion_volume = nn.MSELoss() # Consider a different loss for volume if needed
criterion_ticks = nn.MSELoss()
# Check for CUDA availability and set device
if torch.cuda.is_available():
device = torch.device('cuda')
print("Using CUDA")
else:
device = torch.device('cpu')
print("Using CPU")
try:
await train(model, data_manager, optimizer, criterion_candle, criterion_volume, criterion_ticks, num_epochs=10, device=device)
except KeyboardInterrupt:
print("Training stopped manually.")
finally:
await data_manager.close()
if __name__ == '__main__':
asyncio.run(main())