dash and training wip

This commit is contained in:
Dobromir Popov
2025-09-02 15:30:05 +03:00
parent 443e8e746f
commit 1b54438082
14 changed files with 270 additions and 197 deletions

View File

@@ -14,11 +14,7 @@ from collections import defaultdict
import torch
import random
try:
import wandb
WANDB_AVAILABLE = True
except ImportError:
WANDB_AVAILABLE = False
WANDB_AVAILABLE = False
logger = logging.getLogger(__name__)
@@ -58,13 +54,13 @@ class CheckpointManager:
base_checkpoint_dir: str = "NN/models/saved",
max_checkpoints_per_model: int = 5,
metadata_file: str = "checkpoint_metadata.json",
enable_wandb: bool = True):
enable_wandb: bool = False):
self.base_dir = Path(base_checkpoint_dir)
self.base_dir.mkdir(parents=True, exist_ok=True)
self.max_checkpoints = max_checkpoints_per_model
self.metadata_file = self.base_dir / metadata_file
self.enable_wandb = enable_wandb and WANDB_AVAILABLE
self.enable_wandb = False
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
self._load_metadata()
@@ -115,10 +111,7 @@ class CheckpointManager:
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
)
if self.enable_wandb and wandb.run is not None:
artifact_name = self._upload_to_wandb(checkpoint_path, metadata)
metadata.wandb_run_id = wandb.run.id
metadata.wandb_artifact_name = artifact_name
# W&B disabled
self.checkpoints[model_name].append(metadata)
self._rotate_checkpoints(model_name)
@@ -273,19 +266,7 @@ class CheckpointManager:
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
try:
if not self.enable_wandb or wandb.run is None:
return None
artifact_name = f"{metadata.model_name}_checkpoint"
artifact = wandb.Artifact(artifact_name, type="model")
artifact.add_file(str(file_path))
wandb.log_artifact(artifact)
return artifact_name
except Exception as e:
logger.error(f"Error uploading to W&B: {e}")
return None
return None
def _load_metadata(self):
try:
@@ -404,6 +385,56 @@ class CheckpointManager:
logger.debug(f"Found legacy model file in subdirectory: {candidate_path}")
return candidate_path
# Extended search: scan common project model directories for best checkpoints
try:
# Attempt to infer project root from base_dir (NN/models/saved -> root)
project_root = base_dir.resolve().parent.parent.parent
except Exception:
project_root = Path(".").resolve()
additional_dirs = [
project_root / "models",
project_root / "models" / "archive",
project_root / "models" / "backtest",
]
def _match_legacy_name(candidate: Path, model: str) -> bool:
name = candidate.name.lower()
model_keys = {
'dqn_agent': ['dqn', 'agent', 'policy'],
'enhanced_cnn': ['cnn', 'optimized_short_term'],
'extrema_trainer': ['supervised', 'extrema'],
'cob_rl': ['cob', 'rl', 'policy'],
'decision': ['decision', 'transformer']
}.get(model, [model])
return any(k in name for k in model_keys)
candidates: List[Path] = []
for adir in additional_dirs:
if not adir.exists():
continue
try:
for pt in adir.rglob('*.pt'):
# Prefer files that indicate "best" and match model hints
lname = pt.name.lower()
if 'best' in lname and _match_legacy_name(pt, model_name):
candidates.append(pt)
# Do not add generic fallbacks to avoid mismatched model types
except Exception:
# Ignore directory traversal issues
pass
if candidates:
# Pick the most recently modified candidate
try:
best = max(candidates, key=lambda p: p.stat().st_mtime)
logger.debug(f"Found legacy model file in project models dir: {best}")
return best
except Exception:
# If stat fails, just return the first one deterministically
candidates.sort()
logger.debug(f"Found legacy model file in project models dir: {candidates[0]}")
return candidates[0]
return None
def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata:

View File

@@ -75,15 +75,18 @@ class RewardCalculator:
def calculate_basic_reward(self, pnl, confidence):
"""Calculate basic training reward based on P&L and confidence"""
try:
# Reward based on net PnL after fees and confidence alignment
base_reward = pnl
if pnl < 0 and confidence > 0.7:
confidence_adjustment = -confidence * 2
elif pnl > 0 and confidence > 0.7:
confidence_adjustment = confidence * 1.5
# Stronger penalty for confident wrong decisions
if pnl < 0 and confidence >= 0.6:
confidence_adjustment = -confidence * 3.0
elif pnl > 0 and confidence >= 0.6:
confidence_adjustment = confidence * 1.0
else:
confidence_adjustment = 0
confidence_adjustment = 0.0
final_reward = base_reward + confidence_adjustment
normalized_reward = np.tanh(final_reward / 10.0)
# Reduce tanh compression so small PnL changes are not flattened
normalized_reward = np.tanh(final_reward / 2.5)
logger.debug(f"Basic reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}")
return float(normalized_reward)
except Exception as e:

View File

@@ -14,7 +14,7 @@ from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_be
logger = logging.getLogger(__name__)
class TrainingIntegration:
def __init__(self, enable_wandb: bool = True):
def __init__(self, enable_wandb: bool = False):
self.checkpoint_manager = get_checkpoint_manager()
self.enable_wandb = enable_wandb
@@ -22,24 +22,8 @@ class TrainingIntegration:
self._init_wandb()
def _init_wandb(self):
try:
import wandb
if wandb.run is None:
wandb.init(
project="gogo2-trading",
name=f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
config={
"max_checkpoints_per_model": self.checkpoint_manager.max_checkpoints,
"checkpoint_dir": str(self.checkpoint_manager.base_dir)
}
)
logger.info(f"Initialized W&B run: {wandb.run.id}")
except ImportError:
logger.warning("W&B not available - checkpoint management will work without it")
except Exception as e:
logger.error(f"Error initializing W&B: {e}")
# Disabled by default to avoid CLI prompts
pass
def save_cnn_checkpoint(self,
cnn_model,
@@ -64,19 +48,7 @@ class TrainingIntegration:
'total_parameters': self._count_parameters(cnn_model)
}
if self.enable_wandb:
try:
import wandb
if wandb.run is not None:
wandb.log({
f"{model_name}/train_accuracy": train_accuracy,
f"{model_name}/val_accuracy": val_accuracy,
f"{model_name}/train_loss": train_loss,
f"{model_name}/val_loss": val_loss,
f"{model_name}/epoch": epoch
})
except Exception as e:
logger.warning(f"Error logging to W&B: {e}")
# W&B disabled
metadata = save_checkpoint(
model=cnn_model,
@@ -120,22 +92,7 @@ class TrainingIntegration:
'total_parameters': self._count_parameters(rl_agent)
}
if self.enable_wandb:
try:
import wandb
if wandb.run is not None:
wandb.log({
f"{model_name}/avg_reward": avg_reward,
f"{model_name}/best_reward": best_reward,
f"{model_name}/epsilon": epsilon,
f"{model_name}/episode": episode
})
if total_pnl is not None:
wandb.log({f"{model_name}/total_pnl": total_pnl})
except Exception as e:
logger.warning(f"Error logging to W&B: {e}")
# W&B disabled
metadata = save_checkpoint(
model=rl_agent,
@@ -202,3 +159,75 @@ def get_training_integration() -> TrainingIntegration:
if _training_integration is None:
_training_integration = TrainingIntegration()
return _training_integration
# ---------------- Unified Training Manager ----------------
class UnifiedTrainingManager:
"""Single entry point to manage all training in the system.
Coordinates EnhancedRealtimeTrainingSystem and provides start/stop/status.
"""
def __init__(self, orchestrator, data_provider, dashboard=None):
self.orchestrator = orchestrator
self.data_provider = data_provider
self.dashboard = dashboard
self.training_system = None
self.started = False
def initialize(self) -> bool:
try:
# Import via project root shim to avoid path issues
from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
self.training_system = EnhancedRealtimeTrainingSystem(
orchestrator=self.orchestrator,
data_provider=self.data_provider,
dashboard=self.dashboard
)
return True
except Exception as e:
logger.error(f"UnifiedTrainingManager: failed to initialize training system: {e}")
self.training_system = None
return False
def start(self) -> bool:
try:
if self.training_system is None:
if not self.initialize():
return False
self.training_system.start_training()
self.started = True
logger.info("UnifiedTrainingManager: training started")
return True
except Exception as e:
logger.error(f"UnifiedTrainingManager: error starting training: {e}")
return False
def stop(self) -> bool:
try:
if self.training_system and self.started:
self.training_system.stop_training()
self.started = False
logger.info("UnifiedTrainingManager: training stopped")
return True
except Exception as e:
logger.error(f"UnifiedTrainingManager: error stopping training: {e}")
return False
def get_stats(self) -> Dict[str, Any]:
try:
if self.training_system and hasattr(self.training_system, 'get_training_stats'):
return self.training_system.get_training_stats()
return {}
except Exception:
return {}
_unified_training_manager = None
def get_unified_training_manager(orchestrator=None, data_provider=None, dashboard=None) -> UnifiedTrainingManager:
global _unified_training_manager
if _unified_training_manager is None:
if orchestrator is None or data_provider is None:
raise ValueError("orchestrator and data_provider are required for first-time initialization")
_unified_training_manager = UnifiedTrainingManager(orchestrator, data_provider, dashboard)
return _unified_training_manager