training wip
This commit is contained in:
@ -80,6 +80,9 @@ class EnhancedCNN(nn.Module):
|
||||
self.n_actions = n_actions
|
||||
self.confidence_threshold = confidence_threshold
|
||||
|
||||
# Training data storage
|
||||
self.training_data = []
|
||||
|
||||
# Calculate input dimensions
|
||||
if isinstance(input_shape, (list, tuple)):
|
||||
if len(input_shape) == 3: # [channels, height, width]
|
||||
@ -648,6 +651,30 @@ class EnhancedCNN(nn.Module):
|
||||
'strength': 0.0,
|
||||
'weighted_strength': 0.0
|
||||
}
|
||||
|
||||
def add_training_data(self, state, action, reward):
|
||||
"""
|
||||
Add training data to the model's training buffer
|
||||
|
||||
Args:
|
||||
state: Input state
|
||||
action: Action taken
|
||||
reward: Reward received
|
||||
"""
|
||||
try:
|
||||
self.training_data.append({
|
||||
'state': state,
|
||||
'action': action,
|
||||
'reward': reward,
|
||||
'timestamp': time.time()
|
||||
})
|
||||
|
||||
# Keep only the last 1000 training samples
|
||||
if len(self.training_data) > 1000:
|
||||
self.training_data = self.training_data[-1000:]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding training data: {e}")
|
||||
|
||||
def save(self, path):
|
||||
"""Save model weights and architecture"""
|
||||
|
Reference in New Issue
Block a user