forced actions

This commit is contained in:
Dobromir Popov 2025-02-04 21:52:11 +02:00
parent fbff9c37a2
commit 10ff22eb42
3 changed files with 37 additions and 14 deletions

1
.gitignore vendored
View File

@ -31,3 +31,4 @@ app_data.db
crypto/sol/.vs/* crypto/sol/.vs/*
crypto/brian/models/best/* crypto/brian/models/best/*
crypto/brian/models/last/* crypto/brian/models/last/*
crypto/brian/live_chart.html

View File

@ -186,7 +186,6 @@ def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=B
if len(best_models) < 10: if len(best_models) < 10:
add_to_best = True add_to_best = True
else: else:
# The worst saved checkpoint will have the highest loss.
worst_loss, worst_file = max(best_models, key=lambda x: x[0]) worst_loss, worst_file = max(best_models, key=lambda x: x[0])
if loss < worst_loss: if loss < worst_loss:
add_to_best = True add_to_best = True
@ -219,7 +218,7 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
def update_live_html(candles, trade_history, epoch): def update_live_html(candles, trade_history, epoch):
""" """
Generate a chart image with buy/sell markers and a dotted line between open and close, Generate a chart image with buy/sell markers and a dotted line between open and close,
then embed it in a simple HTML page that auto-refreshes. then embed it in a simple HTML page that auto-refreshes every 10 seconds.
""" """
from io import BytesIO from io import BytesIO
import base64 import base64
@ -301,6 +300,25 @@ def update_live_chart(ax, candles, trade_history):
ax.legend() ax.legend()
ax.grid(True) ax.grid(True)
# --- Forced Action Helper ---
def get_forced_action(env):
"""
Force at least one trade per episode:
- At the very first step, force a BUY (action 2) if no position is open.
- At the penultimate step, if a position is open, force a SELL (action 0).
- Otherwise, default to HOLD (action 1).
"""
total = len(env)
if env.current_index == 0:
return 2 # BUY
elif env.current_index >= total - 2:
if env.position is not None:
return 0 # SELL
else:
return 1 # HOLD
else:
return 1 # HOLD
# --- Backtest Environment --- # --- Backtest Environment ---
class BacktestEnvironment: class BacktestEnvironment:
def __init__(self, candles_dict, base_tf, timeframes): def __init__(self, candles_dict, base_tf, timeframes):
@ -343,11 +361,11 @@ class BacktestEnvironment:
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY. # Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
if self.position is None: if self.position is None:
if action == 2: # BUY: enter at next candle's open. if action == 2: # BUY signal: enter at next candle's open.
entry_price = next_candle["open"] entry_price = next_candle["open"]
self.position = {"entry_price": entry_price, "entry_index": self.current_index} self.position = {"entry_price": entry_price, "entry_index": self.current_index}
else: else:
if action == 0: # SELL: exit at next candle's open. if action == 0: # SELL signal: exit at next candle's open.
exit_price = next_candle["open"] exit_price = next_candle["open"]
reward = exit_price - self.position["entry_price"] reward = exit_price - self.position["entry_price"]
trade = { trade = {
@ -376,11 +394,13 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
total_loss = 0 total_loss = 0
model.train() model.train()
while True: while True:
# Use forced action policy to guarantee at least one trade per episode
action = get_forced_action(env)
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device) state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
timeframe_ids = torch.arange(state.shape[0]).to(device) timeframe_ids = torch.arange(state.shape[0]).to(device)
pred_high, pred_low = model(state_tensor, timeframe_ids) pred_high, pred_low = model(state_tensor, timeframe_ids)
# Get targets from environment (dummy high/low from next candle) # Use the forced action in the environment step.
_, _, next_state, done, actual_high, actual_low = env.step(None) _, reward, next_state, done, actual_high, actual_low = env.step(action)
target_high = torch.FloatTensor([actual_high]).to(device) target_high = torch.FloatTensor([actual_high]).to(device)
target_low = torch.FloatTensor([actual_low]).to(device) target_low = torch.FloatTensor([actual_low]).to(device)
high_loss = torch.abs(pred_high - target_high) * 2 high_loss = torch.abs(pred_high - target_high) * 2
@ -398,7 +418,7 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
epoch_loss = total_loss / len(env) epoch_loss = total_loss / len(env)
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}") print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
save_checkpoint(model, optimizer, epoch, total_loss) save_checkpoint(model, optimizer, epoch, total_loss)
# Update the live HTML file with the current epoch chart # Update the live HTML file with the current epoch chart.
update_live_html(base_candles, env.trade_history, epoch+1) update_live_html(base_candles, env.trade_history, epoch+1)
# --- Live Plotting Functions (For live mode) --- # --- Live Plotting Functions (For live mode) ---
@ -465,7 +485,7 @@ async def main():
print("Loaded optimizer state from checkpoint.") print("Loaded optimizer state from checkpoint.")
else: else:
print("No valid optimizer state found in checkpoint; starting fresh optimizer state.") print("No valid optimizer state found in checkpoint; starting fresh optimizer state.")
# Pass base candles from the base timeframe for HTML chart updates. # Pass the base timeframe candles for the live HTML chart update.
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler, candles_dict[base_tf]) train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler, candles_dict[base_tf])
elif args.mode == 'live': elif args.mode == 'live':
@ -477,9 +497,11 @@ async def main():
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes) env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes)
preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True) preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True)
preview_thread.start() preview_thread.start()
print("Starting live trading loop. (Using random actions for simulation.)") print("Starting live trading loop. (Using forced action policy for simulation.)")
# Here we use the forced-action policy as in training.
while True: while True:
state, reward, next_state, done, _, _ = env.step(random_action()) action = get_forced_action(env)
state, reward, next_state, done, _, _ = env.step(action)
if done: if done:
print("Reached end of simulated data, resetting environment.") print("Reached end of simulated data, resetting environment.")
state = env.reset() state = env.reset()
@ -487,7 +509,7 @@ async def main():
elif args.mode == 'inference': elif args.mode == 'inference':
load_best_checkpoint(model) load_best_checkpoint(model)
print("Running inference...") print("Running inference...")
# Your inference logic goes here. # Here you can apply a similar forced-action policy or use a learned policy.
else: else:
print("Invalid mode specified.") print("Invalid mode specified.")

File diff suppressed because one or more lines are too long