forced actions
This commit is contained in:
parent
fbff9c37a2
commit
10ff22eb42
1
.gitignore
vendored
1
.gitignore
vendored
@ -31,3 +31,4 @@ app_data.db
|
|||||||
crypto/sol/.vs/*
|
crypto/sol/.vs/*
|
||||||
crypto/brian/models/best/*
|
crypto/brian/models/best/*
|
||||||
crypto/brian/models/last/*
|
crypto/brian/models/last/*
|
||||||
|
crypto/brian/live_chart.html
|
||||||
|
@ -186,7 +186,6 @@ def save_checkpoint(model, optimizer, epoch, loss, last_dir=LAST_DIR, best_dir=B
|
|||||||
if len(best_models) < 10:
|
if len(best_models) < 10:
|
||||||
add_to_best = True
|
add_to_best = True
|
||||||
else:
|
else:
|
||||||
# The worst saved checkpoint will have the highest loss.
|
|
||||||
worst_loss, worst_file = max(best_models, key=lambda x: x[0])
|
worst_loss, worst_file = max(best_models, key=lambda x: x[0])
|
||||||
if loss < worst_loss:
|
if loss < worst_loss:
|
||||||
add_to_best = True
|
add_to_best = True
|
||||||
@ -219,7 +218,7 @@ def load_best_checkpoint(model, best_dir=BEST_DIR):
|
|||||||
def update_live_html(candles, trade_history, epoch):
|
def update_live_html(candles, trade_history, epoch):
|
||||||
"""
|
"""
|
||||||
Generate a chart image with buy/sell markers and a dotted line between open and close,
|
Generate a chart image with buy/sell markers and a dotted line between open and close,
|
||||||
then embed it in a simple HTML page that auto-refreshes.
|
then embed it in a simple HTML page that auto-refreshes every 10 seconds.
|
||||||
"""
|
"""
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
import base64
|
import base64
|
||||||
@ -301,6 +300,25 @@ def update_live_chart(ax, candles, trade_history):
|
|||||||
ax.legend()
|
ax.legend()
|
||||||
ax.grid(True)
|
ax.grid(True)
|
||||||
|
|
||||||
|
# --- Forced Action Helper ---
|
||||||
|
def get_forced_action(env):
|
||||||
|
"""
|
||||||
|
Force at least one trade per episode:
|
||||||
|
- At the very first step, force a BUY (action 2) if no position is open.
|
||||||
|
- At the penultimate step, if a position is open, force a SELL (action 0).
|
||||||
|
- Otherwise, default to HOLD (action 1).
|
||||||
|
"""
|
||||||
|
total = len(env)
|
||||||
|
if env.current_index == 0:
|
||||||
|
return 2 # BUY
|
||||||
|
elif env.current_index >= total - 2:
|
||||||
|
if env.position is not None:
|
||||||
|
return 0 # SELL
|
||||||
|
else:
|
||||||
|
return 1 # HOLD
|
||||||
|
else:
|
||||||
|
return 1 # HOLD
|
||||||
|
|
||||||
# --- Backtest Environment ---
|
# --- Backtest Environment ---
|
||||||
class BacktestEnvironment:
|
class BacktestEnvironment:
|
||||||
def __init__(self, candles_dict, base_tf, timeframes):
|
def __init__(self, candles_dict, base_tf, timeframes):
|
||||||
@ -343,11 +361,11 @@ class BacktestEnvironment:
|
|||||||
|
|
||||||
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
|
# Action mapping: 0 -> SELL, 1 -> HOLD, 2 -> BUY.
|
||||||
if self.position is None:
|
if self.position is None:
|
||||||
if action == 2: # BUY: enter at next candle's open.
|
if action == 2: # BUY signal: enter at next candle's open.
|
||||||
entry_price = next_candle["open"]
|
entry_price = next_candle["open"]
|
||||||
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
|
self.position = {"entry_price": entry_price, "entry_index": self.current_index}
|
||||||
else:
|
else:
|
||||||
if action == 0: # SELL: exit at next candle's open.
|
if action == 0: # SELL signal: exit at next candle's open.
|
||||||
exit_price = next_candle["open"]
|
exit_price = next_candle["open"]
|
||||||
reward = exit_price - self.position["entry_price"]
|
reward = exit_price - self.position["entry_price"]
|
||||||
trade = {
|
trade = {
|
||||||
@ -376,11 +394,13 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
|
|||||||
total_loss = 0
|
total_loss = 0
|
||||||
model.train()
|
model.train()
|
||||||
while True:
|
while True:
|
||||||
|
# Use forced action policy to guarantee at least one trade per episode
|
||||||
|
action = get_forced_action(env)
|
||||||
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
|
||||||
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
timeframe_ids = torch.arange(state.shape[0]).to(device)
|
||||||
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
pred_high, pred_low = model(state_tensor, timeframe_ids)
|
||||||
# Get targets from environment (dummy high/low from next candle)
|
# Use the forced action in the environment step.
|
||||||
_, _, next_state, done, actual_high, actual_low = env.step(None)
|
_, reward, next_state, done, actual_high, actual_low = env.step(action)
|
||||||
target_high = torch.FloatTensor([actual_high]).to(device)
|
target_high = torch.FloatTensor([actual_high]).to(device)
|
||||||
target_low = torch.FloatTensor([actual_low]).to(device)
|
target_low = torch.FloatTensor([actual_low]).to(device)
|
||||||
high_loss = torch.abs(pred_high - target_high) * 2
|
high_loss = torch.abs(pred_high - target_high) * 2
|
||||||
@ -398,7 +418,7 @@ def train_on_historical_data(env, model, device, args, start_epoch, optimizer, s
|
|||||||
epoch_loss = total_loss / len(env)
|
epoch_loss = total_loss / len(env)
|
||||||
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
|
print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")
|
||||||
save_checkpoint(model, optimizer, epoch, total_loss)
|
save_checkpoint(model, optimizer, epoch, total_loss)
|
||||||
# Update the live HTML file with the current epoch chart
|
# Update the live HTML file with the current epoch chart.
|
||||||
update_live_html(base_candles, env.trade_history, epoch+1)
|
update_live_html(base_candles, env.trade_history, epoch+1)
|
||||||
|
|
||||||
# --- Live Plotting Functions (For live mode) ---
|
# --- Live Plotting Functions (For live mode) ---
|
||||||
@ -465,7 +485,7 @@ async def main():
|
|||||||
print("Loaded optimizer state from checkpoint.")
|
print("Loaded optimizer state from checkpoint.")
|
||||||
else:
|
else:
|
||||||
print("No valid optimizer state found in checkpoint; starting fresh optimizer state.")
|
print("No valid optimizer state found in checkpoint; starting fresh optimizer state.")
|
||||||
# Pass base candles from the base timeframe for HTML chart updates.
|
# Pass the base timeframe candles for the live HTML chart update.
|
||||||
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler, candles_dict[base_tf])
|
train_on_historical_data(env, model, device, args, start_epoch, optimizer, scheduler, candles_dict[base_tf])
|
||||||
|
|
||||||
elif args.mode == 'live':
|
elif args.mode == 'live':
|
||||||
@ -477,9 +497,11 @@ async def main():
|
|||||||
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes)
|
env = BacktestEnvironment(candles_dict, base_tf="1m", timeframes=timeframes)
|
||||||
preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True)
|
preview_thread = threading.Thread(target=live_preview_loop, args=(candles_dict["1m"], env), daemon=True)
|
||||||
preview_thread.start()
|
preview_thread.start()
|
||||||
print("Starting live trading loop. (Using random actions for simulation.)")
|
print("Starting live trading loop. (Using forced action policy for simulation.)")
|
||||||
|
# Here we use the forced-action policy as in training.
|
||||||
while True:
|
while True:
|
||||||
state, reward, next_state, done, _, _ = env.step(random_action())
|
action = get_forced_action(env)
|
||||||
|
state, reward, next_state, done, _, _ = env.step(action)
|
||||||
if done:
|
if done:
|
||||||
print("Reached end of simulated data, resetting environment.")
|
print("Reached end of simulated data, resetting environment.")
|
||||||
state = env.reset()
|
state = env.reset()
|
||||||
@ -487,7 +509,7 @@ async def main():
|
|||||||
elif args.mode == 'inference':
|
elif args.mode == 'inference':
|
||||||
load_best_checkpoint(model)
|
load_best_checkpoint(model)
|
||||||
print("Running inference...")
|
print("Running inference...")
|
||||||
# Your inference logic goes here.
|
# Here you can apply a similar forced-action policy or use a learned policy.
|
||||||
else:
|
else:
|
||||||
print("Invalid mode specified.")
|
print("Invalid mode specified.")
|
||||||
|
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user