forced actions
This commit is contained in:
parent
fbff9c37a2
commit
10ff22eb42
1
.gitignore
vendored
1
.gitignore
vendored
@ -31,3 +31,4 @@ app_data.db
|
||||
crypto/sol/.vs/*
|
||||
crypto/brian/models/best/*
|
||||
crypto/brian/models/last/*
|
||||
crypto/brian/live_chart.html
|
||||
|
@ -186,7 +186,6 @@ def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=B
|
||||
if len(best_models) < 10:
|
||||
add_to_best = True
|
||||
else:
|
||||
# The worst saved checkpoint will have the highest loss.
|
||||
worst_loss, worst_file = max(best_models, key=lambda x: x[0])
|
||||
if loss < worst_loss:
|
||||
add_to_best = True
|
||||
@ -219,7 +218,7 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
|
||||
def update_live_html(candles, trade_history, epoch):
|
||||
"""
|
||||
Generate a chart image with buy/sell markers and a dotted line between open and close,
|
||||
then embed it in a simple HTML page that auto-refreshes.
|
||||
then embed it in a simple HTML page that auto-refreshes every 10 seconds.
|
||||
"""
|
||||
from io import BytesIO
|
||||
import base64
|
||||
@ -301,6 +300,25 @@ def update_live_chart(ax, candles, trade_history):
|
||||
ax.legend()
|
||||
ax.grid(True)
|
||||
|
||||
# --- Forced Action Helper ---
|
||||
def get_forced_action(env):
|
||||
"""
|
||||
Force at least one trade per episode:
|
||||
- At the very first step, force a BUY (action 2) if no position is open.
|
||||
- At the penultimate step, if a position is open, force a SELL (action 0).
|
||||
- Otherwise, default to HOLD (action 1).
|
||||
"""
|
||||
total = len(env)
|
||||
if env.current_index == 0:
|
||||
return 2 # BUY
|
||||
elif env.current_index >= total - 2:
|
||||
if env.position is not None:
|
||||
return 0 # SELL
|
||||
else:
|
||||
return 1 # HOLD
|
||||
else:
|
||||
return 1 # HOLD
|
||||
|
||||
# --- Backtest Environment ---
|
||||
class BacktestEnvironment:
|
||||
def __init__(self, candles_dict, base_tf, timeframes):
|
||||
@ -343,11 +361,11 @@ class BacktestEnvironment:
|
||||
|
||||
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
|
||||
if self.position is None:
|
||||
if action == 2: # BUY: enter at next candle's open.
|
||||
if action == 2: # BUY signal: enter at next candle's open.
|
||||
entry_price = next_candle["open"]
|
||||
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
|
||||
else:
|
||||
if action == 0: # SELL: exit at next candle's open.
|
||||
if action == 0: # SELL signal: exit at next candle's open.
|
||||
exit_price = next_candle["open"]
|
||||
reward = exit_price - self.position["entry_price"]
|
||||
trade = {
|
||||
@ -376,11 +394,13 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
|
||||
total_loss = 0
|
||||
model.train()
|
||||
while True:
|
||||
# Use forced action policy to guarantee at least one trade per episode
|
||||
action = get_forced_action(env)
|
||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||
# Get targets from environment (dummy high/low from next candle)
|
||||
_, _, next_state, done, actual_high, actual_low = env.step(None)
|
||||
# Use the forced action in the environment step.
|
||||
_, reward, next_state, done, actual_high, actual_low = env.step(action)
|
||||
target_high = torch.FloatTensor([actual_high]).to(device)
|
||||
target_low = torch.FloatTensor([actual_low]).to(device)
|
||||
high_loss = torch.abs(pred_high - target_high) * 2
|
||||
@ -398,7 +418,7 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
|
||||
epoch_loss = total_loss / len(env)
|
||||
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
|
||||
save_checkpoint(model, optimizer, epoch, total_loss)
|
||||
# Update the live HTML file with the current epoch chart
|
||||
# Update the live HTML file with the current epoch chart.
|
||||
update_live_html(base_candles, env.trade_history, epoch+1)
|
||||
|
||||
# --- Live Plotting Functions (For live mode) ---
|
||||
@ -465,7 +485,7 @@ async def main():
|
||||
print("Loaded optimizer state from checkpoint.")
|
||||
else:
|
||||
print("No valid optimizer state found in checkpoint; starting fresh optimizer state.")
|
||||
# Pass base candles from the base timeframe for HTML chart updates.
|
||||
# Pass the base timeframe candles for the live HTML chart update.
|
||||
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler, candles_dict[base_tf])
|
||||
|
||||
elif args.mode == 'live':
|
||||
@ -477,9 +497,11 @@ async def main():
|
||||
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes)
|
||||
preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True)
|
||||
preview_thread.start()
|
||||
print("Starting live trading loop. (Using random actions for simulation.)")
|
||||
print("Starting live trading loop. (Using forced action policy for simulation.)")
|
||||
# Here we use the forced-action policy as in training.
|
||||
while True:
|
||||
state, reward, next_state, done, _, _ = env.step(random_action())
|
||||
action = get_forced_action(env)
|
||||
state, reward, next_state, done, _, _ = env.step(action)
|
||||
if done:
|
||||
print("Reached end of simulated data, resetting environment.")
|
||||
state = env.reset()
|
||||
@ -487,7 +509,7 @@ async def main():
|
||||
elif args.mode == 'inference':
|
||||
load_best_checkpoint(model)
|
||||
print("Running inference...")
|
||||
# Your inference logic goes here.
|
||||
# Here you can apply a similar forced-action policy or use a learned policy.
|
||||
else:
|
||||
print("Invalid mode specified.")
|
||||
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user