training wip

This commit is contained in:
Dobromir Popov
2025-07-27 23:45:57 +03:00
parent 39267697f3
commit b4076241c9
4 changed files with 283 additions and 66 deletions

View File

@ -80,6 +80,9 @@ class EnhancedCNN(nn.Module):
self.n_actions = n_actions
self.confidence_threshold = confidence_threshold
# Training data storage
self.training_data = []
# Calculate input dimensions
if isinstance(input_shape, (list, tuple)):
if len(input_shape) == 3: # [channels, height, width]
@ -648,6 +651,30 @@ class EnhancedCNN(nn.Module):
'strength': 0.0,
'weighted_strength': 0.0
}
def add_training_data(self, state, action, reward):
"""
Add training data to the model's training buffer
Args:
state: Input state
action: Action taken
reward: Reward received
"""
try:
self.training_data.append({
'state': state,
'action': action,
'reward': reward,
'timestamp': time.time()
})
# Keep only the last 1000 training samples
if len(self.training_data) > 1000:
self.training_data = self.training_data[-1000:]
except Exception as e:
logger.error(f"Error adding training data: {e}")
def save(self, path):
"""Save model weights and architecture"""