fix merge
This commit is contained in:
@@ -21,15 +21,6 @@ Key Features:
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
<<<<<<< HEAD
|
||||
from core.reward_calculator import RewardCalculator
|
||||
=======
|
||||
import pandas as pd
|
||||
import torch
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Tuple, Any, Callable
|
||||
from dataclasses import dataclass
|
||||
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
@@ -186,48 +177,6 @@ class TrainingIntegration:
|
||||
collection_time = time.time() - start_time
|
||||
self._update_collection_stats(collection_time)
|
||||
|
||||
<<<<<<< HEAD
|
||||
# Get the model's device to ensure tensors are on the same device
|
||||
model_device = next(cnn_model.parameters()).device
|
||||
|
||||
# Create tensors
|
||||
features_tensor = torch.FloatTensor(features).unsqueeze(0).to(model_device)
|
||||
target_tensor = torch.LongTensor([target]).to(model_device)
|
||||
|
||||
# Training step
|
||||
cnn_model.train()
|
||||
cnn_model.optimizer.zero_grad()
|
||||
|
||||
outputs = cnn_model(features_tensor)
|
||||
|
||||
# Handle different output formats
|
||||
if isinstance(outputs, dict):
|
||||
if 'main_output' in outputs:
|
||||
logits = outputs['main_output']
|
||||
elif 'action_logits' in outputs:
|
||||
logits = outputs['action_logits']
|
||||
else:
|
||||
logits = list(outputs.values())[0]
|
||||
else:
|
||||
logits = outputs
|
||||
|
||||
# Calculate loss with reward weighting
|
||||
loss_fn = torch.nn.CrossEntropyLoss()
|
||||
loss = loss_fn(logits, target_tensor)
|
||||
|
||||
# Weight loss by reward magnitude
|
||||
weighted_loss = loss * abs(reward)
|
||||
|
||||
# Backward pass
|
||||
weighted_loss.backward()
|
||||
cnn_model.optimizer.step()
|
||||
|
||||
logger.info(f"CNN trained on trade outcome: P&L=${pnl:.2f}, loss={loss.item():.4f}")
|
||||
return True
|
||||
=======
|
||||
# Wait for next collection cycle
|
||||
time.sleep(self.config.collection_interval)
|
||||
>>>>>>> d49a473ed6f4aef55bfdd47d6370e53582be6b7b
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in data collection worker: {e}")
|
||||
|
||||
Reference in New Issue
Block a user