fix model loading in live mode
This commit is contained in:
parent
485c61cf8c
commit
991cf57274
@ -1695,7 +1695,22 @@ class Agent:
|
|||||||
|
|
||||||
def load(self, path="models/trading_agent.pt"):
|
def load(self, path="models/trading_agent.pt"):
|
||||||
if os.path.isfile(path):
|
if os.path.isfile(path):
|
||||||
checkpoint = torch.load(path)
|
try:
|
||||||
|
# First try with weights_only=True (safer)
|
||||||
|
checkpoint = torch.load(path, weights_only=True)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to load with weights_only=True: {e}")
|
||||||
|
try:
|
||||||
|
# Try with safe_globals for numpy.scalar
|
||||||
|
import numpy as np
|
||||||
|
from torch.serialization import safe_globals
|
||||||
|
with safe_globals([np.core.multiarray.scalar]):
|
||||||
|
checkpoint = torch.load(path, weights_only=True)
|
||||||
|
except Exception as e2:
|
||||||
|
logger.warning(f"Failed with safe_globals: {e2}")
|
||||||
|
# Fall back to weights_only=False if needed
|
||||||
|
checkpoint = torch.load(path, weights_only=False)
|
||||||
|
|
||||||
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
self.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||||
self.target_net.load_state_dict(checkpoint['target_net'])
|
self.target_net.load_state_dict(checkpoint['target_net'])
|
||||||
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
self.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||||
@ -2154,6 +2169,13 @@ async def live_trading(agent, env, exchange, demo=True):
|
|||||||
# Main trading loop
|
# Main trading loop
|
||||||
step_counter = 0
|
step_counter = 0
|
||||||
|
|
||||||
|
# For online learning
|
||||||
|
states = []
|
||||||
|
actions = []
|
||||||
|
rewards = []
|
||||||
|
next_states = []
|
||||||
|
dones = []
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
# Wait for the next candle (1 minute)
|
# Wait for the next candle (1 minute)
|
||||||
await asyncio.sleep(5) # Check every 5 seconds
|
await asyncio.sleep(5) # Check every 5 seconds
|
||||||
@ -2175,7 +2197,38 @@ async def live_trading(agent, env, exchange, demo=True):
|
|||||||
action = agent.select_action(state, training=False)
|
action = agent.select_action(state, training=False)
|
||||||
|
|
||||||
# Take action
|
# Take action
|
||||||
_, reward, _ = env.step(action)
|
next_state, reward, done = env.step(action)
|
||||||
|
|
||||||
|
# Store experience for online learning
|
||||||
|
states.append(state)
|
||||||
|
actions.append(action)
|
||||||
|
rewards.append(reward)
|
||||||
|
next_states.append(next_state)
|
||||||
|
dones.append(done)
|
||||||
|
|
||||||
|
# Online learning - update the model with new experiences
|
||||||
|
if len(states) >= 10: # Batch size for online learning
|
||||||
|
# Store experiences in replay memory
|
||||||
|
for i in range(len(states)):
|
||||||
|
agent.memory.push(states[i], actions[i], rewards[i], next_states[i], dones[i])
|
||||||
|
|
||||||
|
# Learn from experiences if we have enough samples
|
||||||
|
if len(agent.memory) > 32:
|
||||||
|
loss = agent.learn()
|
||||||
|
if loss is not None:
|
||||||
|
agent.writer.add_scalar('Live/Loss', loss, step_counter)
|
||||||
|
|
||||||
|
# Clear the temporary storage
|
||||||
|
states = []
|
||||||
|
actions = []
|
||||||
|
rewards = []
|
||||||
|
next_states = []
|
||||||
|
dones = []
|
||||||
|
|
||||||
|
# Save the updated model periodically
|
||||||
|
if step_counter % 100 == 0:
|
||||||
|
agent.save("models/trading_agent_live_updated.pt")
|
||||||
|
logger.info("Updated model saved during live trading")
|
||||||
|
|
||||||
# Log trading activity
|
# Log trading activity
|
||||||
action_names = ["HOLD", "BUY", "SELL", "CLOSE"]
|
action_names = ["HOLD", "BUY", "SELL", "CLOSE"]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user