init gogo
This commit is contained in:
47
crypto/gogo/main.py
Normal file
47
crypto/gogo/main.py
Normal file
@ -0,0 +1,47 @@
|
||||
# main.py
|
||||
|
||||
import asyncio
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from data.live_data import LiveDataManager
|
||||
from model.transformer import Transformer
|
||||
from training.train import train
|
||||
from data.data_utils import preprocess_data # Import preprocess_data
|
||||
|
||||
async def main():
|
||||
symbol = 'BTC/USDT'
|
||||
data_manager = LiveDataManager(symbol)
|
||||
|
||||
# Model parameters (adjust for ~1B parameters)
|
||||
input_dim = 6 + len([5, 10, 20, 60, 120, 200]) # OHLCV + EMAs
|
||||
d_model = 512
|
||||
num_heads = 8
|
||||
num_layers = 6
|
||||
d_ff = 2048
|
||||
dropout = 0.1
|
||||
|
||||
model = Transformer(input_dim, d_model, num_heads, num_layers, d_ff, dropout)
|
||||
|
||||
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
|
||||
|
||||
# Define loss functions
|
||||
criterion_candle = nn.MSELoss()
|
||||
criterion_volume = nn.MSELoss() # Consider a different loss for volume if needed
|
||||
criterion_ticks = nn.MSELoss()
|
||||
|
||||
# Check for CUDA availability and set device
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device('cuda')
|
||||
print("Using CUDA")
|
||||
else:
|
||||
device = torch.device('cpu')
|
||||
print("Using CPU")
|
||||
try:
|
||||
await train(model, data_manager, optimizer, criterion_candle, criterion_volume, criterion_ticks, num_epochs=10, device=device)
|
||||
except KeyboardInterrupt:
|
||||
print("Training stopped manually.")
|
||||
finally:
|
||||
await data_manager.close()
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
Reference in New Issue
Block a user