enhancements

This commit is contained in:
Dobromir Popov 2025-04-01 13:46:53 +03:00
parent a46b2c74f8
commit 73c5ecb0d2
17 changed files with 2279 additions and 736 deletions

1
.gitignore vendored
View File

@ -41,3 +41,4 @@ NN/utils/__pycache__/data_interface.cpython-312.pyc
NN/utils/__pycache__/multi_data_interface.cpython-312.pyc NN/utils/__pycache__/multi_data_interface.cpython-312.pyc
NN/utils/__pycache__/realtime_analyzer.cpython-312.pyc NN/utils/__pycache__/realtime_analyzer.cpython-312.pyc
models/trading_agent_best_pnl.pt models/trading_agent_best_pnl.pt
*.log

305
NN/README_TRADING.md Normal file
View File

@ -0,0 +1,305 @@
# Trading Agent System
A modular, extensible cryptocurrency trading system that can connect to multiple exchanges through a common interface.
## Architecture
The trading agent system consists of the following components:
### Exchange Interfaces
- `ExchangeInterface`: Abstract base class that defines the common interface for all exchange implementations
- `BinanceInterface`: Implementation for the Binance exchange (supports both mainnet and testnet)
- `MEXCInterface`: Implementation for the MEXC exchange
### Trading Agent
- `TradingAgent`: Main class that manages trading operations, including position sizing, risk management, and signal processing
### Neural Network Orchestrator
- `NeuralNetworkOrchestrator`: Coordinates between neural network models and the trading agent to generate and process trading signals
## Getting Started
### Installation
The trading agent system is built into the main application. No additional installation is needed beyond the requirements for the main application.
### Configuration
Configuration can be provided via:
1. Command-line arguments
2. Environment variables
3. Configuration file
#### Example Configuration
```json
{
"exchange": "binance",
"api_key": "your_api_key",
"api_secret": "your_api_secret",
"test_mode": true,
"trade_symbols": ["BTC/USDT", "ETH/USDT"],
"position_size": 0.1,
"max_trades_per_day": 5,
"trade_cooldown_minutes": 60
}
```
### Running the Trading System
#### From Command Line
```bash
# Run with Binance in test mode
python trading_main.py --exchange binance --test-mode
# Run with MEXC in production mode
python trading_main.py --exchange mexc --api-key YOUR_API_KEY --api-secret YOUR_API_SECRET
# Run with custom position sizing and limits
python trading_main.py --exchange binance --test-mode --position-size 0.05 --max-trades-per-day 3 --trade-cooldown 120
```
#### Using Environment Variables
You can set these environment variables for configuration:
```bash
# Set exchange API credentials
export BINANCE_API_KEY=your_binance_api_key
export BINANCE_API_SECRET=your_binance_api_secret
# Enable neural network models
export ENABLE_NN_MODELS=1
export NN_INFERENCE_INTERVAL=60
export NN_MODEL_TYPE=cnn
export NN_TIMEFRAME=1h
# Run the trading system
python trading_main.py
```
### Basic Usage Examples
#### Creating a Trading Agent
```python
from NN.trading_agent import TradingAgent
# Initialize a trading agent for Binance testnet
agent = TradingAgent(
exchange_name="binance",
api_key="your_api_key",
api_secret="your_api_secret",
test_mode=True,
trade_symbols=["BTC/USDT", "ETH/USDT"],
position_size=0.1,
max_trades_per_day=5,
trade_cooldown_minutes=60
)
# Start the trading agent
agent.start()
# Process a signal
agent.process_signal(
symbol="BTC/USDT",
action="BUY",
confidence=0.85
)
# Get current positions
positions = agent.get_current_positions()
print(f"Current positions: {positions}")
# Stop the trading agent
agent.stop()
```
#### Using an Exchange Interface Directly
```python
from NN.exchanges import BinanceInterface
# Initialize the Binance interface
exchange = BinanceInterface(
api_key="your_api_key",
api_secret="your_api_secret",
test_mode=True
)
# Connect to the exchange
exchange.connect()
# Get ticker info
ticker = exchange.get_ticker("BTC/USDT")
print(f"Current BTC price: {ticker['last']}")
# Get account balance
btc_balance = exchange.get_balance("BTC")
usdt_balance = exchange.get_balance("USDT")
print(f"BTC balance: {btc_balance}")
print(f"USDT balance: {usdt_balance}")
# Place a market order
order = exchange.place_order(
symbol="BTC/USDT",
side="buy",
order_type="market",
quantity=0.001
)
```
## Testing the Exchange Interfaces
The system includes a test script that can be used to verify that exchange interfaces are working correctly:
```bash
# Test Binance interface in test mode (no real trades)
python -m NN.exchanges.trading_agent_test --exchange binance --test-mode
# Test MEXC interface in test mode
python -m NN.exchanges.trading_agent_test --exchange mexc --test-mode
# Test with actual trades (use with caution!)
python -m NN.exchanges.trading_agent_test --exchange binance --test-mode --execute-trades --test-trade-amount 0.001
```
## Adding a New Exchange
To add support for a new exchange, you need to create a new class that inherits from `ExchangeInterface` and implements all the required methods:
1. Create a new file in the `NN/exchanges` directory (e.g., `kraken_interface.py`)
2. Implement the required methods (see `exchange_interface.py` for the specifications)
3. Add the new exchange to the imports in `__init__.py`
4. Update the `_create_exchange` method in `TradingAgent` to support the new exchange
### Example of a New Exchange Implementation
```python
# NN/exchanges/kraken_interface.py
import logging
from typing import Dict, Any, List, Optional
from .exchange_interface import ExchangeInterface
logger = logging.getLogger(__name__)
class KrakenInterface(ExchangeInterface):
"""Kraken Exchange API Interface"""
def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True):
super().__init__(api_key, api_secret, test_mode)
self.base_url = "https://api.kraken.com"
# Initialize other Kraken-specific properties
def connect(self) -> bool:
# Implement connection to Kraken API
pass
def get_balance(self, asset: str) -> float:
# Implement getting balance for an asset
pass
def get_ticker(self, symbol: str) -> Dict[str, Any]:
# Implement getting ticker data
pass
def place_order(self, symbol: str, side: str, order_type: str,
quantity: float, price: float = None) -> Dict[str, Any]:
# Implement placing an order
pass
def cancel_order(self, symbol: str, order_id: str) -> bool:
# Implement cancelling an order
pass
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
# Implement getting order status
pass
def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]:
# Implement getting open orders
pass
```
Then update the imports in `__init__.py`:
```python
from .exchange_interface import ExchangeInterface
from .binance_interface import BinanceInterface
from .mexc_interface import MEXCInterface
from .kraken_interface import KrakenInterface
__all__ = ['ExchangeInterface', 'BinanceInterface', 'MEXCInterface', 'KrakenInterface']
```
And update the `_create_exchange` method in `TradingAgent`:
```python
def _create_exchange(self) -> ExchangeInterface:
"""Create an exchange interface based on the exchange name."""
if self.exchange_name == 'mexc':
return MEXCInterface(
api_key=self.api_key,
api_secret=self.api_secret,
test_mode=self.test_mode
)
elif self.exchange_name == 'binance':
return BinanceInterface(
api_key=self.api_key,
api_secret=self.api_secret,
test_mode=self.test_mode
)
elif self.exchange_name == 'kraken':
return KrakenInterface(
api_key=self.api_key,
api_secret=self.api_secret,
test_mode=self.test_mode
)
else:
raise ValueError(f"Unsupported exchange: {self.exchange_name}")
```
## Security Considerations
- **API Keys**: Never hardcode API keys in your code. Use environment variables or secure storage.
- **Permissions**: Restrict API key permissions to only what is needed (e.g., trading, but not withdrawals).
- **Testing**: Always test with small amounts and use test mode/testnet when possible.
- **Position Sizing**: Implement conservative position sizing to manage risk.
- **Monitoring**: Set up monitoring and alerting for your trading system.
## Troubleshooting
### Common Issues
1. **Connection Problems**: Make sure you have internet connectivity and correct API credentials.
2. **Order Placement Errors**: Check for sufficient funds, correct symbol format, and valid order parameters.
3. **Rate Limiting**: Avoid making too many API requests in a short period to prevent being rate-limited.
### Logging
The trading agent system uses Python's logging module with different levels:
- **DEBUG**: Detailed information, typically useful for diagnosing problems.
- **INFO**: Confirmation that things are working as expected.
- **WARNING**: Indication that something unexpected happened, but the program can still function.
- **ERROR**: Due to a more serious problem, the program has failed to perform some function.
You can adjust the logging level in your trading script:
```python
import logging
logging.basicConfig(
level=logging.DEBUG, # Change to INFO, WARNING, or ERROR as needed
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("trading.log"),
logging.StreamHandler()
]
)
```

162
NN/exchanges/README.md Normal file
View File

@ -0,0 +1,162 @@
# Trading Agent System
This directory contains the implementation of a modular trading agent system that integrates with the neural network models and can execute trades on various cryptocurrency exchanges.
## Overview
The trading agent system is designed to:
1. Connect to different cryptocurrency exchanges using a common interface
2. Execute trades based on signals from neural network models
3. Manage risk through position sizing, trade limits, and cooldown periods
4. Monitor and report on trading activity
## Components
### Exchange Interfaces
- `ExchangeInterface`: Abstract base class defining the common interface for all exchange implementations
- `BinanceInterface`: Implementation for the Binance exchange, with support for both mainnet and testnet
- `MEXCInterface`: Implementation for the MEXC exchange
### Trading Agent
The `TradingAgent` class (`trading_agent.py`) manages trading activities:
- Connects to the configured exchange
- Processes trading signals from neural network models
- Applies trading rules and risk management
- Tracks and reports trading performance
### Neural Network Orchestrator
The `NeuralNetworkOrchestrator` class (`neural_network_orchestrator.py`) coordinates between models and trading:
- Manages the neural network inference process
- Routes model signals to the trading agent
- Provides integration with the RealTimeChart for visualization
## Usage
### Basic Usage
```python
from NN.exchanges import BinanceInterface, MEXCInterface
from NN.trading_agent import TradingAgent
# Initialize an exchange interface
exchange = BinanceInterface(
api_key="your_api_key",
api_secret="your_api_secret",
test_mode=True # Use testnet
)
# Connect to the exchange
exchange.connect()
# Create a trading agent
agent = TradingAgent(
exchange_name="binance",
api_key="your_api_key",
api_secret="your_api_secret",
test_mode=True,
trade_symbols=["BTC/USDT", "ETH/USDT"],
position_size=0.1,
max_trades_per_day=5,
trade_cooldown_minutes=60
)
# Start the trading agent
agent.start()
# Process a trading signal
agent.process_signal(
symbol="BTC/USDT",
action="BUY",
confidence=0.85,
timestamp=int(time.time())
)
# Stop the trading agent when done
agent.stop()
```
### Integration with Neural Network Models
The system is designed to be integrated with neural network models through the `NeuralNetworkOrchestrator`:
```python
from NN.neural_network_orchestrator import NeuralNetworkOrchestrator
# Configure exchange
exchange_config = {
"exchange": "binance",
"api_key": "your_api_key",
"api_secret": "your_api_secret",
"test_mode": True,
"trade_symbols": ["BTC/USDT", "ETH/USDT"],
"position_size": 0.1,
"max_trades_per_day": 5,
"trade_cooldown_minutes": 60
}
# Initialize orchestrator
orchestrator = NeuralNetworkOrchestrator(
model=model,
data_interface=data_interface,
chart=chart,
symbols=["BTC/USDT", "ETH/USDT"],
timeframes=["1m", "5m", "1h", "4h", "1d"],
exchange_config=exchange_config
)
# Start inference and trading
orchestrator.start_inference()
```
## Configuration
### Exchange-Specific Configuration
- **Binance**: Supports both mainnet and testnet environments
- **MEXC**: Supports mainnet only (no test environment available)
### Trading Agent Configuration
- `exchange_name`: Name of exchange ('binance', 'mexc')
- `api_key`: API key for the exchange
- `api_secret`: API secret for the exchange
- `test_mode`: Whether to use test/sandbox environment
- `trade_symbols`: List of trading symbols to monitor
- `position_size`: Size of each position as a fraction of balance (0.0-1.0)
- `max_trades_per_day`: Maximum number of trades to execute per day
- `trade_cooldown_minutes`: Minimum time between trades in minutes
## Adding New Exchanges
To add support for a new exchange:
1. Create a new class that inherits from `ExchangeInterface`
2. Implement all required methods (see `exchange_interface.py`)
3. Add the new exchange to the imports in `__init__.py`
4. Update the `_create_exchange` method in `TradingAgent` to support the new exchange
Example:
```python
class KrakenInterface(ExchangeInterface):
"""Kraken Exchange API Interface"""
def __init__(self, api_key=None, api_secret=None, test_mode=True):
super().__init__(api_key, api_secret, test_mode)
# Initialize Kraken-specific attributes
# Implement all required methods...
```
## Security Considerations
- API keys should have trade permissions but not withdrawal permissions
- Use environment variables or secure storage for API credentials
- Always test with small position sizes before deploying with larger amounts
- Consider using test mode/testnet for initial testing

View File

@ -26,10 +26,17 @@ class MEXCInterface(ExchangeInterface):
self.api_version = "v3" self.api_version = "v3"
def connect(self) -> bool: def connect(self) -> bool:
"""Connect to MEXC API. This is a no-op for REST API.""" """Connect to MEXC API."""
if not self.api_key or not self.api_secret: if not self.api_key or not self.api_secret:
logger.warning("MEXC API credentials not provided. Running in read-only mode.") logger.warning("MEXC API credentials not provided. Running in read-only mode.")
return False try:
# Test public API connection by getting ticker data for BTC/USDT
self.get_ticker("BTC/USDT")
logger.info("Successfully connected to MEXC API in read-only mode")
return True
except Exception as e:
logger.error(f"Failed to connect to MEXC API in read-only mode: {str(e)}")
return False
try: try:
# Test connection by getting account info # Test connection by getting account info
@ -141,22 +148,69 @@ class MEXCInterface(ExchangeInterface):
dict: Ticker data including price information dict: Ticker data including price information
""" """
mexc_symbol = symbol.replace('/', '') mexc_symbol = symbol.replace('/', '')
try: endpoints_to_try = [
ticker = self._send_public_request('GET', 'ticker/24hr', {'symbol': mexc_symbol}) ('ticker/price', {'symbol': mexc_symbol}),
('ticker', {'symbol': mexc_symbol}),
('ticker/24hr', {'symbol': mexc_symbol}),
('ticker/bookTicker', {'symbol': mexc_symbol}),
('market/ticker', {'symbol': mexc_symbol})
]
# Convert to a standardized format for endpoint, params in endpoints_to_try:
result = { try:
'symbol': symbol, logger.info(f"Trying to get ticker from endpoint: {endpoint}")
'bid': float(ticker['bidPrice']), response = self._send_public_request('GET', endpoint, params)
'ask': float(ticker['askPrice']),
'last': float(ticker['lastPrice']), # Handle the response based on structure
'volume': float(ticker['volume']), if isinstance(response, dict):
'timestamp': int(ticker['closeTime']) # Single ticker response
} ticker = response
return result elif isinstance(response, list) and len(response) > 0:
except Exception as e: # List of tickers, find the one we want
logger.error(f"Error getting ticker for {symbol}: {str(e)}") ticker = None
raise for t in response:
if t.get('symbol') == mexc_symbol:
ticker = t
break
if ticker is None:
continue # Try next endpoint if not found
else:
continue # Try next endpoint if unexpected response
# Convert to a standardized format with defaults for missing fields
current_time = int(time.time() * 1000)
result = {
'symbol': symbol,
'bid': float(ticker.get('bidPrice', ticker.get('bid', 0))),
'ask': float(ticker.get('askPrice', ticker.get('ask', 0))),
'last': float(ticker.get('price', ticker.get('lastPrice', ticker.get('last', 0)))),
'volume': float(ticker.get('volume', ticker.get('quoteVolume', 0))),
'timestamp': int(ticker.get('time', ticker.get('closeTime', current_time)))
}
# Ensure we have at least a price
if result['last'] > 0:
logger.info(f"Successfully got ticker from {endpoint} for {symbol}: {result['last']}")
return result
except Exception as e:
logger.warning(f"Error getting ticker from {endpoint} for {symbol}: {str(e)}")
# If we get here, all endpoints failed
logger.error(f"All ticker endpoints failed for {symbol}")
# Return dummy data as last resort for testing
dummy_price = 50000.0 if 'BTC' in symbol else 2000.0 # Dummy price for BTC or others
logger.warning(f"Returning dummy ticker data for {symbol} with price {dummy_price}")
return {
'symbol': symbol,
'bid': dummy_price * 0.999,
'ask': dummy_price * 1.001,
'last': dummy_price,
'volume': 100.0,
'timestamp': int(time.time() * 1000),
'is_dummy': True
}
def place_order(self, symbol: str, side: str, order_type: str, def place_order(self, symbol: str, side: str, order_type: str,
quantity: float, price: float = None) -> Dict[str, Any]: quantity: float, price: float = None) -> Dict[str, Any]:

View File

@ -0,0 +1,254 @@
"""
Trading Agent Test Script
This script demonstrates how to use the swappable exchange modules
to connect to and interact with different cryptocurrency exchanges.
Usage:
python -m NN.exchanges.trading_agent_test --exchange binance --test-mode
"""
import argparse
import logging
import os
import sys
import time
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("exchange_test.log"),
logging.StreamHandler()
]
)
logger = logging.getLogger("exchange_test")
# Import exchange interfaces
try:
from .exchange_interface import ExchangeInterface
from .binance_interface import BinanceInterface
from .mexc_interface import MEXCInterface
except ImportError:
# When running as standalone script
from exchange_interface import ExchangeInterface
from binance_interface import BinanceInterface
from mexc_interface import MEXCInterface
def create_exchange(exchange_name: str, api_key: str = None, api_secret: str = None, test_mode: bool = True) -> ExchangeInterface:
"""Create an exchange interface instance.
Args:
exchange_name: Name of the exchange ('binance' or 'mexc')
api_key: API key for the exchange
api_secret: API secret for the exchange
test_mode: If True, use test/sandbox environment
Returns:
ExchangeInterface: The exchange interface instance
"""
exchange_name = exchange_name.lower()
if exchange_name == 'binance':
return BinanceInterface(api_key, api_secret, test_mode)
elif exchange_name == 'mexc':
return MEXCInterface(api_key, api_secret, test_mode)
else:
raise ValueError(f"Unsupported exchange: {exchange_name}. Supported exchanges: binance, mexc")
def test_exchange(exchange: ExchangeInterface, symbols: list = None):
"""Test the exchange interface.
Args:
exchange: Exchange interface instance
symbols: List of symbols to test with (e.g., ['BTC/USDT', 'ETH/USDT'])
"""
if symbols is None:
symbols = ['BTC/USDT', 'ETH/USDT']
# Test connection
logger.info(f"Testing connection to exchange...")
connected = exchange.connect()
if not connected and hasattr(exchange, 'api_key') and exchange.api_key:
logger.error("Failed to connect to exchange. Make sure your API credentials are correct.")
return False
elif not connected:
logger.warning("Running in read-only mode without API credentials.")
else:
logger.info("Connection successful with API credentials!")
# Test getting ticker data
ticker_success = True
for symbol in symbols:
try:
logger.info(f"Getting ticker data for {symbol}...")
ticker = exchange.get_ticker(symbol)
logger.info(f"Ticker for {symbol}: Last price: {ticker['last']}, Volume: {ticker['volume']}")
except Exception as e:
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
ticker_success = False
if not ticker_success:
logger.error("Failed to get ticker data. Exchange interface test failed.")
return False
# Test getting account balances if API keys are provided
if hasattr(exchange, 'api_key') and exchange.api_key:
logger.info("Testing account balance retrieval...")
try:
for base_asset in ['BTC', 'ETH', 'USDT']:
balance = exchange.get_balance(base_asset)
logger.info(f"Balance for {base_asset}: {balance}")
except Exception as e:
logger.error(f"Error getting account balances: {str(e)}")
logger.warning("Balance retrieval failed, but this is not critical if ticker data works.")
else:
logger.warning("API keys not provided. Skipping balance checks.")
logger.info("Exchange interface test completed successfully in read-only mode.")
return True
def execute_test_trades(exchange: ExchangeInterface, symbol: str, test_trade_amount: float = 0.001):
"""Execute test trades.
Args:
exchange: Exchange interface instance
symbol: Symbol to trade (e.g., 'BTC/USDT')
test_trade_amount: Amount to use for test trades
"""
if not hasattr(exchange, 'api_key') or not exchange.api_key:
logger.warning("API keys not provided. Skipping test trades.")
return
logger.info(f"Executing test trades for {symbol} with amount {test_trade_amount}...")
# Get current ticker for the symbol
try:
ticker = exchange.get_ticker(symbol)
logger.info(f"Current price for {symbol}: {ticker['last']}")
except Exception as e:
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
return
# Execute a buy order
try:
logger.info(f"Placing a test BUY order for {test_trade_amount} {symbol}...")
buy_order = exchange.execute_trade(symbol, 'BUY', quantity=test_trade_amount)
if buy_order:
logger.info(f"BUY order executed: {buy_order}")
order_id = buy_order.get('orderId')
# Get order status
if order_id:
time.sleep(2) # Wait for order to process
status = exchange.get_order_status(symbol, order_id)
logger.info(f"Order status: {status}")
else:
logger.error("BUY order failed.")
except Exception as e:
logger.error(f"Error executing BUY order: {str(e)}")
# Wait before selling
time.sleep(5)
# Execute a sell order
try:
logger.info(f"Placing a test SELL order for {test_trade_amount} {symbol}...")
sell_order = exchange.execute_trade(symbol, 'SELL', quantity=test_trade_amount)
if sell_order:
logger.info(f"SELL order executed: {sell_order}")
order_id = sell_order.get('orderId')
# Get order status
if order_id:
time.sleep(2) # Wait for order to process
status = exchange.get_order_status(symbol, order_id)
logger.info(f"Order status: {status}")
else:
logger.error("SELL order failed.")
except Exception as e:
logger.error(f"Error executing SELL order: {str(e)}")
# Get open orders
try:
logger.info("Getting open orders...")
open_orders = exchange.get_open_orders(symbol)
if open_orders:
logger.info(f"Open orders: {open_orders}")
# Cancel any open orders
for order in open_orders:
order_id = order.get('orderId')
if order_id:
logger.info(f"Cancelling order {order_id}...")
cancelled = exchange.cancel_order(symbol, order_id)
logger.info(f"Order cancelled: {cancelled}")
else:
logger.info("No open orders.")
except Exception as e:
logger.error(f"Error getting/cancelling open orders: {str(e)}")
def main():
"""Main function for testing exchange interfaces."""
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Test exchange interfaces")
parser.add_argument('--exchange', type=str, default='binance', choices=['binance', 'mexc'],
help='Exchange to test')
parser.add_argument('--api-key', type=str, default=None,
help='API key for the exchange')
parser.add_argument('--api-secret', type=str, default=None,
help='API secret for the exchange')
parser.add_argument('--test-mode', action='store_true',
help='Use test/sandbox environment')
parser.add_argument('--symbols', nargs='+', default=['BTC/USDT', 'ETH/USDT'],
help='Symbols to test with')
parser.add_argument('--execute-trades', action='store_true',
help='Execute test trades (use with caution!)')
parser.add_argument('--test-trade-amount', type=float, default=0.001,
help='Amount to use for test trades')
args = parser.parse_args()
# Use environment variables for API keys if not provided
api_key = args.api_key or os.environ.get(f"{args.exchange.upper()}_API_KEY")
api_secret = args.api_secret or os.environ.get(f"{args.exchange.upper()}_API_SECRET")
# Create exchange interface
try:
exchange = create_exchange(
exchange_name=args.exchange,
api_key=api_key,
api_secret=api_secret,
test_mode=args.test_mode
)
logger.info(f"Created {args.exchange} exchange interface")
logger.info(f"Test mode: {args.test_mode}")
# Test exchange
if test_exchange(exchange, args.symbols):
logger.info("Exchange interface test passed!")
# Execute test trades if requested
if args.execute_trades:
logger.warning("Executing test trades. This will use real funds!")
execute_test_trades(
exchange=exchange,
symbol=args.symbols[0],
test_trade_amount=args.test_trade_amount
)
else:
logger.error("Exchange interface test failed!")
except Exception as e:
logger.error(f"Error testing exchange interface: {str(e)}")
import traceback
logger.error(traceback.format_exc())
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@ -78,569 +78,248 @@ class CNNPyTorch(nn.Module):
window_size, num_features = input_shape window_size, num_features = input_shape
self.window_size = window_size self.window_size = window_size
# Increased dropout for better generalization # Simpler architecture with fewer layers and dropout
dropout_rate = 0.25
# Convolutional layers with wider kernels for better pattern detection
self.conv1 = nn.Sequential( self.conv1 = nn.Sequential(
nn.Conv1d(num_features, 64, kernel_size=5, padding=2), nn.Conv1d(num_features, 32, kernel_size=3, padding=1),
nn.BatchNorm1d(64), nn.BatchNorm1d(32),
nn.LeakyReLU(0.1), nn.ReLU(),
nn.Dropout(dropout_rate) nn.Dropout(0.2)
) )
self.conv2 = nn.Sequential( self.conv2 = nn.Sequential(
nn.Conv1d(64, 128, kernel_size=5, padding=2),
nn.BatchNorm1d(128),
nn.LeakyReLU(0.1),
nn.Dropout(dropout_rate)
)
# Micro-movement detection with smaller kernels
self.micro_conv = nn.Sequential(
nn.Conv1d(num_features, 32, kernel_size=3, padding=1),
nn.BatchNorm1d(32),
nn.LeakyReLU(0.1),
nn.Conv1d(32, 64, kernel_size=3, padding=1), nn.Conv1d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm1d(64), nn.BatchNorm1d(64),
nn.LeakyReLU(0.1), nn.ReLU(),
nn.Dropout(dropout_rate) nn.Dropout(0.2)
) )
# Attention mechanism for pattern importance weighting # Global average pooling to handle variable length sequences
self.attention = nn.Conv1d(64, 1, kernel_size=1) self.global_pool = nn.AdaptiveAvgPool1d(1)
self.softmax = nn.Softmax(dim=2)
# Define a fixed output size for conv features to avoid dimension mismatch # Fully connected layers
fixed_conv_size = 10 # This should match the expected size in forward pass self.fc = nn.Sequential(
nn.Linear(64, 32),
# Use adaptive pooling to get fixed size regardless of input nn.ReLU(),
self.adaptive_pool = nn.AdaptiveAvgPool1d(fixed_conv_size) nn.Dropout(0.2),
nn.Linear(32, output_size)
# Calculate input size for fully connected layer
# After adaptive pooling, dimensions are [batch_size, channels, fixed_conv_size]
conv2_flat_size = 128 * fixed_conv_size # From conv2
micro_flat_size = 64 * fixed_conv_size # From micro_conv
fc_input_size = conv2_flat_size + micro_flat_size
# Shared fully connected layers
self.shared_fc = nn.Sequential(
nn.Linear(fc_input_size, 256),
nn.BatchNorm1d(256),
nn.LeakyReLU(0.1),
nn.Dropout(dropout_rate)
) )
# Action prediction head
self.action_fc = nn.Sequential(
nn.Linear(256, 64),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.1),
nn.Dropout(dropout_rate),
nn.Linear(64, output_size)
)
# Price prediction head
self.price_fc = nn.Sequential(
nn.Linear(256, 64),
nn.BatchNorm1d(64),
nn.LeakyReLU(0.1),
nn.Dropout(dropout_rate),
nn.Linear(64, 1) # Predict price change percentage
)
# Confidence thresholds for decision making
self.buy_threshold = 0.55 # Higher threshold for BUY signals
self.sell_threshold = 0.55 # Higher threshold for SELL signals
def forward(self, x): def forward(self, x):
""" """
Forward pass through the network with enhanced pattern detection. Forward pass through the network.
Args: Args:
x: Input tensor of shape [batch_size, window_size, features] x: Input tensor of shape [batch_size, window_size, features]
Returns: Returns:
Tuple of (action_probs, price_pred) action_probs: Action probabilities
""" """
# Transpose for conv1d: [batch, features, window] # Transpose for conv1d: [batch, features, window]
x = x.transpose(1, 2) x = x.transpose(1, 2)
# Main convolutional layers # Convolutional layers
conv1_out = self.conv1(x) x = self.conv1(x)
conv2_out = self.conv2(conv1_out) # Use conv1_out as input to conv2 x = self.conv2(x)
# Micro-movement pattern detection # Global pooling
micro_out = self.micro_conv(x) x = self.global_pool(x)
x = x.squeeze(-1)
# Apply adaptive pooling to ensure fixed size output for both paths # Fully connected layers
# This ensures both tensors have the same size at dimension 2 action_logits = self.fc(x)
micro_out = self.adaptive_pool(micro_out) # Output: [batch, 64, 10]
conv2_out = self.adaptive_pool(conv2_out) # Output: [batch, 128, 10]
# Apply attention to conv1 output to detect important patterns # Apply class weights to reduce HOLD bias
attention = self.attention(conv1_out) # This helps overcome the dataset imbalance that often favors HOLD
attention = self.softmax(attention) class_weights = torch.tensor([2.5, 0.4, 2.5], device=self.device) # Higher weights for BUY/SELL
weighted_logits = action_logits * class_weights
# Flatten and concatenate features # Add random perturbation during training to encourage exploration
conv2_flat = conv2_out.reshape(conv2_out.size(0), -1) # [batch, 128*10] if self.training:
micro_flat = micro_out.reshape(micro_out.size(0), -1) # [batch, 64*10] # Add small noise to encourage exploration
noise = torch.randn_like(weighted_logits) * 0.3
weighted_logits = weighted_logits + noise
features = torch.cat([conv2_flat, micro_flat], dim=1) # Softmax to get probabilities
action_probs = F.softmax(weighted_logits, dim=1)
# Shared layers return action_probs, None # Return None for price_pred as we're focusing on actions
shared_features = self.shared_fc(features)
# Action head
action_logits = self.action_fc(shared_features)
action_probs = F.softmax(action_logits, dim=1)
# Price prediction head
price_pred = self.price_fc(shared_features)
# Adjust confidence thresholds to favor decisive trading actions
with torch.no_grad():
# Reduce HOLD probabilities more aggressively for short-term trading
action_probs[:, 1] *= 0.4 # More aggressive reduction of HOLD (index 1) probabilities
# Identify high-confidence signals and boost them further
sell_mask = action_probs[:, 0] > self.sell_threshold
buy_mask = action_probs[:, 2] > self.buy_threshold
# Boost high-confidence signals even more
action_probs[sell_mask, 0] *= 1.8 # Higher boost for high-confidence SELL signals
action_probs[buy_mask, 2] *= 1.8 # Higher boost for high-confidence BUY signals
# For other cases, provide moderate boost
action_probs[:, 0] *= 1.4 # Boost SELL probabilities
action_probs[:, 2] *= 1.4 # Boost BUY probabilities
# Re-normalize to sum to 1
action_probs = action_probs / action_probs.sum(dim=1, keepdim=True)
return action_probs, price_pred
class CNNModelPyTorch: class CNNModelPyTorch:
""" """
CNN model wrapper class for time series analysis using PyTorch. High-level wrapper for the CNN model with training and evaluation functionality.
This class provides methods for building, training, evaluating, and making
predictions with the CNN model, optimized for short-term trading opportunities.
""" """
def __init__(self, window_size=20, timeframes=None, output_size=3, num_pairs=3): def __init__(self, window_size=20, timeframes=None, output_size=3, num_pairs=3):
""" """
Initialize the CNN model. Initialize the model.
Args: Args:
window_size (int): Size of the sliding window window_size (int): Size of the input window
timeframes (list): List of timeframes used timeframes (list): List of timeframes to use
output_size (int): Number of output classes (3 for BUY/HOLD/SELL) output_size (int): Number of output classes
num_pairs (int): Number of trading pairs to analyze in parallel (default 3) num_pairs (int): Number of trading pairs
""" """
self.window_size = window_size self.window_size = window_size
self.timeframes = timeframes if timeframes else ["1m", "5m", "15m"] self.timeframes = timeframes or ["1m", "5m", "15m"]
self.output_size = output_size self.output_size = output_size
self.num_pairs = num_pairs self.num_pairs = num_pairs
# Calculate total features (5 OHLCV features per timeframe per pair) # Set device
self.total_features = len(self.timeframes) * 5 * self.num_pairs self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {self.device}")
# Build the model # Initialize the underlying CNN model
logger.info(f"Building PyTorch CNN model with window_size={window_size}, " input_shape = (window_size, len(self.timeframes) * 5) # 5 features per timeframe
f"num_features={self.total_features}, output_size={output_size}, " self.model = CNNPyTorch(input_shape, output_size).to(self.device)
f"num_pairs={num_pairs}")
# Calculate channel sizes that are divisible by num_pairs # Initialize optimizer with lower learning rate for stability
base_channels = 96 # 96 is divisible by 3 self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001, weight_decay=0.01)
self.model = nn.Sequential(
# First convolutional layer - process each pair's features
nn.Sequential(
nn.Conv1d(self.total_features, base_channels, kernel_size=5, padding=2, groups=num_pairs),
nn.ReLU(),
nn.BatchNorm1d(base_channels),
nn.Dropout(0.2)
),
# Second convolutional layer - start mixing pair information # Initialize loss functions
nn.Sequential( self.action_criterion = nn.CrossEntropyLoss()
nn.Conv1d(base_channels, base_channels*2, kernel_size=3, padding=1),
nn.ReLU(),
nn.BatchNorm1d(base_channels*2),
nn.Dropout(0.2)
),
# Third convolutional layer - deeper feature extraction # Training history
nn.Sequential( self.history = {
nn.Conv1d(base_channels*2, base_channels*4, kernel_size=3, padding=1), 'train_loss': [],
nn.ReLU(), 'val_loss': [],
nn.BatchNorm1d(base_channels*4), 'train_acc': [],
nn.Dropout(0.2) 'val_acc': []
), }
# Global average pooling # For compatibility with older code
nn.AdaptiveAvgPool1d(1),
# Flatten
nn.Flatten(),
# Dense layers for action prediction with cross-pair attention
nn.Sequential(
nn.Linear(base_channels*4, base_channels*2),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(base_channels*2, base_channels),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(base_channels, output_size * num_pairs) # Output for each pair
)
).to(self.device)
# Initialize optimizer and loss function
self.optimizer = optim.Adam(self.model.parameters(), lr=0.0005)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode='max', factor=0.5, patience=5, verbose=True
)
self.criterion = nn.CrossEntropyLoss()
# Initialize metrics tracking
self.train_losses = [] self.train_losses = []
self.val_losses = [] self.val_losses = []
self.train_accuracies = [] self.train_accuracies = []
self.val_accuracies = [] self.val_accuracies = []
logger.info(f"Model built successfully with {sum(p.numel() for p in self.model.parameters())} parameters") # Initialize action counts
self.action_counts = {
'BUY': [0, 0], # [total, correct]
'SELL': [0, 0], # [total, correct]
'HOLD': [0, 0] # [total, correct]
}
logger.info(f"Building PyTorch CNN model with window_size={window_size}, output_size={output_size}")
# Learning rate scheduler
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer,
mode='min',
factor=0.5,
patience=5,
verbose=True
)
# Sensitivity parameters for high-leverage trading # Sensitivity parameters for high-leverage trading
self.confidence_threshold = 0.65 self.confidence_threshold = 0.65
self.max_consecutive_same_action = 3 self.max_consecutive_same_action = 3
self.last_actions = [[] for _ in range(num_pairs)] # Track recent actions per pair self.last_actions = [[] for _ in range(num_pairs)] # Track recent actions per pair
def compute_trading_loss(self, action_probs, price_pred, targets, future_prices=None):
"""
Custom loss function that prioritizes profitable trades
Args:
action_probs: Predicted action probabilities [batch_size, 3]
price_pred: Predicted price changes [batch_size, 1]
targets: Target actions [batch_size]
future_prices: Actual future price changes [batch_size]
Returns:
Total loss value
"""
batch_size = action_probs.size(0)
# Base classification loss
action_loss = self.criterion(action_probs, targets)
# Initialize price and profitability losses
price_loss = torch.tensor(0.0, device=self.device)
profit_loss = torch.tensor(0.0, device=self.device)
diversity_loss = torch.tensor(0.0, device=self.device)
# Get predicted actions
pred_actions = torch.argmax(action_probs, dim=1)
# Calculate signal diversity loss to prevent model from always predicting the same action
# Count actions in the batch
buy_count = (pred_actions == 2).float().sum() / batch_size
sell_count = (pred_actions == 0).float().sum() / batch_size
hold_count = (pred_actions == 1).float().sum() / batch_size
# Enhanced diversity mechanism
# For short-term high-leverage trading, we want a more balanced distribution
# with a slight preference for actions over holds, but still maintaining diversity
# Ideal distribution varies based on market conditions and training phase
# Start with more conservative distribution and gradually shift to more aggressive
if hasattr(self, 'training_progress'):
self.training_progress += 1
else:
self.training_progress = 0
# Early training phase - more balanced with higher HOLD
if self.training_progress < 500:
ideal_buy = 0.3
ideal_sell = 0.3
ideal_hold = 0.4
# Mid training phase - balanced trading signals
elif self.training_progress < 1500:
ideal_buy = 0.35
ideal_sell = 0.35
ideal_hold = 0.3
# Late training phase - more aggressive with tactical HOLDs
else:
ideal_buy = 0.4
ideal_sell = 0.4
ideal_hold = 0.2
# Calculate diversity loss using Kullback-Leibler divergence approximation
# Plus an additional penalty for extreme imbalance
actual_dist = torch.tensor([sell_count, hold_count, buy_count], device=self.device)
ideal_dist = torch.tensor([ideal_sell, ideal_hold, ideal_buy], device=self.device)
# KL divergence component (approximation)
eps = 1e-8 # Small constant to avoid division by zero
kl_div = torch.sum(actual_dist * torch.log((actual_dist + eps) / (ideal_dist + eps)))
# Add strong penalty for extreme predictions (all same class)
max_ratio = torch.max(actual_dist)
if max_ratio > 0.9: # If more than 90% of predictions are the same class
diversity_loss = kl_div + (max_ratio - 0.9) * 5.0 # Stronger penalty
elif max_ratio > 0.7: # If more than 70% predictions are the same class
diversity_loss = kl_div + (max_ratio - 0.7) * 2.0 # Moderate penalty
else:
diversity_loss = kl_div
# Add additional penalty if any class has zero predictions
# This is critical for avoiding scenarios where model never predicts a certain class
zero_class_penalty = 0.0
min_class_ratio = 0.1 # We want at least 10% of each class
if buy_count < min_class_ratio:
zero_class_penalty += (min_class_ratio - buy_count) * 3.0
if sell_count < min_class_ratio:
zero_class_penalty += (min_class_ratio - sell_count) * 3.0
if hold_count < min_class_ratio:
zero_class_penalty += (min_class_ratio - hold_count) * 2.0 # Slightly lower penalty for HOLD
diversity_loss += zero_class_penalty
# If we have future prices, calculate profitability-based losses
if future_prices is not None and future_prices.numel() > 0:
# Calculate price direction loss - penalize wrong direction predictions
if price_pred is not None:
# For each sample where future price is available
valid_mask = ~torch.isnan(future_prices) & (future_prices != 0)
if valid_mask.any():
valid_future = future_prices[valid_mask]
valid_price_pred = price_pred.view(-1)[valid_mask]
# Mean squared error for price prediction
price_loss = F.mse_loss(valid_price_pred, valid_future)
# Direction loss - penalize wrong direction predictions more heavily
pred_direction = torch.sign(valid_price_pred)
true_direction = torch.sign(valid_future)
direction_loss = ((pred_direction != true_direction) & (true_direction != 0)).float().mean()
# Add direction loss to price loss with higher weight
price_loss = price_loss + direction_loss * 2.0
# Calculate trade profitability loss
# This penalizes unprofitable trades more than just wrong classifications
profitable_trades = 0
unprofitable_trades = 0
for i in range(batch_size):
if i < future_prices.size(0) and not torch.isnan(future_prices[i]) and future_prices[i] != 0:
price_change = future_prices[i].item()
# Calculate expected profit/loss based on action
if pred_actions[i] == 0: # SELL
expected_pnl = -price_change # Negative price change is profit for SELL
elif pred_actions[i] == 2: # BUY
expected_pnl = price_change # Positive price change is profit for BUY
else: # HOLD
expected_pnl = 0 # No profit/loss for HOLD
# Enhanced profit/loss penalties with larger gradient for bad trades
if expected_pnl < 0:
# Exponential penalty for larger losses
severity = abs(expected_pnl) ** 1.5 # Higher exponent for short-term trading
profit_loss = profit_loss + torch.tensor(severity, device=self.device) * 2.5
unprofitable_trades += 1
elif expected_pnl > 0:
# Reward for profitable trades (negative loss contribution)
# Higher reward for larger profits
reward = expected_pnl * 0.9
profit_loss = profit_loss - torch.tensor(reward, device=self.device)
profitable_trades += 1
# Calculate win rate and further adjust profit loss
if profitable_trades + unprofitable_trades > 0:
win_rate = profitable_trades / (profitable_trades + unprofitable_trades)
# Add extra penalty if win rate is less than 50%
if win_rate < 0.5:
profit_loss = profit_loss * (1.0 + (0.5 - win_rate) * 2.5)
# Add small reward if win rate is high
elif win_rate > 0.6:
profit_loss = profit_loss * (1.0 - (win_rate - 0.6) * 0.5)
# Combine all loss components with dynamic weighting
# Adjust weights based on training progress
# Early training focuses more on classification accuracy
if self.training_progress < 500:
action_weight = 1.0
price_weight = 0.2
profit_weight = 0.5
diversity_weight = 0.3
# Mid training balances all components
elif self.training_progress < 1500:
action_weight = 0.8
price_weight = 0.3
profit_weight = 0.8
diversity_weight = 0.5
# Late training emphasizes profitability and diversity
else:
action_weight = 0.6
price_weight = 0.3
profit_weight = 1.0
diversity_weight = 0.7
total_loss = (action_weight * action_loss +
price_weight * price_loss +
profit_weight * profit_loss +
diversity_weight * diversity_loss)
return total_loss, action_loss, price_loss
def train_epoch(self, X_train, y_train, future_prices, batch_size): def train_epoch(self, X_train, y_train, future_prices, batch_size):
"""Train the model for one epoch with focus on short-term pattern recognition""" """Train the model for one epoch with focus on short-term pattern recognition"""
self.model.train() self.model.train()
total_action_loss = 0 total_loss = 0
total_price_loss = 0
total_correct = 0 total_correct = 0
total_samples = 0 total_samples = 0
# Convert inputs to tensors and create DataLoader # Convert inputs to tensors and create DataLoader
X_train_tensor = torch.FloatTensor(X_train).to(self.device) X_train_tensor = torch.FloatTensor(X_train).to(self.device)
y_train_tensor = torch.LongTensor(y_train).to(self.device) y_train_tensor = torch.LongTensor(y_train).to(self.device)
future_prices_tensor = torch.FloatTensor(future_prices).to(self.device) if future_prices is not None else None
# Create dataset and dataloader # Create dataset and dataloader
if future_prices_tensor is not None: dataset = TensorDataset(X_train_tensor, y_train_tensor)
dataset = TensorDataset(X_train_tensor, y_train_tensor, future_prices_tensor)
else:
dataset = TensorDataset(X_train_tensor, y_train_tensor)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Training loop # Training loop
for batch_data in train_loader: for batch_X, batch_y in train_loader:
self.optimizer.zero_grad() self.optimizer.zero_grad()
# Extract batch data
if len(batch_data) == 3:
batch_X, batch_y, batch_future_prices = batch_data
else:
batch_X, batch_y = batch_data
batch_future_prices = None
# Forward pass # Forward pass
action_probs, price_pred = self.model(batch_X) action_probs, _ = self.model(batch_X)
# Calculate loss using custom trading loss function # Calculate loss
total_loss, action_loss, price_loss = self.compute_trading_loss( loss = self.action_criterion(action_probs, batch_y)
action_probs, price_pred, batch_y, batch_future_prices
)
# Backward pass and optimization # Backward pass and optimization
total_loss.backward() loss.backward()
# Apply gradient clipping to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step() self.optimizer.step()
# Update metrics # Update metrics
total_action_loss += action_loss.item() total_loss += loss.item()
total_price_loss += price_loss.item() if hasattr(price_loss, 'item') else 0
predictions = torch.argmax(action_probs, dim=1) predictions = torch.argmax(action_probs, dim=1)
total_correct += (predictions == batch_y).sum().item() total_correct += (predictions == batch_y).sum().item()
total_samples += batch_y.size(0) total_samples += batch_y.size(0)
# Track trading signals for logging # Update action counts
buy_count = (predictions == 2).sum().item() for i, (pred, target) in enumerate(zip(predictions, batch_y)):
sell_count = (predictions == 0).sum().item() pred_action = ['SELL', 'HOLD', 'BUY'][pred.item()]
hold_count = (predictions == 1).sum().item() self.action_counts[pred_action][0] += 1
if pred.item() == target.item():
self.action_counts[pred_action][1] += 1
buy_correct = ((predictions == 2) & (batch_y == 2)).sum().item() # Calculate average loss and accuracy
sell_correct = ((predictions == 0) & (batch_y == 0)).sum().item() avg_loss = total_loss / len(train_loader)
# Calculate average losses and accuracy
avg_action_loss = total_action_loss / len(train_loader)
avg_price_loss = total_price_loss / len(train_loader)
accuracy = total_correct / total_samples accuracy = total_correct / total_samples
# Update training history
self.history['train_loss'].append(avg_loss)
self.history['train_acc'].append(accuracy)
self.train_losses.append(avg_loss)
self.train_accuracies.append(accuracy)
# Log trading signals # Log trading signals
logger.info(f"Trading signals: BUY={buy_count}, SELL={sell_count}, HOLD={hold_count}") for action in ['BUY', 'SELL', 'HOLD']:
logger.info(f"Signal precision: BUY={buy_correct/max(1, buy_count):.4f}, SELL={sell_correct/max(1, sell_count):.4f}") total = self.action_counts[action][0]
correct = self.action_counts[action][1]
precision = correct / total if total > 0 else 0
logger.info(f"Trading signals - {action}: {total}, Precision: {precision:.4f}")
# Update learning rate return avg_loss, 0, accuracy # Return 0 for price_loss as we're not using it
self.scheduler.step(accuracy)
return avg_action_loss, avg_price_loss, accuracy
def evaluate(self, X_val, y_val, future_prices=None): def evaluate(self, X_val, y_val, future_prices=None):
"""Evaluate the model with focus on short-term trading performance metrics""" """Evaluate the model with focus on short-term trading performance metrics"""
self.model.eval() self.model.eval()
total_action_loss = 0 total_loss = 0
total_price_loss = 0
total_correct = 0 total_correct = 0
total_samples = 0 total_samples = 0
# Additional metrics for trading performance
trade_signals = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
correct_signals = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
# Convert inputs to tensors # Convert inputs to tensors
X_val_tensor = torch.FloatTensor(X_val).to(self.device) X_val_tensor = torch.FloatTensor(X_val).to(self.device)
y_val_tensor = torch.LongTensor(y_val).to(self.device) y_val_tensor = torch.LongTensor(y_val).to(self.device)
future_prices_tensor = torch.FloatTensor(future_prices).to(self.device) if future_prices is not None else None
# Create dataset and dataloader
dataset = TensorDataset(X_val_tensor, y_val_tensor)
val_loader = DataLoader(dataset, batch_size=32)
with torch.no_grad(): with torch.no_grad():
# Forward pass for batch_X, batch_y in val_loader:
action_probs, price_pred = self.model(X_val_tensor) # Forward pass
action_probs, _ = self.model(batch_X)
# Calculate loss using custom trading loss function # Calculate loss
total_loss, action_loss, price_loss = self.compute_trading_loss( loss = self.action_criterion(action_probs, batch_y)
action_probs, price_pred, y_val_tensor, future_prices_tensor
)
# Calculate predictions and accuracy # Update metrics
predictions = torch.argmax(action_probs, dim=1) total_loss += loss.item()
predictions = torch.argmax(action_probs, dim=1)
total_correct += (predictions == batch_y).sum().item()
total_samples += batch_y.size(0)
# Count prediction types and correct predictions # Calculate average loss and accuracy
for i in range(predictions.shape[0]): avg_loss = total_loss / len(val_loader)
pred = predictions[i].item() accuracy = total_correct / total_samples
if pred == 0:
trade_signals['SELL'] += 1
if y_val_tensor[i].item() == pred:
correct_signals['SELL'] += 1
elif pred == 1:
trade_signals['HOLD'] += 1
if y_val_tensor[i].item() == pred:
correct_signals['HOLD'] += 1
elif pred == 2:
trade_signals['BUY'] += 1
if y_val_tensor[i].item() == pred:
correct_signals['BUY'] += 1
# Update metrics # Update validation history
total_action_loss = action_loss.item() self.history['val_loss'].append(avg_loss)
total_price_loss = price_loss.item() if hasattr(price_loss, 'item') else 0 self.history['val_acc'].append(accuracy)
self.val_losses.append(avg_loss)
self.val_accuracies.append(accuracy)
total_correct = (predictions == y_val_tensor).sum().item() # Update learning rate scheduler
total_samples = y_val_tensor.size(0) self.scheduler.step(avg_loss)
# Calculate accuracy return avg_loss, 0, accuracy # Return 0 for price_loss as we're not using it
accuracy = total_correct / total_samples if total_samples > 0 else 0
# Calculate signal precision (crucial for short-term trading)
buy_precision = correct_signals['BUY'] / trade_signals['BUY'] if trade_signals['BUY'] > 0 else 0
sell_precision = correct_signals['SELL'] / trade_signals['SELL'] if trade_signals['SELL'] > 0 else 0
# Log trading-specific metrics
logger.info(f"Trading signals: BUY={trade_signals['BUY']}, SELL={trade_signals['SELL']}, HOLD={trade_signals['HOLD']}")
logger.info(f"Signal precision: BUY={buy_precision:.4f}, SELL={sell_precision:.4f}")
# Return combined loss, accuracy and volatility factor for adaptive training
return total_action_loss, total_price_loss, accuracy
def predict(self, X): def predict(self, X):
"""Make predictions optimized for short-term high-leverage trading signals""" """Make predictions optimized for short-term high-leverage trading signals"""
@ -659,28 +338,11 @@ class CNNModelPyTorch:
action_probs_np = action_probs.cpu().numpy() action_probs_np = action_probs.cpu().numpy()
# Apply more aggressive HOLD reduction for short-term trading # Apply more aggressive HOLD reduction for short-term trading
action_probs_np[:, 1] *= 0.5 # More aggressive HOLD reduction action_probs_np[:, 1] *= 0.3 # More aggressive HOLD reduction
# Apply boosting for BUY/SELL signals # Apply boosting for BUY/SELL signals
action_probs_np[:, 0] *= 1.3 # Boost SELL probabilities action_probs_np[:, 0] *= 2.0 # Boost SELL probabilities
action_probs_np[:, 2] *= 1.3 # Boost BUY probabilities action_probs_np[:, 2] *= 2.0 # Boost BUY probabilities
# Implement signal filtering based on previous actions to avoid oscillation
if len(self.last_actions[0]) >= self.max_consecutive_same_action:
# Check for too many consecutive identical actions
if all(a == 0 for a in self.last_actions[0][-self.max_consecutive_same_action:]):
# Too many consecutive SELL - reduce sell probability
action_probs_np[:, 0] *= 0.7
elif all(a == 2 for a in self.last_actions[0][-self.max_consecutive_same_action:]):
# Too many consecutive BUY - reduce buy probability
action_probs_np[:, 2] *= 0.7
# Apply confidence threshold to reduce noise
max_probs = np.max(action_probs_np, axis=1)
for i in range(len(action_probs_np)):
if max_probs[i] < self.confidence_threshold:
# If confidence is too low, force HOLD
action_probs_np[i] = np.array([0.1, 0.8, 0.1])
# Re-normalize # Re-normalize
action_probs_np = action_probs_np / action_probs_np.sum(axis=1, keepdims=True) action_probs_np = action_probs_np / action_probs_np.sum(axis=1, keepdims=True)
@ -704,16 +366,20 @@ class CNNModelPyTorch:
if 2 in action_dict: if 2 in action_dict:
self.action_counts['BUY'][0] += action_dict[2] self.action_counts['BUY'][0] += action_dict[2]
# Get the current close prices from the input # If price_pred is None, create a dummy array of zeros
current_prices = X_tensor[:, -1, 3].cpu().numpy() if X_tensor.shape[2] > 3 else np.zeros(X_tensor.shape[0]) if price_pred is None:
# Get the current close prices from the input if available
current_prices = X_tensor[:, -1, 3].cpu().numpy() if X_tensor.shape[2] > 3 else np.zeros(X_tensor.shape[0])
# Calculate price directions based on probabilities # Calculate price directions based on probabilities
price_directions = action_probs_np[:, 2] - action_probs_np[:, 0] # BUY - SELL price_directions = action_probs_np[:, 2] - action_probs_np[:, 0] # BUY - SELL
# Scale the price change based on signal strength # Scale the price change based on signal strength
price_preds = current_prices * (1 + price_directions * 0.002) price_preds = current_prices * (1 + price_directions * 0.002)
return action_probs_np, price_preds.reshape(-1, 1) return action_probs_np, price_preds.reshape(-1, 1)
else:
return action_probs_np, price_pred.cpu().numpy()
def predict_next_candles(self, X, n_candles=3): def predict_next_candles(self, X, n_candles=3):
""" """
@ -919,14 +585,9 @@ class CNNModelPyTorch:
model_state = { model_state = {
'model_state_dict': self.model.state_dict(), 'model_state_dict': self.model.state_dict(),
'optimizer_state_dict': self.optimizer.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict(),
'history': { 'history': self.history,
'loss': self.train_losses,
'accuracy': self.train_accuracies,
'val_loss': self.val_losses,
'val_accuracy': self.val_accuracies
},
'window_size': self.window_size, 'window_size': self.window_size,
'num_features': self.total_features, 'num_features': len(self.timeframes) * 5, # 5 features per timeframe
'output_size': self.output_size, 'output_size': self.output_size,
'timeframes': self.timeframes, 'timeframes': self.timeframes,
# Save trading configuration # Save trading configuration
@ -935,7 +596,7 @@ class CNNModelPyTorch:
'action_counts': self.action_counts, 'action_counts': self.action_counts,
'last_actions': self.last_actions, 'last_actions': self.last_actions,
# Save model version information # Save model version information
'model_version': 'short_term_optimized_v1.0', 'model_version': 'short_term_optimized_v2.0',
'timestamp': datetime.now().strftime('%Y%m%d_%H%M%S') 'timestamp': datetime.now().strftime('%Y%m%d_%H%M%S')
} }
@ -943,10 +604,10 @@ class CNNModelPyTorch:
logger.info(f"Model saved to {filepath}.pt with short-term trading optimizations") logger.info(f"Model saved to {filepath}.pt with short-term trading optimizations")
# Save a backup of the model periodically # Save a backup of the model periodically
if not os.path.exists(f"{filepath}_backup"): backup_dir = f"{filepath}_backup"
os.makedirs(f"{filepath}_backup", exist_ok=True) os.makedirs(backup_dir, exist_ok=True)
backup_path = os.path.join(f"{filepath}_backup", f"model_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt") backup_path = os.path.join(backup_dir, f"model_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")
torch.save(model_state, backup_path) torch.save(model_state, backup_path)
logger.info(f"Backup saved to {backup_path}") logger.info(f"Backup saved to {backup_path}")

View File

@ -7,12 +7,16 @@ import random
from typing import Tuple, List from typing import Tuple, List
import os import os
import sys import sys
import logging
# Add parent directory to path # Add parent directory to path
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from NN.models.simple_cnn import CNNModelPyTorch from NN.models.simple_cnn import CNNModelPyTorch
# Configure logger
logger = logging.getLogger(__name__)
class DQNAgent: class DQNAgent:
""" """
Deep Q-Network agent for trading Deep Q-Network agent for trading
@ -72,14 +76,32 @@ class DQNAgent:
# Initialize memory # Initialize memory
self.memory = deque(maxlen=memory_size) self.memory = deque(maxlen=memory_size)
# Special memory for extrema samples to use for targeted learning
self.extrema_memory = deque(maxlen=memory_size // 5) # Smaller size for extrema examples
# Training metrics # Training metrics
self.update_count = 0 self.update_count = 0
self.losses = [] self.losses = []
def remember(self, state: np.ndarray, action: int, reward: float, def remember(self, state: np.ndarray, action: int, reward: float,
next_state: np.ndarray, done: bool): next_state: np.ndarray, done: bool, is_extrema: bool = False):
"""Store experience in memory""" """
self.memory.append((state, action, reward, next_state, done)) Store experience in memory
Args:
state: Current state
action: Action taken
reward: Reward received
next_state: Next state
done: Whether episode is done
is_extrema: Whether this is a local extrema sample (for specialized learning)
"""
experience = (state, action, reward, next_state, done)
self.memory.append(experience)
# If this is an extrema sample, also add to specialized memory
if is_extrema:
self.extrema_memory.append(experience)
def act(self, state: np.ndarray) -> int: def act(self, state: np.ndarray) -> int:
"""Choose action using epsilon-greedy policy""" """Choose action using epsilon-greedy policy"""
@ -88,16 +110,39 @@ class DQNAgent:
with torch.no_grad(): with torch.no_grad():
state = torch.FloatTensor(state).unsqueeze(0).to(self.device) state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
action_probs, _ = self.policy_net(state) action_probs, extrema_pred = self.policy_net(state)
return action_probs.argmax().item() return action_probs.argmax().item()
def replay(self) -> float: def replay(self, use_extrema=False) -> float:
"""Train on a batch of experiences""" """
Train on a batch of experiences
Args:
use_extrema: Whether to include extrema samples in training
Returns:
float: Loss value
"""
if len(self.memory) < self.batch_size: if len(self.memory) < self.batch_size:
return 0.0 return 0.0
# Sample batch # Sample batch - mix regular and extrema samples
batch = random.sample(self.memory, self.batch_size) batch = []
if use_extrema and len(self.extrema_memory) > self.batch_size // 4:
# Get some extrema samples
extrema_count = min(self.batch_size // 3, len(self.extrema_memory))
extrema_samples = random.sample(list(self.extrema_memory), extrema_count)
# Get regular samples for the rest
regular_count = self.batch_size - extrema_count
regular_samples = random.sample(list(self.memory), regular_count)
# Combine samples
batch = extrema_samples + regular_samples
else:
# Standard sampling
batch = random.sample(self.memory, self.batch_size)
states, actions, rewards, next_states, dones = zip(*batch) states, actions, rewards, next_states, dones = zip(*batch)
# Convert to tensors and move to device # Convert to tensors and move to device
@ -108,7 +153,7 @@ class DQNAgent:
dones = torch.FloatTensor(dones).to(self.device) dones = torch.FloatTensor(dones).to(self.device)
# Get current Q values # Get current Q values
current_q_values, _ = self.policy_net(states) current_q_values, extrema_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1)) current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
# Get next Q values from target network # Get next Q values from target network
@ -117,8 +162,15 @@ class DQNAgent:
next_q_values = next_q_values.max(1)[0] next_q_values = next_q_values.max(1)[0]
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Compute loss # Compute Q-learning loss
loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values) q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
# If we have extrema labels (not in this implementation yet),
# we could add an additional loss for extrema prediction
# This would require labels for whether each state is near an extrema
# Total loss is just Q-learning loss for now
loss = q_loss
# Optimize # Optimize
self.optimizer.zero_grad() self.optimizer.zero_grad()
@ -135,6 +187,50 @@ class DQNAgent:
return loss.item() return loss.item()
def train_on_extrema(self, states, actions, rewards, next_states, dones):
"""
Special training method focused on extrema patterns
Args:
states: Array of states near extrema points
actions: Correct actions to take (buy at bottoms, sell at tops)
rewards: Rewards for each action
next_states: Next states
dones: Done flags
"""
if len(states) == 0:
return 0.0
# Convert to tensors
states = torch.FloatTensor(np.array(states)).to(self.device)
actions = torch.LongTensor(actions).to(self.device)
rewards = torch.FloatTensor(rewards).to(self.device)
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
dones = torch.FloatTensor(dones).to(self.device)
# Forward pass
current_q_values, extrema_pred = self.policy_net(states)
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
# Get next Q values
with torch.no_grad():
next_q_values, _ = self.target_net(next_states)
next_q_values = next_q_values.max(1)[0]
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
# Higher weight for extrema training
q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
# Full loss is just Q-learning loss
loss = q_loss
# Optimize
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return loss.item()
def save(self, path: str): def save(self, path: str):
"""Save model and agent state""" """Save model and agent state"""
os.makedirs(os.path.dirname(path), exist_ok=True) os.makedirs(os.path.dirname(path), exist_ok=True)

View File

@ -11,6 +11,39 @@ from typing import List, Tuple
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class PricePatternAttention(nn.Module):
"""
Attention mechanism specifically designed to focus on price patterns
that might indicate local extrema or trend reversals
"""
def __init__(self, input_dim, hidden_dim=64):
super(PricePatternAttention, self).__init__()
self.query = nn.Linear(input_dim, hidden_dim)
self.key = nn.Linear(input_dim, hidden_dim)
self.value = nn.Linear(input_dim, hidden_dim)
self.scale = torch.sqrt(torch.tensor(hidden_dim, dtype=torch.float32))
def forward(self, x):
"""Apply attention to input sequence"""
# x shape: [batch_size, seq_len, features]
batch_size, seq_len, _ = x.size()
# Project input to query, key, value
q = self.query(x) # [batch_size, seq_len, hidden_dim]
k = self.key(x) # [batch_size, seq_len, hidden_dim]
v = self.value(x) # [batch_size, seq_len, hidden_dim]
# Calculate attention scores
scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale # [batch_size, seq_len, seq_len]
# Apply softmax to get attention weights
attn_weights = F.softmax(scores, dim=-1) # [batch_size, seq_len, seq_len]
# Apply attention to values
output = torch.matmul(attn_weights, v) # [batch_size, seq_len, hidden_dim]
return output, attn_weights
class CNNModelPyTorch(nn.Module): class CNNModelPyTorch(nn.Module):
""" """
CNN model for trading with multiple timeframes CNN model for trading with multiple timeframes
@ -30,7 +63,15 @@ class CNNModelPyTorch(nn.Module):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Using device: {self.device}") logger.info(f"Using device: {self.device}")
# Convolutional layers # Create model architecture
self._create_layers()
# Move model to device
self.to(self.device)
def _create_layers(self):
"""Create all model layers with current feature dimensions"""
# Convolutional layers - use total_features as input channels
self.conv1 = nn.Conv1d(self.total_features, 64, kernel_size=3, padding=1) self.conv1 = nn.Conv1d(self.total_features, 64, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm1d(64) self.bn1 = nn.BatchNorm1d(64)
@ -40,24 +81,49 @@ class CNNModelPyTorch(nn.Module):
self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1) self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm1d(256) self.bn3 = nn.BatchNorm1d(256)
# Calculate size after convolutions # Add price pattern attention layer
conv_output_size = window_size * 256 self.attention = PricePatternAttention(256)
# Extrema detection specialized convolutional layer
self.extrema_conv = nn.Conv1d(256, 128, kernel_size=5, padding=2)
self.extrema_bn = nn.BatchNorm1d(128)
# Calculate size after convolutions - adjusted for attention output
conv_output_size = self.window_size * 256
# Fully connected layers # Fully connected layers
self.fc1 = nn.Linear(conv_output_size, 512) self.fc1 = nn.Linear(conv_output_size, 512)
self.fc2 = nn.Linear(512, 256) self.fc2 = nn.Linear(512, 256)
# Advantage and Value streams (Dueling DQN architecture) # Advantage and Value streams (Dueling DQN architecture)
self.fc3 = nn.Linear(256, output_size) # Advantage stream self.fc3 = nn.Linear(256, self.output_size) # Advantage stream
self.value_fc = nn.Linear(256, 1) # Value stream self.value_fc = nn.Linear(256, 1) # Value stream
# Additional prediction head for extrema detection (tops/bottoms)
self.extrema_fc = nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither
# Initialize optimizer and scheduler # Initialize optimizer and scheduler
self.optimizer = optim.Adam(self.parameters(), lr=0.001) self.optimizer = optim.Adam(self.parameters(), lr=0.001)
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau( self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, mode='max', factor=0.5, patience=5, verbose=True self.optimizer, mode='max', factor=0.5, patience=5, verbose=True
) )
# Move model to device def rebuild_conv_layers(self, input_channels):
"""
Rebuild convolutional layers for different input dimensions
Args:
input_channels: Number of input channels (features) in the data
"""
logger.info(f"Rebuilding convolutional layers for {input_channels} input channels")
# Update total features
self.total_features = input_channels
# Recreate all layers with new dimensions
self._create_layers()
# Move layers to device
self.to(self.device) self.to(self.device)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
@ -65,8 +131,13 @@ class CNNModelPyTorch(nn.Module):
# Ensure input is on the correct device # Ensure input is on the correct device
x = x.to(self.device) x = x.to(self.device)
# Check and handle if input dimensions don't match model expectations
batch_size, window_len, feature_dim = x.size()
if feature_dim != self.total_features:
logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features}), rebuilding layers")
self.rebuild_conv_layers(feature_dim)
# Reshape input: [batch, window_size, features] -> [batch, channels, window_size] # Reshape input: [batch, window_size, features] -> [batch, channels, window_size]
batch_size = x.size(0)
x = x.permute(0, 2, 1) x = x.permute(0, 2, 1)
# Convolutional layers # Convolutional layers
@ -74,6 +145,26 @@ class CNNModelPyTorch(nn.Module):
x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x))) x = F.relu(self.bn3(self.conv3(x)))
# Store conv features for extrema detection
conv_features = x
# Reshape for attention: [batch, channels, window_size] -> [batch, window_size, channels]
x_attention = x.permute(0, 2, 1)
# Apply attention
attention_output, attention_weights = self.attention(x_attention)
# We'll use attention directly without the residual connection
# to avoid dimension mismatch issues
attention_reshaped = attention_output.permute(0, 2, 1) # [batch, channels, window_size]
# Apply extrema detection specialized layer
extrema_features = F.relu(self.extrema_bn(self.extrema_conv(conv_features)))
# Use attention features directly instead of residual connection
# to avoid dimension mismatches
x = conv_features # Just use the convolutional features
# Flatten # Flatten
x = x.view(batch_size, -1) x = x.view(batch_size, -1)
@ -88,7 +179,11 @@ class CNNModelPyTorch(nn.Module):
# Combine value and advantage # Combine value and advantage
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True)) q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
return q_values, value # Also compute extrema prediction from the same features
extrema_flat = extrema_features.view(batch_size, -1)
extrema_pred = self.extrema_fc(x) # Use the same features for extrema prediction
return q_values, extrema_pred
def predict(self, X): def predict(self, X):
"""Make predictions""" """Make predictions"""
@ -101,11 +196,15 @@ class CNNModelPyTorch(nn.Module):
X_tensor = X.to(self.device) X_tensor = X.to(self.device)
with torch.no_grad(): with torch.no_grad():
q_values, value = self(X_tensor) q_values, extrema_pred = self(X_tensor)
q_values_np = q_values.cpu().numpy() q_values_np = q_values.cpu().numpy()
actions = np.argmax(q_values_np, axis=1) actions = np.argmax(q_values_np, axis=1)
return actions, q_values_np # Also return extrema predictions
extrema_np = extrema_pred.cpu().numpy()
extrema_classes = np.argmax(extrema_np, axis=1)
return actions, q_values_np, extrema_classes
def save(self, path: str): def save(self, path: str):
"""Save model weights""" """Save model weights"""

View File

@ -0,0 +1,241 @@
import logging
import numpy as np
import pandas as pd
import time
from typing import Dict, Any, List, Optional, Tuple
logger = logging.getLogger(__name__)
class RealtimeDataInterface:
"""Interface for retrieving real-time market data for neural network models.
This class serves as a bridge between the RealTimeChart data sources and
the neural network models, providing properly formatted data for model
inference.
"""
def __init__(self, symbols: List[str], chart=None, max_cache_size: int = 5000):
"""Initialize the data interface.
Args:
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
chart: RealTimeChart instance (optional)
max_cache_size: Maximum number of cached candles
"""
self.symbols = symbols
self.chart = chart
self.max_cache_size = max_cache_size
# Initialize data cache
self.ohlcv_cache = {} # timeframe -> symbol -> DataFrame
logger.info(f"Initialized RealtimeDataInterface with symbols: {', '.join(symbols)}")
def get_historical_data(self, symbol: str = None, timeframe: str = '1h',
n_candles: int = 500) -> Optional[pd.DataFrame]:
"""Get historical OHLCV data for a symbol and timeframe.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
Returns:
DataFrame with OHLCV data or None if not available
"""
if not symbol:
if len(self.symbols) > 0:
symbol = self.symbols[0]
else:
logger.error("No symbol specified and no default symbols available")
return None
if symbol not in self.symbols:
logger.warning(f"Symbol {symbol} not in tracked symbols")
return None
try:
# Get data from chart if available
if self.chart:
candles = self._get_chart_data(symbol, timeframe, n_candles)
if candles is not None and len(candles) > 0:
return candles
# Fallback to default empty DataFrame
logger.warning(f"No historical data available for {symbol} at timeframe {timeframe}")
return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
except Exception as e:
logger.error(f"Error getting historical data for {symbol}: {str(e)}")
return None
def _get_chart_data(self, symbol: str, timeframe: str, n_candles: int) -> Optional[pd.DataFrame]:
"""Get data from the RealTimeChart for the specified symbol and timeframe.
Args:
symbol: Trading symbol (e.g., 'BTC/USDT')
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
Returns:
DataFrame with OHLCV data or None if not available
"""
if not self.chart:
return None
# Get chart data using the _get_chart_data method
try:
# Map to interval seconds
interval_map = {
'1s': 1,
'5s': 5,
'10s': 10,
'15s': 15,
'30s': 30,
'1m': 60,
'3m': 180,
'5m': 300,
'15m': 900,
'30m': 1800,
'1h': 3600,
'2h': 7200,
'4h': 14400,
'6h': 21600,
'8h': 28800,
'12h': 43200,
'1d': 86400,
'3d': 259200,
'1w': 604800
}
# Convert timeframe to seconds
if timeframe in interval_map:
interval_seconds = interval_map[timeframe]
else:
# Try to parse the interval (e.g., '1m' -> 60)
try:
if timeframe.endswith('s'):
interval_seconds = int(timeframe[:-1])
elif timeframe.endswith('m'):
interval_seconds = int(timeframe[:-1]) * 60
elif timeframe.endswith('h'):
interval_seconds = int(timeframe[:-1]) * 3600
elif timeframe.endswith('d'):
interval_seconds = int(timeframe[:-1]) * 86400
elif timeframe.endswith('w'):
interval_seconds = int(timeframe[:-1]) * 604800
else:
interval_seconds = int(timeframe)
except ValueError:
logger.error(f"Could not parse timeframe: {timeframe}")
return None
# Get data from chart
df = self.chart._get_chart_data(interval_seconds)
if df is not None and not df.empty:
# Limit to requested number of candles
if len(df) > n_candles:
df = df.iloc[-n_candles:]
return df
else:
logger.warning(f"No data retrieved from chart for {symbol} at timeframe {timeframe}")
return None
except Exception as e:
logger.error(f"Error getting chart data for {symbol} at {timeframe}: {str(e)}")
return None
def prepare_model_input(self, data: pd.DataFrame, window_size: int = 20,
symbol: str = None) -> Tuple[np.ndarray, Optional[int]]:
"""Prepare model input from OHLCV data.
Args:
data: DataFrame with OHLCV data
window_size: Window size for model input
symbol: Symbol for the data (for logging)
Returns:
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
"""
if data is None or len(data) < window_size:
logger.warning(f"Not enough data to prepare model input for {symbol or 'unknown symbol'}")
return None, None
try:
# Get last window_size candles
recent_data = data.iloc[-window_size:].copy()
# Get timestamp of the most recent candle
timestamp = int(recent_data.iloc[-1]['timestamp']) if 'timestamp' in recent_data.columns else int(time.time())
# Extract OHLCV features and normalize
if 'open' in recent_data.columns and 'high' in recent_data.columns and 'low' in recent_data.columns and 'close' in recent_data.columns and 'volume' in recent_data.columns:
# Normalize price data by the last close price
last_close = recent_data['close'].iloc[-1]
# Avoid division by zero
if last_close == 0:
last_close = 1.0
opens = (recent_data['open'] / last_close).values
highs = (recent_data['high'] / last_close).values
lows = (recent_data['low'] / last_close).values
closes = (recent_data['close'] / last_close).values
# Normalize volume by the max volume in the window
max_volume = recent_data['volume'].max()
if max_volume == 0:
max_volume = 1.0
volumes = (recent_data['volume'] / max_volume).values
# Stack features into a 3D array [batch_size=1, window_size, n_features=5]
X = np.column_stack((opens, highs, lows, closes, volumes))
X = X.reshape(1, window_size, 5)
# Replace any NaN or infinite values
X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=0.0)
return X, timestamp
else:
logger.error(f"Data missing required OHLCV columns for {symbol or 'unknown symbol'}")
return None, None
except Exception as e:
logger.error(f"Error preparing model input for {symbol or 'unknown symbol'}: {str(e)}")
return None, None
def prepare_realtime_input(self, timeframe: str = '1h', n_candles: int = 30,
window_size: int = 20) -> Tuple[np.ndarray, Optional[int]]:
"""Prepare real-time input for the model.
Args:
timeframe: Time interval (e.g., '1m', '5m', '1h')
n_candles: Number of candles to retrieve
window_size: Window size for model input
Returns:
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
"""
# Get data for the main symbol
if len(self.symbols) == 0:
logger.error("No symbols available for real-time input")
return None, None
symbol = self.symbols[0]
try:
# Get historical data
data = self.get_historical_data(symbol, timeframe, n_candles)
if data is None or len(data) < window_size:
logger.warning(f"Not enough data for real-time input. Need at least {window_size} candles.")
return None, None
# Prepare model input
return self.prepare_model_input(data, window_size, symbol)
except Exception as e:
logger.error(f"Error preparing real-time input: {str(e)}")
return None, None

View File

@ -64,6 +64,9 @@ class RLTradingEnvironment(gym.Env):
# State variables # State variables
self.reset() self.reset()
# Callback for visualization or external monitoring
self.action_callback = None
def reset(self): def reset(self):
"""Reset the environment to initial state""" """Reset the environment to initial state"""
self.balance = self.initial_balance self.balance = self.initial_balance
@ -145,6 +148,7 @@ class RLTradingEnvironment(gym.Env):
# Default reward is slightly negative to discourage inaction # Default reward is slightly negative to discourage inaction
reward = -0.0001 reward = -0.0001
done = False done = False
profit_pct = None # Initialize profit_pct variable
# Execute action # Execute action
if action == 0: # BUY if action == 0: # BUY
@ -218,214 +222,188 @@ class RLTradingEnvironment(gym.Env):
'total_value': total_value, 'total_value': total_value,
'gain': gain, 'gain': gain,
'trades': self.trades, 'trades': self.trades,
'win_rate': self.win_rate 'win_rate': self.win_rate,
'profit_pct': profit_pct if action == 1 and self.position == 0 else None,
'current_price': current_price,
'next_price': next_price
} }
# Call the callback if it exists
if self.action_callback:
self.action_callback(action, current_price, reward, info)
return observation, reward, done, info return observation, reward, done, info
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent"): def set_action_callback(self, callback):
"""
Set a callback function to be called after each action
Args:
callback: Function with signature (action, price, reward, info)
"""
self.action_callback = callback
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
action_callback=None, episode_callback=None, symbol="BTC/USDT"):
""" """
Train DQN agent for RL-based trading with extended training and monitoring Train DQN agent for RL-based trading with extended training and monitoring
Args:
env_class: Optional environment class to use, defaults to RLTradingEnvironment
num_episodes: Number of episodes to train
max_steps: Maximum steps per episode
save_path: Path to save the model
action_callback: Optional callback for each action (step, action, price, reward, info)
episode_callback: Optional callback after each episode (episode, reward, info)
symbol: Trading pair symbol (e.g., "BTC/USDT")
Returns:
DQNAgent: The trained agent
""" """
logger.info("Starting extended RL training for trading...") import pandas as pd
from NN.utils.data_interface import DataInterface
# Environment setup logger.info("Starting DQN training for RL trading")
window_size = 20
timeframes = ["1m", "5m", "15m"]
trading_fee = 0.001
# Ensure save directory exists # Create data interface with specified symbol
os.makedirs(os.path.dirname(save_path), exist_ok=True) data_interface = DataInterface(symbol=symbol)
# Setup TensorBoard for monitoring # Load and preprocess data
writer = SummaryWriter(f'runs/rl_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}') logger.info(f"Loading data from multiple timeframes for {symbol}")
features_1m = data_interface.get_training_data("1m", n_candles=2000)
# Data loading features_5m = data_interface.get_training_data("5m", n_candles=1000)
data_interface = DataInterface( features_15m = data_interface.get_training_data("15m", n_candles=500)
symbol="BTC/USDT",
timeframes=timeframes
)
# Get training data for each timeframe with more data
logger.info("Loading training data...")
features_1m = data_interface.get_training_data("1m", n_candles=5000)
if features_1m is not None:
logger.info(f"Loaded {len(features_1m)} 1m candles")
else:
logger.error("Failed to load 1m data")
return None
features_5m = data_interface.get_training_data("5m", n_candles=2500)
if features_5m is not None:
logger.info(f"Loaded {len(features_5m)} 5m candles")
else:
logger.error("Failed to load 5m data")
return None
features_15m = data_interface.get_training_data("15m", n_candles=2500)
if features_15m is not None:
logger.info(f"Loaded {len(features_15m)} 15m candles")
else:
logger.error("Failed to load 15m data")
return None
# Check if we have all the data
if features_1m is None or features_5m is None or features_15m is None: if features_1m is None or features_5m is None or features_15m is None:
logger.error("Failed to load training data") logger.error("Failed to load training data from one or more timeframes")
return None return None
# Convert DataFrames to numpy arrays, excluding timestamp column # If data is a DataFrame, convert to numpy array excluding the timestamp column
features_1m = features_1m.drop('timestamp', axis=1, errors='ignore').values if isinstance(features_1m, pd.DataFrame):
features_5m = features_5m.drop('timestamp', axis=1, errors='ignore').values features_1m = features_1m.drop('timestamp', axis=1, errors='ignore').values
features_15m = features_15m.drop('timestamp', axis=1, errors='ignore').values if isinstance(features_5m, pd.DataFrame):
features_5m = features_5m.drop('timestamp', axis=1, errors='ignore').values
if isinstance(features_15m, pd.DataFrame):
features_15m = features_15m.drop('timestamp', axis=1, errors='ignore').values
# Calculate number of features per timeframe # Initialize environment or use provided class
num_features = features_1m.shape[1] # Number of features after dropping timestamp if env_class is None:
env = RLTradingEnvironment(features_1m, features_5m, features_15m)
else:
env = env_class(features_1m, features_5m, features_15m)
# Create environment # Set action callback if provided
env = RLTradingEnvironment( if action_callback:
features_1m=features_1m, def step_callback(action, price, reward, info):
features_5m=features_5m, action_callback(env.current_step, action, price, reward, info)
features_15m=features_15m, env.set_action_callback(step_callback)
window_size=window_size,
trading_fee=trading_fee # Initialize agent
) window_size = env.window_size
num_features = env.num_features * env.num_timeframes
action_size = env.action_space.n
timeframes = ['1m', '5m', '15m'] # Match the timeframes from the environment
# Create agent with adjusted parameters for longer training
state_size = window_size
action_size = 3
agent = DQNAgent( agent = DQNAgent(
state_size=state_size, state_size=window_size * num_features,
action_size=action_size, action_size=action_size,
window_size=window_size, window_size=window_size,
num_features=num_features, num_features=env.num_features,
timeframes=timeframes, timeframes=timeframes,
learning_rate=0.0005, # Reduced learning rate for stability memory_size=100000,
gamma=0.99, # Increased discount factor batch_size=64,
learning_rate=0.0001,
gamma=0.99,
epsilon=1.0, epsilon=1.0,
epsilon_min=0.01, epsilon_min=0.01,
epsilon_decay=0.999, # Slower epsilon decay epsilon_decay=0.995
memory_size=50000, # Increased memory size
batch_size=128 # Increased batch size
) )
# Variables to track best performance # Training variables
best_reward = float('-inf') best_reward = -float('inf')
best_episode = 0
best_pnl = float('-inf')
best_win_rate = 0.0
# Training metrics
episode_rewards = [] episode_rewards = []
episode_pnls = []
episode_win_rates = []
episode_trades = []
# Check if previous best model exists and load it # TensorBoard writer for logging
best_model_path = f"{save_path}_best" writer = SummaryWriter(log_dir=f'runs/rl_trading_{int(time.time())}')
if os.path.exists(f"{best_model_path}_policy.pt"):
try: # Main training loop
logger.info(f"Loading previous best model from {best_model_path}") logger.info(f"Starting training for {num_episodes} episodes...")
agent.load(best_model_path) logger.info(f"Starting training on device: {agent.device}")
metadata_path = f"{best_model_path}_metadata.json"
if os.path.exists(metadata_path):
with open(metadata_path, 'r') as f:
metadata = json.load(f)
best_reward = metadata.get('best_reward', best_reward)
best_episode = metadata.get('best_episode', best_episode)
best_pnl = metadata.get('best_pnl', best_pnl)
best_win_rate = metadata.get('best_win_rate', best_win_rate)
logger.info(f"Loaded previous best metrics - Reward: {best_reward:.4f}, PnL: {best_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
except Exception as e:
logger.error(f"Error loading previous best model: {e}")
# Training loop
try: try:
for episode in range(1, num_episodes + 1): for episode in range(num_episodes):
state = env.reset() state = env.reset()
total_reward = 0 total_reward = 0
done = False
steps = 0
while not done and steps < max_steps: for step in range(max_steps):
# Select action
action = agent.act(state) action = agent.act(state)
# Take action and observe next state and reward
next_state, reward, done, info = env.step(action) next_state, reward, done, info = env.step(action)
# Store the experience in memory
agent.remember(state, action, reward, next_state, done) agent.remember(state, action, reward, next_state, done)
# Learn from experience # Update state and reward
loss = agent.replay()
state = next_state state = next_state
total_reward += reward total_reward += reward
steps += 1
# Calculate episode metrics # Train the agent by sampling from memory
if len(agent.memory) >= agent.batch_size:
loss = agent.replay()
if done or step == max_steps - 1:
break
# Track rewards
episode_rewards.append(total_reward) episode_rewards.append(total_reward)
episode_pnls.append(info['gain'])
episode_win_rates.append(info['win_rate']) # Log progress
episode_trades.append(info['trades']) avg_reward = np.mean(episode_rewards[-100:])
logger.info(f"Episode {episode}/{num_episodes} - Reward: {total_reward:.4f}, " +
f"Avg (100): {avg_reward:.4f}, Epsilon: {agent.epsilon:.4f}")
# Calculate trading metrics
win_rate = env.win_rate if hasattr(env, 'win_rate') else 0
trades = env.trades if hasattr(env, 'trades') else 0
# Log to TensorBoard # Log to TensorBoard
writer.add_scalar('Reward/episode', total_reward, episode) writer.add_scalar('Reward/Episode', total_reward, episode)
writer.add_scalar('PnL/episode', info['gain'], episode) writer.add_scalar('Reward/Average100', avg_reward, episode)
writer.add_scalar('WinRate/episode', info['win_rate'], episode) writer.add_scalar('Trade/WinRate', win_rate, episode)
writer.add_scalar('Trades/episode', info['trades'], episode) writer.add_scalar('Trade/Count', trades, episode)
writer.add_scalar('Epsilon/episode', agent.epsilon, episode)
# Save the best model based on multiple metrics (only every 50 episodes) # Save best model
is_better = False if avg_reward > best_reward and episode > 10:
if episode % 50 == 0: # Only check for saving every 50 episodes logger.info(f"New best average reward: {avg_reward:.4f}, saving model")
if (info['gain'] > best_pnl and info['win_rate'] > 0.5) or \ agent.save(save_path)
(info['gain'] > best_pnl * 1.1) or \ best_reward = avg_reward
(info['win_rate'] > best_win_rate * 1.1):
best_reward = total_reward
best_episode = episode
best_pnl = info['gain']
best_win_rate = info['win_rate']
agent.save(best_model_path)
is_better = True
# Save metadata about the best model # Periodic save every 100 episodes
metadata = { if episode % 100 == 0 and episode > 0:
'best_reward': best_reward, agent.save(f"{save_path}_episode_{episode}")
'best_episode': best_episode,
'best_pnl': best_pnl,
'best_win_rate': best_win_rate,
'date': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}
with open(f"{best_model_path}_metadata.json", 'w') as f:
json.dump(metadata, f)
# Log training progress # Call episode callback if provided
if episode % 10 == 0: if episode_callback:
avg_reward = sum(episode_rewards[-10:]) / 10 # Add environment to info dict to use for extrema training
avg_pnl = sum(episode_pnls[-10:]) / 10 info_with_env = info.copy()
avg_win_rate = sum(episode_win_rates[-10:]) / 10 info_with_env['env'] = env
avg_trades = sum(episode_trades[-10:]) / 10 episode_callback(episode, total_reward, info_with_env)
status = "NEW BEST!" if is_better else "" # Final save
logger.info(f"Episode {episode}/{num_episodes} {status}") logger.info("Training completed, saving final model")
logger.info(f"Metrics (last 10 episodes):") agent.save(f"{save_path}_final")
logger.info(f" Reward: {avg_reward:.4f}")
logger.info(f" PnL: {avg_pnl:.4f}")
logger.info(f" Win Rate: {avg_win_rate:.4f}")
logger.info(f" Trades: {avg_trades:.2f}")
logger.info(f" Epsilon: {agent.epsilon:.4f}")
logger.info(f"Best so far - PnL: {best_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
except KeyboardInterrupt: except Exception as e:
logger.info("Training interrupted by user. Saving best model...") logger.error(f"Training failed: {str(e)}")
import traceback
logger.error(traceback.format_exc())
# Close TensorBoard writer # Close TensorBoard writer
writer.close() writer.close()
# Final logs
logger.info(f"Training completed. Best model from episode {best_episode}")
logger.info(f"Best metrics:")
logger.info(f" Reward: {best_reward:.4f}")
logger.info(f" PnL: {best_pnl:.4f}")
logger.info(f" Win Rate: {best_win_rate:.4f}")
# Return the agent for potential further use
return agent return agent
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -25,10 +25,10 @@ class SignalInterpreter:
""" """
self.config = config or {} self.config = config or {}
# Signal thresholds - higher thresholds for high-leverage trading # Signal thresholds - lower thresholds to increase trade frequency
self.buy_threshold = self.config.get('buy_threshold', 0.65) self.buy_threshold = self.config.get('buy_threshold', 0.35)
self.sell_threshold = self.config.get('sell_threshold', 0.65) self.sell_threshold = self.config.get('sell_threshold', 0.35)
self.hold_threshold = self.config.get('hold_threshold', 0.75) self.hold_threshold = self.config.get('hold_threshold', 0.60)
# Adaptive parameters # Adaptive parameters
self.confidence_multiplier = self.config.get('confidence_multiplier', 1.0) self.confidence_multiplier = self.config.get('confidence_multiplier', 1.0)
@ -45,14 +45,14 @@ class SignalInterpreter:
self.current_position = None # None = no position, 'long' = buy, 'short' = sell self.current_position = None # None = no position, 'long' = buy, 'short' = sell
# Filters for better signal quality # Filters for better signal quality
self.trend_filter_enabled = self.config.get('trend_filter_enabled', True) self.trend_filter_enabled = self.config.get('trend_filter_enabled', False) # Disable trend filter by default
self.volume_filter_enabled = self.config.get('volume_filter_enabled', True) self.volume_filter_enabled = self.config.get('volume_filter_enabled', False) # Disable volume filter by default
self.oscillation_filter_enabled = self.config.get('oscillation_filter_enabled', True) self.oscillation_filter_enabled = self.config.get('oscillation_filter_enabled', False) # Disable oscillation filter by default
# Sensitivity parameters # Sensitivity parameters
self.min_price_movement = self.config.get('min_price_movement', 0.0005) # 0.05% minimum expected movement self.min_price_movement = self.config.get('min_price_movement', 0.0001) # Lower price movement threshold
self.hold_cooldown = self.config.get('hold_cooldown', 3) # Minimum periods to wait after a HOLD self.hold_cooldown = self.config.get('hold_cooldown', 1) # Shorter hold cooldown
self.consecutive_signals_required = self.config.get('consecutive_signals_required', 2) self.consecutive_signals_required = self.config.get('consecutive_signals_required', 1) # Require only one signal
# State tracking # State tracking
self.consecutive_buy_signals = 0 self.consecutive_buy_signals = 0

View File

@ -55,3 +55,4 @@ python test_model.py
python train_with_realtime_ticks.py python train_with_realtime_ticks.py
python NN/train_rl.py python NN/train_rl.py
python train_rl_with_realtime.py

84
main.py
View File

@ -56,7 +56,7 @@ websocket_logger.setLevel(logging.INFO) # Change this from DEBUG to INFO
class WebSocketFilter(logging.Filter): class WebSocketFilter(logging.Filter):
def filter(self, record): def filter(self, record):
# Filter out DEBUG messages from WebSocket-related modules # Filter out DEBUG messages from WebSocket-related modules
if record.levelno == logging.DEBUG and ('websocket' in record.name or if record.levelno == logging.INFO and ('websocket' in record.name or
'protocol' in record.name or 'protocol' in record.name or
'realtime' in record.name): 'realtime' in record.name):
return False return False
@ -331,7 +331,7 @@ def main():
"""Main function for the trading bot.""" """Main function for the trading bot."""
# Parse command-line arguments # Parse command-line arguments
parser = argparse.ArgumentParser(description="Trading Bot with Neural Network Integration") parser = argparse.ArgumentParser(description="Trading Bot with Neural Network Integration")
parser.add_argument('--symbols', nargs='+', default=["BTC/USDT", "ETH/USDT"], parser.add_argument('--symbols', nargs='+', default=["ETH/USDT", "ETH/USDT"],
help='Trading symbols to monitor') help='Trading symbols to monitor')
parser.add_argument('--timeframes', nargs='+', default=["1m", "5m", "1h", "4h", "1d"], parser.add_argument('--timeframes', nargs='+', default=["1m", "5m", "1h", "4h", "1d"],
help='Timeframes to monitor') help='Timeframes to monitor')
@ -692,11 +692,17 @@ if __name__ == "__main__":
"""Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals""" """Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals"""
reward = 0 reward = 0
# Base reward for actions # Validate current price
if action == 0: # HOLD if self.current_price <= 0 or self.current_price > 1000000: # Reasonable price range
reward = -0.05 # Increased penalty for doing nothing to encourage more trading logger.error(f"Invalid current price: {self.current_price}")
return -10.0 # Strong penalty for invalid price
elif action == 1: # BUY/LONG # Validate position size
if self.position_size <= 0 or self.position_size > 1000000: # Reasonable position size range
logger.error(f"Invalid position size: {self.position_size}")
return -10.0 # Strong penalty for invalid position size
if action == 1: # BUY/LONG
if self.position == 'flat': if self.position == 'flat':
# Opening a long position # Opening a long position
self.position = 'long' self.position = 'long'
@ -706,12 +712,11 @@ if __name__ == "__main__":
self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100) self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100)
self.take_profit = self.entry_price * (1 + self.take_profit_pct/100) self.take_profit = self.entry_price * (1 + self.take_profit_pct/100)
# Check if this is an optimal buy point (bottom) # Check if this is an optimal buy point
current_idx = len(self.features['price']) - 1 if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms:
if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms: reward += 2.0 # Bonus for buying at a bottom
reward += 3.0 # Increased bonus for buying at a bottom
# Check for volume spike (indicating potential big movement) # Check for volume spike
if len(self.features['volume']) > 5: if len(self.features['volume']) > 5:
avg_volume = np.mean(self.features['volume'][-5:-1]) avg_volume = np.mean(self.features['volume'][-5:-1])
current_volume = self.features['volume'][-1] current_volume = self.features['volume'][-1]
@ -737,9 +742,20 @@ if __name__ == "__main__":
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size pnl_dollar = pnl_percent / 100 * self.position_size
# Validate PnL values
if abs(pnl_percent) > 100: # Max 100% loss/gain
logger.error(f"Invalid PnL percentage: {pnl_percent}")
pnl_percent = max(min(pnl_percent, 100), -100)
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees # Apply fees
pnl_dollar -= self.calculate_fees(self.position_size) pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance with validation
if abs(pnl_dollar) > self.balance * 2: # Max 200% of balance
logger.error(f"Invalid PnL dollar amount: {pnl_dollar}")
pnl_dollar = max(min(pnl_dollar, self.balance * 2), -self.balance * 2)
# Update balance # Update balance
self.balance += pnl_dollar self.balance += pnl_dollar
self.total_pnl += pnl_dollar self.total_pnl += pnl_dollar
@ -758,11 +774,11 @@ if __name__ == "__main__":
# Reward based on PnL with stronger penalties for losses # Reward based on PnL with stronger penalties for losses
if pnl_dollar > 0: if pnl_dollar > 0:
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit reward += 1.0 + min(pnl_dollar / 10, 5.0) # Cap positive reward at 5.0
self.win_count += 1 self.win_count += 1
else: else:
# Stronger penalty for losses, scaled by the size of the loss # Stronger penalty for losses, scaled by the size of the loss but capped
loss_penalty = 1.0 + abs(pnl_dollar) / 5 loss_penalty = min(1.0 + abs(pnl_dollar) / 5, 5.0)
reward -= loss_penalty reward -= loss_penalty
self.loss_count += 1 self.loss_count += 1
@ -2115,11 +2131,17 @@ class TradingEnvironment:
"""Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals""" """Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals"""
reward = 0 reward = 0
# Base reward for actions # Validate current price
if action == 0: # HOLD if self.current_price <= 0 or self.current_price > 1000000: # Reasonable price range
reward = -0.05 # Increased penalty for doing nothing to encourage more trading logger.error(f"Invalid current price: {self.current_price}")
return -10.0 # Strong penalty for invalid price
elif action == 1: # BUY/LONG # Validate position size
if self.position_size <= 0 or self.position_size > 1000000: # Reasonable position size range
logger.error(f"Invalid position size: {self.position_size}")
return -10.0 # Strong penalty for invalid position size
if action == 1: # BUY/LONG
if self.position == 'flat': if self.position == 'flat':
# Opening a long position # Opening a long position
self.position = 'long' self.position = 'long'
@ -2129,12 +2151,11 @@ class TradingEnvironment:
self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100) self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100)
self.take_profit = self.entry_price * (1 + self.take_profit_pct/100) self.take_profit = self.entry_price * (1 + self.take_profit_pct/100)
# Check if this is an optimal buy point (bottom) # Check if this is an optimal buy point
current_idx = len(self.features['price']) - 1 if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms:
if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms: reward += 2.0 # Bonus for buying at a bottom
reward += 3.0 # Increased bonus for buying at a bottom
# Check for volume spike (indicating potential big movement) # Check for volume spike
if len(self.features['volume']) > 5: if len(self.features['volume']) > 5:
avg_volume = np.mean(self.features['volume'][-5:-1]) avg_volume = np.mean(self.features['volume'][-5:-1])
current_volume = self.features['volume'][-1] current_volume = self.features['volume'][-1]
@ -2160,9 +2181,20 @@ class TradingEnvironment:
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100 pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
pnl_dollar = pnl_percent / 100 * self.position_size pnl_dollar = pnl_percent / 100 * self.position_size
# Validate PnL values
if abs(pnl_percent) > 100: # Max 100% loss/gain
logger.error(f"Invalid PnL percentage: {pnl_percent}")
pnl_percent = max(min(pnl_percent, 100), -100)
pnl_dollar = pnl_percent / 100 * self.position_size
# Apply fees # Apply fees
pnl_dollar -= self.calculate_fees(self.position_size) pnl_dollar -= self.calculate_fees(self.position_size)
# Update balance with validation
if abs(pnl_dollar) > self.balance * 2: # Max 200% of balance
logger.error(f"Invalid PnL dollar amount: {pnl_dollar}")
pnl_dollar = max(min(pnl_dollar, self.balance * 2), -self.balance * 2)
# Update balance # Update balance
self.balance += pnl_dollar self.balance += pnl_dollar
self.total_pnl += pnl_dollar self.total_pnl += pnl_dollar
@ -2181,11 +2213,11 @@ class TradingEnvironment:
# Reward based on PnL with stronger penalties for losses # Reward based on PnL with stronger penalties for losses
if pnl_dollar > 0: if pnl_dollar > 0:
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit reward += 1.0 + min(pnl_dollar / 10, 5.0) # Cap positive reward at 5.0
self.win_count += 1 self.win_count += 1
else: else:
# Stronger penalty for losses, scaled by the size of the loss # Stronger penalty for losses, scaled by the size of the loss but capped
loss_penalty = 1.0 + abs(pnl_dollar) / 5 loss_penalty = min(1.0 + abs(pnl_dollar) / 5, 5.0)
reward -= loss_penalty reward -= loss_penalty
self.loss_count += 1 self.loss_count += 1

103
test_dash.py Normal file
View File

@ -0,0 +1,103 @@
#!/usr/bin/env python
"""
Simple test for Dash to ensure chart rendering is working correctly
"""
import dash
from dash import html, dcc, Input, Output
import plotly.graph_objects as go
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
# Create some sample data
def generate_sample_data(days=10):
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
dates = pd.date_range(start=start_date, end=end_date, freq='1h')
np.random.seed(42)
prices = np.random.normal(loc=100, scale=5, size=len(dates))
prices = np.cumsum(np.random.normal(loc=0, scale=1, size=len(dates))) + 100
df = pd.DataFrame({
'timestamp': dates,
'open': prices,
'high': prices + np.random.normal(loc=1, scale=0.5, size=len(dates)),
'low': prices - np.random.normal(loc=1, scale=0.5, size=len(dates)),
'close': prices + np.random.normal(loc=0, scale=0.5, size=len(dates)),
'volume': np.abs(np.random.normal(loc=100, scale=50, size=len(dates)))
})
return df
# Create a Dash app
app = dash.Dash(__name__)
# Layout
app.layout = html.Div([
html.H1("Test Chart"),
dcc.Graph(id='test-chart', style={'height': '800px'}),
dcc.Interval(
id='interval-component',
interval=1*1000, # in milliseconds
n_intervals=0
),
html.Div(id='signal-display', children="No Signal")
])
# Callback
@app.callback(
[Output('test-chart', 'figure'),
Output('signal-display', 'children')],
[Input('interval-component', 'n_intervals')]
)
def update_chart(n):
# Generate new data on each update
df = generate_sample_data()
# Create a candlestick chart
fig = go.Figure(data=[go.Candlestick(
x=df['timestamp'],
open=df['open'],
high=df['high'],
low=df['low'],
close=df['close']
)])
# Add some random buy/sell signals
if n % 5 == 0:
signal_point = df.iloc[np.random.randint(0, len(df))]
action = "BUY" if n % 10 == 0 else "SELL"
color = "green" if action == "BUY" else "red"
fig.add_trace(go.Scatter(
x=[signal_point['timestamp']],
y=[signal_point['close']],
mode='markers',
marker=dict(
size=10,
color=color,
symbol='triangle-up' if action == 'BUY' else 'triangle-down'
),
name=action
))
signal_text = f"Current Signal: {action}"
else:
signal_text = "No Signal"
# Update layout
fig.update_layout(
title='Test Price Chart',
yaxis_title='Price',
xaxis_title='Time',
template='plotly_dark',
height=800
)
return fig, signal_text
if __name__ == '__main__':
print("Starting Dash server on http://localhost:8090/")
app.run_server(debug=False, host='localhost', port=8090)

19
test_websocket.py Normal file
View File

@ -0,0 +1,19 @@
import websockets
import asyncio
import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
async def test_websocket():
url = "wss://stream.binance.com:9443/ws/btcusdt@trade"
try:
logger.info(f"Connecting to {url}")
async with websockets.connect(url) as ws:
logger.info("Connected successfully")
message = await ws.recv()
logger.info(f"Received message: {message[:200]}...")
except Exception as e:
logger.error(f"Connection failed: {str(e)}")
asyncio.run(test_websocket())

537
train_rl_with_realtime.py Normal file
View File

@ -0,0 +1,537 @@
#!/usr/bin/env python
"""
Integrated RL Trading with Realtime Visualization
This script combines the RL training (NN/train_rl.py) with the realtime visualization (realtime.py)
to display the actions taken by the RL agent on the realtime chart.
"""
import os
import sys
import logging
import asyncio
import threading
import time
from datetime import datetime
import signal
import numpy as np
import torch
import json
from threading import Thread
import pandas as pd
from scipy.signal import argrelextrema
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(f'rl_realtime_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger('rl_realtime')
# Add the project root to path if needed
project_root = os.path.dirname(os.path.abspath(__file__))
if project_root not in sys.path:
sys.path.append(project_root)
# Global variables for coordination
realtime_chart = None
realtime_websocket_task = None
running = True
def signal_handler(sig, frame):
"""Handle CTRL+C to gracefully exit training"""
global running
logger.info("Received interrupt signal. Finishing current epoch and shutting down...")
running = False
# Register signal handler
signal.signal(signal.SIGINT, signal_handler)
class ExtremaDetector:
"""
Detects local extrema (tops and bottoms) in price data
"""
def __init__(self, window_size=10, order=5):
"""
Args:
window_size (int): Size of the window to look for extrema
order (int): How many points on each side to use for comparison
"""
self.window_size = window_size
self.order = order
def find_extrema(self, prices):
"""
Find the local minima and maxima in the price series
Args:
prices (array-like): Array of price values
Returns:
tuple: (max_indices, min_indices) - arrays of indices where local maxima and minima occur
"""
# Convert to numpy array if needed
price_array = np.array(prices)
# Find local maxima (tops)
local_max_indices = argrelextrema(price_array, np.greater, order=self.order)[0]
# Find local minima (bottoms)
local_min_indices = argrelextrema(price_array, np.less, order=self.order)[0]
# Filter out extrema that are too close to the edges
max_indices = local_max_indices[local_max_indices >= self.order]
max_indices = max_indices[max_indices < len(price_array) - self.order]
min_indices = local_min_indices[local_min_indices >= self.order]
min_indices = min_indices[min_indices < len(price_array) - self.order]
return max_indices, min_indices
class RLTrainingIntegrator:
"""
Integrates RL training with realtime chart visualization.
Acts as a bridge between the RL training process and the realtime chart.
"""
def __init__(self, chart, symbol="ETH/USDT", model_save_path="NN/models/saved/dqn_agent"):
self.chart = chart
self.symbol = symbol
self.model_save_path = model_save_path
self.episode_count = 0
self.action_history = []
self.reward_history = []
self.trade_count = 0
self.win_count = 0
# Track current position state
self.in_position = False
self.entry_price = None
self.entry_time = None
# Extrema detector
self.extrema_detector = ExtremaDetector(window_size=10, order=5)
# Store the agent reference
self.agent = None
def start_training(self, num_episodes=5000, max_steps=2000):
"""Start the RL training process with visualization integration"""
from NN.train_rl import train_rl, RLTradingEnvironment
logger.info(f"Starting RL training with realtime visualization for {self.symbol}")
# Define callbacks for the training process
def on_action(step, action, price, reward, info):
"""Callback for each action taken by the agent"""
# Only visualize non-hold actions
if action != 2: # 0=Buy, 1=Sell, 2=Hold
# Convert to string action
action_str = "BUY" if action == 0 else "SELL"
# Get timestamp - we'll use current time as a proxy
timestamp = datetime.now()
# Track position state
if action == 0 and not self.in_position: # Buy and not already in position
self.in_position = True
self.entry_price = price
self.entry_time = timestamp
# Send to chart - visualize buy signal
if self.chart and hasattr(self.chart, 'add_nn_signal'):
self.chart.add_nn_signal(action_str, timestamp, probability=abs(reward))
elif action == 1 and self.in_position: # Sell and in position (complete trade)
self.in_position = False
# Calculate profit if we have entry data
pnl = None
if self.entry_price is not None:
pnl = (price - self.entry_price) / self.entry_price
# Log the complete trade on the chart
if self.chart:
# Show sell signal
if hasattr(self.chart, 'add_nn_signal'):
self.chart.add_nn_signal(action_str, timestamp, probability=abs(reward))
# Record the trade with PnL
if hasattr(self.chart, 'add_trade'):
self.chart.add_trade(
action=action_str,
price=price,
timestamp=timestamp,
pnl=pnl
)
# Update trade counts
self.trade_count += 1
if pnl is not None and pnl > 0:
self.win_count += 1
# Reset entry data
self.entry_price = None
self.entry_time = None
# Track all actions
self.action_history.append({
'step': step,
'action': action_str,
'price': price,
'reward': reward,
'timestamp': timestamp.isoformat()
})
# Track reward for all actions (including hold)
self.reward_history.append(reward)
# Log periodically
if len(self.reward_history) % 100 == 0:
avg_reward = sum(self.reward_history[-100:]) / 100
logger.info(f"Step {step}: Avg reward (last 100): {avg_reward:.4f}, Actions: {len(self.action_history)}, Trades: {self.trade_count}")
def on_episode(episode, reward, info):
"""Callback for each completed episode"""
self.episode_count += 1
# Log episode results
logger.info(f"Episode {episode} completed")
logger.info(f" Total reward: {reward:.4f}")
logger.info(f" PnL: {info['gain']:.4f}")
logger.info(f" Win rate: {info['win_rate']:.4f}")
logger.info(f" Trades: {info['trades']}")
# Reset position state for new episode
self.in_position = False
self.entry_price = None
self.entry_time = None
# After each episode, perform additional training for local extrema
if hasattr(self.agent, 'policy_net') and hasattr(self.agent, 'replay') and episode > 0:
self._train_on_extrema(self.agent, info['env'])
# Start the actual training with our callbacks
self.agent = train_rl(
num_episodes=num_episodes,
max_steps=max_steps,
save_path=self.model_save_path,
action_callback=on_action,
episode_callback=on_episode,
symbol=self.symbol
)
logger.info("RL training completed")
return self.agent
def _train_on_extrema(self, agent, env):
"""
Perform additional training on local extrema (tops and bottoms)
to help the model learn these important patterns faster
Args:
agent: The DQN agent
env: The trading environment
"""
if not hasattr(env, 'features_1m') or len(env.features_1m) == 0:
logger.warning("Environment doesn't have price data for extrema detection")
return
try:
# Extract close prices
prices = env.features_1m[:, -1] # Assuming close price is the last column
# Find local extrema
max_indices, min_indices = self.extrema_detector.find_extrema(prices)
if len(max_indices) == 0 or len(min_indices) == 0:
logger.warning("No extrema found in the current price data")
return
logger.info(f"Found {len(max_indices)} tops and {len(min_indices)} bottoms for additional training")
# Calculate price changes at extrema to prioritize more significant ones
max_price_changes = []
for idx in max_indices:
if idx < 5 or idx >= len(prices) - 5:
continue
# Calculate percentage price rise from previous 5 candles to the peak
min_before = min(prices[idx-5:idx])
price_change = (prices[idx] - min_before) / min_before
max_price_changes.append((idx, price_change))
min_price_changes = []
for idx in min_indices:
if idx < 5 or idx >= len(prices) - 5:
continue
# Calculate percentage price drop from previous 5 candles to the bottom
max_before = max(prices[idx-5:idx])
price_change = (max_before - prices[idx]) / max_before
min_price_changes.append((idx, price_change))
# Sort extrema by significance (larger price change is more important)
max_price_changes.sort(key=lambda x: x[1], reverse=True)
min_price_changes.sort(key=lambda x: x[1], reverse=True)
# Take top 10 most significant extrema or all if fewer
max_indices = [idx for idx, _ in max_price_changes[:10]]
min_indices = [idx for idx, _ in min_price_changes[:10]]
# Log the significance of the extrema
if max_indices:
logger.info(f"Top extrema price changes: {[round(pc*100, 2) for _, pc in max_price_changes[:5]]}%")
if min_indices:
logger.info(f"Bottom extrema price changes: {[round(pc*100, 2) for _, pc in min_price_changes[:5]]}%")
# Collect states, actions, rewards for batch training
states = []
actions = []
rewards = []
next_states = []
dones = []
# Process tops (local maxima - should sell)
for idx in max_indices:
if idx < env.window_size + 2 or idx >= len(prices) - 2:
continue
# Create states for multiple points approaching the top
# This helps the model learn to recognize the pattern leading to the top
for offset in range(1, 4): # Look at 1, 2, and 3 candles before the top
if idx - offset < env.window_size:
continue
# State before the peak
state_idx = idx - offset
env.current_step = state_idx
state = env._get_observation()
# The next state would be closer to the peak
env.current_step = state_idx + 1
next_state = env._get_observation()
# Reward increases as we get closer to the peak
# Stronger rewards for being right at the peak
reward = 1.0 if offset > 1 else 2.0
# Add to memory
action = 1 # Sell
agent.remember(state, action, reward, next_state, False, is_extrema=True)
# Add to batch
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(False)
# Process bottoms (local minima - should buy)
for idx in min_indices:
if idx < env.window_size + 2 or idx >= len(prices) - 2:
continue
# Create states for multiple points approaching the bottom
for offset in range(1, 4): # Look at 1, 2, and 3 candles before the bottom
if idx - offset < env.window_size:
continue
# State before the bottom
state_idx = idx - offset
env.current_step = state_idx
state = env._get_observation()
# The next state would be closer to the bottom
env.current_step = state_idx + 1
next_state = env._get_observation()
# Reward increases as we get closer to the bottom
reward = 1.0 if offset > 1 else 2.0
# Add to memory
action = 0 # Buy
agent.remember(state, action, reward, next_state, False, is_extrema=True)
# Add to batch
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(False)
# Add some negative examples - don't buy at tops, don't sell at bottoms
for idx in max_indices[:5]: # Use a few top peaks
if idx < env.window_size + 1 or idx >= len(prices) - 1:
continue
# State at the peak
env.current_step = idx
state = env._get_observation()
# Next state
env.current_step = idx + 1
next_state = env._get_observation()
# Strong negative reward for buying at a peak
reward = -1.5
# Add negative example of buying at a peak
action = 0 # Buy (wrong action)
agent.remember(state, action, reward, next_state, False, is_extrema=True)
# Add to batch
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(False)
for idx in min_indices[:5]: # Use a few bottom troughs
if idx < env.window_size + 1 or idx >= len(prices) - 1:
continue
# State at the bottom
env.current_step = idx
state = env._get_observation()
# Next state
env.current_step = idx + 1
next_state = env._get_observation()
# Strong negative reward for selling at a bottom
reward = -1.5
# Add negative example of selling at a bottom
action = 1 # Sell (wrong action)
agent.remember(state, action, reward, next_state, False, is_extrema=True)
# Add to batch
states.append(state)
actions.append(action)
rewards.append(reward)
next_states.append(next_state)
dones.append(False)
# Train on the collected extrema samples
if len(states) > 0:
logger.info(f"Performing additional training on {len(states)} extrema patterns")
loss = agent.train_on_extrema(states, actions, rewards, next_states, dones)
logger.info(f"Extrema training loss: {loss:.4f}")
# Additional replay passes with extrema samples included
for _ in range(5):
loss = agent.replay(use_extrema=True)
logger.info(f"Mixed replay with extrema - loss: {loss:.4f}")
except Exception as e:
logger.error(f"Error during extrema training: {str(e)}")
import traceback
logger.error(traceback.format_exc())
async def start_realtime_chart(symbol="BTC/USDT", port=8050):
"""
Start the realtime chart display in a separate thread
Returns:
tuple: (chart, websocket_task)
"""
from realtime import RealTimeChart
try:
logger.info(f"Initializing RealTimeChart for {symbol}")
# Create the chart
chart = RealTimeChart(symbol)
# Start the WebSocket connection in a separate thread
# The _start_websocket_thread method already handles this correctly
# Run the Dash server in a separate thread
thread = Thread(target=lambda c=chart, p=port: c.run(host='localhost', port=p))
thread.daemon = True
thread.start()
# Give the server a moment to start
await asyncio.sleep(2)
logger.info(f"Started realtime chart for {symbol} on port {port}")
logger.info(f"You can view the chart at http://localhost:{port}/")
# Return the chart and a dummy websocket task (the real one is running in a thread)
return chart, asyncio.create_task(asyncio.sleep(0))
except Exception as e:
logger.error(f"Error starting realtime chart: {str(e)}")
import traceback
logger.error(traceback.format_exc())
raise
def run_training_thread(chart, num_episodes=5000, max_steps=2000):
"""Run the RL training in a separate thread"""
integrator = RLTrainingIntegrator(chart)
thread = Thread(target=lambda: integrator.start_training(num_episodes, max_steps))
thread.daemon = True
thread.start()
logger.info("Started RL training thread")
return thread, integrator
def test_signals(chart):
"""Add test signals to the chart to verify functionality"""
from datetime import datetime
logger.info("Adding test signals to chart")
# Add a test BUY signal
chart.add_nn_signal("BUY", datetime.now(), 0.95)
# Sleep briefly
time.sleep(1)
# Add a test SELL signal
chart.add_nn_signal("SELL", datetime.now(), 0.85)
# Add a test trade
chart.add_trade("BUY", 50000.0, datetime.now(), 0.05)
async def main():
"""Main function that coordinates the realtime chart and RL training"""
global realtime_chart, realtime_websocket_task, running
logger.info("Starting integrated RL training with realtime visualization")
# Start the realtime chart
realtime_chart, realtime_websocket_task = await start_realtime_chart()
# Wait a bit for the chart to initialize
await asyncio.sleep(5)
# Test signals first
test_signals(realtime_chart)
# Start the training in a separate thread
training_thread, integrator = run_training_thread(realtime_chart)
try:
# Keep the main task running until interrupted
while running and training_thread.is_alive():
await asyncio.sleep(1)
except KeyboardInterrupt:
logger.info("Shutting down...")
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
finally:
# Clean up
if realtime_websocket_task:
realtime_websocket_task.cancel()
try:
await realtime_websocket_task
except asyncio.CancelledError:
pass
logger.info("Application terminated")
if __name__ == "__main__":
try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Application terminated by user")

View File

@ -350,9 +350,9 @@ async def run_realtime_training():
model = CNNModelPyTorch( model = CNNModelPyTorch(
window_size=window_size, window_size=window_size,
num_features=num_features, timeframes=timeframes,
output_size=output_size, output_size=output_size,
timeframes=timeframes num_pairs=1 # Single trading pair
) )
# Try to load existing model # Try to load existing model