fix emojies
This commit is contained in:
@@ -109,9 +109,9 @@ def test_model_outputs():
|
||||
print(f" Mean value: {candle_pred.mean().item():.6f}")
|
||||
|
||||
if candle_pred.min() >= 0.0 and candle_pred.max() <= 1.0:
|
||||
print(" ✅ PASS: Values in [0, 1] range (Sigmoid working!)")
|
||||
print(" PASS: Values in [0, 1] range (Sigmoid working!)")
|
||||
else:
|
||||
print(" ❌ FAIL: Values outside [0, 1] range!")
|
||||
print(" FAIL: Values outside [0, 1] range!")
|
||||
|
||||
# Check price prediction is in [-1, 1] range (thanks to Tanh)
|
||||
if 'price_prediction' in outputs:
|
||||
@@ -121,9 +121,9 @@ def test_model_outputs():
|
||||
print(f" Value: {price_pred.item():.6f}")
|
||||
|
||||
if price_pred.min() >= -1.0 and price_pred.max() <= 1.0:
|
||||
print(" ✅ PASS: Values in [-1, 1] range (Tanh working!)")
|
||||
print(" PASS: Values in [-1, 1] range (Tanh working!)")
|
||||
else:
|
||||
print(" ❌ FAIL: Values outside [-1, 1] range!")
|
||||
print(" FAIL: Values outside [-1, 1] range!")
|
||||
|
||||
# Check action probabilities sum to 1
|
||||
if 'action_probs' in outputs:
|
||||
@@ -135,9 +135,9 @@ def test_model_outputs():
|
||||
print(f" Sum: {action_probs[0].sum().item():.6f}")
|
||||
|
||||
if abs(action_probs[0].sum().item() - 1.0) < 0.001:
|
||||
print(" ✅ PASS: Probabilities sum to 1.0")
|
||||
print(" PASS: Probabilities sum to 1.0")
|
||||
else:
|
||||
print(" ❌ FAIL: Probabilities don't sum to 1.0!")
|
||||
print(" FAIL: Probabilities don't sum to 1.0!")
|
||||
|
||||
return outputs
|
||||
|
||||
@@ -180,18 +180,18 @@ def test_denormalization():
|
||||
for i, name in enumerate(['Open', 'High', 'Low', 'Close']):
|
||||
value = denorm_candle[0, i].item()
|
||||
if value < expected_min_price or value > expected_max_price:
|
||||
print(f" ❌ FAIL: {name} price ${value:.2f} outside expected range!")
|
||||
print(f" FAIL: {name} price ${value:.2f} outside expected range!")
|
||||
prices_ok = False
|
||||
|
||||
if prices_ok:
|
||||
print(f" ✅ PASS: All prices in expected range [${expected_min_price}, ${expected_max_price}]")
|
||||
print(f" PASS: All prices in expected range [${expected_min_price}, ${expected_max_price}]")
|
||||
|
||||
# Verify volume
|
||||
volume = denorm_candle[0, 4].item()
|
||||
if norm_params['volume_min'] <= volume <= norm_params['volume_max']:
|
||||
print(f" ✅ PASS: Volume {volume:.2f} in expected range [{norm_params['volume_min']}, {norm_params['volume_max']}]")
|
||||
print(f" PASS: Volume {volume:.2f} in expected range [{norm_params['volume_min']}, {norm_params['volume_max']}]")
|
||||
else:
|
||||
print(f" ❌ FAIL: Volume {volume:.2f} outside expected range!")
|
||||
print(f" FAIL: Volume {volume:.2f} outside expected range!")
|
||||
|
||||
def test_loss_magnitude():
|
||||
"""Test that losses are in reasonable ranges"""
|
||||
@@ -227,15 +227,15 @@ def test_loss_magnitude():
|
||||
all_ok = True
|
||||
|
||||
if result['total_loss'] < 100.0:
|
||||
print(f" ✅ PASS: Total loss < 100 (was {result['total_loss']:.6f})")
|
||||
print(f" PASS: Total loss < 100 (was {result['total_loss']:.6f})")
|
||||
else:
|
||||
print(f" ❌ FAIL: Total loss too high! ({result['total_loss']:.6f})")
|
||||
print(f" FAIL: Total loss too high! ({result['total_loss']:.6f})")
|
||||
all_ok = False
|
||||
|
||||
if result['candle_loss'] < 10.0:
|
||||
print(f" ✅ PASS: Candle loss < 10 (was {result['candle_loss']:.6f})")
|
||||
print(f" PASS: Candle loss < 10 (was {result['candle_loss']:.6f})")
|
||||
else:
|
||||
print(f" ❌ FAIL: Candle loss too high! ({result['candle_loss']:.6f})")
|
||||
print(f" FAIL: Candle loss too high! ({result['candle_loss']:.6f})")
|
||||
all_ok = False
|
||||
|
||||
# Check denormalized losses if available
|
||||
@@ -244,15 +244,15 @@ def test_loss_magnitude():
|
||||
for tf, loss in result['candle_loss_denorm'].items():
|
||||
print(f" {tf}: ${loss:.2f}")
|
||||
if loss < 1000.0:
|
||||
print(f" ✅ PASS: Real price error < $1000")
|
||||
print(f" PASS: Real price error < $1000")
|
||||
else:
|
||||
print(f" ❌ FAIL: Real price error too high!")
|
||||
print(f" FAIL: Real price error too high!")
|
||||
all_ok = False
|
||||
|
||||
if all_ok:
|
||||
print("\n ✅ ALL TESTS PASSED: Losses in reasonable ranges!")
|
||||
print("\n ALL TESTS PASSED: Losses in reasonable ranges!")
|
||||
else:
|
||||
print("\n ❌ SOME TESTS FAILED: Check model/normalization!")
|
||||
print("\n SOME TESTS FAILED: Check model/normalization!")
|
||||
|
||||
return result
|
||||
|
||||
@@ -275,7 +275,7 @@ def main():
|
||||
print("\n" + "=" * 80)
|
||||
print("TEST SUMMARY")
|
||||
print("=" * 80)
|
||||
print("\nIf all tests passed (✅), the normalization fix is working correctly!")
|
||||
print("\nIf all tests passed, the normalization fix is working correctly!")
|
||||
print("You should now see reasonable losses in training logs:")
|
||||
print(" - Total loss: ~0.5-1.0 (not billions!)")
|
||||
print(" - Candle loss: ~0.1-0.3")
|
||||
|
||||
Reference in New Issue
Block a user