added continious mode. fixed errors
This commit is contained in:
parent
cfddc996d7
commit
e884f0c9e6
11
crypto/gogo2/.vscode/launch.json
vendored
11
crypto/gogo2/.vscode/launch.json
vendored
@ -6,7 +6,7 @@
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "train", "--episodes", "1000"],
|
||||
"args": ["--mode", "train", "--episodes", "100"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
},
|
||||
@ -36,6 +36,15 @@
|
||||
"args": ["--mode", "live"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
},
|
||||
{
|
||||
"name": "Continuous Training",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "main.py",
|
||||
"args": ["--mode", "continuous", "--refresh-data"],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": true
|
||||
}
|
||||
]
|
||||
}
|
@ -1679,7 +1679,7 @@ async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
||||
await asyncio.sleep(5)
|
||||
break
|
||||
|
||||
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, exchange=None):
|
||||
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, exchange=None, args=None):
|
||||
"""Train the agent using historical and live data with GPU acceleration"""
|
||||
logger.info(f"Starting training on device: {agent.device}")
|
||||
|
||||
@ -2404,39 +2404,26 @@ async def main():
|
||||
exchange = None
|
||||
try:
|
||||
# Initialize exchange
|
||||
exchange_id = 'mexc'
|
||||
exchange_class = getattr(ccxt, exchange_id)
|
||||
exchange = exchange_class({
|
||||
'apiKey': MEXC_API_KEY,
|
||||
'secret': MEXC_SECRET_KEY,
|
||||
'enableRateLimit': True,
|
||||
'options': {
|
||||
'defaultType': 'future',
|
||||
}
|
||||
})
|
||||
logger.info(f"Exchange initialized with standard CCXT: {exchange.id}")
|
||||
exchange = await initialize_exchange()
|
||||
|
||||
# Create environment with the correct parameters
|
||||
env = TradingEnvironment(
|
||||
initial_balance=INITIAL_BALANCE,
|
||||
window_size=30,
|
||||
demo=args.demo or args.mode != 'live'
|
||||
)
|
||||
|
||||
# Fetch initial data
|
||||
logger.info("Fetching initial data for ETH/USDT")
|
||||
data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500)
|
||||
|
||||
# Initialize environment
|
||||
env = TradingEnvironment(
|
||||
data=data,
|
||||
symbol="ETH/USDT",
|
||||
timeframe="1m",
|
||||
leverage=MAX_LEVERAGE,
|
||||
initial_balance=INITIAL_BALANCE,
|
||||
is_demo=args.demo or args.mode != 'live'
|
||||
)
|
||||
logger.info(f"Initialized environment with {len(data)} candles")
|
||||
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 500)
|
||||
|
||||
# Initialize agent
|
||||
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
||||
|
||||
if args.mode == 'train':
|
||||
# Train the agent
|
||||
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange)
|
||||
logger.info(f"Starting training for {args.episodes} episodes...")
|
||||
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange, args=args)
|
||||
|
||||
elif args.mode == 'continuous':
|
||||
# Run in continuous mode - train indefinitely
|
||||
@ -2449,17 +2436,13 @@ async def main():
|
||||
logger.info(f"Starting training batch {episode_counter // batch_size + 1}")
|
||||
|
||||
# Refresh data at the start of each batch
|
||||
if exchange:
|
||||
if exchange and args.refresh_data:
|
||||
logger.info("Refreshing data for new training batch")
|
||||
new_data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500)
|
||||
if new_data:
|
||||
# Replace environment data with fresh data
|
||||
env.data = new_data
|
||||
env.reset()
|
||||
logger.info(f"Updated environment with {len(new_data)} fresh candles")
|
||||
await env.fetch_new_data(exchange, "ETH/USDT", "1m", 500)
|
||||
logger.info(f"Updated environment with fresh candles")
|
||||
|
||||
# Train for a batch of episodes
|
||||
stats = await train_agent(agent, env, num_episodes=batch_size, exchange=exchange)
|
||||
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange, args=args)
|
||||
|
||||
# Save model after each batch
|
||||
agent.save(f"models/trading_agent_continuous_{episode_counter}.pt")
|
||||
@ -2481,6 +2464,7 @@ async def main():
|
||||
agent.load("models/trading_agent_best_pnl.pt")
|
||||
|
||||
# Evaluate the agent
|
||||
logger.info("Evaluating agent...")
|
||||
results = evaluate_agent(agent, env, num_episodes=10)
|
||||
logger.info(f"Evaluation results: {results}")
|
||||
|
||||
@ -2490,7 +2474,7 @@ async def main():
|
||||
|
||||
# Run live trading
|
||||
logger.info("Starting live trading...")
|
||||
await live_trading(agent, env, exchange)
|
||||
await live_trading(agent, env, exchange, demo=args.demo)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error: {e}")
|
||||
@ -2499,6 +2483,10 @@ async def main():
|
||||
# Close exchange connection
|
||||
if exchange:
|
||||
try:
|
||||
# Some CCXT exchanges have close method, others don't
|
||||
if hasattr(exchange, 'close'):
|
||||
await exchange.close()
|
||||
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
|
||||
await exchange.client.close()
|
||||
logger.info("Exchange connection closed")
|
||||
except Exception as e:
|
||||
|
Loading…
x
Reference in New Issue
Block a user