suggestions
This commit is contained in:
parent
33a5588539
commit
c8b0f77d32
31
crypto/gogo/.vscode/launch.json
vendored
Normal file
31
crypto/gogo/.vscode/launch.json
vendored
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
{
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Train (Live Data)",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/main.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false,
|
||||||
|
"args": ["--mode", "train"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Backtest (Historical Data)",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/main.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false,
|
||||||
|
"args": ["--mode", "backtest"]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Unit Tests",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${workspaceFolder}/tests/test_data.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
11
crypto/gogo/.vscode/settings.json
vendored
Normal file
11
crypto/gogo/.vscode/settings.json
vendored
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"python.testing.unittestArgs": [
|
||||||
|
"-v",
|
||||||
|
"-s",
|
||||||
|
"./tests",
|
||||||
|
"-p",
|
||||||
|
"test_*.py"
|
||||||
|
],
|
||||||
|
"python.testing.pytestEnabled": false,
|
||||||
|
"python.testing.unittestEnabled": true
|
||||||
|
}
|
1
crypto/gogo/candles_cache.json
Normal file
1
crypto/gogo/candles_cache.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"1m": [], "5m": [], "15m": [], "1h": [], "1d": []}
|
@ -1,5 +1,3 @@
|
|||||||
# main.py
|
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -17,8 +15,9 @@ from model.trading_model import TradingModel
|
|||||||
from training.rl_agent import ContinuousRLAgent, ReplayBuffer
|
from training.rl_agent import ContinuousRLAgent, ReplayBuffer
|
||||||
from training.train_historical import train_on_historical_data, load_best_checkpoint, save_candles_cache, CACHE_FILE, BEST_DIR
|
from training.train_historical import train_on_historical_data, load_best_checkpoint, save_candles_cache, CACHE_FILE, BEST_DIR
|
||||||
from data.data_utils import get_aligned_candle_with_index, get_features_for_tf
|
from data.data_utils import get_aligned_candle_with_index, get_features_for_tf
|
||||||
|
import argparse
|
||||||
|
|
||||||
async def main():
|
async def main_training():
|
||||||
symbol = 'BTC/USDT'
|
symbol = 'BTC/USDT'
|
||||||
data_manager = LiveDataManager(symbol)
|
data_manager = LiveDataManager(symbol)
|
||||||
|
|
||||||
@ -293,4 +292,11 @@ def plot_trade_history(candles, trade_history):
|
|||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
asyncio.run(main_backtest())
|
parser = argparse.ArgumentParser(description='Trading Bot Modes')
|
||||||
|
parser.add_argument('--mode', type=str, default='backtest', choices=['train', 'backtest'], help='Choose mode: train or backtest')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.mode == 'train':
|
||||||
|
asyncio.run(main_training())
|
||||||
|
elif args.mode == 'backtest':
|
||||||
|
asyncio.run(main_backtest())
|
||||||
|
@ -219,4 +219,4 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
print("Future Candle Prediction Shape:", future_candle_pred.shape) # Expected: [batch_size, 1, 5]
|
print("Future Candle Prediction Shape:", future_candle_pred.shape) # Expected: [batch_size, 1, 5]
|
||||||
print("Future Volume Prediction Shape:", future_volume_pred.shape) # Expected: [batch_size, 1, 1]
|
print("Future Volume Prediction Shape:", future_volume_pred.shape) # Expected: [batch_size, 1, 1]
|
||||||
print("Future Ticks Prediction Shape:", future_ticks_pred.shape) # Expected: [batch_size, 30, 2]
|
print("Future Ticks Prediction Shape:", future_ticks_pred.shape) # Expected: [batch_size, 30, 2]
|
||||||
|
@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
To run this code:
|
To run this code:
|
||||||
|
|
||||||
Install Dependencies: pip install -r requirements.txt
|
Install Dependencies: pip install -r requirements.txt
|
||||||
@ -28,4 +27,4 @@ Overfitting: Monitor for overfitting (the model performing well on training data
|
|||||||
|
|
||||||
Memory usage: the code uses a deque to store the data. This prevents out of memory errors and keeps only the most recent N samples.
|
Memory usage: the code uses a deque to store the data. This prevents out of memory errors and keeps only the most recent N samples.
|
||||||
|
|
||||||
Learned indicators: This is a complex part. you can create a new NN, that will be trained to predict the next candle data based only on HLOCV. the weights of this NN can be used as new indicators, concatenated to the others.
|
Learned indicators: This is a complex part. you can create a new NN, that will be trained to predict the next candle data based only on HLOCV. the weights of this NN can be used as new indicators, concatenated to the others.
|
||||||
|
@ -3,4 +3,4 @@ ccxt
|
|||||||
python-dotenv
|
python-dotenv
|
||||||
torch
|
torch
|
||||||
numpy
|
numpy
|
||||||
matplotlib
|
matplotlib
|
||||||
|
51
crypto/gogo/tests/test_data.py
Normal file
51
crypto/gogo/tests/test_data.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import asyncio
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import patch, AsyncMock
|
||||||
|
from data.live_data import LiveDataManager
|
||||||
|
import ccxt.async_support as ccxt
|
||||||
|
|
||||||
|
class TestLiveData(unittest.IsolatedAsyncioTestCase):
|
||||||
|
|
||||||
|
async def asyncSetUp(self):
|
||||||
|
self.symbol = 'BTC/USDT'
|
||||||
|
self.data_manager = LiveDataManager(self.symbol)
|
||||||
|
self.exchange = self.data_manager.exchange
|
||||||
|
|
||||||
|
async def asyncTearDown(self):
|
||||||
|
await self.data_manager.close()
|
||||||
|
|
||||||
|
@patch('data.live_data.ccxt.mexc.fetch_trades', new_callable=AsyncMock)
|
||||||
|
async def test_fetch_and_process_ticks(self, mock_fetch_trades):
|
||||||
|
# Mock the exchange response
|
||||||
|
mock_fetch_trades.return_value = [
|
||||||
|
{'timestamp': 1678886400000, 'symbol': 'BTC/USDT', 'price': 25000.0, 'quantity': 1.0},
|
||||||
|
{'timestamp': 1678886460000, 'symbol': 'BTC/USDT', 'price': 25050.0, 'quantity': 0.5}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
await self.data_manager.fetch_and_process_ticks()
|
||||||
|
|
||||||
|
# Assert that the ticks are processed
|
||||||
|
candles, ticks = await self.data_manager.get_data()
|
||||||
|
self.assertEqual(len(ticks), 2)
|
||||||
|
self.assertEqual(len(candles), 1)
|
||||||
|
|
||||||
|
@patch('data.live_data.ccxt.mexc.fetch_ohlcv', new_callable=AsyncMock)
|
||||||
|
async def test_fetch_initial_candles(self, mock_fetch_ohlcv):
|
||||||
|
# Mock the exchange response
|
||||||
|
mock_fetch_ohlcv.return_value = [
|
||||||
|
[1678886400000, 25000.0, 25050.0, 24950.0, 25025.0, 100.0],
|
||||||
|
[1678886460000, 25025.0, 25100.0, 25000.0, 25075.0, 120.0]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the method
|
||||||
|
await self.data_manager._fetch_initial_candles()
|
||||||
|
|
||||||
|
# Assert that the candles are fetched and formatted
|
||||||
|
candles, _ = await self.data_manager.get_data()
|
||||||
|
self.assertEqual(len(candles), 2)
|
||||||
|
self.assertEqual(candles[0]['open'], 25000.0)
|
||||||
|
self.assertEqual(candles[0]['close'], 25025.0)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
@ -154,4 +154,4 @@ async def train(model, data_manager, optimizer, criterion_candle, criterion_volu
|
|||||||
# Plot data
|
# Plot data
|
||||||
if len(trade_history)>0: # only after the first trade
|
if len(trade_history)>0: # only after the first trade
|
||||||
plot_live_data(candles, list(trade_history))
|
plot_live_data(candles, list(trade_history))
|
||||||
await asyncio.sleep(1) # Adjust sleep time as needed
|
await asyncio.sleep(1) # Adjust sleep time as needed
|
||||||
|
@ -32,4 +32,4 @@ def plot_live_data(candles, trade_history):
|
|||||||
ax.grid(True)
|
ax.grid(True)
|
||||||
plt.tight_layout()
|
plt.tight_layout()
|
||||||
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user