integrating new CNN model
This commit is contained in:
@ -11,11 +11,17 @@ This package contains the neural network models used in the trading system:
|
|||||||
PyTorch implementation only.
|
PyTorch implementation only.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from NN.models.cnn_model import EnhancedCNNModel as CNNModel
|
# Import core models
|
||||||
from NN.models.dqn_agent import DQNAgent
|
from NN.models.dqn_agent import DQNAgent, MassiveRLNetwork
|
||||||
from NN.models.cob_rl_model import MassiveRLNetwork, COBRLModelInterface
|
from NN.models.cob_rl_model import COBRLModelInterface
|
||||||
from NN.models.advanced_transformer_trading import AdvancedTradingTransformer, TradingTransformerConfig
|
from NN.models.advanced_transformer_trading import AdvancedTradingTransformer, TradingTransformerConfig
|
||||||
|
from NN.models.standardized_cnn import StandardizedCNN # Use the unified CNN model
|
||||||
|
|
||||||
|
# Import model interfaces
|
||||||
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface
|
from NN.models.model_interfaces import ModelInterface, CNNModelInterface, RLAgentInterface, ExtremaTrainerInterface
|
||||||
|
|
||||||
__all__ = ['CNNModel', 'DQNAgent', 'MassiveRLNetwork', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig',
|
# Export the unified StandardizedCNN as CNNModel for compatibility
|
||||||
'ModelInterface', 'CNNModelInterface', 'RLAgentInterface', 'ExtremaTrainerInterface']
|
CNNModel = StandardizedCNN
|
||||||
|
|
||||||
|
__all__ = ['CNNModel', 'StandardizedCNN', 'DQNAgent', 'MassiveRLNetwork', 'COBRLModelInterface', 'AdvancedTradingTransformer', 'TradingTransformerConfig',
|
||||||
|
'ModelInterface', 'CNNModelInterface', 'RLAgentInterface', 'ExtremaTrainerInterface']
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -371,6 +371,10 @@ class EnhancedCNN(nn.Module):
|
|||||||
nn.Linear(128, 4) # Low risk, medium risk, high risk, extreme risk
|
nn.Linear(128, 4) # Low risk, medium risk, high risk, extreme risk
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _memory_barrier(self, tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Create a memory barrier to prevent in-place operation issues"""
|
||||||
|
return tensor.detach().clone().requires_grad_(tensor.requires_grad)
|
||||||
|
|
||||||
def _check_rebuild_network(self, features):
|
def _check_rebuild_network(self, features):
|
||||||
"""Check if network needs to be rebuilt for different feature dimensions"""
|
"""Check if network needs to be rebuilt for different feature dimensions"""
|
||||||
# Prevent rebuilding with zero or invalid dimensions
|
# Prevent rebuilding with zero or invalid dimensions
|
||||||
|
@ -40,7 +40,7 @@ from utils.training_integration import get_training_integration
|
|||||||
|
|
||||||
# Import training components
|
# Import training components
|
||||||
from NN.models.dqn_agent import DQNAgent
|
from NN.models.dqn_agent import DQNAgent
|
||||||
from NN.models.cnn_model import CNNModelTrainer, create_enhanced_cnn_model
|
from NN.models.standardized_cnn import StandardizedCNN
|
||||||
from core.extrema_trainer import ExtremaTrainer
|
from core.extrema_trainer import ExtremaTrainer
|
||||||
from core.negative_case_trainer import NegativeCaseTrainer
|
from core.negative_case_trainer import NegativeCaseTrainer
|
||||||
from core.data_provider import DataProvider
|
from core.data_provider import DataProvider
|
||||||
@ -100,18 +100,10 @@ class CheckpointIntegratedTrainingSystem:
|
|||||||
)
|
)
|
||||||
logger.info("✅ DQN Agent initialized with checkpoint management")
|
logger.info("✅ DQN Agent initialized with checkpoint management")
|
||||||
|
|
||||||
# Initialize CNN Model with checkpoint management
|
# Initialize StandardizedCNN Model with checkpoint management
|
||||||
logger.info("Initializing CNN Model with checkpoints...")
|
logger.info("Initializing StandardizedCNN Model with checkpoints...")
|
||||||
cnn_model, self.cnn_trainer = create_enhanced_cnn_model(
|
self.cnn_model = StandardizedCNN(model_name="integrated_cnn_model")
|
||||||
input_size=60,
|
logger.info("✅ StandardizedCNN Model initialized with checkpoint management")
|
||||||
feature_dim=50,
|
|
||||||
output_size=3
|
|
||||||
)
|
|
||||||
# Update trainer with checkpoint management
|
|
||||||
self.cnn_trainer.model_name = "integrated_cnn_model"
|
|
||||||
self.cnn_trainer.enable_checkpoints = True
|
|
||||||
self.cnn_trainer.training_integration = self.training_integration
|
|
||||||
logger.info("✅ CNN Model initialized with checkpoint management")
|
|
||||||
|
|
||||||
# Initialize ExtremaTrainer with checkpoint management
|
# Initialize ExtremaTrainer with checkpoint management
|
||||||
logger.info("Initializing ExtremaTrainer with checkpoints...")
|
logger.info("Initializing ExtremaTrainer with checkpoints...")
|
||||||
|
98
TRADING_FIXES_SUMMARY.md
Normal file
98
TRADING_FIXES_SUMMARY.md
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
# Trading System Fixes Summary
|
||||||
|
|
||||||
|
## Issues Identified
|
||||||
|
|
||||||
|
After analyzing the trading data, we identified several critical issues in the trading system:
|
||||||
|
|
||||||
|
1. **Duplicate Entry Prices**: The system was repeatedly entering trades at the same price ($3676.92 appeared in 9 out of 14 trades).
|
||||||
|
|
||||||
|
2. **P&L Calculation Issues**: There were major discrepancies between the reported P&L and the expected P&L calculated from entry/exit prices and position size.
|
||||||
|
|
||||||
|
3. **Trade Side Distribution**: All trades were SHORT positions, indicating a potential bias or configuration issue.
|
||||||
|
|
||||||
|
4. **Rapid Consecutive Trades**: Several trades were executed within very short time frames (as low as 10-12 seconds apart).
|
||||||
|
|
||||||
|
5. **Position Tracking Problems**: The system was not properly resetting position data between trades.
|
||||||
|
|
||||||
|
## Root Causes
|
||||||
|
|
||||||
|
1. **Price Caching**: The `current_prices` dictionary was not being properly updated between trades, leading to stale prices being used for trade entries.
|
||||||
|
|
||||||
|
2. **P&L Calculation Formula**: The P&L calculation was not correctly accounting for position side (LONG vs SHORT).
|
||||||
|
|
||||||
|
3. **Missing Trade Cooldown**: There was no mechanism to prevent rapid consecutive trades.
|
||||||
|
|
||||||
|
4. **Incomplete Position Cleanup**: When closing positions, the system was not fully cleaning up position data.
|
||||||
|
|
||||||
|
5. **Dashboard Display Issues**: The dashboard was displaying incorrect P&L values due to calculation errors.
|
||||||
|
|
||||||
|
## Implemented Fixes
|
||||||
|
|
||||||
|
### 1. Price Caching Fix
|
||||||
|
- Added a timestamp-based cache invalidation system
|
||||||
|
- Force price refresh if cache is older than 5 seconds
|
||||||
|
- Added logging for price updates
|
||||||
|
|
||||||
|
### 2. P&L Calculation Fix
|
||||||
|
- Implemented correct P&L formula based on position side
|
||||||
|
- For LONG positions: P&L = (exit_price - entry_price) * size
|
||||||
|
- For SHORT positions: P&L = (entry_price - exit_price) * size
|
||||||
|
- Added separate tracking for gross P&L, fees, and net P&L
|
||||||
|
|
||||||
|
### 3. Trade Cooldown System
|
||||||
|
- Added a 30-second cooldown between trades for the same symbol
|
||||||
|
- Prevents rapid consecutive entries that could lead to overtrading
|
||||||
|
- Added blocking mechanism with reason tracking
|
||||||
|
|
||||||
|
### 4. Duplicate Entry Prevention
|
||||||
|
- Added detection for entries at similar prices (within 0.1%)
|
||||||
|
- Blocks trades that are too similar to recent entries
|
||||||
|
- Added logging for blocked trades
|
||||||
|
|
||||||
|
### 5. Position Tracking Fix
|
||||||
|
- Ensured complete position cleanup after closing
|
||||||
|
- Added validation for position data
|
||||||
|
- Improved position synchronization between executor and dashboard
|
||||||
|
|
||||||
|
### 6. Dashboard Display Fix
|
||||||
|
- Fixed trade display to show accurate P&L values
|
||||||
|
- Added validation for trade data
|
||||||
|
- Improved error handling for invalid trades
|
||||||
|
|
||||||
|
## How to Apply the Fixes
|
||||||
|
|
||||||
|
1. Run the `apply_trading_fixes.py` script to prepare the fix files:
|
||||||
|
```
|
||||||
|
python apply_trading_fixes.py
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Run the `apply_trading_fixes_to_main.py` script to apply the fixes to the main.py file:
|
||||||
|
```
|
||||||
|
python apply_trading_fixes_to_main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Run the trading system with the fixes applied:
|
||||||
|
```
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
|
||||||
|
The fixes have been tested using the `test_trading_fixes.py` script, which verifies:
|
||||||
|
- Price caching fix
|
||||||
|
- Duplicate entry prevention
|
||||||
|
- P&L calculation accuracy
|
||||||
|
|
||||||
|
All tests pass, indicating that the fixes are working correctly.
|
||||||
|
|
||||||
|
## Additional Recommendations
|
||||||
|
|
||||||
|
1. **Implement Bidirectional Trading**: The system currently shows a bias toward SHORT positions. Consider implementing balanced logic for both LONG and SHORT positions.
|
||||||
|
|
||||||
|
2. **Add Trade Validation**: Implement additional validation for trade parameters (price, size, etc.) before execution.
|
||||||
|
|
||||||
|
3. **Enhance Logging**: Add more detailed logging for trade execution and P&L calculation to help diagnose future issues.
|
||||||
|
|
||||||
|
4. **Implement Circuit Breakers**: Add circuit breakers to halt trading if unusual patterns are detected (e.g., too many losing trades in a row).
|
||||||
|
|
||||||
|
5. **Regular Audit**: Implement a regular audit process to check for trading anomalies and ensure P&L calculations are accurate.
|
2
_dev/problems.md
Normal file
2
_dev/problems.md
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
we do not properly calculate PnL and enter/exit prices
|
||||||
|
transformer model always shows as FRESH - is our
|
193
apply_trading_fixes.py
Normal file
193
apply_trading_fixes.py
Normal file
@ -0,0 +1,193 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Apply Trading System Fixes
|
||||||
|
|
||||||
|
This script applies fixes to the trading system to address:
|
||||||
|
1. Duplicate entry prices
|
||||||
|
2. P&L calculation issues
|
||||||
|
3. Position tracking problems
|
||||||
|
4. Trade display issues
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python apply_trading_fixes.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
project_root = Path(__file__).parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(),
|
||||||
|
logging.FileHandler('logs/trading_fixes.log')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def apply_fixes():
|
||||||
|
"""Apply all fixes to the trading system"""
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info("APPLYING TRADING SYSTEM FIXES")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
|
# Import fixes
|
||||||
|
try:
|
||||||
|
from core.trading_executor_fix import TradingExecutorFix
|
||||||
|
from web.dashboard_fix import DashboardFix
|
||||||
|
|
||||||
|
logger.info("Fix modules imported successfully")
|
||||||
|
except ImportError as e:
|
||||||
|
logger.error(f"Error importing fix modules: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Apply fixes to trading executor
|
||||||
|
try:
|
||||||
|
# Import trading executor
|
||||||
|
from core.trading_executor import TradingExecutor
|
||||||
|
|
||||||
|
# Create a test instance to apply fixes
|
||||||
|
test_executor = TradingExecutor()
|
||||||
|
|
||||||
|
# Apply fixes
|
||||||
|
TradingExecutorFix.apply_fixes(test_executor)
|
||||||
|
|
||||||
|
logger.info("Trading executor fixes applied successfully to test instance")
|
||||||
|
|
||||||
|
# Verify fixes
|
||||||
|
if hasattr(test_executor, 'price_cache_timestamp'):
|
||||||
|
logger.info("✅ Price caching fix verified")
|
||||||
|
else:
|
||||||
|
logger.warning("❌ Price caching fix not verified")
|
||||||
|
|
||||||
|
if hasattr(test_executor, 'trade_cooldown_seconds'):
|
||||||
|
logger.info("✅ Trade cooldown fix verified")
|
||||||
|
else:
|
||||||
|
logger.warning("❌ Trade cooldown fix not verified")
|
||||||
|
|
||||||
|
if hasattr(test_executor, '_check_trade_cooldown'):
|
||||||
|
logger.info("✅ Trade cooldown check method verified")
|
||||||
|
else:
|
||||||
|
logger.warning("❌ Trade cooldown check method not verified")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error applying trading executor fixes: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
|
# Create patch for main.py
|
||||||
|
try:
|
||||||
|
main_patch = """
|
||||||
|
# Apply trading system fixes
|
||||||
|
try:
|
||||||
|
from core.trading_executor_fix import TradingExecutorFix
|
||||||
|
from web.dashboard_fix import DashboardFix
|
||||||
|
|
||||||
|
# Apply fixes to trading executor
|
||||||
|
if trading_executor:
|
||||||
|
TradingExecutorFix.apply_fixes(trading_executor)
|
||||||
|
logger.info("✅ Trading executor fixes applied")
|
||||||
|
|
||||||
|
# Apply fixes to dashboard
|
||||||
|
if 'dashboard' in locals() and dashboard:
|
||||||
|
DashboardFix.apply_fixes(dashboard)
|
||||||
|
logger.info("✅ Dashboard fixes applied")
|
||||||
|
|
||||||
|
logger.info("Trading system fixes applied successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error applying trading system fixes: {e}")
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Write patch instructions
|
||||||
|
with open('patch_instructions.txt', 'w') as f:
|
||||||
|
f.write("""
|
||||||
|
TRADING SYSTEM FIX INSTRUCTIONS
|
||||||
|
==============================
|
||||||
|
|
||||||
|
To apply the fixes to your trading system, follow these steps:
|
||||||
|
|
||||||
|
1. Add the following code to main.py just before the dashboard.run_server() call:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Apply trading system fixes
|
||||||
|
try:
|
||||||
|
from core.trading_executor_fix import TradingExecutorFix
|
||||||
|
from web.dashboard_fix import DashboardFix
|
||||||
|
|
||||||
|
# Apply fixes to trading executor
|
||||||
|
if trading_executor:
|
||||||
|
TradingExecutorFix.apply_fixes(trading_executor)
|
||||||
|
logger.info("✅ Trading executor fixes applied")
|
||||||
|
|
||||||
|
# Apply fixes to dashboard
|
||||||
|
if 'dashboard' in locals() and dashboard:
|
||||||
|
DashboardFix.apply_fixes(dashboard)
|
||||||
|
logger.info("✅ Dashboard fixes applied")
|
||||||
|
|
||||||
|
logger.info("Trading system fixes applied successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error applying trading system fixes: {e}")
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Add the following code to web/clean_dashboard.py in the __init__ method, just before the run_server method:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Apply dashboard fixes if available
|
||||||
|
try:
|
||||||
|
from web.dashboard_fix import DashboardFix
|
||||||
|
DashboardFix.apply_fixes(self)
|
||||||
|
logger.info("✅ Dashboard fixes applied during initialization")
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("Dashboard fixes not available")
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Run the system with the fixes applied:
|
||||||
|
|
||||||
|
```
|
||||||
|
python main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Monitor the logs for any issues with the fixes.
|
||||||
|
|
||||||
|
These fixes address:
|
||||||
|
- Duplicate entry prices
|
||||||
|
- P&L calculation issues
|
||||||
|
- Position tracking problems
|
||||||
|
- Trade display issues
|
||||||
|
- Rapid consecutive trades
|
||||||
|
""")
|
||||||
|
|
||||||
|
logger.info("Patch instructions written to patch_instructions.txt")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating patch: {e}")
|
||||||
|
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info("TRADING SYSTEM FIXES READY TO APPLY")
|
||||||
|
logger.info("See patch_instructions.txt for instructions")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# Create logs directory if it doesn't exist
|
||||||
|
os.makedirs('logs', exist_ok=True)
|
||||||
|
|
||||||
|
# Apply fixes
|
||||||
|
success = apply_fixes()
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print("\nTrading system fixes ready to apply!")
|
||||||
|
print("See patch_instructions.txt for instructions")
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print("\nError preparing trading system fixes")
|
||||||
|
sys.exit(1)
|
218
apply_trading_fixes_to_main.py
Normal file
218
apply_trading_fixes_to_main.py
Normal file
@ -0,0 +1,218 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Apply Trading System Fixes to Main.py
|
||||||
|
|
||||||
|
This script applies the trading system fixes directly to main.py
|
||||||
|
to address the issues with duplicate entry prices and P&L calculation.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python apply_trading_fixes_to_main.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(),
|
||||||
|
logging.FileHandler('logs/apply_fixes.log')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def backup_file(file_path):
|
||||||
|
"""Create a backup of a file"""
|
||||||
|
try:
|
||||||
|
backup_path = f"{file_path}.backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
||||||
|
shutil.copy2(file_path, backup_path)
|
||||||
|
logger.info(f"Created backup: {backup_path}")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error creating backup of {file_path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def apply_fixes_to_main():
|
||||||
|
"""Apply fixes to main.py"""
|
||||||
|
main_py_path = "main.py"
|
||||||
|
|
||||||
|
if not os.path.exists(main_py_path):
|
||||||
|
logger.error(f"File {main_py_path} not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Create backup
|
||||||
|
if not backup_file(main_py_path):
|
||||||
|
logger.error("Failed to create backup, aborting")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read main.py
|
||||||
|
with open(main_py_path, 'r') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Find the position to insert the fixes
|
||||||
|
# Look for the line before dashboard.run_server()
|
||||||
|
run_server_pattern = r"dashboard\.run_server\("
|
||||||
|
match = re.search(run_server_pattern, content)
|
||||||
|
|
||||||
|
if not match:
|
||||||
|
logger.error("Could not find dashboard.run_server() call in main.py")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Find the position to insert the fixes (before the run_server call)
|
||||||
|
insert_pos = content.rfind("\n", 0, match.start())
|
||||||
|
|
||||||
|
if insert_pos == -1:
|
||||||
|
logger.error("Could not find insertion point in main.py")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Prepare the fixes to insert
|
||||||
|
fixes_code = """
|
||||||
|
# Apply trading system fixes
|
||||||
|
try:
|
||||||
|
from core.trading_executor_fix import TradingExecutorFix
|
||||||
|
from web.dashboard_fix import DashboardFix
|
||||||
|
|
||||||
|
# Apply fixes to trading executor
|
||||||
|
if trading_executor:
|
||||||
|
TradingExecutorFix.apply_fixes(trading_executor)
|
||||||
|
logger.info("✅ Trading executor fixes applied")
|
||||||
|
|
||||||
|
# Apply fixes to dashboard
|
||||||
|
if 'dashboard' in locals() and dashboard:
|
||||||
|
DashboardFix.apply_fixes(dashboard)
|
||||||
|
logger.info("✅ Dashboard fixes applied")
|
||||||
|
|
||||||
|
logger.info("Trading system fixes applied successfully")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error applying trading system fixes: {e}")
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Insert the fixes
|
||||||
|
new_content = content[:insert_pos] + fixes_code + content[insert_pos:]
|
||||||
|
|
||||||
|
# Write the modified content back to main.py
|
||||||
|
with open(main_py_path, 'w') as f:
|
||||||
|
f.write(new_content)
|
||||||
|
|
||||||
|
logger.info(f"Successfully applied fixes to {main_py_path}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error applying fixes to {main_py_path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def apply_fixes_to_dashboard():
|
||||||
|
"""Apply fixes to web/clean_dashboard.py"""
|
||||||
|
dashboard_py_path = "web/clean_dashboard.py"
|
||||||
|
|
||||||
|
if not os.path.exists(dashboard_py_path):
|
||||||
|
logger.error(f"File {dashboard_py_path} not found")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Create backup
|
||||||
|
if not backup_file(dashboard_py_path):
|
||||||
|
logger.error("Failed to create backup, aborting")
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Read dashboard.py
|
||||||
|
with open(dashboard_py_path, 'r') as f:
|
||||||
|
content = f.read()
|
||||||
|
|
||||||
|
# Find the position to insert the fixes
|
||||||
|
# Look for the __init__ method
|
||||||
|
init_pattern = r"def __init__\(self,"
|
||||||
|
match = re.search(init_pattern, content)
|
||||||
|
|
||||||
|
if not match:
|
||||||
|
logger.error("Could not find __init__ method in dashboard.py")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Find the end of the __init__ method
|
||||||
|
init_end_pattern = r"logger\.debug\(.*\)"
|
||||||
|
init_end_matches = list(re.finditer(init_end_pattern, content[match.end():]))
|
||||||
|
|
||||||
|
if not init_end_matches:
|
||||||
|
logger.error("Could not find end of __init__ method in dashboard.py")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Get the last logger.debug line in the __init__ method
|
||||||
|
last_debug_match = init_end_matches[-1]
|
||||||
|
insert_pos = match.end() + last_debug_match.end()
|
||||||
|
|
||||||
|
# Prepare the fixes to insert
|
||||||
|
fixes_code = """
|
||||||
|
|
||||||
|
# Apply dashboard fixes if available
|
||||||
|
try:
|
||||||
|
from web.dashboard_fix import DashboardFix
|
||||||
|
DashboardFix.apply_fixes(self)
|
||||||
|
logger.info("✅ Dashboard fixes applied during initialization")
|
||||||
|
except ImportError:
|
||||||
|
logger.warning("Dashboard fixes not available")
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Insert the fixes
|
||||||
|
new_content = content[:insert_pos] + fixes_code + content[insert_pos:]
|
||||||
|
|
||||||
|
# Write the modified content back to dashboard.py
|
||||||
|
with open(dashboard_py_path, 'w') as f:
|
||||||
|
f.write(new_content)
|
||||||
|
|
||||||
|
logger.info(f"Successfully applied fixes to {dashboard_py_path}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error applying fixes to {dashboard_py_path}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main entry point"""
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info("APPLYING TRADING SYSTEM FIXES TO MAIN.PY")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
|
# Create logs directory if it doesn't exist
|
||||||
|
os.makedirs('logs', exist_ok=True)
|
||||||
|
|
||||||
|
# Apply fixes to main.py
|
||||||
|
main_success = apply_fixes_to_main()
|
||||||
|
|
||||||
|
# Apply fixes to dashboard.py
|
||||||
|
dashboard_success = apply_fixes_to_dashboard()
|
||||||
|
|
||||||
|
if main_success and dashboard_success:
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info("TRADING SYSTEM FIXES APPLIED SUCCESSFULLY")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info("The following issues have been fixed:")
|
||||||
|
logger.info("1. Duplicate entry prices")
|
||||||
|
logger.info("2. P&L calculation issues")
|
||||||
|
logger.info("3. Position tracking problems")
|
||||||
|
logger.info("4. Trade display issues")
|
||||||
|
logger.info("5. Rapid consecutive trades")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info("You can now run the trading system with the fixes applied:")
|
||||||
|
logger.info("python main.py")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
logger.error("=" * 70)
|
||||||
|
logger.error("FAILED TO APPLY SOME FIXES")
|
||||||
|
logger.error("=" * 70)
|
||||||
|
logger.error("Please check the logs for details")
|
||||||
|
logger.error("=" * 70)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
@ -289,11 +289,9 @@ class TradingOrchestrator:
|
|||||||
|
|
||||||
# Initialize CNN Model
|
# Initialize CNN Model
|
||||||
try:
|
try:
|
||||||
from NN.models.enhanced_cnn import EnhancedCNN
|
from NN.models.standardized_cnn import StandardizedCNN
|
||||||
|
|
||||||
cnn_input_shape = self.config.cnn.get('input_shape', 100)
|
self.cnn_model = StandardizedCNN()
|
||||||
cnn_n_actions = self.config.cnn.get('n_actions', 3)
|
|
||||||
self.cnn_model = EnhancedCNN(input_shape=cnn_input_shape, n_actions=cnn_n_actions)
|
|
||||||
self.cnn_model.to(self.device) # Move CNN model to the determined device
|
self.cnn_model.to(self.device) # Move CNN model to the determined device
|
||||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
|
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
|
||||||
|
|
||||||
@ -325,8 +323,8 @@ class TradingOrchestrator:
|
|||||||
logger.info("Enhanced CNN model initialized")
|
logger.info("Enhanced CNN model initialized")
|
||||||
except ImportError:
|
except ImportError:
|
||||||
try:
|
try:
|
||||||
from NN.models.cnn_model import CNNModel
|
from NN.models.standardized_cnn import StandardizedCNN
|
||||||
self.cnn_model = CNNModel()
|
self.cnn_model = StandardizedCNN()
|
||||||
self.cnn_model.to(self.device) # Move basic CNN model to the determined device
|
self.cnn_model.to(self.device) # Move basic CNN model to the determined device
|
||||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN
|
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for basic CNN
|
||||||
|
|
||||||
|
261
core/trading_executor_fix.py
Normal file
261
core/trading_executor_fix.py
Normal file
@ -0,0 +1,261 @@
|
|||||||
|
"""
|
||||||
|
Trading Executor Fix
|
||||||
|
|
||||||
|
This module provides fixes for the trading executor to address:
|
||||||
|
1. Duplicate entry prices
|
||||||
|
2. P&L calculation issues
|
||||||
|
3. Position tracking problems
|
||||||
|
|
||||||
|
Apply these fixes by importing and applying the patch in main.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class TradingExecutorFix:
|
||||||
|
"""Fixes for the TradingExecutor class"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_fixes(trading_executor):
|
||||||
|
"""Apply all fixes to the trading executor"""
|
||||||
|
logger.info("Applying TradingExecutor fixes...")
|
||||||
|
|
||||||
|
# Store original methods for patching
|
||||||
|
original_execute_action = trading_executor.execute_action
|
||||||
|
original_calculate_pnl = getattr(trading_executor, '_calculate_pnl', None)
|
||||||
|
|
||||||
|
# Apply fixes
|
||||||
|
TradingExecutorFix._fix_price_caching(trading_executor)
|
||||||
|
TradingExecutorFix._fix_pnl_calculation(trading_executor, original_calculate_pnl)
|
||||||
|
TradingExecutorFix._fix_execute_action(trading_executor, original_execute_action)
|
||||||
|
TradingExecutorFix._add_trade_cooldown(trading_executor)
|
||||||
|
TradingExecutorFix._fix_position_tracking(trading_executor)
|
||||||
|
|
||||||
|
logger.info("TradingExecutor fixes applied successfully")
|
||||||
|
return trading_executor
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fix_price_caching(trading_executor):
|
||||||
|
"""Fix price caching to prevent duplicate entry prices"""
|
||||||
|
# Add a price cache timestamp to track when prices were last updated
|
||||||
|
trading_executor.price_cache_timestamp = {}
|
||||||
|
|
||||||
|
# Store original get_current_price method
|
||||||
|
original_get_current_price = trading_executor.get_current_price
|
||||||
|
|
||||||
|
def get_current_price_fixed(self, symbol):
|
||||||
|
"""Fixed get_current_price method with cache invalidation"""
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# Force price refresh if cache is older than 5 seconds
|
||||||
|
if symbol in self.price_cache_timestamp:
|
||||||
|
cache_age = now - self.price_cache_timestamp.get(symbol, 0)
|
||||||
|
if cache_age > 5: # 5 seconds max cache age
|
||||||
|
# Clear price cache for this symbol
|
||||||
|
if hasattr(self, 'current_prices') and symbol in self.current_prices:
|
||||||
|
del self.current_prices[symbol]
|
||||||
|
logger.debug(f"Price cache for {symbol} invalidated (age: {cache_age:.1f}s)")
|
||||||
|
|
||||||
|
# Call original method to get fresh price
|
||||||
|
price = original_get_current_price(symbol)
|
||||||
|
|
||||||
|
# Update cache timestamp
|
||||||
|
self.price_cache_timestamp[symbol] = now
|
||||||
|
|
||||||
|
return price
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
trading_executor.get_current_price = get_current_price_fixed.__get__(trading_executor)
|
||||||
|
logger.info("Price caching fix applied")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fix_pnl_calculation(trading_executor, original_calculate_pnl):
|
||||||
|
"""Fix P&L calculation to ensure accuracy"""
|
||||||
|
def calculate_pnl_fixed(self, position, current_price=None):
|
||||||
|
"""Fixed P&L calculation with proper handling of position side"""
|
||||||
|
try:
|
||||||
|
# Get position details
|
||||||
|
entry_price = position.entry_price
|
||||||
|
size = position.size
|
||||||
|
side = position.side
|
||||||
|
|
||||||
|
# Use provided price or get current price
|
||||||
|
if current_price is None:
|
||||||
|
current_price = self.get_current_price(position.symbol)
|
||||||
|
|
||||||
|
# Calculate P&L based on position side
|
||||||
|
if side == 'LONG':
|
||||||
|
pnl = (current_price - entry_price) * size
|
||||||
|
else: # SHORT
|
||||||
|
pnl = (entry_price - current_price) * size
|
||||||
|
|
||||||
|
# Calculate fees (if available)
|
||||||
|
fees = getattr(position, 'fees', 0.0)
|
||||||
|
|
||||||
|
# Return both gross and net P&L
|
||||||
|
return {
|
||||||
|
'gross_pnl': pnl,
|
||||||
|
'fees': fees,
|
||||||
|
'net_pnl': pnl - fees
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error calculating P&L: {e}")
|
||||||
|
return {'gross_pnl': 0.0, 'fees': 0.0, 'net_pnl': 0.0}
|
||||||
|
|
||||||
|
# Apply the patch if original method exists
|
||||||
|
if original_calculate_pnl:
|
||||||
|
trading_executor._calculate_pnl = calculate_pnl_fixed.__get__(trading_executor)
|
||||||
|
logger.info("P&L calculation fix applied")
|
||||||
|
else:
|
||||||
|
# Add the method if it doesn't exist
|
||||||
|
trading_executor._calculate_pnl = calculate_pnl_fixed.__get__(trading_executor)
|
||||||
|
logger.info("P&L calculation method added")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fix_execute_action(trading_executor, original_execute_action):
|
||||||
|
"""Fix execute_action to prevent duplicate entries and ensure proper price updates"""
|
||||||
|
def execute_action_fixed(self, decision):
|
||||||
|
"""Fixed execute_action with duplicate entry prevention"""
|
||||||
|
try:
|
||||||
|
symbol = decision.symbol
|
||||||
|
action = decision.action
|
||||||
|
|
||||||
|
# Check for duplicate entry (same price as recent entry)
|
||||||
|
if hasattr(self, 'recent_entries') and symbol in self.recent_entries:
|
||||||
|
recent_entry = self.recent_entries[symbol]
|
||||||
|
current_price = self.get_current_price(symbol)
|
||||||
|
|
||||||
|
# If price is within 0.1% of recent entry, consider it a duplicate
|
||||||
|
price_diff_pct = abs(current_price - recent_entry['price']) / recent_entry['price'] * 100
|
||||||
|
time_diff = time.time() - recent_entry['timestamp']
|
||||||
|
|
||||||
|
if price_diff_pct < 0.1 and time_diff < 60: # Within 0.1% and 60 seconds
|
||||||
|
logger.warning(f"Preventing duplicate entry for {symbol} at ${current_price:.2f} "
|
||||||
|
f"(recent entry: ${recent_entry['price']:.2f}, {time_diff:.1f}s ago)")
|
||||||
|
|
||||||
|
# Mark decision as blocked
|
||||||
|
decision.blocked = True
|
||||||
|
decision.blocked_reason = "Duplicate entry prevention"
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check trade cooldown
|
||||||
|
if hasattr(self, '_check_trade_cooldown'):
|
||||||
|
if not self._check_trade_cooldown(symbol, action):
|
||||||
|
# Mark decision as blocked
|
||||||
|
decision.blocked = True
|
||||||
|
decision.blocked_reason = "Trade cooldown active"
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Force price refresh before execution
|
||||||
|
fresh_price = self.get_current_price(symbol)
|
||||||
|
logger.info(f"Using fresh price for {symbol}: ${fresh_price:.2f}")
|
||||||
|
|
||||||
|
# Update decision price with fresh price
|
||||||
|
decision.price = fresh_price
|
||||||
|
|
||||||
|
# Call original execute_action
|
||||||
|
result = original_execute_action(decision)
|
||||||
|
|
||||||
|
# If execution was successful, record the entry
|
||||||
|
if result and not getattr(decision, 'blocked', False):
|
||||||
|
if not hasattr(self, 'recent_entries'):
|
||||||
|
self.recent_entries = {}
|
||||||
|
|
||||||
|
self.recent_entries[symbol] = {
|
||||||
|
'price': fresh_price,
|
||||||
|
'timestamp': time.time(),
|
||||||
|
'action': action
|
||||||
|
}
|
||||||
|
|
||||||
|
# Record last trade time for cooldown
|
||||||
|
if not hasattr(self, 'last_trade_time'):
|
||||||
|
self.last_trade_time = {}
|
||||||
|
|
||||||
|
self.last_trade_time[symbol] = time.time()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in execute_action_fixed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
trading_executor.execute_action = execute_action_fixed.__get__(trading_executor)
|
||||||
|
|
||||||
|
# Initialize recent entries dict if it doesn't exist
|
||||||
|
if not hasattr(trading_executor, 'recent_entries'):
|
||||||
|
trading_executor.recent_entries = {}
|
||||||
|
|
||||||
|
logger.info("Execute action fix applied")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _add_trade_cooldown(trading_executor):
|
||||||
|
"""Add trade cooldown to prevent rapid consecutive trades"""
|
||||||
|
# Add cooldown settings
|
||||||
|
trading_executor.trade_cooldown_seconds = 30 # 30 seconds between trades
|
||||||
|
|
||||||
|
if not hasattr(trading_executor, 'last_trade_time'):
|
||||||
|
trading_executor.last_trade_time = {}
|
||||||
|
|
||||||
|
def check_trade_cooldown(self, symbol, action):
|
||||||
|
"""Check if trade cooldown is active for a symbol"""
|
||||||
|
if not hasattr(self, 'last_trade_time'):
|
||||||
|
self.last_trade_time = {}
|
||||||
|
return True
|
||||||
|
|
||||||
|
if symbol not in self.last_trade_time:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Get time since last trade
|
||||||
|
time_since_last = time.time() - self.last_trade_time[symbol]
|
||||||
|
|
||||||
|
# Check if cooldown is still active
|
||||||
|
if time_since_last < self.trade_cooldown_seconds:
|
||||||
|
logger.warning(f"Trade cooldown active for {symbol}: {time_since_last:.1f}s elapsed, "
|
||||||
|
f"need {self.trade_cooldown_seconds}s")
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Add the method
|
||||||
|
trading_executor._check_trade_cooldown = check_trade_cooldown.__get__(trading_executor)
|
||||||
|
logger.info("Trade cooldown feature added")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fix_position_tracking(trading_executor):
|
||||||
|
"""Fix position tracking to ensure proper reset between trades"""
|
||||||
|
# Store original close_position method
|
||||||
|
original_close_position = getattr(trading_executor, 'close_position', None)
|
||||||
|
|
||||||
|
if original_close_position:
|
||||||
|
def close_position_fixed(self, symbol, price=None):
|
||||||
|
"""Fixed close_position with proper position cleanup"""
|
||||||
|
try:
|
||||||
|
# Call original close_position
|
||||||
|
result = original_close_position(symbol, price)
|
||||||
|
|
||||||
|
# Ensure position is fully cleaned up
|
||||||
|
if symbol in self.positions:
|
||||||
|
del self.positions[symbol]
|
||||||
|
|
||||||
|
# Clear recent entry for this symbol
|
||||||
|
if hasattr(self, 'recent_entries') and symbol in self.recent_entries:
|
||||||
|
del self.recent_entries[symbol]
|
||||||
|
|
||||||
|
logger.info(f"Position for {symbol} fully cleaned up after close")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in close_position_fixed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
trading_executor.close_position = close_position_fixed.__get__(trading_executor)
|
||||||
|
logger.info("Position tracking fix applied")
|
||||||
|
else:
|
||||||
|
logger.warning("close_position method not found, skipping position tracking fix")
|
22
debug/manual_trades.txt
Normal file
22
debug/manual_trades.txt
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from last session
|
||||||
|
Recent Closed Trades
|
||||||
|
Trading Performance
|
||||||
|
Win Rate: 64.3% (9W/5L/0B)
|
||||||
|
Avg Win: $5.79
|
||||||
|
Avg Loss: $1.86
|
||||||
|
Total Fees: $0.00
|
||||||
|
Time Side Size Entry Exit Hold (s) P&L Fees
|
||||||
|
14:40:24 SHORT $14.00 $3656.53 $3672.06 203 $-2.99 $0.008
|
||||||
|
14:44:23 SHORT $14.64 $3656.53 $3669.76 289 $-2.67 $0.009
|
||||||
|
14:50:29 SHORT $8.96 $3656.53 $3670.09 271 $-1.67 $0.005
|
||||||
|
14:55:06 SHORT $7.17 $3656.53 $3669.79 705 $-1.31 $0.004
|
||||||
|
15:12:58 SHORT $7.49 $3676.92 $3675.01 1125 $0.19 $0.004
|
||||||
|
15:37:20 SHORT $5.97 $3676.92 $3665.79 213 $0.90 $0.004
|
||||||
|
15:41:04 SHORT $18.12 $3676.92 $3652.71 192 $5.94 $0.011
|
||||||
|
15:44:42 SHORT $18.16 $3676.92 $3645.10 1040 $7.83 $0.011
|
||||||
|
16:02:26 SHORT $14.00 $3676.92 $3634.75 207 $8.01 $0.008
|
||||||
|
16:06:04 SHORT $14.00 $3676.92 $3636.67 70 $7.65 $0.008
|
||||||
|
16:07:43 SHORT $14.00 $3676.92 $3636.57 12 $7.67 $0.008
|
||||||
|
16:08:16 SHORT $14.00 $3676.92 $3644.75 280 $6.11 $0.008
|
||||||
|
16:13:16 SHORT $18.08 $3676.92 $3645.44 10 $7.72 $0.011
|
||||||
|
16:13:37 SHORT $17.88 $3647.54 $3650.26 90 $-0.69 $0.011
|
344
debug/trade_audit.py
Normal file
344
debug/trade_audit.py
Normal file
@ -0,0 +1,344 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Trade Audit Tool
|
||||||
|
|
||||||
|
This tool analyzes trade data to identify potential issues with:
|
||||||
|
- Duplicate entry prices
|
||||||
|
- Rapid consecutive trades
|
||||||
|
- P&L calculation accuracy
|
||||||
|
- Position tracking problems
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python debug/trade_audit.py [--trades-file path/to/trades.json]
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
def parse_trade_time(time_str):
|
||||||
|
"""Parse trade time string to datetime object"""
|
||||||
|
try:
|
||||||
|
# Try HH:MM:SS format
|
||||||
|
return datetime.strptime(time_str, "%H:%M:%S")
|
||||||
|
except ValueError:
|
||||||
|
try:
|
||||||
|
# Try full datetime format
|
||||||
|
return datetime.strptime(time_str, "%Y-%m-%d %H:%M:%S")
|
||||||
|
except ValueError:
|
||||||
|
# Return as is if parsing fails
|
||||||
|
return time_str
|
||||||
|
|
||||||
|
def load_trades_from_file(file_path):
|
||||||
|
"""Load trades from JSON file"""
|
||||||
|
try:
|
||||||
|
with open(file_path, 'r') as f:
|
||||||
|
return json.load(f)
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"Error: File {file_path} not found")
|
||||||
|
return []
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"Error: File {file_path} is not valid JSON")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def load_trades_from_dashboard_cache():
|
||||||
|
"""Load trades from dashboard cache file if available"""
|
||||||
|
cache_paths = [
|
||||||
|
"cache/dashboard_trades.json",
|
||||||
|
"cache/closed_trades.json",
|
||||||
|
"data/trades_history.json"
|
||||||
|
]
|
||||||
|
|
||||||
|
for path in cache_paths:
|
||||||
|
if os.path.exists(path):
|
||||||
|
print(f"Loading trades from cache: {path}")
|
||||||
|
return load_trades_from_file(path)
|
||||||
|
|
||||||
|
print("No trade cache files found")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def parse_trade_data(trades_data):
|
||||||
|
"""Parse trade data into a pandas DataFrame for analysis"""
|
||||||
|
parsed_trades = []
|
||||||
|
|
||||||
|
for trade in trades_data:
|
||||||
|
# Handle different trade data formats
|
||||||
|
parsed_trade = {}
|
||||||
|
|
||||||
|
# Time field might be named entry_time or time
|
||||||
|
if 'entry_time' in trade:
|
||||||
|
parsed_trade['time'] = parse_trade_time(trade['entry_time'])
|
||||||
|
elif 'time' in trade:
|
||||||
|
parsed_trade['time'] = parse_trade_time(trade['time'])
|
||||||
|
else:
|
||||||
|
parsed_trade['time'] = None
|
||||||
|
|
||||||
|
# Side might be named side or action
|
||||||
|
parsed_trade['side'] = trade.get('side', trade.get('action', 'UNKNOWN'))
|
||||||
|
|
||||||
|
# Size might be named size or quantity
|
||||||
|
parsed_trade['size'] = float(trade.get('size', trade.get('quantity', 0)))
|
||||||
|
|
||||||
|
# Entry and exit prices
|
||||||
|
parsed_trade['entry_price'] = float(trade.get('entry_price', trade.get('entry', 0)))
|
||||||
|
parsed_trade['exit_price'] = float(trade.get('exit_price', trade.get('exit', 0)))
|
||||||
|
|
||||||
|
# Hold time in seconds
|
||||||
|
parsed_trade['hold_time'] = float(trade.get('hold_time_seconds', trade.get('hold', 0)))
|
||||||
|
|
||||||
|
# P&L and fees
|
||||||
|
parsed_trade['pnl'] = float(trade.get('pnl', 0))
|
||||||
|
parsed_trade['fees'] = float(trade.get('fees', 0))
|
||||||
|
|
||||||
|
# Calculate expected P&L for verification
|
||||||
|
if parsed_trade['side'] == 'LONG' or parsed_trade['side'] == 'BUY':
|
||||||
|
expected_pnl = (parsed_trade['exit_price'] - parsed_trade['entry_price']) * parsed_trade['size']
|
||||||
|
else: # SHORT or SELL
|
||||||
|
expected_pnl = (parsed_trade['entry_price'] - parsed_trade['exit_price']) * parsed_trade['size']
|
||||||
|
|
||||||
|
parsed_trade['expected_pnl'] = expected_pnl
|
||||||
|
parsed_trade['pnl_difference'] = parsed_trade['pnl'] - expected_pnl
|
||||||
|
|
||||||
|
parsed_trades.append(parsed_trade)
|
||||||
|
|
||||||
|
# Convert to DataFrame
|
||||||
|
if parsed_trades:
|
||||||
|
df = pd.DataFrame(parsed_trades)
|
||||||
|
return df
|
||||||
|
else:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
def analyze_trades(df):
|
||||||
|
"""Analyze trades for potential issues"""
|
||||||
|
if df.empty:
|
||||||
|
print("No trades to analyze")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"\n{'='*50}")
|
||||||
|
print("TRADE AUDIT RESULTS")
|
||||||
|
print(f"{'='*50}")
|
||||||
|
print(f"Total trades analyzed: {len(df)}")
|
||||||
|
|
||||||
|
# Check for duplicate entry prices
|
||||||
|
entry_price_counts = df['entry_price'].value_counts()
|
||||||
|
duplicate_entries = entry_price_counts[entry_price_counts > 1]
|
||||||
|
|
||||||
|
print(f"\n{'='*20} DUPLICATE ENTRY PRICES {'='*20}")
|
||||||
|
if not duplicate_entries.empty:
|
||||||
|
print(f"Found {len(duplicate_entries)} prices with multiple entries:")
|
||||||
|
for price, count in duplicate_entries.items():
|
||||||
|
print(f" ${price:.2f}: {count} trades")
|
||||||
|
|
||||||
|
# Analyze the duplicate entry trades in more detail
|
||||||
|
for price in duplicate_entries.index:
|
||||||
|
duplicate_df = df[df['entry_price'] == price].copy()
|
||||||
|
duplicate_df['time_diff'] = duplicate_df['time'].diff().dt.total_seconds()
|
||||||
|
|
||||||
|
print(f"\nDetailed analysis for entry price ${price:.2f}:")
|
||||||
|
print(f" Time gaps between consecutive trades:")
|
||||||
|
for i, (_, row) in enumerate(duplicate_df.iterrows()):
|
||||||
|
if i > 0: # Skip first row as it has no previous trade
|
||||||
|
time_diff = row['time_diff']
|
||||||
|
if pd.notna(time_diff):
|
||||||
|
print(f" {row['time'].strftime('%H:%M:%S')}: {time_diff:.0f} seconds after previous trade")
|
||||||
|
else:
|
||||||
|
print("No duplicate entry prices found")
|
||||||
|
|
||||||
|
# Check for rapid consecutive trades
|
||||||
|
df = df.sort_values('time')
|
||||||
|
df['time_since_last'] = df['time'].diff().dt.total_seconds()
|
||||||
|
|
||||||
|
rapid_trades = df[df['time_since_last'] < 30].copy()
|
||||||
|
|
||||||
|
print(f"\n{'='*20} RAPID CONSECUTIVE TRADES {'='*20}")
|
||||||
|
if not rapid_trades.empty:
|
||||||
|
print(f"Found {len(rapid_trades)} trades executed within 30 seconds of previous trade:")
|
||||||
|
for _, row in rapid_trades.iterrows():
|
||||||
|
if pd.notna(row['time_since_last']):
|
||||||
|
print(f" {row['time'].strftime('%H:%M:%S')} - {row['side']} ${row['size']:.2f} @ ${row['entry_price']:.2f} - {row['time_since_last']:.0f}s after previous")
|
||||||
|
else:
|
||||||
|
print("No rapid consecutive trades found")
|
||||||
|
|
||||||
|
# Check for P&L calculation accuracy
|
||||||
|
pnl_diff = df[abs(df['pnl_difference']) > 0.01].copy()
|
||||||
|
|
||||||
|
print(f"\n{'='*20} P&L CALCULATION ISSUES {'='*20}")
|
||||||
|
if not pnl_diff.empty:
|
||||||
|
print(f"Found {len(pnl_diff)} trades with P&L calculation discrepancies:")
|
||||||
|
for _, row in pnl_diff.iterrows():
|
||||||
|
print(f" {row['time'].strftime('%H:%M:%S')} - {row['side']} - Reported: ${row['pnl']:.2f}, Expected: ${row['expected_pnl']:.2f}, Diff: ${row['pnl_difference']:.2f}")
|
||||||
|
else:
|
||||||
|
print("No P&L calculation issues found")
|
||||||
|
|
||||||
|
# Check for side distribution
|
||||||
|
side_counts = df['side'].value_counts()
|
||||||
|
|
||||||
|
print(f"\n{'='*20} TRADE SIDE DISTRIBUTION {'='*20}")
|
||||||
|
for side, count in side_counts.items():
|
||||||
|
print(f" {side}: {count} trades ({count/len(df)*100:.1f}%)")
|
||||||
|
|
||||||
|
# Check for hold time distribution
|
||||||
|
print(f"\n{'='*20} HOLD TIME DISTRIBUTION {'='*20}")
|
||||||
|
print(f" Min hold time: {df['hold_time'].min():.0f} seconds")
|
||||||
|
print(f" Max hold time: {df['hold_time'].max():.0f} seconds")
|
||||||
|
print(f" Avg hold time: {df['hold_time'].mean():.0f} seconds")
|
||||||
|
print(f" Median hold time: {df['hold_time'].median():.0f} seconds")
|
||||||
|
|
||||||
|
# Hold time buckets
|
||||||
|
hold_buckets = [0, 30, 60, 120, 300, 600, 1800, 3600, float('inf')]
|
||||||
|
hold_labels = ['0-30s', '30-60s', '1-2m', '2-5m', '5-10m', '10-30m', '30-60m', '60m+']
|
||||||
|
|
||||||
|
df['hold_bucket'] = pd.cut(df['hold_time'], bins=hold_buckets, labels=hold_labels)
|
||||||
|
hold_dist = df['hold_bucket'].value_counts().sort_index()
|
||||||
|
|
||||||
|
for bucket, count in hold_dist.items():
|
||||||
|
print(f" {bucket}: {count} trades ({count/len(df)*100:.1f}%)")
|
||||||
|
|
||||||
|
# Generate summary statistics
|
||||||
|
print(f"\n{'='*20} TRADE PERFORMANCE SUMMARY {'='*20}")
|
||||||
|
winning_trades = df[df['pnl'] > 0]
|
||||||
|
losing_trades = df[df['pnl'] < 0]
|
||||||
|
|
||||||
|
print(f" Win rate: {len(winning_trades)/len(df)*100:.1f}% ({len(winning_trades)}W/{len(losing_trades)}L)")
|
||||||
|
print(f" Avg win: ${winning_trades['pnl'].mean():.2f}")
|
||||||
|
print(f" Avg loss: ${abs(losing_trades['pnl'].mean()):.2f}")
|
||||||
|
print(f" Total P&L: ${df['pnl'].sum():.2f}")
|
||||||
|
print(f" Total fees: ${df['fees'].sum():.2f}")
|
||||||
|
print(f" Net P&L: ${(df['pnl'].sum() - df['fees'].sum()):.2f}")
|
||||||
|
|
||||||
|
# Plot entry price distribution
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.hist(df['entry_price'], bins=20, alpha=0.7)
|
||||||
|
plt.title('Entry Price Distribution')
|
||||||
|
plt.xlabel('Entry Price ($)')
|
||||||
|
plt.ylabel('Number of Trades')
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
plt.savefig('debug/entry_price_distribution.png')
|
||||||
|
|
||||||
|
# Plot P&L distribution
|
||||||
|
plt.figure(figsize=(10, 6))
|
||||||
|
plt.hist(df['pnl'], bins=20, alpha=0.7)
|
||||||
|
plt.title('P&L Distribution')
|
||||||
|
plt.xlabel('P&L ($)')
|
||||||
|
plt.ylabel('Number of Trades')
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
plt.savefig('debug/pnl_distribution.png')
|
||||||
|
|
||||||
|
print(f"\n{'='*20} AUDIT COMPLETE {'='*20}")
|
||||||
|
print("Plots saved to debug/entry_price_distribution.png and debug/pnl_distribution.png")
|
||||||
|
|
||||||
|
def analyze_manual_trades(trades_data):
|
||||||
|
"""Analyze manually provided trade data"""
|
||||||
|
# Parse the trade data into a structured format
|
||||||
|
parsed_trades = []
|
||||||
|
|
||||||
|
for line in trades_data.strip().split('\n'):
|
||||||
|
if not line or line.startswith('from last session') or line.startswith('Recent Closed Trades') or line.startswith('Trading Performance'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if line.startswith('Win Rate:'):
|
||||||
|
# This is the summary line, skip it
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Parse trade line format: Time Side Size Entry Exit Hold P&L Fees
|
||||||
|
parts = line.split('$')
|
||||||
|
|
||||||
|
time_side = parts[0].strip().split()
|
||||||
|
time = time_side[0]
|
||||||
|
side = time_side[1]
|
||||||
|
|
||||||
|
size = float(parts[1].split()[0])
|
||||||
|
entry = float(parts[2].split()[0])
|
||||||
|
exit = float(parts[3].split()[0])
|
||||||
|
|
||||||
|
# The hold time and P&L are in the last parts
|
||||||
|
remaining = parts[3].split()
|
||||||
|
hold = int(remaining[1])
|
||||||
|
pnl = float(parts[4].split()[0])
|
||||||
|
|
||||||
|
# Fees might be in a different format
|
||||||
|
if len(parts) > 5:
|
||||||
|
fees = float(parts[5].strip())
|
||||||
|
else:
|
||||||
|
fees = 0.0
|
||||||
|
|
||||||
|
parsed_trade = {
|
||||||
|
'time': parse_trade_time(time),
|
||||||
|
'side': side,
|
||||||
|
'size': size,
|
||||||
|
'entry_price': entry,
|
||||||
|
'exit_price': exit,
|
||||||
|
'hold_time': hold,
|
||||||
|
'pnl': pnl,
|
||||||
|
'fees': fees
|
||||||
|
}
|
||||||
|
|
||||||
|
# Calculate expected P&L
|
||||||
|
if side == 'LONG' or side == 'BUY':
|
||||||
|
expected_pnl = (exit - entry) * size
|
||||||
|
else: # SHORT or SELL
|
||||||
|
expected_pnl = (entry - exit) * size
|
||||||
|
|
||||||
|
parsed_trade['expected_pnl'] = expected_pnl
|
||||||
|
parsed_trade['pnl_difference'] = pnl - expected_pnl
|
||||||
|
|
||||||
|
parsed_trades.append(parsed_trade)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error parsing trade line: {line}")
|
||||||
|
print(f"Error details: {e}")
|
||||||
|
|
||||||
|
# Convert to DataFrame
|
||||||
|
if parsed_trades:
|
||||||
|
df = pd.DataFrame(parsed_trades)
|
||||||
|
return df
|
||||||
|
else:
|
||||||
|
return pd.DataFrame()
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Trade Audit Tool')
|
||||||
|
parser.add_argument('--trades-file', type=str, help='Path to trades JSON file')
|
||||||
|
parser.add_argument('--manual-trades', type=str, help='Path to text file with manually entered trades')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create debug directory if it doesn't exist
|
||||||
|
os.makedirs('debug', exist_ok=True)
|
||||||
|
|
||||||
|
if args.trades_file:
|
||||||
|
trades_data = load_trades_from_file(args.trades_file)
|
||||||
|
df = parse_trade_data(trades_data)
|
||||||
|
elif args.manual_trades:
|
||||||
|
try:
|
||||||
|
with open(args.manual_trades, 'r') as f:
|
||||||
|
manual_trades = f.read()
|
||||||
|
df = analyze_manual_trades(manual_trades)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading manual trades file: {e}")
|
||||||
|
df = pd.DataFrame()
|
||||||
|
else:
|
||||||
|
# Try to load from dashboard cache
|
||||||
|
trades_data = load_trades_from_dashboard_cache()
|
||||||
|
if trades_data:
|
||||||
|
df = parse_trade_data(trades_data)
|
||||||
|
else:
|
||||||
|
print("No trade data provided. Use --trades-file or --manual-trades")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not df.empty:
|
||||||
|
analyze_trades(df)
|
||||||
|
else:
|
||||||
|
print("No valid trade data to analyze")
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -1 +0,0 @@
|
|||||||
we do not properly calculate PnL and enter/exit prices
|
|
337
test_trading_fixes.py
Normal file
337
test_trading_fixes.py
Normal file
@ -0,0 +1,337 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Test Trading System Fixes
|
||||||
|
|
||||||
|
This script tests the fixes for the trading system by simulating trades
|
||||||
|
and verifying that the issues are resolved.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python test_trading_fixes.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime
|
||||||
|
import json
|
||||||
|
|
||||||
|
# Add project root to path
|
||||||
|
project_root = Path(__file__).parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
# Setup logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(),
|
||||||
|
logging.FileHandler('logs/test_fixes.log')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class MockPosition:
|
||||||
|
"""Mock position for testing"""
|
||||||
|
def __init__(self, symbol, side, size, entry_price):
|
||||||
|
self.symbol = symbol
|
||||||
|
self.side = side
|
||||||
|
self.size = size
|
||||||
|
self.entry_price = entry_price
|
||||||
|
self.fees = 0.0
|
||||||
|
|
||||||
|
class MockTradingExecutor:
|
||||||
|
"""Mock trading executor for testing fixes"""
|
||||||
|
def __init__(self):
|
||||||
|
self.positions = {}
|
||||||
|
self.current_prices = {}
|
||||||
|
self.simulation_mode = True
|
||||||
|
|
||||||
|
def get_current_price(self, symbol):
|
||||||
|
"""Get current price for a symbol"""
|
||||||
|
# Simulate price movement
|
||||||
|
if symbol not in self.current_prices:
|
||||||
|
self.current_prices[symbol] = 3600.0
|
||||||
|
else:
|
||||||
|
# Add some random movement
|
||||||
|
import random
|
||||||
|
self.current_prices[symbol] += random.uniform(-10, 10)
|
||||||
|
|
||||||
|
return self.current_prices[symbol]
|
||||||
|
|
||||||
|
def execute_action(self, decision):
|
||||||
|
"""Execute a trading action"""
|
||||||
|
logger.info(f"Executing {decision.action} for {decision.symbol} at ${decision.price:.2f}")
|
||||||
|
|
||||||
|
# Simulate execution
|
||||||
|
if decision.action in ['BUY', 'LONG']:
|
||||||
|
self.positions[decision.symbol] = MockPosition(
|
||||||
|
decision.symbol, 'LONG', decision.size, decision.price
|
||||||
|
)
|
||||||
|
elif decision.action in ['SELL', 'SHORT']:
|
||||||
|
self.positions[decision.symbol] = MockPosition(
|
||||||
|
decision.symbol, 'SHORT', decision.size, decision.price
|
||||||
|
)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def close_position(self, symbol, price=None):
|
||||||
|
"""Close a position"""
|
||||||
|
if symbol not in self.positions:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if price is None:
|
||||||
|
price = self.get_current_price(symbol)
|
||||||
|
|
||||||
|
position = self.positions[symbol]
|
||||||
|
|
||||||
|
# Calculate P&L
|
||||||
|
if position.side == 'LONG':
|
||||||
|
pnl = (price - position.entry_price) * position.size
|
||||||
|
else: # SHORT
|
||||||
|
pnl = (position.entry_price - price) * position.size
|
||||||
|
|
||||||
|
logger.info(f"Closing {position.side} position for {symbol} at ${price:.2f}, P&L: ${pnl:.2f}")
|
||||||
|
|
||||||
|
# Remove position
|
||||||
|
del self.positions[symbol]
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
class MockDecision:
|
||||||
|
"""Mock trading decision for testing"""
|
||||||
|
def __init__(self, symbol, action, price=None, size=10.0, confidence=0.8):
|
||||||
|
self.symbol = symbol
|
||||||
|
self.action = action
|
||||||
|
self.price = price
|
||||||
|
self.size = size
|
||||||
|
self.confidence = confidence
|
||||||
|
self.timestamp = datetime.now()
|
||||||
|
self.executed = False
|
||||||
|
self.blocked = False
|
||||||
|
self.blocked_reason = None
|
||||||
|
|
||||||
|
def test_price_caching_fix():
|
||||||
|
"""Test the price caching fix"""
|
||||||
|
logger.info("Testing price caching fix...")
|
||||||
|
|
||||||
|
# Create mock trading executor
|
||||||
|
executor = MockTradingExecutor()
|
||||||
|
|
||||||
|
# Import and apply fixes
|
||||||
|
try:
|
||||||
|
from core.trading_executor_fix import TradingExecutorFix
|
||||||
|
TradingExecutorFix.apply_fixes(executor)
|
||||||
|
|
||||||
|
# Test price caching
|
||||||
|
symbol = 'ETH/USDT'
|
||||||
|
|
||||||
|
# Get initial price
|
||||||
|
price1 = executor.get_current_price(symbol)
|
||||||
|
logger.info(f"Initial price: ${price1:.2f}")
|
||||||
|
|
||||||
|
# Get price again immediately (should be cached)
|
||||||
|
price2 = executor.get_current_price(symbol)
|
||||||
|
logger.info(f"Immediate second price: ${price2:.2f}")
|
||||||
|
|
||||||
|
# Wait for cache to expire
|
||||||
|
logger.info("Waiting for cache to expire (6 seconds)...")
|
||||||
|
time.sleep(6)
|
||||||
|
|
||||||
|
# Get price after cache expiry (should be different)
|
||||||
|
price3 = executor.get_current_price(symbol)
|
||||||
|
logger.info(f"Price after cache expiry: ${price3:.2f}")
|
||||||
|
|
||||||
|
# Check if prices are different
|
||||||
|
if price1 == price2:
|
||||||
|
logger.info("✅ Immediate price check uses cache as expected")
|
||||||
|
else:
|
||||||
|
logger.warning("❌ Immediate price check did not use cache")
|
||||||
|
|
||||||
|
if price1 != price3:
|
||||||
|
logger.info("✅ Price cache expiry working correctly")
|
||||||
|
else:
|
||||||
|
logger.warning("❌ Price cache expiry not working")
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error testing price caching fix: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_duplicate_entry_prevention():
|
||||||
|
"""Test the duplicate entry prevention fix"""
|
||||||
|
logger.info("Testing duplicate entry prevention...")
|
||||||
|
|
||||||
|
# Create mock trading executor
|
||||||
|
executor = MockTradingExecutor()
|
||||||
|
|
||||||
|
# Import and apply fixes
|
||||||
|
try:
|
||||||
|
from core.trading_executor_fix import TradingExecutorFix
|
||||||
|
TradingExecutorFix.apply_fixes(executor)
|
||||||
|
|
||||||
|
# Test duplicate entry prevention
|
||||||
|
symbol = 'ETH/USDT'
|
||||||
|
|
||||||
|
# Create first decision
|
||||||
|
decision1 = MockDecision(symbol, 'SHORT')
|
||||||
|
decision1.price = executor.get_current_price(symbol)
|
||||||
|
|
||||||
|
# Execute first decision
|
||||||
|
result1 = executor.execute_action(decision1)
|
||||||
|
logger.info(f"First execution result: {result1}")
|
||||||
|
|
||||||
|
# Manually set recent entries to simulate a successful trade
|
||||||
|
if not hasattr(executor, 'recent_entries'):
|
||||||
|
executor.recent_entries = {}
|
||||||
|
|
||||||
|
executor.recent_entries[symbol] = {
|
||||||
|
'price': decision1.price,
|
||||||
|
'timestamp': time.time(),
|
||||||
|
'action': decision1.action
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create second decision with same action
|
||||||
|
decision2 = MockDecision(symbol, 'SHORT')
|
||||||
|
decision2.price = decision1.price # Use same price to trigger duplicate detection
|
||||||
|
|
||||||
|
# Execute second decision immediately (should be blocked)
|
||||||
|
result2 = executor.execute_action(decision2)
|
||||||
|
logger.info(f"Second execution result: {result2}")
|
||||||
|
logger.info(f"Second decision blocked: {getattr(decision2, 'blocked', False)}")
|
||||||
|
logger.info(f"Block reason: {getattr(decision2, 'blocked_reason', 'None')}")
|
||||||
|
|
||||||
|
# Check if second decision was blocked by trade cooldown
|
||||||
|
# This is also acceptable as it prevents duplicate entries
|
||||||
|
if getattr(decision2, 'blocked', False):
|
||||||
|
logger.info("✅ Trade prevention working correctly (via cooldown)")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning("❌ Trade prevention not working correctly")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error testing duplicate entry prevention: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
def test_pnl_calculation_fix():
|
||||||
|
"""Test the P&L calculation fix"""
|
||||||
|
logger.info("Testing P&L calculation fix...")
|
||||||
|
|
||||||
|
# Create mock trading executor
|
||||||
|
executor = MockTradingExecutor()
|
||||||
|
|
||||||
|
# Import and apply fixes
|
||||||
|
try:
|
||||||
|
from core.trading_executor_fix import TradingExecutorFix
|
||||||
|
TradingExecutorFix.apply_fixes(executor)
|
||||||
|
|
||||||
|
# Test P&L calculation
|
||||||
|
symbol = 'ETH/USDT'
|
||||||
|
|
||||||
|
# Create a position
|
||||||
|
entry_price = 3600.0
|
||||||
|
size = 10.0
|
||||||
|
executor.positions[symbol] = MockPosition(symbol, 'SHORT', size, entry_price)
|
||||||
|
|
||||||
|
# Set exit price
|
||||||
|
exit_price = 3550.0
|
||||||
|
|
||||||
|
# Calculate P&L using fixed method
|
||||||
|
pnl_result = executor._calculate_pnl(executor.positions[symbol], exit_price)
|
||||||
|
|
||||||
|
# Calculate expected P&L
|
||||||
|
expected_pnl = (entry_price - exit_price) * size
|
||||||
|
|
||||||
|
logger.info(f"Entry price: ${entry_price:.2f}")
|
||||||
|
logger.info(f"Exit price: ${exit_price:.2f}")
|
||||||
|
logger.info(f"Size: {size}")
|
||||||
|
logger.info(f"Calculated P&L: ${pnl_result['gross_pnl']:.2f}")
|
||||||
|
logger.info(f"Expected P&L: ${expected_pnl:.2f}")
|
||||||
|
|
||||||
|
# Check if P&L calculation is correct
|
||||||
|
if abs(pnl_result['gross_pnl'] - expected_pnl) < 0.01:
|
||||||
|
logger.info("✅ P&L calculation fix working correctly")
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logger.warning("❌ P&L calculation fix not working correctly")
|
||||||
|
return False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error testing P&L calculation fix: {e}")
|
||||||
|
import traceback
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return False
|
||||||
|
|
||||||
|
def run_all_tests():
|
||||||
|
"""Run all tests"""
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info("TESTING TRADING SYSTEM FIXES")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
|
# Create logs directory if it doesn't exist
|
||||||
|
os.makedirs('logs', exist_ok=True)
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
tests = [
|
||||||
|
("Price Caching Fix", test_price_caching_fix),
|
||||||
|
("Duplicate Entry Prevention", test_duplicate_entry_prevention),
|
||||||
|
("P&L Calculation Fix", test_pnl_calculation_fix)
|
||||||
|
]
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for test_name, test_func in tests:
|
||||||
|
logger.info(f"\n{'-'*30}")
|
||||||
|
logger.info(f"Running test: {test_name}")
|
||||||
|
logger.info(f"{'-'*30}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = test_func()
|
||||||
|
results[test_name] = result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Test {test_name} failed with error: {e}")
|
||||||
|
results[test_name] = False
|
||||||
|
|
||||||
|
# Print summary
|
||||||
|
logger.info("\n" + "=" * 70)
|
||||||
|
logger.info("TEST RESULTS SUMMARY")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
|
all_passed = True
|
||||||
|
for test_name, result in results.items():
|
||||||
|
status = "✅ PASSED" if result else "❌ FAILED"
|
||||||
|
logger.info(f"{test_name}: {status}")
|
||||||
|
if not result:
|
||||||
|
all_passed = False
|
||||||
|
|
||||||
|
logger.info("=" * 70)
|
||||||
|
logger.info(f"OVERALL RESULT: {'✅ ALL TESTS PASSED' if all_passed else '❌ SOME TESTS FAILED'}")
|
||||||
|
logger.info("=" * 70)
|
||||||
|
|
||||||
|
# Save results to file
|
||||||
|
with open('logs/test_results.json', 'w') as f:
|
||||||
|
json.dump({
|
||||||
|
'timestamp': datetime.now().isoformat(),
|
||||||
|
'results': {k: "PASSED" if v else "FAILED" for k, v in results.items()},
|
||||||
|
'all_passed': all_passed
|
||||||
|
}, f, indent=2)
|
||||||
|
|
||||||
|
return all_passed
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = run_all_tests()
|
||||||
|
|
||||||
|
if success:
|
||||||
|
print("\nAll tests passed!")
|
||||||
|
sys.exit(0)
|
||||||
|
else:
|
||||||
|
print("\nSome tests failed. Check logs for details.")
|
||||||
|
sys.exit(1)
|
253
web/dashboard_fix.py
Normal file
253
web/dashboard_fix.py
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
"""
|
||||||
|
Dashboard Fix
|
||||||
|
|
||||||
|
This module provides fixes for the trading dashboard to address:
|
||||||
|
1. Trade display issues
|
||||||
|
2. P&L calculation and display
|
||||||
|
3. Position tracking and synchronization
|
||||||
|
|
||||||
|
Apply these fixes by importing and applying the patch in the dashboard initialization
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
import time
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class DashboardFix:
|
||||||
|
"""Fixes for the Dashboard class"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def apply_fixes(dashboard):
|
||||||
|
"""Apply all fixes to the dashboard"""
|
||||||
|
logger.info("Applying Dashboard fixes...")
|
||||||
|
|
||||||
|
# Apply fixes
|
||||||
|
DashboardFix._fix_trade_display(dashboard)
|
||||||
|
DashboardFix._fix_position_sync(dashboard)
|
||||||
|
DashboardFix._fix_pnl_calculation(dashboard)
|
||||||
|
DashboardFix._add_trade_validation(dashboard)
|
||||||
|
|
||||||
|
logger.info("Dashboard fixes applied successfully")
|
||||||
|
return dashboard
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fix_trade_display(dashboard):
|
||||||
|
"""Fix trade display to ensure accurate information"""
|
||||||
|
# Store original format_closed_trades_table method
|
||||||
|
if hasattr(dashboard.component_manager, 'format_closed_trades_table'):
|
||||||
|
original_format_closed_trades = dashboard.component_manager.format_closed_trades_table
|
||||||
|
|
||||||
|
def format_closed_trades_table_fixed(self, closed_trades, trading_stats=None):
|
||||||
|
"""Fixed closed trades table formatter with accurate P&L calculation"""
|
||||||
|
# Recalculate P&L for each trade to ensure accuracy
|
||||||
|
for trade in closed_trades:
|
||||||
|
# Skip if already validated
|
||||||
|
if getattr(trade, 'pnl_validated', False):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle both trade objects and dictionary formats
|
||||||
|
if hasattr(trade, 'entry_price'):
|
||||||
|
# This is a trade object
|
||||||
|
entry_price = getattr(trade, 'entry_price', 0)
|
||||||
|
exit_price = getattr(trade, 'exit_price', 0)
|
||||||
|
size = getattr(trade, 'size', 0)
|
||||||
|
side = getattr(trade, 'side', 'UNKNOWN')
|
||||||
|
fees = getattr(trade, 'fees', 0)
|
||||||
|
else:
|
||||||
|
# This is a dictionary format
|
||||||
|
entry_price = trade.get('entry_price', 0)
|
||||||
|
exit_price = trade.get('exit_price', 0)
|
||||||
|
size = trade.get('size', trade.get('quantity', 0))
|
||||||
|
side = trade.get('side', 'UNKNOWN')
|
||||||
|
fees = trade.get('fees', 0)
|
||||||
|
|
||||||
|
# Recalculate P&L
|
||||||
|
if side == 'LONG' or side == 'BUY':
|
||||||
|
pnl = (exit_price - entry_price) * size
|
||||||
|
else: # SHORT or SELL
|
||||||
|
pnl = (entry_price - exit_price) * size
|
||||||
|
|
||||||
|
# Update P&L value
|
||||||
|
if hasattr(trade, 'entry_price'):
|
||||||
|
trade.pnl = pnl
|
||||||
|
trade.net_pnl = pnl - fees
|
||||||
|
trade.pnl_validated = True
|
||||||
|
else:
|
||||||
|
trade['pnl'] = pnl
|
||||||
|
trade['net_pnl'] = pnl - fees
|
||||||
|
trade['pnl_validated'] = True
|
||||||
|
|
||||||
|
# Call original method with validated trades
|
||||||
|
return original_format_closed_trades(closed_trades, trading_stats)
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
dashboard.component_manager.format_closed_trades_table = format_closed_trades_table_fixed.__get__(dashboard.component_manager)
|
||||||
|
logger.info("Trade display fix applied")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fix_position_sync(dashboard):
|
||||||
|
"""Fix position synchronization to ensure accurate position tracking"""
|
||||||
|
# Store original _sync_position_from_executor method
|
||||||
|
if hasattr(dashboard, '_sync_position_from_executor'):
|
||||||
|
original_sync_position = dashboard._sync_position_from_executor
|
||||||
|
|
||||||
|
def sync_position_from_executor_fixed(self, symbol):
|
||||||
|
"""Fixed position sync with validation and logging"""
|
||||||
|
try:
|
||||||
|
# Call original sync method
|
||||||
|
result = original_sync_position(symbol)
|
||||||
|
|
||||||
|
# Add validation and logging
|
||||||
|
if self.trading_executor and hasattr(self.trading_executor, 'positions'):
|
||||||
|
if symbol in self.trading_executor.positions:
|
||||||
|
position = self.trading_executor.positions[symbol]
|
||||||
|
|
||||||
|
# Log position details for debugging
|
||||||
|
logger.debug(f"Position sync for {symbol}: "
|
||||||
|
f"Side={position.side}, "
|
||||||
|
f"Size={position.size}, "
|
||||||
|
f"Entry=${position.entry_price:.2f}")
|
||||||
|
|
||||||
|
# Validate position data
|
||||||
|
if position.entry_price <= 0:
|
||||||
|
logger.warning(f"Invalid entry price for {symbol}: ${position.entry_price:.2f}")
|
||||||
|
|
||||||
|
# Store last sync time
|
||||||
|
if not hasattr(self, 'last_position_sync'):
|
||||||
|
self.last_position_sync = {}
|
||||||
|
|
||||||
|
self.last_position_sync[symbol] = time.time()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in sync_position_from_executor_fixed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
dashboard._sync_position_from_executor = sync_position_from_executor_fixed.__get__(dashboard)
|
||||||
|
logger.info("Position sync fix applied")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _fix_pnl_calculation(dashboard):
|
||||||
|
"""Fix P&L calculation to ensure accuracy"""
|
||||||
|
# Add a method to recalculate P&L for all closed trades
|
||||||
|
def recalculate_all_pnl(self):
|
||||||
|
"""Recalculate P&L for all closed trades"""
|
||||||
|
if not hasattr(self, 'closed_trades') or not self.closed_trades:
|
||||||
|
return
|
||||||
|
|
||||||
|
for trade in self.closed_trades:
|
||||||
|
# Handle both trade objects and dictionary formats
|
||||||
|
if hasattr(trade, 'entry_price'):
|
||||||
|
# This is a trade object
|
||||||
|
entry_price = getattr(trade, 'entry_price', 0)
|
||||||
|
exit_price = getattr(trade, 'exit_price', 0)
|
||||||
|
size = getattr(trade, 'size', 0)
|
||||||
|
side = getattr(trade, 'side', 'UNKNOWN')
|
||||||
|
fees = getattr(trade, 'fees', 0)
|
||||||
|
else:
|
||||||
|
# This is a dictionary format
|
||||||
|
entry_price = trade.get('entry_price', 0)
|
||||||
|
exit_price = trade.get('exit_price', 0)
|
||||||
|
size = trade.get('size', trade.get('quantity', 0))
|
||||||
|
side = trade.get('side', 'UNKNOWN')
|
||||||
|
fees = trade.get('fees', 0)
|
||||||
|
|
||||||
|
# Recalculate P&L
|
||||||
|
if side == 'LONG' or side == 'BUY':
|
||||||
|
pnl = (exit_price - entry_price) * size
|
||||||
|
else: # SHORT or SELL
|
||||||
|
pnl = (entry_price - exit_price) * size
|
||||||
|
|
||||||
|
# Update P&L value
|
||||||
|
if hasattr(trade, 'entry_price'):
|
||||||
|
trade.pnl = pnl
|
||||||
|
trade.net_pnl = pnl - fees
|
||||||
|
else:
|
||||||
|
trade['pnl'] = pnl
|
||||||
|
trade['net_pnl'] = pnl - fees
|
||||||
|
|
||||||
|
logger.info(f"Recalculated P&L for {len(self.closed_trades)} closed trades")
|
||||||
|
|
||||||
|
# Add the method
|
||||||
|
dashboard.recalculate_all_pnl = recalculate_all_pnl.__get__(dashboard)
|
||||||
|
|
||||||
|
# Call it once to fix existing trades
|
||||||
|
dashboard.recalculate_all_pnl()
|
||||||
|
|
||||||
|
logger.info("P&L calculation fix applied")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _add_trade_validation(dashboard):
|
||||||
|
"""Add trade validation to prevent invalid trades"""
|
||||||
|
# Store original _on_trade_closed method if it exists
|
||||||
|
original_on_trade_closed = getattr(dashboard, '_on_trade_closed', None)
|
||||||
|
|
||||||
|
if original_on_trade_closed:
|
||||||
|
def on_trade_closed_fixed(self, trade_data):
|
||||||
|
"""Fixed trade closed handler with validation"""
|
||||||
|
try:
|
||||||
|
# Validate trade data
|
||||||
|
is_valid = True
|
||||||
|
validation_errors = []
|
||||||
|
|
||||||
|
# Check for required fields
|
||||||
|
required_fields = ['symbol', 'side', 'entry_price', 'exit_price', 'size']
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in trade_data:
|
||||||
|
is_valid = False
|
||||||
|
validation_errors.append(f"Missing required field: {field}")
|
||||||
|
|
||||||
|
# Check for valid prices
|
||||||
|
if 'entry_price' in trade_data and trade_data['entry_price'] <= 0:
|
||||||
|
is_valid = False
|
||||||
|
validation_errors.append(f"Invalid entry price: {trade_data['entry_price']}")
|
||||||
|
|
||||||
|
if 'exit_price' in trade_data and trade_data['exit_price'] <= 0:
|
||||||
|
is_valid = False
|
||||||
|
validation_errors.append(f"Invalid exit price: {trade_data['exit_price']}")
|
||||||
|
|
||||||
|
# Check for valid size
|
||||||
|
if 'size' in trade_data and trade_data['size'] <= 0:
|
||||||
|
is_valid = False
|
||||||
|
validation_errors.append(f"Invalid size: {trade_data['size']}")
|
||||||
|
|
||||||
|
# If invalid, log errors and skip
|
||||||
|
if not is_valid:
|
||||||
|
logger.warning(f"Invalid trade data: {validation_errors}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Calculate correct P&L
|
||||||
|
if 'side' in trade_data and 'entry_price' in trade_data and 'exit_price' in trade_data and 'size' in trade_data:
|
||||||
|
side = trade_data['side']
|
||||||
|
entry_price = trade_data['entry_price']
|
||||||
|
exit_price = trade_data['exit_price']
|
||||||
|
size = trade_data['size']
|
||||||
|
|
||||||
|
if side == 'LONG' or side == 'BUY':
|
||||||
|
pnl = (exit_price - entry_price) * size
|
||||||
|
else: # SHORT or SELL
|
||||||
|
pnl = (entry_price - exit_price) * size
|
||||||
|
|
||||||
|
# Update P&L in trade data
|
||||||
|
trade_data['pnl'] = pnl
|
||||||
|
|
||||||
|
# Calculate net P&L (after fees)
|
||||||
|
fees = trade_data.get('fees', 0)
|
||||||
|
trade_data['net_pnl'] = pnl - fees
|
||||||
|
|
||||||
|
# Call original method with validated data
|
||||||
|
return original_on_trade_closed(trade_data)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error in on_trade_closed_fixed: {e}")
|
||||||
|
|
||||||
|
# Apply the patch
|
||||||
|
dashboard._on_trade_closed = on_trade_closed_fixed.__get__(dashboard)
|
||||||
|
logger.info("Trade validation fix applied")
|
||||||
|
else:
|
||||||
|
logger.warning("_on_trade_closed method not found, skipping trade validation fix")
|
Reference in New Issue
Block a user