This commit is contained in:
Dobromir Popov
2025-03-29 04:09:03 +02:00
parent 43803caaf1
commit 8b3db10a85
3 changed files with 307 additions and 267 deletions

View File

@ -61,6 +61,10 @@ class DataInterface:
"""
cache_file = os.path.join(self.data_dir, f"{self.symbol.replace('/', '_')}_{timeframe}.csv")
# For 1s timeframe, always fetch fresh data
if timeframe == '1s':
use_cache = False
# Check if cached data exists and is recent
if use_cache and os.path.exists(cache_file):
try:
@ -74,18 +78,18 @@ class DataInterface:
logger.error(f"Error reading cached data: {str(e)}")
# If we get here, we need to fetch data
# For now, we'll use a placeholder for fetching data from an exchange
try:
# In a real implementation, we would fetch data from an exchange or API here
# For this example, we'll create dummy data if we can't load from cache
logger.info(f"Fetching historical data for {self.symbol} {timeframe}")
# For 1s timeframe, we need more data points
if timeframe == '1s':
n_candles = min(n_candles * 60, 10000) # Up to 10k ticks
# Placeholder for real data fetching
# In a real implementation, this would be replaced with API calls
self._fetch_data_from_exchange(timeframe, n_candles)
# Save to cache
if self.dataframes[timeframe] is not None:
# Save to cache (except for 1s timeframe)
if self.dataframes[timeframe] is not None and timeframe != '1s':
self.dataframes[timeframe].to_csv(cache_file, index=False)
return self.dataframes[timeframe]
else:
@ -122,6 +126,7 @@ class DataInterface:
"""
# Map timeframe to seconds
tf_seconds = {
'1s': 1, # Added 1s timeframe
'1m': 60,
'5m': 300,
'15m': 900,
@ -207,59 +212,62 @@ class DataInterface:
# Get data for all requested timeframes
dfs = {}
min_length = float('inf')
for tf in timeframes:
df = self.get_historical_data(timeframe=tf, n_candles=n_candles)
# For 1s timeframe, we need more data points
tf_candles = n_candles * 60 if tf == '1s' else n_candles
df = self.get_historical_data(timeframe=tf, n_candles=tf_candles)
if df is not None and not df.empty:
dfs[tf] = df
# Keep track of minimum length across all timeframes
min_length = min(min_length, len(df))
if not dfs:
logger.error("No data available for feature creation")
return None, None, None
# Align all dataframes to the same length
for tf in dfs:
dfs[tf] = dfs[tf].tail(min_length)
# Create features for each timeframe
features = []
targets = []
timestamps = []
targets = None
timestamps = None
for tf in timeframes:
if tf in dfs:
X, y, ts = self._create_features(dfs[tf], window_size)
if X is not None and y is not None:
features.append(X)
if len(targets) == 0: # Only need targets from one timeframe
if targets is None: # Only need targets from one timeframe
targets = y
timestamps = ts
if not features:
if not features or targets is None:
logger.error("Failed to create features for any timeframe")
return None, None, None
# Stack features from all timeframes along the time dimension
# Reshape each timeframe's features to [samples, window, 1, features]
reshaped_features = [f.reshape(f.shape[0], f.shape[1], 1, f.shape[2])
for f in features]
# Concatenate along the channel dimension
X = np.concatenate(reshaped_features, axis=2)
# Reshape to [samples, window, features*timeframes]
X = X.reshape(X.shape[0], X.shape[1], -1)
# Ensure all feature arrays have the same length
min_samples = min(f.shape[0] for f in features)
features = [f[-min_samples:] for f in features]
targets = targets[-min_samples:]
timestamps = timestamps[-min_samples:]
# Stack features from all timeframes
X = np.concatenate([f.reshape(min_samples, window_size, -1) for f in features], axis=2)
# Validate data
if np.any(np.isnan(X)) or np.any(np.isinf(X)):
logger.error("Generated features contain NaN or infinite values")
return None, None, None
# Ensure all values are finite and normalized
X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=-1.0)
X = np.clip(X, -1e6, 1e6) # Clip extreme values
# Log data shapes for debugging
logger.info(f"Prepared input data - X shape: {X.shape}, y shape: {np.array(targets).shape}")
logger.info(f"Prepared input data - X shape: {X.shape}, y shape: {targets.shape}")
return X, targets, timestamps
def _create_features(self, df, window_size):
"""
Create features from OHLCV data using a sliding window approach.
Create features from OHLCV data using a sliding window.
Args:
df (pd.DataFrame): DataFrame with OHLCV data
@ -267,76 +275,34 @@ class DataInterface:
Returns:
tuple: (X, y, timestamps) where:
X is the input features array
X is the feature array
y is the target array
timestamps is an array of timestamps for each sample
timestamps is the array of timestamps
"""
# Extract OHLCV columns
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values
# Validate data before scaling
if np.any(np.isnan(ohlcv)) or np.any(np.isinf(ohlcv)):
logger.error("Input data contains NaN or infinite values")
return None, None, None
# Handle potential constant columns (avoid division by zero in scaler)
ohlcv = np.nan_to_num(ohlcv, nan=0.0)
ranges = np.ptp(ohlcv, axis=0)
for i in range(len(ranges)):
if ranges[i] == 0: # Constant column
ohlcv[:, i] = 1 if i == 3 else 0 # Set close to 1, others to 0
# Scale the data with safety checks
try:
scaler = MinMaxScaler()
ohlcv_scaled = scaler.fit_transform(ohlcv)
if np.any(np.isnan(ohlcv_scaled)) or np.any(np.isinf(ohlcv_scaled)):
logger.error("Scaling produced invalid values")
return None, None, None
except Exception as e:
logger.error(f"Scaling failed: {str(e)}")
if len(df) < window_size + 1:
logger.error(f"Not enough data for feature creation (need {window_size + 1}, got {len(df)})")
return None, None, None
# Store the scaler for later use
timeframe = next((tf for tf in self.timeframes if self.dataframes.get(tf) is not None and
self.dataframes[tf].equals(df)), 'unknown')
self.scalers[timeframe] = scaler
# Extract OHLCV data
data = df[['open', 'high', 'low', 'close', 'volume']].values
timestamps = df['timestamp'].values
# Create sliding windows
X = []
y = []
timestamps = []
X = np.array([data[i:i+window_size] for i in range(len(data)-window_size)])
for i in range(len(ohlcv_scaled) - window_size):
# Input: window_size candles of OHLCV data
window = ohlcv_scaled[i:i+window_size]
# Validate window data
if np.any(np.isnan(window)) or np.any(np.isinf(window)):
continue
X.append(window)
# Target: binary classification - price goes up (1) or down (0)
# 1 if close price increases in the next candle, 0 otherwise
price_change = ohlcv[i+window_size, 3] - ohlcv[i+window_size-1, 3]
y.append(1 if price_change > 0 else 0)
# Store timestamp for reference
timestamps.append(df['timestamp'].iloc[i+window_size])
# Create targets (next candle's movement: 0=down, 1=neutral, 2=up)
next_close = data[window_size:, 3] # Close prices
curr_close = data[window_size-1:-1, 3]
price_changes = (next_close - curr_close) / curr_close
if not X:
logger.error("No valid windows created")
return None, None, None
X = np.array(X)
y = np.array(y)
timestamps = np.array(timestamps)
# Define thresholds for price movement classification
threshold = 0.001 # 0.1% threshold
y = np.zeros(len(price_changes), dtype=int)
y[price_changes > threshold] = 2 # Up
y[(price_changes >= -threshold) & (price_changes <= threshold)] = 1 # Neutral
# Log shapes for debugging
logger.info(f"Created features - X shape: {X.shape}, y shape: {y.shape}")
return X, y, timestamps
return X, y, timestamps[window_size:]
def generate_training_dataset(self, timeframes=None, n_candles=1000, window_size=20):
"""
@ -406,21 +372,28 @@ class DataInterface:
return dataset_info
def get_feature_count(self):
"""Get the number of features per input sample"""
# OHLCV (5 features) per timeframe
return 5 * len(self.timeframes)
"""
Calculate total number of features across all timeframes.
Returns:
int: Total number of features (5 features per timeframe)
"""
return len(self.timeframes) * 5 # OHLCV features for each timeframe
def calculate_pnl(self, predictions, actual_prices, position_size=1.0):
"""
Calculate PnL based on predictions and actual price movements.
Calculate PnL and win rates based on predictions and actual price movements.
Args:
predictions (np.array): Model predictions (0: sell, 1: hold, 2: buy)
actual_prices (np.array): Actual price data
position_size (float): Size of the position to trade
predictions: Array of predicted actions (0=SELL, 1=HOLD, 2=BUY) or probabilities
actual_prices: Array of actual close prices
position_size: Position size for each trade
Returns:
tuple: (total_pnl, win_rate, trade_history)
tuple: (pnl, win_rate, trades) where:
pnl is the total profit and loss
win_rate is the ratio of winning trades
trades is a list of trade dictionaries
"""
if len(predictions) != len(actual_prices) - 1:
logger.error("Predictions and prices length mismatch")
@ -468,36 +441,85 @@ class DataInterface:
win_rate = wins / trades if trades > 0 else 0.0
return pnl, win_rate, trade_history
def prepare_training_data(self, refresh=False, refresh_interval=300):
def get_future_prices(self, prices, n_candles=3):
"""
Prepare training and validation data with optional refresh.
Extract future prices for use in retrospective training.
Args:
refresh (bool): Whether to force refresh data
refresh_interval (int): Minimum seconds between refreshes
prices: Array of close prices
n_candles: Number of future candles to predict
Returns:
tuple: (X_train, y_train, X_val, y_val, prices) numpy arrays
numpy.ndarray: Array of future prices for each sample
"""
if prices is None or len(prices) <= n_candles:
logger.warning(f"Not enough price data for future prediction: {len(prices) if prices is not None else 0} prices")
# Return zeros if not enough data
return np.zeros((len(prices) if prices is not None else 0, 1))
# For each price point i, get the price at i+n_candles
future_prices = np.zeros((len(prices), 1))
for i in range(len(prices) - n_candles):
future_prices[i, 0] = prices[i + n_candles]
# For the last n_candles positions, we don't have future data
# We'll use the last known price as a placeholder
for i in range(len(prices) - n_candles, len(prices)):
future_prices[i, 0] = prices[-1]
return future_prices
def prepare_training_data(self, refresh=False, refresh_interval=300):
"""
Prepare data for training, including splitting into train/validation sets.
Args:
refresh (bool): Whether to refresh the data cache
refresh_interval (int): Interval in seconds to refresh data
Returns:
tuple: (X_train, y_train, X_val, y_val, train_prices, val_prices)
"""
current_time = datetime.now()
if refresh or (current_time - getattr(self, 'last_refresh', datetime.min)).total_seconds() > refresh_interval:
# Check if we should refresh the data
if refresh or not hasattr(self, 'last_refresh_time') or \
(current_time - self.last_refresh_time).total_seconds() > refresh_interval:
logger.info("Refreshing training data...")
for tf in self.timeframes:
self.get_historical_data(timeframe=tf, n_candles=1000, use_cache=False)
self.last_refresh = current_time
# Get all data
self.last_refresh_time = current_time
else:
# Use cached data
if hasattr(self, 'cached_train_data'):
return self.cached_train_data
# Prepare input data
X, y, _ = self.prepare_nn_input()
if X is None:
return None, None, None, None, None
return None, None, None, None, None, None
# Get price data for PnL calculation
prices = self.dataframes[self.timeframes[0]]['close'].values
# Split into train/validation (80/20)
raw_prices = []
for tf in self.timeframes:
if tf in self.dataframes and self.dataframes[tf] is not None:
# Get the close prices for the same period as X
prices = self.dataframes[tf]['close'].values[-len(X):]
if len(prices) == len(X):
raw_prices = prices
break
if len(raw_prices) != len(X):
raw_prices = np.zeros(len(X)) # Fallback if no prices available
# Split data into training and validation sets (80/20)
split_idx = int(len(X) * 0.8)
return (X[:split_idx], y[:split_idx], X[split_idx:], y[split_idx:],
prices[:split_idx], prices[split_idx:])
X_train, X_val = X[:split_idx], X[split_idx:]
y_train, y_val = y[:split_idx], y[split_idx:]
train_prices, val_prices = raw_prices[:split_idx], raw_prices[split_idx:]
# Cache the data
self.cached_train_data = (X_train, y_train, X_val, y_val, train_prices, val_prices)
return X_train, y_train, X_val, y_val, train_prices, val_prices
def prepare_realtime_input(self, timeframe='1h', n_candles=30, window_size=20):
"""