increased model size
This commit is contained in:
parent
4be322622e
commit
783e411242
@ -63,48 +63,77 @@ class ReplayMemory:
|
|||||||
return len(self.memory)
|
return len(self.memory)
|
||||||
|
|
||||||
class DQN(nn.Module):
|
class DQN(nn.Module):
|
||||||
def __init__(self, state_size, action_size):
|
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4):
|
||||||
super(DQN, self).__init__()
|
super(DQN, self).__init__()
|
||||||
|
|
||||||
# Larger architecture for more complex pattern recognition
|
self.state_size = state_size
|
||||||
self.fc1 = nn.Linear(state_size, 256)
|
self.hidden_size = hidden_size
|
||||||
self.bn1 = nn.BatchNorm1d(256)
|
self.lstm_layers = lstm_layers
|
||||||
|
|
||||||
# LSTM layer for sequential data
|
|
||||||
self.lstm = nn.LSTM(256, 256, num_layers=2, batch_first=True)
|
|
||||||
|
|
||||||
# Attention mechanism
|
|
||||||
self.attention = nn.MultiheadAttention(256, 4)
|
|
||||||
|
|
||||||
# Output layers
|
|
||||||
self.fc2 = nn.Linear(256, 128)
|
|
||||||
self.bn2 = nn.BatchNorm1d(128)
|
|
||||||
self.fc3 = nn.Linear(128, 64)
|
|
||||||
self.fc4 = nn.Linear(64, action_size)
|
|
||||||
|
|
||||||
# Dueling DQN architecture
|
|
||||||
self.value_stream = nn.Linear(64, 1)
|
|
||||||
self.advantage_stream = nn.Linear(64, action_size)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if x.dim() == 1:
|
|
||||||
x = x.unsqueeze(0) # Add batch dimension if needed
|
|
||||||
|
|
||||||
# Initial feature extraction
|
# Initial feature extraction
|
||||||
x = F.relu(self.bn1(self.fc1(x)))
|
self.fc1 = nn.Linear(state_size, hidden_size)
|
||||||
|
self.bn1 = nn.BatchNorm1d(hidden_size)
|
||||||
|
|
||||||
# Process sequential data through LSTM
|
# LSTM layer for sequential data
|
||||||
x = x.unsqueeze(0) if x.dim() == 2 else x # Add sequence dimension if needed
|
self.lstm = nn.LSTM(hidden_size, hidden_size, num_layers=lstm_layers, batch_first=True)
|
||||||
x, _ = self.lstm(x)
|
|
||||||
x = x.squeeze(0) if x.dim() == 3 else x # Remove sequence dimension if only one item
|
|
||||||
|
|
||||||
# Self-attention
|
# Attention mechanism
|
||||||
x_reshaped = x.unsqueeze(1) if x.dim() == 2 else x
|
self.attention = nn.MultiheadAttention(hidden_size, attention_heads)
|
||||||
attn_output, _ = self.attention(x_reshaped, x_reshaped, x_reshaped)
|
|
||||||
x = attn_output.squeeze(1) if x.dim() == 3 else attn_output
|
# Output layers with increased capacity
|
||||||
|
self.fc2 = nn.Linear(hidden_size, hidden_size)
|
||||||
|
self.bn2 = nn.BatchNorm1d(hidden_size)
|
||||||
|
self.fc3 = nn.Linear(hidden_size, hidden_size // 2)
|
||||||
|
|
||||||
|
# Dueling DQN architecture
|
||||||
|
self.value_stream = nn.Linear(hidden_size // 2, 1)
|
||||||
|
self.advantage_stream = nn.Linear(hidden_size // 2, action_size)
|
||||||
|
|
||||||
|
# Transformer encoder for more complex pattern recognition
|
||||||
|
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads)
|
||||||
|
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
batch_size = x.size(0) if x.dim() > 1 else 1
|
||||||
|
|
||||||
|
# Ensure input has correct shape
|
||||||
|
if x.dim() == 1:
|
||||||
|
x = x.unsqueeze(0) # Add batch dimension
|
||||||
|
|
||||||
|
# Check if state size matches expected input size
|
||||||
|
if x.size(1) != self.state_size:
|
||||||
|
# Handle mismatched input by either truncating or padding
|
||||||
|
if x.size(1) > self.state_size:
|
||||||
|
x = x[:, :self.state_size] # Truncate
|
||||||
|
print(f"Warning: Input truncated from {x.size(1)} to {self.state_size}")
|
||||||
|
else:
|
||||||
|
# Pad with zeros
|
||||||
|
padding = torch.zeros(batch_size, self.state_size - x.size(1), device=x.device)
|
||||||
|
x = torch.cat([x, padding], dim=1)
|
||||||
|
print(f"Warning: Input padded from {x.size(1) - padding.size(1)} to {self.state_size}")
|
||||||
|
|
||||||
|
# Initial feature extraction
|
||||||
|
x = self.fc1(x)
|
||||||
|
x = F.relu(self.bn1(x) if batch_size > 1 else self.bn1(x.unsqueeze(0)).squeeze(0))
|
||||||
|
|
||||||
|
# Reshape for LSTM
|
||||||
|
x_lstm = x.unsqueeze(1) if x.dim() == 2 else x
|
||||||
|
|
||||||
|
# Process through LSTM
|
||||||
|
lstm_out, _ = self.lstm(x_lstm)
|
||||||
|
lstm_out = lstm_out.squeeze(1) if lstm_out.size(1) == 1 else lstm_out[:, -1]
|
||||||
|
|
||||||
|
# Process through transformer for more complex patterns
|
||||||
|
transformer_input = x.unsqueeze(1) if x.dim() == 2 else x
|
||||||
|
transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1))
|
||||||
|
transformer_out = transformer_out.transpose(0, 1).mean(dim=1)
|
||||||
|
|
||||||
|
# Combine LSTM and transformer outputs
|
||||||
|
x = lstm_out + transformer_out
|
||||||
|
|
||||||
# Final layers
|
# Final layers
|
||||||
x = F.relu(self.bn2(self.fc2(x)))
|
x = self.fc2(x)
|
||||||
|
x = F.relu(self.bn2(x) if batch_size > 1 else self.bn2(x.unsqueeze(0)).squeeze(0))
|
||||||
x = F.relu(self.fc3(x))
|
x = F.relu(self.fc3(x))
|
||||||
|
|
||||||
# Dueling architecture
|
# Dueling architecture
|
||||||
@ -165,11 +194,37 @@ class TradingEnvironment:
|
|||||||
"""Fetch historical data to initialize the environment"""
|
"""Fetch historical data to initialize the environment"""
|
||||||
logger.info(f"Fetching initial {self.window_size} candles for {self.symbol}...")
|
logger.info(f"Fetching initial {self.window_size} candles for {self.symbol}...")
|
||||||
try:
|
try:
|
||||||
ohlcv = await self.exchange.fetch_ohlcv(
|
# Try to use fetch_ohlcv directly
|
||||||
|
try:
|
||||||
|
# Check if exchange has async methods
|
||||||
|
if hasattr(self.exchange, 'has') and self.exchange.has.get('fetchOHLCVAsync', False):
|
||||||
|
ohlcv = await self.exchange.fetchOHLCVAsync(
|
||||||
self.symbol,
|
self.symbol,
|
||||||
timeframe=self.timeframe,
|
timeframe=self.timeframe,
|
||||||
limit=self.window_size
|
limit=self.window_size
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
# Use synchronous method in an executor
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
ohlcv = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self.exchange.fetch_ohlcv(
|
||||||
|
self.symbol,
|
||||||
|
timeframe=self.timeframe,
|
||||||
|
limit=self.window_size
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
# Fallback to synchronous method
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
ohlcv = await loop.run_in_executor(
|
||||||
|
None,
|
||||||
|
lambda: self.exchange.fetch_ohlcv(
|
||||||
|
self.symbol,
|
||||||
|
timeframe=self.timeframe,
|
||||||
|
limit=self.window_size
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
for candle in ohlcv:
|
for candle in ohlcv:
|
||||||
timestamp, open_price, high, low, close, volume = candle
|
timestamp, open_price, high, low, close, volume = candle
|
||||||
@ -298,17 +353,21 @@ class TradingEnvironment:
|
|||||||
return np.zeros(STATE_SIZE)
|
return np.zeros(STATE_SIZE)
|
||||||
|
|
||||||
# Create a normalized state vector with recent price action and indicators
|
# Create a normalized state vector with recent price action and indicators
|
||||||
|
state_components = []
|
||||||
|
|
||||||
# Price features (normalize recent prices by the latest price)
|
# Price features (normalize recent prices by the latest price)
|
||||||
latest_price = self.features['price'][-1]
|
latest_price = self.features['price'][-1]
|
||||||
price_features = self.features['price'][-10:] / latest_price - 1.0
|
price_features = self.features['price'][-10:] / latest_price - 1.0
|
||||||
|
state_components.append(price_features)
|
||||||
|
|
||||||
# Volume features (normalize by max volume)
|
# Volume features (normalize by max volume)
|
||||||
max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1
|
max_vol = max(self.features['volume'][-20:]) if len(self.features['volume']) >= 20 else 1
|
||||||
vol_features = self.features['volume'][-5:] / max_vol
|
vol_features = self.features['volume'][-5:] / max_vol
|
||||||
|
state_components.append(vol_features)
|
||||||
|
|
||||||
# Technical indicators
|
# Technical indicators
|
||||||
rsi = self.features['rsi'][-3:] / 100.0 # Scale to 0-1
|
rsi = self.features['rsi'][-3:] / 100.0 # Scale to 0-1
|
||||||
|
state_components.append(rsi)
|
||||||
|
|
||||||
# MACD (normalize)
|
# MACD (normalize)
|
||||||
macd_vals = self.features['macd'][-3:]
|
macd_vals = self.features['macd'][-3:]
|
||||||
@ -319,6 +378,8 @@ class TradingEnvironment:
|
|||||||
macd_signal_norm = macd_signal / macd_scale
|
macd_signal_norm = macd_signal / macd_scale
|
||||||
macd_hist_norm = macd_hist / macd_scale
|
macd_hist_norm = macd_hist / macd_scale
|
||||||
|
|
||||||
|
state_components.extend([macd_norm, macd_signal_norm, macd_hist_norm])
|
||||||
|
|
||||||
# Bollinger position (where is price relative to bands)
|
# Bollinger position (where is price relative to bands)
|
||||||
bb_upper = self.features['bollinger_upper'][-3:]
|
bb_upper = self.features['bollinger_upper'][-3:]
|
||||||
bb_lower = self.features['bollinger_lower'][-3:]
|
bb_lower = self.features['bollinger_lower'][-3:]
|
||||||
@ -327,6 +388,11 @@ class TradingEnvironment:
|
|||||||
|
|
||||||
# Calculate position of price within Bollinger Bands (0 to 1)
|
# Calculate position of price within Bollinger Bands (0 to 1)
|
||||||
bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)]
|
bb_pos = [(p - l) / (u - l) if u != l else 0.5 for p, u, l in zip(price, bb_upper, bb_lower)]
|
||||||
|
state_components.append(bb_pos)
|
||||||
|
|
||||||
|
# Stochastic oscillator
|
||||||
|
state_components.append(self.features['stoch_k'][-3:] / 100.0)
|
||||||
|
state_components.append(self.features['stoch_d'][-3:] / 100.0)
|
||||||
|
|
||||||
# Position info
|
# Position info
|
||||||
position_info = np.zeros(5)
|
position_info = np.zeros(5)
|
||||||
@ -343,23 +409,23 @@ class TradingEnvironment:
|
|||||||
position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit %
|
position_info[3] = (self.entry_price - self.take_profit) / self.entry_price # Take profit %
|
||||||
position_info[4] = self.position_size / self.balance # Position size relative to balance
|
position_info[4] = self.position_size / self.balance # Position size relative to balance
|
||||||
|
|
||||||
|
state_components.append(position_info)
|
||||||
|
|
||||||
# Combine all features
|
# Combine all features
|
||||||
state = np.concatenate([
|
state = np.concatenate(state_components)
|
||||||
price_features, # 10 values
|
|
||||||
vol_features, # 5 values
|
|
||||||
rsi, # 3 values
|
|
||||||
macd_norm, # 3 values
|
|
||||||
macd_signal_norm, # 3 values
|
|
||||||
macd_hist_norm, # 3 values
|
|
||||||
bb_pos, # 3 values
|
|
||||||
self.features['stoch_k'][-3:] / 100.0, # 3 values
|
|
||||||
self.features['stoch_d'][-3:] / 100.0, # 3 values
|
|
||||||
position_info # 5 values
|
|
||||||
])
|
|
||||||
|
|
||||||
# Replace any NaN values
|
# Replace any NaN values
|
||||||
state = np.nan_to_num(state, nan=0.0)
|
state = np.nan_to_num(state, nan=0.0)
|
||||||
|
|
||||||
|
# Ensure state has exactly STATE_SIZE elements
|
||||||
|
if len(state) > STATE_SIZE:
|
||||||
|
# Truncate if too long
|
||||||
|
state = state[:STATE_SIZE]
|
||||||
|
elif len(state) < STATE_SIZE:
|
||||||
|
# Pad with zeros if too short
|
||||||
|
padding = np.zeros(STATE_SIZE - len(state))
|
||||||
|
state = np.concatenate([state, padding])
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
@ -568,18 +634,23 @@ class TradingEnvironment:
|
|||||||
return self.get_state()
|
return self.get_state()
|
||||||
|
|
||||||
class Agent:
|
class Agent:
|
||||||
def __init__(self, state_size, action_size, device="cuda" if torch.cuda.is_available() else "cpu"):
|
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4,
|
||||||
|
device="cuda" if torch.cuda.is_available() else "cpu"):
|
||||||
self.state_size = state_size
|
self.state_size = state_size
|
||||||
self.action_size = action_size
|
self.action_size = action_size
|
||||||
self.device = device
|
self.device = device
|
||||||
self.memory = ReplayMemory(MEMORY_SIZE)
|
self.memory = ReplayMemory(MEMORY_SIZE)
|
||||||
|
|
||||||
# Q-Networks
|
# Q-Networks with configurable size
|
||||||
self.policy_net = DQN(state_size, action_size).to(device)
|
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(device)
|
||||||
self.target_net = DQN(state_size, action_size).to(device)
|
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(device)
|
||||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||||
self.target_net.eval()
|
self.target_net.eval()
|
||||||
|
|
||||||
|
# Print model size
|
||||||
|
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
||||||
|
logger.info(f"Model size: {total_params:,} parameters")
|
||||||
|
|
||||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
|
||||||
|
|
||||||
self.epsilon = EPSILON_START
|
self.epsilon = EPSILON_START
|
||||||
@ -588,6 +659,55 @@ class Agent:
|
|||||||
# TensorBoard logging
|
# TensorBoard logging
|
||||||
self.writer = SummaryWriter(log_dir='runs/trading_agent')
|
self.writer = SummaryWriter(log_dir='runs/trading_agent')
|
||||||
|
|
||||||
|
def expand_model(self, new_state_size, new_hidden_size=512, new_lstm_layers=3, new_attention_heads=8):
|
||||||
|
"""Expand the model to handle more features or increase capacity"""
|
||||||
|
logger.info(f"Expanding model: {self.state_size} → {new_state_size}, "
|
||||||
|
f"hidden: {self.policy_net.hidden_size} → {new_hidden_size}")
|
||||||
|
|
||||||
|
# Save old weights
|
||||||
|
old_state_dict = self.policy_net.state_dict()
|
||||||
|
|
||||||
|
# Create new larger networks
|
||||||
|
new_policy_net = DQN(new_state_size, self.action_size,
|
||||||
|
new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device)
|
||||||
|
new_target_net = DQN(new_state_size, self.action_size,
|
||||||
|
new_hidden_size, new_lstm_layers, new_attention_heads).to(self.device)
|
||||||
|
|
||||||
|
# Transfer weights for common layers
|
||||||
|
new_state_dict = new_policy_net.state_dict()
|
||||||
|
for name, param in old_state_dict.items():
|
||||||
|
if name in new_state_dict:
|
||||||
|
# If shapes match, copy directly
|
||||||
|
if new_state_dict[name].shape == param.shape:
|
||||||
|
new_state_dict[name] = param
|
||||||
|
# For first layer, copy weights for the original input dimensions
|
||||||
|
elif name == "fc1.weight":
|
||||||
|
new_state_dict[name][:, :self.state_size] = param
|
||||||
|
# For other layers, initialize with a strategy that preserves scale
|
||||||
|
else:
|
||||||
|
logger.info(f"Layer {name} shapes don't match: {param.shape} vs {new_state_dict[name].shape}")
|
||||||
|
|
||||||
|
# Load transferred weights
|
||||||
|
new_policy_net.load_state_dict(new_state_dict)
|
||||||
|
new_target_net.load_state_dict(new_state_dict)
|
||||||
|
|
||||||
|
# Replace networks
|
||||||
|
self.policy_net = new_policy_net
|
||||||
|
self.target_net = new_target_net
|
||||||
|
self.target_net.eval()
|
||||||
|
|
||||||
|
# Update optimizer
|
||||||
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE)
|
||||||
|
|
||||||
|
# Update state size
|
||||||
|
self.state_size = new_state_size
|
||||||
|
|
||||||
|
# Print new model size
|
||||||
|
total_params = sum(p.numel() for p in self.policy_net.parameters())
|
||||||
|
logger.info(f"New model size: {total_params:,} parameters")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def select_action(self, state, training=True):
|
def select_action(self, state, training=True):
|
||||||
sample = random.random()
|
sample = random.random()
|
||||||
|
|
||||||
@ -725,7 +845,9 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
|
|
||||||
best_reward = -float('inf')
|
best_reward = -float('inf')
|
||||||
|
|
||||||
|
try:
|
||||||
for episode in range(num_episodes):
|
for episode in range(num_episodes):
|
||||||
|
try:
|
||||||
state = env.reset()
|
state = env.reset()
|
||||||
episode_reward = 0
|
episode_reward = 0
|
||||||
|
|
||||||
@ -743,9 +865,12 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
episode_reward += reward
|
episode_reward += reward
|
||||||
|
|
||||||
# Learn from experience
|
# Learn from experience
|
||||||
|
try:
|
||||||
loss = agent.learn()
|
loss = agent.learn()
|
||||||
if loss is not None:
|
if loss is not None:
|
||||||
agent.writer.add_scalar('Loss/train', loss, agent.steps_done)
|
agent.writer.add_scalar('Loss/train', loss, agent.steps_done)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Learning error in episode {episode}, step {step}: {e}")
|
||||||
|
|
||||||
if done:
|
if done:
|
||||||
break
|
break
|
||||||
@ -754,7 +879,7 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
if episode % TARGET_UPDATE == 0:
|
if episode % TARGET_UPDATE == 0:
|
||||||
agent.update_target_network()
|
agent.update_target_network()
|
||||||
|
|
||||||
# Calculate win rate
|
# Calculate statistics
|
||||||
if len(env.trades) > 0:
|
if len(env.trades) > 0:
|
||||||
wins = sum(1 for trade in env.trades if trade.get('pnl_percent', 0) > 0)
|
wins = sum(1 for trade in env.trades if trade.get('pnl_percent', 0) > 0)
|
||||||
win_rate = wins / len(env.trades) * 100
|
win_rate = wins / len(env.trades) * 100
|
||||||
@ -784,12 +909,20 @@ async def train_agent(agent, env, num_episodes=1000, max_steps_per_episode=1000)
|
|||||||
if episode % 10 == 0:
|
if episode % 10 == 0:
|
||||||
agent.save(f"models/trading_agent_episode_{episode}.pt")
|
agent.save(f"models/trading_agent_episode_{episode}.pt")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in episode {episode}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
# Save final model
|
# Save final model
|
||||||
agent.save("models/trading_agent_final.pt")
|
agent.save("models/trading_agent_final.pt")
|
||||||
|
|
||||||
# Plot training results
|
# Plot training results
|
||||||
plot_training_results(stats)
|
plot_training_results(stats)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Training failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
return stats
|
return stats
|
||||||
|
|
||||||
def plot_training_results(stats):
|
def plot_training_results(stats):
|
||||||
@ -865,20 +998,116 @@ def evaluate_agent(agent, env, num_episodes=10):
|
|||||||
|
|
||||||
return avg_reward, avg_profit, win_rate
|
return avg_reward, avg_profit, win_rate
|
||||||
|
|
||||||
|
async def test_training():
|
||||||
|
"""Test the training process with a small number of episodes"""
|
||||||
|
logger.info("Starting training tests...")
|
||||||
|
|
||||||
|
# Initialize exchange
|
||||||
|
exchange = ccxt.mexc({
|
||||||
|
'apiKey': MEXC_API_KEY,
|
||||||
|
'secret': MEXC_SECRET_KEY,
|
||||||
|
'enableRateLimit': True,
|
||||||
|
})
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create environment with small initial balance for testing
|
||||||
|
env = TradingEnvironment(
|
||||||
|
exchange=exchange,
|
||||||
|
symbol="ETH/USDT",
|
||||||
|
timeframe="1m",
|
||||||
|
leverage=MAX_LEVERAGE,
|
||||||
|
initial_balance=100, # Small balance for testing
|
||||||
|
is_demo=True # Always use demo mode for testing
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch initial data
|
||||||
|
await env.fetch_initial_data()
|
||||||
|
|
||||||
|
# Create agent
|
||||||
|
agent = Agent(state_size=STATE_SIZE, action_size=env.action_space)
|
||||||
|
|
||||||
|
# Run a few test episodes
|
||||||
|
test_episodes = 3
|
||||||
|
logger.info(f"Running {test_episodes} test episodes...")
|
||||||
|
|
||||||
|
for episode in range(test_episodes):
|
||||||
|
state = env.reset()
|
||||||
|
episode_reward = 0
|
||||||
|
done = False
|
||||||
|
step = 0
|
||||||
|
|
||||||
|
while not done and step < 100: # Limit steps for testing
|
||||||
|
# Select action
|
||||||
|
action = agent.select_action(state)
|
||||||
|
|
||||||
|
# Take action
|
||||||
|
next_state, reward, done = env.step(action)
|
||||||
|
|
||||||
|
# Store experience
|
||||||
|
agent.memory.push(state, action, reward, next_state, done)
|
||||||
|
|
||||||
|
# Learn
|
||||||
|
loss = agent.learn()
|
||||||
|
|
||||||
|
state = next_state
|
||||||
|
episode_reward += reward
|
||||||
|
step += 1
|
||||||
|
|
||||||
|
# Print progress
|
||||||
|
if step % 10 == 0:
|
||||||
|
logger.info(f"Episode {episode + 1}, Step {step}, Reward: {episode_reward:.2f}")
|
||||||
|
|
||||||
|
logger.info(f"Test episode {episode + 1} completed with reward: {episode_reward:.2f}")
|
||||||
|
|
||||||
|
# Test model saving
|
||||||
|
try:
|
||||||
|
agent.save("models/test_model.pt")
|
||||||
|
logger.info("Successfully saved model")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving model: {e}")
|
||||||
|
|
||||||
|
logger.info("Training tests completed successfully")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Training test failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await exchange.close()
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser(description='ETH/USD Trading Bot with RL')
|
parser = argparse.ArgumentParser(description='ETH/USD Trading Bot with RL')
|
||||||
parser.add_argument('--mode', type=str, default='train', choices=['train', 'eval', 'live'],
|
parser.add_argument('--mode', type=str, default='train', choices=['train', 'eval', 'live', 'test'],
|
||||||
help='Operation mode: train, eval, or live')
|
help='Operation mode: train, eval, live, or test')
|
||||||
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes for training/evaluation')
|
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes for training/evaluation')
|
||||||
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trading)')
|
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trading)')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Initialize exchange with async support
|
if args.mode == 'test':
|
||||||
exchange_id = 'mexc'
|
# Run training tests
|
||||||
exchange_class = getattr(ccxt.async_support, exchange_id)
|
success = await test_training()
|
||||||
exchange = exchange_class({
|
if success:
|
||||||
|
logger.info("All tests passed!")
|
||||||
|
else:
|
||||||
|
logger.error("Tests failed!")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Initialize exchange with async capabilities
|
||||||
|
try:
|
||||||
|
# Try the newer CCXT approach first
|
||||||
|
exchange = ccxt.mexc({
|
||||||
|
'apiKey': MEXC_API_KEY,
|
||||||
|
'secret': MEXC_SECRET_KEY,
|
||||||
|
'enableRateLimit': True,
|
||||||
|
'asyncio_loop': asyncio.get_event_loop()
|
||||||
|
})
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Could not initialize exchange with asyncio_loop: {e}")
|
||||||
|
# Fallback to standard exchange
|
||||||
|
exchange = ccxt.mexc({
|
||||||
'apiKey': MEXC_API_KEY,
|
'apiKey': MEXC_API_KEY,
|
||||||
'secret': MEXC_SECRET_KEY,
|
'secret': MEXC_SECRET_KEY,
|
||||||
'enableRateLimit': True,
|
'enableRateLimit': True,
|
||||||
|
@ -163,3 +163,20 @@ Added proper bash syntax highlighting for command examples
|
|||||||
The README.md now provides a complete guide for setting up and using the trading bot, with clear sections for installation, usage, configuration, and safety considerations.
|
The README.md now provides a complete guide for setting up and using the trading bot, with clear sections for installation, usage, configuration, and safety considerations.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# Edits/improvements
|
||||||
|
|
||||||
|
Fixes the shape mismatch by ensuring the state vector is exactly STATE_SIZE elements
|
||||||
|
Adds robust error handling in the model's forward pass to handle mismatched inputs
|
||||||
|
Adds a transformer encoder for more sophisticated pattern recognition
|
||||||
|
Provides an expand_model method to increase model capacity while preserving learned weights
|
||||||
|
Adds detailed logging about model size and shape mismatches
|
||||||
|
The model now has:
|
||||||
|
Configurable hidden layer sizes
|
||||||
|
Transformer layers for complex pattern recognition
|
||||||
|
LSTM layers for temporal patterns
|
||||||
|
Attention mechanisms for focusing on important features
|
||||||
|
Dueling architecture for better Q-value estimation
|
||||||
|
With hidden_size=256, this model has about 1-2 million parameters. By increasing hidden_size to 512 or 1024, you can easily scale to 5-20 million parameters. For even larger models (billions of parameters), you would need to implement a more distributed architecture with multiple GPUs, which would require significant changes to the training loop.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user