suggestions

This commit is contained in:
Dobromir Popov 2025-02-12 01:38:05 +02:00
parent 33a5588539
commit c8b0f77d32
10 changed files with 109 additions and 10 deletions

31
crypto/gogo/.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,31 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Train (Live Data)",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/main.py",
"console": "integratedTerminal",
"justMyCode": false,
"args": ["--mode", "train"]
},
{
"name": "Backtest (Historical Data)",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/main.py",
"console": "integratedTerminal",
"justMyCode": false,
"args": ["--mode", "backtest"]
},
{
"name": "Unit Tests",
"type": "python",
"request": "launch",
"program": "${workspaceFolder}/tests/test_data.py",
"console": "integratedTerminal",
"justMyCode": false
}
]
}

11
crypto/gogo/.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,11 @@
{
"python.testing.unittestArgs": [
"-v",
"-s",
"./tests",
"-p",
"test_*.py"
],
"python.testing.pytestEnabled": false,
"python.testing.unittestEnabled": true
}

View File

@ -0,0 +1 @@
{"1m": [], "5m": [], "15m": [], "1h": [], "1d": []}

View File

@ -1,5 +1,3 @@
# main.py
import asyncio import asyncio
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -17,8 +15,9 @@ from model.trading_model import TradingModel
from training.rl_agent import ContinuousRLAgent, ReplayBuffer from training.rl_agent import ContinuousRLAgent, ReplayBuffer
from training.train_historical import train_on_historical_data, load_best_checkpoint, save_candles_cache, CACHE_FILE, BEST_DIR from training.train_historical import train_on_historical_data, load_best_checkpoint, save_candles_cache, CACHE_FILE, BEST_DIR
from data.data_utils import get_aligned_candle_with_index, get_features_for_tf from data.data_utils import get_aligned_candle_with_index, get_features_for_tf
import argparse
async def main(): async def main_training():
symbol = 'BTC/USDT' symbol = 'BTC/USDT'
data_manager = LiveDataManager(symbol) data_manager = LiveDataManager(symbol)
@ -293,4 +292,11 @@ def plot_trade_history(candles, trade_history):
plt.show() plt.show()
if __name__ == '__main__': if __name__ == '__main__':
asyncio.run(main_backtest()) parser = argparse.ArgumentParser(description='Trading Bot Modes')
parser.add_argument('--mode', type=str, default='backtest', choices=['train', 'backtest'], help='Choose mode: train or backtest')
args = parser.parse_args()
if args.mode == 'train':
asyncio.run(main_training())
elif args.mode == 'backtest':
asyncio.run(main_backtest())

View File

@ -219,4 +219,4 @@ if __name__ == '__main__':
print("Future Candle Prediction Shape:", future_candle_pred.shape) # Expected: [batch_size, 1, 5] print("Future Candle Prediction Shape:", future_candle_pred.shape) # Expected: [batch_size, 1, 5]
print("Future Volume Prediction Shape:", future_volume_pred.shape) # Expected: [batch_size, 1, 1] print("Future Volume Prediction Shape:", future_volume_pred.shape) # Expected: [batch_size, 1, 1]
print("Future Ticks Prediction Shape:", future_ticks_pred.shape) # Expected: [batch_size, 30, 2] print("Future Ticks Prediction Shape:", future_ticks_pred.shape) # Expected: [batch_size, 30, 2]

View File

@ -1,4 +1,3 @@
To run this code: To run this code:
Install Dependencies: pip install -r requirements.txt Install Dependencies: pip install -r requirements.txt
@ -28,4 +27,4 @@ Overfitting: Monitor for overfitting (the model performing well on training data
Memory usage: the code uses a deque to store the data. This prevents out of memory errors and keeps only the most recent N samples. Memory usage: the code uses a deque to store the data. This prevents out of memory errors and keeps only the most recent N samples.
Learned indicators: This is a complex part. you can create a new NN, that will be trained to predict the next candle data based only on HLOCV. the weights of this NN can be used as new indicators, concatenated to the others. Learned indicators: This is a complex part. you can create a new NN, that will be trained to predict the next candle data based only on HLOCV. the weights of this NN can be used as new indicators, concatenated to the others.

View File

@ -3,4 +3,4 @@ ccxt
python-dotenv python-dotenv
torch torch
numpy numpy
matplotlib matplotlib

View File

@ -0,0 +1,51 @@
import asyncio
import unittest
from unittest.mock import patch, AsyncMock
from data.live_data import LiveDataManager
import ccxt.async_support as ccxt
class TestLiveData(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
self.symbol = 'BTC/USDT'
self.data_manager = LiveDataManager(self.symbol)
self.exchange = self.data_manager.exchange
async def asyncTearDown(self):
await self.data_manager.close()
@patch('data.live_data.ccxt.mexc.fetch_trades', new_callable=AsyncMock)
async def test_fetch_and_process_ticks(self, mock_fetch_trades):
# Mock the exchange response
mock_fetch_trades.return_value = [
{'timestamp': 1678886400000, 'symbol': 'BTC/USDT', 'price': 25000.0, 'quantity': 1.0},
{'timestamp': 1678886460000, 'symbol': 'BTC/USDT', 'price': 25050.0, 'quantity': 0.5}
]
# Call the method
await self.data_manager.fetch_and_process_ticks()
# Assert that the ticks are processed
candles, ticks = await self.data_manager.get_data()
self.assertEqual(len(ticks), 2)
self.assertEqual(len(candles), 1)
@patch('data.live_data.ccxt.mexc.fetch_ohlcv', new_callable=AsyncMock)
async def test_fetch_initial_candles(self, mock_fetch_ohlcv):
# Mock the exchange response
mock_fetch_ohlcv.return_value = [
[1678886400000, 25000.0, 25050.0, 24950.0, 25025.0, 100.0],
[1678886460000, 25025.0, 25100.0, 25000.0, 25075.0, 120.0]
]
# Call the method
await self.data_manager._fetch_initial_candles()
# Assert that the candles are fetched and formatted
candles, _ = await self.data_manager.get_data()
self.assertEqual(len(candles), 2)
self.assertEqual(candles[0]['open'], 25000.0)
self.assertEqual(candles[0]['close'], 25025.0)
if __name__ == '__main__':
unittest.main()

View File

@ -154,4 +154,4 @@ async def train(model, data_manager, optimizer, criterion_candle, criterion_volu
# Plot data # Plot data
if len(trade_history)>0: # only after the first trade if len(trade_history)>0: # only after the first trade
plot_live_data(candles, list(trade_history)) plot_live_data(candles, list(trade_history))
await asyncio.sleep(1) # Adjust sleep time as needed await asyncio.sleep(1) # Adjust sleep time as needed

View File

@ -32,4 +32,4 @@ def plot_live_data(candles, trade_history):
ax.grid(True) ax.grid(True)
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()