better saves
This commit is contained in:
parent
2e7a242ac7
commit
bdf6afc6ad
2
crypto/gogo2/.vscode/launch.json
vendored
2
crypto/gogo2/.vscode/launch.json
vendored
@ -6,7 +6,7 @@
|
|||||||
"type": "python",
|
"type": "python",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "main.py",
|
"program": "main.py",
|
||||||
"args": ["--mode", "train", "--episodes", "10"],
|
"args": ["--mode", "train", "--episodes", "100"],
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true
|
"justMyCode": true
|
||||||
},
|
},
|
||||||
|
74
crypto/gogo2/MODEL_SAVING_FIX.md
Normal file
74
crypto/gogo2/MODEL_SAVING_FIX.md
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
# Model Saving Fix
|
||||||
|
|
||||||
|
## Issue
|
||||||
|
|
||||||
|
During training sessions, PyTorch model saving operations sometimes fail with errors like:
|
||||||
|
|
||||||
|
```
|
||||||
|
RuntimeError: [enforce fail at inline_container.cc:626] . unexpected pos 18278784 vs 18278680
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```
|
||||||
|
RuntimeError: [enforce fail at inline_container.cc:820] . PytorchStreamWriter failed writing file data/75: file write failed
|
||||||
|
```
|
||||||
|
|
||||||
|
These errors occur in the PyTorch serialization mechanism when saving models using `torch.save()`.
|
||||||
|
|
||||||
|
## Solution
|
||||||
|
|
||||||
|
We've implemented a robust model saving approach that uses multiple fallback methods if the primary save operation fails:
|
||||||
|
|
||||||
|
1. **Attempt 1**: Save to a backup file first, then copy to the target path.
|
||||||
|
2. **Attempt 2**: Use an older pickle protocol (pickle protocol 2) which can be more compatible.
|
||||||
|
3. **Attempt 3**: Save without the optimizer state, which can reduce file size and avoid serialization issues.
|
||||||
|
4. **Attempt 4**: Use TorchScript's `torch.jit.save()` instead of `torch.save()`, which uses a different serialization mechanism.
|
||||||
|
|
||||||
|
## Implementation
|
||||||
|
|
||||||
|
The solution is implemented in two parts:
|
||||||
|
|
||||||
|
1. A `robust_save` function that tries multiple saving approaches with fallbacks.
|
||||||
|
2. A monkey patch that replaces the Agent's `save` method with our robust version.
|
||||||
|
|
||||||
|
### Example Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Import the robust_save function
|
||||||
|
from live_training import robust_save
|
||||||
|
|
||||||
|
# Save a model with fallbacks
|
||||||
|
success = robust_save(agent, "models/my_model.pt")
|
||||||
|
if success:
|
||||||
|
print("Model saved successfully!")
|
||||||
|
else:
|
||||||
|
print("All save attempts failed")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
We've created a test script `test_save.py` that demonstrates the robust saving approach and verifies that it works correctly.
|
||||||
|
|
||||||
|
To run the test:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python test_save.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This script creates a simple model, attempts to save it using both the standard and robust methods, and reports on the results.
|
||||||
|
|
||||||
|
## Future Improvements
|
||||||
|
|
||||||
|
Possible future improvements to the model saving mechanism:
|
||||||
|
|
||||||
|
1. Additional fallback methods like serializing individual neural network layers.
|
||||||
|
2. Automatic retry mechanism with exponential backoff.
|
||||||
|
3. Asynchronous saving to avoid blocking the training loop.
|
||||||
|
4. Checksumming saved models to verify integrity.
|
||||||
|
|
||||||
|
## Related Issues
|
||||||
|
|
||||||
|
For more information on similar issues with PyTorch model saving, see:
|
||||||
|
- https://github.com/pytorch/pytorch/issues/27736
|
||||||
|
- https://github.com/pytorch/pytorch/issues/24045
|
498
crypto/gogo2/live_training.py
Normal file
498
crypto/gogo2/live_training.py
Normal file
@ -0,0 +1,498 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import platform
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
import traceback
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from main import initialize_exchange, TradingEnvironment, Agent
|
||||||
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
# Fix for Windows asyncio issues with aiodns
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
try:
|
||||||
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
|
||||||
|
|
||||||
|
# Setup logging function
|
||||||
|
def setup_logging():
|
||||||
|
"""Setup logging configuration for the application"""
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("live_training.log"),
|
||||||
|
logging.StreamHandler(sys.stdout) # Added stdout handler for immediate feedback
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
setup_logging()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Implement a robust save function to handle PyTorch serialization errors
|
||||||
|
def robust_save(model, path):
|
||||||
|
"""
|
||||||
|
Robust model saving with multiple fallback approaches
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The Agent model to save
|
||||||
|
path: Path to save the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
||||||
|
|
||||||
|
# Backup path in case the main save fails
|
||||||
|
backup_path = f"{path}.backup"
|
||||||
|
|
||||||
|
# Attempt 1: Try with default settings in a separate file first
|
||||||
|
try:
|
||||||
|
logger.info(f"Saving model to {backup_path} (attempt 1)")
|
||||||
|
checkpoint = {
|
||||||
|
'policy_net': model.policy_net.state_dict(),
|
||||||
|
'target_net': model.target_net.state_dict(),
|
||||||
|
'optimizer': model.optimizer.state_dict(),
|
||||||
|
'epsilon': model.epsilon
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, backup_path)
|
||||||
|
logger.info(f"Successfully saved to {backup_path}")
|
||||||
|
|
||||||
|
# If backup worked, copy to the actual path
|
||||||
|
if os.path.exists(backup_path):
|
||||||
|
import shutil
|
||||||
|
shutil.copy(backup_path, path)
|
||||||
|
logger.info(f"Copied backup to {path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"First save attempt failed: {e}")
|
||||||
|
|
||||||
|
# Attempt 2: Try with pickle protocol 2 (more compatible)
|
||||||
|
try:
|
||||||
|
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
|
||||||
|
checkpoint = {
|
||||||
|
'policy_net': model.policy_net.state_dict(),
|
||||||
|
'target_net': model.target_net.state_dict(),
|
||||||
|
'optimizer': model.optimizer.state_dict(),
|
||||||
|
'epsilon': model.epsilon
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, path, pickle_protocol=2)
|
||||||
|
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Second save attempt failed: {e}")
|
||||||
|
|
||||||
|
# Attempt 3: Try without optimizer state (which can be large and cause issues)
|
||||||
|
try:
|
||||||
|
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
|
||||||
|
checkpoint = {
|
||||||
|
'policy_net': model.policy_net.state_dict(),
|
||||||
|
'target_net': model.target_net.state_dict(),
|
||||||
|
'epsilon': model.epsilon
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, path)
|
||||||
|
logger.info(f"Successfully saved to {path} without optimizer state")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Third save attempt failed: {e}")
|
||||||
|
|
||||||
|
# Attempt 4: Try with torch.jit.save instead
|
||||||
|
try:
|
||||||
|
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
|
||||||
|
# Save policy network using jit
|
||||||
|
scripted_policy = torch.jit.script(model.policy_net)
|
||||||
|
torch.jit.save(scripted_policy, f"{path}.policy.jit")
|
||||||
|
# Save target network using jit
|
||||||
|
scripted_target = torch.jit.script(model.target_net)
|
||||||
|
torch.jit.save(scripted_target, f"{path}.target.jit")
|
||||||
|
# Save epsilon value separately
|
||||||
|
with open(f"{path}.epsilon.txt", "w") as f:
|
||||||
|
f.write(str(model.epsilon))
|
||||||
|
logger.info(f"Successfully saved model components with jit.save")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"All save attempts failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Implement fetch_and_update_data function
|
||||||
|
async def fetch_and_update_data(exchange, env, symbol, timeframe):
|
||||||
|
"""
|
||||||
|
Fetch new candle data and update the environment
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange: CCXT exchange instance
|
||||||
|
env: Trading environment instance
|
||||||
|
symbol: Trading pair symbol
|
||||||
|
timeframe: Timeframe for the candles
|
||||||
|
"""
|
||||||
|
logger.info(f"Fetching new data for {symbol} on {timeframe} timeframe")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Default to 100 candles if not specified
|
||||||
|
limit = 1000
|
||||||
|
|
||||||
|
# Fetch OHLCV data
|
||||||
|
candles = await exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
|
||||||
|
|
||||||
|
if not candles or len(candles) == 0:
|
||||||
|
logger.warning(f"No candles returned for {symbol} on {timeframe}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
logger.info(f"Successfully fetched {len(candles)} candles")
|
||||||
|
|
||||||
|
# Convert to format expected by environment
|
||||||
|
formatted_candles = []
|
||||||
|
for candle in candles:
|
||||||
|
timestamp, open_price, high, low, close, volume = candle
|
||||||
|
formatted_candles.append({
|
||||||
|
'timestamp': timestamp,
|
||||||
|
'open': open_price,
|
||||||
|
'high': high,
|
||||||
|
'low': low,
|
||||||
|
'close': close,
|
||||||
|
'volume': volume
|
||||||
|
})
|
||||||
|
|
||||||
|
# Update environment data
|
||||||
|
env.data = formatted_candles
|
||||||
|
if hasattr(env, '_initialize_features'):
|
||||||
|
env._initialize_features()
|
||||||
|
|
||||||
|
logger.info(f"Updated environment with {len(formatted_candles)} candles")
|
||||||
|
|
||||||
|
# Print latest candle info
|
||||||
|
if formatted_candles:
|
||||||
|
latest = formatted_candles[-1]
|
||||||
|
dt = datetime.datetime.fromtimestamp(latest['timestamp']/1000).strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
logger.info(f"Latest candle: Time={dt}, Open={latest['open']}, High={latest['high']}, Low={latest['low']}, Close={latest['close']}, Volume={latest['volume']}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching candle data: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def live_training(
|
||||||
|
symbol="ETH/USDT",
|
||||||
|
timeframe="1m",
|
||||||
|
model_path="models/trading_agent_best_pnl.pt",
|
||||||
|
save_path="models/trading_agent_live_trained.pt",
|
||||||
|
initial_balance=1000,
|
||||||
|
update_interval=60,
|
||||||
|
training_iterations=100,
|
||||||
|
learning_rate=0.0001,
|
||||||
|
batch_size=64,
|
||||||
|
gamma=0.99,
|
||||||
|
window_size=30,
|
||||||
|
max_episodes=0, # 0 means unlimited
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Live training function that uses real market data to improve the model without executing real trades.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading pair symbol
|
||||||
|
timeframe: Timeframe for training
|
||||||
|
model_path: Path to the initial model to load
|
||||||
|
save_path: Path to save the improved model
|
||||||
|
initial_balance: Initial balance for simulation
|
||||||
|
update_interval: Interval to update data in seconds
|
||||||
|
training_iterations: Number of training iterations per data update
|
||||||
|
learning_rate: Learning rate for training
|
||||||
|
batch_size: Batch size for training
|
||||||
|
gamma: Discount factor for training
|
||||||
|
window_size: Window size for the environment
|
||||||
|
max_episodes: Maximum number of episodes (0 for unlimited)
|
||||||
|
"""
|
||||||
|
logger.info(f"Starting live training for {symbol} on {timeframe} timeframe")
|
||||||
|
|
||||||
|
# Initialize exchange (without sandbox mode)
|
||||||
|
exchange = None
|
||||||
|
try:
|
||||||
|
exchange = await initialize_exchange()
|
||||||
|
logger.info(f"Exchange initialized: {exchange.id}")
|
||||||
|
|
||||||
|
# Initialize environment
|
||||||
|
env = TradingEnvironment(
|
||||||
|
initial_balance=initial_balance,
|
||||||
|
window_size=window_size,
|
||||||
|
symbol=symbol,
|
||||||
|
timeframe=timeframe,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch initial data
|
||||||
|
logger.info(f"Fetching initial data for {symbol}")
|
||||||
|
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
|
||||||
|
if not success:
|
||||||
|
logger.error("Failed to fetch initial data, exiting")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize agent
|
||||||
|
STATE_SIZE = env.get_state().shape[0] if hasattr(env, 'get_state') else 64
|
||||||
|
ACTION_SIZE = env.action_space.n if hasattr(env.action_space, 'n') else 4
|
||||||
|
agent = Agent(state_size=STATE_SIZE, action_size=ACTION_SIZE, hidden_size=384)
|
||||||
|
|
||||||
|
# Load model if provided
|
||||||
|
if os.path.exists(model_path):
|
||||||
|
try:
|
||||||
|
agent.load(model_path)
|
||||||
|
logger.info(f"Model loaded successfully from {model_path}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error loading model: {e}")
|
||||||
|
logger.info("Starting with a new model")
|
||||||
|
else:
|
||||||
|
logger.warning(f"Model file {model_path} not found. Starting with a new model.")
|
||||||
|
|
||||||
|
# Initialize TensorBoard writer
|
||||||
|
run_id = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
|
writer = SummaryWriter(log_dir=f"runs/live_training_{run_id}")
|
||||||
|
agent.writer = writer
|
||||||
|
|
||||||
|
# Initialize training statistics
|
||||||
|
total_rewards = 0
|
||||||
|
episode_count = 0
|
||||||
|
best_reward = float('-inf')
|
||||||
|
best_pnl = float('-inf')
|
||||||
|
|
||||||
|
# Start live training loop
|
||||||
|
logger.info(f"Starting live training loop")
|
||||||
|
|
||||||
|
step_counter = 0
|
||||||
|
last_update_time = datetime.datetime.now()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# Check if we've reached the maximum number of episodes
|
||||||
|
if max_episodes > 0 and episode_count >= max_episodes:
|
||||||
|
logger.info(f"Reached maximum episodes ({max_episodes}), stopping")
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check if it's time to update data
|
||||||
|
current_time = datetime.datetime.now()
|
||||||
|
time_diff = (current_time - last_update_time).total_seconds()
|
||||||
|
|
||||||
|
if time_diff >= update_interval:
|
||||||
|
logger.info(f"Updating market data after {time_diff:.1f} seconds")
|
||||||
|
success = await fetch_and_update_data(exchange, env, symbol, timeframe)
|
||||||
|
if not success:
|
||||||
|
logger.warning("Failed to update data, will try again later")
|
||||||
|
# Wait a bit before trying again
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
continue
|
||||||
|
|
||||||
|
last_update_time = current_time
|
||||||
|
|
||||||
|
# Run training iterations on the updated data
|
||||||
|
episode_reward = 0
|
||||||
|
env.reset()
|
||||||
|
done = False
|
||||||
|
|
||||||
|
# Run one simulated episode with the current data
|
||||||
|
steps_in_episode = 0
|
||||||
|
max_steps = len(env.data) - env.window_size - 1
|
||||||
|
|
||||||
|
logger.info(f"Starting episode {episode_count + 1} with {max_steps} steps")
|
||||||
|
|
||||||
|
while not done and steps_in_episode < max_steps:
|
||||||
|
try:
|
||||||
|
state = env.get_state()
|
||||||
|
action = agent.select_action(state, training=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
next_state, reward, done, info = env.step(action)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Error during env.step: {e}")
|
||||||
|
# If we get a ValueError, it might be because step is returning 3 values instead of 4
|
||||||
|
# Let's try to handle this case
|
||||||
|
if "too many values to unpack" in str(e):
|
||||||
|
logger.info("Trying alternative step format")
|
||||||
|
result = env.step(action)
|
||||||
|
if len(result) == 3:
|
||||||
|
next_state, reward, done = result
|
||||||
|
info = {}
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Save experience in replay memory
|
||||||
|
agent.memory.push(state, action, reward, next_state, done)
|
||||||
|
|
||||||
|
# Move to the next state
|
||||||
|
state = next_state
|
||||||
|
episode_reward += reward
|
||||||
|
step_counter += 1
|
||||||
|
steps_in_episode += 1
|
||||||
|
|
||||||
|
# Log action and results every 50 steps
|
||||||
|
if steps_in_episode % 50 == 0:
|
||||||
|
logger.info(f"Step {steps_in_episode}/{max_steps} | Action: {action} | Reward: {reward:.2f} | Balance: ${env.balance:.2f}")
|
||||||
|
|
||||||
|
# Train the agent on a batch of experiences
|
||||||
|
if len(agent.memory) > batch_size:
|
||||||
|
agent.learn()
|
||||||
|
|
||||||
|
# Additional training iterations
|
||||||
|
if steps_in_episode % 10 == 0 and training_iterations > 1:
|
||||||
|
for _ in range(training_iterations - 1):
|
||||||
|
agent.learn()
|
||||||
|
|
||||||
|
if done:
|
||||||
|
logger.info(f"Episode done after {steps_in_episode} steps")
|
||||||
|
break
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during episode step: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
break
|
||||||
|
|
||||||
|
# Update training statistics
|
||||||
|
episode_count += 1
|
||||||
|
total_rewards += episode_reward
|
||||||
|
avg_reward = total_rewards / episode_count
|
||||||
|
|
||||||
|
# Track metrics
|
||||||
|
writer.add_scalar('LiveTraining/Reward', episode_reward, episode_count)
|
||||||
|
writer.add_scalar('LiveTraining/AvgReward', avg_reward, episode_count)
|
||||||
|
writer.add_scalar('LiveTraining/Balance', env.balance, episode_count)
|
||||||
|
writer.add_scalar('LiveTraining/PnL', env.total_pnl, episode_count)
|
||||||
|
|
||||||
|
# Report progress
|
||||||
|
logger.info(f"""
|
||||||
|
Episode: {episode_count}
|
||||||
|
Reward: {episode_reward:.2f}
|
||||||
|
Avg Reward: {avg_reward:.2f}
|
||||||
|
Balance: ${env.balance:.2f}
|
||||||
|
PnL: ${env.total_pnl:.2f}
|
||||||
|
Memory Size: {len(agent.memory)}
|
||||||
|
Total Steps: {step_counter}
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Save the model if it's the best so far (by reward or PnL)
|
||||||
|
if episode_reward > best_reward:
|
||||||
|
best_reward = episode_reward
|
||||||
|
reward_model_path = f"models/trading_agent_best_reward_{run_id}.pt"
|
||||||
|
if robust_save(agent, reward_model_path):
|
||||||
|
logger.info(f"New best reward model saved: {episode_reward:.2f} to {reward_model_path}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to save best reward model")
|
||||||
|
|
||||||
|
if env.total_pnl > best_pnl:
|
||||||
|
best_pnl = env.total_pnl
|
||||||
|
pnl_model_path = f"models/trading_agent_best_pnl_{run_id}.pt"
|
||||||
|
if robust_save(agent, pnl_model_path):
|
||||||
|
logger.info(f"New best PnL model saved: ${env.total_pnl:.2f} to {pnl_model_path}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to save best PnL model")
|
||||||
|
|
||||||
|
# Regularly save the model
|
||||||
|
if episode_count % 5 == 0:
|
||||||
|
if robust_save(agent, save_path):
|
||||||
|
logger.info(f"Model checkpoint saved to {save_path}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to save checkpoint")
|
||||||
|
|
||||||
|
# Update target network periodically
|
||||||
|
if episode_count % 5 == 0:
|
||||||
|
try:
|
||||||
|
agent.update_target_network()
|
||||||
|
logger.info("Target network updated")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error updating target network: {e}")
|
||||||
|
|
||||||
|
# Sleep to avoid excessive API calls
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("Live training cancelled")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Live training stopped by user")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in live training: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
finally:
|
||||||
|
# Save final model
|
||||||
|
if robust_save(agent, save_path):
|
||||||
|
logger.info(f"Final model saved to {save_path}")
|
||||||
|
else:
|
||||||
|
logger.error(f"Failed to save final model")
|
||||||
|
|
||||||
|
# Close TensorBoard writer
|
||||||
|
try:
|
||||||
|
writer.close()
|
||||||
|
logger.info("TensorBoard writer closed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing TensorBoard writer: {e}")
|
||||||
|
|
||||||
|
# Close exchange connection
|
||||||
|
if exchange:
|
||||||
|
try:
|
||||||
|
await exchange.close()
|
||||||
|
logger.info("Exchange connection closed")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error closing exchange connection: {e}")
|
||||||
|
|
||||||
|
logger.info("Live training completed")
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main function to parse arguments and start live training"""
|
||||||
|
parser = argparse.ArgumentParser(description='Live Training with Real Market Data')
|
||||||
|
parser.add_argument('--symbol', type=str, default='ETH/USDT', help='Trading pair symbol')
|
||||||
|
parser.add_argument('--timeframe', type=str, default='1m', help='Timeframe for training')
|
||||||
|
parser.add_argument('--model_path', type=str, default='models/trading_agent_best_pnl.pt', help='Path to initial model')
|
||||||
|
parser.add_argument('--save_path', type=str, default='models/trading_agent_live_trained.pt', help='Path to save improved model')
|
||||||
|
parser.add_argument('--initial_balance', type=float, default=1000, help='Initial balance for simulation')
|
||||||
|
parser.add_argument('--update_interval', type=int, default=60, help='Interval to update data in seconds')
|
||||||
|
parser.add_argument('--training_iterations', type=int, default=100, help='Training iterations per update')
|
||||||
|
parser.add_argument('--max_episodes', type=int, default=0, help='Maximum number of episodes (0 for unlimited)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logger.info(f"Starting live training with {args.symbol} on {args.timeframe} timeframe")
|
||||||
|
|
||||||
|
await live_training(
|
||||||
|
symbol=args.symbol,
|
||||||
|
timeframe=args.timeframe,
|
||||||
|
model_path=args.model_path,
|
||||||
|
save_path=args.save_path,
|
||||||
|
initial_balance=args.initial_balance,
|
||||||
|
update_interval=args.update_interval,
|
||||||
|
training_iterations=args.training_iterations,
|
||||||
|
max_episodes=args.max_episodes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# At the beginning of the file, after importing the modules
|
||||||
|
# Override Agent's save method with our robust save function
|
||||||
|
def monkey_patch_agent_save():
|
||||||
|
"""Replace Agent's save method with our robust save approach"""
|
||||||
|
original_save = Agent.save
|
||||||
|
|
||||||
|
def patched_save(self, path):
|
||||||
|
return robust_save(self, path)
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
Agent.save = patched_save
|
||||||
|
logger.info("Monkey patched Agent.save with robust_save")
|
||||||
|
|
||||||
|
# Return the original method in case we need to restore it
|
||||||
|
return original_save
|
||||||
|
|
||||||
|
# Call the monkey patch function at the appropriate place
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
print("Starting live training script")
|
||||||
|
# Apply the monkey patch before running the main function
|
||||||
|
original_save = monkey_patch_agent_save()
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Live training stopped by user")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in main function: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
118
crypto/gogo2/simplified_live_training.py
Normal file
118
crypto/gogo2/simplified_live_training.py
Normal file
@ -0,0 +1,118 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import platform
|
||||||
|
import ccxt.async_support as ccxt
|
||||||
|
import os
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
# Fix for Windows asyncio issues with aiodns
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
|
||||||
|
|
||||||
|
# Setup direct console logging for immediate feedback
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(sys.stdout)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def initialize_exchange():
|
||||||
|
"""Initialize the exchange with API credentials from environment variables"""
|
||||||
|
exchange_id = 'mexc'
|
||||||
|
try:
|
||||||
|
# Get API credentials from environment variables
|
||||||
|
api_key = os.getenv('MEXC_API_KEY', '')
|
||||||
|
secret_key = os.getenv('MEXC_SECRET_KEY', '')
|
||||||
|
|
||||||
|
# Initialize the exchange
|
||||||
|
exchange_class = getattr(ccxt, exchange_id)
|
||||||
|
exchange = exchange_class({
|
||||||
|
'apiKey': api_key,
|
||||||
|
'secret': secret_key,
|
||||||
|
'enableRateLimit': True,
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(f"Exchange initialized with standard CCXT: {exchange_id}")
|
||||||
|
return exchange
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error initializing exchange: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def fetch_ohlcv_data(exchange, symbol, timeframe, limit=1000):
|
||||||
|
"""Fetch OHLCV data from the exchange"""
|
||||||
|
logger.info(f"Fetching {limit} {timeframe} candles for {symbol} (attempt 1/3)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
candles = await exchange.fetch_ohlcv(symbol, timeframe, limit=limit)
|
||||||
|
if not candles or len(candles) == 0:
|
||||||
|
logger.warning(f"No candles returned for {symbol} on {timeframe}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
logger.info(f"Successfully fetched {len(candles)} candles")
|
||||||
|
return candles
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error fetching candle data: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main function to test live data fetching"""
|
||||||
|
symbol = "ETH/USDT"
|
||||||
|
timeframe = "1m"
|
||||||
|
|
||||||
|
logger.info(f"Starting simplified live training test for {symbol} on {timeframe}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize exchange
|
||||||
|
exchange = await initialize_exchange()
|
||||||
|
|
||||||
|
# Fetch data every 10 seconds
|
||||||
|
for i in range(5):
|
||||||
|
logger.info(f"Fetch attempt {i+1}/5")
|
||||||
|
candles = await fetch_ohlcv_data(exchange, symbol, timeframe)
|
||||||
|
|
||||||
|
if candles:
|
||||||
|
# Print the latest candle
|
||||||
|
latest = candles[-1]
|
||||||
|
timestamp, open_price, high, low, close, volume = latest
|
||||||
|
dt = datetime.datetime.fromtimestamp(timestamp/1000).strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
logger.info(f"Latest candle: Time={dt}, Open={open_price}, High={high}, Low={low}, Close={close}, Volume={volume}")
|
||||||
|
|
||||||
|
# Wait 10 seconds before next fetch
|
||||||
|
if i < 4: # Don't wait after the last fetch
|
||||||
|
logger.info("Waiting 10 seconds before next fetch...")
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
# Close exchange connection
|
||||||
|
await exchange.close()
|
||||||
|
logger.info("Exchange connection closed")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in simplified live training test: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await exchange.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
logger.info("Test completed")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Test stopped by user")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in main function: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
182
crypto/gogo2/test_save.py
Normal file
182
crypto/gogo2/test_save.py
Normal file
@ -0,0 +1,182 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import platform
|
||||||
|
|
||||||
|
# Fix for Windows asyncio issues with aiodns
|
||||||
|
if platform.system() == 'Windows':
|
||||||
|
try:
|
||||||
|
import asyncio
|
||||||
|
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||||
|
print("Using Windows SelectorEventLoopPolicy to fix aiodns issue")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to set WindowsSelectorEventLoopPolicy: {e}")
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("test_save.log"),
|
||||||
|
logging.StreamHandler(sys.stdout)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Define a simple model for testing
|
||||||
|
class SimpleModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(SimpleModel, self).__init__()
|
||||||
|
self.fc1 = nn.Linear(10, 50)
|
||||||
|
self.fc2 = nn.Linear(50, 20)
|
||||||
|
self.fc3 = nn.Linear(20, 5)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = torch.relu(self.fc1(x))
|
||||||
|
x = torch.relu(self.fc2(x))
|
||||||
|
return self.fc3(x)
|
||||||
|
|
||||||
|
# Create a simple Agent class for testing
|
||||||
|
class TestAgent:
|
||||||
|
def __init__(self):
|
||||||
|
self.policy_net = SimpleModel()
|
||||||
|
self.target_net = SimpleModel()
|
||||||
|
self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=0.001)
|
||||||
|
self.epsilon = 0.1
|
||||||
|
|
||||||
|
def save(self, path):
|
||||||
|
"""Standard save method that might fail"""
|
||||||
|
checkpoint = {
|
||||||
|
'policy_net': self.policy_net.state_dict(),
|
||||||
|
'target_net': self.target_net.state_dict(),
|
||||||
|
'optimizer': self.optimizer.state_dict(),
|
||||||
|
'epsilon': self.epsilon
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, path)
|
||||||
|
logger.info(f"Model saved to {path}")
|
||||||
|
|
||||||
|
# Robust save function with multiple fallback approaches
|
||||||
|
def robust_save(model, path):
|
||||||
|
"""
|
||||||
|
Robust model saving with multiple fallback approaches
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The Agent model to save
|
||||||
|
path: Path to save the model
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
# Create directory if it doesn't exist
|
||||||
|
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
||||||
|
|
||||||
|
# Backup path in case the main save fails
|
||||||
|
backup_path = f"{path}.backup"
|
||||||
|
|
||||||
|
# Attempt 1: Try with default settings in a separate file first
|
||||||
|
try:
|
||||||
|
logger.info(f"Saving model to {backup_path} (attempt 1)")
|
||||||
|
checkpoint = {
|
||||||
|
'policy_net': model.policy_net.state_dict(),
|
||||||
|
'target_net': model.target_net.state_dict(),
|
||||||
|
'optimizer': model.optimizer.state_dict(),
|
||||||
|
'epsilon': model.epsilon
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, backup_path)
|
||||||
|
logger.info(f"Successfully saved to {backup_path}")
|
||||||
|
|
||||||
|
# If backup worked, copy to the actual path
|
||||||
|
if os.path.exists(backup_path):
|
||||||
|
import shutil
|
||||||
|
shutil.copy(backup_path, path)
|
||||||
|
logger.info(f"Copied backup to {path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"First save attempt failed: {e}")
|
||||||
|
|
||||||
|
# Attempt 2: Try with pickle protocol 2 (more compatible)
|
||||||
|
try:
|
||||||
|
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
|
||||||
|
checkpoint = {
|
||||||
|
'policy_net': model.policy_net.state_dict(),
|
||||||
|
'target_net': model.target_net.state_dict(),
|
||||||
|
'optimizer': model.optimizer.state_dict(),
|
||||||
|
'epsilon': model.epsilon
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, path, pickle_protocol=2)
|
||||||
|
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Second save attempt failed: {e}")
|
||||||
|
|
||||||
|
# Attempt 3: Try without optimizer state (which can be large and cause issues)
|
||||||
|
try:
|
||||||
|
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
|
||||||
|
checkpoint = {
|
||||||
|
'policy_net': model.policy_net.state_dict(),
|
||||||
|
'target_net': model.target_net.state_dict(),
|
||||||
|
'epsilon': model.epsilon
|
||||||
|
}
|
||||||
|
torch.save(checkpoint, path)
|
||||||
|
logger.info(f"Successfully saved to {path} without optimizer state")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Third save attempt failed: {e}")
|
||||||
|
|
||||||
|
# Attempt 4: Try with torch.jit.save instead
|
||||||
|
try:
|
||||||
|
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
|
||||||
|
# Save policy network using jit
|
||||||
|
scripted_policy = torch.jit.script(model.policy_net)
|
||||||
|
torch.jit.save(scripted_policy, f"{path}.policy.jit")
|
||||||
|
# Save target network using jit
|
||||||
|
scripted_target = torch.jit.script(model.target_net)
|
||||||
|
torch.jit.save(scripted_target, f"{path}.target.jit")
|
||||||
|
# Save epsilon value separately
|
||||||
|
with open(f"{path}.epsilon.txt", "w") as f:
|
||||||
|
f.write(str(model.epsilon))
|
||||||
|
logger.info(f"Successfully saved model components with jit.save")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"All save attempts failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Create a test directory
|
||||||
|
save_dir = "test_models"
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Create a test agent
|
||||||
|
agent = TestAgent()
|
||||||
|
|
||||||
|
# Test the regular save method (might fail)
|
||||||
|
try:
|
||||||
|
logger.info("Testing regular save method...")
|
||||||
|
save_path = os.path.join(save_dir, "regular_save.pt")
|
||||||
|
agent.save(save_path)
|
||||||
|
logger.info("Regular save succeeded")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Regular save failed: {e}")
|
||||||
|
|
||||||
|
# Test our robust save method
|
||||||
|
logger.info("Testing robust save method...")
|
||||||
|
save_path = os.path.join(save_dir, "robust_save.pt")
|
||||||
|
success = robust_save(agent, save_path)
|
||||||
|
|
||||||
|
if success:
|
||||||
|
logger.info("Robust save succeeded!")
|
||||||
|
else:
|
||||||
|
logger.error("Robust save failed!")
|
||||||
|
|
||||||
|
# Check which files were created
|
||||||
|
logger.info("Files created:")
|
||||||
|
for file in os.listdir(save_dir):
|
||||||
|
file_path = os.path.join(save_dir, file)
|
||||||
|
file_size = os.path.getsize(file_path)
|
||||||
|
logger.info(f" - {file} ({file_size} bytes)")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user