fix model loading in live mode

This commit is contained in:
Dobromir Popov 2025-03-17 02:17:43 +02:00
parent 485c61cf8c
commit 991cf57274

View File

@ -1695,7 +1695,22 @@ class Agent:
def load(self, path="models/trading_agent.pt"):
if os.path.isfile(path):
checkpoint = torch.load(path)
try:
# First try with weights_only=True (safer)
checkpoint = torch.load(path, weights_only=True)
except Exception as e:
logger.warning(f"Failed to load with weights_only=True: {e}")
try:
# Try with safe_globals for numpy.scalar
import numpy as np
from torch.serialization import safe_globals
with safe_globals([np.core.multiarray.scalar]):
checkpoint = torch.load(path, weights_only=True)
except Exception as e2:
logger.warning(f"Failed with safe_globals: {e2}")
# Fall back to weights_only=False if needed
checkpoint = torch.load(path, weights_only=False)
self.policy_net.load_state_dict(checkpoint['policy_net'])
self.target_net.load_state_dict(checkpoint['target_net'])
self.optimizer.load_state_dict(checkpoint['optimizer'])
@ -2154,6 +2169,13 @@ async def live_trading(agent, env, exchange, demo=True):
# Main trading loop
step_counter = 0
# For online learning
states = []
actions = []
rewards = []
next_states = []
dones = []
while True:
# Wait for the next candle (1 minute)
await asyncio.sleep(5) # Check every 5 seconds
@ -2175,7 +2197,38 @@ async def live_trading(agent, env, exchange, demo=True):
action = agent.select_action(state, training=False)
# Take action
_, reward, _ = env.step(action)
next_state, reward, done = env.step(action)
# Store experience for online learning
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(done)
# Online learning - update the model with new experiences
if len(states) >= 10: # Batch size for online learning
# Store experiences in replay memory
for i in range(len(states)):
agent.memory.push(states[i], actions[i], rewards[i], next_states[i], dones[i])
# Learn from experiences if we have enough samples
if len(agent.memory) > 32:
loss = agent.learn()
if loss is not None:
agent.writer.add_scalar('Live/Loss', loss, step_counter)
# Clear the temporary storage
states = []
actions = []
rewards = []
next_states = []
dones = []
# Save the updated model periodically
if step_counter % 100 == 0:
agent.save("models/trading_agent_live_updated.pt")
logger.info("Updated model saved during live trading")
# Log trading activity
action_names = ["HOLD", "BUY", "SELL", "CLOSE"]