diff --git a/crypto/gogo/.vscode/launch.json b/crypto/gogo/.vscode/launch.json new file mode 100644 index 0000000..cac00ca --- /dev/null +++ b/crypto/gogo/.vscode/launch.json @@ -0,0 +1,31 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Train (Live Data)", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/main.py", + "console": "integratedTerminal", + "justMyCode": false, + "args": ["--mode", "train"] + }, + { + "name": "Backtest (Historical Data)", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/main.py", + "console": "integratedTerminal", + "justMyCode": false, + "args": ["--mode", "backtest"] + }, + { + "name": "Unit Tests", + "type": "python", + "request": "launch", + "program": "${workspaceFolder}/tests/test_data.py", + "console": "integratedTerminal", + "justMyCode": false + } + ] +} diff --git a/crypto/gogo/.vscode/settings.json b/crypto/gogo/.vscode/settings.json new file mode 100644 index 0000000..e9e6a80 --- /dev/null +++ b/crypto/gogo/.vscode/settings.json @@ -0,0 +1,11 @@ +{ + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "test_*.py" + ], + "python.testing.pytestEnabled": false, + "python.testing.unittestEnabled": true +} \ No newline at end of file diff --git a/crypto/gogo/candles_cache.json b/crypto/gogo/candles_cache.json new file mode 100644 index 0000000..29b5f7f --- /dev/null +++ b/crypto/gogo/candles_cache.json @@ -0,0 +1 @@ +{"1m": [], "5m": [], "15m": [], "1h": [], "1d": []} \ No newline at end of file diff --git a/crypto/gogo/main.py b/crypto/gogo/main.py index 60e9124..7bfd97f 100644 --- a/crypto/gogo/main.py +++ b/crypto/gogo/main.py @@ -1,5 +1,3 @@ -# main.py - import asyncio import torch import torch.nn as nn @@ -17,8 +15,9 @@ from model.trading_model import TradingModel from training.rl_agent import ContinuousRLAgent, ReplayBuffer from training.train_historical import train_on_historical_data, load_best_checkpoint, save_candles_cache, CACHE_FILE, BEST_DIR from data.data_utils import get_aligned_candle_with_index, get_features_for_tf +import argparse -async def main(): +async def main_training(): symbol = 'BTC/USDT' data_manager = LiveDataManager(symbol) @@ -293,4 +292,11 @@ def plot_trade_history(candles, trade_history): plt.show() if __name__ == '__main__': - asyncio.run(main_backtest()) + parser = argparse.ArgumentParser(description='Trading Bot Modes') + parser.add_argument('--mode', type=str, default='backtest', choices=['train', 'backtest'], help='Choose mode: train or backtest') + args = parser.parse_args() + + if args.mode == 'train': + asyncio.run(main_training()) + elif args.mode == 'backtest': + asyncio.run(main_backtest()) diff --git a/crypto/gogo/model/transformer.py b/crypto/gogo/model/transformer.py index c6365b6..89532b8 100644 --- a/crypto/gogo/model/transformer.py +++ b/crypto/gogo/model/transformer.py @@ -219,4 +219,4 @@ if __name__ == '__main__': print("Future Candle Prediction Shape:", future_candle_pred.shape) # Expected: [batch_size, 1, 5] print("Future Volume Prediction Shape:", future_volume_pred.shape) # Expected: [batch_size, 1, 1] - print("Future Ticks Prediction Shape:", future_ticks_pred.shape) # Expected: [batch_size, 30, 2] \ No newline at end of file + print("Future Ticks Prediction Shape:", future_ticks_pred.shape) # Expected: [batch_size, 30, 2] diff --git a/crypto/gogo/readme.md b/crypto/gogo/readme.md index 65cbefe..e0f3e9e 100644 --- a/crypto/gogo/readme.md +++ b/crypto/gogo/readme.md @@ -1,4 +1,3 @@ - To run this code: Install Dependencies: pip install -r requirements.txt @@ -28,4 +27,4 @@ Overfitting: Monitor for overfitting (the model performing well on training data Memory usage: the code uses a deque to store the data. This prevents out of memory errors and keeps only the most recent N samples. -Learned indicators: This is a complex part. you can create a new NN, that will be trained to predict the next candle data based only on HLOCV. the weights of this NN can be used as new indicators, concatenated to the others. \ No newline at end of file +Learned indicators: This is a complex part. you can create a new NN, that will be trained to predict the next candle data based only on HLOCV. the weights of this NN can be used as new indicators, concatenated to the others. diff --git a/crypto/gogo/requirements.txt b/crypto/gogo/requirements.txt index 8b295a3..3a6e03d 100644 --- a/crypto/gogo/requirements.txt +++ b/crypto/gogo/requirements.txt @@ -3,4 +3,4 @@ ccxt python-dotenv torch numpy -matplotlib \ No newline at end of file +matplotlib diff --git a/crypto/gogo/tests/test_data.py b/crypto/gogo/tests/test_data.py new file mode 100644 index 0000000..6c174e5 --- /dev/null +++ b/crypto/gogo/tests/test_data.py @@ -0,0 +1,51 @@ +import asyncio +import unittest +from unittest.mock import patch, AsyncMock +from data.live_data import LiveDataManager +import ccxt.async_support as ccxt + +class TestLiveData(unittest.IsolatedAsyncioTestCase): + + async def asyncSetUp(self): + self.symbol = 'BTC/USDT' + self.data_manager = LiveDataManager(self.symbol) + self.exchange = self.data_manager.exchange + + async def asyncTearDown(self): + await self.data_manager.close() + + @patch('data.live_data.ccxt.mexc.fetch_trades', new_callable=AsyncMock) + async def test_fetch_and_process_ticks(self, mock_fetch_trades): + # Mock the exchange response + mock_fetch_trades.return_value = [ + {'timestamp': 1678886400000, 'symbol': 'BTC/USDT', 'price': 25000.0, 'quantity': 1.0}, + {'timestamp': 1678886460000, 'symbol': 'BTC/USDT', 'price': 25050.0, 'quantity': 0.5} + ] + + # Call the method + await self.data_manager.fetch_and_process_ticks() + + # Assert that the ticks are processed + candles, ticks = await self.data_manager.get_data() + self.assertEqual(len(ticks), 2) + self.assertEqual(len(candles), 1) + + @patch('data.live_data.ccxt.mexc.fetch_ohlcv', new_callable=AsyncMock) + async def test_fetch_initial_candles(self, mock_fetch_ohlcv): + # Mock the exchange response + mock_fetch_ohlcv.return_value = [ + [1678886400000, 25000.0, 25050.0, 24950.0, 25025.0, 100.0], + [1678886460000, 25025.0, 25100.0, 25000.0, 25075.0, 120.0] + ] + + # Call the method + await self.data_manager._fetch_initial_candles() + + # Assert that the candles are fetched and formatted + candles, _ = await self.data_manager.get_data() + self.assertEqual(len(candles), 2) + self.assertEqual(candles[0]['open'], 25000.0) + self.assertEqual(candles[0]['close'], 25025.0) + +if __name__ == '__main__': + unittest.main() diff --git a/crypto/gogo/training/train.py b/crypto/gogo/training/train.py index 408643e..255a12c 100644 --- a/crypto/gogo/training/train.py +++ b/crypto/gogo/training/train.py @@ -154,4 +154,4 @@ async def train(model, data_manager, optimizer, criterion_candle, criterion_volu # Plot data if len(trade_history)>0: # only after the first trade plot_live_data(candles, list(trade_history)) - await asyncio.sleep(1) # Adjust sleep time as needed \ No newline at end of file + await asyncio.sleep(1) # Adjust sleep time as needed diff --git a/crypto/gogo/visualization/plotting.py b/crypto/gogo/visualization/plotting.py index 5d4323f..0c97387 100644 --- a/crypto/gogo/visualization/plotting.py +++ b/crypto/gogo/visualization/plotting.py @@ -32,4 +32,4 @@ def plot_live_data(candles, trade_history): ax.grid(True) plt.tight_layout() - plt.show() \ No newline at end of file + plt.show()