working on GPU
This commit is contained in:
parent
643bc154a2
commit
2b1f00cbfc
1
crypto/gogo2/checkpoints/best_metrics.json
Normal file
1
crypto/gogo2/checkpoints/best_metrics.json
Normal file
@ -0,0 +1 @@
|
||||
{"best_reward": 202.7441047517104, "best_pnl": -10.072078721366783, "best_win_rate": 30.864197530864196, "last_episode": 10, "timestamp": "2025-03-10T12:45:27.247997"}
|
@ -124,16 +124,24 @@ class DQN(nn.Module):
|
||||
self.advantage_stream = nn.Linear(hidden_size // 2, action_size)
|
||||
|
||||
# Transformer encoder for more complex pattern recognition
|
||||
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=attention_heads, dropout=0.1)
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_size,
|
||||
nhead=attention_heads,
|
||||
dropout=0.1,
|
||||
batch_first=True # Add this parameter
|
||||
)
|
||||
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size = x.size(0) if x.dim() > 1 else 1
|
||||
|
||||
# Ensure input has correct shape
|
||||
# Ensure input has correct shape and type
|
||||
if x.dim() == 1:
|
||||
x = x.unsqueeze(0) # Add batch dimension
|
||||
|
||||
|
||||
# Ensure float32 type
|
||||
x = x.float()
|
||||
|
||||
# Check if state size matches expected input size
|
||||
if x.size(1) != self.state_size:
|
||||
# Handle mismatched input by either truncating or padding
|
||||
@ -158,8 +166,8 @@ class DQN(nn.Module):
|
||||
|
||||
# Process through transformer for more complex patterns
|
||||
transformer_input = x.unsqueeze(1) if x.dim() == 2 else x
|
||||
transformer_out = self.transformer_encoder(transformer_input.transpose(0, 1))
|
||||
transformer_out = transformer_out.transpose(0, 1).mean(dim=1)
|
||||
transformer_out = self.transformer_encoder(transformer_input)
|
||||
transformer_out = transformer_out.mean(dim=1) # Average across sequence dimension
|
||||
|
||||
# Combine LSTM and transformer outputs
|
||||
x = lstm_out + transformer_out
|
||||
@ -1309,48 +1317,62 @@ class TradingEnvironment:
|
||||
return fee
|
||||
|
||||
# Ensure GPU usage if available
|
||||
def get_device():
|
||||
"""Get the device to use (GPU or CPU)"""
|
||||
if torch.cuda.is_available():
|
||||
def get_device(device_preference='gpu'):
|
||||
"""Get the device to use (GPU or CPU) based on preference and availability"""
|
||||
if device_preference.lower() == 'gpu' and torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
# Set default tensor type to float32 for CUDA
|
||||
torch.set_default_tensor_type(torch.FloatTensor)
|
||||
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
logger.info("Using CPU")
|
||||
if device_preference.lower() == 'gpu':
|
||||
logger.info("GPU requested but not available, using CPU instead")
|
||||
else:
|
||||
logger.info("Using CPU as requested")
|
||||
return device
|
||||
|
||||
# Update Agent class to use GPU properly
|
||||
class Agent:
|
||||
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4,
|
||||
device=None):
|
||||
if device is None:
|
||||
self.device = get_device()
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
def __init__(self, state_size, action_size, hidden_size=256, lstm_layers=2, attention_heads=4, device=None):
|
||||
"""Initialize the agent with the policy and target networks"""
|
||||
self.state_size = state_size
|
||||
self.action_size = action_size
|
||||
self.memory = ReplayMemory(MEMORY_SIZE)
|
||||
self.steps_done = 0
|
||||
|
||||
# Initialize policy and target networks
|
||||
# Set device (GPU or CPU)
|
||||
if device is None:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
else:
|
||||
self.device = device
|
||||
|
||||
# Initialize networks
|
||||
self.policy_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
||||
self.target_net = DQN(state_size, action_size, hidden_size, lstm_layers, attention_heads).to(self.device)
|
||||
ensure_float32(self.policy_net)
|
||||
ensure_float32(self.target_net)
|
||||
ensure_model_float32(self.policy_net)
|
||||
ensure_model_float32(self.target_net)
|
||||
self.target_net.load_state_dict(self.policy_net.state_dict())
|
||||
self.target_net.eval()
|
||||
|
||||
# Initialize optimizer with weight decay for regularization
|
||||
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)
|
||||
|
||||
# Initialize gradient scaler for mixed precision training
|
||||
# Initialize experience replay memory
|
||||
self.memory = ReplayMemory(MEMORY_SIZE)
|
||||
|
||||
# Initialize steps counter
|
||||
self.steps_done = 0
|
||||
|
||||
# Initialize epsilon for exploration
|
||||
self.epsilon = EPSILON_START
|
||||
self.epsilon_start = EPSILON_START
|
||||
self.epsilon_end = EPSILON_END
|
||||
self.epsilon_decay = EPSILON_DECAY
|
||||
|
||||
# Initialize mixed precision scaler
|
||||
self.scaler = amp.GradScaler()
|
||||
|
||||
# TensorBoard writer
|
||||
self.writer = SummaryWriter()
|
||||
# Initialize TensorBoard writer
|
||||
self.writer = SummaryWriter(f'runs/trading_agent_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}')
|
||||
|
||||
# Create models directory if it doesn't exist
|
||||
os.makedirs("models", exist_ok=True)
|
||||
@ -2210,10 +2232,12 @@ async def main():
|
||||
parser.add_argument('--episodes', type=int, default=1000, help='Number of episodes to train')
|
||||
parser.add_argument('--demo', action='store_true', help='Run in demo mode (no real trades)')
|
||||
parser.add_argument('--refresh-data', action='store_true', help='Refresh data during training')
|
||||
parser.add_argument('--device', type=str, default='gpu', choices=['gpu', 'cpu'],
|
||||
help='Device to use for training (gpu or cpu)')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get device (GPU or CPU)
|
||||
device = get_device()
|
||||
# Get device based on argument and availability
|
||||
device = get_device(args.device)
|
||||
|
||||
exchange = None
|
||||
|
||||
@ -2270,10 +2294,172 @@ async def main():
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not properly close exchange connection: {e}")
|
||||
|
||||
def ensure_float32(model):
|
||||
def ensure_model_float32(model):
|
||||
"""Ensure all model parameters are float32"""
|
||||
for param in model.parameters():
|
||||
param.data = param.data.float() # Convert to float32
|
||||
return model
|
||||
|
||||
def ensure_float32(tensor_or_array):
|
||||
"""Ensure the input is a float32 tensor or numpy array"""
|
||||
if isinstance(tensor_or_array, torch.Tensor):
|
||||
return tensor_or_array.float()
|
||||
elif isinstance(tensor_or_array, np.ndarray):
|
||||
return tensor_or_array.astype(np.float32)
|
||||
else:
|
||||
return np.array(tensor_or_array, dtype=np.float32)
|
||||
|
||||
def visualize_training_results(env, agent, episode_num):
|
||||
"""Visualize the training results with OHLCV data, actions, and predictions"""
|
||||
try:
|
||||
# Create directory for visualizations if it doesn't exist
|
||||
os.makedirs("visualizations", exist_ok=True)
|
||||
|
||||
# Get the data for visualization
|
||||
if len(env.data) < 100:
|
||||
logger.warning("Not enough data for visualization")
|
||||
return
|
||||
|
||||
# Use the last 100 candles for visualization
|
||||
data = env.data[-100:]
|
||||
|
||||
# Create a pandas DataFrame for easier plotting
|
||||
df = pd.DataFrame([{
|
||||
'timestamp': candle['timestamp'],
|
||||
'open': candle['open'],
|
||||
'high': candle['high'],
|
||||
'low': candle['low'],
|
||||
'close': candle['close'],
|
||||
'volume': candle['volume']
|
||||
} for candle in data])
|
||||
|
||||
# Convert timestamp to datetime
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms')
|
||||
df.set_index('timestamp', inplace=True)
|
||||
|
||||
# Create the plot
|
||||
plt.figure(figsize=(16, 12))
|
||||
|
||||
# Plot OHLC data
|
||||
ax1 = plt.subplot(3, 1, 1)
|
||||
ax1.set_title(f'Training Visualization - Episode {episode_num}')
|
||||
|
||||
# Plot candlestick chart
|
||||
from mplfinance.original_flavor import candlestick_ohlc
|
||||
import matplotlib.dates as mdates
|
||||
|
||||
# Convert date to numerical format for candlestick
|
||||
df_ohlc = df.reset_index()
|
||||
df_ohlc['timestamp'] = df_ohlc['timestamp'].map(mdates.date2num)
|
||||
|
||||
candlestick_ohlc(ax1, df_ohlc[['timestamp', 'open', 'high', 'low', 'close']].values,
|
||||
width=0.6, colorup='green', colordown='red')
|
||||
|
||||
ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d %H:%M'))
|
||||
ax1.set_ylabel('Price (USD)')
|
||||
|
||||
# Plot buy/sell actions if available
|
||||
if hasattr(env, 'trades') and env.trades:
|
||||
# Filter trades that occurred in the visualization window
|
||||
recent_trades = [t for t in env.trades if t.get('timestamp', 0) >= df.index[0].timestamp() * 1000]
|
||||
|
||||
for trade in recent_trades:
|
||||
if trade['type'] == 'long':
|
||||
# Buy point
|
||||
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||
ax1.scatter(mdates.date2num(entry_time), trade['entry'],
|
||||
marker='^', color='green', s=100, label='Buy')
|
||||
|
||||
# Sell point if closed
|
||||
if 'exit' in trade and trade['exit'] > 0:
|
||||
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||
ax1.scatter(mdates.date2num(exit_time), trade['exit'],
|
||||
marker='v', color='blue', s=100, label='Sell Long')
|
||||
|
||||
# Draw line connecting entry and exit
|
||||
ax1.plot([mdates.date2num(entry_time), mdates.date2num(exit_time)],
|
||||
[trade['entry'], trade['exit']], 'g--', alpha=0.5)
|
||||
|
||||
elif trade['type'] == 'short':
|
||||
# Sell short point
|
||||
entry_time = pd.to_datetime(trade['entry_time'], unit='ms')
|
||||
ax1.scatter(mdates.date2num(entry_time), trade['entry'],
|
||||
marker='v', color='red', s=100, label='Sell Short')
|
||||
|
||||
# Buy to cover point if closed
|
||||
if 'exit' in trade and trade['exit'] > 0:
|
||||
exit_time = pd.to_datetime(trade['exit_time'], unit='ms')
|
||||
ax1.scatter(mdates.date2num(exit_time), trade['exit'],
|
||||
marker='^', color='orange', s=100, label='Buy to Cover')
|
||||
|
||||
# Draw line connecting entry and exit
|
||||
ax1.plot([mdates.date2num(entry_time), mdates.date2num(exit_time)],
|
||||
[trade['entry'], trade['exit']], 'r--', alpha=0.5)
|
||||
|
||||
# Plot predicted prices if available
|
||||
if hasattr(env, 'predicted_prices') and len(env.predicted_prices) > 0:
|
||||
ax2 = plt.subplot(3, 1, 2, sharex=ax1)
|
||||
ax2.set_title('Price Predictions vs Actual')
|
||||
|
||||
# Plot actual prices
|
||||
ax2.plot(df.index[-len(env.predicted_prices):], df['close'][-len(env.predicted_prices):],
|
||||
label='Actual Price', color='blue')
|
||||
|
||||
# Plot predicted prices
|
||||
# Align predictions with their corresponding timestamps
|
||||
prediction_dates = df.index[-len(env.predicted_prices)-1:-1]
|
||||
if len(prediction_dates) == len(env.predicted_prices):
|
||||
ax2.plot(prediction_dates, env.predicted_prices,
|
||||
label='Predicted Price', color='orange', linestyle='--')
|
||||
|
||||
# Calculate prediction error
|
||||
actual = df['close'][-len(env.predicted_prices)-1:-1].values
|
||||
predicted = env.predicted_prices
|
||||
mape = np.mean(np.abs((actual - predicted) / actual)) * 100
|
||||
ax2.set_ylabel('Price (USD)')
|
||||
ax2.set_title(f'Price Predictions vs Actual (MAPE: {mape:.2f}%)')
|
||||
ax2.legend()
|
||||
|
||||
# Plot technical indicators
|
||||
ax3 = plt.subplot(3, 1, 3, sharex=ax1)
|
||||
ax3.set_title('Technical Indicators')
|
||||
|
||||
# Plot RSI if available
|
||||
if 'rsi' in env.features and len(env.features['rsi']) > 0:
|
||||
rsi_values = env.features['rsi'][-len(df):]
|
||||
if len(rsi_values) == len(df):
|
||||
ax3.plot(df.index, rsi_values, label='RSI', color='purple')
|
||||
|
||||
# Add RSI overbought/oversold lines
|
||||
ax3.axhline(y=70, color='r', linestyle='-', alpha=0.3)
|
||||
ax3.axhline(y=30, color='g', linestyle='-', alpha=0.3)
|
||||
|
||||
# Plot MACD if available
|
||||
if 'macd' in env.features and 'macd_signal' in env.features:
|
||||
macd_values = env.features['macd'][-len(df):]
|
||||
signal_values = env.features['macd_signal'][-len(df):]
|
||||
|
||||
if len(macd_values) == len(df) and len(signal_values) == len(df):
|
||||
ax3.plot(df.index, macd_values, label='MACD', color='blue')
|
||||
ax3.plot(df.index, signal_values, label='Signal', color='red')
|
||||
|
||||
ax3.set_ylabel('Indicator Value')
|
||||
ax3.legend()
|
||||
|
||||
# Format x-axis
|
||||
plt.xticks(rotation=45)
|
||||
plt.tight_layout()
|
||||
|
||||
# Save the figure
|
||||
plt.savefig(f"visualizations/training_episode_{episode_num}.png")
|
||||
logger.info(f"Visualization saved for episode {episode_num}")
|
||||
|
||||
# Close the figure to free memory
|
||||
plt.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating visualization: {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
|
Loading…
x
Reference in New Issue
Block a user