fix model loading in live mode
This commit is contained in:
parent
485c61cf8c
commit
991cf57274
@ -1695,7 +1695,22 @@ class Agent:
|
||||
|
||||
def load(self, path="models/trading_agent.pt"):
|
||||
if os.path.isfile(path):
|
||||
checkpoint = torch.load(path)
|
||||
try:
|
||||
# First try with weights_only=True (safer)
|
||||
checkpoint = torch.load(path, weights_only=True)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load with weights_only=True: {e}")
|
||||
try:
|
||||
# Try with safe_globals for numpy.scalar
|
||||
import numpy as np
|
||||
from torch.serialization import safe_globals
|
||||
with safe_globals([np.core.multiarray.scalar]):
|
||||
checkpoint = torch.load(path, weights_only=True)
|
||||
except Exception as e2:
|
||||
logger.warning(f"Failed with safe_globals: {e2}")
|
||||
# Fall back to weights_only=False if needed
|
||||
checkpoint = torch.load(path, weights_only=False)
|
||||
|
||||
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||
self.target_net.load_state_dict(checkpoint['target_net'])
|
||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
@ -2154,6 +2169,13 @@ async def live_trading(agent, env, exchange, demo=True):
|
||||
# Main trading loop
|
||||
step_counter = 0
|
||||
|
||||
# For online learning
|
||||
states = []
|
||||
actions = []
|
||||
rewards = []
|
||||
next_states = []
|
||||
dones = []
|
||||
|
||||
while True:
|
||||
# Wait for the next candle (1 minute)
|
||||
await asyncio.sleep(5) # Check every 5 seconds
|
||||
@ -2175,7 +2197,38 @@ async def live_trading(agent, env, exchange, demo=True):
|
||||
action = agent.select_action(state, training=False)
|
||||
|
||||
# Take action
|
||||
_, reward, _ = env.step(action)
|
||||
next_state, reward, done = env.step(action)
|
||||
|
||||
# Store experience for online learning
|
||||
states.append(state)
|
||||
actions.append(action)
|
||||
rewards.append(reward)
|
||||
next_states.append(next_state)
|
||||
dones.append(done)
|
||||
|
||||
# Online learning - update the model with new experiences
|
||||
if len(states) >= 10: # Batch size for online learning
|
||||
# Store experiences in replay memory
|
||||
for i in range(len(states)):
|
||||
agent.memory.push(states[i], actions[i], rewards[i], next_states[i], dones[i])
|
||||
|
||||
# Learn from experiences if we have enough samples
|
||||
if len(agent.memory) > 32:
|
||||
loss = agent.learn()
|
||||
if loss is not None:
|
||||
agent.writer.add_scalar('Live/Loss', loss, step_counter)
|
||||
|
||||
# Clear the temporary storage
|
||||
states = []
|
||||
actions = []
|
||||
rewards = []
|
||||
next_states = []
|
||||
dones = []
|
||||
|
||||
# Save the updated model periodically
|
||||
if step_counter % 100 == 0:
|
||||
agent.save("models/trading_agent_live_updated.pt")
|
||||
logger.info("Updated model saved during live trading")
|
||||
|
||||
# Log trading activity
|
||||
action_names = ["HOLD", "BUY", "SELL", "CLOSE"]
|
||||
|
Loading…
x
Reference in New Issue
Block a user