added continious mode. fixed errors
This commit is contained in:
parent
cfddc996d7
commit
e884f0c9e6
11
crypto/gogo2/.vscode/launch.json
vendored
11
crypto/gogo2/.vscode/launch.json
vendored
@ -6,7 +6,7 @@
|
|||||||
"type": "python",
|
"type": "python",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "main.py",
|
"program": "main.py",
|
||||||
"args": ["--mode", "train", "--episodes", "1000"],
|
"args": ["--mode", "train", "--episodes", "100"],
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true
|
"justMyCode": true
|
||||||
},
|
},
|
||||||
@ -36,6 +36,15 @@
|
|||||||
"args": ["--mode", "live"],
|
"args": ["--mode", "live"],
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": true
|
"justMyCode": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Continuous Training",
|
||||||
|
"type": "python",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "main.py",
|
||||||
|
"args": ["--mode", "continuous", "--refresh-data"],
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"justMyCode": true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
@ -1679,7 +1679,7 @@ async def get_live_prices(symbol="ETH/USDT", timeframe="1m"):
|
|||||||
await asyncio.sleep(5)
|
await asyncio.sleep(5)
|
||||||
break
|
break
|
||||||
|
|
||||||
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, exchange=None):
|
async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000, exchange=None, args=None):
|
||||||
"""Train the agent using historical and live data with GPU acceleration"""
|
"""Train the agent using historical and live data with GPU acceleration"""
|
||||||
logger.info(f"Starting training on device: {agent.device}")
|
logger.info(f"Starting training on device: {agent.device}")
|
||||||
|
|
||||||
@ -2404,39 +2404,26 @@ async def main():
|
|||||||
exchange = None
|
exchange = None
|
||||||
try:
|
try:
|
||||||
# Initialize exchange
|
# Initialize exchange
|
||||||
exchange_id = 'mexc'
|
exchange = await initialize_exchange()
|
||||||
exchange_class = getattr(ccxt, exchange_id)
|
|
||||||
exchange = exchange_class({
|
# Create environment with the correct parameters
|
||||||
'apiKey': MEXC_API_KEY,
|
env = TradingEnvironment(
|
||||||
'secret': MEXC_SECRET_KEY,
|
initial_balance=INITIAL_BALANCE,
|
||||||
'enableRateLimit': True,
|
window_size=30,
|
||||||
'options': {
|
demo=args.demo or args.mode != 'live'
|
||||||
'defaultType': 'future',
|
)
|
||||||
}
|
|
||||||
})
|
|
||||||
logger.info(f"Exchange initialized with standard CCXT: {exchange.id}")
|
|
||||||
|
|
||||||
# Fetch initial data
|
# Fetch initial data
|
||||||
logger.info("Fetching initial data for ETH/USDT")
|
logger.info("Fetching initial data for ETH/USDT")
|
||||||
data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500)
|
await env.fetch_initial_data(exchange, "ETH/USDT", "1m", 500)
|
||||||
|
|
||||||
# Initialize environment
|
|
||||||
env = TradingEnvironment(
|
|
||||||
data=data,
|
|
||||||
symbol="ETH/USDT",
|
|
||||||
timeframe="1m",
|
|
||||||
leverage=MAX_LEVERAGE,
|
|
||||||
initial_balance=INITIAL_BALANCE,
|
|
||||||
is_demo=args.demo or args.mode != 'live'
|
|
||||||
)
|
|
||||||
logger.info(f"Initialized environment with {len(data)} candles")
|
|
||||||
|
|
||||||
# Initialize agent
|
# Initialize agent
|
||||||
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
agent = Agent(STATE_SIZE, 4, hidden_size=384, lstm_layers=2, attention_heads=4, device=device)
|
||||||
|
|
||||||
if args.mode == 'train':
|
if args.mode == 'train':
|
||||||
# Train the agent
|
# Train the agent
|
||||||
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange)
|
logger.info(f"Starting training for {args.episodes} episodes...")
|
||||||
|
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange, args=args)
|
||||||
|
|
||||||
elif args.mode == 'continuous':
|
elif args.mode == 'continuous':
|
||||||
# Run in continuous mode - train indefinitely
|
# Run in continuous mode - train indefinitely
|
||||||
@ -2449,17 +2436,13 @@ async def main():
|
|||||||
logger.info(f"Starting training batch {episode_counter // batch_size + 1}")
|
logger.info(f"Starting training batch {episode_counter // batch_size + 1}")
|
||||||
|
|
||||||
# Refresh data at the start of each batch
|
# Refresh data at the start of each batch
|
||||||
if exchange:
|
if exchange and args.refresh_data:
|
||||||
logger.info("Refreshing data for new training batch")
|
logger.info("Refreshing data for new training batch")
|
||||||
new_data = await fetch_ohlcv_data(exchange, "ETH/USDT", "1m", 500)
|
await env.fetch_new_data(exchange, "ETH/USDT", "1m", 500)
|
||||||
if new_data:
|
logger.info(f"Updated environment with fresh candles")
|
||||||
# Replace environment data with fresh data
|
|
||||||
env.data = new_data
|
|
||||||
env.reset()
|
|
||||||
logger.info(f"Updated environment with {len(new_data)} fresh candles")
|
|
||||||
|
|
||||||
# Train for a batch of episodes
|
# Train for a batch of episodes
|
||||||
stats = await train_agent(agent, env, num_episodes=batch_size, exchange=exchange)
|
stats = await train_agent(agent, env, num_episodes=args.episodes, exchange=exchange, args=args)
|
||||||
|
|
||||||
# Save model after each batch
|
# Save model after each batch
|
||||||
agent.save(f"models/trading_agent_continuous_{episode_counter}.pt")
|
agent.save(f"models/trading_agent_continuous_{episode_counter}.pt")
|
||||||
@ -2481,6 +2464,7 @@ async def main():
|
|||||||
agent.load("models/trading_agent_best_pnl.pt")
|
agent.load("models/trading_agent_best_pnl.pt")
|
||||||
|
|
||||||
# Evaluate the agent
|
# Evaluate the agent
|
||||||
|
logger.info("Evaluating agent...")
|
||||||
results = evaluate_agent(agent, env, num_episodes=10)
|
results = evaluate_agent(agent, env, num_episodes=10)
|
||||||
logger.info(f"Evaluation results: {results}")
|
logger.info(f"Evaluation results: {results}")
|
||||||
|
|
||||||
@ -2490,7 +2474,7 @@ async def main():
|
|||||||
|
|
||||||
# Run live trading
|
# Run live trading
|
||||||
logger.info("Starting live trading...")
|
logger.info("Starting live trading...")
|
||||||
await live_trading(agent, env, exchange)
|
await live_trading(agent, env, exchange, demo=args.demo)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error: {e}")
|
logger.error(f"Error: {e}")
|
||||||
@ -2499,7 +2483,11 @@ async def main():
|
|||||||
# Close exchange connection
|
# Close exchange connection
|
||||||
if exchange:
|
if exchange:
|
||||||
try:
|
try:
|
||||||
await exchange.client.close()
|
# Some CCXT exchanges have close method, others don't
|
||||||
|
if hasattr(exchange, 'close'):
|
||||||
|
await exchange.close()
|
||||||
|
elif hasattr(exchange, 'client') and hasattr(exchange.client, 'close'):
|
||||||
|
await exchange.client.close()
|
||||||
logger.info("Exchange connection closed")
|
logger.info("Exchange connection closed")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not properly close exchange connection: {e}")
|
logger.warning(f"Could not properly close exchange connection: {e}")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user