309 lines
8.3 KiB
Markdown
309 lines
8.3 KiB
Markdown
# Model Inputs & Outputs Reference
|
||
|
||
Quick reference for all trading models in the system.
|
||
|
||
---
|
||
|
||
## 1. Transformer (AdvancedTradingTransformer)
|
||
|
||
**Type**: Sequence-to-sequence transformer for multi-timeframe analysis
|
||
**Size**: 46M parameters
|
||
**Architecture**: 12 layers, 16 attention heads, 1024 model dimension
|
||
|
||
### Inputs
|
||
```python
|
||
price_data: [batch, 150, 5] # OHLCV sequences (150 candles)
|
||
cob_data: [batch, 150, 100] # Change of Bid features
|
||
tech_data: [batch, 40] # Technical indicators (SMA, returns, volatility)
|
||
market_data: [batch, 30] # Market context (volume, pivots, support/resistance)
|
||
```
|
||
|
||
### Outputs
|
||
```python
|
||
action_logits: [batch, 3] # Raw logits for BUY(1), SELL(2), HOLD(0)
|
||
action_probs: [batch, 3] # Softmax probabilities
|
||
confidence: [batch, 1] # Trade confidence (0-1)
|
||
price_prediction: [batch, 1] # Future price target
|
||
volatility_prediction:[batch, 1] # Expected volatility
|
||
trend_strength: [batch, 1] # Trend strength (-1 to 1)
|
||
|
||
# Next candle predictions for each timeframe
|
||
next_candles: {
|
||
'1s': [batch, 5], # [open, high, low, close, volume]
|
||
'1m': [batch, 5],
|
||
'1h': [batch, 5],
|
||
'1d': [batch, 5]
|
||
}
|
||
|
||
# Pivot point predictions (L1-L5)
|
||
next_pivots: {
|
||
'L1': {
|
||
'price': [batch, 1],
|
||
'type_prob_high': [batch, 1], # Probability of high pivot
|
||
'type_prob_low': [batch, 1], # Probability of low pivot
|
||
'confidence': [batch, 1]
|
||
},
|
||
# ... L2, L3, L4, L5 (same structure)
|
||
}
|
||
|
||
# Trend vector analysis
|
||
trend_analysis: {
|
||
'angle_radians': [batch, 1], # Trend angle
|
||
'steepness': [batch, 1], # Trend steepness
|
||
'direction': [batch, 1] # Direction (-1 to 1)
|
||
}
|
||
```
|
||
|
||
### Training Targets
|
||
```python
|
||
actions: [batch] # Action labels (0=HOLD, 1=BUY, 2=SELL)
|
||
future_prices: [batch] # Price targets
|
||
trade_success: [batch, 1] # Success labels (0.0 or 1.0)
|
||
```
|
||
|
||
---
|
||
|
||
## 2. CNN (StandardizedCNN / EnhancedCNN)
|
||
|
||
**Type**: Convolutional neural network for pattern recognition
|
||
**Size**: ~5-10M parameters
|
||
**Architecture**: Multi-scale convolutions with attention
|
||
|
||
### Inputs
|
||
```python
|
||
# Via BaseDataInput.get_feature_vector()
|
||
feature_vector: [batch, 7834] # Flattened features containing:
|
||
- OHLCV ETH: 300 frames × 4 timeframes × 5 = 6000
|
||
- OHLCV BTC: 300 frames × 5 = 1500
|
||
- COB features: 184 (±20 buckets + MA imbalance)
|
||
- Technical indicators: 100 (padded)
|
||
- Last predictions: 50 (padded)
|
||
```
|
||
|
||
### Outputs
|
||
```python
|
||
action_logits: [batch, 3] # BUY, SELL, HOLD logits
|
||
action_probs: [batch, 3] # Softmax probabilities
|
||
confidence: [batch, 1] # Prediction confidence
|
||
hidden_states: [batch, 1024] # Feature embeddings (for cross-model feeding)
|
||
predicted_returns: [batch, 4] # [return_1s, return_1m, return_1h, return_1d]
|
||
```
|
||
|
||
### Training Targets
|
||
```python
|
||
actions: [batch] # Action labels (0=HOLD, 1=BUY, 2=SELL)
|
||
returns: [batch, 4] # Actual returns per timeframe
|
||
```
|
||
|
||
---
|
||
|
||
## 3. DQN (Deep Q-Network Agent)
|
||
|
||
**Type**: Reinforcement learning agent for sequential decision making
|
||
**Size**: ~15M parameters
|
||
**Architecture**: Deep Q-Network with dueling architecture
|
||
|
||
### Inputs
|
||
```python
|
||
# Via BaseDataInput.get_feature_vector()
|
||
state: [batch, 7850] # Full feature vector including:
|
||
- Multi-timeframe OHLCV data
|
||
- COB features
|
||
- Technical indicators
|
||
- Market regime indicators
|
||
- Previous predictions
|
||
```
|
||
|
||
### Outputs
|
||
```python
|
||
q_values: [batch, 3] # Q-values for BUY, SELL, HOLD
|
||
action: int # Selected action (0, 1, 2)
|
||
confidence: float # Action confidence (0-1)
|
||
|
||
# Auxiliary outputs
|
||
regime_probs: [batch, 4] # [trending, ranging, volatile, mixed]
|
||
price_direction:[batch, 3] # [down, neutral, up]
|
||
volatility: [batch, 1] # Predicted volatility
|
||
value: [batch, 1] # State value (V)
|
||
advantage: [batch, 3] # Action advantages (A)
|
||
```
|
||
|
||
### Training Targets
|
||
```python
|
||
# RL uses experience replay
|
||
experience: {
|
||
'state': [7850],
|
||
'action': int,
|
||
'reward': float,
|
||
'next_state': [7850],
|
||
'done': bool
|
||
}
|
||
```
|
||
|
||
---
|
||
|
||
## 4. COB RL Model (MassiveRLNetwork)
|
||
|
||
**Type**: Specialized RL for Change of Bid (COB) data
|
||
**Size**: ~3M parameters
|
||
**Architecture**: Deep network focused on order book dynamics
|
||
|
||
### Inputs
|
||
```python
|
||
cob_features: [batch, input_size] # COB-specific features:
|
||
- Bid/ask imbalance
|
||
- Order book depth
|
||
- Price level changes
|
||
- Volume at price levels
|
||
- Moving averages of imbalance
|
||
```
|
||
|
||
### Outputs
|
||
```python
|
||
price_logits: [batch, 3] # Direction logits [DOWN, SIDEWAYS, UP]
|
||
price_probs: [batch, 3] # Direction probabilities
|
||
confidence: [batch, 1] # Prediction confidence
|
||
value: [batch, 1] # State value estimate
|
||
predicted_direction: int # 0=DOWN, 1=SIDEWAYS, 2=UP
|
||
```
|
||
|
||
### Training Targets
|
||
```python
|
||
targets: {
|
||
'direction': [batch], # Direction labels (0, 1, 2)
|
||
'value': [batch], # Value targets
|
||
'confidence': [batch] # Confidence targets
|
||
}
|
||
```
|
||
|
||
---
|
||
|
||
## 5. Extrema Trainer
|
||
|
||
**Type**: Pivot point detection and prediction
|
||
**Size**: ~1M parameters (lightweight)
|
||
**Architecture**: Statistical + ML hybrid
|
||
|
||
### Inputs
|
||
```python
|
||
# Context data (200 candles)
|
||
context: {
|
||
'symbol': str,
|
||
'candles': deque[200], # Recent OHLCV candles
|
||
'features': array, # Extracted features
|
||
'last_update': datetime
|
||
}
|
||
|
||
# For prediction
|
||
current_price: float
|
||
now: datetime
|
||
```
|
||
|
||
### Outputs
|
||
```python
|
||
# Detected extrema
|
||
extrema: {
|
||
'type': str, # 'high' or 'low'
|
||
'price': float,
|
||
'timestamp': datetime,
|
||
'confidence': float, # 0-1
|
||
'window_size': int
|
||
}
|
||
|
||
# Predicted pivot
|
||
predicted_pivot: {
|
||
'type': str, # 'high' or 'low'
|
||
'price': float, # Predicted price level
|
||
'timestamp': datetime, # Predicted time
|
||
'confidence': float, # 0-1
|
||
'horizon_seconds': int # Time until pivot (30-300s)
|
||
}
|
||
```
|
||
|
||
### Training Data
|
||
```python
|
||
# Historical extrema for validation
|
||
historical_extrema: List[{
|
||
'price': float,
|
||
'timestamp': datetime,
|
||
'type': str,
|
||
'detected': bool
|
||
}]
|
||
```
|
||
|
||
---
|
||
|
||
## Common Patterns
|
||
|
||
### Action Encoding (All Models)
|
||
```python
|
||
0 = HOLD # No action / maintain position
|
||
1 = BUY # Enter long / close short
|
||
2 = SELL # Enter short / close long
|
||
```
|
||
|
||
### Confidence Scores
|
||
- Range: `0.0` to `1.0`
|
||
- Typical threshold: `0.6` (60%)
|
||
- High confidence: `> 0.8`
|
||
- Low confidence: `< 0.4`
|
||
|
||
### Batch Sizes
|
||
- **Training**: Usually `1` (annotation-based) or `32-128` (batch training)
|
||
- **Inference**: Usually `1` (real-time prediction)
|
||
|
||
### Device Management
|
||
All models support:
|
||
- CPU: `torch.device('cpu')`
|
||
- CUDA: `torch.device('cuda')`
|
||
- Automatic device selection based on availability
|
||
|
||
---
|
||
|
||
## Model Selection Guide
|
||
|
||
| Use Case | Recommended Model | Why |
|
||
|----------|------------------|-----|
|
||
| Multi-timeframe analysis | **Transformer** | Handles 150-candle sequences across timeframes |
|
||
| Pattern recognition | **CNN** | Excellent at visual pattern detection |
|
||
| Sequential decisions | **DQN** | Learns optimal action sequences via RL |
|
||
| Order book dynamics | **COB RL** | Specialized for bid/ask imbalance |
|
||
| Pivot detection | **Extrema** | Lightweight, fast pivot predictions |
|
||
|
||
---
|
||
|
||
## Integration Example
|
||
|
||
```python
|
||
# Get base data input
|
||
base_input = data_provider.get_base_data_input(symbol, timestamp)
|
||
|
||
# CNN prediction
|
||
cnn_features = base_input.get_feature_vector()
|
||
cnn_output = cnn_model(cnn_features)
|
||
cnn_action = torch.argmax(cnn_output['action_probs'])
|
||
|
||
# Transformer prediction
|
||
transformer_batch = prepare_transformer_batch(base_input)
|
||
transformer_output = transformer_model(**transformer_batch)
|
||
transformer_action = torch.argmax(transformer_output['action_probs'])
|
||
|
||
# DQN prediction
|
||
dqn_state = base_input.get_feature_vector()
|
||
dqn_output = dqn_agent.select_action(dqn_state)
|
||
dqn_action = dqn_output['action']
|
||
|
||
# Ensemble decision
|
||
final_action = majority_vote([cnn_action, transformer_action, dqn_action])
|
||
```
|
||
|
||
---
|
||
|
||
## Notes
|
||
|
||
1. **Shape Conventions**: `[batch, ...]` indicates batch dimension first
|
||
2. **Dtype**: All tensors use `torch.float32` unless specified
|
||
3. **Gradients**: Only training targets require gradients
|
||
4. **Normalization**: Features are typically normalized to `[-1, 1]` or `[0, 1]`
|
||
5. **Missing Data**: Padded with zeros or last known values
|