Compare commits

..

13 Commits

Author SHA1 Message Date
Dobromir Popov
2c03675f3c fixed training again 2025-03-18 02:04:58 +02:00
Dobromir Popov
bdf6afc6ad better saves 2025-03-17 23:36:44 +02:00
Dobromir Popov
2e7a242ac7 backtseting support 2025-03-17 19:21:43 +02:00
Dobromir Popov
5e9e6360af wip - better training loop; realtime scaffold 2025-03-17 19:17:56 +02:00
Dobromir Popov
4de6352468 added CNN module 2025-03-17 16:18:11 +02:00
Dobromir Popov
58c3a81e6d fixed training again 2025-03-17 04:08:26 +02:00
Dobromir Popov
e87207c1fa fix training & demo mode 2025-03-17 03:59:50 +02:00
Dobromir Popov
c63b1a2daf added modes scripts 2025-03-17 03:21:51 +02:00
Dobromir Popov
469d681c4b added plots; fixes 2025-03-17 02:46:33 +02:00
Dobromir Popov
d9d0ba9da8 added live trade actions (wip) and candle chart 2025-03-17 02:35:15 +02:00
Dobromir Popov
991cf57274 fix model loading in live mode 2025-03-17 02:17:43 +02:00
Dobromir Popov
485c61cf8c show chart 2025-03-17 02:02:05 +02:00
Dobromir Popov
0b6bb000d2 added new params to the NN input 2025-03-17 01:56:07 +02:00
38 changed files with 60660 additions and 194458 deletions

23
.gitignore vendored
View File

@ -32,27 +32,6 @@ crypto/sol/.vs/*
crypto/brian/models/best/* crypto/brian/models/best/*
crypto/brian/models/last/* crypto/brian/models/last/*
crypto/brian/live_chart.html crypto/brian/live_chart.html
crypto/gogo2/models/*
crypto/gogo2/trading_bot.log crypto/gogo2/trading_bot.log
*.log *.log
crypto/gogo2/checkpoints/trading_agent_episode_*.pt
*trading_agent_continuous_*.pt
*trading_agent_episode_*.pt
crypto/gogo2/models/trading_agent_continuous_150.pt
crypto/gogo2/checkpoints/trading_agent_episode_0.pt
crypto/gogo2/checkpoints/trading_agent_episode_10.pt
crypto/gogo2/checkpoints/trading_agent_episode_20.pt
crypto/gogo2/checkpoints/trading_agent_episode_40.pt
crypto/gogo2/models/trading_agent_best_pnl.pt
crypto/gogo2/models/trading_agent_best_reward.pt
crypto/gogo2/models/trading_agent_best_winrate.pt
crypto/gogo2/models/trading_agent_continuous_0.pt
crypto/gogo2/models/trading_agent_continuous_50.pt
crypto/gogo2/models/trading_agent_continuous_100.pt
crypto/gogo2/models/trading_agent_continuous_150.pt
crypto/gogo2/models/trading_agent_emergency.pt
crypto/gogo2/models/trading_agent_episode_0.pt
crypto/gogo2/models/trading_agent_episode_10.pt
crypto/gogo2/models/trading_agent_episode_20.pt
crypto/gogo2/models/trading_agent_episode_30.pt
crypto/gogo2/models/trading_agent_final.pt

View File

@ -1 +0,0 @@
*.pt filter=lfs diff=lfs merge=lfs -text

View File

@ -1 +0,0 @@
*.pt

View File

@ -24,27 +24,53 @@
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "main.py", "program": "main.py",
"args": ["--mode", "live", "--demo"], "args": [
"--mode", "live",
"--demo", "true",
"--symbol", "ETH/USDT",
"--timeframe", "1m"
],
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": true "justMyCode": true,
"env": {
"PYTHONUNBUFFERED": "1"
}
}, },
{ {
"name": "Live Trading (Real)", "name": "Live Trading (Real)",
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "main.py", "program": "main.py",
"args": ["--mode", "live"], "args": [
"--mode", "live",
"--demo", "false",
"--symbol", "ETH/USDT",
"--timeframe", "1m",
"--leverage", "50"
],
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": true "justMyCode": true,
"env": {
"PYTHONUNBUFFERED": "1"
}
}, },
{ {
"name": "Continuous Training", "name": "Live Trading (BTC Futures)",
"type": "python", "type": "python",
"request": "launch", "request": "launch",
"program": "main.py", "program": "main.py",
"args": ["--mode", "continuous", "--refresh-data"], "args": [
"--mode", "live",
"--demo", "false",
"--symbol", "BTC/USDT",
"--timeframe", "5m",
"--leverage", "20"
],
"console": "integratedTerminal", "console": "integratedTerminal",
"justMyCode": true "justMyCode": true,
"env": {
"PYTHONUNBUFFERED": "1"
}
} }
] ]
} }

View File

@ -0,0 +1,74 @@
# Model Saving Fix
## Issue
During training sessions, PyTorch model saving operations sometimes fail with errors like:
```
RuntimeError: [enforce fail at inline_container.cc:626] . unexpected pos 18278784 vs 18278680
```
or
```
RuntimeError: [enforce fail at inline_container.cc:820] . PytorchStreamWriter failed writing file data/75: file write failed
```
These errors occur in the PyTorch serialization mechanism when saving models using `torch.save()`.
## Solution
We've implemented a robust model saving approach that uses multiple fallback methods if the primary save operation fails:
1. **Attempt 1**: Save to a backup file first, then copy to the target path.
2. **Attempt 2**: Use an older pickle protocol (pickle protocol 2) which can be more compatible.
3. **Attempt 3**: Save without the optimizer state, which can reduce file size and avoid serialization issues.
4. **Attempt 4**: Use TorchScript's `torch.jit.save()` instead of `torch.save()`, which uses a different serialization mechanism.
## Implementation
The solution is implemented in two parts:
1. A `robust_save` function that tries multiple saving approaches with fallbacks.
2. A monkey patch that replaces the Agent's `save` method with our robust version.
### Example Usage
```python
# Import the robust_save function
from live_training import robust_save
# Save a model with fallbacks
success = robust_save(agent, "models/my_model.pt")
if success:
print("Model saved successfully!")
else:
print("All save attempts failed")
```
## Testing
We've created a test script `test_save.py` that demonstrates the robust saving approach and verifies that it works correctly.
To run the test:
```bash
python test_save.py
```
This script creates a simple model, attempts to save it using both the standard and robust methods, and reports on the results.
## Future Improvements
Possible future improvements to the model saving mechanism:
1. Additional fallback methods like serializing individual neural network layers.
2. Automatic retry mechanism with exponential backoff.
3. Asynchronous saving to avoid blocking the training loop.
4. Checksumming saved models to verify integrity.
## Related Issues
For more information on similar issues with PyTorch model saving, see:
- https://github.com/pytorch/pytorch/issues/27736
- https://github.com/pytorch/pytorch/issues/24045

View File

@ -1,9 +1,17 @@
https://github.com/mexcdevelop/mexc-api-sdk/blob/main/README.md#test-new-order
python mexc_tick_visualizer.py --symbol BTC/USDT --interval 1.0 --candle 60
ensure we use GPU if available to train faster. during training we need to have RL loop that looks at streaming data, and retrospective backtesting/training on predictions. sincr the start of the traing we're only loosing. implement robust penalty and analysis when closing a loosing trade and improve the reward function. ensure we use GPU if available to train faster. during training we need to have RL loop that looks at streaming data, and retrospective backtesting/training on predictions. sincr the start of the traing we're only loosing. implement robust penalty and analysis when closing a loosing trade and improve the reward function.
add 1h and 1d OHLCV data to let the model have the price action context
2025-03-10 12:11:28,651 - INFO - Initialized environment with 500 candles 2025-03-10 12:11:28,651 - INFO - Initialized environment with 500 candles
C:\Users\popov\miniforge3\Lib\site-packages\torch\nn\modules\transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance) C:\Users\popov\miniforge3\Lib\site-packages\torch\nn\modules\transformer.py:385: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.batch_first was not True(use batch_first for better inference performance)
@ -16,9 +24,4 @@ C:\Users\popov\miniforge3\Lib\site-packages\torch\amp\grad_scaler.py:132: UserWa
2025-03-10 12:11:30,927 - INFO - Starting training on device: cpu 2025-03-10 12:11:30,927 - INFO - Starting training on device: cpu
2025-03-10 12:11:30,928 - ERROR - Training failed: 'TradingEnvironment' object has no attribute 'initialize_price_predictor' 2025-03-10 12:11:30,928 - ERROR - Training failed: 'TradingEnvironment' object has no attribute 'initialize_price_predictor'
2025-03-10 12:11:30,928 - INFO - Exchange connection closed 2025-03-10 12:11:30,928 - INFO - Exchange connection closed
Backend tkagg is interactive backend. Turning interactive mode on. Backend tkagg is interactive backend. Turning interactive mode on.
2025-03-10 12:35:14,489 - INFO - Episode 34: Reward=232.41, Balance=$98.47, Win Rate=70.6%, Trades=17, Episode PnL=$-1.33, Total PnL=$-559.78, Max Drawdown=7.0%, Pred Accuracy=99.9%

View File

@ -0,0 +1,166 @@
import os
import sys
import logging
import importlib
import asyncio
from dotenv import load_dotenv
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()]
)
logger = logging.getLogger("check_live_trading")
def check_dependencies():
"""Check if all required dependencies are installed"""
required_packages = [
"numpy", "pandas", "matplotlib", "mplfinance", "torch",
"dotenv", "ccxt", "websockets", "tensorboard",
"sklearn", "PIL", "asyncio"
]
missing_packages = []
for package in required_packages:
try:
if package == "dotenv":
importlib.import_module("dotenv")
elif package == "PIL":
importlib.import_module("PIL")
else:
importlib.import_module(package)
logger.info(f"{package} is installed")
except ImportError:
missing_packages.append(package)
logger.error(f"{package} is NOT installed")
if missing_packages:
logger.error(f"Missing packages: {', '.join(missing_packages)}")
logger.info("Install missing packages with: pip install -r requirements.txt")
return False
return True
def check_api_keys():
"""Check if API keys are configured"""
load_dotenv()
api_key = os.getenv('MEXC_API_KEY')
secret_key = os.getenv('MEXC_SECRET_KEY')
if not api_key or api_key == "your_api_key_here" or not secret_key or secret_key == "your_secret_key_here":
logger.error("❌ API keys are not properly configured in .env file")
logger.info("Please update your .env file with valid MEXC API keys")
return False
logger.info("✅ API keys are configured")
return True
def check_model_files():
"""Check if trained model files exist"""
model_files = [
"models/trading_agent_best_pnl.pt",
"models/trading_agent_best_reward.pt",
"models/trading_agent_final.pt"
]
missing_models = []
for model_file in model_files:
if os.path.exists(model_file):
logger.info(f"✅ Model file exists: {model_file}")
else:
missing_models.append(model_file)
logger.error(f"❌ Model file missing: {model_file}")
if missing_models:
logger.warning("Some model files are missing. You need to train the model first.")
return False
return True
async def check_exchange_connection():
"""Test connection to MEXC exchange"""
try:
import ccxt
# Load API keys
load_dotenv()
api_key = os.getenv('MEXC_API_KEY')
secret_key = os.getenv('MEXC_SECRET_KEY')
if api_key == "your_api_key_here" or secret_key == "your_secret_key_here":
logger.warning("⚠️ Using placeholder API keys, skipping exchange connection test")
return False
# Initialize exchange
exchange = ccxt.mexc({
'apiKey': api_key,
'secret': secret_key,
'enableRateLimit': True
})
# Test connection by fetching markets
markets = exchange.fetch_markets()
logger.info(f"✅ Successfully connected to MEXC exchange")
logger.info(f"✅ Found {len(markets)} markets")
return True
except Exception as e:
logger.error(f"❌ Failed to connect to MEXC exchange: {str(e)}")
return False
def check_directories():
"""Check if required directories exist"""
required_dirs = ["models", "runs", "trade_logs"]
for directory in required_dirs:
if not os.path.exists(directory):
logger.info(f"Creating directory: {directory}")
os.makedirs(directory, exist_ok=True)
logger.info("✅ All required directories exist")
return True
async def main():
"""Run all checks"""
logger.info("Running pre-flight checks for live trading...")
checks = [
("Dependencies", check_dependencies()),
("API Keys", check_api_keys()),
("Model Files", check_model_files()),
("Directories", check_directories()),
("Exchange Connection", await check_exchange_connection())
]
# Count failed checks
failed_checks = sum(1 for _, result in checks if not result)
# Print summary
logger.info("\n" + "="*50)
logger.info("LIVE TRADING PRE-FLIGHT CHECK SUMMARY")
logger.info("="*50)
for check_name, result in checks:
status = "✅ PASS" if result else "❌ FAIL"
logger.info(f"{check_name}: {status}")
logger.info("="*50)
if failed_checks == 0:
logger.info("🚀 All checks passed! You're ready for live trading.")
logger.info("\nRun live trading with:")
logger.info("python main.py --mode live --demo true --symbol ETH/USDT --timeframe 1m")
logger.info("\nFor real trading (after updating API keys):")
logger.info("python main.py --mode live --demo false --symbol ETH/USDT --timeframe 1m --leverage 50")
return 0
else:
logger.error(f"{failed_checks} check(s) failed. Please fix the issues before running live trading.")
return 1
if __name__ == "__main__":
exit_code = asyncio.run(main())
sys.exit(exit_code)

View File

@ -1 +0,0 @@
{"best_reward": 202.7441047517104, "best_pnl": 9.268344827764809, "best_win_rate": 73.33333333333333, "last_episode": 30, "timestamp": "2025-03-10T17:57:19.913481"}

View File

@ -1,4 +0,0 @@
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda if torch.cuda.is_available() else 'Not available'}")

View File

@ -0,0 +1,593 @@
#!/usr/bin/env python
import asyncio
import logging
import sys
import platform
import argparse
import os
import datetime
import traceback
import numpy as np
import torch
import gc
from functools import partial
from main import initialize_exchange, TradingEnvironment, Agent
from torch.utils.tensorboard import SummaryWriter
# Fix for Windows asyncio issues with aiodns
if platform.system() == 'Windows':
try:
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
except Exception as e:
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
# Setup logging function
def setup_logging():
"""Setup logging configuration for the application"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("live_training.log"),
logging.StreamHandler(sys.stdout) # Added stdout handler for immediate feedback
]
)
# Set up logging
setup_logging()
logger = logging.getLogger(__name__)
# Implement a robust save function to handle PyTorch serialization errors
def robust_save(model, path):
"""
Robust model saving with multiple fallback approaches
Args:
model: The Agent model to save
path: Path to save the model
Returns:
bool: True if successful, False otherwise
"""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
# Backup path in case the main save fails
backup_path = f"{path}.backup"
# Clean up GPU memory before saving
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Attempt 1: Try with default settings in a separate file first
try:
logger.info(f"Saving model to {backup_path} (attempt 1)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, backup_path)
logger.info(f"Successfully saved to {backup_path}")
# If backup worked, copy to the actual path
if os.path.exists(backup_path):
import shutil
shutil.copy(backup_path, path)
logger.info(f"Copied backup to {path}")
return True
except Exception as e:
logger.warning(f"First save attempt failed: {e}")
# Attempt 2: Try with pickle protocol 2 (more compatible)
try:
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path, pickle_protocol=2)
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
return True
except Exception as e:
logger.warning(f"Second save attempt failed: {e}")
# Attempt 3: Try without optimizer state (which can be large and cause issues)
try:
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path)
logger.info(f"Successfully saved to {path} without optimizer state")
return True
except Exception as e:
logger.warning(f"Third save attempt failed: {e}")
# Attempt 4: Try with torch.jit.save instead
try:
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
# Save policy network using jit
scripted_policy = torch.jit.script(model.policy_net)
torch.jit.save(scripted_policy, f"{path}.policy.jit")
# Save target network using jit
scripted_target = torch.jit.script(model.target_net)
torch.jit.save(scripted_target, f"{path}.target.jit")
# Save epsilon value separately
with open(f"{path}.epsilon.txt", "w") as f:
f.write(str(model.epsilon))
logger.info(f"Successfully saved model components with jit.save")
return True
except Exception as e:
logger.error(f"All save attempts failed: {e}")
return False
# Implement timeout wrapper for exchange operations
async def with_timeout(coroutine, timeout=30, default=None):
"""
Execute a coroutine with a timeout
Args:
coroutine: The coroutine to execute
timeout: Timeout in seconds
default: Default value to return on timeout
Returns:
The result of the coroutine or default value on timeout
"""
try:
return await asyncio.wait_for(coroutine, timeout=timeout)
except asyncio.TimeoutError:
logger.warning(f"Operation timed out after {timeout} seconds")
return default
except Exception as e:
logger.error(f"Operation failed: {e}")
return default
# Implement fetch_and_update_data function
async def fetch_and_update_data(exchange, env, symbol, timeframe):
"""
Fetch new candle data and update the environment
Args:
exchange: CCXT exchange instance
env: Trading environment instance
symbol: Trading pair symbol
timeframe: Timeframe for the candles
"""
logger.info(f"Fetching new data for {symbol} on {timeframe} timeframe")
try:
# Default to 100 candles if not specified
limit = 1000
# Fetch OHLCV data with timeout
candles = await with_timeout(
exchange.fetch_ohlcv(symbol, timeframe, limit=limit),
timeout=30,
default=[]
)
if not candles or len(candles) == 0:
logger.warning(f"No candles returned for {symbol} on {timeframe}")
return False
logger.info(f"Successfully fetched {len(candles)} candles")
# Convert to format expected by environment
formatted_candles = []
for candle in candles:
timestamp, open_price, high, low, close, volume = candle
formatted_candles.append({
'timestamp': timestamp,
'open': open_price,
'high': high,
'low': low,
'close': close,
'volume': volume
})
# Update environment data
env.data = formatted_candles
if hasattr(env, '_initialize_features'):
env._initialize_features()
logger.info(f"Updated environment with {len(formatted_candles)} candles")
# Print latest candle info
if formatted_candles:
latest = formatted_candles[-1]
dt = datetime.datetime.fromtimestamp(latest['timestamp']/1000).strftime('%Y-%m-%d %H:%M:%S')
logger.info(f"Latest candle: Time={dt}, Open={latest['open']}, High={latest['high']}, Low={latest['low']}, Close={latest['close']}, Volume={latest['volume']}")
return True
except Exception as e:
logger.error(f"Error fetching candle data: {e}")
logger.error(traceback.format_exc())
return False
# Implement memory management function
def manage_memory():
"""
Clean up memory to avoid memory leaks during long running sessions
"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
logger.debug("Memory cleaned")
async def live_training(
symbol="ETH/USDT",
timeframe="1m",
model_path="models/trading_agent_best_pnl.pt",
save_path="models/trading_agent_live_trained.pt",
initial_balance=1000,
update_interval=60,
training_iterations=100,
learning_rate=0.0001,
batch_size=64,
gamma=0.99,
window_size=30,
max_episodes=0, # 0 means unlimited
retry_delay=5, # Seconds to wait before retrying after an error
max_retries=3, # Maximum number of retries for operations
):
"""
Live training function that uses real market data to improve the model without executing real trades.
Args:
symbol: Trading pair symbol
timeframe: Timeframe for training
model_path: Path to the initial model to load
save_path: Path to save the improved model
initial_balance: Initial balance for simulation
update_interval: Interval to update data in seconds
training_iterations: Number of training iterations per data update
learning_rate: Learning rate for training
batch_size: Batch size for training
gamma: Discount factor for training
window_size: Window size for the environment
max_episodes: Maximum number of episodes (0 for unlimited)
retry_delay: Seconds to wait before retrying after an error
max_retries: Maximum number of retries for operations
"""
logger.info(f"Starting live training for {symbol} on {timeframe} timeframe")
# Initialize exchange (without sandbox mode)
exchange = None
# Retry loop for exchange initialization
for retry in range(max_retries):
try:
exchange = await initialize_exchange()
logger.info(f"Exchange initialized: {exchange.id}")
break
except Exception as e:
logger.error(f"Error initializing exchange (attempt {retry+1}/{max_retries}): {e}")
if retry < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached. Could not initialize exchange.")
return
try:
# Initialize environment
env = TradingEnvironment(
initial_balance=initial_balance,
window_size=window_size,
symbol=symbol,
timeframe=timeframe,
)
# Fetch initial data (with retries)
logger.info(f"Fetching initial data for {symbol}")
success = False
for retry in range(max_retries):
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
if success:
break
logger.warning(f"Failed to fetch initial data (attempt {retry+1}/{max_retries})")
if retry < max_retries - 1:
logger.info(f"Retrying in {retry_delay} seconds...")
await asyncio.sleep(retry_delay)
if not success:
logger.error("Failed to fetch initial data after multiple attempts, exiting")
return
# Initialize agent
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE, hidden_size=384)
# Load model if provided
if os.path.exists(model_path):
try:
agent.load(model_path)
logger.info(f"Model loaded successfully from {model_path}")
except Exception as e:
logger.warning(f"Error loading model: {e}")
logger.info("Starting with a new model")
else:
logger.warning(f"Model file {model_path} not found. Starting with a new model.")
# Initialize TensorBoard writer
run_id = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter(log_dir=f"runs/live_training_{run_id}")
agent.writer = writer
# Initialize training statistics
total_rewards = 0
episode_count = 0
best_reward = float('-inf')
best_pnl = float('-inf')
# Start live training loop
logger.info(f"Starting live training loop")
step_counter = 0
last_update_time = datetime.datetime.now()
# Track consecutive errors to enable circuit breaker
consecutive_errors = 0
max_consecutive_errors = 5
while True:
# Check if we've reached the maximum number of episodes
if max_episodes > 0 and episode_count >= max_episodes:
logger.info(f"Reached maximum episodes ({max_episodes}), stopping")
break
# Check if it's time to update data
current_time = datetime.datetime.now()
time_diff = (current_time - last_update_time).total_seconds()
if time_diff >= update_interval:
logger.info(f"Updating market data after {time_diff:.1f} seconds")
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
if not success:
logger.warning("Failed to update data, will try again later")
# Wait a bit before trying again
await asyncio.sleep(retry_delay)
continue
last_update_time = current_time
# Clean up memory before running an episode
manage_memory()
# Run training iterations on the updated data
episode_reward = 0
env.reset()
done = False
# Run one simulated episode with the current data
steps_in_episode = 0
max_steps = len(env.data) - env.window_size - 1
logger.info(f"Starting episode {episode_count + 1} with {max_steps} steps")
while not done and steps_in_episode < max_steps:
try:
state = env.get_state()
action = agent.select_action(state, training=True)
try:
next_state, reward, done, info = env.step(action)
except ValueError as e:
logger.error(f"Error during env.step: {e}")
# If we get a ValueError, it might be because step is returning 3 values instead of 4
# Let's try to handle this case
if "too many values to unpack" in str(e):
logger.info("Trying alternative step format")
result = env.step(action)
if len(result) == 3:
next_state, reward, done = result
info = {}
else:
raise
else:
raise
# Save experience in replay memory
agent.memory.push(state, action, reward, next_state, done)
# Move to the next state
state = next_state
episode_reward += reward
step_counter += 1
steps_in_episode += 1
# Log action and results every 50 steps
if steps_in_episode % 50 == 0:
logger.info(f"Step {steps_in_episode}/{max_steps} | Action: {action} | Reward: {reward:.2f} | Balance: ${env.balance:.2f}")
# Train the agent on a batch of experiences
if len(agent.memory) > batch_size:
try:
agent.learn()
# Additional training iterations
if steps_in_episode % 10 == 0 and training_iterations > 1:
for _ in range(training_iterations - 1):
agent.learn()
# Reset consecutive errors counter on successful learning
consecutive_errors = 0
except Exception as e:
logger.error(f"Error during learning: {e}")
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
break
if done:
logger.info(f"Episode done after {steps_in_episode} steps")
break
except Exception as e:
logger.error(f"Error during episode step: {e}")
logger.error(traceback.format_exc())
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.warning(f"Circuit breaker triggered after {max_consecutive_errors} consecutive errors")
break
# Update training statistics
episode_count += 1
total_rewards += episode_reward
avg_reward = total_rewards / episode_count
# Track metrics
writer.add_scalar('LiveTraining/Reward', episode_reward, episode_count)
writer.add_scalar('LiveTraining/AvgReward', avg_reward, episode_count)
writer.add_scalar('LiveTraining/Balance', env.balance, episode_count)
writer.add_scalar('LiveTraining/PnL', env.total_pnl, episode_count)
# Report progress
logger.info(f"""
Episode: {episode_count}
Reward: {episode_reward:.2f}
Avg Reward: {avg_reward:.2f}
Balance: ${env.balance:.2f}
PnL: ${env.total_pnl:.2f}
Memory Size: {len(agent.memory)}
Total Steps: {step_counter}
""")
# Save the model if it's the best so far (by reward or PnL)
if episode_reward > best_reward:
best_reward = episode_reward
reward_model_path = f"models/trading_agent_best_reward_{run_id}.pt"
if robust_save(agent, reward_model_path):
logger.info(f"New best reward model saved: {episode_reward:.2f} to {reward_model_path}")
else:
logger.error(f"Failed to save best reward model")
if env.total_pnl > best_pnl:
best_pnl = env.total_pnl
pnl_model_path = f"models/trading_agent_best_pnl_{run_id}.pt"
if robust_save(agent, pnl_model_path):
logger.info(f"New best PnL model saved: ${env.total_pnl:.2f} to {pnl_model_path}")
else:
logger.error(f"Failed to save best PnL model")
# Regularly save the model
if episode_count % 5 == 0:
if robust_save(agent, save_path):
logger.info(f"Model checkpoint saved to {save_path}")
else:
logger.error(f"Failed to save checkpoint")
# Update target network periodically
if episode_count % 5 == 0:
try:
agent.update_target_network()
logger.info("Target network updated")
except Exception as e:
logger.error(f"Error updating target network: {e}")
# Sleep to avoid excessive API calls
await asyncio.sleep(1)
except asyncio.CancelledError:
logger.info("Live training cancelled")
except KeyboardInterrupt:
logger.info("Live training stopped by user")
except Exception as e:
logger.error(f"Error in live training: {e}")
logger.error(traceback.format_exc())
finally:
# Save final model
if 'agent' in locals():
if robust_save(agent, save_path):
logger.info(f"Final model saved to {save_path}")
else:
logger.error(f"Failed to save final model")
# Close TensorBoard writer
try:
writer.close()
logger.info("TensorBoard writer closed")
except Exception as e:
logger.error(f"Error closing TensorBoard writer: {e}")
# Close exchange connection
if exchange:
try:
await with_timeout(exchange.close(), timeout=10)
logger.info("Exchange connection closed")
except Exception as e:
logger.error(f"Error closing exchange connection: {e}")
# Final memory cleanup
manage_memory()
logger.info("Live training completed")
async def main():
"""Main function to parse arguments and start live training"""
parser = argparse.ArgumentParser(description='Live Training with Real Market Data')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for training')
parser.add_argument('--model_path', type=str, default='models/trading_agent_best_pnl.pt', help='Path to initial model')
parser.add_argument('--save_path', type=str, default='models/trading_agent_live_trained.pt', help='Path to save improved model')
parser.add_argument('--initial_balance', type=float, default=1000, help='Initial balance for simulation')
parser.add_argument('--update_interval', type=int, default=60, help='Interval to update data in seconds')
parser.add_argument('--training_iterations', type=int, default=100, help='Training iterations per update')
parser.add_argument('--max_episodes', type=int, default=0, help='Maximum number of episodes (0 for unlimited)')
parser.add_argument('--retry_delay', type=int, default=5, help='Seconds to wait before retrying after an error')
parser.add_argument('--max_retries', type=int, default=3, help='Maximum number of retries for operations')
args = parser.parse_args()
logger.info(f"Starting live training with {args.symbol} on {args.timeframe} timeframe")
await live_training(
symbol=args.symbol,
timeframe=args.timeframe,
model_path=args.model_path,
save_path=args.save_path,
initial_balance=args.initial_balance,
update_interval=args.update_interval,
training_iterations=args.training_iterations,
max_episodes=args.max_episodes,
retry_delay=args.retry_delay,
max_retries=args.max_retries,
)
# Override Agent's save method with our robust save function
def monkey_patch_agent_save():
"""Replace Agent's save method with our robust save approach"""
original_save = Agent.save
def patched_save(self, path):
return robust_save(self, path)
# Apply the patch
Agent.save = patched_save
logger.info("Monkey patched Agent.save with robust_save")
# Return the original method in case we need to restore it
return original_save
# Call the monkey patch function at the appropriate place
if __name__ == "__main__":
try:
print("Starting live training script")
# Apply the monkey patch before running the main function
original_save = monkey_patch_agent_save()
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Live training stopped by user")
except Exception as e:
logger.error(f"Error in main function: {e}")
logger.error(traceback.format_exc())

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,240 @@
import os
import json
import asyncio
import logging
import datetime
import numpy as np
import pandas as pd
import websockets
from dotenv import load_dotenv
from torch.utils.tensorboard import SummaryWriter
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.FileHandler("mexc_tick_stream.log"), logging.StreamHandler()]
)
logger = logging.getLogger("mexc_tick_stream")
# Load environment variables
load_dotenv()
MEXC_API_KEY = os.getenv('MEXC_API_KEY')
MEXC_SECRET_KEY = os.getenv('MEXC_SECRET_KEY')
class MexcTickStreamer:
def __init__(self, symbol="ETH/USDT", update_interval=1.0):
"""
Initialize the MEXC tick data streamer
Args:
symbol: Trading pair symbol (e.g., "ETH/USDT")
update_interval: How often to update the TensorBoard visualization (in seconds)
"""
self.symbol = symbol.replace("/", "").upper() # Convert to MEXC format (e.g., ETHUSDT)
self.update_interval = update_interval
self.uri = "wss://wbs-api.mexc.com/ws"
self.writer = SummaryWriter(f'runs/mexc_ticks_{self.symbol}')
self.trades = []
self.last_update_time = 0
self.running = False
# For visualization
self.price_history = []
self.volume_history = []
self.buy_volume = 0
self.sell_volume = 0
self.step = 0
async def connect(self):
"""Connect to MEXC WebSocket and subscribe to tick data"""
try:
self.websocket = await websockets.connect(self.uri)
logger.info(f"Connected to MEXC WebSocket for {self.symbol}")
# Subscribe to trade stream (using non-protobuf endpoint for simplicity)
subscribe_msg = {
"method": "SUBSCRIPTION",
"params": [f"spot@public.deals.v3.api@{self.symbol}"]
}
await self.websocket.send(json.dumps(subscribe_msg))
logger.info(f"Subscribed to {self.symbol} tick data")
# Start ping task to keep connection alive
asyncio.create_task(self.ping_loop())
return True
except Exception as e:
logger.error(f"Error connecting to MEXC WebSocket: {e}")
return False
async def ping_loop(self):
"""Send ping messages to keep the connection alive"""
while self.running:
try:
await self.websocket.send(json.dumps({"method": "PING"}))
await asyncio.sleep(30) # Send ping every 30 seconds
except Exception as e:
logger.error(f"Error in ping loop: {e}")
break
async def process_message(self, message):
"""Process incoming WebSocket messages"""
try:
# Try to parse as JSON
try:
data = json.loads(message)
# Handle PONG response
if data.get("msg") == "PONG":
return
# Handle subscription confirmation
if data.get("code") == 0:
logger.info(f"Subscription confirmed: {data.get('msg')}")
return
# Handle trade data in the non-protobuf format
if "c" in data and "d" in data and "deals" in data["d"]:
for trade in data["d"]["deals"]:
# Extract trade data
price = float(trade["p"])
quantity = float(trade["v"])
trade_type = 1 if trade["S"] == 1 else 2 # 1 for buy, 2 for sell
timestamp = trade["t"]
# Store trade data
self.trades.append({
"price": price,
"quantity": quantity,
"type": "buy" if trade_type == 1 else "sell",
"timestamp": timestamp
})
# Update volume counters
if trade_type == 1: # Buy
self.buy_volume += quantity
else: # Sell
self.sell_volume += quantity
# Store for visualization
self.price_history.append(price)
self.volume_history.append(quantity)
# Limit history size to prevent memory issues
if len(self.price_history) > 10000:
self.price_history = self.price_history[-5000:]
self.volume_history = self.volume_history[-5000:]
# Update TensorBoard if enough time has passed
current_time = datetime.datetime.now().timestamp()
if current_time - self.last_update_time >= self.update_interval:
await self.update_tensorboard()
self.last_update_time = current_time
except json.JSONDecodeError:
# If it's not valid JSON, it might be binary protobuf data
logger.debug("Received binary data, skipping (protobuf not implemented)")
except Exception as e:
logger.error(f"Error processing message: {e}")
async def update_tensorboard(self):
"""Update TensorBoard visualizations"""
try:
if not self.price_history:
return
# Calculate metrics
current_price = self.price_history[-1]
avg_price = np.mean(self.price_history[-100:]) if len(self.price_history) >= 100 else np.mean(self.price_history)
price_std = np.std(self.price_history[-100:]) if len(self.price_history) >= 100 else np.std(self.price_history)
# Calculate VWAP (Volume Weighted Average Price)
if len(self.price_history) >= 100 and len(self.volume_history) >= 100:
vwap = np.sum(np.array(self.price_history[-100:]) * np.array(self.volume_history[-100:])) / np.sum(self.volume_history[-100:])
else:
vwap = np.sum(np.array(self.price_history) * np.array(self.volume_history)) / np.sum(self.volume_history) if np.sum(self.volume_history) > 0 else current_price
# Calculate buy/sell ratio
total_volume = self.buy_volume + self.sell_volume
buy_ratio = self.buy_volume / total_volume if total_volume > 0 else 0.5
# Log to TensorBoard
self.writer.add_scalar('Price/Current', current_price, self.step)
self.writer.add_scalar('Price/VWAP', vwap, self.step)
self.writer.add_scalar('Price/StdDev', price_std, self.step)
self.writer.add_scalar('Volume/BuyRatio', buy_ratio, self.step)
self.writer.add_scalar('Volume/Total', total_volume, self.step)
# Create a candlestick-like chart for the last 100 ticks
if len(self.price_history) >= 100:
prices = np.array(self.price_history[-100:])
self.writer.add_histogram('Price/Distribution', prices, self.step)
# Create a custom scalars panel
layout = {
"Price": {
"Current vs VWAP": ["Multiline", ["Price/Current", "Price/VWAP"]],
},
"Volume": {
"Buy Ratio": ["Multiline", ["Volume/BuyRatio"]],
}
}
self.writer.add_custom_scalars(layout)
self.step += 1
logger.info(f"Updated TensorBoard: Price={current_price:.2f}, VWAP={vwap:.2f}, Buy Ratio={buy_ratio:.2f}")
except Exception as e:
logger.error(f"Error updating TensorBoard: {e}")
async def run(self):
"""Main loop to receive and process WebSocket messages"""
self.running = True
self.last_update_time = datetime.datetime.now().timestamp()
if not await self.connect():
logger.error("Failed to connect. Exiting.")
return
try:
while self.running:
message = await self.websocket.recv()
await self.process_message(message)
except websockets.exceptions.ConnectionClosed:
logger.warning("WebSocket connection closed")
except Exception as e:
logger.error(f"Error in run loop: {e}")
finally:
self.running = False
await self.cleanup()
async def cleanup(self):
"""Clean up resources"""
try:
if hasattr(self, 'websocket'):
await self.websocket.close()
self.writer.close()
logger.info("Cleaned up resources")
except Exception as e:
logger.error(f"Error during cleanup: {e}")
async def main():
"""Main entry point"""
# Parse command line arguments
import argparse
parser = argparse.ArgumentParser(description='MEXC Tick Data Streamer')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol (e.g., ETH/USDT)')
parser.add_argument('--interval', type=float, default=1.0, help='TensorBoard update interval in seconds')
args = parser.parse_args()
# Create and run the streamer
streamer = MexcTickStreamer(symbol=args.symbol, update_interval=args.interval)
await streamer.run()
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Program interrupted by user")
except Exception as e:
logger.error(f"Unhandled exception: {e}")

View File

@ -46,12 +46,6 @@ pip install -r requirements.txt
```bash ```bash
MEXC_API_KEY=your_api_key MEXC_API_KEY=your_api_key
MEXC_API_SECRET=your_api_secret MEXC_API_SECRET=your_api_secret
cuda support
```bash
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
``` ```
## Usage ## Usage

View File

@ -1,10 +1,12 @@
numpy>=1.21.0 numpy>=1.21.0
pandas>=1.3.0 pandas>=1.3.0
matplotlib>=3.4.0 matplotlib>=3.4.0
mplfinance>=0.12.7
torch>=1.9.0 torch>=1.9.0
python-dotenv>=0.19.0 python-dotenv>=0.19.0
ccxt>=2.0.0 ccxt>=2.0.0
websockets>=10.0 websockets>=10.0
tensorboard>=2.6.0 tensorboard>=2.6.0
scikit-learn scikit-learn>=1.0.0
mplfinance Pillow>=9.0.0
asyncio>=3.4.3

34
crypto/gogo2/run_demo.py Normal file
View File

@ -0,0 +1,34 @@
#!/usr/bin/env python
import asyncio
import logging
from main import live_trading, setup_logging
# Set up logging
setup_logging()
logger = logging.getLogger(__name__)
async def main():
"""Run a simplified demo trading session with mock data"""
logger.info("Starting simplified demo trading session")
# Run live trading in demo mode with simplified parameters
await live_trading(
symbol="ETH/USDT",
timeframe="1m",
model_path="models/trading_agent_best_pnl.pt",
demo=True,
initial_balance=1000,
update_interval=10, # Update every 10 seconds for faster feedback
max_position_size=0.1,
risk_per_trade=0.02,
stop_loss_pct=0.02,
take_profit_pct=0.04,
)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Demo trading stopped by user")
except Exception as e:
logger.error(f"Error in demo trading: {e}")

View File

@ -0,0 +1,40 @@
#!/usr/bin/env python
import asyncio
import argparse
import logging
from main import live_trading, setup_logging
# Set up logging
setup_logging()
logger = logging.getLogger(__name__)
async def main():
parser = argparse.ArgumentParser(description='Run live trading in demo mode')
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for trading')
parser.add_argument('--model_path', type=str, default='data/best_model.pth', help='Path to the trained model')
parser.add_argument('--initial_balance', type=float, default=1000, help='Initial balance')
parser.add_argument('--update_interval', type=int, default=30, help='Interval to update data in seconds')
args = parser.parse_args()
logger.info(f"Starting live trading demo with {args.symbol} on {args.timeframe} timeframe")
# Run live trading in demo mode
await live_trading(
symbol=args.symbol,
timeframe=args.timeframe,
model_path=args.model_path,
demo=True, # Always use demo mode in this script
initial_balance=args.initial_balance,
update_interval=args.update_interval,
# Using default values for other parameters
)
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Live trading demo stopped by user")
except Exception as e:
logger.error(f"Error in live trading demo: {e}")

View File

@ -0,0 +1,69 @@
import os
import sys
import subprocess
import webbrowser
import time
import argparse
def run_tensorboard():
"""Run TensorBoard server and open browser"""
parser = argparse.ArgumentParser(description='TensorBoard Launcher')
parser.add_argument('--port', type=int, default=6006, help='Port for TensorBoard server')
parser.add_argument('--logdir', type=str, default='runs', help='Log directory for TensorBoard')
parser.add_argument('--no-browser', action='store_true', help='Do not open browser automatically')
args = parser.parse_args()
# Create log directory if it doesn't exist
os.makedirs(args.logdir, exist_ok=True)
# Print banner
print("\n" + "="*60)
print("📊 TRADING BOT - TENSORBOARD MONITORING 📊")
print("="*60)
print(f"Starting TensorBoard server on port {args.port}")
print(f"Log directory: {args.logdir}")
print("Press Ctrl+C to stop the server")
print("="*60 + "\n")
# Start TensorBoard server
cmd = ["tensorboard", "--logdir", args.logdir, "--port", str(args.port)]
try:
# Start TensorBoard process
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
universal_newlines=True
)
# Wait for TensorBoard to start
time.sleep(3)
# Open browser
if not args.no_browser:
url = f"http://localhost:{args.port}"
print(f"Opening browser to {url}")
webbrowser.open(url)
# Print TensorBoard output
while True:
output = process.stdout.readline()
if output == '' and process.poll() is not None:
break
if output:
print(output.strip())
return process.poll()
except KeyboardInterrupt:
print("\nStopping TensorBoard server...")
process.terminate()
return 0
except Exception as e:
print(f"Error running TensorBoard: {str(e)}")
return 1
if __name__ == "__main__":
exit_code = run_tensorboard()
sys.exit(exit_code)

77
crypto/gogo2/run_tests.py Normal file
View File

@ -0,0 +1,77 @@
#!/usr/bin/env python
"""
Run unit tests for the trading bot.
This script runs the unit tests defined in tests.py and displays the results.
It can run a single test or all tests.
Usage:
python run_tests.py [test_name]
If test_name is provided, only that test will be run.
Otherwise, all tests will be run.
Example:
python run_tests.py TestPeriodicUpdates
python run_tests.py TestBacktesting
python run_tests.py TestBacktestingLastSevenDays
python run_tests.py TestSingleDayBacktesting
python run_tests.py
"""
import sys
import unittest
import logging
from tests import (
TestPeriodicUpdates,
TestBacktesting,
TestBacktestingLastSevenDays,
TestSingleDayBacktesting
)
if __name__ == "__main__":
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()])
# Get the test name from the command line
test_name = sys.argv[1] if len(sys.argv) > 1 else None
# Run the specified test or all tests
if test_name:
logging.info(f"Running test: {test_name}")
if test_name == "TestPeriodicUpdates":
suite = unittest.TestLoader().loadTestsFromTestCase(TestPeriodicUpdates)
elif test_name == "TestBacktesting":
suite = unittest.TestLoader().loadTestsFromTestCase(TestBacktesting)
elif test_name == "TestBacktestingLastSevenDays":
suite = unittest.TestLoader().loadTestsFromTestCase(TestBacktestingLastSevenDays)
elif test_name == "TestSingleDayBacktesting":
suite = unittest.TestLoader().loadTestsFromTestCase(TestSingleDayBacktesting)
else:
logging.error(f"Unknown test: {test_name}")
logging.info("Available tests: TestPeriodicUpdates, TestBacktesting, TestBacktestingLastSevenDays, TestSingleDayBacktesting")
sys.exit(1)
else:
# Run all tests
logging.info("Running all tests")
suite = unittest.TestSuite()
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestPeriodicUpdates))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestBacktesting))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestBacktestingLastSevenDays))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestSingleDayBacktesting))
# Run the tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
# Print summary
print("\nTest Summary:")
print(f" Ran {result.testsRun} tests")
print(f" Errors: {len(result.errors)}")
print(f" Failures: {len(result.failures)}")
print(f" Skipped: {len(result.skipped)}")
# Exit with non-zero status if any tests failed
sys.exit(len(result.errors) + len(result.failures))

View File

@ -0,0 +1,118 @@
#!/usr/bin/env python
import asyncio
import logging
import sys
import platform
import ccxt.async_support as ccxt
import os
import datetime
# Fix for Windows asyncio issues with aiodns
if platform.system() == 'Windows':
try:
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
except Exception as e:
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
# Setup direct console logging for immediate feedback
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
async def initialize_exchange():
"""Initialize the exchange with API credentials from environment variables"""
exchange_id = 'mexc'
try:
# Get API credentials from environment variables
api_key = os.getenv('MEXC_API_KEY', '')
secret_key = os.getenv('MEXC_SECRET_KEY', '')
# Initialize the exchange
exchange_class = getattr(ccxt, exchange_id)
exchange = exchange_class({
'apiKey': api_key,
'secret': secret_key,
'enableRateLimit': True,
})
logger.info(f"Exchange initialized with standard CCXT: {exchange_id}")
return exchange
except Exception as e:
logger.error(f"Error initializing exchange: {e}")
raise
async def fetch_ohlcv_data(exchange, symbol, timeframe, limit=1000):
"""Fetch OHLCV data from the exchange"""
logger.info(f"Fetching {limit} {timeframe} candles for {symbol} (attempt 1/3)")
try:
candles = await exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
if not candles or len(candles) == 0:
logger.warning(f"No candles returned for {symbol} on {timeframe}")
return None
logger.info(f"Successfully fetched {len(candles)} candles")
return candles
except Exception as e:
logger.error(f"Error fetching candle data: {e}")
return None
async def main():
"""Main function to test live data fetching"""
symbol = "ETH/USDT"
timeframe = "1m"
logger.info(f"Starting simplified live training test for {symbol} on {timeframe}")
try:
# Initialize exchange
exchange = await initialize_exchange()
# Fetch data every 10 seconds
for i in range(5):
logger.info(f"Fetch attempt {i+1}/5")
candles = await fetch_ohlcv_data(exchange, symbol, timeframe)
if candles:
# Print the latest candle
latest = candles[-1]
timestamp, open_price, high, low, close, volume = latest
dt = datetime.datetime.fromtimestamp(timestamp/1000).strftime('%Y-%m-%d %H:%M:%S')
logger.info(f"Latest candle: Time={dt}, Open={open_price}, High={high}, Low={low}, Close={close}, Volume={volume}")
# Wait 10 seconds before next fetch
if i < 4: # Don't wait after the last fetch
logger.info("Waiting 10 seconds before next fetch...")
await asyncio.sleep(10)
# Close exchange connection
await exchange.close()
logger.info("Exchange connection closed")
except Exception as e:
logger.error(f"Error in simplified live training test: {e}")
import traceback
logger.error(traceback.format_exc())
finally:
try:
await exchange.close()
except:
pass
logger.info("Test completed")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Test stopped by user")
except Exception as e:
logger.error(f"Error in main function: {e}")
import traceback
logger.error(traceback.format_exc())

View File

@ -0,0 +1,14 @@
# PowerShell script to start live trading demo and TensorBoard
Write-Host "Starting Trading Bot Live Demo..." -ForegroundColor Green
# Create a new PowerShell window for TensorBoard
Start-Process powershell -ArgumentList "-Command python run_tensorboard.py" -WindowStyle Normal
# Wait a moment for TensorBoard to start
Write-Host "Starting TensorBoard... Please wait" -ForegroundColor Yellow
Start-Sleep -Seconds 5
# Start the live trading demo in the current window
Write-Host "Starting Live Trading Demo with mock data..." -ForegroundColor Green
python run_live_demo.py --symbol ETH/USDT --timeframe 1m --model models/trading_agent_best_pnl.pt --mock

View File

@ -0,0 +1,227 @@
#!/usr/bin/env python
import os
import logging
import torch
import argparse
import gc
import traceback
import shutil
from main import Agent, robust_save
# Set up logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.FileHandler("test_model_save_load.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def create_test_directory():
"""Create a test directory for saving models"""
test_dir = "test_models"
os.makedirs(test_dir, exist_ok=True)
return test_dir
def test_save_load_cycle(state_size=64, action_size=4, hidden_size=384):
"""Test a full cycle of saving and loading models"""
test_dir = create_test_directory()
# Create a test agent
logger.info(f"Creating test agent with state_size={state_size}, action_size={action_size}, hidden_size={hidden_size}")
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
# Define paths for testing
save_path = os.path.join(test_dir, "test_agent.pt")
# Test saving
logger.info(f"Testing save to {save_path}")
save_success = agent.save(save_path)
if save_success:
logger.info(f"Save successful, model size: {os.path.getsize(save_path)} bytes")
else:
logger.error("Save failed!")
return False
# Memory cleanup
del agent
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
# Test loading
logger.info(f"Testing load from {save_path}")
try:
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
new_agent.load(save_path)
logger.info("Load successful")
# Verify model architecture
logger.info(f"Verifying model architecture")
assert new_agent.state_size == state_size, f"Expected state_size={state_size}, got {new_agent.state_size}"
assert new_agent.action_size == action_size, f"Expected action_size={action_size}, got {new_agent.action_size}"
assert new_agent.hidden_size == hidden_size, f"Expected hidden_size={hidden_size}, got {new_agent.hidden_size}"
logger.info("Model architecture verified correctly")
return True
except Exception as e:
logger.error(f"Error during load or verification: {e}")
logger.error(traceback.format_exc())
return False
def test_robust_save_methods(state_size=64, action_size=4, hidden_size=384):
"""Test all the robust save methods"""
test_dir = create_test_directory()
# Create a test agent
logger.info(f"Creating test agent for robust save testing")
agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
# Test each robust save method
methods = [
("regular", os.path.join(test_dir, "regular_save.pt")),
("backup", os.path.join(test_dir, "backup_save.pt")),
("pickle2", os.path.join(test_dir, "pickle2_save.pt")),
("no_optimizer", os.path.join(test_dir, "no_optimizer_save.pt")),
("jit", os.path.join(test_dir, "jit_save.pt"))
]
results = {}
for method_name, save_path in methods:
logger.info(f"Testing {method_name} save method to {save_path}")
try:
if method_name == "regular":
# Use regular save
success = agent.save(save_path)
elif method_name == "backup":
# Use backup method directly
backup_path = f"{save_path}.backup"
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'optimizer': agent.optimizer.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, backup_path)
shutil.copy(backup_path, save_path)
success = os.path.exists(save_path)
elif method_name == "pickle2":
# Use pickle protocol 2
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'optimizer': agent.optimizer.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, save_path, pickle_protocol=2)
success = os.path.exists(save_path)
elif method_name == "no_optimizer":
# Save without optimizer
checkpoint = {
'policy_net': agent.policy_net.state_dict(),
'target_net': agent.target_net.state_dict(),
'epsilon': agent.epsilon,
'state_size': agent.state_size,
'action_size': agent.action_size,
'hidden_size': agent.hidden_size
}
torch.save(checkpoint, save_path)
success = os.path.exists(save_path)
elif method_name == "jit":
# Use JIT save
try:
scripted_policy = torch.jit.script(agent.policy_net)
torch.jit.save(scripted_policy, f"{save_path}.policy.jit")
scripted_target = torch.jit.script(agent.target_net)
torch.jit.save(scripted_target, f"{save_path}.target.jit")
# Save parameters
with open(f"{save_path}.params.json", "w") as f:
import json
params = {
'epsilon': float(agent.epsilon),
'state_size': int(agent.state_size),
'action_size': int(agent.action_size),
'hidden_size': int(agent.hidden_size)
}
json.dump(params, f)
success = (os.path.exists(f"{save_path}.policy.jit") and
os.path.exists(f"{save_path}.target.jit") and
os.path.exists(f"{save_path}.params.json"))
except Exception as e:
logger.error(f"JIT save failed: {e}")
success = False
if success:
if method_name != "jit":
file_size = os.path.getsize(save_path)
logger.info(f"{method_name} save successful, size: {file_size} bytes")
else:
logger.info(f"{method_name} save successful")
results[method_name] = True
else:
logger.error(f"{method_name} save failed")
results[method_name] = False
except Exception as e:
logger.error(f"Error during {method_name} save: {e}")
logger.error(traceback.format_exc())
results[method_name] = False
# Test loading each saved model
for method_name, save_path in methods:
if not results[method_name]:
logger.info(f"Skipping load test for {method_name} (save failed)")
continue
if method_name == "jit":
logger.info(f"Skipping load test for {method_name} (requires special loading)")
continue
logger.info(f"Testing load from {save_path}")
try:
new_agent = Agent(state_size=state_size, action_size=action_size, hidden_size=hidden_size)
new_agent.load(save_path)
logger.info(f"Load successful for {method_name} save")
except Exception as e:
logger.error(f"Error loading from {method_name} save: {e}")
logger.error(traceback.format_exc())
results[method_name] += " (load failed)"
# Return summary of results
return results
def main():
parser = argparse.ArgumentParser(description='Test model saving and loading')
parser.add_argument('--state_size', type=int, default=64, help='State size for test model')
parser.add_argument('--action_size', type=int, default=4, help='Action size for test model')
parser.add_argument('--hidden_size', type=int, default=384, help='Hidden size for test model')
parser.add_argument('--test_robust', action='store_true', help='Test all robust save methods')
args = parser.parse_args()
logger.info("Starting model save/load test")
if args.test_robust:
results = test_robust_save_methods(args.state_size, args.action_size, args.hidden_size)
logger.info(f"Robust save method results: {results}")
else:
success = test_save_load_cycle(args.state_size, args.action_size, args.hidden_size)
logger.info(f"Save/load cycle {'successful' if success else 'failed'}")
logger.info("Test completed")
if __name__ == "__main__":
main()

182
crypto/gogo2/test_save.py Normal file
View File

@ -0,0 +1,182 @@
#!/usr/bin/env python
import torch
import torch.nn as nn
import os
import logging
import sys
import platform
# Fix for Windows asyncio issues with aiodns
if platform.system() == 'Windows':
try:
import asyncio
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
except Exception as e:
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
# Setup logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("test_save.log"),
logging.StreamHandler(sys.stdout)
]
)
logger = logging.getLogger(__name__)
# Define a simple model for testing
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc1 = nn.Linear(10, 50)
self.fc2 = nn.Linear(50, 20)
self.fc3 = nn.Linear(20, 5)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
return self.fc3(x)
# Create a simple Agent class for testing
class TestAgent:
def __init__(self):
self.policy_net = SimpleModel()
self.target_net = SimpleModel()
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=0.001)
self.epsilon = 0.1
def save(self, path):
"""Standard save method that might fail"""
checkpoint = {
'policy_net': self.policy_net.state_dict(),
'target_net': self.target_net.state_dict(),
'optimizer': self.optimizer.state_dict(),
'epsilon': self.epsilon
}
torch.save(checkpoint, path)
logger.info(f"Model saved to {path}")
# Robust save function with multiple fallback approaches
def robust_save(model, path):
"""
Robust model saving with multiple fallback approaches
Args:
model: The Agent model to save
path: Path to save the model
Returns:
bool: True if successful, False otherwise
"""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
# Backup path in case the main save fails
backup_path = f"{path}.backup"
# Attempt 1: Try with default settings in a separate file first
try:
logger.info(f"Saving model to {backup_path} (attempt 1)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, backup_path)
logger.info(f"Successfully saved to {backup_path}")
# If backup worked, copy to the actual path
if os.path.exists(backup_path):
import shutil
shutil.copy(backup_path, path)
logger.info(f"Copied backup to {path}")
return True
except Exception as e:
logger.warning(f"First save attempt failed: {e}")
# Attempt 2: Try with pickle protocol 2 (more compatible)
try:
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'optimizer': model.optimizer.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path, pickle_protocol=2)
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
return True
except Exception as e:
logger.warning(f"Second save attempt failed: {e}")
# Attempt 3: Try without optimizer state (which can be large and cause issues)
try:
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
checkpoint = {
'policy_net': model.policy_net.state_dict(),
'target_net': model.target_net.state_dict(),
'epsilon': model.epsilon
}
torch.save(checkpoint, path)
logger.info(f"Successfully saved to {path} without optimizer state")
return True
except Exception as e:
logger.warning(f"Third save attempt failed: {e}")
# Attempt 4: Try with torch.jit.save instead
try:
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
# Save policy network using jit
scripted_policy = torch.jit.script(model.policy_net)
torch.jit.save(scripted_policy, f"{path}.policy.jit")
# Save target network using jit
scripted_target = torch.jit.script(model.target_net)
torch.jit.save(scripted_target, f"{path}.target.jit")
# Save epsilon value separately
with open(f"{path}.epsilon.txt", "w") as f:
f.write(str(model.epsilon))
logger.info(f"Successfully saved model components with jit.save")
return True
except Exception as e:
logger.error(f"All save attempts failed: {e}")
return False
def main():
# Create a test directory
save_dir = "test_models"
os.makedirs(save_dir, exist_ok=True)
# Create a test agent
agent = TestAgent()
# Test the regular save method (might fail)
try:
logger.info("Testing regular save method...")
save_path = os.path.join(save_dir, "regular_save.pt")
agent.save(save_path)
logger.info("Regular save succeeded")
except Exception as e:
logger.error(f"Regular save failed: {e}")
# Test our robust save method
logger.info("Testing robust save method...")
save_path = os.path.join(save_dir, "robust_save.pt")
success = robust_save(agent, save_path)
if success:
logger.info("Robust save succeeded!")
else:
logger.error("Robust save failed!")
# Check which files were created
logger.info("Files created:")
for file in os.listdir(save_dir):
file_path = os.path.join(save_dir, file)
file_size = os.path.getsize(file_path)
logger.info(f" - {file} ({file_size} bytes)")
if __name__ == "__main__":
main()

337
crypto/gogo2/tests.py Normal file
View File

@ -0,0 +1,337 @@
"""
Unit tests for the trading bot.
This file contains tests for various components of the trading bot, including:
1. Periodic candle updates
2. Backtesting on historical data
3. Training on the last 7 days of data
"""
import unittest
import asyncio
import os
import sys
import logging
import datetime
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
# Configure logging
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler()])
# Import functionality from main.py
import main
from main import (
CandleCache, BacktestCandles, initialize_exchange,
TradingEnvironment, Agent, train_with_backtesting,
fetch_multi_timeframe_data, train_agent
)
class TestPeriodicUpdates(unittest.TestCase):
"""Test that candle data is periodically updated during training."""
async def async_test_periodic_updates(self):
"""Test that candle data is periodically updated during training."""
logging.info("Testing periodic candle updates...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create candle cache
candle_cache = CandleCache()
# Initial fetch of candle data
candle_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
self.assertIsNotNone(candle_data, "Failed to fetch initial candle data")
self.assertIn('1m', candle_data, "1m candles not found in initial data")
# Check initial data timestamps
initial_1m_candles = candle_data['1m']
self.assertGreater(len(initial_1m_candles), 0, "No 1m candles found in initial data")
initial_timestamp = initial_1m_candles[-1][0]
# Wait for update interval to pass
logging.info("Waiting for update interval to pass (5 seconds for testing)...")
await asyncio.sleep(5) # Short wait for testing
# Force update by setting last_updated to None
candle_cache.last_updated['1m'] = None
# Fetch updated data
updated_data = await fetch_multi_timeframe_data(exchange, "ETH/USDT", candle_cache)
self.assertIsNotNone(updated_data, "Failed to fetch updated candle data")
# Check if data was updated
updated_1m_candles = updated_data['1m']
self.assertGreater(len(updated_1m_candles), 0, "No 1m candles found in updated data")
updated_timestamp = updated_1m_candles[-1][0]
# In a live scenario, this check should pass with real-time updates
# For testing, we just ensure data was fetched
logging.info(f"Initial timestamp: {initial_timestamp}, Updated timestamp: {updated_timestamp}")
self.assertIsNotNone(updated_timestamp, "Updated timestamp is None")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("Periodic update test completed")
def test_periodic_updates(self):
"""Run the async test."""
asyncio.run(self.async_test_periodic_updates())
class TestBacktesting(unittest.TestCase):
"""Test backtesting on historical data."""
async def async_test_backtesting(self):
"""Test backtesting on a specific time period."""
logging.info("Testing backtesting with historical data...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create a timestamp for 24 hours ago
now = datetime.datetime.now()
yesterday = now - datetime.timedelta(days=1)
since_timestamp = int(yesterday.timestamp() * 1000) # Convert to milliseconds
# Create a backtesting candle cache
backtest_cache = BacktestCandles(since_timestamp=since_timestamp)
backtest_cache.period_name = "1-day-ago"
# Fetch historical data
candle_data = await backtest_cache.fetch_all_timeframes(exchange, "ETH/USDT")
self.assertIsNotNone(candle_data, "Failed to fetch historical candle data")
self.assertIn('1m', candle_data, "1m candles not found in historical data")
# Check historical data timestamps
minute_candles = candle_data['1m']
self.assertGreater(len(minute_candles), 0, "No minute candles found in historical data")
# Check if timestamps are within the requested range
first_timestamp = minute_candles[0][0]
last_timestamp = minute_candles[-1][0]
logging.info(f"Requested since: {since_timestamp}")
logging.info(f"First timestamp in data: {first_timestamp}")
logging.info(f"Last timestamp in data: {last_timestamp}")
# In real tests, this check should compare timestamps precisely
# For this test, we just ensure data was fetched
self.assertLessEqual(first_timestamp, last_timestamp, "First timestamp should be before last timestamp")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("Backtesting fetch test completed")
def test_backtesting(self):
"""Run the async test."""
asyncio.run(self.async_test_backtesting())
class TestBacktestingLastSevenDays(unittest.TestCase):
"""Test backtesting on the last 7 days of data."""
async def async_test_seven_days_backtesting(self):
"""Test backtesting on the last 7 days."""
logging.info("Testing backtesting on the last 7 days...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create environment with small initial balance for testing
env = TradingEnvironment(
initial_balance=100, # Small balance for testing
leverage=10, # Lower leverage for testing
window_size=50, # Smaller window for faster testing
commission=0.0004 # Standard commission
)
# Create agent
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
# Initialize empty results dataframe
all_results = pd.DataFrame()
# Run backtesting for the last 7 days, one day at a time
now = datetime.datetime.now()
for day_offset in range(1, 8):
# Calculate time period
end_day = now - datetime.timedelta(days=day_offset-1)
start_day = end_day - datetime.timedelta(days=1)
# Convert to milliseconds
since_timestamp = int(start_day.timestamp() * 1000)
until_timestamp = int(end_day.timestamp() * 1000)
# Period name
period_name = f"Day-{day_offset}"
logging.info(f"Testing backtesting for period: {period_name}")
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
# Run backtesting with a small number of episodes for testing
stats = await train_with_backtesting(
agent=agent,
env=env,
symbol="ETH/USDT",
since_timestamp=since_timestamp,
until_timestamp=until_timestamp,
num_episodes=3, # Use a small number for testing
max_steps_per_episode=200, # Use a small number for testing
period_name=period_name
)
# Check if stats were returned
if stats is None:
logging.warning(f"No stats returned for period: {period_name}")
continue
# Create a dataframe from stats
if len(stats['episode_rewards']) > 0:
df = pd.DataFrame({
'Period': [period_name] * len(stats['episode_rewards']),
'Episode': list(range(1, len(stats['episode_rewards']) + 1)),
'Reward': stats['episode_rewards'],
'Balance': stats['balances'],
'PnL': stats['episode_pnls'],
'Fees': stats['fees'],
'Net_PnL': stats['net_pnl_after_fees']
})
# Append to all results
all_results = pd.concat([all_results, df], ignore_index=True)
logging.info(f"Completed backtesting for period: {period_name}")
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
else:
logging.warning(f"No episodes completed for period: {period_name}")
# Save all results
if not all_results.empty:
all_results.to_csv("all_backtest_results.csv", index=False)
logging.info("Saved all backtest results to all_backtest_results.csv")
# Create plot of results
plt.figure(figsize=(12, 8))
# Plot Net PnL by period
all_results.groupby('Period')['Net_PnL'].last().plot(kind='bar')
plt.title('Net PnL by Training Period (Last Episode)')
plt.ylabel('Net PnL ($)')
plt.tight_layout()
plt.savefig("backtest_results.png")
logging.info("Saved backtest results plot to backtest_results.png")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("7-day backtesting test completed")
def test_seven_days_backtesting(self):
"""Run the async test."""
asyncio.run(self.async_test_seven_days_backtesting())
class TestSingleDayBacktesting(unittest.TestCase):
"""Test backtesting on a single day of historical data."""
async def async_test_single_day_backtesting(self):
"""Test backtesting on a single day."""
logging.info("Testing backtesting on a single day...")
# Initialize exchange
exchange = await initialize_exchange()
self.assertIsNotNone(exchange, "Failed to initialize exchange")
# Create environment with small initial balance for testing
env = TradingEnvironment(
initial_balance=100, # Small balance for testing
leverage=10, # Lower leverage for testing
window_size=50, # Smaller window for faster testing
commission=0.0004 # Standard commission
)
# Create agent
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE)
# Calculate time period for 1 day ago
now = datetime.datetime.now()
end_day = now
start_day = end_day - datetime.timedelta(days=1)
# Convert to milliseconds
since_timestamp = int(start_day.timestamp() * 1000)
until_timestamp = int(end_day.timestamp() * 1000)
# Period name
period_name = "Test-Day-1"
logging.info(f"Testing backtesting for period: {period_name}")
logging.info(f" - From: {start_day.strftime('%Y-%m-%d %H:%M:%S')}")
logging.info(f" - To: {end_day.strftime('%Y-%m-%d %H:%M:%S')}")
# Run backtesting with a small number of episodes for testing
stats = await train_with_backtesting(
agent=agent,
env=env,
symbol="ETH/USDT",
since_timestamp=since_timestamp,
until_timestamp=until_timestamp,
num_episodes=2, # Very small number for quick testing
max_steps_per_episode=100, # Very small number for quick testing
period_name=period_name
)
# Check if stats were returned
self.assertIsNotNone(stats, "No stats returned from backtesting")
# Check if episodes were completed
self.assertGreater(len(stats['episode_rewards']), 0, "No episodes completed")
# Log results
logging.info(f"Completed backtesting for period: {period_name}")
logging.info(f" - Episodes: {len(stats['episode_rewards'])}")
logging.info(f" - Final Balance: ${stats['balances'][-1]:.2f}")
logging.info(f" - Net PnL: ${stats['net_pnl_after_fees'][-1]:.2f}")
# Close exchange connection
try:
await exchange.close()
except AttributeError:
# Some exchanges don't have a close method
pass
logging.info("Single day backtesting test completed")
def test_single_day_backtesting(self):
"""Run the async test."""
asyncio.run(self.async_test_single_day_backtesting())
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1 @@
timestamp,action,price,position_size,balance,pnl
1 timestamp action price position_size balance pnl

View File

@ -0,0 +1 @@
timestamp,action,price,position_size,balance,pnl
1 timestamp action price position_size balance pnl

View File

@ -0,0 +1 @@
timestamp,action,price,position_size,balance,pnl
1 timestamp action price position_size balance pnl

View File

@ -0,0 +1 @@
timestamp,action,price,position_size,balance,pnl
1 timestamp action price position_size balance pnl

File diff suppressed because it is too large Load Diff

Binary file not shown.

Before

Width:  |  Height:  |  Size: 170 KiB

After

Width:  |  Height:  |  Size: 60 KiB

View File

@ -1 +1,11 @@
episode_rewards,episode_lengths,balances,win_rates,episode_pnls,cumulative_pnl,drawdowns,prediction_accuracy episode_rewards,episode_lengths,balances,win_rates,episode_pnls,cumulative_pnl,drawdowns,prediction_accuracy,trade_analysis
0,1,100,0,0.0,0.0,0.0,0.0,{}
0,1,100,0,0.0,0.0,0.0,0.0,{}
0,1,100,0,0.0,0.0,0.0,0.0,{}
0,1,100,0,0.0,0.0,0.0,0.0,{}
0,1,100,0,0.0,0.0,0.0,0.0,{}
0,1,100,0,0.0,0.0,0.0,0.0,{}
0,1,100,0,0.0,0.0,0.0,0.0,{}
0,1,100,0,0.0,0.0,0.0,0.0,{}
0,1,100,0,0.0,0.0,0.0,0.0,{}
0,1,100,0,0.0,0.0,0.0,0.0,{}

1 episode_rewards episode_lengths balances win_rates episode_pnls cumulative_pnl drawdowns prediction_accuracy trade_analysis
2 0 1 100 0 0.0 0.0 0.0 0.0 {}
3 0 1 100 0 0.0 0.0 0.0 0.0 {}
4 0 1 100 0 0.0 0.0 0.0 0.0 {}
5 0 1 100 0 0.0 0.0 0.0 0.0 {}
6 0 1 100 0 0.0 0.0 0.0 0.0 {}
7 0 1 100 0 0.0 0.0 0.0 0.0 {}
8 0 1 100 0 0.0 0.0 0.0 0.0 {}
9 0 1 100 0 0.0 0.0 0.0 0.0 {}
10 0 1 100 0 0.0 0.0 0.0 0.0 {}
11 0 1 100 0 0.0 0.0 0.0 0.0 {}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 86 KiB