From 991cf5727435a903df9118d9f6deda0b87fb015d Mon Sep 17 00:00:00 2001 From: Dobromir Popov Date: Mon, 17 Mar 2025 02:17:43 +0200 Subject: [PATCH] fix model loading in live mode --- crypto/gogo2/main.py | 57 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 55 insertions(+), 2 deletions(-) diff --git a/crypto/gogo2/main.py b/crypto/gogo2/main.py index cd2e3db..1a3f9fc 100644 --- a/crypto/gogo2/main.py +++ b/crypto/gogo2/main.py @@ -1695,7 +1695,22 @@ class Agent: def load(self, path="models/trading_agent.pt"): if os.path.isfile(path): - checkpoint = torch.load(path) + try: + # First try with weights_only=True (safer) + checkpoint = torch.load(path, weights_only=True) + except Exception as e: + logger.warning(f"Failed to load with weights_only=True: {e}") + try: + # Try with safe_globals for numpy.scalar + import numpy as np + from torch.serialization import safe_globals + with safe_globals([np.core.multiarray.scalar]): + checkpoint = torch.load(path, weights_only=True) + except Exception as e2: + logger.warning(f"Failed with safe_globals: {e2}") + # Fall back to weights_only=False if needed + checkpoint = torch.load(path, weights_only=False) + self.policy_net.load_state_dict(checkpoint['policy_net']) self.target_net.load_state_dict(checkpoint['target_net']) self.optimizer.load_state_dict(checkpoint['optimizer']) @@ -2154,6 +2169,13 @@ async def live_trading(agent, env, exchange, demo=True): # Main trading loop step_counter = 0 + # For online learning + states = [] + actions = [] + rewards = [] + next_states = [] + dones = [] + while True: # Wait for the next candle (1 minute) await asyncio.sleep(5) # Check every 5 seconds @@ -2175,7 +2197,38 @@ async def live_trading(agent, env, exchange, demo=True): action = agent.select_action(state, training=False) # Take action - _, reward, _ = env.step(action) + next_state, reward, done = env.step(action) + + # Store experience for online learning + states.append(state) + actions.append(action) + rewards.append(reward) + next_states.append(next_state) + dones.append(done) + + # Online learning - update the model with new experiences + if len(states) >= 10: # Batch size for online learning + # Store experiences in replay memory + for i in range(len(states)): + agent.memory.push(states[i], actions[i], rewards[i], next_states[i], dones[i]) + + # Learn from experiences if we have enough samples + if len(agent.memory) > 32: + loss = agent.learn() + if loss is not None: + agent.writer.add_scalar('Live/Loss', loss, step_counter) + + # Clear the temporary storage + states = [] + actions = [] + rewards = [] + next_states = [] + dones = [] + + # Save the updated model periodically + if step_counter % 100 == 0: + agent.save("models/trading_agent_live_updated.pt") + logger.info("Updated model saved during live trading") # Log trading activity action_names = ["HOLD", "BUY", "SELL", "CLOSE"]