fix realtime training
This commit is contained in:
@@ -2256,10 +2256,18 @@ class RealTrainingAdapter:
|
||||
if not os.path.exists(checkpoint_dir):
|
||||
return
|
||||
|
||||
import time
|
||||
# Add small delay to ensure files are fully written
|
||||
time.sleep(0.5)
|
||||
|
||||
checkpoints = []
|
||||
for filename in os.listdir(checkpoint_dir):
|
||||
if filename.endswith('.pt'):
|
||||
filepath = os.path.join(checkpoint_dir, filename)
|
||||
# Check if file exists and is not being written
|
||||
if not os.path.exists(filepath):
|
||||
continue
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(filepath, map_location='cpu')
|
||||
checkpoints.append({
|
||||
@@ -2276,10 +2284,12 @@ class RealTrainingAdapter:
|
||||
# Delete checkpoints beyond keep_best
|
||||
for checkpoint in checkpoints[keep_best:]:
|
||||
try:
|
||||
os.remove(checkpoint['path'])
|
||||
logger.debug(f"Removed old checkpoint: {checkpoint['path']}")
|
||||
# Double-check file still exists before deleting
|
||||
if os.path.exists(checkpoint['path']):
|
||||
os.remove(checkpoint['path'])
|
||||
logger.debug(f"Removed old checkpoint: {checkpoint['path']}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove checkpoint: {e}")
|
||||
logger.debug(f"Could not remove checkpoint: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error cleaning up checkpoints: {e}")
|
||||
@@ -3541,6 +3551,13 @@ class RealTrainingAdapter:
|
||||
logger.warning(f"Per-candle training failed: Could not convert sample to batch")
|
||||
return
|
||||
|
||||
# Validate batch has required keys
|
||||
required_keys = ['actions', 'price_data_1m', 'price_data_1h', 'price_data_1d']
|
||||
missing_keys = [k for k in required_keys if k not in batch or batch[k] is None]
|
||||
if missing_keys:
|
||||
logger.warning(f"Per-candle training skipped: Missing required keys: {missing_keys}")
|
||||
return
|
||||
|
||||
# Train on this batch
|
||||
import torch
|
||||
with torch.enable_grad():
|
||||
@@ -3691,11 +3708,19 @@ class RealTrainingAdapter:
|
||||
return
|
||||
|
||||
import torch
|
||||
import time
|
||||
|
||||
# Add small delay to ensure files are fully written
|
||||
time.sleep(0.5)
|
||||
|
||||
checkpoints = []
|
||||
for filename in os.listdir(checkpoint_dir):
|
||||
if filename.endswith('.pt') and filename.startswith('realtime_'):
|
||||
filepath = os.path.join(checkpoint_dir, filename)
|
||||
# Check if file exists and is not being written
|
||||
if not os.path.exists(filepath):
|
||||
continue
|
||||
|
||||
try:
|
||||
checkpoint = torch.load(filepath, map_location='cpu')
|
||||
checkpoints.append({
|
||||
@@ -3714,8 +3739,10 @@ class RealTrainingAdapter:
|
||||
# Keep best N checkpoints
|
||||
for checkpoint in checkpoints[keep_best:]:
|
||||
try:
|
||||
os.remove(checkpoint['path'])
|
||||
logger.debug(f"Removed old realtime checkpoint: {os.path.basename(checkpoint['path'])}")
|
||||
# Double-check file still exists before deleting
|
||||
if os.path.exists(checkpoint['path']):
|
||||
os.remove(checkpoint['path'])
|
||||
logger.debug(f"Removed old realtime checkpoint: {os.path.basename(checkpoint['path'])}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not remove checkpoint: {e}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user