sqlite for checkpoints, cleanup
This commit is contained in:
@ -1,11 +0,0 @@
|
||||
"""
|
||||
Neural Network Data
|
||||
=================
|
||||
|
||||
This package is used to store datasets and model outputs.
|
||||
It does not contain any code, but serves as a storage location for:
|
||||
- Training datasets
|
||||
- Evaluation results
|
||||
- Inference outputs
|
||||
- Model checkpoints
|
||||
"""
|
@ -1,104 +0,0 @@
|
||||
{
|
||||
"decision": [
|
||||
{
|
||||
"checkpoint_id": "decision_20250704_082022",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_082022.pt",
|
||||
"created_at": "2025-07-04T08:20:22.416087",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79971076963062,
|
||||
"accuracy": null,
|
||||
"loss": 2.8923120591883844e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250704_082021",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_082021.pt",
|
||||
"created_at": "2025-07-04T08:20:21.900854",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79970038321,
|
||||
"accuracy": null,
|
||||
"loss": 2.996176877014177e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250704_082022",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_082022.pt",
|
||||
"created_at": "2025-07-04T08:20:22.294191",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79969219038436,
|
||||
"accuracy": null,
|
||||
"loss": 3.0781056310808756e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250704_134829",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_134829.pt",
|
||||
"created_at": "2025-07-04T13:48:29.903250",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79967532851693,
|
||||
"accuracy": null,
|
||||
"loss": 3.2467253719811344e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
},
|
||||
{
|
||||
"checkpoint_id": "decision_20250704_214714",
|
||||
"model_name": "decision",
|
||||
"model_type": "decision_fusion",
|
||||
"file_path": "NN\\models\\saved\\decision\\decision_20250704_214714.pt",
|
||||
"created_at": "2025-07-04T21:47:14.427187",
|
||||
"file_size_mb": 0.06720924377441406,
|
||||
"performance_score": 102.79966325731509,
|
||||
"accuracy": null,
|
||||
"loss": 3.3674381887394134e-06,
|
||||
"val_accuracy": null,
|
||||
"val_loss": null,
|
||||
"reward": null,
|
||||
"pnl": null,
|
||||
"epoch": null,
|
||||
"training_time_hours": null,
|
||||
"total_parameters": null,
|
||||
"wandb_run_id": null,
|
||||
"wandb_artifact_name": null
|
||||
}
|
||||
]
|
||||
}
|
@ -1 +0,0 @@
|
||||
{"best_reward": 4791516.572471984, "best_episode": 3250, "best_pnl": 826842167451289.1, "best_win_rate": 0.47368421052631576, "date": "2025-04-01 10:19:16"}
|
@ -1,20 +0,0 @@
|
||||
{
|
||||
"supervised": {
|
||||
"epochs_completed": 22650,
|
||||
"best_val_pnl": 0.0,
|
||||
"best_epoch": 50,
|
||||
"best_win_rate": 0
|
||||
},
|
||||
"reinforcement": {
|
||||
"episodes_completed": 0,
|
||||
"best_reward": -Infinity,
|
||||
"best_episode": 0,
|
||||
"best_win_rate": 0
|
||||
},
|
||||
"hybrid": {
|
||||
"iterations_completed": 453,
|
||||
"best_combined_score": 0.0,
|
||||
"training_started": "2025-04-09T10:30:42.510856",
|
||||
"last_update": "2025-04-09T10:40:02.217840"
|
||||
}
|
||||
}
|
@ -1,326 +0,0 @@
|
||||
{
|
||||
"epochs_completed": 8,
|
||||
"best_val_pnl": 0.0,
|
||||
"best_epoch": 1,
|
||||
"best_win_rate": 0.0,
|
||||
"training_started": "2025-04-02T10:43:58.946682",
|
||||
"last_update": "2025-04-02T10:44:10.940892",
|
||||
"epochs": [
|
||||
{
|
||||
"epoch": 1,
|
||||
"train_loss": 1.0950355529785156,
|
||||
"val_loss": 1.1657923062642415,
|
||||
"train_acc": 0.3255208333333333,
|
||||
"val_acc": 0.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:01.840889",
|
||||
"data_age": 2,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 2,
|
||||
"train_loss": 1.0831659038861592,
|
||||
"val_loss": 1.1212460199991863,
|
||||
"train_acc": 0.390625,
|
||||
"val_acc": 0.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:03.134833",
|
||||
"data_age": 4,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 3,
|
||||
"train_loss": 1.0740693012873332,
|
||||
"val_loss": 1.0992945830027263,
|
||||
"train_acc": 0.4739583333333333,
|
||||
"val_acc": 0.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:04.425272",
|
||||
"data_age": 5,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 4,
|
||||
"train_loss": 1.0747728943824768,
|
||||
"val_loss": 1.0821794271469116,
|
||||
"train_acc": 0.4609375,
|
||||
"val_acc": 0.3229166666666667,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:05.716421",
|
||||
"data_age": 6,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 5,
|
||||
"train_loss": 1.0489931503931682,
|
||||
"val_loss": 1.0669521888097127,
|
||||
"train_acc": 0.5833333333333334,
|
||||
"val_acc": 1.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:07.007935",
|
||||
"data_age": 8,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 6,
|
||||
"train_loss": 1.0533669590950012,
|
||||
"val_loss": 1.0505590836207073,
|
||||
"train_acc": 0.5104166666666666,
|
||||
"val_acc": 1.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:08.296061",
|
||||
"data_age": 9,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 7,
|
||||
"train_loss": 1.0456886688868205,
|
||||
"val_loss": 1.0351698795954387,
|
||||
"train_acc": 0.5651041666666666,
|
||||
"val_acc": 1.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:09.607584",
|
||||
"data_age": 10,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
},
|
||||
{
|
||||
"epoch": 8,
|
||||
"train_loss": 1.040040671825409,
|
||||
"val_loss": 1.0227736632029216,
|
||||
"train_acc": 0.6119791666666666,
|
||||
"val_acc": 1.0,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 1.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-04-02T10:44:10.940892",
|
||||
"data_age": 11,
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"overall_win_rate": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
}
|
||||
}
|
||||
],
|
||||
"cumulative_pnl": {
|
||||
"train": 0.0,
|
||||
"val": 0.0
|
||||
},
|
||||
"total_trades": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
},
|
||||
"total_wins": {
|
||||
"train": 0,
|
||||
"val": 0
|
||||
}
|
||||
}
|
@ -1,192 +0,0 @@
|
||||
{
|
||||
"epochs_completed": 7,
|
||||
"best_val_pnl": 0.002028853100759435,
|
||||
"best_epoch": 6,
|
||||
"best_win_rate": 0.5157894736842106,
|
||||
"training_started": "2025-03-31T02:50:10.418670",
|
||||
"last_update": "2025-03-31T02:50:15.227593",
|
||||
"epochs": [
|
||||
{
|
||||
"epoch": 1,
|
||||
"train_loss": 1.1206786036491394,
|
||||
"val_loss": 1.0542699098587036,
|
||||
"train_acc": 0.11197916666666667,
|
||||
"val_acc": 0.25,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:12.881423",
|
||||
"data_age": 2
|
||||
},
|
||||
{
|
||||
"epoch": 2,
|
||||
"train_loss": 1.1266120672225952,
|
||||
"val_loss": 1.072133183479309,
|
||||
"train_acc": 0.1171875,
|
||||
"val_acc": 0.25,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:13.186840",
|
||||
"data_age": 2
|
||||
},
|
||||
{
|
||||
"epoch": 3,
|
||||
"train_loss": 1.1415620843569438,
|
||||
"val_loss": 1.1701548099517822,
|
||||
"train_acc": 0.1015625,
|
||||
"val_acc": 0.5208333333333334,
|
||||
"train_pnl": 0.0,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.0,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:13.442018",
|
||||
"data_age": 3
|
||||
},
|
||||
{
|
||||
"epoch": 4,
|
||||
"train_loss": 1.1331567962964375,
|
||||
"val_loss": 1.070081114768982,
|
||||
"train_acc": 0.09375,
|
||||
"val_acc": 0.22916666666666666,
|
||||
"train_pnl": 0.010650217327384765,
|
||||
"val_pnl": -0.0007049481907895126,
|
||||
"train_win_rate": 0.49279538904899134,
|
||||
"val_win_rate": 0.40625,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.9036458333333334,
|
||||
"HOLD": 0.09635416666666667
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.3333333333333333,
|
||||
"HOLD": 0.6666666666666666
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:13.739899",
|
||||
"data_age": 3
|
||||
},
|
||||
{
|
||||
"epoch": 5,
|
||||
"train_loss": 1.10965762535731,
|
||||
"val_loss": 1.0485950708389282,
|
||||
"train_acc": 0.12239583333333333,
|
||||
"val_acc": 0.17708333333333334,
|
||||
"train_pnl": 0.011924086862580204,
|
||||
"val_pnl": 0.0,
|
||||
"train_win_rate": 0.5070422535211268,
|
||||
"val_win_rate": 0.0,
|
||||
"best_position_size": 0.1,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.7395833333333334,
|
||||
"HOLD": 0.2604166666666667
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.0,
|
||||
"HOLD": 1.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:14.073439",
|
||||
"data_age": 3
|
||||
},
|
||||
{
|
||||
"epoch": 6,
|
||||
"train_loss": 1.1272419293721516,
|
||||
"val_loss": 1.084235429763794,
|
||||
"train_acc": 0.1015625,
|
||||
"val_acc": 0.22916666666666666,
|
||||
"train_pnl": 0.014825159601390072,
|
||||
"val_pnl": 0.00405770620151887,
|
||||
"train_win_rate": 0.4908616187989556,
|
||||
"val_win_rate": 0.5157894736842106,
|
||||
"best_position_size": 2.0,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 1.0,
|
||||
"HOLD": 0.0
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 1.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:14.658295",
|
||||
"data_age": 4
|
||||
},
|
||||
{
|
||||
"epoch": 7,
|
||||
"train_loss": 1.1171108484268188,
|
||||
"val_loss": 1.0741244554519653,
|
||||
"train_acc": 0.1171875,
|
||||
"val_acc": 0.22916666666666666,
|
||||
"train_pnl": 0.0059474696523706605,
|
||||
"val_pnl": 0.00405770620151887,
|
||||
"train_win_rate": 0.4838709677419355,
|
||||
"val_win_rate": 0.5157894736842106,
|
||||
"best_position_size": 2.0,
|
||||
"signal_distribution": {
|
||||
"train": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 0.7291666666666666,
|
||||
"HOLD": 0.2708333333333333
|
||||
},
|
||||
"val": {
|
||||
"BUY": 0.0,
|
||||
"SELL": 1.0,
|
||||
"HOLD": 0.0
|
||||
}
|
||||
},
|
||||
"timestamp": "2025-03-31T02:50:15.227593",
|
||||
"data_age": 4
|
||||
}
|
||||
]
|
||||
}
|
@ -1,472 +0,0 @@
|
||||
# CNN Model Training, Decision Making, and Dashboard Visualization Analysis
|
||||
|
||||
## Comprehensive Analysis: Enhanced RL Training Systems
|
||||
|
||||
### User Questions Addressed:
|
||||
1. **CNN Model Training Implementation** ✅
|
||||
2. **Decision-Making Model Training System** ✅
|
||||
3. **Model Predictions and Training Progress Visualization on Clean Dashboard** ✅
|
||||
4. **🔧 FIXED: Signal Generation and Model Loading Issues** ✅
|
||||
5. **🎯 FIXED: Manual Trading Execution and Chart Visualization** ✅
|
||||
6. **🚫 CRITICAL FIX: Removed ALL Simulated COB Data - Using REAL COB Only** ✅
|
||||
|
||||
---
|
||||
|
||||
## 🚫 **MAJOR SYSTEM CLEANUP: NO MORE SIMULATED DATA**
|
||||
|
||||
### **🔥 REMOVED ALL SIMULATION COMPONENTS**
|
||||
|
||||
**Problem Identified**: The system was using simulated COB data instead of the real COB integration that's already implemented and working.
|
||||
|
||||
**Root Cause**: Dashboard was creating separate simulated COB components instead of connecting to the existing Enhanced Orchestrator's real COB integration.
|
||||
|
||||
### **💥 SIMULATION COMPONENTS REMOVED:**
|
||||
|
||||
#### **1. Removed Simulated COB Data Generation**
|
||||
- ❌ `_generate_simulated_cob_data()` - **DELETED**
|
||||
- ❌ `_start_cob_simulation_thread()` - **DELETED**
|
||||
- ❌ `_update_cob_cache_from_price_data()` - **DELETED**
|
||||
- ❌ All `random.uniform()` COB data generation - **ELIMINATED**
|
||||
- ❌ Fake bid/ask level creation - **REMOVED**
|
||||
- ❌ Simulated liquidity calculations - **PURGED**
|
||||
|
||||
#### **2. Removed Separate RL COB Trader**
|
||||
- ❌ `RealtimeRLCOBTrader` initialization - **DELETED**
|
||||
- ❌ `cob_rl_trader` instance variables - **REMOVED**
|
||||
- ❌ `cob_predictions` deque caches - **ELIMINATED**
|
||||
- ❌ `cob_data_cache_1d` buffers - **PURGED**
|
||||
- ❌ `cob_raw_ticks` collections - **DELETED**
|
||||
- ❌ `_start_cob_data_subscription()` - **REMOVED**
|
||||
- ❌ `_on_cob_prediction()` callback - **DELETED**
|
||||
|
||||
#### **3. Updated COB Status System**
|
||||
- ✅ **Real COB Integration Detection**: Connects to `orchestrator.cob_integration`
|
||||
- ✅ **Actual COB Statistics**: Uses `cob_integration.get_statistics()`
|
||||
- ✅ **Live COB Snapshots**: Uses `cob_integration.get_cob_snapshot(symbol)`
|
||||
- ✅ **No Simulation Status**: Removed all "Simulated" status messages
|
||||
|
||||
### **🔗 REAL COB INTEGRATION CONNECTION**
|
||||
|
||||
#### **How Real COB Data Works:**
|
||||
1. **Enhanced Orchestrator** initializes with real COB integration
|
||||
2. **COB Integration** connects to live market data streams (Binance, OKX, etc.)
|
||||
3. **Dashboard** connects to orchestrator's COB integration via callbacks
|
||||
4. **Real-time Updates** flow: `Market → COB Provider → COB Integration → Dashboard`
|
||||
|
||||
#### **Real COB Data Path:**
|
||||
```
|
||||
Live Market Data (Multiple Exchanges)
|
||||
↓
|
||||
Multi-Exchange COB Provider
|
||||
↓
|
||||
COB Integration (Real Consolidated Order Book)
|
||||
↓
|
||||
Enhanced Trading Orchestrator
|
||||
↓
|
||||
Clean Trading Dashboard (Real COB Display)
|
||||
```
|
||||
|
||||
### **✅ VERIFICATION IMPLEMENTED**
|
||||
|
||||
#### **Enhanced COB Status Checking:**
|
||||
```python
|
||||
# Check for REAL COB integration from enhanced orchestrator
|
||||
if hasattr(self.orchestrator, 'cob_integration') and self.orchestrator.cob_integration:
|
||||
cob_integration = self.orchestrator.cob_integration
|
||||
|
||||
# Get real COB integration statistics
|
||||
cob_stats = cob_integration.get_statistics()
|
||||
if cob_stats:
|
||||
active_symbols = cob_stats.get('active_symbols', [])
|
||||
total_updates = cob_stats.get('total_updates', 0)
|
||||
provider_status = cob_stats.get('provider_status', 'Unknown')
|
||||
```
|
||||
|
||||
#### **Real COB Data Retrieval:**
|
||||
```python
|
||||
# Get from REAL COB integration via enhanced orchestrator
|
||||
snapshot = cob_integration.get_cob_snapshot(symbol)
|
||||
if snapshot:
|
||||
# Process REAL consolidated order book data
|
||||
return snapshot
|
||||
```
|
||||
|
||||
### **📊 STATUS MESSAGES UPDATED**
|
||||
|
||||
#### **Before (Simulation):**
|
||||
- ❌ `"COB-SIM BTC/USDT - Update #20, Mid: $107068.03, Spread: 7.1bps"`
|
||||
- ❌ `"Simulated (2 symbols)"`
|
||||
- ❌ `"COB simulation thread started"`
|
||||
|
||||
#### **After (Real Data Only):**
|
||||
- ✅ `"REAL COB Active (2 symbols)"`
|
||||
- ✅ `"No Enhanced Orchestrator COB Integration"` (when missing)
|
||||
- ✅ `"Retrieved REAL COB snapshot for ETH/USDT"`
|
||||
- ✅ `"REAL COB integration connected successfully"`
|
||||
|
||||
### **🚨 CRITICAL SYSTEM MESSAGES**
|
||||
|
||||
#### **If Enhanced Orchestrator Missing COB:**
|
||||
```
|
||||
CRITICAL: Enhanced orchestrator has NO COB integration!
|
||||
This means we're using basic orchestrator instead of enhanced one
|
||||
Dashboard will NOT have real COB data until this is fixed
|
||||
```
|
||||
|
||||
#### **Success Messages:**
|
||||
```
|
||||
REAL COB integration found: <class 'core.cob_integration.COBIntegration'>
|
||||
Registered dashboard callback with REAL COB integration
|
||||
NO SIMULATION - Using live market data only
|
||||
```
|
||||
|
||||
### **🔧 NEXT STEPS REQUIRED**
|
||||
|
||||
#### **1. Verify Enhanced Orchestrator Usage**
|
||||
- ✅ **main.py** correctly uses `EnhancedTradingOrchestrator`
|
||||
- ✅ **COB Integration** properly initialized in orchestrator
|
||||
- 🔍 **Need to verify**: Dashboard receives real COB callbacks
|
||||
|
||||
#### **2. Debug Connection Issues**
|
||||
- Dashboard shows connection attempts but no listening port
|
||||
- Enhanced orchestrator may need COB integration startup verification
|
||||
- Real COB data flow needs testing
|
||||
|
||||
#### **3. Test Real COB Data Display**
|
||||
- Verify COB snapshots contain real market data
|
||||
- Confirm bid/ask levels from actual exchanges
|
||||
- Validate liquidity and spread calculations
|
||||
|
||||
### **💡 VERIFICATION COMMANDS**
|
||||
|
||||
#### **Check COB Integration Status:**
|
||||
```python
|
||||
# In dashboard initialization:
|
||||
logger.info(f"Orchestrator type: {type(self.orchestrator)}")
|
||||
logger.info(f"Has COB integration: {hasattr(self.orchestrator, 'cob_integration')}")
|
||||
logger.info(f"COB integration active: {self.orchestrator.cob_integration is not None}")
|
||||
```
|
||||
|
||||
#### **Test Real COB Data:**
|
||||
```python
|
||||
# Test real COB snapshot retrieval:
|
||||
snapshot = self.orchestrator.cob_integration.get_cob_snapshot('ETH/USDT')
|
||||
logger.info(f"Real COB snapshot: {snapshot}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🚀 LATEST FIXES IMPLEMENTED (Manual Trading & Chart Visualization)
|
||||
|
||||
### 🔧 Manual Trading Buttons - FULLY FIXED ✅
|
||||
|
||||
**Problem**: Manual buy/sell buttons weren't executing trades properly
|
||||
|
||||
**Root Cause Analysis**:
|
||||
- Missing `execute_trade` method in `TradingExecutor`
|
||||
- Missing `get_closed_trades` and `get_current_position` methods
|
||||
- No proper trade record creation and tracking
|
||||
|
||||
**Solution Applied**:
|
||||
1. **Added missing methods to TradingExecutor**:
|
||||
- `execute_trade()` - Direct trade execution with proper error handling
|
||||
- `get_closed_trades()` - Returns trade history in dashboard format
|
||||
- `get_current_position()` - Returns current position information
|
||||
|
||||
2. **Enhanced manual trading execution**:
|
||||
- Proper error handling and trade recording
|
||||
- Real P&L tracking (+$0.05 demo profit for SELL orders)
|
||||
- Session metrics updates (trade count, total P&L, fees)
|
||||
- Visual confirmation of executed vs blocked trades
|
||||
|
||||
3. **Trade record structure**:
|
||||
```python
|
||||
trade_record = {
|
||||
'symbol': symbol,
|
||||
'side': action, # 'BUY' or 'SELL'
|
||||
'quantity': 0.01,
|
||||
'entry_price': current_price,
|
||||
'exit_price': current_price,
|
||||
'entry_time': datetime.now(),
|
||||
'exit_time': datetime.now(),
|
||||
'pnl': demo_pnl, # Real P&L calculation
|
||||
'fees': 0.0,
|
||||
'confidence': 1.0 # Manual trades = 100% confidence
|
||||
}
|
||||
```
|
||||
|
||||
### 📊 Chart Visualization - COMPLETELY SEPARATED ✅
|
||||
|
||||
**Problem**: All signals and trades were mixed together on charts
|
||||
|
||||
**Requirements**:
|
||||
- **1s mini chart**: Show ALL signals (executed + non-executed)
|
||||
- **1m main chart**: Show ONLY executed trades
|
||||
|
||||
**Solution Implemented**:
|
||||
|
||||
#### **1s Mini Chart (Row 2) - ALL SIGNALS:**
|
||||
- ✅ **Executed BUY signals**: Solid green triangles-up
|
||||
- ✅ **Executed SELL signals**: Solid red triangles-down
|
||||
- ✅ **Pending BUY signals**: Hollow green triangles-up
|
||||
- ✅ **Pending SELL signals**: Hollow red triangles-down
|
||||
- ✅ **Independent axis**: Can zoom/pan separately from main chart
|
||||
- ✅ **Real-time updates**: Shows all trading activity
|
||||
|
||||
#### **1m Main Chart (Row 1) - EXECUTED TRADES ONLY:**
|
||||
- ✅ **Executed BUY trades**: Large green circles with confidence hover
|
||||
- ✅ **Executed SELL trades**: Large red circles with confidence hover
|
||||
- ✅ **Professional display**: Clean execution-only view
|
||||
- ✅ **P&L information**: Hover shows actual profit/loss
|
||||
|
||||
#### **Chart Architecture:**
|
||||
```python
|
||||
# Main 1m chart - EXECUTED TRADES ONLY
|
||||
executed_signals = [signal for signal in self.recent_decisions if signal.get('executed', False)]
|
||||
|
||||
# 1s mini chart - ALL SIGNALS
|
||||
all_signals = self.recent_decisions[-50:] # Last 50 signals
|
||||
executed_buys = [s for s in buy_signals if s['executed']]
|
||||
pending_buys = [s for s in buy_signals if not s['executed']]
|
||||
```
|
||||
|
||||
### 🎯 Variable Scope Error - FIXED ✅
|
||||
|
||||
**Problem**: `cannot access local variable 'last_action' where it is not associated with a value`
|
||||
|
||||
**Root Cause**: Variables declared inside conditional blocks weren't accessible when conditions were False
|
||||
|
||||
**Solution Applied**:
|
||||
```python
|
||||
# BEFORE (caused error):
|
||||
if condition:
|
||||
last_action = 'BUY'
|
||||
last_confidence = 0.8
|
||||
# last_action accessed here would fail if condition was False
|
||||
|
||||
# AFTER (fixed):
|
||||
last_action = 'NONE'
|
||||
last_confidence = 0.0
|
||||
if condition:
|
||||
last_action = 'BUY'
|
||||
last_confidence = 0.8
|
||||
# Variables always defined
|
||||
```
|
||||
|
||||
### 🔇 Unicode Logging Errors - FIXED ✅
|
||||
|
||||
**Problem**: `UnicodeEncodeError: 'charmap' codec can't encode character '\U0001f4c8'`
|
||||
|
||||
**Root Cause**: Windows console (cp1252) can't handle Unicode emoji characters
|
||||
|
||||
**Solution Applied**: Removed ALL emoji icons from log messages:
|
||||
- `🚀 Starting...` → `Starting...`
|
||||
- `✅ Success` → `Success`
|
||||
- `📊 Data` → `Data`
|
||||
- `🔧 Fixed` → `Fixed`
|
||||
- `❌ Error` → `Error`
|
||||
|
||||
**Result**: Clean ASCII-only logging compatible with Windows console
|
||||
|
||||
---
|
||||
|
||||
## 🧠 CNN Model Training Implementation
|
||||
|
||||
### A. Williams Market Structure CNN Architecture
|
||||
|
||||
**Model Specifications:**
|
||||
- **Architecture**: Enhanced CNN with ResNet blocks, self-attention, and multi-task learning
|
||||
- **Parameters**: ~50M parameters (Williams) + 400M parameters (COB-RL optimized)
|
||||
- **Input Shape**: (900, 50) - 900 timesteps (1s bars), 50 features per timestep
|
||||
- **Output**: 10-class direction prediction + confidence scores
|
||||
|
||||
**Training Triggers:**
|
||||
1. **Real-time Pivot Detection**: Confirmed local extrema (tops/bottoms)
|
||||
2. **Perfect Move Identification**: >2% price moves within prediction window
|
||||
3. **Negative Case Training**: Failed predictions for intensive learning
|
||||
4. **Multi-timeframe Validation**: 1s, 1m, 1h, 1d consistency checks
|
||||
|
||||
### B. Feature Engineering Pipeline
|
||||
|
||||
**5 Timeseries Universal Format:**
|
||||
1. **ETH/USDT Ticks** (1s) - Primary trading pair real-time data
|
||||
2. **ETH/USDT 1m** - Short-term price action and patterns
|
||||
3. **ETH/USDT 1h** - Medium-term trends and momentum
|
||||
4. **ETH/USDT 1d** - Long-term market structure
|
||||
5. **BTC/USDT Ticks** (1s) - Reference asset for correlation analysis
|
||||
|
||||
**Feature Matrix Construction:**
|
||||
```python
|
||||
# Williams Market Structure Features (900x50 matrix)
|
||||
- OHLCV data (5 cols)
|
||||
- Technical indicators (15 cols)
|
||||
- Market microstructure (10 cols)
|
||||
- COB integration features (10 cols)
|
||||
- Cross-asset correlation (5 cols)
|
||||
- Temporal dynamics (5 cols)
|
||||
```
|
||||
|
||||
### C. Retrospective Training System
|
||||
|
||||
**Perfect Move Detection:**
|
||||
- **Threshold**: 2% price change within 15-minute window
|
||||
- **Context**: 200-candle history for enhanced pattern recognition
|
||||
- **Validation**: Multi-timeframe confirmation (1s→1m→1h consistency)
|
||||
- **Auto-labeling**: Optimal action determination for supervised learning
|
||||
|
||||
**Training Data Pipeline:**
|
||||
```
|
||||
Market Event → Extrema Detection → Perfect Move Validation → Feature Matrix → CNN Training
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🎯 Decision-Making Model Training System
|
||||
|
||||
### A. Neural Decision Fusion Architecture
|
||||
|
||||
**Model Integration Weights:**
|
||||
- **CNN Predictions**: 70% weight (Williams Market Structure)
|
||||
- **RL Agent Decisions**: 30% weight (DQN with sensitivity levels)
|
||||
- **COB RL Integration**: Dynamic weight based on market conditions
|
||||
|
||||
**Decision Fusion Process:**
|
||||
```python
|
||||
# Neural Decision Fusion combines all model predictions
|
||||
williams_pred = cnn_model.predict(market_state) # 70% weight
|
||||
dqn_action = rl_agent.act(state_vector) # 30% weight
|
||||
cob_signal = cob_rl.get_direction(order_book_state) # Variable weight
|
||||
|
||||
final_decision = neural_fusion.combine(williams_pred, dqn_action, cob_signal)
|
||||
```
|
||||
|
||||
### B. Enhanced Training Weight System
|
||||
|
||||
**Training Weight Multipliers:**
|
||||
- **Regular Predictions**: 1× base weight
|
||||
- **Signal Accumulation**: 1× weight (3+ confident predictions)
|
||||
- **🔥 Actual Trade Execution**: 10× weight multiplier**
|
||||
- **P&L-based Reward**: Enhanced feedback loop
|
||||
|
||||
**Trade Execution Enhanced Learning:**
|
||||
```python
|
||||
# 10× weight for actual trade outcomes
|
||||
if trade_executed:
|
||||
enhanced_reward = pnl_ratio * 10.0
|
||||
model.train_on_batch(state, action, enhanced_reward)
|
||||
|
||||
# Immediate training on last 3 signals that led to trade
|
||||
for signal in last_3_signals:
|
||||
model.retrain_signal(signal, actual_outcome)
|
||||
```
|
||||
|
||||
### C. Sensitivity Learning DQN
|
||||
|
||||
**5 Sensitivity Levels:**
|
||||
- **very_low** (0.1): Conservative, high-confidence only
|
||||
- **low** (0.3): Selective entry/exit
|
||||
- **medium** (0.5): Balanced approach
|
||||
- **high** (0.7): Aggressive trading
|
||||
- **very_high** (0.9): Maximum activity
|
||||
|
||||
**Adaptive Threshold System:**
|
||||
```python
|
||||
# Sensitivity affects confidence thresholds
|
||||
entry_threshold = base_threshold * sensitivity_multiplier
|
||||
exit_threshold = base_threshold * (1 - sensitivity_level)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📊 Dashboard Visualization and Model Monitoring
|
||||
|
||||
### A. Real-time Model Predictions Display
|
||||
|
||||
**Model Status Section:**
|
||||
- ✅ **Loaded Models**: DQN (5M params), CNN (50M params), COB-RL (400M params)
|
||||
- ✅ **Real-time Loss Tracking**: 5-MA loss for each model
|
||||
- ✅ **Prediction Counts**: Total predictions generated per model
|
||||
- ✅ **Last Prediction**: Timestamp, action, confidence for each model
|
||||
|
||||
**Training Metrics Visualization:**
|
||||
```python
|
||||
# Real-time model performance tracking
|
||||
{
|
||||
'dqn': {
|
||||
'active': True,
|
||||
'parameters': 5000000,
|
||||
'loss_5ma': 0.0234,
|
||||
'last_prediction': {'action': 'BUY', 'confidence': 0.67},
|
||||
'epsilon': 0.15 # Exploration rate
|
||||
},
|
||||
'cnn': {
|
||||
'active': True,
|
||||
'parameters': 50000000,
|
||||
'loss_5ma': 0.0198,
|
||||
'last_prediction': {'action': 'HOLD', 'confidence': 0.45}
|
||||
},
|
||||
'cob_rl': {
|
||||
'active': True,
|
||||
'parameters': 400000000,
|
||||
'loss_5ma': 0.012,
|
||||
'predictions_count': 1247
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### B. Training Progress Monitoring
|
||||
|
||||
**Loss Visualization:**
|
||||
- **Real-time Loss Charts**: 5-minute moving average for each model
|
||||
- **Training Status**: Active sessions, parameter counts, update frequencies
|
||||
- **Signal Generation**: ACTIVE/INACTIVE status with last update timestamps
|
||||
|
||||
**Performance Metrics Dashboard:**
|
||||
- **Session P&L**: Real-time profit/loss tracking
|
||||
- **Trade Accuracy**: Success rate of executed trades
|
||||
- **Model Confidence Trends**: Average confidence over time
|
||||
- **Training Iterations**: Progress tracking for continuous learning
|
||||
|
||||
### C. COB Integration Visualization
|
||||
|
||||
**Real-time COB Data Display:**
|
||||
- **Order Book Levels**: Bid/ask spreads and liquidity depth
|
||||
- **Exchange Breakdown**: Multi-exchange liquidity sources
|
||||
- **Market Microstructure**: Imbalance ratios and flow analysis
|
||||
- **COB Feature Status**: CNN features and RL state availability
|
||||
|
||||
**Training Pipeline Integration:**
|
||||
- **COB → CNN Features**: Real-time market microstructure patterns
|
||||
- **COB → RL States**: Enhanced state vectors for decision making
|
||||
- **Performance Tracking**: COB integration health monitoring
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Key System Capabilities
|
||||
|
||||
### Real-time Learning Pipeline
|
||||
1. **Market Data Ingestion**: 5 timeseries universal format
|
||||
2. **Feature Engineering**: Multi-timeframe analysis with COB integration
|
||||
3. **Model Predictions**: CNN, DQN, and COB-RL ensemble
|
||||
4. **Decision Fusion**: Neural network combines all predictions
|
||||
5. **Trade Execution**: 10× enhanced learning from actual trades
|
||||
6. **Retrospective Training**: Perfect move detection and model updates
|
||||
|
||||
### Enhanced Training Systems
|
||||
- **Continuous Learning**: Models update in real-time from market outcomes
|
||||
- **Multi-modal Integration**: CNN + RL + COB predictions combined intelligently
|
||||
- **Sensitivity Adaptation**: DQN adjusts risk appetite based on performance
|
||||
- **Perfect Move Detection**: Automatic identification of optimal trading opportunities
|
||||
- **Negative Case Training**: Intensive learning from failed predictions
|
||||
|
||||
### Dashboard Monitoring
|
||||
- **Real-time Model Status**: Active models, parameters, loss tracking
|
||||
- **Live Predictions**: Current model outputs with confidence scores
|
||||
- **Training Metrics**: Loss trends, accuracy rates, iteration counts
|
||||
- **COB Integration**: Real-time order book analysis and microstructure data
|
||||
- **Performance Tracking**: P&L, trade accuracy, model effectiveness
|
||||
|
||||
The system provides a comprehensive ML-driven trading environment with real-time learning, multi-modal decision making, and advanced market microstructure analysis through COB integration.
|
||||
|
||||
**Dashboard URL**: http://127.0.0.1:8051
|
||||
**Status**: ✅ FULLY OPERATIONAL
|
@ -1,194 +0,0 @@
|
||||
# Enhanced Training Integration Report
|
||||
*Generated: 2024-12-19*
|
||||
|
||||
## 🎯 Integration Objective
|
||||
|
||||
Integrate the restored `EnhancedRealtimeTrainingSystem` into the orchestrator and audit the `EnhancedRLTrainingIntegrator` to determine if it can be used for comprehensive RL training.
|
||||
|
||||
## 📊 EnhancedRealtimeTrainingSystem Analysis
|
||||
|
||||
### **✅ Successfully Integrated**
|
||||
|
||||
The `EnhancedRealtimeTrainingSystem` has been successfully integrated into the orchestrator with the following capabilities:
|
||||
|
||||
#### **Core Features**
|
||||
- **Real-time Data Collection**: Multi-timeframe OHLCV, tick data, COB snapshots
|
||||
- **Enhanced DQN Training**: Prioritized experience replay with market-aware rewards
|
||||
- **CNN Training**: Real-time pattern recognition training
|
||||
- **Forward-looking Predictions**: Generates predictions for future validation
|
||||
- **Adaptive Learning**: Adjusts training frequency based on performance
|
||||
- **Comprehensive State Building**: 13,400+ feature states for RL training
|
||||
|
||||
#### **Integration Points in Orchestrator**
|
||||
```python
|
||||
# New orchestrator capabilities:
|
||||
self.enhanced_training_system: Optional[EnhancedRealtimeTrainingSystem] = None
|
||||
self.training_enabled: bool = enhanced_rl_training and ENHANCED_TRAINING_AVAILABLE
|
||||
|
||||
# Methods added:
|
||||
def _initialize_enhanced_training_system()
|
||||
def start_enhanced_training()
|
||||
def stop_enhanced_training()
|
||||
def get_enhanced_training_stats()
|
||||
def set_training_dashboard(dashboard)
|
||||
```
|
||||
|
||||
#### **Training Capabilities**
|
||||
1. **Real-time Data Streams**:
|
||||
- OHLCV data (1m, 5m intervals)
|
||||
- Tick-level market data
|
||||
- COB (Change of Bid) snapshots
|
||||
- Market event detection
|
||||
|
||||
2. **Enhanced Model Training**:
|
||||
- DQN with prioritized experience replay
|
||||
- CNN with multi-timeframe features
|
||||
- Comprehensive reward engineering
|
||||
- Performance-based adaptation
|
||||
|
||||
3. **Prediction Tracking**:
|
||||
- Forward-looking predictions with validation
|
||||
- Accuracy measurement and tracking
|
||||
- Model confidence scoring
|
||||
|
||||
## 🔍 EnhancedRLTrainingIntegrator Audit
|
||||
|
||||
### **Purpose & Scope**
|
||||
The `EnhancedRLTrainingIntegrator` is a comprehensive testing and validation system designed to:
|
||||
- Verify 13,400-feature comprehensive state building
|
||||
- Test enhanced pivot-based reward calculation
|
||||
- Validate Williams market structure integration
|
||||
- Demonstrate live comprehensive training
|
||||
|
||||
### **Audit Results**
|
||||
|
||||
#### **✅ Valuable Components**
|
||||
1. **Comprehensive State Verification**: Tests for exactly 13,400 features
|
||||
2. **Feature Distribution Analysis**: Analyzes non-zero vs zero features
|
||||
3. **Enhanced Reward Testing**: Validates pivot-based reward calculations
|
||||
4. **Williams Integration**: Tests market structure feature extraction
|
||||
5. **Live Training Demo**: Demonstrates coordinated decision making
|
||||
|
||||
#### **🔧 Integration Challenges**
|
||||
1. **Dependency Issues**: References `core.enhanced_orchestrator.EnhancedTradingOrchestrator` (not available)
|
||||
2. **Missing Methods**: Expects methods not present in current orchestrator:
|
||||
- `build_comprehensive_rl_state()`
|
||||
- `calculate_enhanced_pivot_reward()`
|
||||
- `make_coordinated_decisions()`
|
||||
3. **Williams Module**: Depends on `training.williams_market_structure` (needs verification)
|
||||
|
||||
#### **💡 Recommended Usage**
|
||||
The `EnhancedRLTrainingIntegrator` should be used as a **testing and validation tool** rather than direct integration:
|
||||
|
||||
```python
|
||||
# Use as standalone testing script
|
||||
python enhanced_rl_training_integration.py
|
||||
|
||||
# Or import specific testing functions
|
||||
from enhanced_rl_training_integration import EnhancedRLTrainingIntegrator
|
||||
integrator = EnhancedRLTrainingIntegrator()
|
||||
await integrator._verify_comprehensive_state_building()
|
||||
```
|
||||
|
||||
## 🚀 Implementation Strategy
|
||||
|
||||
### **Phase 1: EnhancedRealtimeTrainingSystem (✅ COMPLETE)**
|
||||
- [x] Integrated into orchestrator
|
||||
- [x] Added initialization methods
|
||||
- [x] Connected to data provider
|
||||
- [x] Dashboard integration support
|
||||
|
||||
### **Phase 2: Enhanced Methods (🔄 IN PROGRESS)**
|
||||
Add missing methods expected by the integrator:
|
||||
|
||||
```python
|
||||
# Add to orchestrator:
|
||||
def build_comprehensive_rl_state(self, symbol: str) -> Optional[np.ndarray]:
|
||||
"""Build comprehensive 13,400+ feature state for RL training"""
|
||||
|
||||
def calculate_enhanced_pivot_reward(self, trade_decision: Dict,
|
||||
market_data: Dict,
|
||||
trade_outcome: Dict) -> float:
|
||||
"""Calculate enhanced pivot-based rewards"""
|
||||
|
||||
async def make_coordinated_decisions(self) -> Dict[str, TradingDecision]:
|
||||
"""Make coordinated decisions across all symbols"""
|
||||
```
|
||||
|
||||
### **Phase 3: Validation Integration (📋 PLANNED)**
|
||||
Use `EnhancedRLTrainingIntegrator` as a validation tool:
|
||||
|
||||
```python
|
||||
# Integration validation workflow:
|
||||
1. Start enhanced training system
|
||||
2. Run comprehensive state building tests
|
||||
3. Validate reward calculation accuracy
|
||||
4. Test Williams market structure integration
|
||||
5. Monitor live training performance
|
||||
```
|
||||
|
||||
## 📈 Benefits of Integration
|
||||
|
||||
### **Real-time Learning**
|
||||
- Continuous model improvement during live trading
|
||||
- Adaptive learning based on market conditions
|
||||
- Forward-looking prediction validation
|
||||
|
||||
### **Comprehensive Features**
|
||||
- 13,400+ feature comprehensive states
|
||||
- Multi-timeframe market analysis
|
||||
- COB microstructure integration
|
||||
- Enhanced reward engineering
|
||||
|
||||
### **Performance Monitoring**
|
||||
- Real-time training statistics
|
||||
- Model accuracy tracking
|
||||
- Adaptive parameter adjustment
|
||||
- Comprehensive logging
|
||||
|
||||
## 🎯 Next Steps
|
||||
|
||||
### **Immediate Actions**
|
||||
1. **Complete Method Implementation**: Add missing orchestrator methods
|
||||
2. **Williams Module Verification**: Ensure market structure module is available
|
||||
3. **Testing Integration**: Use integrator for validation testing
|
||||
4. **Dashboard Connection**: Connect training system to dashboard
|
||||
|
||||
### **Future Enhancements**
|
||||
1. **Multi-Symbol Coordination**: Enhance coordinated decision making
|
||||
2. **Advanced Reward Engineering**: Implement sophisticated reward functions
|
||||
3. **Model Ensemble**: Combine multiple model predictions
|
||||
4. **Performance Optimization**: GPU acceleration for training
|
||||
|
||||
## 📊 Integration Status
|
||||
|
||||
| Component | Status | Notes |
|
||||
|-----------|--------|-------|
|
||||
| EnhancedRealtimeTrainingSystem | ✅ Integrated | Fully functional in orchestrator |
|
||||
| Real-time Data Collection | ✅ Available | Multi-timeframe data streams |
|
||||
| Enhanced DQN Training | ✅ Available | Prioritized experience replay |
|
||||
| CNN Training | ✅ Available | Pattern recognition training |
|
||||
| Forward Predictions | ✅ Available | Prediction validation system |
|
||||
| EnhancedRLTrainingIntegrator | 🔧 Partial | Use as validation tool |
|
||||
| Comprehensive State Building | 📋 Planned | Need to implement method |
|
||||
| Enhanced Reward Calculation | 📋 Planned | Need to implement method |
|
||||
| Williams Integration | ❓ Unknown | Need to verify module |
|
||||
|
||||
## 🏆 Conclusion
|
||||
|
||||
The `EnhancedRealtimeTrainingSystem` has been successfully integrated into the orchestrator, providing comprehensive real-time training capabilities. The `EnhancedRLTrainingIntegrator` serves as an excellent validation and testing tool, but requires additional method implementations in the orchestrator for full functionality.
|
||||
|
||||
**Key Achievements:**
|
||||
- ✅ Real-time training system fully integrated
|
||||
- ✅ Comprehensive feature extraction capabilities
|
||||
- ✅ Enhanced reward engineering framework
|
||||
- ✅ Forward-looking prediction validation
|
||||
- ✅ Performance monitoring and adaptation
|
||||
|
||||
**Recommended Actions:**
|
||||
1. Use the integrated training system for live model improvement
|
||||
2. Implement missing orchestrator methods for full integrator compatibility
|
||||
3. Use the integrator as a comprehensive testing and validation tool
|
||||
4. Monitor training performance and adapt parameters as needed
|
||||
|
||||
The integration provides a solid foundation for advanced ML-driven trading with continuous learning capabilities.
|
@ -201,6 +201,9 @@ class DataProvider:
|
||||
self.last_pivot_calculation: Dict[str, datetime] = {}
|
||||
self.pivot_calculation_interval = timedelta(minutes=5) # Recalculate every 5 minutes
|
||||
|
||||
# Auto-fix corrupted cache files on startup
|
||||
self._auto_fix_corrupted_cache()
|
||||
|
||||
# Load existing pivot bounds from cache
|
||||
self._load_all_pivot_bounds()
|
||||
|
||||
@ -1231,6 +1234,36 @@ class DataProvider:
|
||||
return symbol # Return first symbol for now - can be improved
|
||||
return None
|
||||
|
||||
# === CACHE MANAGEMENT ===
|
||||
|
||||
def _auto_fix_corrupted_cache(self):
|
||||
"""Automatically fix corrupted cache files on startup"""
|
||||
try:
|
||||
from utils.cache_manager import get_cache_manager
|
||||
cache_manager = get_cache_manager()
|
||||
|
||||
# Quick health check
|
||||
health_summary = cache_manager.get_cache_summary()
|
||||
|
||||
if health_summary['corrupted_files'] > 0:
|
||||
logger.warning(f"Found {health_summary['corrupted_files']} corrupted cache files, cleaning up...")
|
||||
|
||||
# Auto-cleanup corrupted files (no confirmation needed)
|
||||
deleted_files = cache_manager.cleanup_corrupted_files(dry_run=False)
|
||||
|
||||
deleted_count = 0
|
||||
for cache_dir, files in deleted_files.items():
|
||||
for file_info in files:
|
||||
if "DELETED:" in file_info:
|
||||
deleted_count += 1
|
||||
|
||||
logger.info(f"Auto-cleaned {deleted_count} corrupted cache files")
|
||||
else:
|
||||
logger.debug("Cache health check passed - no corrupted files found")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache auto-fix failed: {e}")
|
||||
|
||||
# === PIVOT BOUNDS CACHING ===
|
||||
|
||||
def _load_all_pivot_bounds(self):
|
||||
@ -1285,13 +1318,25 @@ class DataProvider:
|
||||
logger.info(f"Loaded {len(df)} 1m candles from cache for {symbol}")
|
||||
return df
|
||||
except Exception as parquet_e:
|
||||
# Handle corrupted Parquet file
|
||||
if "Parquet magic bytes not found" in str(parquet_e) or "corrupted" in str(parquet_e).lower():
|
||||
# Handle corrupted Parquet file - expanded error detection
|
||||
error_str = str(parquet_e).lower()
|
||||
corrupted_indicators = [
|
||||
"parquet magic bytes not found",
|
||||
"corrupted",
|
||||
"couldn't deserialize thrift",
|
||||
"don't know what type",
|
||||
"invalid parquet file",
|
||||
"unexpected end of file",
|
||||
"invalid metadata"
|
||||
]
|
||||
|
||||
if any(indicator in error_str for indicator in corrupted_indicators):
|
||||
logger.warning(f"Corrupted Parquet cache file for {symbol}, removing and returning None: {parquet_e}")
|
||||
try:
|
||||
cache_file.unlink() # Delete corrupted file
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Deleted corrupted monthly cache file: {cache_file}")
|
||||
except Exception as delete_e:
|
||||
logger.error(f"Failed to delete corrupted monthly cache file: {delete_e}")
|
||||
return None
|
||||
else:
|
||||
raise parquet_e
|
||||
@ -1393,13 +1438,25 @@ class DataProvider:
|
||||
logger.debug(f"Loaded {len(df)} rows from cache for {symbol} {timeframe} (age: {cache_age/60:.1f}min)")
|
||||
return df
|
||||
except Exception as parquet_e:
|
||||
# Handle corrupted Parquet file
|
||||
if "Parquet magic bytes not found" in str(parquet_e) or "corrupted" in str(parquet_e).lower():
|
||||
# Handle corrupted Parquet file - expanded error detection
|
||||
error_str = str(parquet_e).lower()
|
||||
corrupted_indicators = [
|
||||
"parquet magic bytes not found",
|
||||
"corrupted",
|
||||
"couldn't deserialize thrift",
|
||||
"don't know what type",
|
||||
"invalid parquet file",
|
||||
"unexpected end of file",
|
||||
"invalid metadata"
|
||||
]
|
||||
|
||||
if any(indicator in error_str for indicator in corrupted_indicators):
|
||||
logger.warning(f"Corrupted Parquet cache file for {symbol} {timeframe}, removing and returning None: {parquet_e}")
|
||||
try:
|
||||
cache_file.unlink() # Delete corrupted file
|
||||
except Exception:
|
||||
pass
|
||||
logger.info(f"Deleted corrupted cache file: {cache_file}")
|
||||
except Exception as delete_e:
|
||||
logger.error(f"Failed to delete corrupted cache file: {delete_e}")
|
||||
return None
|
||||
else:
|
||||
raise parquet_e
|
||||
|
@ -38,6 +38,11 @@ from NN.models.cob_rl_model import COBRLModelInterface # Specific import for COB
|
||||
from NN.models.model_interfaces import ModelInterface as NNModelInterface, CNNModelInterface as NNCNNModelInterface, RLAgentInterface as NNRLAgentInterface, ExtremaTrainerInterface as NNExtremaTrainerInterface # Import from new file
|
||||
from core.extrema_trainer import ExtremaTrainer # Import ExtremaTrainer for its interface
|
||||
|
||||
# Import new logging and database systems
|
||||
from utils.inference_logger import get_inference_logger, log_model_inference
|
||||
from utils.database_manager import get_database_manager
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
|
||||
# Import COB integration for real-time market microstructure data
|
||||
try:
|
||||
from .cob_integration import COBIntegration
|
||||
@ -213,6 +218,10 @@ class TradingOrchestrator:
|
||||
# Initialize inference history for each model (will be populated as models make predictions)
|
||||
# We'll create entries dynamically as models are used
|
||||
|
||||
# Initialize inference logger
|
||||
self.inference_logger = get_inference_logger()
|
||||
self.db_manager = get_database_manager()
|
||||
|
||||
# ENHANCED: Real-time Training System Integration
|
||||
self.enhanced_training_system = None # Will be set to EnhancedRealtimeTrainingSystem if available
|
||||
# Enable training by default - don't depend on external training system
|
||||
@ -232,6 +241,9 @@ class TradingOrchestrator:
|
||||
self.data_provider.start_centralized_data_collection()
|
||||
logger.info("Centralized data collection started - all models and dashboard will receive data")
|
||||
|
||||
# Initialize database cleanup task
|
||||
self._schedule_database_cleanup()
|
||||
|
||||
# CRITICAL: Initialize checkpoint manager for saving training progress
|
||||
self.checkpoint_manager = None
|
||||
self.training_iterations = 0 # Track training iterations for periodic saves
|
||||
@ -265,24 +277,23 @@ class TradingOrchestrator:
|
||||
self.rl_agent = DQNAgent(state_shape=state_size, n_actions=action_size)
|
||||
self.rl_agent.to(self.device) # Move DQN agent to the determined device
|
||||
|
||||
# Load best checkpoint and capture initial state
|
||||
# Load best checkpoint and capture initial state (using database metadata)
|
||||
checkpoint_loaded = False
|
||||
if hasattr(self.rl_agent, 'load_best_checkpoint'):
|
||||
try:
|
||||
self.rl_agent.load_best_checkpoint() # This loads the state into the model
|
||||
# Check if we have checkpoints available
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("dqn_agent")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
# Check if we have checkpoints available using database metadata (fast!)
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata("dqn_agent")
|
||||
if checkpoint_metadata:
|
||||
self.model_states['dqn']['initial_loss'] = 0.412
|
||||
self.model_states['dqn']['current_loss'] = metadata.loss
|
||||
self.model_states['dqn']['best_loss'] = metadata.loss
|
||||
self.model_states['dqn']['current_loss'] = checkpoint_metadata.performance_metrics.get('loss', 0.0)
|
||||
self.model_states['dqn']['best_loss'] = checkpoint_metadata.performance_metrics.get('loss', 0.0)
|
||||
self.model_states['dqn']['checkpoint_loaded'] = True
|
||||
self.model_states['dqn']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
self.model_states['dqn']['checkpoint_filename'] = checkpoint_metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "unknown"
|
||||
logger.info(f"DQN checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||
logger.info(f"DQN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading DQN checkpoint: {e}")
|
||||
|
||||
@ -307,21 +318,20 @@ class TradingOrchestrator:
|
||||
self.cnn_model = self.cnn_adapter.model # Keep reference for compatibility
|
||||
self.cnn_optimizer = optim.Adam(self.cnn_model.parameters(), lr=0.001) # Initialize optimizer for CNN
|
||||
|
||||
# Load best checkpoint and capture initial state
|
||||
# Load best checkpoint and capture initial state (using database metadata)
|
||||
checkpoint_loaded = False
|
||||
try:
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("enhanced_cnn")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata("enhanced_cnn")
|
||||
if checkpoint_metadata:
|
||||
self.model_states['cnn']['initial_loss'] = 0.412
|
||||
self.model_states['cnn']['current_loss'] = metadata.loss or 0.0187
|
||||
self.model_states['cnn']['best_loss'] = metadata.loss or 0.0134
|
||||
self.model_states['cnn']['current_loss'] = checkpoint_metadata.performance_metrics.get('loss', 0.0187)
|
||||
self.model_states['cnn']['best_loss'] = checkpoint_metadata.performance_metrics.get('loss', 0.0134)
|
||||
self.model_states['cnn']['checkpoint_loaded'] = True
|
||||
self.model_states['cnn']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
self.model_states['cnn']['checkpoint_filename'] = checkpoint_metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "unknown"
|
||||
logger.info(f"CNN checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||
logger.info(f"CNN checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading CNN checkpoint: {e}")
|
||||
|
||||
@ -399,23 +409,22 @@ class TradingOrchestrator:
|
||||
if hasattr(self.cob_rl_agent, 'to'):
|
||||
self.cob_rl_agent.to(self.device)
|
||||
|
||||
# Load best checkpoint and capture initial state
|
||||
# Load best checkpoint and capture initial state (using database metadata)
|
||||
checkpoint_loaded = False
|
||||
if hasattr(self.cob_rl_agent, 'load_model'):
|
||||
try:
|
||||
self.cob_rl_agent.load_model() # This loads the state into the model
|
||||
from utils.checkpoint_manager import load_best_checkpoint
|
||||
result = load_best_checkpoint("cob_rl_model")
|
||||
if result:
|
||||
file_path, metadata = result
|
||||
self.model_states['cob_rl']['initial_loss'] = getattr(metadata, 'initial_loss', None)
|
||||
self.model_states['cob_rl']['current_loss'] = metadata.loss
|
||||
self.model_states['cob_rl']['best_loss'] = metadata.loss
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata("cob_rl")
|
||||
if checkpoint_metadata:
|
||||
self.model_states['cob_rl']['initial_loss'] = checkpoint_metadata.training_metadata.get('initial_loss', None)
|
||||
self.model_states['cob_rl']['current_loss'] = checkpoint_metadata.performance_metrics.get('loss', 0.0)
|
||||
self.model_states['cob_rl']['best_loss'] = checkpoint_metadata.performance_metrics.get('loss', 0.0)
|
||||
self.model_states['cob_rl']['checkpoint_loaded'] = True
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = metadata.checkpoint_id
|
||||
self.model_states['cob_rl']['checkpoint_filename'] = checkpoint_metadata.checkpoint_id
|
||||
checkpoint_loaded = True
|
||||
loss_str = f"{metadata.loss:.4f}" if metadata.loss is not None else "unknown"
|
||||
logger.info(f"COB RL checkpoint loaded: {metadata.checkpoint_id} (loss={loss_str})")
|
||||
loss_str = f"{checkpoint_metadata.performance_metrics.get('loss', 0.0):.4f}"
|
||||
logger.info(f"COB RL checkpoint loaded: {checkpoint_metadata.checkpoint_id} (loss={loss_str})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading COB RL checkpoint: {e}")
|
||||
|
||||
@ -1247,51 +1256,210 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error storing inference data for {model_name}: {e}")
|
||||
|
||||
async def _save_inference_to_disk_async(self, model_name: str, inference_record: Dict):
|
||||
"""Async save inference record to disk with file capping"""
|
||||
"""Async save inference record to SQLite database and model-specific log"""
|
||||
try:
|
||||
# Create model-specific directory
|
||||
model_dir = Path(f"training_data/inference_history/{model_name}")
|
||||
model_dir.mkdir(parents=True, exist_ok=True)
|
||||
# Use SQLite for comprehensive storage
|
||||
await self._save_to_sqlite_db(model_name, inference_record)
|
||||
|
||||
# Create filename with timestamp
|
||||
timestamp_str = datetime.fromisoformat(inference_record['timestamp']).strftime('%Y%m%d_%H%M%S_%f')[:-3]
|
||||
filename = f"inference_{timestamp_str}.json"
|
||||
filepath = model_dir / filename
|
||||
|
||||
# Convert to JSON-serializable format
|
||||
serializable_record = self._make_json_serializable(inference_record)
|
||||
|
||||
# Save to file
|
||||
with open(filepath, 'w') as f:
|
||||
json.dump(serializable_record, f, indent=2)
|
||||
|
||||
# Cap files per model (keep only latest 200)
|
||||
await self._cap_model_files(model_dir)
|
||||
|
||||
logger.debug(f"Saved inference record to disk: {filepath}")
|
||||
# Also save key metrics to model-specific log for debugging
|
||||
await self._save_to_model_log(model_name, inference_record)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving inference to disk for {model_name}: {e}")
|
||||
|
||||
async def _cap_model_files(self, model_dir: Path):
|
||||
"""Cap the number of files per model to max_disk_files_per_model"""
|
||||
async def _save_to_sqlite_db(self, model_name: str, inference_record: Dict):
|
||||
"""Save inference record to SQLite database"""
|
||||
import sqlite3
|
||||
import asyncio
|
||||
|
||||
def save_to_db():
|
||||
try:
|
||||
# Create database directory
|
||||
db_dir = Path("training_data/inference_db")
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Connect to SQLite database
|
||||
db_path = db_dir / "inference_history.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create table if it doesn't exist
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS inference_records (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
probabilities TEXT,
|
||||
timeframe TEXT,
|
||||
metadata TEXT,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
''')
|
||||
|
||||
# Create index for faster queries
|
||||
cursor.execute('''
|
||||
CREATE INDEX IF NOT EXISTS idx_model_timestamp
|
||||
ON inference_records(model_name, timestamp)
|
||||
''')
|
||||
|
||||
# Extract data from inference record
|
||||
prediction = inference_record.get('prediction', {})
|
||||
probabilities_str = str(prediction.get('probabilities', {}))
|
||||
metadata_str = str(inference_record.get('metadata', {}))
|
||||
|
||||
# Insert record
|
||||
cursor.execute('''
|
||||
INSERT INTO inference_records
|
||||
(model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
model_name,
|
||||
inference_record.get('symbol', 'ETH/USDT'),
|
||||
inference_record.get('timestamp', ''),
|
||||
prediction.get('action', 'HOLD'),
|
||||
prediction.get('confidence', 0.0),
|
||||
probabilities_str,
|
||||
prediction.get('timeframe', '1m'),
|
||||
metadata_str
|
||||
))
|
||||
|
||||
# Clean up old records (keep only last 1000 per model)
|
||||
cursor.execute('''
|
||||
DELETE FROM inference_records
|
||||
WHERE model_name = ? AND id NOT IN (
|
||||
SELECT id FROM inference_records
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT 1000
|
||||
)
|
||||
''', (model_name, model_name))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving to SQLite database: {e}")
|
||||
|
||||
# Run database operation in thread pool to avoid blocking
|
||||
await asyncio.get_event_loop().run_in_executor(None, save_to_db)
|
||||
|
||||
async def _save_to_model_log(self, model_name: str, inference_record: Dict):
|
||||
"""Save key inference metrics to model-specific log file for debugging"""
|
||||
import asyncio
|
||||
|
||||
def save_to_log():
|
||||
try:
|
||||
# Create logs directory
|
||||
logs_dir = Path("logs/model_inference")
|
||||
logs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Create model-specific log file
|
||||
log_file = logs_dir / f"{model_name}_inference.log"
|
||||
|
||||
# Extract key metrics
|
||||
prediction = inference_record.get('prediction', {})
|
||||
timestamp = inference_record.get('timestamp', '')
|
||||
symbol = inference_record.get('symbol', 'N/A')
|
||||
|
||||
# Format log entry with key metrics
|
||||
log_entry = (
|
||||
f"{timestamp} | "
|
||||
f"Symbol: {symbol} | "
|
||||
f"Action: {prediction.get('action', 'N/A'):4} | "
|
||||
f"Confidence: {prediction.get('confidence', 0.0):6.3f} | "
|
||||
f"Timeframe: {prediction.get('timeframe', 'N/A'):3} | "
|
||||
f"Probs: BUY={prediction.get('probabilities', {}).get('BUY', 0.0):5.3f} "
|
||||
f"SELL={prediction.get('probabilities', {}).get('SELL', 0.0):5.3f} "
|
||||
f"HOLD={prediction.get('probabilities', {}).get('HOLD', 0.0):5.3f}\n"
|
||||
)
|
||||
|
||||
# Append to log file
|
||||
with open(log_file, 'a', encoding='utf-8') as f:
|
||||
f.write(log_entry)
|
||||
|
||||
# Keep log files manageable (rotate when > 10MB)
|
||||
if log_file.stat().st_size > 10 * 1024 * 1024: # 10MB
|
||||
self._rotate_log_file(log_file)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving to model log: {e}")
|
||||
|
||||
# Run log operation in thread pool to avoid blocking
|
||||
await asyncio.get_event_loop().run_in_executor(None, save_to_log)
|
||||
|
||||
def _rotate_log_file(self, log_file: Path):
|
||||
"""Rotate log file when it gets too large"""
|
||||
try:
|
||||
# Get all inference files
|
||||
files = list(model_dir.glob("inference_*.json"))
|
||||
# Keep last 1000 lines
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
if len(files) > self.max_disk_files_per_model:
|
||||
# Sort by modification time (oldest first)
|
||||
files.sort(key=lambda x: x.stat().st_mtime)
|
||||
|
||||
# Remove oldest files
|
||||
files_to_remove = files[:-self.max_disk_files_per_model]
|
||||
for file_path in files_to_remove:
|
||||
file_path.unlink()
|
||||
|
||||
logger.debug(f"Removed {len(files_to_remove)} old inference files from {model_dir.name}")
|
||||
# Write back only the last 1000 lines
|
||||
with open(log_file, 'w', encoding='utf-8') as f:
|
||||
f.writelines(lines[-1000:])
|
||||
|
||||
logger.debug(f"Rotated log file {log_file.name} (kept last 1000 lines)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error capping model files in {model_dir}: {e}")
|
||||
logger.error(f"Error rotating log file {log_file}: {e}")
|
||||
|
||||
def get_inference_records_from_db(self, model_name: str = None, limit: int = 100) -> List[Dict]:
|
||||
"""Get inference records from SQLite database"""
|
||||
import sqlite3
|
||||
|
||||
try:
|
||||
# Connect to database
|
||||
db_path = Path("training_data/inference_db/inference_history.db")
|
||||
if not db_path.exists():
|
||||
return []
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Query records
|
||||
if model_name:
|
||||
cursor.execute('''
|
||||
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
|
||||
FROM inference_records
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
''', (model_name, limit))
|
||||
else:
|
||||
cursor.execute('''
|
||||
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
|
||||
FROM inference_records
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
''', (limit,))
|
||||
|
||||
records = []
|
||||
for row in cursor.fetchall():
|
||||
record = {
|
||||
'model_name': row[0],
|
||||
'symbol': row[1],
|
||||
'timestamp': row[2],
|
||||
'prediction': {
|
||||
'action': row[3],
|
||||
'confidence': row[4],
|
||||
'probabilities': eval(row[5]) if row[5] else {},
|
||||
'timeframe': row[6]
|
||||
},
|
||||
'metadata': eval(row[7]) if row[7] else {}
|
||||
}
|
||||
records.append(record)
|
||||
|
||||
conn.close()
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error querying SQLite database: {e}")
|
||||
return []
|
||||
|
||||
|
||||
|
||||
def _prepare_cnn_input_data(self, ohlcv_data: Dict, cob_data: Any, technical_indicators: Dict) -> torch.Tensor:
|
||||
"""Prepare standardized input data for CNN models with proper GPU device placement"""
|
||||
@ -1472,67 +1640,60 @@ class TradingOrchestrator:
|
||||
return obj
|
||||
|
||||
def load_inference_history_from_disk(self, symbol: str, days_back: int = 7) -> List[Dict]:
|
||||
"""Load inference history from disk for training replay"""
|
||||
"""Load inference history from SQLite database for training replay"""
|
||||
try:
|
||||
inference_dir = Path("training_data/inference_history")
|
||||
if not inference_dir.exists():
|
||||
import sqlite3
|
||||
|
||||
# Connect to database
|
||||
db_path = Path("training_data/inference_db/inference_history.db")
|
||||
if not db_path.exists():
|
||||
return []
|
||||
|
||||
# Get files for the symbol from the last N days
|
||||
cutoff_date = datetime.now() - timedelta(days=days_back)
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get records for the symbol from the last N days
|
||||
cutoff_date = (datetime.now() - timedelta(days=days_back)).isoformat()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT model_name, symbol, timestamp, action, confidence, probabilities, timeframe, metadata
|
||||
FROM inference_records
|
||||
WHERE symbol = ? AND timestamp >= ?
|
||||
ORDER BY timestamp ASC
|
||||
''', (symbol, cutoff_date))
|
||||
|
||||
inference_records = []
|
||||
for row in cursor.fetchall():
|
||||
record = {
|
||||
'model_name': row[0],
|
||||
'symbol': row[1],
|
||||
'timestamp': row[2],
|
||||
'prediction': {
|
||||
'action': row[3],
|
||||
'confidence': row[4],
|
||||
'probabilities': eval(row[5]) if row[5] else {},
|
||||
'timeframe': row[6]
|
||||
},
|
||||
'metadata': eval(row[7]) if row[7] else {}
|
||||
}
|
||||
inference_records.append(record)
|
||||
|
||||
for filepath in inference_dir.glob(f"{symbol}_*.json"):
|
||||
try:
|
||||
# Extract timestamp from filename
|
||||
filename_parts = filepath.stem.split('_')
|
||||
if len(filename_parts) >= 3:
|
||||
timestamp_str = f"{filename_parts[-2]}_{filename_parts[-1]}"
|
||||
file_timestamp = datetime.strptime(timestamp_str, '%Y%m%d_%H%M%S')
|
||||
|
||||
if file_timestamp >= cutoff_date:
|
||||
with open(filepath, 'r') as f:
|
||||
record = json.load(f)
|
||||
inference_records.append(record)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading inference file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
# Sort by timestamp
|
||||
inference_records.sort(key=lambda x: x['timestamp'])
|
||||
logger.info(f"Loaded {len(inference_records)} inference records for {symbol} from disk")
|
||||
conn.close()
|
||||
logger.info(f"Loaded {len(inference_records)} inference records for {symbol} from SQLite database")
|
||||
|
||||
return inference_records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading inference history from disk: {e}")
|
||||
logger.error(f"Error loading inference history from database: {e}")
|
||||
return []
|
||||
|
||||
async def load_model_inference_history(self, model_name: str, limit: int = 50) -> List[Dict]:
|
||||
"""Load inference history for a specific model from disk"""
|
||||
"""Load inference history for a specific model from SQLite database"""
|
||||
try:
|
||||
model_dir = Path(f"training_data/inference_history/{model_name}")
|
||||
if not model_dir.exists():
|
||||
return []
|
||||
|
||||
# Get all inference files
|
||||
files = list(model_dir.glob("inference_*.json"))
|
||||
files.sort(key=lambda x: x.stat().st_mtime, reverse=True) # Newest first
|
||||
|
||||
# Load up to 'limit' files
|
||||
inference_records = []
|
||||
for filepath in files[:limit]:
|
||||
try:
|
||||
with open(filepath, 'r') as f:
|
||||
record = json.load(f)
|
||||
inference_records.append(record)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error loading inference file {filepath}: {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"Loaded {len(inference_records)} inference records for {model_name}")
|
||||
return inference_records
|
||||
# Use the SQLite database method
|
||||
records = self.get_inference_records_from_db(model_name, limit)
|
||||
logger.info(f"Loaded {len(records)} inference records for {model_name} from database")
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model inference history for {model_name}: {e}")
|
||||
@ -3284,6 +3445,15 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error initializing checkpoint manager: {e}")
|
||||
self.checkpoint_manager = None
|
||||
|
||||
def _schedule_database_cleanup(self):
|
||||
"""Schedule periodic database cleanup"""
|
||||
try:
|
||||
# Clean up old inference records (keep 30 days)
|
||||
self.inference_logger.cleanup_old_logs(days_to_keep=30)
|
||||
logger.info("Database cleanup completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Database cleanup failed: {e}")
|
||||
|
||||
def _save_training_checkpoints(self, models_trained: List[str], performance_score: float):
|
||||
"""Save checkpoints for trained models if performance improved
|
||||
|
||||
@ -3419,4 +3589,45 @@ class TradingOrchestrator:
|
||||
logger.error(f"Error saving checkpoint for {model_name}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in _save_training_checkpoints: {e}")
|
||||
logger.error(f"Error in _save_training_checkpoints: {e}")
|
||||
def _schedule_database_cleanup(self):
|
||||
"""Schedule periodic database cleanup"""
|
||||
try:
|
||||
# Clean up old inference records (keep 30 days)
|
||||
self.inference_logger.cleanup_old_logs(days_to_keep=30)
|
||||
logger.info("Database cleanup completed")
|
||||
except Exception as e:
|
||||
logger.error(f"Database cleanup failed: {e}")
|
||||
|
||||
def log_model_inference(self, model_name: str, symbol: str, action: str,
|
||||
confidence: float, probabilities: Dict[str, float],
|
||||
input_features: Any, processing_time_ms: float,
|
||||
checkpoint_id: str = None, metadata: Dict[str, Any] = None) -> bool:
|
||||
"""
|
||||
Centralized method for models to log their inferences
|
||||
|
||||
This replaces scattered logger.info() calls throughout the codebase
|
||||
"""
|
||||
return log_model_inference(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
probabilities=probabilities,
|
||||
input_features=input_features,
|
||||
processing_time_ms=processing_time_ms,
|
||||
checkpoint_id=checkpoint_id,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
def get_model_inference_stats(self, model_name: str, hours: int = 24) -> Dict[str, Any]:
|
||||
"""Get inference statistics for a model"""
|
||||
return self.inference_logger.get_model_stats(model_name, hours)
|
||||
|
||||
def get_checkpoint_metadata_fast(self, model_name: str) -> Optional[Any]:
|
||||
"""
|
||||
Get checkpoint metadata without loading the full model
|
||||
|
||||
This is much faster than loading the entire checkpoint just to get metadata
|
||||
"""
|
||||
return self.db_manager.get_best_checkpoint_metadata(model_name)
|
BIN
data/trading_system.db
Normal file
BIN
data/trading_system.db
Normal file
Binary file not shown.
280
docs/logging_system_upgrade.md
Normal file
280
docs/logging_system_upgrade.md
Normal file
@ -0,0 +1,280 @@
|
||||
# Trading System Logging Upgrade
|
||||
|
||||
## Overview
|
||||
|
||||
This upgrade implements a comprehensive logging and metadata management system that addresses the key issues:
|
||||
|
||||
1. **Eliminates scattered "No checkpoints found" logs** during runtime
|
||||
2. **Fast checkpoint metadata access** without loading full models
|
||||
3. **Centralized inference logging** with database and text file storage
|
||||
4. **Structured tracking** of model performance and checkpoints
|
||||
|
||||
## Key Components
|
||||
|
||||
### 1. Database Manager (`utils/database_manager.py`)
|
||||
|
||||
**Purpose**: SQLite-based storage for structured data
|
||||
|
||||
**Features**:
|
||||
- Inference records logging with deduplication
|
||||
- Checkpoint metadata storage (separate from model weights)
|
||||
- Model performance tracking
|
||||
- Fast queries without loading model files
|
||||
|
||||
**Tables**:
|
||||
- `inference_records`: All model predictions with metadata
|
||||
- `checkpoint_metadata`: Checkpoint info without model weights
|
||||
- `model_performance`: Daily aggregated statistics
|
||||
|
||||
### 2. Inference Logger (`utils/inference_logger.py`)
|
||||
|
||||
**Purpose**: Centralized logging for all model inferences
|
||||
|
||||
**Features**:
|
||||
- Single function call replaces scattered `logger.info()` calls
|
||||
- Automatic feature hashing for deduplication
|
||||
- Memory usage tracking
|
||||
- Processing time measurement
|
||||
- Dual storage (database + text files)
|
||||
|
||||
**Usage**:
|
||||
```python
|
||||
from utils.inference_logger import log_model_inference
|
||||
|
||||
log_model_inference(
|
||||
model_name="dqn_agent",
|
||||
symbol="ETH/USDT",
|
||||
action="BUY",
|
||||
confidence=0.85,
|
||||
probabilities={"BUY": 0.85, "SELL": 0.10, "HOLD": 0.05},
|
||||
input_features=features_array,
|
||||
processing_time_ms=12.5,
|
||||
checkpoint_id="dqn_agent_20250725_143500"
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Text Logger (`utils/text_logger.py`)
|
||||
|
||||
**Purpose**: Human-readable log files for tracking
|
||||
|
||||
**Features**:
|
||||
- Separate files for different event types
|
||||
- Clean, tabular format
|
||||
- Automatic cleanup of old entries
|
||||
- Easy to read and grep
|
||||
|
||||
**Files**:
|
||||
- `logs/inference_records.txt`: All model predictions
|
||||
- `logs/checkpoint_events.txt`: Save/load events
|
||||
- `logs/system_events.txt`: General system events
|
||||
|
||||
### 4. Enhanced Checkpoint Manager (`utils/checkpoint_manager.py`)
|
||||
|
||||
**Purpose**: Improved checkpoint handling with metadata separation
|
||||
|
||||
**Features**:
|
||||
- Database-backed metadata storage
|
||||
- Fast metadata queries without loading models
|
||||
- Eliminates "No checkpoints found" spam
|
||||
- Backward compatibility with existing code
|
||||
|
||||
## Benefits
|
||||
|
||||
### 1. Performance Improvements
|
||||
|
||||
**Before**: Loading full checkpoint just to get metadata
|
||||
```python
|
||||
# Old way - loads entire model!
|
||||
checkpoint_path, metadata = load_best_checkpoint("dqn_agent")
|
||||
loss = metadata.loss # Expensive operation
|
||||
```
|
||||
|
||||
**After**: Fast metadata access from database
|
||||
```python
|
||||
# New way - database query only
|
||||
metadata = db_manager.get_best_checkpoint_metadata("dqn_agent")
|
||||
loss = metadata.performance_metrics['loss'] # Fast!
|
||||
```
|
||||
|
||||
### 2. Cleaner Runtime Logs
|
||||
|
||||
**Before**: Scattered logs everywhere
|
||||
```
|
||||
2025-07-25 14:34:39,749 - utils.checkpoint_manager - INFO - No checkpoints found for dqn_agent
|
||||
2025-07-25 14:34:39,754 - utils.checkpoint_manager - INFO - No checkpoints found for enhanced_cnn
|
||||
2025-07-25 14:34:39,756 - utils.checkpoint_manager - INFO - No checkpoints found for extrema_trainer
|
||||
```
|
||||
|
||||
**After**: Clean, structured logging
|
||||
```
|
||||
2025-07-25 14:34:39 | dqn_agent | ETH/USDT | BUY | conf=0.850 | time= 12.5ms [checkpoint: dqn_agent_20250725_143500]
|
||||
2025-07-25 14:34:40 | enhanced_cnn | ETH/USDT | HOLD | conf=0.720 | time= 8.2ms [checkpoint: enhanced_cnn_20250725_143501]
|
||||
```
|
||||
|
||||
### 3. Structured Data Storage
|
||||
|
||||
**Database Schema**:
|
||||
```sql
|
||||
-- Fast metadata queries
|
||||
SELECT * FROM checkpoint_metadata WHERE model_name = 'dqn_agent' AND is_active = TRUE;
|
||||
|
||||
-- Performance analysis
|
||||
SELECT model_name, AVG(confidence), COUNT(*)
|
||||
FROM inference_records
|
||||
WHERE timestamp > datetime('now', '-24 hours')
|
||||
GROUP BY model_name;
|
||||
```
|
||||
|
||||
### 4. Easy Integration
|
||||
|
||||
**In Model Code**:
|
||||
```python
|
||||
# Replace scattered logging
|
||||
# OLD: logger.info(f"DQN prediction: {action} confidence={conf}")
|
||||
|
||||
# NEW: Centralized logging
|
||||
self.orchestrator.log_model_inference(
|
||||
model_name="dqn_agent",
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
probabilities=probs,
|
||||
input_features=features,
|
||||
processing_time_ms=processing_time
|
||||
)
|
||||
```
|
||||
|
||||
## Implementation Guide
|
||||
|
||||
### 1. Update Model Classes
|
||||
|
||||
Add inference logging to prediction methods:
|
||||
|
||||
```python
|
||||
class DQNAgent:
|
||||
def predict(self, state):
|
||||
start_time = time.time()
|
||||
|
||||
# Your prediction logic here
|
||||
action = self._predict_action(state)
|
||||
confidence = self._calculate_confidence()
|
||||
|
||||
processing_time = (time.time() - start_time) * 1000
|
||||
|
||||
# Log the inference
|
||||
self.orchestrator.log_model_inference(
|
||||
model_name="dqn_agent",
|
||||
symbol=self.symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
probabilities=self.action_probabilities,
|
||||
input_features=state,
|
||||
processing_time_ms=processing_time,
|
||||
checkpoint_id=self.current_checkpoint_id
|
||||
)
|
||||
|
||||
return action
|
||||
```
|
||||
|
||||
### 2. Update Checkpoint Saving
|
||||
|
||||
Use the enhanced checkpoint manager:
|
||||
|
||||
```python
|
||||
from utils.checkpoint_manager import save_checkpoint
|
||||
|
||||
# Save with metadata
|
||||
checkpoint_metadata = save_checkpoint(
|
||||
model=self.model,
|
||||
model_name="dqn_agent",
|
||||
model_type="rl",
|
||||
performance_metrics={"loss": 0.0234, "accuracy": 0.87},
|
||||
training_metadata={"epochs": 100, "lr": 0.001}
|
||||
)
|
||||
```
|
||||
|
||||
### 3. Fast Metadata Access
|
||||
|
||||
Get checkpoint info without loading models:
|
||||
|
||||
```python
|
||||
# Fast metadata access
|
||||
metadata = orchestrator.get_checkpoint_metadata_fast("dqn_agent")
|
||||
if metadata:
|
||||
current_loss = metadata.performance_metrics['loss']
|
||||
checkpoint_id = metadata.checkpoint_id
|
||||
```
|
||||
|
||||
## Migration Steps
|
||||
|
||||
1. **Install new dependencies** (if any)
|
||||
2. **Update model classes** to use centralized logging
|
||||
3. **Replace checkpoint loading** with database queries where possible
|
||||
4. **Remove scattered logger.info()** calls for inferences
|
||||
5. **Test with demo script**: `python demo_logging_system.py`
|
||||
|
||||
## File Structure
|
||||
|
||||
```
|
||||
utils/
|
||||
├── database_manager.py # SQLite database management
|
||||
├── inference_logger.py # Centralized inference logging
|
||||
├── text_logger.py # Human-readable text logs
|
||||
└── checkpoint_manager.py # Enhanced checkpoint handling
|
||||
|
||||
logs/ # Text log files
|
||||
├── inference_records.txt
|
||||
├── checkpoint_events.txt
|
||||
└── system_events.txt
|
||||
|
||||
data/
|
||||
└── trading_system.db # SQLite database
|
||||
|
||||
demo_logging_system.py # Demonstration script
|
||||
```
|
||||
|
||||
## Monitoring and Maintenance
|
||||
|
||||
### Daily Tasks
|
||||
- Check `logs/inference_records.txt` for recent activity
|
||||
- Monitor database size: `ls -lh data/trading_system.db`
|
||||
|
||||
### Weekly Tasks
|
||||
- Run cleanup: `inference_logger.cleanup_old_logs(days_to_keep=30)`
|
||||
- Check model performance trends in database
|
||||
|
||||
### Monthly Tasks
|
||||
- Archive old log files
|
||||
- Analyze model performance statistics
|
||||
- Review checkpoint storage usage
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Database locked**: Multiple processes accessing SQLite
|
||||
- Solution: Use connection timeout and proper context managers
|
||||
|
||||
2. **Log files growing too large**:
|
||||
- Solution: Run `text_logger.cleanup_old_logs(max_lines=10000)`
|
||||
|
||||
3. **Missing checkpoint metadata**:
|
||||
- Solution: System falls back to file-based approach automatically
|
||||
|
||||
### Debug Commands
|
||||
|
||||
```python
|
||||
# Check database status
|
||||
db_manager = get_database_manager()
|
||||
checkpoints = db_manager.list_checkpoints("dqn_agent")
|
||||
|
||||
# Check recent inferences
|
||||
inference_logger = get_inference_logger()
|
||||
stats = inference_logger.get_model_stats("dqn_agent", hours=24)
|
||||
|
||||
# View text logs
|
||||
text_logger = get_text_logger()
|
||||
recent = text_logger.get_recent_inferences(lines=50)
|
||||
```
|
||||
|
||||
This upgrade provides a solid foundation for tracking model performance, eliminating log spam, and enabling fast metadata access without the overhead of loading full model checkpoints.
|
133
fix_cache.py
Normal file
133
fix_cache.py
Normal file
@ -0,0 +1,133 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Cache Fix Script
|
||||
|
||||
Quick script to diagnose and fix cache issues, including the Parquet deserialization error
|
||||
"""
|
||||
|
||||
import sys
|
||||
import logging
|
||||
from utils.cache_manager import get_cache_manager, cleanup_corrupted_cache, get_cache_health
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def main():
|
||||
"""Main cache fix routine"""
|
||||
print("=== Trading System Cache Fix ===")
|
||||
print()
|
||||
|
||||
# Get cache manager
|
||||
cache_manager = get_cache_manager()
|
||||
|
||||
# 1. Scan cache health
|
||||
print("1. Scanning cache health...")
|
||||
health_summary = get_cache_health()
|
||||
|
||||
print(f"Total files: {health_summary['total_files']}")
|
||||
print(f"Valid files: {health_summary['valid_files']}")
|
||||
print(f"Corrupted files: {health_summary['corrupted_files']}")
|
||||
print(f"Health percentage: {health_summary['health_percentage']:.1f}%")
|
||||
print(f"Total cache size: {health_summary['total_size_mb']:.1f} MB")
|
||||
print()
|
||||
|
||||
# Show detailed report
|
||||
for cache_dir, report in health_summary['directories'].items():
|
||||
if report['total_files'] > 0:
|
||||
print(f"Directory: {cache_dir}")
|
||||
print(f" Files: {report['valid_files']}/{report['total_files']} valid")
|
||||
print(f" Size: {report['total_size_mb']:.1f} MB")
|
||||
|
||||
if report['corrupted_files'] > 0:
|
||||
print(f" CORRUPTED FILES ({report['corrupted_files']}):")
|
||||
for corrupted in report['corrupted_files_list']:
|
||||
print(f" - {corrupted['file']}: {corrupted['error']}")
|
||||
|
||||
if report['old_files']:
|
||||
print(f" OLD FILES ({len(report['old_files'])}):")
|
||||
for old_file in report['old_files'][:3]: # Show first 3
|
||||
print(f" - {old_file['file']}: {old_file['age_days']} days old")
|
||||
if len(report['old_files']) > 3:
|
||||
print(f" ... and {len(report['old_files']) - 3} more")
|
||||
print()
|
||||
|
||||
# 2. Fix corrupted files
|
||||
if health_summary['corrupted_files'] > 0:
|
||||
print("2. Fixing corrupted files...")
|
||||
|
||||
# First show what would be deleted
|
||||
print("Files that will be deleted:")
|
||||
dry_run_result = cleanup_corrupted_cache(dry_run=True)
|
||||
for cache_dir, files in dry_run_result.items():
|
||||
if files:
|
||||
print(f" {cache_dir}:")
|
||||
for file_info in files:
|
||||
print(f" {file_info}")
|
||||
|
||||
# Ask for confirmation
|
||||
response = input("\nProceed with deletion? (y/N): ").strip().lower()
|
||||
if response == 'y':
|
||||
print("Deleting corrupted files...")
|
||||
actual_result = cleanup_corrupted_cache(dry_run=False)
|
||||
|
||||
deleted_count = 0
|
||||
for cache_dir, files in actual_result.items():
|
||||
for file_info in files:
|
||||
if "DELETED:" in file_info:
|
||||
deleted_count += 1
|
||||
|
||||
print(f"Deleted {deleted_count} corrupted files")
|
||||
else:
|
||||
print("Skipped deletion")
|
||||
else:
|
||||
print("2. No corrupted files found - cache is healthy!")
|
||||
|
||||
print()
|
||||
|
||||
# 3. Optional: Clean old files
|
||||
print("3. Checking for old files...")
|
||||
old_files_result = cache_manager.cleanup_old_files(days_to_keep=7, dry_run=True)
|
||||
|
||||
old_file_count = sum(len(files) for files in old_files_result.values())
|
||||
if old_file_count > 0:
|
||||
print(f"Found {old_file_count} old files (>7 days)")
|
||||
response = input("Clean up old files? (y/N): ").strip().lower()
|
||||
if response == 'y':
|
||||
actual_old_result = cache_manager.cleanup_old_files(days_to_keep=7, dry_run=False)
|
||||
deleted_old_count = sum(len([f for f in files if "DELETED:" in f]) for files in actual_old_result.values())
|
||||
print(f"Deleted {deleted_old_count} old files")
|
||||
else:
|
||||
print("Skipped old file cleanup")
|
||||
else:
|
||||
print("No old files found")
|
||||
|
||||
print()
|
||||
print("=== Cache Fix Complete ===")
|
||||
print("The system should now work without Parquet deserialization errors.")
|
||||
print("If you continue to see issues, consider running with --emergency-reset")
|
||||
|
||||
def emergency_reset():
|
||||
"""Emergency cache reset"""
|
||||
print("=== EMERGENCY CACHE RESET ===")
|
||||
print("WARNING: This will delete ALL cache files!")
|
||||
print("You will need to re-download all historical data.")
|
||||
print()
|
||||
|
||||
response = input("Are you sure you want to proceed? Type 'DELETE ALL CACHE' to confirm: ")
|
||||
if response == "DELETE ALL CACHE":
|
||||
cache_manager = get_cache_manager()
|
||||
success = cache_manager.emergency_cache_reset(confirm=True)
|
||||
if success:
|
||||
print("Emergency cache reset completed.")
|
||||
print("All cache files have been deleted.")
|
||||
else:
|
||||
print("Emergency reset failed.")
|
||||
else:
|
||||
print("Emergency reset cancelled.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) > 1 and sys.argv[1] == "--emergency-reset":
|
||||
emergency_reset()
|
||||
else:
|
||||
main()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
295
utils/cache_manager.py
Normal file
295
utils/cache_manager.py
Normal file
@ -0,0 +1,295 @@
|
||||
"""
|
||||
Cache Manager for Trading System
|
||||
|
||||
Utilities for managing and cleaning up cache files, including:
|
||||
- Parquet file validation and repair
|
||||
- Cache cleanup and maintenance
|
||||
- Cache health monitoring
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CacheManager:
|
||||
"""Manages cache files for the trading system"""
|
||||
|
||||
def __init__(self, cache_dirs: List[str] = None):
|
||||
"""
|
||||
Initialize cache manager
|
||||
|
||||
Args:
|
||||
cache_dirs: List of cache directories to manage
|
||||
"""
|
||||
self.cache_dirs = cache_dirs or [
|
||||
"data/cache",
|
||||
"data/monthly_cache",
|
||||
"data/pivot_cache"
|
||||
]
|
||||
|
||||
# Ensure cache directories exist
|
||||
for cache_dir in self.cache_dirs:
|
||||
Path(cache_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def validate_parquet_file(self, file_path: Path) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
Validate a Parquet file
|
||||
|
||||
Args:
|
||||
file_path: Path to the Parquet file
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_message)
|
||||
"""
|
||||
try:
|
||||
if not file_path.exists():
|
||||
return False, "File does not exist"
|
||||
|
||||
if file_path.stat().st_size == 0:
|
||||
return False, "File is empty"
|
||||
|
||||
# Try to read the file
|
||||
df = pd.read_parquet(file_path)
|
||||
|
||||
if df.empty:
|
||||
return False, "File contains no data"
|
||||
|
||||
# Check for required columns (basic validation)
|
||||
required_columns = ['timestamp', 'open', 'high', 'low', 'close', 'volume']
|
||||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||||
|
||||
if missing_columns:
|
||||
return False, f"Missing required columns: {missing_columns}"
|
||||
|
||||
return True, None
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
corrupted_indicators = [
|
||||
"parquet magic bytes not found",
|
||||
"corrupted",
|
||||
"couldn't deserialize thrift",
|
||||
"don't know what type",
|
||||
"invalid parquet file",
|
||||
"unexpected end of file",
|
||||
"invalid metadata"
|
||||
]
|
||||
|
||||
if any(indicator in error_str for indicator in corrupted_indicators):
|
||||
return False, f"Corrupted Parquet file: {e}"
|
||||
else:
|
||||
return False, f"Validation error: {e}"
|
||||
|
||||
def scan_cache_health(self) -> Dict[str, Dict]:
|
||||
"""
|
||||
Scan all cache directories for file health
|
||||
|
||||
Returns:
|
||||
Dictionary with cache health information
|
||||
"""
|
||||
health_report = {}
|
||||
|
||||
for cache_dir in self.cache_dirs:
|
||||
cache_path = Path(cache_dir)
|
||||
if not cache_path.exists():
|
||||
continue
|
||||
|
||||
dir_report = {
|
||||
'total_files': 0,
|
||||
'valid_files': 0,
|
||||
'corrupted_files': 0,
|
||||
'empty_files': 0,
|
||||
'total_size_mb': 0.0,
|
||||
'corrupted_files_list': [],
|
||||
'old_files': []
|
||||
}
|
||||
|
||||
# Scan all Parquet files
|
||||
for file_path in cache_path.glob("*.parquet"):
|
||||
dir_report['total_files'] += 1
|
||||
file_size_mb = file_path.stat().st_size / (1024 * 1024)
|
||||
dir_report['total_size_mb'] += file_size_mb
|
||||
|
||||
# Check file age
|
||||
file_age = datetime.now() - datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
if file_age > timedelta(days=7): # Files older than 7 days
|
||||
dir_report['old_files'].append({
|
||||
'file': str(file_path),
|
||||
'age_days': file_age.days,
|
||||
'size_mb': file_size_mb
|
||||
})
|
||||
|
||||
# Validate file
|
||||
is_valid, error_msg = self.validate_parquet_file(file_path)
|
||||
|
||||
if is_valid:
|
||||
dir_report['valid_files'] += 1
|
||||
else:
|
||||
if "empty" in error_msg.lower():
|
||||
dir_report['empty_files'] += 1
|
||||
else:
|
||||
dir_report['corrupted_files'] += 1
|
||||
dir_report['corrupted_files_list'].append({
|
||||
'file': str(file_path),
|
||||
'error': error_msg,
|
||||
'size_mb': file_size_mb
|
||||
})
|
||||
|
||||
health_report[cache_dir] = dir_report
|
||||
|
||||
return health_report
|
||||
|
||||
def cleanup_corrupted_files(self, dry_run: bool = True) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Clean up corrupted cache files
|
||||
|
||||
Args:
|
||||
dry_run: If True, only report what would be deleted
|
||||
|
||||
Returns:
|
||||
Dictionary of deleted files by directory
|
||||
"""
|
||||
deleted_files = {}
|
||||
|
||||
for cache_dir in self.cache_dirs:
|
||||
cache_path = Path(cache_dir)
|
||||
if not cache_path.exists():
|
||||
continue
|
||||
|
||||
deleted_files[cache_dir] = []
|
||||
|
||||
for file_path in cache_path.glob("*.parquet"):
|
||||
is_valid, error_msg = self.validate_parquet_file(file_path)
|
||||
|
||||
if not is_valid:
|
||||
if dry_run:
|
||||
deleted_files[cache_dir].append(f"WOULD DELETE: {file_path} ({error_msg})")
|
||||
logger.info(f"Would delete corrupted file: {file_path} ({error_msg})")
|
||||
else:
|
||||
try:
|
||||
file_path.unlink()
|
||||
deleted_files[cache_dir].append(f"DELETED: {file_path} ({error_msg})")
|
||||
logger.info(f"Deleted corrupted file: {file_path}")
|
||||
except Exception as e:
|
||||
deleted_files[cache_dir].append(f"FAILED TO DELETE: {file_path} ({e})")
|
||||
logger.error(f"Failed to delete corrupted file {file_path}: {e}")
|
||||
|
||||
return deleted_files
|
||||
|
||||
def cleanup_old_files(self, days_to_keep: int = 7, dry_run: bool = True) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Clean up old cache files
|
||||
|
||||
Args:
|
||||
days_to_keep: Number of days to keep files
|
||||
dry_run: If True, only report what would be deleted
|
||||
|
||||
Returns:
|
||||
Dictionary of deleted files by directory
|
||||
"""
|
||||
deleted_files = {}
|
||||
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
for cache_dir in self.cache_dirs:
|
||||
cache_path = Path(cache_dir)
|
||||
if not cache_path.exists():
|
||||
continue
|
||||
|
||||
deleted_files[cache_dir] = []
|
||||
|
||||
for file_path in cache_path.glob("*.parquet"):
|
||||
file_mtime = datetime.fromtimestamp(file_path.stat().st_mtime)
|
||||
|
||||
if file_mtime < cutoff_date:
|
||||
age_days = (datetime.now() - file_mtime).days
|
||||
|
||||
if dry_run:
|
||||
deleted_files[cache_dir].append(f"WOULD DELETE: {file_path} (age: {age_days} days)")
|
||||
logger.info(f"Would delete old file: {file_path} (age: {age_days} days)")
|
||||
else:
|
||||
try:
|
||||
file_path.unlink()
|
||||
deleted_files[cache_dir].append(f"DELETED: {file_path} (age: {age_days} days)")
|
||||
logger.info(f"Deleted old file: {file_path}")
|
||||
except Exception as e:
|
||||
deleted_files[cache_dir].append(f"FAILED TO DELETE: {file_path} ({e})")
|
||||
logger.error(f"Failed to delete old file {file_path}: {e}")
|
||||
|
||||
return deleted_files
|
||||
|
||||
def get_cache_summary(self) -> Dict[str, any]:
|
||||
"""Get a summary of cache usage"""
|
||||
health_report = self.scan_cache_health()
|
||||
|
||||
total_files = sum(report['total_files'] for report in health_report.values())
|
||||
total_valid = sum(report['valid_files'] for report in health_report.values())
|
||||
total_corrupted = sum(report['corrupted_files'] for report in health_report.values())
|
||||
total_size_mb = sum(report['total_size_mb'] for report in health_report.values())
|
||||
|
||||
return {
|
||||
'total_files': total_files,
|
||||
'valid_files': total_valid,
|
||||
'corrupted_files': total_corrupted,
|
||||
'health_percentage': (total_valid / total_files * 100) if total_files > 0 else 0,
|
||||
'total_size_mb': total_size_mb,
|
||||
'directories': health_report
|
||||
}
|
||||
|
||||
def emergency_cache_reset(self, confirm: bool = False) -> bool:
|
||||
"""
|
||||
Emergency cache reset - deletes all cache files
|
||||
|
||||
Args:
|
||||
confirm: Must be True to actually delete files
|
||||
|
||||
Returns:
|
||||
True if reset was performed
|
||||
"""
|
||||
if not confirm:
|
||||
logger.warning("Emergency cache reset called but not confirmed")
|
||||
return False
|
||||
|
||||
deleted_count = 0
|
||||
|
||||
for cache_dir in self.cache_dirs:
|
||||
cache_path = Path(cache_dir)
|
||||
if not cache_path.exists():
|
||||
continue
|
||||
|
||||
for file_path in cache_path.glob("*"):
|
||||
try:
|
||||
if file_path.is_file():
|
||||
file_path.unlink()
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete {file_path}: {e}")
|
||||
|
||||
logger.warning(f"Emergency cache reset completed: deleted {deleted_count} files")
|
||||
return True
|
||||
|
||||
# Global cache manager instance
|
||||
_cache_manager_instance = None
|
||||
|
||||
def get_cache_manager() -> CacheManager:
|
||||
"""Get the global cache manager instance"""
|
||||
global _cache_manager_instance
|
||||
|
||||
if _cache_manager_instance is None:
|
||||
_cache_manager_instance = CacheManager()
|
||||
|
||||
return _cache_manager_instance
|
||||
|
||||
def cleanup_corrupted_cache(dry_run: bool = True) -> Dict[str, List[str]]:
|
||||
"""Convenience function to clean up corrupted cache files"""
|
||||
cache_manager = get_cache_manager()
|
||||
return cache_manager.cleanup_corrupted_files(dry_run=dry_run)
|
||||
|
||||
def get_cache_health() -> Dict[str, any]:
|
||||
"""Convenience function to get cache health summary"""
|
||||
cache_manager = get_cache_manager()
|
||||
return cache_manager.get_cache_summary()
|
@ -5,6 +5,7 @@ This module provides functionality for managing model checkpoints, including:
|
||||
- Saving checkpoints with metadata
|
||||
- Loading the best checkpoint based on performance metrics
|
||||
- Cleaning up old or underperforming checkpoints
|
||||
- Database-backed metadata storage for efficient access
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -13,9 +14,13 @@ import glob
|
||||
import logging
|
||||
import shutil
|
||||
import torch
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
|
||||
from .database_manager import get_database_manager, CheckpointMetadata
|
||||
from .text_logger import get_text_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global checkpoint manager instance
|
||||
@ -46,7 +51,7 @@ def get_checkpoint_manager(checkpoint_dir: str = "models/checkpoints", max_check
|
||||
|
||||
def save_checkpoint(model, model_name: str, model_type: str, performance_metrics: Dict[str, float], training_metadata: Dict[str, Any] = None, checkpoint_dir: str = "models/checkpoints") -> Any:
|
||||
"""
|
||||
Save a checkpoint with metadata
|
||||
Save a checkpoint with metadata to both filesystem and database
|
||||
|
||||
Args:
|
||||
model: The model to save
|
||||
@ -64,57 +69,90 @@ def save_checkpoint(model, model_name: str, model_type: str, performance_metrics
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Create timestamp
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
timestamp = datetime.now()
|
||||
timestamp_str = timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# Create checkpoint path
|
||||
model_dir = os.path.join(checkpoint_dir, model_name)
|
||||
os.makedirs(model_dir, exist_ok=True)
|
||||
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp}")
|
||||
checkpoint_path = os.path.join(model_dir, f"{model_name}_{timestamp_str}")
|
||||
checkpoint_id = f"{model_name}_{timestamp_str}"
|
||||
|
||||
# Save model
|
||||
torch_path = f"{checkpoint_path}.pt"
|
||||
if hasattr(model, 'save'):
|
||||
# Use model's save method if available
|
||||
model.save(checkpoint_path)
|
||||
else:
|
||||
# Otherwise, save state_dict
|
||||
torch_path = f"{checkpoint_path}.pt"
|
||||
torch.save({
|
||||
'model_state_dict': model.state_dict() if hasattr(model, 'state_dict') else None,
|
||||
'model_name': model_name,
|
||||
'model_type': model_type,
|
||||
'timestamp': timestamp
|
||||
'timestamp': timestamp_str,
|
||||
'checkpoint_id': checkpoint_id
|
||||
}, torch_path)
|
||||
|
||||
# Create metadata
|
||||
checkpoint_metadata = {
|
||||
# Calculate file size
|
||||
file_size_mb = os.path.getsize(torch_path) / (1024 * 1024) if os.path.exists(torch_path) else 0.0
|
||||
|
||||
# Save metadata to database
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = CheckpointMetadata(
|
||||
checkpoint_id=checkpoint_id,
|
||||
model_name=model_name,
|
||||
model_type=model_type,
|
||||
timestamp=timestamp,
|
||||
performance_metrics=performance_metrics,
|
||||
training_metadata=training_metadata or {},
|
||||
file_path=torch_path,
|
||||
file_size_mb=file_size_mb,
|
||||
is_active=True # New checkpoint is active by default
|
||||
)
|
||||
|
||||
# Save to database
|
||||
if db_manager.save_checkpoint_metadata(checkpoint_metadata):
|
||||
# Log checkpoint save event to text file
|
||||
text_logger = get_text_logger()
|
||||
text_logger.log_checkpoint_event(
|
||||
model_name=model_name,
|
||||
event_type="SAVED",
|
||||
checkpoint_id=checkpoint_id,
|
||||
details=f"loss={performance_metrics.get('loss', 'N/A')}, size={file_size_mb:.1f}MB"
|
||||
)
|
||||
logger.info(f"Checkpoint saved: {checkpoint_id}")
|
||||
else:
|
||||
logger.warning(f"Failed to save checkpoint metadata to database: {checkpoint_id}")
|
||||
|
||||
# Also save legacy JSON metadata for backward compatibility
|
||||
legacy_metadata = {
|
||||
'model_name': model_name,
|
||||
'model_type': model_type,
|
||||
'timestamp': timestamp,
|
||||
'timestamp': timestamp_str,
|
||||
'performance_metrics': performance_metrics,
|
||||
'training_metadata': training_metadata or {},
|
||||
'checkpoint_id': f"{model_name}_{timestamp}"
|
||||
'checkpoint_id': checkpoint_id,
|
||||
'performance_score': performance_metrics.get('accuracy', performance_metrics.get('reward', 0.0)),
|
||||
'created_at': timestamp_str
|
||||
}
|
||||
|
||||
# Add performance score for sorting
|
||||
primary_metric = 'accuracy' if 'accuracy' in performance_metrics else 'reward'
|
||||
checkpoint_metadata['performance_score'] = performance_metrics.get(primary_metric, 0.0)
|
||||
checkpoint_metadata['created_at'] = timestamp
|
||||
|
||||
# Save metadata
|
||||
with open(f"{checkpoint_path}_metadata.json", 'w') as f:
|
||||
json.dump(checkpoint_metadata, f, indent=2)
|
||||
json.dump(legacy_metadata, f, indent=2)
|
||||
|
||||
# Get checkpoint manager and clean up old checkpoints
|
||||
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
||||
checkpoint_manager._cleanup_checkpoints(model_name)
|
||||
|
||||
# Return metadata as an object
|
||||
class CheckpointMetadata:
|
||||
# Return metadata as an object for backward compatibility
|
||||
class CheckpointMetadataObj:
|
||||
def __init__(self, metadata):
|
||||
for key, value in metadata.items():
|
||||
setattr(self, key, value)
|
||||
# Add database fields
|
||||
self.checkpoint_id = checkpoint_id
|
||||
self.loss = performance_metrics.get('loss', performance_metrics.get('accuracy', 0.0))
|
||||
|
||||
return CheckpointMetadata(checkpoint_metadata)
|
||||
return CheckpointMetadataObj(legacy_metadata)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint: {e}")
|
||||
@ -122,7 +160,7 @@ def save_checkpoint(model, model_name: str, model_type: str, performance_metrics
|
||||
|
||||
def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoints") -> Optional[Tuple[str, Any]]:
|
||||
"""
|
||||
Load the best checkpoint based on performance metrics
|
||||
Load the best checkpoint based on performance metrics using database metadata
|
||||
|
||||
Args:
|
||||
model_name: Name of the model
|
||||
@ -132,29 +170,77 @@ def load_best_checkpoint(model_name: str, checkpoint_dir: str = "models/checkpoi
|
||||
Optional[Tuple[str, Any]]: Path to the best checkpoint and its metadata, or None if not found
|
||||
"""
|
||||
try:
|
||||
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
||||
checkpoint_path, checkpoint_metadata = checkpoint_manager.load_best_checkpoint(model_name)
|
||||
# First try to get from database (fast metadata access)
|
||||
db_manager = get_database_manager()
|
||||
checkpoint_metadata = db_manager.get_best_checkpoint_metadata(model_name, "accuracy")
|
||||
|
||||
if not checkpoint_path:
|
||||
if not checkpoint_metadata:
|
||||
# Fallback to legacy file-based approach (no more scattered "No checkpoints found" logs)
|
||||
pass # Silent fallback
|
||||
checkpoint_manager = get_checkpoint_manager(checkpoint_dir=checkpoint_dir)
|
||||
checkpoint_path, legacy_metadata = checkpoint_manager.load_best_checkpoint(model_name)
|
||||
|
||||
if not checkpoint_path:
|
||||
return None
|
||||
|
||||
# Convert legacy metadata to object
|
||||
class CheckpointMetadataObj:
|
||||
def __init__(self, metadata):
|
||||
for key, value in metadata.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
# Add performance score if not present
|
||||
if not hasattr(self, 'performance_score'):
|
||||
metrics = getattr(self, 'metrics', {})
|
||||
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
|
||||
self.performance_score = metrics.get(primary_metric, 0.0)
|
||||
|
||||
# Add created_at if not present
|
||||
if not hasattr(self, 'created_at'):
|
||||
self.created_at = getattr(self, 'timestamp', 'unknown')
|
||||
|
||||
# Add loss for compatibility
|
||||
self.loss = metrics.get('loss', self.performance_score)
|
||||
self.checkpoint_id = getattr(self, 'checkpoint_id', f"{model_name}_unknown")
|
||||
|
||||
return f"{checkpoint_path}.pt", CheckpointMetadataObj(legacy_metadata)
|
||||
|
||||
# Check if checkpoint file exists
|
||||
if not os.path.exists(checkpoint_metadata.file_path):
|
||||
logger.warning(f"Checkpoint file not found: {checkpoint_metadata.file_path}")
|
||||
return None
|
||||
|
||||
# Convert metadata to object
|
||||
class CheckpointMetadata:
|
||||
def __init__(self, metadata):
|
||||
for key, value in metadata.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
# Add performance score if not present
|
||||
if not hasattr(self, 'performance_score'):
|
||||
metrics = getattr(self, 'metrics', {})
|
||||
primary_metric = 'accuracy' if 'accuracy' in metrics else 'reward'
|
||||
self.performance_score = metrics.get(primary_metric, 0.0)
|
||||
|
||||
# Add created_at if not present
|
||||
if not hasattr(self, 'created_at'):
|
||||
self.created_at = getattr(self, 'timestamp', 'unknown')
|
||||
# Log checkpoint load event to text file
|
||||
text_logger = get_text_logger()
|
||||
text_logger.log_checkpoint_event(
|
||||
model_name=model_name,
|
||||
event_type="LOADED",
|
||||
checkpoint_id=checkpoint_metadata.checkpoint_id,
|
||||
details=f"loss={checkpoint_metadata.performance_metrics.get('loss', 'N/A')}"
|
||||
)
|
||||
|
||||
return f"{checkpoint_path}.pt", CheckpointMetadata(checkpoint_metadata)
|
||||
# Convert database metadata to object for backward compatibility
|
||||
class CheckpointMetadataObj:
|
||||
def __init__(self, db_metadata: CheckpointMetadata):
|
||||
self.checkpoint_id = db_metadata.checkpoint_id
|
||||
self.model_name = db_metadata.model_name
|
||||
self.model_type = db_metadata.model_type
|
||||
self.timestamp = db_metadata.timestamp.strftime("%Y%m%d_%H%M%S")
|
||||
self.performance_metrics = db_metadata.performance_metrics
|
||||
self.training_metadata = db_metadata.training_metadata
|
||||
self.file_path = db_metadata.file_path
|
||||
self.file_size_mb = db_metadata.file_size_mb
|
||||
self.is_active = db_metadata.is_active
|
||||
|
||||
# Backward compatibility fields
|
||||
self.metrics = db_metadata.performance_metrics
|
||||
self.metadata = db_metadata.training_metadata
|
||||
self.created_at = self.timestamp
|
||||
self.performance_score = db_metadata.performance_metrics.get('accuracy',
|
||||
db_metadata.performance_metrics.get('reward', 0.0))
|
||||
self.loss = db_metadata.performance_metrics.get('loss', self.performance_score)
|
||||
|
||||
return checkpoint_metadata.file_path, CheckpointMetadataObj(checkpoint_metadata)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading best checkpoint: {e}")
|
||||
@ -254,7 +340,7 @@ class CheckpointManager:
|
||||
metadata_files = glob.glob(os.path.join(checkpoint_dir, f"{model_name}_*_metadata.json"))
|
||||
|
||||
if not metadata_files:
|
||||
logger.info(f"No checkpoints found for {model_name}")
|
||||
# No more scattered "No checkpoints found" logs - handled by database system
|
||||
return "", {}
|
||||
|
||||
# Load metadata for each checkpoint
|
||||
@ -278,7 +364,7 @@ class CheckpointManager:
|
||||
logger.error(f"Error loading checkpoint metadata {metadata_file}: {e}")
|
||||
|
||||
if not checkpoints:
|
||||
logger.info(f"No valid checkpoints found for {model_name}")
|
||||
# No more scattered logs - handled by database system
|
||||
return "", {}
|
||||
|
||||
# Sort by metric (highest first)
|
||||
|
408
utils/database_manager.py
Normal file
408
utils/database_manager.py
Normal file
@ -0,0 +1,408 @@
|
||||
"""
|
||||
Database Manager for Trading System
|
||||
|
||||
Manages SQLite database for:
|
||||
1. Inference records logging
|
||||
2. Checkpoint metadata storage
|
||||
3. Model performance tracking
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class InferenceRecord:
|
||||
"""Structure for inference logging"""
|
||||
model_name: str
|
||||
timestamp: datetime
|
||||
symbol: str
|
||||
action: str
|
||||
confidence: float
|
||||
probabilities: Dict[str, float]
|
||||
input_features_hash: str # Hash of input features for deduplication
|
||||
processing_time_ms: float
|
||||
memory_usage_mb: float
|
||||
checkpoint_id: Optional[str] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
@dataclass
|
||||
class CheckpointMetadata:
|
||||
"""Structure for checkpoint metadata"""
|
||||
checkpoint_id: str
|
||||
model_name: str
|
||||
model_type: str
|
||||
timestamp: datetime
|
||||
performance_metrics: Dict[str, float]
|
||||
training_metadata: Dict[str, Any]
|
||||
file_path: str
|
||||
file_size_mb: float
|
||||
is_active: bool = False # Currently loaded checkpoint
|
||||
|
||||
class DatabaseManager:
|
||||
"""Manages SQLite database for trading system logging and metadata"""
|
||||
|
||||
def __init__(self, db_path: str = "data/trading_system.db"):
|
||||
self.db_path = db_path
|
||||
self._ensure_db_directory()
|
||||
self._initialize_database()
|
||||
|
||||
def _ensure_db_directory(self):
|
||||
"""Ensure database directory exists"""
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
|
||||
def _initialize_database(self):
|
||||
"""Initialize database tables"""
|
||||
with self._get_connection() as conn:
|
||||
# Inference records table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS inference_records (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
symbol TEXT NOT NULL,
|
||||
action TEXT NOT NULL,
|
||||
confidence REAL NOT NULL,
|
||||
probabilities TEXT NOT NULL, -- JSON
|
||||
input_features_hash TEXT NOT NULL,
|
||||
processing_time_ms REAL NOT NULL,
|
||||
memory_usage_mb REAL NOT NULL,
|
||||
checkpoint_id TEXT,
|
||||
metadata TEXT, -- JSON
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Checkpoint metadata table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS checkpoint_metadata (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
checkpoint_id TEXT UNIQUE NOT NULL,
|
||||
model_name TEXT NOT NULL,
|
||||
model_type TEXT NOT NULL,
|
||||
timestamp TEXT NOT NULL,
|
||||
performance_metrics TEXT NOT NULL, -- JSON
|
||||
training_metadata TEXT NOT NULL, -- JSON
|
||||
file_path TEXT NOT NULL,
|
||||
file_size_mb REAL NOT NULL,
|
||||
is_active BOOLEAN DEFAULT FALSE,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
|
||||
# Model performance tracking table
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS model_performance (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
model_name TEXT NOT NULL,
|
||||
date TEXT NOT NULL,
|
||||
total_predictions INTEGER DEFAULT 0,
|
||||
correct_predictions INTEGER DEFAULT 0,
|
||||
accuracy REAL DEFAULT 0.0,
|
||||
avg_confidence REAL DEFAULT 0.0,
|
||||
avg_processing_time_ms REAL DEFAULT 0.0,
|
||||
created_at TEXT DEFAULT CURRENT_TIMESTAMP,
|
||||
UNIQUE(model_name, date)
|
||||
)
|
||||
""")
|
||||
|
||||
# Create indexes for better performance
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_model_timestamp ON inference_records(model_name, timestamp)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_inference_symbol ON inference_records(symbol)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_model ON checkpoint_metadata(model_name)")
|
||||
conn.execute("CREATE INDEX IF NOT EXISTS idx_checkpoint_active ON checkpoint_metadata(is_active)")
|
||||
|
||||
logger.info(f"Database initialized at {self.db_path}")
|
||||
|
||||
@contextmanager
|
||||
def _get_connection(self):
|
||||
"""Get database connection with proper error handling"""
|
||||
conn = None
|
||||
try:
|
||||
conn = sqlite3.connect(self.db_path, timeout=30.0)
|
||||
conn.row_factory = sqlite3.Row # Enable dict-like access
|
||||
yield conn
|
||||
except Exception as e:
|
||||
if conn:
|
||||
conn.rollback()
|
||||
logger.error(f"Database error: {e}")
|
||||
raise
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
def log_inference(self, record: InferenceRecord) -> bool:
|
||||
"""Log an inference record"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
conn.execute("""
|
||||
INSERT INTO inference_records (
|
||||
model_name, timestamp, symbol, action, confidence,
|
||||
probabilities, input_features_hash, processing_time_ms,
|
||||
memory_usage_mb, checkpoint_id, metadata
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
record.model_name,
|
||||
record.timestamp.isoformat(),
|
||||
record.symbol,
|
||||
record.action,
|
||||
record.confidence,
|
||||
json.dumps(record.probabilities),
|
||||
record.input_features_hash,
|
||||
record.processing_time_ms,
|
||||
record.memory_usage_mb,
|
||||
record.checkpoint_id,
|
||||
json.dumps(record.metadata) if record.metadata else None
|
||||
))
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log inference record: {e}")
|
||||
return False
|
||||
|
||||
def save_checkpoint_metadata(self, metadata: CheckpointMetadata) -> bool:
|
||||
"""Save checkpoint metadata"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
# First, set all other checkpoints for this model as inactive
|
||||
conn.execute("""
|
||||
UPDATE checkpoint_metadata
|
||||
SET is_active = FALSE
|
||||
WHERE model_name = ?
|
||||
""", (metadata.model_name,))
|
||||
|
||||
# Insert or replace the new checkpoint metadata
|
||||
conn.execute("""
|
||||
INSERT OR REPLACE INTO checkpoint_metadata (
|
||||
checkpoint_id, model_name, model_type, timestamp,
|
||||
performance_metrics, training_metadata, file_path,
|
||||
file_size_mb, is_active
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
metadata.checkpoint_id,
|
||||
metadata.model_name,
|
||||
metadata.model_type,
|
||||
metadata.timestamp.isoformat(),
|
||||
json.dumps(metadata.performance_metrics),
|
||||
json.dumps(metadata.training_metadata),
|
||||
metadata.file_path,
|
||||
metadata.file_size_mb,
|
||||
metadata.is_active
|
||||
))
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save checkpoint metadata: {e}")
|
||||
return False
|
||||
|
||||
def get_checkpoint_metadata(self, model_name: str, checkpoint_id: str = None) -> Optional[CheckpointMetadata]:
|
||||
"""Get checkpoint metadata without loading the actual model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
if checkpoint_id:
|
||||
# Get specific checkpoint
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ? AND checkpoint_id = ?
|
||||
""", (model_name, checkpoint_id))
|
||||
else:
|
||||
# Get active checkpoint
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ? AND is_active = TRUE
|
||||
ORDER BY timestamp DESC LIMIT 1
|
||||
""", (model_name,))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return CheckpointMetadata(
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
model_name=row['model_name'],
|
||||
model_type=row['model_type'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
performance_metrics=json.loads(row['performance_metrics']),
|
||||
training_metadata=json.loads(row['training_metadata']),
|
||||
file_path=row['file_path'],
|
||||
file_size_mb=row['file_size_mb'],
|
||||
is_active=bool(row['is_active'])
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get checkpoint metadata: {e}")
|
||||
return None
|
||||
|
||||
def get_best_checkpoint_metadata(self, model_name: str, metric_name: str = "accuracy") -> Optional[CheckpointMetadata]:
|
||||
"""Get best checkpoint metadata based on performance metric"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ?
|
||||
ORDER BY json_extract(performance_metrics, '$.' || ?) DESC
|
||||
LIMIT 1
|
||||
""", (model_name, metric_name))
|
||||
|
||||
row = cursor.fetchone()
|
||||
if row:
|
||||
return CheckpointMetadata(
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
model_name=row['model_name'],
|
||||
model_type=row['model_type'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
performance_metrics=json.loads(row['performance_metrics']),
|
||||
training_metadata=json.loads(row['training_metadata']),
|
||||
file_path=row['file_path'],
|
||||
file_size_mb=row['file_size_mb'],
|
||||
is_active=bool(row['is_active'])
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get best checkpoint metadata: {e}")
|
||||
return None
|
||||
|
||||
def list_checkpoints(self, model_name: str) -> List[CheckpointMetadata]:
|
||||
"""List all checkpoints for a model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM checkpoint_metadata
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
""", (model_name,))
|
||||
|
||||
checkpoints = []
|
||||
for row in cursor.fetchall():
|
||||
checkpoints.append(CheckpointMetadata(
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
model_name=row['model_name'],
|
||||
model_type=row['model_type'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
performance_metrics=json.loads(row['performance_metrics']),
|
||||
training_metadata=json.loads(row['training_metadata']),
|
||||
file_path=row['file_path'],
|
||||
file_size_mb=row['file_size_mb'],
|
||||
is_active=bool(row['is_active'])
|
||||
))
|
||||
return checkpoints
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list checkpoints: {e}")
|
||||
return []
|
||||
|
||||
def set_active_checkpoint(self, model_name: str, checkpoint_id: str) -> bool:
|
||||
"""Set a checkpoint as active for a model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
# First, set all checkpoints for this model as inactive
|
||||
conn.execute("""
|
||||
UPDATE checkpoint_metadata
|
||||
SET is_active = FALSE
|
||||
WHERE model_name = ?
|
||||
""", (model_name,))
|
||||
|
||||
# Set the specified checkpoint as active
|
||||
cursor = conn.execute("""
|
||||
UPDATE checkpoint_metadata
|
||||
SET is_active = TRUE
|
||||
WHERE model_name = ? AND checkpoint_id = ?
|
||||
""", (model_name, checkpoint_id))
|
||||
|
||||
conn.commit()
|
||||
return cursor.rowcount > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to set active checkpoint: {e}")
|
||||
return False
|
||||
|
||||
def get_recent_inferences(self, model_name: str, limit: int = 100) -> List[InferenceRecord]:
|
||||
"""Get recent inference records for a model"""
|
||||
try:
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
SELECT * FROM inference_records
|
||||
WHERE model_name = ?
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT ?
|
||||
""", (model_name, limit))
|
||||
|
||||
records = []
|
||||
for row in cursor.fetchall():
|
||||
records.append(InferenceRecord(
|
||||
model_name=row['model_name'],
|
||||
timestamp=datetime.fromisoformat(row['timestamp']),
|
||||
symbol=row['symbol'],
|
||||
action=row['action'],
|
||||
confidence=row['confidence'],
|
||||
probabilities=json.loads(row['probabilities']),
|
||||
input_features_hash=row['input_features_hash'],
|
||||
processing_time_ms=row['processing_time_ms'],
|
||||
memory_usage_mb=row['memory_usage_mb'],
|
||||
checkpoint_id=row['checkpoint_id'],
|
||||
metadata=json.loads(row['metadata']) if row['metadata'] else None
|
||||
))
|
||||
return records
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get recent inferences: {e}")
|
||||
return []
|
||||
|
||||
def update_model_performance(self, model_name: str, date: str,
|
||||
total_predictions: int, correct_predictions: int,
|
||||
avg_confidence: float, avg_processing_time: float) -> bool:
|
||||
"""Update daily model performance statistics"""
|
||||
try:
|
||||
accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0.0
|
||||
|
||||
with self._get_connection() as conn:
|
||||
conn.execute("""
|
||||
INSERT OR REPLACE INTO model_performance (
|
||||
model_name, date, total_predictions, correct_predictions,
|
||||
accuracy, avg_confidence, avg_processing_time_ms
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
""", (
|
||||
model_name, date, total_predictions, correct_predictions,
|
||||
accuracy, avg_confidence, avg_processing_time
|
||||
))
|
||||
conn.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update model performance: {e}")
|
||||
return False
|
||||
|
||||
def cleanup_old_records(self, days_to_keep: int = 30) -> bool:
|
||||
"""Clean up old inference records"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_to_keep)
|
||||
|
||||
with self._get_connection() as conn:
|
||||
cursor = conn.execute("""
|
||||
DELETE FROM inference_records
|
||||
WHERE timestamp < ?
|
||||
""", (cutoff_date.isoformat(),))
|
||||
|
||||
deleted_count = cursor.rowcount
|
||||
conn.commit()
|
||||
|
||||
if deleted_count > 0:
|
||||
logger.info(f"Cleaned up {deleted_count} old inference records")
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup old records: {e}")
|
||||
return False
|
||||
|
||||
# Global database manager instance
|
||||
_db_manager_instance = None
|
||||
|
||||
def get_database_manager(db_path: str = "data/trading_system.db") -> DatabaseManager:
|
||||
"""Get the global database manager instance"""
|
||||
global _db_manager_instance
|
||||
|
||||
if _db_manager_instance is None:
|
||||
_db_manager_instance = DatabaseManager(db_path)
|
||||
|
||||
return _db_manager_instance
|
226
utils/inference_logger.py
Normal file
226
utils/inference_logger.py
Normal file
@ -0,0 +1,226 @@
|
||||
"""
|
||||
Inference Logger
|
||||
|
||||
Centralized logging system for model inferences with database storage
|
||||
Eliminates scattered logging throughout the codebase
|
||||
"""
|
||||
|
||||
import time
|
||||
import hashlib
|
||||
import logging
|
||||
import psutil
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from dataclasses import dataclass
|
||||
import numpy as np
|
||||
|
||||
from .database_manager import get_database_manager, InferenceRecord
|
||||
from .text_logger import get_text_logger
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class InferenceLogger:
|
||||
"""Centralized inference logging system"""
|
||||
|
||||
def __init__(self):
|
||||
self.db_manager = get_database_manager()
|
||||
self.text_logger = get_text_logger()
|
||||
self._process = psutil.Process()
|
||||
|
||||
def log_inference(self,
|
||||
model_name: str,
|
||||
symbol: str,
|
||||
action: str,
|
||||
confidence: float,
|
||||
probabilities: Dict[str, float],
|
||||
input_features: Union[np.ndarray, Dict, List],
|
||||
processing_time_ms: float,
|
||||
checkpoint_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Log a model inference with all relevant details
|
||||
|
||||
Args:
|
||||
model_name: Name of the model making the prediction
|
||||
symbol: Trading symbol
|
||||
action: Predicted action (BUY/SELL/HOLD)
|
||||
confidence: Confidence score (0.0 to 1.0)
|
||||
probabilities: Action probabilities dict
|
||||
input_features: Input features used for prediction
|
||||
processing_time_ms: Time taken for inference in milliseconds
|
||||
checkpoint_id: ID of the checkpoint used
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
bool: True if logged successfully
|
||||
"""
|
||||
try:
|
||||
# Create feature hash for deduplication
|
||||
feature_hash = self._hash_features(input_features)
|
||||
|
||||
# Get current memory usage
|
||||
memory_usage_mb = self._get_memory_usage()
|
||||
|
||||
# Create inference record
|
||||
record = InferenceRecord(
|
||||
model_name=model_name,
|
||||
timestamp=datetime.now(),
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
probabilities=probabilities,
|
||||
input_features_hash=feature_hash,
|
||||
processing_time_ms=processing_time_ms,
|
||||
memory_usage_mb=memory_usage_mb,
|
||||
checkpoint_id=checkpoint_id,
|
||||
metadata=metadata
|
||||
)
|
||||
|
||||
# Log to database
|
||||
db_success = self.db_manager.log_inference(record)
|
||||
|
||||
# Log to text file
|
||||
text_success = self.text_logger.log_inference(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
processing_time_ms=processing_time_ms,
|
||||
checkpoint_id=checkpoint_id
|
||||
)
|
||||
|
||||
if db_success:
|
||||
# Reduced logging - no more scattered logs at runtime
|
||||
pass # Database logging successful, text file provides human-readable record
|
||||
else:
|
||||
logger.error(f"Failed to log inference for {model_name}")
|
||||
|
||||
return db_success and text_success
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error logging inference: {e}")
|
||||
return False
|
||||
|
||||
def _hash_features(self, features: Union[np.ndarray, Dict, List]) -> str:
|
||||
"""Create a hash of input features for deduplication"""
|
||||
try:
|
||||
if isinstance(features, np.ndarray):
|
||||
# Hash numpy array
|
||||
return hashlib.md5(features.tobytes()).hexdigest()[:16]
|
||||
elif isinstance(features, (dict, list)):
|
||||
# Hash dict or list by converting to string
|
||||
feature_str = str(sorted(features.items()) if isinstance(features, dict) else features)
|
||||
return hashlib.md5(feature_str.encode()).hexdigest()[:16]
|
||||
else:
|
||||
# Hash string representation
|
||||
return hashlib.md5(str(features).encode()).hexdigest()[:16]
|
||||
except Exception:
|
||||
# Fallback to timestamp-based hash
|
||||
return hashlib.md5(str(time.time()).encode()).hexdigest()[:16]
|
||||
|
||||
def _get_memory_usage(self) -> float:
|
||||
"""Get current memory usage in MB"""
|
||||
try:
|
||||
return self._process.memory_info().rss / (1024 * 1024)
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def get_model_stats(self, model_name: str, hours: int = 24) -> Dict[str, Any]:
|
||||
"""Get inference statistics for a model"""
|
||||
try:
|
||||
# Get recent inferences
|
||||
recent_inferences = self.db_manager.get_recent_inferences(model_name, limit=1000)
|
||||
|
||||
if not recent_inferences:
|
||||
return {
|
||||
'total_inferences': 0,
|
||||
'avg_confidence': 0.0,
|
||||
'avg_processing_time_ms': 0.0,
|
||||
'action_distribution': {},
|
||||
'symbol_distribution': {}
|
||||
}
|
||||
|
||||
# Filter by time window
|
||||
cutoff_time = datetime.now() - timedelta(hours=hours)
|
||||
recent_inferences = [r for r in recent_inferences if r.timestamp >= cutoff_time]
|
||||
|
||||
if not recent_inferences:
|
||||
return {
|
||||
'total_inferences': 0,
|
||||
'avg_confidence': 0.0,
|
||||
'avg_processing_time_ms': 0.0,
|
||||
'action_distribution': {},
|
||||
'symbol_distribution': {}
|
||||
}
|
||||
|
||||
# Calculate statistics
|
||||
total_inferences = len(recent_inferences)
|
||||
avg_confidence = sum(r.confidence for r in recent_inferences) / total_inferences
|
||||
avg_processing_time = sum(r.processing_time_ms for r in recent_inferences) / total_inferences
|
||||
|
||||
# Action distribution
|
||||
action_counts = {}
|
||||
for record in recent_inferences:
|
||||
action_counts[record.action] = action_counts.get(record.action, 0) + 1
|
||||
|
||||
# Symbol distribution
|
||||
symbol_counts = {}
|
||||
for record in recent_inferences:
|
||||
symbol_counts[record.symbol] = symbol_counts.get(record.symbol, 0) + 1
|
||||
|
||||
return {
|
||||
'total_inferences': total_inferences,
|
||||
'avg_confidence': avg_confidence,
|
||||
'avg_processing_time_ms': avg_processing_time,
|
||||
'action_distribution': action_counts,
|
||||
'symbol_distribution': symbol_counts,
|
||||
'latest_inference': recent_inferences[0].timestamp.isoformat() if recent_inferences else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model stats: {e}")
|
||||
return {}
|
||||
|
||||
def cleanup_old_logs(self, days_to_keep: int = 30) -> bool:
|
||||
"""Clean up old inference logs"""
|
||||
return self.db_manager.cleanup_old_records(days_to_keep)
|
||||
|
||||
# Global inference logger instance
|
||||
_inference_logger_instance = None
|
||||
|
||||
def get_inference_logger() -> InferenceLogger:
|
||||
"""Get the global inference logger instance"""
|
||||
global _inference_logger_instance
|
||||
|
||||
if _inference_logger_instance is None:
|
||||
_inference_logger_instance = InferenceLogger()
|
||||
|
||||
return _inference_logger_instance
|
||||
|
||||
def log_model_inference(model_name: str,
|
||||
symbol: str,
|
||||
action: str,
|
||||
confidence: float,
|
||||
probabilities: Dict[str, float],
|
||||
input_features: Union[np.ndarray, Dict, List],
|
||||
processing_time_ms: float,
|
||||
checkpoint_id: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None) -> bool:
|
||||
"""
|
||||
Convenience function to log model inference
|
||||
|
||||
This is the main function that should be called throughout the codebase
|
||||
instead of scattered logger.info() calls
|
||||
"""
|
||||
inference_logger = get_inference_logger()
|
||||
return inference_logger.log_inference(
|
||||
model_name=model_name,
|
||||
symbol=symbol,
|
||||
action=action,
|
||||
confidence=confidence,
|
||||
probabilities=probabilities,
|
||||
input_features=input_features,
|
||||
processing_time_ms=processing_time_ms,
|
||||
checkpoint_id=checkpoint_id,
|
||||
metadata=metadata
|
||||
)
|
156
utils/text_logger.py
Normal file
156
utils/text_logger.py
Normal file
@ -0,0 +1,156 @@
|
||||
"""
|
||||
Text File Logger for Trading System
|
||||
|
||||
Simple text file logging for tracking inference records and system events
|
||||
Provides human-readable logs alongside database storage
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TextLogger:
|
||||
"""Simple text file logger for trading system events"""
|
||||
|
||||
def __init__(self, log_dir: str = "logs"):
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(exist_ok=True)
|
||||
|
||||
# Create separate log files for different types of events
|
||||
self.inference_log = self.log_dir / "inference_records.txt"
|
||||
self.checkpoint_log = self.log_dir / "checkpoint_events.txt"
|
||||
self.system_log = self.log_dir / "system_events.txt"
|
||||
|
||||
def log_inference(self, model_name: str, symbol: str, action: str,
|
||||
confidence: float, processing_time_ms: float,
|
||||
checkpoint_id: str = None) -> bool:
|
||||
"""Log inference record to text file"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
checkpoint_info = f" [checkpoint: {checkpoint_id}]" if checkpoint_id else ""
|
||||
|
||||
log_entry = (
|
||||
f"{timestamp} | {model_name:15} | {symbol:10} | "
|
||||
f"{action:4} | conf={confidence:.3f} | "
|
||||
f"time={processing_time_ms:6.1f}ms{checkpoint_info}\n"
|
||||
)
|
||||
|
||||
with open(self.inference_log, 'a', encoding='utf-8') as f:
|
||||
f.write(log_entry)
|
||||
f.flush()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log inference to text file: {e}")
|
||||
return False
|
||||
|
||||
def log_checkpoint_event(self, model_name: str, event_type: str,
|
||||
checkpoint_id: str, details: str = "") -> bool:
|
||||
"""Log checkpoint events (save, load, etc.)"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
details_str = f" - {details}" if details else ""
|
||||
|
||||
log_entry = (
|
||||
f"{timestamp} | {model_name:15} | {event_type:10} | "
|
||||
f"{checkpoint_id}{details_str}\n"
|
||||
)
|
||||
|
||||
with open(self.checkpoint_log, 'a', encoding='utf-8') as f:
|
||||
f.write(log_entry)
|
||||
f.flush()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log checkpoint event to text file: {e}")
|
||||
return False
|
||||
|
||||
def log_system_event(self, event_type: str, message: str,
|
||||
component: str = "system") -> bool:
|
||||
"""Log general system events"""
|
||||
try:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
log_entry = (
|
||||
f"{timestamp} | {component:15} | {event_type:10} | {message}\n"
|
||||
)
|
||||
|
||||
with open(self.system_log, 'a', encoding='utf-8') as f:
|
||||
f.write(log_entry)
|
||||
f.flush()
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to log system event to text file: {e}")
|
||||
return False
|
||||
|
||||
def get_recent_inferences(self, lines: int = 50) -> str:
|
||||
"""Get recent inference records from text file"""
|
||||
try:
|
||||
if not self.inference_log.exists():
|
||||
return "No inference records found"
|
||||
|
||||
with open(self.inference_log, 'r', encoding='utf-8') as f:
|
||||
all_lines = f.readlines()
|
||||
recent_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
|
||||
return ''.join(recent_lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read inference log: {e}")
|
||||
return f"Error reading log: {e}"
|
||||
|
||||
def get_recent_checkpoint_events(self, lines: int = 20) -> str:
|
||||
"""Get recent checkpoint events from text file"""
|
||||
try:
|
||||
if not self.checkpoint_log.exists():
|
||||
return "No checkpoint events found"
|
||||
|
||||
with open(self.checkpoint_log, 'r', encoding='utf-8') as f:
|
||||
all_lines = f.readlines()
|
||||
recent_lines = all_lines[-lines:] if len(all_lines) > lines else all_lines
|
||||
return ''.join(recent_lines)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read checkpoint log: {e}")
|
||||
return f"Error reading log: {e}"
|
||||
|
||||
def cleanup_old_logs(self, max_lines: int = 10000) -> bool:
|
||||
"""Keep only the most recent log entries"""
|
||||
try:
|
||||
for log_file in [self.inference_log, self.checkpoint_log, self.system_log]:
|
||||
if log_file.exists():
|
||||
with open(log_file, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
if len(lines) > max_lines:
|
||||
# Keep only the most recent lines
|
||||
recent_lines = lines[-max_lines:]
|
||||
with open(log_file, 'w', encoding='utf-8') as f:
|
||||
f.writelines(recent_lines)
|
||||
|
||||
logger.info(f"Cleaned up {log_file.name}: kept {len(recent_lines)} lines")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup logs: {e}")
|
||||
return False
|
||||
|
||||
# Global text logger instance
|
||||
_text_logger_instance = None
|
||||
|
||||
def get_text_logger(log_dir: str = "logs") -> TextLogger:
|
||||
"""Get the global text logger instance"""
|
||||
global _text_logger_instance
|
||||
|
||||
if _text_logger_instance is None:
|
||||
_text_logger_instance = TextLogger(log_dir)
|
||||
|
||||
return _text_logger_instance
|
Reference in New Issue
Block a user