enhancements
This commit is contained in:
parent
a46b2c74f8
commit
73c5ecb0d2
1
.gitignore
vendored
1
.gitignore
vendored
@ -41,3 +41,4 @@ NN/utils/__pycache__/data_interface.cpython-312.pyc
|
|||||||
NN/utils/__pycache__/multi_data_interface.cpython-312.pyc
|
NN/utils/__pycache__/multi_data_interface.cpython-312.pyc
|
||||||
NN/utils/__pycache__/realtime_analyzer.cpython-312.pyc
|
NN/utils/__pycache__/realtime_analyzer.cpython-312.pyc
|
||||||
models/trading_agent_best_pnl.pt
|
models/trading_agent_best_pnl.pt
|
||||||
|
*.log
|
||||||
|
305
NN/README_TRADING.md
Normal file
305
NN/README_TRADING.md
Normal file
@ -0,0 +1,305 @@
|
|||||||
|
# Trading Agent System
|
||||||
|
|
||||||
|
A modular, extensible cryptocurrency trading system that can connect to multiple exchanges through a common interface.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
The trading agent system consists of the following components:
|
||||||
|
|
||||||
|
### Exchange Interfaces
|
||||||
|
|
||||||
|
- `ExchangeInterface`: Abstract base class that defines the common interface for all exchange implementations
|
||||||
|
- `BinanceInterface`: Implementation for the Binance exchange (supports both mainnet and testnet)
|
||||||
|
- `MEXCInterface`: Implementation for the MEXC exchange
|
||||||
|
|
||||||
|
### Trading Agent
|
||||||
|
|
||||||
|
- `TradingAgent`: Main class that manages trading operations, including position sizing, risk management, and signal processing
|
||||||
|
|
||||||
|
### Neural Network Orchestrator
|
||||||
|
|
||||||
|
- `NeuralNetworkOrchestrator`: Coordinates between neural network models and the trading agent to generate and process trading signals
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
The trading agent system is built into the main application. No additional installation is needed beyond the requirements for the main application.
|
||||||
|
|
||||||
|
### Configuration
|
||||||
|
|
||||||
|
Configuration can be provided via:
|
||||||
|
|
||||||
|
1. Command-line arguments
|
||||||
|
2. Environment variables
|
||||||
|
3. Configuration file
|
||||||
|
|
||||||
|
#### Example Configuration
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"exchange": "binance",
|
||||||
|
"api_key": "your_api_key",
|
||||||
|
"api_secret": "your_api_secret",
|
||||||
|
"test_mode": true,
|
||||||
|
"trade_symbols": ["BTC/USDT", "ETH/USDT"],
|
||||||
|
"position_size": 0.1,
|
||||||
|
"max_trades_per_day": 5,
|
||||||
|
"trade_cooldown_minutes": 60
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Running the Trading System
|
||||||
|
|
||||||
|
#### From Command Line
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run with Binance in test mode
|
||||||
|
python trading_main.py --exchange binance --test-mode
|
||||||
|
|
||||||
|
# Run with MEXC in production mode
|
||||||
|
python trading_main.py --exchange mexc --api-key YOUR_API_KEY --api-secret YOUR_API_SECRET
|
||||||
|
|
||||||
|
# Run with custom position sizing and limits
|
||||||
|
python trading_main.py --exchange binance --test-mode --position-size 0.05 --max-trades-per-day 3 --trade-cooldown 120
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Using Environment Variables
|
||||||
|
|
||||||
|
You can set these environment variables for configuration:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Set exchange API credentials
|
||||||
|
export BINANCE_API_KEY=your_binance_api_key
|
||||||
|
export BINANCE_API_SECRET=your_binance_api_secret
|
||||||
|
|
||||||
|
# Enable neural network models
|
||||||
|
export ENABLE_NN_MODELS=1
|
||||||
|
export NN_INFERENCE_INTERVAL=60
|
||||||
|
export NN_MODEL_TYPE=cnn
|
||||||
|
export NN_TIMEFRAME=1h
|
||||||
|
|
||||||
|
# Run the trading system
|
||||||
|
python trading_main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Basic Usage Examples
|
||||||
|
|
||||||
|
#### Creating a Trading Agent
|
||||||
|
|
||||||
|
```python
|
||||||
|
from NN.trading_agent import TradingAgent
|
||||||
|
|
||||||
|
# Initialize a trading agent for Binance testnet
|
||||||
|
agent = TradingAgent(
|
||||||
|
exchange_name="binance",
|
||||||
|
api_key="your_api_key",
|
||||||
|
api_secret="your_api_secret",
|
||||||
|
test_mode=True,
|
||||||
|
trade_symbols=["BTC/USDT", "ETH/USDT"],
|
||||||
|
position_size=0.1,
|
||||||
|
max_trades_per_day=5,
|
||||||
|
trade_cooldown_minutes=60
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the trading agent
|
||||||
|
agent.start()
|
||||||
|
|
||||||
|
# Process a signal
|
||||||
|
agent.process_signal(
|
||||||
|
symbol="BTC/USDT",
|
||||||
|
action="BUY",
|
||||||
|
confidence=0.85
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get current positions
|
||||||
|
positions = agent.get_current_positions()
|
||||||
|
print(f"Current positions: {positions}")
|
||||||
|
|
||||||
|
# Stop the trading agent
|
||||||
|
agent.stop()
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Using an Exchange Interface Directly
|
||||||
|
|
||||||
|
```python
|
||||||
|
from NN.exchanges import BinanceInterface
|
||||||
|
|
||||||
|
# Initialize the Binance interface
|
||||||
|
exchange = BinanceInterface(
|
||||||
|
api_key="your_api_key",
|
||||||
|
api_secret="your_api_secret",
|
||||||
|
test_mode=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect to the exchange
|
||||||
|
exchange.connect()
|
||||||
|
|
||||||
|
# Get ticker info
|
||||||
|
ticker = exchange.get_ticker("BTC/USDT")
|
||||||
|
print(f"Current BTC price: {ticker['last']}")
|
||||||
|
|
||||||
|
# Get account balance
|
||||||
|
btc_balance = exchange.get_balance("BTC")
|
||||||
|
usdt_balance = exchange.get_balance("USDT")
|
||||||
|
print(f"BTC balance: {btc_balance}")
|
||||||
|
print(f"USDT balance: {usdt_balance}")
|
||||||
|
|
||||||
|
# Place a market order
|
||||||
|
order = exchange.place_order(
|
||||||
|
symbol="BTC/USDT",
|
||||||
|
side="buy",
|
||||||
|
order_type="market",
|
||||||
|
quantity=0.001
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing the Exchange Interfaces
|
||||||
|
|
||||||
|
The system includes a test script that can be used to verify that exchange interfaces are working correctly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Test Binance interface in test mode (no real trades)
|
||||||
|
python -m NN.exchanges.trading_agent_test --exchange binance --test-mode
|
||||||
|
|
||||||
|
# Test MEXC interface in test mode
|
||||||
|
python -m NN.exchanges.trading_agent_test --exchange mexc --test-mode
|
||||||
|
|
||||||
|
# Test with actual trades (use with caution!)
|
||||||
|
python -m NN.exchanges.trading_agent_test --exchange binance --test-mode --execute-trades --test-trade-amount 0.001
|
||||||
|
```
|
||||||
|
|
||||||
|
## Adding a New Exchange
|
||||||
|
|
||||||
|
To add support for a new exchange, you need to create a new class that inherits from `ExchangeInterface` and implements all the required methods:
|
||||||
|
|
||||||
|
1. Create a new file in the `NN/exchanges` directory (e.g., `kraken_interface.py`)
|
||||||
|
2. Implement the required methods (see `exchange_interface.py` for the specifications)
|
||||||
|
3. Add the new exchange to the imports in `__init__.py`
|
||||||
|
4. Update the `_create_exchange` method in `TradingAgent` to support the new exchange
|
||||||
|
|
||||||
|
### Example of a New Exchange Implementation
|
||||||
|
|
||||||
|
```python
|
||||||
|
# NN/exchanges/kraken_interface.py
|
||||||
|
import logging
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
from .exchange_interface import ExchangeInterface
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class KrakenInterface(ExchangeInterface):
|
||||||
|
"""Kraken Exchange API Interface"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str = None, api_secret: str = None, test_mode: bool = True):
|
||||||
|
super().__init__(api_key, api_secret, test_mode)
|
||||||
|
self.base_url = "https://api.kraken.com"
|
||||||
|
# Initialize other Kraken-specific properties
|
||||||
|
|
||||||
|
def connect(self) -> bool:
|
||||||
|
# Implement connection to Kraken API
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_balance(self, asset: str) -> float:
|
||||||
|
# Implement getting balance for an asset
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_ticker(self, symbol: str) -> Dict[str, Any]:
|
||||||
|
# Implement getting ticker data
|
||||||
|
pass
|
||||||
|
|
||||||
|
def place_order(self, symbol: str, side: str, order_type: str,
|
||||||
|
quantity: float, price: float = None) -> Dict[str, Any]:
|
||||||
|
# Implement placing an order
|
||||||
|
pass
|
||||||
|
|
||||||
|
def cancel_order(self, symbol: str, order_id: str) -> bool:
|
||||||
|
# Implement cancelling an order
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_order_status(self, symbol: str, order_id: str) -> Dict[str, Any]:
|
||||||
|
# Implement getting order status
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_open_orders(self, symbol: str = None) -> List[Dict[str, Any]]:
|
||||||
|
# Implement getting open orders
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
|
Then update the imports in `__init__.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from .exchange_interface import ExchangeInterface
|
||||||
|
from .binance_interface import BinanceInterface
|
||||||
|
from .mexc_interface import MEXCInterface
|
||||||
|
from .kraken_interface import KrakenInterface
|
||||||
|
|
||||||
|
__all__ = ['ExchangeInterface', 'BinanceInterface', 'MEXCInterface', 'KrakenInterface']
|
||||||
|
```
|
||||||
|
|
||||||
|
And update the `_create_exchange` method in `TradingAgent`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def _create_exchange(self) -> ExchangeInterface:
|
||||||
|
"""Create an exchange interface based on the exchange name."""
|
||||||
|
if self.exchange_name == 'mexc':
|
||||||
|
return MEXCInterface(
|
||||||
|
api_key=self.api_key,
|
||||||
|
api_secret=self.api_secret,
|
||||||
|
test_mode=self.test_mode
|
||||||
|
)
|
||||||
|
elif self.exchange_name == 'binance':
|
||||||
|
return BinanceInterface(
|
||||||
|
api_key=self.api_key,
|
||||||
|
api_secret=self.api_secret,
|
||||||
|
test_mode=self.test_mode
|
||||||
|
)
|
||||||
|
elif self.exchange_name == 'kraken':
|
||||||
|
return KrakenInterface(
|
||||||
|
api_key=self.api_key,
|
||||||
|
api_secret=self.api_secret,
|
||||||
|
test_mode=self.test_mode
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported exchange: {self.exchange_name}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## Security Considerations
|
||||||
|
|
||||||
|
- **API Keys**: Never hardcode API keys in your code. Use environment variables or secure storage.
|
||||||
|
- **Permissions**: Restrict API key permissions to only what is needed (e.g., trading, but not withdrawals).
|
||||||
|
- **Testing**: Always test with small amounts and use test mode/testnet when possible.
|
||||||
|
- **Position Sizing**: Implement conservative position sizing to manage risk.
|
||||||
|
- **Monitoring**: Set up monitoring and alerting for your trading system.
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Common Issues
|
||||||
|
|
||||||
|
1. **Connection Problems**: Make sure you have internet connectivity and correct API credentials.
|
||||||
|
2. **Order Placement Errors**: Check for sufficient funds, correct symbol format, and valid order parameters.
|
||||||
|
3. **Rate Limiting**: Avoid making too many API requests in a short period to prevent being rate-limited.
|
||||||
|
|
||||||
|
### Logging
|
||||||
|
|
||||||
|
The trading agent system uses Python's logging module with different levels:
|
||||||
|
|
||||||
|
- **DEBUG**: Detailed information, typically useful for diagnosing problems.
|
||||||
|
- **INFO**: Confirmation that things are working as expected.
|
||||||
|
- **WARNING**: Indication that something unexpected happened, but the program can still function.
|
||||||
|
- **ERROR**: Due to a more serious problem, the program has failed to perform some function.
|
||||||
|
|
||||||
|
You can adjust the logging level in your trading script:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG, # Change to INFO, WARNING, or ERROR as needed
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("trading.log"),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
```
|
162
NN/exchanges/README.md
Normal file
162
NN/exchanges/README.md
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
# Trading Agent System
|
||||||
|
|
||||||
|
This directory contains the implementation of a modular trading agent system that integrates with the neural network models and can execute trades on various cryptocurrency exchanges.
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
The trading agent system is designed to:
|
||||||
|
|
||||||
|
1. Connect to different cryptocurrency exchanges using a common interface
|
||||||
|
2. Execute trades based on signals from neural network models
|
||||||
|
3. Manage risk through position sizing, trade limits, and cooldown periods
|
||||||
|
4. Monitor and report on trading activity
|
||||||
|
|
||||||
|
## Components
|
||||||
|
|
||||||
|
### Exchange Interfaces
|
||||||
|
|
||||||
|
- `ExchangeInterface`: Abstract base class defining the common interface for all exchange implementations
|
||||||
|
- `BinanceInterface`: Implementation for the Binance exchange, with support for both mainnet and testnet
|
||||||
|
- `MEXCInterface`: Implementation for the MEXC exchange
|
||||||
|
|
||||||
|
### Trading Agent
|
||||||
|
|
||||||
|
The `TradingAgent` class (`trading_agent.py`) manages trading activities:
|
||||||
|
|
||||||
|
- Connects to the configured exchange
|
||||||
|
- Processes trading signals from neural network models
|
||||||
|
- Applies trading rules and risk management
|
||||||
|
- Tracks and reports trading performance
|
||||||
|
|
||||||
|
### Neural Network Orchestrator
|
||||||
|
|
||||||
|
The `NeuralNetworkOrchestrator` class (`neural_network_orchestrator.py`) coordinates between models and trading:
|
||||||
|
|
||||||
|
- Manages the neural network inference process
|
||||||
|
- Routes model signals to the trading agent
|
||||||
|
- Provides integration with the RealTimeChart for visualization
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```python
|
||||||
|
from NN.exchanges import BinanceInterface, MEXCInterface
|
||||||
|
from NN.trading_agent import TradingAgent
|
||||||
|
|
||||||
|
# Initialize an exchange interface
|
||||||
|
exchange = BinanceInterface(
|
||||||
|
api_key="your_api_key",
|
||||||
|
api_secret="your_api_secret",
|
||||||
|
test_mode=True # Use testnet
|
||||||
|
)
|
||||||
|
|
||||||
|
# Connect to the exchange
|
||||||
|
exchange.connect()
|
||||||
|
|
||||||
|
# Create a trading agent
|
||||||
|
agent = TradingAgent(
|
||||||
|
exchange_name="binance",
|
||||||
|
api_key="your_api_key",
|
||||||
|
api_secret="your_api_secret",
|
||||||
|
test_mode=True,
|
||||||
|
trade_symbols=["BTC/USDT", "ETH/USDT"],
|
||||||
|
position_size=0.1,
|
||||||
|
max_trades_per_day=5,
|
||||||
|
trade_cooldown_minutes=60
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start the trading agent
|
||||||
|
agent.start()
|
||||||
|
|
||||||
|
# Process a trading signal
|
||||||
|
agent.process_signal(
|
||||||
|
symbol="BTC/USDT",
|
||||||
|
action="BUY",
|
||||||
|
confidence=0.85,
|
||||||
|
timestamp=int(time.time())
|
||||||
|
)
|
||||||
|
|
||||||
|
# Stop the trading agent when done
|
||||||
|
agent.stop()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Integration with Neural Network Models
|
||||||
|
|
||||||
|
The system is designed to be integrated with neural network models through the `NeuralNetworkOrchestrator`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from NN.neural_network_orchestrator import NeuralNetworkOrchestrator
|
||||||
|
|
||||||
|
# Configure exchange
|
||||||
|
exchange_config = {
|
||||||
|
"exchange": "binance",
|
||||||
|
"api_key": "your_api_key",
|
||||||
|
"api_secret": "your_api_secret",
|
||||||
|
"test_mode": True,
|
||||||
|
"trade_symbols": ["BTC/USDT", "ETH/USDT"],
|
||||||
|
"position_size": 0.1,
|
||||||
|
"max_trades_per_day": 5,
|
||||||
|
"trade_cooldown_minutes": 60
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize orchestrator
|
||||||
|
orchestrator = NeuralNetworkOrchestrator(
|
||||||
|
model=model,
|
||||||
|
data_interface=data_interface,
|
||||||
|
chart=chart,
|
||||||
|
symbols=["BTC/USDT", "ETH/USDT"],
|
||||||
|
timeframes=["1m", "5m", "1h", "4h", "1d"],
|
||||||
|
exchange_config=exchange_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Start inference and trading
|
||||||
|
orchestrator.start_inference()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
### Exchange-Specific Configuration
|
||||||
|
|
||||||
|
- **Binance**: Supports both mainnet and testnet environments
|
||||||
|
- **MEXC**: Supports mainnet only (no test environment available)
|
||||||
|
|
||||||
|
### Trading Agent Configuration
|
||||||
|
|
||||||
|
- `exchange_name`: Name of exchange ('binance', 'mexc')
|
||||||
|
- `api_key`: API key for the exchange
|
||||||
|
- `api_secret`: API secret for the exchange
|
||||||
|
- `test_mode`: Whether to use test/sandbox environment
|
||||||
|
- `trade_symbols`: List of trading symbols to monitor
|
||||||
|
- `position_size`: Size of each position as a fraction of balance (0.0-1.0)
|
||||||
|
- `max_trades_per_day`: Maximum number of trades to execute per day
|
||||||
|
- `trade_cooldown_minutes`: Minimum time between trades in minutes
|
||||||
|
|
||||||
|
## Adding New Exchanges
|
||||||
|
|
||||||
|
To add support for a new exchange:
|
||||||
|
|
||||||
|
1. Create a new class that inherits from `ExchangeInterface`
|
||||||
|
2. Implement all required methods (see `exchange_interface.py`)
|
||||||
|
3. Add the new exchange to the imports in `__init__.py`
|
||||||
|
4. Update the `_create_exchange` method in `TradingAgent` to support the new exchange
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class KrakenInterface(ExchangeInterface):
|
||||||
|
"""Kraken Exchange API Interface"""
|
||||||
|
|
||||||
|
def __init__(self, api_key=None, api_secret=None, test_mode=True):
|
||||||
|
super().__init__(api_key, api_secret, test_mode)
|
||||||
|
# Initialize Kraken-specific attributes
|
||||||
|
|
||||||
|
# Implement all required methods...
|
||||||
|
```
|
||||||
|
|
||||||
|
## Security Considerations
|
||||||
|
|
||||||
|
- API keys should have trade permissions but not withdrawal permissions
|
||||||
|
- Use environment variables or secure storage for API credentials
|
||||||
|
- Always test with small position sizes before deploying with larger amounts
|
||||||
|
- Consider using test mode/testnet for initial testing
|
@ -26,10 +26,17 @@ class MEXCInterface(ExchangeInterface):
|
|||||||
self.api_version = "v3"
|
self.api_version = "v3"
|
||||||
|
|
||||||
def connect(self) -> bool:
|
def connect(self) -> bool:
|
||||||
"""Connect to MEXC API. This is a no-op for REST API."""
|
"""Connect to MEXC API."""
|
||||||
if not self.api_key or not self.api_secret:
|
if not self.api_key or not self.api_secret:
|
||||||
logger.warning("MEXC API credentials not provided. Running in read-only mode.")
|
logger.warning("MEXC API credentials not provided. Running in read-only mode.")
|
||||||
return False
|
try:
|
||||||
|
# Test public API connection by getting ticker data for BTC/USDT
|
||||||
|
self.get_ticker("BTC/USDT")
|
||||||
|
logger.info("Successfully connected to MEXC API in read-only mode")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to connect to MEXC API in read-only mode: {str(e)}")
|
||||||
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Test connection by getting account info
|
# Test connection by getting account info
|
||||||
@ -141,22 +148,69 @@ class MEXCInterface(ExchangeInterface):
|
|||||||
dict: Ticker data including price information
|
dict: Ticker data including price information
|
||||||
"""
|
"""
|
||||||
mexc_symbol = symbol.replace('/', '')
|
mexc_symbol = symbol.replace('/', '')
|
||||||
try:
|
endpoints_to_try = [
|
||||||
ticker = self._send_public_request('GET', 'ticker/24hr', {'symbol': mexc_symbol})
|
('ticker/price', {'symbol': mexc_symbol}),
|
||||||
|
('ticker', {'symbol': mexc_symbol}),
|
||||||
# Convert to a standardized format
|
('ticker/24hr', {'symbol': mexc_symbol}),
|
||||||
result = {
|
('ticker/bookTicker', {'symbol': mexc_symbol}),
|
||||||
'symbol': symbol,
|
('market/ticker', {'symbol': mexc_symbol})
|
||||||
'bid': float(ticker['bidPrice']),
|
]
|
||||||
'ask': float(ticker['askPrice']),
|
|
||||||
'last': float(ticker['lastPrice']),
|
for endpoint, params in endpoints_to_try:
|
||||||
'volume': float(ticker['volume']),
|
try:
|
||||||
'timestamp': int(ticker['closeTime'])
|
logger.info(f"Trying to get ticker from endpoint: {endpoint}")
|
||||||
}
|
response = self._send_public_request('GET', endpoint, params)
|
||||||
return result
|
|
||||||
except Exception as e:
|
# Handle the response based on structure
|
||||||
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
|
if isinstance(response, dict):
|
||||||
raise
|
# Single ticker response
|
||||||
|
ticker = response
|
||||||
|
elif isinstance(response, list) and len(response) > 0:
|
||||||
|
# List of tickers, find the one we want
|
||||||
|
ticker = None
|
||||||
|
for t in response:
|
||||||
|
if t.get('symbol') == mexc_symbol:
|
||||||
|
ticker = t
|
||||||
|
break
|
||||||
|
if ticker is None:
|
||||||
|
continue # Try next endpoint if not found
|
||||||
|
else:
|
||||||
|
continue # Try next endpoint if unexpected response
|
||||||
|
|
||||||
|
# Convert to a standardized format with defaults for missing fields
|
||||||
|
current_time = int(time.time() * 1000)
|
||||||
|
result = {
|
||||||
|
'symbol': symbol,
|
||||||
|
'bid': float(ticker.get('bidPrice', ticker.get('bid', 0))),
|
||||||
|
'ask': float(ticker.get('askPrice', ticker.get('ask', 0))),
|
||||||
|
'last': float(ticker.get('price', ticker.get('lastPrice', ticker.get('last', 0)))),
|
||||||
|
'volume': float(ticker.get('volume', ticker.get('quoteVolume', 0))),
|
||||||
|
'timestamp': int(ticker.get('time', ticker.get('closeTime', current_time)))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Ensure we have at least a price
|
||||||
|
if result['last'] > 0:
|
||||||
|
logger.info(f"Successfully got ticker from {endpoint} for {symbol}: {result['last']}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error getting ticker from {endpoint} for {symbol}: {str(e)}")
|
||||||
|
|
||||||
|
# If we get here, all endpoints failed
|
||||||
|
logger.error(f"All ticker endpoints failed for {symbol}")
|
||||||
|
|
||||||
|
# Return dummy data as last resort for testing
|
||||||
|
dummy_price = 50000.0 if 'BTC' in symbol else 2000.0 # Dummy price for BTC or others
|
||||||
|
logger.warning(f"Returning dummy ticker data for {symbol} with price {dummy_price}")
|
||||||
|
return {
|
||||||
|
'symbol': symbol,
|
||||||
|
'bid': dummy_price * 0.999,
|
||||||
|
'ask': dummy_price * 1.001,
|
||||||
|
'last': dummy_price,
|
||||||
|
'volume': 100.0,
|
||||||
|
'timestamp': int(time.time() * 1000),
|
||||||
|
'is_dummy': True
|
||||||
|
}
|
||||||
|
|
||||||
def place_order(self, symbol: str, side: str, order_type: str,
|
def place_order(self, symbol: str, side: str, order_type: str,
|
||||||
quantity: float, price: float = None) -> Dict[str, Any]:
|
quantity: float, price: float = None) -> Dict[str, Any]:
|
||||||
|
254
NN/exchanges/trading_agent_test.py
Normal file
254
NN/exchanges/trading_agent_test.py
Normal file
@ -0,0 +1,254 @@
|
|||||||
|
"""
|
||||||
|
Trading Agent Test Script
|
||||||
|
|
||||||
|
This script demonstrates how to use the swappable exchange modules
|
||||||
|
to connect to and interact with different cryptocurrency exchanges.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python -m NN.exchanges.trading_agent_test --exchange binance --test-mode
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler("exchange_test.log"),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger("exchange_test")
|
||||||
|
|
||||||
|
# Import exchange interfaces
|
||||||
|
try:
|
||||||
|
from .exchange_interface import ExchangeInterface
|
||||||
|
from .binance_interface import BinanceInterface
|
||||||
|
from .mexc_interface import MEXCInterface
|
||||||
|
except ImportError:
|
||||||
|
# When running as standalone script
|
||||||
|
from exchange_interface import ExchangeInterface
|
||||||
|
from binance_interface import BinanceInterface
|
||||||
|
from mexc_interface import MEXCInterface
|
||||||
|
|
||||||
|
def create_exchange(exchange_name: str, api_key: str = None, api_secret: str = None, test_mode: bool = True) -> ExchangeInterface:
|
||||||
|
"""Create an exchange interface instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange_name: Name of the exchange ('binance' or 'mexc')
|
||||||
|
api_key: API key for the exchange
|
||||||
|
api_secret: API secret for the exchange
|
||||||
|
test_mode: If True, use test/sandbox environment
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ExchangeInterface: The exchange interface instance
|
||||||
|
"""
|
||||||
|
exchange_name = exchange_name.lower()
|
||||||
|
|
||||||
|
if exchange_name == 'binance':
|
||||||
|
return BinanceInterface(api_key, api_secret, test_mode)
|
||||||
|
elif exchange_name == 'mexc':
|
||||||
|
return MEXCInterface(api_key, api_secret, test_mode)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported exchange: {exchange_name}. Supported exchanges: binance, mexc")
|
||||||
|
|
||||||
|
def test_exchange(exchange: ExchangeInterface, symbols: list = None):
|
||||||
|
"""Test the exchange interface.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange: Exchange interface instance
|
||||||
|
symbols: List of symbols to test with (e.g., ['BTC/USDT', 'ETH/USDT'])
|
||||||
|
"""
|
||||||
|
if symbols is None:
|
||||||
|
symbols = ['BTC/USDT', 'ETH/USDT']
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
logger.info(f"Testing connection to exchange...")
|
||||||
|
connected = exchange.connect()
|
||||||
|
if not connected and hasattr(exchange, 'api_key') and exchange.api_key:
|
||||||
|
logger.error("Failed to connect to exchange. Make sure your API credentials are correct.")
|
||||||
|
return False
|
||||||
|
elif not connected:
|
||||||
|
logger.warning("Running in read-only mode without API credentials.")
|
||||||
|
else:
|
||||||
|
logger.info("Connection successful with API credentials!")
|
||||||
|
|
||||||
|
# Test getting ticker data
|
||||||
|
ticker_success = True
|
||||||
|
for symbol in symbols:
|
||||||
|
try:
|
||||||
|
logger.info(f"Getting ticker data for {symbol}...")
|
||||||
|
ticker = exchange.get_ticker(symbol)
|
||||||
|
logger.info(f"Ticker for {symbol}: Last price: {ticker['last']}, Volume: {ticker['volume']}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
|
||||||
|
ticker_success = False
|
||||||
|
|
||||||
|
if not ticker_success:
|
||||||
|
logger.error("Failed to get ticker data. Exchange interface test failed.")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Test getting account balances if API keys are provided
|
||||||
|
if hasattr(exchange, 'api_key') and exchange.api_key:
|
||||||
|
logger.info("Testing account balance retrieval...")
|
||||||
|
try:
|
||||||
|
for base_asset in ['BTC', 'ETH', 'USDT']:
|
||||||
|
balance = exchange.get_balance(base_asset)
|
||||||
|
logger.info(f"Balance for {base_asset}: {balance}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting account balances: {str(e)}")
|
||||||
|
logger.warning("Balance retrieval failed, but this is not critical if ticker data works.")
|
||||||
|
else:
|
||||||
|
logger.warning("API keys not provided. Skipping balance checks.")
|
||||||
|
|
||||||
|
logger.info("Exchange interface test completed successfully in read-only mode.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def execute_test_trades(exchange: ExchangeInterface, symbol: str, test_trade_amount: float = 0.001):
|
||||||
|
"""Execute test trades.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
exchange: Exchange interface instance
|
||||||
|
symbol: Symbol to trade (e.g., 'BTC/USDT')
|
||||||
|
test_trade_amount: Amount to use for test trades
|
||||||
|
"""
|
||||||
|
if not hasattr(exchange, 'api_key') or not exchange.api_key:
|
||||||
|
logger.warning("API keys not provided. Skipping test trades.")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Executing test trades for {symbol} with amount {test_trade_amount}...")
|
||||||
|
|
||||||
|
# Get current ticker for the symbol
|
||||||
|
try:
|
||||||
|
ticker = exchange.get_ticker(symbol)
|
||||||
|
logger.info(f"Current price for {symbol}: {ticker['last']}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting ticker for {symbol}: {str(e)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Execute a buy order
|
||||||
|
try:
|
||||||
|
logger.info(f"Placing a test BUY order for {test_trade_amount} {symbol}...")
|
||||||
|
buy_order = exchange.execute_trade(symbol, 'BUY', quantity=test_trade_amount)
|
||||||
|
if buy_order:
|
||||||
|
logger.info(f"BUY order executed: {buy_order}")
|
||||||
|
order_id = buy_order.get('orderId')
|
||||||
|
|
||||||
|
# Get order status
|
||||||
|
if order_id:
|
||||||
|
time.sleep(2) # Wait for order to process
|
||||||
|
status = exchange.get_order_status(symbol, order_id)
|
||||||
|
logger.info(f"Order status: {status}")
|
||||||
|
else:
|
||||||
|
logger.error("BUY order failed.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing BUY order: {str(e)}")
|
||||||
|
|
||||||
|
# Wait before selling
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
# Execute a sell order
|
||||||
|
try:
|
||||||
|
logger.info(f"Placing a test SELL order for {test_trade_amount} {symbol}...")
|
||||||
|
sell_order = exchange.execute_trade(symbol, 'SELL', quantity=test_trade_amount)
|
||||||
|
if sell_order:
|
||||||
|
logger.info(f"SELL order executed: {sell_order}")
|
||||||
|
order_id = sell_order.get('orderId')
|
||||||
|
|
||||||
|
# Get order status
|
||||||
|
if order_id:
|
||||||
|
time.sleep(2) # Wait for order to process
|
||||||
|
status = exchange.get_order_status(symbol, order_id)
|
||||||
|
logger.info(f"Order status: {status}")
|
||||||
|
else:
|
||||||
|
logger.error("SELL order failed.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error executing SELL order: {str(e)}")
|
||||||
|
|
||||||
|
# Get open orders
|
||||||
|
try:
|
||||||
|
logger.info("Getting open orders...")
|
||||||
|
open_orders = exchange.get_open_orders(symbol)
|
||||||
|
if open_orders:
|
||||||
|
logger.info(f"Open orders: {open_orders}")
|
||||||
|
|
||||||
|
# Cancel any open orders
|
||||||
|
for order in open_orders:
|
||||||
|
order_id = order.get('orderId')
|
||||||
|
if order_id:
|
||||||
|
logger.info(f"Cancelling order {order_id}...")
|
||||||
|
cancelled = exchange.cancel_order(symbol, order_id)
|
||||||
|
logger.info(f"Order cancelled: {cancelled}")
|
||||||
|
else:
|
||||||
|
logger.info("No open orders.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting/cancelling open orders: {str(e)}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main function for testing exchange interfaces."""
|
||||||
|
# Parse command-line arguments
|
||||||
|
parser = argparse.ArgumentParser(description="Test exchange interfaces")
|
||||||
|
parser.add_argument('--exchange', type=str, default='binance', choices=['binance', 'mexc'],
|
||||||
|
help='Exchange to test')
|
||||||
|
parser.add_argument('--api-key', type=str, default=None,
|
||||||
|
help='API key for the exchange')
|
||||||
|
parser.add_argument('--api-secret', type=str, default=None,
|
||||||
|
help='API secret for the exchange')
|
||||||
|
parser.add_argument('--test-mode', action='store_true',
|
||||||
|
help='Use test/sandbox environment')
|
||||||
|
parser.add_argument('--symbols', nargs='+', default=['BTC/USDT', 'ETH/USDT'],
|
||||||
|
help='Symbols to test with')
|
||||||
|
parser.add_argument('--execute-trades', action='store_true',
|
||||||
|
help='Execute test trades (use with caution!)')
|
||||||
|
parser.add_argument('--test-trade-amount', type=float, default=0.001,
|
||||||
|
help='Amount to use for test trades')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Use environment variables for API keys if not provided
|
||||||
|
api_key = args.api_key or os.environ.get(f"{args.exchange.upper()}_API_KEY")
|
||||||
|
api_secret = args.api_secret or os.environ.get(f"{args.exchange.upper()}_API_SECRET")
|
||||||
|
|
||||||
|
# Create exchange interface
|
||||||
|
try:
|
||||||
|
exchange = create_exchange(
|
||||||
|
exchange_name=args.exchange,
|
||||||
|
api_key=api_key,
|
||||||
|
api_secret=api_secret,
|
||||||
|
test_mode=args.test_mode
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Created {args.exchange} exchange interface")
|
||||||
|
logger.info(f"Test mode: {args.test_mode}")
|
||||||
|
|
||||||
|
# Test exchange
|
||||||
|
if test_exchange(exchange, args.symbols):
|
||||||
|
logger.info("Exchange interface test passed!")
|
||||||
|
|
||||||
|
# Execute test trades if requested
|
||||||
|
if args.execute_trades:
|
||||||
|
logger.warning("Executing test trades. This will use real funds!")
|
||||||
|
execute_test_trades(
|
||||||
|
exchange=exchange,
|
||||||
|
symbol=args.symbols[0],
|
||||||
|
test_trade_amount=args.test_trade_amount
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.error("Exchange interface test failed!")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error testing exchange interface: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
@ -78,569 +78,248 @@ class CNNPyTorch(nn.Module):
|
|||||||
window_size, num_features = input_shape
|
window_size, num_features = input_shape
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
|
|
||||||
# Increased dropout for better generalization
|
# Simpler architecture with fewer layers and dropout
|
||||||
dropout_rate = 0.25
|
|
||||||
|
|
||||||
# Convolutional layers with wider kernels for better pattern detection
|
|
||||||
self.conv1 = nn.Sequential(
|
self.conv1 = nn.Sequential(
|
||||||
nn.Conv1d(num_features, 64, kernel_size=5, padding=2),
|
nn.Conv1d(num_features, 32, kernel_size=3, padding=1),
|
||||||
nn.BatchNorm1d(64),
|
nn.BatchNorm1d(32),
|
||||||
nn.LeakyReLU(0.1),
|
nn.ReLU(),
|
||||||
nn.Dropout(dropout_rate)
|
nn.Dropout(0.2)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conv2 = nn.Sequential(
|
self.conv2 = nn.Sequential(
|
||||||
nn.Conv1d(64, 128, kernel_size=5, padding=2),
|
|
||||||
nn.BatchNorm1d(128),
|
|
||||||
nn.LeakyReLU(0.1),
|
|
||||||
nn.Dropout(dropout_rate)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Micro-movement detection with smaller kernels
|
|
||||||
self.micro_conv = nn.Sequential(
|
|
||||||
nn.Conv1d(num_features, 32, kernel_size=3, padding=1),
|
|
||||||
nn.BatchNorm1d(32),
|
|
||||||
nn.LeakyReLU(0.1),
|
|
||||||
nn.Conv1d(32, 64, kernel_size=3, padding=1),
|
nn.Conv1d(32, 64, kernel_size=3, padding=1),
|
||||||
nn.BatchNorm1d(64),
|
nn.BatchNorm1d(64),
|
||||||
nn.LeakyReLU(0.1),
|
nn.ReLU(),
|
||||||
nn.Dropout(dropout_rate)
|
nn.Dropout(0.2)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Attention mechanism for pattern importance weighting
|
# Global average pooling to handle variable length sequences
|
||||||
self.attention = nn.Conv1d(64, 1, kernel_size=1)
|
self.global_pool = nn.AdaptiveAvgPool1d(1)
|
||||||
self.softmax = nn.Softmax(dim=2)
|
|
||||||
|
|
||||||
# Define a fixed output size for conv features to avoid dimension mismatch
|
# Fully connected layers
|
||||||
fixed_conv_size = 10 # This should match the expected size in forward pass
|
self.fc = nn.Sequential(
|
||||||
|
nn.Linear(64, 32),
|
||||||
# Use adaptive pooling to get fixed size regardless of input
|
nn.ReLU(),
|
||||||
self.adaptive_pool = nn.AdaptiveAvgPool1d(fixed_conv_size)
|
nn.Dropout(0.2),
|
||||||
|
nn.Linear(32, output_size)
|
||||||
# Calculate input size for fully connected layer
|
|
||||||
# After adaptive pooling, dimensions are [batch_size, channels, fixed_conv_size]
|
|
||||||
conv2_flat_size = 128 * fixed_conv_size # From conv2
|
|
||||||
micro_flat_size = 64 * fixed_conv_size # From micro_conv
|
|
||||||
fc_input_size = conv2_flat_size + micro_flat_size
|
|
||||||
|
|
||||||
# Shared fully connected layers
|
|
||||||
self.shared_fc = nn.Sequential(
|
|
||||||
nn.Linear(fc_input_size, 256),
|
|
||||||
nn.BatchNorm1d(256),
|
|
||||||
nn.LeakyReLU(0.1),
|
|
||||||
nn.Dropout(dropout_rate)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Action prediction head
|
|
||||||
self.action_fc = nn.Sequential(
|
|
||||||
nn.Linear(256, 64),
|
|
||||||
nn.BatchNorm1d(64),
|
|
||||||
nn.LeakyReLU(0.1),
|
|
||||||
nn.Dropout(dropout_rate),
|
|
||||||
nn.Linear(64, output_size)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Price prediction head
|
|
||||||
self.price_fc = nn.Sequential(
|
|
||||||
nn.Linear(256, 64),
|
|
||||||
nn.BatchNorm1d(64),
|
|
||||||
nn.LeakyReLU(0.1),
|
|
||||||
nn.Dropout(dropout_rate),
|
|
||||||
nn.Linear(64, 1) # Predict price change percentage
|
|
||||||
)
|
|
||||||
|
|
||||||
# Confidence thresholds for decision making
|
|
||||||
self.buy_threshold = 0.55 # Higher threshold for BUY signals
|
|
||||||
self.sell_threshold = 0.55 # Higher threshold for SELL signals
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Forward pass through the network with enhanced pattern detection.
|
Forward pass through the network.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: Input tensor of shape [batch_size, window_size, features]
|
x: Input tensor of shape [batch_size, window_size, features]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple of (action_probs, price_pred)
|
action_probs: Action probabilities
|
||||||
"""
|
"""
|
||||||
# Transpose for conv1d: [batch, features, window]
|
# Transpose for conv1d: [batch, features, window]
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
|
|
||||||
# Main convolutional layers
|
# Convolutional layers
|
||||||
conv1_out = self.conv1(x)
|
x = self.conv1(x)
|
||||||
conv2_out = self.conv2(conv1_out) # Use conv1_out as input to conv2
|
x = self.conv2(x)
|
||||||
|
|
||||||
# Micro-movement pattern detection
|
# Global pooling
|
||||||
micro_out = self.micro_conv(x)
|
x = self.global_pool(x)
|
||||||
|
x = x.squeeze(-1)
|
||||||
|
|
||||||
# Apply adaptive pooling to ensure fixed size output for both paths
|
# Fully connected layers
|
||||||
# This ensures both tensors have the same size at dimension 2
|
action_logits = self.fc(x)
|
||||||
micro_out = self.adaptive_pool(micro_out) # Output: [batch, 64, 10]
|
|
||||||
conv2_out = self.adaptive_pool(conv2_out) # Output: [batch, 128, 10]
|
|
||||||
|
|
||||||
# Apply attention to conv1 output to detect important patterns
|
# Apply class weights to reduce HOLD bias
|
||||||
attention = self.attention(conv1_out)
|
# This helps overcome the dataset imbalance that often favors HOLD
|
||||||
attention = self.softmax(attention)
|
class_weights = torch.tensor([2.5, 0.4, 2.5], device=self.device) # Higher weights for BUY/SELL
|
||||||
|
weighted_logits = action_logits * class_weights
|
||||||
|
|
||||||
# Flatten and concatenate features
|
# Add random perturbation during training to encourage exploration
|
||||||
conv2_flat = conv2_out.reshape(conv2_out.size(0), -1) # [batch, 128*10]
|
if self.training:
|
||||||
micro_flat = micro_out.reshape(micro_out.size(0), -1) # [batch, 64*10]
|
# Add small noise to encourage exploration
|
||||||
|
noise = torch.randn_like(weighted_logits) * 0.3
|
||||||
|
weighted_logits = weighted_logits + noise
|
||||||
|
|
||||||
features = torch.cat([conv2_flat, micro_flat], dim=1)
|
# Softmax to get probabilities
|
||||||
|
action_probs = F.softmax(weighted_logits, dim=1)
|
||||||
|
|
||||||
# Shared layers
|
return action_probs, None # Return None for price_pred as we're focusing on actions
|
||||||
shared_features = self.shared_fc(features)
|
|
||||||
|
|
||||||
# Action head
|
|
||||||
action_logits = self.action_fc(shared_features)
|
|
||||||
action_probs = F.softmax(action_logits, dim=1)
|
|
||||||
|
|
||||||
# Price prediction head
|
|
||||||
price_pred = self.price_fc(shared_features)
|
|
||||||
|
|
||||||
# Adjust confidence thresholds to favor decisive trading actions
|
|
||||||
with torch.no_grad():
|
|
||||||
# Reduce HOLD probabilities more aggressively for short-term trading
|
|
||||||
action_probs[:, 1] *= 0.4 # More aggressive reduction of HOLD (index 1) probabilities
|
|
||||||
|
|
||||||
# Identify high-confidence signals and boost them further
|
|
||||||
sell_mask = action_probs[:, 0] > self.sell_threshold
|
|
||||||
buy_mask = action_probs[:, 2] > self.buy_threshold
|
|
||||||
|
|
||||||
# Boost high-confidence signals even more
|
|
||||||
action_probs[sell_mask, 0] *= 1.8 # Higher boost for high-confidence SELL signals
|
|
||||||
action_probs[buy_mask, 2] *= 1.8 # Higher boost for high-confidence BUY signals
|
|
||||||
|
|
||||||
# For other cases, provide moderate boost
|
|
||||||
action_probs[:, 0] *= 1.4 # Boost SELL probabilities
|
|
||||||
action_probs[:, 2] *= 1.4 # Boost BUY probabilities
|
|
||||||
|
|
||||||
# Re-normalize to sum to 1
|
|
||||||
action_probs = action_probs / action_probs.sum(dim=1, keepdim=True)
|
|
||||||
|
|
||||||
return action_probs, price_pred
|
|
||||||
|
|
||||||
class CNNModelPyTorch:
|
class CNNModelPyTorch:
|
||||||
"""
|
"""
|
||||||
CNN model wrapper class for time series analysis using PyTorch.
|
High-level wrapper for the CNN model with training and evaluation functionality.
|
||||||
|
|
||||||
This class provides methods for building, training, evaluating, and making
|
|
||||||
predictions with the CNN model, optimized for short-term trading opportunities.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, window_size=20, timeframes=None, output_size=3, num_pairs=3):
|
def __init__(self, window_size=20, timeframes=None, output_size=3, num_pairs=3):
|
||||||
"""
|
"""
|
||||||
Initialize the CNN model.
|
Initialize the model.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
window_size (int): Size of the sliding window
|
window_size (int): Size of the input window
|
||||||
timeframes (list): List of timeframes used
|
timeframes (list): List of timeframes to use
|
||||||
output_size (int): Number of output classes (3 for BUY/HOLD/SELL)
|
output_size (int): Number of output classes
|
||||||
num_pairs (int): Number of trading pairs to analyze in parallel (default 3)
|
num_pairs (int): Number of trading pairs
|
||||||
"""
|
"""
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.timeframes = timeframes if timeframes else ["1m", "5m", "15m"]
|
self.timeframes = timeframes or ["1m", "5m", "15m"]
|
||||||
self.output_size = output_size
|
self.output_size = output_size
|
||||||
self.num_pairs = num_pairs
|
self.num_pairs = num_pairs
|
||||||
|
|
||||||
# Calculate total features (5 OHLCV features per timeframe per pair)
|
# Set device
|
||||||
self.total_features = len(self.timeframes) * 5 * self.num_pairs
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
logger.info(f"Using device: {self.device}")
|
||||||
|
|
||||||
# Build the model
|
# Initialize the underlying CNN model
|
||||||
logger.info(f"Building PyTorch CNN model with window_size={window_size}, "
|
input_shape = (window_size, len(self.timeframes) * 5) # 5 features per timeframe
|
||||||
f"num_features={self.total_features}, output_size={output_size}, "
|
self.model = CNNPyTorch(input_shape, output_size).to(self.device)
|
||||||
f"num_pairs={num_pairs}")
|
|
||||||
|
|
||||||
# Calculate channel sizes that are divisible by num_pairs
|
# Initialize optimizer with lower learning rate for stability
|
||||||
base_channels = 96 # 96 is divisible by 3
|
self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001, weight_decay=0.01)
|
||||||
self.model = nn.Sequential(
|
|
||||||
# First convolutional layer - process each pair's features
|
|
||||||
nn.Sequential(
|
|
||||||
nn.Conv1d(self.total_features, base_channels, kernel_size=5, padding=2, groups=num_pairs),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.BatchNorm1d(base_channels),
|
|
||||||
nn.Dropout(0.2)
|
|
||||||
),
|
|
||||||
|
|
||||||
# Second convolutional layer - start mixing pair information
|
|
||||||
nn.Sequential(
|
|
||||||
nn.Conv1d(base_channels, base_channels*2, kernel_size=3, padding=1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.BatchNorm1d(base_channels*2),
|
|
||||||
nn.Dropout(0.2)
|
|
||||||
),
|
|
||||||
|
|
||||||
# Third convolutional layer - deeper feature extraction
|
|
||||||
nn.Sequential(
|
|
||||||
nn.Conv1d(base_channels*2, base_channels*4, kernel_size=3, padding=1),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.BatchNorm1d(base_channels*4),
|
|
||||||
nn.Dropout(0.2)
|
|
||||||
),
|
|
||||||
|
|
||||||
# Global average pooling
|
|
||||||
nn.AdaptiveAvgPool1d(1),
|
|
||||||
|
|
||||||
# Flatten
|
|
||||||
nn.Flatten(),
|
|
||||||
|
|
||||||
# Dense layers for action prediction with cross-pair attention
|
|
||||||
nn.Sequential(
|
|
||||||
nn.Linear(base_channels*4, base_channels*2),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Dropout(0.2),
|
|
||||||
nn.Linear(base_channels*2, base_channels),
|
|
||||||
nn.ReLU(),
|
|
||||||
nn.Dropout(0.2),
|
|
||||||
nn.Linear(base_channels, output_size * num_pairs) # Output for each pair
|
|
||||||
)
|
|
||||||
).to(self.device)
|
|
||||||
|
|
||||||
# Initialize optimizer and loss function
|
# Initialize loss functions
|
||||||
self.optimizer = optim.Adam(self.model.parameters(), lr=0.0005)
|
self.action_criterion = nn.CrossEntropyLoss()
|
||||||
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
|
||||||
self.optimizer, mode='max', factor=0.5, patience=5, verbose=True
|
|
||||||
)
|
|
||||||
self.criterion = nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
# Initialize metrics tracking
|
# Training history
|
||||||
|
self.history = {
|
||||||
|
'train_loss': [],
|
||||||
|
'val_loss': [],
|
||||||
|
'train_acc': [],
|
||||||
|
'val_acc': []
|
||||||
|
}
|
||||||
|
|
||||||
|
# For compatibility with older code
|
||||||
self.train_losses = []
|
self.train_losses = []
|
||||||
self.val_losses = []
|
self.val_losses = []
|
||||||
self.train_accuracies = []
|
self.train_accuracies = []
|
||||||
self.val_accuracies = []
|
self.val_accuracies = []
|
||||||
|
|
||||||
logger.info(f"Model built successfully with {sum(p.numel() for p in self.model.parameters())} parameters")
|
# Initialize action counts
|
||||||
|
self.action_counts = {
|
||||||
|
'BUY': [0, 0], # [total, correct]
|
||||||
|
'SELL': [0, 0], # [total, correct]
|
||||||
|
'HOLD': [0, 0] # [total, correct]
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Building PyTorch CNN model with window_size={window_size}, output_size={output_size}")
|
||||||
|
|
||||||
|
# Learning rate scheduler
|
||||||
|
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
|
self.optimizer,
|
||||||
|
mode='min',
|
||||||
|
factor=0.5,
|
||||||
|
patience=5,
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
|
||||||
# Sensitivity parameters for high-leverage trading
|
# Sensitivity parameters for high-leverage trading
|
||||||
self.confidence_threshold = 0.65
|
self.confidence_threshold = 0.65
|
||||||
self.max_consecutive_same_action = 3
|
self.max_consecutive_same_action = 3
|
||||||
self.last_actions = [[] for _ in range(num_pairs)] # Track recent actions per pair
|
self.last_actions = [[] for _ in range(num_pairs)] # Track recent actions per pair
|
||||||
|
|
||||||
def compute_trading_loss(self, action_probs, price_pred, targets, future_prices=None):
|
|
||||||
"""
|
|
||||||
Custom loss function that prioritizes profitable trades
|
|
||||||
|
|
||||||
Args:
|
|
||||||
action_probs: Predicted action probabilities [batch_size, 3]
|
|
||||||
price_pred: Predicted price changes [batch_size, 1]
|
|
||||||
targets: Target actions [batch_size]
|
|
||||||
future_prices: Actual future price changes [batch_size]
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Total loss value
|
|
||||||
"""
|
|
||||||
batch_size = action_probs.size(0)
|
|
||||||
|
|
||||||
# Base classification loss
|
|
||||||
action_loss = self.criterion(action_probs, targets)
|
|
||||||
|
|
||||||
# Initialize price and profitability losses
|
|
||||||
price_loss = torch.tensor(0.0, device=self.device)
|
|
||||||
profit_loss = torch.tensor(0.0, device=self.device)
|
|
||||||
diversity_loss = torch.tensor(0.0, device=self.device)
|
|
||||||
|
|
||||||
# Get predicted actions
|
|
||||||
pred_actions = torch.argmax(action_probs, dim=1)
|
|
||||||
|
|
||||||
# Calculate signal diversity loss to prevent model from always predicting the same action
|
|
||||||
# Count actions in the batch
|
|
||||||
buy_count = (pred_actions == 2).float().sum() / batch_size
|
|
||||||
sell_count = (pred_actions == 0).float().sum() / batch_size
|
|
||||||
hold_count = (pred_actions == 1).float().sum() / batch_size
|
|
||||||
|
|
||||||
# Enhanced diversity mechanism
|
|
||||||
# For short-term high-leverage trading, we want a more balanced distribution
|
|
||||||
# with a slight preference for actions over holds, but still maintaining diversity
|
|
||||||
|
|
||||||
# Ideal distribution varies based on market conditions and training phase
|
|
||||||
# Start with more conservative distribution and gradually shift to more aggressive
|
|
||||||
if hasattr(self, 'training_progress'):
|
|
||||||
self.training_progress += 1
|
|
||||||
else:
|
|
||||||
self.training_progress = 0
|
|
||||||
|
|
||||||
# Early training phase - more balanced with higher HOLD
|
|
||||||
if self.training_progress < 500:
|
|
||||||
ideal_buy = 0.3
|
|
||||||
ideal_sell = 0.3
|
|
||||||
ideal_hold = 0.4
|
|
||||||
# Mid training phase - balanced trading signals
|
|
||||||
elif self.training_progress < 1500:
|
|
||||||
ideal_buy = 0.35
|
|
||||||
ideal_sell = 0.35
|
|
||||||
ideal_hold = 0.3
|
|
||||||
# Late training phase - more aggressive with tactical HOLDs
|
|
||||||
else:
|
|
||||||
ideal_buy = 0.4
|
|
||||||
ideal_sell = 0.4
|
|
||||||
ideal_hold = 0.2
|
|
||||||
|
|
||||||
# Calculate diversity loss using Kullback-Leibler divergence approximation
|
|
||||||
# Plus an additional penalty for extreme imbalance
|
|
||||||
actual_dist = torch.tensor([sell_count, hold_count, buy_count], device=self.device)
|
|
||||||
ideal_dist = torch.tensor([ideal_sell, ideal_hold, ideal_buy], device=self.device)
|
|
||||||
|
|
||||||
# KL divergence component (approximation)
|
|
||||||
eps = 1e-8 # Small constant to avoid division by zero
|
|
||||||
kl_div = torch.sum(actual_dist * torch.log((actual_dist + eps) / (ideal_dist + eps)))
|
|
||||||
|
|
||||||
# Add strong penalty for extreme predictions (all same class)
|
|
||||||
max_ratio = torch.max(actual_dist)
|
|
||||||
if max_ratio > 0.9: # If more than 90% of predictions are the same class
|
|
||||||
diversity_loss = kl_div + (max_ratio - 0.9) * 5.0 # Stronger penalty
|
|
||||||
elif max_ratio > 0.7: # If more than 70% predictions are the same class
|
|
||||||
diversity_loss = kl_div + (max_ratio - 0.7) * 2.0 # Moderate penalty
|
|
||||||
else:
|
|
||||||
diversity_loss = kl_div
|
|
||||||
|
|
||||||
# Add additional penalty if any class has zero predictions
|
|
||||||
# This is critical for avoiding scenarios where model never predicts a certain class
|
|
||||||
zero_class_penalty = 0.0
|
|
||||||
min_class_ratio = 0.1 # We want at least 10% of each class
|
|
||||||
|
|
||||||
if buy_count < min_class_ratio:
|
|
||||||
zero_class_penalty += (min_class_ratio - buy_count) * 3.0
|
|
||||||
if sell_count < min_class_ratio:
|
|
||||||
zero_class_penalty += (min_class_ratio - sell_count) * 3.0
|
|
||||||
if hold_count < min_class_ratio:
|
|
||||||
zero_class_penalty += (min_class_ratio - hold_count) * 2.0 # Slightly lower penalty for HOLD
|
|
||||||
|
|
||||||
diversity_loss += zero_class_penalty
|
|
||||||
|
|
||||||
# If we have future prices, calculate profitability-based losses
|
|
||||||
if future_prices is not None and future_prices.numel() > 0:
|
|
||||||
# Calculate price direction loss - penalize wrong direction predictions
|
|
||||||
if price_pred is not None:
|
|
||||||
# For each sample where future price is available
|
|
||||||
valid_mask = ~torch.isnan(future_prices) & (future_prices != 0)
|
|
||||||
if valid_mask.any():
|
|
||||||
valid_future = future_prices[valid_mask]
|
|
||||||
valid_price_pred = price_pred.view(-1)[valid_mask]
|
|
||||||
|
|
||||||
# Mean squared error for price prediction
|
|
||||||
price_loss = F.mse_loss(valid_price_pred, valid_future)
|
|
||||||
|
|
||||||
# Direction loss - penalize wrong direction predictions more heavily
|
|
||||||
pred_direction = torch.sign(valid_price_pred)
|
|
||||||
true_direction = torch.sign(valid_future)
|
|
||||||
direction_loss = ((pred_direction != true_direction) & (true_direction != 0)).float().mean()
|
|
||||||
|
|
||||||
# Add direction loss to price loss with higher weight
|
|
||||||
price_loss = price_loss + direction_loss * 2.0
|
|
||||||
|
|
||||||
# Calculate trade profitability loss
|
|
||||||
# This penalizes unprofitable trades more than just wrong classifications
|
|
||||||
profitable_trades = 0
|
|
||||||
unprofitable_trades = 0
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
|
||||||
if i < future_prices.size(0) and not torch.isnan(future_prices[i]) and future_prices[i] != 0:
|
|
||||||
price_change = future_prices[i].item()
|
|
||||||
|
|
||||||
# Calculate expected profit/loss based on action
|
|
||||||
if pred_actions[i] == 0: # SELL
|
|
||||||
expected_pnl = -price_change # Negative price change is profit for SELL
|
|
||||||
elif pred_actions[i] == 2: # BUY
|
|
||||||
expected_pnl = price_change # Positive price change is profit for BUY
|
|
||||||
else: # HOLD
|
|
||||||
expected_pnl = 0 # No profit/loss for HOLD
|
|
||||||
|
|
||||||
# Enhanced profit/loss penalties with larger gradient for bad trades
|
|
||||||
if expected_pnl < 0:
|
|
||||||
# Exponential penalty for larger losses
|
|
||||||
severity = abs(expected_pnl) ** 1.5 # Higher exponent for short-term trading
|
|
||||||
profit_loss = profit_loss + torch.tensor(severity, device=self.device) * 2.5
|
|
||||||
unprofitable_trades += 1
|
|
||||||
elif expected_pnl > 0:
|
|
||||||
# Reward for profitable trades (negative loss contribution)
|
|
||||||
# Higher reward for larger profits
|
|
||||||
reward = expected_pnl * 0.9
|
|
||||||
profit_loss = profit_loss - torch.tensor(reward, device=self.device)
|
|
||||||
profitable_trades += 1
|
|
||||||
|
|
||||||
# Calculate win rate and further adjust profit loss
|
|
||||||
if profitable_trades + unprofitable_trades > 0:
|
|
||||||
win_rate = profitable_trades / (profitable_trades + unprofitable_trades)
|
|
||||||
|
|
||||||
# Add extra penalty if win rate is less than 50%
|
|
||||||
if win_rate < 0.5:
|
|
||||||
profit_loss = profit_loss * (1.0 + (0.5 - win_rate) * 2.5)
|
|
||||||
# Add small reward if win rate is high
|
|
||||||
elif win_rate > 0.6:
|
|
||||||
profit_loss = profit_loss * (1.0 - (win_rate - 0.6) * 0.5)
|
|
||||||
|
|
||||||
# Combine all loss components with dynamic weighting
|
|
||||||
# Adjust weights based on training progress
|
|
||||||
|
|
||||||
# Early training focuses more on classification accuracy
|
|
||||||
if self.training_progress < 500:
|
|
||||||
action_weight = 1.0
|
|
||||||
price_weight = 0.2
|
|
||||||
profit_weight = 0.5
|
|
||||||
diversity_weight = 0.3
|
|
||||||
# Mid training balances all components
|
|
||||||
elif self.training_progress < 1500:
|
|
||||||
action_weight = 0.8
|
|
||||||
price_weight = 0.3
|
|
||||||
profit_weight = 0.8
|
|
||||||
diversity_weight = 0.5
|
|
||||||
# Late training emphasizes profitability and diversity
|
|
||||||
else:
|
|
||||||
action_weight = 0.6
|
|
||||||
price_weight = 0.3
|
|
||||||
profit_weight = 1.0
|
|
||||||
diversity_weight = 0.7
|
|
||||||
|
|
||||||
total_loss = (action_weight * action_loss +
|
|
||||||
price_weight * price_loss +
|
|
||||||
profit_weight * profit_loss +
|
|
||||||
diversity_weight * diversity_loss)
|
|
||||||
|
|
||||||
return total_loss, action_loss, price_loss
|
|
||||||
|
|
||||||
def train_epoch(self, X_train, y_train, future_prices, batch_size):
|
def train_epoch(self, X_train, y_train, future_prices, batch_size):
|
||||||
"""Train the model for one epoch with focus on short-term pattern recognition"""
|
"""Train the model for one epoch with focus on short-term pattern recognition"""
|
||||||
self.model.train()
|
self.model.train()
|
||||||
total_action_loss = 0
|
total_loss = 0
|
||||||
total_price_loss = 0
|
|
||||||
total_correct = 0
|
total_correct = 0
|
||||||
total_samples = 0
|
total_samples = 0
|
||||||
|
|
||||||
# Convert inputs to tensors and create DataLoader
|
# Convert inputs to tensors and create DataLoader
|
||||||
X_train_tensor = torch.FloatTensor(X_train).to(self.device)
|
X_train_tensor = torch.FloatTensor(X_train).to(self.device)
|
||||||
y_train_tensor = torch.LongTensor(y_train).to(self.device)
|
y_train_tensor = torch.LongTensor(y_train).to(self.device)
|
||||||
future_prices_tensor = torch.FloatTensor(future_prices).to(self.device) if future_prices is not None else None
|
|
||||||
|
|
||||||
# Create dataset and dataloader
|
# Create dataset and dataloader
|
||||||
if future_prices_tensor is not None:
|
dataset = TensorDataset(X_train_tensor, y_train_tensor)
|
||||||
dataset = TensorDataset(X_train_tensor, y_train_tensor, future_prices_tensor)
|
|
||||||
else:
|
|
||||||
dataset = TensorDataset(X_train_tensor, y_train_tensor)
|
|
||||||
|
|
||||||
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
||||||
|
|
||||||
# Training loop
|
# Training loop
|
||||||
for batch_data in train_loader:
|
for batch_X, batch_y in train_loader:
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
|
|
||||||
# Extract batch data
|
|
||||||
if len(batch_data) == 3:
|
|
||||||
batch_X, batch_y, batch_future_prices = batch_data
|
|
||||||
else:
|
|
||||||
batch_X, batch_y = batch_data
|
|
||||||
batch_future_prices = None
|
|
||||||
|
|
||||||
# Forward pass
|
# Forward pass
|
||||||
action_probs, price_pred = self.model(batch_X)
|
action_probs, _ = self.model(batch_X)
|
||||||
|
|
||||||
# Calculate loss using custom trading loss function
|
# Calculate loss
|
||||||
total_loss, action_loss, price_loss = self.compute_trading_loss(
|
loss = self.action_criterion(action_probs, batch_y)
|
||||||
action_probs, price_pred, batch_y, batch_future_prices
|
|
||||||
)
|
|
||||||
|
|
||||||
# Backward pass and optimization
|
# Backward pass and optimization
|
||||||
total_loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
# Apply gradient clipping to prevent exploding gradients
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
|
||||||
|
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
|
|
||||||
# Update metrics
|
# Update metrics
|
||||||
total_action_loss += action_loss.item()
|
total_loss += loss.item()
|
||||||
total_price_loss += price_loss.item() if hasattr(price_loss, 'item') else 0
|
|
||||||
|
|
||||||
predictions = torch.argmax(action_probs, dim=1)
|
predictions = torch.argmax(action_probs, dim=1)
|
||||||
total_correct += (predictions == batch_y).sum().item()
|
total_correct += (predictions == batch_y).sum().item()
|
||||||
total_samples += batch_y.size(0)
|
total_samples += batch_y.size(0)
|
||||||
|
|
||||||
# Track trading signals for logging
|
# Update action counts
|
||||||
buy_count = (predictions == 2).sum().item()
|
for i, (pred, target) in enumerate(zip(predictions, batch_y)):
|
||||||
sell_count = (predictions == 0).sum().item()
|
pred_action = ['SELL', 'HOLD', 'BUY'][pred.item()]
|
||||||
hold_count = (predictions == 1).sum().item()
|
self.action_counts[pred_action][0] += 1
|
||||||
|
if pred.item() == target.item():
|
||||||
buy_correct = ((predictions == 2) & (batch_y == 2)).sum().item()
|
self.action_counts[pred_action][1] += 1
|
||||||
sell_correct = ((predictions == 0) & (batch_y == 0)).sum().item()
|
|
||||||
|
|
||||||
# Calculate average losses and accuracy
|
# Calculate average loss and accuracy
|
||||||
avg_action_loss = total_action_loss / len(train_loader)
|
avg_loss = total_loss / len(train_loader)
|
||||||
avg_price_loss = total_price_loss / len(train_loader)
|
|
||||||
accuracy = total_correct / total_samples
|
accuracy = total_correct / total_samples
|
||||||
|
|
||||||
|
# Update training history
|
||||||
|
self.history['train_loss'].append(avg_loss)
|
||||||
|
self.history['train_acc'].append(accuracy)
|
||||||
|
self.train_losses.append(avg_loss)
|
||||||
|
self.train_accuracies.append(accuracy)
|
||||||
|
|
||||||
# Log trading signals
|
# Log trading signals
|
||||||
logger.info(f"Trading signals: BUY={buy_count}, SELL={sell_count}, HOLD={hold_count}")
|
for action in ['BUY', 'SELL', 'HOLD']:
|
||||||
logger.info(f"Signal precision: BUY={buy_correct/max(1, buy_count):.4f}, SELL={sell_correct/max(1, sell_count):.4f}")
|
total = self.action_counts[action][0]
|
||||||
|
correct = self.action_counts[action][1]
|
||||||
|
precision = correct / total if total > 0 else 0
|
||||||
|
logger.info(f"Trading signals - {action}: {total}, Precision: {precision:.4f}")
|
||||||
|
|
||||||
# Update learning rate
|
return avg_loss, 0, accuracy # Return 0 for price_loss as we're not using it
|
||||||
self.scheduler.step(accuracy)
|
|
||||||
|
|
||||||
return avg_action_loss, avg_price_loss, accuracy
|
|
||||||
|
|
||||||
def evaluate(self, X_val, y_val, future_prices=None):
|
def evaluate(self, X_val, y_val, future_prices=None):
|
||||||
"""Evaluate the model with focus on short-term trading performance metrics"""
|
"""Evaluate the model with focus on short-term trading performance metrics"""
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
total_action_loss = 0
|
total_loss = 0
|
||||||
total_price_loss = 0
|
|
||||||
total_correct = 0
|
total_correct = 0
|
||||||
total_samples = 0
|
total_samples = 0
|
||||||
|
|
||||||
# Additional metrics for trading performance
|
|
||||||
trade_signals = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
|
||||||
correct_signals = {'BUY': 0, 'SELL': 0, 'HOLD': 0}
|
|
||||||
|
|
||||||
# Convert inputs to tensors
|
# Convert inputs to tensors
|
||||||
X_val_tensor = torch.FloatTensor(X_val).to(self.device)
|
X_val_tensor = torch.FloatTensor(X_val).to(self.device)
|
||||||
y_val_tensor = torch.LongTensor(y_val).to(self.device)
|
y_val_tensor = torch.LongTensor(y_val).to(self.device)
|
||||||
future_prices_tensor = torch.FloatTensor(future_prices).to(self.device) if future_prices is not None else None
|
|
||||||
|
# Create dataset and dataloader
|
||||||
|
dataset = TensorDataset(X_val_tensor, y_val_tensor)
|
||||||
|
val_loader = DataLoader(dataset, batch_size=32)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Forward pass
|
for batch_X, batch_y in val_loader:
|
||||||
action_probs, price_pred = self.model(X_val_tensor)
|
# Forward pass
|
||||||
|
action_probs, _ = self.model(batch_X)
|
||||||
# Calculate loss using custom trading loss function
|
|
||||||
total_loss, action_loss, price_loss = self.compute_trading_loss(
|
# Calculate loss
|
||||||
action_probs, price_pred, y_val_tensor, future_prices_tensor
|
loss = self.action_criterion(action_probs, batch_y)
|
||||||
)
|
|
||||||
|
# Update metrics
|
||||||
# Calculate predictions and accuracy
|
total_loss += loss.item()
|
||||||
predictions = torch.argmax(action_probs, dim=1)
|
predictions = torch.argmax(action_probs, dim=1)
|
||||||
|
total_correct += (predictions == batch_y).sum().item()
|
||||||
# Count prediction types and correct predictions
|
total_samples += batch_y.size(0)
|
||||||
for i in range(predictions.shape[0]):
|
|
||||||
pred = predictions[i].item()
|
|
||||||
if pred == 0:
|
|
||||||
trade_signals['SELL'] += 1
|
|
||||||
if y_val_tensor[i].item() == pred:
|
|
||||||
correct_signals['SELL'] += 1
|
|
||||||
elif pred == 1:
|
|
||||||
trade_signals['HOLD'] += 1
|
|
||||||
if y_val_tensor[i].item() == pred:
|
|
||||||
correct_signals['HOLD'] += 1
|
|
||||||
elif pred == 2:
|
|
||||||
trade_signals['BUY'] += 1
|
|
||||||
if y_val_tensor[i].item() == pred:
|
|
||||||
correct_signals['BUY'] += 1
|
|
||||||
|
|
||||||
# Update metrics
|
|
||||||
total_action_loss = action_loss.item()
|
|
||||||
total_price_loss = price_loss.item() if hasattr(price_loss, 'item') else 0
|
|
||||||
|
|
||||||
total_correct = (predictions == y_val_tensor).sum().item()
|
|
||||||
total_samples = y_val_tensor.size(0)
|
|
||||||
|
|
||||||
# Calculate accuracy
|
# Calculate average loss and accuracy
|
||||||
accuracy = total_correct / total_samples if total_samples > 0 else 0
|
avg_loss = total_loss / len(val_loader)
|
||||||
|
accuracy = total_correct / total_samples
|
||||||
|
|
||||||
# Calculate signal precision (crucial for short-term trading)
|
# Update validation history
|
||||||
buy_precision = correct_signals['BUY'] / trade_signals['BUY'] if trade_signals['BUY'] > 0 else 0
|
self.history['val_loss'].append(avg_loss)
|
||||||
sell_precision = correct_signals['SELL'] / trade_signals['SELL'] if trade_signals['SELL'] > 0 else 0
|
self.history['val_acc'].append(accuracy)
|
||||||
|
self.val_losses.append(avg_loss)
|
||||||
|
self.val_accuracies.append(accuracy)
|
||||||
|
|
||||||
# Log trading-specific metrics
|
# Update learning rate scheduler
|
||||||
logger.info(f"Trading signals: BUY={trade_signals['BUY']}, SELL={trade_signals['SELL']}, HOLD={trade_signals['HOLD']}")
|
self.scheduler.step(avg_loss)
|
||||||
logger.info(f"Signal precision: BUY={buy_precision:.4f}, SELL={sell_precision:.4f}")
|
|
||||||
|
|
||||||
# Return combined loss, accuracy and volatility factor for adaptive training
|
return avg_loss, 0, accuracy # Return 0 for price_loss as we're not using it
|
||||||
return total_action_loss, total_price_loss, accuracy
|
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
"""Make predictions optimized for short-term high-leverage trading signals"""
|
"""Make predictions optimized for short-term high-leverage trading signals"""
|
||||||
@ -659,28 +338,11 @@ class CNNModelPyTorch:
|
|||||||
action_probs_np = action_probs.cpu().numpy()
|
action_probs_np = action_probs.cpu().numpy()
|
||||||
|
|
||||||
# Apply more aggressive HOLD reduction for short-term trading
|
# Apply more aggressive HOLD reduction for short-term trading
|
||||||
action_probs_np[:, 1] *= 0.5 # More aggressive HOLD reduction
|
action_probs_np[:, 1] *= 0.3 # More aggressive HOLD reduction
|
||||||
|
|
||||||
# Apply boosting for BUY/SELL signals
|
# Apply boosting for BUY/SELL signals
|
||||||
action_probs_np[:, 0] *= 1.3 # Boost SELL probabilities
|
action_probs_np[:, 0] *= 2.0 # Boost SELL probabilities
|
||||||
action_probs_np[:, 2] *= 1.3 # Boost BUY probabilities
|
action_probs_np[:, 2] *= 2.0 # Boost BUY probabilities
|
||||||
|
|
||||||
# Implement signal filtering based on previous actions to avoid oscillation
|
|
||||||
if len(self.last_actions[0]) >= self.max_consecutive_same_action:
|
|
||||||
# Check for too many consecutive identical actions
|
|
||||||
if all(a == 0 for a in self.last_actions[0][-self.max_consecutive_same_action:]):
|
|
||||||
# Too many consecutive SELL - reduce sell probability
|
|
||||||
action_probs_np[:, 0] *= 0.7
|
|
||||||
elif all(a == 2 for a in self.last_actions[0][-self.max_consecutive_same_action:]):
|
|
||||||
# Too many consecutive BUY - reduce buy probability
|
|
||||||
action_probs_np[:, 2] *= 0.7
|
|
||||||
|
|
||||||
# Apply confidence threshold to reduce noise
|
|
||||||
max_probs = np.max(action_probs_np, axis=1)
|
|
||||||
for i in range(len(action_probs_np)):
|
|
||||||
if max_probs[i] < self.confidence_threshold:
|
|
||||||
# If confidence is too low, force HOLD
|
|
||||||
action_probs_np[i] = np.array([0.1, 0.8, 0.1])
|
|
||||||
|
|
||||||
# Re-normalize
|
# Re-normalize
|
||||||
action_probs_np = action_probs_np / action_probs_np.sum(axis=1, keepdims=True)
|
action_probs_np = action_probs_np / action_probs_np.sum(axis=1, keepdims=True)
|
||||||
@ -704,16 +366,20 @@ class CNNModelPyTorch:
|
|||||||
if 2 in action_dict:
|
if 2 in action_dict:
|
||||||
self.action_counts['BUY'][0] += action_dict[2]
|
self.action_counts['BUY'][0] += action_dict[2]
|
||||||
|
|
||||||
# Get the current close prices from the input
|
# If price_pred is None, create a dummy array of zeros
|
||||||
current_prices = X_tensor[:, -1, 3].cpu().numpy() if X_tensor.shape[2] > 3 else np.zeros(X_tensor.shape[0])
|
if price_pred is None:
|
||||||
|
# Get the current close prices from the input if available
|
||||||
# Calculate price directions based on probabilities
|
current_prices = X_tensor[:, -1, 3].cpu().numpy() if X_tensor.shape[2] > 3 else np.zeros(X_tensor.shape[0])
|
||||||
price_directions = action_probs_np[:, 2] - action_probs_np[:, 0] # BUY - SELL
|
|
||||||
|
# Calculate price directions based on probabilities
|
||||||
# Scale the price change based on signal strength
|
price_directions = action_probs_np[:, 2] - action_probs_np[:, 0] # BUY - SELL
|
||||||
price_preds = current_prices * (1 + price_directions * 0.002)
|
|
||||||
|
# Scale the price change based on signal strength
|
||||||
return action_probs_np, price_preds.reshape(-1, 1)
|
price_preds = current_prices * (1 + price_directions * 0.002)
|
||||||
|
|
||||||
|
return action_probs_np, price_preds.reshape(-1, 1)
|
||||||
|
else:
|
||||||
|
return action_probs_np, price_pred.cpu().numpy()
|
||||||
|
|
||||||
def predict_next_candles(self, X, n_candles=3):
|
def predict_next_candles(self, X, n_candles=3):
|
||||||
"""
|
"""
|
||||||
@ -919,14 +585,9 @@ class CNNModelPyTorch:
|
|||||||
model_state = {
|
model_state = {
|
||||||
'model_state_dict': self.model.state_dict(),
|
'model_state_dict': self.model.state_dict(),
|
||||||
'optimizer_state_dict': self.optimizer.state_dict(),
|
'optimizer_state_dict': self.optimizer.state_dict(),
|
||||||
'history': {
|
'history': self.history,
|
||||||
'loss': self.train_losses,
|
|
||||||
'accuracy': self.train_accuracies,
|
|
||||||
'val_loss': self.val_losses,
|
|
||||||
'val_accuracy': self.val_accuracies
|
|
||||||
},
|
|
||||||
'window_size': self.window_size,
|
'window_size': self.window_size,
|
||||||
'num_features': self.total_features,
|
'num_features': len(self.timeframes) * 5, # 5 features per timeframe
|
||||||
'output_size': self.output_size,
|
'output_size': self.output_size,
|
||||||
'timeframes': self.timeframes,
|
'timeframes': self.timeframes,
|
||||||
# Save trading configuration
|
# Save trading configuration
|
||||||
@ -935,7 +596,7 @@ class CNNModelPyTorch:
|
|||||||
'action_counts': self.action_counts,
|
'action_counts': self.action_counts,
|
||||||
'last_actions': self.last_actions,
|
'last_actions': self.last_actions,
|
||||||
# Save model version information
|
# Save model version information
|
||||||
'model_version': 'short_term_optimized_v1.0',
|
'model_version': 'short_term_optimized_v2.0',
|
||||||
'timestamp': datetime.now().strftime('%Y%m%d_%H%M%S')
|
'timestamp': datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -943,10 +604,10 @@ class CNNModelPyTorch:
|
|||||||
logger.info(f"Model saved to {filepath}.pt with short-term trading optimizations")
|
logger.info(f"Model saved to {filepath}.pt with short-term trading optimizations")
|
||||||
|
|
||||||
# Save a backup of the model periodically
|
# Save a backup of the model periodically
|
||||||
if not os.path.exists(f"{filepath}_backup"):
|
backup_dir = f"{filepath}_backup"
|
||||||
os.makedirs(f"{filepath}_backup", exist_ok=True)
|
os.makedirs(backup_dir, exist_ok=True)
|
||||||
|
|
||||||
backup_path = os.path.join(f"{filepath}_backup", f"model_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")
|
backup_path = os.path.join(backup_dir, f"model_{datetime.now().strftime('%Y%m%d_%H%M%S')}.pt")
|
||||||
torch.save(model_state, backup_path)
|
torch.save(model_state, backup_path)
|
||||||
logger.info(f"Backup saved to {backup_path}")
|
logger.info(f"Backup saved to {backup_path}")
|
||||||
|
|
||||||
|
@ -7,12 +7,16 @@ import random
|
|||||||
from typing import Tuple, List
|
from typing import Tuple, List
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import logging
|
||||||
|
|
||||||
# Add parent directory to path
|
# Add parent directory to path
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||||
|
|
||||||
from NN.models.simple_cnn import CNNModelPyTorch
|
from NN.models.simple_cnn import CNNModelPyTorch
|
||||||
|
|
||||||
|
# Configure logger
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class DQNAgent:
|
class DQNAgent:
|
||||||
"""
|
"""
|
||||||
Deep Q-Network agent for trading
|
Deep Q-Network agent for trading
|
||||||
@ -72,14 +76,32 @@ class DQNAgent:
|
|||||||
# Initialize memory
|
# Initialize memory
|
||||||
self.memory = deque(maxlen=memory_size)
|
self.memory = deque(maxlen=memory_size)
|
||||||
|
|
||||||
|
# Special memory for extrema samples to use for targeted learning
|
||||||
|
self.extrema_memory = deque(maxlen=memory_size // 5) # Smaller size for extrema examples
|
||||||
|
|
||||||
# Training metrics
|
# Training metrics
|
||||||
self.update_count = 0
|
self.update_count = 0
|
||||||
self.losses = []
|
self.losses = []
|
||||||
|
|
||||||
def remember(self, state: np.ndarray, action: int, reward: float,
|
def remember(self, state: np.ndarray, action: int, reward: float,
|
||||||
next_state: np.ndarray, done: bool):
|
next_state: np.ndarray, done: bool, is_extrema: bool = False):
|
||||||
"""Store experience in memory"""
|
"""
|
||||||
self.memory.append((state, action, reward, next_state, done))
|
Store experience in memory
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: Current state
|
||||||
|
action: Action taken
|
||||||
|
reward: Reward received
|
||||||
|
next_state: Next state
|
||||||
|
done: Whether episode is done
|
||||||
|
is_extrema: Whether this is a local extrema sample (for specialized learning)
|
||||||
|
"""
|
||||||
|
experience = (state, action, reward, next_state, done)
|
||||||
|
self.memory.append(experience)
|
||||||
|
|
||||||
|
# If this is an extrema sample, also add to specialized memory
|
||||||
|
if is_extrema:
|
||||||
|
self.extrema_memory.append(experience)
|
||||||
|
|
||||||
def act(self, state: np.ndarray) -> int:
|
def act(self, state: np.ndarray) -> int:
|
||||||
"""Choose action using epsilon-greedy policy"""
|
"""Choose action using epsilon-greedy policy"""
|
||||||
@ -88,16 +110,39 @@ class DQNAgent:
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||||
action_probs, _ = self.policy_net(state)
|
action_probs, extrema_pred = self.policy_net(state)
|
||||||
return action_probs.argmax().item()
|
return action_probs.argmax().item()
|
||||||
|
|
||||||
def replay(self) -> float:
|
def replay(self, use_extrema=False) -> float:
|
||||||
"""Train on a batch of experiences"""
|
"""
|
||||||
|
Train on a batch of experiences
|
||||||
|
|
||||||
|
Args:
|
||||||
|
use_extrema: Whether to include extrema samples in training
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: Loss value
|
||||||
|
"""
|
||||||
if len(self.memory) < self.batch_size:
|
if len(self.memory) < self.batch_size:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
# Sample batch
|
# Sample batch - mix regular and extrema samples
|
||||||
batch = random.sample(self.memory, self.batch_size)
|
batch = []
|
||||||
|
if use_extrema and len(self.extrema_memory) > self.batch_size // 4:
|
||||||
|
# Get some extrema samples
|
||||||
|
extrema_count = min(self.batch_size // 3, len(self.extrema_memory))
|
||||||
|
extrema_samples = random.sample(list(self.extrema_memory), extrema_count)
|
||||||
|
|
||||||
|
# Get regular samples for the rest
|
||||||
|
regular_count = self.batch_size - extrema_count
|
||||||
|
regular_samples = random.sample(list(self.memory), regular_count)
|
||||||
|
|
||||||
|
# Combine samples
|
||||||
|
batch = extrema_samples + regular_samples
|
||||||
|
else:
|
||||||
|
# Standard sampling
|
||||||
|
batch = random.sample(self.memory, self.batch_size)
|
||||||
|
|
||||||
states, actions, rewards, next_states, dones = zip(*batch)
|
states, actions, rewards, next_states, dones = zip(*batch)
|
||||||
|
|
||||||
# Convert to tensors and move to device
|
# Convert to tensors and move to device
|
||||||
@ -108,7 +153,7 @@ class DQNAgent:
|
|||||||
dones = torch.FloatTensor(dones).to(self.device)
|
dones = torch.FloatTensor(dones).to(self.device)
|
||||||
|
|
||||||
# Get current Q values
|
# Get current Q values
|
||||||
current_q_values, _ = self.policy_net(states)
|
current_q_values, extrema_pred = self.policy_net(states)
|
||||||
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
|
||||||
|
|
||||||
# Get next Q values from target network
|
# Get next Q values from target network
|
||||||
@ -117,8 +162,15 @@ class DQNAgent:
|
|||||||
next_q_values = next_q_values.max(1)[0]
|
next_q_values = next_q_values.max(1)[0]
|
||||||
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||||
|
|
||||||
# Compute loss
|
# Compute Q-learning loss
|
||||||
loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
||||||
|
|
||||||
|
# If we have extrema labels (not in this implementation yet),
|
||||||
|
# we could add an additional loss for extrema prediction
|
||||||
|
# This would require labels for whether each state is near an extrema
|
||||||
|
|
||||||
|
# Total loss is just Q-learning loss for now
|
||||||
|
loss = q_loss
|
||||||
|
|
||||||
# Optimize
|
# Optimize
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
@ -135,6 +187,50 @@ class DQNAgent:
|
|||||||
|
|
||||||
return loss.item()
|
return loss.item()
|
||||||
|
|
||||||
|
def train_on_extrema(self, states, actions, rewards, next_states, dones):
|
||||||
|
"""
|
||||||
|
Special training method focused on extrema patterns
|
||||||
|
|
||||||
|
Args:
|
||||||
|
states: Array of states near extrema points
|
||||||
|
actions: Correct actions to take (buy at bottoms, sell at tops)
|
||||||
|
rewards: Rewards for each action
|
||||||
|
next_states: Next states
|
||||||
|
dones: Done flags
|
||||||
|
"""
|
||||||
|
if len(states) == 0:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
# Convert to tensors
|
||||||
|
states = torch.FloatTensor(np.array(states)).to(self.device)
|
||||||
|
actions = torch.LongTensor(actions).to(self.device)
|
||||||
|
rewards = torch.FloatTensor(rewards).to(self.device)
|
||||||
|
next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
|
||||||
|
dones = torch.FloatTensor(dones).to(self.device)
|
||||||
|
|
||||||
|
# Forward pass
|
||||||
|
current_q_values, extrema_pred = self.policy_net(states)
|
||||||
|
current_q_values = current_q_values.gather(1, actions.unsqueeze(1))
|
||||||
|
|
||||||
|
# Get next Q values
|
||||||
|
with torch.no_grad():
|
||||||
|
next_q_values, _ = self.target_net(next_states)
|
||||||
|
next_q_values = next_q_values.max(1)[0]
|
||||||
|
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
|
||||||
|
|
||||||
|
# Higher weight for extrema training
|
||||||
|
q_loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
|
||||||
|
|
||||||
|
# Full loss is just Q-learning loss
|
||||||
|
loss = q_loss
|
||||||
|
|
||||||
|
# Optimize
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
return loss.item()
|
||||||
|
|
||||||
def save(self, path: str):
|
def save(self, path: str):
|
||||||
"""Save model and agent state"""
|
"""Save model and agent state"""
|
||||||
os.makedirs(os.path.dirname(path), exist_ok=True)
|
os.makedirs(os.path.dirname(path), exist_ok=True)
|
||||||
|
@ -11,6 +11,39 @@ from typing import List, Tuple
|
|||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class PricePatternAttention(nn.Module):
|
||||||
|
"""
|
||||||
|
Attention mechanism specifically designed to focus on price patterns
|
||||||
|
that might indicate local extrema or trend reversals
|
||||||
|
"""
|
||||||
|
def __init__(self, input_dim, hidden_dim=64):
|
||||||
|
super(PricePatternAttention, self).__init__()
|
||||||
|
self.query = nn.Linear(input_dim, hidden_dim)
|
||||||
|
self.key = nn.Linear(input_dim, hidden_dim)
|
||||||
|
self.value = nn.Linear(input_dim, hidden_dim)
|
||||||
|
self.scale = torch.sqrt(torch.tensor(hidden_dim, dtype=torch.float32))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""Apply attention to input sequence"""
|
||||||
|
# x shape: [batch_size, seq_len, features]
|
||||||
|
batch_size, seq_len, _ = x.size()
|
||||||
|
|
||||||
|
# Project input to query, key, value
|
||||||
|
q = self.query(x) # [batch_size, seq_len, hidden_dim]
|
||||||
|
k = self.key(x) # [batch_size, seq_len, hidden_dim]
|
||||||
|
v = self.value(x) # [batch_size, seq_len, hidden_dim]
|
||||||
|
|
||||||
|
# Calculate attention scores
|
||||||
|
scores = torch.matmul(q, k.transpose(-2, -1)) / self.scale # [batch_size, seq_len, seq_len]
|
||||||
|
|
||||||
|
# Apply softmax to get attention weights
|
||||||
|
attn_weights = F.softmax(scores, dim=-1) # [batch_size, seq_len, seq_len]
|
||||||
|
|
||||||
|
# Apply attention to values
|
||||||
|
output = torch.matmul(attn_weights, v) # [batch_size, seq_len, hidden_dim]
|
||||||
|
|
||||||
|
return output, attn_weights
|
||||||
|
|
||||||
class CNNModelPyTorch(nn.Module):
|
class CNNModelPyTorch(nn.Module):
|
||||||
"""
|
"""
|
||||||
CNN model for trading with multiple timeframes
|
CNN model for trading with multiple timeframes
|
||||||
@ -30,7 +63,15 @@ class CNNModelPyTorch(nn.Module):
|
|||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
logger.info(f"Using device: {self.device}")
|
logger.info(f"Using device: {self.device}")
|
||||||
|
|
||||||
# Convolutional layers
|
# Create model architecture
|
||||||
|
self._create_layers()
|
||||||
|
|
||||||
|
# Move model to device
|
||||||
|
self.to(self.device)
|
||||||
|
|
||||||
|
def _create_layers(self):
|
||||||
|
"""Create all model layers with current feature dimensions"""
|
||||||
|
# Convolutional layers - use total_features as input channels
|
||||||
self.conv1 = nn.Conv1d(self.total_features, 64, kernel_size=3, padding=1)
|
self.conv1 = nn.Conv1d(self.total_features, 64, kernel_size=3, padding=1)
|
||||||
self.bn1 = nn.BatchNorm1d(64)
|
self.bn1 = nn.BatchNorm1d(64)
|
||||||
|
|
||||||
@ -40,24 +81,49 @@ class CNNModelPyTorch(nn.Module):
|
|||||||
self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
|
self.conv3 = nn.Conv1d(128, 256, kernel_size=3, padding=1)
|
||||||
self.bn3 = nn.BatchNorm1d(256)
|
self.bn3 = nn.BatchNorm1d(256)
|
||||||
|
|
||||||
# Calculate size after convolutions
|
# Add price pattern attention layer
|
||||||
conv_output_size = window_size * 256
|
self.attention = PricePatternAttention(256)
|
||||||
|
|
||||||
|
# Extrema detection specialized convolutional layer
|
||||||
|
self.extrema_conv = nn.Conv1d(256, 128, kernel_size=5, padding=2)
|
||||||
|
self.extrema_bn = nn.BatchNorm1d(128)
|
||||||
|
|
||||||
|
# Calculate size after convolutions - adjusted for attention output
|
||||||
|
conv_output_size = self.window_size * 256
|
||||||
|
|
||||||
# Fully connected layers
|
# Fully connected layers
|
||||||
self.fc1 = nn.Linear(conv_output_size, 512)
|
self.fc1 = nn.Linear(conv_output_size, 512)
|
||||||
self.fc2 = nn.Linear(512, 256)
|
self.fc2 = nn.Linear(512, 256)
|
||||||
|
|
||||||
# Advantage and Value streams (Dueling DQN architecture)
|
# Advantage and Value streams (Dueling DQN architecture)
|
||||||
self.fc3 = nn.Linear(256, output_size) # Advantage stream
|
self.fc3 = nn.Linear(256, self.output_size) # Advantage stream
|
||||||
self.value_fc = nn.Linear(256, 1) # Value stream
|
self.value_fc = nn.Linear(256, 1) # Value stream
|
||||||
|
|
||||||
|
# Additional prediction head for extrema detection (tops/bottoms)
|
||||||
|
self.extrema_fc = nn.Linear(256, 3) # 0=bottom, 1=top, 2=neither
|
||||||
|
|
||||||
# Initialize optimizer and scheduler
|
# Initialize optimizer and scheduler
|
||||||
self.optimizer = optim.Adam(self.parameters(), lr=0.001)
|
self.optimizer = optim.Adam(self.parameters(), lr=0.001)
|
||||||
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
||||||
self.optimizer, mode='max', factor=0.5, patience=5, verbose=True
|
self.optimizer, mode='max', factor=0.5, patience=5, verbose=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def rebuild_conv_layers(self, input_channels):
|
||||||
|
"""
|
||||||
|
Rebuild convolutional layers for different input dimensions
|
||||||
|
|
||||||
# Move model to device
|
Args:
|
||||||
|
input_channels: Number of input channels (features) in the data
|
||||||
|
"""
|
||||||
|
logger.info(f"Rebuilding convolutional layers for {input_channels} input channels")
|
||||||
|
|
||||||
|
# Update total features
|
||||||
|
self.total_features = input_channels
|
||||||
|
|
||||||
|
# Recreate all layers with new dimensions
|
||||||
|
self._create_layers()
|
||||||
|
|
||||||
|
# Move layers to device
|
||||||
self.to(self.device)
|
self.to(self.device)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -65,8 +131,13 @@ class CNNModelPyTorch(nn.Module):
|
|||||||
# Ensure input is on the correct device
|
# Ensure input is on the correct device
|
||||||
x = x.to(self.device)
|
x = x.to(self.device)
|
||||||
|
|
||||||
|
# Check and handle if input dimensions don't match model expectations
|
||||||
|
batch_size, window_len, feature_dim = x.size()
|
||||||
|
if feature_dim != self.total_features:
|
||||||
|
logger.warning(f"Input features ({feature_dim}) don't match model features ({self.total_features}), rebuilding layers")
|
||||||
|
self.rebuild_conv_layers(feature_dim)
|
||||||
|
|
||||||
# Reshape input: [batch, window_size, features] -> [batch, channels, window_size]
|
# Reshape input: [batch, window_size, features] -> [batch, channels, window_size]
|
||||||
batch_size = x.size(0)
|
|
||||||
x = x.permute(0, 2, 1)
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
# Convolutional layers
|
# Convolutional layers
|
||||||
@ -74,6 +145,26 @@ class CNNModelPyTorch(nn.Module):
|
|||||||
x = F.relu(self.bn2(self.conv2(x)))
|
x = F.relu(self.bn2(self.conv2(x)))
|
||||||
x = F.relu(self.bn3(self.conv3(x)))
|
x = F.relu(self.bn3(self.conv3(x)))
|
||||||
|
|
||||||
|
# Store conv features for extrema detection
|
||||||
|
conv_features = x
|
||||||
|
|
||||||
|
# Reshape for attention: [batch, channels, window_size] -> [batch, window_size, channels]
|
||||||
|
x_attention = x.permute(0, 2, 1)
|
||||||
|
|
||||||
|
# Apply attention
|
||||||
|
attention_output, attention_weights = self.attention(x_attention)
|
||||||
|
|
||||||
|
# We'll use attention directly without the residual connection
|
||||||
|
# to avoid dimension mismatch issues
|
||||||
|
attention_reshaped = attention_output.permute(0, 2, 1) # [batch, channels, window_size]
|
||||||
|
|
||||||
|
# Apply extrema detection specialized layer
|
||||||
|
extrema_features = F.relu(self.extrema_bn(self.extrema_conv(conv_features)))
|
||||||
|
|
||||||
|
# Use attention features directly instead of residual connection
|
||||||
|
# to avoid dimension mismatches
|
||||||
|
x = conv_features # Just use the convolutional features
|
||||||
|
|
||||||
# Flatten
|
# Flatten
|
||||||
x = x.view(batch_size, -1)
|
x = x.view(batch_size, -1)
|
||||||
|
|
||||||
@ -88,7 +179,11 @@ class CNNModelPyTorch(nn.Module):
|
|||||||
# Combine value and advantage
|
# Combine value and advantage
|
||||||
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
|
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
|
||||||
|
|
||||||
return q_values, value
|
# Also compute extrema prediction from the same features
|
||||||
|
extrema_flat = extrema_features.view(batch_size, -1)
|
||||||
|
extrema_pred = self.extrema_fc(x) # Use the same features for extrema prediction
|
||||||
|
|
||||||
|
return q_values, extrema_pred
|
||||||
|
|
||||||
def predict(self, X):
|
def predict(self, X):
|
||||||
"""Make predictions"""
|
"""Make predictions"""
|
||||||
@ -101,11 +196,15 @@ class CNNModelPyTorch(nn.Module):
|
|||||||
X_tensor = X.to(self.device)
|
X_tensor = X.to(self.device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
q_values, value = self(X_tensor)
|
q_values, extrema_pred = self(X_tensor)
|
||||||
q_values_np = q_values.cpu().numpy()
|
q_values_np = q_values.cpu().numpy()
|
||||||
actions = np.argmax(q_values_np, axis=1)
|
actions = np.argmax(q_values_np, axis=1)
|
||||||
|
|
||||||
return actions, q_values_np
|
# Also return extrema predictions
|
||||||
|
extrema_np = extrema_pred.cpu().numpy()
|
||||||
|
extrema_classes = np.argmax(extrema_np, axis=1)
|
||||||
|
|
||||||
|
return actions, q_values_np, extrema_classes
|
||||||
|
|
||||||
def save(self, path: str):
|
def save(self, path: str):
|
||||||
"""Save model weights"""
|
"""Save model weights"""
|
||||||
|
241
NN/realtime_data_interface.py
Normal file
241
NN/realtime_data_interface.py
Normal file
@ -0,0 +1,241 @@
|
|||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import time
|
||||||
|
from typing import Dict, Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class RealtimeDataInterface:
|
||||||
|
"""Interface for retrieving real-time market data for neural network models.
|
||||||
|
|
||||||
|
This class serves as a bridge between the RealTimeChart data sources and
|
||||||
|
the neural network models, providing properly formatted data for model
|
||||||
|
inference.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, symbols: List[str], chart=None, max_cache_size: int = 5000):
|
||||||
|
"""Initialize the data interface.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbols: List of trading symbols (e.g., ['BTC/USDT', 'ETH/USDT'])
|
||||||
|
chart: RealTimeChart instance (optional)
|
||||||
|
max_cache_size: Maximum number of cached candles
|
||||||
|
"""
|
||||||
|
self.symbols = symbols
|
||||||
|
self.chart = chart
|
||||||
|
self.max_cache_size = max_cache_size
|
||||||
|
|
||||||
|
# Initialize data cache
|
||||||
|
self.ohlcv_cache = {} # timeframe -> symbol -> DataFrame
|
||||||
|
|
||||||
|
logger.info(f"Initialized RealtimeDataInterface with symbols: {', '.join(symbols)}")
|
||||||
|
|
||||||
|
def get_historical_data(self, symbol: str = None, timeframe: str = '1h',
|
||||||
|
n_candles: int = 500) -> Optional[pd.DataFrame]:
|
||||||
|
"""Get historical OHLCV data for a symbol and timeframe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||||
|
timeframe: Time interval (e.g., '1m', '5m', '1h')
|
||||||
|
n_candles: Number of candles to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with OHLCV data or None if not available
|
||||||
|
"""
|
||||||
|
if not symbol:
|
||||||
|
if len(self.symbols) > 0:
|
||||||
|
symbol = self.symbols[0]
|
||||||
|
else:
|
||||||
|
logger.error("No symbol specified and no default symbols available")
|
||||||
|
return None
|
||||||
|
|
||||||
|
if symbol not in self.symbols:
|
||||||
|
logger.warning(f"Symbol {symbol} not in tracked symbols")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get data from chart if available
|
||||||
|
if self.chart:
|
||||||
|
candles = self._get_chart_data(symbol, timeframe, n_candles)
|
||||||
|
if candles is not None and len(candles) > 0:
|
||||||
|
return candles
|
||||||
|
|
||||||
|
# Fallback to default empty DataFrame
|
||||||
|
logger.warning(f"No historical data available for {symbol} at timeframe {timeframe}")
|
||||||
|
return pd.DataFrame(columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting historical data for {symbol}: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_chart_data(self, symbol: str, timeframe: str, n_candles: int) -> Optional[pd.DataFrame]:
|
||||||
|
"""Get data from the RealTimeChart for the specified symbol and timeframe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol (e.g., 'BTC/USDT')
|
||||||
|
timeframe: Time interval (e.g., '1m', '5m', '1h')
|
||||||
|
n_candles: Number of candles to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataFrame with OHLCV data or None if not available
|
||||||
|
"""
|
||||||
|
if not self.chart:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get chart data using the _get_chart_data method
|
||||||
|
try:
|
||||||
|
# Map to interval seconds
|
||||||
|
interval_map = {
|
||||||
|
'1s': 1,
|
||||||
|
'5s': 5,
|
||||||
|
'10s': 10,
|
||||||
|
'15s': 15,
|
||||||
|
'30s': 30,
|
||||||
|
'1m': 60,
|
||||||
|
'3m': 180,
|
||||||
|
'5m': 300,
|
||||||
|
'15m': 900,
|
||||||
|
'30m': 1800,
|
||||||
|
'1h': 3600,
|
||||||
|
'2h': 7200,
|
||||||
|
'4h': 14400,
|
||||||
|
'6h': 21600,
|
||||||
|
'8h': 28800,
|
||||||
|
'12h': 43200,
|
||||||
|
'1d': 86400,
|
||||||
|
'3d': 259200,
|
||||||
|
'1w': 604800
|
||||||
|
}
|
||||||
|
|
||||||
|
# Convert timeframe to seconds
|
||||||
|
if timeframe in interval_map:
|
||||||
|
interval_seconds = interval_map[timeframe]
|
||||||
|
else:
|
||||||
|
# Try to parse the interval (e.g., '1m' -> 60)
|
||||||
|
try:
|
||||||
|
if timeframe.endswith('s'):
|
||||||
|
interval_seconds = int(timeframe[:-1])
|
||||||
|
elif timeframe.endswith('m'):
|
||||||
|
interval_seconds = int(timeframe[:-1]) * 60
|
||||||
|
elif timeframe.endswith('h'):
|
||||||
|
interval_seconds = int(timeframe[:-1]) * 3600
|
||||||
|
elif timeframe.endswith('d'):
|
||||||
|
interval_seconds = int(timeframe[:-1]) * 86400
|
||||||
|
elif timeframe.endswith('w'):
|
||||||
|
interval_seconds = int(timeframe[:-1]) * 604800
|
||||||
|
else:
|
||||||
|
interval_seconds = int(timeframe)
|
||||||
|
except ValueError:
|
||||||
|
logger.error(f"Could not parse timeframe: {timeframe}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Get data from chart
|
||||||
|
df = self.chart._get_chart_data(interval_seconds)
|
||||||
|
|
||||||
|
if df is not None and not df.empty:
|
||||||
|
# Limit to requested number of candles
|
||||||
|
if len(df) > n_candles:
|
||||||
|
df = df.iloc[-n_candles:]
|
||||||
|
|
||||||
|
return df
|
||||||
|
else:
|
||||||
|
logger.warning(f"No data retrieved from chart for {symbol} at timeframe {timeframe}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting chart data for {symbol} at {timeframe}: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def prepare_model_input(self, data: pd.DataFrame, window_size: int = 20,
|
||||||
|
symbol: str = None) -> Tuple[np.ndarray, Optional[int]]:
|
||||||
|
"""Prepare model input from OHLCV data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: DataFrame with OHLCV data
|
||||||
|
window_size: Window size for model input
|
||||||
|
symbol: Symbol for the data (for logging)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
|
||||||
|
"""
|
||||||
|
if data is None or len(data) < window_size:
|
||||||
|
logger.warning(f"Not enough data to prepare model input for {symbol or 'unknown symbol'}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get last window_size candles
|
||||||
|
recent_data = data.iloc[-window_size:].copy()
|
||||||
|
|
||||||
|
# Get timestamp of the most recent candle
|
||||||
|
timestamp = int(recent_data.iloc[-1]['timestamp']) if 'timestamp' in recent_data.columns else int(time.time())
|
||||||
|
|
||||||
|
# Extract OHLCV features and normalize
|
||||||
|
if 'open' in recent_data.columns and 'high' in recent_data.columns and 'low' in recent_data.columns and 'close' in recent_data.columns and 'volume' in recent_data.columns:
|
||||||
|
# Normalize price data by the last close price
|
||||||
|
last_close = recent_data['close'].iloc[-1]
|
||||||
|
|
||||||
|
# Avoid division by zero
|
||||||
|
if last_close == 0:
|
||||||
|
last_close = 1.0
|
||||||
|
|
||||||
|
opens = (recent_data['open'] / last_close).values
|
||||||
|
highs = (recent_data['high'] / last_close).values
|
||||||
|
lows = (recent_data['low'] / last_close).values
|
||||||
|
closes = (recent_data['close'] / last_close).values
|
||||||
|
|
||||||
|
# Normalize volume by the max volume in the window
|
||||||
|
max_volume = recent_data['volume'].max()
|
||||||
|
if max_volume == 0:
|
||||||
|
max_volume = 1.0
|
||||||
|
volumes = (recent_data['volume'] / max_volume).values
|
||||||
|
|
||||||
|
# Stack features into a 3D array [batch_size=1, window_size, n_features=5]
|
||||||
|
X = np.column_stack((opens, highs, lows, closes, volumes))
|
||||||
|
X = X.reshape(1, window_size, 5)
|
||||||
|
|
||||||
|
# Replace any NaN or infinite values
|
||||||
|
X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=0.0)
|
||||||
|
|
||||||
|
return X, timestamp
|
||||||
|
else:
|
||||||
|
logger.error(f"Data missing required OHLCV columns for {symbol or 'unknown symbol'}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error preparing model input for {symbol or 'unknown symbol'}: {str(e)}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def prepare_realtime_input(self, timeframe: str = '1h', n_candles: int = 30,
|
||||||
|
window_size: int = 20) -> Tuple[np.ndarray, Optional[int]]:
|
||||||
|
"""Prepare real-time input for the model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeframe: Time interval (e.g., '1m', '5m', '1h')
|
||||||
|
n_candles: Number of candles to retrieve
|
||||||
|
window_size: Window size for model input
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (X, timestamp) where X is the model input and timestamp is the latest timestamp
|
||||||
|
"""
|
||||||
|
# Get data for the main symbol
|
||||||
|
if len(self.symbols) == 0:
|
||||||
|
logger.error("No symbols available for real-time input")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
symbol = self.symbols[0]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get historical data
|
||||||
|
data = self.get_historical_data(symbol, timeframe, n_candles)
|
||||||
|
|
||||||
|
if data is None or len(data) < window_size:
|
||||||
|
logger.warning(f"Not enough data for real-time input. Need at least {window_size} candles.")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Prepare model input
|
||||||
|
return self.prepare_model_input(data, window_size, symbol)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error preparing real-time input: {str(e)}")
|
||||||
|
return None, None
|
296
NN/train_rl.py
296
NN/train_rl.py
@ -63,6 +63,9 @@ class RLTradingEnvironment(gym.Env):
|
|||||||
|
|
||||||
# State variables
|
# State variables
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
|
# Callback for visualization or external monitoring
|
||||||
|
self.action_callback = None
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Reset the environment to initial state"""
|
"""Reset the environment to initial state"""
|
||||||
@ -145,6 +148,7 @@ class RLTradingEnvironment(gym.Env):
|
|||||||
# Default reward is slightly negative to discourage inaction
|
# Default reward is slightly negative to discourage inaction
|
||||||
reward = -0.0001
|
reward = -0.0001
|
||||||
done = False
|
done = False
|
||||||
|
profit_pct = None # Initialize profit_pct variable
|
||||||
|
|
||||||
# Execute action
|
# Execute action
|
||||||
if action == 0: # BUY
|
if action == 0: # BUY
|
||||||
@ -218,214 +222,188 @@ class RLTradingEnvironment(gym.Env):
|
|||||||
'total_value': total_value,
|
'total_value': total_value,
|
||||||
'gain': gain,
|
'gain': gain,
|
||||||
'trades': self.trades,
|
'trades': self.trades,
|
||||||
'win_rate': self.win_rate
|
'win_rate': self.win_rate,
|
||||||
|
'profit_pct': profit_pct if action == 1 and self.position == 0 else None,
|
||||||
|
'current_price': current_price,
|
||||||
|
'next_price': next_price
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Call the callback if it exists
|
||||||
|
if self.action_callback:
|
||||||
|
self.action_callback(action, current_price, reward, info)
|
||||||
|
|
||||||
return observation, reward, done, info
|
return observation, reward, done, info
|
||||||
|
|
||||||
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent"):
|
def set_action_callback(self, callback):
|
||||||
|
"""
|
||||||
|
Set a callback function to be called after each action
|
||||||
|
|
||||||
|
Args:
|
||||||
|
callback: Function with signature (action, price, reward, info)
|
||||||
|
"""
|
||||||
|
self.action_callback = callback
|
||||||
|
|
||||||
|
def train_rl(env_class=None, num_episodes=5000, max_steps=2000, save_path="NN/models/saved/dqn_agent",
|
||||||
|
action_callback=None, episode_callback=None, symbol="BTC/USDT"):
|
||||||
"""
|
"""
|
||||||
Train DQN agent for RL-based trading with extended training and monitoring
|
Train DQN agent for RL-based trading with extended training and monitoring
|
||||||
|
|
||||||
|
Args:
|
||||||
|
env_class: Optional environment class to use, defaults to RLTradingEnvironment
|
||||||
|
num_episodes: Number of episodes to train
|
||||||
|
max_steps: Maximum steps per episode
|
||||||
|
save_path: Path to save the model
|
||||||
|
action_callback: Optional callback for each action (step, action, price, reward, info)
|
||||||
|
episode_callback: Optional callback after each episode (episode, reward, info)
|
||||||
|
symbol: Trading pair symbol (e.g., "BTC/USDT")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DQNAgent: The trained agent
|
||||||
"""
|
"""
|
||||||
logger.info("Starting extended RL training for trading...")
|
import pandas as pd
|
||||||
|
from NN.utils.data_interface import DataInterface
|
||||||
|
|
||||||
# Environment setup
|
logger.info("Starting DQN training for RL trading")
|
||||||
window_size = 20
|
|
||||||
timeframes = ["1m", "5m", "15m"]
|
|
||||||
trading_fee = 0.001
|
|
||||||
|
|
||||||
# Ensure save directory exists
|
# Create data interface with specified symbol
|
||||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
data_interface = DataInterface(symbol=symbol)
|
||||||
|
|
||||||
# Setup TensorBoard for monitoring
|
# Load and preprocess data
|
||||||
writer = SummaryWriter(f'runs/rl_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}')
|
logger.info(f"Loading data from multiple timeframes for {symbol}")
|
||||||
|
features_1m = data_interface.get_training_data("1m", n_candles=2000)
|
||||||
# Data loading
|
features_5m = data_interface.get_training_data("5m", n_candles=1000)
|
||||||
data_interface = DataInterface(
|
features_15m = data_interface.get_training_data("15m", n_candles=500)
|
||||||
symbol="BTC/USDT",
|
|
||||||
timeframes=timeframes
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get training data for each timeframe with more data
|
|
||||||
logger.info("Loading training data...")
|
|
||||||
features_1m = data_interface.get_training_data("1m", n_candles=5000)
|
|
||||||
if features_1m is not None:
|
|
||||||
logger.info(f"Loaded {len(features_1m)} 1m candles")
|
|
||||||
else:
|
|
||||||
logger.error("Failed to load 1m data")
|
|
||||||
return None
|
|
||||||
|
|
||||||
features_5m = data_interface.get_training_data("5m", n_candles=2500)
|
|
||||||
if features_5m is not None:
|
|
||||||
logger.info(f"Loaded {len(features_5m)} 5m candles")
|
|
||||||
else:
|
|
||||||
logger.error("Failed to load 5m data")
|
|
||||||
return None
|
|
||||||
|
|
||||||
features_15m = data_interface.get_training_data("15m", n_candles=2500)
|
|
||||||
if features_15m is not None:
|
|
||||||
logger.info(f"Loaded {len(features_15m)} 15m candles")
|
|
||||||
else:
|
|
||||||
logger.error("Failed to load 15m data")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
# Check if we have all the data
|
||||||
if features_1m is None or features_5m is None or features_15m is None:
|
if features_1m is None or features_5m is None or features_15m is None:
|
||||||
logger.error("Failed to load training data")
|
logger.error("Failed to load training data from one or more timeframes")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Convert DataFrames to numpy arrays, excluding timestamp column
|
# If data is a DataFrame, convert to numpy array excluding the timestamp column
|
||||||
features_1m = features_1m.drop('timestamp', axis=1, errors='ignore').values
|
if isinstance(features_1m, pd.DataFrame):
|
||||||
features_5m = features_5m.drop('timestamp', axis=1, errors='ignore').values
|
features_1m = features_1m.drop('timestamp', axis=1, errors='ignore').values
|
||||||
features_15m = features_15m.drop('timestamp', axis=1, errors='ignore').values
|
if isinstance(features_5m, pd.DataFrame):
|
||||||
|
features_5m = features_5m.drop('timestamp', axis=1, errors='ignore').values
|
||||||
|
if isinstance(features_15m, pd.DataFrame):
|
||||||
|
features_15m = features_15m.drop('timestamp', axis=1, errors='ignore').values
|
||||||
|
|
||||||
# Calculate number of features per timeframe
|
# Initialize environment or use provided class
|
||||||
num_features = features_1m.shape[1] # Number of features after dropping timestamp
|
if env_class is None:
|
||||||
|
env = RLTradingEnvironment(features_1m, features_5m, features_15m)
|
||||||
|
else:
|
||||||
|
env = env_class(features_1m, features_5m, features_15m)
|
||||||
|
|
||||||
# Create environment
|
# Set action callback if provided
|
||||||
env = RLTradingEnvironment(
|
if action_callback:
|
||||||
features_1m=features_1m,
|
def step_callback(action, price, reward, info):
|
||||||
features_5m=features_5m,
|
action_callback(env.current_step, action, price, reward, info)
|
||||||
features_15m=features_15m,
|
env.set_action_callback(step_callback)
|
||||||
window_size=window_size,
|
|
||||||
trading_fee=trading_fee
|
# Initialize agent
|
||||||
)
|
window_size = env.window_size
|
||||||
|
num_features = env.num_features * env.num_timeframes
|
||||||
|
action_size = env.action_space.n
|
||||||
|
timeframes = ['1m', '5m', '15m'] # Match the timeframes from the environment
|
||||||
|
|
||||||
# Create agent with adjusted parameters for longer training
|
|
||||||
state_size = window_size
|
|
||||||
action_size = 3
|
|
||||||
agent = DQNAgent(
|
agent = DQNAgent(
|
||||||
state_size=state_size,
|
state_size=window_size * num_features,
|
||||||
action_size=action_size,
|
action_size=action_size,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
num_features=num_features,
|
num_features=env.num_features,
|
||||||
timeframes=timeframes,
|
timeframes=timeframes,
|
||||||
learning_rate=0.0005, # Reduced learning rate for stability
|
memory_size=100000,
|
||||||
gamma=0.99, # Increased discount factor
|
batch_size=64,
|
||||||
|
learning_rate=0.0001,
|
||||||
|
gamma=0.99,
|
||||||
epsilon=1.0,
|
epsilon=1.0,
|
||||||
epsilon_min=0.01,
|
epsilon_min=0.01,
|
||||||
epsilon_decay=0.999, # Slower epsilon decay
|
epsilon_decay=0.995
|
||||||
memory_size=50000, # Increased memory size
|
|
||||||
batch_size=128 # Increased batch size
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Variables to track best performance
|
# Training variables
|
||||||
best_reward = float('-inf')
|
best_reward = -float('inf')
|
||||||
best_episode = 0
|
|
||||||
best_pnl = float('-inf')
|
|
||||||
best_win_rate = 0.0
|
|
||||||
|
|
||||||
# Training metrics
|
|
||||||
episode_rewards = []
|
episode_rewards = []
|
||||||
episode_pnls = []
|
|
||||||
episode_win_rates = []
|
|
||||||
episode_trades = []
|
|
||||||
|
|
||||||
# Check if previous best model exists and load it
|
# TensorBoard writer for logging
|
||||||
best_model_path = f"{save_path}_best"
|
writer = SummaryWriter(log_dir=f'runs/rl_trading_{int(time.time())}')
|
||||||
if os.path.exists(f"{best_model_path}_policy.pt"):
|
|
||||||
try:
|
# Main training loop
|
||||||
logger.info(f"Loading previous best model from {best_model_path}")
|
logger.info(f"Starting training for {num_episodes} episodes...")
|
||||||
agent.load(best_model_path)
|
logger.info(f"Starting training on device: {agent.device}")
|
||||||
metadata_path = f"{best_model_path}_metadata.json"
|
|
||||||
if os.path.exists(metadata_path):
|
|
||||||
with open(metadata_path, 'r') as f:
|
|
||||||
metadata = json.load(f)
|
|
||||||
best_reward = metadata.get('best_reward', best_reward)
|
|
||||||
best_episode = metadata.get('best_episode', best_episode)
|
|
||||||
best_pnl = metadata.get('best_pnl', best_pnl)
|
|
||||||
best_win_rate = metadata.get('best_win_rate', best_win_rate)
|
|
||||||
logger.info(f"Loaded previous best metrics - Reward: {best_reward:.4f}, PnL: {best_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error loading previous best model: {e}")
|
|
||||||
|
|
||||||
# Training loop
|
|
||||||
try:
|
try:
|
||||||
for episode in range(1, num_episodes + 1):
|
for episode in range(num_episodes):
|
||||||
state = env.reset()
|
state = env.reset()
|
||||||
total_reward = 0
|
total_reward = 0
|
||||||
done = False
|
|
||||||
steps = 0
|
|
||||||
|
|
||||||
while not done and steps < max_steps:
|
for step in range(max_steps):
|
||||||
|
# Select action
|
||||||
action = agent.act(state)
|
action = agent.act(state)
|
||||||
|
|
||||||
|
# Take action and observe next state and reward
|
||||||
next_state, reward, done, info = env.step(action)
|
next_state, reward, done, info = env.step(action)
|
||||||
|
|
||||||
|
# Store the experience in memory
|
||||||
agent.remember(state, action, reward, next_state, done)
|
agent.remember(state, action, reward, next_state, done)
|
||||||
|
|
||||||
# Learn from experience
|
# Update state and reward
|
||||||
loss = agent.replay()
|
|
||||||
|
|
||||||
state = next_state
|
state = next_state
|
||||||
total_reward += reward
|
total_reward += reward
|
||||||
steps += 1
|
|
||||||
|
# Train the agent by sampling from memory
|
||||||
|
if len(agent.memory) >= agent.batch_size:
|
||||||
|
loss = agent.replay()
|
||||||
|
|
||||||
|
if done or step == max_steps - 1:
|
||||||
|
break
|
||||||
|
|
||||||
# Calculate episode metrics
|
# Track rewards
|
||||||
episode_rewards.append(total_reward)
|
episode_rewards.append(total_reward)
|
||||||
episode_pnls.append(info['gain'])
|
|
||||||
episode_win_rates.append(info['win_rate'])
|
# Log progress
|
||||||
episode_trades.append(info['trades'])
|
avg_reward = np.mean(episode_rewards[-100:])
|
||||||
|
logger.info(f"Episode {episode}/{num_episodes} - Reward: {total_reward:.4f}, " +
|
||||||
|
f"Avg (100): {avg_reward:.4f}, Epsilon: {agent.epsilon:.4f}")
|
||||||
|
|
||||||
|
# Calculate trading metrics
|
||||||
|
win_rate = env.win_rate if hasattr(env, 'win_rate') else 0
|
||||||
|
trades = env.trades if hasattr(env, 'trades') else 0
|
||||||
|
|
||||||
# Log to TensorBoard
|
# Log to TensorBoard
|
||||||
writer.add_scalar('Reward/episode', total_reward, episode)
|
writer.add_scalar('Reward/Episode', total_reward, episode)
|
||||||
writer.add_scalar('PnL/episode', info['gain'], episode)
|
writer.add_scalar('Reward/Average100', avg_reward, episode)
|
||||||
writer.add_scalar('WinRate/episode', info['win_rate'], episode)
|
writer.add_scalar('Trade/WinRate', win_rate, episode)
|
||||||
writer.add_scalar('Trades/episode', info['trades'], episode)
|
writer.add_scalar('Trade/Count', trades, episode)
|
||||||
writer.add_scalar('Epsilon/episode', agent.epsilon, episode)
|
|
||||||
|
|
||||||
# Save the best model based on multiple metrics (only every 50 episodes)
|
# Save best model
|
||||||
is_better = False
|
if avg_reward > best_reward and episode > 10:
|
||||||
if episode % 50 == 0: # Only check for saving every 50 episodes
|
logger.info(f"New best average reward: {avg_reward:.4f}, saving model")
|
||||||
if (info['gain'] > best_pnl and info['win_rate'] > 0.5) or \
|
agent.save(save_path)
|
||||||
(info['gain'] > best_pnl * 1.1) or \
|
best_reward = avg_reward
|
||||||
(info['win_rate'] > best_win_rate * 1.1):
|
|
||||||
best_reward = total_reward
|
|
||||||
best_episode = episode
|
|
||||||
best_pnl = info['gain']
|
|
||||||
best_win_rate = info['win_rate']
|
|
||||||
agent.save(best_model_path)
|
|
||||||
is_better = True
|
|
||||||
|
|
||||||
# Save metadata about the best model
|
|
||||||
metadata = {
|
|
||||||
'best_reward': best_reward,
|
|
||||||
'best_episode': best_episode,
|
|
||||||
'best_pnl': best_pnl,
|
|
||||||
'best_win_rate': best_win_rate,
|
|
||||||
'date': datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
}
|
|
||||||
with open(f"{best_model_path}_metadata.json", 'w') as f:
|
|
||||||
json.dump(metadata, f)
|
|
||||||
|
|
||||||
# Log training progress
|
# Periodic save every 100 episodes
|
||||||
if episode % 10 == 0:
|
if episode % 100 == 0 and episode > 0:
|
||||||
avg_reward = sum(episode_rewards[-10:]) / 10
|
agent.save(f"{save_path}_episode_{episode}")
|
||||||
avg_pnl = sum(episode_pnls[-10:]) / 10
|
|
||||||
avg_win_rate = sum(episode_win_rates[-10:]) / 10
|
|
||||||
avg_trades = sum(episode_trades[-10:]) / 10
|
|
||||||
|
|
||||||
status = "NEW BEST!" if is_better else ""
|
# Call episode callback if provided
|
||||||
logger.info(f"Episode {episode}/{num_episodes} {status}")
|
if episode_callback:
|
||||||
logger.info(f"Metrics (last 10 episodes):")
|
# Add environment to info dict to use for extrema training
|
||||||
logger.info(f" Reward: {avg_reward:.4f}")
|
info_with_env = info.copy()
|
||||||
logger.info(f" PnL: {avg_pnl:.4f}")
|
info_with_env['env'] = env
|
||||||
logger.info(f" Win Rate: {avg_win_rate:.4f}")
|
episode_callback(episode, total_reward, info_with_env)
|
||||||
logger.info(f" Trades: {avg_trades:.2f}")
|
|
||||||
logger.info(f" Epsilon: {agent.epsilon:.4f}")
|
# Final save
|
||||||
logger.info(f"Best so far - PnL: {best_pnl:.4f}, Win Rate: {best_win_rate:.4f}")
|
logger.info("Training completed, saving final model")
|
||||||
|
agent.save(f"{save_path}_final")
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("Training interrupted by user. Saving best model...")
|
except Exception as e:
|
||||||
|
logger.error(f"Training failed: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
# Close TensorBoard writer
|
# Close TensorBoard writer
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
# Final logs
|
|
||||||
logger.info(f"Training completed. Best model from episode {best_episode}")
|
|
||||||
logger.info(f"Best metrics:")
|
|
||||||
logger.info(f" Reward: {best_reward:.4f}")
|
|
||||||
logger.info(f" PnL: {best_pnl:.4f}")
|
|
||||||
logger.info(f" Win Rate: {best_win_rate:.4f}")
|
|
||||||
|
|
||||||
# Return the agent for potential further use
|
|
||||||
return agent
|
return agent
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -25,10 +25,10 @@ class SignalInterpreter:
|
|||||||
"""
|
"""
|
||||||
self.config = config or {}
|
self.config = config or {}
|
||||||
|
|
||||||
# Signal thresholds - higher thresholds for high-leverage trading
|
# Signal thresholds - lower thresholds to increase trade frequency
|
||||||
self.buy_threshold = self.config.get('buy_threshold', 0.65)
|
self.buy_threshold = self.config.get('buy_threshold', 0.35)
|
||||||
self.sell_threshold = self.config.get('sell_threshold', 0.65)
|
self.sell_threshold = self.config.get('sell_threshold', 0.35)
|
||||||
self.hold_threshold = self.config.get('hold_threshold', 0.75)
|
self.hold_threshold = self.config.get('hold_threshold', 0.60)
|
||||||
|
|
||||||
# Adaptive parameters
|
# Adaptive parameters
|
||||||
self.confidence_multiplier = self.config.get('confidence_multiplier', 1.0)
|
self.confidence_multiplier = self.config.get('confidence_multiplier', 1.0)
|
||||||
@ -45,14 +45,14 @@ class SignalInterpreter:
|
|||||||
self.current_position = None # None = no position, 'long' = buy, 'short' = sell
|
self.current_position = None # None = no position, 'long' = buy, 'short' = sell
|
||||||
|
|
||||||
# Filters for better signal quality
|
# Filters for better signal quality
|
||||||
self.trend_filter_enabled = self.config.get('trend_filter_enabled', True)
|
self.trend_filter_enabled = self.config.get('trend_filter_enabled', False) # Disable trend filter by default
|
||||||
self.volume_filter_enabled = self.config.get('volume_filter_enabled', True)
|
self.volume_filter_enabled = self.config.get('volume_filter_enabled', False) # Disable volume filter by default
|
||||||
self.oscillation_filter_enabled = self.config.get('oscillation_filter_enabled', True)
|
self.oscillation_filter_enabled = self.config.get('oscillation_filter_enabled', False) # Disable oscillation filter by default
|
||||||
|
|
||||||
# Sensitivity parameters
|
# Sensitivity parameters
|
||||||
self.min_price_movement = self.config.get('min_price_movement', 0.0005) # 0.05% minimum expected movement
|
self.min_price_movement = self.config.get('min_price_movement', 0.0001) # Lower price movement threshold
|
||||||
self.hold_cooldown = self.config.get('hold_cooldown', 3) # Minimum periods to wait after a HOLD
|
self.hold_cooldown = self.config.get('hold_cooldown', 1) # Shorter hold cooldown
|
||||||
self.consecutive_signals_required = self.config.get('consecutive_signals_required', 2)
|
self.consecutive_signals_required = self.config.get('consecutive_signals_required', 1) # Require only one signal
|
||||||
|
|
||||||
# State tracking
|
# State tracking
|
||||||
self.consecutive_buy_signals = 0
|
self.consecutive_buy_signals = 0
|
||||||
|
@ -54,4 +54,5 @@ python test_model.py
|
|||||||
|
|
||||||
|
|
||||||
python train_with_realtime_ticks.py
|
python train_with_realtime_ticks.py
|
||||||
python NN/train_rl.py
|
python NN/train_rl.py
|
||||||
|
python train_rl_with_realtime.py
|
88
main.py
88
main.py
@ -56,7 +56,7 @@ websocket_logger.setLevel(logging.INFO) # Change this from DEBUG to INFO
|
|||||||
class WebSocketFilter(logging.Filter):
|
class WebSocketFilter(logging.Filter):
|
||||||
def filter(self, record):
|
def filter(self, record):
|
||||||
# Filter out DEBUG messages from WebSocket-related modules
|
# Filter out DEBUG messages from WebSocket-related modules
|
||||||
if record.levelno == logging.DEBUG and ('websocket' in record.name or
|
if record.levelno == logging.INFO and ('websocket' in record.name or
|
||||||
'protocol' in record.name or
|
'protocol' in record.name or
|
||||||
'realtime' in record.name):
|
'realtime' in record.name):
|
||||||
return False
|
return False
|
||||||
@ -331,7 +331,7 @@ def main():
|
|||||||
"""Main function for the trading bot."""
|
"""Main function for the trading bot."""
|
||||||
# Parse command-line arguments
|
# Parse command-line arguments
|
||||||
parser = argparse.ArgumentParser(description="Trading Bot with Neural Network Integration")
|
parser = argparse.ArgumentParser(description="Trading Bot with Neural Network Integration")
|
||||||
parser.add_argument('--symbols', nargs='+', default=["BTC/USDT", "ETH/USDT"],
|
parser.add_argument('--symbols', nargs='+', default=["ETH/USDT", "ETH/USDT"],
|
||||||
help='Trading symbols to monitor')
|
help='Trading symbols to monitor')
|
||||||
parser.add_argument('--timeframes', nargs='+', default=["1m", "5m", "1h", "4h", "1d"],
|
parser.add_argument('--timeframes', nargs='+', default=["1m", "5m", "1h", "4h", "1d"],
|
||||||
help='Timeframes to monitor')
|
help='Timeframes to monitor')
|
||||||
@ -692,11 +692,17 @@ if __name__ == "__main__":
|
|||||||
"""Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals"""
|
"""Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals"""
|
||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
# Base reward for actions
|
# Validate current price
|
||||||
if action == 0: # HOLD
|
if self.current_price <= 0 or self.current_price > 1000000: # Reasonable price range
|
||||||
reward = -0.05 # Increased penalty for doing nothing to encourage more trading
|
logger.error(f"Invalid current price: {self.current_price}")
|
||||||
|
return -10.0 # Strong penalty for invalid price
|
||||||
elif action == 1: # BUY/LONG
|
|
||||||
|
# Validate position size
|
||||||
|
if self.position_size <= 0 or self.position_size > 1000000: # Reasonable position size range
|
||||||
|
logger.error(f"Invalid position size: {self.position_size}")
|
||||||
|
return -10.0 # Strong penalty for invalid position size
|
||||||
|
|
||||||
|
if action == 1: # BUY/LONG
|
||||||
if self.position == 'flat':
|
if self.position == 'flat':
|
||||||
# Opening a long position
|
# Opening a long position
|
||||||
self.position = 'long'
|
self.position = 'long'
|
||||||
@ -706,12 +712,11 @@ if __name__ == "__main__":
|
|||||||
self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100)
|
self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100)
|
||||||
self.take_profit = self.entry_price * (1 + self.take_profit_pct/100)
|
self.take_profit = self.entry_price * (1 + self.take_profit_pct/100)
|
||||||
|
|
||||||
# Check if this is an optimal buy point (bottom)
|
# Check if this is an optimal buy point
|
||||||
current_idx = len(self.features['price']) - 1
|
if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms:
|
||||||
if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms:
|
reward += 2.0 # Bonus for buying at a bottom
|
||||||
reward += 3.0 # Increased bonus for buying at a bottom
|
|
||||||
|
|
||||||
# Check for volume spike (indicating potential big movement)
|
# Check for volume spike
|
||||||
if len(self.features['volume']) > 5:
|
if len(self.features['volume']) > 5:
|
||||||
avg_volume = np.mean(self.features['volume'][-5:-1])
|
avg_volume = np.mean(self.features['volume'][-5:-1])
|
||||||
current_volume = self.features['volume'][-1]
|
current_volume = self.features['volume'][-1]
|
||||||
@ -737,9 +742,20 @@ if __name__ == "__main__":
|
|||||||
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
||||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||||
|
|
||||||
|
# Validate PnL values
|
||||||
|
if abs(pnl_percent) > 100: # Max 100% loss/gain
|
||||||
|
logger.error(f"Invalid PnL percentage: {pnl_percent}")
|
||||||
|
pnl_percent = max(min(pnl_percent, 100), -100)
|
||||||
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||||
|
|
||||||
# Apply fees
|
# Apply fees
|
||||||
pnl_dollar -= self.calculate_fees(self.position_size)
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
||||||
|
|
||||||
|
# Update balance with validation
|
||||||
|
if abs(pnl_dollar) > self.balance * 2: # Max 200% of balance
|
||||||
|
logger.error(f"Invalid PnL dollar amount: {pnl_dollar}")
|
||||||
|
pnl_dollar = max(min(pnl_dollar, self.balance * 2), -self.balance * 2)
|
||||||
|
|
||||||
# Update balance
|
# Update balance
|
||||||
self.balance += pnl_dollar
|
self.balance += pnl_dollar
|
||||||
self.total_pnl += pnl_dollar
|
self.total_pnl += pnl_dollar
|
||||||
@ -758,11 +774,11 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Reward based on PnL with stronger penalties for losses
|
# Reward based on PnL with stronger penalties for losses
|
||||||
if pnl_dollar > 0:
|
if pnl_dollar > 0:
|
||||||
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
reward += 1.0 + min(pnl_dollar / 10, 5.0) # Cap positive reward at 5.0
|
||||||
self.win_count += 1
|
self.win_count += 1
|
||||||
else:
|
else:
|
||||||
# Stronger penalty for losses, scaled by the size of the loss
|
# Stronger penalty for losses, scaled by the size of the loss but capped
|
||||||
loss_penalty = 1.0 + abs(pnl_dollar) / 5
|
loss_penalty = min(1.0 + abs(pnl_dollar) / 5, 5.0)
|
||||||
reward -= loss_penalty
|
reward -= loss_penalty
|
||||||
self.loss_count += 1
|
self.loss_count += 1
|
||||||
|
|
||||||
@ -2115,11 +2131,17 @@ class TradingEnvironment:
|
|||||||
"""Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals"""
|
"""Calculate reward for the given action with aggressive rewards for profitable trades and volume/price action signals"""
|
||||||
reward = 0
|
reward = 0
|
||||||
|
|
||||||
# Base reward for actions
|
# Validate current price
|
||||||
if action == 0: # HOLD
|
if self.current_price <= 0 or self.current_price > 1000000: # Reasonable price range
|
||||||
reward = -0.05 # Increased penalty for doing nothing to encourage more trading
|
logger.error(f"Invalid current price: {self.current_price}")
|
||||||
|
return -10.0 # Strong penalty for invalid price
|
||||||
elif action == 1: # BUY/LONG
|
|
||||||
|
# Validate position size
|
||||||
|
if self.position_size <= 0 or self.position_size > 1000000: # Reasonable position size range
|
||||||
|
logger.error(f"Invalid position size: {self.position_size}")
|
||||||
|
return -10.0 # Strong penalty for invalid position size
|
||||||
|
|
||||||
|
if action == 1: # BUY/LONG
|
||||||
if self.position == 'flat':
|
if self.position == 'flat':
|
||||||
# Opening a long position
|
# Opening a long position
|
||||||
self.position = 'long'
|
self.position = 'long'
|
||||||
@ -2129,12 +2151,11 @@ class TradingEnvironment:
|
|||||||
self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100)
|
self.stop_loss = self.entry_price * (1 - self.stop_loss_pct/100)
|
||||||
self.take_profit = self.entry_price * (1 + self.take_profit_pct/100)
|
self.take_profit = self.entry_price * (1 + self.take_profit_pct/100)
|
||||||
|
|
||||||
# Check if this is an optimal buy point (bottom)
|
# Check if this is an optimal buy point
|
||||||
current_idx = len(self.features['price']) - 1
|
if hasattr(self, 'optimal_bottoms') and self.entry_index in self.optimal_bottoms:
|
||||||
if hasattr(self, 'optimal_bottoms') and current_idx in self.optimal_bottoms:
|
reward += 2.0 # Bonus for buying at a bottom
|
||||||
reward += 3.0 # Increased bonus for buying at a bottom
|
|
||||||
|
|
||||||
# Check for volume spike (indicating potential big movement)
|
# Check for volume spike
|
||||||
if len(self.features['volume']) > 5:
|
if len(self.features['volume']) > 5:
|
||||||
avg_volume = np.mean(self.features['volume'][-5:-1])
|
avg_volume = np.mean(self.features['volume'][-5:-1])
|
||||||
current_volume = self.features['volume'][-1]
|
current_volume = self.features['volume'][-1]
|
||||||
@ -2160,9 +2181,20 @@ class TradingEnvironment:
|
|||||||
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
pnl_percent = (self.entry_price - self.current_price) / self.entry_price * 100
|
||||||
pnl_dollar = pnl_percent / 100 * self.position_size
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||||
|
|
||||||
|
# Validate PnL values
|
||||||
|
if abs(pnl_percent) > 100: # Max 100% loss/gain
|
||||||
|
logger.error(f"Invalid PnL percentage: {pnl_percent}")
|
||||||
|
pnl_percent = max(min(pnl_percent, 100), -100)
|
||||||
|
pnl_dollar = pnl_percent / 100 * self.position_size
|
||||||
|
|
||||||
# Apply fees
|
# Apply fees
|
||||||
pnl_dollar -= self.calculate_fees(self.position_size)
|
pnl_dollar -= self.calculate_fees(self.position_size)
|
||||||
|
|
||||||
|
# Update balance with validation
|
||||||
|
if abs(pnl_dollar) > self.balance * 2: # Max 200% of balance
|
||||||
|
logger.error(f"Invalid PnL dollar amount: {pnl_dollar}")
|
||||||
|
pnl_dollar = max(min(pnl_dollar, self.balance * 2), -self.balance * 2)
|
||||||
|
|
||||||
# Update balance
|
# Update balance
|
||||||
self.balance += pnl_dollar
|
self.balance += pnl_dollar
|
||||||
self.total_pnl += pnl_dollar
|
self.total_pnl += pnl_dollar
|
||||||
@ -2181,11 +2213,11 @@ class TradingEnvironment:
|
|||||||
|
|
||||||
# Reward based on PnL with stronger penalties for losses
|
# Reward based on PnL with stronger penalties for losses
|
||||||
if pnl_dollar > 0:
|
if pnl_dollar > 0:
|
||||||
reward += 1.0 + pnl_dollar / 10 # Positive reward for profit
|
reward += 1.0 + min(pnl_dollar / 10, 5.0) # Cap positive reward at 5.0
|
||||||
self.win_count += 1
|
self.win_count += 1
|
||||||
else:
|
else:
|
||||||
# Stronger penalty for losses, scaled by the size of the loss
|
# Stronger penalty for losses, scaled by the size of the loss but capped
|
||||||
loss_penalty = 1.0 + abs(pnl_dollar) / 5
|
loss_penalty = min(1.0 + abs(pnl_dollar) / 5, 5.0)
|
||||||
reward -= loss_penalty
|
reward -= loss_penalty
|
||||||
self.loss_count += 1
|
self.loss_count += 1
|
||||||
|
|
||||||
|
103
test_dash.py
Normal file
103
test_dash.py
Normal file
@ -0,0 +1,103 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Simple test for Dash to ensure chart rendering is working correctly
|
||||||
|
"""
|
||||||
|
|
||||||
|
import dash
|
||||||
|
from dash import html, dcc, Input, Output
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
# Create some sample data
|
||||||
|
def generate_sample_data(days=10):
|
||||||
|
end_date = datetime.now()
|
||||||
|
start_date = end_date - timedelta(days=days)
|
||||||
|
dates = pd.date_range(start=start_date, end=end_date, freq='1h')
|
||||||
|
|
||||||
|
np.random.seed(42)
|
||||||
|
prices = np.random.normal(loc=100, scale=5, size=len(dates))
|
||||||
|
prices = np.cumsum(np.random.normal(loc=0, scale=1, size=len(dates))) + 100
|
||||||
|
|
||||||
|
df = pd.DataFrame({
|
||||||
|
'timestamp': dates,
|
||||||
|
'open': prices,
|
||||||
|
'high': prices + np.random.normal(loc=1, scale=0.5, size=len(dates)),
|
||||||
|
'low': prices - np.random.normal(loc=1, scale=0.5, size=len(dates)),
|
||||||
|
'close': prices + np.random.normal(loc=0, scale=0.5, size=len(dates)),
|
||||||
|
'volume': np.abs(np.random.normal(loc=100, scale=50, size=len(dates)))
|
||||||
|
})
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
# Create a Dash app
|
||||||
|
app = dash.Dash(__name__)
|
||||||
|
|
||||||
|
# Layout
|
||||||
|
app.layout = html.Div([
|
||||||
|
html.H1("Test Chart"),
|
||||||
|
dcc.Graph(id='test-chart', style={'height': '800px'}),
|
||||||
|
dcc.Interval(
|
||||||
|
id='interval-component',
|
||||||
|
interval=1*1000, # in milliseconds
|
||||||
|
n_intervals=0
|
||||||
|
),
|
||||||
|
html.Div(id='signal-display', children="No Signal")
|
||||||
|
])
|
||||||
|
|
||||||
|
# Callback
|
||||||
|
@app.callback(
|
||||||
|
[Output('test-chart', 'figure'),
|
||||||
|
Output('signal-display', 'children')],
|
||||||
|
[Input('interval-component', 'n_intervals')]
|
||||||
|
)
|
||||||
|
def update_chart(n):
|
||||||
|
# Generate new data on each update
|
||||||
|
df = generate_sample_data()
|
||||||
|
|
||||||
|
# Create a candlestick chart
|
||||||
|
fig = go.Figure(data=[go.Candlestick(
|
||||||
|
x=df['timestamp'],
|
||||||
|
open=df['open'],
|
||||||
|
high=df['high'],
|
||||||
|
low=df['low'],
|
||||||
|
close=df['close']
|
||||||
|
)])
|
||||||
|
|
||||||
|
# Add some random buy/sell signals
|
||||||
|
if n % 5 == 0:
|
||||||
|
signal_point = df.iloc[np.random.randint(0, len(df))]
|
||||||
|
action = "BUY" if n % 10 == 0 else "SELL"
|
||||||
|
color = "green" if action == "BUY" else "red"
|
||||||
|
|
||||||
|
fig.add_trace(go.Scatter(
|
||||||
|
x=[signal_point['timestamp']],
|
||||||
|
y=[signal_point['close']],
|
||||||
|
mode='markers',
|
||||||
|
marker=dict(
|
||||||
|
size=10,
|
||||||
|
color=color,
|
||||||
|
symbol='triangle-up' if action == 'BUY' else 'triangle-down'
|
||||||
|
),
|
||||||
|
name=action
|
||||||
|
))
|
||||||
|
|
||||||
|
signal_text = f"Current Signal: {action}"
|
||||||
|
else:
|
||||||
|
signal_text = "No Signal"
|
||||||
|
|
||||||
|
# Update layout
|
||||||
|
fig.update_layout(
|
||||||
|
title='Test Price Chart',
|
||||||
|
yaxis_title='Price',
|
||||||
|
xaxis_title='Time',
|
||||||
|
template='plotly_dark',
|
||||||
|
height=800
|
||||||
|
)
|
||||||
|
|
||||||
|
return fig, signal_text
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
print("Starting Dash server on http://localhost:8090/")
|
||||||
|
app.run_server(debug=False, host='localhost', port=8090)
|
19
test_websocket.py
Normal file
19
test_websocket.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import websockets
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def test_websocket():
|
||||||
|
url = "wss://stream.binance.com:9443/ws/btcusdt@trade"
|
||||||
|
try:
|
||||||
|
logger.info(f"Connecting to {url}")
|
||||||
|
async with websockets.connect(url) as ws:
|
||||||
|
logger.info("Connected successfully")
|
||||||
|
message = await ws.recv()
|
||||||
|
logger.info(f"Received message: {message[:200]}...")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Connection failed: {str(e)}")
|
||||||
|
|
||||||
|
asyncio.run(test_websocket())
|
537
train_rl_with_realtime.py
Normal file
537
train_rl_with_realtime.py
Normal file
@ -0,0 +1,537 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
"""
|
||||||
|
Integrated RL Trading with Realtime Visualization
|
||||||
|
|
||||||
|
This script combines the RL training (NN/train_rl.py) with the realtime visualization (realtime.py)
|
||||||
|
to display the actions taken by the RL agent on the realtime chart.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
import signal
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
from threading import Thread
|
||||||
|
import pandas as pd
|
||||||
|
from scipy.signal import argrelextrema
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.FileHandler(f'rl_realtime_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
|
||||||
|
logging.StreamHandler()
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger('rl_realtime')
|
||||||
|
|
||||||
|
# Add the project root to path if needed
|
||||||
|
project_root = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
if project_root not in sys.path:
|
||||||
|
sys.path.append(project_root)
|
||||||
|
|
||||||
|
# Global variables for coordination
|
||||||
|
realtime_chart = None
|
||||||
|
realtime_websocket_task = None
|
||||||
|
running = True
|
||||||
|
|
||||||
|
def signal_handler(sig, frame):
|
||||||
|
"""Handle CTRL+C to gracefully exit training"""
|
||||||
|
global running
|
||||||
|
logger.info("Received interrupt signal. Finishing current epoch and shutting down...")
|
||||||
|
running = False
|
||||||
|
|
||||||
|
# Register signal handler
|
||||||
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|
||||||
|
class ExtremaDetector:
|
||||||
|
"""
|
||||||
|
Detects local extrema (tops and bottoms) in price data
|
||||||
|
"""
|
||||||
|
def __init__(self, window_size=10, order=5):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
window_size (int): Size of the window to look for extrema
|
||||||
|
order (int): How many points on each side to use for comparison
|
||||||
|
"""
|
||||||
|
self.window_size = window_size
|
||||||
|
self.order = order
|
||||||
|
|
||||||
|
def find_extrema(self, prices):
|
||||||
|
"""
|
||||||
|
Find the local minima and maxima in the price series
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prices (array-like): Array of price values
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (max_indices, min_indices) - arrays of indices where local maxima and minima occur
|
||||||
|
"""
|
||||||
|
# Convert to numpy array if needed
|
||||||
|
price_array = np.array(prices)
|
||||||
|
|
||||||
|
# Find local maxima (tops)
|
||||||
|
local_max_indices = argrelextrema(price_array, np.greater, order=self.order)[0]
|
||||||
|
|
||||||
|
# Find local minima (bottoms)
|
||||||
|
local_min_indices = argrelextrema(price_array, np.less, order=self.order)[0]
|
||||||
|
|
||||||
|
# Filter out extrema that are too close to the edges
|
||||||
|
max_indices = local_max_indices[local_max_indices >= self.order]
|
||||||
|
max_indices = max_indices[max_indices < len(price_array) - self.order]
|
||||||
|
|
||||||
|
min_indices = local_min_indices[local_min_indices >= self.order]
|
||||||
|
min_indices = min_indices[min_indices < len(price_array) - self.order]
|
||||||
|
|
||||||
|
return max_indices, min_indices
|
||||||
|
|
||||||
|
class RLTrainingIntegrator:
|
||||||
|
"""
|
||||||
|
Integrates RL training with realtime chart visualization.
|
||||||
|
Acts as a bridge between the RL training process and the realtime chart.
|
||||||
|
"""
|
||||||
|
def __init__(self, chart, symbol="ETH/USDT", model_save_path="NN/models/saved/dqn_agent"):
|
||||||
|
self.chart = chart
|
||||||
|
self.symbol = symbol
|
||||||
|
self.model_save_path = model_save_path
|
||||||
|
self.episode_count = 0
|
||||||
|
self.action_history = []
|
||||||
|
self.reward_history = []
|
||||||
|
self.trade_count = 0
|
||||||
|
self.win_count = 0
|
||||||
|
|
||||||
|
# Track current position state
|
||||||
|
self.in_position = False
|
||||||
|
self.entry_price = None
|
||||||
|
self.entry_time = None
|
||||||
|
|
||||||
|
# Extrema detector
|
||||||
|
self.extrema_detector = ExtremaDetector(window_size=10, order=5)
|
||||||
|
|
||||||
|
# Store the agent reference
|
||||||
|
self.agent = None
|
||||||
|
|
||||||
|
def start_training(self, num_episodes=5000, max_steps=2000):
|
||||||
|
"""Start the RL training process with visualization integration"""
|
||||||
|
from NN.train_rl import train_rl, RLTradingEnvironment
|
||||||
|
|
||||||
|
logger.info(f"Starting RL training with realtime visualization for {self.symbol}")
|
||||||
|
|
||||||
|
# Define callbacks for the training process
|
||||||
|
def on_action(step, action, price, reward, info):
|
||||||
|
"""Callback for each action taken by the agent"""
|
||||||
|
# Only visualize non-hold actions
|
||||||
|
if action != 2: # 0=Buy, 1=Sell, 2=Hold
|
||||||
|
# Convert to string action
|
||||||
|
action_str = "BUY" if action == 0 else "SELL"
|
||||||
|
|
||||||
|
# Get timestamp - we'll use current time as a proxy
|
||||||
|
timestamp = datetime.now()
|
||||||
|
|
||||||
|
# Track position state
|
||||||
|
if action == 0 and not self.in_position: # Buy and not already in position
|
||||||
|
self.in_position = True
|
||||||
|
self.entry_price = price
|
||||||
|
self.entry_time = timestamp
|
||||||
|
|
||||||
|
# Send to chart - visualize buy signal
|
||||||
|
if self.chart and hasattr(self.chart, 'add_nn_signal'):
|
||||||
|
self.chart.add_nn_signal(action_str, timestamp, probability=abs(reward))
|
||||||
|
|
||||||
|
elif action == 1 and self.in_position: # Sell and in position (complete trade)
|
||||||
|
self.in_position = False
|
||||||
|
|
||||||
|
# Calculate profit if we have entry data
|
||||||
|
pnl = None
|
||||||
|
if self.entry_price is not None:
|
||||||
|
pnl = (price - self.entry_price) / self.entry_price
|
||||||
|
|
||||||
|
# Log the complete trade on the chart
|
||||||
|
if self.chart:
|
||||||
|
# Show sell signal
|
||||||
|
if hasattr(self.chart, 'add_nn_signal'):
|
||||||
|
self.chart.add_nn_signal(action_str, timestamp, probability=abs(reward))
|
||||||
|
|
||||||
|
# Record the trade with PnL
|
||||||
|
if hasattr(self.chart, 'add_trade'):
|
||||||
|
self.chart.add_trade(
|
||||||
|
action=action_str,
|
||||||
|
price=price,
|
||||||
|
timestamp=timestamp,
|
||||||
|
pnl=pnl
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update trade counts
|
||||||
|
self.trade_count += 1
|
||||||
|
if pnl is not None and pnl > 0:
|
||||||
|
self.win_count += 1
|
||||||
|
|
||||||
|
# Reset entry data
|
||||||
|
self.entry_price = None
|
||||||
|
self.entry_time = None
|
||||||
|
|
||||||
|
# Track all actions
|
||||||
|
self.action_history.append({
|
||||||
|
'step': step,
|
||||||
|
'action': action_str,
|
||||||
|
'price': price,
|
||||||
|
'reward': reward,
|
||||||
|
'timestamp': timestamp.isoformat()
|
||||||
|
})
|
||||||
|
|
||||||
|
# Track reward for all actions (including hold)
|
||||||
|
self.reward_history.append(reward)
|
||||||
|
|
||||||
|
# Log periodically
|
||||||
|
if len(self.reward_history) % 100 == 0:
|
||||||
|
avg_reward = sum(self.reward_history[-100:]) / 100
|
||||||
|
logger.info(f"Step {step}: Avg reward (last 100): {avg_reward:.4f}, Actions: {len(self.action_history)}, Trades: {self.trade_count}")
|
||||||
|
|
||||||
|
def on_episode(episode, reward, info):
|
||||||
|
"""Callback for each completed episode"""
|
||||||
|
self.episode_count += 1
|
||||||
|
|
||||||
|
# Log episode results
|
||||||
|
logger.info(f"Episode {episode} completed")
|
||||||
|
logger.info(f" Total reward: {reward:.4f}")
|
||||||
|
logger.info(f" PnL: {info['gain']:.4f}")
|
||||||
|
logger.info(f" Win rate: {info['win_rate']:.4f}")
|
||||||
|
logger.info(f" Trades: {info['trades']}")
|
||||||
|
|
||||||
|
# Reset position state for new episode
|
||||||
|
self.in_position = False
|
||||||
|
self.entry_price = None
|
||||||
|
self.entry_time = None
|
||||||
|
|
||||||
|
# After each episode, perform additional training for local extrema
|
||||||
|
if hasattr(self.agent, 'policy_net') and hasattr(self.agent, 'replay') and episode > 0:
|
||||||
|
self._train_on_extrema(self.agent, info['env'])
|
||||||
|
|
||||||
|
# Start the actual training with our callbacks
|
||||||
|
self.agent = train_rl(
|
||||||
|
num_episodes=num_episodes,
|
||||||
|
max_steps=max_steps,
|
||||||
|
save_path=self.model_save_path,
|
||||||
|
action_callback=on_action,
|
||||||
|
episode_callback=on_episode,
|
||||||
|
symbol=self.symbol
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("RL training completed")
|
||||||
|
return self.agent
|
||||||
|
|
||||||
|
def _train_on_extrema(self, agent, env):
|
||||||
|
"""
|
||||||
|
Perform additional training on local extrema (tops and bottoms)
|
||||||
|
to help the model learn these important patterns faster
|
||||||
|
|
||||||
|
Args:
|
||||||
|
agent: The DQN agent
|
||||||
|
env: The trading environment
|
||||||
|
"""
|
||||||
|
if not hasattr(env, 'features_1m') or len(env.features_1m) == 0:
|
||||||
|
logger.warning("Environment doesn't have price data for extrema detection")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Extract close prices
|
||||||
|
prices = env.features_1m[:, -1] # Assuming close price is the last column
|
||||||
|
|
||||||
|
# Find local extrema
|
||||||
|
max_indices, min_indices = self.extrema_detector.find_extrema(prices)
|
||||||
|
|
||||||
|
if len(max_indices) == 0 or len(min_indices) == 0:
|
||||||
|
logger.warning("No extrema found in the current price data")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Found {len(max_indices)} tops and {len(min_indices)} bottoms for additional training")
|
||||||
|
|
||||||
|
# Calculate price changes at extrema to prioritize more significant ones
|
||||||
|
max_price_changes = []
|
||||||
|
for idx in max_indices:
|
||||||
|
if idx < 5 or idx >= len(prices) - 5:
|
||||||
|
continue
|
||||||
|
# Calculate percentage price rise from previous 5 candles to the peak
|
||||||
|
min_before = min(prices[idx-5:idx])
|
||||||
|
price_change = (prices[idx] - min_before) / min_before
|
||||||
|
max_price_changes.append((idx, price_change))
|
||||||
|
|
||||||
|
min_price_changes = []
|
||||||
|
for idx in min_indices:
|
||||||
|
if idx < 5 or idx >= len(prices) - 5:
|
||||||
|
continue
|
||||||
|
# Calculate percentage price drop from previous 5 candles to the bottom
|
||||||
|
max_before = max(prices[idx-5:idx])
|
||||||
|
price_change = (max_before - prices[idx]) / max_before
|
||||||
|
min_price_changes.append((idx, price_change))
|
||||||
|
|
||||||
|
# Sort extrema by significance (larger price change is more important)
|
||||||
|
max_price_changes.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
min_price_changes.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
# Take top 10 most significant extrema or all if fewer
|
||||||
|
max_indices = [idx for idx, _ in max_price_changes[:10]]
|
||||||
|
min_indices = [idx for idx, _ in min_price_changes[:10]]
|
||||||
|
|
||||||
|
# Log the significance of the extrema
|
||||||
|
if max_indices:
|
||||||
|
logger.info(f"Top extrema price changes: {[round(pc*100, 2) for _, pc in max_price_changes[:5]]}%")
|
||||||
|
if min_indices:
|
||||||
|
logger.info(f"Bottom extrema price changes: {[round(pc*100, 2) for _, pc in min_price_changes[:5]]}%")
|
||||||
|
|
||||||
|
# Collect states, actions, rewards for batch training
|
||||||
|
states = []
|
||||||
|
actions = []
|
||||||
|
rewards = []
|
||||||
|
next_states = []
|
||||||
|
dones = []
|
||||||
|
|
||||||
|
# Process tops (local maxima - should sell)
|
||||||
|
for idx in max_indices:
|
||||||
|
if idx < env.window_size + 2 or idx >= len(prices) - 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create states for multiple points approaching the top
|
||||||
|
# This helps the model learn to recognize the pattern leading to the top
|
||||||
|
for offset in range(1, 4): # Look at 1, 2, and 3 candles before the top
|
||||||
|
if idx - offset < env.window_size:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# State before the peak
|
||||||
|
state_idx = idx - offset
|
||||||
|
env.current_step = state_idx
|
||||||
|
state = env._get_observation()
|
||||||
|
|
||||||
|
# The next state would be closer to the peak
|
||||||
|
env.current_step = state_idx + 1
|
||||||
|
next_state = env._get_observation()
|
||||||
|
|
||||||
|
# Reward increases as we get closer to the peak
|
||||||
|
# Stronger rewards for being right at the peak
|
||||||
|
reward = 1.0 if offset > 1 else 2.0
|
||||||
|
|
||||||
|
# Add to memory
|
||||||
|
action = 1 # Sell
|
||||||
|
agent.remember(state, action, reward, next_state, False, is_extrema=True)
|
||||||
|
|
||||||
|
# Add to batch
|
||||||
|
states.append(state)
|
||||||
|
actions.append(action)
|
||||||
|
rewards.append(reward)
|
||||||
|
next_states.append(next_state)
|
||||||
|
dones.append(False)
|
||||||
|
|
||||||
|
# Process bottoms (local minima - should buy)
|
||||||
|
for idx in min_indices:
|
||||||
|
if idx < env.window_size + 2 or idx >= len(prices) - 2:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create states for multiple points approaching the bottom
|
||||||
|
for offset in range(1, 4): # Look at 1, 2, and 3 candles before the bottom
|
||||||
|
if idx - offset < env.window_size:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# State before the bottom
|
||||||
|
state_idx = idx - offset
|
||||||
|
env.current_step = state_idx
|
||||||
|
state = env._get_observation()
|
||||||
|
|
||||||
|
# The next state would be closer to the bottom
|
||||||
|
env.current_step = state_idx + 1
|
||||||
|
next_state = env._get_observation()
|
||||||
|
|
||||||
|
# Reward increases as we get closer to the bottom
|
||||||
|
reward = 1.0 if offset > 1 else 2.0
|
||||||
|
|
||||||
|
# Add to memory
|
||||||
|
action = 0 # Buy
|
||||||
|
agent.remember(state, action, reward, next_state, False, is_extrema=True)
|
||||||
|
|
||||||
|
# Add to batch
|
||||||
|
states.append(state)
|
||||||
|
actions.append(action)
|
||||||
|
rewards.append(reward)
|
||||||
|
next_states.append(next_state)
|
||||||
|
dones.append(False)
|
||||||
|
|
||||||
|
# Add some negative examples - don't buy at tops, don't sell at bottoms
|
||||||
|
for idx in max_indices[:5]: # Use a few top peaks
|
||||||
|
if idx < env.window_size + 1 or idx >= len(prices) - 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# State at the peak
|
||||||
|
env.current_step = idx
|
||||||
|
state = env._get_observation()
|
||||||
|
|
||||||
|
# Next state
|
||||||
|
env.current_step = idx + 1
|
||||||
|
next_state = env._get_observation()
|
||||||
|
|
||||||
|
# Strong negative reward for buying at a peak
|
||||||
|
reward = -1.5
|
||||||
|
|
||||||
|
# Add negative example of buying at a peak
|
||||||
|
action = 0 # Buy (wrong action)
|
||||||
|
agent.remember(state, action, reward, next_state, False, is_extrema=True)
|
||||||
|
|
||||||
|
# Add to batch
|
||||||
|
states.append(state)
|
||||||
|
actions.append(action)
|
||||||
|
rewards.append(reward)
|
||||||
|
next_states.append(next_state)
|
||||||
|
dones.append(False)
|
||||||
|
|
||||||
|
for idx in min_indices[:5]: # Use a few bottom troughs
|
||||||
|
if idx < env.window_size + 1 or idx >= len(prices) - 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# State at the bottom
|
||||||
|
env.current_step = idx
|
||||||
|
state = env._get_observation()
|
||||||
|
|
||||||
|
# Next state
|
||||||
|
env.current_step = idx + 1
|
||||||
|
next_state = env._get_observation()
|
||||||
|
|
||||||
|
# Strong negative reward for selling at a bottom
|
||||||
|
reward = -1.5
|
||||||
|
|
||||||
|
# Add negative example of selling at a bottom
|
||||||
|
action = 1 # Sell (wrong action)
|
||||||
|
agent.remember(state, action, reward, next_state, False, is_extrema=True)
|
||||||
|
|
||||||
|
# Add to batch
|
||||||
|
states.append(state)
|
||||||
|
actions.append(action)
|
||||||
|
rewards.append(reward)
|
||||||
|
next_states.append(next_state)
|
||||||
|
dones.append(False)
|
||||||
|
|
||||||
|
# Train on the collected extrema samples
|
||||||
|
if len(states) > 0:
|
||||||
|
logger.info(f"Performing additional training on {len(states)} extrema patterns")
|
||||||
|
loss = agent.train_on_extrema(states, actions, rewards, next_states, dones)
|
||||||
|
logger.info(f"Extrema training loss: {loss:.4f}")
|
||||||
|
|
||||||
|
# Additional replay passes with extrema samples included
|
||||||
|
for _ in range(5):
|
||||||
|
loss = agent.replay(use_extrema=True)
|
||||||
|
logger.info(f"Mixed replay with extrema - loss: {loss:.4f}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error during extrema training: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
async def start_realtime_chart(symbol="BTC/USDT", port=8050):
|
||||||
|
"""
|
||||||
|
Start the realtime chart display in a separate thread
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: (chart, websocket_task)
|
||||||
|
"""
|
||||||
|
from realtime import RealTimeChart
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"Initializing RealTimeChart for {symbol}")
|
||||||
|
# Create the chart
|
||||||
|
chart = RealTimeChart(symbol)
|
||||||
|
|
||||||
|
# Start the WebSocket connection in a separate thread
|
||||||
|
# The _start_websocket_thread method already handles this correctly
|
||||||
|
|
||||||
|
# Run the Dash server in a separate thread
|
||||||
|
thread = Thread(target=lambda c=chart, p=port: c.run(host='localhost', port=p))
|
||||||
|
thread.daemon = True
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# Give the server a moment to start
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
|
logger.info(f"Started realtime chart for {symbol} on port {port}")
|
||||||
|
logger.info(f"You can view the chart at http://localhost:{port}/")
|
||||||
|
|
||||||
|
# Return the chart and a dummy websocket task (the real one is running in a thread)
|
||||||
|
return chart, asyncio.create_task(asyncio.sleep(0))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error starting realtime chart: {str(e)}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
raise
|
||||||
|
|
||||||
|
def run_training_thread(chart, num_episodes=5000, max_steps=2000):
|
||||||
|
"""Run the RL training in a separate thread"""
|
||||||
|
integrator = RLTrainingIntegrator(chart)
|
||||||
|
thread = Thread(target=lambda: integrator.start_training(num_episodes, max_steps))
|
||||||
|
thread.daemon = True
|
||||||
|
thread.start()
|
||||||
|
logger.info("Started RL training thread")
|
||||||
|
return thread, integrator
|
||||||
|
|
||||||
|
def test_signals(chart):
|
||||||
|
"""Add test signals to the chart to verify functionality"""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
logger.info("Adding test signals to chart")
|
||||||
|
|
||||||
|
# Add a test BUY signal
|
||||||
|
chart.add_nn_signal("BUY", datetime.now(), 0.95)
|
||||||
|
|
||||||
|
# Sleep briefly
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Add a test SELL signal
|
||||||
|
chart.add_nn_signal("SELL", datetime.now(), 0.85)
|
||||||
|
|
||||||
|
# Add a test trade
|
||||||
|
chart.add_trade("BUY", 50000.0, datetime.now(), 0.05)
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main function that coordinates the realtime chart and RL training"""
|
||||||
|
global realtime_chart, realtime_websocket_task, running
|
||||||
|
|
||||||
|
logger.info("Starting integrated RL training with realtime visualization")
|
||||||
|
|
||||||
|
# Start the realtime chart
|
||||||
|
realtime_chart, realtime_websocket_task = await start_realtime_chart()
|
||||||
|
|
||||||
|
# Wait a bit for the chart to initialize
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
|
||||||
|
# Test signals first
|
||||||
|
test_signals(realtime_chart)
|
||||||
|
|
||||||
|
# Start the training in a separate thread
|
||||||
|
training_thread, integrator = run_training_thread(realtime_chart)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Keep the main task running until interrupted
|
||||||
|
while running and training_thread.is_alive():
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Shutting down...")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error: {str(e)}")
|
||||||
|
finally:
|
||||||
|
# Clean up
|
||||||
|
if realtime_websocket_task:
|
||||||
|
realtime_websocket_task.cancel()
|
||||||
|
try:
|
||||||
|
await realtime_websocket_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.info("Application terminated")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
asyncio.run(main())
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
logger.info("Application terminated by user")
|
@ -350,9 +350,9 @@ async def run_realtime_training():
|
|||||||
|
|
||||||
model = CNNModelPyTorch(
|
model = CNNModelPyTorch(
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
num_features=num_features,
|
timeframes=timeframes,
|
||||||
output_size=output_size,
|
output_size=output_size,
|
||||||
timeframes=timeframes
|
num_pairs=1 # Single trading pair
|
||||||
)
|
)
|
||||||
|
|
||||||
# Try to load existing model
|
# Try to load existing model
|
||||||
|
Loading…
x
Reference in New Issue
Block a user