cleanup models; beef up models to 500M
This commit is contained in:
@ -110,108 +110,213 @@ class EnhancedCNN(nn.Module):
|
||||
logger.info(f"EnhancedCNN initialized with input shape: {input_shape}, actions: {n_actions}")
|
||||
|
||||
def _build_network(self):
|
||||
"""Build the enhanced neural network with current feature dimensions"""
|
||||
"""Build the MASSIVELY enhanced neural network for 4GB VRAM budget"""
|
||||
|
||||
# 1D CNN for sequential data
|
||||
# MASSIVELY SCALED ARCHITECTURE for 4GB VRAM (up to ~50M parameters)
|
||||
if self.channels > 1:
|
||||
# Reshape expected: [batch, timeframes, features]
|
||||
# Massive convolutional backbone with deeper residual blocks
|
||||
self.conv_layers = nn.Sequential(
|
||||
nn.Conv1d(self.channels, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(64),
|
||||
# Initial large conv block
|
||||
nn.Conv1d(self.channels, 256, kernel_size=7, padding=3), # Much wider initial layer
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.1),
|
||||
|
||||
# First residual stage - 256 channels
|
||||
ResidualBlock(256, 512),
|
||||
ResidualBlock(512, 512),
|
||||
ResidualBlock(512, 512),
|
||||
nn.MaxPool1d(kernel_size=2, stride=2),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
ResidualBlock(64, 128),
|
||||
# Second residual stage - 512 channels
|
||||
ResidualBlock(512, 1024),
|
||||
ResidualBlock(1024, 1024),
|
||||
ResidualBlock(1024, 1024),
|
||||
nn.MaxPool1d(kernel_size=2, stride=2),
|
||||
nn.Dropout(0.25),
|
||||
|
||||
# Third residual stage - 1024 channels
|
||||
ResidualBlock(1024, 1536),
|
||||
ResidualBlock(1536, 1536),
|
||||
ResidualBlock(1536, 1536),
|
||||
nn.MaxPool1d(kernel_size=2, stride=2),
|
||||
nn.Dropout(0.3),
|
||||
|
||||
ResidualBlock(128, 256),
|
||||
nn.MaxPool1d(kernel_size=2, stride=2),
|
||||
nn.Dropout(0.4),
|
||||
|
||||
ResidualBlock(256, 512),
|
||||
# Fourth residual stage - 1536 channels (MASSIVE)
|
||||
ResidualBlock(1536, 2048),
|
||||
ResidualBlock(2048, 2048),
|
||||
ResidualBlock(2048, 2048),
|
||||
nn.AdaptiveAvgPool1d(1) # Global average pooling
|
||||
)
|
||||
# Feature dimension after conv layers
|
||||
self.conv_features = 512
|
||||
# Massive feature dimension after conv layers
|
||||
self.conv_features = 2048
|
||||
else:
|
||||
# For 1D vectors, skip the convolutional part
|
||||
# For 1D vectors, use massive dense preprocessing
|
||||
self.conv_layers = None
|
||||
self.conv_features = 0
|
||||
|
||||
# Fully connected layers for all cases
|
||||
# We'll use deeper layers with skip connections
|
||||
# MASSIVE fully connected feature extraction layers
|
||||
if self.conv_layers is None:
|
||||
# For 1D inputs without conv preprocessing
|
||||
self.fc1 = nn.Linear(self.feature_dim, 512)
|
||||
self.features_dim = 512
|
||||
# For 1D inputs - massive feature extraction
|
||||
self.fc1 = nn.Linear(self.feature_dim, 2048)
|
||||
self.features_dim = 2048
|
||||
else:
|
||||
# For data processed by conv layers
|
||||
self.fc1 = nn.Linear(self.conv_features, 512)
|
||||
self.features_dim = 512
|
||||
# For data processed by massive conv layers
|
||||
self.fc1 = nn.Linear(self.conv_features, 2048)
|
||||
self.features_dim = 2048
|
||||
|
||||
# Common feature extraction layers
|
||||
# MASSIVE common feature extraction with multiple attention layers
|
||||
self.fc_layers = nn.Sequential(
|
||||
self.fc1,
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.4),
|
||||
nn.Linear(512, 512),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(2048, 2048), # Keep massive width
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.4),
|
||||
nn.Linear(512, 256),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(2048, 1536), # Still very wide
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1536, 1024), # Large hidden layer
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1024, 768), # Final feature representation
|
||||
nn.ReLU()
|
||||
)
|
||||
|
||||
# Dueling architecture
|
||||
# Multiple attention mechanisms for different aspects
|
||||
self.price_attention = SelfAttention(768)
|
||||
self.volume_attention = SelfAttention(768)
|
||||
self.trend_attention = SelfAttention(768)
|
||||
self.volatility_attention = SelfAttention(768)
|
||||
|
||||
# Attention fusion layer
|
||||
self.attention_fusion = nn.Sequential(
|
||||
nn.Linear(768 * 4, 1024), # Combine all attention outputs
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(1024, 768)
|
||||
)
|
||||
|
||||
# MASSIVE dueling architecture with deeper networks
|
||||
self.advantage_stream = nn.Sequential(
|
||||
nn.Linear(768, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, self.n_actions)
|
||||
)
|
||||
|
||||
self.value_stream = nn.Sequential(
|
||||
nn.Linear(768, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 1)
|
||||
)
|
||||
|
||||
# Extrema detection head with increased capacity
|
||||
# MASSIVE extrema detection head with ensemble predictions
|
||||
self.extrema_head = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.Linear(768, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 3) # 0=bottom, 1=top, 2=neither
|
||||
)
|
||||
|
||||
# Price prediction heads with increased capacity
|
||||
# MASSIVE multi-timeframe price prediction heads
|
||||
self.price_pred_immediate = nn.Sequential(
|
||||
nn.Linear(256, 64),
|
||||
nn.Linear(768, 256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 3) # Up, Down, Sideways
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 3) # Up, Down, Sideways
|
||||
)
|
||||
|
||||
self.price_pred_midterm = nn.Sequential(
|
||||
nn.Linear(256, 64),
|
||||
nn.Linear(768, 256),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 3) # Up, Down, Sideways
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 3) # Up, Down, Sideways
|
||||
)
|
||||
|
||||
self.price_pred_longterm = nn.Sequential(
|
||||
nn.Linear(256, 64),
|
||||
nn.ReLU(),
|
||||
nn.Linear(64, 3) # Up, Down, Sideways
|
||||
)
|
||||
|
||||
# Value prediction with increased capacity
|
||||
self.price_pred_value = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.Linear(768, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(128, 4) # % change for different timeframes
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 3) # Up, Down, Sideways
|
||||
)
|
||||
|
||||
# Additional attention layer for feature refinement
|
||||
self.attention = SelfAttention(256)
|
||||
# MASSIVE value prediction with ensemble approaches
|
||||
self.price_pred_value = nn.Sequential(
|
||||
nn.Linear(768, 512),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(512, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 8) # More granular % change predictions for different timeframes
|
||||
)
|
||||
|
||||
# Additional specialized prediction heads for better accuracy
|
||||
# Volatility prediction head
|
||||
self.volatility_head = nn.Sequential(
|
||||
nn.Linear(768, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 5) # Very low, low, medium, high, very high volatility
|
||||
)
|
||||
|
||||
# Support/Resistance level detection head
|
||||
self.support_resistance_head = nn.Sequential(
|
||||
nn.Linear(768, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 6) # Strong support, weak support, neutral, weak resistance, strong resistance, breakout
|
||||
)
|
||||
|
||||
# Market regime classification head
|
||||
self.market_regime_head = nn.Sequential(
|
||||
nn.Linear(768, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 7) # Bull trend, bear trend, sideways, volatile up, volatile down, accumulation, distribution
|
||||
)
|
||||
|
||||
# Risk assessment head
|
||||
self.risk_head = nn.Sequential(
|
||||
nn.Linear(768, 256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Linear(128, 4) # Low risk, medium risk, high risk, extreme risk
|
||||
)
|
||||
|
||||
def _check_rebuild_network(self, features):
|
||||
"""Check if network needs to be rebuilt for different feature dimensions"""
|
||||
@ -225,7 +330,7 @@ class EnhancedCNN(nn.Module):
|
||||
return False
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward pass through the network"""
|
||||
"""Forward pass through the MASSIVE network"""
|
||||
batch_size = x.size(0)
|
||||
|
||||
# Process different input shapes
|
||||
@ -243,7 +348,7 @@ class EnhancedCNN(nn.Module):
|
||||
total_features = x_reshaped.size(1) * x_reshaped.size(2)
|
||||
self._check_rebuild_network(total_features)
|
||||
|
||||
# Apply convolutions
|
||||
# Apply massive convolutions
|
||||
x_conv = self.conv_layers(x_reshaped)
|
||||
# Flatten: [batch, channels, 1] -> [batch, channels]
|
||||
x_flat = x_conv.view(batch_size, -1)
|
||||
@ -258,31 +363,59 @@ class EnhancedCNN(nn.Module):
|
||||
if x_flat.size(1) != self.feature_dim:
|
||||
self._check_rebuild_network(x_flat.size(1))
|
||||
|
||||
# Apply FC layers
|
||||
features = self.fc_layers(x_flat)
|
||||
# Apply MASSIVE FC layers to get base features
|
||||
features = self.fc_layers(x_flat) # [batch, 768]
|
||||
|
||||
# Add attention for feature refinement
|
||||
features_3d = features.unsqueeze(1) # [batch, 1, features]
|
||||
features_attended, _ = self.attention(features_3d)
|
||||
features_refined = features_attended.squeeze(1) # [batch, features]
|
||||
# Apply multiple specialized attention mechanisms
|
||||
features_3d = features.unsqueeze(1) # [batch, 1, 768]
|
||||
|
||||
# Calculate advantage and value
|
||||
# Get attention-refined features for different aspects
|
||||
price_features, _ = self.price_attention(features_3d)
|
||||
price_features = price_features.squeeze(1) # [batch, 768]
|
||||
|
||||
volume_features, _ = self.volume_attention(features_3d)
|
||||
volume_features = volume_features.squeeze(1) # [batch, 768]
|
||||
|
||||
trend_features, _ = self.trend_attention(features_3d)
|
||||
trend_features = trend_features.squeeze(1) # [batch, 768]
|
||||
|
||||
volatility_features, _ = self.volatility_attention(features_3d)
|
||||
volatility_features = volatility_features.squeeze(1) # [batch, 768]
|
||||
|
||||
# Fuse all attention outputs
|
||||
combined_attention = torch.cat([
|
||||
price_features, volume_features,
|
||||
trend_features, volatility_features
|
||||
], dim=1) # [batch, 768*4]
|
||||
|
||||
# Apply attention fusion to get final refined features
|
||||
features_refined = self.attention_fusion(combined_attention) # [batch, 768]
|
||||
|
||||
# Calculate advantage and value (Dueling DQN architecture)
|
||||
advantage = self.advantage_stream(features_refined)
|
||||
value = self.value_stream(features_refined)
|
||||
|
||||
# Combine for Q-values (Dueling architecture)
|
||||
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
||||
|
||||
# Get extrema predictions
|
||||
# Get massive ensemble of predictions
|
||||
|
||||
# Extrema predictions (bottom/top/neither detection)
|
||||
extrema_pred = self.extrema_head(features_refined)
|
||||
|
||||
# Price movement predictions
|
||||
# Multi-timeframe price movement predictions
|
||||
price_immediate = self.price_pred_immediate(features_refined)
|
||||
price_midterm = self.price_pred_midterm(features_refined)
|
||||
price_longterm = self.price_pred_longterm(features_refined)
|
||||
price_values = self.price_pred_value(features_refined)
|
||||
|
||||
# Package price predictions
|
||||
# Additional specialized predictions for enhanced accuracy
|
||||
volatility_pred = self.volatility_head(features_refined)
|
||||
support_resistance_pred = self.support_resistance_head(features_refined)
|
||||
market_regime_pred = self.market_regime_head(features_refined)
|
||||
risk_pred = self.risk_head(features_refined)
|
||||
|
||||
# Package all price predictions
|
||||
price_predictions = {
|
||||
'immediate': price_immediate,
|
||||
'midterm': price_midterm,
|
||||
@ -290,31 +423,60 @@ class EnhancedCNN(nn.Module):
|
||||
'values': price_values
|
||||
}
|
||||
|
||||
return q_values, extrema_pred, price_predictions, features_refined
|
||||
# Package additional predictions for enhanced decision making
|
||||
advanced_predictions = {
|
||||
'volatility': volatility_pred,
|
||||
'support_resistance': support_resistance_pred,
|
||||
'market_regime': market_regime_pred,
|
||||
'risk_assessment': risk_pred
|
||||
}
|
||||
|
||||
return q_values, extrema_pred, price_predictions, features_refined, advanced_predictions
|
||||
|
||||
def act(self, state, explore=True):
|
||||
"""
|
||||
Choose action based on state with confidence thresholding
|
||||
"""
|
||||
"""Enhanced action selection with massive model predictions"""
|
||||
if explore and np.random.random() < 0.1: # 10% random exploration
|
||||
return np.random.choice(self.n_actions)
|
||||
|
||||
self.eval()
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
q_values, _, _, _ = self(state_tensor)
|
||||
q_values, extrema_pred, price_predictions, features, advanced_predictions = self(state_tensor)
|
||||
|
||||
# Apply softmax to get action probabilities
|
||||
action_probs = F.softmax(q_values, dim=1)
|
||||
action_probs = torch.softmax(q_values, dim=1)
|
||||
action = torch.argmax(action_probs, dim=1).item()
|
||||
|
||||
# Get action with highest probability
|
||||
action = action_probs.argmax(dim=1).item()
|
||||
action_confidence = action_probs[0, action].item()
|
||||
# Log advanced predictions for better decision making
|
||||
if hasattr(self, '_log_predictions') and self._log_predictions:
|
||||
# Log volatility prediction
|
||||
volatility = torch.softmax(advanced_predictions['volatility'], dim=1)
|
||||
volatility_class = torch.argmax(volatility, dim=1).item()
|
||||
volatility_labels = ['Very Low', 'Low', 'Medium', 'High', 'Very High']
|
||||
|
||||
# Log support/resistance prediction
|
||||
sr = torch.softmax(advanced_predictions['support_resistance'], dim=1)
|
||||
sr_class = torch.argmax(sr, dim=1).item()
|
||||
sr_labels = ['Strong Support', 'Weak Support', 'Neutral', 'Weak Resistance', 'Strong Resistance', 'Breakout']
|
||||
|
||||
# Log market regime prediction
|
||||
regime = torch.softmax(advanced_predictions['market_regime'], dim=1)
|
||||
regime_class = torch.argmax(regime, dim=1).item()
|
||||
regime_labels = ['Bull Trend', 'Bear Trend', 'Sideways', 'Volatile Up', 'Volatile Down', 'Accumulation', 'Distribution']
|
||||
|
||||
# Log risk assessment
|
||||
risk = torch.softmax(advanced_predictions['risk_assessment'], dim=1)
|
||||
risk_class = torch.argmax(risk, dim=1).item()
|
||||
risk_labels = ['Low Risk', 'Medium Risk', 'High Risk', 'Extreme Risk']
|
||||
|
||||
logger.info(f"MASSIVE Model Predictions:")
|
||||
logger.info(f" Volatility: {volatility_labels[volatility_class]} ({volatility[0, volatility_class]:.3f})")
|
||||
logger.info(f" Support/Resistance: {sr_labels[sr_class]} ({sr[0, sr_class]:.3f})")
|
||||
logger.info(f" Market Regime: {regime_labels[regime_class]} ({regime[0, regime_class]:.3f})")
|
||||
logger.info(f" Risk Level: {risk_labels[risk_class]} ({risk[0, risk_class]:.3f})")
|
||||
|
||||
# Check if confidence exceeds threshold
|
||||
if action_confidence < self.confidence_threshold:
|
||||
# Force HOLD action (typically action 2)
|
||||
action = 2 # Assume 2 is HOLD
|
||||
logger.info(f"Action {action} confidence {action_confidence:.4f} below threshold {self.confidence_threshold}, forcing HOLD")
|
||||
|
||||
return action, action_confidence
|
||||
return action
|
||||
|
||||
def save(self, path):
|
||||
"""Save model weights and architecture"""
|
||||
|
Reference in New Issue
Block a user