wip wip wip
This commit is contained in:
@@ -582,6 +582,65 @@ class DataProvider:
|
||||
logger.error(f"Error loading initial data for {symbol} {timeframe}: {e}")
|
||||
|
||||
logger.info("Initial data load completed")
|
||||
|
||||
# Catch up on missing candles if needed
|
||||
self._catch_up_missing_candles()
|
||||
|
||||
def _catch_up_missing_candles(self):
|
||||
"""
|
||||
Catch up on missing candles at startup
|
||||
Fetches up to 1500 candles per timeframe if we're missing data
|
||||
"""
|
||||
logger.info("Checking for missing candles to catch up...")
|
||||
|
||||
target_candles = 1500 # Target number of candles per timeframe
|
||||
|
||||
for symbol in self.symbols:
|
||||
for timeframe in self.timeframes:
|
||||
try:
|
||||
# Check current candle count
|
||||
current_df = self.cached_data[symbol][timeframe]
|
||||
current_count = len(current_df) if not current_df.empty else 0
|
||||
|
||||
if current_count >= target_candles:
|
||||
logger.debug(f"{symbol} {timeframe}: Already have {current_count} candles (target: {target_candles})")
|
||||
continue
|
||||
|
||||
# Calculate how many candles we need
|
||||
needed = target_candles - current_count
|
||||
logger.info(f"{symbol} {timeframe}: Need {needed} more candles (have {current_count}/{target_candles})")
|
||||
|
||||
# Fetch missing candles
|
||||
# Try Binance first (usually has better historical data)
|
||||
df = self._fetch_from_binance(symbol, timeframe, needed)
|
||||
|
||||
if df is None or df.empty:
|
||||
# Fallback to MEXC
|
||||
logger.debug(f"Binance fetch failed for {symbol} {timeframe}, trying MEXC...")
|
||||
df = self._fetch_from_mexc(symbol, timeframe, needed)
|
||||
|
||||
if df is not None and not df.empty:
|
||||
# Ensure proper datetime index
|
||||
df = self._ensure_datetime_index(df)
|
||||
|
||||
# Merge with existing data
|
||||
if not current_df.empty:
|
||||
combined_df = pd.concat([current_df, df], ignore_index=False)
|
||||
combined_df = combined_df[~combined_df.index.duplicated(keep='last')]
|
||||
combined_df = combined_df.sort_index()
|
||||
self.cached_data[symbol][timeframe] = combined_df.tail(target_candles)
|
||||
else:
|
||||
self.cached_data[symbol][timeframe] = df.tail(target_candles)
|
||||
|
||||
final_count = len(self.cached_data[symbol][timeframe])
|
||||
logger.info(f"✅ {symbol} {timeframe}: Caught up! Now have {final_count} candles")
|
||||
else:
|
||||
logger.warning(f"❌ {symbol} {timeframe}: Could not fetch historical data from any exchange")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error catching up candles for {symbol} {timeframe}: {e}")
|
||||
|
||||
logger.info("Candle catch-up completed")
|
||||
|
||||
def _update_cached_data(self, symbol: str, timeframe: str):
|
||||
"""Update cached data by fetching last 2 candles"""
|
||||
|
||||
@@ -142,11 +142,11 @@ class EnhancedRewardCalculator:
|
||||
symbol: str,
|
||||
timeframe: TimeFrame,
|
||||
predicted_price: float,
|
||||
predicted_return: Optional[float] = None,
|
||||
predicted_direction: int,
|
||||
confidence: float,
|
||||
current_price: float,
|
||||
model_name: str,
|
||||
predicted_return: Optional[float] = None,
|
||||
state_vector: Optional[list] = None) -> str:
|
||||
"""
|
||||
Add a new prediction to track
|
||||
|
||||
@@ -17,7 +17,7 @@ import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from typing import Dict, List, Optional, Any, Union, Tuple
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
371
core/timescale_storage.py
Normal file
371
core/timescale_storage.py
Normal file
@@ -0,0 +1,371 @@
|
||||
"""
|
||||
TimescaleDB Storage for OHLCV Candle Data
|
||||
|
||||
Provides long-term storage for all candle data without limits.
|
||||
Replaces capped deques with unlimited database storage.
|
||||
|
||||
CRITICAL POLICY: NO SYNTHETIC DATA ALLOWED
|
||||
This module MUST ONLY store real market data from exchanges.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List
|
||||
import psycopg2
|
||||
from psycopg2.extras import execute_values
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TimescaleDBStorage:
|
||||
"""
|
||||
TimescaleDB storage for OHLCV candle data
|
||||
|
||||
Features:
|
||||
- Unlimited storage (no caps)
|
||||
- Fast time-range queries
|
||||
- Automatic compression
|
||||
- Multi-symbol, multi-timeframe support
|
||||
"""
|
||||
|
||||
def __init__(self, connection_string: str = None):
|
||||
"""
|
||||
Initialize TimescaleDB storage
|
||||
|
||||
Args:
|
||||
connection_string: PostgreSQL connection string
|
||||
Default: postgresql://postgres:password@localhost:5432/trading_data
|
||||
"""
|
||||
self.connection_string = connection_string or \
|
||||
"postgresql://postgres:password@localhost:5432/trading_data"
|
||||
|
||||
# Test connection
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("SELECT version();")
|
||||
version = cur.fetchone()
|
||||
logger.info(f"Connected to TimescaleDB: {version[0]}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to TimescaleDB: {e}")
|
||||
logger.warning("TimescaleDB storage will not be available")
|
||||
raise
|
||||
|
||||
@contextmanager
|
||||
def get_connection(self):
|
||||
"""Get database connection with automatic cleanup"""
|
||||
conn = psycopg2.connect(self.connection_string)
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
except Exception as e:
|
||||
conn.rollback()
|
||||
raise e
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def create_tables(self):
|
||||
"""Create TimescaleDB tables and hypertables"""
|
||||
with self.get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
# Create extension if not exists
|
||||
cur.execute("CREATE EXTENSION IF NOT EXISTS timescaledb;")
|
||||
|
||||
# Create ohlcv_candles table
|
||||
cur.execute("""
|
||||
CREATE TABLE IF NOT EXISTS ohlcv_candles (
|
||||
time TIMESTAMPTZ NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
timeframe TEXT NOT NULL,
|
||||
open DOUBLE PRECISION NOT NULL,
|
||||
high DOUBLE PRECISION NOT NULL,
|
||||
low DOUBLE PRECISION NOT NULL,
|
||||
close DOUBLE PRECISION NOT NULL,
|
||||
volume DOUBLE PRECISION NOT NULL,
|
||||
PRIMARY KEY (time, symbol, timeframe)
|
||||
);
|
||||
""")
|
||||
|
||||
# Convert to hypertable (if not already)
|
||||
try:
|
||||
cur.execute("""
|
||||
SELECT create_hypertable('ohlcv_candles', 'time',
|
||||
if_not_exists => TRUE);
|
||||
""")
|
||||
logger.info("Created hypertable: ohlcv_candles")
|
||||
except Exception as e:
|
||||
logger.debug(f"Hypertable may already exist: {e}")
|
||||
|
||||
# Create indexes for fast queries
|
||||
cur.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_symbol_timeframe_time
|
||||
ON ohlcv_candles (symbol, timeframe, time DESC);
|
||||
""")
|
||||
|
||||
# Enable compression (saves 10-20x space)
|
||||
try:
|
||||
cur.execute("""
|
||||
ALTER TABLE ohlcv_candles SET (
|
||||
timescaledb.compress,
|
||||
timescaledb.compress_segmentby = 'symbol,timeframe'
|
||||
);
|
||||
""")
|
||||
logger.info("Enabled compression on ohlcv_candles")
|
||||
except Exception as e:
|
||||
logger.debug(f"Compression may already be enabled: {e}")
|
||||
|
||||
# Add compression policy (compress data older than 7 days)
|
||||
try:
|
||||
cur.execute("""
|
||||
SELECT add_compression_policy('ohlcv_candles', INTERVAL '7 days');
|
||||
""")
|
||||
logger.info("Added compression policy (7 days)")
|
||||
except Exception as e:
|
||||
logger.debug(f"Compression policy may already exist: {e}")
|
||||
|
||||
logger.info("TimescaleDB tables created successfully")
|
||||
|
||||
def store_candles(self, symbol: str, timeframe: str, df: pd.DataFrame):
|
||||
"""
|
||||
Store OHLCV candles in TimescaleDB
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe (e.g., '1s', '1m', '1h', '1d')
|
||||
df: DataFrame with columns: open, high, low, close, volume
|
||||
Index must be DatetimeIndex (timestamps)
|
||||
|
||||
Returns:
|
||||
int: Number of candles stored
|
||||
"""
|
||||
if df is None or df.empty:
|
||||
logger.warning(f"No data to store for {symbol} {timeframe}")
|
||||
return 0
|
||||
|
||||
try:
|
||||
# Prepare data for insertion
|
||||
data = []
|
||||
for timestamp, row in df.iterrows():
|
||||
data.append((
|
||||
timestamp,
|
||||
symbol,
|
||||
timeframe,
|
||||
float(row['open']),
|
||||
float(row['high']),
|
||||
float(row['low']),
|
||||
float(row['close']),
|
||||
float(row['volume'])
|
||||
))
|
||||
|
||||
# Insert data (ON CONFLICT DO NOTHING to avoid duplicates)
|
||||
with self.get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
execute_values(
|
||||
cur,
|
||||
"""
|
||||
INSERT INTO ohlcv_candles
|
||||
(time, symbol, timeframe, open, high, low, close, volume)
|
||||
VALUES %s
|
||||
ON CONFLICT (time, symbol, timeframe) DO NOTHING
|
||||
""",
|
||||
data
|
||||
)
|
||||
|
||||
logger.info(f"Stored {len(data)} candles for {symbol} {timeframe}")
|
||||
return len(data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing candles for {symbol} {timeframe}: {e}")
|
||||
return 0
|
||||
|
||||
def get_candles(self, symbol: str, timeframe: str,
|
||||
start_time: datetime = None, end_time: datetime = None,
|
||||
limit: int = None) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Retrieve OHLCV candles from TimescaleDB
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
start_time: Start of time range (optional)
|
||||
end_time: End of time range (optional)
|
||||
limit: Maximum number of candles to return (optional)
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data, indexed by timestamp
|
||||
"""
|
||||
try:
|
||||
# Build query
|
||||
query = """
|
||||
SELECT time, open, high, low, close, volume
|
||||
FROM ohlcv_candles
|
||||
WHERE symbol = %s AND timeframe = %s
|
||||
"""
|
||||
params = [symbol, timeframe]
|
||||
|
||||
# Add time range filter
|
||||
if start_time:
|
||||
query += " AND time >= %s"
|
||||
params.append(start_time)
|
||||
if end_time:
|
||||
query += " AND time <= %s"
|
||||
params.append(end_time)
|
||||
|
||||
# Order by time
|
||||
query += " ORDER BY time DESC"
|
||||
|
||||
# Add limit
|
||||
if limit:
|
||||
query += " LIMIT %s"
|
||||
params.append(limit)
|
||||
|
||||
# Execute query
|
||||
with self.get_connection() as conn:
|
||||
df = pd.read_sql(query, conn, params=params, index_col='time')
|
||||
|
||||
# Sort by time ascending (oldest first)
|
||||
if not df.empty:
|
||||
df = df.sort_index()
|
||||
|
||||
logger.debug(f"Retrieved {len(df)} candles for {symbol} {timeframe}")
|
||||
return df
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving candles for {symbol} {timeframe}: {e}")
|
||||
return None
|
||||
|
||||
def get_recent_candles(self, symbol: str, timeframe: str,
|
||||
limit: int = 1000) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Get most recent candles
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
limit: Number of recent candles to retrieve
|
||||
|
||||
Returns:
|
||||
DataFrame with recent OHLCV data
|
||||
"""
|
||||
return self.get_candles(symbol, timeframe, limit=limit)
|
||||
|
||||
def get_candles_count(self, symbol: str = None, timeframe: str = None) -> int:
|
||||
"""
|
||||
Get count of stored candles
|
||||
|
||||
Args:
|
||||
symbol: Optional symbol filter
|
||||
timeframe: Optional timeframe filter
|
||||
|
||||
Returns:
|
||||
Number of candles stored
|
||||
"""
|
||||
try:
|
||||
query = "SELECT COUNT(*) FROM ohlcv_candles WHERE 1=1"
|
||||
params = []
|
||||
|
||||
if symbol:
|
||||
query += " AND symbol = %s"
|
||||
params.append(symbol)
|
||||
if timeframe:
|
||||
query += " AND timeframe = %s"
|
||||
params.append(timeframe)
|
||||
|
||||
with self.get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(query, params)
|
||||
count = cur.fetchone()[0]
|
||||
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting candles count: {e}")
|
||||
return 0
|
||||
|
||||
def get_storage_stats(self) -> dict:
|
||||
"""
|
||||
Get storage statistics
|
||||
|
||||
Returns:
|
||||
Dictionary with storage stats
|
||||
"""
|
||||
try:
|
||||
with self.get_connection() as conn:
|
||||
with conn.cursor() as cur:
|
||||
# Total candles
|
||||
cur.execute("SELECT COUNT(*) FROM ohlcv_candles")
|
||||
total_candles = cur.fetchone()[0]
|
||||
|
||||
# Candles by symbol
|
||||
cur.execute("""
|
||||
SELECT symbol, COUNT(*) as count
|
||||
FROM ohlcv_candles
|
||||
GROUP BY symbol
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
by_symbol = dict(cur.fetchall())
|
||||
|
||||
# Candles by timeframe
|
||||
cur.execute("""
|
||||
SELECT timeframe, COUNT(*) as count
|
||||
FROM ohlcv_candles
|
||||
GROUP BY timeframe
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
by_timeframe = dict(cur.fetchall())
|
||||
|
||||
# Time range
|
||||
cur.execute("""
|
||||
SELECT MIN(time) as oldest, MAX(time) as newest
|
||||
FROM ohlcv_candles
|
||||
""")
|
||||
oldest, newest = cur.fetchone()
|
||||
|
||||
# Table size
|
||||
cur.execute("""
|
||||
SELECT pg_size_pretty(pg_total_relation_size('ohlcv_candles'))
|
||||
""")
|
||||
table_size = cur.fetchone()[0]
|
||||
|
||||
return {
|
||||
'total_candles': total_candles,
|
||||
'by_symbol': by_symbol,
|
||||
'by_timeframe': by_timeframe,
|
||||
'oldest_candle': oldest,
|
||||
'newest_candle': newest,
|
||||
'table_size': table_size
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting storage stats: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# Global instance
|
||||
_timescale_storage = None
|
||||
|
||||
|
||||
def get_timescale_storage(connection_string: str = None) -> Optional[TimescaleDBStorage]:
|
||||
"""
|
||||
Get global TimescaleDB storage instance
|
||||
|
||||
Args:
|
||||
connection_string: PostgreSQL connection string (optional)
|
||||
|
||||
Returns:
|
||||
TimescaleDBStorage instance or None if unavailable
|
||||
"""
|
||||
global _timescale_storage
|
||||
|
||||
if _timescale_storage is None:
|
||||
try:
|
||||
_timescale_storage = TimescaleDBStorage(connection_string)
|
||||
_timescale_storage.create_tables()
|
||||
logger.info("TimescaleDB storage initialized successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"TimescaleDB storage not available: {e}")
|
||||
_timescale_storage = None
|
||||
|
||||
return _timescale_storage
|
||||
561
core/unified_queryable_storage.py
Normal file
561
core/unified_queryable_storage.py
Normal file
@@ -0,0 +1,561 @@
|
||||
"""
|
||||
Unified Queryable Storage Manager
|
||||
|
||||
Provides a unified interface for queryable data storage with automatic fallback:
|
||||
1. TimescaleDB (preferred) - for production with time-series optimization
|
||||
2. SQLite (fallback) - for development/testing without TimescaleDB
|
||||
|
||||
This avoids data duplication with parquet/cache by providing a single queryable layer
|
||||
that can be reused across multiple training setups.
|
||||
|
||||
Key Features:
|
||||
- Automatic detection and fallback
|
||||
- Unified query interface
|
||||
- Time-series optimized queries
|
||||
- Efficient storage for training data
|
||||
- No duplication with existing cache implementations
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import pandas as pd
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class UnifiedQueryableStorage:
|
||||
"""
|
||||
Unified storage manager with TimescaleDB/SQLite fallback
|
||||
|
||||
Provides queryable storage for:
|
||||
- OHLCV candle data
|
||||
- Prediction records
|
||||
- Training data
|
||||
- Model metrics
|
||||
|
||||
Automatically uses TimescaleDB when available, falls back to SQLite otherwise.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
timescale_connection_string: Optional[str] = None,
|
||||
sqlite_path: str = "data/queryable_storage.db"):
|
||||
"""
|
||||
Initialize unified storage with automatic fallback
|
||||
|
||||
Args:
|
||||
timescale_connection_string: PostgreSQL/TimescaleDB connection string
|
||||
sqlite_path: Path to SQLite database file (fallback)
|
||||
"""
|
||||
self.backend = None
|
||||
self.backend_type = None
|
||||
|
||||
# Try TimescaleDB first
|
||||
if timescale_connection_string:
|
||||
try:
|
||||
from core.timescale_storage import get_timescale_storage
|
||||
self.backend = get_timescale_storage(timescale_connection_string)
|
||||
if self.backend:
|
||||
self.backend_type = "timescale"
|
||||
logger.info("✅ Using TimescaleDB for queryable storage")
|
||||
except Exception as e:
|
||||
logger.warning(f"TimescaleDB not available: {e}")
|
||||
|
||||
# Fallback to SQLite
|
||||
if self.backend is None:
|
||||
try:
|
||||
self.backend = SQLiteQueryableStorage(sqlite_path)
|
||||
self.backend_type = "sqlite"
|
||||
logger.info("✅ Using SQLite for queryable storage (TimescaleDB fallback)")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize SQLite storage: {e}")
|
||||
raise Exception("No queryable storage backend available")
|
||||
|
||||
def store_candles(self, symbol: str, timeframe: str, df: pd.DataFrame) -> bool:
|
||||
"""
|
||||
Store OHLCV candles
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol (e.g., 'ETH/USDT')
|
||||
timeframe: Timeframe (e.g., '1m', '1h', '1d')
|
||||
df: DataFrame with OHLCV data
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
try:
|
||||
if self.backend_type == "timescale":
|
||||
self.backend.store_candles(symbol, timeframe, df)
|
||||
else:
|
||||
self.backend.store_candles(symbol, timeframe, df)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing candles: {e}")
|
||||
return False
|
||||
|
||||
def get_candles(self,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
limit: Optional[int] = None) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Retrieve OHLCV candles with time range filtering
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
start_time: Start of time range (optional)
|
||||
end_time: End of time range (optional)
|
||||
limit: Maximum number of candles (optional)
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data or None
|
||||
"""
|
||||
try:
|
||||
if self.backend_type == "timescale":
|
||||
return self.backend.get_candles(symbol, timeframe, start_time, end_time, limit)
|
||||
else:
|
||||
return self.backend.get_candles(symbol, timeframe, start_time, end_time, limit)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving candles: {e}")
|
||||
return None
|
||||
|
||||
def store_prediction(self, prediction_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Store prediction record for training
|
||||
|
||||
Args:
|
||||
prediction_data: Dictionary with prediction information
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
try:
|
||||
return self.backend.store_prediction(prediction_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing prediction: {e}")
|
||||
return False
|
||||
|
||||
def get_predictions(self,
|
||||
symbol: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query predictions with filtering
|
||||
|
||||
Args:
|
||||
symbol: Filter by symbol (optional)
|
||||
model_name: Filter by model (optional)
|
||||
start_time: Start of time range (optional)
|
||||
end_time: End of time range (optional)
|
||||
limit: Maximum number of records (optional)
|
||||
|
||||
Returns:
|
||||
List of prediction records
|
||||
"""
|
||||
try:
|
||||
return self.backend.get_predictions(symbol, model_name, start_time, end_time, limit)
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving predictions: {e}")
|
||||
return []
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get storage statistics
|
||||
|
||||
Returns:
|
||||
Dictionary with storage stats
|
||||
"""
|
||||
try:
|
||||
stats = self.backend.get_storage_stats()
|
||||
stats['backend_type'] = self.backend_type
|
||||
return stats
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting storage stats: {e}")
|
||||
return {'backend_type': self.backend_type, 'error': str(e)}
|
||||
|
||||
|
||||
class SQLiteQueryableStorage:
|
||||
"""
|
||||
SQLite-based queryable storage (fallback when TimescaleDB unavailable)
|
||||
|
||||
Provides similar functionality to TimescaleDB but using SQLite.
|
||||
Optimized for time-series queries with proper indexing.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: str = "data/queryable_storage.db"):
|
||||
"""
|
||||
Initialize SQLite storage
|
||||
|
||||
Args:
|
||||
db_path: Path to SQLite database file
|
||||
"""
|
||||
self.db_path = Path(db_path)
|
||||
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Initialize database
|
||||
self._create_tables()
|
||||
logger.info(f"SQLite queryable storage initialized: {self.db_path}")
|
||||
|
||||
def _create_tables(self):
|
||||
"""Create SQLite tables with proper indexing"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# OHLCV candles table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS ohlcv_candles (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
symbol TEXT NOT NULL,
|
||||
timeframe TEXT NOT NULL,
|
||||
timestamp INTEGER NOT NULL,
|
||||
open REAL NOT NULL,
|
||||
high REAL NOT NULL,
|
||||
low REAL NOT NULL,
|
||||
close REAL NOT NULL,
|
||||
volume REAL NOT NULL,
|
||||
created_at INTEGER DEFAULT (strftime('%s', 'now')),
|
||||
UNIQUE(symbol, timeframe, timestamp)
|
||||
)
|
||||
""")
|
||||
|
||||
# Indexes for efficient time-series queries
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_ohlcv_symbol_timeframe_timestamp
|
||||
ON ohlcv_candles(symbol, timeframe, timestamp DESC)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_ohlcv_timestamp
|
||||
ON ohlcv_candles(timestamp DESC)
|
||||
""")
|
||||
|
||||
# Predictions table
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS predictions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
prediction_id TEXT UNIQUE NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
timestamp INTEGER NOT NULL,
|
||||
predicted_price REAL,
|
||||
current_price REAL,
|
||||
predicted_direction INTEGER,
|
||||
confidence REAL,
|
||||
timeframe TEXT,
|
||||
outcome_price REAL,
|
||||
outcome_timestamp INTEGER,
|
||||
reward REAL,
|
||||
metadata TEXT,
|
||||
created_at INTEGER DEFAULT (strftime('%s', 'now'))
|
||||
)
|
||||
""")
|
||||
|
||||
# Indexes for prediction queries
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_predictions_symbol_timestamp
|
||||
ON predictions(symbol, timestamp DESC)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_predictions_model_timestamp
|
||||
ON predictions(model_name, timestamp DESC)
|
||||
""")
|
||||
|
||||
cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_predictions_timestamp
|
||||
ON predictions(timestamp DESC)
|
||||
""")
|
||||
|
||||
conn.commit()
|
||||
logger.debug("SQLite tables created successfully")
|
||||
|
||||
def store_candles(self, symbol: str, timeframe: str, df: pd.DataFrame):
|
||||
"""
|
||||
Store OHLCV candles in SQLite
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
df: DataFrame with OHLCV data
|
||||
"""
|
||||
if df is None or df.empty:
|
||||
return
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
# Prepare data
|
||||
df_copy = df.copy()
|
||||
df_copy['symbol'] = symbol
|
||||
df_copy['timeframe'] = timeframe
|
||||
|
||||
# Convert timestamp to Unix timestamp if it's a datetime
|
||||
if pd.api.types.is_datetime64_any_dtype(df_copy.index):
|
||||
df_copy['timestamp'] = df_copy.index.astype('int64') // 10**9
|
||||
else:
|
||||
df_copy['timestamp'] = df_copy.index
|
||||
|
||||
# Reset index to make timestamp a column
|
||||
df_copy = df_copy.reset_index(drop=True)
|
||||
|
||||
# Select only required columns
|
||||
columns = ['symbol', 'timeframe', 'timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
df_insert = df_copy[columns]
|
||||
|
||||
# Insert with REPLACE to handle duplicates
|
||||
df_insert.to_sql('ohlcv_candles', conn, if_exists='append', index=False, method='multi')
|
||||
|
||||
logger.debug(f"Stored {len(df_insert)} candles for {symbol} {timeframe}")
|
||||
|
||||
def get_candles(self,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
limit: Optional[int] = None) -> Optional[pd.DataFrame]:
|
||||
"""
|
||||
Retrieve OHLCV candles from SQLite
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
timeframe: Timeframe
|
||||
start_time: Start of time range
|
||||
end_time: End of time range
|
||||
limit: Maximum number of candles
|
||||
|
||||
Returns:
|
||||
DataFrame with OHLCV data
|
||||
"""
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
# Build query
|
||||
query = """
|
||||
SELECT timestamp, open, high, low, close, volume
|
||||
FROM ohlcv_candles
|
||||
WHERE symbol = ? AND timeframe = ?
|
||||
"""
|
||||
params = [symbol, timeframe]
|
||||
|
||||
# Add time range filters
|
||||
if start_time:
|
||||
query += " AND timestamp >= ?"
|
||||
params.append(int(start_time.timestamp()))
|
||||
|
||||
if end_time:
|
||||
query += " AND timestamp <= ?"
|
||||
params.append(int(end_time.timestamp()))
|
||||
|
||||
# Order by timestamp
|
||||
query += " ORDER BY timestamp DESC"
|
||||
|
||||
# Add limit
|
||||
if limit:
|
||||
query += " LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
# Execute query
|
||||
df = pd.read_sql_query(query, conn, params=params)
|
||||
|
||||
if df.empty:
|
||||
return None
|
||||
|
||||
# Convert timestamp to datetime and set as index
|
||||
df['timestamp'] = pd.to_datetime(df['timestamp'], unit='s')
|
||||
df.set_index('timestamp', inplace=True)
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
return df
|
||||
|
||||
def store_prediction(self, prediction_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Store prediction record
|
||||
|
||||
Args:
|
||||
prediction_data: Dictionary with prediction information
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Extract fields
|
||||
prediction_id = prediction_data.get('prediction_id')
|
||||
symbol = prediction_data.get('symbol')
|
||||
model_name = prediction_data.get('model_name')
|
||||
timestamp = prediction_data.get('timestamp')
|
||||
|
||||
# Convert datetime to Unix timestamp
|
||||
if isinstance(timestamp, datetime):
|
||||
timestamp = int(timestamp.timestamp())
|
||||
|
||||
# Prepare metadata
|
||||
metadata = {k: v for k, v in prediction_data.items()
|
||||
if k not in ['prediction_id', 'symbol', 'model_name', 'timestamp',
|
||||
'predicted_price', 'current_price', 'predicted_direction',
|
||||
'confidence', 'timeframe', 'outcome_price',
|
||||
'outcome_timestamp', 'reward']}
|
||||
|
||||
# Insert prediction
|
||||
cursor.execute("""
|
||||
INSERT OR REPLACE INTO predictions
|
||||
(prediction_id, symbol, model_name, timestamp, predicted_price,
|
||||
current_price, predicted_direction, confidence, timeframe,
|
||||
outcome_price, outcome_timestamp, reward, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
prediction_id,
|
||||
symbol,
|
||||
model_name,
|
||||
timestamp,
|
||||
prediction_data.get('predicted_price'),
|
||||
prediction_data.get('current_price'),
|
||||
prediction_data.get('predicted_direction'),
|
||||
prediction_data.get('confidence'),
|
||||
prediction_data.get('timeframe'),
|
||||
prediction_data.get('outcome_price'),
|
||||
prediction_data.get('outcome_timestamp'),
|
||||
prediction_data.get('reward'),
|
||||
json.dumps(metadata)
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error storing prediction: {e}")
|
||||
return False
|
||||
|
||||
def get_predictions(self,
|
||||
symbol: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
start_time: Optional[datetime] = None,
|
||||
end_time: Optional[datetime] = None,
|
||||
limit: Optional[int] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Query predictions with filtering
|
||||
|
||||
Args:
|
||||
symbol: Filter by symbol
|
||||
model_name: Filter by model
|
||||
start_time: Start of time range
|
||||
end_time: End of time range
|
||||
limit: Maximum number of records
|
||||
|
||||
Returns:
|
||||
List of prediction records
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Build query
|
||||
query = "SELECT * FROM predictions WHERE 1=1"
|
||||
params = []
|
||||
|
||||
if symbol:
|
||||
query += " AND symbol = ?"
|
||||
params.append(symbol)
|
||||
|
||||
if model_name:
|
||||
query += " AND model_name = ?"
|
||||
params.append(model_name)
|
||||
|
||||
if start_time:
|
||||
query += " AND timestamp >= ?"
|
||||
params.append(int(start_time.timestamp()))
|
||||
|
||||
if end_time:
|
||||
query += " AND timestamp <= ?"
|
||||
params.append(int(end_time.timestamp()))
|
||||
|
||||
query += " ORDER BY timestamp DESC"
|
||||
|
||||
if limit:
|
||||
query += " LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
cursor.execute(query, params)
|
||||
rows = cursor.fetchall()
|
||||
|
||||
# Convert to list of dicts
|
||||
predictions = []
|
||||
for row in rows:
|
||||
pred = dict(row)
|
||||
# Parse metadata JSON
|
||||
if pred.get('metadata'):
|
||||
try:
|
||||
pred['metadata'] = json.loads(pred['metadata'])
|
||||
except:
|
||||
pass
|
||||
predictions.append(pred)
|
||||
|
||||
return predictions
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying predictions: {e}")
|
||||
return []
|
||||
|
||||
def get_storage_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get storage statistics
|
||||
|
||||
Returns:
|
||||
Dictionary with storage stats
|
||||
"""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get table sizes
|
||||
cursor.execute("SELECT COUNT(*) FROM ohlcv_candles")
|
||||
candles_count = cursor.fetchone()[0]
|
||||
|
||||
cursor.execute("SELECT COUNT(*) FROM predictions")
|
||||
predictions_count = cursor.fetchone()[0]
|
||||
|
||||
# Get database file size
|
||||
db_size = self.db_path.stat().st_size if self.db_path.exists() else 0
|
||||
|
||||
return {
|
||||
'candles_count': candles_count,
|
||||
'predictions_count': predictions_count,
|
||||
'database_size_bytes': db_size,
|
||||
'database_size_mb': db_size / (1024 * 1024),
|
||||
'database_path': str(self.db_path)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting storage stats: {e}")
|
||||
return {'error': str(e)}
|
||||
|
||||
|
||||
# Global instance
|
||||
_unified_storage = None
|
||||
|
||||
|
||||
def get_unified_storage(timescale_connection_string: Optional[str] = None,
|
||||
sqlite_path: str = "data/queryable_storage.db") -> UnifiedQueryableStorage:
|
||||
"""
|
||||
Get global unified storage instance
|
||||
|
||||
Args:
|
||||
timescale_connection_string: PostgreSQL/TimescaleDB connection string
|
||||
sqlite_path: Path to SQLite database file (fallback)
|
||||
|
||||
Returns:
|
||||
UnifiedQueryableStorage instance
|
||||
"""
|
||||
global _unified_storage
|
||||
|
||||
if _unified_storage is None:
|
||||
_unified_storage = UnifiedQueryableStorage(timescale_connection_string, sqlite_path)
|
||||
logger.info(f"Unified queryable storage initialized: {_unified_storage.backend_type}")
|
||||
|
||||
return _unified_storage
|
||||
486
core/unified_training_manager_v2.py
Normal file
486
core/unified_training_manager_v2.py
Normal file
@@ -0,0 +1,486 @@
|
||||
"""
|
||||
Unified Training Manager V2 (Refactored)
|
||||
|
||||
Combines UnifiedTrainingManager and EnhancedRLTrainingAdapter into a single,
|
||||
comprehensive training system that handles:
|
||||
- Periodic training loops (DQN, COB RL, CNN)
|
||||
- Reward-driven training with EnhancedRewardCalculator
|
||||
- Multi-timeframe training coordination
|
||||
- Batch processing and statistics tracking
|
||||
- Inference coordination (optional)
|
||||
|
||||
This eliminates duplication and provides a single entry point for all training.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union, Tuple
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingBatch:
|
||||
"""Training batch for RL models with enhanced reward data"""
|
||||
model_name: str
|
||||
symbol: str
|
||||
timeframe: str
|
||||
states: List[np.ndarray]
|
||||
actions: List[int]
|
||||
rewards: List[float]
|
||||
next_states: List[np.ndarray]
|
||||
dones: List[bool]
|
||||
confidences: List[float]
|
||||
metadata: Dict[str, Any]
|
||||
batch_timestamp: datetime
|
||||
|
||||
|
||||
class UnifiedTrainingManager:
|
||||
"""
|
||||
Unified training controller that combines periodic and reward-driven training
|
||||
|
||||
Features:
|
||||
- Periodic training loops for DQN, COB RL, CNN
|
||||
- Reward-driven training with EnhancedRewardCalculator
|
||||
- Multi-timeframe training coordination
|
||||
- Batch processing and statistics
|
||||
- Inference coordination (optional)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
orchestrator: Any,
|
||||
reward_system: Any = None,
|
||||
inference_coordinator: Any = None,
|
||||
# Periodic training intervals
|
||||
dqn_interval_s: int = 5,
|
||||
cob_rl_interval_s: int = 1,
|
||||
cnn_interval_s: int = 10,
|
||||
# Batch configuration
|
||||
min_dqn_experiences: int = 16,
|
||||
min_batch_size: int = 8,
|
||||
max_batch_size: int = 64,
|
||||
# Reward-driven training
|
||||
reward_training_interval_s: int = 2,
|
||||
):
|
||||
"""
|
||||
Initialize unified training manager
|
||||
|
||||
Args:
|
||||
orchestrator: Trading orchestrator with models
|
||||
reward_system: Enhanced reward system (optional)
|
||||
inference_coordinator: Timeframe inference coordinator (optional)
|
||||
dqn_interval_s: DQN training interval
|
||||
cob_rl_interval_s: COB RL training interval
|
||||
cnn_interval_s: CNN training interval
|
||||
min_dqn_experiences: Minimum experiences before DQN training
|
||||
min_batch_size: Minimum batch size for reward-driven training
|
||||
max_batch_size: Maximum batch size for reward-driven training
|
||||
reward_training_interval_s: Reward-driven training check interval
|
||||
"""
|
||||
self.orchestrator = orchestrator
|
||||
self.reward_system = reward_system
|
||||
self.inference_coordinator = inference_coordinator
|
||||
|
||||
# Training intervals
|
||||
self.dqn_interval_s = dqn_interval_s
|
||||
self.cob_rl_interval_s = cob_rl_interval_s
|
||||
self.cnn_interval_s = cnn_interval_s
|
||||
self.reward_training_interval_s = reward_training_interval_s
|
||||
|
||||
# Batch configuration
|
||||
self.min_dqn_experiences = min_dqn_experiences
|
||||
self.min_batch_size = min_batch_size
|
||||
self.max_batch_size = max_batch_size
|
||||
|
||||
# Training statistics
|
||||
self.training_stats = {
|
||||
'total_training_batches': 0,
|
||||
'successful_training_calls': 0,
|
||||
'failed_training_calls': 0,
|
||||
'last_training_time': None,
|
||||
'training_times_per_model': {},
|
||||
'average_batch_sizes': {},
|
||||
'periodic_training_counts': {
|
||||
'dqn': 0,
|
||||
'cob_rl': 0,
|
||||
'cnn': 0
|
||||
},
|
||||
'reward_driven_training_count': 0
|
||||
}
|
||||
|
||||
# Thread safety
|
||||
self.lock = threading.RLock()
|
||||
|
||||
# Running state
|
||||
self.running = False
|
||||
self._tasks: List[asyncio.Task] = []
|
||||
|
||||
logger.info("UnifiedTrainingManager V2 initialized")
|
||||
|
||||
# Register inference wrappers if coordinator available
|
||||
if self.inference_coordinator:
|
||||
self._register_inference_wrappers()
|
||||
|
||||
def _register_inference_wrappers(self):
|
||||
"""Register inference wrappers with coordinator"""
|
||||
try:
|
||||
# Register model inference functions
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'dqn_agent', self._dqn_inference_wrapper
|
||||
)
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'cob_rl', self._cob_rl_inference_wrapper
|
||||
)
|
||||
self.inference_coordinator.register_model_inference_function(
|
||||
'enhanced_cnn', self._cnn_inference_wrapper
|
||||
)
|
||||
logger.info("Inference wrappers registered with coordinator")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not register inference wrappers: {e}")
|
||||
|
||||
async def start(self):
|
||||
"""Start all training loops"""
|
||||
if self.running:
|
||||
logger.warning("UnifiedTrainingManager already running")
|
||||
return
|
||||
|
||||
self.running = True
|
||||
logger.info("UnifiedTrainingManager started")
|
||||
|
||||
# Start periodic training loops
|
||||
self._tasks.append(asyncio.create_task(self._dqn_trainer_loop()))
|
||||
self._tasks.append(asyncio.create_task(self._cob_rl_trainer_loop()))
|
||||
self._tasks.append(asyncio.create_task(self._cnn_trainer_loop()))
|
||||
|
||||
# Start reward-driven training if reward system available
|
||||
if self.reward_system is not None:
|
||||
self._tasks.append(asyncio.create_task(self._reward_driven_training_loop()))
|
||||
logger.info("Reward-driven training enabled")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop all training loops"""
|
||||
if not self.running:
|
||||
return
|
||||
|
||||
self.running = False
|
||||
|
||||
# Cancel all tasks
|
||||
for t in self._tasks:
|
||||
t.cancel()
|
||||
|
||||
# Wait for tasks to complete
|
||||
await asyncio.gather(*self._tasks, return_exceptions=True)
|
||||
self._tasks.clear()
|
||||
|
||||
logger.info("UnifiedTrainingManager stopped")
|
||||
|
||||
# ========================================================================
|
||||
# PERIODIC TRAINING LOOPS
|
||||
# ========================================================================
|
||||
|
||||
async def _dqn_trainer_loop(self):
|
||||
"""Periodic DQN training loop"""
|
||||
while self.running:
|
||||
try:
|
||||
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
|
||||
if rl_agent and hasattr(rl_agent, 'memory') and hasattr(rl_agent, 'replay'):
|
||||
if len(rl_agent.memory) >= self.min_dqn_experiences:
|
||||
loss = rl_agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"DQN periodic training loss: {loss:.6f}")
|
||||
self._update_periodic_training_stats('dqn', loss)
|
||||
|
||||
await asyncio.sleep(self.dqn_interval_s)
|
||||
except Exception as e:
|
||||
logger.error(f"DQN trainer loop error: {e}")
|
||||
await asyncio.sleep(self.dqn_interval_s)
|
||||
|
||||
async def _cob_rl_trainer_loop(self):
|
||||
"""Periodic COB RL training loop"""
|
||||
while self.running:
|
||||
try:
|
||||
cob_agent = getattr(self.orchestrator, 'cob_rl_agent', None)
|
||||
if cob_agent and hasattr(cob_agent, 'replay') and hasattr(cob_agent, 'memory'):
|
||||
if len(getattr(cob_agent, 'memory', [])) >= 8:
|
||||
loss = cob_agent.replay()
|
||||
if loss is not None:
|
||||
logger.debug(f"COB RL periodic training loss: {loss:.6f}")
|
||||
self._update_periodic_training_stats('cob_rl', loss)
|
||||
|
||||
await asyncio.sleep(self.cob_rl_interval_s)
|
||||
except Exception as e:
|
||||
logger.error(f"COB RL trainer loop error: {e}")
|
||||
await asyncio.sleep(self.cob_rl_interval_s)
|
||||
|
||||
async def _cnn_trainer_loop(self):
|
||||
"""Periodic CNN training loop"""
|
||||
while self.running:
|
||||
try:
|
||||
# Hook to CNN trainer if available
|
||||
cnn_model = getattr(self.orchestrator, 'cnn_model', None)
|
||||
if cnn_model and hasattr(cnn_model, 'train_step'):
|
||||
# CNN training would go here
|
||||
pass
|
||||
|
||||
await asyncio.sleep(self.cnn_interval_s)
|
||||
except Exception as e:
|
||||
logger.error(f"CNN trainer loop error: {e}")
|
||||
await asyncio.sleep(self.cnn_interval_s)
|
||||
|
||||
# ========================================================================
|
||||
# REWARD-DRIVEN TRAINING
|
||||
# ========================================================================
|
||||
|
||||
async def _reward_driven_training_loop(self):
|
||||
"""Reward-driven training loop using EnhancedRewardCalculator"""
|
||||
while self.running:
|
||||
try:
|
||||
# Get reward calculator
|
||||
reward_calculator = getattr(self.reward_system, 'reward_calculator', None)
|
||||
if not reward_calculator:
|
||||
await asyncio.sleep(self.reward_training_interval_s)
|
||||
continue
|
||||
|
||||
# Get symbols to train on
|
||||
symbols = getattr(reward_calculator, 'symbols', [])
|
||||
|
||||
# Import TimeFrame enum
|
||||
try:
|
||||
from core.enhanced_reward_calculator import TimeFrame
|
||||
timeframes = [TimeFrame.SECONDS_1, TimeFrame.MINUTES_1,
|
||||
TimeFrame.HOURS_1, TimeFrame.DAYS_1]
|
||||
except ImportError:
|
||||
timeframes = ['1s', '1m', '1h', '1d']
|
||||
|
||||
# Process each symbol and timeframe
|
||||
for symbol in symbols:
|
||||
for timeframe in timeframes:
|
||||
# Get training data
|
||||
training_data = reward_calculator.get_training_data(
|
||||
symbol, timeframe, self.max_batch_size
|
||||
)
|
||||
|
||||
if len(training_data) >= self.min_batch_size:
|
||||
await self._process_reward_training_batch(
|
||||
symbol, timeframe, training_data
|
||||
)
|
||||
|
||||
await asyncio.sleep(self.reward_training_interval_s)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Reward-driven training loop error: {e}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
async def _process_reward_training_batch(self, symbol: str, timeframe: Any,
|
||||
training_data: List[Tuple[Any, float]]):
|
||||
"""Process reward-driven training batch"""
|
||||
try:
|
||||
# Group by model
|
||||
model_batches = {}
|
||||
|
||||
for prediction_record, reward in training_data:
|
||||
model_name = getattr(prediction_record, 'model_name', 'unknown')
|
||||
if model_name not in model_batches:
|
||||
model_batches[model_name] = []
|
||||
model_batches[model_name].append((prediction_record, reward))
|
||||
|
||||
# Train each model
|
||||
for model_name, model_data in model_batches.items():
|
||||
if len(model_data) >= self.min_batch_size:
|
||||
await self._train_model_with_rewards(
|
||||
model_name, symbol, timeframe, model_data
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing reward training batch: {e}")
|
||||
|
||||
async def _train_model_with_rewards(self, model_name: str, symbol: str,
|
||||
timeframe: Any, training_data: List[Tuple[Any, float]]):
|
||||
"""Train model with reward-evaluated data"""
|
||||
try:
|
||||
training_start = time.time()
|
||||
|
||||
# Route to appropriate model
|
||||
if 'dqn' in model_name.lower():
|
||||
success = await self._train_dqn_with_rewards(training_data)
|
||||
elif 'cob' in model_name.lower():
|
||||
success = await self._train_cob_rl_with_rewards(training_data)
|
||||
elif 'cnn' in model_name.lower():
|
||||
success = await self._train_cnn_with_rewards(training_data)
|
||||
else:
|
||||
logger.warning(f"Unknown model type: {model_name}")
|
||||
return
|
||||
|
||||
training_time = time.time() - training_start
|
||||
|
||||
if success:
|
||||
with self.lock:
|
||||
self.training_stats['reward_driven_training_count'] += 1
|
||||
logger.info(f"Reward-driven training: {model_name} on {symbol} "
|
||||
f"with {len(training_data)} samples in {training_time:.3f}s")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in reward-driven training for {model_name}: {e}")
|
||||
|
||||
async def _train_dqn_with_rewards(self, training_data: List[Tuple[Any, float]]) -> bool:
|
||||
"""Train DQN with reward-evaluated data"""
|
||||
try:
|
||||
rl_agent = getattr(self.orchestrator, 'rl_agent', None)
|
||||
if not rl_agent or not hasattr(rl_agent, 'remember'):
|
||||
return False
|
||||
|
||||
# Add experiences to memory
|
||||
for prediction_record, reward in training_data:
|
||||
# Get state vector from prediction record
|
||||
state = getattr(prediction_record, 'state_vector', None)
|
||||
if not state:
|
||||
continue
|
||||
|
||||
# Convert direction to action
|
||||
direction = getattr(prediction_record, 'predicted_direction', 0)
|
||||
action = direction + 1 # Convert -1,0,1 to 0,1,2
|
||||
|
||||
# Add to memory
|
||||
rl_agent.remember(state, action, reward, state, True)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training DQN with rewards: {e}")
|
||||
return False
|
||||
|
||||
async def _train_cob_rl_with_rewards(self, training_data: List[Tuple[Any, float]]) -> bool:
|
||||
"""Train COB RL with reward-evaluated data"""
|
||||
try:
|
||||
cob_agent = getattr(self.orchestrator, 'cob_rl_agent', None)
|
||||
if not cob_agent or not hasattr(cob_agent, 'remember'):
|
||||
return False
|
||||
|
||||
# Similar to DQN training
|
||||
for prediction_record, reward in training_data:
|
||||
state = getattr(prediction_record, 'state_vector', None)
|
||||
if not state:
|
||||
continue
|
||||
|
||||
direction = getattr(prediction_record, 'predicted_direction', 0)
|
||||
action = direction + 1
|
||||
|
||||
cob_agent.remember(state, action, reward, state, True)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training COB RL with rewards: {e}")
|
||||
return False
|
||||
|
||||
async def _train_cnn_with_rewards(self, training_data: List[Tuple[Any, float]]) -> bool:
|
||||
"""Train CNN with reward-evaluated data"""
|
||||
try:
|
||||
# CNN training with rewards would go here
|
||||
# This depends on CNN's training interface
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error training CNN with rewards: {e}")
|
||||
return False
|
||||
|
||||
# ========================================================================
|
||||
# INFERENCE WRAPPERS (Optional - for TimeframeInferenceCoordinator)
|
||||
# ========================================================================
|
||||
|
||||
async def _dqn_inference_wrapper(self, context: Any) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for DQN model inference"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'rl_agent'):
|
||||
# Get base data
|
||||
base_data = await self._get_base_data(context.symbol)
|
||||
if base_data is None:
|
||||
return None
|
||||
|
||||
# Convert to state
|
||||
state = self._convert_to_dqn_state(base_data, context)
|
||||
|
||||
# Run prediction
|
||||
if hasattr(self.orchestrator.rl_agent, 'act'):
|
||||
action_idx = self.orchestrator.rl_agent.act(state)
|
||||
confidence = getattr(self.orchestrator.rl_agent, 'last_confidence', 0.5)
|
||||
|
||||
action_names = ['SELL', 'HOLD', 'BUY']
|
||||
direction = action_idx - 1
|
||||
|
||||
current_price = self._safe_get_current_price(context.symbol)
|
||||
|
||||
return {
|
||||
'predicted_price': current_price,
|
||||
'current_price': current_price,
|
||||
'direction': direction,
|
||||
'confidence': float(confidence),
|
||||
'action': action_names[action_idx],
|
||||
'model_state': (state.tolist() if hasattr(state, 'tolist') else state),
|
||||
'context': context
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in DQN inference wrapper: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def _cob_rl_inference_wrapper(self, context: Any) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for COB RL model inference"""
|
||||
# Implementation similar to EnhancedRLTrainingAdapter
|
||||
return None
|
||||
|
||||
async def _cnn_inference_wrapper(self, context: Any) -> Optional[Dict[str, Any]]:
|
||||
"""Wrapper for CNN model inference"""
|
||||
# Implementation similar to EnhancedRLTrainingAdapter
|
||||
return None
|
||||
|
||||
# ========================================================================
|
||||
# HELPER METHODS
|
||||
# ========================================================================
|
||||
|
||||
async def _get_base_data(self, symbol: str) -> Optional[Any]:
|
||||
"""Get base data for a symbol"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, '_build_base_data'):
|
||||
return await self.orchestrator._build_base_data(symbol)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting base data: {e}")
|
||||
return None
|
||||
|
||||
def _safe_get_current_price(self, symbol: str) -> float:
|
||||
"""Get current price safely"""
|
||||
try:
|
||||
if self.orchestrator and hasattr(self.orchestrator, 'data_provider'):
|
||||
price = self.orchestrator.data_provider.get_current_price(symbol)
|
||||
return float(price) if price is not None else 0.0
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting current price: {e}")
|
||||
return 0.0
|
||||
|
||||
def _convert_to_dqn_state(self, base_data: Any, context: Any) -> np.ndarray:
|
||||
"""Convert base data to DQN state"""
|
||||
try:
|
||||
feature_vector = base_data.get_feature_vector() if hasattr(base_data, 'get_feature_vector') else []
|
||||
if feature_vector:
|
||||
return np.array(feature_vector, dtype=np.float32)
|
||||
return np.zeros(100, dtype=np.float32)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting to DQN state: {e}")
|
||||
return np.zeros(100, dtype=np.float32)
|
||||
|
||||
def _update_periodic_training_stats(self, model_type: str, loss: float):
|
||||
"""Update periodic training statistics"""
|
||||
with self.lock:
|
||||
self.training_stats['periodic_training_counts'][model_type] += 1
|
||||
self.training_stats['last_training_time'] = datetime.now().isoformat()
|
||||
|
||||
def get_training_statistics(self) -> Dict[str, Any]:
|
||||
"""Get training statistics"""
|
||||
with self.lock:
|
||||
return self.training_stats.copy()
|
||||
Reference in New Issue
Block a user