dash and training wip
This commit is contained in:
4
.vscode/launch.json
vendored
4
.vscode/launch.json
vendored
@@ -47,6 +47,9 @@
|
|||||||
"env": {
|
"env": {
|
||||||
"PYTHONUNBUFFERED": "1",
|
"PYTHONUNBUFFERED": "1",
|
||||||
"ENABLE_REALTIME_CHARTS": "1"
|
"ENABLE_REALTIME_CHARTS": "1"
|
||||||
|
},
|
||||||
|
"linux": {
|
||||||
|
"python": "${workspaceFolder}/venv/bin/python"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -156,6 +159,7 @@
|
|||||||
"type": "python",
|
"type": "python",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
"program": "run_clean_dashboard.py",
|
"program": "run_clean_dashboard.py",
|
||||||
|
"python": "${workspaceFolder}/venv/bin/python",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"justMyCode": false,
|
"justMyCode": false,
|
||||||
"env": {
|
"env": {
|
||||||
|
@@ -1,104 +1,3 @@
|
|||||||
{
|
{
|
||||||
"decision": [
|
"decision": []
|
||||||
{
|
|
||||||
"checkpoint_id": "decision_20250704_082022",
|
|
||||||
"model_name": "decision",
|
|
||||||
"model_type": "decision_fusion",
|
|
||||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_082022.pt",
|
|
||||||
"created_at": "2025-07-04T08:20:22.416087",
|
|
||||||
"file_size_mb": 0.06720924377441406,
|
|
||||||
"performance_score": 102.79971076963062,
|
|
||||||
"accuracy": null,
|
|
||||||
"loss": 2.8923120591883844e-06,
|
|
||||||
"val_accuracy": null,
|
|
||||||
"val_loss": null,
|
|
||||||
"reward": null,
|
|
||||||
"pnl": null,
|
|
||||||
"epoch": null,
|
|
||||||
"training_time_hours": null,
|
|
||||||
"total_parameters": null,
|
|
||||||
"wandb_run_id": null,
|
|
||||||
"wandb_artifact_name": null
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"checkpoint_id": "decision_20250704_082021",
|
|
||||||
"model_name": "decision",
|
|
||||||
"model_type": "decision_fusion",
|
|
||||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_082021.pt",
|
|
||||||
"created_at": "2025-07-04T08:20:21.900854",
|
|
||||||
"file_size_mb": 0.06720924377441406,
|
|
||||||
"performance_score": 102.79970038321,
|
|
||||||
"accuracy": null,
|
|
||||||
"loss": 2.996176877014177e-06,
|
|
||||||
"val_accuracy": null,
|
|
||||||
"val_loss": null,
|
|
||||||
"reward": null,
|
|
||||||
"pnl": null,
|
|
||||||
"epoch": null,
|
|
||||||
"training_time_hours": null,
|
|
||||||
"total_parameters": null,
|
|
||||||
"wandb_run_id": null,
|
|
||||||
"wandb_artifact_name": null
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"checkpoint_id": "decision_20250704_082022",
|
|
||||||
"model_name": "decision",
|
|
||||||
"model_type": "decision_fusion",
|
|
||||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_082022.pt",
|
|
||||||
"created_at": "2025-07-04T08:20:22.294191",
|
|
||||||
"file_size_mb": 0.06720924377441406,
|
|
||||||
"performance_score": 102.79969219038436,
|
|
||||||
"accuracy": null,
|
|
||||||
"loss": 3.0781056310808756e-06,
|
|
||||||
"val_accuracy": null,
|
|
||||||
"val_loss": null,
|
|
||||||
"reward": null,
|
|
||||||
"pnl": null,
|
|
||||||
"epoch": null,
|
|
||||||
"training_time_hours": null,
|
|
||||||
"total_parameters": null,
|
|
||||||
"wandb_run_id": null,
|
|
||||||
"wandb_artifact_name": null
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"checkpoint_id": "decision_20250704_134829",
|
|
||||||
"model_name": "decision",
|
|
||||||
"model_type": "decision_fusion",
|
|
||||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_134829.pt",
|
|
||||||
"created_at": "2025-07-04T13:48:29.903250",
|
|
||||||
"file_size_mb": 0.06720924377441406,
|
|
||||||
"performance_score": 102.79967532851693,
|
|
||||||
"accuracy": null,
|
|
||||||
"loss": 3.2467253719811344e-06,
|
|
||||||
"val_accuracy": null,
|
|
||||||
"val_loss": null,
|
|
||||||
"reward": null,
|
|
||||||
"pnl": null,
|
|
||||||
"epoch": null,
|
|
||||||
"training_time_hours": null,
|
|
||||||
"total_parameters": null,
|
|
||||||
"wandb_run_id": null,
|
|
||||||
"wandb_artifact_name": null
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"checkpoint_id": "decision_20250704_214714",
|
|
||||||
"model_name": "decision",
|
|
||||||
"model_type": "decision_fusion",
|
|
||||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_214714.pt",
|
|
||||||
"created_at": "2025-07-04T21:47:14.427187",
|
|
||||||
"file_size_mb": 0.06720924377441406,
|
|
||||||
"performance_score": 102.79966325731509,
|
|
||||||
"accuracy": null,
|
|
||||||
"loss": 3.3674381887394134e-06,
|
|
||||||
"val_accuracy": null,
|
|
||||||
"val_loss": null,
|
|
||||||
"reward": null,
|
|
||||||
"pnl": null,
|
|
||||||
"epoch": null,
|
|
||||||
"training_time_hours": null,
|
|
||||||
"total_parameters": null,
|
|
||||||
"wandb_run_id": null,
|
|
||||||
"wandb_artifact_name": null
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
@@ -1969,7 +1969,17 @@ class EnhancedRealtimeTrainingSystem:
|
|||||||
|
|
||||||
self.last_prediction_time[symbol] = int(current_time)
|
self.last_prediction_time[symbol] = int(current_time)
|
||||||
|
|
||||||
logger.info(f"Forward DQN prediction: {symbol} action={['BUY','SELL','HOLD'][action]} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
# Robust action labeling
|
||||||
|
if action is None:
|
||||||
|
action_label = 'HOLD'
|
||||||
|
elif action == 0:
|
||||||
|
action_label = 'SELL'
|
||||||
|
elif action == 1:
|
||||||
|
action_label = 'BUY'
|
||||||
|
else:
|
||||||
|
action_label = 'UNKNOWN'
|
||||||
|
|
||||||
|
logger.info(f"Forward DQN prediction: {symbol} action={action_label} confidence={confidence:.2f} target={target_time.strftime('%H:%M:%S')}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error generating forward DQN prediction: {e}")
|
logger.error(f"Error generating forward DQN prediction: {e}")
|
||||||
|
@@ -81,8 +81,8 @@ orchestrator:
|
|||||||
# Model weights for decision combination
|
# Model weights for decision combination
|
||||||
cnn_weight: 0.7 # Weight for CNN predictions
|
cnn_weight: 0.7 # Weight for CNN predictions
|
||||||
rl_weight: 0.3 # Weight for RL decisions
|
rl_weight: 0.3 # Weight for RL decisions
|
||||||
confidence_threshold: 0.15
|
confidence_threshold: 0.45
|
||||||
confidence_threshold_close: 0.08
|
confidence_threshold_close: 0.30
|
||||||
decision_frequency: 30
|
decision_frequency: 30
|
||||||
|
|
||||||
# Multi-symbol coordination
|
# Multi-symbol coordination
|
||||||
|
@@ -349,7 +349,8 @@ class TradingOrchestrator:
|
|||||||
try:
|
try:
|
||||||
self.cob_rl_agent.load_model() # This loads the state into the model
|
self.cob_rl_agent.load_model() # This loads the state into the model
|
||||||
from utils.checkpoint_manager import load_best_checkpoint
|
from utils.checkpoint_manager import load_best_checkpoint
|
||||||
result = load_best_checkpoint("cob_rl_model")
|
# Use consistent model name with checkpoint manager and get_model_states
|
||||||
|
result = load_best_checkpoint("cob_rl")
|
||||||
if result:
|
if result:
|
||||||
file_path, metadata = result
|
file_path, metadata = result
|
||||||
self.model_states['cob_rl']['initial_loss'] = getattr(metadata, 'initial_loss', None)
|
self.model_states['cob_rl']['initial_loss'] = getattr(metadata, 'initial_loss', None)
|
||||||
@@ -1592,13 +1593,16 @@ class TradingOrchestrator:
|
|||||||
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled")
|
logger.warning("EnhancedRealtimeTrainingSystem not available - training disabled")
|
||||||
self.training_enabled = False
|
self.training_enabled = False
|
||||||
return
|
return
|
||||||
|
# Initialize unified training manager
|
||||||
# Initialize the enhanced training system
|
from utils.training_integration import get_unified_training_manager
|
||||||
self.enhanced_training_system = EnhancedRealtimeTrainingSystem(
|
self.training_manager = get_unified_training_manager(
|
||||||
orchestrator=self,
|
orchestrator=self,
|
||||||
data_provider=self.data_provider,
|
data_provider=self.data_provider,
|
||||||
dashboard=None # Will be set by dashboard when available
|
dashboard=None
|
||||||
)
|
)
|
||||||
|
self.training_manager.initialize()
|
||||||
|
# Keep backward-compatible attribute
|
||||||
|
self.enhanced_training_system = getattr(self.training_manager, 'training_system', None)
|
||||||
|
|
||||||
logger.info("Enhanced real-time training system initialized")
|
logger.info("Enhanced real-time training system initialized")
|
||||||
logger.info(" - Real-time model training: ENABLED")
|
logger.info(" - Real-time model training: ENABLED")
|
||||||
@@ -1614,11 +1618,11 @@ class TradingOrchestrator:
|
|||||||
def start_enhanced_training(self):
|
def start_enhanced_training(self):
|
||||||
"""Start the enhanced real-time training system"""
|
"""Start the enhanced real-time training system"""
|
||||||
try:
|
try:
|
||||||
if not self.training_enabled or not self.enhanced_training_system:
|
if not self.training_enabled or not getattr(self, 'training_manager', None):
|
||||||
logger.warning("Enhanced training system not available")
|
logger.warning("Enhanced training system not available")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
self.enhanced_training_system.start_training()
|
self.training_manager.start()
|
||||||
logger.info("Enhanced real-time training started")
|
logger.info("Enhanced real-time training started")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@@ -1629,8 +1633,8 @@ class TradingOrchestrator:
|
|||||||
def stop_enhanced_training(self):
|
def stop_enhanced_training(self):
|
||||||
"""Stop the enhanced real-time training system"""
|
"""Stop the enhanced real-time training system"""
|
||||||
try:
|
try:
|
||||||
if self.enhanced_training_system:
|
if getattr(self, 'training_manager', None):
|
||||||
self.enhanced_training_system.stop_training()
|
self.training_manager.stop()
|
||||||
logger.info("Enhanced real-time training stopped")
|
logger.info("Enhanced real-time training stopped")
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
@@ -731,7 +731,8 @@ class RealtimeRLCOBTrader:
|
|||||||
with self.training_lock:
|
with self.training_lock:
|
||||||
# Check if we have enough data for training
|
# Check if we have enough data for training
|
||||||
predictions = list(self.prediction_history[symbol])
|
predictions = list(self.prediction_history[symbol])
|
||||||
if len(predictions) < 10:
|
# Train with fewer samples to kickstart learning
|
||||||
|
if len(predictions) < 6:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Calculate rewards for recent predictions
|
# Calculate rewards for recent predictions
|
||||||
@@ -739,11 +740,11 @@ class RealtimeRLCOBTrader:
|
|||||||
|
|
||||||
# Filter predictions with calculated rewards
|
# Filter predictions with calculated rewards
|
||||||
training_predictions = [p for p in predictions if p.reward is not None]
|
training_predictions = [p for p in predictions if p.reward is not None]
|
||||||
if len(training_predictions) < 5:
|
if len(training_predictions) < 3:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Prepare training batch
|
# Prepare training batch
|
||||||
batch_size = min(32, len(training_predictions))
|
batch_size = min(16, len(training_predictions))
|
||||||
batch_predictions = training_predictions[-batch_size:]
|
batch_predictions = training_predictions[-batch_size:]
|
||||||
|
|
||||||
# Train model
|
# Train model
|
||||||
|
8
enhanced_realtime_training.py
Normal file
8
enhanced_realtime_training.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
Shim module to expose EnhancedRealtimeTrainingSystem at project root.
|
||||||
|
This avoids import issues when modules do `from enhanced_realtime_training import EnhancedRealtimeTrainingSystem`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from NN.training.enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||||||
|
|
||||||
|
__all__ = ["EnhancedRealtimeTrainingSystem"]
|
@@ -10,7 +10,6 @@ tensorboard>=2.15.0
|
|||||||
scikit-learn>=1.3.0
|
scikit-learn>=1.3.0
|
||||||
matplotlib>=3.7.0
|
matplotlib>=3.7.0
|
||||||
seaborn>=0.12.0
|
seaborn>=0.12.0
|
||||||
wandb>=0.16.0
|
|
||||||
|
|
||||||
ta>=0.11.0
|
ta>=0.11.0
|
||||||
ccxt>=4.0.0
|
ccxt>=4.0.0
|
||||||
|
@@ -3,6 +3,34 @@
|
|||||||
Clean Trading Dashboard Runner with Enhanced Stability and Error Handling
|
Clean Trading Dashboard Runner with Enhanced Stability and Error Handling
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Ensure we run with the project's virtual environment Python
|
||||||
|
try:
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
import platform
|
||||||
|
|
||||||
|
def _ensure_project_venv():
|
||||||
|
try:
|
||||||
|
project_root = Path(__file__).resolve().parent
|
||||||
|
if platform.system().lower().startswith('win'):
|
||||||
|
venv_python = project_root / 'venv' / 'Scripts' / 'python.exe'
|
||||||
|
else:
|
||||||
|
venv_python = project_root / 'venv' / 'bin' / 'python'
|
||||||
|
|
||||||
|
if venv_python.exists():
|
||||||
|
current = Path(sys.executable).resolve()
|
||||||
|
target = venv_python.resolve()
|
||||||
|
if current != target:
|
||||||
|
os.execv(str(target), [str(target), *sys.argv])
|
||||||
|
except Exception:
|
||||||
|
# If anything goes wrong, continue with current interpreter
|
||||||
|
pass
|
||||||
|
|
||||||
|
_ensure_project_venv()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import logging
|
import logging
|
||||||
import traceback
|
import traceback
|
||||||
|
@@ -14,11 +14,7 @@ from collections import defaultdict
|
|||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
|
|
||||||
try:
|
WANDB_AVAILABLE = False
|
||||||
import wandb
|
|
||||||
WANDB_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
WANDB_AVAILABLE = False
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -58,13 +54,13 @@ class CheckpointManager:
|
|||||||
base_checkpoint_dir: str = "NN/models/saved",
|
base_checkpoint_dir: str = "NN/models/saved",
|
||||||
max_checkpoints_per_model: int = 5,
|
max_checkpoints_per_model: int = 5,
|
||||||
metadata_file: str = "checkpoint_metadata.json",
|
metadata_file: str = "checkpoint_metadata.json",
|
||||||
enable_wandb: bool = True):
|
enable_wandb: bool = False):
|
||||||
self.base_dir = Path(base_checkpoint_dir)
|
self.base_dir = Path(base_checkpoint_dir)
|
||||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
self.max_checkpoints = max_checkpoints_per_model
|
self.max_checkpoints = max_checkpoints_per_model
|
||||||
self.metadata_file = self.base_dir / metadata_file
|
self.metadata_file = self.base_dir / metadata_file
|
||||||
self.enable_wandb = enable_wandb and WANDB_AVAILABLE
|
self.enable_wandb = False
|
||||||
|
|
||||||
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
|
self.checkpoints: Dict[str, List[CheckpointMetadata]] = defaultdict(list)
|
||||||
self._load_metadata()
|
self._load_metadata()
|
||||||
@@ -115,10 +111,7 @@ class CheckpointManager:
|
|||||||
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
|
total_parameters=training_metadata.get('total_parameters') if training_metadata else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.enable_wandb and wandb.run is not None:
|
# W&B disabled
|
||||||
artifact_name = self._upload_to_wandb(checkpoint_path, metadata)
|
|
||||||
metadata.wandb_run_id = wandb.run.id
|
|
||||||
metadata.wandb_artifact_name = artifact_name
|
|
||||||
|
|
||||||
self.checkpoints[model_name].append(metadata)
|
self.checkpoints[model_name].append(metadata)
|
||||||
self._rotate_checkpoints(model_name)
|
self._rotate_checkpoints(model_name)
|
||||||
@@ -273,18 +266,6 @@ class CheckpointManager:
|
|||||||
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
|
logger.error(f"Error removing rotated checkpoint {checkpoint.checkpoint_id}: {e}")
|
||||||
|
|
||||||
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
|
def _upload_to_wandb(self, file_path: Path, metadata: CheckpointMetadata) -> Optional[str]:
|
||||||
try:
|
|
||||||
if not self.enable_wandb or wandb.run is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
artifact_name = f"{metadata.model_name}_checkpoint"
|
|
||||||
artifact = wandb.Artifact(artifact_name, type="model")
|
|
||||||
artifact.add_file(str(file_path))
|
|
||||||
wandb.log_artifact(artifact)
|
|
||||||
|
|
||||||
return artifact_name
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error uploading to W&B: {e}")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _load_metadata(self):
|
def _load_metadata(self):
|
||||||
@@ -404,6 +385,56 @@ class CheckpointManager:
|
|||||||
logger.debug(f"Found legacy model file in subdirectory: {candidate_path}")
|
logger.debug(f"Found legacy model file in subdirectory: {candidate_path}")
|
||||||
return candidate_path
|
return candidate_path
|
||||||
|
|
||||||
|
# Extended search: scan common project model directories for best checkpoints
|
||||||
|
try:
|
||||||
|
# Attempt to infer project root from base_dir (NN/models/saved -> root)
|
||||||
|
project_root = base_dir.resolve().parent.parent.parent
|
||||||
|
except Exception:
|
||||||
|
project_root = Path(".").resolve()
|
||||||
|
additional_dirs = [
|
||||||
|
project_root / "models",
|
||||||
|
project_root / "models" / "archive",
|
||||||
|
project_root / "models" / "backtest",
|
||||||
|
]
|
||||||
|
|
||||||
|
def _match_legacy_name(candidate: Path, model: str) -> bool:
|
||||||
|
name = candidate.name.lower()
|
||||||
|
model_keys = {
|
||||||
|
'dqn_agent': ['dqn', 'agent', 'policy'],
|
||||||
|
'enhanced_cnn': ['cnn', 'optimized_short_term'],
|
||||||
|
'extrema_trainer': ['supervised', 'extrema'],
|
||||||
|
'cob_rl': ['cob', 'rl', 'policy'],
|
||||||
|
'decision': ['decision', 'transformer']
|
||||||
|
}.get(model, [model])
|
||||||
|
return any(k in name for k in model_keys)
|
||||||
|
|
||||||
|
candidates: List[Path] = []
|
||||||
|
for adir in additional_dirs:
|
||||||
|
if not adir.exists():
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
for pt in adir.rglob('*.pt'):
|
||||||
|
# Prefer files that indicate "best" and match model hints
|
||||||
|
lname = pt.name.lower()
|
||||||
|
if 'best' in lname and _match_legacy_name(pt, model_name):
|
||||||
|
candidates.append(pt)
|
||||||
|
# Do not add generic fallbacks to avoid mismatched model types
|
||||||
|
except Exception:
|
||||||
|
# Ignore directory traversal issues
|
||||||
|
pass
|
||||||
|
|
||||||
|
if candidates:
|
||||||
|
# Pick the most recently modified candidate
|
||||||
|
try:
|
||||||
|
best = max(candidates, key=lambda p: p.stat().st_mtime)
|
||||||
|
logger.debug(f"Found legacy model file in project models dir: {best}")
|
||||||
|
return best
|
||||||
|
except Exception:
|
||||||
|
# If stat fails, just return the first one deterministically
|
||||||
|
candidates.sort()
|
||||||
|
logger.debug(f"Found legacy model file in project models dir: {candidates[0]}")
|
||||||
|
return candidates[0]
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata:
|
def _create_legacy_metadata(self, model_name: str, file_path: Path) -> CheckpointMetadata:
|
||||||
|
@@ -75,15 +75,18 @@ class RewardCalculator:
|
|||||||
def calculate_basic_reward(self, pnl, confidence):
|
def calculate_basic_reward(self, pnl, confidence):
|
||||||
"""Calculate basic training reward based on P&L and confidence"""
|
"""Calculate basic training reward based on P&L and confidence"""
|
||||||
try:
|
try:
|
||||||
|
# Reward based on net PnL after fees and confidence alignment
|
||||||
base_reward = pnl
|
base_reward = pnl
|
||||||
if pnl < 0 and confidence > 0.7:
|
# Stronger penalty for confident wrong decisions
|
||||||
confidence_adjustment = -confidence * 2
|
if pnl < 0 and confidence >= 0.6:
|
||||||
elif pnl > 0 and confidence > 0.7:
|
confidence_adjustment = -confidence * 3.0
|
||||||
confidence_adjustment = confidence * 1.5
|
elif pnl > 0 and confidence >= 0.6:
|
||||||
|
confidence_adjustment = confidence * 1.0
|
||||||
else:
|
else:
|
||||||
confidence_adjustment = 0
|
confidence_adjustment = 0.0
|
||||||
final_reward = base_reward + confidence_adjustment
|
final_reward = base_reward + confidence_adjustment
|
||||||
normalized_reward = np.tanh(final_reward / 10.0)
|
# Reduce tanh compression so small PnL changes are not flattened
|
||||||
|
normalized_reward = np.tanh(final_reward / 2.5)
|
||||||
logger.debug(f"Basic reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}")
|
logger.debug(f"Basic reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}")
|
||||||
return float(normalized_reward)
|
return float(normalized_reward)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@@ -14,7 +14,7 @@ from .checkpoint_manager import get_checkpoint_manager, save_checkpoint, load_be
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class TrainingIntegration:
|
class TrainingIntegration:
|
||||||
def __init__(self, enable_wandb: bool = True):
|
def __init__(self, enable_wandb: bool = False):
|
||||||
self.checkpoint_manager = get_checkpoint_manager()
|
self.checkpoint_manager = get_checkpoint_manager()
|
||||||
self.enable_wandb = enable_wandb
|
self.enable_wandb = enable_wandb
|
||||||
|
|
||||||
@@ -22,24 +22,8 @@ class TrainingIntegration:
|
|||||||
self._init_wandb()
|
self._init_wandb()
|
||||||
|
|
||||||
def _init_wandb(self):
|
def _init_wandb(self):
|
||||||
try:
|
# Disabled by default to avoid CLI prompts
|
||||||
import wandb
|
pass
|
||||||
|
|
||||||
if wandb.run is None:
|
|
||||||
wandb.init(
|
|
||||||
project="gogo2-trading",
|
|
||||||
name=f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
|
|
||||||
config={
|
|
||||||
"max_checkpoints_per_model": self.checkpoint_manager.max_checkpoints,
|
|
||||||
"checkpoint_dir": str(self.checkpoint_manager.base_dir)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
logger.info(f"Initialized W&B run: {wandb.run.id}")
|
|
||||||
|
|
||||||
except ImportError:
|
|
||||||
logger.warning("W&B not available - checkpoint management will work without it")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error initializing W&B: {e}")
|
|
||||||
|
|
||||||
def save_cnn_checkpoint(self,
|
def save_cnn_checkpoint(self,
|
||||||
cnn_model,
|
cnn_model,
|
||||||
@@ -64,19 +48,7 @@ class TrainingIntegration:
|
|||||||
'total_parameters': self._count_parameters(cnn_model)
|
'total_parameters': self._count_parameters(cnn_model)
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.enable_wandb:
|
# W&B disabled
|
||||||
try:
|
|
||||||
import wandb
|
|
||||||
if wandb.run is not None:
|
|
||||||
wandb.log({
|
|
||||||
f"{model_name}/train_accuracy": train_accuracy,
|
|
||||||
f"{model_name}/val_accuracy": val_accuracy,
|
|
||||||
f"{model_name}/train_loss": train_loss,
|
|
||||||
f"{model_name}/val_loss": val_loss,
|
|
||||||
f"{model_name}/epoch": epoch
|
|
||||||
})
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error logging to W&B: {e}")
|
|
||||||
|
|
||||||
metadata = save_checkpoint(
|
metadata = save_checkpoint(
|
||||||
model=cnn_model,
|
model=cnn_model,
|
||||||
@@ -120,22 +92,7 @@ class TrainingIntegration:
|
|||||||
'total_parameters': self._count_parameters(rl_agent)
|
'total_parameters': self._count_parameters(rl_agent)
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.enable_wandb:
|
# W&B disabled
|
||||||
try:
|
|
||||||
import wandb
|
|
||||||
if wandb.run is not None:
|
|
||||||
wandb.log({
|
|
||||||
f"{model_name}/avg_reward": avg_reward,
|
|
||||||
f"{model_name}/best_reward": best_reward,
|
|
||||||
f"{model_name}/epsilon": epsilon,
|
|
||||||
f"{model_name}/episode": episode
|
|
||||||
})
|
|
||||||
|
|
||||||
if total_pnl is not None:
|
|
||||||
wandb.log({f"{model_name}/total_pnl": total_pnl})
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error logging to W&B: {e}")
|
|
||||||
|
|
||||||
metadata = save_checkpoint(
|
metadata = save_checkpoint(
|
||||||
model=rl_agent,
|
model=rl_agent,
|
||||||
@@ -202,3 +159,75 @@ def get_training_integration() -> TrainingIntegration:
|
|||||||
if _training_integration is None:
|
if _training_integration is None:
|
||||||
_training_integration = TrainingIntegration()
|
_training_integration = TrainingIntegration()
|
||||||
return _training_integration
|
return _training_integration
|
||||||
|
|
||||||
|
# ---------------- Unified Training Manager ----------------
|
||||||
|
|
||||||
|
class UnifiedTrainingManager:
|
||||||
|
"""Single entry point to manage all training in the system.
|
||||||
|
|
||||||
|
Coordinates EnhancedRealtimeTrainingSystem and provides start/stop/status.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, orchestrator, data_provider, dashboard=None):
|
||||||
|
self.orchestrator = orchestrator
|
||||||
|
self.data_provider = data_provider
|
||||||
|
self.dashboard = dashboard
|
||||||
|
self.training_system = None
|
||||||
|
self.started = False
|
||||||
|
|
||||||
|
def initialize(self) -> bool:
|
||||||
|
try:
|
||||||
|
# Import via project root shim to avoid path issues
|
||||||
|
from enhanced_realtime_training import EnhancedRealtimeTrainingSystem
|
||||||
|
self.training_system = EnhancedRealtimeTrainingSystem(
|
||||||
|
orchestrator=self.orchestrator,
|
||||||
|
data_provider=self.data_provider,
|
||||||
|
dashboard=self.dashboard
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"UnifiedTrainingManager: failed to initialize training system: {e}")
|
||||||
|
self.training_system = None
|
||||||
|
return False
|
||||||
|
|
||||||
|
def start(self) -> bool:
|
||||||
|
try:
|
||||||
|
if self.training_system is None:
|
||||||
|
if not self.initialize():
|
||||||
|
return False
|
||||||
|
self.training_system.start_training()
|
||||||
|
self.started = True
|
||||||
|
logger.info("UnifiedTrainingManager: training started")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"UnifiedTrainingManager: error starting training: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def stop(self) -> bool:
|
||||||
|
try:
|
||||||
|
if self.training_system and self.started:
|
||||||
|
self.training_system.stop_training()
|
||||||
|
self.started = False
|
||||||
|
logger.info("UnifiedTrainingManager: training stopped")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"UnifiedTrainingManager: error stopping training: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
|
try:
|
||||||
|
if self.training_system and hasattr(self.training_system, 'get_training_stats'):
|
||||||
|
return self.training_system.get_training_stats()
|
||||||
|
return {}
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
_unified_training_manager = None
|
||||||
|
|
||||||
|
def get_unified_training_manager(orchestrator=None, data_provider=None, dashboard=None) -> UnifiedTrainingManager:
|
||||||
|
global _unified_training_manager
|
||||||
|
if _unified_training_manager is None:
|
||||||
|
if orchestrator is None or data_provider is None:
|
||||||
|
raise ValueError("orchestrator and data_provider are required for first-time initialization")
|
||||||
|
_unified_training_manager = UnifiedTrainingManager(orchestrator, data_provider, dashboard)
|
||||||
|
return _unified_training_manager
|
||||||
|
@@ -642,6 +642,35 @@ class CleanTradingDashboard:
|
|||||||
logger.error(f"Error updating trades table: {e}")
|
logger.error(f"Error updating trades table: {e}")
|
||||||
return html.P(f"Error: {str(e)}", className="text-danger")
|
return html.P(f"Error: {str(e)}", className="text-danger")
|
||||||
|
|
||||||
|
@self.app.callback(
|
||||||
|
Output('training-status', 'children'),
|
||||||
|
[Input('start-training-btn', 'n_clicks'),
|
||||||
|
Input('stop-training-btn', 'n_clicks')],
|
||||||
|
prevent_initial_call=True
|
||||||
|
)
|
||||||
|
def control_training(start_clicks, stop_clicks):
|
||||||
|
try:
|
||||||
|
from utils.training_integration import get_unified_training_manager
|
||||||
|
manager = get_unified_training_manager(
|
||||||
|
orchestrator=self.orchestrator,
|
||||||
|
data_provider=self.data_provider,
|
||||||
|
dashboard=self
|
||||||
|
)
|
||||||
|
ctx = dash.callback_context
|
||||||
|
if not ctx.triggered:
|
||||||
|
raise PreventUpdate
|
||||||
|
trigger_id = ctx.triggered[0]['prop_id'].split('.')[0]
|
||||||
|
if trigger_id == 'start-training-btn':
|
||||||
|
ok = manager.start()
|
||||||
|
return 'Running' if ok else 'Error'
|
||||||
|
elif trigger_id == 'stop-training-btn':
|
||||||
|
ok = manager.stop()
|
||||||
|
return 'Stopped' if ok else 'Error'
|
||||||
|
return 'Idle'
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Training control error: {e}")
|
||||||
|
return 'Error'
|
||||||
|
|
||||||
@self.app.callback(
|
@self.app.callback(
|
||||||
[Output('eth-cob-content', 'children'),
|
[Output('eth-cob-content', 'children'),
|
||||||
Output('btc-cob-content', 'children')],
|
Output('btc-cob-content', 'children')],
|
||||||
@@ -5215,7 +5244,12 @@ class CleanTradingDashboard:
|
|||||||
"""Start the Dash server"""
|
"""Start the Dash server"""
|
||||||
try:
|
try:
|
||||||
logger.info(f"TRADING: Starting Clean Dashboard at http://{host}:{port}")
|
logger.info(f"TRADING: Starting Clean Dashboard at http://{host}:{port}")
|
||||||
|
# Run the Dash app normally; launch/activation is handled by the runner
|
||||||
|
if hasattr(self, 'app') and self.app is not None:
|
||||||
|
# Dash 3.x: use app.run
|
||||||
self.app.run(host=host, port=port, debug=debug)
|
self.app.run(host=host, port=port, debug=debug)
|
||||||
|
else:
|
||||||
|
logger.error("Dash app is not initialized")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error starting dashboard server: {e}")
|
logger.error(f"Error starting dashboard server: {e}")
|
||||||
raise
|
raise
|
||||||
|
@@ -153,6 +153,29 @@ class DashboardLayoutManager:
|
|||||||
tooltip={"placement": "bottom", "always_visible": False}
|
tooltip={"placement": "bottom", "always_visible": False}
|
||||||
)
|
)
|
||||||
], className="mb-2"),
|
], className="mb-2"),
|
||||||
|
# Training Controls
|
||||||
|
html.Div([
|
||||||
|
html.Label([
|
||||||
|
html.I(className="fas fa-play me-1"),
|
||||||
|
"Training Controls"
|
||||||
|
], className="form-label small mb-1"),
|
||||||
|
html.Div([
|
||||||
|
html.Button([
|
||||||
|
html.I(className="fas fa-play me-1"),
|
||||||
|
"Start Training"
|
||||||
|
], id="start-training-btn", className="btn btn-success btn-sm me-2",
|
||||||
|
style={"fontSize": "10px", "padding": "2px 8px"}),
|
||||||
|
html.Button([
|
||||||
|
html.I(className="fas fa-stop me-1"),
|
||||||
|
"Stop Training"
|
||||||
|
], id="stop-training-btn", className="btn btn-danger btn-sm",
|
||||||
|
style={"fontSize": "10px", "padding": "2px 8px"})
|
||||||
|
], className="d-flex align-items-center mb-1"),
|
||||||
|
html.Div([
|
||||||
|
html.Span("Training:", className="small me-1"),
|
||||||
|
html.Span(id="training-status", children="Idle", className="badge bg-secondary small")
|
||||||
|
])
|
||||||
|
], className="mb-2"),
|
||||||
|
|
||||||
# Entry Aggressiveness Control
|
# Entry Aggressiveness Control
|
||||||
html.Div([
|
html.Div([
|
||||||
|
Reference in New Issue
Block a user