big cleanup
This commit is contained in:
parent
7d8eca995e
commit
0331bbfa7c
@ -1,359 +0,0 @@
|
||||
[
|
||||
{
|
||||
"trade_id": 1,
|
||||
"side": "SHORT",
|
||||
"entry_time": "2025-05-30T17:25:58.643819+00:00",
|
||||
"exit_time": "2025-05-30T17:26:39.729472+00:00",
|
||||
"entry_price": 2550.7,
|
||||
"exit_price": 2546.59,
|
||||
"size": 0.003724,
|
||||
"gross_pnl": 0.015305639999998781,
|
||||
"fees": 0.00949115398,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": 0.005814486019998783,
|
||||
"duration": "0:00:41.085653",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": true
|
||||
},
|
||||
{
|
||||
"trade_id": 2,
|
||||
"side": "LONG",
|
||||
"entry_time": "2025-05-30T17:26:39.729472+00:00",
|
||||
"exit_time": "2025-05-30T17:26:40.742643+00:00",
|
||||
"entry_price": 2546.59,
|
||||
"exit_price": 2546.58,
|
||||
"size": 0.003456,
|
||||
"gross_pnl": -3.456000000075437e-05,
|
||||
"fees": 0.008800997759999998,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.008835557760000754,
|
||||
"duration": "0:00:01.013171",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": false
|
||||
},
|
||||
{
|
||||
"trade_id": 3,
|
||||
"side": "SHORT",
|
||||
"entry_time": "2025-05-30T17:26:40.742643+00:00",
|
||||
"exit_time": "2025-05-30T17:26:44.783909+00:00",
|
||||
"entry_price": 2546.58,
|
||||
"exit_price": 2546.69,
|
||||
"size": 0.003155,
|
||||
"gross_pnl": -0.0003470500000004017,
|
||||
"fees": 0.008034633425,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.008381683425000402,
|
||||
"duration": "0:00:04.041266",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": false
|
||||
},
|
||||
{
|
||||
"trade_id": 4,
|
||||
"side": "LONG",
|
||||
"entry_time": "2025-05-30T17:26:44.783909+00:00",
|
||||
"exit_time": "2025-05-30T17:26:56.903098+00:00",
|
||||
"entry_price": 2546.69,
|
||||
"exit_price": 2546.9,
|
||||
"size": 0.003374,
|
||||
"gross_pnl": 0.0007085400000001227,
|
||||
"fees": 0.00859288633,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.007884346329999877,
|
||||
"duration": "0:00:12.119189",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": false
|
||||
},
|
||||
{
|
||||
"trade_id": 5,
|
||||
"side": "SHORT",
|
||||
"entry_time": "2025-05-30T17:26:56.903098+00:00",
|
||||
"exit_time": "2025-05-30T17:27:03.971971+00:00",
|
||||
"entry_price": 2546.9,
|
||||
"exit_price": 2547.78,
|
||||
"size": 0.003309,
|
||||
"gross_pnl": -0.002911920000000361,
|
||||
"fees": 0.00842914806,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.011341068060000362,
|
||||
"duration": "0:00:07.068873",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": false
|
||||
},
|
||||
{
|
||||
"trade_id": 6,
|
||||
"side": "LONG",
|
||||
"entry_time": "2025-05-30T17:27:03.971971+00:00",
|
||||
"exit_time": "2025-05-30T17:27:24.185714+00:00",
|
||||
"entry_price": 2547.78,
|
||||
"exit_price": 2548.0,
|
||||
"size": 0.003704,
|
||||
"gross_pnl": 0.0008148799999992589,
|
||||
"fees": 0.009437384560000001,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.008622504560000742,
|
||||
"duration": "0:00:20.213743",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": true
|
||||
},
|
||||
{
|
||||
"trade_id": 7,
|
||||
"side": "SHORT",
|
||||
"entry_time": "2025-05-30T17:27:24.185714+00:00",
|
||||
"exit_time": "2025-05-30T17:27:35.315014+00:00",
|
||||
"entry_price": 2548.0,
|
||||
"exit_price": 2547.67,
|
||||
"size": 0.003304,
|
||||
"gross_pnl": 0.0010903199999997596,
|
||||
"fees": 0.008418046840000002,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.007327726840000242,
|
||||
"duration": "0:00:11.129300",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": false
|
||||
},
|
||||
{
|
||||
"trade_id": 8,
|
||||
"side": "LONG",
|
||||
"entry_time": "2025-05-30T17:27:35.315014+00:00",
|
||||
"exit_time": "2025-05-30T17:27:48.488282+00:00",
|
||||
"entry_price": 2547.67,
|
||||
"exit_price": 2547.5,
|
||||
"size": 0.003442,
|
||||
"gross_pnl": -0.0005851400000002505,
|
||||
"fees": 0.00876878757,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.00935392757000025,
|
||||
"duration": "0:00:13.173268",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": false
|
||||
},
|
||||
{
|
||||
"trade_id": 9,
|
||||
"side": "SHORT",
|
||||
"entry_time": "2025-05-30T17:27:48.488282+00:00",
|
||||
"exit_time": "2025-05-30T17:28:09.641167+00:00",
|
||||
"entry_price": 2547.5,
|
||||
"exit_price": 2547.2,
|
||||
"size": 0.003729,
|
||||
"gross_pnl": 0.0011187000000006783,
|
||||
"fees": 0.009499068150000001,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.008380368149999321,
|
||||
"duration": "0:00:21.152885",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": true
|
||||
},
|
||||
{
|
||||
"trade_id": 10,
|
||||
"side": "LONG",
|
||||
"entry_time": "2025-05-30T17:28:09.641167+00:00",
|
||||
"exit_time": "2025-05-30T17:29:03.116674+00:00",
|
||||
"entry_price": 2547.2,
|
||||
"exit_price": 2549.4,
|
||||
"size": 0.0034,
|
||||
"gross_pnl": 0.007480000000000927,
|
||||
"fees": 0.00866422,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.0011842199999990725,
|
||||
"duration": "0:00:53.475507",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": true
|
||||
},
|
||||
{
|
||||
"trade_id": 11,
|
||||
"side": "SHORT",
|
||||
"entry_time": "2025-05-30T17:29:03.116674+00:00",
|
||||
"exit_time": "2025-05-30T17:29:10.180571+00:00",
|
||||
"entry_price": 2549.4,
|
||||
"exit_price": 2549.79,
|
||||
"size": 0.003408,
|
||||
"gross_pnl": -0.0013291199999995661,
|
||||
"fees": 0.00868901976,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.010018139759999566,
|
||||
"duration": "0:00:07.063897",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": false
|
||||
},
|
||||
{
|
||||
"trade_id": 12,
|
||||
"side": "LONG",
|
||||
"entry_time": "2025-05-30T17:29:10.180571+00:00",
|
||||
"exit_time": "2025-05-30T17:29:19.404003+00:00",
|
||||
"entry_price": 2549.79,
|
||||
"exit_price": 2548.9,
|
||||
"size": 0.003552,
|
||||
"gross_pnl": -0.003161279999999548,
|
||||
"fees": 0.00905527344,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.012216553439999549,
|
||||
"duration": "0:00:09.223432",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": false
|
||||
},
|
||||
{
|
||||
"trade_id": 13,
|
||||
"side": "SHORT",
|
||||
"entry_time": "2025-05-30T17:29:19.404003+00:00",
|
||||
"exit_time": "2025-05-30T17:29:40.434581+00:00",
|
||||
"entry_price": 2548.9,
|
||||
"exit_price": 2547.8,
|
||||
"size": 0.003692,
|
||||
"gross_pnl": 0.004061199999999664,
|
||||
"fees": 0.0094085082,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.005347308200000336,
|
||||
"duration": "0:00:21.030578",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": true
|
||||
},
|
||||
{
|
||||
"trade_id": 14,
|
||||
"side": "LONG",
|
||||
"entry_time": "2025-05-30T17:29:40.434581+00:00",
|
||||
"exit_time": "2025-05-30T17:29:41.445058+00:00",
|
||||
"entry_price": 2547.8,
|
||||
"exit_price": 2547.8,
|
||||
"size": 0.003729,
|
||||
"gross_pnl": 0.0,
|
||||
"fees": 0.009500746200000002,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.009500746200000002,
|
||||
"duration": "0:00:01.010477",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": false
|
||||
},
|
||||
{
|
||||
"trade_id": 15,
|
||||
"side": "SHORT",
|
||||
"entry_time": "2025-05-30T17:29:41.445058+00:00",
|
||||
"exit_time": "2025-05-30T17:29:45.488994+00:00",
|
||||
"entry_price": 2547.8,
|
||||
"exit_price": 2547.88,
|
||||
"size": 0.003215,
|
||||
"gross_pnl": -0.0002571999999997661,
|
||||
"fees": 0.0081913056,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.008448505599999765,
|
||||
"duration": "0:00:04.043936",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": false
|
||||
},
|
||||
{
|
||||
"trade_id": 16,
|
||||
"side": "LONG",
|
||||
"entry_time": "2025-05-30T17:29:45.488994+00:00",
|
||||
"exit_time": "2025-05-30T17:30:11.732339+00:00",
|
||||
"entry_price": 2547.88,
|
||||
"exit_price": 2549.3,
|
||||
"size": 0.003189,
|
||||
"gross_pnl": 0.004528380000000232,
|
||||
"fees": 0.00812745351,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.0035990735099997685,
|
||||
"duration": "0:00:26.243345",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": true
|
||||
},
|
||||
{
|
||||
"trade_id": 17,
|
||||
"side": "SHORT",
|
||||
"entry_time": "2025-05-30T17:30:11.732339+00:00",
|
||||
"exit_time": "2025-05-30T17:30:25.893383+00:00",
|
||||
"entry_price": 2549.3,
|
||||
"exit_price": 2548.76,
|
||||
"size": 0.003013,
|
||||
"gross_pnl": 0.0016270199999998904,
|
||||
"fees": 0.007680227390000001,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.00605320739000011,
|
||||
"duration": "0:00:14.161044",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": false
|
||||
},
|
||||
{
|
||||
"trade_id": 18,
|
||||
"side": "LONG",
|
||||
"entry_time": "2025-05-30T17:30:25.893383+00:00",
|
||||
"exit_time": "2025-05-30T17:30:40.053758+00:00",
|
||||
"entry_price": 2548.76,
|
||||
"exit_price": 2549.4,
|
||||
"size": 0.002905,
|
||||
"gross_pnl": 0.0018591999999996302,
|
||||
"fees": 0.007405077400000001,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.005545877400000371,
|
||||
"duration": "0:00:14.160375",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": false
|
||||
},
|
||||
{
|
||||
"trade_id": 19,
|
||||
"side": "SHORT",
|
||||
"entry_time": "2025-05-30T17:30:40.053758+00:00",
|
||||
"exit_time": "2025-05-30T17:30:46.111367+00:00",
|
||||
"entry_price": 2549.4,
|
||||
"exit_price": 2549.8,
|
||||
"size": 0.003726,
|
||||
"gross_pnl": -0.001490400000000339,
|
||||
"fees": 0.0094998096,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.01099020960000034,
|
||||
"duration": "0:00:06.057609",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": true
|
||||
},
|
||||
{
|
||||
"trade_id": 20,
|
||||
"side": "LONG",
|
||||
"entry_time": "2025-05-30T17:30:46.111367+00:00",
|
||||
"exit_time": "2025-05-30T17:30:48.166894+00:00",
|
||||
"entry_price": 2549.8,
|
||||
"exit_price": 2549.21,
|
||||
"size": 0.003652,
|
||||
"gross_pnl": -0.0021546800000005312,
|
||||
"fees": 0.009310792259999999,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.011465472260000532,
|
||||
"duration": "0:00:02.055527",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": false
|
||||
},
|
||||
{
|
||||
"trade_id": 21,
|
||||
"side": "SHORT",
|
||||
"entry_time": "2025-05-30T17:30:48.166894+00:00",
|
||||
"exit_time": "2025-05-30T17:31:12.387130+00:00",
|
||||
"entry_price": 2549.21,
|
||||
"exit_price": 2547.77,
|
||||
"size": 0.003313,
|
||||
"gross_pnl": 0.00477072000000018,
|
||||
"fees": 0.008443147370000001,
|
||||
"fee_type": "taker",
|
||||
"fee_rate": 0.0005,
|
||||
"net_pnl": -0.0036724273699998197,
|
||||
"duration": "0:00:24.220236",
|
||||
"symbol": "ETH/USDC",
|
||||
"mexc_executed": false
|
||||
}
|
||||
]
|
@ -62,7 +62,11 @@ I see we're always invested. adjust the training, reward functions use the orche
|
||||
I see we're always invested. adjust the training, reward functions use the orchestrator to learn to make that decison when gets uncertain signals from the expert models.mods hould learn to effectively spot setups in the market which are with high risk/reward level and act on theese
|
||||
if that does not work I think we can make it simpler and easier to train if we have just 2 model actions buy/sell. we don't need hold signal, as until we have action we hold. And when we are long and we get a sell signal - we close. and enter short on consequtive sell signal. also, we will have different thresholds for entering and exiting. learning to enter when we are more certain
|
||||
this will also help us simplify the training and our codebase to keep it easy to develop.
|
||||
as our models are chained, it does not make sense anymore to train them separately. so remove all modes from main_clean and all referenced code. we use only web mode wherehe flow is: we collect data, calculate indicators and pivot points -> CNN -> RL => orchestrator -> broker/web
|
||||
as our models are chained, it does not make sense anymore to train them separately. so remove all modes from main_clean and all referenced code. we use only web mode wherehe
|
||||
|
||||
#######
|
||||
flow is: we collect data, calculate indicators and pivot points -> CNN -> RL => orchestrator -> broker/web
|
||||
we use UnifiedDataStream to collect data and pass it to the models.
|
||||
|
||||
orchestrator model also should be an appropriate MoE model that will be able to learn to make decisions based on the signals from the expert models. it should be able to include more models in the future.
|
||||
|
||||
@ -74,3 +78,23 @@ make all dashboard processes run on the server without need of dashboard page to
|
||||
all models/training/inference should be run on the server. dashboard should be used only for displaying the data and controlling the processes. let's add a start/stop button to the dashboard to control the processes. also add slider to adjust the buy/sell thresholds for the orchestrator model and therefore bias the agressiveness of the model actions.
|
||||
|
||||
add a row with small charts showing all the data we feed to the models: the 1m 1h 1d and reference (btc) ohlcv on the dashboard
|
||||
|
||||
|
||||
# PROBLEMS
|
||||
also, tell me which CNN model is uesd in /web/dashboard.py training pipeline right now and what are it's inputs/outputs?
|
||||
|
||||
CNN model should predict next pivot point and the timestamp it will happen at - for each of the pivot point levels taht we feed. do we do that now and do we train the model and what is the current loss?
|
||||
|
||||
# overview/overhaul
|
||||
but why the classes in training folder define their own models??? they should use the models defined in NN folder. no wonder i see no progress in trining. audit the whole project and remove redundant implementations.
|
||||
as described, we should have single point where data is prepared - in the data probider class. it also calculates indicators and pivot points and caches different timeframes of OHLCV data to reduce load and external API calls.
|
||||
then the web UI and the CNN model consume that data in inference mode but when a pivot is detected we run a training round on the CNN.
|
||||
then cnn outputs and part of the hidden layers state are passed to the RL model which generates buy/sell signals.
|
||||
then the orchestrator (moe gateway of sorts) gets the data from both CNN and RL and generates it's own output. actions are then shown on the dash and executed via the brokerage api
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
@ -1,308 +0,0 @@
|
||||
"""
|
||||
Enhanced Multi-Modal Trading System - Main Application
|
||||
|
||||
This is the main launcher for the sophisticated trading system featuring:
|
||||
1. Enhanced orchestrator coordinating CNN and RL modules
|
||||
2. Multi-timeframe, multi-symbol (ETH, BTC) trading decisions
|
||||
3. Perfect move marking for CNN training with known outcomes
|
||||
4. Continuous RL learning from trading action evaluations
|
||||
5. Market environment adaptation and coordinated decision making
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
import argparse
|
||||
|
||||
# Core components
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from models import get_model_registry
|
||||
|
||||
# Training components
|
||||
from training.enhanced_cnn_trainer import EnhancedCNNTrainer, EnhancedCNNModel
|
||||
from training.enhanced_rl_trainer import EnhancedRLTrainer, EnhancedDQNAgent
|
||||
|
||||
# Utilities
|
||||
import torch
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler('logs/enhanced_trading.log')
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class EnhancedTradingSystem:
|
||||
"""Main enhanced trading system coordinator"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
"""Initialize the enhanced trading system"""
|
||||
self.config = get_config(config_path)
|
||||
|
||||
# Initialize core components
|
||||
self.data_provider = DataProvider(self.config)
|
||||
self.orchestrator = EnhancedTradingOrchestrator(self.data_provider)
|
||||
|
||||
# Initialize training components
|
||||
self.cnn_trainer = EnhancedCNNTrainer(self.config, self.orchestrator)
|
||||
self.rl_trainer = EnhancedRLTrainer(self.config, self.orchestrator)
|
||||
|
||||
# Performance tracking
|
||||
self.performance_metrics = {
|
||||
'total_decisions': 0,
|
||||
'profitable_decisions': 0,
|
||||
'perfect_moves_marked': 0,
|
||||
'cnn_training_sessions': 0,
|
||||
'rl_training_steps': 0,
|
||||
'start_time': datetime.now()
|
||||
}
|
||||
|
||||
# System state
|
||||
self.running = False
|
||||
self.tasks = []
|
||||
|
||||
logger.info("Enhanced Trading System initialized")
|
||||
logger.info(f"Symbols: {self.config.symbols}")
|
||||
logger.info(f"Timeframes: {self.config.timeframes}")
|
||||
logger.info("LEARNING SYSTEMS ACTIVE:")
|
||||
logger.info("- RL agents learning from every trading decision")
|
||||
logger.info("- CNN training on perfect moves with known outcomes")
|
||||
logger.info("- Continuous pattern recognition and adaptation")
|
||||
|
||||
async def start(self):
|
||||
"""Start the enhanced trading system"""
|
||||
logger.info("Starting Enhanced Multi-Modal Trading System...")
|
||||
self.running = True
|
||||
|
||||
try:
|
||||
# Start all system components
|
||||
trading_task = asyncio.create_task(self.start_trading_loop())
|
||||
training_tasks = await self.start_training_loops()
|
||||
monitoring_task = asyncio.create_task(self.start_monitoring_loop())
|
||||
|
||||
# Store tasks for cleanup
|
||||
self.tasks = [trading_task, monitoring_task] + list(training_tasks)
|
||||
|
||||
# Wait for all tasks
|
||||
await asyncio.gather(*self.tasks)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutdown signal received...")
|
||||
await self.shutdown()
|
||||
except Exception as e:
|
||||
logger.error(f"System error: {e}")
|
||||
await self.shutdown()
|
||||
|
||||
async def start_trading_loop(self):
|
||||
"""Start the main trading decision loop"""
|
||||
logger.info("Starting enhanced trading decision loop...")
|
||||
decision_count = 0
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Get coordinated decisions for all symbols
|
||||
decisions = await self.orchestrator.make_coordinated_decisions()
|
||||
|
||||
for decision in decisions:
|
||||
decision_count += 1
|
||||
self.performance_metrics['total_decisions'] = decision_count
|
||||
|
||||
logger.info(f"DECISION #{decision_count}: {decision.action} {decision.symbol} "
|
||||
f"@ ${decision.price:.2f} (Confidence: {decision.confidence:.1%})")
|
||||
|
||||
# Execute decision (this would connect to broker in live trading)
|
||||
await self._execute_decision(decision)
|
||||
|
||||
# Add to RL evaluation queue for future learning
|
||||
await self.orchestrator.queue_action_for_evaluation(decision)
|
||||
|
||||
# Check for perfect moves to train CNN
|
||||
perfect_moves = self.orchestrator.get_recent_perfect_moves()
|
||||
if perfect_moves:
|
||||
self.performance_metrics['perfect_moves_marked'] = len(perfect_moves)
|
||||
logger.info(f"CNN LEARNING: {len(perfect_moves)} perfect moves identified for training")
|
||||
|
||||
# Log performance metrics every 10 decisions
|
||||
if decision_count % 10 == 0 and decision_count > 0:
|
||||
await self._log_performance_metrics()
|
||||
|
||||
# Wait before next decision cycle
|
||||
await asyncio.sleep(self.orchestrator.decision_frequency)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in trading loop: {e}")
|
||||
await asyncio.sleep(30) # Wait 30 seconds on error
|
||||
|
||||
async def start_training_loops(self):
|
||||
"""Start continuous training loops"""
|
||||
logger.info("Starting continuous learning systems...")
|
||||
|
||||
# Start RL continuous learning
|
||||
logger.info("STARTING RL CONTINUOUS LEARNING:")
|
||||
logger.info("- Learning from every trading decision outcome")
|
||||
logger.info("- Adapting to market regime changes")
|
||||
logger.info("- Prioritized experience replay")
|
||||
rl_task = asyncio.create_task(self.rl_trainer.continuous_learning_loop())
|
||||
|
||||
# Start periodic CNN training
|
||||
logger.info("STARTING CNN PATTERN LEARNING:")
|
||||
logger.info("- Training on perfect moves with known outcomes")
|
||||
logger.info("- Multi-timeframe pattern recognition")
|
||||
logger.info("- Retrospective learning from market data")
|
||||
cnn_task = asyncio.create_task(self._periodic_cnn_training())
|
||||
|
||||
return rl_task, cnn_task
|
||||
|
||||
async def _periodic_cnn_training(self):
|
||||
"""Periodically train CNN on perfect moves"""
|
||||
training_interval = self.config.training.get('cnn_training_interval', 21600) # 6 hours
|
||||
min_perfect_moves = self.config.training.get('min_perfect_moves', 200)
|
||||
|
||||
while self.running:
|
||||
try:
|
||||
# Check if we have enough perfect moves for training
|
||||
perfect_moves = self.orchestrator.get_perfect_moves_for_training()
|
||||
|
||||
if len(perfect_moves) >= min_perfect_moves:
|
||||
logger.info(f"CNN TRAINING: Starting with {len(perfect_moves)} perfect moves")
|
||||
|
||||
# Train CNN on perfect moves
|
||||
training_results = self.cnn_trainer.train_on_perfect_moves(min_samples=min_perfect_moves)
|
||||
|
||||
if 'error' not in training_results:
|
||||
self.performance_metrics['cnn_training_sessions'] += 1
|
||||
logger.info(f"CNN TRAINING COMPLETED: Session #{self.performance_metrics['cnn_training_sessions']}")
|
||||
logger.info(f"Training accuracy: {training_results.get('final_accuracy', 'N/A')}")
|
||||
logger.info(f"Confidence accuracy: {training_results.get('confidence_accuracy', 'N/A')}")
|
||||
else:
|
||||
logger.warning(f"CNN training failed: {training_results['error']}")
|
||||
else:
|
||||
logger.info(f"CNN WAITING: Need {min_perfect_moves - len(perfect_moves)} more perfect moves for training")
|
||||
|
||||
# Wait for next training cycle
|
||||
await asyncio.sleep(training_interval)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CNN training loop: {e}")
|
||||
await asyncio.sleep(3600) # Wait 1 hour on error
|
||||
|
||||
async def start_monitoring_loop(self):
|
||||
"""Monitor system performance and health"""
|
||||
while self.running:
|
||||
try:
|
||||
# Monitor memory usage
|
||||
if torch.cuda.is_available():
|
||||
gpu_memory = torch.cuda.memory_allocated() / (1024**3) # GB
|
||||
logger.info(f"SYSTEM HEALTH: GPU Memory: {gpu_memory:.2f}GB")
|
||||
|
||||
# Monitor model performance
|
||||
model_registry = get_model_registry()
|
||||
for model_name, model in model_registry.models.items():
|
||||
if hasattr(model, 'get_memory_usage'):
|
||||
memory_mb = model.get_memory_usage()
|
||||
logger.info(f"MODEL MEMORY: {model_name}: {memory_mb}MB")
|
||||
|
||||
# Monitor RL training progress
|
||||
for symbol, agent in self.rl_trainer.agents.items():
|
||||
buffer_size = len(agent.replay_buffer)
|
||||
epsilon = agent.epsilon
|
||||
logger.info(f"RL AGENT {symbol}: Buffer={buffer_size}, Epsilon={epsilon:.3f}")
|
||||
|
||||
await asyncio.sleep(300) # Monitor every 5 minutes
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in monitoring loop: {e}")
|
||||
await asyncio.sleep(60)
|
||||
|
||||
async def _execute_decision(self, decision):
|
||||
"""Execute trading decision (placeholder for broker integration)"""
|
||||
# This is where we would connect to a real broker API
|
||||
# For now, we just log the decision
|
||||
logger.info(f"EXECUTING: {decision.action} {decision.symbol} @ ${decision.price:.2f}")
|
||||
|
||||
# Simulate execution delay
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Mark as profitable for demo (in real trading, this would be determined by actual outcome)
|
||||
if decision.confidence > 0.7:
|
||||
self.performance_metrics['profitable_decisions'] += 1
|
||||
|
||||
async def _log_performance_metrics(self):
|
||||
"""Log comprehensive performance metrics"""
|
||||
runtime = datetime.now() - self.performance_metrics['start_time']
|
||||
|
||||
logger.info("PERFORMANCE METRICS:")
|
||||
logger.info(f"Runtime: {runtime}")
|
||||
logger.info(f"Total Decisions: {self.performance_metrics['total_decisions']}")
|
||||
logger.info(f"Profitable Decisions: {self.performance_metrics['profitable_decisions']}")
|
||||
logger.info(f"Perfect Moves Marked: {self.performance_metrics['perfect_moves_marked']}")
|
||||
logger.info(f"CNN Training Sessions: {self.performance_metrics['cnn_training_sessions']}")
|
||||
|
||||
# Calculate success rate
|
||||
if self.performance_metrics['total_decisions'] > 0:
|
||||
success_rate = self.performance_metrics['profitable_decisions'] / self.performance_metrics['total_decisions']
|
||||
logger.info(f"Success Rate: {success_rate:.1%}")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Gracefully shutdown the system"""
|
||||
logger.info("Shutting down Enhanced Trading System...")
|
||||
self.running = False
|
||||
|
||||
# Cancel all tasks
|
||||
for task in self.tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
# Save models
|
||||
try:
|
||||
self.cnn_trainer._save_model('shutdown_model.pt')
|
||||
self.rl_trainer._save_all_models()
|
||||
logger.info("Models saved successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving models: {e}")
|
||||
|
||||
# Final performance report
|
||||
await self._log_performance_metrics()
|
||||
logger.info("Enhanced Trading System shutdown complete")
|
||||
|
||||
async def main():
|
||||
"""Main entry point"""
|
||||
parser = argparse.ArgumentParser(description='Enhanced Multi-Modal Trading System')
|
||||
parser.add_argument('--config', type=str, help='Path to configuration file')
|
||||
parser.add_argument('--symbols', nargs='+', default=['ETH/USDT', 'BTC/USDT'],
|
||||
help='Trading symbols')
|
||||
parser.add_argument('--timeframes', nargs='+', default=['1s', '1m', '1h', '1d'],
|
||||
help='Trading timeframes')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create and start the enhanced trading system
|
||||
system = EnhancedTradingSystem(args.config)
|
||||
|
||||
# Setup signal handlers for graceful shutdown
|
||||
def signal_handler(signum, frame):
|
||||
logger.info(f"Received signal {signum}")
|
||||
asyncio.create_task(system.shutdown())
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Start the system
|
||||
await system.start()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Ensure logs directory exists
|
||||
Path('logs').mkdir(exist_ok=True)
|
||||
|
||||
# Run the enhanced trading system
|
||||
asyncio.run(main())
|
@ -1,219 +0,0 @@
|
||||
"""
|
||||
CNN-RL Bridge Module
|
||||
|
||||
This module provides the interface between CNN models and RL training,
|
||||
extracting hidden features and predictions from CNN models for use in RL state building.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CNNRLBridge:
|
||||
"""Bridge between CNN models and RL training for feature extraction"""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
"""Initialize CNN-RL bridge"""
|
||||
self.config = config
|
||||
self.cnn_models = {}
|
||||
self.feature_cache = {}
|
||||
self.cache_timeout = 60 # Cache features for 60 seconds
|
||||
|
||||
# Initialize CNN model registry if available
|
||||
self._initialize_cnn_models()
|
||||
|
||||
logger.info("CNN-RL Bridge initialized")
|
||||
|
||||
def _initialize_cnn_models(self):
|
||||
"""Initialize CNN models from config or model registry"""
|
||||
try:
|
||||
# Try to load CNN models from config
|
||||
if hasattr(self.config, 'cnn_models') and self.config.cnn_models:
|
||||
for model_name, model_config in self.config.cnn_models.items():
|
||||
try:
|
||||
# Load CNN model (implementation would depend on your CNN architecture)
|
||||
model = self._load_cnn_model(model_name, model_config)
|
||||
if model:
|
||||
self.cnn_models[model_name] = model
|
||||
logger.info(f"Loaded CNN model: {model_name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load CNN model {model_name}: {e}")
|
||||
|
||||
if not self.cnn_models:
|
||||
logger.info("No CNN models available - RL will train without CNN features")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error initializing CNN models: {e}")
|
||||
|
||||
def _load_cnn_model(self, model_name: str, model_config: Dict) -> Optional[nn.Module]:
|
||||
"""Load a CNN model from configuration"""
|
||||
try:
|
||||
# This would implement actual CNN model loading
|
||||
# For now, return None to indicate no models available
|
||||
# In your implementation, this would load your specific CNN architecture
|
||||
|
||||
logger.info(f"CNN model loading framework ready for {model_name}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading CNN model {model_name}: {e}")
|
||||
return None
|
||||
|
||||
def get_latest_features_for_symbol(self, symbol: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get latest CNN features and predictions for a symbol"""
|
||||
try:
|
||||
# Check cache first
|
||||
cache_key = f"{symbol}_{datetime.now().strftime('%Y%m%d_%H%M')}"
|
||||
if cache_key in self.feature_cache:
|
||||
cached_data = self.feature_cache[cache_key]
|
||||
if (datetime.now() - cached_data['timestamp']).seconds < self.cache_timeout:
|
||||
return cached_data['features']
|
||||
|
||||
# Generate new features if models available
|
||||
if self.cnn_models:
|
||||
features = self._extract_cnn_features_for_symbol(symbol)
|
||||
|
||||
# Cache the features
|
||||
self.feature_cache[cache_key] = {
|
||||
'timestamp': datetime.now(),
|
||||
'features': features
|
||||
}
|
||||
|
||||
# Clean old cache entries
|
||||
self._cleanup_cache()
|
||||
|
||||
return features
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting CNN features for {symbol}: {e}")
|
||||
return None
|
||||
|
||||
def _extract_cnn_features_for_symbol(self, symbol: str) -> Dict[str, Any]:
|
||||
"""Extract CNN hidden features and predictions for a symbol"""
|
||||
try:
|
||||
extracted_features = {
|
||||
'hidden_features': {},
|
||||
'predictions': {}
|
||||
}
|
||||
|
||||
for model_name, model in self.cnn_models.items():
|
||||
try:
|
||||
# Extract features from each CNN model
|
||||
hidden_features, predictions = self._extract_model_features(model, symbol)
|
||||
|
||||
if hidden_features is not None:
|
||||
extracted_features['hidden_features'][model_name] = hidden_features
|
||||
|
||||
if predictions is not None:
|
||||
extracted_features['predictions'][model_name] = predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting features from {model_name}: {e}")
|
||||
|
||||
return extracted_features
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting CNN features for {symbol}: {e}")
|
||||
return {'hidden_features': {}, 'predictions': {}}
|
||||
|
||||
def _extract_model_features(self, model: nn.Module, symbol: str) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""Extract hidden features and predictions from a specific CNN model"""
|
||||
try:
|
||||
# This would implement the actual feature extraction from your CNN models
|
||||
# The implementation depends on your specific CNN architecture
|
||||
|
||||
# For now, return mock data to show the structure
|
||||
# In real implementation, this would:
|
||||
# 1. Get market data for the model
|
||||
# 2. Run forward pass through CNN
|
||||
# 3. Extract hidden layer activations
|
||||
# 4. Get model predictions
|
||||
|
||||
# Mock hidden features (last hidden layer of CNN)
|
||||
hidden_features = np.random.random(512).astype(np.float32)
|
||||
|
||||
# Mock predictions for different timeframes
|
||||
# [1s_pred, 1m_pred, 1h_pred, 1d_pred] for each timeframe
|
||||
predictions = np.array([
|
||||
0.45, # 1s prediction (probability of up move)
|
||||
0.52, # 1m prediction
|
||||
0.38, # 1h prediction
|
||||
0.61 # 1d prediction
|
||||
]).astype(np.float32)
|
||||
|
||||
logger.debug(f"Extracted CNN features for {symbol}: {len(hidden_features)} hidden, {len(predictions)} predictions")
|
||||
|
||||
return hidden_features, predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error extracting features from model: {e}")
|
||||
return None, None
|
||||
|
||||
def _cleanup_cache(self):
|
||||
"""Clean up old cache entries"""
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
expired_keys = []
|
||||
|
||||
for key, data in self.feature_cache.items():
|
||||
if (current_time - data['timestamp']).seconds > self.cache_timeout * 2:
|
||||
expired_keys.append(key)
|
||||
|
||||
for key in expired_keys:
|
||||
del self.feature_cache[key]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error cleaning up feature cache: {e}")
|
||||
|
||||
def register_cnn_model(self, model_name: str, model: nn.Module):
|
||||
"""Register a CNN model for feature extraction"""
|
||||
try:
|
||||
self.cnn_models[model_name] = model
|
||||
logger.info(f"Registered CNN model: {model_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering CNN model {model_name}: {e}")
|
||||
|
||||
def unregister_cnn_model(self, model_name: str):
|
||||
"""Unregister a CNN model"""
|
||||
try:
|
||||
if model_name in self.cnn_models:
|
||||
del self.cnn_models[model_name]
|
||||
logger.info(f"Unregistered CNN model: {model_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error unregistering CNN model {model_name}: {e}")
|
||||
|
||||
def get_available_models(self) -> List[str]:
|
||||
"""Get list of available CNN models"""
|
||||
return list(self.cnn_models.keys())
|
||||
|
||||
def is_model_available(self, model_name: str) -> bool:
|
||||
"""Check if a specific CNN model is available"""
|
||||
return model_name in self.cnn_models
|
||||
|
||||
def get_feature_dimensions(self) -> Dict[str, int]:
|
||||
"""Get the dimensions of features extracted from CNN models"""
|
||||
return {
|
||||
'hidden_features_per_model': 512,
|
||||
'predictions_per_model': 4, # 1s, 1m, 1h, 1d
|
||||
'total_models': len(self.cnn_models)
|
||||
}
|
||||
|
||||
def validate_cnn_integration(self) -> Dict[str, Any]:
|
||||
"""Validate CNN integration status"""
|
||||
status = {
|
||||
'models_available': len(self.cnn_models),
|
||||
'models_list': list(self.cnn_models.keys()),
|
||||
'cache_entries': len(self.feature_cache),
|
||||
'integration_ready': len(self.cnn_models) > 0,
|
||||
'expected_feature_size': len(self.cnn_models) * 512, # hidden features
|
||||
'expected_prediction_size': len(self.cnn_models) * 4 # predictions
|
||||
}
|
||||
|
||||
return status
|
@ -1,491 +0,0 @@
|
||||
"""
|
||||
CNN Training Pipeline
|
||||
|
||||
This module handles training of the CNN model using ONLY real market data.
|
||||
All training metrics are logged to TensorBoard for real-time monitoring.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import Dataset, DataLoader, random_split
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Optional
|
||||
from pathlib import Path
|
||||
import time
|
||||
from sklearn.metrics import classification_report, confusion_matrix
|
||||
import json
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from models.cnn.scalping_cnn import MultiTimeframeCNN, ScalpingDataGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CNNDataset(Dataset):
|
||||
"""Dataset for CNN training with real market data"""
|
||||
|
||||
def __init__(self, features: np.ndarray, labels: np.ndarray):
|
||||
self.features = torch.FloatTensor(features)
|
||||
self.labels = torch.LongTensor(np.argmax(labels, axis=1)) # Convert one-hot to class indices
|
||||
|
||||
def __len__(self):
|
||||
return len(self.features)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.features[idx], self.labels[idx]
|
||||
|
||||
class CNNTrainer:
|
||||
"""CNN Trainer using ONLY real market data with TensorBoard monitoring"""
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None):
|
||||
"""Initialize CNN trainer"""
|
||||
self.config = config or get_config()
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Training parameters
|
||||
self.learning_rate = self.config.training.get('learning_rate', 0.001)
|
||||
self.batch_size = self.config.training.get('batch_size', 32)
|
||||
self.epochs = self.config.training.get('epochs', 100)
|
||||
self.validation_split = self.config.training.get('validation_split', 0.2)
|
||||
self.early_stopping_patience = self.config.training.get('early_stopping_patience', 10)
|
||||
|
||||
# Model parameters - will be updated based on real data
|
||||
self.n_timeframes = len(self.config.timeframes)
|
||||
self.window_size = self.config.cnn.get('window_size', 20)
|
||||
self.n_features = self.config.cnn.get('features', 26) # Will be dynamically updated
|
||||
self.n_classes = 3 # BUY, SELL, HOLD
|
||||
|
||||
# Initialize components
|
||||
self.data_provider = DataProvider(self.config)
|
||||
self.data_generator = ScalpingDataGenerator(self.data_provider, self.window_size)
|
||||
self.model = None
|
||||
|
||||
# TensorBoard setup
|
||||
self.setup_tensorboard()
|
||||
|
||||
logger.info(f"CNNTrainer initialized with {self.n_timeframes} timeframes, {self.n_features} features")
|
||||
logger.info("Will use ONLY real market data for training")
|
||||
|
||||
def setup_tensorboard(self):
|
||||
"""Setup TensorBoard logging"""
|
||||
# Create tensorboard logs directory
|
||||
log_dir = Path("runs") / f"cnn_training_{int(time.time())}"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.writer = SummaryWriter(log_dir=str(log_dir))
|
||||
self.tensorboard_dir = log_dir
|
||||
|
||||
logger.info(f"TensorBoard logging to: {log_dir}")
|
||||
logger.info(f"Run: tensorboard --logdir=runs")
|
||||
|
||||
def log_model_architecture(self):
|
||||
"""Log model architecture to TensorBoard"""
|
||||
if self.model is not None:
|
||||
# Log model graph (requires a dummy input)
|
||||
dummy_input = torch.randn(1, self.n_timeframes, self.window_size, self.n_features).to(self.device)
|
||||
try:
|
||||
self.writer.add_graph(self.model, dummy_input)
|
||||
logger.info("Model architecture logged to TensorBoard")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not log model graph: {e}")
|
||||
|
||||
# Log model parameters count
|
||||
total_params = sum(p.numel() for p in self.model.parameters())
|
||||
trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||
|
||||
self.writer.add_scalar('Model/TotalParameters', total_params, 0)
|
||||
self.writer.add_scalar('Model/TrainableParameters', trainable_params, 0)
|
||||
|
||||
def create_model(self) -> MultiTimeframeCNN:
|
||||
"""Create CNN model"""
|
||||
model = MultiTimeframeCNN(
|
||||
n_timeframes=self.n_timeframes,
|
||||
window_size=self.window_size,
|
||||
n_features=self.n_features,
|
||||
n_classes=self.n_classes,
|
||||
dropout_rate=self.config.cnn.get('dropout', 0.2)
|
||||
)
|
||||
|
||||
model = model.to(self.device)
|
||||
|
||||
# Log model info
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
memory_usage = model.get_memory_usage()
|
||||
|
||||
logger.info(f"Model created with {total_params:,} total parameters")
|
||||
logger.info(f"Trainable parameters: {trainable_params:,}")
|
||||
logger.info(f"Estimated memory usage: {memory_usage}MB")
|
||||
|
||||
return model
|
||||
|
||||
def prepare_data(self, symbols: List[str], num_samples: int = 10000) -> Tuple[np.ndarray, np.ndarray, Dict]:
|
||||
"""Prepare training data from REAL market data"""
|
||||
logger.info("Preparing training data...")
|
||||
logger.info("Data source: REAL market data from exchange APIs")
|
||||
|
||||
all_features = []
|
||||
all_labels = []
|
||||
all_metadata = []
|
||||
|
||||
for symbol in symbols:
|
||||
logger.info(f"Generating data for {symbol}...")
|
||||
|
||||
features, labels, metadata = self.data_generator.generate_training_cases(
|
||||
symbol=symbol,
|
||||
timeframes=self.config.timeframes,
|
||||
num_samples=num_samples
|
||||
)
|
||||
|
||||
if features is not None:
|
||||
all_features.append(features)
|
||||
all_labels.append(labels)
|
||||
all_metadata.append(metadata)
|
||||
|
||||
logger.info(f"Generated {len(features)} samples for {symbol}")
|
||||
|
||||
# Update feature count if needed
|
||||
actual_features = features.shape[-1]
|
||||
if actual_features != self.n_features:
|
||||
logger.info(f"Updating feature count from {self.n_features} to {actual_features}")
|
||||
self.n_features = actual_features
|
||||
|
||||
if not all_features:
|
||||
raise ValueError("No training data generated from real market data")
|
||||
|
||||
# Combine all data
|
||||
features = np.concatenate(all_features, axis=0)
|
||||
labels = np.concatenate(all_labels, axis=0)
|
||||
|
||||
# Log data statistics to TensorBoard
|
||||
self.log_data_statistics(features, labels)
|
||||
|
||||
return features, labels, all_metadata
|
||||
|
||||
def log_data_statistics(self, features: np.ndarray, labels: np.ndarray):
|
||||
"""Log data statistics to TensorBoard"""
|
||||
# Dataset size
|
||||
self.writer.add_scalar('Data/TotalSamples', len(features), 0)
|
||||
self.writer.add_scalar('Data/Features', features.shape[-1], 0)
|
||||
self.writer.add_scalar('Data/Timeframes', features.shape[1], 0)
|
||||
self.writer.add_scalar('Data/WindowSize', features.shape[2], 0)
|
||||
|
||||
# Class distribution
|
||||
class_counts = np.bincount(np.argmax(labels, axis=1))
|
||||
for i, count in enumerate(class_counts):
|
||||
self.writer.add_scalar(f'Data/Class_{i}_Count', count, 0)
|
||||
|
||||
# Feature statistics
|
||||
feature_means = features.mean(axis=(0, 1, 2))
|
||||
feature_stds = features.std(axis=(0, 1, 2))
|
||||
|
||||
for i in range(min(10, len(feature_means))): # Log first 10 features
|
||||
self.writer.add_scalar(f'Data/Feature_{i}_Mean', feature_means[i], 0)
|
||||
self.writer.add_scalar(f'Data/Feature_{i}_Std', feature_stds[i], 0)
|
||||
|
||||
def train_epoch(self, model: nn.Module, train_loader: DataLoader,
|
||||
optimizer: torch.optim.Optimizer, criterion: nn.Module, epoch: int) -> Tuple[float, float]:
|
||||
"""Train for one epoch with TensorBoard logging"""
|
||||
model.train()
|
||||
total_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
|
||||
for batch_idx, (features, labels) in enumerate(train_loader):
|
||||
features, labels = features.to(self.device), labels.to(self.device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
predictions = model(features)
|
||||
loss = criterion(predictions['action'], labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
_, predicted = torch.max(predictions['action'].data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
# Log batch metrics
|
||||
step = epoch * len(train_loader) + batch_idx
|
||||
self.writer.add_scalar('Training/BatchLoss', loss.item(), step)
|
||||
|
||||
if batch_idx % 50 == 0: # Log every 50 batches
|
||||
batch_acc = 100. * (predicted == labels).sum().item() / labels.size(0)
|
||||
self.writer.add_scalar('Training/BatchAccuracy', batch_acc, step)
|
||||
|
||||
# Log confidence scores
|
||||
avg_confidence = predictions['confidence'].mean().item()
|
||||
self.writer.add_scalar('Training/BatchConfidence', avg_confidence, step)
|
||||
|
||||
epoch_loss = total_loss / len(train_loader)
|
||||
epoch_accuracy = correct / total
|
||||
|
||||
return epoch_loss, epoch_accuracy
|
||||
|
||||
def validate_epoch(self, model: nn.Module, val_loader: DataLoader,
|
||||
criterion: nn.Module, epoch: int) -> Tuple[float, float, Dict]:
|
||||
"""Validate for one epoch with TensorBoard logging"""
|
||||
model.eval()
|
||||
total_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
all_predictions = []
|
||||
all_labels = []
|
||||
all_confidences = []
|
||||
|
||||
with torch.no_grad():
|
||||
for features, labels in val_loader:
|
||||
features, labels = features.to(self.device), labels.to(self.device)
|
||||
|
||||
predictions = model(features)
|
||||
loss = criterion(predictions['action'], labels)
|
||||
|
||||
total_loss += loss.item()
|
||||
_, predicted = torch.max(predictions['action'].data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
|
||||
all_predictions.extend(predicted.cpu().numpy())
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
all_confidences.extend(predictions['confidence'].cpu().numpy())
|
||||
|
||||
epoch_loss = total_loss / len(val_loader)
|
||||
epoch_accuracy = correct / total
|
||||
|
||||
# Calculate detailed metrics
|
||||
metrics = self.calculate_detailed_metrics(all_predictions, all_labels, all_confidences)
|
||||
|
||||
# Log validation metrics to TensorBoard
|
||||
self.writer.add_scalar('Validation/Loss', epoch_loss, epoch)
|
||||
self.writer.add_scalar('Validation/Accuracy', epoch_accuracy, epoch)
|
||||
self.writer.add_scalar('Validation/AvgConfidence', metrics['avg_confidence'], epoch)
|
||||
|
||||
for class_idx, acc in metrics['class_accuracies'].items():
|
||||
self.writer.add_scalar(f'Validation/Class_{class_idx}_Accuracy', acc, epoch)
|
||||
|
||||
return epoch_loss, epoch_accuracy, metrics
|
||||
|
||||
def calculate_detailed_metrics(self, predictions: List, labels: List, confidences: List) -> Dict:
|
||||
"""Calculate detailed training metrics"""
|
||||
predictions = np.array(predictions)
|
||||
labels = np.array(labels)
|
||||
confidences = np.array(confidences)
|
||||
|
||||
# Class-wise accuracies
|
||||
class_accuracies = {}
|
||||
for class_idx in range(self.n_classes):
|
||||
class_mask = labels == class_idx
|
||||
if class_mask.sum() > 0:
|
||||
class_acc = (predictions[class_mask] == labels[class_mask]).mean()
|
||||
class_accuracies[class_idx] = class_acc
|
||||
|
||||
return {
|
||||
'class_accuracies': class_accuracies,
|
||||
'avg_confidence': confidences.mean(),
|
||||
'confusion_matrix': confusion_matrix(labels, predictions)
|
||||
}
|
||||
|
||||
def train(self, symbols: List[str], save_path: str = 'models/cnn/scalping_cnn_trained.pt',
|
||||
num_samples: int = 10000) -> Dict:
|
||||
"""Train CNN model with TensorBoard monitoring"""
|
||||
logger.info("Starting CNN training...")
|
||||
logger.info("Using ONLY real market data from exchange APIs")
|
||||
|
||||
# Prepare data
|
||||
features, labels, metadata = self.prepare_data(symbols, num_samples)
|
||||
|
||||
# Log training configuration
|
||||
self.writer.add_text('Config/Symbols', str(symbols), 0)
|
||||
self.writer.add_text('Config/Timeframes', str(self.config.timeframes), 0)
|
||||
self.writer.add_scalar('Config/LearningRate', self.learning_rate, 0)
|
||||
self.writer.add_scalar('Config/BatchSize', self.batch_size, 0)
|
||||
self.writer.add_scalar('Config/MaxEpochs', self.epochs, 0)
|
||||
|
||||
# Create datasets
|
||||
dataset = CNNDataset(features, labels)
|
||||
|
||||
# Split data
|
||||
val_size = int(len(dataset) * self.validation_split)
|
||||
train_size = len(dataset) - val_size
|
||||
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
||||
|
||||
# Create data loaders
|
||||
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
|
||||
|
||||
logger.info(f"Total dataset: {len(dataset)} samples")
|
||||
logger.info(f"Features shape: {features.shape}")
|
||||
logger.info(f"Labels shape: {labels.shape}")
|
||||
logger.info(f"Train samples: {train_size}")
|
||||
logger.info(f"Validation samples: {val_size}")
|
||||
|
||||
# Log class distributions
|
||||
train_labels = [dataset[i][1].item() for i in train_dataset.indices]
|
||||
val_labels = [dataset[i][1].item() for i in val_dataset.indices]
|
||||
|
||||
logger.info(f"Train label distribution: {np.bincount(train_labels)}")
|
||||
logger.info(f"Val label distribution: {np.bincount(val_labels)}")
|
||||
|
||||
# Create model
|
||||
self.model = self.create_model()
|
||||
self.log_model_architecture()
|
||||
|
||||
# Setup training
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
|
||||
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=5, verbose=True)
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
best_val_accuracy = 0.0
|
||||
patience_counter = 0
|
||||
start_time = time.time()
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
epoch_start = time.time()
|
||||
|
||||
# Train
|
||||
train_loss, train_accuracy = self.train_epoch(self.model, train_loader, optimizer, criterion, epoch)
|
||||
|
||||
# Validate
|
||||
val_loss, val_accuracy, val_metrics = self.validate_epoch(self.model, val_loader, criterion, epoch)
|
||||
|
||||
# Update learning rate
|
||||
scheduler.step(val_loss)
|
||||
current_lr = optimizer.param_groups[0]['lr']
|
||||
|
||||
# Log epoch metrics
|
||||
self.writer.add_scalar('Training/EpochLoss', train_loss, epoch)
|
||||
self.writer.add_scalar('Training/EpochAccuracy', train_accuracy, epoch)
|
||||
self.writer.add_scalar('Training/LearningRate', current_lr, epoch)
|
||||
|
||||
epoch_time = time.time() - epoch_start
|
||||
self.writer.add_scalar('Training/EpochTime', epoch_time, epoch)
|
||||
|
||||
# Save best model
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
best_val_accuracy = val_accuracy
|
||||
patience_counter = 0
|
||||
|
||||
# Save best model
|
||||
best_path = save_path.replace('.pt', '_best.pt')
|
||||
self.model.save(best_path)
|
||||
logger.info(f"New best model saved: {best_path}")
|
||||
|
||||
# Log best metrics
|
||||
self.writer.add_scalar('Best/ValidationLoss', best_val_loss, epoch)
|
||||
self.writer.add_scalar('Best/ValidationAccuracy', best_val_accuracy, epoch)
|
||||
else:
|
||||
patience_counter += 1
|
||||
|
||||
logger.info(f"Epoch {epoch+1}/{self.epochs} - "
|
||||
f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f} - "
|
||||
f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f} - "
|
||||
f"Time: {epoch_time:.2f}s")
|
||||
|
||||
# Log detailed metrics every 10 epochs
|
||||
if (epoch + 1) % 10 == 0:
|
||||
logger.info(f"Class accuracies: {val_metrics['class_accuracies']}")
|
||||
logger.info(f"Average confidence: {val_metrics['avg_confidence']:.4f}")
|
||||
|
||||
# Early stopping
|
||||
if patience_counter >= self.early_stopping_patience:
|
||||
logger.info(f"Early stopping triggered after {epoch+1} epochs")
|
||||
break
|
||||
|
||||
# Training completed
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"Training completed in {total_time:.2f} seconds")
|
||||
logger.info(f"Best validation loss: {best_val_loss:.4f}")
|
||||
logger.info(f"Best validation accuracy: {best_val_accuracy:.4f}")
|
||||
|
||||
# Log final metrics
|
||||
self.writer.add_scalar('Final/TotalTrainingTime', total_time, 0)
|
||||
self.writer.add_scalar('Final/TotalEpochs', epoch + 1, 0)
|
||||
|
||||
# Save final model
|
||||
self.model.save(save_path)
|
||||
logger.info(f"Final model saved: {save_path}")
|
||||
|
||||
# Log training summary
|
||||
self.writer.add_text('Training/Summary',
|
||||
f"Completed training with {len(features)} real market samples. "
|
||||
f"Best validation accuracy: {best_val_accuracy:.4f}", 0)
|
||||
|
||||
return {
|
||||
'best_val_loss': best_val_loss,
|
||||
'best_val_accuracy': best_val_accuracy,
|
||||
'total_epochs': epoch + 1,
|
||||
'training_time': total_time,
|
||||
'tensorboard_dir': str(self.tensorboard_dir)
|
||||
}
|
||||
|
||||
def evaluate(self, symbols: List[str], num_samples: int = 5000) -> Dict:
|
||||
"""Evaluate trained model on test data"""
|
||||
if self.model is None:
|
||||
raise ValueError("Model not trained yet")
|
||||
|
||||
logger.info("Evaluating model...")
|
||||
|
||||
# Generate test data from real market data
|
||||
features, labels, metadata = self.prepare_data(symbols, num_samples)
|
||||
|
||||
# Create test dataset and loader
|
||||
test_dataset = CNNDataset(features, labels)
|
||||
test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
|
||||
|
||||
# Evaluate
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
test_loss, test_accuracy, test_metrics = self.validate_epoch(
|
||||
self.model, test_loader, criterion, epoch=0
|
||||
)
|
||||
|
||||
# Generate detailed classification report
|
||||
from sklearn.metrics import classification_report
|
||||
class_names = ['BUY', 'SELL', 'HOLD']
|
||||
all_predictions = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for features_batch, labels_batch in test_loader:
|
||||
features_batch = features_batch.to(self.device)
|
||||
predictions = self.model(features_batch)
|
||||
_, predicted = torch.max(predictions['action'].data, 1)
|
||||
all_predictions.extend(predicted.cpu().numpy())
|
||||
all_labels.extend(labels_batch.numpy())
|
||||
|
||||
classification_rep = classification_report(
|
||||
all_labels, all_predictions, target_names=class_names, output_dict=True
|
||||
)
|
||||
|
||||
evaluation_results = {
|
||||
'test_loss': test_loss,
|
||||
'test_accuracy': test_accuracy,
|
||||
'classification_report': classification_rep,
|
||||
'class_accuracies': test_metrics['class_accuracies'],
|
||||
'avg_confidence': test_metrics['avg_confidence'],
|
||||
'confusion_matrix': test_metrics['confusion_matrix']
|
||||
}
|
||||
|
||||
logger.info(f"Test accuracy: {test_accuracy:.4f}")
|
||||
logger.info(f"Test loss: {test_loss:.4f}")
|
||||
|
||||
return evaluation_results
|
||||
|
||||
def close_tensorboard(self):
|
||||
"""Close TensorBoard writer"""
|
||||
if hasattr(self, 'writer'):
|
||||
self.writer.close()
|
||||
logger.info("TensorBoard writer closed")
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup"""
|
||||
self.close_tensorboard()
|
||||
|
||||
# Export
|
||||
__all__ = ['CNNTrainer', 'CNNDataset']
|
@ -1,811 +0,0 @@
|
||||
"""
|
||||
Enhanced CNN Trainer with Perfect Move Learning
|
||||
|
||||
This trainer implements:
|
||||
1. Training on marked perfect moves with known outcomes
|
||||
2. Multi-timeframe CNN model training with confidence scoring
|
||||
3. Backpropagation on optimal moves when future outcomes are known
|
||||
4. Progressive learning from real trading experience
|
||||
5. Symbol-specific and timeframe-specific model fine-tuning
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import PerfectMove, EnhancedTradingOrchestrator
|
||||
from models import CNNModelInterface
|
||||
import models
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PerfectMoveDataset(Dataset):
|
||||
"""Dataset for training on perfect moves with known outcomes"""
|
||||
|
||||
def __init__(self, perfect_moves: List[PerfectMove], data_provider: DataProvider):
|
||||
"""
|
||||
Initialize dataset from perfect moves
|
||||
|
||||
Args:
|
||||
perfect_moves: List of perfect moves with known outcomes
|
||||
data_provider: Data provider to fetch additional context
|
||||
"""
|
||||
self.perfect_moves = perfect_moves
|
||||
self.data_provider = data_provider
|
||||
self.samples = []
|
||||
self._prepare_samples()
|
||||
|
||||
def _prepare_samples(self):
|
||||
"""Prepare training samples from perfect moves"""
|
||||
logger.info(f"Preparing {len(self.perfect_moves)} perfect move samples")
|
||||
|
||||
for move in self.perfect_moves:
|
||||
try:
|
||||
# Get feature matrix at the time of the decision
|
||||
feature_matrix = self.data_provider.get_feature_matrix(
|
||||
symbol=move.symbol,
|
||||
timeframes=[move.timeframe],
|
||||
window_size=20,
|
||||
end_time=move.timestamp
|
||||
)
|
||||
|
||||
if feature_matrix is not None:
|
||||
# Convert optimal action to label
|
||||
action_to_label = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
|
||||
label = action_to_label.get(move.optimal_action, 1)
|
||||
|
||||
# Create confidence target (what confidence should have been)
|
||||
confidence_target = move.confidence_should_have_been
|
||||
|
||||
sample = {
|
||||
'features': feature_matrix,
|
||||
'action_label': label,
|
||||
'confidence_target': confidence_target,
|
||||
'symbol': move.symbol,
|
||||
'timeframe': move.timeframe,
|
||||
'outcome': move.actual_outcome,
|
||||
'timestamp': move.timestamp
|
||||
}
|
||||
self.samples.append(sample)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error preparing sample for perfect move: {e}")
|
||||
|
||||
logger.info(f"Prepared {len(self.samples)} valid training samples")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.samples)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
sample = self.samples[idx]
|
||||
|
||||
# Convert to tensors
|
||||
features = torch.FloatTensor(sample['features'])
|
||||
action_label = torch.LongTensor([sample['action_label']])
|
||||
confidence_target = torch.FloatTensor([sample['confidence_target']])
|
||||
|
||||
return {
|
||||
'features': features,
|
||||
'action_label': action_label,
|
||||
'confidence_target': confidence_target,
|
||||
'metadata': {
|
||||
'symbol': sample['symbol'],
|
||||
'timeframe': sample['timeframe'],
|
||||
'outcome': sample['outcome'],
|
||||
'timestamp': sample['timestamp']
|
||||
}
|
||||
}
|
||||
|
||||
class EnhancedCNNModel(nn.Module, CNNModelInterface):
|
||||
"""Enhanced CNN model with timeframe-specific predictions and confidence scoring"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
nn.Module.__init__(self)
|
||||
CNNModelInterface.__init__(self, config)
|
||||
|
||||
self.timeframes = config.get('timeframes', ['1h', '4h', '1d'])
|
||||
self.n_features = len(config.get('features', ['open', 'high', 'low', 'close', 'volume']))
|
||||
self.window_size = config.get('window_size', 20)
|
||||
|
||||
# Build the neural network
|
||||
self._build_network()
|
||||
|
||||
# Initialize device
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.to(self.device)
|
||||
|
||||
# Training components
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=config.get('learning_rate', 0.001))
|
||||
self.action_criterion = nn.CrossEntropyLoss()
|
||||
self.confidence_criterion = nn.MSELoss()
|
||||
|
||||
logger.info(f"Enhanced CNN model initialized for {len(self.timeframes)} timeframes")
|
||||
|
||||
def _build_network(self):
|
||||
"""Build the CNN architecture"""
|
||||
# Convolutional feature extraction
|
||||
self.conv_layers = nn.Sequential(
|
||||
# First conv block
|
||||
nn.Conv1d(self.n_features, 64, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Second conv block
|
||||
nn.Conv1d(64, 128, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Third conv block
|
||||
nn.Conv1d(128, 256, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(256),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2),
|
||||
|
||||
# Global average pooling
|
||||
nn.AdaptiveAvgPool1d(1)
|
||||
)
|
||||
|
||||
# Timeframe-specific heads
|
||||
self.timeframe_heads = nn.ModuleDict()
|
||||
for timeframe in self.timeframes:
|
||||
self.timeframe_heads[timeframe] = nn.Sequential(
|
||||
nn.Linear(256, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(128, 64),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3)
|
||||
)
|
||||
|
||||
# Action prediction heads (one per timeframe)
|
||||
self.action_heads = nn.ModuleDict()
|
||||
for timeframe in self.timeframes:
|
||||
self.action_heads[timeframe] = nn.Linear(64, 3) # BUY, HOLD, SELL
|
||||
|
||||
# Confidence prediction heads (one per timeframe)
|
||||
self.confidence_heads = nn.ModuleDict()
|
||||
for timeframe in self.timeframes:
|
||||
self.confidence_heads[timeframe] = nn.Sequential(
|
||||
nn.Linear(64, 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(32, 1),
|
||||
nn.Sigmoid() # Output between 0 and 1
|
||||
)
|
||||
|
||||
def forward(self, x, timeframe: str = None):
|
||||
"""
|
||||
Forward pass through the network
|
||||
|
||||
Args:
|
||||
x: Input tensor [batch_size, window_size, features]
|
||||
timeframe: Specific timeframe to predict for
|
||||
|
||||
Returns:
|
||||
action_probs: Action probabilities
|
||||
confidence: Confidence score
|
||||
"""
|
||||
# Reshape for conv1d: [batch, features, sequence]
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
# Extract features
|
||||
features = self.conv_layers(x) # [batch, 256, 1]
|
||||
features = features.squeeze(-1) # [batch, 256]
|
||||
|
||||
if timeframe and timeframe in self.timeframe_heads:
|
||||
# Timeframe-specific prediction
|
||||
tf_features = self.timeframe_heads[timeframe](features)
|
||||
action_logits = self.action_heads[timeframe](tf_features)
|
||||
confidence = self.confidence_heads[timeframe](tf_features)
|
||||
|
||||
action_probs = torch.softmax(action_logits, dim=1)
|
||||
return action_probs, confidence.squeeze(-1)
|
||||
else:
|
||||
# Multi-timeframe prediction (average across timeframes)
|
||||
all_action_probs = []
|
||||
all_confidences = []
|
||||
|
||||
for tf in self.timeframes:
|
||||
tf_features = self.timeframe_heads[tf](features)
|
||||
action_logits = self.action_heads[tf](tf_features)
|
||||
confidence = self.confidence_heads[tf](tf_features)
|
||||
|
||||
action_probs = torch.softmax(action_logits, dim=1)
|
||||
all_action_probs.append(action_probs)
|
||||
all_confidences.append(confidence.squeeze(-1))
|
||||
|
||||
# Average predictions across timeframes
|
||||
avg_action_probs = torch.stack(all_action_probs).mean(dim=0)
|
||||
avg_confidence = torch.stack(all_confidences).mean(dim=0)
|
||||
|
||||
return avg_action_probs, avg_confidence
|
||||
|
||||
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||
"""Predict action probabilities and confidence"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
x = torch.FloatTensor(features).to(self.device)
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0) # Add batch dimension
|
||||
|
||||
action_probs, confidence = self.forward(x)
|
||||
|
||||
return action_probs[0].cpu().numpy(), confidence[0].cpu().item()
|
||||
|
||||
def predict_timeframe(self, features: np.ndarray, timeframe: str) -> Tuple[np.ndarray, float]:
|
||||
"""Predict for specific timeframe"""
|
||||
self.eval()
|
||||
with torch.no_grad():
|
||||
x = torch.FloatTensor(features).to(self.device)
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0) # Add batch dimension
|
||||
|
||||
action_probs, confidence = self.forward(x, timeframe)
|
||||
|
||||
return action_probs[0].cpu().numpy(), confidence[0].cpu().item()
|
||||
|
||||
def get_memory_usage(self) -> int:
|
||||
"""Get memory usage in MB"""
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.memory_allocated(self.device) // (1024 * 1024)
|
||||
else:
|
||||
# Rough estimate for CPU
|
||||
param_count = sum(p.numel() for p in self.parameters())
|
||||
return (param_count * 4) // (1024 * 1024) # 4 bytes per float32
|
||||
|
||||
def train(self, training_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Train the model (placeholder for interface compatibility)"""
|
||||
return {}
|
||||
|
||||
class EnhancedCNNTrainer:
|
||||
"""Enhanced CNN trainer using perfect moves and real market outcomes"""
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
|
||||
"""Initialize the enhanced trainer"""
|
||||
self.config = config or get_config()
|
||||
self.orchestrator = orchestrator
|
||||
self.data_provider = DataProvider(self.config)
|
||||
|
||||
# Training parameters
|
||||
self.learning_rate = self.config.training.get('learning_rate', 0.001)
|
||||
self.batch_size = self.config.training.get('batch_size', 32)
|
||||
self.epochs = self.config.training.get('epochs', 100)
|
||||
self.patience = self.config.training.get('early_stopping_patience', 10)
|
||||
|
||||
# Model
|
||||
self.model = EnhancedCNNModel(self.config.cnn)
|
||||
|
||||
# Training history
|
||||
self.training_history = {
|
||||
'train_loss': [],
|
||||
'val_loss': [],
|
||||
'train_accuracy': [],
|
||||
'val_accuracy': [],
|
||||
'confidence_accuracy': []
|
||||
}
|
||||
|
||||
# Create save directory
|
||||
models_path = self.config.cnn.get('model_dir', "models/enhanced_cnn")
|
||||
self.save_dir = Path(models_path)
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("Enhanced CNN trainer initialized")
|
||||
|
||||
def train_on_perfect_moves(self, min_samples: int = 100) -> Dict[str, Any]:
|
||||
"""Train the model on perfect moves from the orchestrator"""
|
||||
if not self.orchestrator:
|
||||
raise ValueError("Orchestrator required for perfect move training")
|
||||
|
||||
# Get perfect moves from orchestrator
|
||||
perfect_moves = []
|
||||
for symbol in self.config.symbols:
|
||||
symbol_moves = self.orchestrator.get_perfect_moves_for_training(symbol=symbol)
|
||||
perfect_moves.extend(symbol_moves)
|
||||
|
||||
if len(perfect_moves) < min_samples:
|
||||
logger.warning(f"Not enough perfect moves for training: {len(perfect_moves)} < {min_samples}")
|
||||
return {'error': 'insufficient_data', 'samples': len(perfect_moves)}
|
||||
|
||||
logger.info(f"Training on {len(perfect_moves)} perfect moves")
|
||||
|
||||
# Create dataset
|
||||
dataset = PerfectMoveDataset(perfect_moves, self.data_provider)
|
||||
|
||||
# Split into train/validation
|
||||
train_size = int(0.8 * len(dataset))
|
||||
val_size = len(dataset) - train_size
|
||||
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
|
||||
|
||||
# Create data loaders
|
||||
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
|
||||
val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
|
||||
|
||||
# Training loop
|
||||
best_val_loss = float('inf')
|
||||
patience_counter = 0
|
||||
|
||||
for epoch in range(self.epochs):
|
||||
# Training phase
|
||||
train_loss, train_acc = self._train_epoch(train_loader)
|
||||
|
||||
# Validation phase
|
||||
val_loss, val_acc, conf_acc = self._validate_epoch(val_loader)
|
||||
|
||||
# Update history
|
||||
self.training_history['train_loss'].append(train_loss)
|
||||
self.training_history['val_loss'].append(val_loss)
|
||||
self.training_history['train_accuracy'].append(train_acc)
|
||||
self.training_history['val_accuracy'].append(val_acc)
|
||||
self.training_history['confidence_accuracy'].append(conf_acc)
|
||||
|
||||
# Log progress
|
||||
logger.info(f"Epoch {epoch+1}/{self.epochs}: "
|
||||
f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
|
||||
f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, "
|
||||
f"Conf Acc: {conf_acc:.4f}")
|
||||
|
||||
# Early stopping
|
||||
if val_loss < best_val_loss:
|
||||
best_val_loss = val_loss
|
||||
patience_counter = 0
|
||||
self._save_model('best_model.pt')
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= self.patience:
|
||||
logger.info(f"Early stopping at epoch {epoch+1}")
|
||||
break
|
||||
|
||||
# Save final model
|
||||
self._save_model('final_model.pt')
|
||||
|
||||
# Generate training report
|
||||
return self._generate_training_report()
|
||||
|
||||
def _train_epoch(self, train_loader: DataLoader) -> Tuple[float, float]:
|
||||
"""Train for one epoch"""
|
||||
self.model.train()
|
||||
total_loss = 0.0
|
||||
correct_predictions = 0
|
||||
total_predictions = 0
|
||||
|
||||
for batch in train_loader:
|
||||
features = batch['features'].to(self.model.device)
|
||||
action_labels = batch['action_label'].to(self.model.device).squeeze(-1)
|
||||
confidence_targets = batch['confidence_target'].to(self.model.device).squeeze(-1)
|
||||
|
||||
# Zero gradients
|
||||
self.model.optimizer.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
action_probs, confidence_pred = self.model(features)
|
||||
|
||||
# Calculate losses
|
||||
action_loss = self.model.action_criterion(action_probs, action_labels)
|
||||
confidence_loss = self.model.confidence_criterion(confidence_pred, confidence_targets)
|
||||
|
||||
# Combined loss
|
||||
total_loss_batch = action_loss + 0.5 * confidence_loss
|
||||
|
||||
# Backward pass
|
||||
total_loss_batch.backward()
|
||||
self.model.optimizer.step()
|
||||
|
||||
# Track metrics
|
||||
total_loss += total_loss_batch.item()
|
||||
predicted_actions = torch.argmax(action_probs, dim=1)
|
||||
correct_predictions += (predicted_actions == action_labels).sum().item()
|
||||
total_predictions += action_labels.size(0)
|
||||
|
||||
avg_loss = total_loss / len(train_loader)
|
||||
accuracy = correct_predictions / total_predictions
|
||||
|
||||
return avg_loss, accuracy
|
||||
|
||||
def _validate_epoch(self, val_loader: DataLoader) -> Tuple[float, float, float]:
|
||||
"""Validate for one epoch"""
|
||||
self.model.eval()
|
||||
total_loss = 0.0
|
||||
correct_predictions = 0
|
||||
total_predictions = 0
|
||||
confidence_errors = []
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in val_loader:
|
||||
features = batch['features'].to(self.model.device)
|
||||
action_labels = batch['action_label'].to(self.model.device).squeeze(-1)
|
||||
confidence_targets = batch['confidence_target'].to(self.model.device).squeeze(-1)
|
||||
|
||||
# Forward pass
|
||||
action_probs, confidence_pred = self.model(features)
|
||||
|
||||
# Calculate losses
|
||||
action_loss = self.model.action_criterion(action_probs, action_labels)
|
||||
confidence_loss = self.model.confidence_criterion(confidence_pred, confidence_targets)
|
||||
total_loss_batch = action_loss + 0.5 * confidence_loss
|
||||
|
||||
# Track metrics
|
||||
total_loss += total_loss_batch.item()
|
||||
predicted_actions = torch.argmax(action_probs, dim=1)
|
||||
correct_predictions += (predicted_actions == action_labels).sum().item()
|
||||
total_predictions += action_labels.size(0)
|
||||
|
||||
# Track confidence accuracy
|
||||
conf_errors = torch.abs(confidence_pred - confidence_targets)
|
||||
confidence_errors.extend(conf_errors.cpu().numpy())
|
||||
|
||||
avg_loss = total_loss / len(val_loader)
|
||||
accuracy = correct_predictions / total_predictions
|
||||
confidence_accuracy = 1.0 - np.mean(confidence_errors) # 1 - mean absolute error
|
||||
|
||||
return avg_loss, accuracy, confidence_accuracy
|
||||
|
||||
def _save_model(self, filename: str):
|
||||
"""Save the model"""
|
||||
save_path = self.save_dir / filename
|
||||
torch.save({
|
||||
'model_state_dict': self.model.state_dict(),
|
||||
'optimizer_state_dict': self.model.optimizer.state_dict(),
|
||||
'config': self.config.cnn,
|
||||
'training_history': self.training_history
|
||||
}, save_path)
|
||||
logger.info(f"Model saved to {save_path}")
|
||||
|
||||
def load_model(self, filename: str) -> bool:
|
||||
"""Load a saved model"""
|
||||
load_path = self.save_dir / filename
|
||||
if not load_path.exists():
|
||||
logger.error(f"Model file not found: {load_path}")
|
||||
return False
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(load_path, map_location=self.model.device)
|
||||
self.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self.model.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.training_history = checkpoint.get('training_history', {})
|
||||
logger.info(f"Model loaded from {load_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {e}")
|
||||
return False
|
||||
|
||||
def _generate_training_report(self) -> Dict[str, Any]:
|
||||
"""Generate comprehensive training report"""
|
||||
if not self.training_history['train_loss']:
|
||||
return {'error': 'no_training_data'}
|
||||
|
||||
# Calculate final metrics
|
||||
final_train_loss = self.training_history['train_loss'][-1]
|
||||
final_val_loss = self.training_history['val_loss'][-1]
|
||||
final_train_acc = self.training_history['train_accuracy'][-1]
|
||||
final_val_acc = self.training_history['val_accuracy'][-1]
|
||||
final_conf_acc = self.training_history['confidence_accuracy'][-1]
|
||||
|
||||
# Best metrics
|
||||
best_val_loss = min(self.training_history['val_loss'])
|
||||
best_val_acc = max(self.training_history['val_accuracy'])
|
||||
best_conf_acc = max(self.training_history['confidence_accuracy'])
|
||||
|
||||
report = {
|
||||
'training_completed': True,
|
||||
'epochs_trained': len(self.training_history['train_loss']),
|
||||
'final_metrics': {
|
||||
'train_loss': final_train_loss,
|
||||
'val_loss': final_val_loss,
|
||||
'train_accuracy': final_train_acc,
|
||||
'val_accuracy': final_val_acc,
|
||||
'confidence_accuracy': final_conf_acc
|
||||
},
|
||||
'best_metrics': {
|
||||
'val_loss': best_val_loss,
|
||||
'val_accuracy': best_val_acc,
|
||||
'confidence_accuracy': best_conf_acc
|
||||
},
|
||||
'model_info': {
|
||||
'timeframes': self.model.timeframes,
|
||||
'memory_usage_mb': self.model.get_memory_usage(),
|
||||
'device': str(self.model.device)
|
||||
}
|
||||
}
|
||||
|
||||
# Generate plots
|
||||
self._plot_training_history()
|
||||
|
||||
logger.info("Training completed successfully")
|
||||
logger.info(f"Final validation accuracy: {final_val_acc:.4f}")
|
||||
logger.info(f"Final confidence accuracy: {final_conf_acc:.4f}")
|
||||
|
||||
return report
|
||||
|
||||
def _plot_training_history(self):
|
||||
"""Plot training history"""
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
|
||||
fig.suptitle('Enhanced CNN Training History')
|
||||
|
||||
# Loss plot
|
||||
axes[0, 0].plot(self.training_history['train_loss'], label='Train Loss')
|
||||
axes[0, 0].plot(self.training_history['val_loss'], label='Val Loss')
|
||||
axes[0, 0].set_title('Loss')
|
||||
axes[0, 0].set_xlabel('Epoch')
|
||||
axes[0, 0].set_ylabel('Loss')
|
||||
axes[0, 0].legend()
|
||||
|
||||
# Accuracy plot
|
||||
axes[0, 1].plot(self.training_history['train_accuracy'], label='Train Accuracy')
|
||||
axes[0, 1].plot(self.training_history['val_accuracy'], label='Val Accuracy')
|
||||
axes[0, 1].set_title('Action Accuracy')
|
||||
axes[0, 1].set_xlabel('Epoch')
|
||||
axes[0, 1].set_ylabel('Accuracy')
|
||||
axes[0, 1].legend()
|
||||
|
||||
# Confidence accuracy plot
|
||||
axes[1, 0].plot(self.training_history['confidence_accuracy'], label='Confidence Accuracy')
|
||||
axes[1, 0].set_title('Confidence Prediction Accuracy')
|
||||
axes[1, 0].set_xlabel('Epoch')
|
||||
axes[1, 0].set_ylabel('Accuracy')
|
||||
axes[1, 0].legend()
|
||||
|
||||
# Learning curves comparison
|
||||
axes[1, 1].plot(self.training_history['val_loss'], label='Validation Loss')
|
||||
axes[1, 1].plot(self.training_history['confidence_accuracy'], label='Confidence Accuracy')
|
||||
axes[1, 1].set_title('Model Performance Overview')
|
||||
axes[1, 1].set_xlabel('Epoch')
|
||||
axes[1, 1].legend()
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(self.save_dir / 'training_history.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
logger.info(f"Training plots saved to {self.save_dir / 'training_history.png'}")
|
||||
|
||||
def get_model(self) -> EnhancedCNNModel:
|
||||
"""Get the trained model"""
|
||||
return self.model
|
||||
|
||||
def close_tensorboard(self):
|
||||
"""Close TensorBoard writer if it exists"""
|
||||
if hasattr(self, 'writer') and self.writer:
|
||||
try:
|
||||
self.writer.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup when object is destroyed"""
|
||||
self.close_tensorboard()
|
||||
|
||||
def main():
|
||||
"""Main function for standalone CNN live training with backtesting and analysis"""
|
||||
import argparse
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
parser = argparse.ArgumentParser(description='Enhanced CNN Live Training with Backtesting and Analysis')
|
||||
parser.add_argument('--symbols', type=str, nargs='+', default=['ETH/USDT', 'BTC/USDT'],
|
||||
help='Trading symbols to train on')
|
||||
parser.add_argument('--timeframes', type=str, nargs='+', default=['1m', '5m', '15m', '1h'],
|
||||
help='Timeframes to use for training')
|
||||
parser.add_argument('--epochs', type=int, default=100,
|
||||
help='Number of training epochs')
|
||||
parser.add_argument('--batch-size', type=int, default=32,
|
||||
help='Training batch size')
|
||||
parser.add_argument('--learning-rate', type=float, default=0.001,
|
||||
help='Learning rate')
|
||||
parser.add_argument('--save-path', type=str, default='models/enhanced_cnn/live_trained_model.pt',
|
||||
help='Path to save the trained model')
|
||||
parser.add_argument('--enable-backtesting', action='store_true', default=True,
|
||||
help='Enable backtesting after training')
|
||||
parser.add_argument('--enable-analysis', action='store_true', default=True,
|
||||
help='Enable detailed analysis and reporting')
|
||||
parser.add_argument('--enable-live-validation', action='store_true', default=True,
|
||||
help='Enable live validation during training')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
logger.info("="*80)
|
||||
logger.info("ENHANCED CNN LIVE TRAINING WITH BACKTESTING & ANALYSIS")
|
||||
logger.info("="*80)
|
||||
logger.info(f"Symbols: {args.symbols}")
|
||||
logger.info(f"Timeframes: {args.timeframes}")
|
||||
logger.info(f"Epochs: {args.epochs}")
|
||||
logger.info(f"Batch Size: {args.batch_size}")
|
||||
logger.info(f"Learning Rate: {args.learning_rate}")
|
||||
logger.info(f"Save Path: {args.save_path}")
|
||||
logger.info(f"Backtesting: {'Enabled' if args.enable_backtesting else 'Disabled'}")
|
||||
logger.info(f"Analysis: {'Enabled' if args.enable_analysis else 'Disabled'}")
|
||||
logger.info(f"Live Validation: {'Enabled' if args.enable_live_validation else 'Disabled'}")
|
||||
logger.info("="*80)
|
||||
|
||||
try:
|
||||
# Update config with command line arguments
|
||||
config = get_config()
|
||||
config.update('symbols', args.symbols)
|
||||
config.update('timeframes', args.timeframes)
|
||||
config.update('training', {
|
||||
**config.training,
|
||||
'epochs': args.epochs,
|
||||
'batch_size': args.batch_size,
|
||||
'learning_rate': args.learning_rate
|
||||
})
|
||||
|
||||
# Initialize enhanced trainer
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.data_provider import DataProvider
|
||||
|
||||
data_provider = DataProvider(config)
|
||||
orchestrator = EnhancedTradingOrchestrator(data_provider)
|
||||
trainer = EnhancedCNNTrainer(config, orchestrator)
|
||||
|
||||
# Phase 1: Data Collection and Preparation
|
||||
logger.info("📊 Phase 1: Collecting and preparing training data...")
|
||||
training_data = trainer.collect_training_data(args.symbols, lookback_days=30)
|
||||
logger.info(f" Collected {len(training_data)} training samples")
|
||||
|
||||
# Phase 2: Model Training
|
||||
logger.info("Phase 2: Training Enhanced CNN Model...")
|
||||
training_results = trainer.train_on_perfect_moves(min_samples=1000)
|
||||
|
||||
logger.info("Training Results:")
|
||||
logger.info(f" Best Validation Accuracy: {training_results['best_val_accuracy']:.4f}")
|
||||
logger.info(f" Best Validation Loss: {training_results['best_val_loss']:.4f}")
|
||||
logger.info(f" Total Epochs: {training_results['epochs_completed']}")
|
||||
logger.info(f" Training Time: {training_results['total_time']:.2f}s")
|
||||
|
||||
# Phase 3: Model Evaluation
|
||||
logger.info("📈 Phase 3: Model Evaluation...")
|
||||
evaluation_results = trainer.evaluate_model(args.symbols[:1]) # Use first symbol for evaluation
|
||||
|
||||
logger.info("Evaluation Results:")
|
||||
logger.info(f" Test Accuracy: {evaluation_results['test_accuracy']:.4f}")
|
||||
logger.info(f" Test Loss: {evaluation_results['test_loss']:.4f}")
|
||||
logger.info(f" Confidence Score: {evaluation_results['avg_confidence']:.4f}")
|
||||
|
||||
# Phase 4: Backtesting (if enabled)
|
||||
if args.enable_backtesting:
|
||||
logger.info("📊 Phase 4: Backtesting...")
|
||||
|
||||
# Create backtest environment
|
||||
from trading.backtest_environment import BacktestEnvironment
|
||||
backtest_env = BacktestEnvironment(
|
||||
symbols=args.symbols,
|
||||
timeframes=args.timeframes,
|
||||
initial_balance=10000.0,
|
||||
data_provider=data_provider
|
||||
)
|
||||
|
||||
# Run backtest
|
||||
backtest_results = backtest_env.run_backtest_with_model(
|
||||
model=trainer.model,
|
||||
lookback_days=7, # Test on last 7 days
|
||||
max_trades_per_day=50
|
||||
)
|
||||
|
||||
logger.info("Backtesting Results:")
|
||||
logger.info(f" Total Returns: {backtest_results['total_return']:.2f}%")
|
||||
logger.info(f" Win Rate: {backtest_results['win_rate']:.2f}%")
|
||||
logger.info(f" Sharpe Ratio: {backtest_results['sharpe_ratio']:.4f}")
|
||||
logger.info(f" Max Drawdown: {backtest_results['max_drawdown']:.2f}%")
|
||||
logger.info(f" Total Trades: {backtest_results['total_trades']}")
|
||||
logger.info(f" Profit Factor: {backtest_results['profit_factor']:.4f}")
|
||||
|
||||
# Phase 5: Analysis and Reporting (if enabled)
|
||||
if args.enable_analysis:
|
||||
logger.info("📋 Phase 5: Analysis and Reporting...")
|
||||
|
||||
# Generate comprehensive analysis report
|
||||
analysis_report = trainer.generate_analysis_report(
|
||||
training_results=training_results,
|
||||
evaluation_results=evaluation_results,
|
||||
backtest_results=backtest_results if args.enable_backtesting else None
|
||||
)
|
||||
|
||||
# Save analysis report
|
||||
report_path = Path(args.save_path).parent / "analysis_report.json"
|
||||
report_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(report_path, 'w') as f:
|
||||
json.dump(analysis_report, f, indent=2, default=str)
|
||||
|
||||
logger.info(f" Analysis report saved: {report_path}")
|
||||
|
||||
# Generate performance plots
|
||||
plots_dir = Path(args.save_path).parent / "plots"
|
||||
plots_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
trainer.generate_performance_plots(
|
||||
training_results=training_results,
|
||||
evaluation_results=evaluation_results,
|
||||
save_dir=plots_dir
|
||||
)
|
||||
|
||||
logger.info(f" Performance plots saved: {plots_dir}")
|
||||
|
||||
# Phase 6: Model Saving
|
||||
logger.info("💾 Phase 6: Saving trained model...")
|
||||
model_path = Path(args.save_path)
|
||||
model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
trainer.model.save(str(model_path))
|
||||
logger.info(f" Model saved: {model_path}")
|
||||
|
||||
# Save training metadata
|
||||
metadata = {
|
||||
'training_config': {
|
||||
'symbols': args.symbols,
|
||||
'timeframes': args.timeframes,
|
||||
'epochs': args.epochs,
|
||||
'batch_size': args.batch_size,
|
||||
'learning_rate': args.learning_rate
|
||||
},
|
||||
'training_results': training_results,
|
||||
'evaluation_results': evaluation_results
|
||||
}
|
||||
|
||||
if args.enable_backtesting:
|
||||
metadata['backtest_results'] = backtest_results
|
||||
|
||||
metadata_path = model_path.with_suffix('.json')
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(metadata, f, indent=2, default=str)
|
||||
|
||||
logger.info(f" Training metadata saved: {metadata_path}")
|
||||
|
||||
# Phase 7: Live Validation (if enabled)
|
||||
if args.enable_live_validation:
|
||||
logger.info("🔄 Phase 7: Live Validation...")
|
||||
|
||||
# Test model on recent live data
|
||||
live_validation_results = trainer.run_live_validation(
|
||||
symbols=args.symbols[:1], # Use first symbol
|
||||
validation_hours=2 # Validate on last 2 hours
|
||||
)
|
||||
|
||||
logger.info("Live Validation Results:")
|
||||
logger.info(f" Prediction Accuracy: {live_validation_results['accuracy']:.2f}%")
|
||||
logger.info(f" Average Confidence: {live_validation_results['avg_confidence']:.4f}")
|
||||
logger.info(f" Predictions Made: {live_validation_results['total_predictions']}")
|
||||
|
||||
logger.info("="*80)
|
||||
logger.info("🎉 ENHANCED CNN LIVE TRAINING COMPLETED SUCCESSFULLY!")
|
||||
logger.info("="*80)
|
||||
logger.info(f"📊 Model Path: {model_path}")
|
||||
logger.info(f"📋 Metadata: {metadata_path}")
|
||||
if args.enable_analysis:
|
||||
logger.info(f"📈 Analysis Report: {report_path}")
|
||||
logger.info(f"📊 Performance Plots: {plots_dir}")
|
||||
logger.info("="*80)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Training interrupted by user")
|
||||
return 1
|
||||
except Exception as e:
|
||||
logger.error(f"Training failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
@ -1,584 +0,0 @@
|
||||
"""
|
||||
Enhanced Pivot-Based RL Trainer
|
||||
|
||||
Integrates Williams Market Structure pivot points with CNN predictions
|
||||
for improved trading decisions and training rewards.
|
||||
|
||||
Key Features:
|
||||
- Train RL model to buy/sell at local pivot points
|
||||
- CNN predicts next pivot to avoid late signals
|
||||
- Different thresholds for entry vs exit
|
||||
- Rewards for staying uninvested when uncertain
|
||||
- Uncertainty-based confidence adjustment
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from collections import deque, namedtuple
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union, TYPE_CHECKING
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from training.williams_market_structure import WilliamsMarketStructure, SwingType, SwingPoint
|
||||
|
||||
# Use TYPE_CHECKING to avoid circular import
|
||||
if TYPE_CHECKING:
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class PivotReward:
|
||||
"""Reward structure for pivot-based trading decisions"""
|
||||
|
||||
def __init__(self):
|
||||
# Pivot-based reward weights
|
||||
self.pivot_hit_bonus = 2.0 # Bonus for trading at actual pivot points
|
||||
self.pivot_anticipation_bonus = 1.5 # Bonus for trading before pivot (CNN prediction)
|
||||
self.wrong_direction_penalty = -1.0 # Penalty for trading opposite to pivot direction
|
||||
self.late_entry_penalty = -0.5 # Penalty for entering after pivot is confirmed
|
||||
|
||||
# Stay uninvested rewards
|
||||
self.uninvested_reward = 0.1 # Small positive reward for staying out of poor setups
|
||||
self.avoid_false_signal_bonus = 0.5 # Bonus for avoiding false signals
|
||||
|
||||
# Uncertainty penalties
|
||||
self.overconfidence_penalty = -0.3 # Penalty for being overconfident on losses
|
||||
self.underconfidence_penalty = -0.1 # Small penalty for being underconfident on wins
|
||||
|
||||
class EnhancedPivotRLTrainer:
|
||||
"""Enhanced RL trainer focused on Williams pivot points and CNN predictions"""
|
||||
|
||||
def __init__(self,
|
||||
data_provider: DataProvider = None,
|
||||
orchestrator: Optional["EnhancedTradingOrchestrator"] = None):
|
||||
|
||||
self.config = get_config()
|
||||
self.data_provider = data_provider or DataProvider()
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
# Initialize Williams Market Structure with CNN
|
||||
self.williams = WilliamsMarketStructure(
|
||||
swing_strengths=[2, 4, 6, 8, 10], # Multiple strengths for better detection
|
||||
enable_cnn_feature=True,
|
||||
training_data_provider=data_provider
|
||||
)
|
||||
|
||||
# Pivot tracking
|
||||
self.recent_pivots = deque(maxlen=50)
|
||||
self.pivot_predictions = deque(maxlen=20)
|
||||
self.trade_outcomes = deque(maxlen=100)
|
||||
|
||||
# Threshold management - different for entry vs exit
|
||||
self.entry_threshold = 0.65 # Higher threshold for entering positions
|
||||
self.exit_threshold = 0.35 # Lower threshold for exiting positions
|
||||
self.max_uninvested_reward_threshold = 0.60 # Stay out if confidence below this
|
||||
|
||||
# Confidence learning parameters
|
||||
self.confidence_history = deque(maxlen=200)
|
||||
self.mistake_severity_tracker = deque(maxlen=50)
|
||||
|
||||
# Reward calculator
|
||||
self.pivot_reward = PivotReward()
|
||||
|
||||
logger.info("Enhanced Pivot RL Trainer initialized")
|
||||
logger.info(f"Entry threshold: {self.entry_threshold:.2%}")
|
||||
logger.info(f"Exit threshold: {self.exit_threshold:.2%}")
|
||||
logger.info(f"Uninvested reward threshold: {self.max_uninvested_reward_threshold:.2%}")
|
||||
|
||||
def calculate_pivot_based_reward(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
market_data: pd.DataFrame,
|
||||
trade_outcome: Dict[str, Any]) -> float:
|
||||
"""
|
||||
Calculate enhanced reward based on pivot points and CNN predictions
|
||||
|
||||
Args:
|
||||
trade_decision: The trading decision made by the model
|
||||
market_data: Market data context
|
||||
trade_outcome: Actual trade outcome
|
||||
|
||||
Returns:
|
||||
Enhanced reward score
|
||||
"""
|
||||
try:
|
||||
base_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
confidence = trade_decision.get('confidence', 0.5)
|
||||
action = trade_decision.get('action', 'HOLD')
|
||||
entry_price = trade_decision.get('price', 0.0)
|
||||
exit_price = trade_outcome.get('exit_price', entry_price)
|
||||
duration = trade_outcome.get('duration', timedelta(0))
|
||||
|
||||
# Base PnL reward
|
||||
base_reward = base_pnl / 5.0
|
||||
|
||||
# 1. Pivot Point Analysis Rewards
|
||||
pivot_reward = self._calculate_pivot_rewards(
|
||||
trade_decision, market_data, trade_outcome
|
||||
)
|
||||
|
||||
# 2. CNN Prediction Accuracy Rewards
|
||||
cnn_reward = self._calculate_cnn_prediction_rewards(
|
||||
trade_decision, market_data, trade_outcome
|
||||
)
|
||||
|
||||
# 3. Uninvested Period Rewards
|
||||
uninvested_reward = self._calculate_uninvested_rewards(
|
||||
trade_decision, confidence
|
||||
)
|
||||
|
||||
# 4. Uncertainty-based Confidence Adjustment
|
||||
confidence_adjustment = self._calculate_confidence_adjustment(
|
||||
trade_decision, trade_outcome
|
||||
)
|
||||
|
||||
# 5. Time efficiency with pivot context
|
||||
time_reward = self._calculate_time_efficiency_reward(
|
||||
duration, base_pnl, market_data
|
||||
)
|
||||
|
||||
# Combine all rewards
|
||||
total_reward = (
|
||||
base_reward +
|
||||
pivot_reward +
|
||||
cnn_reward +
|
||||
uninvested_reward +
|
||||
confidence_adjustment +
|
||||
time_reward
|
||||
)
|
||||
|
||||
# Log detailed reward breakdown
|
||||
self._log_reward_breakdown(
|
||||
trade_decision, trade_outcome, {
|
||||
'base': base_reward,
|
||||
'pivot': pivot_reward,
|
||||
'cnn': cnn_reward,
|
||||
'uninvested': uninvested_reward,
|
||||
'confidence': confidence_adjustment,
|
||||
'time': time_reward,
|
||||
'total': total_reward
|
||||
}
|
||||
)
|
||||
|
||||
# Track for learning
|
||||
self._track_reward_outcome(trade_decision, trade_outcome, total_reward)
|
||||
|
||||
return np.clip(total_reward, -15.0, 10.0)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating pivot-based reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_pivot_rewards(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
market_data: pd.DataFrame,
|
||||
trade_outcome: Dict[str, Any]) -> float:
|
||||
"""Calculate rewards based on proximity to pivot points"""
|
||||
try:
|
||||
entry_price = trade_decision.get('price', 0.0)
|
||||
action = trade_decision.get('action', 'HOLD')
|
||||
entry_time = trade_decision.get('timestamp', datetime.now())
|
||||
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
|
||||
# Find recent pivot points from Williams analysis
|
||||
ohlcv_array = self._convert_dataframe_to_ohlcv_array(market_data)
|
||||
if ohlcv_array is None or len(ohlcv_array) < 20:
|
||||
return 0.0
|
||||
|
||||
# Get pivot points from Williams structure
|
||||
structure_levels = self.williams.calculate_recursive_pivot_points(ohlcv_array)
|
||||
if not structure_levels or 'level_0' not in structure_levels:
|
||||
return 0.0
|
||||
|
||||
level_0_pivots = structure_levels['level_0'].swing_points
|
||||
if not level_0_pivots:
|
||||
return 0.0
|
||||
|
||||
# Find closest pivot to entry
|
||||
closest_pivot = self._find_closest_pivot(entry_price, entry_time, level_0_pivots)
|
||||
if not closest_pivot:
|
||||
return 0.0
|
||||
|
||||
# Calculate distance to pivot (price and time)
|
||||
price_distance = abs(entry_price - closest_pivot.price) / closest_pivot.price
|
||||
time_distance = abs((entry_time - closest_pivot.timestamp).total_seconds()) / 3600.0 # hours
|
||||
|
||||
pivot_reward = 0.0
|
||||
|
||||
# Reward trading at or near pivot points
|
||||
if price_distance < 0.005: # Within 0.5% of pivot
|
||||
if time_distance < 0.5: # Within 30 minutes
|
||||
pivot_reward += self.pivot_reward.pivot_hit_bonus
|
||||
logger.debug(f"PIVOT HIT BONUS: {self.pivot_reward.pivot_hit_bonus:.2f}")
|
||||
|
||||
# Check if trade direction aligns with pivot
|
||||
if self._trade_aligns_with_pivot(action, closest_pivot, net_pnl):
|
||||
pivot_reward += self.pivot_reward.pivot_anticipation_bonus
|
||||
logger.debug(f"PIVOT DIRECTION BONUS: {self.pivot_reward.pivot_anticipation_bonus:.2f}")
|
||||
else:
|
||||
pivot_reward += self.pivot_reward.wrong_direction_penalty
|
||||
logger.debug(f"WRONG DIRECTION PENALTY: {self.pivot_reward.wrong_direction_penalty:.2f}")
|
||||
|
||||
# Penalty for late entry after pivot confirmation
|
||||
if time_distance > 2.0: # More than 2 hours after pivot
|
||||
pivot_reward += self.pivot_reward.late_entry_penalty
|
||||
logger.debug(f"LATE ENTRY PENALTY: {self.pivot_reward.late_entry_penalty:.2f}")
|
||||
|
||||
return pivot_reward
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating pivot rewards: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_cnn_prediction_rewards(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
market_data: pd.DataFrame,
|
||||
trade_outcome: Dict[str, Any]) -> float:
|
||||
"""Calculate rewards based on CNN pivot predictions"""
|
||||
try:
|
||||
# Check if we have CNN predictions available
|
||||
if not hasattr(self.williams, 'cnn_model') or not self.williams.cnn_model:
|
||||
return 0.0
|
||||
|
||||
action = trade_decision.get('action', 'HOLD')
|
||||
confidence = trade_decision.get('confidence', 0.5)
|
||||
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
|
||||
# Get latest CNN prediction if available
|
||||
# This would be the prediction made before the trade
|
||||
cnn_prediction = self._get_latest_cnn_prediction()
|
||||
if not cnn_prediction:
|
||||
return 0.0
|
||||
|
||||
cnn_reward = 0.0
|
||||
|
||||
# Reward for following CNN predictions that turn out correct
|
||||
predicted_direction = self._interpret_cnn_prediction(cnn_prediction)
|
||||
|
||||
if predicted_direction == action and net_pnl > 0:
|
||||
# CNN prediction was correct and we followed it
|
||||
cnn_reward += 1.0 * confidence # Scale by confidence
|
||||
logger.debug(f"CNN CORRECT FOLLOW: +{1.0 * confidence:.2f}")
|
||||
|
||||
elif predicted_direction != action and net_pnl < 0:
|
||||
# We didn't follow CNN and it was right (we were wrong)
|
||||
cnn_reward -= 0.5
|
||||
logger.debug(f"CNN IGNORE PENALTY: -0.5")
|
||||
|
||||
elif predicted_direction == action and net_pnl < 0:
|
||||
# We followed CNN but it was wrong
|
||||
cnn_reward -= 0.2 # Small penalty, CNN predictions can be wrong
|
||||
logger.debug(f"CNN WRONG FOLLOW: -0.2")
|
||||
|
||||
return cnn_reward
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating CNN prediction rewards: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_uninvested_rewards(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
confidence: float) -> float:
|
||||
"""Calculate rewards for staying uninvested when uncertain"""
|
||||
try:
|
||||
action = trade_decision.get('action', 'HOLD')
|
||||
|
||||
# Reward staying out when confidence is low
|
||||
if action == 'HOLD' and confidence < self.max_uninvested_reward_threshold:
|
||||
uninvested_reward = self.pivot_reward.uninvested_reward
|
||||
|
||||
# Bonus for avoiding very uncertain setups
|
||||
if confidence < 0.4:
|
||||
uninvested_reward += self.pivot_reward.avoid_false_signal_bonus
|
||||
logger.debug(f"AVOID FALSE SIGNAL BONUS: +{self.pivot_reward.avoid_false_signal_bonus:.2f}")
|
||||
|
||||
logger.debug(f"UNINVESTED REWARD: +{uninvested_reward:.2f}")
|
||||
return uninvested_reward
|
||||
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating uninvested rewards: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_confidence_adjustment(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
trade_outcome: Dict[str, Any]) -> float:
|
||||
"""Adjust rewards based on confidence vs outcome to reduce overconfidence"""
|
||||
try:
|
||||
confidence = trade_decision.get('confidence', 0.5)
|
||||
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
|
||||
confidence_adjustment = 0.0
|
||||
|
||||
# Track mistake severity
|
||||
mistake_severity = abs(net_pnl) if net_pnl < 0 else 0.0
|
||||
self.mistake_severity_tracker.append(mistake_severity)
|
||||
|
||||
# Penalize overconfidence on losses
|
||||
if net_pnl < 0 and confidence > 0.7:
|
||||
# High confidence but loss - penalize overconfidence
|
||||
overconfidence_factor = (confidence - 0.7) / 0.3 # 0-1 scale
|
||||
severity_factor = min(mistake_severity / 2.0, 1.0) # Scale by loss size
|
||||
|
||||
penalty = self.pivot_reward.overconfidence_penalty * overconfidence_factor * severity_factor
|
||||
confidence_adjustment += penalty
|
||||
|
||||
logger.debug(f"OVERCONFIDENCE PENALTY: {penalty:.2f} (conf: {confidence:.2f}, loss: ${net_pnl:.2f})")
|
||||
|
||||
# Small penalty for underconfidence on wins
|
||||
elif net_pnl > 0 and confidence < 0.4:
|
||||
underconfidence_factor = (0.4 - confidence) / 0.4 # 0-1 scale
|
||||
penalty = self.pivot_reward.underconfidence_penalty * underconfidence_factor
|
||||
confidence_adjustment += penalty
|
||||
|
||||
logger.debug(f"UNDERCONFIDENCE PENALTY: {penalty:.2f} (conf: {confidence:.2f}, profit: ${net_pnl:.2f})")
|
||||
|
||||
# Update confidence learning
|
||||
self._update_confidence_learning(confidence, net_pnl, mistake_severity)
|
||||
|
||||
return confidence_adjustment
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating confidence adjustment: {e}")
|
||||
return 0.0
|
||||
|
||||
def _calculate_time_efficiency_reward(self,
|
||||
duration: timedelta,
|
||||
net_pnl: float,
|
||||
market_data: pd.DataFrame) -> float:
|
||||
"""Calculate time-based rewards considering market context"""
|
||||
try:
|
||||
duration_hours = duration.total_seconds() / 3600.0
|
||||
|
||||
# Quick profitable trades get bonus
|
||||
if net_pnl > 0 and duration_hours < 0.5: # Less than 30 minutes
|
||||
return 0.3
|
||||
|
||||
# Holding losses too long gets penalty
|
||||
elif net_pnl < 0 and duration_hours > 2.0: # More than 2 hours
|
||||
return -0.5
|
||||
|
||||
return 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating time efficiency reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def update_thresholds_based_on_performance(self):
|
||||
"""Dynamically adjust entry/exit thresholds based on recent performance"""
|
||||
try:
|
||||
if len(self.trade_outcomes) < 20:
|
||||
return
|
||||
|
||||
recent_outcomes = list(self.trade_outcomes)[-20:]
|
||||
|
||||
# Calculate win rate and average PnL
|
||||
wins = sum(1 for outcome in recent_outcomes if outcome['net_pnl'] > 0)
|
||||
win_rate = wins / len(recent_outcomes)
|
||||
avg_pnl = np.mean([outcome['net_pnl'] for outcome in recent_outcomes])
|
||||
|
||||
# Adjust thresholds based on performance
|
||||
if win_rate < 0.4: # Low win rate - be more selective
|
||||
self.entry_threshold = min(self.entry_threshold + 0.02, 0.80)
|
||||
logger.info(f"Low win rate ({win_rate:.2%}) - increased entry threshold to {self.entry_threshold:.2%}")
|
||||
|
||||
elif win_rate > 0.6 and avg_pnl > 0: # High win rate - can be more aggressive
|
||||
self.entry_threshold = max(self.entry_threshold - 0.01, 0.50)
|
||||
logger.info(f"High win rate ({win_rate:.2%}) - decreased entry threshold to {self.entry_threshold:.2%}")
|
||||
|
||||
# Adjust exit threshold based on loss severity
|
||||
avg_loss_severity = np.mean(list(self.mistake_severity_tracker)) if self.mistake_severity_tracker else 0
|
||||
|
||||
if avg_loss_severity > 1.0: # Large average losses
|
||||
self.exit_threshold = max(self.exit_threshold - 0.01, 0.20)
|
||||
logger.info(f"High loss severity - decreased exit threshold to {self.exit_threshold:.2%}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating thresholds: {e}")
|
||||
|
||||
def get_current_thresholds(self) -> Dict[str, float]:
|
||||
"""Get current entry and exit thresholds"""
|
||||
return {
|
||||
'entry_threshold': self.entry_threshold,
|
||||
'exit_threshold': self.exit_threshold,
|
||||
'uninvested_threshold': self.max_uninvested_reward_threshold
|
||||
}
|
||||
|
||||
# Helper methods
|
||||
|
||||
def _convert_dataframe_to_ohlcv_array(self, df: pd.DataFrame) -> Optional[np.ndarray]:
|
||||
"""Convert pandas DataFrame to numpy array for Williams analysis"""
|
||||
try:
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
# Ensure we have required columns
|
||||
required_cols = ['open', 'high', 'low', 'close', 'volume']
|
||||
if not all(col in df.columns for col in required_cols):
|
||||
return None
|
||||
|
||||
# Convert to numpy array
|
||||
timestamps = df.index.astype(np.int64) // 10**9 # Convert to Unix timestamp
|
||||
ohlcv_array = np.column_stack([
|
||||
timestamps,
|
||||
df['open'].values,
|
||||
df['high'].values,
|
||||
df['low'].values,
|
||||
df['close'].values,
|
||||
df['volume'].values
|
||||
])
|
||||
|
||||
return ohlcv_array.astype(np.float64)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting DataFrame to OHLCV array: {e}")
|
||||
return None
|
||||
|
||||
def _find_closest_pivot(self,
|
||||
entry_price: float,
|
||||
entry_time: datetime,
|
||||
pivots: List[SwingPoint]) -> Optional[SwingPoint]:
|
||||
"""Find the closest pivot point to the trade entry"""
|
||||
try:
|
||||
if not pivots:
|
||||
return None
|
||||
|
||||
# Find pivot closest in time and price
|
||||
best_pivot = None
|
||||
best_score = float('inf')
|
||||
|
||||
for pivot in pivots:
|
||||
time_diff = abs((entry_time - pivot.timestamp).total_seconds()) / 3600.0
|
||||
price_diff = abs(entry_price - pivot.price) / pivot.price
|
||||
|
||||
# Combined score (weighted by time and price proximity)
|
||||
score = time_diff * 0.3 + price_diff * 100 # Weight price difference more heavily
|
||||
|
||||
if score < best_score:
|
||||
best_score = score
|
||||
best_pivot = pivot
|
||||
|
||||
return best_pivot
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding closest pivot: {e}")
|
||||
return None
|
||||
|
||||
def _trade_aligns_with_pivot(self,
|
||||
action: str,
|
||||
pivot: SwingPoint,
|
||||
net_pnl: float) -> bool:
|
||||
"""Check if trade direction aligns with pivot type and was profitable"""
|
||||
try:
|
||||
if net_pnl <= 0: # Only consider profitable trades as aligned
|
||||
return False
|
||||
|
||||
if action == 'BUY' and pivot.swing_type == SwingType.SWING_LOW:
|
||||
return True # Bought at/near swing low
|
||||
elif action == 'SELL' and pivot.swing_type == SwingType.SWING_HIGH:
|
||||
return True # Sold at/near swing high
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking trade alignment: {e}")
|
||||
return False
|
||||
|
||||
def _get_latest_cnn_prediction(self) -> Optional[np.ndarray]:
|
||||
"""Get the latest CNN prediction from Williams structure"""
|
||||
try:
|
||||
# This would access the Williams CNN model's latest prediction
|
||||
# For now, return None if not available
|
||||
if hasattr(self.williams, 'latest_cnn_prediction'):
|
||||
return self.williams.latest_cnn_prediction
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting CNN prediction: {e}")
|
||||
return None
|
||||
|
||||
def _interpret_cnn_prediction(self, prediction: np.ndarray) -> str:
|
||||
"""Interpret CNN prediction array to trading action"""
|
||||
try:
|
||||
if len(prediction) < 2:
|
||||
return 'HOLD'
|
||||
|
||||
# Assuming prediction format: [type, price] for level 0
|
||||
predicted_type = prediction[0] # 0 = LOW, 1 = HIGH
|
||||
|
||||
if predicted_type > 0.5:
|
||||
return 'SELL' # Expecting swing high - sell
|
||||
else:
|
||||
return 'BUY' # Expecting swing low - buy
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error interpreting CNN prediction: {e}")
|
||||
return 'HOLD'
|
||||
|
||||
def _update_confidence_learning(self,
|
||||
confidence: float,
|
||||
net_pnl: float,
|
||||
mistake_severity: float):
|
||||
"""Update confidence learning parameters"""
|
||||
try:
|
||||
self.confidence_history.append({
|
||||
'confidence': confidence,
|
||||
'net_pnl': net_pnl,
|
||||
'mistake_severity': mistake_severity,
|
||||
'timestamp': datetime.now()
|
||||
})
|
||||
|
||||
# Periodically update thresholds based on confidence patterns
|
||||
if len(self.confidence_history) % 10 == 0:
|
||||
self.update_thresholds_based_on_performance()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating confidence learning: {e}")
|
||||
|
||||
def _track_reward_outcome(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
trade_outcome: Dict[str, Any],
|
||||
total_reward: float):
|
||||
"""Track reward outcomes for analysis"""
|
||||
try:
|
||||
outcome_record = {
|
||||
'timestamp': datetime.now(),
|
||||
'action': trade_decision.get('action'),
|
||||
'confidence': trade_decision.get('confidence'),
|
||||
'net_pnl': trade_outcome.get('net_pnl'),
|
||||
'reward': total_reward,
|
||||
'duration': trade_outcome.get('duration')
|
||||
}
|
||||
|
||||
self.trade_outcomes.append(outcome_record)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error tracking reward outcome: {e}")
|
||||
|
||||
def _log_reward_breakdown(self,
|
||||
trade_decision: Dict[str, Any],
|
||||
trade_outcome: Dict[str, Any],
|
||||
rewards: Dict[str, float]):
|
||||
"""Log detailed reward breakdown"""
|
||||
try:
|
||||
action = trade_decision.get('action', 'UNKNOWN')
|
||||
confidence = trade_decision.get('confidence', 0.0)
|
||||
net_pnl = trade_outcome.get('net_pnl', 0.0)
|
||||
|
||||
logger.info(f"[REWARD] {action} (conf: {confidence:.2%}) PnL: ${net_pnl:.2f} -> Total: {rewards['total']:.2f}")
|
||||
logger.debug(f" Base: {rewards['base']:.2f}, Pivot: {rewards['pivot']:.2f}, CNN: {rewards['cnn']:.2f}")
|
||||
logger.debug(f" Uninvested: {rewards['uninvested']:.2f}, Confidence: {rewards['confidence']:.2f}, Time: {rewards['time']:.2f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging reward breakdown: {e}")
|
||||
|
||||
def create_enhanced_pivot_trainer(data_provider: DataProvider = None,
|
||||
orchestrator: Optional["EnhancedTradingOrchestrator"] = None) -> EnhancedPivotRLTrainer:
|
||||
"""Factory function to create enhanced pivot trainer"""
|
||||
return EnhancedPivotRLTrainer(data_provider, orchestrator)
|
@ -1,708 +0,0 @@
|
||||
"""
|
||||
Enhanced RL State Builder for Comprehensive Market Data Integration
|
||||
|
||||
This module implements the specification requirements for RL training with:
|
||||
- 300s of raw tick data for momentum detection
|
||||
- Multi-timeframe OHLCV data (1s, 1m, 1h, 1d) for ETH and BTC
|
||||
- CNN hidden layer features integration
|
||||
- CNN predictions from all timeframes
|
||||
- Pivot point predictions using Williams market structure
|
||||
- Market regime analysis
|
||||
|
||||
State Vector Components:
|
||||
- ETH tick data: ~3000 features (300s * 10 features/tick)
|
||||
- ETH OHLCV 1s: ~2400 features (300 bars * 8 features)
|
||||
- ETH OHLCV 1m: ~2400 features (300 bars * 8 features)
|
||||
- ETH OHLCV 1h: ~2400 features (300 bars * 8 features)
|
||||
- ETH OHLCV 1d: ~2400 features (300 bars * 8 features)
|
||||
- BTC reference: ~2400 features (300 bars * 8 features)
|
||||
- CNN features: ~512 features (hidden layer)
|
||||
- CNN predictions: ~16 features (4 timeframes * 4 outputs)
|
||||
- Pivot points: ~250 features (Williams structure)
|
||||
- Market regime: ~20 features
|
||||
Total: ~8000+ features
|
||||
"""
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
try:
|
||||
import ta
|
||||
except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning("TA-Lib not available, using pandas for technical indicators")
|
||||
ta = None
|
||||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass
|
||||
|
||||
from core.universal_data_adapter import UniversalDataStream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class TickData:
|
||||
"""Tick data structure"""
|
||||
timestamp: datetime
|
||||
price: float
|
||||
volume: float
|
||||
bid: float = 0.0
|
||||
ask: float = 0.0
|
||||
|
||||
@property
|
||||
def spread(self) -> float:
|
||||
return self.ask - self.bid if self.ask > 0 and self.bid > 0 else 0.0
|
||||
|
||||
@dataclass
|
||||
class OHLCVData:
|
||||
"""OHLCV data structure"""
|
||||
timestamp: datetime
|
||||
open: float
|
||||
high: float
|
||||
low: float
|
||||
close: float
|
||||
volume: float
|
||||
|
||||
# Technical indicators (optional)
|
||||
rsi: Optional[float] = None
|
||||
macd: Optional[float] = None
|
||||
bb_upper: Optional[float] = None
|
||||
bb_lower: Optional[float] = None
|
||||
sma_20: Optional[float] = None
|
||||
ema_12: Optional[float] = None
|
||||
atr: Optional[float] = None
|
||||
|
||||
@dataclass
|
||||
class StateComponentConfig:
|
||||
"""Configuration for state component sizes"""
|
||||
eth_ticks: int = 3000 # 300s * 10 features per tick
|
||||
eth_1s_ohlcv: int = 2400 # 300 bars * 8 features (OHLCV + indicators)
|
||||
eth_1m_ohlcv: int = 2400 # 300 bars * 8 features
|
||||
eth_1h_ohlcv: int = 2400 # 300 bars * 8 features
|
||||
eth_1d_ohlcv: int = 2400 # 300 bars * 8 features
|
||||
btc_reference: int = 2400 # BTC reference data
|
||||
cnn_features: int = 512 # CNN hidden layer features
|
||||
cnn_predictions: int = 16 # CNN predictions (4 timeframes * 4 outputs)
|
||||
pivot_points: int = 250 # Recursive pivot points (5 levels * 50 points)
|
||||
market_regime: int = 20 # Market regime features
|
||||
|
||||
@property
|
||||
def total_size(self) -> int:
|
||||
"""Calculate total state size"""
|
||||
return (self.eth_ticks + self.eth_1s_ohlcv + self.eth_1m_ohlcv +
|
||||
self.eth_1h_ohlcv + self.eth_1d_ohlcv + self.btc_reference +
|
||||
self.cnn_features + self.cnn_predictions + self.pivot_points +
|
||||
self.market_regime)
|
||||
|
||||
class EnhancedRLStateBuilder:
|
||||
"""
|
||||
Comprehensive RL state builder implementing specification requirements
|
||||
|
||||
Features:
|
||||
- 300s tick data processing with momentum detection
|
||||
- Multi-timeframe OHLCV integration
|
||||
- CNN hidden layer feature extraction
|
||||
- Pivot point calculation and integration
|
||||
- Market regime analysis
|
||||
- BTC reference data processing
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
self.config = config
|
||||
|
||||
# Data windows
|
||||
self.tick_window_seconds = 300 # 5 minutes of tick data
|
||||
self.ohlcv_window_bars = 300 # 300 bars for each timeframe
|
||||
|
||||
# State component sizes
|
||||
self.state_components = {
|
||||
'eth_ticks': 300 * 10, # 3000 features: tick data with derived features
|
||||
'eth_1s_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'eth_1m_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'eth_1h_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'eth_1d_ohlcv': 300 * 8, # 2400 features: OHLCV + indicators
|
||||
'btc_reference': 300 * 8, # 2400 features: BTC reference data
|
||||
'cnn_features': 512, # 512 features: CNN hidden layer
|
||||
'cnn_predictions': 16, # 16 features: CNN predictions (4 timeframes * 4 outputs)
|
||||
'pivot_points': 250, # 250 features: Williams market structure
|
||||
'market_regime': 20 # 20 features: Market regime indicators
|
||||
}
|
||||
|
||||
self.total_state_size = sum(self.state_components.values())
|
||||
|
||||
# Data buffers for maintaining windows
|
||||
self.tick_buffers = {}
|
||||
self.ohlcv_buffers = {}
|
||||
|
||||
# Normalization parameters
|
||||
self.normalization_params = self._initialize_normalization_params()
|
||||
|
||||
# Feature extractors
|
||||
self.momentum_detector = TickMomentumDetector()
|
||||
self.indicator_calculator = TechnicalIndicatorCalculator()
|
||||
self.regime_analyzer = MarketRegimeAnalyzer()
|
||||
|
||||
logger.info(f"Enhanced RL State Builder initialized")
|
||||
logger.info(f"Total state size: {self.total_state_size} features")
|
||||
logger.info(f"State components: {self.state_components}")
|
||||
|
||||
def build_rl_state(self,
|
||||
eth_ticks: List[TickData],
|
||||
eth_ohlcv: Dict[str, List[OHLCVData]],
|
||||
btc_ohlcv: Dict[str, List[OHLCVData]],
|
||||
cnn_hidden_features: Optional[Dict[str, np.ndarray]] = None,
|
||||
cnn_predictions: Optional[Dict[str, np.ndarray]] = None,
|
||||
pivot_data: Optional[Dict[str, Any]] = None) -> np.ndarray:
|
||||
"""
|
||||
Build comprehensive RL state vector from all data sources
|
||||
|
||||
Args:
|
||||
eth_ticks: List of ETH tick data (last 300s)
|
||||
eth_ohlcv: Dict of ETH OHLCV data by timeframe
|
||||
btc_ohlcv: Dict of BTC OHLCV data by timeframe
|
||||
cnn_hidden_features: CNN hidden layer features by timeframe
|
||||
cnn_predictions: CNN predictions by timeframe
|
||||
pivot_data: Pivot point data from Williams analysis
|
||||
|
||||
Returns:
|
||||
np.ndarray: Comprehensive state vector (~8000+ features)
|
||||
"""
|
||||
try:
|
||||
state_vector = []
|
||||
|
||||
# 1. Process ETH tick data (3000 features)
|
||||
tick_features = self._process_tick_data(eth_ticks)
|
||||
state_vector.extend(tick_features)
|
||||
|
||||
# 2. Process ETH multi-timeframe OHLCV (9600 features total)
|
||||
for timeframe in ['1s', '1m', '1h', '1d']:
|
||||
if timeframe in eth_ohlcv:
|
||||
ohlcv_features = self._process_ohlcv_data(
|
||||
eth_ohlcv[timeframe], timeframe, symbol='ETH'
|
||||
)
|
||||
else:
|
||||
ohlcv_features = np.zeros(self.state_components[f'eth_{timeframe}_ohlcv'])
|
||||
state_vector.extend(ohlcv_features)
|
||||
|
||||
# 3. Process BTC reference data (2400 features)
|
||||
btc_features = self._process_btc_reference_data(btc_ohlcv)
|
||||
state_vector.extend(btc_features)
|
||||
|
||||
# 4. Process CNN hidden layer features (512 features)
|
||||
cnn_hidden = self._process_cnn_hidden_features(cnn_hidden_features)
|
||||
state_vector.extend(cnn_hidden)
|
||||
|
||||
# 5. Process CNN predictions (16 features)
|
||||
cnn_pred = self._process_cnn_predictions(cnn_predictions)
|
||||
state_vector.extend(cnn_pred)
|
||||
|
||||
# 6. Process pivot points (250 features)
|
||||
pivot_features = self._process_pivot_points(pivot_data, eth_ohlcv)
|
||||
state_vector.extend(pivot_features)
|
||||
|
||||
# 7. Process market regime features (20 features)
|
||||
regime_features = self._process_market_regime(eth_ohlcv, btc_ohlcv)
|
||||
state_vector.extend(regime_features)
|
||||
|
||||
# Convert to numpy array and validate size
|
||||
state_array = np.array(state_vector, dtype=np.float32)
|
||||
|
||||
if len(state_array) != self.total_state_size:
|
||||
logger.warning(f"State size mismatch: expected {self.total_state_size}, got {len(state_array)}")
|
||||
# Pad or truncate to expected size
|
||||
if len(state_array) < self.total_state_size:
|
||||
padding = np.zeros(self.total_state_size - len(state_array))
|
||||
state_array = np.concatenate([state_array, padding])
|
||||
else:
|
||||
state_array = state_array[:self.total_state_size]
|
||||
|
||||
# Apply normalization
|
||||
state_array = self._normalize_state(state_array)
|
||||
|
||||
return state_array
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building RL state: {e}")
|
||||
# Return zero state on error
|
||||
return np.zeros(self.total_state_size, dtype=np.float32)
|
||||
|
||||
def _process_tick_data(self, ticks: List[TickData]) -> List[float]:
|
||||
"""Process raw tick data into features for momentum detection"""
|
||||
features = []
|
||||
|
||||
if not ticks or len(ticks) < 10:
|
||||
# Return zeros if insufficient data
|
||||
return [0.0] * self.state_components['eth_ticks']
|
||||
|
||||
# Ensure we have exactly 300 data points (pad or sample)
|
||||
processed_ticks = self._normalize_tick_window(ticks, 300)
|
||||
|
||||
for i, tick in enumerate(processed_ticks):
|
||||
# Basic tick features
|
||||
tick_features = [
|
||||
tick.price,
|
||||
tick.volume,
|
||||
tick.bid,
|
||||
tick.ask,
|
||||
tick.spread
|
||||
]
|
||||
|
||||
# Derived features
|
||||
if i > 0:
|
||||
prev_tick = processed_ticks[i-1]
|
||||
price_change = (tick.price - prev_tick.price) / prev_tick.price if prev_tick.price > 0 else 0
|
||||
volume_change = (tick.volume - prev_tick.volume) / prev_tick.volume if prev_tick.volume > 0 else 0
|
||||
|
||||
tick_features.extend([
|
||||
price_change,
|
||||
volume_change,
|
||||
tick.price / prev_tick.price - 1.0 if prev_tick.price > 0 else 0, # Price ratio
|
||||
np.log(tick.volume / prev_tick.volume) if prev_tick.volume > 0 else 0, # Log volume ratio
|
||||
self.momentum_detector.calculate_micro_momentum(processed_ticks[max(0, i-5):i+1])
|
||||
])
|
||||
else:
|
||||
tick_features.extend([0.0, 0.0, 0.0, 0.0, 0.0])
|
||||
|
||||
features.extend(tick_features)
|
||||
|
||||
return features[:self.state_components['eth_ticks']]
|
||||
|
||||
def _process_ohlcv_data(self, ohlcv_data: List[OHLCVData],
|
||||
timeframe: str, symbol: str = 'ETH') -> List[float]:
|
||||
"""Process OHLCV data with technical indicators"""
|
||||
features = []
|
||||
|
||||
if not ohlcv_data or len(ohlcv_data) < 20:
|
||||
component_key = f'{symbol.lower()}_{timeframe}_ohlcv' if symbol == 'ETH' else 'btc_reference'
|
||||
return [0.0] * self.state_components[component_key]
|
||||
|
||||
# Convert to DataFrame for indicator calculation
|
||||
df = pd.DataFrame([{
|
||||
'timestamp': bar.timestamp,
|
||||
'open': bar.open,
|
||||
'high': bar.high,
|
||||
'low': bar.low,
|
||||
'close': bar.close,
|
||||
'volume': bar.volume
|
||||
} for bar in ohlcv_data[-self.ohlcv_window_bars:]])
|
||||
|
||||
# Calculate technical indicators
|
||||
df = self.indicator_calculator.add_all_indicators(df)
|
||||
|
||||
# Ensure we have exactly 300 bars
|
||||
if len(df) < 300:
|
||||
# Pad with last known values
|
||||
last_row = df.iloc[-1:].copy()
|
||||
padding_rows = []
|
||||
for _ in range(300 - len(df)):
|
||||
padding_rows.append(last_row)
|
||||
if padding_rows:
|
||||
df = pd.concat([df] + padding_rows, ignore_index=True)
|
||||
else:
|
||||
df = df.tail(300)
|
||||
|
||||
# Extract features for each bar
|
||||
feature_columns = ['open', 'high', 'low', 'close', 'volume', 'rsi', 'macd', 'bb_middle']
|
||||
|
||||
for _, row in df.iterrows():
|
||||
bar_features = []
|
||||
for col in feature_columns:
|
||||
if col in row and not pd.isna(row[col]):
|
||||
bar_features.append(float(row[col]))
|
||||
else:
|
||||
bar_features.append(0.0)
|
||||
features.extend(bar_features)
|
||||
|
||||
component_key = f'{symbol.lower()}_{timeframe}_ohlcv' if symbol == 'ETH' else 'btc_reference'
|
||||
return features[:self.state_components[component_key]]
|
||||
|
||||
def _process_btc_reference_data(self, btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Process BTC reference data (using 1h timeframe as primary)"""
|
||||
if '1h' in btc_ohlcv and btc_ohlcv['1h']:
|
||||
return self._process_ohlcv_data(btc_ohlcv['1h'], '1h', 'BTC')
|
||||
elif '1m' in btc_ohlcv and btc_ohlcv['1m']:
|
||||
return self._process_ohlcv_data(btc_ohlcv['1m'], '1m', 'BTC')
|
||||
else:
|
||||
return [0.0] * self.state_components['btc_reference']
|
||||
|
||||
def _process_cnn_hidden_features(self, cnn_features: Optional[Dict[str, np.ndarray]]) -> List[float]:
|
||||
"""Process CNN hidden layer features"""
|
||||
if not cnn_features:
|
||||
return [0.0] * self.state_components['cnn_features']
|
||||
|
||||
# Combine features from all timeframes
|
||||
combined_features = []
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
features_per_timeframe = self.state_components['cnn_features'] // len(timeframes)
|
||||
|
||||
for tf in timeframes:
|
||||
if tf in cnn_features and cnn_features[tf] is not None:
|
||||
tf_features = cnn_features[tf].flatten()
|
||||
# Truncate or pad to fit allocation
|
||||
if len(tf_features) >= features_per_timeframe:
|
||||
combined_features.extend(tf_features[:features_per_timeframe])
|
||||
else:
|
||||
combined_features.extend(tf_features)
|
||||
combined_features.extend([0.0] * (features_per_timeframe - len(tf_features)))
|
||||
else:
|
||||
combined_features.extend([0.0] * features_per_timeframe)
|
||||
|
||||
return combined_features[:self.state_components['cnn_features']]
|
||||
|
||||
def _process_cnn_predictions(self, cnn_predictions: Optional[Dict[str, np.ndarray]]) -> List[float]:
|
||||
"""Process CNN predictions from all timeframes"""
|
||||
if not cnn_predictions:
|
||||
return [0.0] * self.state_components['cnn_predictions']
|
||||
|
||||
predictions = []
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
for tf in timeframes:
|
||||
if tf in cnn_predictions and cnn_predictions[tf] is not None:
|
||||
pred = cnn_predictions[tf].flatten()
|
||||
# Expecting 4 outputs per timeframe (BUY, SELL, HOLD, confidence)
|
||||
if len(pred) >= 4:
|
||||
predictions.extend(pred[:4])
|
||||
else:
|
||||
predictions.extend(pred)
|
||||
predictions.extend([0.0] * (4 - len(pred)))
|
||||
else:
|
||||
predictions.extend([0.0, 0.0, 1.0, 0.0]) # Default to HOLD with 0 confidence
|
||||
|
||||
return predictions[:self.state_components['cnn_predictions']]
|
||||
|
||||
def _process_pivot_points(self, pivot_data: Optional[Dict[str, Any]],
|
||||
eth_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Process pivot points using Williams market structure"""
|
||||
if pivot_data:
|
||||
# Use provided pivot data
|
||||
return self._extract_pivot_features(pivot_data)
|
||||
elif '1m' in eth_ohlcv and eth_ohlcv['1m']:
|
||||
# Calculate pivot points from 1m data
|
||||
from training.williams_market_structure import WilliamsMarketStructure
|
||||
williams = WilliamsMarketStructure()
|
||||
|
||||
# Convert OHLCV to numpy array
|
||||
ohlcv_array = self._ohlcv_to_array(eth_ohlcv['1m'])
|
||||
pivot_data = williams.calculate_recursive_pivot_points(ohlcv_array)
|
||||
return self._extract_pivot_features(pivot_data)
|
||||
else:
|
||||
return [0.0] * self.state_components['pivot_points']
|
||||
|
||||
def _process_market_regime(self, eth_ohlcv: Dict[str, List[OHLCVData]],
|
||||
btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Process market regime indicators"""
|
||||
regime_features = []
|
||||
|
||||
# ETH regime analysis
|
||||
if '1h' in eth_ohlcv and eth_ohlcv['1h']:
|
||||
eth_regime = self.regime_analyzer.analyze_regime(eth_ohlcv['1h'])
|
||||
regime_features.extend([
|
||||
eth_regime['volatility'],
|
||||
eth_regime['trend_strength'],
|
||||
eth_regime['volume_trend'],
|
||||
eth_regime['momentum'],
|
||||
1.0 if eth_regime['regime'] == 'trending' else 0.0,
|
||||
1.0 if eth_regime['regime'] == 'ranging' else 0.0,
|
||||
1.0 if eth_regime['regime'] == 'volatile' else 0.0
|
||||
])
|
||||
else:
|
||||
regime_features.extend([0.0] * 7)
|
||||
|
||||
# BTC regime analysis
|
||||
if '1h' in btc_ohlcv and btc_ohlcv['1h']:
|
||||
btc_regime = self.regime_analyzer.analyze_regime(btc_ohlcv['1h'])
|
||||
regime_features.extend([
|
||||
btc_regime['volatility'],
|
||||
btc_regime['trend_strength'],
|
||||
btc_regime['volume_trend'],
|
||||
btc_regime['momentum'],
|
||||
1.0 if btc_regime['regime'] == 'trending' else 0.0,
|
||||
1.0 if btc_regime['regime'] == 'ranging' else 0.0,
|
||||
1.0 if btc_regime['regime'] == 'volatile' else 0.0
|
||||
])
|
||||
else:
|
||||
regime_features.extend([0.0] * 7)
|
||||
|
||||
# Correlation features
|
||||
correlation_features = self._calculate_btc_eth_correlation(eth_ohlcv, btc_ohlcv)
|
||||
regime_features.extend(correlation_features)
|
||||
|
||||
return regime_features[:self.state_components['market_regime']]
|
||||
|
||||
def _normalize_tick_window(self, ticks: List[TickData], target_size: int) -> List[TickData]:
|
||||
"""Normalize tick window to target size"""
|
||||
if len(ticks) == target_size:
|
||||
return ticks
|
||||
elif len(ticks) > target_size:
|
||||
# Sample evenly
|
||||
step = len(ticks) / target_size
|
||||
indices = [int(i * step) for i in range(target_size)]
|
||||
return [ticks[i] for i in indices]
|
||||
else:
|
||||
# Pad with last tick
|
||||
result = ticks.copy()
|
||||
last_tick = ticks[-1] if ticks else TickData(datetime.now(), 0, 0)
|
||||
while len(result) < target_size:
|
||||
result.append(last_tick)
|
||||
return result
|
||||
|
||||
def _extract_pivot_features(self, pivot_data: Dict[str, Any]) -> List[float]:
|
||||
"""Extract features from pivot point data"""
|
||||
features = []
|
||||
|
||||
for level in range(5): # 5 levels of recursion
|
||||
level_key = f'level_{level}'
|
||||
if level_key in pivot_data:
|
||||
level_data = pivot_data[level_key]
|
||||
|
||||
# Swing point features
|
||||
swing_points = level_data.get('swing_points', [])
|
||||
if swing_points:
|
||||
# Last 10 swing points
|
||||
recent_swings = swing_points[-10:]
|
||||
for swing in recent_swings:
|
||||
features.extend([
|
||||
swing['price'],
|
||||
1.0 if swing['type'] == 'swing_high' else 0.0,
|
||||
swing['index']
|
||||
])
|
||||
|
||||
# Pad if fewer than 10 swings
|
||||
while len(recent_swings) < 10:
|
||||
features.extend([0.0, 0.0, 0.0])
|
||||
recent_swings.append({'type': 'none'})
|
||||
else:
|
||||
features.extend([0.0] * 30) # 10 swings * 3 features
|
||||
|
||||
# Trend features
|
||||
features.extend([
|
||||
level_data.get('trend_strength', 0.0),
|
||||
1.0 if level_data.get('trend_direction') == 'up' else 0.0,
|
||||
1.0 if level_data.get('trend_direction') == 'down' else 0.0
|
||||
])
|
||||
else:
|
||||
features.extend([0.0] * 33) # 30 swing + 3 trend features
|
||||
|
||||
return features[:self.state_components['pivot_points']]
|
||||
|
||||
def _ohlcv_to_array(self, ohlcv_data: List[OHLCVData]) -> np.ndarray:
|
||||
"""Convert OHLCV data to numpy array"""
|
||||
return np.array([[
|
||||
bar.timestamp.timestamp(),
|
||||
bar.open,
|
||||
bar.high,
|
||||
bar.low,
|
||||
bar.close,
|
||||
bar.volume
|
||||
] for bar in ohlcv_data])
|
||||
|
||||
def _calculate_btc_eth_correlation(self, eth_ohlcv: Dict[str, List[OHLCVData]],
|
||||
btc_ohlcv: Dict[str, List[OHLCVData]]) -> List[float]:
|
||||
"""Calculate BTC-ETH correlation features"""
|
||||
try:
|
||||
# Use 1h data for correlation
|
||||
if '1h' not in eth_ohlcv or '1h' not in btc_ohlcv:
|
||||
return [0.0] * 6
|
||||
|
||||
eth_prices = [bar.close for bar in eth_ohlcv['1h'][-50:]] # Last 50 hours
|
||||
btc_prices = [bar.close for bar in btc_ohlcv['1h'][-50:]]
|
||||
|
||||
if len(eth_prices) < 10 or len(btc_prices) < 10:
|
||||
return [0.0] * 6
|
||||
|
||||
# Align lengths
|
||||
min_len = min(len(eth_prices), len(btc_prices))
|
||||
eth_prices = eth_prices[-min_len:]
|
||||
btc_prices = btc_prices[-min_len:]
|
||||
|
||||
# Calculate returns
|
||||
eth_returns = np.diff(eth_prices) / eth_prices[:-1]
|
||||
btc_returns = np.diff(btc_prices) / btc_prices[:-1]
|
||||
|
||||
# Correlation
|
||||
correlation = np.corrcoef(eth_returns, btc_returns)[0, 1] if len(eth_returns) > 1 else 0.0
|
||||
|
||||
# Price ratio
|
||||
current_ratio = eth_prices[-1] / btc_prices[-1] if btc_prices[-1] > 0 else 0.0
|
||||
avg_ratio = np.mean([e/b for e, b in zip(eth_prices, btc_prices) if b > 0])
|
||||
ratio_deviation = (current_ratio - avg_ratio) / avg_ratio if avg_ratio > 0 else 0.0
|
||||
|
||||
# Volatility comparison
|
||||
eth_vol = np.std(eth_returns) if len(eth_returns) > 1 else 0.0
|
||||
btc_vol = np.std(btc_returns) if len(btc_returns) > 1 else 0.0
|
||||
vol_ratio = eth_vol / btc_vol if btc_vol > 0 else 1.0
|
||||
|
||||
return [
|
||||
correlation,
|
||||
current_ratio,
|
||||
ratio_deviation,
|
||||
vol_ratio,
|
||||
eth_vol,
|
||||
btc_vol
|
||||
]
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating BTC-ETH correlation: {e}")
|
||||
return [0.0] * 6
|
||||
|
||||
def _initialize_normalization_params(self) -> Dict[str, Dict[str, float]]:
|
||||
"""Initialize normalization parameters for different feature types"""
|
||||
return {
|
||||
'price_features': {'mean': 0.0, 'std': 1.0, 'min': -10.0, 'max': 10.0},
|
||||
'volume_features': {'mean': 0.0, 'std': 1.0, 'min': -5.0, 'max': 5.0},
|
||||
'indicator_features': {'mean': 0.0, 'std': 1.0, 'min': -3.0, 'max': 3.0},
|
||||
'cnn_features': {'mean': 0.0, 'std': 1.0, 'min': -2.0, 'max': 2.0},
|
||||
'pivot_features': {'mean': 0.0, 'std': 1.0, 'min': -5.0, 'max': 5.0}
|
||||
}
|
||||
|
||||
def _normalize_state(self, state: np.ndarray) -> np.ndarray:
|
||||
"""Apply normalization to state vector"""
|
||||
try:
|
||||
# Simple clipping and scaling for now
|
||||
# More sophisticated normalization can be added based on training data
|
||||
normalized_state = np.clip(state, -10.0, 10.0)
|
||||
|
||||
# Replace any NaN or inf values
|
||||
normalized_state = np.nan_to_num(normalized_state, nan=0.0, posinf=10.0, neginf=-10.0)
|
||||
|
||||
return normalized_state.astype(np.float32)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error normalizing state: {e}")
|
||||
return state.astype(np.float32)
|
||||
|
||||
class TickMomentumDetector:
|
||||
"""Detect momentum from tick-level data"""
|
||||
|
||||
def calculate_micro_momentum(self, ticks: List[TickData]) -> float:
|
||||
"""Calculate micro-momentum from tick sequence"""
|
||||
if len(ticks) < 2:
|
||||
return 0.0
|
||||
|
||||
# Price momentum
|
||||
prices = [tick.price for tick in ticks]
|
||||
price_changes = np.diff(prices)
|
||||
price_momentum = np.sum(price_changes) / len(price_changes) if len(price_changes) > 0 else 0.0
|
||||
|
||||
# Volume-weighted momentum
|
||||
volumes = [tick.volume for tick in ticks]
|
||||
if sum(volumes) > 0:
|
||||
weighted_changes = [pc * v for pc, v in zip(price_changes, volumes[1:])]
|
||||
volume_momentum = sum(weighted_changes) / sum(volumes[1:])
|
||||
else:
|
||||
volume_momentum = 0.0
|
||||
|
||||
return (price_momentum + volume_momentum) / 2.0
|
||||
|
||||
class TechnicalIndicatorCalculator:
|
||||
"""Calculate technical indicators for OHLCV data"""
|
||||
|
||||
def add_all_indicators(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Add all technical indicators to DataFrame"""
|
||||
df = df.copy()
|
||||
|
||||
# RSI
|
||||
df['rsi'] = self.calculate_rsi(df['close'])
|
||||
|
||||
# MACD
|
||||
df['macd'] = self.calculate_macd(df['close'])
|
||||
|
||||
# Bollinger Bands
|
||||
df['bb_middle'] = df['close'].rolling(20).mean()
|
||||
df['bb_std'] = df['close'].rolling(20).std()
|
||||
df['bb_upper'] = df['bb_middle'] + (df['bb_std'] * 2)
|
||||
df['bb_lower'] = df['bb_middle'] - (df['bb_std'] * 2)
|
||||
|
||||
# Fill NaN values
|
||||
df = df.fillna(method='forward').fillna(0)
|
||||
|
||||
return df
|
||||
|
||||
def calculate_rsi(self, prices: pd.Series, period: int = 14) -> pd.Series:
|
||||
"""Calculate RSI"""
|
||||
delta = prices.diff()
|
||||
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
|
||||
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
|
||||
rs = gain / loss
|
||||
rsi = 100 - (100 / (1 + rs))
|
||||
return rsi.fillna(50)
|
||||
|
||||
def calculate_macd(self, prices: pd.Series, fast: int = 12, slow: int = 26) -> pd.Series:
|
||||
"""Calculate MACD"""
|
||||
ema_fast = prices.ewm(span=fast).mean()
|
||||
ema_slow = prices.ewm(span=slow).mean()
|
||||
macd = ema_fast - ema_slow
|
||||
return macd.fillna(0)
|
||||
|
||||
class MarketRegimeAnalyzer:
|
||||
"""Analyze market regime from OHLCV data"""
|
||||
|
||||
def analyze_regime(self, ohlcv_data: List[OHLCVData]) -> Dict[str, Any]:
|
||||
"""Analyze market regime"""
|
||||
if len(ohlcv_data) < 20:
|
||||
return {
|
||||
'regime': 'unknown',
|
||||
'volatility': 0.0,
|
||||
'trend_strength': 0.0,
|
||||
'volume_trend': 0.0,
|
||||
'momentum': 0.0
|
||||
}
|
||||
|
||||
prices = [bar.close for bar in ohlcv_data[-50:]] # Last 50 bars
|
||||
volumes = [bar.volume for bar in ohlcv_data[-50:]]
|
||||
|
||||
# Calculate volatility
|
||||
returns = np.diff(prices) / prices[:-1]
|
||||
volatility = np.std(returns) * 100 # Percentage volatility
|
||||
|
||||
# Calculate trend strength
|
||||
sma_short = np.mean(prices[-10:])
|
||||
sma_long = np.mean(prices[-30:])
|
||||
trend_strength = abs(sma_short - sma_long) / sma_long if sma_long > 0 else 0.0
|
||||
|
||||
# Volume trend
|
||||
volume_ma_short = np.mean(volumes[-10:])
|
||||
volume_ma_long = np.mean(volumes[-30:])
|
||||
volume_trend = (volume_ma_short - volume_ma_long) / volume_ma_long if volume_ma_long > 0 else 0.0
|
||||
|
||||
# Momentum
|
||||
momentum = (prices[-1] - prices[-10]) / prices[-10] if len(prices) >= 10 and prices[-10] > 0 else 0.0
|
||||
|
||||
# Determine regime
|
||||
if volatility > 3.0: # High volatility
|
||||
regime = 'volatile'
|
||||
elif abs(momentum) > 0.02: # Strong momentum
|
||||
regime = 'trending'
|
||||
else:
|
||||
regime = 'ranging'
|
||||
|
||||
return {
|
||||
'regime': regime,
|
||||
'volatility': volatility,
|
||||
'trend_strength': trend_strength,
|
||||
'volume_trend': volume_trend,
|
||||
'momentum': momentum
|
||||
}
|
||||
|
||||
def get_state_info(self) -> Dict[str, Any]:
|
||||
"""Get information about the state structure"""
|
||||
return {
|
||||
'total_size': self.config.total_size,
|
||||
'components': {
|
||||
'eth_ticks': self.config.eth_ticks,
|
||||
'eth_1s_ohlcv': self.config.eth_1s_ohlcv,
|
||||
'eth_1m_ohlcv': self.config.eth_1m_ohlcv,
|
||||
'eth_1h_ohlcv': self.config.eth_1h_ohlcv,
|
||||
'eth_1d_ohlcv': self.config.eth_1d_ohlcv,
|
||||
'btc_reference': self.config.btc_reference,
|
||||
'cnn_features': self.config.cnn_features,
|
||||
'cnn_predictions': self.config.cnn_predictions,
|
||||
'pivot_points': self.config.pivot_points,
|
||||
'market_regime': self.config.market_regime,
|
||||
},
|
||||
'data_windows': {
|
||||
'tick_window_seconds': self.tick_window_seconds,
|
||||
'ohlcv_window_bars': self.ohlcv_window_bars,
|
||||
}
|
||||
}
|
@ -1,821 +0,0 @@
|
||||
"""
|
||||
Enhanced RL Trainer with Continuous Learning
|
||||
|
||||
This module implements sophisticated RL training with:
|
||||
- Prioritized experience replay
|
||||
- Market regime adaptation
|
||||
- Continuous learning from trading outcomes
|
||||
- Performance tracking and visualization
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from collections import deque, namedtuple
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Union
|
||||
import matplotlib.pyplot as plt
|
||||
from pathlib import Path
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator, MarketState, TradingAction
|
||||
from models import RLAgentInterface
|
||||
import models
|
||||
from training.enhanced_rl_state_builder import EnhancedRLStateBuilder
|
||||
from training.williams_market_structure import WilliamsMarketStructure
|
||||
from training.cnn_rl_bridge import CNNRLBridge
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Experience tuple for replay buffer
|
||||
Experience = namedtuple('Experience', ['state', 'action', 'reward', 'next_state', 'done', 'priority'])
|
||||
|
||||
class PrioritizedReplayBuffer:
|
||||
"""Prioritized experience replay buffer for RL training"""
|
||||
|
||||
def __init__(self, capacity: int = 10000, alpha: float = 0.6):
|
||||
"""
|
||||
Initialize prioritized replay buffer
|
||||
|
||||
Args:
|
||||
capacity: Maximum number of experiences to store
|
||||
alpha: Priority exponent (0 = uniform, 1 = fully prioritized)
|
||||
"""
|
||||
self.capacity = capacity
|
||||
self.alpha = alpha
|
||||
self.buffer = []
|
||||
self.priorities = np.zeros(capacity, dtype=np.float32)
|
||||
self.position = 0
|
||||
self.size = 0
|
||||
|
||||
def add(self, experience: Experience):
|
||||
"""Add experience to buffer with priority"""
|
||||
max_priority = self.priorities[:self.size].max() if self.size > 0 else 1.0
|
||||
|
||||
if self.size < self.capacity:
|
||||
self.buffer.append(experience)
|
||||
self.size += 1
|
||||
else:
|
||||
self.buffer[self.position] = experience
|
||||
|
||||
self.priorities[self.position] = max_priority
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
def sample(self, batch_size: int, beta: float = 0.4) -> Tuple[List[Experience], np.ndarray, np.ndarray]:
|
||||
"""Sample batch with prioritized sampling"""
|
||||
if self.size == 0:
|
||||
return [], np.array([]), np.array([])
|
||||
|
||||
# Calculate sampling probabilities
|
||||
priorities = self.priorities[:self.size] ** self.alpha
|
||||
probabilities = priorities / priorities.sum()
|
||||
|
||||
# Sample indices
|
||||
indices = np.random.choice(self.size, batch_size, p=probabilities)
|
||||
experiences = [self.buffer[i] for i in indices]
|
||||
|
||||
# Calculate importance sampling weights
|
||||
weights = (self.size * probabilities[indices]) ** (-beta)
|
||||
weights = weights / weights.max() # Normalize
|
||||
|
||||
return experiences, indices, weights
|
||||
|
||||
def update_priorities(self, indices: np.ndarray, priorities: np.ndarray):
|
||||
"""Update priorities for sampled experiences"""
|
||||
for idx, priority in zip(indices, priorities):
|
||||
self.priorities[idx] = priority + 1e-6 # Small epsilon to avoid zero priority
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
class EnhancedDQNAgent(nn.Module, RLAgentInterface):
|
||||
"""Enhanced DQN agent with market environment adaptation"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any]):
|
||||
nn.Module.__init__(self)
|
||||
RLAgentInterface.__init__(self, config)
|
||||
|
||||
# Network architecture
|
||||
self.state_size = config.get('state_size', 100)
|
||||
self.action_space = config.get('action_space', 3)
|
||||
self.hidden_size = config.get('hidden_size', 256)
|
||||
|
||||
# Build networks
|
||||
self._build_networks()
|
||||
|
||||
# Training parameters
|
||||
self.learning_rate = config.get('learning_rate', 0.0001)
|
||||
self.gamma = config.get('gamma', 0.99)
|
||||
self.epsilon = config.get('epsilon', 1.0)
|
||||
self.epsilon_decay = config.get('epsilon_decay', 0.995)
|
||||
self.epsilon_min = config.get('epsilon_min', 0.01)
|
||||
self.target_update_freq = config.get('target_update_freq', 1000)
|
||||
|
||||
# Initialize device and optimizer
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.to(self.device)
|
||||
self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
|
||||
# Experience replay
|
||||
self.replay_buffer = PrioritizedReplayBuffer(config.get('buffer_size', 10000))
|
||||
self.batch_size = config.get('batch_size', 64)
|
||||
|
||||
# Market adaptation
|
||||
self.market_regime_weights = {
|
||||
'trending': 1.2, # Higher confidence in trending markets
|
||||
'ranging': 0.8, # Lower confidence in ranging markets
|
||||
'volatile': 0.6 # Much lower confidence in volatile markets
|
||||
}
|
||||
|
||||
# Training statistics
|
||||
self.training_steps = 0
|
||||
self.losses = []
|
||||
self.rewards = []
|
||||
self.epsilon_history = []
|
||||
|
||||
logger.info(f"Enhanced DQN agent initialized with state size: {self.state_size}")
|
||||
|
||||
def _build_networks(self):
|
||||
"""Build main and target networks"""
|
||||
# Main network
|
||||
self.main_network = nn.Sequential(
|
||||
nn.Linear(self.state_size, self.hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(self.hidden_size, self.hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(self.hidden_size, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2)
|
||||
)
|
||||
|
||||
# Dueling network heads
|
||||
self.value_head = nn.Linear(128, 1)
|
||||
self.advantage_head = nn.Linear(128, self.action_space)
|
||||
|
||||
# Target network (copy of main network)
|
||||
self.target_network = nn.Sequential(
|
||||
nn.Linear(self.state_size, self.hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(self.hidden_size, self.hidden_size),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.3),
|
||||
nn.Linear(self.hidden_size, 128),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(0.2)
|
||||
)
|
||||
|
||||
self.target_value_head = nn.Linear(128, 1)
|
||||
self.target_advantage_head = nn.Linear(128, self.action_space)
|
||||
|
||||
# Initialize target network with same weights
|
||||
self._update_target_network()
|
||||
|
||||
def forward(self, state, target: bool = False):
|
||||
"""Forward pass through the network"""
|
||||
if target:
|
||||
features = self.target_network(state)
|
||||
value = self.target_value_head(features)
|
||||
advantage = self.target_advantage_head(features)
|
||||
else:
|
||||
features = self.main_network(state)
|
||||
value = self.value_head(features)
|
||||
advantage = self.advantage_head(features)
|
||||
|
||||
# Dueling architecture: Q(s,a) = V(s) + A(s,a) - mean(A(s,a))
|
||||
q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
|
||||
|
||||
return q_values
|
||||
|
||||
def act(self, state: np.ndarray) -> int:
|
||||
"""Choose action using epsilon-greedy policy"""
|
||||
if random.random() < self.epsilon:
|
||||
return random.randint(0, self.action_space - 1)
|
||||
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
q_values = self.forward(state_tensor)
|
||||
return q_values.argmax().item()
|
||||
|
||||
def act_with_confidence(self, state: np.ndarray, market_regime: str = 'trending') -> Tuple[int, float]:
|
||||
"""Choose action with confidence score adapted to market regime"""
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
q_values = self.forward(state_tensor)
|
||||
|
||||
# Convert Q-values to probabilities
|
||||
action_probs = torch.softmax(q_values, dim=1)
|
||||
action = q_values.argmax().item()
|
||||
base_confidence = action_probs[0, action].item()
|
||||
|
||||
# Adapt confidence based on market regime
|
||||
regime_weight = self.market_regime_weights.get(market_regime, 1.0)
|
||||
adapted_confidence = min(base_confidence * regime_weight, 1.0)
|
||||
|
||||
return action, adapted_confidence
|
||||
|
||||
def remember(self, state: np.ndarray, action: int, reward: float,
|
||||
next_state: np.ndarray, done: bool):
|
||||
"""Store experience in replay buffer"""
|
||||
# Calculate TD error for priority
|
||||
with torch.no_grad():
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
||||
next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(self.device)
|
||||
|
||||
current_q = self.forward(state_tensor)[0, action]
|
||||
next_q = self.forward(next_state_tensor, target=True).max(1)[0]
|
||||
target_q = reward + (self.gamma * next_q * (1 - done))
|
||||
|
||||
td_error = abs(current_q.item() - target_q.item())
|
||||
|
||||
experience = Experience(state, action, reward, next_state, done, td_error)
|
||||
self.replay_buffer.add(experience)
|
||||
|
||||
def replay(self) -> Optional[float]:
|
||||
"""Train the network on a batch of experiences"""
|
||||
if len(self.replay_buffer) < self.batch_size:
|
||||
return None
|
||||
|
||||
# Sample batch
|
||||
experiences, indices, weights = self.replay_buffer.sample(self.batch_size)
|
||||
|
||||
if not experiences:
|
||||
return None
|
||||
|
||||
# Convert to tensors
|
||||
states = torch.FloatTensor([e.state for e in experiences]).to(self.device)
|
||||
actions = torch.LongTensor([e.action for e in experiences]).to(self.device)
|
||||
rewards = torch.FloatTensor([e.reward for e in experiences]).to(self.device)
|
||||
next_states = torch.FloatTensor([e.next_state for e in experiences]).to(self.device)
|
||||
dones = torch.BoolTensor([e.done for e in experiences]).to(self.device)
|
||||
weights_tensor = torch.FloatTensor(weights).to(self.device)
|
||||
|
||||
# Current Q-values
|
||||
current_q_values = self.forward(states).gather(1, actions.unsqueeze(1))
|
||||
|
||||
# Target Q-values (Double DQN)
|
||||
with torch.no_grad():
|
||||
# Use main network to select actions
|
||||
next_actions = self.forward(next_states).argmax(1)
|
||||
# Use target network to evaluate actions
|
||||
next_q_values = self.forward(next_states, target=True).gather(1, next_actions.unsqueeze(1))
|
||||
target_q_values = rewards.unsqueeze(1) + (self.gamma * next_q_values * ~dones.unsqueeze(1))
|
||||
|
||||
# Calculate weighted loss
|
||||
td_errors = target_q_values - current_q_values
|
||||
loss = (weights_tensor * (td_errors ** 2)).mean()
|
||||
|
||||
# Optimize
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=1.0)
|
||||
self.optimizer.step()
|
||||
|
||||
# Update priorities
|
||||
new_priorities = torch.abs(td_errors).detach().cpu().numpy().flatten()
|
||||
self.replay_buffer.update_priorities(indices, new_priorities)
|
||||
|
||||
# Update target network
|
||||
self.training_steps += 1
|
||||
if self.training_steps % self.target_update_freq == 0:
|
||||
self._update_target_network()
|
||||
|
||||
# Decay epsilon
|
||||
if self.epsilon > self.epsilon_min:
|
||||
self.epsilon *= self.epsilon_decay
|
||||
|
||||
# Track statistics
|
||||
self.losses.append(loss.item())
|
||||
self.epsilon_history.append(self.epsilon)
|
||||
|
||||
return loss.item()
|
||||
|
||||
def _update_target_network(self):
|
||||
"""Update target network with main network weights"""
|
||||
self.target_network.load_state_dict(self.main_network.state_dict())
|
||||
self.target_value_head.load_state_dict(self.value_head.state_dict())
|
||||
self.target_advantage_head.load_state_dict(self.advantage_head.state_dict())
|
||||
|
||||
def predict(self, features: np.ndarray) -> Tuple[np.ndarray, float]:
|
||||
"""Predict action probabilities and confidence (required by ModelInterface)"""
|
||||
action, confidence = self.act_with_confidence(features)
|
||||
# Convert action to probabilities
|
||||
action_probs = np.zeros(self.action_space)
|
||||
action_probs[action] = 1.0
|
||||
return action_probs, confidence
|
||||
|
||||
def get_memory_usage(self) -> int:
|
||||
"""Get memory usage in MB"""
|
||||
if torch.cuda.is_available():
|
||||
return torch.cuda.memory_allocated(self.device) // (1024 * 1024)
|
||||
else:
|
||||
param_count = sum(p.numel() for p in self.parameters())
|
||||
buffer_size = len(self.replay_buffer) * self.state_size * 4 # Rough estimate
|
||||
return (param_count * 4 + buffer_size) // (1024 * 1024)
|
||||
|
||||
class EnhancedRLTrainer:
|
||||
"""Enhanced RL trainer with comprehensive state representation and real data integration"""
|
||||
|
||||
def __init__(self, config: Optional[Dict] = None, orchestrator: EnhancedTradingOrchestrator = None):
|
||||
"""Initialize enhanced RL trainer with comprehensive state building"""
|
||||
self.config = config or get_config()
|
||||
self.orchestrator = orchestrator
|
||||
|
||||
# Initialize comprehensive state builder (replaces mock code)
|
||||
self.state_builder = EnhancedRLStateBuilder(self.config)
|
||||
self.williams_structure = WilliamsMarketStructure()
|
||||
self.cnn_rl_bridge = CNNRLBridge(self.config) if hasattr(self.config, 'cnn_models') else None
|
||||
|
||||
# Enhanced RL agents with much larger state space
|
||||
self.agents = {}
|
||||
self.initialize_agents()
|
||||
|
||||
# Training configuration
|
||||
self.symbols = self.config.symbols
|
||||
self.save_dir = Path(self.config.rl.get('save_dir', 'models/rl/saved'))
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Performance tracking
|
||||
self.training_metrics = {
|
||||
'total_episodes': 0,
|
||||
'total_rewards': {symbol: [] for symbol in self.symbols},
|
||||
'losses': {symbol: [] for symbol in self.symbols},
|
||||
'epsilon_values': {symbol: [] for symbol in self.symbols}
|
||||
}
|
||||
|
||||
self.performance_history = {symbol: [] for symbol in self.symbols}
|
||||
|
||||
# Real-time learning parameters
|
||||
self.learning_active = False
|
||||
self.experience_buffer_size = 1000
|
||||
self.min_experiences_for_training = 100
|
||||
|
||||
logger.info("Enhanced RL Trainer initialized with comprehensive state representation")
|
||||
logger.info(f"State builder total size: {self.state_builder.total_state_size} features")
|
||||
logger.info(f"Symbols: {self.symbols}")
|
||||
|
||||
def initialize_agents(self):
|
||||
"""Initialize RL agents with enhanced state size"""
|
||||
for symbol in self.symbols:
|
||||
agent_config = {
|
||||
'state_size': self.state_builder.total_state_size, # ~13,400 features
|
||||
'action_space': 3, # BUY, SELL, HOLD
|
||||
'hidden_size': 1024, # Larger hidden layers for complex state
|
||||
'learning_rate': 0.0001,
|
||||
'gamma': 0.99,
|
||||
'epsilon': 1.0,
|
||||
'epsilon_decay': 0.995,
|
||||
'epsilon_min': 0.01,
|
||||
'buffer_size': 50000, # Larger replay buffer
|
||||
'batch_size': 128,
|
||||
'target_update_freq': 1000
|
||||
}
|
||||
|
||||
self.agents[symbol] = EnhancedDQNAgent(agent_config)
|
||||
logger.info(f"Initialized {symbol} RL agent with state size: {agent_config['state_size']}")
|
||||
|
||||
async def continuous_learning_loop(self):
|
||||
"""Main continuous learning loop"""
|
||||
logger.info("Starting continuous RL learning loop")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Train agents with recent experiences
|
||||
await self._train_all_agents()
|
||||
|
||||
# Evaluate recent actions
|
||||
if self.orchestrator:
|
||||
await self.orchestrator.evaluate_actions_with_rl()
|
||||
|
||||
# Adapt to market regime changes
|
||||
await self._adapt_to_market_changes()
|
||||
|
||||
# Update performance metrics
|
||||
self._update_performance_metrics()
|
||||
|
||||
# Save models periodically
|
||||
if self.training_metrics['total_episodes'] % 100 == 0:
|
||||
self._save_all_models()
|
||||
|
||||
# Wait before next training cycle
|
||||
await asyncio.sleep(3600) # Train every hour
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in continuous learning loop: {e}")
|
||||
await asyncio.sleep(60) # Wait 1 minute on error
|
||||
|
||||
async def _train_all_agents(self):
|
||||
"""Train all RL agents with their experiences"""
|
||||
for symbol, agent in self.agents.items():
|
||||
try:
|
||||
if len(agent.replay_buffer) >= self.min_experiences_for_training:
|
||||
# Train for multiple steps
|
||||
losses = []
|
||||
for _ in range(10): # Train 10 steps per cycle
|
||||
loss = agent.replay()
|
||||
if loss is not None:
|
||||
losses.append(loss)
|
||||
|
||||
if losses:
|
||||
avg_loss = np.mean(losses)
|
||||
self.training_metrics['losses'][symbol].append(avg_loss)
|
||||
self.training_metrics['epsilon_values'][symbol].append(agent.epsilon)
|
||||
|
||||
logger.info(f"Trained {symbol} RL agent: Loss={avg_loss:.4f}, Epsilon={agent.epsilon:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training {symbol} agent: {e}")
|
||||
|
||||
async def _adapt_to_market_changes(self):
|
||||
"""Adapt agents to market regime changes"""
|
||||
if not self.orchestrator:
|
||||
return
|
||||
|
||||
for symbol in self.symbols:
|
||||
try:
|
||||
# Get recent market states
|
||||
recent_states = list(self.orchestrator.market_states[symbol])[-10:] # Last 10 states
|
||||
|
||||
if len(recent_states) < 5:
|
||||
continue
|
||||
|
||||
# Analyze regime stability
|
||||
regimes = [state.market_regime for state in recent_states]
|
||||
regime_stability = len(set(regimes)) / len(regimes) # Lower = more stable
|
||||
|
||||
# Adjust learning parameters based on stability
|
||||
agent = self.agents[symbol]
|
||||
if regime_stability < 0.3: # Stable regime
|
||||
agent.epsilon *= 0.99 # Faster epsilon decay
|
||||
elif regime_stability > 0.7: # Unstable regime
|
||||
agent.epsilon = min(agent.epsilon * 1.01, 0.5) # Increase exploration
|
||||
|
||||
logger.debug(f"{symbol} regime stability: {regime_stability:.3f}, epsilon: {agent.epsilon:.3f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adapting {symbol} to market changes: {e}")
|
||||
|
||||
def add_trading_experience(self, symbol: str, action: TradingAction,
|
||||
initial_state: MarketState, final_state: MarketState,
|
||||
reward: float):
|
||||
"""Add trading experience to the appropriate agent"""
|
||||
if symbol not in self.agents:
|
||||
logger.warning(f"No agent for symbol {symbol}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Convert market states to RL state vectors
|
||||
initial_rl_state = self._market_state_to_rl_state(initial_state)
|
||||
final_rl_state = self._market_state_to_rl_state(final_state)
|
||||
|
||||
# Convert action to RL action index
|
||||
action_mapping = {'SELL': 0, 'HOLD': 1, 'BUY': 2}
|
||||
action_idx = action_mapping.get(action.action, 1)
|
||||
|
||||
# Store experience
|
||||
agent = self.agents[symbol]
|
||||
agent.remember(
|
||||
state=initial_rl_state,
|
||||
action=action_idx,
|
||||
reward=reward,
|
||||
next_state=final_rl_state,
|
||||
done=False
|
||||
)
|
||||
|
||||
# Track reward
|
||||
self.training_metrics['total_rewards'][symbol].append(reward)
|
||||
|
||||
logger.debug(f"Added experience for {symbol}: action={action.action}, reward={reward:.4f}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding experience for {symbol}: {e}")
|
||||
|
||||
def _market_state_to_rl_state(self, market_state: MarketState) -> np.ndarray:
|
||||
"""Convert market state to comprehensive RL state vector using real data"""
|
||||
try:
|
||||
# Extract data from market state and orchestrator
|
||||
if not self.orchestrator:
|
||||
logger.warning("No orchestrator available for comprehensive state building")
|
||||
return self._fallback_state_conversion(market_state)
|
||||
|
||||
# Get real tick data from orchestrator's data provider
|
||||
symbol = market_state.symbol
|
||||
eth_ticks = self._get_recent_tick_data(symbol, seconds=300)
|
||||
|
||||
# Get multi-timeframe OHLCV data
|
||||
eth_ohlcv = self._get_multiframe_ohlcv_data(symbol)
|
||||
btc_ohlcv = self._get_multiframe_ohlcv_data('BTC/USDT')
|
||||
|
||||
# Get CNN features if available
|
||||
cnn_hidden_features = None
|
||||
cnn_predictions = None
|
||||
if self.cnn_rl_bridge:
|
||||
cnn_data = self.cnn_rl_bridge.get_latest_features_for_symbol(symbol)
|
||||
if cnn_data:
|
||||
cnn_hidden_features = cnn_data.get('hidden_features', {})
|
||||
cnn_predictions = cnn_data.get('predictions', {})
|
||||
|
||||
# Get pivot point data
|
||||
pivot_data = self._calculate_pivot_points(eth_ohlcv)
|
||||
|
||||
# Build comprehensive state using enhanced state builder
|
||||
comprehensive_state = self.state_builder.build_rl_state(
|
||||
eth_ticks=eth_ticks,
|
||||
eth_ohlcv=eth_ohlcv,
|
||||
btc_ohlcv=btc_ohlcv,
|
||||
cnn_hidden_features=cnn_hidden_features,
|
||||
cnn_predictions=cnn_predictions,
|
||||
pivot_data=pivot_data
|
||||
)
|
||||
|
||||
logger.debug(f"Built comprehensive RL state: {len(comprehensive_state)} features")
|
||||
return comprehensive_state
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error building comprehensive RL state: {e}")
|
||||
return self._fallback_state_conversion(market_state)
|
||||
|
||||
def _get_recent_tick_data(self, symbol: str, seconds: int = 300) -> List:
|
||||
"""Get recent tick data from orchestrator's data provider"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider') and self.orchestrator.data_provider:
|
||||
# Get recent ticks from data provider
|
||||
recent_ticks = self.orchestrator.data_provider.get_recent_ticks(symbol, count=seconds*10)
|
||||
|
||||
# Convert to required format
|
||||
tick_data = []
|
||||
for tick in recent_ticks[-300:]: # Last 300 ticks max
|
||||
tick_data.append({
|
||||
'timestamp': tick.timestamp,
|
||||
'price': tick.price,
|
||||
'volume': tick.volume,
|
||||
'quantity': getattr(tick, 'quantity', tick.volume),
|
||||
'side': getattr(tick, 'side', 'unknown'),
|
||||
'trade_id': getattr(tick, 'trade_id', 'unknown')
|
||||
})
|
||||
|
||||
return tick_data
|
||||
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting tick data for {symbol}: {e}")
|
||||
return []
|
||||
|
||||
def _get_multiframe_ohlcv_data(self, symbol: str) -> Dict[str, List]:
|
||||
"""Get multi-timeframe OHLCV data"""
|
||||
try:
|
||||
if hasattr(self.orchestrator, 'data_provider') and self.orchestrator.data_provider:
|
||||
ohlcv_data = {}
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
for tf in timeframes:
|
||||
try:
|
||||
# Get historical data for timeframe
|
||||
df = self.orchestrator.data_provider.get_historical_data(
|
||||
symbol=symbol,
|
||||
timeframe=tf,
|
||||
limit=300,
|
||||
refresh=True
|
||||
)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Convert to list of dictionaries
|
||||
bars = []
|
||||
for _, row in df.tail(300).iterrows():
|
||||
bar = {
|
||||
'timestamp': row.name if hasattr(row, 'name') else datetime.now(),
|
||||
'open': float(row.get('open', 0)),
|
||||
'high': float(row.get('high', 0)),
|
||||
'low': float(row.get('low', 0)),
|
||||
'close': float(row.get('close', 0)),
|
||||
'volume': float(row.get('volume', 0))
|
||||
}
|
||||
bars.append(bar)
|
||||
|
||||
ohlcv_data[tf] = bars
|
||||
else:
|
||||
ohlcv_data[tf] = []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting {tf} data for {symbol}: {e}")
|
||||
ohlcv_data[tf] = []
|
||||
|
||||
return ohlcv_data
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting OHLCV data for {symbol}: {e}")
|
||||
return {}
|
||||
|
||||
def _calculate_pivot_points(self, eth_ohlcv: Dict[str, List]) -> Dict[str, Any]:
|
||||
"""Calculate Williams pivot points from OHLCV data"""
|
||||
try:
|
||||
if '1m' in eth_ohlcv and eth_ohlcv['1m']:
|
||||
# Convert to numpy array for Williams calculation
|
||||
bars = eth_ohlcv['1m']
|
||||
if len(bars) >= 50: # Need minimum data for pivot calculation
|
||||
ohlc_array = np.array([
|
||||
[bar['timestamp'].timestamp() if hasattr(bar['timestamp'], 'timestamp') else time.time(),
|
||||
bar['open'], bar['high'], bar['low'], bar['close'], bar['volume']]
|
||||
for bar in bars[-200:] # Last 200 bars
|
||||
])
|
||||
|
||||
pivot_data = self.williams_structure.calculate_recursive_pivot_points(ohlc_array)
|
||||
return pivot_data
|
||||
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error calculating pivot points: {e}")
|
||||
return {}
|
||||
|
||||
def _fallback_state_conversion(self, market_state: MarketState) -> np.ndarray:
|
||||
"""Fallback to basic state conversion if comprehensive state building fails"""
|
||||
logger.warning("Using fallback state conversion - limited features")
|
||||
|
||||
state_components = [
|
||||
market_state.volatility,
|
||||
market_state.volume,
|
||||
market_state.trend_strength
|
||||
]
|
||||
|
||||
# Add price features
|
||||
for timeframe in sorted(market_state.prices.keys()):
|
||||
state_components.append(market_state.prices[timeframe])
|
||||
|
||||
# Pad to match expected state size
|
||||
expected_size = self.state_builder.total_state_size
|
||||
if len(state_components) < expected_size:
|
||||
state_components.extend([0.0] * (expected_size - len(state_components)))
|
||||
else:
|
||||
state_components = state_components[:expected_size]
|
||||
|
||||
return np.array(state_components, dtype=np.float32)
|
||||
|
||||
def _update_performance_metrics(self):
|
||||
"""Update performance tracking metrics"""
|
||||
self.training_metrics['total_episodes'] += 1
|
||||
|
||||
# Calculate recent performance for each agent
|
||||
for symbol, agent in self.agents.items():
|
||||
recent_rewards = self.training_metrics['total_rewards'][symbol][-100:] # Last 100 rewards
|
||||
if recent_rewards:
|
||||
avg_reward = np.mean(recent_rewards)
|
||||
self.performance_history[symbol].append({
|
||||
'timestamp': datetime.now(),
|
||||
'avg_reward': avg_reward,
|
||||
'epsilon': agent.epsilon,
|
||||
'experiences': len(agent.replay_buffer)
|
||||
})
|
||||
|
||||
def _save_all_models(self):
|
||||
"""Save all RL models"""
|
||||
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
|
||||
for symbol, agent in self.agents.items():
|
||||
filename = f"rl_agent_{symbol}_{timestamp}.pt"
|
||||
filepath = self.save_dir / filename
|
||||
|
||||
torch.save({
|
||||
'model_state_dict': agent.state_dict(),
|
||||
'optimizer_state_dict': agent.optimizer.state_dict(),
|
||||
'config': self.config.rl,
|
||||
'training_metrics': self.training_metrics,
|
||||
'symbol': symbol,
|
||||
'epsilon': agent.epsilon,
|
||||
'training_steps': agent.training_steps
|
||||
}, filepath)
|
||||
|
||||
logger.info(f"Saved {symbol} RL agent to {filepath}")
|
||||
|
||||
def load_models(self, timestamp: str = None):
|
||||
"""Load RL models from files"""
|
||||
if timestamp is None:
|
||||
# Find most recent models
|
||||
model_files = list(self.save_dir.glob("rl_agent_*.pt"))
|
||||
if not model_files:
|
||||
logger.warning("No saved RL models found")
|
||||
return False
|
||||
|
||||
# Group by timestamp and get most recent
|
||||
timestamps = set(f.stem.split('_')[-2] + '_' + f.stem.split('_')[-1] for f in model_files)
|
||||
timestamp = max(timestamps)
|
||||
|
||||
loaded_count = 0
|
||||
for symbol in self.symbols:
|
||||
filename = f"rl_agent_{symbol}_{timestamp}.pt"
|
||||
filepath = self.save_dir / filename
|
||||
|
||||
if filepath.exists():
|
||||
try:
|
||||
checkpoint = torch.load(filepath, map_location=self.agents[symbol].device)
|
||||
self.agents[symbol].load_state_dict(checkpoint['model_state_dict'])
|
||||
self.agents[symbol].optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
self.agents[symbol].epsilon = checkpoint.get('epsilon', 0.1)
|
||||
self.agents[symbol].training_steps = checkpoint.get('training_steps', 0)
|
||||
|
||||
logger.info(f"Loaded {symbol} RL agent from {filepath}")
|
||||
loaded_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading {symbol} RL agent: {e}")
|
||||
|
||||
return loaded_count > 0
|
||||
|
||||
def get_performance_report(self) -> Dict[str, Any]:
|
||||
"""Generate performance report for all agents"""
|
||||
report = {
|
||||
'total_episodes': self.training_metrics['total_episodes'],
|
||||
'agents': {}
|
||||
}
|
||||
|
||||
for symbol, agent in self.agents.items():
|
||||
recent_rewards = self.training_metrics['total_rewards'][symbol][-100:]
|
||||
recent_losses = self.training_metrics['losses'][symbol][-10:]
|
||||
|
||||
agent_report = {
|
||||
'symbol': symbol,
|
||||
'epsilon': agent.epsilon,
|
||||
'training_steps': agent.training_steps,
|
||||
'experiences_stored': len(agent.replay_buffer),
|
||||
'memory_usage_mb': agent.get_memory_usage(),
|
||||
'avg_recent_reward': np.mean(recent_rewards) if recent_rewards else 0.0,
|
||||
'avg_recent_loss': np.mean(recent_losses) if recent_losses else 0.0,
|
||||
'total_rewards': len(self.training_metrics['total_rewards'][symbol])
|
||||
}
|
||||
|
||||
report['agents'][symbol] = agent_report
|
||||
|
||||
return report
|
||||
|
||||
def plot_training_metrics(self):
|
||||
"""Plot training metrics for all agents"""
|
||||
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
|
||||
fig.suptitle('Enhanced RL Training Metrics')
|
||||
|
||||
symbols = list(self.agents.keys())
|
||||
colors = ['blue', 'red', 'green', 'orange'][:len(symbols)]
|
||||
|
||||
# Rewards plot
|
||||
for i, symbol in enumerate(symbols):
|
||||
rewards = self.training_metrics['total_rewards'][symbol]
|
||||
if rewards:
|
||||
# Moving average of rewards
|
||||
window = min(100, len(rewards))
|
||||
if len(rewards) >= window:
|
||||
moving_avg = np.convolve(rewards, np.ones(window)/window, mode='valid')
|
||||
axes[0, 0].plot(moving_avg, label=f'{symbol}', color=colors[i])
|
||||
|
||||
axes[0, 0].set_title('Average Rewards (Moving Average)')
|
||||
axes[0, 0].set_xlabel('Episodes')
|
||||
axes[0, 0].set_ylabel('Reward')
|
||||
axes[0, 0].legend()
|
||||
|
||||
# Losses plot
|
||||
for i, symbol in enumerate(symbols):
|
||||
losses = self.training_metrics['losses'][symbol]
|
||||
if losses:
|
||||
axes[0, 1].plot(losses, label=f'{symbol}', color=colors[i])
|
||||
|
||||
axes[0, 1].set_title('Training Losses')
|
||||
axes[0, 1].set_xlabel('Training Steps')
|
||||
axes[0, 1].set_ylabel('Loss')
|
||||
axes[0, 1].legend()
|
||||
|
||||
# Epsilon values
|
||||
for i, symbol in enumerate(symbols):
|
||||
epsilon_values = self.training_metrics['epsilon_values'][symbol]
|
||||
if epsilon_values:
|
||||
axes[1, 0].plot(epsilon_values, label=f'{symbol}', color=colors[i])
|
||||
|
||||
axes[1, 0].set_title('Exploration Rate (Epsilon)')
|
||||
axes[1, 0].set_xlabel('Training Steps')
|
||||
axes[1, 0].set_ylabel('Epsilon')
|
||||
axes[1, 0].legend()
|
||||
|
||||
# Experience buffer sizes
|
||||
buffer_sizes = [len(agent.replay_buffer) for agent in self.agents.values()]
|
||||
axes[1, 1].bar(symbols, buffer_sizes, color=colors[:len(symbols)])
|
||||
axes[1, 1].set_title('Experience Buffer Sizes')
|
||||
axes[1, 1].set_ylabel('Number of Experiences')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(self.save_dir / 'rl_training_metrics.png', dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
logger.info(f"RL training plots saved to {self.save_dir / 'rl_training_metrics.png'}")
|
||||
|
||||
def get_agents(self) -> Dict[str, EnhancedDQNAgent]:
|
||||
"""Get all RL agents"""
|
||||
return self.agents
|
@ -1,523 +0,0 @@
|
||||
"""
|
||||
RL Training Pipeline - Scalping Agent Training
|
||||
|
||||
Comprehensive training pipeline for scalping RL agents:
|
||||
- Environment setup and management
|
||||
- Agent training with experience replay
|
||||
- Performance tracking and evaluation
|
||||
- Memory-efficient training loops
|
||||
"""
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Optional, Any
|
||||
import time
|
||||
from pathlib import Path
|
||||
import matplotlib.pyplot as plt
|
||||
from collections import deque
|
||||
import random
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
# Add project imports
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from core.config import get_config
|
||||
from core.data_provider import DataProvider
|
||||
from models.rl.scalping_agent import ScalpingEnvironment, ScalpingRLAgent
|
||||
from utils.model_utils import robust_save, robust_load
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RLTrainer:
|
||||
"""
|
||||
RL Training Pipeline for Scalping
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: DataProvider, config: Optional[Dict] = None):
|
||||
self.data_provider = data_provider
|
||||
self.config = config or get_config()
|
||||
|
||||
# Training parameters
|
||||
self.num_episodes = 1000
|
||||
self.max_steps_per_episode = 1000
|
||||
self.training_frequency = 4 # Train every N steps
|
||||
self.evaluation_frequency = 50 # Evaluate every N episodes
|
||||
self.save_frequency = 100 # Save model every N episodes
|
||||
|
||||
# Environment parameters
|
||||
self.symbols = ['ETH/USDT']
|
||||
self.initial_balance = 1000.0
|
||||
self.max_position_size = 0.1
|
||||
|
||||
# Agent parameters (will be set when we know state dimension)
|
||||
self.state_dim = None
|
||||
self.action_dim = 3 # BUY, SELL, HOLD
|
||||
self.learning_rate = 1e-4
|
||||
self.memory_size = 50000
|
||||
|
||||
# Device
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
# Training state
|
||||
self.environment = None
|
||||
self.agent = None
|
||||
self.episode_rewards = []
|
||||
self.episode_lengths = []
|
||||
self.episode_balances = []
|
||||
self.episode_trades = []
|
||||
self.training_losses = []
|
||||
|
||||
# Performance tracking
|
||||
self.best_reward = -float('inf')
|
||||
self.best_balance = 0.0
|
||||
self.win_rates = []
|
||||
self.avg_rewards = []
|
||||
|
||||
# TensorBoard setup
|
||||
self.setup_tensorboard()
|
||||
|
||||
logger.info(f"RLTrainer initialized for symbols: {self.symbols}")
|
||||
|
||||
def setup_tensorboard(self):
|
||||
"""Setup TensorBoard logging"""
|
||||
# Create tensorboard logs directory
|
||||
log_dir = Path("runs") / f"rl_training_{int(time.time())}"
|
||||
log_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.writer = SummaryWriter(log_dir=str(log_dir))
|
||||
self.tensorboard_dir = log_dir
|
||||
|
||||
logger.info(f"TensorBoard logging to: {log_dir}")
|
||||
logger.info(f"Run: tensorboard --logdir=runs")
|
||||
|
||||
def setup_environment_and_agent(self) -> Tuple[ScalpingEnvironment, ScalpingRLAgent]:
|
||||
"""Setup trading environment and RL agent"""
|
||||
logger.info("Setting up environment and agent...")
|
||||
|
||||
# Create environment
|
||||
environment = ScalpingEnvironment(
|
||||
data_provider=self.data_provider,
|
||||
symbol=self.symbols[0],
|
||||
initial_balance=self.initial_balance,
|
||||
max_position_size=self.max_position_size
|
||||
)
|
||||
|
||||
# Get state dimension by resetting environment
|
||||
initial_state = environment.reset()
|
||||
if initial_state is None:
|
||||
raise ValueError("Could not get initial state from environment")
|
||||
|
||||
self.state_dim = len(initial_state)
|
||||
logger.info(f"State dimension: {self.state_dim}")
|
||||
|
||||
# Create agent
|
||||
agent = ScalpingRLAgent(
|
||||
state_dim=self.state_dim,
|
||||
action_dim=self.action_dim,
|
||||
learning_rate=self.learning_rate,
|
||||
memory_size=self.memory_size
|
||||
)
|
||||
|
||||
return environment, agent
|
||||
|
||||
def run_episode(self, episode_num: int, training: bool = True) -> Dict:
|
||||
"""Run a single episode"""
|
||||
state = self.environment.reset()
|
||||
if state is None:
|
||||
return {'error': 'Could not reset environment'}
|
||||
|
||||
episode_reward = 0.0
|
||||
episode_loss = 0.0
|
||||
step_count = 0
|
||||
trades_made = 0
|
||||
|
||||
# Episode loop
|
||||
for step in range(self.max_steps_per_episode):
|
||||
# Select action
|
||||
action = self.agent.act(state, training=training)
|
||||
|
||||
# Execute action in environment
|
||||
next_state, reward, done, info = self.environment.step(action, step)
|
||||
|
||||
if next_state is None:
|
||||
break
|
||||
|
||||
# Store experience if training
|
||||
if training:
|
||||
# Determine if this is a high-priority experience
|
||||
priority = (abs(reward) > 0.1 or
|
||||
info.get('trade_info', {}).get('executed', False))
|
||||
|
||||
self.agent.remember(state, action, reward, next_state, done, priority)
|
||||
|
||||
# Train agent
|
||||
if step % self.training_frequency == 0 and len(self.agent.memory) > self.agent.batch_size:
|
||||
loss = self.agent.replay()
|
||||
if loss is not None:
|
||||
episode_loss += loss
|
||||
|
||||
# Update state
|
||||
state = next_state
|
||||
episode_reward += reward
|
||||
step_count += 1
|
||||
|
||||
# Track trades
|
||||
if info.get('trade_info', {}).get('executed', False):
|
||||
trades_made += 1
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Episode results
|
||||
final_balance = info.get('balance', self.initial_balance)
|
||||
total_fees = info.get('total_fees', 0.0)
|
||||
|
||||
episode_results = {
|
||||
'episode': episode_num,
|
||||
'reward': episode_reward,
|
||||
'steps': step_count,
|
||||
'balance': final_balance,
|
||||
'trades': trades_made,
|
||||
'fees': total_fees,
|
||||
'pnl': final_balance - self.initial_balance,
|
||||
'pnl_percentage': (final_balance - self.initial_balance) / self.initial_balance * 100,
|
||||
'avg_loss': episode_loss / max(step_count // self.training_frequency, 1) if training else 0
|
||||
}
|
||||
|
||||
return episode_results
|
||||
|
||||
def evaluate_agent(self, num_episodes: int = 10) -> Dict:
|
||||
"""Evaluate agent performance"""
|
||||
logger.info(f"Evaluating agent over {num_episodes} episodes...")
|
||||
|
||||
evaluation_results = []
|
||||
total_reward = 0.0
|
||||
total_balance = 0.0
|
||||
total_trades = 0
|
||||
winning_episodes = 0
|
||||
|
||||
# Set agent to evaluation mode
|
||||
original_epsilon = self.agent.epsilon
|
||||
self.agent.epsilon = 0.0 # No exploration during evaluation
|
||||
|
||||
for episode in range(num_episodes):
|
||||
results = self.run_episode(episode, training=False)
|
||||
evaluation_results.append(results)
|
||||
|
||||
total_reward += results['reward']
|
||||
total_balance += results['balance']
|
||||
total_trades += results['trades']
|
||||
|
||||
if results['pnl'] > 0:
|
||||
winning_episodes += 1
|
||||
|
||||
# Restore original epsilon
|
||||
self.agent.epsilon = original_epsilon
|
||||
|
||||
# Calculate summary statistics
|
||||
avg_reward = total_reward / num_episodes
|
||||
avg_balance = total_balance / num_episodes
|
||||
avg_trades = total_trades / num_episodes
|
||||
win_rate = winning_episodes / num_episodes
|
||||
|
||||
evaluation_summary = {
|
||||
'num_episodes': num_episodes,
|
||||
'avg_reward': avg_reward,
|
||||
'avg_balance': avg_balance,
|
||||
'avg_pnl': avg_balance - self.initial_balance,
|
||||
'avg_pnl_percentage': (avg_balance - self.initial_balance) / self.initial_balance * 100,
|
||||
'avg_trades': avg_trades,
|
||||
'win_rate': win_rate,
|
||||
'results': evaluation_results
|
||||
}
|
||||
|
||||
logger.info(f"Evaluation complete - Avg Reward: {avg_reward:.4f}, Win Rate: {win_rate:.2%}")
|
||||
|
||||
return evaluation_summary
|
||||
|
||||
def train(self, save_path: Optional[str] = None) -> Dict:
|
||||
"""Train the RL agent"""
|
||||
logger.info("Starting RL agent training...")
|
||||
|
||||
# Setup environment and agent
|
||||
self.environment, self.agent = self.setup_environment_and_agent()
|
||||
|
||||
# Training state
|
||||
start_time = time.time()
|
||||
best_eval_reward = -float('inf')
|
||||
|
||||
# Training loop
|
||||
for episode in range(self.num_episodes):
|
||||
episode_start_time = time.time()
|
||||
|
||||
# Run training episode
|
||||
results = self.run_episode(episode, training=True)
|
||||
|
||||
# Track metrics
|
||||
self.episode_rewards.append(results['reward'])
|
||||
self.episode_lengths.append(results['steps'])
|
||||
self.episode_balances.append(results['balance'])
|
||||
self.episode_trades.append(results['trades'])
|
||||
|
||||
if results.get('avg_loss', 0) > 0:
|
||||
self.training_losses.append(results['avg_loss'])
|
||||
|
||||
# Update best metrics
|
||||
if results['reward'] > self.best_reward:
|
||||
self.best_reward = results['reward']
|
||||
|
||||
if results['balance'] > self.best_balance:
|
||||
self.best_balance = results['balance']
|
||||
|
||||
# Calculate running averages
|
||||
recent_rewards = self.episode_rewards[-100:] # Last 100 episodes
|
||||
recent_balances = self.episode_balances[-100:]
|
||||
|
||||
avg_reward = np.mean(recent_rewards)
|
||||
avg_balance = np.mean(recent_balances)
|
||||
|
||||
self.avg_rewards.append(avg_reward)
|
||||
|
||||
# Log progress
|
||||
episode_time = time.time() - episode_start_time
|
||||
|
||||
if episode % 10 == 0:
|
||||
logger.info(
|
||||
f"Episode {episode}/{self.num_episodes} - "
|
||||
f"Reward: {results['reward']:.4f}, Balance: ${results['balance']:.2f}, "
|
||||
f"Trades: {results['trades']}, PnL: {results['pnl_percentage']:.2f}%, "
|
||||
f"Epsilon: {self.agent.epsilon:.3f}, Time: {episode_time:.2f}s"
|
||||
)
|
||||
|
||||
# Evaluation
|
||||
if episode % self.evaluation_frequency == 0 and episode > 0:
|
||||
eval_results = self.evaluate_agent(num_episodes=5)
|
||||
|
||||
# Track win rate
|
||||
self.win_rates.append(eval_results['win_rate'])
|
||||
|
||||
logger.info(
|
||||
f"Evaluation - Avg Reward: {eval_results['avg_reward']:.4f}, "
|
||||
f"Win Rate: {eval_results['win_rate']:.2%}, "
|
||||
f"Avg PnL: {eval_results['avg_pnl_percentage']:.2f}%"
|
||||
)
|
||||
|
||||
# Save best model
|
||||
if eval_results['avg_reward'] > best_eval_reward:
|
||||
best_eval_reward = eval_results['avg_reward']
|
||||
if save_path:
|
||||
best_path = save_path.replace('.pt', '_best.pt')
|
||||
self.agent.save(best_path)
|
||||
logger.info(f"New best model saved: {best_path}")
|
||||
|
||||
# Save checkpoint
|
||||
if episode % self.save_frequency == 0 and episode > 0 and save_path:
|
||||
checkpoint_path = save_path.replace('.pt', f'_checkpoint_{episode}.pt')
|
||||
self.agent.save(checkpoint_path)
|
||||
logger.info(f"Checkpoint saved: {checkpoint_path}")
|
||||
|
||||
# Training complete
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"Training completed in {total_time:.2f} seconds")
|
||||
|
||||
# Final evaluation
|
||||
final_eval = self.evaluate_agent(num_episodes=20)
|
||||
|
||||
# Save final model
|
||||
if save_path:
|
||||
self.agent.save(save_path)
|
||||
logger.info(f"Final model saved: {save_path}")
|
||||
|
||||
# Prepare training results
|
||||
training_results = {
|
||||
'total_episodes': self.num_episodes,
|
||||
'total_time': total_time,
|
||||
'best_reward': self.best_reward,
|
||||
'best_balance': self.best_balance,
|
||||
'final_evaluation': final_eval,
|
||||
'episode_rewards': self.episode_rewards,
|
||||
'episode_balances': self.episode_balances,
|
||||
'episode_trades': self.episode_trades,
|
||||
'training_losses': self.training_losses,
|
||||
'avg_rewards': self.avg_rewards,
|
||||
'win_rates': self.win_rates,
|
||||
'agent_config': {
|
||||
'state_dim': self.state_dim,
|
||||
'action_dim': self.action_dim,
|
||||
'learning_rate': self.learning_rate,
|
||||
'epsilon_final': self.agent.epsilon
|
||||
}
|
||||
}
|
||||
|
||||
return training_results
|
||||
|
||||
def backtest_agent(self, agent_path: str, test_episodes: int = 50) -> Dict:
|
||||
"""Backtest trained agent"""
|
||||
logger.info(f"Backtesting agent from {agent_path}...")
|
||||
|
||||
# Setup environment and agent
|
||||
self.environment, self.agent = self.setup_environment_and_agent()
|
||||
|
||||
# Load trained agent
|
||||
self.agent.load(agent_path)
|
||||
|
||||
# Run backtest
|
||||
backtest_results = self.evaluate_agent(test_episodes)
|
||||
|
||||
# Additional analysis
|
||||
results = backtest_results['results']
|
||||
pnls = [r['pnl_percentage'] for r in results]
|
||||
rewards = [r['reward'] for r in results]
|
||||
trades = [r['trades'] for r in results]
|
||||
|
||||
analysis = {
|
||||
'total_episodes': test_episodes,
|
||||
'avg_pnl': np.mean(pnls),
|
||||
'std_pnl': np.std(pnls),
|
||||
'max_pnl': np.max(pnls),
|
||||
'min_pnl': np.min(pnls),
|
||||
'avg_reward': np.mean(rewards),
|
||||
'avg_trades': np.mean(trades),
|
||||
'win_rate': backtest_results['win_rate'],
|
||||
'profit_factor': np.sum([p for p in pnls if p > 0]) / abs(np.sum([p for p in pnls if p < 0])) if any(p < 0 for p in pnls) else float('inf'),
|
||||
'sharpe_ratio': np.mean(pnls) / np.std(pnls) if np.std(pnls) > 0 else 0,
|
||||
'max_drawdown': self._calculate_max_drawdown(pnls)
|
||||
}
|
||||
|
||||
logger.info(f"Backtest complete - Win Rate: {analysis['win_rate']:.2%}, Avg PnL: {analysis['avg_pnl']:.2f}%")
|
||||
|
||||
return {
|
||||
'backtest_results': backtest_results,
|
||||
'analysis': analysis
|
||||
}
|
||||
|
||||
def _calculate_max_drawdown(self, pnls: List[float]) -> float:
|
||||
"""Calculate maximum drawdown"""
|
||||
cumulative = np.cumsum(pnls)
|
||||
running_max = np.maximum.accumulate(cumulative)
|
||||
drawdowns = running_max - cumulative
|
||||
return np.max(drawdowns) if len(drawdowns) > 0 else 0.0
|
||||
|
||||
def plot_training_progress(self, save_path: Optional[str] = None):
|
||||
"""Plot training progress"""
|
||||
if not self.episode_rewards:
|
||||
logger.warning("No training data to plot")
|
||||
return
|
||||
|
||||
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
|
||||
|
||||
episodes = range(1, len(self.episode_rewards) + 1)
|
||||
|
||||
# Episode rewards
|
||||
ax1.plot(episodes, self.episode_rewards, alpha=0.6, label='Episode Reward')
|
||||
if self.avg_rewards:
|
||||
ax1.plot(episodes, self.avg_rewards, 'r-', label='Avg Reward (100 episodes)')
|
||||
ax1.set_title('Training Rewards')
|
||||
ax1.set_xlabel('Episode')
|
||||
ax1.set_ylabel('Reward')
|
||||
ax1.legend()
|
||||
ax1.grid(True)
|
||||
|
||||
# Episode balances
|
||||
ax2.plot(episodes, self.episode_balances, alpha=0.6, label='Episode Balance')
|
||||
ax2.axhline(y=self.initial_balance, color='r', linestyle='--', label='Initial Balance')
|
||||
ax2.set_title('Portfolio Balance')
|
||||
ax2.set_xlabel('Episode')
|
||||
ax2.set_ylabel('Balance ($)')
|
||||
ax2.legend()
|
||||
ax2.grid(True)
|
||||
|
||||
# Training losses
|
||||
if self.training_losses:
|
||||
loss_episodes = np.linspace(1, len(self.episode_rewards), len(self.training_losses))
|
||||
ax3.plot(loss_episodes, self.training_losses, 'g-', alpha=0.8)
|
||||
ax3.set_title('Training Loss')
|
||||
ax3.set_xlabel('Episode')
|
||||
ax3.set_ylabel('Loss')
|
||||
ax3.grid(True)
|
||||
|
||||
# Win rates
|
||||
if self.win_rates:
|
||||
eval_episodes = np.arange(self.evaluation_frequency,
|
||||
len(self.episode_rewards) + 1,
|
||||
self.evaluation_frequency)[:len(self.win_rates)]
|
||||
ax4.plot(eval_episodes, self.win_rates, 'purple', marker='o')
|
||||
ax4.set_title('Win Rate')
|
||||
ax4.set_xlabel('Episode')
|
||||
ax4.set_ylabel('Win Rate')
|
||||
ax4.grid(True)
|
||||
ax4.set_ylim(0, 1)
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
if save_path:
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
logger.info(f"Training progress plot saved: {save_path}")
|
||||
|
||||
plt.show()
|
||||
|
||||
def log_episode_metrics(self, episode: int, metrics: Dict):
|
||||
"""Log episode metrics to TensorBoard"""
|
||||
# Main performance metrics
|
||||
self.writer.add_scalar('Episode/TotalReward', metrics['total_reward'], episode)
|
||||
self.writer.add_scalar('Episode/FinalBalance', metrics['final_balance'], episode)
|
||||
self.writer.add_scalar('Episode/TotalReturn', metrics['total_return'], episode)
|
||||
self.writer.add_scalar('Episode/Steps', metrics['steps'], episode)
|
||||
|
||||
# Trading metrics
|
||||
self.writer.add_scalar('Trading/TotalTrades', metrics['total_trades'], episode)
|
||||
self.writer.add_scalar('Trading/WinRate', metrics['win_rate'], episode)
|
||||
self.writer.add_scalar('Trading/ProfitFactor', metrics.get('profit_factor', 0), episode)
|
||||
self.writer.add_scalar('Trading/MaxDrawdown', metrics.get('max_drawdown', 0), episode)
|
||||
|
||||
# Agent metrics
|
||||
self.writer.add_scalar('Agent/Epsilon', metrics['epsilon'], episode)
|
||||
self.writer.add_scalar('Agent/LearningRate', metrics.get('learning_rate', self.learning_rate), episode)
|
||||
self.writer.add_scalar('Agent/MemorySize', metrics.get('memory_size', 0), episode)
|
||||
|
||||
# Loss metrics (if available)
|
||||
if 'loss' in metrics:
|
||||
self.writer.add_scalar('Agent/Loss', metrics['loss'], episode)
|
||||
|
||||
class HybridTrainer:
|
||||
"""
|
||||
Hybrid training pipeline combining CNN and RL
|
||||
"""
|
||||
|
||||
def __init__(self, data_provider: DataProvider):
|
||||
self.data_provider = data_provider
|
||||
self.cnn_trainer = None
|
||||
self.rl_trainer = None
|
||||
|
||||
def train_hybrid(self, symbols: List[str], cnn_save_path: str, rl_save_path: str) -> Dict:
|
||||
"""Train CNN first, then RL with CNN features"""
|
||||
logger.info("Starting hybrid CNN + RL training...")
|
||||
|
||||
# Phase 1: Train CNN
|
||||
logger.info("Phase 1: Training CNN...")
|
||||
from training.cnn_trainer import CNNTrainer
|
||||
|
||||
self.cnn_trainer = CNNTrainer(self.data_provider)
|
||||
cnn_results = self.cnn_trainer.train(symbols, cnn_save_path)
|
||||
|
||||
# Phase 2: Train RL
|
||||
logger.info("Phase 2: Training RL...")
|
||||
self.rl_trainer = RLTrainer(self.data_provider)
|
||||
rl_results = self.rl_trainer.train(rl_save_path)
|
||||
|
||||
# Combine results
|
||||
hybrid_results = {
|
||||
'cnn_results': cnn_results,
|
||||
'rl_results': rl_results,
|
||||
'total_time': cnn_results['total_time'] + rl_results['total_time']
|
||||
}
|
||||
|
||||
logger.info("Hybrid training completed!")
|
||||
return hybrid_results
|
||||
|
||||
# Export
|
||||
__all__ = ['RLTrainer', 'HybridTrainer']
|
@ -57,7 +57,6 @@ try:
|
||||
from core.enhanced_orchestrator import EnhancedTradingOrchestrator
|
||||
from core.universal_data_adapter import UniversalDataAdapter
|
||||
from core.unified_data_stream import UnifiedDataStream, TrainingDataPacket, UIDataPacket
|
||||
from training.enhanced_pivot_rl_trainer import EnhancedPivotRLTrainer, create_enhanced_pivot_trainer
|
||||
ENHANCED_RL_AVAILABLE = True
|
||||
logger.info("Enhanced RL training components available")
|
||||
except ImportError as e:
|
||||
|
Loading…
x
Reference in New Issue
Block a user