enhancements

This commit is contained in:
Dobromir Popov
2025-04-01 13:46:53 +03:00
parent a46b2c74f8
commit 73c5ecb0d2
17 changed files with 2279 additions and 736 deletions

305
NN/README_TRADING.md Normal file
View File

@ -0,0 +1,305 @@
# Trading Agent System
A modular, extensible cryptocurrency trading system that can connect to multiple exchanges through a common interface.
## Architecture
The trading agent system consists of the following components:
### Exchange Interfaces
- `ExchangeInterface`: Abstract base class that defines the common interface for all exchange implementations
- `BinanceInterface`: Implementation for the Binance exchange (supports both mainnet and testnet)
- `MEXCInterface`: Implementation for the MEXC exchange
### Trading Agent
- `TradingAgent`: Main class that manages trading operations, including position sizing, risk management, and signal processing
### Neural Network Orchestrator
- `NeuralNetworkOrchestrator`: Coordinates between neural network models and the trading agent to generate and process trading signals
## Getting Started
### Installation
The trading agent system is built into the main application. No additional installation is needed beyond the requirements for the main application.
### Configuration
Configuration can be provided via:
1. Command-line arguments
2. Environment variables
3. Configuration file
#### Example Configuration
```json
{
"exchange": "binance",
"api_key": "your_api_key",
"api_secret": "your_api_secret",
"test_mode": true,
"trade_symbols": ["BTC/USDT", "ETH/USDT"],
"position_size": 0.1,
"max_trades_per_day": 5,
"trade_cooldown_minutes": 60
}
```
### Running the Trading System
#### From Command Line
```bash
# Run with Binance in test mode
python trading_main.py --exchange binance --test-mode
# Run with MEXC in production mode
python trading_main.py --exchange mexc --api-key YOUR_API_KEY --api-secret YOUR_API_SECRET
# Run with custom position sizing and limits
python trading_main.py --exchange binance --test-mode --position-size 0.05 --max-trades-per-day 3 --trade-cooldown 120
```
#### Using Environment Variables
You can set these environment variables for configuration:
```bash
# Set exchange API credentials
export BINANCE_API_KEY=your_binance_api_key
export BINANCE_API_SECRET=your_binance_api_secret
# Enable neural network models
export ENABLE_NN_MODELS=1
export NN_INFERENCE_INTERVAL=60
export NN_MODEL_TYPE=cnn
export NN_TIMEFRAME=1h
# Run the trading system
python trading_main.py
```
### Basic Usage Examples
#### Creating a Trading Agent
```python
from NN.trading_agent import TradingAgent
# Initialize a trading agent for Binance testnet
agent = TradingAgent(
exchange_name="binance",
api_key="your_api_key",
api_secret="your_api_secret",
test_mode=True,
trade_symbols=["BTC/USDT", "ETH/USDT"],
position_size=0.1,
max_trades_per_day=5,
trade_cooldown_minutes=60
)
# Start the trading agent
agent.start()
# Process a signal
agent.process_signal(
symbol="BTC/USDT",
action="BUY",
confidence=0.85
)
# Get current positions
positions = agent.get_current_positions()
print(f"Current positions: {positions}")
# Stop the trading agent
agent.stop()
```
#### Using an Exchange Interface Directly
```python
from NN.exchanges import BinanceInterface
# Initialize the Binance interface
exchange = BinanceInterface(
api_key="your_api_key",
api_secret="your_api_secret",
test_mode=True
)
# Connect to the exchange
exchange.connect()
# Get ticker info
ticker = exchange.get_ticker("BTC/USDT")
print(f"Current BTC price: {ticker['last']}")
# Get account balance
btc_balance = exchange.get_balance("BTC")
usdt_balance = exchange.get_balance("USDT")
print(f"BTC balance: {btc_balance}")
print(f"USDT balance: {usdt_balance}")
# Place a market order
order = exchange.place_order(
symbol="BTC/USDT",
side="buy",
order_type="market",
quantity=0.001
)
```
## Testing the Exchange Interfaces
The system includes a test script that can be used to verify that exchange interfaces are working correctly:
```bash
# Test Binance interface in test mode (no real trades)
python -m NN.exchanges.trading_agent_test --exchange binance --test-mode
# Test MEXC interface in test mode
python -m NN.exchanges.trading_agent_test --exchange mexc --test-mode
# Test with actual trades (use with caution!)
python -m NN.exchanges.trading_agent_test --exchange binance --test-mode --execute-trades --test-trade-amount 0.001
```
## Adding a New Exchange
To add support for a new exchange, you need to create a new class that inherits from `ExchangeInterface` and implements all the required methods:
1. Create a new file in the `NN/exchanges` directory (e.g., `kraken_interface.py`)
2. Implement the required methods (see `exchange_interface.py` for the specifications)
3. Add the new exchange to the imports in `__init__.py`
4. Update the `_create_exchange` method in `TradingAgent` to support the new exchange
### Example of a New Exchange Implementation
```python
# NN/exchanges/kraken_interface.py
import logging
from typing import Dict, Any, List, Optional
from .exchange_interface import ExchangeInterface
logger = logging.getLogger(__name__)
class KrakenInterface(ExchangeInterface):
"""Kraken Exchange API Interface"""
def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True):
super().__init__(api_key, api_secret, test_mode)
self.base_url = "https://api.kraken.com"
# Initialize other Kraken-specific properties
def connect(self) -> bool:
# Implement connection to Kraken API
pass
def get_balance(self, asset: str) -> float:
# Implement getting balance for an asset
pass
def get_ticker(self, symbol: str) -> Dict[str, Any]:
# Implement getting ticker data
pass
def place_order(self, symbol: str, side: str, order_type: str,
quantity: float, price: float = None) -> Dict[str, Any]:
# Implement placing an order
pass
def cancel_order(self, symbol: str, order_id: str) -> bool:
# Implement cancelling an order
pass
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
# Implement getting order status
pass
def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]:
# Implement getting open orders
pass
```
Then update the imports in `__init__.py`:
```python
from .exchange_interface import ExchangeInterface
from .binance_interface import BinanceInterface
from .mexc_interface import MEXCInterface
from .kraken_interface import KrakenInterface
__all__ = ['ExchangeInterface', 'BinanceInterface', 'MEXCInterface', 'KrakenInterface']
```
And update the `_create_exchange` method in `TradingAgent`:
```python
def _create_exchange(self) -> ExchangeInterface:
"""Create an exchange interface based on the exchange name."""
if self.exchange_name == 'mexc':
return MEXCInterface(
api_key=self.api_key,
api_secret=self.api_secret,
test_mode=self.test_mode
)
elif self.exchange_name == 'binance':
return BinanceInterface(
api_key=self.api_key,
api_secret=self.api_secret,
test_mode=self.test_mode
)
elif self.exchange_name == 'kraken':
return KrakenInterface(
api_key=self.api_key,
api_secret=self.api_secret,
test_mode=self.test_mode
)
else:
raise ValueError(f"Unsupported exchange: {self.exchange_name}")
```
## Security Considerations
- **API Keys**: Never hardcode API keys in your code. Use environment variables or secure storage.
- **Permissions**: Restrict API key permissions to only what is needed (e.g., trading, but not withdrawals).
- **Testing**: Always test with small amounts and use test mode/testnet when possible.
- **Position Sizing**: Implement conservative position sizing to manage risk.
- **Monitoring**: Set up monitoring and alerting for your trading system.
## Troubleshooting
### Common Issues
1. **Connection Problems**: Make sure you have internet connectivity and correct API credentials.
2. **Order Placement Errors**: Check for sufficient funds, correct symbol format, and valid order parameters.
3. **Rate Limiting**: Avoid making too many API requests in a short period to prevent being rate-limited.
### Logging
The trading agent system uses Python's logging module with different levels:
- **DEBUG**: Detailed information, typically useful for diagnosing problems.
- **INFO**: Confirmation that things are working as expected.
- **WARNING**: Indication that something unexpected happened, but the program can still function.
- **ERROR**: Due to a more serious problem, the program has failed to perform some function.
You can adjust the logging level in your trading script:
```python
import logging
logging.basicConfig(
level=logging.DEBUG, # Change to INFO, WARNING, or ERROR as needed
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("trading.log"),
logging.StreamHandler()
]
)
```

162
NN/exchanges/README.md Normal file
View File

@ -0,0 +1,162 @@
# Trading Agent System
This directory contains the implementation of a modular trading agent system that integrates with the neural network models and can execute trades on various cryptocurrency exchanges.
## Overview
The trading agent system is designed to:
1. Connect to different cryptocurrency exchanges using a common interface
2. Execute trades based on signals from neural network models
3. Manage risk through position sizing, trade limits, and cooldown periods
4. Monitor and report on trading activity
## Components
### Exchange Interfaces
- `ExchangeInterface`: Abstract base class defining the common interface for all exchange implementations
- `BinanceInterface`: Implementation for the Binance exchange, with support for both mainnet and testnet
- `MEXCInterface`: Implementation for the MEXC exchange
### Trading Agent
The `TradingAgent` class (`trading_agent.py`) manages trading activities:
- Connects to the configured exchange
- Processes trading signals from neural network models
- Applies trading rules and risk management
- Tracks and reports trading performance
### Neural Network Orchestrator
The `NeuralNetworkOrchestrator` class (`neural_network_orchestrator.py`) coordinates between models and trading:
- Manages the neural network inference process
- Routes model signals to the trading agent
- Provides integration with the RealTimeChart for visualization
## Usage
### Basic Usage
```python
from NN.exchanges import BinanceInterface, MEXCInterface
from NN.trading_agent import TradingAgent
# Initialize an exchange interface
exchange = BinanceInterface(
api_key="your_api_key",
api_secret="your_api_secret",
test_mode=True # Use testnet
)
# Connect to the exchange
exchange.connect()
# Create a trading agent
agent = TradingAgent(
exchange_name="binance",
api_key="your_api_key",
api_secret="your_api_secret",
test_mode=True,
trade_symbols=["BTC/USDT", "ETH/USDT"],
position_size=0.1,
max_trades_per_day=5,
trade_cooldown_minutes=60
)
# Start the trading agent
agent.start()
# Process a trading signal
agent.process_signal(
symbol="BTC/USDT",
action="BUY",
confidence=0.85,
timestamp=int(time.time())
)
# Stop the trading agent when done
agent.stop()
```
### Integration with Neural Network Models
The system is designed to be integrated with neural network models through the `NeuralNetworkOrchestrator`:
```python
from NN.neural_network_orchestrator import NeuralNetworkOrchestrator
# Configure exchange
exchange_config = {
"exchange": "binance",
"api_key": "your_api_key",
"api_secret": "your_api_secret",
"test_mode": True,
"trade_symbols": ["BTC/USDT", "ETH/USDT"],
"position_size": 0.1,
"max_trades_per_day": 5,
"trade_cooldown_minutes": 60
}
# Initialize orchestrator
orchestrator = NeuralNetworkOrchestrator(
model=model,
data_interface=data_interface,
chart=chart,
symbols=["BTC/USDT", "ETH/USDT"],
timeframes=["1m", "5m", "1h", "4h", "1d"],
exchange_config=exchange_config
)
# Start inference and trading
orchestrator.start_inference()
```
## Configuration
### Exchange-Specific Configuration
- **Binance**: Supports both mainnet and testnet environments
- **MEXC**: Supports mainnet only (no test environment available)
### Trading Agent Configuration
- `exchange_name`: Name of exchange ('binance', 'mexc')
- `api_key`: API key for the exchange
- `api_secret`: API secret for the exchange
- `test_mode`: Whether to use test/sandbox environment
- `trade_symbols`: List of trading symbols to monitor
- `position_size`: Size of each position as a fraction of balance (0.0-1.0)
- `max_trades_per_day`: Maximum number of trades to execute per day
- `trade_cooldown_minutes`: Minimum time between trades in minutes
## Adding New Exchanges
To add support for a new exchange:
1. Create a new class that inherits from `ExchangeInterface`
2. Implement all required methods (see `exchange_interface.py`)
3. Add the new exchange to the imports in `__init__.py`
4. Update the `_create_exchange` method in `TradingAgent` to support the new exchange
Example:
```python
class KrakenInterface(ExchangeInterface):
"""Kraken Exchange API Interface"""
def __init__(self, api_key=None, api_secret=None, test_mode=True):
super().__init__(api_key, api_secret, test_mode)
# Initialize Kraken-specific attributes
# Implement all required methods...
```
## Security Considerations
- API keys should have trade permissions but not withdrawal permissions
- Use environment variables or secure storage for API credentials
- Always test with small position sizes before deploying with larger amounts
- Consider using test mode/testnet for initial testing

View File

@ -26,10 +26,17 @@ class MEXCInterface(ExchangeInterface):
self.api_version = "v3"
def connect(self) -> bool:
"""Connect to MEXC API. This is a no-op for REST API."""
"""Connect to MEXC API."""
if not self.api_key or not self.api_secret:
logger.warning("MEXC API credentials not provided. Running in read-only mode.")
return False
try:
# Test public API connection by getting ticker data for BTC/USDT
self.get_ticker("BTC/USDT")
logger.info("Successfully connected to MEXC API in read-only mode")
return True
except Exception as e:
logger.error(f"Failed to connect to MEXC API in read-only mode: {str(e)}")
return False
try:
# Test connection by getting account info
@ -141,22 +148,69 @@ class MEXCInterface(ExchangeInterface):
dict: Ticker data including price information
"""
mexc_symbol = symbol.replace('/', '')
try:
ticker = self._send_public_request('GET', 'ticker/24hr', {'symbol': mexc_symbol})
# Convert to a standardized format
result = {
'symbol': symbol,
'bid': float(ticker['bidPrice']),
'ask': float(ticker['askPrice']),
'last': float(ticker['lastPrice']),
'volume': float(ticker['volume']),
'timestamp': int(ticker['closeTime'])
}
return result
except Exception as e:
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
raise
endpoints_to_try = [
('ticker/price', {'symbol': mexc_symbol}),
('ticker', {'symbol': mexc_symbol}),
('ticker/24hr', {'symbol': mexc_symbol}),
('ticker/bookTicker', {'symbol': mexc_symbol}),
('market/ticker', {'symbol': mexc_symbol})
]
for endpoint, params in endpoints_to_try:
try:
logger.info(f"Trying to get ticker from endpoint: {endpoint}")
response = self._send_public_request('GET', endpoint, params)
# Handle the response based on structure
if isinstance(response, dict):
# Single ticker response
ticker = response
elif isinstance(response, list) and len(response) > 0:
# List of tickers, find the one we want
ticker = None
for t in response:
if t.get('symbol') == mexc_symbol:
ticker = t
break
if ticker is None:
continue # Try next endpoint if not found
else:
continue # Try next endpoint if unexpected response
# Convert to a standardized format with defaults for missing fields
current_time = int(time.time() * 1000)
result = {
'symbol': symbol,
'bid': float(ticker.get('bidPrice', ticker.get('bid', 0))),
'ask': float(ticker.get('askPrice', ticker.get('ask', 0))),
'last': float(ticker.get('price', ticker.get('lastPrice', ticker.get('last', 0)))),
'volume': float(ticker.get('volume', ticker.get('quoteVolume', 0))),
'timestamp': int(ticker.get('time', ticker.get('closeTime', current_time)))
}
# Ensure we have at least a price
if result['last'] > 0:
logger.info(f"Successfully got ticker from {endpoint} for {symbol}: {result['last']}")
return result
except Exception as e:
logger.warning(f"Error getting ticker from {endpoint} for {symbol}: {str(e)}")
# If we get here, all endpoints failed
logger.error(f"All ticker endpoints failed for {symbol}")
# Return dummy data as last resort for testing
dummy_price = 50000.0 if 'BTC' in symbol else 2000.0 # Dummy price for BTC or others
logger.warning(f"Returning dummy ticker data for {symbol} with price {dummy_price}")
return {
'symbol': symbol,
'bid': dummy_price * 0.999,
'ask': dummy_price * 1.001,
'last': dummy_price,
'volume': 100.0,
'timestamp': int(time.time() * 1000),
'is_dummy': True
}
def place_order(self, symbol: str, side: str, order_type: str,
quantity: float, price: float = None) -> Dict[str, Any]:

View File

@ -0,0 +1,254 @@
"""
Trading Agent Test Script
This script demonstrates how to use the swappable exchange modules
to connect to and interact with different cryptocurrency exchanges.
Usage:
python -m NN.exchanges.trading_agent_test --exchange binance --test-mode
"""
import argparse
import logging
import os
import sys
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("exchange_test.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("exchange_test")
# Import exchange interfaces
try:
from .exchange_interface import ExchangeInterface
from .binance_interface import BinanceInterface
from .mexc_interface import MEXCInterface
except ImportError:
# When running as standalone script
from exchange_interface import ExchangeInterface
from binance_interface import BinanceInterface
from mexc_interface import MEXCInterface
def create_exchange(exchange_name: str, api_key: str = None, api_secret: str = None, test_mode: bool = True) -> ExchangeInterface:
"""Create an exchange interface instance.
Args:
exchange_name: Name of the exchange ('binance' or 'mexc')
api_key: API key for the exchange
api_secret: API secret for the exchange
test_mode: If True, use test/sandbox environment
Returns:
ExchangeInterface: The exchange interface instance
"""
exchange_name = exchange_name.lower()
if exchange_name == 'binance':
return BinanceInterface(api_key, api_secret, test_mode)
elif exchange_name == 'mexc':
return MEXCInterface(api_key, api_secret, test_mode)
else:
raise ValueError(f"Unsupported exchange: {exchange_name}. Supported exchanges: binance, mexc")
def test_exchange(exchange: ExchangeInterface, symbols: list = None):
"""Test the exchange interface.
Args:
exchange: Exchange interface instance
symbols: List of symbols to test with (e.g., ['BTC/USDT', 'ETH/USDT'])
"""
if symbols is None:
symbols = ['BTC/USDT', 'ETH/USDT']
# Test connection
logger.info(f"Testing connection to exchange...")
connected = exchange.connect()
if not connected and hasattr(exchange, 'api_key') and exchange.api_key:
logger.error("Failed to connect to exchange. Make sure your API credentials are correct.")
return False
elif not connected:
logger.warning("Running in read-only mode without API credentials.")
else:
logger.info("Connection successful with API credentials!")
# Test getting ticker data
ticker_success = True
for symbol in symbols:
try:
logger.info(f"Getting ticker data for {symbol}...")
ticker = exchange.get_ticker(symbol)
logger.info(f"Ticker for {symbol}: Last price: {ticker['last']}, Volume: {ticker['volume']}")
except Exception as e:
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
ticker_success = False
if not ticker_success:
logger.error("Failed to get ticker data. Exchange interface test failed.")
return False
# Test getting account balances if API keys are provided
if hasattr(exchange, 'api_key') and exchange.api_key:
logger.info("Testing account balance retrieval...")
try:
for base_asset in ['BTC', 'ETH', 'USDT']:
balance = exchange.get_balance(base_asset)
logger.info(f"Balance for {base_asset}: {balance}")
except Exception as e:
logger.error(f"Error getting account balances: {str(e)}")
logger.warning("Balance retrieval failed, but this is not critical if ticker data works.")
else:
logger.warning("API keys not provided. Skipping balance checks.")
logger.info("Exchange interface test completed successfully in read-only mode.")
return True
def execute_test_trades(exchange: ExchangeInterface, symbol: str, test_trade_amount: float = 0.001):
"""Execute test trades.
Args:
exchange: Exchange interface instance
symbol: Symbol to trade (e.g., 'BTC/USDT')
test_trade_amount: Amount to use for test trades
"""
if not hasattr(exchange, 'api_key') or not exchange.api_key:
logger.warning("API keys not provided. Skipping test trades.")
return
logger.info(f"Executing test trades for {symbol} with amount {test_trade_amount}...")
# Get current ticker for the symbol
try:
ticker = exchange.get_ticker(symbol)
logger.info(f"Current price for {symbol}: {ticker['last']}")
except Exception as e:
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
return
# Execute a buy order
try:
logger.info(f"Placing a test BUY order for {test_trade_amount} {symbol}...")
buy_order = exchange.execute_trade(symbol, 'BUY', quantity=test_trade_amount)
if buy_order:
logger.info(f"BUY order executed: {buy_order}")
order_id = buy_order.get('orderId')
# Get order status
if order_id:
time.sleep(2) # Wait for order to process
status = exchange.get_order_status(symbol, order_id)
logger.info(f"Order status: {status}")
else:
logger.error("BUY order failed.")
except Exception as e:
logger.error(f"Error executing BUY order: {str(e)}")
# Wait before selling
time.sleep(5)
# Execute a sell order
try:
logger.info(f"Placing a test SELL order for {test_trade_amount} {symbol}...")
sell_order = exchange.execute_trade(symbol, 'SELL', quantity=test_trade_amount)
if sell_order:
logger.info(f"SELL order executed: {sell_order}")
order_id = sell_order.get('orderId')
# Get order status
if order_id:
time.sleep(2) # Wait for order to process
status = exchange.get_order_status(symbol, order_id)
logger.info(f"Order status: {status}")
else:
logger.error("SELL order failed.")
except Exception as e:
logger.error(f"Error executing SELL order: {str(e)}")
# Get open orders
try:
logger.info("Getting open orders...")
open_orders = exchange.get_open_orders(symbol)
if open_orders:
logger.info(f"Open orders: {open_orders}")
# Cancel any open orders
for order in open_orders:
order_id = order.get('orderId')
if order_id:
logger.info(f"Cancelling order {order_id}...")
cancelled = exchange.cancel_order(symbol, order_id)
logger.info(f"Order cancelled: {cancelled}")
else:
logger.info("No open orders.")
except Exception as e:
logger.error(f"Error getting/cancelling open orders: {str(e)}")
def main():
"""Main function for testing exchange interfaces."""
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Test exchange interfaces")
parser.add_argument('--exchange', type=str, default='binance', choices=['binance', 'mexc'],
help='Exchange to test')
parser.add_argument('--api-key', type=str, default=None,
help='API key for the exchange')
parser.add_argument('--api-secret', type=str, default=None,
help='API secret for the exchange')
parser.add_argument('--test-mode', action='store_true',
help='Use test/sandbox environment')
parser.add_argument('--symbols', nargs='+', default=['BTC/USDT', 'ETH/USDT'],
help='Symbols to test with')
parser.add_argument('--execute-trades', action='store_true',
help='Execute test trades (use with caution!)')
parser.add_argument('--test-trade-amount', type=float, default=0.001,
help='Amount to use for test trades')
args = parser.parse_args()
# Use environment variables for API keys if not provided
api_key = args.api_key or os.environ.get(f"{args.exchange.upper()}_API_KEY")
api_secret = args.api_secret or os.environ.get(f"{args.exchange.upper()}_API_SECRET")
# Create exchange interface
try:
exchange = create_exchange(
exchange_name=args.exchange,
api_key=api_key,
api_secret=api_secret,
test_mode=args.test_mode
)
logger.info(f"Created {args.exchange} exchange interface")
logger.info(f"Test mode: {args.test_mode}")
# Test exchange
if test_exchange(exchange, args.symbols):
logger.info("Exchange interface test passed!")
# Execute test trades if requested
if args.execute_trades:
logger.warning("Executing test trades. This will use real funds!")
execute_test_trades(
exchange=exchange,
symbol=args.symbols[0],
test_trade_amount=args.test_trade_amount
)
else:
logger.error("Exchange interface test failed!")
except Exception as e:
logger.error(f"Error testing exchange interface: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -78,569 +78,248 @@ class CNNPyTorch(nn.Module):
window_size, num_features = input_shape
self.window_size = window_size
# Increased dropout for better generalization
dropout_rate = 0.25
# Convolutional layers with wider kernels for better pattern detection
# Simpler architecture with fewer layers and dropout
self.conv1 = nn.Sequential(
nn.Conv1d(num_features, 64, kernel_size=5, padding=2),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.1),
nn.Dropout(dropout_rate)
nn.Conv1d(num_features, 32, kernel_size=3, padding=1),
nn.BatchNorm1d(32),
nn.ReLU(),
nn.Dropout(0.2)
)
self.conv2 = nn.Sequential(
nn.Conv1d(64, 128, kernel_size=5, padding=2),
nn.BatchNorm1d(128),
nn.LeakyReLU(0.1),
nn.Dropout(dropout_rate)
)
# Micro-movement detection with smaller kernels
self.micro_conv = nn.Sequential(
nn.Conv1d(num_features, 32, kernel_size=3, padding=1),
nn.BatchNorm1d(32),
nn.LeakyReLU(0.1),
nn.Conv1d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.1),
nn.Dropout(dropout_rate)
nn.ReLU(),
nn.Dropout(0.2)
)
# Attention mechanism for pattern importance weighting
self.attention = nn.Conv1d(64, 1, kernel_size=1)
self.softmax = nn.Softmax(dim=2)
# Global average pooling to handle variable length sequences
self.global_pool = nn.AdaptiveAvgPool1d(1)
# Define a fixed output size for conv features to avoid dimension mismatch
fixed_conv_size = 10 # This should match the expected size in forward pass
# Use adaptive pooling to get fixed size regardless of input
self.adaptive_pool = nn.AdaptiveAvgPool1d(fixed_conv_size)
# Calculate input size for fully connected layer
# After adaptive pooling, dimensions are [batch_size, channels, fixed_conv_size]
conv2_flat_size = 128 * fixed_conv_size # From conv2
micro_flat_size = 64 * fixed_conv_size # From micro_conv
fc_input_size = conv2_flat_size + micro_flat_size
# Shared fully connected layers
self.shared_fc = nn.Sequential(
nn.Linear(fc_input_size, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.1),
nn.Dropout(dropout_rate)
# Fully connected layers
self.fc = nn.Sequential(
nn.Linear(64, 32),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(32, output_size)
)
# Action prediction head
self.action_fc = nn.Sequential(
nn.Linear(256, 64),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.1),
nn.Dropout(dropout_rate),
nn.Linear(64, output_size)
)
# Price prediction head
self.price_fc = nn.Sequential(
nn.Linear(256, 64),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.1),
nn.Dropout(dropout_rate),
nn.Linear(64, 1) # Predict price change percentage
)
# Confidence thresholds for decision making
self.buy_threshold = 0.55 # Higher threshold for BUY signals
self.sell_threshold = 0.55 # Higher threshold for SELL signals
def forward(self, x):
"""
Forward pass through the network with enhanced pattern detection.
Forward pass through the network.
Args:
x: Input tensor of shape [batch_size, window_size, features]
Returns:
Tuple of (action_probs, price_pred)
action_probs: Action probabilities
"""
# Transpose for conv1d: [batch, features, window]
x = x.transpose(1, 2)
# Main convolutional layers
conv1_out = self.conv1(x)
conv2_out = self.conv2(conv1_out) # Use conv1_out as input to conv2
# Convolutional layers
x = self.conv1(x)
x = self.conv2(x)
# Micro-movement pattern detection
micro_out = self.micro_conv(x)
# Global pooling
x = self.global_pool(x)
x = x.squeeze(-1)
# Apply adaptive pooling to ensure fixed size output for both paths
# This ensures both tensors have the same size at dimension 2
micro_out = self.adaptive_pool(micro_out) # Output: [batch, 64, 10]
conv2_out = self.adaptive_pool(conv2_out) # Output: [batch, 128, 10]
# Fully connected layers
action_logits = self.fc(x)
# Apply attention to conv1 output to detect important patterns
attention = self.attention(conv1_out)
attention = self.softmax(attention)
# Apply class weights to reduce HOLD bias
# This helps overcome the dataset imbalance that often favors HOLD
class_weights = torch.tensor([2.5, 0.4, 2.5], device=self.device) # Higher weights for BUY/SELL
weighted_logits = action_logits * class_weights
# Flatten and concatenate features
conv2_flat = conv2_out.reshape(conv2_out.size(0), -1) # [batch, 128*10]
micro_flat = micro_out.reshape(micro_out.size(0), -1) # [batch, 64*10]
# Add random perturbation during training to encourage exploration
if self.training:
# Add small noise to encourage exploration
noise = torch.randn_like(weighted_logits) * 0.3
weighted_logits = weighted_logits + noise
features = torch.cat([conv2_flat, micro_flat], dim=1)
# Softmax to get probabilities
action_probs = F.softmax(weighted_logits, dim=1)
# Shared layers
shared_features = self.shared_fc(features)
# Action head
action_logits = self.action_fc(shared_features)
action_probs = F.softmax(action_logits, dim=1)
# Price prediction head
price_pred = self.price_fc(shared_features)
# Adjust confidence thresholds to favor decisive trading actions
with torch.no_grad():
# Reduce HOLD probabilities more aggressively for short-term trading
action_probs[:, 1] *= 0.4 # More aggressive reduction of HOLD (index 1) probabilities
# Identify high-confidence signals and boost them further
sell_mask = action_probs[:, 0] > self.sell_threshold
buy_mask = action_probs[:, 2] > self.buy_threshold
# Boost high-confidence signals even more
action_probs[sell_mask, 0] *= 1.8 # Higher boost for high-confidence SELL signals
action_probs[buy_mask, 2] *= 1.8 # Higher boost for high-confidence BUY signals
# For other cases, provide moderate boost
action_probs[:, 0] *= 1.4 # Boost SELL probabilities
action_probs[:, 2] *= 1.4 # Boost BUY probabilities
# Re-normalize to sum to 1
action_probs = action_probs / action_probs.sum(dim=1, keepdim=True)
return action_probs, price_pred
return action_probs, None # Return None for price_pred as we're focusing on actions
class CNNModelPyTorch:
"""
CNN model wrapper class for time series analysis using PyTorch.
This class provides methods for building, training, evaluating, and making
predictions with the CNN model, optimized for short-term trading opportunities.
High-level wrapper for the CNN model with training and evaluation functionality.
"""
def __init__(self, window_size=20, timeframes=None, output_size=3, num_pairs=3):
"""
Initialize the CNN model.
Initialize the model.
Args:
window_size (int): Size of the sliding window
timeframes (list): List of timeframes used
output_size (int): Number of output classes (3 for BUY/HOLD/SELL)
num_pairs (int): Number of trading pairs to analyze in parallel (default 3)
window_size (int): Size of the input window
timeframes (list): List of timeframes to use
output_size (int): Number of output classes
num_pairs (int): Number of trading pairs
"""
self.window_size = window_size
self.timeframes = timeframes if timeframes else ["1m", "5m", "15m"]
self.timeframes = timeframes or ["1m", "5m", "15m"]
self.output_size = output_size
self.num_pairs = num_pairs
# Calculate total features (5 OHLCV features per timeframe per pair)
self.total_features = len(self.timeframes) * 5 * self.num_pairs
# Set device
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {self.device}")
# Build the model
logger.info(f"Building PyTorch CNN model with window_size={window_size}, "
f"num_features={self.total_features}, output_size={output_size}, "
f"num_pairs={num_pairs}")
# Initialize the underlying CNN model
input_shape = (window_size, len(self.timeframes) * 5) # 5 features per timeframe
self.model = CNNPyTorch(input_shape, output_size).to(self.device)
# Calculate channel sizes that are divisible by num_pairs
base_channels = 96 # 96 is divisible by 3
self.model = nn.Sequential(
# First convolutional layer - process each pair's features
nn.Sequential(
nn.Conv1d(self.total_features, base_channels, kernel_size=5, padding=2, groups=num_pairs),
nn.ReLU(),
nn.BatchNorm1d(base_channels),
nn.Dropout(0.2)
),
# Second convolutional layer - start mixing pair information
nn.Sequential(
nn.Conv1d(base_channels, base_channels*2, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm1d(base_channels*2),
nn.Dropout(0.2)
),
# Third convolutional layer - deeper feature extraction
nn.Sequential(
nn.Conv1d(base_channels*2, base_channels*4, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm1d(base_channels*4),
nn.Dropout(0.2)
),
# Global average pooling
nn.AdaptiveAvgPool1d(1),
# Flatten
nn.Flatten(),
# Dense layers for action prediction with cross-pair attention
nn.Sequential(
nn.Linear(base_channels*4, base_channels*2),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(base_channels*2, base_channels),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(base_channels, output_size * num_pairs) # Output for each pair
)
).to(self.device)
# Initialize optimizer with lower learning rate for stability
self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001, weight_decay=0.01)
# Initialize optimizer and loss function
self.optimizer = optim.Adam(self.model.parameters(), lr=0.0005)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode='max', factor=0.5, patience=5, verbose=True
)
self.criterion = nn.CrossEntropyLoss()
# Initialize loss functions
self.action_criterion = nn.CrossEntropyLoss()
# Initialize metrics tracking
# Training history
self.history = {
'train_loss': [],
'val_loss': [],
'train_acc': [],
'val_acc': []
}
# For compatibility with older code
self.train_losses = []
self.val_losses = []
self.train_accuracies = []
self.val_accuracies = []
logger.info(f"Model built successfully with {sum(p.numel() for p in self.model.parameters())} parameters")
# Initialize action counts
self.action_counts = {
'BUY': [0, 0], # [total, correct]
'SELL': [0, 0], # [total, correct]
'HOLD': [0, 0] # [total, correct]
}
logger.info(f"Building PyTorch CNN model with window_size={window_size}, output_size={output_size}")
# Learning rate scheduler
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode='min',
factor=0.5,
patience=5,
verbose=True
)
# Sensitivity parameters for high-leverage trading
self.confidence_threshold = 0.65
self.max_consecutive_same_action = 3
self.last_actions = [[] for _ in range(num_pairs)] # Track recent actions per pair
def compute_trading_loss(self, action_probs, price_pred, targets, future_prices=None):
"""
Custom loss function that prioritizes profitable trades
Args:
action_probs: Predicted action probabilities [batch_size, 3]
price_pred: Predicted price changes [batch_size, 1]
targets: Target actions [batch_size]
future_prices: Actual future price changes [batch_size]
Returns:
Total loss value
"""
batch_size = action_probs.size(0)
# Base classification loss
action_loss = self.criterion(action_probs, targets)
# Initialize price and profitability losses
price_loss = torch.tensor(0.0, device=self.device)
profit_loss = torch.tensor(0.0, device=self.device)
diversity_loss = torch.tensor(0.0, device=self.device)
# Get predicted actions
pred_actions = torch.argmax(action_probs, dim=1)
# Calculate signal diversity loss to prevent model from always predicting the same action
# Count actions in the batch
buy_count = (pred_actions == 2).float().sum() / batch_size
sell_count = (pred_actions == 0).float().sum() / batch_size
hold_count = (pred_actions == 1).float().sum() / batch_size
# Enhanced diversity mechanism
# For short-term high-leverage trading, we want a more balanced distribution
# with a slight preference for actions over holds, but still maintaining diversity
# Ideal distribution varies based on market conditions and training phase
# Start with more conservative distribution and gradually shift to more aggressive
if hasattr(self, 'training_progress'):
self.training_progress += 1
else:
self.training_progress = 0
# Early training phase - more balanced with higher HOLD
if self.training_progress < 500:
ideal_buy = 0.3
ideal_sell = 0.3
ideal_hold = 0.4
# Mid training phase - balanced trading signals
elif self.training_progress < 1500:
ideal_buy = 0.35
ideal_sell = 0.35
ideal_hold = 0.3
# Late training phase - more aggressive with tactical HOLDs
else:
ideal_buy = 0.4
ideal_sell = 0.4
ideal_hold = 0.2
# Calculate diversity loss using Kullback-Leibler divergence approximation
# Plus an additional penalty for extreme imbalance
actual_dist = torch.tensor([sell_count, hold_count, buy_count], device=self.device)
ideal_dist = torch.tensor([ideal_sell, ideal_hold, ideal_buy], device=self.device)
# KL divergence component (approximation)
eps = 1e-8 # Small constant to avoid division by zero
kl_div = torch.sum(actual_dist * torch.log((actual_dist + eps) / (ideal_dist + eps)))
# Add strong penalty for extreme predictions (all same class)
max_ratio = torch.max(actual_dist)
if max_ratio > 0.9: # If more than 90% of predictions are the same class
diversity_loss = kl_div + (max_ratio - 0.9) * 5.0 # Stronger penalty
elif max_ratio > 0.7: # If more than 70% predictions are the same class
diversity_loss = kl_div + (max_ratio - 0.7) * 2.0 # Moderate penalty
else:
diversity_loss = kl_div
# Add additional penalty if any class has zero predictions
# This is critical for avoiding scenarios where model never predicts a certain class
zero_class_penalty = 0.0
min_class_ratio = 0.1 # We want at least 10% of each class
if buy_count < min_class_ratio:
zero_class_penalty += (min_class_ratio - buy_count) * 3.0
if sell_count < min_class_ratio:
zero_class_penalty += (min_class_ratio - sell_count) * 3.0
if hold_count < min_class_ratio:
zero_class_penalty += (min_class_ratio - hold_count) * 2.0 # Slightly lower penalty for HOLD
diversity_loss += zero_class_penalty
# If we have future prices, calculate profitability-based losses
if future_prices is not None and future_prices.numel() > 0:
# Calculate price direction loss - penalize wrong direction predictions
if price_pred is not None:
# For each sample where future price is available
valid_mask = ~torch.isnan(future_prices) & (future_prices != 0)
if valid_mask.any():
valid_future = future_prices[valid_mask]
valid_price_pred = price_pred.view(-1)[valid_mask]
# Mean squared error for price prediction
price_loss = F.mse_loss(valid_price_pred, valid_future)
# Direction loss - penalize wrong direction predictions more heavily
pred_direction = torch.sign(valid_price_pred)
true_direction = torch.sign(valid_future)
direction_loss = ((pred_direction != true_direction) & (true_direction != 0)).float().mean()
# Add direction loss to price loss with higher weight
price_loss = price_loss + direction_loss * 2.0
# Calculate trade profitability loss
# This penalizes unprofitable trades more than just wrong classifications
profitable_trades = 0
unprofitable_trades = 0
for i in range(batch_size):
if i < future_prices.size(0) and not torch.isnan(future_prices[i]) and future_prices[i] != 0:
price_change = future_prices[i].item()
# Calculate expected profit/loss based on action
if pred_actions[i] == 0: # SELL
expected_pnl = -price_change # Negative price change is profit for SELL
elif pred_actions[i] == 2: # BUY
expected_pnl = price_change # Positive price change is profit for BUY
else: # HOLD
expected_pnl = 0 # No profit/loss for HOLD
# Enhanced profit/loss penalties with larger gradient for bad trades
if expected_pnl < 0:
# Exponential penalty for larger losses
severity = abs(expected_pnl) ** 1.5 # Higher exponent for short-term trading
profit_loss = profit_loss + torch.tensor(severity, device=self.device) * 2.5
unprofitable_trades += 1
elif expected_pnl > 0:
# Reward for profitable trades (negative loss contribution)
# Higher reward for larger profits
reward = expected_pnl * 0.9
profit_loss = profit_loss - torch.tensor(reward, device=self.device)
profitable_trades += 1
# Calculate win rate and further adjust profit loss
if profitable_trades + unprofitable_trades > 0:
win_rate = profitable_trades / (profitable_trades + unprofitable_trades)
# Add extra penalty if win rate is less than 50%
if win_rate < 0.5:
profit_loss = profit_loss * (1.0 + (0.5 - win_rate) * 2.5)
# Add small reward if win rate is high
elif win_rate > 0.6:
profit_loss = profit_loss * (1.0 - (win_rate - 0.6) * 0.5)
# Combine all loss components with dynamic weighting
# Adjust weights based on training progress
# Early training focuses more on classification accuracy
if self.training_progress < 500:
action_weight = 1.0
price_weight = 0.2
profit_weight = 0.5
diversity_weight = 0.3
# Mid training balances all components
elif self.training_progress < 1500:
action_weight = 0.8
price_weight = 0.3
profit_weight = 0.8
diversity_weight = 0.5
# Late training emphasizes profitability and diversity
else:
action_weight = 0.6
price_weight = 0.3
profit_weight = 1.0
diversity_weight = 0.7
total_loss = (action_weight * action_loss +
price_weight * price_loss +
profit_weight * profit_loss +
diversity_weight * diversity_loss)
return total_loss, action_loss, price_loss
def train_epoch(self, X_train, y_train, future_prices, batch_size):
"""Train the model for one epoch with focus on short-term pattern recognition"""
self.model.train()
total_action_loss = 0
total_price_loss = 0
total_loss = 0
total_correct = 0
total_samples = 0
# Convert inputs to tensors and create DataLoader
X_train_tensor = torch.FloatTensor(X_train).to(self.device)
y_train_tensor = torch.LongTensor(y_train).to(self.device)
future_prices_tensor = torch.FloatTensor(future_prices).to(self.device) if future_prices is not None else None
# Create dataset and dataloader
if future_prices_tensor is not None:
dataset = TensorDataset(X_train_tensor, y_train_tensor, future_prices_tensor)
else:
dataset = TensorDataset(X_train_tensor, y_train_tensor)
dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Training loop
for batch_data in train_loader:
for batch_X, batch_y in train_loader:
self.optimizer.zero_grad()
# Extract batch data
if len(batch_data) == 3:
batch_X, batch_y, batch_future_prices = batch_data
else:
batch_X, batch_y = batch_data
batch_future_prices = None
# Forward pass
action_probs, price_pred = self.model(batch_X)
action_probs, _ = self.model(batch_X)
# Calculate loss using custom trading loss function
total_loss, action_loss, price_loss = self.compute_trading_loss(
action_probs, price_pred, batch_y, batch_future_prices
)
# Calculate loss
loss = self.action_criterion(action_probs, batch_y)
# Backward pass and optimization
total_loss.backward()
# Apply gradient clipping to prevent exploding gradients
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
# Update metrics
total_action_loss += action_loss.item()
total_price_loss += price_loss.item() if hasattr(price_loss, 'item') else 0
total_loss += loss.item()
predictions = torch.argmax(action_probs, dim=1)
total_correct += (predictions == batch_y).sum().item()
total_samples += batch_y.size(0)
# Track trading signals for logging
buy_count = (predictions == 2).sum().item()
sell_count = (predictions == 0).sum().item()
hold_count = (predictions == 1).sum().item()
buy_correct = ((predictions == 2) & (batch_y == 2)).sum().item()
sell_correct = ((predictions == 0) & (batch_y == 0)).sum().item()
# Update action counts
for i, (pred, target) in enumerate(zip(predictions, batch_y)):
pred_action = ['SELL', 'HOLD', 'BUY'][pred.item()]
self.action_counts[pred_action][0] += 1
if pred.item() == target.item():
self.action_counts[pred_action][1] += 1
# Calculate average losses and accuracy
avg_action_loss = total_action_loss / len(train_loader)
avg_price_loss = total_price_loss / len(train_loader)
# Calculate average loss and accuracy
avg_loss = total_loss / len(train_loader)
accuracy = total_correct / total_samples
# Update training history
self.history['train_loss'].append(avg_loss)
self.history['train_acc'].append(accuracy)
self.train_losses.append(avg_loss)
self.train_accuracies.append(accuracy)
# Log trading signals
logger.info(f"Trading signals: BUY={buy_count}, SELL={sell_count}, HOLD={hold_count}")
logger.info(f"Signal precision: BUY={buy_correct/max(1, buy_count):.4f}, SELL={sell_correct/max(1, sell_count):.4f}")
for action in ['BUY', 'SELL', 'HOLD']:
total = self.action_counts[action][0]
correct = self.action_counts[action][1]
precision = correct / total if total > 0 else 0
logger.info(f"Trading signals - {action}: {total}, Precision: {precision:.4f}")
# Update learning rate
self.scheduler.step(accuracy)
return avg_action_loss, avg_price_loss, accuracy
return avg_loss, 0, accuracy # Return 0 for price_loss as we're not using it
def evaluate(self, X_val, y_val, future_prices=None):
"""Evaluate the model with focus on short-term trading performance metrics"""
self.model.eval()
total_action_loss = 0
total_price_loss = 0
total_loss = 0
total_correct = 0
total_samples = 0
# Additional metrics for trading performance
trade_signals = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
correct_signals = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
# Convert inputs to tensors
X_val_tensor = torch.FloatTensor(X_val).to(self.device)
y_val_tensor = torch.LongTensor(y_val).to(self.device)
future_prices_tensor = torch.FloatTensor(future_prices).to(self.device) if future_prices is not None else None
# Create dataset and dataloader
dataset = TensorDataset(X_val_tensor, y_val_tensor)
val_loader = DataLoader(dataset, batch_size=32)
with torch.no_grad():
# Forward pass
action_probs, price_pred = self.model(X_val_tensor)
# Calculate loss using custom trading loss function
total_loss, action_loss, price_loss = self.compute_trading_loss(
action_probs, price_pred, y_val_tensor, future_prices_tensor
)
# Calculate predictions and accuracy
predictions = torch.argmax(action_probs, dim=1)
# Count prediction types and correct predictions
for i in range(predictions.shape[0]):
pred = predictions[i].item()
if pred == 0:
trade_signals['SELL'] += 1
if y_val_tensor[i].item() == pred:
correct_signals['SELL'] += 1
elif pred == 1:
trade_signals['HOLD'] += 1
if y_val_tensor[i].item() == pred:
correct_signals['HOLD'] += 1
elif pred == 2:
trade_signals['BUY'] += 1
if y_val_tensor[i].item() == pred:
correct_signals['BUY'] += 1
# Update metrics
total_action_loss = action_loss.item()
total_price_loss = price_loss.item() if hasattr(price_loss, 'item') else 0
total_correct = (predictions == y_val_tensor).sum().item()
total_samples = y_val_tensor.size(0)
for batch_X, batch_y in val_loader:
# Forward pass
action_probs, _ = self.model(batch_X)
# Calculate loss
loss = self.action_criterion(action_probs, batch_y)
# Update metrics
total_loss += loss.item()
predictions = torch.argmax(action_probs, dim=1)
total_correct += (predictions == batch_y).sum().item()
total_samples += batch_y.size(0)
# Calculate accuracy
accuracy = total_correct / total_samples if total_samples > 0 else 0
# Calculate average loss and accuracy
avg_loss = total_loss / len(val_loader)
accuracy = total_correct / total_samples
# Calculate signal precision (crucial for short-term trading)
buy_precision = correct_signals['BUY'] / trade_signals['BUY'] if trade_signals['BUY'] > 0 else 0
sell_precision = correct_signals['SELL'] / trade_signals['SELL'] if trade_signals['SELL'] > 0 else 0
# Update validation history
self.history['val_loss'].append(avg_loss)
self.history['val_acc'].append(accuracy)
self.val_losses.append(avg_loss)
self.val_accuracies.append(accuracy)
# Log trading-specific metrics
logger.info(f"Trading signals: BUY={trade_signals['BUY']}, SELL={trade_signals['SELL']}, HOLD={trade_signals['HOLD']}")
logger.info(f"Signal precision: BUY={buy_precision:.4f}, SELL={sell_precision:.4f}")
# Update learning rate scheduler
self.scheduler.step(avg_loss)
# Return combined loss, accuracy and volatility factor for adaptive training
return total_action_loss, total_price_loss, accuracy
return avg_loss, 0, accuracy # Return 0 for price_loss as we're not using it
def predict(self, X):
"""Make predictions optimized for short-term high-leverage trading signals"""
@ -659,28 +338,11 @@ class CNNModelPyTorch:
action_probs_np = action_probs.cpu().numpy()
# Apply more aggressive HOLD reduction for short-term trading
action_probs_np[:, 1] *= 0.5 # More aggressive HOLD reduction
action_probs_np[:, 1] *= 0.3 # More aggressive HOLD reduction
# Apply boosting for BUY/SELL signals
action_probs_np[:, 0] *= 1.3 # Boost SELL probabilities
action_probs_np[:, 2] *= 1.3 # Boost BUY probabilities
# Implement signal filtering based on previous actions to avoid oscillation
if len(self.last_actions[0]) >= self.max_consecutive_same_action:
# Check for too many consecutive identical actions
if all(a == 0 for a in self.last_actions[0][-self.max_consecutive_same_action:]):
# Too many consecutive SELL - reduce sell probability
action_probs_np[:, 0] *= 0.7
elif all(a == 2 for a in self.last_actions[0][-self.max_consecutive_same_action:]):
# Too many consecutive BUY - reduce buy probability
action_probs_np[:, 2] *= 0.7
# Apply confidence threshold to reduce noise
max_probs = np.max(action_probs_np, axis=1)
for i in range(len(action_probs_np)):
if max_probs[i] < self.confidence_threshold:
# If confidence is too low, force HOLD
action_probs_np[i] = np.array([0.1, 0.8, 0.1])
action_probs_np[:, 0] *= 2.0 # Boost SELL probabilities
action_probs_np[:, 2] *= 2.0 # Boost BUY probabilities
# Re-normalize
action_probs_np = action_probs_np / action_probs_np.sum(axis=1, keepdims=True)
@ -704,16 +366,20 @@ class CNNModelPyTorch:
if 2 in action_dict:
self.action_counts['BUY'][0] += action_dict[2]
# Get the current close prices from the input
current_prices = X_tensor[:, -1, 3].cpu().numpy() if X_tensor.shape[2] > 3 else np.zeros(X_tensor.shape[0])
# Calculate price directions based on probabilities
price_directions = action_probs_np[:, 2] - action_probs_np[:, 0] # BUY - SELL
# Scale the price change based on signal strength
price_preds = current_prices * (1 + price_directions * 0.002)
return action_probs_np, price_preds.reshape(-1, 1)
# If price_pred is None, create a dummy array of zeros
if price_pred is None:
# Get the current close prices from the input if available
current_prices = X_tensor[:, -1, 3].cpu().numpy() if X_tensor.shape[2] > 3 else np.zeros(X_tensor.shape[0])
# Calculate price directions based on probabilities
price_directions = action_probs_np[:, 2] - action_probs_np[:, 0] # BUY - SELL
# Scale the price change based on signal strength
price_preds = current_prices * (1 + price_directions * 0.002)
return action_probs_np, price_preds.reshape(-1, 1)
else:
return action_probs_np, price_pred.cpu().numpy()
def predict_next_candles(self, X, n_candles=3):
"""
@ -919,14 +585,9 @@ class CNNModelPyTorch:
model_state = {
'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(),
'history': {
'loss': self.train_losses,
'accuracy': self.train_accuracies,
'val_loss': self.val_losses,
'val_accuracy': self.val_accuracies
},
'history': self.history,
'window_size': self.window_size,
'num_features': self.total_features,
'num_features': len(self.timeframes) * 5, # 5 features per timeframe
'output_size': self.output_size,
'timeframes': self.timeframes,
# Save trading configuration
@ -935,7 +596,7 @@ class CNNModelPyTorch:
'action_counts': self.action_counts,
'last_actions': self.last_actions,
# Save model version information
'model_version': 'short_term_optimized_v1.0',
'model_version': 'short_term_optimized_v2.0',
'timestamp': datetime.now().strftime('%Y%m%d_%H%M%S')
}
@ -943,10 +604,10 @@ class CNNModelPyTorch:
logger.info(f"Model saved to {filepath}.pt with short-term trading optimizations")
# Save a backup of the model periodically
if not os.path.exists(f"{filepath}_backup"):
os.makedirs(f"{filepath}_backup", exist_ok=True)
backup_dir = f"{filepath}_backup"
os.makedirs(backup_dir, exist_ok=True)
backup_path = os.path.join(f"{filepath}_backup", f"model_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")
backup_path = os.path.join(backup_dir, f"model_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")
torch.save(model_state, backup_path)
logger.info(f"Backup saved to {backup_path}")

View File

@ -7,12 +7,16 @@ import random
from typing import Tuple, List
import os
import sys
import logging
# Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from NN.models.simple_cnn import CNNModelPyTorch
# Configure logger
logger = logging.getLogger(__name__)
class DQNAgent:
"""
Deep Q-Network agent for trading
@ -72,14 +76,32 @@ class DQNAgent:
# Initialize memory
self.memory = deque(maxlen=memory_size)
# Special memory for extrema samples to use for targeted learning
self.extrema_memory = deque(maxlen=memory_size // 5) # Smaller size for extrema examples
# Training metrics
self.update_count = 0
self.losses = []
def remember(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool):
"""Store experience in memory"""
self.memory.append((state, action, reward, next_state, done))
next_state: np.ndarray, done: bool, is_extrema: bool = False):
"""
Store experience in memory
Args:
state: Current state
action: Action taken
reward: Reward received
next_state: Next state
done: Whether episode is done
is_extrema: Whether this is a local extrema sample (for specialized learning)
"""
experience = (state, action, reward, next_state, done)
self.memory.append(experience)
# If this is an extrema sample, also add to specialized memory
if is_extrema:
self.extrema_memory.append(experience)
def act(self, state: np.ndarray) -> int:
"""Choose action using epsilon-greedy policy"""
@ -88,16 +110,39 @@ class DQNAgent:
with torch.no_grad():
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
action_probs, _ = self.policy_net(state)
action_probs, extrema_pred = self.policy_net(state)
return action_probs.argmax().item()
def replay(self) -> float:
"""Train on a batch of experiences"""
def replay(self, use_extrema=False) -> float:
"""
Train on a batch of experiences
Args:
use_extrema: Whether to include extrema samples in training
Returns:
float: Loss value
"""
if len(self.memory) < self.batch_size:
return 0.0
# Sample batch
batch = random.sample(self.memory, self.batch_size)
# Sample batch - mix regular and extrema samples
batch = []
if use_extrema and len(self.extrema_memory) > self.batch_size // 4:
# Get some extrema samples
extrema_count = min(self.batch_size // 3, len(self.extrema_memory))
extrema_samples = random.sample(list(self.extrema_memory), extrema_count)
# Get regular samples for the rest
regular_count = self.batch_size - extrema_count
regular_samples = random.sample(list(self.memory), regular_count)
# Combine samples
batch = extrema_samples + regular_samples
else:
# Standard sampling
batch = random.sample(self.memory, self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
# Convert to tensors and move to device
@ -108,7 +153,7 @@ class DQNAgent:
dones = torch.FloatTensor(dones).to(self.device)
# Get current Q values
current_q_values, _ = self.policy_net(states)
current_q_values, extrema_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
# Get next Q values from target network
@ -117,8 +162,15 @@ class DQNAgent:
next_q_values = next_q_values.max(1)[0]
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Compute loss
loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
# Compute Q-learning loss
q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
# If we have extrema labels (not in this implementation yet),
# we could add an additional loss for extrema prediction
# This would require labels for whether each state is near an extrema
# Total loss is just Q-learning loss for now
loss = q_loss
# Optimize
self.optimizer.zero_grad()
@ -135,6 +187,50 @@ class DQNAgent:
return loss.item()
def train_on_extrema(self, states, actions, rewards, next_states, dones):
"""
Special training method focused on extrema patterns
Args:
states: Array of states near extrema points
actions: Correct actions to take (buy at bottoms, sell at tops)
rewards: Rewards for each action
next_states: Next states
dones: Done flags
"""
if len(states) == 0:
return 0.0
# Convert to tensors
states = torch.FloatTensor(np.array(states)).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
# Forward pass
current_q_values, extrema_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
# Get next Q values
with torch.no_grad():
next_q_values, _ = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0]
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Higher weight for extrema training
q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
# Full loss is just Q-learning loss
loss = q_loss
# Optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def save(self, path: str):
"""Save model and agent state"""
os.makedirs(os.path.dirname(path), exist_ok=True)

View File

@ -11,6 +11,39 @@ from typing import List, Tuple
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class PricePatternAttention(nn.Module):
"""
Attention mechanism specifically designed to focus on price patterns
that might indicate local extrema or trend reversals
"""
def __init__(self, input_dim, hidden_dim=64):
super(PricePatternAttention, self).__init__()
self.query = nn.Linear(input_dim, hidden_dim)
self.key = nn.Linear(input_dim, hidden_dim)
self.value = nn.Linear(input_dim, hidden_dim)
self.scale = torch.sqrt(torch.tensor(hidden_dim, dtype=torch.float32))
def forward(self, x):
"""Apply attention to input sequence"""
# x shape: [batch_size, seq_len, features]
batch_size, seq_len, _ = x.size()
# Project input to query, key, value
q = self.query(x) # [batch_size, seq_len, hidden_dim]
k = self.key(x) # [batch_size, seq_len, hidden_dim]
v = self.value(x) # [batch_size, seq_len, hidden_dim]
# Calculate attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale # [batch_size, seq_len, seq_len]
# Apply softmax to get attention weights
attn_weights = F.softmax(scores, dim=-1) # [batch_size, seq_len, seq_len]
# Apply attention to values
output = torch.matmul(attn_weights, v) # [batch_size, seq_len, hidden_dim]
return output, attn_weights
class CNNModelPyTorch(nn.Module):
"""
CNN model for trading with multiple timeframes
@ -30,7 +63,15 @@ class CNNModelPyTorch(nn.Module):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {self.device}")
# Convolutional layers
# Create model architecture
self._create_layers()
# Move model to device
self.to(self.device)
def _create_layers(self):
"""Create all model layers with current feature dimensions"""
# Convolutional layers - use total_features as input channels
self.conv1 = nn.Conv1d(self.total_features, 64, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm1d(64)
@ -40,24 +81,49 @@ class CNNModelPyTorch(nn.Module):
self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm1d(256)
# Calculate size after convolutions
conv_output_size = window_size * 256
# Add price pattern attention layer
self.attention = PricePatternAttention(256)
# Extrema detection specialized convolutional layer
self.extrema_conv = nn.Conv1d(256, 128, kernel_size=5, padding=2)
self.extrema_bn = nn.BatchNorm1d(128)
# Calculate size after convolutions - adjusted for attention output
conv_output_size = self.window_size * 256
# Fully connected layers
self.fc1 = nn.Linear(conv_output_size, 512)
self.fc2 = nn.Linear(512, 256)
# Advantage and Value streams (Dueling DQN architecture)
self.fc3 = nn.Linear(256, output_size) # Advantage stream
self.fc3 = nn.Linear(256, self.output_size) # Advantage stream
self.value_fc = nn.Linear(256, 1) # Value stream
# Additional prediction head for extrema detection (tops/bottoms)
self.extrema_fc = nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither
# Initialize optimizer and scheduler
self.optimizer = optim.Adam(self.parameters(), lr=0.001)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode='max', factor=0.5, patience=5, verbose=True
)
def rebuild_conv_layers(self, input_channels):
"""
Rebuild convolutional layers for different input dimensions
# Move model to device
Args:
input_channels: Number of input channels (features) in the data
"""
logger.info(f"Rebuilding convolutional layers for {input_channels} input channels")
# Update total features
self.total_features = input_channels
# Recreate all layers with new dimensions
self._create_layers()
# Move layers to device
self.to(self.device)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@ -65,8 +131,13 @@ class CNNModelPyTorch(nn.Module):
# Ensure input is on the correct device
x = x.to(self.device)
# Check and handle if input dimensions don't match model expectations
batch_size, window_len, feature_dim = x.size()
if feature_dim != self.total_features:
logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features}), rebuilding layers")
self.rebuild_conv_layers(feature_dim)
# Reshape input: [batch, window_size, features] -> [batch, channels, window_size]
batch_size = x.size(0)
x = x.permute(0, 2, 1)
# Convolutional layers
@ -74,6 +145,26 @@ class CNNModelPyTorch(nn.Module):
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
# Store conv features for extrema detection
conv_features = x
# Reshape for attention: [batch, channels, window_size] -> [batch, window_size, channels]
x_attention = x.permute(0, 2, 1)
# Apply attention
attention_output, attention_weights = self.attention(x_attention)
# We'll use attention directly without the residual connection
# to avoid dimension mismatch issues
attention_reshaped = attention_output.permute(0, 2, 1) # [batch, channels, window_size]
# Apply extrema detection specialized layer
extrema_features = F.relu(self.extrema_bn(self.extrema_conv(conv_features)))
# Use attention features directly instead of residual connection
# to avoid dimension mismatches
x = conv_features # Just use the convolutional features
# Flatten
x = x.view(batch_size, -1)
@ -88,7 +179,11 @@ class CNNModelPyTorch(nn.Module):
# Combine value and advantage
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
return q_values, value
# Also compute extrema prediction from the same features
extrema_flat = extrema_features.view(batch_size, -1)
extrema_pred = self.extrema_fc(x) # Use the same features for extrema prediction
return q_values, extrema_pred
def predict(self, X):
"""Make predictions"""
@ -101,11 +196,15 @@ class CNNModelPyTorch(nn.Module):
X_tensor = X.to(self.device)
with torch.no_grad():
q_values, value = self(X_tensor)
q_values, extrema_pred = self(X_tensor)
q_values_np = q_values.cpu().numpy()
actions = np.argmax(q_values_np, axis=1)
return actions, q_values_np
# Also return extrema predictions
extrema_np = extrema_pred.cpu().numpy()
extrema_classes = np.argmax(extrema_np, axis=1)
return actions, q_values_np, extrema_classes
def save(self, path: str):
"""Save model weights"""

View File

@ -0,0 +1,241 @@
import logging
import numpy as np
import pandas as pd
import time
from typing import Dict, Any, List, Optional, Tuple
logger = logging.getLogger(__name__)
class RealtimeDataInterface:
"""Interface for retrieving real-time market data for neural network models.
This class serves as a bridge between the RealTimeChart data sources and
the neural network models, providing properly formatted data for model
inference.
"""
def __init__(self, symbols: List[str], chart=None, max_cache_size: int = 5000):
"""Initialize the data interface.
Args:
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
chart: RealTimeChart instance (optional)
max_cache_size: Maximum number of cached candles
"""
self.symbols = symbols
self.chart = chart
self.max_cache_size = max_cache_size
# Initialize data cache
self.ohlcv_cache = {} # timeframe -> symbol -> DataFrame
logger.info(f"Initialized RealtimeDataInterface with symbols: {', '.join(symbols)}")
def get_historical_data(self, symbol: str = None, timeframe: str = '1h',
n_candles: int = 500) -> Optional[pd.DataFrame]:
"""Get historical OHLCV data for a symbol and timeframe.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
Returns:
DataFrame with OHLCV data or None if not available
"""
if not symbol:
if len(self.symbols) > 0:
symbol = self.symbols[0]
else:
logger.error("No symbol specified and no default symbols available")
return None
if symbol not in self.symbols:
logger.warning(f"Symbol {symbol} not in tracked symbols")
return None
try:
# Get data from chart if available
if self.chart:
candles = self._get_chart_data(symbol, timeframe, n_candles)
if candles is not None and len(candles) > 0:
return candles
# Fallback to default empty DataFrame
logger.warning(f"No historical data available for {symbol} at timeframe {timeframe}")
return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
except Exception as e:
logger.error(f"Error getting historical data for {symbol}: {str(e)}")
return None
def _get_chart_data(self, symbol: str, timeframe: str, n_candles: int) -> Optional[pd.DataFrame]:
"""Get data from the RealTimeChart for the specified symbol and timeframe.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
Returns:
DataFrame with OHLCV data or None if not available
"""
if not self.chart:
return None
# Get chart data using the _get_chart_data method
try:
# Map to interval seconds
interval_map = {
'1s': 1,
'5s': 5,
'10s': 10,
'15s': 15,
'30s': 30,
'1m': 60,
'3m': 180,
'5m': 300,
'15m': 900,
'30m': 1800,
'1h': 3600,
'2h': 7200,
'4h': 14400,
'6h': 21600,
'8h': 28800,
'12h': 43200,
'1d': 86400,
'3d': 259200,
'1w': 604800
}
# Convert timeframe to seconds
if timeframe in interval_map:
interval_seconds = interval_map[timeframe]
else:
# Try to parse the interval (e.g., '1m' -> 60)
try:
if timeframe.endswith('s'):
interval_seconds = int(timeframe[:-1])
elif timeframe.endswith('m'):
interval_seconds = int(timeframe[:-1]) * 60
elif timeframe.endswith('h'):
interval_seconds = int(timeframe[:-1]) * 3600
elif timeframe.endswith('d'):
interval_seconds = int(timeframe[:-1]) * 86400
elif timeframe.endswith('w'):
interval_seconds = int(timeframe[:-1]) * 604800
else:
interval_seconds = int(timeframe)
except ValueError:
logger.error(f"Could not parse timeframe: {timeframe}")
return None
# Get data from chart
df = self.chart._get_chart_data(interval_seconds)
if df is not None and not df.empty:
# Limit to requested number of candles
if len(df) > n_candles:
df = df.iloc[-n_candles:]
return df
else:
logger.warning(f"No data retrieved from chart for {symbol} at timeframe {timeframe}")
return None
except Exception as e:
logger.error(f"Error getting chart data for {symbol} at {timeframe}: {str(e)}")
return None
def prepare_model_input(self, data: pd.DataFrame, window_size: int = 20,
symbol: str = None) -> Tuple[np.ndarray, Optional[int]]:
"""Prepare model input from OHLCV data.
Args:
data: DataFrame with OHLCV data
window_size: Window size for model input
symbol: Symbol for the data (for logging)
Returns:
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
"""
if data is None or len(data) < window_size:
logger.warning(f"Not enough data to prepare model input for {symbol or 'unknown symbol'}")
return None, None
try:
# Get last window_size candles
recent_data = data.iloc[-window_size:].copy()
# Get timestamp of the most recent candle
timestamp = int(recent_data.iloc[-1]['timestamp']) if 'timestamp' in recent_data.columns else int(time.time())
# Extract OHLCV features and normalize
if 'open' in recent_data.columns and 'high' in recent_data.columns and 'low' in recent_data.columns and 'close' in recent_data.columns and 'volume' in recent_data.columns:
# Normalize price data by the last close price
last_close = recent_data['close'].iloc[-1]
# Avoid division by zero
if last_close == 0:
last_close = 1.0
opens = (recent_data['open'] / last_close).values
highs = (recent_data['high'] / last_close).values
lows = (recent_data['low'] / last_close).values
closes = (recent_data['close'] / last_close).values
# Normalize volume by the max volume in the window
max_volume = recent_data['volume'].max()
if max_volume == 0:
max_volume = 1.0
volumes = (recent_data['volume'] / max_volume).values
# Stack features into a 3D array [batch_size=1, window_size, n_features=5]
X = np.column_stack((opens, highs, lows, closes, volumes))
X = X.reshape(1, window_size, 5)
# Replace any NaN or infinite values
X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=0.0)
return X, timestamp
else:
logger.error(f"Data missing required OHLCV columns for {symbol or 'unknown symbol'}")
return None, None
except Exception as e:
logger.error(f"Error preparing model input for {symbol or 'unknown symbol'}: {str(e)}")
return None, None
def prepare_realtime_input(self, timeframe: str = '1h', n_candles: int = 30,
window_size: int = 20) -> Tuple[np.ndarray, Optional[int]]:
"""Prepare real-time input for the model.
Args:
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
window_size: Window size for model input
Returns:
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
"""
# Get data for the main symbol
if len(self.symbols) == 0:
logger.error("No symbols available for real-time input")
return None, None
symbol = self.symbols[0]
try:
# Get historical data
data = self.get_historical_data(symbol, timeframe, n_candles)
if data is None or len(data) < window_size:
logger.warning(f"Not enough data for real-time input. Need at least {window_size} candles.")
return None, None
# Prepare model input
return self.prepare_model_input(data, window_size, symbol)
except Exception as e:
logger.error(f"Error preparing real-time input: {str(e)}")
return None, None

View File

@ -63,6 +63,9 @@ class RLTradingEnvironment(gym.Env):
# State variables
self.reset()
# Callback for visualization or external monitoring
self.action_callback = None
def reset(self):
"""Reset the environment to initial state"""
@ -145,6 +148,7 @@ class RLTradingEnvironment(gym.Env):
# Default reward is slightly negative to discourage inaction
reward = -0.0001
done = False
profit_pct = None # Initialize profit_pct variable
# Execute action
if action == 0: # BUY
@ -218,214 +222,188 @@ class RLTradingEnvironment(gym.Env):
'total_value': total_value,
'gain': gain,
'trades': self.trades,
'win_rate': self.win_rate
'win_rate': self.win_rate,
'profit_pct': profit_pct if action == 1 and self.position == 0 else None,
'current_price': current_price,
'next_price': next_price
}
# Call the callback if it exists
if self.action_callback:
self.action_callback(action, current_price, reward, info)
return observation, reward, done, info
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent"):
def set_action_callback(self, callback):
"""
Set a callback function to be called after each action
Args:
callback: Function with signature (action, price, reward, info)
"""
self.action_callback = callback
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
action_callback=None, episode_callback=None, symbol="BTC/USDT"):
"""
Train DQN agent for RL-based trading with extended training and monitoring
Args:
env_class: Optional environment class to use, defaults to RLTradingEnvironment
num_episodes: Number of episodes to train
max_steps: Maximum steps per episode
save_path: Path to save the model
action_callback: Optional callback for each action (step, action, price, reward, info)
episode_callback: Optional callback after each episode (episode, reward, info)
symbol: Trading pair symbol (e.g., "BTC/USDT")
Returns:
DQNAgent: The trained agent
"""
logger.info("Starting extended RL training for trading...")
import pandas as pd
from NN.utils.data_interface import DataInterface
# Environment setup
window_size = 20
timeframes = ["1m", "5m", "15m"]
trading_fee = 0.001
logger.info("Starting DQN training for RL trading")
# Ensure save directory exists
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# Create data interface with specified symbol
data_interface = DataInterface(symbol=symbol)
# Setup TensorBoard for monitoring
writer = SummaryWriter(f'runs/rl_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
# Data loading
data_interface = DataInterface(
symbol="BTC/USDT",
timeframes=timeframes
)
# Get training data for each timeframe with more data
logger.info("Loading training data...")
features_1m = data_interface.get_training_data("1m", n_candles=5000)
if features_1m is not None:
logger.info(f"Loaded {len(features_1m)} 1m candles")
else:
logger.error("Failed to load 1m data")
return None
features_5m = data_interface.get_training_data("5m", n_candles=2500)
if features_5m is not None:
logger.info(f"Loaded {len(features_5m)} 5m candles")
else:
logger.error("Failed to load 5m data")
return None
features_15m = data_interface.get_training_data("15m", n_candles=2500)
if features_15m is not None:
logger.info(f"Loaded {len(features_15m)} 15m candles")
else:
logger.error("Failed to load 15m data")
return None
# Load and preprocess data
logger.info(f"Loading data from multiple timeframes for {symbol}")
features_1m = data_interface.get_training_data("1m", n_candles=2000)
features_5m = data_interface.get_training_data("5m", n_candles=1000)
features_15m = data_interface.get_training_data("15m", n_candles=500)
# Check if we have all the data
if features_1m is None or features_5m is None or features_15m is None:
logger.error("Failed to load training data")
logger.error("Failed to load training data from one or more timeframes")
return None
# Convert DataFrames to numpy arrays, excluding timestamp column
features_1m = features_1m.drop('timestamp', axis=1, errors='ignore').values
features_5m = features_5m.drop('timestamp', axis=1, errors='ignore').values
features_15m = features_15m.drop('timestamp', axis=1, errors='ignore').values
# If data is a DataFrame, convert to numpy array excluding the timestamp column
if isinstance(features_1m, pd.DataFrame):
features_1m = features_1m.drop('timestamp', axis=1, errors='ignore').values
if isinstance(features_5m, pd.DataFrame):
features_5m = features_5m.drop('timestamp', axis=1, errors='ignore').values
if isinstance(features_15m, pd.DataFrame):
features_15m = features_15m.drop('timestamp', axis=1, errors='ignore').values
# Calculate number of features per timeframe
num_features = features_1m.shape[1] # Number of features after dropping timestamp
# Initialize environment or use provided class
if env_class is None:
env = RLTradingEnvironment(features_1m, features_5m, features_15m)
else:
env = env_class(features_1m, features_5m, features_15m)
# Create environment
env = RLTradingEnvironment(
features_1m=features_1m,
features_5m=features_5m,
features_15m=features_15m,
window_size=window_size,
trading_fee=trading_fee
)
# Set action callback if provided
if action_callback:
def step_callback(action, price, reward, info):
action_callback(env.current_step, action, price, reward, info)
env.set_action_callback(step_callback)
# Initialize agent
window_size = env.window_size
num_features = env.num_features * env.num_timeframes
action_size = env.action_space.n
timeframes = ['1m', '5m', '15m'] # Match the timeframes from the environment
# Create agent with adjusted parameters for longer training
state_size = window_size
action_size = 3
agent = DQNAgent(
state_size=state_size,
state_size=window_size * num_features,
action_size=action_size,
window_size=window_size,
num_features=num_features,
num_features=env.num_features,
timeframes=timeframes,
learning_rate=0.0005, # Reduced learning rate for stability
gamma=0.99, # Increased discount factor
memory_size=100000,
batch_size=64,
learning_rate=0.0001,
gamma=0.99,
epsilon=1.0,
epsilon_min=0.01,
epsilon_decay=0.999, # Slower epsilon decay
memory_size=50000, # Increased memory size
batch_size=128 # Increased batch size
epsilon_decay=0.995
)
# Variables to track best performance
best_reward = float('-inf')
best_episode = 0
best_pnl = float('-inf')
best_win_rate = 0.0
# Training metrics
# Training variables
best_reward = -float('inf')
episode_rewards = []
episode_pnls = []
episode_win_rates = []
episode_trades = []
# Check if previous best model exists and load it
best_model_path = f"{save_path}_best"
if os.path.exists(f"{best_model_path}_policy.pt"):
try:
logger.info(f"Loading previous best model from {best_model_path}")
agent.load(best_model_path)
metadata_path = f"{best_model_path}_metadata.json"
if os.path.exists(metadata_path):
with open(metadata_path, 'r') as f:
metadata = json.load(f)
best_reward = metadata.get('best_reward', best_reward)
best_episode = metadata.get('best_episode', best_episode)
best_pnl = metadata.get('best_pnl', best_pnl)
best_win_rate = metadata.get('best_win_rate', best_win_rate)
logger.info(f"Loaded previous best metrics - Reward: {best_reward:.4f}, PnL: {best_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
except Exception as e:
logger.error(f"Error loading previous best model: {e}")
# TensorBoard writer for logging
writer = SummaryWriter(log_dir=f'runs/rl_trading_{int(time.time())}')
# Main training loop
logger.info(f"Starting training for {num_episodes} episodes...")
logger.info(f"Starting training on device: {agent.device}")
# Training loop
try:
for episode in range(1, num_episodes + 1):
for episode in range(num_episodes):
state = env.reset()
total_reward = 0
done = False
steps = 0
while not done and steps < max_steps:
for step in range(max_steps):
# Select action
action = agent.act(state)
# Take action and observe next state and reward
next_state, reward, done, info = env.step(action)
# Store the experience in memory
agent.remember(state, action, reward, next_state, done)
# Learn from experience
loss = agent.replay()
# Update state and reward
state = next_state
total_reward += reward
steps += 1
# Train the agent by sampling from memory
if len(agent.memory) >= agent.batch_size:
loss = agent.replay()
if done or step == max_steps - 1:
break
# Calculate episode metrics
# Track rewards
episode_rewards.append(total_reward)
episode_pnls.append(info['gain'])
episode_win_rates.append(info['win_rate'])
episode_trades.append(info['trades'])
# Log progress
avg_reward = np.mean(episode_rewards[-100:])
logger.info(f"Episode {episode}/{num_episodes} - Reward: {total_reward:.4f}, " +
f"Avg (100): {avg_reward:.4f}, Epsilon: {agent.epsilon:.4f}")
# Calculate trading metrics
win_rate = env.win_rate if hasattr(env, 'win_rate') else 0
trades = env.trades if hasattr(env, 'trades') else 0
# Log to TensorBoard
writer.add_scalar('Reward/episode', total_reward, episode)
writer.add_scalar('PnL/episode', info['gain'], episode)
writer.add_scalar('WinRate/episode', info['win_rate'], episode)
writer.add_scalar('Trades/episode', info['trades'], episode)
writer.add_scalar('Epsilon/episode', agent.epsilon, episode)
writer.add_scalar('Reward/Episode', total_reward, episode)
writer.add_scalar('Reward/Average100', avg_reward, episode)
writer.add_scalar('Trade/WinRate', win_rate, episode)
writer.add_scalar('Trade/Count', trades, episode)
# Save the best model based on multiple metrics (only every 50 episodes)
is_better = False
if episode % 50 == 0: # Only check for saving every 50 episodes
if (info['gain'] > best_pnl and info['win_rate'] > 0.5) or \
(info['gain'] > best_pnl * 1.1) or \
(info['win_rate'] > best_win_rate * 1.1):
best_reward = total_reward
best_episode = episode
best_pnl = info['gain']
best_win_rate = info['win_rate']
agent.save(best_model_path)
is_better = True
# Save metadata about the best model
metadata = {
'best_reward': best_reward,
'best_episode': best_episode,
'best_pnl': best_pnl,
'best_win_rate': best_win_rate,
'date': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
with open(f"{best_model_path}_metadata.json", 'w') as f:
json.dump(metadata, f)
# Save best model
if avg_reward > best_reward and episode > 10:
logger.info(f"New best average reward: {avg_reward:.4f}, saving model")
agent.save(save_path)
best_reward = avg_reward
# Log training progress
if episode % 10 == 0:
avg_reward = sum(episode_rewards[-10:]) / 10
avg_pnl = sum(episode_pnls[-10:]) / 10
avg_win_rate = sum(episode_win_rates[-10:]) / 10
avg_trades = sum(episode_trades[-10:]) / 10
# Periodic save every 100 episodes
if episode % 100 == 0 and episode > 0:
agent.save(f"{save_path}_episode_{episode}")
status = "NEW BEST!" if is_better else ""
logger.info(f"Episode {episode}/{num_episodes} {status}")
logger.info(f"Metrics (last 10 episodes):")
logger.info(f" Reward: {avg_reward:.4f}")
logger.info(f" PnL: {avg_pnl:.4f}")
logger.info(f" Win Rate: {avg_win_rate:.4f}")
logger.info(f" Trades: {avg_trades:.2f}")
logger.info(f" Epsilon: {agent.epsilon:.4f}")
logger.info(f"Best so far - PnL: {best_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
except KeyboardInterrupt:
logger.info("Training interrupted by user. Saving best model...")
# Call episode callback if provided
if episode_callback:
# Add environment to info dict to use for extrema training
info_with_env = info.copy()
info_with_env['env'] = env
episode_callback(episode, total_reward, info_with_env)
# Final save
logger.info("Training completed, saving final model")
agent.save(f"{save_path}_final")
except Exception as e:
logger.error(f"Training failed: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Close TensorBoard writer
writer.close()
# Final logs
logger.info(f"Training completed. Best model from episode {best_episode}")
logger.info(f"Best metrics:")
logger.info(f" Reward: {best_reward:.4f}")
logger.info(f" PnL: {best_pnl:.4f}")
logger.info(f" Win Rate: {best_win_rate:.4f}")
# Return the agent for potential further use
return agent
if __name__ == "__main__":

View File

@ -25,10 +25,10 @@ class SignalInterpreter:
"""
self.config = config or {}
# Signal thresholds - higher thresholds for high-leverage trading
self.buy_threshold = self.config.get('buy_threshold', 0.65)
self.sell_threshold = self.config.get('sell_threshold', 0.65)
self.hold_threshold = self.config.get('hold_threshold', 0.75)
# Signal thresholds - lower thresholds to increase trade frequency
self.buy_threshold = self.config.get('buy_threshold', 0.35)
self.sell_threshold = self.config.get('sell_threshold', 0.35)
self.hold_threshold = self.config.get('hold_threshold', 0.60)
# Adaptive parameters
self.confidence_multiplier = self.config.get('confidence_multiplier', 1.0)
@ -45,14 +45,14 @@ class SignalInterpreter:
self.current_position = None # None = no position, 'long' = buy, 'short' = sell
# Filters for better signal quality
self.trend_filter_enabled = self.config.get('trend_filter_enabled', True)
self.volume_filter_enabled = self.config.get('volume_filter_enabled', True)
self.oscillation_filter_enabled = self.config.get('oscillation_filter_enabled', True)
self.trend_filter_enabled = self.config.get('trend_filter_enabled', False) # Disable trend filter by default
self.volume_filter_enabled = self.config.get('volume_filter_enabled', False) # Disable volume filter by default
self.oscillation_filter_enabled = self.config.get('oscillation_filter_enabled', False) # Disable oscillation filter by default
# Sensitivity parameters
self.min_price_movement = self.config.get('min_price_movement', 0.0005) # 0.05% minimum expected movement
self.hold_cooldown = self.config.get('hold_cooldown', 3) # Minimum periods to wait after a HOLD
self.consecutive_signals_required = self.config.get('consecutive_signals_required', 2)
self.min_price_movement = self.config.get('min_price_movement', 0.0001) # Lower price movement threshold
self.hold_cooldown = self.config.get('hold_cooldown', 1) # Shorter hold cooldown
self.consecutive_signals_required = self.config.get('consecutive_signals_required', 1) # Require only one signal
# State tracking
self.consecutive_buy_signals = 0