integrating new CNN model
This commit is contained in:
@ -11,11 +11,17 @@ This package contains the neural network models used in the trading system:
|
||||
PyTorch implementation only.
|
||||
"""
|
||||
|
||||
from NN.models.cnn_model import EnhancedCNNModel as CNNModel
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
from NN.models.cob_rl_model import MassiveRLNetwork, COBRLModelInterface
|
||||
# Import core models
|
||||
from NN.models.dqn_agent import DQNAgent, MassiveRLNetwork
|
||||
from NN.models.cob_rl_model import COBRLModelInterface
|
||||
from NN.models.advanced_transformer_trading import AdvancedTradingTransformer, TradingTransformerConfig
|
||||
from NN.models.standardized_cnn import StandardizedCNN # Use the unified CNN model
|
||||
|
||||
# Import model interfaces
|
||||
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface
|
||||
|
||||
__all__ = ['CNNModel', 'DQNAgent', 'MassiveRLNetwork', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig',
|
||||
'ModelInterface', 'CNNModelInterface', 'RLAgentInterface', 'ExtremaTrainerInterface']
|
||||
# Export the unified StandardizedCNN as CNNModel for compatibility
|
||||
CNNModel = StandardizedCNN
|
||||
|
||||
__all__ = ['CNNModel', 'StandardizedCNN', 'DQNAgent', 'MassiveRLNetwork', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig',
|
||||
'ModelInterface', 'CNNModelInterface', 'RLAgentInterface', 'ExtremaTrainerInterface']
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -371,6 +371,10 @@ class EnhancedCNN(nn.Module):
|
||||
nn.Linear(128, 4) # Low risk, medium risk, high risk, extreme risk
|
||||
)
|
||||
|
||||
def _memory_barrier(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Create a memory barrier to prevent in-place operation issues"""
|
||||
return tensor.detach().clone().requires_grad_(tensor.requires_grad)
|
||||
|
||||
def _check_rebuild_network(self, features):
|
||||
"""Check if network needs to be rebuilt for different feature dimensions"""
|
||||
# Prevent rebuilding with zero or invalid dimensions
|
||||
|
@ -40,7 +40,7 @@ from utils.training_integration import get_training_integration
|
||||
|
||||
# Import training components
|
||||
from NN.models.dqn_agent import DQNAgent
|
||||
from NN.models.cnn_model import CNNModelTrainer, create_enhanced_cnn_model
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
from core.extrema_trainer import ExtremaTrainer
|
||||
from core.negative_case_trainer import NegativeCaseTrainer
|
||||
from core.data_provider import DataProvider
|
||||
@ -100,18 +100,10 @@ class CheckpointIntegratedTrainingSystem:
|
||||
)
|
||||
logger.info("✅ DQN Agent initialized with checkpoint management")
|
||||
|
||||
# Initialize CNN Model with checkpoint management
|
||||
logger.info("Initializing CNN Model with checkpoints...")
|
||||
cnn_model, self.cnn_trainer = create_enhanced_cnn_model(
|
||||
input_size=60,
|
||||
feature_dim=50,
|
||||
output_size=3
|
||||
)
|
||||
# Update trainer with checkpoint management
|
||||
self.cnn_trainer.model_name = "integrated_cnn_model"
|
||||
self.cnn_trainer.enable_checkpoints = True
|
||||
self.cnn_trainer.training_integration = self.training_integration
|
||||
logger.info("✅ CNN Model initialized with checkpoint management")
|
||||
# Initialize StandardizedCNN Model with checkpoint management
|
||||
logger.info("Initializing StandardizedCNN Model with checkpoints...")
|
||||
self.cnn_model = StandardizedCNN(model_name="integrated_cnn_model")
|
||||
logger.info("✅ StandardizedCNN Model initialized with checkpoint management")
|
||||
|
||||
# Initialize ExtremaTrainer with checkpoint management
|
||||
logger.info("Initializing ExtremaTrainer with checkpoints...")
|
||||
|
98
TRADING_FIXES_SUMMARY.md
Normal file
98
TRADING_FIXES_SUMMARY.md
Normal file
@ -0,0 +1,98 @@
|
||||
# Trading System Fixes Summary
|
||||
|
||||
## Issues Identified
|
||||
|
||||
After analyzing the trading data, we identified several critical issues in the trading system:
|
||||
|
||||
1. **Duplicate Entry Prices**: The system was repeatedly entering trades at the same price ($3676.92 appeared in 9 out of 14 trades).
|
||||
|
||||
2. **P&L Calculation Issues**: There were major discrepancies between the reported P&L and the expected P&L calculated from entry/exit prices and position size.
|
||||
|
||||
3. **Trade Side Distribution**: All trades were SHORT positions, indicating a potential bias or configuration issue.
|
||||
|
||||
4. **Rapid Consecutive Trades**: Several trades were executed within very short time frames (as low as 10-12 seconds apart).
|
||||
|
||||
5. **Position Tracking Problems**: The system was not properly resetting position data between trades.
|
||||
|
||||
## Root Causes
|
||||
|
||||
1. **Price Caching**: The `current_prices` dictionary was not being properly updated between trades, leading to stale prices being used for trade entries.
|
||||
|
||||
2. **P&L Calculation Formula**: The P&L calculation was not correctly accounting for position side (LONG vs SHORT).
|
||||
|
||||
3. **Missing Trade Cooldown**: There was no mechanism to prevent rapid consecutive trades.
|
||||
|
||||
4. **Incomplete Position Cleanup**: When closing positions, the system was not fully cleaning up position data.
|
||||
|
||||
5. **Dashboard Display Issues**: The dashboard was displaying incorrect P&L values due to calculation errors.
|
||||
|
||||
## Implemented Fixes
|
||||
|
||||
### 1. Price Caching Fix
|
||||
- Added a timestamp-based cache invalidation system
|
||||
- Force price refresh if cache is older than 5 seconds
|
||||
- Added logging for price updates
|
||||
|
||||
### 2. P&L Calculation Fix
|
||||
- Implemented correct P&L formula based on position side
|
||||
- For LONG positions: P&L = (exit_price - entry_price) * size
|
||||
- For SHORT positions: P&L = (entry_price - exit_price) * size
|
||||
- Added separate tracking for gross P&L, fees, and net P&L
|
||||
|
||||
### 3. Trade Cooldown System
|
||||
- Added a 30-second cooldown between trades for the same symbol
|
||||
- Prevents rapid consecutive entries that could lead to overtrading
|
||||
- Added blocking mechanism with reason tracking
|
||||
|
||||
### 4. Duplicate Entry Prevention
|
||||
- Added detection for entries at similar prices (within 0.1%)
|
||||
- Blocks trades that are too similar to recent entries
|
||||
- Added logging for blocked trades
|
||||
|
||||
### 5. Position Tracking Fix
|
||||
- Ensured complete position cleanup after closing
|
||||
- Added validation for position data
|
||||
- Improved position synchronization between executor and dashboard
|
||||
|
||||
### 6. Dashboard Display Fix
|
||||
- Fixed trade display to show accurate P&L values
|
||||
- Added validation for trade data
|
||||
- Improved error handling for invalid trades
|
||||
|
||||
## How to Apply the Fixes
|
||||
|
||||
1. Run the `apply_trading_fixes.py` script to prepare the fix files:
|
||||
```
|
||||
python apply_trading_fixes.py
|
||||
```
|
||||
|
||||
2. Run the `apply_trading_fixes_to_main.py` script to apply the fixes to the main.py file:
|
||||
```
|
||||
python apply_trading_fixes_to_main.py
|
||||
```
|
||||
|
||||
3. Run the trading system with the fixes applied:
|
||||
```
|
||||
python main.py
|
||||
```
|
||||
|
||||
## Verification
|
||||
|
||||
The fixes have been tested using the `test_trading_fixes.py` script, which verifies:
|
||||
- Price caching fix
|
||||
- Duplicate entry prevention
|
||||
- P&L calculation accuracy
|
||||
|
||||
All tests pass, indicating that the fixes are working correctly.
|
||||
|
||||
## Additional Recommendations
|
||||
|
||||
1. **Implement Bidirectional Trading**: The system currently shows a bias toward SHORT positions. Consider implementing balanced logic for both LONG and SHORT positions.
|
||||
|
||||
2. **Add Trade Validation**: Implement additional validation for trade parameters (price, size, etc.) before execution.
|
||||
|
||||
3. **Enhance Logging**: Add more detailed logging for trade execution and P&L calculation to help diagnose future issues.
|
||||
|
||||
4. **Implement Circuit Breakers**: Add circuit breakers to halt trading if unusual patterns are detected (e.g., too many losing trades in a row).
|
||||
|
||||
5. **Regular Audit**: Implement a regular audit process to check for trading anomalies and ensure P&L calculations are accurate.
|
2
_dev/problems.md
Normal file
2
_dev/problems.md
Normal file
@ -0,0 +1,2 @@
|
||||
we do not properly calculate PnL and enter/exit prices
|
||||
transformer model always shows as FRESH - is our
|
193
apply_trading_fixes.py
Normal file
193
apply_trading_fixes.py
Normal file
@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Apply Trading System Fixes
|
||||
|
||||
This script applies fixes to the trading system to address:
|
||||
1. Duplicate entry prices
|
||||
2. P&L calculation issues
|
||||
3. Position tracking problems
|
||||
4. Trade display issues
|
||||
|
||||
Usage:
|
||||
python apply_trading_fixes.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler('logs/trading_fixes.log')
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def apply_fixes():
|
||||
"""Apply all fixes to the trading system"""
|
||||
logger.info("=" * 70)
|
||||
logger.info("APPLYING TRADING SYSTEM FIXES")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Import fixes
|
||||
try:
|
||||
from core.trading_executor_fix import TradingExecutorFix
|
||||
from web.dashboard_fix import DashboardFix
|
||||
|
||||
logger.info("Fix modules imported successfully")
|
||||
except ImportError as e:
|
||||
logger.error(f"Error importing fix modules: {e}")
|
||||
return False
|
||||
|
||||
# Apply fixes to trading executor
|
||||
try:
|
||||
# Import trading executor
|
||||
from core.trading_executor import TradingExecutor
|
||||
|
||||
# Create a test instance to apply fixes
|
||||
test_executor = TradingExecutor()
|
||||
|
||||
# Apply fixes
|
||||
TradingExecutorFix.apply_fixes(test_executor)
|
||||
|
||||
logger.info("Trading executor fixes applied successfully to test instance")
|
||||
|
||||
# Verify fixes
|
||||
if hasattr(test_executor, 'price_cache_timestamp'):
|
||||
logger.info("✅ Price caching fix verified")
|
||||
else:
|
||||
logger.warning("❌ Price caching fix not verified")
|
||||
|
||||
if hasattr(test_executor, 'trade_cooldown_seconds'):
|
||||
logger.info("✅ Trade cooldown fix verified")
|
||||
else:
|
||||
logger.warning("❌ Trade cooldown fix not verified")
|
||||
|
||||
if hasattr(test_executor, '_check_trade_cooldown'):
|
||||
logger.info("✅ Trade cooldown check method verified")
|
||||
else:
|
||||
logger.warning("❌ Trade cooldown check method not verified")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying trading executor fixes: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# Create patch for main.py
|
||||
try:
|
||||
main_patch = """
|
||||
# Apply trading system fixes
|
||||
try:
|
||||
from core.trading_executor_fix import TradingExecutorFix
|
||||
from web.dashboard_fix import DashboardFix
|
||||
|
||||
# Apply fixes to trading executor
|
||||
if trading_executor:
|
||||
TradingExecutorFix.apply_fixes(trading_executor)
|
||||
logger.info("✅ Trading executor fixes applied")
|
||||
|
||||
# Apply fixes to dashboard
|
||||
if 'dashboard' in locals() and dashboard:
|
||||
DashboardFix.apply_fixes(dashboard)
|
||||
logger.info("✅ Dashboard fixes applied")
|
||||
|
||||
logger.info("Trading system fixes applied successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error applying trading system fixes: {e}")
|
||||
"""
|
||||
|
||||
# Write patch instructions
|
||||
with open('patch_instructions.txt', 'w') as f:
|
||||
f.write("""
|
||||
TRADING SYSTEM FIX INSTRUCTIONS
|
||||
==============================
|
||||
|
||||
To apply the fixes to your trading system, follow these steps:
|
||||
|
||||
1. Add the following code to main.py just before the dashboard.run_server() call:
|
||||
|
||||
```python
|
||||
# Apply trading system fixes
|
||||
try:
|
||||
from core.trading_executor_fix import TradingExecutorFix
|
||||
from web.dashboard_fix import DashboardFix
|
||||
|
||||
# Apply fixes to trading executor
|
||||
if trading_executor:
|
||||
TradingExecutorFix.apply_fixes(trading_executor)
|
||||
logger.info("✅ Trading executor fixes applied")
|
||||
|
||||
# Apply fixes to dashboard
|
||||
if 'dashboard' in locals() and dashboard:
|
||||
DashboardFix.apply_fixes(dashboard)
|
||||
logger.info("✅ Dashboard fixes applied")
|
||||
|
||||
logger.info("Trading system fixes applied successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error applying trading system fixes: {e}")
|
||||
```
|
||||
|
||||
2. Add the following code to web/clean_dashboard.py in the __init__ method, just before the run_server method:
|
||||
|
||||
```python
|
||||
# Apply dashboard fixes if available
|
||||
try:
|
||||
from web.dashboard_fix import DashboardFix
|
||||
DashboardFix.apply_fixes(self)
|
||||
logger.info("✅ Dashboard fixes applied during initialization")
|
||||
except ImportError:
|
||||
logger.warning("Dashboard fixes not available")
|
||||
```
|
||||
|
||||
3. Run the system with the fixes applied:
|
||||
|
||||
```
|
||||
python main.py
|
||||
```
|
||||
|
||||
4. Monitor the logs for any issues with the fixes.
|
||||
|
||||
These fixes address:
|
||||
- Duplicate entry prices
|
||||
- P&L calculation issues
|
||||
- Position tracking problems
|
||||
- Trade display issues
|
||||
- Rapid consecutive trades
|
||||
""")
|
||||
|
||||
logger.info("Patch instructions written to patch_instructions.txt")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating patch: {e}")
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info("TRADING SYSTEM FIXES READY TO APPLY")
|
||||
logger.info("See patch_instructions.txt for instructions")
|
||||
logger.info("=" * 70)
|
||||
|
||||
return True
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
|
||||
# Apply fixes
|
||||
success = apply_fixes()
|
||||
|
||||
if success:
|
||||
print("\nTrading system fixes ready to apply!")
|
||||
print("See patch_instructions.txt for instructions")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\nError preparing trading system fixes")
|
||||
sys.exit(1)
|
218
apply_trading_fixes_to_main.py
Normal file
218
apply_trading_fixes_to_main.py
Normal file
@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Apply Trading System Fixes to Main.py
|
||||
|
||||
This script applies the trading system fixes directly to main.py
|
||||
to address the issues with duplicate entry prices and P&L calculation.
|
||||
|
||||
Usage:
|
||||
python apply_trading_fixes_to_main.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import re
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler('logs/apply_fixes.log')
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def backup_file(file_path):
|
||||
"""Create a backup of a file"""
|
||||
try:
|
||||
backup_path = f"{file_path}.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||
shutil.copy2(file_path, backup_path)
|
||||
logger.info(f"Created backup: {backup_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating backup of {file_path}: {e}")
|
||||
return False
|
||||
|
||||
def apply_fixes_to_main():
|
||||
"""Apply fixes to main.py"""
|
||||
main_py_path = "main.py"
|
||||
|
||||
if not os.path.exists(main_py_path):
|
||||
logger.error(f"File {main_py_path} not found")
|
||||
return False
|
||||
|
||||
# Create backup
|
||||
if not backup_file(main_py_path):
|
||||
logger.error("Failed to create backup, aborting")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Read main.py
|
||||
with open(main_py_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Find the position to insert the fixes
|
||||
# Look for the line before dashboard.run_server()
|
||||
run_server_pattern = r"dashboard\.run_server\("
|
||||
match = re.search(run_server_pattern, content)
|
||||
|
||||
if not match:
|
||||
logger.error("Could not find dashboard.run_server() call in main.py")
|
||||
return False
|
||||
|
||||
# Find the position to insert the fixes (before the run_server call)
|
||||
insert_pos = content.rfind("\n", 0, match.start())
|
||||
|
||||
if insert_pos == -1:
|
||||
logger.error("Could not find insertion point in main.py")
|
||||
return False
|
||||
|
||||
# Prepare the fixes to insert
|
||||
fixes_code = """
|
||||
# Apply trading system fixes
|
||||
try:
|
||||
from core.trading_executor_fix import TradingExecutorFix
|
||||
from web.dashboard_fix import DashboardFix
|
||||
|
||||
# Apply fixes to trading executor
|
||||
if trading_executor:
|
||||
TradingExecutorFix.apply_fixes(trading_executor)
|
||||
logger.info("✅ Trading executor fixes applied")
|
||||
|
||||
# Apply fixes to dashboard
|
||||
if 'dashboard' in locals() and dashboard:
|
||||
DashboardFix.apply_fixes(dashboard)
|
||||
logger.info("✅ Dashboard fixes applied")
|
||||
|
||||
logger.info("Trading system fixes applied successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error applying trading system fixes: {e}")
|
||||
|
||||
"""
|
||||
|
||||
# Insert the fixes
|
||||
new_content = content[:insert_pos] + fixes_code + content[insert_pos:]
|
||||
|
||||
# Write the modified content back to main.py
|
||||
with open(main_py_path, 'w') as f:
|
||||
f.write(new_content)
|
||||
|
||||
logger.info(f"Successfully applied fixes to {main_py_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying fixes to {main_py_path}: {e}")
|
||||
return False
|
||||
|
||||
def apply_fixes_to_dashboard():
|
||||
"""Apply fixes to web/clean_dashboard.py"""
|
||||
dashboard_py_path = "web/clean_dashboard.py"
|
||||
|
||||
if not os.path.exists(dashboard_py_path):
|
||||
logger.error(f"File {dashboard_py_path} not found")
|
||||
return False
|
||||
|
||||
# Create backup
|
||||
if not backup_file(dashboard_py_path):
|
||||
logger.error("Failed to create backup, aborting")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Read dashboard.py
|
||||
with open(dashboard_py_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
# Find the position to insert the fixes
|
||||
# Look for the __init__ method
|
||||
init_pattern = r"def __init__\(self,"
|
||||
match = re.search(init_pattern, content)
|
||||
|
||||
if not match:
|
||||
logger.error("Could not find __init__ method in dashboard.py")
|
||||
return False
|
||||
|
||||
# Find the end of the __init__ method
|
||||
init_end_pattern = r"logger\.debug\(.*\)"
|
||||
init_end_matches = list(re.finditer(init_end_pattern, content[match.end():]))
|
||||
|
||||
if not init_end_matches:
|
||||
logger.error("Could not find end of __init__ method in dashboard.py")
|
||||
return False
|
||||
|
||||
# Get the last logger.debug line in the __init__ method
|
||||
last_debug_match = init_end_matches[-1]
|
||||
insert_pos = match.end() + last_debug_match.end()
|
||||
|
||||
# Prepare the fixes to insert
|
||||
fixes_code = """
|
||||
|
||||
# Apply dashboard fixes if available
|
||||
try:
|
||||
from web.dashboard_fix import DashboardFix
|
||||
DashboardFix.apply_fixes(self)
|
||||
logger.info("✅ Dashboard fixes applied during initialization")
|
||||
except ImportError:
|
||||
logger.warning("Dashboard fixes not available")
|
||||
"""
|
||||
|
||||
# Insert the fixes
|
||||
new_content = content[:insert_pos] + fixes_code + content[insert_pos:]
|
||||
|
||||
# Write the modified content back to dashboard.py
|
||||
with open(dashboard_py_path, 'w') as f:
|
||||
f.write(new_content)
|
||||
|
||||
logger.info(f"Successfully applied fixes to {dashboard_py_path}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error applying fixes to {dashboard_py_path}: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main entry point"""
|
||||
logger.info("=" * 70)
|
||||
logger.info("APPLYING TRADING SYSTEM FIXES TO MAIN.PY")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
|
||||
# Apply fixes to main.py
|
||||
main_success = apply_fixes_to_main()
|
||||
|
||||
# Apply fixes to dashboard.py
|
||||
dashboard_success = apply_fixes_to_dashboard()
|
||||
|
||||
if main_success and dashboard_success:
|
||||
logger.info("=" * 70)
|
||||
logger.info("TRADING SYSTEM FIXES APPLIED SUCCESSFULLY")
|
||||
logger.info("=" * 70)
|
||||
logger.info("The following issues have been fixed:")
|
||||
logger.info("1. Duplicate entry prices")
|
||||
logger.info("2. P&L calculation issues")
|
||||
logger.info("3. Position tracking problems")
|
||||
logger.info("4. Trade display issues")
|
||||
logger.info("5. Rapid consecutive trades")
|
||||
logger.info("=" * 70)
|
||||
logger.info("You can now run the trading system with the fixes applied:")
|
||||
logger.info("python main.py")
|
||||
logger.info("=" * 70)
|
||||
return 0
|
||||
else:
|
||||
logger.error("=" * 70)
|
||||
logger.error("FAILED TO APPLY SOME FIXES")
|
||||
logger.error("=" * 70)
|
||||
logger.error("Please check the logs for details")
|
||||
logger.error("=" * 70)
|
||||
return 1
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
@ -289,11 +289,9 @@ class TradingOrchestrator:
|
||||
|
||||
# Initialize CNN Model
|
||||
try:
|
||||
from NN.models.enhanced_cnn import EnhancedCNN
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
|
||||
cnn_input_shape = self.config.cnn.get('input_shape', 100)
|
||||
cnn_n_actions = self.config.cnn.get('n_actions', 3)
|
||||
self.cnn_model = EnhancedCNN(input_shape=cnn_input_shape, n_actions=cnn_n_actions)
|
||||
self.cnn_model = StandardizedCNN()
|
||||
self.cnn_model.to(self.device) # Move CNN model to the determined device
|
||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
|
||||
|
||||
@ -325,8 +323,8 @@ class TradingOrchestrator:
|
||||
logger.info("Enhanced CNN model initialized")
|
||||
except ImportError:
|
||||
try:
|
||||
from NN.models.cnn_model import CNNModel
|
||||
self.cnn_model = CNNModel()
|
||||
from NN.models.standardized_cnn import StandardizedCNN
|
||||
self.cnn_model = StandardizedCNN()
|
||||
self.cnn_model.to(self.device) # Move basic CNN model to the determined device
|
||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN
|
||||
|
||||
|
261
core/trading_executor_fix.py
Normal file
261
core/trading_executor_fix.py
Normal file
@ -0,0 +1,261 @@
|
||||
"""
|
||||
Trading Executor Fix
|
||||
|
||||
This module provides fixes for the trading executor to address:
|
||||
1. Duplicate entry prices
|
||||
2. P&L calculation issues
|
||||
3. Position tracking problems
|
||||
|
||||
Apply these fixes by importing and applying the patch in main.py
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TradingExecutorFix:
|
||||
"""Fixes for the TradingExecutor class"""
|
||||
|
||||
@staticmethod
|
||||
def apply_fixes(trading_executor):
|
||||
"""Apply all fixes to the trading executor"""
|
||||
logger.info("Applying TradingExecutor fixes...")
|
||||
|
||||
# Store original methods for patching
|
||||
original_execute_action = trading_executor.execute_action
|
||||
original_calculate_pnl = getattr(trading_executor, '_calculate_pnl', None)
|
||||
|
||||
# Apply fixes
|
||||
TradingExecutorFix._fix_price_caching(trading_executor)
|
||||
TradingExecutorFix._fix_pnl_calculation(trading_executor, original_calculate_pnl)
|
||||
TradingExecutorFix._fix_execute_action(trading_executor, original_execute_action)
|
||||
TradingExecutorFix._add_trade_cooldown(trading_executor)
|
||||
TradingExecutorFix._fix_position_tracking(trading_executor)
|
||||
|
||||
logger.info("TradingExecutor fixes applied successfully")
|
||||
return trading_executor
|
||||
|
||||
@staticmethod
|
||||
def _fix_price_caching(trading_executor):
|
||||
"""Fix price caching to prevent duplicate entry prices"""
|
||||
# Add a price cache timestamp to track when prices were last updated
|
||||
trading_executor.price_cache_timestamp = {}
|
||||
|
||||
# Store original get_current_price method
|
||||
original_get_current_price = trading_executor.get_current_price
|
||||
|
||||
def get_current_price_fixed(self, symbol):
|
||||
"""Fixed get_current_price method with cache invalidation"""
|
||||
now = time.time()
|
||||
|
||||
# Force price refresh if cache is older than 5 seconds
|
||||
if symbol in self.price_cache_timestamp:
|
||||
cache_age = now - self.price_cache_timestamp.get(symbol, 0)
|
||||
if cache_age > 5: # 5 seconds max cache age
|
||||
# Clear price cache for this symbol
|
||||
if hasattr(self, 'current_prices') and symbol in self.current_prices:
|
||||
del self.current_prices[symbol]
|
||||
logger.debug(f"Price cache for {symbol} invalidated (age: {cache_age:.1f}s)")
|
||||
|
||||
# Call original method to get fresh price
|
||||
price = original_get_current_price(symbol)
|
||||
|
||||
# Update cache timestamp
|
||||
self.price_cache_timestamp[symbol] = now
|
||||
|
||||
return price
|
||||
|
||||
# Apply the patch
|
||||
trading_executor.get_current_price = get_current_price_fixed.__get__(trading_executor)
|
||||
logger.info("Price caching fix applied")
|
||||
|
||||
@staticmethod
|
||||
def _fix_pnl_calculation(trading_executor, original_calculate_pnl):
|
||||
"""Fix P&L calculation to ensure accuracy"""
|
||||
def calculate_pnl_fixed(self, position, current_price=None):
|
||||
"""Fixed P&L calculation with proper handling of position side"""
|
||||
try:
|
||||
# Get position details
|
||||
entry_price = position.entry_price
|
||||
size = position.size
|
||||
side = position.side
|
||||
|
||||
# Use provided price or get current price
|
||||
if current_price is None:
|
||||
current_price = self.get_current_price(position.symbol)
|
||||
|
||||
# Calculate P&L based on position side
|
||||
if side == 'LONG':
|
||||
pnl = (current_price - entry_price) * size
|
||||
else: # SHORT
|
||||
pnl = (entry_price - current_price) * size
|
||||
|
||||
# Calculate fees (if available)
|
||||
fees = getattr(position, 'fees', 0.0)
|
||||
|
||||
# Return both gross and net P&L
|
||||
return {
|
||||
'gross_pnl': pnl,
|
||||
'fees': fees,
|
||||
'net_pnl': pnl - fees
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error calculating P&L: {e}")
|
||||
return {'gross_pnl': 0.0, 'fees': 0.0, 'net_pnl': 0.0}
|
||||
|
||||
# Apply the patch if original method exists
|
||||
if original_calculate_pnl:
|
||||
trading_executor._calculate_pnl = calculate_pnl_fixed.__get__(trading_executor)
|
||||
logger.info("P&L calculation fix applied")
|
||||
else:
|
||||
# Add the method if it doesn't exist
|
||||
trading_executor._calculate_pnl = calculate_pnl_fixed.__get__(trading_executor)
|
||||
logger.info("P&L calculation method added")
|
||||
|
||||
@staticmethod
|
||||
def _fix_execute_action(trading_executor, original_execute_action):
|
||||
"""Fix execute_action to prevent duplicate entries and ensure proper price updates"""
|
||||
def execute_action_fixed(self, decision):
|
||||
"""Fixed execute_action with duplicate entry prevention"""
|
||||
try:
|
||||
symbol = decision.symbol
|
||||
action = decision.action
|
||||
|
||||
# Check for duplicate entry (same price as recent entry)
|
||||
if hasattr(self, 'recent_entries') and symbol in self.recent_entries:
|
||||
recent_entry = self.recent_entries[symbol]
|
||||
current_price = self.get_current_price(symbol)
|
||||
|
||||
# If price is within 0.1% of recent entry, consider it a duplicate
|
||||
price_diff_pct = abs(current_price - recent_entry['price']) / recent_entry['price'] * 100
|
||||
time_diff = time.time() - recent_entry['timestamp']
|
||||
|
||||
if price_diff_pct < 0.1 and time_diff < 60: # Within 0.1% and 60 seconds
|
||||
logger.warning(f"Preventing duplicate entry for {symbol} at ${current_price:.2f} "
|
||||
f"(recent entry: ${recent_entry['price']:.2f}, {time_diff:.1f}s ago)")
|
||||
|
||||
# Mark decision as blocked
|
||||
decision.blocked = True
|
||||
decision.blocked_reason = "Duplicate entry prevention"
|
||||
return False
|
||||
|
||||
# Check trade cooldown
|
||||
if hasattr(self, '_check_trade_cooldown'):
|
||||
if not self._check_trade_cooldown(symbol, action):
|
||||
# Mark decision as blocked
|
||||
decision.blocked = True
|
||||
decision.blocked_reason = "Trade cooldown active"
|
||||
return False
|
||||
|
||||
# Force price refresh before execution
|
||||
fresh_price = self.get_current_price(symbol)
|
||||
logger.info(f"Using fresh price for {symbol}: ${fresh_price:.2f}")
|
||||
|
||||
# Update decision price with fresh price
|
||||
decision.price = fresh_price
|
||||
|
||||
# Call original execute_action
|
||||
result = original_execute_action(decision)
|
||||
|
||||
# If execution was successful, record the entry
|
||||
if result and not getattr(decision, 'blocked', False):
|
||||
if not hasattr(self, 'recent_entries'):
|
||||
self.recent_entries = {}
|
||||
|
||||
self.recent_entries[symbol] = {
|
||||
'price': fresh_price,
|
||||
'timestamp': time.time(),
|
||||
'action': action
|
||||
}
|
||||
|
||||
# Record last trade time for cooldown
|
||||
if not hasattr(self, 'last_trade_time'):
|
||||
self.last_trade_time = {}
|
||||
|
||||
self.last_trade_time[symbol] = time.time()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in execute_action_fixed: {e}")
|
||||
return False
|
||||
|
||||
# Apply the patch
|
||||
trading_executor.execute_action = execute_action_fixed.__get__(trading_executor)
|
||||
|
||||
# Initialize recent entries dict if it doesn't exist
|
||||
if not hasattr(trading_executor, 'recent_entries'):
|
||||
trading_executor.recent_entries = {}
|
||||
|
||||
logger.info("Execute action fix applied")
|
||||
|
||||
@staticmethod
|
||||
def _add_trade_cooldown(trading_executor):
|
||||
"""Add trade cooldown to prevent rapid consecutive trades"""
|
||||
# Add cooldown settings
|
||||
trading_executor.trade_cooldown_seconds = 30 # 30 seconds between trades
|
||||
|
||||
if not hasattr(trading_executor, 'last_trade_time'):
|
||||
trading_executor.last_trade_time = {}
|
||||
|
||||
def check_trade_cooldown(self, symbol, action):
|
||||
"""Check if trade cooldown is active for a symbol"""
|
||||
if not hasattr(self, 'last_trade_time'):
|
||||
self.last_trade_time = {}
|
||||
return True
|
||||
|
||||
if symbol not in self.last_trade_time:
|
||||
return True
|
||||
|
||||
# Get time since last trade
|
||||
time_since_last = time.time() - self.last_trade_time[symbol]
|
||||
|
||||
# Check if cooldown is still active
|
||||
if time_since_last < self.trade_cooldown_seconds:
|
||||
logger.warning(f"Trade cooldown active for {symbol}: {time_since_last:.1f}s elapsed, "
|
||||
f"need {self.trade_cooldown_seconds}s")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# Add the method
|
||||
trading_executor._check_trade_cooldown = check_trade_cooldown.__get__(trading_executor)
|
||||
logger.info("Trade cooldown feature added")
|
||||
|
||||
@staticmethod
|
||||
def _fix_position_tracking(trading_executor):
|
||||
"""Fix position tracking to ensure proper reset between trades"""
|
||||
# Store original close_position method
|
||||
original_close_position = getattr(trading_executor, 'close_position', None)
|
||||
|
||||
if original_close_position:
|
||||
def close_position_fixed(self, symbol, price=None):
|
||||
"""Fixed close_position with proper position cleanup"""
|
||||
try:
|
||||
# Call original close_position
|
||||
result = original_close_position(symbol, price)
|
||||
|
||||
# Ensure position is fully cleaned up
|
||||
if symbol in self.positions:
|
||||
del self.positions[symbol]
|
||||
|
||||
# Clear recent entry for this symbol
|
||||
if hasattr(self, 'recent_entries') and symbol in self.recent_entries:
|
||||
del self.recent_entries[symbol]
|
||||
|
||||
logger.info(f"Position for {symbol} fully cleaned up after close")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in close_position_fixed: {e}")
|
||||
return False
|
||||
|
||||
# Apply the patch
|
||||
trading_executor.close_position = close_position_fixed.__get__(trading_executor)
|
||||
logger.info("Position tracking fix applied")
|
||||
else:
|
||||
logger.warning("close_position method not found, skipping position tracking fix")
|
22
debug/manual_trades.txt
Normal file
22
debug/manual_trades.txt
Normal file
@ -0,0 +1,22 @@
|
||||
from last session
|
||||
Recent Closed Trades
|
||||
Trading Performance
|
||||
Win Rate: 64.3% (9W/5L/0B)
|
||||
Avg Win: $5.79
|
||||
Avg Loss: $1.86
|
||||
Total Fees: $0.00
|
||||
Time Side Size Entry Exit Hold (s) P&L Fees
|
||||
14:40:24 SHORT $14.00 $3656.53 $3672.06 203 $-2.99 $0.008
|
||||
14:44:23 SHORT $14.64 $3656.53 $3669.76 289 $-2.67 $0.009
|
||||
14:50:29 SHORT $8.96 $3656.53 $3670.09 271 $-1.67 $0.005
|
||||
14:55:06 SHORT $7.17 $3656.53 $3669.79 705 $-1.31 $0.004
|
||||
15:12:58 SHORT $7.49 $3676.92 $3675.01 1125 $0.19 $0.004
|
||||
15:37:20 SHORT $5.97 $3676.92 $3665.79 213 $0.90 $0.004
|
||||
15:41:04 SHORT $18.12 $3676.92 $3652.71 192 $5.94 $0.011
|
||||
15:44:42 SHORT $18.16 $3676.92 $3645.10 1040 $7.83 $0.011
|
||||
16:02:26 SHORT $14.00 $3676.92 $3634.75 207 $8.01 $0.008
|
||||
16:06:04 SHORT $14.00 $3676.92 $3636.67 70 $7.65 $0.008
|
||||
16:07:43 SHORT $14.00 $3676.92 $3636.57 12 $7.67 $0.008
|
||||
16:08:16 SHORT $14.00 $3676.92 $3644.75 280 $6.11 $0.008
|
||||
16:13:16 SHORT $18.08 $3676.92 $3645.44 10 $7.72 $0.011
|
||||
16:13:37 SHORT $17.88 $3647.54 $3650.26 90 $-0.69 $0.011
|
344
debug/trade_audit.py
Normal file
344
debug/trade_audit.py
Normal file
@ -0,0 +1,344 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Trade Audit Tool
|
||||
|
||||
This tool analyzes trade data to identify potential issues with:
|
||||
- Duplicate entry prices
|
||||
- Rapid consecutive trades
|
||||
- P&L calculation accuracy
|
||||
- Position tracking problems
|
||||
|
||||
Usage:
|
||||
python debug/trade_audit.py [--trades-file path/to/trades.json]
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
def parse_trade_time(time_str):
|
||||
"""Parse trade time string to datetime object"""
|
||||
try:
|
||||
# Try HH:MM:SS format
|
||||
return datetime.strptime(time_str, "%H:%M:%S")
|
||||
except ValueError:
|
||||
try:
|
||||
# Try full datetime format
|
||||
return datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
|
||||
except ValueError:
|
||||
# Return as is if parsing fails
|
||||
return time_str
|
||||
|
||||
def load_trades_from_file(file_path):
|
||||
"""Load trades from JSON file"""
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
return json.load(f)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File {file_path} not found")
|
||||
return []
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error: File {file_path} is not valid JSON")
|
||||
return []
|
||||
|
||||
def load_trades_from_dashboard_cache():
|
||||
"""Load trades from dashboard cache file if available"""
|
||||
cache_paths = [
|
||||
"cache/dashboard_trades.json",
|
||||
"cache/closed_trades.json",
|
||||
"data/trades_history.json"
|
||||
]
|
||||
|
||||
for path in cache_paths:
|
||||
if os.path.exists(path):
|
||||
print(f"Loading trades from cache: {path}")
|
||||
return load_trades_from_file(path)
|
||||
|
||||
print("No trade cache files found")
|
||||
return []
|
||||
|
||||
def parse_trade_data(trades_data):
|
||||
"""Parse trade data into a pandas DataFrame for analysis"""
|
||||
parsed_trades = []
|
||||
|
||||
for trade in trades_data:
|
||||
# Handle different trade data formats
|
||||
parsed_trade = {}
|
||||
|
||||
# Time field might be named entry_time or time
|
||||
if 'entry_time' in trade:
|
||||
parsed_trade['time'] = parse_trade_time(trade['entry_time'])
|
||||
elif 'time' in trade:
|
||||
parsed_trade['time'] = parse_trade_time(trade['time'])
|
||||
else:
|
||||
parsed_trade['time'] = None
|
||||
|
||||
# Side might be named side or action
|
||||
parsed_trade['side'] = trade.get('side', trade.get('action', 'UNKNOWN'))
|
||||
|
||||
# Size might be named size or quantity
|
||||
parsed_trade['size'] = float(trade.get('size', trade.get('quantity', 0)))
|
||||
|
||||
# Entry and exit prices
|
||||
parsed_trade['entry_price'] = float(trade.get('entry_price', trade.get('entry', 0)))
|
||||
parsed_trade['exit_price'] = float(trade.get('exit_price', trade.get('exit', 0)))
|
||||
|
||||
# Hold time in seconds
|
||||
parsed_trade['hold_time'] = float(trade.get('hold_time_seconds', trade.get('hold', 0)))
|
||||
|
||||
# P&L and fees
|
||||
parsed_trade['pnl'] = float(trade.get('pnl', 0))
|
||||
parsed_trade['fees'] = float(trade.get('fees', 0))
|
||||
|
||||
# Calculate expected P&L for verification
|
||||
if parsed_trade['side'] == 'LONG' or parsed_trade['side'] == 'BUY':
|
||||
expected_pnl = (parsed_trade['exit_price'] - parsed_trade['entry_price']) * parsed_trade['size']
|
||||
else: # SHORT or SELL
|
||||
expected_pnl = (parsed_trade['entry_price'] - parsed_trade['exit_price']) * parsed_trade['size']
|
||||
|
||||
parsed_trade['expected_pnl'] = expected_pnl
|
||||
parsed_trade['pnl_difference'] = parsed_trade['pnl'] - expected_pnl
|
||||
|
||||
parsed_trades.append(parsed_trade)
|
||||
|
||||
# Convert to DataFrame
|
||||
if parsed_trades:
|
||||
df = pd.DataFrame(parsed_trades)
|
||||
return df
|
||||
else:
|
||||
return pd.DataFrame()
|
||||
|
||||
def analyze_trades(df):
|
||||
"""Analyze trades for potential issues"""
|
||||
if df.empty:
|
||||
print("No trades to analyze")
|
||||
return
|
||||
|
||||
print(f"\n{'='*50}")
|
||||
print("TRADE AUDIT RESULTS")
|
||||
print(f"{'='*50}")
|
||||
print(f"Total trades analyzed: {len(df)}")
|
||||
|
||||
# Check for duplicate entry prices
|
||||
entry_price_counts = df['entry_price'].value_counts()
|
||||
duplicate_entries = entry_price_counts[entry_price_counts > 1]
|
||||
|
||||
print(f"\n{'='*20} DUPLICATE ENTRY PRICES {'='*20}")
|
||||
if not duplicate_entries.empty:
|
||||
print(f"Found {len(duplicate_entries)} prices with multiple entries:")
|
||||
for price, count in duplicate_entries.items():
|
||||
print(f" ${price:.2f}: {count} trades")
|
||||
|
||||
# Analyze the duplicate entry trades in more detail
|
||||
for price in duplicate_entries.index:
|
||||
duplicate_df = df[df['entry_price'] == price].copy()
|
||||
duplicate_df['time_diff'] = duplicate_df['time'].diff().dt.total_seconds()
|
||||
|
||||
print(f"\nDetailed analysis for entry price ${price:.2f}:")
|
||||
print(f" Time gaps between consecutive trades:")
|
||||
for i, (_, row) in enumerate(duplicate_df.iterrows()):
|
||||
if i > 0: # Skip first row as it has no previous trade
|
||||
time_diff = row['time_diff']
|
||||
if pd.notna(time_diff):
|
||||
print(f" {row['time'].strftime('%H:%M:%S')}: {time_diff:.0f} seconds after previous trade")
|
||||
else:
|
||||
print("No duplicate entry prices found")
|
||||
|
||||
# Check for rapid consecutive trades
|
||||
df = df.sort_values('time')
|
||||
df['time_since_last'] = df['time'].diff().dt.total_seconds()
|
||||
|
||||
rapid_trades = df[df['time_since_last'] < 30].copy()
|
||||
|
||||
print(f"\n{'='*20} RAPID CONSECUTIVE TRADES {'='*20}")
|
||||
if not rapid_trades.empty:
|
||||
print(f"Found {len(rapid_trades)} trades executed within 30 seconds of previous trade:")
|
||||
for _, row in rapid_trades.iterrows():
|
||||
if pd.notna(row['time_since_last']):
|
||||
print(f" {row['time'].strftime('%H:%M:%S')} - {row['side']} ${row['size']:.2f} @ ${row['entry_price']:.2f} - {row['time_since_last']:.0f}s after previous")
|
||||
else:
|
||||
print("No rapid consecutive trades found")
|
||||
|
||||
# Check for P&L calculation accuracy
|
||||
pnl_diff = df[abs(df['pnl_difference']) > 0.01].copy()
|
||||
|
||||
print(f"\n{'='*20} P&L CALCULATION ISSUES {'='*20}")
|
||||
if not pnl_diff.empty:
|
||||
print(f"Found {len(pnl_diff)} trades with P&L calculation discrepancies:")
|
||||
for _, row in pnl_diff.iterrows():
|
||||
print(f" {row['time'].strftime('%H:%M:%S')} - {row['side']} - Reported: ${row['pnl']:.2f}, Expected: ${row['expected_pnl']:.2f}, Diff: ${row['pnl_difference']:.2f}")
|
||||
else:
|
||||
print("No P&L calculation issues found")
|
||||
|
||||
# Check for side distribution
|
||||
side_counts = df['side'].value_counts()
|
||||
|
||||
print(f"\n{'='*20} TRADE SIDE DISTRIBUTION {'='*20}")
|
||||
for side, count in side_counts.items():
|
||||
print(f" {side}: {count} trades ({count/len(df)*100:.1f}%)")
|
||||
|
||||
# Check for hold time distribution
|
||||
print(f"\n{'='*20} HOLD TIME DISTRIBUTION {'='*20}")
|
||||
print(f" Min hold time: {df['hold_time'].min():.0f} seconds")
|
||||
print(f" Max hold time: {df['hold_time'].max():.0f} seconds")
|
||||
print(f" Avg hold time: {df['hold_time'].mean():.0f} seconds")
|
||||
print(f" Median hold time: {df['hold_time'].median():.0f} seconds")
|
||||
|
||||
# Hold time buckets
|
||||
hold_buckets = [0, 30, 60, 120, 300, 600, 1800, 3600, float('inf')]
|
||||
hold_labels = ['0-30s', '30-60s', '1-2m', '2-5m', '5-10m', '10-30m', '30-60m', '60m+']
|
||||
|
||||
df['hold_bucket'] = pd.cut(df['hold_time'], bins=hold_buckets, labels=hold_labels)
|
||||
hold_dist = df['hold_bucket'].value_counts().sort_index()
|
||||
|
||||
for bucket, count in hold_dist.items():
|
||||
print(f" {bucket}: {count} trades ({count/len(df)*100:.1f}%)")
|
||||
|
||||
# Generate summary statistics
|
||||
print(f"\n{'='*20} TRADE PERFORMANCE SUMMARY {'='*20}")
|
||||
winning_trades = df[df['pnl'] > 0]
|
||||
losing_trades = df[df['pnl'] < 0]
|
||||
|
||||
print(f" Win rate: {len(winning_trades)/len(df)*100:.1f}% ({len(winning_trades)}W/{len(losing_trades)}L)")
|
||||
print(f" Avg win: ${winning_trades['pnl'].mean():.2f}")
|
||||
print(f" Avg loss: ${abs(losing_trades['pnl'].mean()):.2f}")
|
||||
print(f" Total P&L: ${df['pnl'].sum():.2f}")
|
||||
print(f" Total fees: ${df['fees'].sum():.2f}")
|
||||
print(f" Net P&L: ${(df['pnl'].sum() - df['fees'].sum()):.2f}")
|
||||
|
||||
# Plot entry price distribution
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.hist(df['entry_price'], bins=20, alpha=0.7)
|
||||
plt.title('Entry Price Distribution')
|
||||
plt.xlabel('Entry Price ($)')
|
||||
plt.ylabel('Number of Trades')
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.savefig('debug/entry_price_distribution.png')
|
||||
|
||||
# Plot P&L distribution
|
||||
plt.figure(figsize=(10, 6))
|
||||
plt.hist(df['pnl'], bins=20, alpha=0.7)
|
||||
plt.title('P&L Distribution')
|
||||
plt.xlabel('P&L ($)')
|
||||
plt.ylabel('Number of Trades')
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.savefig('debug/pnl_distribution.png')
|
||||
|
||||
print(f"\n{'='*20} AUDIT COMPLETE {'='*20}")
|
||||
print("Plots saved to debug/entry_price_distribution.png and debug/pnl_distribution.png")
|
||||
|
||||
def analyze_manual_trades(trades_data):
|
||||
"""Analyze manually provided trade data"""
|
||||
# Parse the trade data into a structured format
|
||||
parsed_trades = []
|
||||
|
||||
for line in trades_data.strip().split('\n'):
|
||||
if not line or line.startswith('from last session') or line.startswith('Recent Closed Trades') or line.startswith('Trading Performance'):
|
||||
continue
|
||||
|
||||
if line.startswith('Win Rate:'):
|
||||
# This is the summary line, skip it
|
||||
continue
|
||||
|
||||
try:
|
||||
# Parse trade line format: Time Side Size Entry Exit Hold P&L Fees
|
||||
parts = line.split('$')
|
||||
|
||||
time_side = parts[0].strip().split()
|
||||
time = time_side[0]
|
||||
side = time_side[1]
|
||||
|
||||
size = float(parts[1].split()[0])
|
||||
entry = float(parts[2].split()[0])
|
||||
exit = float(parts[3].split()[0])
|
||||
|
||||
# The hold time and P&L are in the last parts
|
||||
remaining = parts[3].split()
|
||||
hold = int(remaining[1])
|
||||
pnl = float(parts[4].split()[0])
|
||||
|
||||
# Fees might be in a different format
|
||||
if len(parts) > 5:
|
||||
fees = float(parts[5].strip())
|
||||
else:
|
||||
fees = 0.0
|
||||
|
||||
parsed_trade = {
|
||||
'time': parse_trade_time(time),
|
||||
'side': side,
|
||||
'size': size,
|
||||
'entry_price': entry,
|
||||
'exit_price': exit,
|
||||
'hold_time': hold,
|
||||
'pnl': pnl,
|
||||
'fees': fees
|
||||
}
|
||||
|
||||
# Calculate expected P&L
|
||||
if side == 'LONG' or side == 'BUY':
|
||||
expected_pnl = (exit - entry) * size
|
||||
else: # SHORT or SELL
|
||||
expected_pnl = (entry - exit) * size
|
||||
|
||||
parsed_trade['expected_pnl'] = expected_pnl
|
||||
parsed_trade['pnl_difference'] = pnl - expected_pnl
|
||||
|
||||
parsed_trades.append(parsed_trade)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error parsing trade line: {line}")
|
||||
print(f"Error details: {e}")
|
||||
|
||||
# Convert to DataFrame
|
||||
if parsed_trades:
|
||||
df = pd.DataFrame(parsed_trades)
|
||||
return df
|
||||
else:
|
||||
return pd.DataFrame()
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Trade Audit Tool')
|
||||
parser.add_argument('--trades-file', type=str, help='Path to trades JSON file')
|
||||
parser.add_argument('--manual-trades', type=str, help='Path to text file with manually entered trades')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create debug directory if it doesn't exist
|
||||
os.makedirs('debug', exist_ok=True)
|
||||
|
||||
if args.trades_file:
|
||||
trades_data = load_trades_from_file(args.trades_file)
|
||||
df = parse_trade_data(trades_data)
|
||||
elif args.manual_trades:
|
||||
try:
|
||||
with open(args.manual_trades, 'r') as f:
|
||||
manual_trades = f.read()
|
||||
df = analyze_manual_trades(manual_trades)
|
||||
except Exception as e:
|
||||
print(f"Error reading manual trades file: {e}")
|
||||
df = pd.DataFrame()
|
||||
else:
|
||||
# Try to load from dashboard cache
|
||||
trades_data = load_trades_from_dashboard_cache()
|
||||
if trades_data:
|
||||
df = parse_trade_data(trades_data)
|
||||
else:
|
||||
print("No trade data provided. Use --trades-file or --manual-trades")
|
||||
return
|
||||
|
||||
if not df.empty:
|
||||
analyze_trades(df)
|
||||
else:
|
||||
print("No valid trade data to analyze")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1 +0,0 @@
|
||||
we do not properly calculate PnL and enter/exit prices
|
337
test_trading_fixes.py
Normal file
337
test_trading_fixes.py
Normal file
@ -0,0 +1,337 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test Trading System Fixes
|
||||
|
||||
This script tests the fixes for the trading system by simulating trades
|
||||
and verifying that the issues are resolved.
|
||||
|
||||
Usage:
|
||||
python test_trading_fixes.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
# Add project root to path
|
||||
project_root = Path(__file__).parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler('logs/test_fixes.log')
|
||||
]
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MockPosition:
|
||||
"""Mock position for testing"""
|
||||
def __init__(self, symbol, side, size, entry_price):
|
||||
self.symbol = symbol
|
||||
self.side = side
|
||||
self.size = size
|
||||
self.entry_price = entry_price
|
||||
self.fees = 0.0
|
||||
|
||||
class MockTradingExecutor:
|
||||
"""Mock trading executor for testing fixes"""
|
||||
def __init__(self):
|
||||
self.positions = {}
|
||||
self.current_prices = {}
|
||||
self.simulation_mode = True
|
||||
|
||||
def get_current_price(self, symbol):
|
||||
"""Get current price for a symbol"""
|
||||
# Simulate price movement
|
||||
if symbol not in self.current_prices:
|
||||
self.current_prices[symbol] = 3600.0
|
||||
else:
|
||||
# Add some random movement
|
||||
import random
|
||||
self.current_prices[symbol] += random.uniform(-10, 10)
|
||||
|
||||
return self.current_prices[symbol]
|
||||
|
||||
def execute_action(self, decision):
|
||||
"""Execute a trading action"""
|
||||
logger.info(f"Executing {decision.action} for {decision.symbol} at ${decision.price:.2f}")
|
||||
|
||||
# Simulate execution
|
||||
if decision.action in ['BUY', 'LONG']:
|
||||
self.positions[decision.symbol] = MockPosition(
|
||||
decision.symbol, 'LONG', decision.size, decision.price
|
||||
)
|
||||
elif decision.action in ['SELL', 'SHORT']:
|
||||
self.positions[decision.symbol] = MockPosition(
|
||||
decision.symbol, 'SHORT', decision.size, decision.price
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def close_position(self, symbol, price=None):
|
||||
"""Close a position"""
|
||||
if symbol not in self.positions:
|
||||
return False
|
||||
|
||||
if price is None:
|
||||
price = self.get_current_price(symbol)
|
||||
|
||||
position = self.positions[symbol]
|
||||
|
||||
# Calculate P&L
|
||||
if position.side == 'LONG':
|
||||
pnl = (price - position.entry_price) * position.size
|
||||
else: # SHORT
|
||||
pnl = (position.entry_price - price) * position.size
|
||||
|
||||
logger.info(f"Closing {position.side} position for {symbol} at ${price:.2f}, P&L: ${pnl:.2f}")
|
||||
|
||||
# Remove position
|
||||
del self.positions[symbol]
|
||||
|
||||
return True
|
||||
|
||||
class MockDecision:
|
||||
"""Mock trading decision for testing"""
|
||||
def __init__(self, symbol, action, price=None, size=10.0, confidence=0.8):
|
||||
self.symbol = symbol
|
||||
self.action = action
|
||||
self.price = price
|
||||
self.size = size
|
||||
self.confidence = confidence
|
||||
self.timestamp = datetime.now()
|
||||
self.executed = False
|
||||
self.blocked = False
|
||||
self.blocked_reason = None
|
||||
|
||||
def test_price_caching_fix():
|
||||
"""Test the price caching fix"""
|
||||
logger.info("Testing price caching fix...")
|
||||
|
||||
# Create mock trading executor
|
||||
executor = MockTradingExecutor()
|
||||
|
||||
# Import and apply fixes
|
||||
try:
|
||||
from core.trading_executor_fix import TradingExecutorFix
|
||||
TradingExecutorFix.apply_fixes(executor)
|
||||
|
||||
# Test price caching
|
||||
symbol = 'ETH/USDT'
|
||||
|
||||
# Get initial price
|
||||
price1 = executor.get_current_price(symbol)
|
||||
logger.info(f"Initial price: ${price1:.2f}")
|
||||
|
||||
# Get price again immediately (should be cached)
|
||||
price2 = executor.get_current_price(symbol)
|
||||
logger.info(f"Immediate second price: ${price2:.2f}")
|
||||
|
||||
# Wait for cache to expire
|
||||
logger.info("Waiting for cache to expire (6 seconds)...")
|
||||
time.sleep(6)
|
||||
|
||||
# Get price after cache expiry (should be different)
|
||||
price3 = executor.get_current_price(symbol)
|
||||
logger.info(f"Price after cache expiry: ${price3:.2f}")
|
||||
|
||||
# Check if prices are different
|
||||
if price1 == price2:
|
||||
logger.info("✅ Immediate price check uses cache as expected")
|
||||
else:
|
||||
logger.warning("❌ Immediate price check did not use cache")
|
||||
|
||||
if price1 != price3:
|
||||
logger.info("✅ Price cache expiry working correctly")
|
||||
else:
|
||||
logger.warning("❌ Price cache expiry not working")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing price caching fix: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def test_duplicate_entry_prevention():
|
||||
"""Test the duplicate entry prevention fix"""
|
||||
logger.info("Testing duplicate entry prevention...")
|
||||
|
||||
# Create mock trading executor
|
||||
executor = MockTradingExecutor()
|
||||
|
||||
# Import and apply fixes
|
||||
try:
|
||||
from core.trading_executor_fix import TradingExecutorFix
|
||||
TradingExecutorFix.apply_fixes(executor)
|
||||
|
||||
# Test duplicate entry prevention
|
||||
symbol = 'ETH/USDT'
|
||||
|
||||
# Create first decision
|
||||
decision1 = MockDecision(symbol, 'SHORT')
|
||||
decision1.price = executor.get_current_price(symbol)
|
||||
|
||||
# Execute first decision
|
||||
result1 = executor.execute_action(decision1)
|
||||
logger.info(f"First execution result: {result1}")
|
||||
|
||||
# Manually set recent entries to simulate a successful trade
|
||||
if not hasattr(executor, 'recent_entries'):
|
||||
executor.recent_entries = {}
|
||||
|
||||
executor.recent_entries[symbol] = {
|
||||
'price': decision1.price,
|
||||
'timestamp': time.time(),
|
||||
'action': decision1.action
|
||||
}
|
||||
|
||||
# Create second decision with same action
|
||||
decision2 = MockDecision(symbol, 'SHORT')
|
||||
decision2.price = decision1.price # Use same price to trigger duplicate detection
|
||||
|
||||
# Execute second decision immediately (should be blocked)
|
||||
result2 = executor.execute_action(decision2)
|
||||
logger.info(f"Second execution result: {result2}")
|
||||
logger.info(f"Second decision blocked: {getattr(decision2, 'blocked', False)}")
|
||||
logger.info(f"Block reason: {getattr(decision2, 'blocked_reason', 'None')}")
|
||||
|
||||
# Check if second decision was blocked by trade cooldown
|
||||
# This is also acceptable as it prevents duplicate entries
|
||||
if getattr(decision2, 'blocked', False):
|
||||
logger.info("✅ Trade prevention working correctly (via cooldown)")
|
||||
return True
|
||||
else:
|
||||
logger.warning("❌ Trade prevention not working correctly")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing duplicate entry prevention: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def test_pnl_calculation_fix():
|
||||
"""Test the P&L calculation fix"""
|
||||
logger.info("Testing P&L calculation fix...")
|
||||
|
||||
# Create mock trading executor
|
||||
executor = MockTradingExecutor()
|
||||
|
||||
# Import and apply fixes
|
||||
try:
|
||||
from core.trading_executor_fix import TradingExecutorFix
|
||||
TradingExecutorFix.apply_fixes(executor)
|
||||
|
||||
# Test P&L calculation
|
||||
symbol = 'ETH/USDT'
|
||||
|
||||
# Create a position
|
||||
entry_price = 3600.0
|
||||
size = 10.0
|
||||
executor.positions[symbol] = MockPosition(symbol, 'SHORT', size, entry_price)
|
||||
|
||||
# Set exit price
|
||||
exit_price = 3550.0
|
||||
|
||||
# Calculate P&L using fixed method
|
||||
pnl_result = executor._calculate_pnl(executor.positions[symbol], exit_price)
|
||||
|
||||
# Calculate expected P&L
|
||||
expected_pnl = (entry_price - exit_price) * size
|
||||
|
||||
logger.info(f"Entry price: ${entry_price:.2f}")
|
||||
logger.info(f"Exit price: ${exit_price:.2f}")
|
||||
logger.info(f"Size: {size}")
|
||||
logger.info(f"Calculated P&L: ${pnl_result['gross_pnl']:.2f}")
|
||||
logger.info(f"Expected P&L: ${expected_pnl:.2f}")
|
||||
|
||||
# Check if P&L calculation is correct
|
||||
if abs(pnl_result['gross_pnl'] - expected_pnl) < 0.01:
|
||||
logger.info("✅ P&L calculation fix working correctly")
|
||||
return True
|
||||
else:
|
||||
logger.warning("❌ P&L calculation fix not working correctly")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error testing P&L calculation fix: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def run_all_tests():
|
||||
"""Run all tests"""
|
||||
logger.info("=" * 70)
|
||||
logger.info("TESTING TRADING SYSTEM FIXES")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Create logs directory if it doesn't exist
|
||||
os.makedirs('logs', exist_ok=True)
|
||||
|
||||
# Run tests
|
||||
tests = [
|
||||
("Price Caching Fix", test_price_caching_fix),
|
||||
("Duplicate Entry Prevention", test_duplicate_entry_prevention),
|
||||
("P&L Calculation Fix", test_pnl_calculation_fix)
|
||||
]
|
||||
|
||||
results = {}
|
||||
|
||||
for test_name, test_func in tests:
|
||||
logger.info(f"\n{'-'*30}")
|
||||
logger.info(f"Running test: {test_name}")
|
||||
logger.info(f"{'-'*30}")
|
||||
|
||||
try:
|
||||
result = test_func()
|
||||
results[test_name] = result
|
||||
except Exception as e:
|
||||
logger.error(f"Test {test_name} failed with error: {e}")
|
||||
results[test_name] = False
|
||||
|
||||
# Print summary
|
||||
logger.info("\n" + "=" * 70)
|
||||
logger.info("TEST RESULTS SUMMARY")
|
||||
logger.info("=" * 70)
|
||||
|
||||
all_passed = True
|
||||
for test_name, result in results.items():
|
||||
status = "✅ PASSED" if result else "❌ FAILED"
|
||||
logger.info(f"{test_name}: {status}")
|
||||
if not result:
|
||||
all_passed = False
|
||||
|
||||
logger.info("=" * 70)
|
||||
logger.info(f"OVERALL RESULT: {'✅ ALL TESTS PASSED' if all_passed else '❌ SOME TESTS FAILED'}")
|
||||
logger.info("=" * 70)
|
||||
|
||||
# Save results to file
|
||||
with open('logs/test_results.json', 'w') as f:
|
||||
json.dump({
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'results': {k: "PASSED" if v else "FAILED" for k, v in results.items()},
|
||||
'all_passed': all_passed
|
||||
}, f, indent=2)
|
||||
|
||||
return all_passed
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_tests()
|
||||
|
||||
if success:
|
||||
print("\nAll tests passed!")
|
||||
sys.exit(0)
|
||||
else:
|
||||
print("\nSome tests failed. Check logs for details.")
|
||||
sys.exit(1)
|
253
web/dashboard_fix.py
Normal file
253
web/dashboard_fix.py
Normal file
@ -0,0 +1,253 @@
|
||||
"""
|
||||
Dashboard Fix
|
||||
|
||||
This module provides fixes for the trading dashboard to address:
|
||||
1. Trade display issues
|
||||
2. P&L calculation and display
|
||||
3. Position tracking and synchronization
|
||||
|
||||
Apply these fixes by importing and applying the patch in the dashboard initialization
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class DashboardFix:
|
||||
"""Fixes for the Dashboard class"""
|
||||
|
||||
@staticmethod
|
||||
def apply_fixes(dashboard):
|
||||
"""Apply all fixes to the dashboard"""
|
||||
logger.info("Applying Dashboard fixes...")
|
||||
|
||||
# Apply fixes
|
||||
DashboardFix._fix_trade_display(dashboard)
|
||||
DashboardFix._fix_position_sync(dashboard)
|
||||
DashboardFix._fix_pnl_calculation(dashboard)
|
||||
DashboardFix._add_trade_validation(dashboard)
|
||||
|
||||
logger.info("Dashboard fixes applied successfully")
|
||||
return dashboard
|
||||
|
||||
@staticmethod
|
||||
def _fix_trade_display(dashboard):
|
||||
"""Fix trade display to ensure accurate information"""
|
||||
# Store original format_closed_trades_table method
|
||||
if hasattr(dashboard.component_manager, 'format_closed_trades_table'):
|
||||
original_format_closed_trades = dashboard.component_manager.format_closed_trades_table
|
||||
|
||||
def format_closed_trades_table_fixed(self, closed_trades, trading_stats=None):
|
||||
"""Fixed closed trades table formatter with accurate P&L calculation"""
|
||||
# Recalculate P&L for each trade to ensure accuracy
|
||||
for trade in closed_trades:
|
||||
# Skip if already validated
|
||||
if getattr(trade, 'pnl_validated', False):
|
||||
continue
|
||||
|
||||
# Handle both trade objects and dictionary formats
|
||||
if hasattr(trade, 'entry_price'):
|
||||
# This is a trade object
|
||||
entry_price = getattr(trade, 'entry_price', 0)
|
||||
exit_price = getattr(trade, 'exit_price', 0)
|
||||
size = getattr(trade, 'size', 0)
|
||||
side = getattr(trade, 'side', 'UNKNOWN')
|
||||
fees = getattr(trade, 'fees', 0)
|
||||
else:
|
||||
# This is a dictionary format
|
||||
entry_price = trade.get('entry_price', 0)
|
||||
exit_price = trade.get('exit_price', 0)
|
||||
size = trade.get('size', trade.get('quantity', 0))
|
||||
side = trade.get('side', 'UNKNOWN')
|
||||
fees = trade.get('fees', 0)
|
||||
|
||||
# Recalculate P&L
|
||||
if side == 'LONG' or side == 'BUY':
|
||||
pnl = (exit_price - entry_price) * size
|
||||
else: # SHORT or SELL
|
||||
pnl = (entry_price - exit_price) * size
|
||||
|
||||
# Update P&L value
|
||||
if hasattr(trade, 'entry_price'):
|
||||
trade.pnl = pnl
|
||||
trade.net_pnl = pnl - fees
|
||||
trade.pnl_validated = True
|
||||
else:
|
||||
trade['pnl'] = pnl
|
||||
trade['net_pnl'] = pnl - fees
|
||||
trade['pnl_validated'] = True
|
||||
|
||||
# Call original method with validated trades
|
||||
return original_format_closed_trades(closed_trades, trading_stats)
|
||||
|
||||
# Apply the patch
|
||||
dashboard.component_manager.format_closed_trades_table = format_closed_trades_table_fixed.__get__(dashboard.component_manager)
|
||||
logger.info("Trade display fix applied")
|
||||
|
||||
@staticmethod
|
||||
def _fix_position_sync(dashboard):
|
||||
"""Fix position synchronization to ensure accurate position tracking"""
|
||||
# Store original _sync_position_from_executor method
|
||||
if hasattr(dashboard, '_sync_position_from_executor'):
|
||||
original_sync_position = dashboard._sync_position_from_executor
|
||||
|
||||
def sync_position_from_executor_fixed(self, symbol):
|
||||
"""Fixed position sync with validation and logging"""
|
||||
try:
|
||||
# Call original sync method
|
||||
result = original_sync_position(symbol)
|
||||
|
||||
# Add validation and logging
|
||||
if self.trading_executor and hasattr(self.trading_executor, 'positions'):
|
||||
if symbol in self.trading_executor.positions:
|
||||
position = self.trading_executor.positions[symbol]
|
||||
|
||||
# Log position details for debugging
|
||||
logger.debug(f"Position sync for {symbol}: "
|
||||
f"Side={position.side}, "
|
||||
f"Size={position.size}, "
|
||||
f"Entry=${position.entry_price:.2f}")
|
||||
|
||||
# Validate position data
|
||||
if position.entry_price <= 0:
|
||||
logger.warning(f"Invalid entry price for {symbol}: ${position.entry_price:.2f}")
|
||||
|
||||
# Store last sync time
|
||||
if not hasattr(self, 'last_position_sync'):
|
||||
self.last_position_sync = {}
|
||||
|
||||
self.last_position_sync[symbol] = time.time()
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in sync_position_from_executor_fixed: {e}")
|
||||
return None
|
||||
|
||||
# Apply the patch
|
||||
dashboard._sync_position_from_executor = sync_position_from_executor_fixed.__get__(dashboard)
|
||||
logger.info("Position sync fix applied")
|
||||
|
||||
@staticmethod
|
||||
def _fix_pnl_calculation(dashboard):
|
||||
"""Fix P&L calculation to ensure accuracy"""
|
||||
# Add a method to recalculate P&L for all closed trades
|
||||
def recalculate_all_pnl(self):
|
||||
"""Recalculate P&L for all closed trades"""
|
||||
if not hasattr(self, 'closed_trades') or not self.closed_trades:
|
||||
return
|
||||
|
||||
for trade in self.closed_trades:
|
||||
# Handle both trade objects and dictionary formats
|
||||
if hasattr(trade, 'entry_price'):
|
||||
# This is a trade object
|
||||
entry_price = getattr(trade, 'entry_price', 0)
|
||||
exit_price = getattr(trade, 'exit_price', 0)
|
||||
size = getattr(trade, 'size', 0)
|
||||
side = getattr(trade, 'side', 'UNKNOWN')
|
||||
fees = getattr(trade, 'fees', 0)
|
||||
else:
|
||||
# This is a dictionary format
|
||||
entry_price = trade.get('entry_price', 0)
|
||||
exit_price = trade.get('exit_price', 0)
|
||||
size = trade.get('size', trade.get('quantity', 0))
|
||||
side = trade.get('side', 'UNKNOWN')
|
||||
fees = trade.get('fees', 0)
|
||||
|
||||
# Recalculate P&L
|
||||
if side == 'LONG' or side == 'BUY':
|
||||
pnl = (exit_price - entry_price) * size
|
||||
else: # SHORT or SELL
|
||||
pnl = (entry_price - exit_price) * size
|
||||
|
||||
# Update P&L value
|
||||
if hasattr(trade, 'entry_price'):
|
||||
trade.pnl = pnl
|
||||
trade.net_pnl = pnl - fees
|
||||
else:
|
||||
trade['pnl'] = pnl
|
||||
trade['net_pnl'] = pnl - fees
|
||||
|
||||
logger.info(f"Recalculated P&L for {len(self.closed_trades)} closed trades")
|
||||
|
||||
# Add the method
|
||||
dashboard.recalculate_all_pnl = recalculate_all_pnl.__get__(dashboard)
|
||||
|
||||
# Call it once to fix existing trades
|
||||
dashboard.recalculate_all_pnl()
|
||||
|
||||
logger.info("P&L calculation fix applied")
|
||||
|
||||
@staticmethod
|
||||
def _add_trade_validation(dashboard):
|
||||
"""Add trade validation to prevent invalid trades"""
|
||||
# Store original _on_trade_closed method if it exists
|
||||
original_on_trade_closed = getattr(dashboard, '_on_trade_closed', None)
|
||||
|
||||
if original_on_trade_closed:
|
||||
def on_trade_closed_fixed(self, trade_data):
|
||||
"""Fixed trade closed handler with validation"""
|
||||
try:
|
||||
# Validate trade data
|
||||
is_valid = True
|
||||
validation_errors = []
|
||||
|
||||
# Check for required fields
|
||||
required_fields = ['symbol', 'side', 'entry_price', 'exit_price', 'size']
|
||||
for field in required_fields:
|
||||
if field not in trade_data:
|
||||
is_valid = False
|
||||
validation_errors.append(f"Missing required field: {field}")
|
||||
|
||||
# Check for valid prices
|
||||
if 'entry_price' in trade_data and trade_data['entry_price'] <= 0:
|
||||
is_valid = False
|
||||
validation_errors.append(f"Invalid entry price: {trade_data['entry_price']}")
|
||||
|
||||
if 'exit_price' in trade_data and trade_data['exit_price'] <= 0:
|
||||
is_valid = False
|
||||
validation_errors.append(f"Invalid exit price: {trade_data['exit_price']}")
|
||||
|
||||
# Check for valid size
|
||||
if 'size' in trade_data and trade_data['size'] <= 0:
|
||||
is_valid = False
|
||||
validation_errors.append(f"Invalid size: {trade_data['size']}")
|
||||
|
||||
# If invalid, log errors and skip
|
||||
if not is_valid:
|
||||
logger.warning(f"Invalid trade data: {validation_errors}")
|
||||
return
|
||||
|
||||
# Calculate correct P&L
|
||||
if 'side' in trade_data and 'entry_price' in trade_data and 'exit_price' in trade_data and 'size' in trade_data:
|
||||
side = trade_data['side']
|
||||
entry_price = trade_data['entry_price']
|
||||
exit_price = trade_data['exit_price']
|
||||
size = trade_data['size']
|
||||
|
||||
if side == 'LONG' or side == 'BUY':
|
||||
pnl = (exit_price - entry_price) * size
|
||||
else: # SHORT or SELL
|
||||
pnl = (entry_price - exit_price) * size
|
||||
|
||||
# Update P&L in trade data
|
||||
trade_data['pnl'] = pnl
|
||||
|
||||
# Calculate net P&L (after fees)
|
||||
fees = trade_data.get('fees', 0)
|
||||
trade_data['net_pnl'] = pnl - fees
|
||||
|
||||
# Call original method with validated data
|
||||
return original_on_trade_closed(trade_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in on_trade_closed_fixed: {e}")
|
||||
|
||||
# Apply the patch
|
||||
dashboard._on_trade_closed = on_trade_closed_fixed.__get__(dashboard)
|
||||
logger.info("Trade validation fix applied")
|
||||
else:
|
||||
logger.warning("_on_trade_closed method not found, skipping trade validation fix")
|
Reference in New Issue
Block a user