suggestions
This commit is contained in:
parent
33a5588539
commit
c8b0f77d32
31
crypto/gogo/.vscode/launch.json
vendored
Normal file
31
crypto/gogo/.vscode/launch.json
vendored
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
{
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Train (Live Data)",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/main.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false,
|
||||||
|
"args": ["--mode", "train"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Backtest (Historical Data)",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/main.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false,
|
||||||
|
"args": ["--mode", "backtest"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Unit Tests",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/tests/test_data.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
11
crypto/gogo/.vscode/settings.json
vendored
Normal file
11
crypto/gogo/.vscode/settings.json
vendored
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"python.testing.unittestArgs": [
|
||||||
|
"-v",
|
||||||
|
"-s",
|
||||||
|
"./tests",
|
||||||
|
"-p",
|
||||||
|
"test_*.py"
|
||||||
|
],
|
||||||
|
"python.testing.pytestEnabled": false,
|
||||||
|
"python.testing.unittestEnabled": true
|
||||||
|
}
|
1
crypto/gogo/candles_cache.json
Normal file
1
crypto/gogo/candles_cache.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"1m": [], "5m": [], "15m": [], "1h": [], "1d": []}
|
@ -1,5 +1,3 @@
|
|||||||
# main.py
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -17,8 +15,9 @@ from model.trading_model import TradingModel
|
|||||||
from training.rl_agent import ContinuousRLAgent, ReplayBuffer
|
from training.rl_agent import ContinuousRLAgent, ReplayBuffer
|
||||||
from training.train_historical import train_on_historical_data, load_best_checkpoint, save_candles_cache, CACHE_FILE, BEST_DIR
|
from training.train_historical import train_on_historical_data, load_best_checkpoint, save_candles_cache, CACHE_FILE, BEST_DIR
|
||||||
from data.data_utils import get_aligned_candle_with_index, get_features_for_tf
|
from data.data_utils import get_aligned_candle_with_index, get_features_for_tf
|
||||||
|
import argparse
|
||||||
|
|
||||||
async def main():
|
async def main_training():
|
||||||
symbol = 'BTC/USDT'
|
symbol = 'BTC/USDT'
|
||||||
data_manager = LiveDataManager(symbol)
|
data_manager = LiveDataManager(symbol)
|
||||||
|
|
||||||
@ -293,4 +292,11 @@ def plot_trade_history(candles, trade_history):
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Trading Bot Modes')
|
||||||
|
parser.add_argument('--mode', type=str, default='backtest', choices=['train', 'backtest'], help='Choose mode: train or backtest')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.mode == 'train':
|
||||||
|
asyncio.run(main_training())
|
||||||
|
elif args.mode == 'backtest':
|
||||||
asyncio.run(main_backtest())
|
asyncio.run(main_backtest())
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
To run this code:
|
To run this code:
|
||||||
|
|
||||||
Install Dependencies: pip install -r requirements.txt
|
Install Dependencies: pip install -r requirements.txt
|
||||||
|
51
crypto/gogo/tests/test_data.py
Normal file
51
crypto/gogo/tests/test_data.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import asyncio
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch, AsyncMock
|
||||||
|
from data.live_data import LiveDataManager
|
||||||
|
import ccxt.async_support as ccxt
|
||||||
|
|
||||||
|
class TestLiveData(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.symbol = 'BTC/USDT'
|
||||||
|
self.data_manager = LiveDataManager(self.symbol)
|
||||||
|
self.exchange = self.data_manager.exchange
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
await self.data_manager.close()
|
||||||
|
|
||||||
|
@patch('data.live_data.ccxt.mexc.fetch_trades', new_callable=AsyncMock)
|
||||||
|
async def test_fetch_and_process_ticks(self, mock_fetch_trades):
|
||||||
|
# Mock the exchange response
|
||||||
|
mock_fetch_trades.return_value = [
|
||||||
|
{'timestamp': 1678886400000, 'symbol': 'BTC/USDT', 'price': 25000.0, 'quantity': 1.0},
|
||||||
|
{'timestamp': 1678886460000, 'symbol': 'BTC/USDT', 'price': 25050.0, 'quantity': 0.5}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
await self.data_manager.fetch_and_process_ticks()
|
||||||
|
|
||||||
|
# Assert that the ticks are processed
|
||||||
|
candles, ticks = await self.data_manager.get_data()
|
||||||
|
self.assertEqual(len(ticks), 2)
|
||||||
|
self.assertEqual(len(candles), 1)
|
||||||
|
|
||||||
|
@patch('data.live_data.ccxt.mexc.fetch_ohlcv', new_callable=AsyncMock)
|
||||||
|
async def test_fetch_initial_candles(self, mock_fetch_ohlcv):
|
||||||
|
# Mock the exchange response
|
||||||
|
mock_fetch_ohlcv.return_value = [
|
||||||
|
[1678886400000, 25000.0, 25050.0, 24950.0, 25025.0, 100.0],
|
||||||
|
[1678886460000, 25025.0, 25100.0, 25000.0, 25075.0, 120.0]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
await self.data_manager._fetch_initial_candles()
|
||||||
|
|
||||||
|
# Assert that the candles are fetched and formatted
|
||||||
|
candles, _ = await self.data_manager.get_data()
|
||||||
|
self.assertEqual(len(candles), 2)
|
||||||
|
self.assertEqual(candles[0]['open'], 25000.0)
|
||||||
|
self.assertEqual(candles[0]['close'], 25025.0)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
x
Reference in New Issue
Block a user