working on GPU

This commit is contained in:
Dobromir Popov 2025-03-10 13:15:30 +02:00
parent 643bc154a2
commit 2b1f00cbfc
2 changed files with 214 additions and 27 deletions

View File

@ -0,0 +1 @@
{"best_reward": 202.7441047517104, "best_pnl": -10.072078721366783, "best_win_rate": 30.864197530864196, "last_episode": 10, "timestamp": "2025-03-10T12:45:27.247997"}

View File

@ -124,16 +124,24 @@ class DQN(nn.Module):
self.advantage_stream = nn.Linear(hidden_size // 2, action_size)
# Transformer encoder for more complex pattern recognition
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads, dropout=0.1)
encoder_layer = nn.TransformerEncoderLayer(
d_model=hidden_size,
nhead=attention_heads,
dropout=0.1,
batch_first=True # Add this parameter
)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
def forward(self, x):
batch_size = x.size(0) if x.dim() > 1 else 1
# Ensure input has correct shape
# Ensure input has correct shape and type
if x.dim() == 1:
x = x.unsqueeze(0) # Add batch dimension
# Ensure float32 type
x = x.float()
# Check if state size matches expected input size
if x.size(1) != self.state_size:
# Handle mismatched input by either truncating or padding
@ -158,8 +166,8 @@ class DQN(nn.Module):
# Process through transformer for more complex patterns
transformer_input = x.unsqueeze(1) if x.dim() == 2 else x
transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1))
transformer_out = transformer_out.transpose(0, 1).mean(dim=1)
transformer_out = self.transformer_encoder(transformer_input)
transformer_out = transformer_out.mean(dim=1) # Average across sequence dimension
# Combine LSTM and transformer outputs
x = lstm_out + transformer_out
@ -1309,48 +1317,62 @@ class TradingEnvironment:
return fee
# Ensure GPU usage if available
def get_device():
"""Get the device to use (GPU or CPU)"""
if torch.cuda.is_available():
def get_device(device_preference='gpu'):
"""Get the device to use (GPU or CPU) based on preference and availability"""
if device_preference.lower() == 'gpu' and torch.cuda.is_available():
device = torch.device("cuda")
# Set default tensor type to float32 for CUDA
torch.set_default_tensor_type(torch.FloatTensor)
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
device = torch.device("cpu")
logger.info("Using CPU")
if device_preference.lower() == 'gpu':
logger.info("GPU requested but not available, using CPU instead")
else:
logger.info("Using CPU as requested")
return device
# Update Agent class to use GPU properly
class Agent:
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4,
device=None):
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, device=None):
"""Initialize the agent with the policy and target networks"""
self.state_size = state_size
self.action_size = action_size
# Set device (GPU or CPU)
if device is None:
self.device = get_device()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device
self.state_size = state_size
self.action_size = action_size
self.memory = ReplayMemory(MEMORY_SIZE)
self.steps_done = 0
# Initialize policy and target networks
# Initialize networks
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
ensure_float32(self.policy_net)
ensure_float32(self.target_net)
ensure_model_float32(self.policy_net)
ensure_model_float32(self.target_net)
self.target_net.load_state_dict(self.policy_net.state_dict())
self.target_net.eval()
# Initialize optimizer with weight decay for regularization
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
# Initialize gradient scaler for mixed precision training
# Initialize experience replay memory
self.memory = ReplayMemory(MEMORY_SIZE)
# Initialize steps counter
self.steps_done = 0
# Initialize epsilon for exploration
self.epsilon = EPSILON_START
self.epsilon_start = EPSILON_START
self.epsilon_end = EPSILON_END
self.epsilon_decay = EPSILON_DECAY
# Initialize mixed precision scaler
self.scaler = amp.GradScaler()
# TensorBoard writer
self.writer = SummaryWriter()
# Initialize TensorBoard writer
self.writer = SummaryWriter(f'runs/trading_agent_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}')
# Create models directory if it doesn't exist
os.makedirs("models", exist_ok=True)
@ -2210,10 +2232,12 @@ async def main():
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
parser.add_argument('--refresh-data', action='store_true', help='Refresh data during training')
parser.add_argument('--device', type=str, default='gpu', choices=['gpu', 'cpu'],
help='Device to use for training (gpu or cpu)')
args = parser.parse_args()
# Get device (GPU or CPU)
device = get_device()
# Get device based on argument and availability
device = get_device(args.device)
exchange = None
@ -2270,10 +2294,172 @@ async def main():
except Exception as e:
logger.warning(f"Could not properly close exchange connection: {e}")
def ensure_float32(model):
def ensure_model_float32(model):
"""Ensure all model parameters are float32"""
for param in model.parameters():
param.data = param.data.float() # Convert to float32
return model
def ensure_float32(tensor_or_array):
"""Ensure the input is a float32 tensor or numpy array"""
if isinstance(tensor_or_array, torch.Tensor):
return tensor_or_array.float()
elif isinstance(tensor_or_array, np.ndarray):
return tensor_or_array.astype(np.float32)
else:
return np.array(tensor_or_array, dtype=np.float32)
def visualize_training_results(env, agent, episode_num):
"""Visualize the training results with OHLCV data, actions, and predictions"""
try:
# Create directory for visualizations if it doesn't exist
os.makedirs("visualizations", exist_ok=True)
# Get the data for visualization
if len(env.data) < 100:
logger.warning("Not enough data for visualization")
return
# Use the last 100 candles for visualization
data = env.data[-100:]
# Create a pandas DataFrame for easier plotting
df = pd.DataFrame([{
'timestamp': candle['timestamp'],
'open': candle['open'],
'high': candle['high'],
'low': candle['low'],
'close': candle['close'],
'volume': candle['volume']
} for candle in data])
# Convert timestamp to datetime
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
df.set_index('timestamp', inplace=True)
# Create the plot
plt.figure(figsize=(16, 12))
# Plot OHLC data
ax1 = plt.subplot(3, 1, 1)
ax1.set_title(f'Training Visualization - Episode {episode_num}')
# Plot candlestick chart
from mplfinance.original_flavor import candlestick_ohlc
import matplotlib.dates as mdates
# Convert date to numerical format for candlestick
df_ohlc = df.reset_index()
df_ohlc['timestamp'] = df_ohlc['timestamp'].map(mdates.date2num)
candlestick_ohlc(ax1, df_ohlc[['timestamp', 'open', 'high', 'low', 'close']].values,
width=0.6, colorup='green', colordown='red')
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
ax1.set_ylabel('Price (USD)')
# Plot buy/sell actions if available
if hasattr(env, 'trades') and env.trades:
# Filter trades that occurred in the visualization window
recent_trades = [t for t in env.trades if t.get('timestamp', 0) >= df.index[0].timestamp() * 1000]
for trade in recent_trades:
if trade['type'] == 'long':
# Buy point
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
ax1.scatter(mdates.date2num(entry_time), trade['entry'],
marker='^', color='green', s=100, label='Buy')
# Sell point if closed
if 'exit' in trade and trade['exit'] > 0:
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
ax1.scatter(mdates.date2num(exit_time), trade['exit'],
marker='v', color='blue', s=100, label='Sell Long')
# Draw line connecting entry and exit
ax1.plot([mdates.date2num(entry_time), mdates.date2num(exit_time)],
[trade['entry'], trade['exit']], 'g--', alpha=0.5)
elif trade['type'] == 'short':
# Sell short point
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
ax1.scatter(mdates.date2num(entry_time), trade['entry'],
marker='v', color='red', s=100, label='Sell Short')
# Buy to cover point if closed
if 'exit' in trade and trade['exit'] > 0:
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
ax1.scatter(mdates.date2num(exit_time), trade['exit'],
marker='^', color='orange', s=100, label='Buy to Cover')
# Draw line connecting entry and exit
ax1.plot([mdates.date2num(entry_time), mdates.date2num(exit_time)],
[trade['entry'], trade['exit']], 'r--', alpha=0.5)
# Plot predicted prices if available
if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0:
ax2 = plt.subplot(3, 1, 2, sharex=ax1)
ax2.set_title('Price Predictions vs Actual')
# Plot actual prices
ax2.plot(df.index[-len(env.predicted_prices):], df['close'][-len(env.predicted_prices):],
label='Actual Price', color='blue')
# Plot predicted prices
# Align predictions with their corresponding timestamps
prediction_dates = df.index[-len(env.predicted_prices)-1:-1]
if len(prediction_dates) == len(env.predicted_prices):
ax2.plot(prediction_dates, env.predicted_prices,
label='Predicted Price', color='orange', linestyle='--')
# Calculate prediction error
actual = df['close'][-len(env.predicted_prices)-1:-1].values
predicted = env.predicted_prices
mape = np.mean(np.abs((actual - predicted) / actual)) * 100
ax2.set_ylabel('Price (USD)')
ax2.set_title(f'Price Predictions vs Actual (MAPE: {mape:.2f}%)')
ax2.legend()
# Plot technical indicators
ax3 = plt.subplot(3, 1, 3, sharex=ax1)
ax3.set_title('Technical Indicators')
# Plot RSI if available
if 'rsi' in env.features and len(env.features['rsi']) > 0:
rsi_values = env.features['rsi'][-len(df):]
if len(rsi_values) == len(df):
ax3.plot(df.index, rsi_values, label='RSI', color='purple')
# Add RSI overbought/oversold lines
ax3.axhline(y=70, color='r', linestyle='-', alpha=0.3)
ax3.axhline(y=30, color='g', linestyle='-', alpha=0.3)
# Plot MACD if available
if 'macd' in env.features and 'macd_signal' in env.features:
macd_values = env.features['macd'][-len(df):]
signal_values = env.features['macd_signal'][-len(df):]
if len(macd_values) == len(df) and len(signal_values) == len(df):
ax3.plot(df.index, macd_values, label='MACD', color='blue')
ax3.plot(df.index, signal_values, label='Signal', color='red')
ax3.set_ylabel('Indicator Value')
ax3.legend()
# Format x-axis
plt.xticks(rotation=45)
plt.tight_layout()
# Save the figure
plt.savefig(f"visualizations/training_episode_{episode_num}.png")
logger.info(f"Visualization saved for episode {episode_num}")
# Close the figure to free memory
plt.close()
except Exception as e:
logger.error(f"Error creating visualization: {e}")
logger.error(f"Traceback: {traceback.format_exc()}")
if __name__ == "__main__":
try: