added continious mode. fixed errors

This commit is contained in:
Dobromir Popov 2025-03-10 15:37:02 +02:00
parent cfddc996d7
commit e884f0c9e6
2 changed files with 33 additions and 36 deletions

View File

@ -6,7 +6,7 @@
"type": "python",
"request": "launch",
"program": "main.py",
"args": ["--mode", "train", "--episodes", "1000"],
"args": ["--mode", "train", "--episodes", "100"],
"console": "integratedTerminal",
"justMyCode": true
},
@ -36,6 +36,15 @@
"args": ["--mode", "live"],
"console": "integratedTerminal",
"justMyCode": true
},
{
"name": "Continuous Training",
"type": "python",
"request": "launch",
"program": "main.py",
"args": ["--mode", "continuous", "--refresh-data"],
"console": "integratedTerminal",
"justMyCode": true
}
]
}

View File

@ -1679,7 +1679,7 @@ async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
await asyncio.sleep(5)
break
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, exchange=None):
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, exchange=None, args=None):
"""Train the agent using historical and live data with GPU acceleration"""
logger.info(f"Starting training on device: {agent.device}")
@ -2404,39 +2404,26 @@ async def main():
exchange = None
try:
# Initialize exchange
exchange_id = 'mexc'
exchange_class = getattr(ccxt, exchange_id)
exchange = exchange_class({
'apiKey': MEXC_API_KEY,
'secret': MEXC_SECRET_KEY,
'enableRateLimit': True,
'options': {
'defaultType': 'future',
}
})
logger.info(f"Exchange initialized with standard CCXT: {exchange.id}")
exchange = await initialize_exchange()
# Create environment with the correct parameters
env = TradingEnvironment(
initial_balance=INITIAL_BALANCE,
window_size=30,
demo=args.demo or args.mode != 'live'
)
# Fetch initial data
logger.info("Fetching initial data for ETH/USDT")
data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500)
# Initialize environment
env = TradingEnvironment(
data=data,
symbol="ETH/USDT",
timeframe="1m",
leverage=MAX_LEVERAGE,
initial_balance=INITIAL_BALANCE,
is_demo=args.demo or args.mode != 'live'
)
logger.info(f"Initialized environment with {len(data)} candles")
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 500)
# Initialize agent
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
if args.mode == 'train':
# Train the agent
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange)
logger.info(f"Starting training for {args.episodes} episodes...")
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange, args=args)
elif args.mode == 'continuous':
# Run in continuous mode - train indefinitely
@ -2449,17 +2436,13 @@ async def main():
logger.info(f"Starting training batch {episode_counter // batch_size + 1}")
# Refresh data at the start of each batch
if exchange:
if exchange and args.refresh_data:
logger.info("Refreshing data for new training batch")
new_data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500)
if new_data:
# Replace environment data with fresh data
env.data = new_data
env.reset()
logger.info(f"Updated environment with {len(new_data)} fresh candles")
await env.fetch_new_data(exchange, "ETH/USDT", "1m", 500)
logger.info(f"Updated environment with fresh candles")
# Train for a batch of episodes
stats = await train_agent(agent, env, num_episodes=batch_size, exchange=exchange)
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange, args=args)
# Save model after each batch
agent.save(f"models/trading_agent_continuous_{episode_counter}.pt")
@ -2481,6 +2464,7 @@ async def main():
agent.load("models/trading_agent_best_pnl.pt")
# Evaluate the agent
logger.info("Evaluating agent...")
results = evaluate_agent(agent, env, num_episodes=10)
logger.info(f"Evaluation results: {results}")
@ -2490,7 +2474,7 @@ async def main():
# Run live trading
logger.info("Starting live trading...")
await live_trading(agent, env, exchange)
await live_trading(agent, env, exchange, demo=args.demo)
except Exception as e:
logger.error(f"Error: {e}")
@ -2499,6 +2483,10 @@ async def main():
# Close exchange connection
if exchange:
try:
# Some CCXT exchanges have close method, others don't
if hasattr(exchange, 'close'):
await exchange.close()
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
await exchange.client.close()
logger.info("Exchange connection closed")
except Exception as e: