training
This commit is contained in:
@ -61,6 +61,10 @@ class DataInterface:
|
||||
"""
|
||||
cache_file = os.path.join(self.data_dir, f"{self.symbol.replace('/', '_')}_{timeframe}.csv")
|
||||
|
||||
# For 1s timeframe, always fetch fresh data
|
||||
if timeframe == '1s':
|
||||
use_cache = False
|
||||
|
||||
# Check if cached data exists and is recent
|
||||
if use_cache and os.path.exists(cache_file):
|
||||
try:
|
||||
@ -74,18 +78,18 @@ class DataInterface:
|
||||
logger.error(f"Error reading cached data: {str(e)}")
|
||||
|
||||
# If we get here, we need to fetch data
|
||||
# For now, we'll use a placeholder for fetching data from an exchange
|
||||
try:
|
||||
# In a real implementation, we would fetch data from an exchange or API here
|
||||
# For this example, we'll create dummy data if we can't load from cache
|
||||
logger.info(f"Fetching historical data for {self.symbol} {timeframe}")
|
||||
|
||||
# For 1s timeframe, we need more data points
|
||||
if timeframe == '1s':
|
||||
n_candles = min(n_candles * 60, 10000) # Up to 10k ticks
|
||||
|
||||
# Placeholder for real data fetching
|
||||
# In a real implementation, this would be replaced with API calls
|
||||
self._fetch_data_from_exchange(timeframe, n_candles)
|
||||
|
||||
# Save to cache
|
||||
if self.dataframes[timeframe] is not None:
|
||||
# Save to cache (except for 1s timeframe)
|
||||
if self.dataframes[timeframe] is not None and timeframe != '1s':
|
||||
self.dataframes[timeframe].to_csv(cache_file, index=False)
|
||||
return self.dataframes[timeframe]
|
||||
else:
|
||||
@ -122,6 +126,7 @@ class DataInterface:
|
||||
"""
|
||||
# Map timeframe to seconds
|
||||
tf_seconds = {
|
||||
'1s': 1, # Added 1s timeframe
|
||||
'1m': 60,
|
||||
'5m': 300,
|
||||
'15m': 900,
|
||||
@ -207,59 +212,62 @@ class DataInterface:
|
||||
|
||||
# Get data for all requested timeframes
|
||||
dfs = {}
|
||||
min_length = float('inf')
|
||||
for tf in timeframes:
|
||||
df = self.get_historical_data(timeframe=tf, n_candles=n_candles)
|
||||
# For 1s timeframe, we need more data points
|
||||
tf_candles = n_candles * 60 if tf == '1s' else n_candles
|
||||
df = self.get_historical_data(timeframe=tf, n_candles=tf_candles)
|
||||
if df is not None and not df.empty:
|
||||
dfs[tf] = df
|
||||
# Keep track of minimum length across all timeframes
|
||||
min_length = min(min_length, len(df))
|
||||
|
||||
if not dfs:
|
||||
logger.error("No data available for feature creation")
|
||||
return None, None, None
|
||||
|
||||
# Align all dataframes to the same length
|
||||
for tf in dfs:
|
||||
dfs[tf] = dfs[tf].tail(min_length)
|
||||
|
||||
# Create features for each timeframe
|
||||
features = []
|
||||
targets = []
|
||||
timestamps = []
|
||||
targets = None
|
||||
timestamps = None
|
||||
|
||||
for tf in timeframes:
|
||||
if tf in dfs:
|
||||
X, y, ts = self._create_features(dfs[tf], window_size)
|
||||
if X is not None and y is not None:
|
||||
features.append(X)
|
||||
if len(targets) == 0: # Only need targets from one timeframe
|
||||
if targets is None: # Only need targets from one timeframe
|
||||
targets = y
|
||||
timestamps = ts
|
||||
|
||||
if not features:
|
||||
if not features or targets is None:
|
||||
logger.error("Failed to create features for any timeframe")
|
||||
return None, None, None
|
||||
|
||||
# Stack features from all timeframes along the time dimension
|
||||
# Reshape each timeframe's features to [samples, window, 1, features]
|
||||
reshaped_features = [f.reshape(f.shape[0], f.shape[1], 1, f.shape[2])
|
||||
for f in features]
|
||||
# Concatenate along the channel dimension
|
||||
X = np.concatenate(reshaped_features, axis=2)
|
||||
# Reshape to [samples, window, features*timeframes]
|
||||
X = X.reshape(X.shape[0], X.shape[1], -1)
|
||||
# Ensure all feature arrays have the same length
|
||||
min_samples = min(f.shape[0] for f in features)
|
||||
features = [f[-min_samples:] for f in features]
|
||||
targets = targets[-min_samples:]
|
||||
timestamps = timestamps[-min_samples:]
|
||||
|
||||
# Stack features from all timeframes
|
||||
X = np.concatenate([f.reshape(min_samples, window_size, -1) for f in features], axis=2)
|
||||
|
||||
# Validate data
|
||||
if np.any(np.isnan(X)) or np.any(np.isinf(X)):
|
||||
logger.error("Generated features contain NaN or infinite values")
|
||||
return None, None, None
|
||||
|
||||
# Ensure all values are finite and normalized
|
||||
X = np.nan_to_num(X, nan=0.0, posinf=1.0, neginf=-1.0)
|
||||
X = np.clip(X, -1e6, 1e6) # Clip extreme values
|
||||
|
||||
# Log data shapes for debugging
|
||||
logger.info(f"Prepared input data - X shape: {X.shape}, y shape: {np.array(targets).shape}")
|
||||
|
||||
logger.info(f"Prepared input data - X shape: {X.shape}, y shape: {targets.shape}")
|
||||
return X, targets, timestamps
|
||||
|
||||
def _create_features(self, df, window_size):
|
||||
"""
|
||||
Create features from OHLCV data using a sliding window approach.
|
||||
Create features from OHLCV data using a sliding window.
|
||||
|
||||
Args:
|
||||
df (pd.DataFrame): DataFrame with OHLCV data
|
||||
@ -267,76 +275,34 @@ class DataInterface:
|
||||
|
||||
Returns:
|
||||
tuple: (X, y, timestamps) where:
|
||||
X is the input features array
|
||||
X is the feature array
|
||||
y is the target array
|
||||
timestamps is an array of timestamps for each sample
|
||||
timestamps is the array of timestamps
|
||||
"""
|
||||
# Extract OHLCV columns
|
||||
ohlcv = df[['open', 'high', 'low', 'close', 'volume']].values
|
||||
|
||||
# Validate data before scaling
|
||||
if np.any(np.isnan(ohlcv)) or np.any(np.isinf(ohlcv)):
|
||||
logger.error("Input data contains NaN or infinite values")
|
||||
return None, None, None
|
||||
|
||||
# Handle potential constant columns (avoid division by zero in scaler)
|
||||
ohlcv = np.nan_to_num(ohlcv, nan=0.0)
|
||||
ranges = np.ptp(ohlcv, axis=0)
|
||||
for i in range(len(ranges)):
|
||||
if ranges[i] == 0: # Constant column
|
||||
ohlcv[:, i] = 1 if i == 3 else 0 # Set close to 1, others to 0
|
||||
|
||||
# Scale the data with safety checks
|
||||
try:
|
||||
scaler = MinMaxScaler()
|
||||
ohlcv_scaled = scaler.fit_transform(ohlcv)
|
||||
if np.any(np.isnan(ohlcv_scaled)) or np.any(np.isinf(ohlcv_scaled)):
|
||||
logger.error("Scaling produced invalid values")
|
||||
return None, None, None
|
||||
except Exception as e:
|
||||
logger.error(f"Scaling failed: {str(e)}")
|
||||
if len(df) < window_size + 1:
|
||||
logger.error(f"Not enough data for feature creation (need {window_size + 1}, got {len(df)})")
|
||||
return None, None, None
|
||||
|
||||
# Store the scaler for later use
|
||||
timeframe = next((tf for tf in self.timeframes if self.dataframes.get(tf) is not None and
|
||||
self.dataframes[tf].equals(df)), 'unknown')
|
||||
self.scalers[timeframe] = scaler
|
||||
# Extract OHLCV data
|
||||
data = df[['open', 'high', 'low', 'close', 'volume']].values
|
||||
timestamps = df['timestamp'].values
|
||||
|
||||
# Create sliding windows
|
||||
X = []
|
||||
y = []
|
||||
timestamps = []
|
||||
X = np.array([data[i:i+window_size] for i in range(len(data)-window_size)])
|
||||
|
||||
for i in range(len(ohlcv_scaled) - window_size):
|
||||
# Input: window_size candles of OHLCV data
|
||||
window = ohlcv_scaled[i:i+window_size]
|
||||
|
||||
# Validate window data
|
||||
if np.any(np.isnan(window)) or np.any(np.isinf(window)):
|
||||
continue
|
||||
|
||||
X.append(window)
|
||||
|
||||
# Target: binary classification - price goes up (1) or down (0)
|
||||
# 1 if close price increases in the next candle, 0 otherwise
|
||||
price_change = ohlcv[i+window_size, 3] - ohlcv[i+window_size-1, 3]
|
||||
y.append(1 if price_change > 0 else 0)
|
||||
|
||||
# Store timestamp for reference
|
||||
timestamps.append(df['timestamp'].iloc[i+window_size])
|
||||
# Create targets (next candle's movement: 0=down, 1=neutral, 2=up)
|
||||
next_close = data[window_size:, 3] # Close prices
|
||||
curr_close = data[window_size-1:-1, 3]
|
||||
price_changes = (next_close - curr_close) / curr_close
|
||||
|
||||
if not X:
|
||||
logger.error("No valid windows created")
|
||||
return None, None, None
|
||||
|
||||
X = np.array(X)
|
||||
y = np.array(y)
|
||||
timestamps = np.array(timestamps)
|
||||
# Define thresholds for price movement classification
|
||||
threshold = 0.001 # 0.1% threshold
|
||||
y = np.zeros(len(price_changes), dtype=int)
|
||||
y[price_changes > threshold] = 2 # Up
|
||||
y[(price_changes >= -threshold) & (price_changes <= threshold)] = 1 # Neutral
|
||||
|
||||
# Log shapes for debugging
|
||||
logger.info(f"Created features - X shape: {X.shape}, y shape: {y.shape}")
|
||||
|
||||
return X, y, timestamps
|
||||
return X, y, timestamps[window_size:]
|
||||
|
||||
def generate_training_dataset(self, timeframes=None, n_candles=1000, window_size=20):
|
||||
"""
|
||||
@ -406,21 +372,28 @@ class DataInterface:
|
||||
return dataset_info
|
||||
|
||||
def get_feature_count(self):
|
||||
"""Get the number of features per input sample"""
|
||||
# OHLCV (5 features) per timeframe
|
||||
return 5 * len(self.timeframes)
|
||||
"""
|
||||
Calculate total number of features across all timeframes.
|
||||
|
||||
Returns:
|
||||
int: Total number of features (5 features per timeframe)
|
||||
"""
|
||||
return len(self.timeframes) * 5 # OHLCV features for each timeframe
|
||||
|
||||
def calculate_pnl(self, predictions, actual_prices, position_size=1.0):
|
||||
"""
|
||||
Calculate PnL based on predictions and actual price movements.
|
||||
Calculate PnL and win rates based on predictions and actual price movements.
|
||||
|
||||
Args:
|
||||
predictions (np.array): Model predictions (0: sell, 1: hold, 2: buy)
|
||||
actual_prices (np.array): Actual price data
|
||||
position_size (float): Size of the position to trade
|
||||
predictions: Array of predicted actions (0=SELL, 1=HOLD, 2=BUY) or probabilities
|
||||
actual_prices: Array of actual close prices
|
||||
position_size: Position size for each trade
|
||||
|
||||
Returns:
|
||||
tuple: (total_pnl, win_rate, trade_history)
|
||||
tuple: (pnl, win_rate, trades) where:
|
||||
pnl is the total profit and loss
|
||||
win_rate is the ratio of winning trades
|
||||
trades is a list of trade dictionaries
|
||||
"""
|
||||
if len(predictions) != len(actual_prices) - 1:
|
||||
logger.error("Predictions and prices length mismatch")
|
||||
@ -468,36 +441,85 @@ class DataInterface:
|
||||
win_rate = wins / trades if trades > 0 else 0.0
|
||||
return pnl, win_rate, trade_history
|
||||
|
||||
def prepare_training_data(self, refresh=False, refresh_interval=300):
|
||||
def get_future_prices(self, prices, n_candles=3):
|
||||
"""
|
||||
Prepare training and validation data with optional refresh.
|
||||
Extract future prices for use in retrospective training.
|
||||
|
||||
Args:
|
||||
refresh (bool): Whether to force refresh data
|
||||
refresh_interval (int): Minimum seconds between refreshes
|
||||
prices: Array of close prices
|
||||
n_candles: Number of future candles to predict
|
||||
|
||||
Returns:
|
||||
tuple: (X_train, y_train, X_val, y_val, prices) numpy arrays
|
||||
numpy.ndarray: Array of future prices for each sample
|
||||
"""
|
||||
if prices is None or len(prices) <= n_candles:
|
||||
logger.warning(f"Not enough price data for future prediction: {len(prices) if prices is not None else 0} prices")
|
||||
# Return zeros if not enough data
|
||||
return np.zeros((len(prices) if prices is not None else 0, 1))
|
||||
|
||||
# For each price point i, get the price at i+n_candles
|
||||
future_prices = np.zeros((len(prices), 1))
|
||||
for i in range(len(prices) - n_candles):
|
||||
future_prices[i, 0] = prices[i + n_candles]
|
||||
|
||||
# For the last n_candles positions, we don't have future data
|
||||
# We'll use the last known price as a placeholder
|
||||
for i in range(len(prices) - n_candles, len(prices)):
|
||||
future_prices[i, 0] = prices[-1]
|
||||
|
||||
return future_prices
|
||||
|
||||
def prepare_training_data(self, refresh=False, refresh_interval=300):
|
||||
"""
|
||||
Prepare data for training, including splitting into train/validation sets.
|
||||
|
||||
Args:
|
||||
refresh (bool): Whether to refresh the data cache
|
||||
refresh_interval (int): Interval in seconds to refresh data
|
||||
|
||||
Returns:
|
||||
tuple: (X_train, y_train, X_val, y_val, train_prices, val_prices)
|
||||
"""
|
||||
current_time = datetime.now()
|
||||
if refresh or (current_time - getattr(self, 'last_refresh', datetime.min)).total_seconds() > refresh_interval:
|
||||
|
||||
# Check if we should refresh the data
|
||||
if refresh or not hasattr(self, 'last_refresh_time') or \
|
||||
(current_time - self.last_refresh_time).total_seconds() > refresh_interval:
|
||||
logger.info("Refreshing training data...")
|
||||
for tf in self.timeframes:
|
||||
self.get_historical_data(timeframe=tf, n_candles=1000, use_cache=False)
|
||||
self.last_refresh = current_time
|
||||
|
||||
# Get all data
|
||||
self.last_refresh_time = current_time
|
||||
else:
|
||||
# Use cached data
|
||||
if hasattr(self, 'cached_train_data'):
|
||||
return self.cached_train_data
|
||||
|
||||
# Prepare input data
|
||||
X, y, _ = self.prepare_nn_input()
|
||||
if X is None:
|
||||
return None, None, None, None, None
|
||||
return None, None, None, None, None, None
|
||||
|
||||
# Get price data for PnL calculation
|
||||
prices = self.dataframes[self.timeframes[0]]['close'].values
|
||||
|
||||
# Split into train/validation (80/20)
|
||||
raw_prices = []
|
||||
for tf in self.timeframes:
|
||||
if tf in self.dataframes and self.dataframes[tf] is not None:
|
||||
# Get the close prices for the same period as X
|
||||
prices = self.dataframes[tf]['close'].values[-len(X):]
|
||||
if len(prices) == len(X):
|
||||
raw_prices = prices
|
||||
break
|
||||
|
||||
if len(raw_prices) != len(X):
|
||||
raw_prices = np.zeros(len(X)) # Fallback if no prices available
|
||||
|
||||
# Split data into training and validation sets (80/20)
|
||||
split_idx = int(len(X) * 0.8)
|
||||
return (X[:split_idx], y[:split_idx], X[split_idx:], y[split_idx:],
|
||||
prices[:split_idx], prices[split_idx:])
|
||||
X_train, X_val = X[:split_idx], X[split_idx:]
|
||||
y_train, y_val = y[:split_idx], y[split_idx:]
|
||||
train_prices, val_prices = raw_prices[:split_idx], raw_prices[split_idx:]
|
||||
|
||||
# Cache the data
|
||||
self.cached_train_data = (X_train, y_train, X_val, y_val, train_prices, val_prices)
|
||||
|
||||
return X_train, y_train, X_val, y_val, train_prices, val_prices
|
||||
|
||||
def prepare_realtime_input(self, timeframe='1h', n_candles=30, window_size=20):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user