working on GPU
This commit is contained in:
parent
643bc154a2
commit
2b1f00cbfc
1
crypto/gogo2/checkpoints/best_metrics.json
Normal file
1
crypto/gogo2/checkpoints/best_metrics.json
Normal file
@ -0,0 +1 @@
|
|||||||
|
{"best_reward": 202.7441047517104, "best_pnl": -10.072078721366783, "best_win_rate": 30.864197530864196, "last_episode": 10, "timestamp": "2025-03-10T12:45:27.247997"}
|
@ -124,16 +124,24 @@ class DQN(nn.Module):
|
|||||||
self.advantage_stream = nn.Linear(hidden_size // 2, action_size)
|
self.advantage_stream = nn.Linear(hidden_size // 2, action_size)
|
||||||
|
|
||||||
# Transformer encoder for more complex pattern recognition
|
# Transformer encoder for more complex pattern recognition
|
||||||
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads, dropout=0.1)
|
encoder_layer = nn.TransformerEncoderLayer(
|
||||||
|
d_model=hidden_size,
|
||||||
|
nhead=attention_heads,
|
||||||
|
dropout=0.1,
|
||||||
|
batch_first=True # Add this parameter
|
||||||
|
)
|
||||||
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
batch_size = x.size(0) if x.dim() > 1 else 1
|
batch_size = x.size(0) if x.dim() > 1 else 1
|
||||||
|
|
||||||
# Ensure input has correct shape
|
# Ensure input has correct shape and type
|
||||||
if x.dim() == 1:
|
if x.dim() == 1:
|
||||||
x = x.unsqueeze(0) # Add batch dimension
|
x = x.unsqueeze(0) # Add batch dimension
|
||||||
|
|
||||||
|
# Ensure float32 type
|
||||||
|
x = x.float()
|
||||||
|
|
||||||
# Check if state size matches expected input size
|
# Check if state size matches expected input size
|
||||||
if x.size(1) != self.state_size:
|
if x.size(1) != self.state_size:
|
||||||
# Handle mismatched input by either truncating or padding
|
# Handle mismatched input by either truncating or padding
|
||||||
@ -158,8 +166,8 @@ class DQN(nn.Module):
|
|||||||
|
|
||||||
# Process through transformer for more complex patterns
|
# Process through transformer for more complex patterns
|
||||||
transformer_input = x.unsqueeze(1) if x.dim() == 2 else x
|
transformer_input = x.unsqueeze(1) if x.dim() == 2 else x
|
||||||
transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1))
|
transformer_out = self.transformer_encoder(transformer_input)
|
||||||
transformer_out = transformer_out.transpose(0, 1).mean(dim=1)
|
transformer_out = transformer_out.mean(dim=1) # Average across sequence dimension
|
||||||
|
|
||||||
# Combine LSTM and transformer outputs
|
# Combine LSTM and transformer outputs
|
||||||
x = lstm_out + transformer_out
|
x = lstm_out + transformer_out
|
||||||
@ -1309,48 +1317,62 @@ class TradingEnvironment:
|
|||||||
return fee
|
return fee
|
||||||
|
|
||||||
# Ensure GPU usage if available
|
# Ensure GPU usage if available
|
||||||
def get_device():
|
def get_device(device_preference='gpu'):
|
||||||
"""Get the device to use (GPU or CPU)"""
|
"""Get the device to use (GPU or CPU) based on preference and availability"""
|
||||||
if torch.cuda.is_available():
|
if device_preference.lower() == 'gpu' and torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
# Set default tensor type to float32 for CUDA
|
# Set default tensor type to float32 for CUDA
|
||||||
torch.set_default_tensor_type(torch.FloatTensor)
|
torch.set_default_tensor_type(torch.FloatTensor)
|
||||||
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
logger.info("Using CPU")
|
if device_preference.lower() == 'gpu':
|
||||||
|
logger.info("GPU requested but not available, using CPU instead")
|
||||||
|
else:
|
||||||
|
logger.info("Using CPU as requested")
|
||||||
return device
|
return device
|
||||||
|
|
||||||
# Update Agent class to use GPU properly
|
# Update Agent class to use GPU properly
|
||||||
class Agent:
|
class Agent:
|
||||||
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4,
|
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, device=None):
|
||||||
device=None):
|
"""Initialize the agent with the policy and target networks"""
|
||||||
if device is None:
|
|
||||||
self.device = get_device()
|
|
||||||
else:
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
self.state_size = state_size
|
self.state_size = state_size
|
||||||
self.action_size = action_size
|
self.action_size = action_size
|
||||||
self.memory = ReplayMemory(MEMORY_SIZE)
|
|
||||||
self.steps_done = 0
|
|
||||||
|
|
||||||
# Initialize policy and target networks
|
# Set device (GPU or CPU)
|
||||||
|
if device is None:
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
else:
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Initialize networks
|
||||||
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
||||||
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
||||||
ensure_float32(self.policy_net)
|
ensure_model_float32(self.policy_net)
|
||||||
ensure_float32(self.target_net)
|
ensure_model_float32(self.target_net)
|
||||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||||
self.target_net.eval()
|
self.target_net.eval()
|
||||||
|
|
||||||
# Initialize optimizer with weight decay for regularization
|
# Initialize optimizer with weight decay for regularization
|
||||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
|
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
|
||||||
|
|
||||||
# Initialize gradient scaler for mixed precision training
|
# Initialize experience replay memory
|
||||||
|
self.memory = ReplayMemory(MEMORY_SIZE)
|
||||||
|
|
||||||
|
# Initialize steps counter
|
||||||
|
self.steps_done = 0
|
||||||
|
|
||||||
|
# Initialize epsilon for exploration
|
||||||
|
self.epsilon = EPSILON_START
|
||||||
|
self.epsilon_start = EPSILON_START
|
||||||
|
self.epsilon_end = EPSILON_END
|
||||||
|
self.epsilon_decay = EPSILON_DECAY
|
||||||
|
|
||||||
|
# Initialize mixed precision scaler
|
||||||
self.scaler = amp.GradScaler()
|
self.scaler = amp.GradScaler()
|
||||||
|
|
||||||
# TensorBoard writer
|
# Initialize TensorBoard writer
|
||||||
self.writer = SummaryWriter()
|
self.writer = SummaryWriter(f'runs/trading_agent_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}')
|
||||||
|
|
||||||
# Create models directory if it doesn't exist
|
# Create models directory if it doesn't exist
|
||||||
os.makedirs("models", exist_ok=True)
|
os.makedirs("models", exist_ok=True)
|
||||||
@ -2210,10 +2232,12 @@ async def main():
|
|||||||
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
|
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
|
||||||
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
|
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
|
||||||
parser.add_argument('--refresh-data', action='store_true', help='Refresh data during training')
|
parser.add_argument('--refresh-data', action='store_true', help='Refresh data during training')
|
||||||
|
parser.add_argument('--device', type=str, default='gpu', choices=['gpu', 'cpu'],
|
||||||
|
help='Device to use for training (gpu or cpu)')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Get device (GPU or CPU)
|
# Get device based on argument and availability
|
||||||
device = get_device()
|
device = get_device(args.device)
|
||||||
|
|
||||||
exchange = None
|
exchange = None
|
||||||
|
|
||||||
@ -2270,10 +2294,172 @@ async def main():
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Could not properly close exchange connection: {e}")
|
logger.warning(f"Could not properly close exchange connection: {e}")
|
||||||
|
|
||||||
def ensure_float32(model):
|
def ensure_model_float32(model):
|
||||||
"""Ensure all model parameters are float32"""
|
"""Ensure all model parameters are float32"""
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
param.data = param.data.float() # Convert to float32
|
param.data = param.data.float() # Convert to float32
|
||||||
|
return model
|
||||||
|
|
||||||
|
def ensure_float32(tensor_or_array):
|
||||||
|
"""Ensure the input is a float32 tensor or numpy array"""
|
||||||
|
if isinstance(tensor_or_array, torch.Tensor):
|
||||||
|
return tensor_or_array.float()
|
||||||
|
elif isinstance(tensor_or_array, np.ndarray):
|
||||||
|
return tensor_or_array.astype(np.float32)
|
||||||
|
else:
|
||||||
|
return np.array(tensor_or_array, dtype=np.float32)
|
||||||
|
|
||||||
|
def visualize_training_results(env, agent, episode_num):
|
||||||
|
"""Visualize the training results with OHLCV data, actions, and predictions"""
|
||||||
|
try:
|
||||||
|
# Create directory for visualizations if it doesn't exist
|
||||||
|
os.makedirs("visualizations", exist_ok=True)
|
||||||
|
|
||||||
|
# Get the data for visualization
|
||||||
|
if len(env.data) < 100:
|
||||||
|
logger.warning("Not enough data for visualization")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use the last 100 candles for visualization
|
||||||
|
data = env.data[-100:]
|
||||||
|
|
||||||
|
# Create a pandas DataFrame for easier plotting
|
||||||
|
df = pd.DataFrame([{
|
||||||
|
'timestamp': candle['timestamp'],
|
||||||
|
'open': candle['open'],
|
||||||
|
'high': candle['high'],
|
||||||
|
'low': candle['low'],
|
||||||
|
'close': candle['close'],
|
||||||
|
'volume': candle['volume']
|
||||||
|
} for candle in data])
|
||||||
|
|
||||||
|
# Convert timestamp to datetime
|
||||||
|
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||||
|
df.set_index('timestamp', inplace=True)
|
||||||
|
|
||||||
|
# Create the plot
|
||||||
|
plt.figure(figsize=(16, 12))
|
||||||
|
|
||||||
|
# Plot OHLC data
|
||||||
|
ax1 = plt.subplot(3, 1, 1)
|
||||||
|
ax1.set_title(f'Training Visualization - Episode {episode_num}')
|
||||||
|
|
||||||
|
# Plot candlestick chart
|
||||||
|
from mplfinance.original_flavor import candlestick_ohlc
|
||||||
|
import matplotlib.dates as mdates
|
||||||
|
|
||||||
|
# Convert date to numerical format for candlestick
|
||||||
|
df_ohlc = df.reset_index()
|
||||||
|
df_ohlc['timestamp'] = df_ohlc['timestamp'].map(mdates.date2num)
|
||||||
|
|
||||||
|
candlestick_ohlc(ax1, df_ohlc[['timestamp', 'open', 'high', 'low', 'close']].values,
|
||||||
|
width=0.6, colorup='green', colordown='red')
|
||||||
|
|
||||||
|
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
|
||||||
|
ax1.set_ylabel('Price (USD)')
|
||||||
|
|
||||||
|
# Plot buy/sell actions if available
|
||||||
|
if hasattr(env, 'trades') and env.trades:
|
||||||
|
# Filter trades that occurred in the visualization window
|
||||||
|
recent_trades = [t for t in env.trades if t.get('timestamp', 0) >= df.index[0].timestamp() * 1000]
|
||||||
|
|
||||||
|
for trade in recent_trades:
|
||||||
|
if trade['type'] == 'long':
|
||||||
|
# Buy point
|
||||||
|
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||||
|
ax1.scatter(mdates.date2num(entry_time), trade['entry'],
|
||||||
|
marker='^', color='green', s=100, label='Buy')
|
||||||
|
|
||||||
|
# Sell point if closed
|
||||||
|
if 'exit' in trade and trade['exit'] > 0:
|
||||||
|
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||||
|
ax1.scatter(mdates.date2num(exit_time), trade['exit'],
|
||||||
|
marker='v', color='blue', s=100, label='Sell Long')
|
||||||
|
|
||||||
|
# Draw line connecting entry and exit
|
||||||
|
ax1.plot([mdates.date2num(entry_time), mdates.date2num(exit_time)],
|
||||||
|
[trade['entry'], trade['exit']], 'g--', alpha=0.5)
|
||||||
|
|
||||||
|
elif trade['type'] == 'short':
|
||||||
|
# Sell short point
|
||||||
|
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||||
|
ax1.scatter(mdates.date2num(entry_time), trade['entry'],
|
||||||
|
marker='v', color='red', s=100, label='Sell Short')
|
||||||
|
|
||||||
|
# Buy to cover point if closed
|
||||||
|
if 'exit' in trade and trade['exit'] > 0:
|
||||||
|
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||||
|
ax1.scatter(mdates.date2num(exit_time), trade['exit'],
|
||||||
|
marker='^', color='orange', s=100, label='Buy to Cover')
|
||||||
|
|
||||||
|
# Draw line connecting entry and exit
|
||||||
|
ax1.plot([mdates.date2num(entry_time), mdates.date2num(exit_time)],
|
||||||
|
[trade['entry'], trade['exit']], 'r--', alpha=0.5)
|
||||||
|
|
||||||
|
# Plot predicted prices if available
|
||||||
|
if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0:
|
||||||
|
ax2 = plt.subplot(3, 1, 2, sharex=ax1)
|
||||||
|
ax2.set_title('Price Predictions vs Actual')
|
||||||
|
|
||||||
|
# Plot actual prices
|
||||||
|
ax2.plot(df.index[-len(env.predicted_prices):], df['close'][-len(env.predicted_prices):],
|
||||||
|
label='Actual Price', color='blue')
|
||||||
|
|
||||||
|
# Plot predicted prices
|
||||||
|
# Align predictions with their corresponding timestamps
|
||||||
|
prediction_dates = df.index[-len(env.predicted_prices)-1:-1]
|
||||||
|
if len(prediction_dates) == len(env.predicted_prices):
|
||||||
|
ax2.plot(prediction_dates, env.predicted_prices,
|
||||||
|
label='Predicted Price', color='orange', linestyle='--')
|
||||||
|
|
||||||
|
# Calculate prediction error
|
||||||
|
actual = df['close'][-len(env.predicted_prices)-1:-1].values
|
||||||
|
predicted = env.predicted_prices
|
||||||
|
mape = np.mean(np.abs((actual - predicted) / actual)) * 100
|
||||||
|
ax2.set_ylabel('Price (USD)')
|
||||||
|
ax2.set_title(f'Price Predictions vs Actual (MAPE: {mape:.2f}%)')
|
||||||
|
ax2.legend()
|
||||||
|
|
||||||
|
# Plot technical indicators
|
||||||
|
ax3 = plt.subplot(3, 1, 3, sharex=ax1)
|
||||||
|
ax3.set_title('Technical Indicators')
|
||||||
|
|
||||||
|
# Plot RSI if available
|
||||||
|
if 'rsi' in env.features and len(env.features['rsi']) > 0:
|
||||||
|
rsi_values = env.features['rsi'][-len(df):]
|
||||||
|
if len(rsi_values) == len(df):
|
||||||
|
ax3.plot(df.index, rsi_values, label='RSI', color='purple')
|
||||||
|
|
||||||
|
# Add RSI overbought/oversold lines
|
||||||
|
ax3.axhline(y=70, color='r', linestyle='-', alpha=0.3)
|
||||||
|
ax3.axhline(y=30, color='g', linestyle='-', alpha=0.3)
|
||||||
|
|
||||||
|
# Plot MACD if available
|
||||||
|
if 'macd' in env.features and 'macd_signal' in env.features:
|
||||||
|
macd_values = env.features['macd'][-len(df):]
|
||||||
|
signal_values = env.features['macd_signal'][-len(df):]
|
||||||
|
|
||||||
|
if len(macd_values) == len(df) and len(signal_values) == len(df):
|
||||||
|
ax3.plot(df.index, macd_values, label='MACD', color='blue')
|
||||||
|
ax3.plot(df.index, signal_values, label='Signal', color='red')
|
||||||
|
|
||||||
|
ax3.set_ylabel('Indicator Value')
|
||||||
|
ax3.legend()
|
||||||
|
|
||||||
|
# Format x-axis
|
||||||
|
plt.xticks(rotation=45)
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# Save the figure
|
||||||
|
plt.savefig(f"visualizations/training_episode_{episode_num}.png")
|
||||||
|
logger.info(f"Visualization saved for episode {episode_num}")
|
||||||
|
|
||||||
|
# Close the figure to free memory
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating visualization: {e}")
|
||||||
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
try:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user