filter anotations by symbol
This commit is contained in:
@@ -338,8 +338,17 @@ class AnnotationManager:
|
||||
logger.info(f"Generated lightweight test case: {test_case['test_case_id']} (OHLCV data will be fetched dynamically)")
|
||||
return test_case
|
||||
|
||||
def get_all_test_cases(self) -> List[Dict]:
|
||||
"""Load all test cases from disk"""
|
||||
def get_all_test_cases(self, symbol: Optional[str] = None) -> List[Dict]:
|
||||
"""
|
||||
Load all test cases from disk
|
||||
|
||||
Args:
|
||||
symbol: Optional symbol filter (e.g., 'ETH/USDT'). If provided, only returns
|
||||
test cases for that symbol. Critical for avoiding cross-symbol training.
|
||||
|
||||
Returns:
|
||||
List of test case dictionaries
|
||||
"""
|
||||
test_cases = []
|
||||
|
||||
if not self.test_cases_dir.exists():
|
||||
@@ -349,11 +358,22 @@ class AnnotationManager:
|
||||
try:
|
||||
with open(test_case_file, 'r') as f:
|
||||
test_case = json.load(f)
|
||||
|
||||
# CRITICAL: Filter by symbol to avoid training on wrong symbol
|
||||
if symbol:
|
||||
test_case_symbol = test_case.get('symbol', '')
|
||||
if test_case_symbol != symbol:
|
||||
logger.debug(f"Skipping {test_case_file.name}: symbol {test_case_symbol} != {symbol}")
|
||||
continue
|
||||
|
||||
test_cases.append(test_case)
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading test case {test_case_file}: {e}")
|
||||
|
||||
logger.info(f"Loaded {len(test_cases)} test cases from disk")
|
||||
if symbol:
|
||||
logger.info(f"Loaded {len(test_cases)} test cases for symbol {symbol}")
|
||||
else:
|
||||
logger.info(f"Loaded {len(test_cases)} test cases (all symbols)")
|
||||
return test_cases
|
||||
|
||||
def _calculate_holding_period(self, annotation: TradeAnnotation) -> float:
|
||||
|
||||
@@ -1740,13 +1740,17 @@ class RealTrainingAdapter:
|
||||
logger.info(f"Using orchestrator's TradingTransformerTrainer")
|
||||
logger.info(f" Trainer type: {type(trainer).__name__}")
|
||||
|
||||
# Import torch at function level (not inside try block)
|
||||
import torch
|
||||
import gc
|
||||
|
||||
# Load best checkpoint if available to continue training
|
||||
try:
|
||||
checkpoint_dir = "models/checkpoints/transformer"
|
||||
best_checkpoint_path = self._find_best_checkpoint(checkpoint_dir, metric='accuracy')
|
||||
|
||||
if best_checkpoint_path and os.path.exists(best_checkpoint_path):
|
||||
checkpoint = torch.load(best_checkpoint_path)
|
||||
checkpoint = torch.load(best_checkpoint_path, map_location='cpu')
|
||||
trainer.model.load_state_dict(checkpoint['model_state_dict'])
|
||||
trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
||||
trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
||||
|
||||
Reference in New Issue
Block a user