BIG CLEANUP
This commit is contained in:
@@ -1,3 +0,0 @@
|
||||
"""
|
||||
Utils package for the multi-modal trading system
|
||||
"""
|
@@ -1,232 +0,0 @@
|
||||
"""
|
||||
Async Task Manager - Handles async tasks with comprehensive error handling
|
||||
Prevents silent failures in async operations
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import functools
|
||||
import traceback
|
||||
from typing import Any, Callable, Optional, Dict, List
|
||||
from datetime import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AsyncTaskManager:
|
||||
"""Manage async tasks with error handling and monitoring"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_tasks: Dict[str, asyncio.Task] = {}
|
||||
self.completed_tasks: List[Dict[str, Any]] = []
|
||||
self.failed_tasks: List[Dict[str, Any]] = []
|
||||
self.max_history = 100
|
||||
|
||||
def create_task_with_error_handling(self,
|
||||
coro: Any,
|
||||
name: str,
|
||||
error_callback: Optional[Callable] = None,
|
||||
success_callback: Optional[Callable] = None) -> asyncio.Task:
|
||||
"""
|
||||
Create an async task with comprehensive error handling
|
||||
|
||||
Args:
|
||||
coro: Coroutine to run
|
||||
name: Task name for identification
|
||||
error_callback: Called on error with (name, exception)
|
||||
success_callback: Called on success with (name, result)
|
||||
"""
|
||||
|
||||
async def wrapped_coro():
|
||||
"""Wrapper coroutine with error handling"""
|
||||
start_time = datetime.now()
|
||||
try:
|
||||
logger.debug(f"Starting async task: {name}")
|
||||
result = await coro
|
||||
|
||||
# Log success
|
||||
duration = (datetime.now() - start_time).total_seconds()
|
||||
logger.debug(f"Async task '{name}' completed successfully in {duration:.2f}s")
|
||||
|
||||
# Store completion info
|
||||
completion_info = {
|
||||
'name': name,
|
||||
'status': 'completed',
|
||||
'start_time': start_time,
|
||||
'end_time': datetime.now(),
|
||||
'duration': duration,
|
||||
'result': str(result)[:200] if result else None # Truncate long results
|
||||
}
|
||||
self.completed_tasks.append(completion_info)
|
||||
|
||||
# Trim history
|
||||
if len(self.completed_tasks) > self.max_history:
|
||||
self.completed_tasks.pop(0)
|
||||
|
||||
# Call success callback
|
||||
if success_callback:
|
||||
try:
|
||||
success_callback(name, result)
|
||||
except Exception as cb_error:
|
||||
logger.error(f"Error in success callback for task '{name}': {cb_error}")
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Async task '{name}' was cancelled")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
# Log error with full traceback
|
||||
duration = (datetime.now() - start_time).total_seconds()
|
||||
error_msg = f"Async task '{name}' failed after {duration:.2f}s: {e}"
|
||||
logger.error(error_msg)
|
||||
logger.error(f"Task '{name}' traceback: {traceback.format_exc()}")
|
||||
|
||||
# Store failure info
|
||||
failure_info = {
|
||||
'name': name,
|
||||
'status': 'failed',
|
||||
'start_time': start_time,
|
||||
'end_time': datetime.now(),
|
||||
'duration': duration,
|
||||
'error': str(e),
|
||||
'traceback': traceback.format_exc()
|
||||
}
|
||||
self.failed_tasks.append(failure_info)
|
||||
|
||||
# Trim history
|
||||
if len(self.failed_tasks) > self.max_history:
|
||||
self.failed_tasks.pop(0)
|
||||
|
||||
# Call error callback
|
||||
if error_callback:
|
||||
try:
|
||||
error_callback(name, e)
|
||||
except Exception as cb_error:
|
||||
logger.error(f"Error in error callback for task '{name}': {cb_error}")
|
||||
|
||||
# Don't re-raise to prevent task from crashing the event loop
|
||||
# Instead, return None to indicate failure
|
||||
return None
|
||||
|
||||
finally:
|
||||
# Remove from active tasks
|
||||
if name in self.active_tasks:
|
||||
del self.active_tasks[name]
|
||||
|
||||
# Create and store task
|
||||
task = asyncio.create_task(wrapped_coro(), name=name)
|
||||
self.active_tasks[name] = task
|
||||
|
||||
return task
|
||||
|
||||
def cancel_task(self, name: str) -> bool:
|
||||
"""Cancel a specific task"""
|
||||
if name in self.active_tasks:
|
||||
task = self.active_tasks[name]
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
logger.info(f"Cancelled async task: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def cancel_all_tasks(self):
|
||||
"""Cancel all active tasks"""
|
||||
for name, task in list(self.active_tasks.items()):
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
logger.info(f"Cancelled async task: {name}")
|
||||
|
||||
def get_task_status(self) -> Dict[str, Any]:
|
||||
"""Get status of all tasks"""
|
||||
active_count = len(self.active_tasks)
|
||||
completed_count = len(self.completed_tasks)
|
||||
failed_count = len(self.failed_tasks)
|
||||
|
||||
# Get recent failures
|
||||
recent_failures = self.failed_tasks[-5:] if self.failed_tasks else []
|
||||
|
||||
return {
|
||||
'active_tasks': active_count,
|
||||
'completed_tasks': completed_count,
|
||||
'failed_tasks': failed_count,
|
||||
'active_task_names': list(self.active_tasks.keys()),
|
||||
'recent_failures': [
|
||||
{
|
||||
'name': f['name'],
|
||||
'error': f['error'],
|
||||
'duration': f['duration'],
|
||||
'time': f['end_time'].strftime('%H:%M:%S')
|
||||
}
|
||||
for f in recent_failures
|
||||
]
|
||||
}
|
||||
|
||||
def get_failure_summary(self) -> Dict[str, Any]:
|
||||
"""Get summary of task failures"""
|
||||
if not self.failed_tasks:
|
||||
return {'total_failures': 0, 'failure_patterns': {}}
|
||||
|
||||
# Count failures by error type
|
||||
error_counts = {}
|
||||
for failure in self.failed_tasks:
|
||||
error_type = type(failure.get('error', 'Unknown')).__name__
|
||||
error_counts[error_type] = error_counts.get(error_type, 0) + 1
|
||||
|
||||
# Recent failure rate
|
||||
recent_failures = [f for f in self.failed_tasks if
|
||||
(datetime.now() - f['end_time']).total_seconds() < 3600] # Last hour
|
||||
|
||||
return {
|
||||
'total_failures': len(self.failed_tasks),
|
||||
'recent_failures_1h': len(recent_failures),
|
||||
'failure_patterns': error_counts,
|
||||
'most_common_error': max(error_counts.items(), key=lambda x: x[1])[0] if error_counts else None
|
||||
}
|
||||
|
||||
# Global instance
|
||||
_task_manager = None
|
||||
|
||||
def get_async_task_manager() -> AsyncTaskManager:
|
||||
"""Get global async task manager instance"""
|
||||
global _task_manager
|
||||
if _task_manager is None:
|
||||
_task_manager = AsyncTaskManager()
|
||||
return _task_manager
|
||||
|
||||
def create_safe_task(coro: Any,
|
||||
name: str,
|
||||
error_callback: Optional[Callable] = None,
|
||||
success_callback: Optional[Callable] = None) -> asyncio.Task:
|
||||
"""
|
||||
Create a safe async task with error handling
|
||||
|
||||
Args:
|
||||
coro: Coroutine to run
|
||||
name: Task name for identification
|
||||
error_callback: Called on error with (name, exception)
|
||||
success_callback: Called on success with (name, result)
|
||||
"""
|
||||
manager = get_async_task_manager()
|
||||
return manager.create_task_with_error_handling(coro, name, error_callback, success_callback)
|
||||
|
||||
def safe_async_wrapper(name: str,
|
||||
error_callback: Optional[Callable] = None,
|
||||
success_callback: Optional[Callable] = None):
|
||||
"""
|
||||
Decorator for creating safe async functions
|
||||
|
||||
Usage:
|
||||
@safe_async_wrapper("my_task")
|
||||
async def my_async_function():
|
||||
# Your async code here
|
||||
pass
|
||||
"""
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
coro = func(*args, **kwargs)
|
||||
task = create_safe_task(coro, name, error_callback, success_callback)
|
||||
return await task
|
||||
return wrapper
|
||||
return decorator
|
@@ -1,164 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TensorBoard Launcher with Automatic Port Management
|
||||
|
||||
This script launches TensorBoard with automatic port fallback if the preferred port is in use.
|
||||
It also kills any stale debug instances that might be running.
|
||||
|
||||
Usage:
|
||||
python launch_tensorboard.py --logdir=path/to/logs --preferred-port=6007 --port-range=6000-7000
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import subprocess
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
if project_root not in sys.path:
|
||||
sys.path.append(project_root)
|
||||
|
||||
from utils.port_manager import get_port_with_fallback, kill_stale_debug_instances
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('tensorboard_launcher')
|
||||
|
||||
def launch_tensorboard(logdir, port, host='localhost', open_browser=True):
|
||||
"""
|
||||
Launch TensorBoard on the specified port
|
||||
|
||||
Args:
|
||||
logdir (str): Path to log directory
|
||||
port (int): Port to use
|
||||
host (str): Host to bind to
|
||||
open_browser (bool): Whether to open browser automatically
|
||||
|
||||
Returns:
|
||||
subprocess.Popen: Process object
|
||||
"""
|
||||
cmd = [
|
||||
sys.executable, "-m", "tensorboard.main",
|
||||
f"--logdir={logdir}",
|
||||
f"--port={port}",
|
||||
f"--host={host}"
|
||||
]
|
||||
|
||||
# Add --load_fast=false to improve startup times
|
||||
cmd.append("--load_fast=false")
|
||||
|
||||
# Control whether to open browser
|
||||
if not open_browser:
|
||||
cmd.append("--window_title=TensorBoard")
|
||||
|
||||
logger.info(f"Launching TensorBoard: {' '.join(cmd)}")
|
||||
|
||||
# Use subprocess.Popen to start TensorBoard without waiting for it to finish
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
universal_newlines=True,
|
||||
bufsize=1
|
||||
)
|
||||
|
||||
# Log the first few lines of output to confirm it's starting correctly
|
||||
line_count = 0
|
||||
for line in process.stdout:
|
||||
logger.info(f"TensorBoard: {line.strip()}")
|
||||
line_count += 1
|
||||
|
||||
# Check if TensorBoard has started successfully
|
||||
if "TensorBoard" in line and "http://" in line:
|
||||
url = line.strip().split("http://")[1].split(" ")[0]
|
||||
logger.info(f"TensorBoard available at: http://{url}")
|
||||
|
||||
# Only log the first few lines
|
||||
if line_count >= 10:
|
||||
break
|
||||
|
||||
# Continue reading output in background to prevent pipe from filling
|
||||
def read_output():
|
||||
for line in process.stdout:
|
||||
pass
|
||||
|
||||
import threading
|
||||
threading.Thread(target=read_output, daemon=True).start()
|
||||
|
||||
return process
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Launch TensorBoard with automatic port management')
|
||||
parser.add_argument('--logdir', type=str, default='NN/models/saved/logs',
|
||||
help='Directory containing TensorBoard event files')
|
||||
parser.add_argument('--preferred-port', type=int, default=6007,
|
||||
help='Preferred port to use')
|
||||
parser.add_argument('--port-range', type=str, default='6000-7000',
|
||||
help='Port range to try if preferred port is unavailable (format: min-max)')
|
||||
parser.add_argument('--host', type=str, default='localhost',
|
||||
help='Host to bind to')
|
||||
parser.add_argument('--no-browser', action='store_true',
|
||||
help='Do not open browser automatically')
|
||||
parser.add_argument('--kill-stale', action='store_true',
|
||||
help='Kill stale debug instances before starting')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse port range
|
||||
try:
|
||||
min_port, max_port = map(int, args.port_range.split('-'))
|
||||
except ValueError:
|
||||
logger.error(f"Invalid port range format: {args.port_range}. Use format: min-max")
|
||||
return 1
|
||||
|
||||
# Kill stale instances if requested
|
||||
if args.kill_stale:
|
||||
logger.info("Killing stale debug instances...")
|
||||
count, _ = kill_stale_debug_instances()
|
||||
logger.info(f"Killed {count} stale instances")
|
||||
|
||||
# Get an available port
|
||||
try:
|
||||
port = get_port_with_fallback(args.preferred_port, min_port, max_port)
|
||||
logger.info(f"Using port {port} for TensorBoard")
|
||||
except RuntimeError as e:
|
||||
logger.error(str(e))
|
||||
return 1
|
||||
|
||||
# Ensure log directory exists
|
||||
logdir = os.path.abspath(args.logdir)
|
||||
os.makedirs(logdir, exist_ok=True)
|
||||
|
||||
# Launch TensorBoard
|
||||
process = launch_tensorboard(
|
||||
logdir=logdir,
|
||||
port=port,
|
||||
host=args.host,
|
||||
open_browser=not args.no_browser
|
||||
)
|
||||
|
||||
# Wait for process to end (it shouldn't unless there's an error or user kills it)
|
||||
try:
|
||||
return_code = process.wait()
|
||||
if return_code != 0:
|
||||
logger.error(f"TensorBoard exited with code {return_code}")
|
||||
return return_code
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received keyboard interrupt, shutting down TensorBoard...")
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
logger.warning("TensorBoard didn't terminate gracefully, forcing kill")
|
||||
process.kill()
|
||||
|
||||
return 0
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
@@ -1,241 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
Model utilities for robust saving and loading of PyTorch models
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
import shutil
|
||||
import gc
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def robust_save(model: Any, path: str, include_optimizer: bool = True) -> bool:
|
||||
"""
|
||||
Robust model saving with multiple fallback approaches
|
||||
|
||||
Args:
|
||||
model: The model object to save (should have policy_net, target_net, optimizer, epsilon attributes)
|
||||
path: Path to save the model
|
||||
include_optimizer: Whether to include optimizer state in the save
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
|
||||
|
||||
# Backup path in case the main save fails
|
||||
backup_path = f"{path}.backup"
|
||||
|
||||
# Clean up GPU memory before saving
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
# Prepare checkpoint data
|
||||
checkpoint = {
|
||||
'policy_net': model.policy_net.state_dict(),
|
||||
'target_net': model.target_net.state_dict(),
|
||||
'epsilon': getattr(model, 'epsilon', 0.0),
|
||||
'state_size': getattr(model, 'state_size', None),
|
||||
'action_size': getattr(model, 'action_size', None),
|
||||
'hidden_size': getattr(model, 'hidden_size', None),
|
||||
}
|
||||
|
||||
# Add optimizer state if requested and available
|
||||
if include_optimizer and hasattr(model, 'optimizer') and model.optimizer is not None:
|
||||
checkpoint['optimizer'] = model.optimizer.state_dict()
|
||||
|
||||
# Attempt 1: Try with default settings in a separate file first
|
||||
try:
|
||||
logger.info(f"Saving model to {backup_path} (attempt 1)")
|
||||
torch.save(checkpoint, backup_path)
|
||||
logger.info(f"Successfully saved to {backup_path}")
|
||||
|
||||
# If backup worked, copy to the actual path
|
||||
if os.path.exists(backup_path):
|
||||
shutil.copy(backup_path, path)
|
||||
logger.info(f"Copied backup to {path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"First save attempt failed: {e}")
|
||||
|
||||
# Attempt 2: Try with pickle protocol 2 (more compatible)
|
||||
try:
|
||||
logger.info(f"Saving model to {path} (attempt 2 - pickle protocol 2)")
|
||||
torch.save(checkpoint, path, pickle_protocol=2)
|
||||
logger.info(f"Successfully saved to {path} with pickle_protocol=2")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Second save attempt failed: {e}")
|
||||
|
||||
# Attempt 3: Try without optimizer state (which can be large and cause issues)
|
||||
try:
|
||||
logger.info(f"Saving model to {path} (attempt 3 - without optimizer)")
|
||||
checkpoint_no_opt = {k: v for k, v in checkpoint.items() if k != 'optimizer'}
|
||||
torch.save(checkpoint_no_opt, path)
|
||||
logger.info(f"Successfully saved to {path} without optimizer state")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Third save attempt failed: {e}")
|
||||
|
||||
# Attempt 4: Try with torch.jit.save instead
|
||||
try:
|
||||
logger.info(f"Saving model to {path} (attempt 4 - with jit.save)")
|
||||
# Save policy network using jit
|
||||
scripted_policy = torch.jit.script(model.policy_net)
|
||||
torch.jit.save(scripted_policy, f"{path}.policy.jit")
|
||||
|
||||
# Save target network using jit
|
||||
scripted_target = torch.jit.script(model.target_net)
|
||||
torch.jit.save(scripted_target, f"{path}.target.jit")
|
||||
|
||||
# Save parameters separately as JSON
|
||||
params = {
|
||||
'epsilon': float(getattr(model, 'epsilon', 0.0)),
|
||||
'state_size': int(getattr(model, 'state_size', 0)),
|
||||
'action_size': int(getattr(model, 'action_size', 0)),
|
||||
'hidden_size': int(getattr(model, 'hidden_size', 0))
|
||||
}
|
||||
with open(f"{path}.params.json", "w") as f:
|
||||
json.dump(params, f)
|
||||
|
||||
logger.info(f"Successfully saved model components with jit.save")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"All save attempts failed: {e}")
|
||||
return False
|
||||
|
||||
def robust_load(model: Any, path: str, device: Optional[torch.device] = None) -> bool:
|
||||
"""
|
||||
Robust model loading with fallback approaches
|
||||
|
||||
Args:
|
||||
model: The model object to load into
|
||||
path: Path to load the model from
|
||||
device: Device to load the model on
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Try regular PyTorch load first
|
||||
try:
|
||||
logger.info(f"Loading model from {path}")
|
||||
if os.path.exists(path):
|
||||
checkpoint = torch.load(path, map_location=device)
|
||||
|
||||
# Load network states
|
||||
if 'policy_net' in checkpoint:
|
||||
model.policy_net.load_state_dict(checkpoint['policy_net'])
|
||||
if 'target_net' in checkpoint:
|
||||
model.target_net.load_state_dict(checkpoint['target_net'])
|
||||
|
||||
# Load other attributes
|
||||
if 'epsilon' in checkpoint:
|
||||
model.epsilon = checkpoint['epsilon']
|
||||
if 'optimizer' in checkpoint and hasattr(model, 'optimizer') and model.optimizer is not None:
|
||||
try:
|
||||
model.optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load optimizer state: {e}")
|
||||
|
||||
logger.info("Successfully loaded model")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Regular load failed: {e}")
|
||||
|
||||
# Try loading JIT saved components
|
||||
try:
|
||||
policy_path = f"{path}.policy.jit"
|
||||
target_path = f"{path}.target.jit"
|
||||
params_path = f"{path}.params.json"
|
||||
|
||||
if all(os.path.exists(p) for p in [policy_path, target_path, params_path]):
|
||||
logger.info(f"Loading JIT model components")
|
||||
|
||||
# Load JIT models (this is more complex and may need model reconstruction)
|
||||
# For now, just log that we found JIT files
|
||||
logger.info("Found JIT model files, but loading them requires special handling")
|
||||
with open(params_path, 'r') as f:
|
||||
params = json.load(f)
|
||||
logger.info(f"Model parameters: {params}")
|
||||
|
||||
# Note: Actually loading JIT models would require recreating the model architecture
|
||||
# This is a placeholder for future implementation
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"JIT load failed: {e}")
|
||||
|
||||
logger.error(f"All load attempts failed for {path}")
|
||||
return False
|
||||
|
||||
def get_model_info(path: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about a saved model
|
||||
|
||||
Args:
|
||||
path: Path to the model file
|
||||
|
||||
Returns:
|
||||
dict: Model information
|
||||
"""
|
||||
info = {
|
||||
'exists': False,
|
||||
'size_bytes': 0,
|
||||
'has_optimizer': False,
|
||||
'parameters': {}
|
||||
}
|
||||
|
||||
try:
|
||||
if os.path.exists(path):
|
||||
info['exists'] = True
|
||||
info['size_bytes'] = os.path.getsize(path)
|
||||
|
||||
# Try to load and inspect
|
||||
checkpoint = torch.load(path, map_location='cpu')
|
||||
info['has_optimizer'] = 'optimizer' in checkpoint
|
||||
|
||||
# Extract parameter info
|
||||
for key in ['epsilon', 'state_size', 'action_size', 'hidden_size']:
|
||||
if key in checkpoint:
|
||||
info['parameters'][key] = checkpoint[key]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get model info for {path}: {e}")
|
||||
|
||||
return info
|
||||
|
||||
def verify_save_load_cycle(model: Any, test_path: str) -> bool:
|
||||
"""
|
||||
Test that a model can be saved and loaded correctly
|
||||
|
||||
Args:
|
||||
model: Model to test
|
||||
test_path: Path for test file
|
||||
|
||||
Returns:
|
||||
bool: True if save/load cycle successful
|
||||
"""
|
||||
try:
|
||||
# Save the model
|
||||
if not robust_save(model, test_path):
|
||||
return False
|
||||
|
||||
# Create a new model instance (this would need model creation logic)
|
||||
# For now, just verify the file exists and has content
|
||||
if os.path.exists(test_path) and os.path.getsize(test_path) > 0:
|
||||
logger.info("Save/load cycle verification successful")
|
||||
# Clean up test file
|
||||
os.remove(test_path)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Save/load cycle verification failed: {e}")
|
||||
return False
|
@@ -1,238 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Port Management Utility
|
||||
|
||||
This script provides utilities to:
|
||||
1. Find available ports in a specified range
|
||||
2. Kill stale processes running on specific ports
|
||||
3. Kill all debug/training instances
|
||||
|
||||
Usage:
|
||||
- As a module: import port_manager and use its functions
|
||||
- Directly: python port_manager.py --kill-stale --min-port 6000 --max-port 7000
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import socket
|
||||
import argparse
|
||||
import psutil
|
||||
import logging
|
||||
import time
|
||||
import signal
|
||||
from typing import List, Tuple, Optional, Set
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('port_manager')
|
||||
|
||||
# Define process names to look for when killing stale instances
|
||||
DEBUG_PROCESS_KEYWORDS = [
|
||||
'tensorboard',
|
||||
'python train_',
|
||||
'realtime.py',
|
||||
'train_rl_with_realtime.py'
|
||||
]
|
||||
|
||||
def is_port_in_use(port: int) -> bool:
|
||||
"""
|
||||
Check if a port is in use
|
||||
|
||||
Args:
|
||||
port (int): Port number to check
|
||||
|
||||
Returns:
|
||||
bool: True if port is in use, False otherwise
|
||||
"""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(('localhost', port)) == 0
|
||||
|
||||
def find_available_port(start_port: int, end_port: int) -> Optional[int]:
|
||||
"""
|
||||
Find an available port in the specified range
|
||||
|
||||
Args:
|
||||
start_port (int): Lower bound of port range
|
||||
end_port (int): Upper bound of port range
|
||||
|
||||
Returns:
|
||||
Optional[int]: Available port number or None if no ports available
|
||||
"""
|
||||
for port in range(start_port, end_port + 1):
|
||||
if not is_port_in_use(port):
|
||||
return port
|
||||
return None
|
||||
|
||||
def get_process_by_port(port: int) -> List[psutil.Process]:
|
||||
"""
|
||||
Get processes using a specific port
|
||||
|
||||
Args:
|
||||
port (int): Port number to check
|
||||
|
||||
Returns:
|
||||
List[psutil.Process]: List of processes using the port
|
||||
"""
|
||||
processes = []
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
||||
try:
|
||||
for conn in proc.connections(kind='inet'):
|
||||
if conn.laddr.port == port:
|
||||
processes.append(proc)
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
pass
|
||||
return processes
|
||||
|
||||
def kill_process_by_port(port: int) -> Tuple[int, List[str]]:
|
||||
"""
|
||||
Kill processes using a specific port
|
||||
|
||||
Args:
|
||||
port (int): Port number to check
|
||||
|
||||
Returns:
|
||||
Tuple[int, List[str]]: Count of killed processes and their names
|
||||
"""
|
||||
processes = get_process_by_port(port)
|
||||
killed = []
|
||||
|
||||
for proc in processes:
|
||||
try:
|
||||
proc_name = " ".join(proc.cmdline()) if proc.cmdline() else proc.name()
|
||||
logger.info(f"Terminating process {proc.pid}: {proc_name}")
|
||||
proc.terminate()
|
||||
killed.append(proc_name)
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
|
||||
# Give processes time to terminate gracefully
|
||||
if processes:
|
||||
time.sleep(0.5)
|
||||
|
||||
# Force kill any remaining processes
|
||||
for proc in processes:
|
||||
try:
|
||||
if proc.is_running():
|
||||
logger.info(f"Force killing process {proc.pid}")
|
||||
proc.kill()
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied):
|
||||
pass
|
||||
|
||||
return len(killed), killed
|
||||
|
||||
def kill_stale_debug_instances() -> Tuple[int, Set[str]]:
|
||||
"""
|
||||
Kill all stale debug and training instances based on process names
|
||||
|
||||
Returns:
|
||||
Tuple[int, Set[str]]: Count of killed processes and their names
|
||||
"""
|
||||
killed_count = 0
|
||||
killed_procs = set()
|
||||
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
||||
try:
|
||||
cmd = " ".join(proc.cmdline()) if proc.cmdline() else proc.name()
|
||||
|
||||
# Check if this is a debug/training process we should kill
|
||||
if any(keyword in cmd for keyword in DEBUG_PROCESS_KEYWORDS):
|
||||
logger.info(f"Terminating stale process {proc.pid}: {cmd}")
|
||||
proc.terminate()
|
||||
killed_count += 1
|
||||
killed_procs.add(cmd)
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
pass
|
||||
|
||||
# Give processes time to terminate
|
||||
if killed_count > 0:
|
||||
time.sleep(1)
|
||||
|
||||
# Force kill any remaining processes
|
||||
for proc in psutil.process_iter(['pid', 'name', 'cmdline']):
|
||||
try:
|
||||
cmd = " ".join(proc.cmdline()) if proc.cmdline() else proc.name()
|
||||
|
||||
if any(keyword in cmd for keyword in DEBUG_PROCESS_KEYWORDS) and proc.is_running():
|
||||
logger.info(f"Force killing stale process {proc.pid}")
|
||||
proc.kill()
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
pass
|
||||
|
||||
return killed_count, killed_procs
|
||||
|
||||
def get_port_with_fallback(preferred_port: int, min_port: int, max_port: int) -> int:
|
||||
"""
|
||||
Try to use preferred port, fall back to any available port in range
|
||||
|
||||
Args:
|
||||
preferred_port (int): Preferred port to use
|
||||
min_port (int): Minimum port in fallback range
|
||||
max_port (int): Maximum port in fallback range
|
||||
|
||||
Returns:
|
||||
int: Available port number
|
||||
"""
|
||||
# First try the preferred port
|
||||
if not is_port_in_use(preferred_port):
|
||||
return preferred_port
|
||||
|
||||
# If preferred port is in use, try to free it
|
||||
logger.info(f"Preferred port {preferred_port} is in use, attempting to free it")
|
||||
kill_count, _ = kill_process_by_port(preferred_port)
|
||||
|
||||
if kill_count > 0 and not is_port_in_use(preferred_port):
|
||||
logger.info(f"Successfully freed port {preferred_port}")
|
||||
return preferred_port
|
||||
|
||||
# If we couldn't free the preferred port, find another available port
|
||||
logger.info(f"Looking for available port in range {min_port}-{max_port}")
|
||||
available_port = find_available_port(min_port, max_port)
|
||||
|
||||
if available_port:
|
||||
logger.info(f"Using alternative port: {available_port}")
|
||||
return available_port
|
||||
else:
|
||||
# If no ports are available, force kill processes in the entire range
|
||||
logger.warning(f"No available ports in range {min_port}-{max_port}, freeing ports")
|
||||
for port in range(min_port, max_port + 1):
|
||||
kill_process_by_port(port)
|
||||
|
||||
# Try again
|
||||
available_port = find_available_port(min_port, max_port)
|
||||
if available_port:
|
||||
logger.info(f"Using port {available_port} after freeing")
|
||||
return available_port
|
||||
else:
|
||||
logger.error(f"Could not find available port even after freeing range {min_port}-{max_port}")
|
||||
raise RuntimeError(f"No available ports in range {min_port}-{max_port}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Port management utility')
|
||||
parser.add_argument('--kill-stale', action='store_true', help='Kill all stale debug instances')
|
||||
parser.add_argument('--free-port', type=int, help='Free a specific port')
|
||||
parser.add_argument('--find-port', action='store_true', help='Find an available port')
|
||||
parser.add_argument('--min-port', type=int, default=6000, help='Minimum port in range')
|
||||
parser.add_argument('--max-port', type=int, default=7000, help='Maximum port in range')
|
||||
parser.add_argument('--preferred-port', type=int, help='Preferred port to use')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.kill_stale:
|
||||
count, procs = kill_stale_debug_instances()
|
||||
logger.info(f"Killed {count} stale processes")
|
||||
for proc in procs:
|
||||
logger.info(f" - {proc}")
|
||||
|
||||
if args.free_port:
|
||||
count, killed = kill_process_by_port(args.free_port)
|
||||
logger.info(f"Killed {count} processes using port {args.free_port}")
|
||||
for proc in killed:
|
||||
logger.info(f" - {proc}")
|
||||
|
||||
if args.find_port or args.preferred_port:
|
||||
preferred = args.preferred_port if args.preferred_port else args.min_port
|
||||
port = get_port_with_fallback(preferred, args.min_port, args.max_port)
|
||||
print(port) # Print only the port number for easy capture in scripts
|
@@ -1,340 +0,0 @@
|
||||
"""
|
||||
Process Supervisor - Handles process monitoring, restarts, and supervision
|
||||
Prevents silent failures by monitoring process health and restarting on crashes
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import logging
|
||||
import signal
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, Any, Optional, Callable, List
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ProcessSupervisor:
|
||||
"""Supervise processes and restart them on failure"""
|
||||
|
||||
def __init__(self, max_restarts: int = 5, restart_delay: int = 10):
|
||||
"""
|
||||
Initialize process supervisor
|
||||
|
||||
Args:
|
||||
max_restarts: Maximum number of restarts before giving up
|
||||
restart_delay: Delay in seconds between restarts
|
||||
"""
|
||||
self.max_restarts = max_restarts
|
||||
self.restart_delay = restart_delay
|
||||
|
||||
self.processes: Dict[str, Dict[str, Any]] = {}
|
||||
self.monitoring = False
|
||||
self.monitor_thread = None
|
||||
|
||||
# Callbacks
|
||||
self.process_started_callback: Optional[Callable] = None
|
||||
self.process_failed_callback: Optional[Callable] = None
|
||||
self.process_restarted_callback: Optional[Callable] = None
|
||||
|
||||
def add_process(self, name: str, command: List[str],
|
||||
working_dir: Optional[str] = None,
|
||||
env: Optional[Dict[str, str]] = None,
|
||||
auto_restart: bool = True):
|
||||
"""
|
||||
Add a process to supervise
|
||||
|
||||
Args:
|
||||
name: Process name
|
||||
command: Command to run as list
|
||||
working_dir: Working directory
|
||||
env: Environment variables
|
||||
auto_restart: Whether to auto-restart on failure
|
||||
"""
|
||||
self.processes[name] = {
|
||||
'command': command,
|
||||
'working_dir': working_dir,
|
||||
'env': env,
|
||||
'auto_restart': auto_restart,
|
||||
'process': None,
|
||||
'restart_count': 0,
|
||||
'last_start': None,
|
||||
'last_failure': None,
|
||||
'status': 'stopped'
|
||||
}
|
||||
logger.info(f"Added process '{name}' to supervisor")
|
||||
|
||||
def start_process(self, name: str) -> bool:
|
||||
"""Start a specific process"""
|
||||
if name not in self.processes:
|
||||
logger.error(f"Process '{name}' not found")
|
||||
return False
|
||||
|
||||
proc_info = self.processes[name]
|
||||
|
||||
if proc_info['process'] and proc_info['process'].poll() is None:
|
||||
logger.warning(f"Process '{name}' is already running")
|
||||
return True
|
||||
|
||||
try:
|
||||
# Prepare environment
|
||||
env = os.environ.copy()
|
||||
if proc_info['env']:
|
||||
env.update(proc_info['env'])
|
||||
|
||||
# Start process
|
||||
process = subprocess.Popen(
|
||||
proc_info['command'],
|
||||
cwd=proc_info['working_dir'],
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True
|
||||
)
|
||||
|
||||
proc_info['process'] = process
|
||||
proc_info['last_start'] = datetime.now()
|
||||
proc_info['status'] = 'running'
|
||||
|
||||
logger.info(f"Started process '{name}' (PID: {process.pid})")
|
||||
|
||||
if self.process_started_callback:
|
||||
try:
|
||||
self.process_started_callback(name, process.pid)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process started callback: {e}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start process '{name}': {e}")
|
||||
proc_info['status'] = 'failed'
|
||||
proc_info['last_failure'] = datetime.now()
|
||||
return False
|
||||
|
||||
def stop_process(self, name: str, timeout: int = 10) -> bool:
|
||||
"""Stop a specific process"""
|
||||
if name not in self.processes:
|
||||
logger.error(f"Process '{name}' not found")
|
||||
return False
|
||||
|
||||
proc_info = self.processes[name]
|
||||
process = proc_info['process']
|
||||
|
||||
if not process or process.poll() is not None:
|
||||
logger.info(f"Process '{name}' is not running")
|
||||
proc_info['status'] = 'stopped'
|
||||
return True
|
||||
|
||||
try:
|
||||
# Try graceful shutdown first
|
||||
process.terminate()
|
||||
|
||||
# Wait for graceful shutdown
|
||||
try:
|
||||
process.wait(timeout=timeout)
|
||||
logger.info(f"Process '{name}' terminated gracefully")
|
||||
except subprocess.TimeoutExpired:
|
||||
# Force kill if graceful shutdown fails
|
||||
logger.warning(f"Process '{name}' did not terminate gracefully, force killing")
|
||||
process.kill()
|
||||
process.wait()
|
||||
logger.info(f"Process '{name}' force killed")
|
||||
|
||||
proc_info['status'] = 'stopped'
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error stopping process '{name}': {e}")
|
||||
return False
|
||||
|
||||
def restart_process(self, name: str) -> bool:
|
||||
"""Restart a specific process"""
|
||||
logger.info(f"Restarting process '{name}'")
|
||||
|
||||
if name not in self.processes:
|
||||
logger.error(f"Process '{name}' not found")
|
||||
return False
|
||||
|
||||
proc_info = self.processes[name]
|
||||
|
||||
# Stop if running
|
||||
if proc_info['process'] and proc_info['process'].poll() is None:
|
||||
self.stop_process(name)
|
||||
|
||||
# Wait restart delay
|
||||
time.sleep(self.restart_delay)
|
||||
|
||||
# Increment restart count
|
||||
proc_info['restart_count'] += 1
|
||||
|
||||
# Check restart limit
|
||||
if proc_info['restart_count'] > self.max_restarts:
|
||||
logger.error(f"Process '{name}' exceeded max restarts ({self.max_restarts})")
|
||||
proc_info['status'] = 'failed_max_restarts'
|
||||
return False
|
||||
|
||||
# Start process
|
||||
success = self.start_process(name)
|
||||
|
||||
if success and self.process_restarted_callback:
|
||||
try:
|
||||
self.process_restarted_callback(name, proc_info['restart_count'])
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process restarted callback: {e}")
|
||||
|
||||
return success
|
||||
|
||||
def start_monitoring(self):
|
||||
"""Start process monitoring"""
|
||||
if self.monitoring:
|
||||
logger.warning("Process monitoring already started")
|
||||
return
|
||||
|
||||
self.monitoring = True
|
||||
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||
self.monitor_thread.start()
|
||||
logger.info("Process monitoring started")
|
||||
|
||||
def stop_monitoring(self):
|
||||
"""Stop process monitoring"""
|
||||
self.monitoring = False
|
||||
if self.monitor_thread:
|
||||
self.monitor_thread.join(timeout=5)
|
||||
logger.info("Process monitoring stopped")
|
||||
|
||||
def _monitor_loop(self):
|
||||
"""Main monitoring loop"""
|
||||
logger.info("Process monitoring loop started")
|
||||
|
||||
while self.monitoring:
|
||||
try:
|
||||
for name, proc_info in self.processes.items():
|
||||
self._check_process_health(name, proc_info)
|
||||
|
||||
time.sleep(5) # Check every 5 seconds
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process monitoring loop: {e}")
|
||||
time.sleep(5)
|
||||
|
||||
logger.info("Process monitoring loop stopped")
|
||||
|
||||
def _check_process_health(self, name: str, proc_info: Dict[str, Any]):
|
||||
"""Check health of a specific process"""
|
||||
process = proc_info['process']
|
||||
|
||||
if not process:
|
||||
return
|
||||
|
||||
# Check if process is still running
|
||||
return_code = process.poll()
|
||||
|
||||
if return_code is not None:
|
||||
# Process has exited
|
||||
proc_info['status'] = 'exited'
|
||||
proc_info['last_failure'] = datetime.now()
|
||||
|
||||
logger.warning(f"Process '{name}' exited with code {return_code}")
|
||||
|
||||
# Read stdout/stderr for debugging
|
||||
try:
|
||||
stdout, stderr = process.communicate(timeout=1)
|
||||
if stdout:
|
||||
logger.info(f"Process '{name}' stdout: {stdout[-500:]}") # Last 500 chars
|
||||
if stderr:
|
||||
logger.error(f"Process '{name}' stderr: {stderr[-500:]}") # Last 500 chars
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read process output: {e}")
|
||||
|
||||
if self.process_failed_callback:
|
||||
try:
|
||||
self.process_failed_callback(name, return_code)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in process failed callback: {e}")
|
||||
|
||||
# Auto-restart if enabled
|
||||
if proc_info['auto_restart'] and proc_info['restart_count'] < self.max_restarts:
|
||||
logger.info(f"Auto-restarting process '{name}'")
|
||||
threading.Thread(target=self.restart_process, args=(name,), daemon=True).start()
|
||||
|
||||
def get_process_status(self, name: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get status of a specific process"""
|
||||
if name not in self.processes:
|
||||
return None
|
||||
|
||||
proc_info = self.processes[name]
|
||||
process = proc_info['process']
|
||||
|
||||
status = {
|
||||
'name': name,
|
||||
'status': proc_info['status'],
|
||||
'restart_count': proc_info['restart_count'],
|
||||
'last_start': proc_info['last_start'],
|
||||
'last_failure': proc_info['last_failure'],
|
||||
'auto_restart': proc_info['auto_restart'],
|
||||
'pid': process.pid if process and process.poll() is None else None,
|
||||
'running': process is not None and process.poll() is None
|
||||
}
|
||||
|
||||
return status
|
||||
|
||||
def get_all_status(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""Get status of all processes"""
|
||||
return {name: self.get_process_status(name) for name in self.processes}
|
||||
|
||||
def set_callbacks(self,
|
||||
process_started: Optional[Callable] = None,
|
||||
process_failed: Optional[Callable] = None,
|
||||
process_restarted: Optional[Callable] = None):
|
||||
"""Set callback functions for process events"""
|
||||
self.process_started_callback = process_started
|
||||
self.process_failed_callback = process_failed
|
||||
self.process_restarted_callback = process_restarted
|
||||
|
||||
def shutdown_all(self):
|
||||
"""Shutdown all processes"""
|
||||
logger.info("Shutting down all supervised processes")
|
||||
|
||||
for name in list(self.processes.keys()):
|
||||
self.stop_process(name)
|
||||
|
||||
self.stop_monitoring()
|
||||
|
||||
# Global instance
|
||||
_process_supervisor = None
|
||||
|
||||
def get_process_supervisor() -> ProcessSupervisor:
|
||||
"""Get global process supervisor instance"""
|
||||
global _process_supervisor
|
||||
if _process_supervisor is None:
|
||||
_process_supervisor = ProcessSupervisor()
|
||||
return _process_supervisor
|
||||
|
||||
def create_supervised_dashboard_runner():
|
||||
"""Create a supervised version of the dashboard runner"""
|
||||
supervisor = get_process_supervisor()
|
||||
|
||||
# Add dashboard process
|
||||
supervisor.add_process(
|
||||
name="clean_dashboard",
|
||||
command=[sys.executable, "run_clean_dashboard.py"],
|
||||
working_dir=os.getcwd(),
|
||||
auto_restart=True
|
||||
)
|
||||
|
||||
# Set up callbacks
|
||||
def on_process_failed(name: str, return_code: int):
|
||||
logger.error(f"Dashboard process failed with code {return_code}")
|
||||
|
||||
def on_process_restarted(name: str, restart_count: int):
|
||||
logger.info(f"Dashboard restarted (attempt {restart_count})")
|
||||
|
||||
supervisor.set_callbacks(
|
||||
process_failed=on_process_failed,
|
||||
process_restarted=on_process_restarted
|
||||
)
|
||||
|
||||
return supervisor
|
@@ -1,220 +0,0 @@
|
||||
"""
|
||||
Improved Reward Function for RL Trading Agent
|
||||
|
||||
This module provides a more sophisticated reward function for the RL trading agent
|
||||
that incorporates realistic trading fees, penalties for excessive trading, and
|
||||
rewards for successful holding of positions.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from collections import deque
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RewardCalculator:
|
||||
def __init__(self, base_fee_rate=0.001, reward_scaling=10.0, risk_aversion=0.1):
|
||||
self.base_fee_rate = base_fee_rate
|
||||
self.reward_scaling = reward_scaling
|
||||
self.risk_aversion = risk_aversion
|
||||
self.trade_pnls = []
|
||||
self.returns = []
|
||||
self.trade_timestamps = []
|
||||
self.frequency_threshold = 10 # Trades per minute threshold for penalty
|
||||
self.max_frequency_penalty = 0.05
|
||||
|
||||
def record_pnl(self, pnl):
|
||||
"""Record P&L for risk adjustment calculations"""
|
||||
self.trade_pnls.append(pnl)
|
||||
if len(self.trade_pnls) > 100:
|
||||
self.trade_pnls.pop(0)
|
||||
|
||||
def record_trade(self, action):
|
||||
"""Record trade action for frequency penalty calculations"""
|
||||
from time import time
|
||||
self.trade_timestamps.append(time())
|
||||
if len(self.trade_timestamps) > 100:
|
||||
self.trade_timestamps.pop(0)
|
||||
|
||||
def _calculate_frequency_penalty(self):
|
||||
"""Calculate penalty for high-frequency trading"""
|
||||
if len(self.trade_timestamps) < 2:
|
||||
return 0.0
|
||||
time_span = self.trade_timestamps[-1] - self.trade_timestamps[0]
|
||||
if time_span <= 0:
|
||||
return 0.0
|
||||
trades_per_minute = (len(self.trade_timestamps) / time_span) * 60
|
||||
if trades_per_minute > self.frequency_threshold:
|
||||
penalty = min(self.max_frequency_penalty, (trades_per_minute - self.frequency_threshold) * 0.001)
|
||||
return penalty
|
||||
return 0.0
|
||||
|
||||
def _calculate_risk_adjustment(self, reward):
|
||||
"""Adjust rewards based on risk (simple Sharpe ratio implementation)"""
|
||||
if len(self.trade_pnls) < 5:
|
||||
return reward
|
||||
pnl_array = np.array(self.trade_pnls)
|
||||
mean_return = np.mean(pnl_array)
|
||||
std_return = np.std(pnl_array)
|
||||
if std_return == 0:
|
||||
return reward
|
||||
sharpe = mean_return / std_return
|
||||
adjustment_factor = np.clip(1.0 + 0.5 * sharpe, 0.5, 2.0)
|
||||
return reward * adjustment_factor
|
||||
|
||||
def _calculate_holding_reward(self, position_held_time, price_change):
|
||||
"""Calculate reward for holding a position"""
|
||||
base_holding_reward = 0.0005 * (position_held_time / 60.0)
|
||||
if price_change > 0:
|
||||
return base_holding_reward * 2
|
||||
elif price_change < 0:
|
||||
return base_holding_reward * 0.5
|
||||
return base_holding_reward
|
||||
|
||||
def calculate_basic_reward(self, pnl, confidence):
|
||||
"""Calculate basic training reward based on P&L and confidence"""
|
||||
try:
|
||||
base_reward = pnl
|
||||
if pnl < 0 and confidence > 0.7:
|
||||
confidence_adjustment = -confidence * 2
|
||||
elif pnl > 0 and confidence > 0.7:
|
||||
confidence_adjustment = confidence * 1.5
|
||||
else:
|
||||
confidence_adjustment = 0
|
||||
final_reward = base_reward + confidence_adjustment
|
||||
normalized_reward = np.tanh(final_reward / 10.0)
|
||||
logger.debug(f"Basic reward calculation: P&L={pnl:.4f}, confidence={confidence:.2f}, reward={normalized_reward:.4f}")
|
||||
return float(normalized_reward)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating basic reward: {e}")
|
||||
return 0.0
|
||||
|
||||
def calculate_enhanced_reward(self, action, price_change, position_held_time=0, volatility=None, is_profitable=False, confidence=0.0, predicted_change=0.0, actual_change=0.0, current_pnl=0.0, symbol='UNKNOWN'):
|
||||
"""Calculate enhanced reward for trading actions with shifted neutral point
|
||||
|
||||
Neutral reward is shifted to require profits that exceed double the fees,
|
||||
which penalizes small profit trades and encourages holding for larger moves.
|
||||
Current PnL is given more weight in the decision-making process.
|
||||
"""
|
||||
fee = self.base_fee_rate
|
||||
double_fee = fee * 4 # Double the fees (2x open + 2x close = 4x base fee)
|
||||
frequency_penalty = self._calculate_frequency_penalty()
|
||||
|
||||
if action == 0: # Buy
|
||||
# Penalize buying more when already in profit
|
||||
reward = -fee - frequency_penalty
|
||||
if current_pnl > 0:
|
||||
# Reduce incentive to close profitable positions
|
||||
reward -= current_pnl * 0.2
|
||||
elif action == 1: # Sell
|
||||
profit_pct = price_change
|
||||
|
||||
# Shift neutral point - require profit > double fees to be considered positive
|
||||
net_profit = profit_pct - double_fee
|
||||
|
||||
# Scale reward based on profit size
|
||||
if net_profit > 0:
|
||||
# Exponential reward for larger profits
|
||||
reward = (net_profit ** 1.5) * self.reward_scaling
|
||||
else:
|
||||
# Linear penalty for losses
|
||||
reward = net_profit * self.reward_scaling
|
||||
|
||||
reward -= frequency_penalty
|
||||
self.record_pnl(net_profit)
|
||||
|
||||
# Add extra penalty for very small profits (less than 3x fees)
|
||||
if 0 < profit_pct < (fee * 6):
|
||||
reward -= 0.5 # Discourage tiny profit-taking
|
||||
else: # Hold
|
||||
if is_profitable:
|
||||
# Increase reward for holding profitable positions
|
||||
profit_factor = min(5.0, current_pnl * 20) # Cap at 5x
|
||||
reward = self._calculate_holding_reward(position_held_time, price_change) * (1.0 + profit_factor)
|
||||
|
||||
# Add bonus for holding through volatility when profitable
|
||||
if volatility is not None and volatility > 0.001:
|
||||
reward += 0.1 * volatility * 100
|
||||
else:
|
||||
# Small penalty for holding losing positions
|
||||
loss_factor = min(1.0, abs(current_pnl) * 10)
|
||||
reward = -0.0001 * (1.0 + loss_factor)
|
||||
|
||||
# But reduce penalty for very recent positions (give them time)
|
||||
if position_held_time < 30: # Less than 30 seconds
|
||||
reward *= 0.5
|
||||
|
||||
# Prediction accuracy reward component
|
||||
if action in [0, 1] and predicted_change != 0:
|
||||
if (action == 0 and actual_change > 0) or (action == 1 and actual_change < 0):
|
||||
reward += abs(actual_change) * 5.0
|
||||
else:
|
||||
reward -= abs(predicted_change) * 2.0
|
||||
|
||||
# Increase weight of current PnL in decision making (3x more than before)
|
||||
reward += current_pnl * 0.3
|
||||
|
||||
# Volatility penalty
|
||||
if volatility is not None:
|
||||
reward -= abs(volatility) * 100
|
||||
|
||||
# Risk adjustment
|
||||
if self.risk_aversion > 0 and len(self.returns) > 1:
|
||||
returns_std = np.std(self.returns)
|
||||
reward -= returns_std * self.risk_aversion
|
||||
|
||||
self.record_trade(action)
|
||||
return reward
|
||||
|
||||
def calculate_prediction_reward(self, symbol, predicted_direction, actual_direction, confidence, predicted_change, actual_change, current_pnl=0.0, position_duration=0.0):
|
||||
"""Calculate reward for prediction accuracy"""
|
||||
reward = 0.0
|
||||
if predicted_direction == actual_direction:
|
||||
reward += 1.0 * confidence
|
||||
else:
|
||||
reward -= 0.5
|
||||
if predicted_direction == actual_direction and abs(predicted_change) > 0.001:
|
||||
reward += abs(actual_change) * 5.0
|
||||
if predicted_direction != actual_direction and abs(predicted_change) > 0.001:
|
||||
reward -= abs(predicted_change) * 2.0
|
||||
reward += current_pnl * 0.1
|
||||
# Dynamic adjustment based on recent PnL (loss cutting incentive)
|
||||
if hasattr(self, 'pnl_history') and symbol in self.pnl_history and self.pnl_history[symbol]:
|
||||
latest_pnl_entry = self.pnl_history[symbol][-1]
|
||||
latest_pnl_value = latest_pnl_entry.get('pnl', 0.0) if isinstance(latest_pnl_entry, dict) else 0.0
|
||||
if latest_pnl_value < 0 and position_duration > 60:
|
||||
reward -= (abs(latest_pnl_value) * 0.2)
|
||||
pnl_values = [entry.get('pnl', 0.0) for entry in self.pnl_history[symbol] if isinstance(entry, dict)]
|
||||
best_pnl = max(pnl_values) if pnl_values else 0.0
|
||||
if best_pnl < 0.0:
|
||||
reward -= 0.1
|
||||
return reward
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
# Create calculator instance
|
||||
reward_calc = RewardCalculator()
|
||||
|
||||
# Example reward for a buy action
|
||||
buy_reward = reward_calc.calculate_enhanced_reward(action=0, price_change=0)
|
||||
print(f"Buy action reward: {buy_reward:.5f}")
|
||||
|
||||
# Record a trade for frequency tracking
|
||||
reward_calc.record_trade(0)
|
||||
|
||||
# Wait a bit and make another trade to test frequency penalty
|
||||
import time
|
||||
time.sleep(0.1)
|
||||
|
||||
# Example reward for a sell action with profit
|
||||
sell_reward = reward_calc.calculate_enhanced_reward(action=1, price_change=0.015, position_held_time=60)
|
||||
print(f"Sell action reward (with profit): {sell_reward:.5f}")
|
||||
|
||||
# Example reward for a hold action on profitable position
|
||||
hold_reward = reward_calc.calculate_enhanced_reward(action=2, price_change=0.01, position_held_time=30, is_profitable=True)
|
||||
print(f"Hold action reward (profitable): {hold_reward:.5f}")
|
||||
|
||||
# Example reward for a hold action on unprofitable position
|
||||
hold_reward_neg = reward_calc.calculate_enhanced_reward(action=2, price_change=-0.01, position_held_time=30, is_profitable=False)
|
||||
print(f"Hold action reward (unprofitable): {hold_reward_neg:.5f}")
|
@@ -1,288 +0,0 @@
|
||||
"""
|
||||
System Resource Monitor - Prevents resource exhaustion and silent failures
|
||||
Monitors memory, CPU, and disk usage to prevent system crashes
|
||||
"""
|
||||
|
||||
import psutil
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
import gc
|
||||
import os
|
||||
from typing import Dict, Any, Optional, Callable
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SystemResourceMonitor:
|
||||
"""Monitor system resources and prevent exhaustion"""
|
||||
|
||||
def __init__(self,
|
||||
memory_threshold_mb: int = 7000, # 7GB threshold for 8GB system
|
||||
cpu_threshold_percent: float = 90.0,
|
||||
disk_threshold_percent: float = 95.0,
|
||||
check_interval_seconds: int = 30):
|
||||
"""
|
||||
Initialize system resource monitor
|
||||
|
||||
Args:
|
||||
memory_threshold_mb: Memory threshold in MB before cleanup
|
||||
cpu_threshold_percent: CPU threshold percentage before warning
|
||||
disk_threshold_percent: Disk usage threshold before warning
|
||||
check_interval_seconds: How often to check resources
|
||||
"""
|
||||
self.memory_threshold_mb = memory_threshold_mb
|
||||
self.cpu_threshold_percent = cpu_threshold_percent
|
||||
self.disk_threshold_percent = disk_threshold_percent
|
||||
self.check_interval = check_interval_seconds
|
||||
|
||||
self.monitoring = False
|
||||
self.monitor_thread = None
|
||||
|
||||
# Callbacks for resource events
|
||||
self.memory_warning_callback: Optional[Callable] = None
|
||||
self.cpu_warning_callback: Optional[Callable] = None
|
||||
self.disk_warning_callback: Optional[Callable] = None
|
||||
self.cleanup_callback: Optional[Callable] = None
|
||||
|
||||
# Resource history for trending
|
||||
self.resource_history = []
|
||||
self.max_history_entries = 100
|
||||
|
||||
# Last warning times to prevent spam
|
||||
self.last_memory_warning = datetime.min
|
||||
self.last_cpu_warning = datetime.min
|
||||
self.last_disk_warning = datetime.min
|
||||
self.warning_cooldown = timedelta(minutes=5)
|
||||
|
||||
def start_monitoring(self):
|
||||
"""Start resource monitoring in background thread"""
|
||||
if self.monitoring:
|
||||
logger.warning("Resource monitoring already started")
|
||||
return
|
||||
|
||||
self.monitoring = True
|
||||
self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
|
||||
self.monitor_thread.start()
|
||||
logger.info(f"System resource monitoring started (memory threshold: {self.memory_threshold_mb}MB)")
|
||||
|
||||
def stop_monitoring(self):
|
||||
"""Stop resource monitoring"""
|
||||
self.monitoring = False
|
||||
if self.monitor_thread:
|
||||
self.monitor_thread.join(timeout=5)
|
||||
logger.info("System resource monitoring stopped")
|
||||
|
||||
def set_callbacks(self,
|
||||
memory_warning: Optional[Callable] = None,
|
||||
cpu_warning: Optional[Callable] = None,
|
||||
disk_warning: Optional[Callable] = None,
|
||||
cleanup: Optional[Callable] = None):
|
||||
"""Set callback functions for resource events"""
|
||||
self.memory_warning_callback = memory_warning
|
||||
self.cpu_warning_callback = cpu_warning
|
||||
self.disk_warning_callback = disk_warning
|
||||
self.cleanup_callback = cleanup
|
||||
|
||||
def get_current_usage(self) -> Dict[str, Any]:
|
||||
"""Get current system resource usage"""
|
||||
try:
|
||||
# Memory usage
|
||||
memory = psutil.virtual_memory()
|
||||
memory_mb = memory.used / (1024 * 1024)
|
||||
memory_percent = memory.percent
|
||||
|
||||
# CPU usage
|
||||
cpu_percent = psutil.cpu_percent(interval=1)
|
||||
|
||||
# Disk usage (current directory)
|
||||
disk = psutil.disk_usage('.')
|
||||
disk_percent = (disk.used / disk.total) * 100
|
||||
|
||||
# Process-specific info
|
||||
process = psutil.Process()
|
||||
process_memory_mb = process.memory_info().rss / (1024 * 1024)
|
||||
|
||||
return {
|
||||
'timestamp': datetime.now(),
|
||||
'memory': {
|
||||
'total_mb': memory.total / (1024 * 1024),
|
||||
'used_mb': memory_mb,
|
||||
'percent': memory_percent,
|
||||
'available_mb': memory.available / (1024 * 1024)
|
||||
},
|
||||
'process_memory_mb': process_memory_mb,
|
||||
'cpu_percent': cpu_percent,
|
||||
'disk': {
|
||||
'total_gb': disk.total / (1024 * 1024 * 1024),
|
||||
'used_gb': disk.used / (1024 * 1024 * 1024),
|
||||
'percent': disk_percent
|
||||
}
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting system usage: {e}")
|
||||
return {}
|
||||
|
||||
def _monitor_loop(self):
|
||||
"""Main monitoring loop"""
|
||||
logger.info("Resource monitoring loop started")
|
||||
|
||||
while self.monitoring:
|
||||
try:
|
||||
usage = self.get_current_usage()
|
||||
if not usage:
|
||||
time.sleep(self.check_interval)
|
||||
continue
|
||||
|
||||
# Store in history
|
||||
self.resource_history.append(usage)
|
||||
if len(self.resource_history) > self.max_history_entries:
|
||||
self.resource_history.pop(0)
|
||||
|
||||
# Check thresholds
|
||||
self._check_memory_threshold(usage)
|
||||
self._check_cpu_threshold(usage)
|
||||
self._check_disk_threshold(usage)
|
||||
|
||||
# Log periodic status (every 10 minutes)
|
||||
if len(self.resource_history) % 20 == 0: # 20 * 30s = 10 minutes
|
||||
self._log_resource_status(usage)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in resource monitoring loop: {e}")
|
||||
|
||||
time.sleep(self.check_interval)
|
||||
|
||||
logger.info("Resource monitoring loop stopped")
|
||||
|
||||
def _check_memory_threshold(self, usage: Dict[str, Any]):
|
||||
"""Check memory usage threshold"""
|
||||
memory_mb = usage.get('memory', {}).get('used_mb', 0)
|
||||
|
||||
if memory_mb > self.memory_threshold_mb:
|
||||
now = datetime.now()
|
||||
if now - self.last_memory_warning > self.warning_cooldown:
|
||||
logger.warning(f"HIGH MEMORY USAGE: {memory_mb:.1f}MB / {self.memory_threshold_mb}MB threshold")
|
||||
self.last_memory_warning = now
|
||||
|
||||
# Trigger cleanup
|
||||
self._trigger_memory_cleanup()
|
||||
|
||||
# Call callback if set
|
||||
if self.memory_warning_callback:
|
||||
try:
|
||||
self.memory_warning_callback(memory_mb, self.memory_threshold_mb)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in memory warning callback: {e}")
|
||||
|
||||
def _check_cpu_threshold(self, usage: Dict[str, Any]):
|
||||
"""Check CPU usage threshold"""
|
||||
cpu_percent = usage.get('cpu_percent', 0)
|
||||
|
||||
if cpu_percent > self.cpu_threshold_percent:
|
||||
now = datetime.now()
|
||||
if now - self.last_cpu_warning > self.warning_cooldown:
|
||||
logger.warning(f"HIGH CPU USAGE: {cpu_percent:.1f}% / {self.cpu_threshold_percent}% threshold")
|
||||
self.last_cpu_warning = now
|
||||
|
||||
if self.cpu_warning_callback:
|
||||
try:
|
||||
self.cpu_warning_callback(cpu_percent, self.cpu_threshold_percent)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in CPU warning callback: {e}")
|
||||
|
||||
def _check_disk_threshold(self, usage: Dict[str, Any]):
|
||||
"""Check disk usage threshold"""
|
||||
disk_percent = usage.get('disk', {}).get('percent', 0)
|
||||
|
||||
if disk_percent > self.disk_threshold_percent:
|
||||
now = datetime.now()
|
||||
if now - self.last_disk_warning > self.warning_cooldown:
|
||||
logger.warning(f"HIGH DISK USAGE: {disk_percent:.1f}% / {self.disk_threshold_percent}% threshold")
|
||||
self.last_disk_warning = now
|
||||
|
||||
if self.disk_warning_callback:
|
||||
try:
|
||||
self.disk_warning_callback(disk_percent, self.disk_threshold_percent)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in disk warning callback: {e}")
|
||||
|
||||
def _trigger_memory_cleanup(self):
|
||||
"""Trigger memory cleanup procedures"""
|
||||
logger.info("Triggering memory cleanup...")
|
||||
|
||||
# Force garbage collection
|
||||
collected = gc.collect()
|
||||
logger.info(f"Garbage collection freed {collected} objects")
|
||||
|
||||
# Call custom cleanup callback if set
|
||||
if self.cleanup_callback:
|
||||
try:
|
||||
self.cleanup_callback()
|
||||
logger.info("Custom cleanup callback executed")
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup callback: {e}")
|
||||
|
||||
# Log memory after cleanup
|
||||
try:
|
||||
usage_after = self.get_current_usage()
|
||||
memory_after = usage_after.get('memory', {}).get('used_mb', 0)
|
||||
logger.info(f"Memory after cleanup: {memory_after:.1f}MB")
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking memory after cleanup: {e}")
|
||||
|
||||
def _log_resource_status(self, usage: Dict[str, Any]):
|
||||
"""Log current resource status"""
|
||||
memory = usage.get('memory', {})
|
||||
cpu = usage.get('cpu_percent', 0)
|
||||
disk = usage.get('disk', {})
|
||||
process_memory = usage.get('process_memory_mb', 0)
|
||||
|
||||
logger.info(f"RESOURCE STATUS - Memory: {memory.get('used_mb', 0):.1f}MB ({memory.get('percent', 0):.1f}%), "
|
||||
f"Process: {process_memory:.1f}MB, CPU: {cpu:.1f}%, Disk: {disk.get('percent', 0):.1f}%")
|
||||
|
||||
def get_resource_summary(self) -> Dict[str, Any]:
|
||||
"""Get resource usage summary"""
|
||||
if not self.resource_history:
|
||||
return {}
|
||||
|
||||
recent_usage = self.resource_history[-10:] # Last 10 entries
|
||||
|
||||
# Calculate averages
|
||||
avg_memory = sum(u.get('memory', {}).get('used_mb', 0) for u in recent_usage) / len(recent_usage)
|
||||
avg_cpu = sum(u.get('cpu_percent', 0) for u in recent_usage) / len(recent_usage)
|
||||
avg_disk = sum(u.get('disk', {}).get('percent', 0) for u in recent_usage) / len(recent_usage)
|
||||
|
||||
current = self.resource_history[-1] if self.resource_history else {}
|
||||
|
||||
return {
|
||||
'current': current,
|
||||
'averages': {
|
||||
'memory_mb': avg_memory,
|
||||
'cpu_percent': avg_cpu,
|
||||
'disk_percent': avg_disk
|
||||
},
|
||||
'thresholds': {
|
||||
'memory_mb': self.memory_threshold_mb,
|
||||
'cpu_percent': self.cpu_threshold_percent,
|
||||
'disk_percent': self.disk_threshold_percent
|
||||
},
|
||||
'monitoring': self.monitoring,
|
||||
'history_entries': len(self.resource_history)
|
||||
}
|
||||
|
||||
# Global instance
|
||||
_system_monitor = None
|
||||
|
||||
def get_system_monitor() -> SystemResourceMonitor:
|
||||
"""Get global system monitor instance"""
|
||||
global _system_monitor
|
||||
if _system_monitor is None:
|
||||
_system_monitor = SystemResourceMonitor()
|
||||
return _system_monitor
|
||||
|
||||
def start_system_monitoring():
|
||||
"""Start system monitoring with default settings"""
|
||||
monitor = get_system_monitor()
|
||||
monitor.start_monitoring()
|
||||
return monitor
|
@@ -1,219 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
TensorBoard Logger Utility
|
||||
|
||||
This module provides a centralized way to log training metrics to TensorBoard.
|
||||
It ensures consistent logging across different training components.
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, Union, List
|
||||
|
||||
# Import conditionally to handle missing dependencies gracefully
|
||||
try:
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
TENSORBOARD_AVAILABLE = True
|
||||
except ImportError:
|
||||
TENSORBOARD_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TensorBoardLogger:
|
||||
"""
|
||||
Centralized TensorBoard logging utility for training metrics
|
||||
|
||||
This class provides a consistent interface for logging metrics to TensorBoard
|
||||
across different training components.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
log_dir: Optional[str] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
enabled: bool = True):
|
||||
"""
|
||||
Initialize TensorBoard logger
|
||||
|
||||
Args:
|
||||
log_dir: Base directory for TensorBoard logs (default: 'runs')
|
||||
experiment_name: Name of the experiment (default: timestamp)
|
||||
enabled: Whether TensorBoard logging is enabled
|
||||
"""
|
||||
self.enabled = enabled and TENSORBOARD_AVAILABLE
|
||||
self.writer = None
|
||||
|
||||
if not self.enabled:
|
||||
if not TENSORBOARD_AVAILABLE:
|
||||
logger.warning("TensorBoard not available. Install with: pip install tensorboard")
|
||||
return
|
||||
|
||||
# Set up log directory
|
||||
if log_dir is None:
|
||||
log_dir = "runs"
|
||||
|
||||
# Create experiment name if not provided
|
||||
if experiment_name is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
experiment_name = f"training_{timestamp}"
|
||||
|
||||
# Create full log path
|
||||
self.log_dir = os.path.join(log_dir, experiment_name)
|
||||
|
||||
# Create writer
|
||||
try:
|
||||
self.writer = SummaryWriter(log_dir=self.log_dir)
|
||||
logger.info(f"TensorBoard logging enabled at: {self.log_dir}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize TensorBoard: {e}")
|
||||
self.enabled = False
|
||||
|
||||
def log_scalar(self, tag: str, value: float, step: int) -> None:
|
||||
"""
|
||||
Log a scalar value to TensorBoard
|
||||
|
||||
Args:
|
||||
tag: Metric name
|
||||
value: Metric value
|
||||
step: Training step
|
||||
"""
|
||||
if not self.enabled or self.writer is None:
|
||||
return
|
||||
|
||||
try:
|
||||
self.writer.add_scalar(tag, value, step)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log scalar {tag}: {e}")
|
||||
|
||||
def log_scalars(self, main_tag: str, tag_value_dict: Dict[str, float], step: int) -> None:
|
||||
"""
|
||||
Log multiple scalar values with the same main tag
|
||||
|
||||
Args:
|
||||
main_tag: Main tag for the metrics
|
||||
tag_value_dict: Dictionary of tag names to values
|
||||
step: Training step
|
||||
"""
|
||||
if not self.enabled or self.writer is None:
|
||||
return
|
||||
|
||||
try:
|
||||
self.writer.add_scalars(main_tag, tag_value_dict, step)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log scalars for {main_tag}: {e}")
|
||||
|
||||
def log_histogram(self, tag: str, values, step: int) -> None:
|
||||
"""
|
||||
Log a histogram to TensorBoard
|
||||
|
||||
Args:
|
||||
tag: Histogram name
|
||||
values: Values to create histogram from
|
||||
step: Training step
|
||||
"""
|
||||
if not self.enabled or self.writer is None:
|
||||
return
|
||||
|
||||
try:
|
||||
self.writer.add_histogram(tag, values, step)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to log histogram {tag}: {e}")
|
||||
|
||||
def log_training_metrics(self,
|
||||
metrics: Dict[str, Any],
|
||||
step: int,
|
||||
prefix: str = "Training") -> None:
|
||||
"""
|
||||
Log training metrics to TensorBoard
|
||||
|
||||
Args:
|
||||
metrics: Dictionary of metric names to values
|
||||
step: Training step
|
||||
prefix: Prefix for metric names
|
||||
"""
|
||||
if not self.enabled or self.writer is None:
|
||||
return
|
||||
|
||||
for name, value in metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
self.log_scalar(f"{prefix}/{name}", value, step)
|
||||
elif hasattr(value, "shape"): # For numpy arrays or tensors
|
||||
try:
|
||||
self.log_histogram(f"{prefix}/{name}", value, step)
|
||||
except:
|
||||
pass
|
||||
|
||||
def log_model_metrics(self,
|
||||
model_name: str,
|
||||
metrics: Dict[str, Any],
|
||||
step: int) -> None:
|
||||
"""
|
||||
Log model-specific metrics to TensorBoard
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
metrics: Dictionary of metric names to values
|
||||
step: Training step
|
||||
"""
|
||||
if not self.enabled or self.writer is None:
|
||||
return
|
||||
|
||||
for name, value in metrics.items():
|
||||
if isinstance(value, (int, float)):
|
||||
self.log_scalar(f"Model/{model_name}/{name}", value, step)
|
||||
|
||||
def log_reward_metrics(self,
|
||||
symbol: str,
|
||||
metrics: Dict[str, float],
|
||||
step: int) -> None:
|
||||
"""
|
||||
Log reward-related metrics to TensorBoard
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
metrics: Dictionary of metric names to values
|
||||
step: Training step
|
||||
"""
|
||||
if not self.enabled or self.writer is None:
|
||||
return
|
||||
|
||||
for name, value in metrics.items():
|
||||
self.log_scalar(f"Rewards/{symbol}/{name}", value, step)
|
||||
|
||||
def log_state_metrics(self,
|
||||
symbol: str,
|
||||
state_info: Dict[str, Any],
|
||||
step: int) -> None:
|
||||
"""
|
||||
Log state-related metrics to TensorBoard
|
||||
|
||||
Args:
|
||||
symbol: Trading symbol
|
||||
state_info: Dictionary of state information
|
||||
step: Training step
|
||||
"""
|
||||
if not self.enabled or self.writer is None:
|
||||
return
|
||||
|
||||
# Log state size
|
||||
if "size" in state_info:
|
||||
self.log_scalar(f"State/{symbol}/Size", state_info["size"], step)
|
||||
|
||||
# Log state quality
|
||||
if "quality" in state_info:
|
||||
self.log_scalar(f"State/{symbol}/Quality", state_info["quality"], step)
|
||||
|
||||
# Log feature counts
|
||||
if "feature_counts" in state_info:
|
||||
for feature_type, count in state_info["feature_counts"].items():
|
||||
self.log_scalar(f"State/{symbol}/Features/{feature_type}", count, step)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the TensorBoard writer"""
|
||||
if self.enabled and self.writer is not None:
|
||||
try:
|
||||
self.writer.close()
|
||||
logger.info("TensorBoard writer closed")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error closing TensorBoard writer: {e}")
|
Reference in New Issue
Block a user