190 Commits

Author SHA1 Message Date
d0b678015e mark tasks 2025-08-04 23:50:37 +03:00
3c7d13416f Binance (completed in earlier tasks)
 Coinbase Pro (completed in task 12)
 Kraken (completed in task 12)
 Bybit (completed in task 13)
 OKX (completed in task 13)
 Huobi (completed in task 13)
 KuCoin (completed in this task)
 Gate.io (completed in this task)
 Bitfinex (completed in this task)
 MEXC (completed in this task)
2025-08-04 23:49:08 +03:00
7339972eab Binance (completed previously)
 Coinbase Pro (completed in task 12)
 Kraken (completed in task 12)
 Bybit (completed in this task)
 OKX (completed in this task)
 Huobi (completed in this task)
2025-08-04 23:41:42 +03:00
4170553cf3 Binance (completed previously)
 Coinbase Pro (completed in this task)
 Kraken (completed in this task)
2025-08-04 23:21:21 +03:00
68a556e09c orchestrator adaptor 2025-08-04 23:13:44 +03:00
1479ac1624 replay system 2025-08-04 22:46:11 +03:00
db61f3c3bf storage manager 2025-08-04 21:50:11 +03:00
42cf02cf3a web dash 2025-08-04 20:28:48 +03:00
fd6ec4eb40 api 2025-08-04 18:38:51 +03:00
ff75af566c caching 2025-08-04 17:55:00 +03:00
8ee9b7a90c wip 2025-08-04 17:40:30 +03:00
de77b0afa8 bucket aggregation 2025-08-04 17:28:55 +03:00
504736c0f7 cob integration scaffold 2025-08-04 17:12:26 +03:00
de9fa4a421 COBY : specs + task 1 2025-08-04 15:50:54 +03:00
e223bc90e9 inference_enabled, cleanup 2025-08-04 14:24:39 +03:00
29382ac0db price vector predictions 2025-07-29 23:45:57 +03:00
3fad2caeb8 decision model card 2025-07-29 23:42:46 +03:00
a204362df2 model cards back 2025-07-29 23:14:00 +03:00
ab5784b890 normalize by unified price range 2025-07-29 22:05:28 +03:00
aa2a1bf7ee fixed CNN training 2025-07-29 20:11:22 +03:00
b1ae557843 models overhaul 2025-07-29 19:22:04 +03:00
0b5fa07498 ui fixes 2025-07-29 19:02:44 +03:00
ac4068c168 suppress_callback_exceptions 2025-07-29 18:20:07 +03:00
5f7032937e UI dash fix 2025-07-29 17:49:25 +03:00
3a532a1220 PnL in reward, show leveraged power in dash (broken) 2025-07-29 17:42:00 +03:00
d35530a9e9 win uni toggle 2025-07-29 16:10:45 +03:00
ecbbabc0c1 inf/trn toggles UI 2025-07-29 15:51:18 +03:00
ff41f0a278 training wip 2025-07-29 15:25:36 +03:00
b3e3a7673f TZ wip, UI model stats fix 2025-07-29 15:12:48 +03:00
afde58bc40 wip model CP storage/loading,
models are aware of current position
fix kill stale procc task
2025-07-29 14:51:40 +03:00
f34b2a46a2 better decision details 2025-07-29 09:49:09 +03:00
e2ededcdf0 fuse decision fusion 2025-07-29 09:09:11 +03:00
f4ac504963 fix model toggle 2025-07-29 00:52:58 +03:00
b44216ae1e UI: fix models info 2025-07-29 00:46:16 +03:00
aefc460082 wip dqn state 2025-07-29 00:25:31 +03:00
ea4db519de more info at signals 2025-07-29 00:20:07 +03:00
e1e453c204 dqn model data fix 2025-07-29 00:09:13 +03:00
548c0d5e0f ui state, models toggle 2025-07-28 23:49:47 +03:00
a341fade80 wip 2025-07-28 22:09:15 +03:00
bc4b72c6de add decision fusion. training but not enabled.
reports cleanup
2025-07-28 18:22:13 +03:00
233bb9935c fixed trading and leverage 2025-07-28 16:57:02 +03:00
db23ad10da trading risk management 2025-07-28 16:42:11 +03:00
44821b2a89 UI and stability 2025-07-28 14:05:37 +03:00
25b2d3840a ui fix 2025-07-28 12:15:26 +03:00
fb72c93743 stability 2025-07-28 12:10:52 +03:00
9219b78241 UI 2025-07-28 11:44:01 +03:00
7c508ab536 cob 2025-07-28 11:12:42 +03:00
1084b7f5b5 cob buffered 2025-07-28 10:31:24 +03:00
619e39ac9b binance WS api enhanced 2025-07-28 10:26:47 +03:00
f5416c4f1e cob update fix 2025-07-28 09:46:49 +03:00
240d2b7877 stats, standartized data provider 2025-07-28 08:35:08 +03:00
6efaa27c33 dix price ccalls 2025-07-28 00:14:03 +03:00
b4076241c9 training wip 2025-07-27 23:45:57 +03:00
39267697f3 predict price direction 2025-07-27 23:20:47 +03:00
dfa18035f1 untrack sqlite 2025-07-27 22:46:19 +03:00
368c49df50 device fix , TZ fix 2025-07-27 22:13:28 +03:00
9e1684f9f8 cb ws 2025-07-27 20:56:37 +03:00
bd986f4534 beef up DQN model, fix training issues 2025-07-27 20:48:44 +03:00
1894d453c9 timezones 2025-07-27 20:43:28 +03:00
1636082ba3 CNN adapter retired 2025-07-27 20:38:04 +03:00
d333681447 wip train 2025-07-27 20:34:51 +03:00
ff66cb8b79 fix TA warning 2025-07-27 20:11:37 +03:00
64dbfa3780 training fix 2025-07-27 20:08:33 +03:00
86373fd5a7 training 2025-07-27 19:45:16 +03:00
87c0dc8ac4 wip training and inference stats 2025-07-27 19:20:23 +03:00
2a21878ed5 wip training 2025-07-27 19:07:34 +03:00
e2c495d83c cleanup 2025-07-27 18:31:30 +03:00
a94b80c1f4 decouple external API and local data consumption 2025-07-27 17:28:07 +03:00
fec6acb783 wip UI clear session 2025-07-27 17:21:16 +03:00
74e98709ad stats 2025-07-27 00:31:50 +03:00
13155197f8 inference works 2025-07-27 00:24:32 +03:00
36a8e256a8 fix DQN RL inference, rebuild model 2025-07-26 23:57:03 +03:00
87942d3807 cleanup and removed dummy data 2025-07-26 23:35:14 +03:00
3eb6335169 inrefence predictions fix 2025-07-26 23:34:36 +03:00
7c61c12b70 stability fixes, lower updates 2025-07-26 22:32:45 +03:00
9576c52039 optimize updates, remove fifo for simple cache 2025-07-26 22:17:29 +03:00
c349ff6f30 fifo n1 que 2025-07-26 21:34:16 +03:00
a3828c708c fix netwrk rebuild 2025-07-25 23:59:51 +03:00
43ed694917 fix checkpoints wip 2025-07-25 23:59:28 +03:00
50c6dae485 UI 2025-07-25 23:37:34 +03:00
22524b0389 cache fix 2025-07-25 22:46:23 +03:00
dd9f4b63ba sqlite for checkpoints, cleanup 2025-07-25 22:34:13 +03:00
130a52fb9b improved reward/penalty 2025-07-25 14:15:43 +03:00
26eeb9b35b ACTUAL TRAINING WORKING (WIP) 2025-07-25 14:08:25 +03:00
1f60c80d67 device tensor fix 2025-07-25 13:59:33 +03:00
78b4bb0f06 wip, training still disabled 2025-07-24 16:20:37 +03:00
045780758a wip symbols tidy up 2025-07-24 16:08:58 +03:00
d17af5ca4b inference data storage 2025-07-24 15:31:57 +03:00
fa07265a16 wip training 2025-07-24 15:27:32 +03:00
b3edd21f1b cnn training stats on dash 2025-07-24 14:28:28 +03:00
5437495003 wip cnn training and cob 2025-07-23 23:33:36 +03:00
8677c4c01c cob wip 2025-07-23 23:10:54 +03:00
8ba52640bd wip cob test 2025-07-23 22:56:28 +03:00
4765b1b1e1 cob data providers tests 2025-07-23 22:49:54 +03:00
c30267bf0b COB tests and data analysis 2025-07-23 22:39:10 +03:00
94ee7389c4 CNN training first working 2025-07-23 22:39:00 +03:00
26e6ba2e1d integrate CNN, fix COB data 2025-07-23 22:12:10 +03:00
45a62443a0 checkpoint manager 2025-07-23 22:11:19 +03:00
bab39fa68f dash inference fixes 2025-07-23 17:37:11 +03:00
2a0f8f5199 integratoin fixes - COB and CNN 2025-07-23 17:33:43 +03:00
f1d63f9da6 integrating new CNN model 2025-07-23 16:59:35 +03:00
1be270cc5c using new data probider and StandardizedCNN 2025-07-23 16:27:16 +03:00
735ee255bc new cnn model 2025-07-23 16:13:41 +03:00
dbb918ea92 wip 2025-07-23 15:52:40 +03:00
2b3c6abdeb refine design 2025-07-23 15:00:08 +03:00
55ea3bce93 feat: Добавяне на подобрена реализация на оркестратора съгласно изискванията в дизайнерския документ
Co-authored-by: aider (openai/Qwen/Qwen3-Coder-480B-A35B-Instruct) <aider@aider.chat>
2025-07-23 14:08:27 +03:00
56b35bd362 more design 2025-07-23 13:48:31 +03:00
f759eac04b updated design 2025-07-23 13:39:50 +03:00
df17a99247 wip 2025-07-23 13:39:41 +03:00
944a7b79e6 aider 2025-07-23 13:09:19 +03:00
8ad153aab5 aider 2025-07-23 11:23:15 +03:00
f515035ea0 use hyperbolic direactly instead of openrouter 2025-07-23 11:15:31 +03:00
3914ba40cf aider openrouter 2025-07-23 11:08:41 +03:00
7c8f52c07a aider 2025-07-23 10:28:19 +03:00
b0bc6c2a65 misc 2025-07-23 10:17:09 +03:00
630bc644fa wip 2025-07-22 20:23:17 +03:00
9b72b18eb7 references 2025-07-22 16:53:36 +03:00
1d224e5b8c references 2025-07-22 16:28:16 +03:00
a68df64b83 code structure 2025-07-22 16:23:13 +03:00
cc0c783411 cp man 2025-07-22 16:13:42 +03:00
c63dc11c14 cleanup 2025-07-22 16:08:58 +03:00
1a54fb1d56 fix model mappings,dash updates, trading 2025-07-22 15:44:59 +03:00
3e35b9cddb leverage calc fix 2025-07-20 22:41:37 +03:00
0838a828ce refactoring cob ws 2025-07-20 21:23:27 +03:00
330f0de053 COB WS fix 2025-07-20 20:38:42 +03:00
9c56ea238e dynamic profitabiliy reward 2025-07-20 18:08:37 +03:00
a2c07a1f3e dash working 2025-07-20 14:27:11 +03:00
0bb4409c30 fix syntax 2025-07-20 12:39:34 +03:00
12865fd3ef replay system 2025-07-20 12:37:02 +03:00
469269e809 working with errors 2025-07-20 01:52:36 +03:00
92919cb1ef adjust weights 2025-07-17 21:50:27 +03:00
23f0caea74 safety measures - 5 consequtive losses 2025-07-17 21:06:49 +03:00
26d440f772 artificially doule fees to promote more profitable trades 2025-07-17 19:22:35 +03:00
6d55061e86 wip training 2025-07-17 02:51:20 +03:00
c3010a6737 dash fixes 2025-07-17 02:25:52 +03:00
6b9482d2be pivots 2025-07-17 02:15:24 +03:00
b4e592b406 kiro tasks 2025-07-17 01:02:16 +03:00
f73cd17dfc kiro design and requirements 2025-07-17 00:57:50 +03:00
8023dae18f wip 2025-07-15 11:12:30 +03:00
e586d850f1 trading sim agin while training 2025-07-15 03:04:34 +03:00
0b07825be0 limit max positions 2025-07-15 02:27:33 +03:00
439611cf88 trading works! 2025-07-15 01:10:37 +03:00
24230f7f79 leverae tweak 2025-07-15 00:51:42 +03:00
154fa75c93 revert broken changes - indentations 2025-07-15 00:39:26 +03:00
a7905ce4e9 test bybit opening/closing orders 2025-07-15 00:03:59 +03:00
5b2dd3b0b8 bybit ballance working 2025-07-14 23:20:01 +03:00
02804ee64f bybit REST api 2025-07-14 22:57:02 +03:00
ee2e6478d8 bybit 2025-07-14 22:23:27 +03:00
4a55c5ff03 deribit 2025-07-14 17:56:09 +03:00
d53a2ba75d live position sync for LIMIT orders 2025-07-14 14:50:30 +03:00
f861559319 work with order execution - we are forced to do limit orders over the API 2025-07-14 13:36:07 +03:00
d7205a9745 lock with timeout 2025-07-14 13:03:42 +03:00
ab232a1262 in the bussiness -but wip 2025-07-14 12:58:16 +03:00
c651ae585a mexc debug files 2025-07-14 12:32:06 +03:00
0c54899fef MEXC INTEGRATION WORKS!!! 2025-07-14 11:23:13 +03:00
d42c9ada8c mexc interface integrations REST API fixes 2025-07-14 11:15:11 +03:00
e74f1393c4 training fixes and enhancements wip 2025-07-14 10:00:42 +03:00
e76b1b16dc training fixes 2025-07-14 00:47:44 +03:00
ebf65494a8 try to fix input dimentions 2025-07-13 23:41:47 +03:00
bcc13a5db3 training wip 2025-07-13 11:29:01 +03:00
2d8f763eeb improve training and model data 2025-07-07 15:48:25 +03:00
271e7d59b5 fixed cob 2025-07-07 01:44:16 +03:00
c2c0e12a4b behaviour/agressiveness sliders, fix cob data using provider 2025-07-07 01:37:04 +03:00
9101448e78 cleanup, cob ladder still broken 2025-07-07 01:07:48 +03:00
97d9bc97ee ETS integration and UI 2025-07-05 00:33:32 +03:00
d260e73f9a integration of (legacy) training systems, initialize, train, show on the UI 2025-07-05 00:33:03 +03:00
5ca7493708 cleanup, CNN fixes 2025-07-05 00:12:40 +03:00
ce8c00a9d1 remove dummy data, improve training , follow architecture 2025-07-04 23:51:35 +03:00
e8b9c05148 risk managment 2025-07-04 20:52:40 +03:00
ed42e7c238 execution and training fixes 2025-07-04 20:45:39 +03:00
0c4c682498 improve orchestrator 2025-07-04 02:26:38 +03:00
d0cf04536c fix dash actions 2025-07-04 02:24:18 +03:00
cf91e090c8 i think we fixed mexc interface at the end!!! 2025-07-04 02:14:29 +03:00
978cecf0c5 fix indentations 2025-07-03 03:03:35 +03:00
8bacf3c537 capcha and credentials stored in json. test intgration 2025-07-03 02:59:21 +03:00
ab73f95a3f capturing capcha tokens 2025-07-03 02:31:01 +03:00
09ed86c8ae capture more capcha info 2025-07-03 02:20:21 +03:00
e4a611a0cc selenium session, captcha 2025-07-03 02:06:09 +03:00
936ccf10e6 try to improve captcha support 2025-07-03 01:23:00 +03:00
5bd5c9f14d mexc webclient captcha debug 2025-07-03 01:20:38 +03:00
118c34b990 mexc API failed, working on futures API as it what i we need anyway 2025-07-03 00:56:02 +03:00
568ec049db Best checkpoint file not found 2025-07-03 00:44:31 +03:00
d15ebf54ca improve training on signals, add save session button to store all progress 2025-07-02 10:59:13 +03:00
488fbacf67 show each model's prediction (last inference) and store T model checkpoint 2025-07-02 09:52:45 +03:00
b47805dafc cob signas 2025-07-02 03:31:37 +03:00
11718bf92f loss /performance display 2025-07-02 03:29:38 +03:00
29e4076638 template dash using real integrations (wip) 2025-07-02 03:05:11 +03:00
03573cfb56 Fix templated dashboard Dash compatibility and change port to 8052\n\n- Fixed html.Style compatibility issue by removing custom CSS for now\n- Fixed app.run_server() deprecation by changing to app.run()\n- Changed default port from 8051 to 8052 to avoid conflicts\n- Templated dashboard now starts successfully on port 8052\n- Template-based MVC architecture is fully functional\n- Demonstrates clean separation of HTML templates and Python logic 2025-07-02 02:09:49 +03:00
083c1272ae Fix templated dashboard Dash import compatibility\n\n- Fixed obsolete dash_html_components import in template_renderer.py\n- Changed from 'import dash_html_components as html' to 'from dash import html, dcc'\n- Templated dashboard now starts successfully on port 8051\n- Compatible with modern Dash versions where html/dcc components are in dash package\n- Template-based MVC architecture is now fully functional 2025-07-02 02:04:45 +03:00
b9159690ef Fix COB ladder bucket sizes: ETH uses buckets, BTC uses buckets
- Fixed hardcoded bucket_size = 10 in component_manager.py
- Now uses symbol-specific bucket sizes: ETH = , BTC =
- Matches the COB provider configuration and launch.json settings
- ETH/USDT will now show proper  price granularity in dashboard
- BTC/USDT continues to use  buckets as intended
2025-07-02 01:59:54 +03:00
478 changed files with 95890 additions and 54410 deletions

25
.aider.conf.yml Normal file
View File

@ -0,0 +1,25 @@
# Aider configuration file
# For more information, see: https://aider.chat/docs/config/aider_conf.html
# Configure for Hyperbolic API (OpenAI-compatible endpoint)
# hyperbolic
model: openai/Qwen/Qwen3-Coder-480B-A35B-Instruct
openai-api-base: https://api.hyperbolic.xyz/v1
openai-api-key: "eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJkb2Jyb21pci5wb3BvdkB5YWhvby5jb20iLCJpYXQiOjE3NTMyMzE0MjZ9.fCbv2pUmDO9xxjVqfSKru4yz1vtrNvuGIXHibWZWInE"
# setx OPENAI_API_BASE https://api.hyperbolic.xyz/v1
# setx OPENAI_API_KEY eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJkb2Jyb21pci5wb3BvdkB5YWhvby5jb20iLCJpYXQiOjE3NTMyMzE0MjZ9.fCbv2pUmDO9xxjVqfSKru4yz1vtrNvuGIXHibWZWInE
# Environment variables for litellm to recognize Hyperbolic provider
set-env:
#setx HYPERBOLIC_API_KEY eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJkb2Jyb21pci5wb3BvdkB5YWhvby5jb20iLCJpYXQiOjE3NTMyMzE0MjZ9.fCbv2pUmDO9xxjVqfSKru4yz1vtrNvuGIXHibWZWInE
- HYPERBOLIC_API_KEY=eyJhbGciOiJIUzI1NiJ9.eyJzdWIiOiJkb2Jyb21pci5wb3BvdkB5YWhvby5jb20iLCJpYXQiOjE3NTMyMzE0MjZ9.fCbv2pUmDO9xxjVqfSKru4yz1vtrNvuGIXHibWZWInE
# - HYPERBOLIC_API_BASE=https://api.hyperbolic.xyz/v1
# Set encoding to UTF-8 (default)
encoding: utf-8
gitignore: false
# The metadata file is still needed to inform aider about the
# context window and costs for this custom model.
model-metadata-file: .aider.model.metadata.json

View File

@ -0,0 +1,7 @@
{
"hyperbolic/Qwen/Qwen3-Coder-480B-A35B-Instruct": {
"context_window": 262144,
"input_cost_per_token": 0.000002,
"output_cost_per_token": 0.000002
}
}

7
.env
View File

@ -1,6 +1,11 @@
# MEXC API Configuration (Spot Trading)
MEXC_API_KEY=mx0vglhVPZeIJ32Qw1
MEXC_SECRET_KEY=3bfe4bd99d5541e4a1bca87ab257cc7e
MEXC_SECRET_KEY=3bfe4bd99d5541e4a1bca87ab257cc7e
DERBIT_API_CLIENTID=me1yf6K0
DERBIT_API_SECRET=PxdvEHmJ59FrguNVIt45-iUBj3lPXbmlA7OQUeINE9s
BYBIT_API_KEY=GQ50IkgZKkR3ljlbPx
BYBIT_API_SECRET=0GWpva5lYrhzsUqZCidQpO5TxYwaEmdiEDyc
#3bfe4bd99d5541e4a1bca87ab257cc7e 45d0b3c26f2644f19bfb98b07741b2f5
# BASE ENDPOINTS: https://api.mexc.com wss://wbs-api.mexc.com/ws !!! DO NOT CHANGE THIS

12
.gitignore vendored
View File

@ -16,7 +16,7 @@ models/trading_agent_final.pt.backup
*.pt
*.backup
logs/
trade_logs/
# trade_logs/
*.csv
cache/
realtime_chart.log
@ -41,3 +41,13 @@ closed_trades_history.json
data/cnn_training/cnn_training_data*
testcases/*
testcases/negative/case_index.json
chrome_user_data/*
.aider*
!.aider.conf.yml
!.aider.model.metadata.json
.env
.env
training_data/*
data/trading_system.db
/data/trading_system.db

View File

@ -0,0 +1,448 @@
# Design Document
## Overview
The Multi-Exchange Data Aggregation System is a comprehensive data collection and processing subsystem designed to serve as the foundational data layer for the trading orchestrator. The system will collect real-time order book and OHLCV data from the top 10 cryptocurrency exchanges, aggregate it into standardized formats, store it in a TimescaleDB time-series database, and provide both live data feeds and historical replay capabilities.
The system follows a microservices architecture with containerized components, ensuring scalability, maintainability, and seamless integration with the existing trading infrastructure.
We implement it in the `.\COBY` subfolder for easy integration with the existing system
## Architecture
### High-Level Architecture
```mermaid
graph TB
subgraph "Exchange Connectors"
E1[Binance WebSocket]
E2[Coinbase WebSocket]
E3[Kraken WebSocket]
E4[Bybit WebSocket]
E5[OKX WebSocket]
E6[Huobi WebSocket]
E7[KuCoin WebSocket]
E8[Gate.io WebSocket]
E9[Bitfinex WebSocket]
E10[MEXC WebSocket]
end
subgraph "Data Processing Layer"
DP[Data Processor]
AGG[Aggregation Engine]
NORM[Data Normalizer]
end
subgraph "Storage Layer"
TSDB[(TimescaleDB)]
CACHE[Redis Cache]
end
subgraph "API Layer"
LIVE[Live Data API]
REPLAY[Replay API]
WEB[Web Dashboard]
end
subgraph "Integration Layer"
ORCH[Orchestrator Interface]
ADAPTER[Data Adapter]
end
E1 --> DP
E2 --> DP
E3 --> DP
E4 --> DP
E5 --> DP
E6 --> DP
E7 --> DP
E8 --> DP
E9 --> DP
E10 --> DP
DP --> NORM
NORM --> AGG
AGG --> TSDB
AGG --> CACHE
CACHE --> LIVE
TSDB --> REPLAY
LIVE --> WEB
REPLAY --> WEB
LIVE --> ADAPTER
REPLAY --> ADAPTER
ADAPTER --> ORCH
```
### Component Architecture
The system is organized into several key components:
1. **Exchange Connectors**: WebSocket clients for each exchange
2. **Data Processing Engine**: Normalizes and validates incoming data
3. **Aggregation Engine**: Creates price buckets and heatmaps
4. **Storage Layer**: TimescaleDB for persistence, Redis for caching
5. **API Layer**: REST and WebSocket APIs for data access
6. **Web Dashboard**: Real-time visualization interface
7. **Integration Layer**: Orchestrator-compatible interface
## Components and Interfaces
### Exchange Connector Interface
```python
class ExchangeConnector:
"""Base interface for exchange WebSocket connectors"""
async def connect(self) -> bool
async def disconnect(self) -> None
async def subscribe_orderbook(self, symbol: str) -> None
async def subscribe_trades(self, symbol: str) -> None
def get_connection_status(self) -> ConnectionStatus
def add_data_callback(self, callback: Callable) -> None
```
### Data Processing Interface
```python
class DataProcessor:
"""Processes and normalizes raw exchange data"""
def normalize_orderbook(self, raw_data: Dict, exchange: str) -> OrderBookSnapshot
def normalize_trade(self, raw_data: Dict, exchange: str) -> TradeEvent
def validate_data(self, data: Union[OrderBookSnapshot, TradeEvent]) -> bool
def calculate_metrics(self, orderbook: OrderBookSnapshot) -> OrderBookMetrics
```
### Aggregation Engine Interface
```python
class AggregationEngine:
"""Aggregates data into price buckets and heatmaps"""
def create_price_buckets(self, orderbook: OrderBookSnapshot, bucket_size: float) -> PriceBuckets
def update_heatmap(self, symbol: str, buckets: PriceBuckets) -> HeatmapData
def calculate_imbalances(self, orderbook: OrderBookSnapshot) -> ImbalanceMetrics
def aggregate_across_exchanges(self, symbol: str) -> ConsolidatedOrderBook
```
### Storage Interface
```python
class StorageManager:
"""Manages data persistence and retrieval"""
async def store_orderbook(self, data: OrderBookSnapshot) -> bool
async def store_trade(self, data: TradeEvent) -> bool
async def get_historical_data(self, symbol: str, start: datetime, end: datetime) -> List[Dict]
async def get_latest_data(self, symbol: str) -> Dict
def setup_database_schema(self) -> None
```
### Replay Interface
```python
class ReplayManager:
"""Provides historical data replay functionality"""
def create_replay_session(self, start_time: datetime, end_time: datetime, speed: float) -> str
async def start_replay(self, session_id: str) -> None
async def pause_replay(self, session_id: str) -> None
async def stop_replay(self, session_id: str) -> None
def get_replay_status(self, session_id: str) -> ReplayStatus
```
## Data Models
### Core Data Structures
```python
@dataclass
class OrderBookSnapshot:
"""Standardized order book snapshot"""
symbol: str
exchange: str
timestamp: datetime
bids: List[PriceLevel]
asks: List[PriceLevel]
sequence_id: Optional[int] = None
@dataclass
class PriceLevel:
"""Individual price level in order book"""
price: float
size: float
count: Optional[int] = None
@dataclass
class TradeEvent:
"""Standardized trade event"""
symbol: str
exchange: str
timestamp: datetime
price: float
size: float
side: str # 'buy' or 'sell'
trade_id: str
@dataclass
class PriceBuckets:
"""Aggregated price buckets for heatmap"""
symbol: str
timestamp: datetime
bucket_size: float
bid_buckets: Dict[float, float] # price -> volume
ask_buckets: Dict[float, float] # price -> volume
@dataclass
class HeatmapData:
"""Heatmap visualization data"""
symbol: str
timestamp: datetime
bucket_size: float
data: List[HeatmapPoint]
@dataclass
class HeatmapPoint:
"""Individual heatmap data point"""
price: float
volume: float
intensity: float # 0.0 to 1.0
side: str # 'bid' or 'ask'
```
### Database Schema
#### TimescaleDB Tables
```sql
-- Order book snapshots table
CREATE TABLE order_book_snapshots (
id BIGSERIAL,
symbol VARCHAR(20) NOT NULL,
exchange VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
bids JSONB NOT NULL,
asks JSONB NOT NULL,
sequence_id BIGINT,
mid_price DECIMAL(20,8),
spread DECIMAL(20,8),
bid_volume DECIMAL(30,8),
ask_volume DECIMAL(30,8),
PRIMARY KEY (timestamp, symbol, exchange)
);
-- Convert to hypertable
SELECT create_hypertable('order_book_snapshots', 'timestamp');
-- Trade events table
CREATE TABLE trade_events (
id BIGSERIAL,
symbol VARCHAR(20) NOT NULL,
exchange VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
price DECIMAL(20,8) NOT NULL,
size DECIMAL(30,8) NOT NULL,
side VARCHAR(4) NOT NULL,
trade_id VARCHAR(100) NOT NULL,
PRIMARY KEY (timestamp, symbol, exchange, trade_id)
);
-- Convert to hypertable
SELECT create_hypertable('trade_events', 'timestamp');
-- Aggregated heatmap data table
CREATE TABLE heatmap_data (
symbol VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
bucket_size DECIMAL(10,2) NOT NULL,
price_bucket DECIMAL(20,8) NOT NULL,
volume DECIMAL(30,8) NOT NULL,
side VARCHAR(3) NOT NULL,
exchange_count INTEGER NOT NULL,
PRIMARY KEY (timestamp, symbol, bucket_size, price_bucket, side)
);
-- Convert to hypertable
SELECT create_hypertable('heatmap_data', 'timestamp');
-- OHLCV data table
CREATE TABLE ohlcv_data (
symbol VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
timeframe VARCHAR(10) NOT NULL,
open_price DECIMAL(20,8) NOT NULL,
high_price DECIMAL(20,8) NOT NULL,
low_price DECIMAL(20,8) NOT NULL,
close_price DECIMAL(20,8) NOT NULL,
volume DECIMAL(30,8) NOT NULL,
trade_count INTEGER,
PRIMARY KEY (timestamp, symbol, timeframe)
);
-- Convert to hypertable
SELECT create_hypertable('ohlcv_data', 'timestamp');
```
## Error Handling
### Connection Management
The system implements robust error handling for exchange connections:
1. **Exponential Backoff**: Failed connections retry with increasing delays
2. **Circuit Breaker**: Temporarily disable problematic exchanges
3. **Graceful Degradation**: Continue operation with available exchanges
4. **Health Monitoring**: Continuous monitoring of connection status
### Data Validation
All incoming data undergoes validation:
1. **Schema Validation**: Ensure data structure compliance
2. **Range Validation**: Check price and volume ranges
3. **Timestamp Validation**: Verify temporal consistency
4. **Duplicate Detection**: Prevent duplicate data storage
### Database Resilience
Database operations include comprehensive error handling:
1. **Connection Pooling**: Maintain multiple database connections
2. **Transaction Management**: Ensure data consistency
3. **Retry Logic**: Automatic retry for transient failures
4. **Backup Strategies**: Regular data backups and recovery procedures
## Testing Strategy
### Unit Testing
Each component will have comprehensive unit tests:
1. **Exchange Connectors**: Mock WebSocket responses
2. **Data Processing**: Test normalization and validation
3. **Aggregation Engine**: Verify bucket calculations
4. **Storage Layer**: Test database operations
5. **API Layer**: Test endpoint responses
### Integration Testing
End-to-end testing scenarios:
1. **Multi-Exchange Data Flow**: Test complete data pipeline
2. **Database Integration**: Verify TimescaleDB operations
3. **API Integration**: Test orchestrator interface compatibility
4. **Performance Testing**: Load testing with high-frequency data
### Performance Testing
Performance benchmarks and testing:
1. **Throughput Testing**: Measure data processing capacity
2. **Latency Testing**: Measure end-to-end data latency
3. **Memory Usage**: Monitor memory consumption patterns
4. **Database Performance**: Query performance optimization
### Monitoring and Observability
Comprehensive monitoring system:
1. **Metrics Collection**: Prometheus-compatible metrics
2. **Logging**: Structured logging with correlation IDs
3. **Alerting**: Real-time alerts for system issues
4. **Dashboards**: Grafana dashboards for system monitoring
## Deployment Architecture
### Docker Containerization
The system will be deployed using Docker containers:
```yaml
# docker-compose.yml
version: '3.8'
services:
timescaledb:
image: timescale/timescaledb:latest-pg14
environment:
POSTGRES_DB: market_data
POSTGRES_USER: market_user
POSTGRES_PASSWORD: ${DB_PASSWORD}
volumes:
- timescale_data:/var/lib/postgresql/data
ports:
- "5432:5432"
redis:
image: redis:7-alpine
ports:
- "6379:6379"
volumes:
- redis_data:/data
data-aggregator:
build: ./data-aggregator
environment:
- DB_HOST=timescaledb
- REDIS_HOST=redis
- LOG_LEVEL=INFO
depends_on:
- timescaledb
- redis
web-dashboard:
build: ./web-dashboard
ports:
- "8080:8080"
environment:
- API_HOST=data-aggregator
depends_on:
- data-aggregator
volumes:
timescale_data:
redis_data:
```
### Configuration Management
Environment-based configuration:
```python
# config.py
@dataclass
class Config:
# Database settings
db_host: str = os.getenv('DB_HOST', 'localhost')
db_port: int = int(os.getenv('DB_PORT', '5432'))
db_name: str = os.getenv('DB_NAME', 'market_data')
db_user: str = os.getenv('DB_USER', 'market_user')
db_password: str = os.getenv('DB_PASSWORD', '')
# Redis settings
redis_host: str = os.getenv('REDIS_HOST', 'localhost')
redis_port: int = int(os.getenv('REDIS_PORT', '6379'))
# Exchange settings
exchanges: List[str] = field(default_factory=lambda: [
'binance', 'coinbase', 'kraken', 'bybit', 'okx',
'huobi', 'kucoin', 'gateio', 'bitfinex', 'mexc'
])
# Aggregation settings
btc_bucket_size: float = 10.0 # $10 USD buckets for BTC
eth_bucket_size: float = 1.0 # $1 USD buckets for ETH
# Performance settings
max_connections_per_exchange: int = 5
data_buffer_size: int = 10000
batch_write_size: int = 1000
# API settings
api_host: str = os.getenv('API_HOST', '0.0.0.0')
api_port: int = int(os.getenv('API_PORT', '8080'))
websocket_port: int = int(os.getenv('WS_PORT', '8081'))
```
This design provides a robust, scalable foundation for multi-exchange data aggregation that seamlessly integrates with the existing trading orchestrator while providing the flexibility for future enhancements and additional exchange integrations.

View File

@ -0,0 +1,103 @@
# Requirements Document
## Introduction
This document outlines the requirements for a comprehensive data collection and aggregation subsystem that will serve as a foundational component for the trading orchestrator. The system will collect, aggregate, and store real-time order book and OHLCV data from multiple cryptocurrency exchanges, providing both live data feeds and historical replay capabilities for model training and backtesting.
## Requirements
### Requirement 1
**User Story:** As a trading system developer, I want to collect real-time order book data from top 10 cryptocurrency exchanges, so that I can have comprehensive market data for analysis and trading decisions.
#### Acceptance Criteria
1. WHEN the system starts THEN it SHALL establish WebSocket connections to up to 10 major cryptocurrency exchanges
2. WHEN order book updates are received THEN the system SHALL process and store raw order book events in real-time
3. WHEN processing order book data THEN the system SHALL handle connection failures gracefully and automatically reconnect
4. WHEN multiple exchanges provide data THEN the system SHALL normalize data formats to a consistent structure
5. IF an exchange connection fails THEN the system SHALL log the failure and attempt reconnection with exponential backoff
### Requirement 2
**User Story:** As a trading analyst, I want order book data aggregated into price buckets with heatmap visualization, so that I can quickly identify market depth and liquidity patterns.
#### Acceptance Criteria
1. WHEN processing BTC order book data THEN the system SHALL aggregate orders into $10 USD price range buckets
2. WHEN processing ETH order book data THEN the system SHALL aggregate orders into $1 USD price range buckets
3. WHEN aggregating order data THEN the system SHALL maintain separate bid and ask heatmaps
4. WHEN building heatmaps THEN the system SHALL update distribution data at high frequency (sub-second)
5. WHEN displaying heatmaps THEN the system SHALL show volume intensity using color gradients or progress bars
### Requirement 3
**User Story:** As a system architect, I want all market data stored in a TimescaleDB database, so that I can efficiently query time-series data and maintain historical records.
#### Acceptance Criteria
1. WHEN the system initializes THEN it SHALL connect to a TimescaleDB instance running in a Docker container
2. WHEN storing order book events THEN the system SHALL use TimescaleDB's time-series optimized storage
3. WHEN storing OHLCV data THEN the system SHALL create appropriate time-series tables with proper indexing
4. WHEN writing to database THEN the system SHALL batch writes for optimal performance
5. IF database connection fails THEN the system SHALL queue data in memory and retry with backoff strategy
### Requirement 4
**User Story:** As a trading system operator, I want a web-based dashboard to monitor real-time order book heatmaps, so that I can visualize market conditions across multiple exchanges.
#### Acceptance Criteria
1. WHEN accessing the web dashboard THEN it SHALL display real-time order book heatmaps for BTC and ETH
2. WHEN viewing heatmaps THEN the dashboard SHALL show aggregated data from all connected exchanges
3. WHEN displaying progress bars THEN they SHALL always show aggregated values across price buckets
4. WHEN updating the display THEN the dashboard SHALL refresh data at least once per second
5. WHEN an exchange goes offline THEN the dashboard SHALL indicate the status change visually
### Requirement 5
**User Story:** As a model trainer, I want a replay interface that can provide historical data in the same format as live data, so that I can train models on past market events.
#### Acceptance Criteria
1. WHEN requesting historical data THEN the replay interface SHALL provide data in the same structure as live feeds
2. WHEN replaying data THEN the system SHALL maintain original timing relationships between events
3. WHEN using replay mode THEN the interface SHALL support configurable playback speeds
4. WHEN switching between live and replay modes THEN the orchestrator SHALL receive data through the same interface
5. IF replay data is requested for unavailable time periods THEN the system SHALL return appropriate error messages
### Requirement 6
**User Story:** As a trading system integrator, I want the data aggregation system to follow the same interface as the current orchestrator data provider, so that I can seamlessly integrate it into existing workflows.
#### Acceptance Criteria
1. WHEN the orchestrator requests data THEN the aggregation system SHALL provide data in the expected format
2. WHEN integrating with existing systems THEN the interface SHALL be compatible with current data provider contracts
3. WHEN providing aggregated data THEN the system SHALL include metadata about data sources and quality
4. WHEN the orchestrator switches data sources THEN it SHALL work without code changes
5. IF data quality issues are detected THEN the system SHALL provide quality indicators in the response
### Requirement 7
**User Story:** As a system administrator, I want the data collection system to be containerized and easily deployable, so that I can manage it alongside other system components.
#### Acceptance Criteria
1. WHEN deploying the system THEN it SHALL run in Docker containers with proper resource allocation
2. WHEN starting services THEN TimescaleDB SHALL be automatically provisioned in its own container
3. WHEN configuring the system THEN all settings SHALL be externalized through environment variables or config files
4. WHEN monitoring the system THEN it SHALL provide health check endpoints for container orchestration
5. IF containers need to be restarted THEN the system SHALL recover gracefully without data loss
### Requirement 8
**User Story:** As a performance engineer, I want the system to handle high-frequency data efficiently, so that it can process order book updates from multiple exchanges without latency issues.
#### Acceptance Criteria
1. WHEN processing order book updates THEN the system SHALL handle at least 10 updates per second per exchange
2. WHEN aggregating data THEN processing latency SHALL be less than 10 milliseconds per update
3. WHEN storing data THEN the system SHALL use efficient batching to minimize database overhead
4. WHEN memory usage grows THEN the system SHALL implement appropriate cleanup and garbage collection
5. IF processing falls behind THEN the system SHALL prioritize recent data and log performance warnings

View File

@ -0,0 +1,210 @@
# Implementation Plan
- [x] 1. Set up project structure and core interfaces
- Create directory structure in `.\COBY` subfolder for the multi-exchange data aggregation system
- Define base interfaces and data models for exchange connectors, data processing, and storage
- Implement configuration management system with environment variable support
- _Requirements: 1.1, 6.1, 7.3_
- [x] 2. Implement TimescaleDB integration and database schema
- Create TimescaleDB connection manager with connection pooling
- Implement database schema creation with hypertables for time-series optimization
- Write database operations for storing order book snapshots and trade events
- Create database migration system for schema updates
- _Requirements: 3.1, 3.2, 3.3, 3.4_
- [x] 3. Create base exchange connector framework
- Implement abstract base class for exchange WebSocket connectors
- Create connection management with exponential backoff and circuit breaker patterns
- Implement WebSocket message handling with proper error recovery
- Add connection status monitoring and health checks
- _Requirements: 1.1, 1.3, 1.4, 8.5_
- [x] 4. Implement Binance exchange connector
- Create Binance-specific WebSocket connector extending the base framework
- Implement order book depth stream subscription and processing
- Add trade stream subscription for volume analysis
- Implement data normalization from Binance format to standard format
- Write unit tests for Binance connector functionality
- _Requirements: 1.1, 1.2, 1.4, 6.2_
- [x] 5. Create data processing and normalization engine
- Implement data processor for normalizing raw exchange data
- Create validation logic for order book and trade data
- Implement data quality checks and filtering
- Add metrics calculation for order book statistics
- Write comprehensive unit tests for data processing logic
- _Requirements: 1.4, 6.3, 8.1_
- [x] 6. Implement price bucket aggregation system
- Create aggregation engine for converting order book data to price buckets
- Implement configurable bucket sizes ($10 for BTC, $1 for ETH)
- Create heatmap data structure generation from price buckets
- Implement real-time aggregation with high-frequency updates
- Add volume-weighted aggregation calculations
- _Requirements: 2.1, 2.2, 2.3, 2.4, 8.1, 8.2_
- [x] 7. Build Redis caching layer
- Implement Redis connection manager with connection pooling
- Create caching strategies for latest order book data and heatmaps
- Implement cache invalidation and TTL management
- Add cache performance monitoring and metrics
- Write tests for caching functionality
- _Requirements: 8.2, 8.3_
- [x] 8. Create live data API endpoints
- Implement REST API for accessing current order book data
- Create WebSocket API for real-time data streaming
- Add endpoints for heatmap data retrieval
- Implement API rate limiting and authentication
- Create comprehensive API documentation
- _Requirements: 4.1, 4.2, 4.4, 6.3_
- [x] 9. Implement web dashboard for visualization
- Create HTML/CSS/JavaScript dashboard for real-time heatmap visualization
- Implement WebSocket client for receiving real-time updates
- Create progress bar visualization for aggregated price buckets
- Add exchange status indicators and connection monitoring
- Implement responsive design for different screen sizes
- _Requirements: 4.1, 4.2, 4.3, 4.5_
- [x] 10. Build historical data replay system
- Create replay manager for historical data playback
- Implement configurable playback speeds and time range selection
- Create replay session management with start/pause/stop controls
- Implement data streaming interface compatible with live data format
- Add replay status monitoring and progress tracking
- _Requirements: 5.1, 5.2, 5.3, 5.4, 5.5_
- [x] 11. Create orchestrator integration interface
- Implement data adapter that matches existing orchestrator interface
- Create compatibility layer for seamless integration with current data provider
- Add data quality indicators and metadata in responses
- Implement switching mechanism between live and replay modes
- Write integration tests with existing orchestrator code
- _Requirements: 6.1, 6.2, 6.3, 6.4, 6.5_
- [x] 12. Add additional exchange connectors (Coinbase, Kraken)
- Implement Coinbase Pro WebSocket connector with proper authentication
- Create Kraken WebSocket connector with their specific message format
- Add exchange-specific data normalization for both exchanges
- Implement proper error handling for each exchange's quirks
- Write unit tests for both new exchange connectors
- _Requirements: 1.1, 1.2, 1.4_
- [x] 13. Implement remaining exchange connectors (Bybit, OKX, Huobi)
- Create Bybit WebSocket connector with unified trading account support
- Implement OKX connector with their V5 API WebSocket streams
- Add Huobi Global connector with proper symbol mapping
- Ensure all connectors follow the same interface and error handling patterns
- Write comprehensive tests for all three exchange connectors
- _Requirements: 1.1, 1.2, 1.4_
- [x] 14. Complete exchange connector suite (KuCoin, Gate.io, Bitfinex, MEXC)
- Implement KuCoin connector with proper token-based authentication
- Create Gate.io connector with their WebSocket v4 API
- Add Bitfinex connector with proper channel subscription management
- Implement MEXC connector with their WebSocket streams
- Ensure all 10 exchanges are properly integrated and tested
- _Requirements: 1.1, 1.2, 1.4_
- [ ] 15. Implement cross-exchange data consolidation
- Create consolidation engine that merges order book data from multiple exchanges
- Implement weighted aggregation based on exchange liquidity and reliability
- Add conflict resolution for price discrepancies between exchanges
- Create consolidated heatmap that shows combined market depth
- Write tests for multi-exchange aggregation scenarios
- _Requirements: 2.5, 4.2_
- [ ] 16. Add performance monitoring and optimization
- Implement comprehensive metrics collection for all system components
- Create performance monitoring dashboard with key system metrics
- Add latency tracking for end-to-end data processing
- Implement memory usage monitoring and garbage collection optimization
- Create alerting system for performance degradation
- _Requirements: 8.1, 8.2, 8.3, 8.4, 8.5_
- [ ] 17. Create Docker containerization and deployment
- Write Dockerfiles for all system components
- Create docker-compose configuration for local development
- Implement health check endpoints for container orchestration
- Add environment variable configuration for all services
- Create deployment scripts and documentation
- _Requirements: 7.1, 7.2, 7.3, 7.4, 7.5_
- [ ] 18. Implement comprehensive testing suite
- Create integration tests for complete data pipeline from exchanges to storage
- Implement load testing for high-frequency data scenarios
- Add end-to-end tests for web dashboard functionality
- Create performance benchmarks and regression tests
- Write documentation for running and maintaining tests
- _Requirements: 8.1, 8.2, 8.3, 8.4_
- [ ] 19. Add system monitoring and alerting
- Implement structured logging with correlation IDs across all components
- Create Prometheus metrics exporters for system monitoring
- Add Grafana dashboards for system visualization
- Implement alerting rules for system failures and performance issues
- Create runbook documentation for common operational scenarios
- _Requirements: 7.4, 8.5_
- [ ] 20. Final integration and system testing
- Integrate the complete system with existing trading orchestrator
- Perform end-to-end testing with real market data
- Validate replay functionality with historical data scenarios
- Test failover scenarios and system resilience
- Create user documentation and operational guides
- _Requirements: 6.1, 6.2, 6.4, 5.1, 5.2_

View File

@ -0,0 +1,713 @@
# Multi-Modal Trading System Design Document
## Overview
The Multi-Modal Trading System is designed as an advanced algorithmic trading platform that combines Convolutional Neural Networks (CNN) and Reinforcement Learning (RL) models orchestrated by a decision-making module. The system processes multi-timeframe and multi-symbol market data (primarily ETH and BTC) to generate trading actions.
This design document outlines the architecture, components, data flow, and implementation details for the system based on the requirements and existing codebase.
## Architecture
The system follows a modular architecture with clear separation of concerns:
```mermaid
graph TD
A[Data Provider] --> B[Data Processor] (calculates pivot points)
B --> C[CNN Model]
B --> D[RL(DQN) Model]
C --> E[Orchestrator]
D --> E
E --> F[Trading Executor]
E --> G[Dashboard]
F --> G
H[Risk Manager] --> F
H --> G
```
### Key Components
1. **Data Provider**: Centralized component responsible for collecting, processing, and distributing market data from multiple sources.
2. **Data Processor**: Processes raw market data, calculates technical indicators, and identifies pivot points.
3. **CNN Model**: Analyzes patterns in market data and predicts pivot points across multiple timeframes.
4. **RL Model**: Learns optimal trading strategies based on market data and CNN predictions.
5. **Orchestrator**: Makes final trading decisions based on inputs from both CNN and RL models.
6. **Trading Executor**: Executes trading actions through brokerage APIs.
7. **Risk Manager**: Implements risk management features like stop-loss and position sizing.
8. **Dashboard**: Provides a user interface for monitoring and controlling the system.
## Components and Interfaces
### 1. Data Provider
The Data Provider is the foundation of the system, responsible for collecting, processing, and distributing market data to all other components.
#### Key Classes and Interfaces
- **DataProvider**: Central class that manages data collection, processing, and distribution.
- **MarketTick**: Data structure for standardized market tick data.
- **DataSubscriber**: Interface for components that subscribe to market data.
- **PivotBounds**: Data structure for pivot-based normalization bounds.
#### Implementation Details
The DataProvider class will:
- Collect data from multiple sources (Binance, MEXC)
- Support multiple timeframes (1s, 1m, 1h, 1d)
- Support multiple symbols (ETH, BTC)
- Calculate technical indicators
- Identify pivot points
- Normalize data
- Distribute data to subscribers
- Calculate any other algoritmic manipulations/calculations on the data
- Cache up to 3x the model inputs (300 ticks OHLCV, etc) data so we can do a proper backtesting in up to 2x time in the future
Based on the existing implementation in `core/data_provider.py`, we'll enhance it to:
- Improve pivot point calculation using reccursive Williams Market Structure
- Optimize data caching for better performance
- Enhance real-time data streaming
- Implement better error handling and fallback mechanisms
### BASE FOR ALL MODELS ###
- ***INPUTS***: COB+OHCLV data frame as described:
- OHCLV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC
- COB: for each 1s OHCLV we have +- 20 buckets of COB ammounts in USD
- 1,5,15 and 60s MA of the COB imbalance counting +- 5 COB buckets
- ***OUTPUTS***:
- suggested trade action (BUY/SELL/HOLD). Paired with confidence
- immediate price movement drection vector (-1: vertical down, 1: vertical up, 0: horizontal) - linear; with it's own confidence
# Standardized input for all models:
{
'primary_symbol': 'ETH/USDT',
'reference_symbol': 'BTC/USDT',
'eth_data': {'ETH_1s': df, 'ETH_1m': df, 'ETH_1h': df, 'ETH_1d': df},
'btc_data': {'BTC_1s': df},
'current_prices': {'ETH': price, 'BTC': price},
'data_completeness': {...}
}
### 2. CNN Model
The CNN Model is responsible for analyzing patterns in market data and predicting pivot points across multiple timeframes.
#### Key Classes and Interfaces
- **CNNModel**: Main class for the CNN model.
- **PivotPointPredictor**: Interface for predicting pivot points.
- **CNNTrainer**: Class for training the CNN model.
- ***INPUTS***: COB+OHCLV+Old Pivots (5 levels of pivots)
- ***OUTPUTS***: next pivot point for each level as price-time vector. (can be plotted as trend line) + suggested trade action (BUY/SELL)
#### Implementation Details
The CNN Model will:
- Accept multi-timeframe and multi-symbol data as input
- Output predicted pivot points for each timeframe (1s, 1m, 1h, 1d)
- Provide confidence scores for each prediction
- Make hidden layer states available for the RL model
Architecture:
- Input layer: Multi-channel input for different timeframes and symbols
- Convolutional layers: Extract patterns from time series data
- LSTM/GRU layers: Capture temporal dependencies
- Attention mechanism: Focus on relevant parts of the input
- Output layer: Predict pivot points and confidence scores
Training:
- Use programmatically calculated pivot points as ground truth
- Train on historical data
- Update model when new pivot points are detected
- Use backpropagation to optimize weights
### 3. RL Model
The RL Model is responsible for learning optimal trading strategies based on market data and CNN predictions.
#### Key Classes and Interfaces
- **RLModel**: Main class for the RL model.
- **TradingActionGenerator**: Interface for generating trading actions.
- **RLTrainer**: Class for training the RL model.
#### Implementation Details
The RL Model will:
- Accept market data, CNN model predictions (output), and CNN hidden layer states as input
- Output trading action recommendations (buy/sell)
- Provide confidence scores for each action
- Learn from past experiences to adapt to the current market environment
Architecture:
- State representation: Market data, CNN model predictions (output), CNN hidden layer states
- Action space: Buy, Sell
- Reward function: PnL, risk-adjusted returns
- Policy network: Deep neural network
- Value network: Estimate expected returns
Training:
- Use reinforcement learning algorithms (DQN, PPO, A3C)
- Train on historical data
- Update model based on trading outcomes
- Use experience replay to improve sample efficiency
### 4. Orchestrator
The Orchestrator serves as the central coordination hub of the multi-modal trading system, responsible for data subscription management, model inference coordination, output storage, training pipeline orchestration, and inference-training feedback loop management.
#### Key Classes and Interfaces
- **Orchestrator**: Main class for the orchestrator.
- **DataSubscriptionManager**: Manages subscriptions to multiple data streams with different refresh rates.
- **ModelInferenceCoordinator**: Coordinates inference across all models.
- **ModelOutputStore**: Stores and manages model outputs for cross-model feeding.
- **TrainingPipelineManager**: Manages training pipelines for all models.
- **DecisionMaker**: Interface for making trading decisions.
- **MoEGateway**: Mixture of Experts gateway for model integration.
#### Core Responsibilities
##### 1. Data Subscription and Management
The Orchestrator subscribes to the Data Provider and manages multiple data streams with varying refresh rates:
- **10Hz COB (Cumulative Order Book) Data**: High-frequency order book updates for real-time market depth analysis
- **OHLCV Data**: Traditional candlestick data at multiple timeframes (1s, 1m, 1h, 1d)
- **Market Tick Data**: Individual trade executions and price movements
- **Technical Indicators**: Calculated indicators that update at different frequencies
- **Pivot Points**: Market structure analysis data
**Data Stream Management**:
- Maintains separate buffers for each data type with appropriate retention policies
- Ensures thread-safe access to data streams from multiple models
- Implements intelligent caching to serve "last updated" data efficiently
- Maintains full base dataframe that stays current for any model requesting data
- Handles data synchronization across different refresh rates
**Enhanced 1s Timeseries Data Combination**:
- Combines OHLCV data with COB (Cumulative Order Book) data for 1s timeframes
- Implements price bucket aggregation: ±20 buckets around current price
- ETH: $1 bucket size (e.g., $3000-$3040 range = 40 buckets) when current price is 3020
- BTC: $10 bucket size (e.g., $50000-$50400 range = 40 buckets) when price is 50200
- Creates unified base data input that includes:
- Traditional OHLCV metrics (Open, High, Low, Close, Volume)
- Order book depth and liquidity at each price level
- Bid/ask imbalances for the +-5 buckets with Moving Averages for 5,15, and 60s
- Volume-weighted average prices within buckets
- Order flow dynamics and market microstructure data
##### 2. Model Inference Coordination
The Orchestrator coordinates inference across all models in the system:
**Inference Pipeline**:
- Triggers model inference when relevant data updates occur
- Manages inference scheduling based on data availability and model requirements
- Coordinates parallel inference execution for independent models
- Handles model dependencies (e.g., RL model waiting for CNN hidden states)
**Model Input Management**:
- Assembles appropriate input data for each model based on their requirements
- Ensures models receive the most current data available at inference time
- Manages feature engineering and data preprocessing for each model
- Handles different input formats and requirements across models
##### 3. Model Output Storage and Cross-Feeding
The Orchestrator maintains a centralized store for all model outputs and manages cross-model data feeding:
**Output Storage**:
- Stores CNN predictions, confidence scores, and hidden layer states
- Stores RL action recommendations and value estimates
- Stores outputs from all models in extensible format supporting future models (LSTM, Transformer, etc.)
- Maintains historical output sequences for temporal analysis
- Implements efficient retrieval mechanisms for real-time access
- Uses standardized ModelOutput format for easy extension and cross-model compatibility
**Cross-Model Feeding**:
- Feeds CNN hidden layer states into RL model inputs
- Provides CNN predictions as context for RL decision-making
- Includes "last predictions" from each available model as part of base data input
- Stores model outputs that become inputs for subsequent inference cycles
- Manages circular dependencies and feedback loops between models
- Supports dynamic model addition without requiring system architecture changes
##### 4. Training Pipeline Management
The Orchestrator coordinates training for all models by managing the prediction-result feedback loop:
**Training Coordination**:
- Calls each model's training pipeline when new inference results are available
- Provides previous predictions alongside new results for supervised learning
- Manages training data collection and labeling
- Coordinates online learning updates based on real-time performance
**Training Data Management**:
- Maintains training datasets with prediction-result pairs
- Implements data quality checks and filtering
- Manages training data retention and archival policies
- Provides training data statistics and monitoring
**Performance Tracking**:
- Tracks prediction accuracy for each model over time
- Monitors model performance degradation and triggers retraining
- Maintains performance metrics for model comparison and selection
**Training progress and checkpoints persistance**
- it uses the checkpoint manager to store check points of each model over time as training progresses and we have improvements
- checkpoint manager has capability to ensure only top 5 to 10 best checkpoints are stored for each model deleting the least performant ones. it stores metadata along the CPs to decide the performance
- we automatically load the best CP at startup if we have stored ones
##### 5. Inference Data Validation and Storage
The Orchestrator implements comprehensive inference data validation and persistent storage:
**Input Data Validation**:
- Validates complete OHLCV dataframes for all required timeframes before inference
- Checks input data dimensions against model requirements
- Logs missing components and prevents prediction on incomplete data
- Raises validation errors with specific details about expected vs actual dimensions
**Inference History Storage**:
- Stores complete input data packages with each prediction in persistent storage
- Includes timestamp, symbol, input features, prediction outputs, confidence scores, and model internal states
- Maintains compressed storage to minimize footprint while preserving accessibility
- Implements efficient query mechanisms by symbol, timeframe, and date range
**Storage Management**:
- Applies configurable retention policies to manage storage limits
- Archives or removes oldest entries when limits are reached
- Prioritizes keeping most recent and valuable training examples during storage pressure
- Provides data completeness metrics and validation results in logs
##### 6. Inference-Training Feedback Loop
The Orchestrator manages the continuous learning cycle through inference-training feedback:
**Prediction Outcome Evaluation**:
- Evaluates prediction accuracy against actual price movements after sufficient time has passed
- Creates training examples using stored inference data paired with actual market outcomes
- Feeds prediction-result pairs back to respective models for learning
**Adaptive Learning Signals**:
- Provides positive reinforcement signals for accurate predictions
- Delivers corrective training signals for inaccurate predictions to help models learn from mistakes
- Retrieves last inference data for each model to compare predictions against actual outcomes
**Continuous Improvement Tracking**:
- Tracks and reports accuracy improvements or degradations over time
- Monitors model learning progress through the feedback loop
- Alerts administrators when data flow issues are detected with specific error details and remediation suggestions
##### 5. Decision Making and Trading Actions
Beyond coordination, the Orchestrator makes final trading decisions:
**Decision Integration**:
- Combines outputs from CNN and RL models using Mixture of Experts approach
- Applies confidence-based filtering to avoid uncertain trades
- Implements configurable thresholds for buy/sell decisions
- Considers market conditions and risk parameters
#### Implementation Details
**Architecture**:
```python
class Orchestrator:
def __init__(self):
self.data_subscription_manager = DataSubscriptionManager()
self.model_inference_coordinator = ModelInferenceCoordinator()
self.model_output_store = ModelOutputStore()
self.training_pipeline_manager = TrainingPipelineManager()
self.decision_maker = DecisionMaker()
self.moe_gateway = MoEGateway()
async def run(self):
# Subscribe to data streams
await self.data_subscription_manager.subscribe_to_data_provider()
# Start inference coordination loop
await self.model_inference_coordinator.start()
# Start training pipeline management
await self.training_pipeline_manager.start()
```
**Data Flow Management**:
- Implements event-driven architecture for data updates
- Uses async/await patterns for non-blocking operations
- Maintains data freshness timestamps for each stream
- Implements backpressure handling for high-frequency data
**Model Coordination**:
- Manages model lifecycle (loading, inference, training, updating)
- Implements model versioning and rollback capabilities
- Handles model failures and fallback mechanisms
- Provides model performance monitoring and alerting
**Training Integration**:
- Implements incremental learning strategies
- Manages training batch composition and scheduling
- Provides training progress monitoring and control
- Handles training failures and recovery
### 5. Trading Executor
The Trading Executor is responsible for executing trading actions through brokerage APIs.
#### Key Classes and Interfaces
- **TradingExecutor**: Main class for the trading executor.
- **BrokerageAPI**: Interface for interacting with brokerages.
- **OrderManager**: Class for managing orders.
#### Implementation Details
The Trading Executor will:
- Accept trading actions from the orchestrator
- Execute orders through brokerage APIs
- Manage order lifecycle
- Handle errors and retries
- Provide feedback on order execution
Supported brokerages:
- MEXC
- Binance
- Bybit (future extension)
Order types:
- Market orders
- Limit orders
- Stop-loss orders
### 6. Risk Manager
The Risk Manager is responsible for implementing risk management features like stop-loss and position sizing.
#### Key Classes and Interfaces
- **RiskManager**: Main class for the risk manager.
- **StopLossManager**: Class for managing stop-loss orders.
- **PositionSizer**: Class for determining position sizes.
#### Implementation Details
The Risk Manager will:
- Implement configurable stop-loss functionality
- Implement configurable position sizing based on risk parameters
- Implement configurable maximum drawdown limits
- Provide real-time risk metrics
- Provide alerts for high-risk situations
Risk parameters:
- Maximum position size
- Maximum drawdown
- Risk per trade
- Maximum leverage
### 7. Dashboard
The Dashboard provides a user interface for monitoring and controlling the system.
#### Key Classes and Interfaces
- **Dashboard**: Main class for the dashboard.
- **ChartManager**: Class for managing charts.
- **ControlPanel**: Class for managing controls.
#### Implementation Details
The Dashboard will:
- Display real-time market data for all symbols and timeframes
- Display OHLCV charts for all timeframes
- Display CNN pivot point predictions and confidence levels
- Display RL and orchestrator trading actions and confidence levels
- Display system status and model performance metrics
- Provide start/stop toggles for all system processes
- Provide sliders to adjust buy/sell thresholds for the orchestrator
Implementation:
- Web-based dashboard using Flask/Dash
- Real-time updates using WebSockets
- Interactive charts using Plotly
- Server-side processing for all models
## Data Models
### Market Data
```python
@dataclass
class MarketTick:
symbol: str
timestamp: datetime
price: float
volume: float
quantity: float
side: str # 'buy' or 'sell'
trade_id: str
is_buyer_maker: bool
raw_data: Dict[str, Any] = field(default_factory=dict)
```
### OHLCV Data
```python
@dataclass
class OHLCVBar:
symbol: str
timestamp: datetime
open: float
high: float
low: float
close: float
volume: float
timeframe: str
indicators: Dict[str, float] = field(default_factory=dict)
```
### Pivot Points
```python
@dataclass
class PivotPoint:
symbol: str
timestamp: datetime
price: float
type: str # 'high' or 'low'
level: int # Pivot level (1, 2, 3, etc.)
confidence: float = 1.0
```
### Trading Actions
```python
@dataclass
class TradingAction:
symbol: str
timestamp: datetime
action: str # 'buy' or 'sell'
confidence: float
source: str # 'rl', 'cnn', 'orchestrator'
price: Optional[float] = None
quantity: Optional[float] = None
reason: Optional[str] = None
```
### Model Predictions
```python
@dataclass
class ModelOutput:
"""Extensible model output format supporting all model types"""
model_type: str # 'cnn', 'rl', 'lstm', 'transformer', 'orchestrator'
model_name: str # Specific model identifier
symbol: str
timestamp: datetime
confidence: float
predictions: Dict[str, Any] # Model-specific predictions
hidden_states: Optional[Dict[str, Any]] = None # For cross-model feeding
metadata: Dict[str, Any] = field(default_factory=dict) # Additional info
```
```python
@dataclass
class CNNPrediction:
symbol: str
timestamp: datetime
pivot_points: List[PivotPoint]
hidden_states: Dict[str, Any]
confidence: float
```
```python
@dataclass
class RLPrediction:
symbol: str
timestamp: datetime
action: str # 'buy' or 'sell'
confidence: float
expected_reward: float
```
### Enhanced Base Data Input
```python
@dataclass
class BaseDataInput:
"""Unified base data input for all models"""
symbol: str
timestamp: datetime
ohlcv_data: Dict[str, OHLCVBar] # Multi-timeframe OHLCV
cob_data: Optional[Dict[str, float]] = None # COB buckets for 1s timeframe
technical_indicators: Dict[str, float] = field(default_factory=dict)
pivot_points: List[PivotPoint] = field(default_factory=list)
last_predictions: Dict[str, ModelOutput] = field(default_factory=dict) # From all models
market_microstructure: Dict[str, Any] = field(default_factory=dict) # Order flow, etc.
```
### COB Data Structure
```python
@dataclass
class COBData:
"""Cumulative Order Book data for price buckets"""
symbol: str
timestamp: datetime
current_price: float
bucket_size: float # $1 for ETH, $10 for BTC
price_buckets: Dict[float, Dict[str, float]] # price -> {bid_volume, ask_volume, etc.}
bid_ask_imbalance: Dict[float, float] # price -> imbalance ratio
volume_weighted_prices: Dict[float, float] # price -> VWAP within bucket
order_flow_metrics: Dict[str, float] # Various order flow indicators
```
### Data Collection Errors
- Implement retry mechanisms for API failures
- Use fallback data sources when primary sources are unavailable
- Log all errors with detailed information
- Notify users through the dashboard
### Model Errors
- Implement model validation before deployment
- Use fallback models when primary models fail
- Log all errors with detailed information
- Notify users through the dashboard
### Trading Errors
- Implement order validation before submission
- Use retry mechanisms for order failures
- Implement circuit breakers for extreme market conditions
- Log all errors with detailed information
- Notify users through the dashboard
## Testing Strategy
### Unit Testing
- Test individual components in isolation
- Use mock objects for dependencies
- Focus on edge cases and error handling
### Integration Testing
- Test interactions between components
- Use real data for testing
- Focus on data flow and error propagation
### System Testing
- Test the entire system end-to-end
- Use real data for testing
- Focus on performance and reliability
### Backtesting
- Test trading strategies on historical data
- Measure performance metrics (PnL, Sharpe ratio, etc.)
- Compare against benchmarks
### Live Testing
- Test the system in a live environment with small position sizes
- Monitor performance and stability
- Gradually increase position sizes as confidence grows
## Implementation Plan
The implementation will follow a phased approach:
1. **Phase 1: Data Provider**
- Implement the enhanced data provider
- Implement pivot point calculation
- Implement technical indicator calculation
- Implement data normalization
2. **Phase 2: CNN Model**
- Implement the CNN model architecture
- Implement the training pipeline
- Implement the inference pipeline
- Implement the pivot point prediction
3. **Phase 3: RL Model**
- Implement the RL model architecture
- Implement the training pipeline
- Implement the inference pipeline
- Implement the trading action generation
4. **Phase 4: Orchestrator**
- Implement the orchestrator architecture
- Implement the decision-making logic
- Implement the MoE gateway
- Implement the confidence-based filtering
5. **Phase 5: Trading Executor**
- Implement the trading executor
- Implement the brokerage API integrations
- Implement the order management
- Implement the error handling
6. **Phase 6: Risk Manager**
- Implement the risk manager
- Implement the stop-loss functionality
- Implement the position sizing
- Implement the risk metrics
7. **Phase 7: Dashboard**
- Implement the dashboard UI
- Implement the chart management
- Implement the control panel
- Implement the real-time updates
8. **Phase 8: Integration and Testing**
- Integrate all components
- Implement comprehensive testing
- Fix bugs and optimize performance
- Deploy to production
## Monitoring and Visualization
### TensorBoard Integration (Future Enhancement)
A comprehensive TensorBoard integration has been designed to provide detailed training visualization and monitoring capabilities:
#### Features
- **Training Metrics Visualization**: Real-time tracking of model losses, rewards, and performance metrics
- **Feature Distribution Analysis**: Histograms and statistics of input features to validate data quality
- **State Quality Monitoring**: Tracking of comprehensive state building (13,400 features) success rates
- **Reward Component Analysis**: Detailed breakdown of reward calculations including PnL, confidence, volatility, and order flow
- **Model Performance Comparison**: Side-by-side comparison of CNN, RL, and orchestrator performance
#### Implementation Status
- **Completed**: TensorBoardLogger utility class with comprehensive logging methods
- **Completed**: Integration points in enhanced_rl_training_integration.py
- **Completed**: Enhanced run_tensorboard.py with improved visualization options
- **Status**: Ready for deployment when system stability is achieved
#### Usage
```bash
# Start TensorBoard dashboard
python run_tensorboard.py
# Access at http://localhost:6006
# View training metrics, feature distributions, and model performance
```
#### Benefits
- Real-time validation of training process
- Early detection of training issues
- Feature importance analysis
- Model performance comparison
- Historical training progress tracking
**Note**: TensorBoard integration is currently deprioritized in favor of system stability and core model improvements. It will be activated once the core training system is stable and performing optimally.
## Conclusion
This design document outlines the architecture, components, data flow, and implementation details for the Multi-Modal Trading System. The system is designed to be modular, extensible, and robust, with a focus on performance, reliability, and user experience.
The implementation will follow a phased approach, with each phase building on the previous one. The system will be thoroughly tested at each phase to ensure that it meets the requirements and performs as expected.
The final system will provide traders with a powerful tool for analyzing market data, identifying trading opportunities, and executing trades with confidence.

View File

@ -0,0 +1,175 @@
# Requirements Document
## Introduction
The Multi-Modal Trading System is an advanced algorithmic trading platform that combines Convolutional Neural Networks (CNN) and Reinforcement Learning (RL) models orchestrated by a decision-making module. The system processes multi-timeframe and multi-symbol market data (primarily ETH and BTC) to generate trading actions. The system is designed to adapt to current market conditions through continuous learning from past experiences, with the CNN module trained on historical data to predict pivot points and the RL module optimizing trading decisions based on these predictions and market data.
## Requirements
### Requirement 1: Data Collection and Processing
**User Story:** As a trader, I want the system to collect and process multi-timeframe and multi-symbol market data, so that the models have comprehensive market information for making accurate trading decisions.
#### Acceptance Criteria
0. NEVER USE GENERATED/SYNTHETIC DATA or mock implementations and UI. If somethings is not implemented yet, it should be obvious.
1. WHEN the system starts THEN it SHALL collect and process data for both ETH and BTC symbols.
2. WHEN collecting data THEN the system SHALL store the following for the primary symbol (ETH):
- 300 seconds of raw tick data - price and COB snapshot for all prices +- 1% on fine reslolution buckets (1$ for ETH, 10$ for BTC)
- 300 seconds of 1-second OHLCV data + 1s aggregated COB data
- 300 bars of OHLCV + indicators for each timeframe (1s, 1m, 1h, 1d)
3. WHEN collecting data THEN the system SHALL store similar data for the reference symbol (BTC).
4. WHEN processing data THEN the system SHALL calculate standard technical indicators for all timeframes.
5. WHEN processing data THEN the system SHALL calculate pivot points for all timeframes according to the specified methodology.
6. WHEN new data arrives THEN the system SHALL update its data cache in real-time.
7. IF tick data is not available THEN the system SHALL substitute with the lowest available timeframe data.
8. WHEN normalizing data THEN the system SHALL normalize to the max and min of the highest timeframe to maintain relationships between different timeframes.
9. data is cached for longer (let's start with double the model inputs so 600 bars) to support performing backtesting when we know the current predictions outcomes so we can generate test cases.
10. In general all models have access to the whole data we collect in a central data provider implementation. only some are specialized. All models should also take as input the last output of evey other model (also cached in the data provider). there should be a room for adding more models in the other models data input so we can extend the system without having to loose existing models and trained W&B
### Requirement 2: CNN Model Implementation
**User Story:** As a trader, I want the system to implement a CNN model that can identify patterns and predict pivot points across multiple timeframes, so that I can anticipate market direction changes.
#### Acceptance Criteria
1. WHEN the CNN model is initialized THEN it SHALL accept multi-timeframe and multi-symbol data as input.
2. WHEN processing input data THEN the CNN model SHALL output predicted pivot points for each timeframe (1s, 1m, 1h, 1d).
3. WHEN predicting pivot points THEN the CNN model SHALL provide both the predicted pivot point value and the timestamp when it is expected to occur.
4. WHEN a pivot point is detected THEN the system SHALL trigger a training round for the CNN model using historical data.
5. WHEN training the CNN model THEN the system SHALL use programmatically calculated pivot points from historical data as ground truth.
6. WHEN outputting predictions THEN the CNN model SHALL include a confidence score for each prediction.
7. WHEN calculating pivot points THEN the system SHALL implement both standard pivot points and the recursive Williams market structure pivot points as described.
8. WHEN processing data THEN the CNN model SHALL make available its hidden layer states for use by the RL model.
### Requirement 3: RL Model Implementation
**User Story:** As a trader, I want the system to implement an RL model that can learn optimal trading strategies based on market data and CNN predictions, so that the system can adapt to changing market conditions.
#### Acceptance Criteria
1. WHEN the RL model is initialized THEN it SHALL accept market data, CNN predictions, and CNN hidden layer states as input.
2. WHEN processing input data THEN the RL model SHALL output trading action recommendations (buy/sell).
3. WHEN evaluating trading actions THEN the RL model SHALL learn from past experiences to adapt to the current market environment.
4. WHEN making decisions THEN the RL model SHALL consider the confidence levels of CNN predictions.
5. WHEN uncertain about market direction THEN the RL model SHALL learn to avoid entering positions.
6. WHEN training the RL model THEN the system SHALL use a reward function that incentivizes high risk/reward setups.
7. WHEN outputting trading actions THEN the RL model SHALL provide a confidence score for each action.
8. WHEN a trading action is executed THEN the system SHALL store the input data for future training.
### Requirement 4: Orchestrator Implementation
**User Story:** As a trader, I want the system to implement an orchestrator that can make final trading decisions based on inputs from both CNN and RL models, so that the system can make more balanced and informed trading decisions.
#### Acceptance Criteria
1. WHEN the orchestrator is initialized THEN it SHALL accept inputs from both CNN and RL models.
2. WHEN processing model inputs THEN the orchestrator SHALL output final trading actions (buy/sell).
3. WHEN making decisions THEN the orchestrator SHALL consider the confidence levels of both CNN and RL models.
4. WHEN uncertain about market direction THEN the orchestrator SHALL learn to avoid entering positions.
5. WHEN implementing the orchestrator THEN the system SHALL use a Mixture of Experts (MoE) approach to allow for future model integration.
6. WHEN outputting trading actions THEN the orchestrator SHALL provide a confidence score for each action.
7. WHEN a trading action is executed THEN the system SHALL store the input data for future training.
8. WHEN implementing the orchestrator THEN the system SHALL allow for configurable thresholds for entering and exiting positions.
### Requirement 5: Training Pipeline
**User Story:** As a developer, I want the system to implement a unified training pipeline for both CNN and RL models, so that the models can be trained efficiently and consistently.
#### Acceptance Criteria
1. WHEN training models THEN the system SHALL use a unified data provider to prepare data for all models.
2. WHEN a pivot point is detected THEN the system SHALL trigger a training round for the CNN model.
3. WHEN training the CNN model THEN the system SHALL use programmatically calculated pivot points from historical data as ground truth.
4. WHEN training the RL model THEN the system SHALL use a reward function that incentivizes high risk/reward setups.
5. WHEN training models THEN the system SHALL run the training process on the server without requiring the dashboard to be open.
6. WHEN training models THEN the system SHALL provide real-time feedback on training progress through the dashboard.
7. WHEN training models THEN the system SHALL store model checkpoints for future use.
8. WHEN training models THEN the system SHALL provide metrics on model performance.
### Requirement 6: Dashboard Implementation
**User Story:** As a trader, I want the system to implement a comprehensive dashboard that displays real-time data, model predictions, and trading actions, so that I can monitor the system's performance and make informed decisions.
#### Acceptance Criteria
1. WHEN the dashboard is initialized THEN it SHALL display real-time market data for all symbols and timeframes.
2. WHEN displaying market data THEN the dashboard SHALL show OHLCV charts for all timeframes.
3. WHEN displaying model predictions THEN the dashboard SHALL show CNN pivot point predictions and confidence levels.
4. WHEN displaying trading actions THEN the dashboard SHALL show RL and orchestrator trading actions and confidence levels.
5. WHEN displaying system status THEN the dashboard SHALL show training progress and model performance metrics.
6. WHEN implementing controls THEN the dashboard SHALL provide start/stop toggles for all system processes.
7. WHEN implementing controls THEN the dashboard SHALL provide sliders to adjust buy/sell thresholds for the orchestrator.
8. WHEN implementing the dashboard THEN the system SHALL ensure all processes run on the server without requiring the dashboard to be open.
### Requirement 7: Risk Management
**User Story:** As a trader, I want the system to implement risk management features, so that I can protect my capital from significant losses.
#### Acceptance Criteria
1. WHEN implementing risk management THEN the system SHALL provide configurable stop-loss functionality.
2. WHEN a stop-loss is triggered THEN the system SHALL automatically close the position.
3. WHEN implementing risk management THEN the system SHALL provide configurable position sizing based on risk parameters.
4. WHEN implementing risk management THEN the system SHALL provide configurable maximum drawdown limits.
5. WHEN maximum drawdown limits are reached THEN the system SHALL automatically stop trading.
6. WHEN implementing risk management THEN the system SHALL provide real-time risk metrics through the dashboard.
7. WHEN implementing risk management THEN the system SHALL allow for different risk parameters for different market conditions.
8. WHEN implementing risk management THEN the system SHALL provide alerts for high-risk situations.
### Requirement 8: System Architecture and Integration
**User Story:** As a developer, I want the system to implement a clean and modular architecture, so that the system is easy to maintain and extend.
#### Acceptance Criteria
1. WHEN implementing the system architecture THEN the system SHALL use a unified data provider to prepare data for all models.
2. WHEN implementing the system architecture THEN the system SHALL use a modular approach to allow for easy extension.
3. WHEN implementing the system architecture THEN the system SHALL use a clean separation of concerns between data collection, model training, and trading execution.
4. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all models.
5. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all data providers.
6. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all trading executors.
7. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all risk management components.
8. WHEN implementing the system architecture THEN the system SHALL use a unified interface for all dashboard components.
### Requirement 9: Model Inference Data Validation and Storage
**User Story:** As a trading system developer, I want to ensure that all model predictions include complete input data validation and persistent storage, so that I can verify models receive correct inputs and track their performance over time.
#### Acceptance Criteria
1. WHEN a model makes a prediction THEN the system SHALL validate that the input data contains complete OHLCV dataframes for all required timeframes
2. WHEN input data is incomplete THEN the system SHALL log the missing components and SHALL NOT proceed with prediction
3. WHEN input validation passes THEN the system SHALL store the complete input data package with the prediction in persistent storage
4. IF input data dimensions are incorrect THEN the system SHALL raise a validation error with specific details about expected vs actual dimensions
5. WHEN a model completes inference THEN the system SHALL store the complete input data, model outputs, confidence scores, and metadata in a persistent inference history
6. WHEN storing inference data THEN the system SHALL include timestamp, symbol, input features, prediction outputs, and model internal states
7. IF inference history storage fails THEN the system SHALL log the error and continue operation without breaking the prediction flow
### Requirement 10: Inference-Training Feedback Loop
**User Story:** As a machine learning engineer, I want the system to automatically train models using their previous inference data compared to actual market outcomes, so that models continuously improve their accuracy through real-world feedback.
#### Acceptance Criteria
1. WHEN sufficient time has passed after a prediction THEN the system SHALL evaluate the prediction accuracy against actual price movements
2. WHEN a prediction outcome is determined THEN the system SHALL create a training example using the stored inference data and actual outcome
3. WHEN training examples are created THEN the system SHALL feed them back to the respective models for learning
4. IF the prediction was accurate THEN the system SHALL reinforce the model's decision pathway through positive training signals
5. IF the prediction was inaccurate THEN the system SHALL provide corrective training signals to help the model learn from mistakes
6. WHEN the system needs training data THEN it SHALL retrieve the last inference data for each model to compare predictions against actual market outcomes
7. WHEN models are trained on inference feedback THEN the system SHALL track and report accuracy improvements or degradations over time
### Requirement 11: Inference History Management and Monitoring
**User Story:** As a system administrator, I want comprehensive logging and monitoring of the inference-training feedback loop with configurable retention policies, so that I can track model learning progress and manage storage efficiently.
#### Acceptance Criteria
1. WHEN inference data is stored THEN the system SHALL log the storage operation with data completeness metrics and validation results
2. WHEN training occurs based on previous inference THEN the system SHALL log the training outcome and model performance changes
3. WHEN the system detects data flow issues THEN it SHALL alert administrators with specific error details and suggested remediation
4. WHEN inference history reaches configured limits THEN the system SHALL archive or remove oldest entries based on retention policy
5. WHEN storing inference data THEN the system SHALL compress data to minimize storage footprint while maintaining accessibility
6. WHEN retrieving historical inference data THEN the system SHALL provide efficient query mechanisms by symbol, timeframe, and date range
7. IF storage space is critically low THEN the system SHALL prioritize keeping the most recent and most valuable training examples

View File

@ -0,0 +1,382 @@
# Implementation Plan
## Enhanced Data Provider and COB Integration
- [ ] 1. Enhance the existing DataProvider class with standardized model inputs
- Extend the current implementation in core/data_provider.py
- Implement standardized COB+OHLCV data frame for all models
- Create unified input format: 300 frames OHLCV (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC
- Integrate with existing multi_exchange_cob_provider.py for COB data
- _Requirements: 1.1, 1.2, 1.3, 1.6_
- [ ] 1.1. Implement standardized COB+OHLCV data frame for all models
- Create BaseDataInput class with standardized format for all models
- Implement OHLCV: 300 frames of (1s, 1m, 1h, 1d) ETH + 300s of 1s BTC
- Add COB: ±20 buckets of COB amounts in USD for each 1s OHLCV
- Include 1s, 5s, 15s, and 60s MA of COB imbalance counting ±5 COB buckets
- Ensure all models receive identical input format for consistency
- _Requirements: 1.2, 1.3, 8.1_
- [ ] 1.2. Implement extensible model output storage
- Create standardized ModelOutput data structure
- Support CNN, RL, LSTM, Transformer, and future model types
- Include model-specific predictions and cross-model hidden states
- Add metadata support for extensible model information
- _Requirements: 1.10, 8.2_
- [ ] 1.3. Enhance Williams Market Structure pivot point calculation
- Extend existing williams_market_structure.py implementation
- Improve recursive pivot point calculation accuracy
- Add unit tests to verify pivot point detection
- Integrate with COB data for enhanced pivot detection
- _Requirements: 1.5, 2.7_
- [-] 1.4. Optimize real-time data streaming with COB integration
- Enhance existing WebSocket connections in enhanced_cob_websocket.py
- Implement 10Hz COB data streaming alongside OHLCV data
- Add data synchronization across different refresh rates
- Ensure thread-safe access to multi-rate data streams
- _Requirements: 1.6, 8.5_
- [ ] 1.5. Fix WebSocket COB data processing errors
- Fix 'NoneType' object has no attribute 'append' errors in COB data processing
- Ensure proper initialization of data structures in MultiExchangeCOBProvider
- Add validation and defensive checks before accessing data structures
- Implement proper error handling for WebSocket data processing
- _Requirements: 1.1, 1.6, 8.5_
- [ ] 1.6. Enhance error handling in COB data processing
- Add validation for incoming WebSocket data
- Implement reconnection logic with exponential backoff
- Add detailed logging for debugging COB data issues
- Ensure system continues operation with last valid data during failures
- _Requirements: 1.6, 8.5_
## Enhanced CNN Model Implementation
- [ ] 2. Enhance the existing CNN model with standardized inputs/outputs
- Extend the current implementation in NN/models/enhanced_cnn.py
- Accept standardized COB+OHLCV data frame: 300 frames (1s,1m,1h,1d) ETH + 300s 1s BTC
- Include COB ±20 buckets and MA (1s,5s,15s,60s) of COB imbalance ±5 buckets
- Output BUY/SELL trading action with confidence scores - _Requirements: 2.1, 2.2, 2.8, 1.10_
- [x] 2.1. Implement CNN inference with standardized input format
- Accept BaseDataInput with standardized COB+OHLCV format
- Process 300 frames of multi-timeframe data with COB buckets
- Output BUY/SELL recommendations with confidence scores
- Make hidden layer states available for cross-model feeding
- Optimize inference performance for real-time processing
- _Requirements: 2.2, 2.6, 2.8, 4.3_
- [x] 2.2. Enhance CNN training pipeline with checkpoint management
- Integrate with checkpoint manager for training progress persistence
- Store top 5-10 best checkpoints based on performance metrics
- Automatically load best checkpoint at startup
- Implement training triggers based on orchestrator feedback
- Store metadata with checkpoints for performance tracking
- _Requirements: 2.4, 2.5, 5.2, 5.3, 5.7_
- [ ] 2.3. Implement CNN model evaluation and checkpoint optimization
- Create evaluation methods using standardized input/output format
- Implement performance metrics for checkpoint ranking
- Add validation against historical trading outcomes
- Support automatic checkpoint cleanup (keep only top performers)
- Track model improvement over time through checkpoint metadata
- _Requirements: 2.5, 5.8, 4.4_
## Enhanced RL Model Implementation
- [ ] 3. Enhance the existing RL model with standardized inputs/outputs
- Extend the current implementation in NN/models/dqn_agent.py
- Accept standardized COB+OHLCV data frame: 300 frames (1s,1m,1h,1d) ETH + 300s 1s BTC
- Include COB ±20 buckets and MA (1s,5s,15s,60s) of COB imbalance ±5 buckets
- Output BUY/SELL trading action with confidence scores
- _Requirements: 3.1, 3.2, 3.7, 1.10_
- [ ] 3.1. Implement RL inference with standardized input format
- Accept BaseDataInput with standardized COB+OHLCV format
- Process CNN hidden states and predictions as part of state input
- Output BUY/SELL recommendations with confidence scores
- Include expected rewards and value estimates in output
- Optimize inference performance for real-time processing
- _Requirements: 3.2, 3.7, 4.3_
- [ ] 3.2. Enhance RL training pipeline with checkpoint management
- Integrate with checkpoint manager for training progress persistence
- Store top 5-10 best checkpoints based on trading performance metrics
- Automatically load best checkpoint at startup
- Implement experience replay with profitability-based prioritization
- Store metadata with checkpoints for performance tracking
- _Requirements: 3.3, 3.5, 5.4, 5.7, 4.4_
- [ ] 3.3. Implement RL model evaluation and checkpoint optimization
- Create evaluation methods using standardized input/output format
- Implement trading performance metrics for checkpoint ranking
- Add validation against historical trading opportunities
- Support automatic checkpoint cleanup (keep only top performers)
- Track model improvement over time through checkpoint metadata
- _Requirements: 3.3, 5.8, 4.4_
## Enhanced Orchestrator Implementation
- [ ] 4. Enhance the existing orchestrator with centralized coordination
- Extend the current implementation in core/orchestrator.py
- Implement DataSubscriptionManager for multi-rate data streams
- Add ModelInferenceCoordinator for cross-model coordination
- Create ModelOutputStore for extensible model output management
- Add TrainingPipelineManager for continuous learning coordination
- _Requirements: 4.1, 4.2, 4.5, 8.1_
- [ ] 4.1. Implement data subscription and management system
- Create DataSubscriptionManager class
- Subscribe to 10Hz COB data, OHLCV, market ticks, and technical indicators
- Implement intelligent caching for "last updated" data serving
- Maintain synchronized base dataframe across different refresh rates
- Add thread-safe access to multi-rate data streams
- _Requirements: 4.1, 1.6, 8.5_
- [ ] 4.2. Implement model inference coordination
- Create ModelInferenceCoordinator class
- Trigger model inference based on data availability and requirements
- Coordinate parallel inference execution for independent models
- Handle model dependencies (e.g., RL waiting for CNN hidden states)
- Assemble appropriate input data for each model type
- _Requirements: 4.2, 3.1, 2.1_
- [ ] 4.3. Implement model output storage and cross-feeding
- Create ModelOutputStore class using standardized ModelOutput format
- Store CNN predictions, confidence scores, and hidden layer states
- Store RL action recommendations and value estimates
- Support extensible storage for LSTM, Transformer, and future models
- Implement cross-model feeding of hidden states and predictions
- Include "last predictions" from all models in base data input
- _Requirements: 4.3, 1.10, 8.2_
- [ ] 4.4. Implement training pipeline management
- Create TrainingPipelineManager class
- Call each model's training pipeline with prediction-result pairs
- Manage training data collection and labeling
- Coordinate online learning updates based on real-time performance
- Track prediction accuracy and trigger retraining when needed
- _Requirements: 4.4, 5.2, 5.4, 5.7_
- [ ] 4.5. Implement enhanced decision-making with MoE
- Create enhanced DecisionMaker class
- Implement Mixture of Experts approach for model integration
- Apply confidence-based filtering to avoid uncertain trades
- Support configurable thresholds for buy/sell decisions
- Consider market conditions and risk parameters in decisions
- _Requirements: 4.5, 4.8, 6.7_
- [ ] 4.6. Implement extensible model integration architecture
- Create MoEGateway class supporting dynamic model addition
- Support CNN, RL, LSTM, Transformer model types without architecture changes
- Implement model versioning and rollback capabilities
- Handle model failures and fallback mechanisms
- Provide model performance monitoring and alerting
- _Requirements: 4.6, 8.2, 8.3_
## Model Inference Data Validation and Storage
- [x] 5. Implement comprehensive inference data validation system
- Create InferenceDataValidator class for input validation
- Validate complete OHLCV dataframes for all required timeframes
- Check input data dimensions against model requirements
- Log missing components and prevent prediction on incomplete data
- _Requirements: 9.1, 9.2, 9.3, 9.4_
- [ ] 5.1. Implement input data validation for all models
- Create validation methods for CNN, RL, and future model inputs
- Validate OHLCV data completeness (300 frames for 1s, 1m, 1h, 1d)
- Validate COB data structure (±20 buckets, MA calculations)
- Raise specific validation errors with expected vs actual dimensions
- Ensure validation occurs before any model inference
- _Requirements: 9.1, 9.4_
- [x] 5.2. Implement persistent inference history storage
- Create InferenceHistoryStore class for persistent storage
- Store complete input data packages with each prediction
- Include timestamp, symbol, input features, prediction outputs, confidence scores
- Store model internal states for cross-model feeding
- Implement compressed storage to minimize footprint
- _Requirements: 9.5, 9.6_
- [x] 5.3. Implement inference history query and retrieval system
- Create efficient query mechanisms by symbol, timeframe, and date range
- Implement data retrieval for training pipeline consumption
- Add data completeness metrics and validation results in storage
- Handle storage failures gracefully without breaking prediction flow
- _Requirements: 9.7, 11.6_
## Inference-Training Feedback Loop Implementation
- [ ] 6. Implement prediction outcome evaluation system
- Create PredictionOutcomeEvaluator class
- Evaluate prediction accuracy against actual price movements
- Create training examples using stored inference data and actual outcomes
- Feed prediction-result pairs back to respective models
- _Requirements: 10.1, 10.2, 10.3_
- [ ] 6.1. Implement adaptive learning signal generation
- Create positive reinforcement signals for accurate predictions
- Generate corrective training signals for inaccurate predictions
- Retrieve last inference data for each model for outcome comparison
- Implement model-specific learning signal formats
- _Requirements: 10.4, 10.5, 10.6_
- [ ] 6.2. Implement continuous improvement tracking
- Track and report accuracy improvements/degradations over time
- Monitor model learning progress through feedback loop
- Create performance metrics for inference-training effectiveness
- Generate alerts for learning regression or stagnation
- _Requirements: 10.7_
## Inference History Management and Monitoring
- [ ] 7. Implement comprehensive inference logging and monitoring
- Create InferenceMonitor class for logging and alerting
- Log inference data storage operations with completeness metrics
- Log training outcomes and model performance changes
- Alert administrators on data flow issues with specific error details
- _Requirements: 11.1, 11.2, 11.3_
- [ ] 7.1. Implement configurable retention policies
- Create RetentionPolicyManager class
- Archive or remove oldest entries when limits are reached
- Prioritize keeping most recent and valuable training examples
- Implement storage space monitoring and alerts
- _Requirements: 11.4, 11.7_
- [ ] 7.2. Implement efficient historical data management
- Compress inference data to minimize storage footprint
- Maintain accessibility for training and analysis
- Implement efficient query mechanisms for historical analysis
- Add data archival and restoration capabilities
- _Requirements: 11.5, 11.6_
## Trading Executor Implementation
- [ ] 5. Design and implement the trading executor
- Create a TradingExecutor class that accepts trading actions from the orchestrator
- Implement order execution through brokerage APIs
- Add order lifecycle management
- _Requirements: 7.1, 7.2, 8.6_
- [ ] 5.1. Implement brokerage API integrations
- Create a BrokerageAPI interface
- Implement concrete classes for MEXC and Binance
- Add error handling and retry mechanisms
- _Requirements: 7.1, 7.2, 8.6_
- [ ] 5.2. Implement order management
- Create an OrderManager class
- Implement methods for creating, updating, and canceling orders
- Add order tracking and status updates
- _Requirements: 7.1, 7.2, 8.6_
- [ ] 5.3. Implement error handling
- Add comprehensive error handling for API failures
- Implement circuit breakers for extreme market conditions
- Add logging and notification mechanisms
- _Requirements: 7.1, 7.2, 8.6_
## Risk Manager Implementation
- [ ] 6. Design and implement the risk manager
- Create a RiskManager class
- Implement risk parameter management
- Add risk metric calculation
- _Requirements: 7.1, 7.3, 7.4_
- [ ] 6.1. Implement stop-loss functionality
- Create a StopLossManager class
- Implement methods for creating and managing stop-loss orders
- Add mechanisms to automatically close positions when stop-loss is triggered
- _Requirements: 7.1, 7.2_
- [ ] 6.2. Implement position sizing
- Create a PositionSizer class
- Implement methods for calculating position sizes based on risk parameters
- Add validation to ensure position sizes are within limits
- _Requirements: 7.3, 7.7_
- [ ] 6.3. Implement risk metrics
- Add methods to calculate risk metrics (drawdown, VaR, etc.)
- Implement real-time risk monitoring
- Add alerts for high-risk situations
- _Requirements: 7.4, 7.5, 7.6, 7.8_
## Dashboard Implementation
- [ ] 7. Design and implement the dashboard UI
- Create a Dashboard class
- Implement the web-based UI using Flask/Dash
- Add real-time updates using WebSockets
- _Requirements: 6.1, 6.8_
- [ ] 7.1. Implement chart management
- Create a ChartManager class
- Implement methods for creating and updating charts
- Add interactive features (zoom, pan, etc.)
- _Requirements: 6.1, 6.2_
- [ ] 7.2. Implement control panel
- Create a ControlPanel class
- Implement start/stop toggles for system processes
- Add sliders for adjusting buy/sell thresholds
- _Requirements: 6.6, 6.7_
- [ ] 7.3. Implement system status display
- Add methods to display training progress
- Implement model performance metrics visualization
- Add real-time system status updates
- _Requirements: 6.5, 5.6_
- [ ] 7.4. Implement server-side processing
- Ensure all processes run on the server without requiring the dashboard to be open
- Implement background tasks for model training and inference
- Add mechanisms to persist system state
- _Requirements: 6.8, 5.5_
## Integration and Testing
- [ ] 8. Integrate all components
- Connect the data provider to the CNN and RL models
- Connect the CNN and RL models to the orchestrator
- Connect the orchestrator to the trading executor
- _Requirements: 8.1, 8.2, 8.3_
- [ ] 8.1. Implement comprehensive unit tests
- Create unit tests for each component
- Implement test fixtures and mocks
- Add test coverage reporting
- _Requirements: 8.1, 8.2, 8.3_
- [ ] 8.2. Implement integration tests
- Create tests for component interactions
- Implement end-to-end tests
- Add performance benchmarks
- _Requirements: 8.1, 8.2, 8.3_
- [ ] 8.3. Implement backtesting framework
- Create a backtesting environment
- Implement methods to replay historical data
- Add performance metrics calculation
- _Requirements: 5.8, 8.1_
- [ ] 8.4. Optimize performance
- Profile the system to identify bottlenecks
- Implement optimizations for critical paths
- Add caching and parallelization where appropriate
- _Requirements: 8.1, 8.2, 8.3_

View File

@ -0,0 +1,350 @@
# Design Document
## Overview
The UI Stability Fix implements a comprehensive solution to resolve critical stability issues between the dashboard UI and training processes. The design focuses on complete process isolation, proper async/await handling, resource conflict resolution, and robust error handling. The solution ensures that the dashboard can operate independently without affecting training system stability.
## Architecture
### High-Level Architecture
```mermaid
graph TB
subgraph "Training Process"
TP[Training Process]
TM[Training Models]
TD[Training Data]
TL[Training Logs]
end
subgraph "Dashboard Process"
DP[Dashboard Process]
DU[Dashboard UI]
DC[Dashboard Cache]
DL[Dashboard Logs]
end
subgraph "Shared Resources"
SF[Shared Files]
SC[Shared Config]
SM[Shared Models]
SD[Shared Data]
end
TP --> SF
DP --> SF
TP --> SC
DP --> SC
TP --> SM
DP --> SM
TP --> SD
DP --> SD
TP -.->|No Direct Connection| DP
```
### Process Isolation Design
The system will implement complete process isolation using:
1. **Separate Python Processes**: Dashboard and training run as independent processes
2. **Inter-Process Communication**: File-based communication for status and data sharing
3. **Resource Partitioning**: Separate resource allocation for each process
4. **Independent Lifecycle Management**: Each process can start, stop, and restart independently
### Async/Await Error Resolution
The design addresses async issues through:
1. **Proper Event Loop Management**: Single event loop per process with proper lifecycle
2. **Async Context Isolation**: Separate async contexts for different components
3. **Coroutine Handling**: Proper awaiting of all async operations
4. **Exception Propagation**: Proper async exception handling and propagation
## Components and Interfaces
### 1. Process Manager
**Purpose**: Manages the lifecycle of both dashboard and training processes
**Interface**:
```python
class ProcessManager:
def start_training_process(self) -> bool
def start_dashboard_process(self, port: int = 8050) -> bool
def stop_training_process(self) -> bool
def stop_dashboard_process(self) -> bool
def get_process_status(self) -> Dict[str, str]
def restart_process(self, process_name: str) -> bool
```
**Implementation Details**:
- Uses subprocess.Popen for process creation
- Monitors process health with periodic checks
- Handles process output logging and error capture
- Implements graceful shutdown with timeout handling
### 2. Isolated Dashboard
**Purpose**: Provides a completely isolated dashboard that doesn't interfere with training
**Interface**:
```python
class IsolatedDashboard:
def __init__(self, config: Dict[str, Any])
def start_server(self, host: str, port: int) -> None
def stop_server(self) -> None
def update_data_from_files(self) -> None
def get_training_status(self) -> Dict[str, Any]
```
**Implementation Details**:
- Runs in separate process with own event loop
- Reads data from shared files instead of direct memory access
- Uses file-based communication for training status
- Implements proper async/await patterns for all operations
### 3. Isolated Training Process
**Purpose**: Runs training completely isolated from UI components
**Interface**:
```python
class IsolatedTrainingProcess:
def __init__(self, config: Dict[str, Any])
def start_training(self) -> None
def stop_training(self) -> None
def get_training_metrics(self) -> Dict[str, Any]
def save_status_to_file(self) -> None
```
**Implementation Details**:
- No UI dependencies or imports
- Writes status and metrics to shared files
- Implements proper resource cleanup
- Uses separate logging configuration
### 4. Shared Data Manager
**Purpose**: Manages data sharing between processes through files
**Interface**:
```python
class SharedDataManager:
def write_training_status(self, status: Dict[str, Any]) -> None
def read_training_status(self) -> Dict[str, Any]
def write_market_data(self, data: Dict[str, Any]) -> None
def read_market_data(self) -> Dict[str, Any]
def write_model_metrics(self, metrics: Dict[str, Any]) -> None
def read_model_metrics(self) -> Dict[str, Any]
```
**Implementation Details**:
- Uses JSON files for structured data
- Implements file locking to prevent corruption
- Provides atomic write operations
- Includes data validation and error handling
### 5. Resource Manager
**Purpose**: Manages resource allocation and prevents conflicts
**Interface**:
```python
class ResourceManager:
def allocate_gpu_resources(self, process_name: str) -> bool
def release_gpu_resources(self, process_name: str) -> None
def check_memory_usage(self) -> Dict[str, float]
def enforce_resource_limits(self) -> None
```
**Implementation Details**:
- Monitors GPU memory usage per process
- Implements resource quotas and limits
- Provides resource conflict detection
- Includes automatic resource cleanup
### 6. Async Handler
**Purpose**: Properly handles all async operations in the dashboard
**Interface**:
```python
class AsyncHandler:
def __init__(self, loop: asyncio.AbstractEventLoop)
async def handle_orchestrator_connection(self) -> None
async def handle_cob_integration(self) -> None
async def handle_trading_decisions(self, decision: Dict) -> None
def run_async_safely(self, coro: Coroutine) -> Any
```
**Implementation Details**:
- Manages single event loop per process
- Provides proper exception handling for async operations
- Implements timeout handling for long-running operations
- Includes async context management
## Data Models
### Process Status Model
```python
@dataclass
class ProcessStatus:
name: str
pid: int
status: str # 'running', 'stopped', 'error'
start_time: datetime
last_heartbeat: datetime
memory_usage: float
cpu_usage: float
error_message: Optional[str] = None
```
### Training Status Model
```python
@dataclass
class TrainingStatus:
is_running: bool
current_epoch: int
total_epochs: int
loss: float
accuracy: float
last_update: datetime
model_path: str
error_message: Optional[str] = None
```
### Dashboard State Model
```python
@dataclass
class DashboardState:
is_connected: bool
last_data_update: datetime
active_connections: int
error_count: int
performance_metrics: Dict[str, float]
```
## Error Handling
### Exception Hierarchy
```python
class UIStabilityError(Exception):
"""Base exception for UI stability issues"""
pass
class ProcessCommunicationError(UIStabilityError):
"""Error in inter-process communication"""
pass
class AsyncOperationError(UIStabilityError):
"""Error in async operation handling"""
pass
class ResourceConflictError(UIStabilityError):
"""Error due to resource conflicts"""
pass
```
### Error Recovery Strategies
1. **Automatic Retry**: For transient network and file I/O errors
2. **Graceful Degradation**: Fallback to basic functionality when components fail
3. **Process Restart**: Automatic restart of failed processes
4. **Circuit Breaker**: Temporary disable of failing components
5. **Rollback**: Revert to last known good state
### Error Monitoring
- Centralized error logging with structured format
- Real-time error rate monitoring
- Automatic alerting for critical errors
- Error trend analysis and reporting
## Testing Strategy
### Unit Tests
- Test each component in isolation
- Mock external dependencies
- Verify error handling paths
- Test async operation handling
### Integration Tests
- Test inter-process communication
- Verify resource sharing mechanisms
- Test process lifecycle management
- Validate error recovery scenarios
### System Tests
- End-to-end stability testing
- Load testing with concurrent processes
- Failure injection testing
- Performance regression testing
### Monitoring Tests
- Health check endpoint testing
- Metrics collection validation
- Alert system testing
- Dashboard functionality testing
## Performance Considerations
### Resource Optimization
- Minimize memory footprint of each process
- Optimize file I/O operations for data sharing
- Implement efficient data serialization
- Use connection pooling for external services
### Scalability
- Support multiple dashboard instances
- Handle increased data volume gracefully
- Implement efficient caching strategies
- Optimize for high-frequency updates
### Monitoring
- Real-time performance metrics collection
- Resource usage tracking per process
- Response time monitoring
- Throughput measurement
## Security Considerations
### Process Isolation
- Separate user contexts for processes
- Limited file system access permissions
- Network access restrictions
- Resource usage limits
### Data Protection
- Secure file sharing mechanisms
- Data validation and sanitization
- Access control for shared resources
- Audit logging for sensitive operations
### Communication Security
- Encrypted inter-process communication
- Authentication for API endpoints
- Input validation for all interfaces
- Rate limiting for external requests
## Deployment Strategy
### Development Environment
- Local process management scripts
- Development-specific configuration
- Enhanced logging and debugging
- Hot-reload capabilities
### Production Environment
- Systemd service management
- Production configuration templates
- Log rotation and archiving
- Monitoring and alerting setup
### Migration Plan
1. Deploy new process management components
2. Update configuration files
3. Test process isolation functionality
4. Gradually migrate existing deployments
5. Monitor stability improvements
6. Remove legacy components

View File

@ -0,0 +1,111 @@
# Requirements Document
## Introduction
The UI Stability Fix addresses critical issues where loading the dashboard UI crashes the training process and causes unhandled exceptions. The system currently suffers from async/await handling problems, threading conflicts, resource contention, and improper separation of concerns between the UI and training processes. This fix will ensure the dashboard can run independently without affecting the training system's stability.
## Requirements
### Requirement 1: Async/Await Error Resolution
**User Story:** As a developer, I want the dashboard to properly handle async operations, so that unhandled exceptions don't crash the entire system.
#### Acceptance Criteria
1. WHEN the dashboard initializes THEN it SHALL properly handle all async operations without throwing "An asyncio.Future, a coroutine or an awaitable is required" errors.
2. WHEN connecting to the orchestrator THEN the system SHALL use proper async/await patterns for all coroutine calls.
3. WHEN starting COB integration THEN the system SHALL properly manage event loops without conflicts.
4. WHEN handling trading decisions THEN async callbacks SHALL be properly awaited and handled.
5. WHEN the dashboard starts THEN it SHALL not create multiple conflicting event loops.
6. WHEN async operations fail THEN the system SHALL handle exceptions gracefully without crashing.
### Requirement 2: Process Isolation
**User Story:** As a user, I want the dashboard and training processes to run independently, so that UI issues don't affect training stability.
#### Acceptance Criteria
1. WHEN the dashboard starts THEN it SHALL run in a completely separate process from the training system.
2. WHEN the dashboard crashes THEN the training process SHALL continue running unaffected.
3. WHEN the training process encounters issues THEN the dashboard SHALL remain functional.
4. WHEN both processes are running THEN they SHALL communicate only through well-defined interfaces (files, APIs, or message queues).
5. WHEN either process restarts THEN the other process SHALL continue operating normally.
6. WHEN resources are accessed THEN there SHALL be no direct shared memory or threading conflicts between processes.
### Requirement 3: Resource Contention Resolution
**User Story:** As a system administrator, I want to eliminate resource conflicts between UI and training, so that both can operate efficiently without interference.
#### Acceptance Criteria
1. WHEN both dashboard and training are running THEN they SHALL not compete for the same GPU resources.
2. WHEN accessing data files THEN proper file locking SHALL prevent corruption or access conflicts.
3. WHEN using network resources THEN rate limiting SHALL prevent API conflicts between processes.
4. WHEN accessing model files THEN proper synchronization SHALL prevent read/write conflicts.
5. WHEN logging THEN separate log files SHALL be used to prevent write conflicts.
6. WHEN using temporary files THEN separate directories SHALL be used for each process.
### Requirement 4: Threading Safety
**User Story:** As a developer, I want all threading operations to be safe and properly managed, so that race conditions and deadlocks don't occur.
#### Acceptance Criteria
1. WHEN the dashboard uses threads THEN all shared data SHALL be properly synchronized.
2. WHEN background updates run THEN they SHALL not interfere with main UI thread operations.
3. WHEN stopping threads THEN proper cleanup SHALL occur without hanging or deadlocks.
4. WHEN accessing shared resources THEN proper locking mechanisms SHALL be used.
5. WHEN threads encounter exceptions THEN they SHALL be handled without crashing the main process.
6. WHEN the dashboard shuts down THEN all threads SHALL be properly terminated.
### Requirement 5: Error Handling and Recovery
**User Story:** As a user, I want the system to handle errors gracefully and recover automatically, so that temporary issues don't cause permanent failures.
#### Acceptance Criteria
1. WHEN unhandled exceptions occur THEN they SHALL be caught and logged without crashing the process.
2. WHEN network connections fail THEN the system SHALL retry with exponential backoff.
3. WHEN data sources are unavailable THEN fallback mechanisms SHALL provide basic functionality.
4. WHEN memory issues occur THEN the system SHALL free resources and continue operating.
5. WHEN critical errors happen THEN the system SHALL attempt automatic recovery.
6. WHEN recovery fails THEN the system SHALL provide clear error messages and graceful degradation.
### Requirement 6: Monitoring and Diagnostics
**User Story:** As a developer, I want comprehensive monitoring and diagnostics, so that I can quickly identify and resolve stability issues.
#### Acceptance Criteria
1. WHEN the system runs THEN it SHALL provide real-time health monitoring for all components.
2. WHEN errors occur THEN detailed diagnostic information SHALL be logged with timestamps and context.
3. WHEN performance issues arise THEN resource usage metrics SHALL be available.
4. WHEN processes communicate THEN message flow SHALL be traceable for debugging.
5. WHEN the system starts THEN startup diagnostics SHALL verify all components are working correctly.
6. WHEN stability issues occur THEN automated alerts SHALL notify administrators.
### Requirement 7: Configuration and Control
**User Story:** As a system administrator, I want flexible configuration options, so that I can optimize system behavior for different environments.
#### Acceptance Criteria
1. WHEN configuring the system THEN separate configuration files SHALL be used for dashboard and training processes.
2. WHEN adjusting resource limits THEN configuration SHALL allow tuning memory, CPU, and GPU usage.
3. WHEN setting update intervals THEN dashboard refresh rates SHALL be configurable.
4. WHEN enabling features THEN individual components SHALL be independently controllable.
5. WHEN debugging THEN log levels SHALL be adjustable without restarting processes.
6. WHEN deploying THEN environment-specific configurations SHALL be supported.
### Requirement 8: Backward Compatibility
**User Story:** As a user, I want the stability fixes to maintain existing functionality, so that current workflows continue to work.
#### Acceptance Criteria
1. WHEN the fixes are applied THEN all existing dashboard features SHALL continue to work.
2. WHEN training processes run THEN they SHALL maintain the same interfaces and outputs.
3. WHEN data is accessed THEN existing data formats SHALL remain compatible.
4. WHEN APIs are used THEN existing endpoints SHALL continue to function.
5. WHEN configurations are loaded THEN existing config files SHALL remain valid.
6. WHEN the system upgrades THEN migration paths SHALL preserve user settings and data.

View File

@ -0,0 +1,79 @@
# Implementation Plan
- [x] 1. Create Shared Data Manager for inter-process communication
- Implement JSON-based file sharing with atomic writes and file locking
- Create data models for training status, dashboard state, and process status
- Add validation and error handling for all data operations
- _Requirements: 2.4, 3.4, 5.2_
- [ ] 2. Implement Async Handler for proper async/await management
- Create centralized async operation handler with single event loop management
- Fix all async/await patterns in dashboard code
- Add proper exception handling for async operations with timeout support
- _Requirements: 1.1, 1.2, 1.3, 1.6_
- [ ] 3. Create Isolated Training Process
- Extract training logic into standalone process without UI dependencies
- Implement file-based status reporting and metrics sharing
- Add proper resource cleanup and error handling
- _Requirements: 2.1, 2.2, 3.1, 4.5_
- [ ] 4. Create Isolated Dashboard Process
- Refactor dashboard to run independently with file-based data access
- Remove direct memory sharing and threading conflicts with training
- Implement proper process lifecycle management
- _Requirements: 2.1, 2.3, 4.1, 4.2_
- [ ] 5. Implement Process Manager
- Create process lifecycle management with subprocess handling
- Add process monitoring, health checks, and automatic restart capabilities
- Implement graceful shutdown with proper cleanup
- _Requirements: 2.5, 5.5, 6.1, 6.6_
- [ ] 6. Create Resource Manager
- Implement GPU resource allocation and conflict prevention
- Add memory usage monitoring and resource limits enforcement
- Create separate logging and temporary file management
- _Requirements: 3.1, 3.2, 3.5, 3.6_
- [ ] 7. Fix Threading Safety Issues
- Audit and fix all shared data access with proper synchronization
- Implement proper thread cleanup and exception handling
- Remove race conditions and deadlock potential
- _Requirements: 4.1, 4.2, 4.3, 4.6_
- [ ] 8. Implement Error Handling and Recovery
- Add comprehensive exception handling with proper logging
- Create automatic retry mechanisms with exponential backoff
- Implement fallback mechanisms and graceful degradation
- _Requirements: 5.1, 5.2, 5.3, 5.6_
- [ ] 9. Create System Launcher and Configuration
- Build unified launcher script for both processes
- Create separate configuration files for dashboard and training
- Add environment-specific configuration support
- _Requirements: 7.1, 7.2, 7.4, 7.6_
- [ ] 10. Add Monitoring and Diagnostics
- Implement real-time health monitoring for all components
- Create detailed diagnostic logging with structured format
- Add performance metrics collection and resource usage tracking
- _Requirements: 6.1, 6.2, 6.3, 6.5_
- [ ] 11. Create Integration Tests
- Write tests for inter-process communication and data sharing
- Test process lifecycle management and error recovery
- Validate resource conflict resolution and stability improvements
- _Requirements: 5.4, 5.5, 6.4, 8.1_
- [ ] 12. Update Documentation and Migration Guide
- Document new architecture and deployment procedures
- Create migration guide from existing system
- Add troubleshooting guide for common stability issues
- _Requirements: 8.2, 8.5, 8.6_

View File

@ -0,0 +1,293 @@
# WebSocket COB Data Fix Design Document
## Overview
This design document outlines the approach to fix the WebSocket COB (Change of Basis) data processing issue in the trading system. The current implementation is failing with `'NoneType' object has no attribute 'append'` errors for both BTC/USDT and ETH/USDT pairs, which indicates that a data structure expected to be a list is actually None. This issue is preventing the dashboard from functioning properly and needs to be addressed to ensure reliable real-time market data processing.
## Architecture
The COB data processing pipeline involves several components:
1. **MultiExchangeCOBProvider**: Collects order book data from exchanges via WebSockets
2. **StandardizedDataProvider**: Extends DataProvider with standardized BaseDataInput functionality
3. **Dashboard Components**: Display COB data in the UI
The error occurs during WebSocket data processing, specifically when trying to append data to a collection that hasn't been properly initialized. The fix will focus on ensuring proper initialization of data structures and implementing robust error handling.
## Components and Interfaces
### 1. MultiExchangeCOBProvider
The `MultiExchangeCOBProvider` class is responsible for collecting order book data from exchanges and distributing it to subscribers. The issue appears to be in the WebSocket data processing logic, where data structures may not be properly initialized before use.
#### Key Issues to Address
1. **Data Structure Initialization**: Ensure all data structures (particularly collections that will have `append` called on them) are properly initialized during object creation.
2. **Subscriber Notification**: Fix the `_notify_cob_subscribers` method to handle edge cases and ensure data is properly formatted before notification.
3. **WebSocket Processing**: Enhance error handling in WebSocket processing methods to prevent cascading failures.
#### Implementation Details
```python
class MultiExchangeCOBProvider:
def __init__(self, symbols: List[str], exchange_configs: Dict[str, ExchangeConfig]):
# Existing initialization code...
# Ensure all data structures are properly initialized
self.cob_data_cache = {} # Cache for COB data
self.cob_subscribers = [] # List of callback functions
self.exchange_order_books = {}
self.session_trades = {}
self.svp_cache = {}
# Initialize data structures for each symbol
for symbol in symbols:
self.cob_data_cache[symbol] = {}
self.exchange_order_books[symbol] = {}
self.session_trades[symbol] = []
self.svp_cache[symbol] = {}
# Initialize exchange-specific data structures
for exchange_name in self.active_exchanges:
self.exchange_order_books[symbol][exchange_name] = {
'bids': {},
'asks': {},
'deep_bids': {},
'deep_asks': {},
'timestamp': datetime.now(),
'deep_timestamp': datetime.now(),
'connected': False,
'last_update_id': 0
}
logger.info(f"Multi-exchange COB provider initialized for symbols: {symbols}")
async def _notify_cob_subscribers(self, symbol: str, cob_snapshot: Dict):
"""Notify all subscribers of COB data updates with improved error handling"""
try:
if not cob_snapshot:
logger.warning(f"Attempted to notify subscribers with empty COB snapshot for {symbol}")
return
for callback in self.cob_subscribers:
try:
if asyncio.iscoroutinefunction(callback):
await callback(symbol, cob_snapshot)
else:
callback(symbol, cob_snapshot)
except Exception as e:
logger.error(f"Error in COB subscriber callback: {e}", exc_info=True)
except Exception as e:
logger.error(f"Error notifying COB subscribers: {e}", exc_info=True)
```
### 2. StandardizedDataProvider
The `StandardizedDataProvider` class extends the base `DataProvider` with standardized data input functionality. It needs to properly handle COB data and ensure all data structures are initialized.
#### Key Issues to Address
1. **COB Data Handling**: Ensure proper initialization and validation of COB data structures.
2. **Error Handling**: Improve error handling when processing COB data.
3. **Data Structure Consistency**: Maintain consistent data structures throughout the processing pipeline.
#### Implementation Details
```python
class StandardizedDataProvider(DataProvider):
def __init__(self, symbols: List[str] = None, timeframes: List[str] = None):
"""Initialize the standardized data provider with proper data structure initialization"""
super().__init__(symbols, timeframes)
# Standardized data storage
self.base_data_cache = {} # {symbol: BaseDataInput}
self.cob_data_cache = {} # {symbol: COBData}
# Model output management with extensible storage
self.model_output_manager = ModelOutputManager(
cache_dir=str(self.cache_dir / "model_outputs"),
max_history=1000
)
# COB moving averages calculation
self.cob_imbalance_history = {} # {symbol: deque of (timestamp, imbalance_data)}
self.ma_calculation_lock = Lock()
# Initialize caches for each symbol
for symbol in self.symbols:
self.base_data_cache[symbol] = None
self.cob_data_cache[symbol] = None
self.cob_imbalance_history[symbol] = deque(maxlen=300) # 5 minutes of 1s data
# COB provider integration
self.cob_provider = None
self._initialize_cob_provider()
logger.info("StandardizedDataProvider initialized with BaseDataInput support")
def _process_cob_data(self, symbol: str, cob_snapshot: Dict):
"""Process COB data with improved error handling"""
try:
if not cob_snapshot:
logger.warning(f"Received empty COB snapshot for {symbol}")
return
# Process COB data and update caches
# ...
except Exception as e:
logger.error(f"Error processing COB data for {symbol}: {e}", exc_info=True)
```
### 3. WebSocket COB Data Processing
The WebSocket COB data processing logic needs to be enhanced to handle edge cases and ensure proper data structure initialization.
#### Key Issues to Address
1. **WebSocket Connection Management**: Improve connection management to handle disconnections gracefully.
2. **Data Processing**: Ensure data is properly validated before processing.
3. **Error Recovery**: Implement recovery mechanisms for WebSocket failures.
#### Implementation Details
```python
async def _stream_binance_orderbook(self, symbol: str, config: ExchangeConfig):
"""Stream order book data from Binance with improved error handling"""
reconnect_delay = 1 # Start with 1 second delay
max_reconnect_delay = 60 # Maximum delay of 60 seconds
while self.is_streaming:
try:
ws_url = f"{config.websocket_url}{config.symbols_mapping[symbol].lower()}@depth20@100ms"
logger.info(f"Connecting to Binance WebSocket: {ws_url}")
if websockets is None or websockets_connect is None:
raise ImportError("websockets module not available")
async with websockets_connect(ws_url) as websocket:
# Ensure data structures are initialized
if symbol not in self.exchange_order_books:
self.exchange_order_books[symbol] = {}
if 'binance' not in self.exchange_order_books[symbol]:
self.exchange_order_books[symbol]['binance'] = {
'bids': {},
'asks': {},
'deep_bids': {},
'deep_asks': {},
'timestamp': datetime.now(),
'deep_timestamp': datetime.now(),
'connected': False,
'last_update_id': 0
}
self.exchange_order_books[symbol]['binance']['connected'] = True
logger.info(f"Connected to Binance order book stream for {symbol}")
# Reset reconnect delay on successful connection
reconnect_delay = 1
async for message in websocket:
if not self.is_streaming:
break
try:
data = json.loads(message)
await self._process_binance_orderbook(symbol, data)
except json.JSONDecodeError as e:
logger.error(f"Error parsing Binance message: {e}")
except Exception as e:
logger.error(f"Error processing Binance data: {e}", exc_info=True)
except Exception as e:
logger.error(f"Binance WebSocket error for {symbol}: {e}", exc_info=True)
# Mark as disconnected
if symbol in self.exchange_order_books and 'binance' in self.exchange_order_books[symbol]:
self.exchange_order_books[symbol]['binance']['connected'] = False
# Implement exponential backoff for reconnection
logger.info(f"Reconnecting to Binance WebSocket for {symbol} in {reconnect_delay}s")
await asyncio.sleep(reconnect_delay)
reconnect_delay = min(reconnect_delay * 2, max_reconnect_delay)
```
## Data Models
The data models remain unchanged, but we need to ensure they are properly initialized and validated throughout the system.
### COBSnapshot
```python
@dataclass
class COBSnapshot:
"""Complete Consolidated Order Book snapshot"""
symbol: str
timestamp: datetime
consolidated_bids: List[ConsolidatedOrderBookLevel]
consolidated_asks: List[ConsolidatedOrderBookLevel]
exchanges_active: List[str]
volume_weighted_mid: float
total_bid_liquidity: float
total_ask_liquidity: float
spread_bps: float
liquidity_imbalance: float
price_buckets: Dict[str, Dict[str, float]] # Fine-grain volume buckets
```
## Error Handling
### WebSocket Connection Errors
- Implement exponential backoff for reconnection attempts
- Log detailed error information
- Maintain system operation with last valid data
### Data Processing Errors
- Validate data before processing
- Handle edge cases gracefully
- Log detailed error information
- Continue operation with last valid data
### Subscriber Notification Errors
- Catch and log errors in subscriber callbacks
- Prevent errors in one subscriber from affecting others
- Ensure data is properly formatted before notification
## Testing Strategy
### Unit Testing
- Test data structure initialization
- Test error handling in WebSocket processing
- Test subscriber notification with various edge cases
### Integration Testing
- Test end-to-end COB data flow
- Test recovery from WebSocket disconnections
- Test handling of malformed data
### System Testing
- Test dashboard operation with COB data
- Test system stability under high load
- Test recovery from various failure scenarios
## Implementation Plan
1. Fix data structure initialization in `MultiExchangeCOBProvider`
2. Enhance error handling in WebSocket processing
3. Improve subscriber notification logic
4. Update `StandardizedDataProvider` to properly handle COB data
5. Add comprehensive logging for debugging
6. Implement recovery mechanisms for WebSocket failures
7. Test all changes thoroughly
## Conclusion
This design addresses the WebSocket COB data processing issue by ensuring proper initialization of data structures, implementing robust error handling, and adding recovery mechanisms for WebSocket failures. These changes will improve the reliability and stability of the trading system, allowing traders to monitor market data in real-time without interruptions.

View File

@ -0,0 +1,43 @@
# Requirements Document
## Introduction
The WebSocket COB Data Fix is needed to address a critical issue in the trading system where WebSocket COB (Change of Basis) data processing is failing with the error `'NoneType' object has no attribute 'append'`. This error is occurring for both BTC/USDT and ETH/USDT pairs and is preventing the dashboard from functioning properly. The fix will ensure proper initialization and handling of data structures in the COB data processing pipeline.
## Requirements
### Requirement 1: Fix WebSocket COB Data Processing
**User Story:** As a trader, I want the WebSocket COB data processing to work reliably without errors, so that I can monitor market data in real-time and make informed trading decisions.
#### Acceptance Criteria
1. WHEN WebSocket COB data is received for any trading pair THEN the system SHALL process it without throwing 'NoneType' object has no attribute 'append' errors
2. WHEN the dashboard is started THEN all data structures for COB processing SHALL be properly initialized
3. WHEN COB data is processed THEN the system SHALL handle edge cases such as missing or incomplete data gracefully
4. WHEN a WebSocket connection is established THEN the system SHALL verify that all required data structures are initialized before processing data
5. WHEN COB data is being processed THEN the system SHALL log appropriate debug information to help diagnose any issues
### Requirement 2: Ensure Data Structure Consistency
**User Story:** As a system administrator, I want consistent data structures throughout the COB processing pipeline, so that data can flow smoothly between components without errors.
#### Acceptance Criteria
1. WHEN the multi_exchange_cob_provider initializes THEN it SHALL properly initialize all required data structures
2. WHEN the standardized_data_provider receives COB data THEN it SHALL validate the data structure before processing
3. WHEN COB data is passed between components THEN the system SHALL ensure type consistency
4. WHEN new COB data arrives THEN the system SHALL update the data structures atomically to prevent race conditions
5. WHEN a component subscribes to COB updates THEN the system SHALL verify the subscriber can handle the data format
### Requirement 3: Improve Error Handling and Recovery
**User Story:** As a system operator, I want robust error handling and recovery mechanisms in the COB data processing pipeline, so that temporary failures don't cause the entire system to crash.
#### Acceptance Criteria
1. WHEN an error occurs in COB data processing THEN the system SHALL log detailed error information
2. WHEN a WebSocket connection fails THEN the system SHALL attempt to reconnect automatically
3. WHEN data processing fails THEN the system SHALL continue operation with the last valid data
4. WHEN the system recovers from an error THEN it SHALL restore normal operation without manual intervention
5. WHEN multiple consecutive errors occur THEN the system SHALL implement exponential backoff to prevent overwhelming the system

View File

@ -0,0 +1,115 @@
# Implementation Plan
- [ ] 1. Fix data structure initialization in MultiExchangeCOBProvider
- Ensure all collections are properly initialized during object creation
- Add defensive checks before accessing data structures
- Implement proper initialization for symbol-specific data structures
- _Requirements: 1.1, 1.2, 2.1_
- [ ] 1.1. Update MultiExchangeCOBProvider constructor
- Modify __init__ method to properly initialize all data structures
- Ensure exchange_order_books is initialized for each symbol and exchange
- Initialize session_trades and svp_cache for each symbol
- Add defensive checks to prevent NoneType errors
- _Requirements: 1.2, 2.1_
- [ ] 1.2. Fix _notify_cob_subscribers method
- Add validation to ensure cob_snapshot is not None before processing
- Add defensive checks before accessing cob_snapshot attributes
- Improve error handling for subscriber callbacks
- Add detailed logging for debugging
- _Requirements: 1.1, 1.5, 2.3_
- [ ] 2. Enhance WebSocket data processing in MultiExchangeCOBProvider
- Improve error handling in WebSocket connection methods
- Add validation for incoming data
- Implement reconnection logic with exponential backoff
- _Requirements: 1.3, 1.4, 3.1, 3.2_
- [ ] 2.1. Update _stream_binance_orderbook method
- Add data structure initialization checks
- Implement exponential backoff for reconnection attempts
- Add detailed error logging
- Ensure proper cleanup on disconnection
- _Requirements: 1.4, 3.2, 3.4_
- [ ] 2.2. Fix _process_binance_orderbook method
- Add validation for incoming data
- Ensure data structures exist before updating
- Add defensive checks to prevent NoneType errors
- Improve error handling and logging
- _Requirements: 1.1, 1.3, 3.1_
- [ ] 3. Update StandardizedDataProvider to handle COB data properly
- Improve initialization of COB-related data structures
- Add validation for COB data
- Enhance error handling for COB data processing
- _Requirements: 1.3, 2.2, 2.3_
- [ ] 3.1. Fix _get_cob_data method
- Add validation for COB provider availability
- Ensure proper initialization of COB data structures
- Add defensive checks to prevent NoneType errors
- Improve error handling and logging
- _Requirements: 1.3, 2.2, 3.3_
- [ ] 3.2. Update _calculate_cob_moving_averages method
- Add validation for input data
- Ensure proper initialization of moving average data structures
- Add defensive checks to prevent NoneType errors
- Improve error handling for edge cases
- _Requirements: 1.3, 2.2, 3.3_
- [ ] 4. Implement recovery mechanisms for WebSocket failures
- Add state tracking for WebSocket connections
- Implement automatic reconnection with exponential backoff
- Add fallback mechanisms for temporary failures
- _Requirements: 3.2, 3.3, 3.4_
- [ ] 4.1. Add connection state management
- Track connection state for each WebSocket
- Implement health check mechanism
- Add reconnection logic based on connection state
- _Requirements: 3.2, 3.4_
- [ ] 4.2. Implement data recovery mechanisms
- Add caching for last valid data
- Implement fallback to cached data during connection issues
- Add mechanism to rebuild state after reconnection
- _Requirements: 3.3, 3.4_
- [ ] 5. Add comprehensive logging for debugging
- Add detailed logging throughout the COB processing pipeline
- Include context information in log messages
- Add performance metrics logging
- _Requirements: 1.5, 3.1_
- [ ] 5.1. Enhance logging in MultiExchangeCOBProvider
- Add detailed logging for WebSocket connections
- Log data processing steps and outcomes
- Add performance metrics for data processing
- _Requirements: 1.5, 3.1_
- [ ] 5.2. Add logging in StandardizedDataProvider
- Log COB data processing steps
- Add validation logging
- Include performance metrics for data processing
- _Requirements: 1.5, 3.1_
- [ ] 6. Test all changes thoroughly
- Write unit tests for fixed components
- Test integration between components
- Verify dashboard operation with COB data
- _Requirements: 1.1, 2.3, 3.4_
- [ ] 6.1. Write unit tests for MultiExchangeCOBProvider
- Test data structure initialization
- Test WebSocket processing with mock data
- Test error handling and recovery
- _Requirements: 1.1, 1.3, 3.1_
- [ ] 6.2. Test integration with dashboard
- Verify COB data display in dashboard
- Test system stability under load
- Verify recovery from failures
- _Requirements: 1.1, 3.3, 3.4_

2
.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,2 @@
{
}

8
.vscode/tasks.json vendored
View File

@ -6,8 +6,10 @@
"type": "shell",
"command": "powershell",
"args": [
"-Command",
"Get-Process python | Where-Object {$_.ProcessName -eq 'python' -and $_.MainWindowTitle -like '*dashboard*'} | Stop-Process -Force; Start-Sleep -Seconds 1"
"-ExecutionPolicy",
"Bypass",
"-File",
"scripts/kill_stale_processes.ps1"
],
"group": "build",
"presentation": {
@ -108,4 +110,4 @@
"problemMatcher": []
}
]
}
}

231
COBY/README.md Normal file
View File

@ -0,0 +1,231 @@
# COBY - Multi-Exchange Data Aggregation System
COBY (Cryptocurrency Order Book Yielder) is a comprehensive data collection and aggregation subsystem designed to serve as the foundational data layer for trading systems. It collects real-time order book and OHLCV data from multiple cryptocurrency exchanges, aggregates it into standardized formats, and provides both live data feeds and historical replay capabilities.
## 🏗️ Architecture
The system follows a modular architecture with clear separation of concerns:
```
COBY/
├── config.py # Configuration management
├── models/ # Data models and structures
│ ├── __init__.py
│ └── core.py # Core data models
├── interfaces/ # Abstract interfaces
│ ├── __init__.py
│ ├── exchange_connector.py
│ ├── data_processor.py
│ ├── aggregation_engine.py
│ ├── storage_manager.py
│ └── replay_manager.py
├── utils/ # Utility functions
│ ├── __init__.py
│ ├── exceptions.py
│ ├── logging.py
│ ├── validation.py
│ └── timing.py
└── README.md
```
## 🚀 Features
- **Multi-Exchange Support**: Connect to 10+ major cryptocurrency exchanges
- **Real-Time Data**: High-frequency order book and trade data collection
- **Price Bucket Aggregation**: Configurable price buckets ($10 for BTC, $1 for ETH)
- **Heatmap Visualization**: Real-time market depth heatmaps
- **Historical Replay**: Replay past market events for model training
- **TimescaleDB Storage**: Optimized time-series data storage
- **Redis Caching**: High-performance data caching layer
- **Orchestrator Integration**: Compatible with existing trading systems
## 📊 Data Models
### Core Models
- **OrderBookSnapshot**: Standardized order book data
- **TradeEvent**: Individual trade events
- **PriceBuckets**: Aggregated price bucket data
- **HeatmapData**: Visualization-ready heatmap data
- **ConnectionStatus**: Exchange connection monitoring
- **ReplaySession**: Historical data replay management
### Key Features
- Automatic data validation and normalization
- Configurable price bucket sizes per symbol
- Real-time metrics calculation
- Cross-exchange data consolidation
- Quality scoring and anomaly detection
## ⚙️ Configuration
The system uses environment variables for configuration:
```python
# Database settings
DB_HOST=192.168.0.10
DB_PORT=5432
DB_NAME=market_data
DB_USER=market_user
DB_PASSWORD=your_password
# Redis settings
REDIS_HOST=192.168.0.10
REDIS_PORT=6379
REDIS_PASSWORD=your_password
# Aggregation settings
BTC_BUCKET_SIZE=10.0
ETH_BUCKET_SIZE=1.0
HEATMAP_DEPTH=50
UPDATE_FREQUENCY=0.5
# Performance settings
DATA_BUFFER_SIZE=10000
BATCH_WRITE_SIZE=1000
MAX_MEMORY_USAGE=2048
```
## 🔌 Interfaces
### ExchangeConnector
Abstract base class for exchange WebSocket connectors with:
- Connection management with auto-reconnect
- Order book and trade subscriptions
- Data normalization callbacks
- Health monitoring
### DataProcessor
Interface for data processing and validation:
- Raw data normalization
- Quality validation
- Metrics calculation
- Anomaly detection
### AggregationEngine
Interface for data aggregation:
- Price bucket creation
- Heatmap generation
- Cross-exchange consolidation
- Imbalance calculations
### StorageManager
Interface for data persistence:
- TimescaleDB operations
- Batch processing
- Historical data retrieval
- Storage optimization
### ReplayManager
Interface for historical data replay:
- Session management
- Configurable playback speeds
- Time-based seeking
- Real-time compatibility
## 🛠️ Utilities
### Logging
- Structured logging with correlation IDs
- Configurable log levels and outputs
- Rotating file handlers
- Context-aware logging
### Validation
- Symbol format validation
- Price and volume validation
- Configuration validation
- Data quality checks
### Timing
- UTC timestamp handling
- Performance measurement
- Time-based operations
- Interval calculations
### Exceptions
- Custom exception hierarchy
- Error code management
- Detailed error context
- Structured error responses
## 🔧 Usage
### Basic Configuration
```python
from COBY.config import config
# Access configuration
db_url = config.get_database_url()
bucket_size = config.get_bucket_size('BTCUSDT')
```
### Data Models
```python
from COBY.models import OrderBookSnapshot, PriceLevel
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol='BTCUSDT',
exchange='binance',
timestamp=datetime.now(timezone.utc),
bids=[PriceLevel(50000.0, 1.5)],
asks=[PriceLevel(50100.0, 2.0)]
)
# Access calculated properties
mid_price = orderbook.mid_price
spread = orderbook.spread
```
### Logging
```python
from COBY.utils import setup_logging, get_logger, set_correlation_id
# Setup logging
setup_logging(level='INFO', log_file='logs/coby.log')
# Get logger
logger = get_logger(__name__)
# Use correlation ID
set_correlation_id('req-123')
logger.info("Processing order book data")
```
## 🏃 Next Steps
This is the foundational structure for the COBY system. The next implementation tasks will build upon these interfaces and models to create:
1. TimescaleDB integration
2. Exchange connector implementations
3. Data processing engines
4. Aggregation algorithms
5. Web dashboard
6. API endpoints
7. Replay functionality
Each component will implement the defined interfaces, ensuring consistency and maintainability across the entire system.
## 📝 Development Guidelines
- All components must implement the defined interfaces
- Use the provided data models for consistency
- Follow the logging and error handling patterns
- Validate all input data using the utility functions
- Maintain backward compatibility with the orchestrator interface
- Write comprehensive tests for all functionality
## 🔍 Monitoring
The system provides comprehensive monitoring through:
- Structured logging with correlation IDs
- Performance metrics collection
- Health check endpoints
- Connection status monitoring
- Data quality indicators
- System resource tracking

9
COBY/__init__.py Normal file
View File

@ -0,0 +1,9 @@
"""
Multi-Exchange Data Aggregation System (COBY)
A comprehensive data collection and aggregation subsystem for cryptocurrency exchanges.
Provides real-time order book data, heatmap visualization, and historical replay capabilities.
"""
__version__ = "1.0.0"
__author__ = "Trading System Team"

View File

@ -0,0 +1,15 @@
"""
Data aggregation components for the COBY system.
"""
from .aggregation_engine import StandardAggregationEngine
from .price_bucketer import PriceBucketer
from .heatmap_generator import HeatmapGenerator
from .cross_exchange_aggregator import CrossExchangeAggregator
__all__ = [
'StandardAggregationEngine',
'PriceBucketer',
'HeatmapGenerator',
'CrossExchangeAggregator'
]

View File

@ -0,0 +1,338 @@
"""
Main aggregation engine implementation.
"""
from typing import Dict, List
from ..interfaces.aggregation_engine import AggregationEngine
from ..models.core import (
OrderBookSnapshot, PriceBuckets, HeatmapData,
ImbalanceMetrics, ConsolidatedOrderBook
)
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import AggregationError
from .price_bucketer import PriceBucketer
from .heatmap_generator import HeatmapGenerator
from .cross_exchange_aggregator import CrossExchangeAggregator
from ..processing.metrics_calculator import MetricsCalculator
logger = get_logger(__name__)
class StandardAggregationEngine(AggregationEngine):
"""
Standard implementation of aggregation engine interface.
Provides:
- Price bucket creation with $1 USD buckets
- Heatmap generation
- Cross-exchange aggregation
- Imbalance calculations
- Support/resistance detection
"""
def __init__(self):
"""Initialize aggregation engine with components"""
self.price_bucketer = PriceBucketer()
self.heatmap_generator = HeatmapGenerator()
self.cross_exchange_aggregator = CrossExchangeAggregator()
self.metrics_calculator = MetricsCalculator()
# Processing statistics
self.buckets_created = 0
self.heatmaps_generated = 0
self.consolidations_performed = 0
logger.info("Standard aggregation engine initialized")
def create_price_buckets(self, orderbook: OrderBookSnapshot,
bucket_size: float = None) -> PriceBuckets:
"""
Convert order book data to price buckets.
Args:
orderbook: Order book snapshot
bucket_size: Size of each price bucket (uses $1 default)
Returns:
PriceBuckets: Aggregated price bucket data
"""
try:
set_correlation_id()
# Use provided bucket size or default $1
if bucket_size:
bucketer = PriceBucketer(bucket_size)
else:
bucketer = self.price_bucketer
buckets = bucketer.create_price_buckets(orderbook)
self.buckets_created += 1
logger.debug(f"Created price buckets for {orderbook.symbol}@{orderbook.exchange}")
return buckets
except Exception as e:
logger.error(f"Error creating price buckets: {e}")
raise AggregationError(f"Price bucket creation failed: {e}", "BUCKET_ERROR")
def update_heatmap(self, symbol: str, buckets: PriceBuckets) -> HeatmapData:
"""
Update heatmap data with new price buckets.
Args:
symbol: Trading symbol
buckets: Price bucket data
Returns:
HeatmapData: Updated heatmap visualization data
"""
try:
set_correlation_id()
heatmap = self.heatmap_generator.generate_heatmap(buckets)
self.heatmaps_generated += 1
logger.debug(f"Generated heatmap for {symbol}: {len(heatmap.data)} points")
return heatmap
except Exception as e:
logger.error(f"Error updating heatmap: {e}")
raise AggregationError(f"Heatmap update failed: {e}", "HEATMAP_ERROR")
def calculate_imbalances(self, orderbook: OrderBookSnapshot) -> ImbalanceMetrics:
"""
Calculate order book imbalance metrics.
Args:
orderbook: Order book snapshot
Returns:
ImbalanceMetrics: Calculated imbalance metrics
"""
try:
set_correlation_id()
return self.metrics_calculator.calculate_imbalance_metrics(orderbook)
except Exception as e:
logger.error(f"Error calculating imbalances: {e}")
raise AggregationError(f"Imbalance calculation failed: {e}", "IMBALANCE_ERROR")
def aggregate_across_exchanges(self, symbol: str,
orderbooks: List[OrderBookSnapshot]) -> ConsolidatedOrderBook:
"""
Aggregate order book data from multiple exchanges.
Args:
symbol: Trading symbol
orderbooks: List of order book snapshots from different exchanges
Returns:
ConsolidatedOrderBook: Consolidated order book data
"""
try:
set_correlation_id()
consolidated = self.cross_exchange_aggregator.aggregate_across_exchanges(
symbol, orderbooks
)
self.consolidations_performed += 1
logger.debug(f"Consolidated {len(orderbooks)} order books for {symbol}")
return consolidated
except Exception as e:
logger.error(f"Error aggregating across exchanges: {e}")
raise AggregationError(f"Cross-exchange aggregation failed: {e}", "CONSOLIDATION_ERROR")
def calculate_volume_weighted_price(self, orderbooks: List[OrderBookSnapshot]) -> float:
"""
Calculate volume-weighted average price across exchanges.
Args:
orderbooks: List of order book snapshots
Returns:
float: Volume-weighted average price
"""
try:
set_correlation_id()
return self.cross_exchange_aggregator._calculate_weighted_mid_price(orderbooks)
except Exception as e:
logger.error(f"Error calculating volume weighted price: {e}")
raise AggregationError(f"VWAP calculation failed: {e}", "VWAP_ERROR")
def get_market_depth(self, orderbook: OrderBookSnapshot,
depth_levels: List[float]) -> Dict[float, Dict[str, float]]:
"""
Calculate market depth at different price levels.
Args:
orderbook: Order book snapshot
depth_levels: List of depth percentages (e.g., [0.1, 0.5, 1.0])
Returns:
Dict: Market depth data {level: {'bid_volume': x, 'ask_volume': y}}
"""
try:
set_correlation_id()
depth_data = {}
if not orderbook.mid_price:
return depth_data
for level_pct in depth_levels:
# Calculate price range for this depth level
price_range = orderbook.mid_price * (level_pct / 100.0)
min_bid_price = orderbook.mid_price - price_range
max_ask_price = orderbook.mid_price + price_range
# Calculate volumes within this range
bid_volume = sum(
bid.size for bid in orderbook.bids
if bid.price >= min_bid_price
)
ask_volume = sum(
ask.size for ask in orderbook.asks
if ask.price <= max_ask_price
)
depth_data[level_pct] = {
'bid_volume': bid_volume,
'ask_volume': ask_volume,
'total_volume': bid_volume + ask_volume
}
logger.debug(f"Calculated market depth for {len(depth_levels)} levels")
return depth_data
except Exception as e:
logger.error(f"Error calculating market depth: {e}")
return {}
def smooth_heatmap(self, heatmap: HeatmapData, smoothing_factor: float) -> HeatmapData:
"""
Apply smoothing to heatmap data to reduce noise.
Args:
heatmap: Raw heatmap data
smoothing_factor: Smoothing factor (0.0 to 1.0)
Returns:
HeatmapData: Smoothed heatmap data
"""
try:
set_correlation_id()
return self.heatmap_generator.apply_smoothing(heatmap, smoothing_factor)
except Exception as e:
logger.error(f"Error smoothing heatmap: {e}")
return heatmap # Return original on error
def calculate_liquidity_score(self, orderbook: OrderBookSnapshot) -> float:
"""
Calculate liquidity score for an order book.
Args:
orderbook: Order book snapshot
Returns:
float: Liquidity score (0.0 to 1.0)
"""
try:
set_correlation_id()
return self.metrics_calculator.calculate_liquidity_score(orderbook)
except Exception as e:
logger.error(f"Error calculating liquidity score: {e}")
return 0.0
def detect_support_resistance(self, heatmap: HeatmapData) -> Dict[str, List[float]]:
"""
Detect support and resistance levels from heatmap data.
Args:
heatmap: Heatmap data
Returns:
Dict: {'support': [prices], 'resistance': [prices]}
"""
try:
set_correlation_id()
return self.heatmap_generator.calculate_support_resistance(heatmap)
except Exception as e:
logger.error(f"Error detecting support/resistance: {e}")
return {'support': [], 'resistance': []}
def create_consolidated_heatmap(self, symbol: str,
orderbooks: List[OrderBookSnapshot]) -> HeatmapData:
"""
Create consolidated heatmap from multiple exchanges.
Args:
symbol: Trading symbol
orderbooks: List of order book snapshots
Returns:
HeatmapData: Consolidated heatmap data
"""
try:
set_correlation_id()
return self.cross_exchange_aggregator.create_consolidated_heatmap(
symbol, orderbooks
)
except Exception as e:
logger.error(f"Error creating consolidated heatmap: {e}")
raise AggregationError(f"Consolidated heatmap creation failed: {e}", "CONSOLIDATED_HEATMAP_ERROR")
def detect_arbitrage_opportunities(self, orderbooks: List[OrderBookSnapshot]) -> List[Dict]:
"""
Detect arbitrage opportunities between exchanges.
Args:
orderbooks: List of order book snapshots
Returns:
List[Dict]: Arbitrage opportunities
"""
try:
set_correlation_id()
return self.cross_exchange_aggregator.detect_arbitrage_opportunities(orderbooks)
except Exception as e:
logger.error(f"Error detecting arbitrage opportunities: {e}")
return []
def get_processing_stats(self) -> Dict[str, any]:
"""Get processing statistics"""
return {
'buckets_created': self.buckets_created,
'heatmaps_generated': self.heatmaps_generated,
'consolidations_performed': self.consolidations_performed,
'price_bucketer_stats': self.price_bucketer.get_processing_stats(),
'heatmap_generator_stats': self.heatmap_generator.get_processing_stats(),
'cross_exchange_stats': self.cross_exchange_aggregator.get_processing_stats()
}
def reset_stats(self) -> None:
"""Reset processing statistics"""
self.buckets_created = 0
self.heatmaps_generated = 0
self.consolidations_performed = 0
self.price_bucketer.reset_stats()
self.heatmap_generator.reset_stats()
self.cross_exchange_aggregator.reset_stats()
logger.info("Aggregation engine statistics reset")

View File

@ -0,0 +1,390 @@
"""
Cross-exchange data aggregation and consolidation.
"""
from typing import List, Dict, Optional
from collections import defaultdict
from datetime import datetime
from ..models.core import (
OrderBookSnapshot, ConsolidatedOrderBook, PriceLevel,
PriceBuckets, HeatmapData, HeatmapPoint
)
from ..utils.logging import get_logger
from ..utils.timing import get_current_timestamp
from .price_bucketer import PriceBucketer
from .heatmap_generator import HeatmapGenerator
logger = get_logger(__name__)
class CrossExchangeAggregator:
"""
Aggregates data across multiple exchanges.
Provides consolidated order books and cross-exchange heatmaps.
"""
def __init__(self):
"""Initialize cross-exchange aggregator"""
self.price_bucketer = PriceBucketer()
self.heatmap_generator = HeatmapGenerator()
# Exchange weights for aggregation
self.exchange_weights = {
'binance': 1.0,
'coinbase': 0.9,
'kraken': 0.8,
'bybit': 0.7,
'okx': 0.7,
'huobi': 0.6,
'kucoin': 0.6,
'gateio': 0.5,
'bitfinex': 0.5,
'mexc': 0.4
}
# Statistics
self.consolidations_performed = 0
self.exchanges_processed = set()
logger.info("Cross-exchange aggregator initialized")
def aggregate_across_exchanges(self, symbol: str,
orderbooks: List[OrderBookSnapshot]) -> ConsolidatedOrderBook:
"""
Aggregate order book data from multiple exchanges.
Args:
symbol: Trading symbol
orderbooks: List of order book snapshots from different exchanges
Returns:
ConsolidatedOrderBook: Consolidated order book data
"""
if not orderbooks:
raise ValueError("Cannot aggregate empty orderbook list")
try:
# Track exchanges
exchanges = [ob.exchange for ob in orderbooks]
self.exchanges_processed.update(exchanges)
# Calculate weighted mid price
weighted_mid_price = self._calculate_weighted_mid_price(orderbooks)
# Consolidate bids and asks
consolidated_bids = self._consolidate_price_levels(
[ob.bids for ob in orderbooks],
[ob.exchange for ob in orderbooks],
'bid'
)
consolidated_asks = self._consolidate_price_levels(
[ob.asks for ob in orderbooks],
[ob.exchange for ob in orderbooks],
'ask'
)
# Calculate total volumes
total_bid_volume = sum(level.size for level in consolidated_bids)
total_ask_volume = sum(level.size for level in consolidated_asks)
# Create consolidated order book
consolidated = ConsolidatedOrderBook(
symbol=symbol,
timestamp=get_current_timestamp(),
exchanges=exchanges,
bids=consolidated_bids,
asks=consolidated_asks,
weighted_mid_price=weighted_mid_price,
total_bid_volume=total_bid_volume,
total_ask_volume=total_ask_volume,
exchange_weights={ex: self.exchange_weights.get(ex, 0.5) for ex in exchanges}
)
self.consolidations_performed += 1
logger.debug(
f"Consolidated {len(orderbooks)} order books for {symbol}: "
f"{len(consolidated_bids)} bids, {len(consolidated_asks)} asks"
)
return consolidated
except Exception as e:
logger.error(f"Error aggregating across exchanges: {e}")
raise
def create_consolidated_heatmap(self, symbol: str,
orderbooks: List[OrderBookSnapshot]) -> HeatmapData:
"""
Create consolidated heatmap from multiple exchanges.
Args:
symbol: Trading symbol
orderbooks: List of order book snapshots
Returns:
HeatmapData: Consolidated heatmap data
"""
try:
# Create price buckets for each exchange
all_buckets = []
for orderbook in orderbooks:
buckets = self.price_bucketer.create_price_buckets(orderbook)
all_buckets.append(buckets)
# Aggregate all buckets
if len(all_buckets) == 1:
consolidated_buckets = all_buckets[0]
else:
consolidated_buckets = self.price_bucketer.aggregate_buckets(all_buckets)
# Generate heatmap from consolidated buckets
heatmap = self.heatmap_generator.generate_heatmap(consolidated_buckets)
# Add exchange metadata to heatmap points
self._add_exchange_metadata(heatmap, orderbooks)
logger.debug(f"Created consolidated heatmap for {symbol} from {len(orderbooks)} exchanges")
return heatmap
except Exception as e:
logger.error(f"Error creating consolidated heatmap: {e}")
raise
def _calculate_weighted_mid_price(self, orderbooks: List[OrderBookSnapshot]) -> float:
"""Calculate volume-weighted mid price across exchanges"""
total_weight = 0.0
weighted_sum = 0.0
for orderbook in orderbooks:
if orderbook.mid_price:
# Use total volume as weight
volume_weight = orderbook.bid_volume + orderbook.ask_volume
exchange_weight = self.exchange_weights.get(orderbook.exchange, 0.5)
# Combined weight
weight = volume_weight * exchange_weight
weighted_sum += orderbook.mid_price * weight
total_weight += weight
return weighted_sum / total_weight if total_weight > 0 else 0.0
def _consolidate_price_levels(self, level_lists: List[List[PriceLevel]],
exchanges: List[str], side: str) -> List[PriceLevel]:
"""Consolidate price levels from multiple exchanges"""
# Group levels by price bucket
price_groups = defaultdict(lambda: {'size': 0.0, 'count': 0, 'exchanges': set()})
for levels, exchange in zip(level_lists, exchanges):
exchange_weight = self.exchange_weights.get(exchange, 0.5)
for level in levels:
# Round price to bucket
bucket_price = self.price_bucketer.get_bucket_price(level.price)
# Add weighted volume
weighted_size = level.size * exchange_weight
price_groups[bucket_price]['size'] += weighted_size
price_groups[bucket_price]['count'] += level.count or 1
price_groups[bucket_price]['exchanges'].add(exchange)
# Create consolidated price levels
consolidated_levels = []
for price, data in price_groups.items():
if data['size'] > 0: # Only include non-zero volumes
level = PriceLevel(
price=price,
size=data['size'],
count=data['count']
)
consolidated_levels.append(level)
# Sort levels appropriately
if side == 'bid':
consolidated_levels.sort(key=lambda x: x.price, reverse=True)
else:
consolidated_levels.sort(key=lambda x: x.price)
return consolidated_levels
def _add_exchange_metadata(self, heatmap: HeatmapData,
orderbooks: List[OrderBookSnapshot]) -> None:
"""Add exchange metadata to heatmap points"""
# Create exchange mapping by price bucket
exchange_map = defaultdict(set)
for orderbook in orderbooks:
# Map bid prices to exchanges
for bid in orderbook.bids:
bucket_price = self.price_bucketer.get_bucket_price(bid.price)
exchange_map[bucket_price].add(orderbook.exchange)
# Map ask prices to exchanges
for ask in orderbook.asks:
bucket_price = self.price_bucketer.get_bucket_price(ask.price)
exchange_map[bucket_price].add(orderbook.exchange)
# Add exchange information to heatmap points
for point in heatmap.data:
bucket_price = self.price_bucketer.get_bucket_price(point.price)
# Store exchange info in a custom attribute (would need to extend HeatmapPoint)
# For now, we'll log it
exchanges_at_price = exchange_map.get(bucket_price, set())
if len(exchanges_at_price) > 1:
logger.debug(f"Price {point.price} has data from {len(exchanges_at_price)} exchanges")
def calculate_exchange_dominance(self, orderbooks: List[OrderBookSnapshot]) -> Dict[str, float]:
"""
Calculate which exchanges dominate at different price levels.
Args:
orderbooks: List of order book snapshots
Returns:
Dict[str, float]: Exchange dominance scores
"""
exchange_volumes = defaultdict(float)
total_volume = 0.0
for orderbook in orderbooks:
volume = orderbook.bid_volume + orderbook.ask_volume
exchange_volumes[orderbook.exchange] += volume
total_volume += volume
# Calculate dominance percentages
dominance = {}
for exchange, volume in exchange_volumes.items():
dominance[exchange] = (volume / total_volume * 100) if total_volume > 0 else 0.0
return dominance
def detect_arbitrage_opportunities(self, orderbooks: List[OrderBookSnapshot],
min_spread_pct: float = 0.1) -> List[Dict]:
"""
Detect potential arbitrage opportunities between exchanges.
Args:
orderbooks: List of order book snapshots
min_spread_pct: Minimum spread percentage to consider
Returns:
List[Dict]: Arbitrage opportunities
"""
opportunities = []
if len(orderbooks) < 2:
return opportunities
try:
# Find best bid and ask across exchanges
best_bids = []
best_asks = []
for orderbook in orderbooks:
if orderbook.bids and orderbook.asks:
best_bids.append({
'exchange': orderbook.exchange,
'price': orderbook.bids[0].price,
'size': orderbook.bids[0].size
})
best_asks.append({
'exchange': orderbook.exchange,
'price': orderbook.asks[0].price,
'size': orderbook.asks[0].size
})
# Sort to find best opportunities
best_bids.sort(key=lambda x: x['price'], reverse=True)
best_asks.sort(key=lambda x: x['price'])
# Check for arbitrage opportunities
for bid in best_bids:
for ask in best_asks:
if bid['exchange'] != ask['exchange'] and bid['price'] > ask['price']:
spread = bid['price'] - ask['price']
spread_pct = (spread / ask['price']) * 100
if spread_pct >= min_spread_pct:
opportunities.append({
'buy_exchange': ask['exchange'],
'sell_exchange': bid['exchange'],
'buy_price': ask['price'],
'sell_price': bid['price'],
'spread': spread,
'spread_percentage': spread_pct,
'max_size': min(bid['size'], ask['size'])
})
# Sort by spread percentage
opportunities.sort(key=lambda x: x['spread_percentage'], reverse=True)
if opportunities:
logger.info(f"Found {len(opportunities)} arbitrage opportunities")
return opportunities
except Exception as e:
logger.error(f"Error detecting arbitrage opportunities: {e}")
return []
def get_exchange_correlation(self, orderbooks: List[OrderBookSnapshot]) -> Dict[str, Dict[str, float]]:
"""
Calculate price correlation between exchanges.
Args:
orderbooks: List of order book snapshots
Returns:
Dict: Correlation matrix between exchanges
"""
correlations = {}
# Extract mid prices by exchange
exchange_prices = {}
for orderbook in orderbooks:
if orderbook.mid_price:
exchange_prices[orderbook.exchange] = orderbook.mid_price
# Calculate simple correlation (would need historical data for proper correlation)
exchanges = list(exchange_prices.keys())
for i, exchange1 in enumerate(exchanges):
correlations[exchange1] = {}
for j, exchange2 in enumerate(exchanges):
if i == j:
correlations[exchange1][exchange2] = 1.0
else:
# Simple price difference as correlation proxy
price1 = exchange_prices[exchange1]
price2 = exchange_prices[exchange2]
diff_pct = abs(price1 - price2) / max(price1, price2) * 100
# Convert to correlation-like score (lower difference = higher correlation)
correlation = max(0.0, 1.0 - (diff_pct / 10.0))
correlations[exchange1][exchange2] = correlation
return correlations
def get_processing_stats(self) -> Dict[str, int]:
"""Get processing statistics"""
return {
'consolidations_performed': self.consolidations_performed,
'unique_exchanges_processed': len(self.exchanges_processed),
'exchanges_processed': list(self.exchanges_processed),
'bucketer_stats': self.price_bucketer.get_processing_stats(),
'heatmap_stats': self.heatmap_generator.get_processing_stats()
}
def update_exchange_weights(self, new_weights: Dict[str, float]) -> None:
"""Update exchange weights for aggregation"""
self.exchange_weights.update(new_weights)
logger.info(f"Updated exchange weights: {new_weights}")
def reset_stats(self) -> None:
"""Reset processing statistics"""
self.consolidations_performed = 0
self.exchanges_processed.clear()
self.price_bucketer.reset_stats()
self.heatmap_generator.reset_stats()
logger.info("Cross-exchange aggregator statistics reset")

View File

@ -0,0 +1,376 @@
"""
Heatmap data generation from price buckets.
"""
from typing import List, Dict, Optional, Tuple
from ..models.core import PriceBuckets, HeatmapData, HeatmapPoint
from ..config import config
from ..utils.logging import get_logger
logger = get_logger(__name__)
class HeatmapGenerator:
"""
Generates heatmap visualization data from price buckets.
Creates intensity-based heatmap points for visualization.
"""
def __init__(self):
"""Initialize heatmap generator"""
self.heatmaps_generated = 0
self.total_points_created = 0
logger.info("Heatmap generator initialized")
def generate_heatmap(self, buckets: PriceBuckets,
max_points: Optional[int] = None) -> HeatmapData:
"""
Generate heatmap data from price buckets.
Args:
buckets: Price buckets to convert
max_points: Maximum number of points to include (None = all)
Returns:
HeatmapData: Heatmap visualization data
"""
try:
heatmap = HeatmapData(
symbol=buckets.symbol,
timestamp=buckets.timestamp,
bucket_size=buckets.bucket_size
)
# Calculate maximum volume for intensity normalization
all_volumes = list(buckets.bid_buckets.values()) + list(buckets.ask_buckets.values())
max_volume = max(all_volumes) if all_volumes else 1.0
# Generate bid points
bid_points = self._create_heatmap_points(
buckets.bid_buckets, 'bid', max_volume
)
# Generate ask points
ask_points = self._create_heatmap_points(
buckets.ask_buckets, 'ask', max_volume
)
# Combine all points
all_points = bid_points + ask_points
# Limit points if requested
if max_points and len(all_points) > max_points:
# Sort by volume and take top points
all_points.sort(key=lambda p: p.volume, reverse=True)
all_points = all_points[:max_points]
heatmap.data = all_points
self.heatmaps_generated += 1
self.total_points_created += len(all_points)
logger.debug(
f"Generated heatmap for {buckets.symbol}: {len(all_points)} points "
f"(max_volume: {max_volume:.6f})"
)
return heatmap
except Exception as e:
logger.error(f"Error generating heatmap: {e}")
raise
def _create_heatmap_points(self, bucket_dict: Dict[float, float],
side: str, max_volume: float) -> List[HeatmapPoint]:
"""
Create heatmap points from bucket dictionary.
Args:
bucket_dict: Dictionary of price -> volume
side: 'bid' or 'ask'
max_volume: Maximum volume for intensity calculation
Returns:
List[HeatmapPoint]: List of heatmap points
"""
points = []
for price, volume in bucket_dict.items():
if volume > 0: # Only include non-zero volumes
intensity = min(volume / max_volume, 1.0) if max_volume > 0 else 0.0
point = HeatmapPoint(
price=price,
volume=volume,
intensity=intensity,
side=side
)
points.append(point)
return points
def apply_smoothing(self, heatmap: HeatmapData,
smoothing_factor: float = 0.3) -> HeatmapData:
"""
Apply smoothing to heatmap data to reduce noise.
Args:
heatmap: Original heatmap data
smoothing_factor: Smoothing factor (0.0 = no smoothing, 1.0 = maximum)
Returns:
HeatmapData: Smoothed heatmap data
"""
if smoothing_factor <= 0:
return heatmap
try:
smoothed = HeatmapData(
symbol=heatmap.symbol,
timestamp=heatmap.timestamp,
bucket_size=heatmap.bucket_size
)
# Separate bids and asks
bids = [p for p in heatmap.data if p.side == 'bid']
asks = [p for p in heatmap.data if p.side == 'ask']
# Apply smoothing to each side
smoothed_bids = self._smooth_points(bids, smoothing_factor)
smoothed_asks = self._smooth_points(asks, smoothing_factor)
smoothed.data = smoothed_bids + smoothed_asks
logger.debug(f"Applied smoothing with factor {smoothing_factor}")
return smoothed
except Exception as e:
logger.error(f"Error applying smoothing: {e}")
return heatmap # Return original on error
def _smooth_points(self, points: List[HeatmapPoint],
smoothing_factor: float) -> List[HeatmapPoint]:
"""
Apply smoothing to a list of heatmap points.
Args:
points: Points to smooth
smoothing_factor: Smoothing factor
Returns:
List[HeatmapPoint]: Smoothed points
"""
if len(points) < 3:
return points
# Sort points by price
sorted_points = sorted(points, key=lambda p: p.price)
smoothed_points = []
for i, point in enumerate(sorted_points):
# Calculate weighted average with neighbors
total_weight = 1.0
weighted_volume = point.volume
weighted_intensity = point.intensity
# Add left neighbor
if i > 0:
left_point = sorted_points[i - 1]
weight = smoothing_factor
total_weight += weight
weighted_volume += left_point.volume * weight
weighted_intensity += left_point.intensity * weight
# Add right neighbor
if i < len(sorted_points) - 1:
right_point = sorted_points[i + 1]
weight = smoothing_factor
total_weight += weight
weighted_volume += right_point.volume * weight
weighted_intensity += right_point.intensity * weight
# Create smoothed point
smoothed_point = HeatmapPoint(
price=point.price,
volume=weighted_volume / total_weight,
intensity=min(weighted_intensity / total_weight, 1.0),
side=point.side
)
smoothed_points.append(smoothed_point)
return smoothed_points
def filter_by_intensity(self, heatmap: HeatmapData,
min_intensity: float = 0.1) -> HeatmapData:
"""
Filter heatmap points by minimum intensity.
Args:
heatmap: Original heatmap data
min_intensity: Minimum intensity threshold
Returns:
HeatmapData: Filtered heatmap data
"""
filtered = HeatmapData(
symbol=heatmap.symbol,
timestamp=heatmap.timestamp,
bucket_size=heatmap.bucket_size
)
# Filter points by intensity
filtered.data = [
point for point in heatmap.data
if point.intensity >= min_intensity
]
logger.debug(
f"Filtered heatmap: {len(heatmap.data)} -> {len(filtered.data)} points "
f"(min_intensity: {min_intensity})"
)
return filtered
def get_price_levels(self, heatmap: HeatmapData,
side: str = None) -> List[float]:
"""
Get sorted list of price levels from heatmap.
Args:
heatmap: Heatmap data
side: 'bid', 'ask', or None for both
Returns:
List[float]: Sorted price levels
"""
if side:
points = [p for p in heatmap.data if p.side == side]
else:
points = heatmap.data
prices = [p.price for p in points]
return sorted(prices)
def get_volume_profile(self, heatmap: HeatmapData) -> Dict[str, List[Tuple[float, float]]]:
"""
Get volume profile from heatmap data.
Args:
heatmap: Heatmap data
Returns:
Dict: Volume profile with 'bids' and 'asks' as (price, volume) tuples
"""
profile = {'bids': [], 'asks': []}
# Extract bid profile
bid_points = [p for p in heatmap.data if p.side == 'bid']
profile['bids'] = [(p.price, p.volume) for p in bid_points]
profile['bids'].sort(key=lambda x: x[0], reverse=True) # Highest price first
# Extract ask profile
ask_points = [p for p in heatmap.data if p.side == 'ask']
profile['asks'] = [(p.price, p.volume) for p in ask_points]
profile['asks'].sort(key=lambda x: x[0]) # Lowest price first
return profile
def calculate_support_resistance(self, heatmap: HeatmapData,
threshold: float = 0.7) -> Dict[str, List[float]]:
"""
Identify potential support and resistance levels from heatmap.
Args:
heatmap: Heatmap data
threshold: Intensity threshold for significant levels
Returns:
Dict: Support and resistance levels
"""
levels = {'support': [], 'resistance': []}
# Find high-intensity bid levels (potential support)
bid_points = [p for p in heatmap.data if p.side == 'bid' and p.intensity >= threshold]
levels['support'] = sorted([p.price for p in bid_points], reverse=True)
# Find high-intensity ask levels (potential resistance)
ask_points = [p for p in heatmap.data if p.side == 'ask' and p.intensity >= threshold]
levels['resistance'] = sorted([p.price for p in ask_points])
logger.debug(
f"Identified {len(levels['support'])} support and "
f"{len(levels['resistance'])} resistance levels"
)
return levels
def get_heatmap_summary(self, heatmap: HeatmapData) -> Dict[str, float]:
"""
Get summary statistics for heatmap data.
Args:
heatmap: Heatmap data
Returns:
Dict: Summary statistics
"""
if not heatmap.data:
return {}
# Separate bids and asks
bids = [p for p in heatmap.data if p.side == 'bid']
asks = [p for p in heatmap.data if p.side == 'ask']
summary = {
'total_points': len(heatmap.data),
'bid_points': len(bids),
'ask_points': len(asks),
'total_volume': sum(p.volume for p in heatmap.data),
'bid_volume': sum(p.volume for p in bids),
'ask_volume': sum(p.volume for p in asks),
'max_intensity': max(p.intensity for p in heatmap.data),
'avg_intensity': sum(p.intensity for p in heatmap.data) / len(heatmap.data),
'price_range': 0.0,
'best_bid': 0.0,
'best_ask': 0.0
}
# Calculate price range
all_prices = [p.price for p in heatmap.data]
if all_prices:
summary['price_range'] = max(all_prices) - min(all_prices)
# Calculate best bid and ask
if bids:
summary['best_bid'] = max(p.price for p in bids)
if asks:
summary['best_ask'] = min(p.price for p in asks)
# Calculate volume imbalance
total_volume = summary['total_volume']
if total_volume > 0:
summary['volume_imbalance'] = (
(summary['bid_volume'] - summary['ask_volume']) / total_volume
)
else:
summary['volume_imbalance'] = 0.0
return summary
def get_processing_stats(self) -> Dict[str, int]:
"""Get processing statistics"""
return {
'heatmaps_generated': self.heatmaps_generated,
'total_points_created': self.total_points_created,
'avg_points_per_heatmap': (
self.total_points_created // max(self.heatmaps_generated, 1)
)
}
def reset_stats(self) -> None:
"""Reset processing statistics"""
self.heatmaps_generated = 0
self.total_points_created = 0
logger.info("Heatmap generator statistics reset")

View File

@ -0,0 +1,341 @@
"""
Price bucketing system for order book aggregation.
"""
import math
from typing import Dict, List, Tuple, Optional
from collections import defaultdict
from ..models.core import OrderBookSnapshot, PriceBuckets, PriceLevel
from ..config import config
from ..utils.logging import get_logger
from ..utils.validation import validate_price, validate_volume
logger = get_logger(__name__)
class PriceBucketer:
"""
Converts order book data into price buckets for heatmap visualization.
Uses universal $1 USD buckets for all symbols to simplify logic.
"""
def __init__(self, bucket_size: float = None):
"""
Initialize price bucketer.
Args:
bucket_size: Size of price buckets in USD (defaults to config value)
"""
self.bucket_size = bucket_size or config.get_bucket_size()
# Statistics
self.buckets_created = 0
self.total_volume_processed = 0.0
logger.info(f"Price bucketer initialized with ${self.bucket_size} buckets")
def create_price_buckets(self, orderbook: OrderBookSnapshot) -> PriceBuckets:
"""
Convert order book data to price buckets.
Args:
orderbook: Order book snapshot
Returns:
PriceBuckets: Aggregated price bucket data
"""
try:
# Create price buckets object
buckets = PriceBuckets(
symbol=orderbook.symbol,
timestamp=orderbook.timestamp,
bucket_size=self.bucket_size
)
# Process bids (aggregate into buckets)
for bid in orderbook.bids:
if validate_price(bid.price) and validate_volume(bid.size):
buckets.add_bid(bid.price, bid.size)
self.total_volume_processed += bid.size
# Process asks (aggregate into buckets)
for ask in orderbook.asks:
if validate_price(ask.price) and validate_volume(ask.size):
buckets.add_ask(ask.price, ask.size)
self.total_volume_processed += ask.size
self.buckets_created += 1
logger.debug(
f"Created price buckets for {orderbook.symbol}: "
f"{len(buckets.bid_buckets)} bid buckets, {len(buckets.ask_buckets)} ask buckets"
)
return buckets
except Exception as e:
logger.error(f"Error creating price buckets: {e}")
raise
def aggregate_buckets(self, bucket_list: List[PriceBuckets]) -> PriceBuckets:
"""
Aggregate multiple price buckets into a single bucket set.
Args:
bucket_list: List of price buckets to aggregate
Returns:
PriceBuckets: Aggregated buckets
"""
if not bucket_list:
raise ValueError("Cannot aggregate empty bucket list")
# Use first bucket as template
first_bucket = bucket_list[0]
aggregated = PriceBuckets(
symbol=first_bucket.symbol,
timestamp=first_bucket.timestamp,
bucket_size=self.bucket_size
)
# Aggregate all bid buckets
for buckets in bucket_list:
for price, volume in buckets.bid_buckets.items():
bucket_price = aggregated.get_bucket_price(price)
aggregated.bid_buckets[bucket_price] = (
aggregated.bid_buckets.get(bucket_price, 0) + volume
)
# Aggregate all ask buckets
for buckets in bucket_list:
for price, volume in buckets.ask_buckets.items():
bucket_price = aggregated.get_bucket_price(price)
aggregated.ask_buckets[bucket_price] = (
aggregated.ask_buckets.get(bucket_price, 0) + volume
)
logger.debug(f"Aggregated {len(bucket_list)} bucket sets")
return aggregated
def get_bucket_range(self, center_price: float, depth: int) -> Tuple[float, float]:
"""
Get price range for buckets around a center price.
Args:
center_price: Center price for the range
depth: Number of buckets on each side
Returns:
Tuple[float, float]: (min_price, max_price)
"""
half_range = depth * self.bucket_size
min_price = center_price - half_range
max_price = center_price + half_range
return (max(0, min_price), max_price)
def filter_buckets_by_range(self, buckets: PriceBuckets,
min_price: float, max_price: float) -> PriceBuckets:
"""
Filter buckets to only include those within a price range.
Args:
buckets: Original price buckets
min_price: Minimum price to include
max_price: Maximum price to include
Returns:
PriceBuckets: Filtered buckets
"""
filtered = PriceBuckets(
symbol=buckets.symbol,
timestamp=buckets.timestamp,
bucket_size=buckets.bucket_size
)
# Filter bid buckets
for price, volume in buckets.bid_buckets.items():
if min_price <= price <= max_price:
filtered.bid_buckets[price] = volume
# Filter ask buckets
for price, volume in buckets.ask_buckets.items():
if min_price <= price <= max_price:
filtered.ask_buckets[price] = volume
return filtered
def get_top_buckets(self, buckets: PriceBuckets, count: int) -> PriceBuckets:
"""
Get top N buckets by volume.
Args:
buckets: Original price buckets
count: Number of top buckets to return
Returns:
PriceBuckets: Top buckets by volume
"""
top_buckets = PriceBuckets(
symbol=buckets.symbol,
timestamp=buckets.timestamp,
bucket_size=buckets.bucket_size
)
# Get top bid buckets
top_bids = sorted(
buckets.bid_buckets.items(),
key=lambda x: x[1], # Sort by volume
reverse=True
)[:count]
for price, volume in top_bids:
top_buckets.bid_buckets[price] = volume
# Get top ask buckets
top_asks = sorted(
buckets.ask_buckets.items(),
key=lambda x: x[1], # Sort by volume
reverse=True
)[:count]
for price, volume in top_asks:
top_buckets.ask_buckets[price] = volume
return top_buckets
def calculate_bucket_statistics(self, buckets: PriceBuckets) -> Dict[str, float]:
"""
Calculate statistics for price buckets.
Args:
buckets: Price buckets to analyze
Returns:
Dict[str, float]: Bucket statistics
"""
stats = {
'total_bid_buckets': len(buckets.bid_buckets),
'total_ask_buckets': len(buckets.ask_buckets),
'total_bid_volume': sum(buckets.bid_buckets.values()),
'total_ask_volume': sum(buckets.ask_buckets.values()),
'bid_price_range': 0.0,
'ask_price_range': 0.0,
'max_bid_volume': 0.0,
'max_ask_volume': 0.0,
'avg_bid_volume': 0.0,
'avg_ask_volume': 0.0
}
# Calculate bid statistics
if buckets.bid_buckets:
bid_prices = list(buckets.bid_buckets.keys())
bid_volumes = list(buckets.bid_buckets.values())
stats['bid_price_range'] = max(bid_prices) - min(bid_prices)
stats['max_bid_volume'] = max(bid_volumes)
stats['avg_bid_volume'] = sum(bid_volumes) / len(bid_volumes)
# Calculate ask statistics
if buckets.ask_buckets:
ask_prices = list(buckets.ask_buckets.keys())
ask_volumes = list(buckets.ask_buckets.values())
stats['ask_price_range'] = max(ask_prices) - min(ask_prices)
stats['max_ask_volume'] = max(ask_volumes)
stats['avg_ask_volume'] = sum(ask_volumes) / len(ask_volumes)
# Calculate combined statistics
stats['total_volume'] = stats['total_bid_volume'] + stats['total_ask_volume']
stats['volume_imbalance'] = (
(stats['total_bid_volume'] - stats['total_ask_volume']) /
max(stats['total_volume'], 1e-10)
)
return stats
def merge_adjacent_buckets(self, buckets: PriceBuckets, merge_factor: int = 2) -> PriceBuckets:
"""
Merge adjacent buckets to create larger bucket sizes.
Args:
buckets: Original price buckets
merge_factor: Number of adjacent buckets to merge
Returns:
PriceBuckets: Merged buckets with larger bucket size
"""
merged = PriceBuckets(
symbol=buckets.symbol,
timestamp=buckets.timestamp,
bucket_size=buckets.bucket_size * merge_factor
)
# Merge bid buckets
bid_groups = defaultdict(float)
for price, volume in buckets.bid_buckets.items():
# Calculate new bucket price
new_bucket_price = merged.get_bucket_price(price)
bid_groups[new_bucket_price] += volume
merged.bid_buckets = dict(bid_groups)
# Merge ask buckets
ask_groups = defaultdict(float)
for price, volume in buckets.ask_buckets.items():
# Calculate new bucket price
new_bucket_price = merged.get_bucket_price(price)
ask_groups[new_bucket_price] += volume
merged.ask_buckets = dict(ask_groups)
logger.debug(f"Merged buckets with factor {merge_factor}")
return merged
def get_bucket_depth_profile(self, buckets: PriceBuckets,
center_price: float) -> Dict[str, List[Tuple[float, float]]]:
"""
Get depth profile showing volume at different distances from center price.
Args:
buckets: Price buckets
center_price: Center price for depth calculation
Returns:
Dict: Depth profile with 'bids' and 'asks' lists of (distance, volume) tuples
"""
profile = {'bids': [], 'asks': []}
# Calculate bid depth profile
for price, volume in buckets.bid_buckets.items():
distance = abs(center_price - price)
profile['bids'].append((distance, volume))
# Calculate ask depth profile
for price, volume in buckets.ask_buckets.items():
distance = abs(price - center_price)
profile['asks'].append((distance, volume))
# Sort by distance
profile['bids'].sort(key=lambda x: x[0])
profile['asks'].sort(key=lambda x: x[0])
return profile
def get_processing_stats(self) -> Dict[str, float]:
"""Get processing statistics"""
return {
'bucket_size': self.bucket_size,
'buckets_created': self.buckets_created,
'total_volume_processed': self.total_volume_processed,
'avg_volume_per_bucket': (
self.total_volume_processed / max(self.buckets_created, 1)
)
}
def reset_stats(self) -> None:
"""Reset processing statistics"""
self.buckets_created = 0
self.total_volume_processed = 0.0
logger.info("Price bucketer statistics reset")

15
COBY/api/__init__.py Normal file
View File

@ -0,0 +1,15 @@
"""
API layer for the COBY system.
"""
from .rest_api import create_app
from .websocket_server import WebSocketServer
from .rate_limiter import RateLimiter
from .response_formatter import ResponseFormatter
__all__ = [
'create_app',
'WebSocketServer',
'RateLimiter',
'ResponseFormatter'
]

183
COBY/api/rate_limiter.py Normal file
View File

@ -0,0 +1,183 @@
"""
Rate limiting for API endpoints.
"""
import time
from typing import Dict, Optional
from collections import defaultdict, deque
from ..utils.logging import get_logger
logger = get_logger(__name__)
class RateLimiter:
"""
Token bucket rate limiter for API endpoints.
Provides per-client rate limiting with configurable limits.
"""
def __init__(self, requests_per_minute: int = 100, burst_size: int = 20):
"""
Initialize rate limiter.
Args:
requests_per_minute: Maximum requests per minute
burst_size: Maximum burst requests
"""
self.requests_per_minute = requests_per_minute
self.burst_size = burst_size
self.refill_rate = requests_per_minute / 60.0 # tokens per second
# Client buckets: client_id -> {'tokens': float, 'last_refill': float}
self.buckets: Dict[str, Dict] = defaultdict(lambda: {
'tokens': float(burst_size),
'last_refill': time.time()
})
# Request history for monitoring
self.request_history: Dict[str, deque] = defaultdict(lambda: deque(maxlen=1000))
logger.info(f"Rate limiter initialized: {requests_per_minute} req/min, burst: {burst_size}")
def is_allowed(self, client_id: str, tokens_requested: int = 1) -> bool:
"""
Check if request is allowed for client.
Args:
client_id: Client identifier (IP, user ID, etc.)
tokens_requested: Number of tokens requested
Returns:
bool: True if request is allowed, False otherwise
"""
current_time = time.time()
bucket = self.buckets[client_id]
# Refill tokens based on time elapsed
time_elapsed = current_time - bucket['last_refill']
tokens_to_add = time_elapsed * self.refill_rate
# Update bucket
bucket['tokens'] = min(self.burst_size, bucket['tokens'] + tokens_to_add)
bucket['last_refill'] = current_time
# Check if enough tokens available
if bucket['tokens'] >= tokens_requested:
bucket['tokens'] -= tokens_requested
# Record successful request
self.request_history[client_id].append(current_time)
return True
else:
logger.debug(f"Rate limit exceeded for client {client_id}")
return False
def get_remaining_tokens(self, client_id: str) -> float:
"""
Get remaining tokens for client.
Args:
client_id: Client identifier
Returns:
float: Number of remaining tokens
"""
current_time = time.time()
bucket = self.buckets[client_id]
# Calculate current tokens (with refill)
time_elapsed = current_time - bucket['last_refill']
tokens_to_add = time_elapsed * self.refill_rate
current_tokens = min(self.burst_size, bucket['tokens'] + tokens_to_add)
return current_tokens
def get_reset_time(self, client_id: str) -> float:
"""
Get time until bucket is fully refilled.
Args:
client_id: Client identifier
Returns:
float: Seconds until full refill
"""
remaining_tokens = self.get_remaining_tokens(client_id)
tokens_needed = self.burst_size - remaining_tokens
if tokens_needed <= 0:
return 0.0
return tokens_needed / self.refill_rate
def get_client_stats(self, client_id: str) -> Dict[str, float]:
"""
Get statistics for a client.
Args:
client_id: Client identifier
Returns:
Dict: Client statistics
"""
current_time = time.time()
history = self.request_history[client_id]
# Count requests in last minute
minute_ago = current_time - 60
recent_requests = sum(1 for req_time in history if req_time > minute_ago)
return {
'remaining_tokens': self.get_remaining_tokens(client_id),
'reset_time': self.get_reset_time(client_id),
'requests_last_minute': recent_requests,
'total_requests': len(history)
}
def cleanup_old_data(self, max_age_hours: int = 24) -> None:
"""
Clean up old client data.
Args:
max_age_hours: Maximum age of data to keep
"""
current_time = time.time()
cutoff_time = current_time - (max_age_hours * 3600)
# Clean up buckets for inactive clients
inactive_clients = []
for client_id, bucket in self.buckets.items():
if bucket['last_refill'] < cutoff_time:
inactive_clients.append(client_id)
for client_id in inactive_clients:
del self.buckets[client_id]
if client_id in self.request_history:
del self.request_history[client_id]
logger.debug(f"Cleaned up {len(inactive_clients)} inactive clients")
def get_global_stats(self) -> Dict[str, int]:
"""Get global rate limiter statistics"""
current_time = time.time()
minute_ago = current_time - 60
total_clients = len(self.buckets)
active_clients = 0
total_requests_last_minute = 0
for client_id, history in self.request_history.items():
recent_requests = sum(1 for req_time in history if req_time > minute_ago)
if recent_requests > 0:
active_clients += 1
total_requests_last_minute += recent_requests
return {
'total_clients': total_clients,
'active_clients': active_clients,
'requests_per_minute_limit': self.requests_per_minute,
'burst_size': self.burst_size,
'total_requests_last_minute': total_requests_last_minute
}

306
COBY/api/replay_api.py Normal file
View File

@ -0,0 +1,306 @@
"""
REST API endpoints for historical data replay functionality.
"""
from fastapi import APIRouter, HTTPException, Query, Path
from typing import Optional, List, Dict, Any
from datetime import datetime
from pydantic import BaseModel, Field
from ..replay.replay_manager import HistoricalReplayManager
from ..models.core import ReplayStatus
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ReplayError, ValidationError
logger = get_logger(__name__)
class CreateReplayRequest(BaseModel):
"""Request model for creating replay session"""
start_time: datetime = Field(..., description="Replay start time")
end_time: datetime = Field(..., description="Replay end time")
speed: float = Field(1.0, gt=0, le=100, description="Playback speed multiplier")
symbols: Optional[List[str]] = Field(None, description="Symbols to replay")
exchanges: Optional[List[str]] = Field(None, description="Exchanges to replay")
class ReplayControlRequest(BaseModel):
"""Request model for replay control operations"""
action: str = Field(..., description="Control action: start, pause, resume, stop")
class SeekRequest(BaseModel):
"""Request model for seeking in replay"""
timestamp: datetime = Field(..., description="Target timestamp")
class SpeedRequest(BaseModel):
"""Request model for changing replay speed"""
speed: float = Field(..., gt=0, le=100, description="New playback speed")
def create_replay_router(replay_manager: HistoricalReplayManager) -> APIRouter:
"""Create replay API router with endpoints"""
router = APIRouter(prefix="/replay", tags=["replay"])
@router.post("/sessions", response_model=Dict[str, str])
async def create_replay_session(request: CreateReplayRequest):
"""Create a new replay session"""
try:
set_correlation_id()
session_id = replay_manager.create_replay_session(
start_time=request.start_time,
end_time=request.end_time,
speed=request.speed,
symbols=request.symbols,
exchanges=request.exchanges
)
logger.info(f"Created replay session {session_id}")
return {
"session_id": session_id,
"status": "created",
"message": "Replay session created successfully"
}
except ValidationError as e:
logger.warning(f"Invalid replay request: {e}")
raise HTTPException(status_code=400, detail=str(e))
except ReplayError as e:
logger.error(f"Replay creation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
except Exception as e:
logger.error(f"Unexpected error creating replay session: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/sessions", response_model=List[Dict[str, Any]])
async def list_replay_sessions():
"""List all replay sessions"""
try:
sessions = replay_manager.list_replay_sessions()
return [
{
"session_id": session.session_id,
"start_time": session.start_time.isoformat(),
"end_time": session.end_time.isoformat(),
"current_time": session.current_time.isoformat(),
"speed": session.speed,
"status": session.status.value,
"symbols": session.symbols,
"exchanges": session.exchanges,
"progress": session.progress,
"events_replayed": session.events_replayed,
"total_events": session.total_events,
"created_at": session.created_at.isoformat(),
"started_at": session.started_at.isoformat() if session.started_at else None,
"error_message": getattr(session, 'error_message', None)
}
for session in sessions
]
except Exception as e:
logger.error(f"Error listing replay sessions: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/sessions/{session_id}", response_model=Dict[str, Any])
async def get_replay_session(session_id: str = Path(..., description="Session ID")):
"""Get replay session details"""
try:
session = replay_manager.get_replay_status(session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
return {
"session_id": session.session_id,
"start_time": session.start_time.isoformat(),
"end_time": session.end_time.isoformat(),
"current_time": session.current_time.isoformat(),
"speed": session.speed,
"status": session.status.value,
"symbols": session.symbols,
"exchanges": session.exchanges,
"progress": session.progress,
"events_replayed": session.events_replayed,
"total_events": session.total_events,
"created_at": session.created_at.isoformat(),
"started_at": session.started_at.isoformat() if session.started_at else None,
"paused_at": session.paused_at.isoformat() if session.paused_at else None,
"stopped_at": session.stopped_at.isoformat() if session.stopped_at else None,
"completed_at": session.completed_at.isoformat() if session.completed_at else None,
"error_message": getattr(session, 'error_message', None)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting replay session {session_id}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/sessions/{session_id}/control", response_model=Dict[str, str])
async def control_replay_session(
session_id: str = Path(..., description="Session ID"),
request: ReplayControlRequest = None
):
"""Control replay session (start, pause, resume, stop)"""
try:
set_correlation_id()
if not request:
raise HTTPException(status_code=400, detail="Control action required")
action = request.action.lower()
if action == "start":
await replay_manager.start_replay(session_id)
message = "Replay started"
elif action == "pause":
await replay_manager.pause_replay(session_id)
message = "Replay paused"
elif action == "resume":
await replay_manager.resume_replay(session_id)
message = "Replay resumed"
elif action == "stop":
await replay_manager.stop_replay(session_id)
message = "Replay stopped"
else:
raise HTTPException(status_code=400, detail="Invalid action")
logger.info(f"Replay session {session_id} action: {action}")
return {
"session_id": session_id,
"action": action,
"message": message
}
except ReplayError as e:
logger.error(f"Replay control failed for {session_id}: {e}")
raise HTTPException(status_code=400, detail=str(e))
except HTTPException:
raise
except Exception as e:
logger.error(f"Unexpected error controlling replay {session_id}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/sessions/{session_id}/seek", response_model=Dict[str, str])
async def seek_replay_session(
session_id: str = Path(..., description="Session ID"),
request: SeekRequest = None
):
"""Seek to specific timestamp in replay"""
try:
if not request:
raise HTTPException(status_code=400, detail="Timestamp required")
success = replay_manager.seek_replay(session_id, request.timestamp)
if not success:
raise HTTPException(status_code=400, detail="Seek failed")
logger.info(f"Seeked replay session {session_id} to {request.timestamp}")
return {
"session_id": session_id,
"timestamp": request.timestamp.isoformat(),
"message": "Seek successful"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error seeking replay session {session_id}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.post("/sessions/{session_id}/speed", response_model=Dict[str, Any])
async def set_replay_speed(
session_id: str = Path(..., description="Session ID"),
request: SpeedRequest = None
):
"""Change replay speed"""
try:
if not request:
raise HTTPException(status_code=400, detail="Speed required")
success = replay_manager.set_replay_speed(session_id, request.speed)
if not success:
raise HTTPException(status_code=400, detail="Speed change failed")
logger.info(f"Set replay speed to {request.speed}x for session {session_id}")
return {
"session_id": session_id,
"speed": request.speed,
"message": "Speed changed successfully"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error setting replay speed for {session_id}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.delete("/sessions/{session_id}", response_model=Dict[str, str])
async def delete_replay_session(session_id: str = Path(..., description="Session ID")):
"""Delete replay session"""
try:
success = replay_manager.delete_replay_session(session_id)
if not success:
raise HTTPException(status_code=404, detail="Session not found")
logger.info(f"Deleted replay session {session_id}")
return {
"session_id": session_id,
"message": "Session deleted successfully"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error deleting replay session {session_id}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/data-range/{symbol}", response_model=Dict[str, Any])
async def get_data_range(
symbol: str = Path(..., description="Trading symbol"),
exchange: Optional[str] = Query(None, description="Exchange name")
):
"""Get available data time range for a symbol"""
try:
data_range = await replay_manager.get_available_data_range(symbol, exchange)
if not data_range:
raise HTTPException(status_code=404, detail="No data available for symbol")
return {
"symbol": symbol,
"exchange": exchange,
"start_time": data_range['start'].isoformat(),
"end_time": data_range['end'].isoformat(),
"duration_days": (data_range['end'] - data_range['start']).days
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error getting data range for {symbol}: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
@router.get("/stats", response_model=Dict[str, Any])
async def get_replay_stats():
"""Get replay system statistics"""
try:
return replay_manager.get_stats()
except Exception as e:
logger.error(f"Error getting replay stats: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
return router

View File

@ -0,0 +1,435 @@
"""
WebSocket server for real-time replay data streaming.
"""
import asyncio
import json
import logging
from typing import Dict, Set, Optional, Any
from fastapi import WebSocket, WebSocketDisconnect
from datetime import datetime
from ..replay.replay_manager import HistoricalReplayManager
from ..models.core import OrderBookSnapshot, TradeEvent, ReplayStatus
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ReplayError
logger = get_logger(__name__)
class ReplayWebSocketManager:
"""
WebSocket manager for replay data streaming.
Provides:
- Real-time replay data streaming
- Session-based connections
- Automatic cleanup on disconnect
- Status updates
"""
def __init__(self, replay_manager: HistoricalReplayManager):
"""
Initialize WebSocket manager.
Args:
replay_manager: Replay manager instance
"""
self.replay_manager = replay_manager
# Connection management
self.connections: Dict[str, Set[WebSocket]] = {} # session_id -> websockets
self.websocket_sessions: Dict[WebSocket, str] = {} # websocket -> session_id
# Statistics
self.stats = {
'active_connections': 0,
'total_connections': 0,
'messages_sent': 0,
'connection_errors': 0
}
logger.info("Replay WebSocket manager initialized")
async def connect_to_session(self, websocket: WebSocket, session_id: str) -> bool:
"""
Connect WebSocket to a replay session.
Args:
websocket: WebSocket connection
session_id: Replay session ID
Returns:
bool: True if connected successfully, False otherwise
"""
try:
set_correlation_id()
# Check if session exists
session = self.replay_manager.get_replay_status(session_id)
if not session:
await websocket.send_json({
"type": "error",
"message": f"Session {session_id} not found"
})
return False
# Accept WebSocket connection
await websocket.accept()
# Add to connection tracking
if session_id not in self.connections:
self.connections[session_id] = set()
self.connections[session_id].add(websocket)
self.websocket_sessions[websocket] = session_id
# Update statistics
self.stats['active_connections'] += 1
self.stats['total_connections'] += 1
# Add callbacks to replay session
self.replay_manager.add_data_callback(session_id, self._data_callback)
self.replay_manager.add_status_callback(session_id, self._status_callback)
# Send initial session status
await self._send_session_status(websocket, session)
logger.info(f"WebSocket connected to replay session {session_id}")
return True
except Exception as e:
logger.error(f"Failed to connect WebSocket to session {session_id}: {e}")
self.stats['connection_errors'] += 1
return False
async def disconnect(self, websocket: WebSocket) -> None:
"""
Disconnect WebSocket and cleanup.
Args:
websocket: WebSocket connection to disconnect
"""
try:
session_id = self.websocket_sessions.get(websocket)
if session_id:
# Remove from connection tracking
if session_id in self.connections:
self.connections[session_id].discard(websocket)
# Clean up empty session connections
if not self.connections[session_id]:
del self.connections[session_id]
del self.websocket_sessions[websocket]
# Update statistics
self.stats['active_connections'] -= 1
logger.info(f"WebSocket disconnected from replay session {session_id}")
except Exception as e:
logger.error(f"Error during WebSocket disconnect: {e}")
async def handle_websocket_messages(self, websocket: WebSocket) -> None:
"""
Handle incoming WebSocket messages.
Args:
websocket: WebSocket connection
"""
try:
while True:
# Receive message
message = await websocket.receive_json()
# Process message
await self._process_websocket_message(websocket, message)
except WebSocketDisconnect:
logger.info("WebSocket disconnected")
except Exception as e:
logger.error(f"WebSocket message handling error: {e}")
await websocket.send_json({
"type": "error",
"message": "Message processing error"
})
async def _process_websocket_message(self, websocket: WebSocket, message: Dict[str, Any]) -> None:
"""
Process incoming WebSocket message.
Args:
websocket: WebSocket connection
message: Received message
"""
try:
message_type = message.get('type')
session_id = self.websocket_sessions.get(websocket)
if not session_id:
await websocket.send_json({
"type": "error",
"message": "Not connected to any session"
})
return
if message_type == "control":
await self._handle_control_message(websocket, session_id, message)
elif message_type == "seek":
await self._handle_seek_message(websocket, session_id, message)
elif message_type == "speed":
await self._handle_speed_message(websocket, session_id, message)
elif message_type == "status":
await self._handle_status_request(websocket, session_id)
else:
await websocket.send_json({
"type": "error",
"message": f"Unknown message type: {message_type}"
})
except Exception as e:
logger.error(f"Error processing WebSocket message: {e}")
await websocket.send_json({
"type": "error",
"message": "Message processing failed"
})
async def _handle_control_message(self, websocket: WebSocket, session_id: str,
message: Dict[str, Any]) -> None:
"""Handle replay control messages."""
try:
action = message.get('action')
if action == "start":
await self.replay_manager.start_replay(session_id)
elif action == "pause":
await self.replay_manager.pause_replay(session_id)
elif action == "resume":
await self.replay_manager.resume_replay(session_id)
elif action == "stop":
await self.replay_manager.stop_replay(session_id)
else:
await websocket.send_json({
"type": "error",
"message": f"Invalid control action: {action}"
})
return
await websocket.send_json({
"type": "control_response",
"action": action,
"status": "success"
})
except ReplayError as e:
await websocket.send_json({
"type": "error",
"message": str(e)
})
except Exception as e:
logger.error(f"Control message error: {e}")
await websocket.send_json({
"type": "error",
"message": "Control action failed"
})
async def _handle_seek_message(self, websocket: WebSocket, session_id: str,
message: Dict[str, Any]) -> None:
"""Handle seek messages."""
try:
timestamp_str = message.get('timestamp')
if not timestamp_str:
await websocket.send_json({
"type": "error",
"message": "Timestamp required for seek"
})
return
timestamp = datetime.fromisoformat(timestamp_str.replace('Z', '+00:00'))
success = self.replay_manager.seek_replay(session_id, timestamp)
await websocket.send_json({
"type": "seek_response",
"timestamp": timestamp_str,
"status": "success" if success else "failed"
})
except Exception as e:
logger.error(f"Seek message error: {e}")
await websocket.send_json({
"type": "error",
"message": "Seek failed"
})
async def _handle_speed_message(self, websocket: WebSocket, session_id: str,
message: Dict[str, Any]) -> None:
"""Handle speed change messages."""
try:
speed = message.get('speed')
if not speed or speed <= 0:
await websocket.send_json({
"type": "error",
"message": "Valid speed required"
})
return
success = self.replay_manager.set_replay_speed(session_id, speed)
await websocket.send_json({
"type": "speed_response",
"speed": speed,
"status": "success" if success else "failed"
})
except Exception as e:
logger.error(f"Speed message error: {e}")
await websocket.send_json({
"type": "error",
"message": "Speed change failed"
})
async def _handle_status_request(self, websocket: WebSocket, session_id: str) -> None:
"""Handle status request messages."""
try:
session = self.replay_manager.get_replay_status(session_id)
if session:
await self._send_session_status(websocket, session)
else:
await websocket.send_json({
"type": "error",
"message": "Session not found"
})
except Exception as e:
logger.error(f"Status request error: {e}")
await websocket.send_json({
"type": "error",
"message": "Status request failed"
})
async def _data_callback(self, data) -> None:
"""Callback for replay data - broadcasts to all connected WebSockets."""
try:
# Determine which session this data belongs to
# This is a simplified approach - in practice, you'd need to track
# which session generated this callback
# Serialize data
if isinstance(data, OrderBookSnapshot):
message = {
"type": "orderbook",
"data": {
"symbol": data.symbol,
"exchange": data.exchange,
"timestamp": data.timestamp.isoformat(),
"bids": [{"price": b.price, "size": b.size} for b in data.bids[:10]],
"asks": [{"price": a.price, "size": a.size} for a in data.asks[:10]],
"sequence_id": data.sequence_id
}
}
elif isinstance(data, TradeEvent):
message = {
"type": "trade",
"data": {
"symbol": data.symbol,
"exchange": data.exchange,
"timestamp": data.timestamp.isoformat(),
"price": data.price,
"size": data.size,
"side": data.side,
"trade_id": data.trade_id
}
}
else:
return
# Broadcast to all connections
await self._broadcast_message(message)
except Exception as e:
logger.error(f"Data callback error: {e}")
async def _status_callback(self, session_id: str, status: ReplayStatus) -> None:
"""Callback for replay status changes."""
try:
message = {
"type": "status",
"session_id": session_id,
"status": status.value,
"timestamp": datetime.utcnow().isoformat()
}
# Send to connections for this session
if session_id in self.connections:
await self._broadcast_to_session(session_id, message)
except Exception as e:
logger.error(f"Status callback error: {e}")
async def _send_session_status(self, websocket: WebSocket, session) -> None:
"""Send session status to WebSocket."""
try:
message = {
"type": "session_status",
"data": {
"session_id": session.session_id,
"status": session.status.value,
"progress": session.progress,
"current_time": session.current_time.isoformat(),
"speed": session.speed,
"events_replayed": session.events_replayed,
"total_events": session.total_events
}
}
await websocket.send_json(message)
self.stats['messages_sent'] += 1
except Exception as e:
logger.error(f"Error sending session status: {e}")
async def _broadcast_message(self, message: Dict[str, Any]) -> None:
"""Broadcast message to all connected WebSockets."""
disconnected = []
for session_id, websockets in self.connections.items():
for websocket in websockets.copy():
try:
await websocket.send_json(message)
self.stats['messages_sent'] += 1
except Exception as e:
logger.warning(f"Failed to send message to WebSocket: {e}")
disconnected.append((session_id, websocket))
# Clean up disconnected WebSockets
for session_id, websocket in disconnected:
await self.disconnect(websocket)
async def _broadcast_to_session(self, session_id: str, message: Dict[str, Any]) -> None:
"""Broadcast message to WebSockets connected to a specific session."""
if session_id not in self.connections:
return
disconnected = []
for websocket in self.connections[session_id].copy():
try:
await websocket.send_json(message)
self.stats['messages_sent'] += 1
except Exception as e:
logger.warning(f"Failed to send message to WebSocket: {e}")
disconnected.append(websocket)
# Clean up disconnected WebSockets
for websocket in disconnected:
await self.disconnect(websocket)
def get_stats(self) -> Dict[str, Any]:
"""Get WebSocket manager statistics."""
return {
**self.stats,
'sessions_with_connections': len(self.connections),
'total_websockets': sum(len(ws_set) for ws_set in self.connections.values())
}

View File

@ -0,0 +1,326 @@
"""
Response formatting for API endpoints.
"""
import json
from typing import Any, Dict, Optional, List
from datetime import datetime
from ..utils.logging import get_logger
from ..utils.timing import get_current_timestamp
logger = get_logger(__name__)
class ResponseFormatter:
"""
Formats API responses with consistent structure and metadata.
"""
def __init__(self):
"""Initialize response formatter"""
self.responses_formatted = 0
logger.info("Response formatter initialized")
def success(self, data: Any, message: str = "Success",
metadata: Optional[Dict] = None) -> Dict[str, Any]:
"""
Format successful response.
Args:
data: Response data
message: Success message
metadata: Additional metadata
Returns:
Dict: Formatted response
"""
response = {
'success': True,
'message': message,
'data': data,
'timestamp': get_current_timestamp().isoformat(),
'metadata': metadata or {}
}
self.responses_formatted += 1
return response
def error(self, message: str, error_code: str = "UNKNOWN_ERROR",
details: Optional[Dict] = None, status_code: int = 400) -> Dict[str, Any]:
"""
Format error response.
Args:
message: Error message
error_code: Error code
details: Error details
status_code: HTTP status code
Returns:
Dict: Formatted error response
"""
response = {
'success': False,
'error': {
'message': message,
'code': error_code,
'details': details or {},
'status_code': status_code
},
'timestamp': get_current_timestamp().isoformat()
}
self.responses_formatted += 1
return response
def paginated(self, data: List[Any], page: int, page_size: int,
total_count: int, message: str = "Success") -> Dict[str, Any]:
"""
Format paginated response.
Args:
data: Page data
page: Current page number
page_size: Items per page
total_count: Total number of items
message: Success message
Returns:
Dict: Formatted paginated response
"""
total_pages = (total_count + page_size - 1) // page_size
pagination = {
'page': page,
'page_size': page_size,
'total_count': total_count,
'total_pages': total_pages,
'has_next': page < total_pages,
'has_previous': page > 1
}
return self.success(
data=data,
message=message,
metadata={'pagination': pagination}
)
def heatmap_response(self, heatmap_data, symbol: str,
exchange: Optional[str] = None) -> Dict[str, Any]:
"""
Format heatmap data response.
Args:
heatmap_data: Heatmap data
symbol: Trading symbol
exchange: Exchange name (None for consolidated)
Returns:
Dict: Formatted heatmap response
"""
if not heatmap_data:
return self.error("Heatmap data not found", "HEATMAP_NOT_FOUND", status_code=404)
# Convert heatmap to API format
formatted_data = {
'symbol': heatmap_data.symbol,
'timestamp': heatmap_data.timestamp.isoformat(),
'bucket_size': heatmap_data.bucket_size,
'exchange': exchange,
'points': [
{
'price': point.price,
'volume': point.volume,
'intensity': point.intensity,
'side': point.side
}
for point in heatmap_data.data
]
}
metadata = {
'total_points': len(heatmap_data.data),
'bid_points': len([p for p in heatmap_data.data if p.side == 'bid']),
'ask_points': len([p for p in heatmap_data.data if p.side == 'ask']),
'data_type': 'consolidated' if not exchange else 'exchange_specific'
}
return self.success(
data=formatted_data,
message=f"Heatmap data for {symbol}",
metadata=metadata
)
def orderbook_response(self, orderbook_data, symbol: str, exchange: str) -> Dict[str, Any]:
"""
Format order book response.
Args:
orderbook_data: Order book data
symbol: Trading symbol
exchange: Exchange name
Returns:
Dict: Formatted order book response
"""
if not orderbook_data:
return self.error("Order book not found", "ORDERBOOK_NOT_FOUND", status_code=404)
# Convert order book to API format
formatted_data = {
'symbol': orderbook_data.symbol,
'exchange': orderbook_data.exchange,
'timestamp': orderbook_data.timestamp.isoformat(),
'sequence_id': orderbook_data.sequence_id,
'bids': [
{
'price': bid.price,
'size': bid.size,
'count': bid.count
}
for bid in orderbook_data.bids
],
'asks': [
{
'price': ask.price,
'size': ask.size,
'count': ask.count
}
for ask in orderbook_data.asks
],
'mid_price': orderbook_data.mid_price,
'spread': orderbook_data.spread,
'bid_volume': orderbook_data.bid_volume,
'ask_volume': orderbook_data.ask_volume
}
metadata = {
'bid_levels': len(orderbook_data.bids),
'ask_levels': len(orderbook_data.asks),
'total_bid_volume': orderbook_data.bid_volume,
'total_ask_volume': orderbook_data.ask_volume
}
return self.success(
data=formatted_data,
message=f"Order book for {symbol}@{exchange}",
metadata=metadata
)
def metrics_response(self, metrics_data, symbol: str, exchange: str) -> Dict[str, Any]:
"""
Format metrics response.
Args:
metrics_data: Metrics data
symbol: Trading symbol
exchange: Exchange name
Returns:
Dict: Formatted metrics response
"""
if not metrics_data:
return self.error("Metrics not found", "METRICS_NOT_FOUND", status_code=404)
# Convert metrics to API format
formatted_data = {
'symbol': metrics_data.symbol,
'exchange': metrics_data.exchange,
'timestamp': metrics_data.timestamp.isoformat(),
'mid_price': metrics_data.mid_price,
'spread': metrics_data.spread,
'spread_percentage': metrics_data.spread_percentage,
'bid_volume': metrics_data.bid_volume,
'ask_volume': metrics_data.ask_volume,
'volume_imbalance': metrics_data.volume_imbalance,
'depth_10': metrics_data.depth_10,
'depth_50': metrics_data.depth_50
}
return self.success(
data=formatted_data,
message=f"Metrics for {symbol}@{exchange}"
)
def status_response(self, status_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Format system status response.
Args:
status_data: System status data
Returns:
Dict: Formatted status response
"""
return self.success(
data=status_data,
message="System status",
metadata={'response_count': self.responses_formatted}
)
def rate_limit_error(self, client_stats: Dict[str, float]) -> Dict[str, Any]:
"""
Format rate limit error response.
Args:
client_stats: Client rate limit statistics
Returns:
Dict: Formatted rate limit error
"""
return self.error(
message="Rate limit exceeded",
error_code="RATE_LIMIT_EXCEEDED",
details={
'remaining_tokens': client_stats['remaining_tokens'],
'reset_time': client_stats['reset_time'],
'requests_last_minute': client_stats['requests_last_minute']
},
status_code=429
)
def validation_error(self, field: str, message: str) -> Dict[str, Any]:
"""
Format validation error response.
Args:
field: Field that failed validation
message: Validation error message
Returns:
Dict: Formatted validation error
"""
return self.error(
message=f"Validation error: {message}",
error_code="VALIDATION_ERROR",
details={'field': field, 'message': message},
status_code=400
)
def to_json(self, response: Dict[str, Any], indent: Optional[int] = None) -> str:
"""
Convert response to JSON string.
Args:
response: Response dictionary
indent: JSON indentation (None for compact)
Returns:
str: JSON string
"""
try:
return json.dumps(response, indent=indent, ensure_ascii=False, default=str)
except Exception as e:
logger.error(f"Error converting response to JSON: {e}")
return json.dumps(self.error("JSON serialization failed", "JSON_ERROR"))
def get_stats(self) -> Dict[str, int]:
"""Get formatter statistics"""
return {
'responses_formatted': self.responses_formatted
}
def reset_stats(self) -> None:
"""Reset formatter statistics"""
self.responses_formatted = 0
logger.info("Response formatter statistics reset")

391
COBY/api/rest_api.py Normal file
View File

@ -0,0 +1,391 @@
"""
REST API server for COBY system.
"""
from fastapi import FastAPI, HTTPException, Request, Query, Path
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from typing import Optional, List
import asyncio
from ..config import config
from ..caching.redis_manager import redis_manager
from ..utils.logging import get_logger, set_correlation_id
from ..utils.validation import validate_symbol
from .rate_limiter import RateLimiter
from .response_formatter import ResponseFormatter
logger = get_logger(__name__)
def create_app() -> FastAPI:
"""Create and configure FastAPI application"""
app = FastAPI(
title="COBY Market Data API",
description="Real-time cryptocurrency market data aggregation API",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=config.api.cors_origins,
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["*"],
)
# Initialize components
rate_limiter = RateLimiter(
requests_per_minute=config.api.rate_limit,
burst_size=20
)
response_formatter = ResponseFormatter()
@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
"""Rate limiting middleware"""
client_ip = request.client.host
if not rate_limiter.is_allowed(client_ip):
client_stats = rate_limiter.get_client_stats(client_ip)
error_response = response_formatter.rate_limit_error(client_stats)
return JSONResponse(
status_code=429,
content=error_response,
headers={
"X-RateLimit-Remaining": str(int(client_stats['remaining_tokens'])),
"X-RateLimit-Reset": str(int(client_stats['reset_time']))
}
)
response = await call_next(request)
# Add rate limit headers
client_stats = rate_limiter.get_client_stats(client_ip)
response.headers["X-RateLimit-Remaining"] = str(int(client_stats['remaining_tokens']))
response.headers["X-RateLimit-Reset"] = str(int(client_stats['reset_time']))
return response
@app.middleware("http")
async def correlation_middleware(request: Request, call_next):
"""Add correlation ID to requests"""
set_correlation_id()
response = await call_next(request)
return response
@app.on_event("startup")
async def startup_event():
"""Initialize services on startup"""
try:
await redis_manager.initialize()
logger.info("API server startup completed")
except Exception as e:
logger.error(f"API server startup failed: {e}")
raise
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on shutdown"""
try:
await redis_manager.close()
logger.info("API server shutdown completed")
except Exception as e:
logger.error(f"API server shutdown error: {e}")
# Health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint"""
try:
# Check Redis connection
redis_healthy = await redis_manager.ping()
health_data = {
'status': 'healthy' if redis_healthy else 'degraded',
'redis': 'connected' if redis_healthy else 'disconnected',
'version': '1.0.0'
}
return response_formatter.status_response(health_data)
except Exception as e:
logger.error(f"Health check failed: {e}")
return JSONResponse(
status_code=503,
content=response_formatter.error("Service unavailable", "HEALTH_CHECK_FAILED")
)
# Heatmap endpoints
@app.get("/api/v1/heatmap/{symbol}")
async def get_heatmap(
symbol: str = Path(..., description="Trading symbol (e.g., BTCUSDT)"),
exchange: Optional[str] = Query(None, description="Exchange name (None for consolidated)")
):
"""Get heatmap data for a symbol"""
try:
# Validate symbol
if not validate_symbol(symbol):
return JSONResponse(
status_code=400,
content=response_formatter.validation_error("symbol", "Invalid symbol format")
)
# Get heatmap from cache
heatmap_data = await redis_manager.get_heatmap(symbol.upper(), exchange)
return response_formatter.heatmap_response(heatmap_data, symbol.upper(), exchange)
except Exception as e:
logger.error(f"Error getting heatmap for {symbol}: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "HEATMAP_ERROR")
)
# Order book endpoints
@app.get("/api/v1/orderbook/{symbol}/{exchange}")
async def get_orderbook(
symbol: str = Path(..., description="Trading symbol"),
exchange: str = Path(..., description="Exchange name")
):
"""Get order book data for a symbol on an exchange"""
try:
# Validate symbol
if not validate_symbol(symbol):
return JSONResponse(
status_code=400,
content=response_formatter.validation_error("symbol", "Invalid symbol format")
)
# Get order book from cache
orderbook_data = await redis_manager.get_orderbook(symbol.upper(), exchange.lower())
return response_formatter.orderbook_response(orderbook_data, symbol.upper(), exchange.lower())
except Exception as e:
logger.error(f"Error getting order book for {symbol}@{exchange}: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "ORDERBOOK_ERROR")
)
# Metrics endpoints
@app.get("/api/v1/metrics/{symbol}/{exchange}")
async def get_metrics(
symbol: str = Path(..., description="Trading symbol"),
exchange: str = Path(..., description="Exchange name")
):
"""Get metrics data for a symbol on an exchange"""
try:
# Validate symbol
if not validate_symbol(symbol):
return JSONResponse(
status_code=400,
content=response_formatter.validation_error("symbol", "Invalid symbol format")
)
# Get metrics from cache
metrics_data = await redis_manager.get_metrics(symbol.upper(), exchange.lower())
return response_formatter.metrics_response(metrics_data, symbol.upper(), exchange.lower())
except Exception as e:
logger.error(f"Error getting metrics for {symbol}@{exchange}: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "METRICS_ERROR")
)
# Exchange status endpoints
@app.get("/api/v1/status/{exchange}")
async def get_exchange_status(
exchange: str = Path(..., description="Exchange name")
):
"""Get status for an exchange"""
try:
# Get status from cache
status_data = await redis_manager.get_exchange_status(exchange.lower())
if not status_data:
return JSONResponse(
status_code=404,
content=response_formatter.error("Exchange status not found", "STATUS_NOT_FOUND")
)
return response_formatter.success(
data=status_data,
message=f"Status for {exchange}"
)
except Exception as e:
logger.error(f"Error getting status for {exchange}: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "STATUS_ERROR")
)
# List endpoints
@app.get("/api/v1/symbols")
async def list_symbols():
"""List available trading symbols"""
try:
# Get symbols from cache (this would be populated by exchange connectors)
symbols_pattern = "symbols:*"
symbol_keys = await redis_manager.keys(symbols_pattern)
all_symbols = set()
for key in symbol_keys:
symbols_data = await redis_manager.get(key)
if symbols_data and isinstance(symbols_data, list):
all_symbols.update(symbols_data)
return response_formatter.success(
data=sorted(list(all_symbols)),
message="Available trading symbols",
metadata={'total_symbols': len(all_symbols)}
)
except Exception as e:
logger.error(f"Error listing symbols: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "SYMBOLS_ERROR")
)
@app.get("/api/v1/exchanges")
async def list_exchanges():
"""List available exchanges"""
try:
# Get exchange status keys
status_pattern = "st:*"
status_keys = await redis_manager.keys(status_pattern)
exchanges = []
for key in status_keys:
# Extract exchange name from key (st:exchange_name)
exchange_name = key.split(':', 1)[1] if ':' in key else key
exchanges.append(exchange_name)
return response_formatter.success(
data=sorted(exchanges),
message="Available exchanges",
metadata={'total_exchanges': len(exchanges)}
)
except Exception as e:
logger.error(f"Error listing exchanges: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "EXCHANGES_ERROR")
)
# Statistics endpoints
@app.get("/api/v1/stats/cache")
async def get_cache_stats():
"""Get cache statistics"""
try:
cache_stats = redis_manager.get_stats()
redis_health = await redis_manager.health_check()
stats_data = {
'cache_performance': cache_stats,
'redis_health': redis_health
}
return response_formatter.success(
data=stats_data,
message="Cache statistics"
)
except Exception as e:
logger.error(f"Error getting cache stats: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "STATS_ERROR")
)
@app.get("/api/v1/stats/api")
async def get_api_stats():
"""Get API statistics"""
try:
api_stats = {
'rate_limiter': rate_limiter.get_global_stats(),
'response_formatter': response_formatter.get_stats()
}
return response_formatter.success(
data=api_stats,
message="API statistics"
)
except Exception as e:
logger.error(f"Error getting API stats: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "API_STATS_ERROR")
)
# Batch endpoints for efficiency
@app.get("/api/v1/batch/heatmaps")
async def get_batch_heatmaps(
symbols: str = Query(..., description="Comma-separated list of symbols"),
exchange: Optional[str] = Query(None, description="Exchange name (None for consolidated)")
):
"""Get heatmaps for multiple symbols"""
try:
symbol_list = [s.strip().upper() for s in symbols.split(',')]
# Validate all symbols
for symbol in symbol_list:
if not validate_symbol(symbol):
return JSONResponse(
status_code=400,
content=response_formatter.validation_error("symbols", f"Invalid symbol: {symbol}")
)
# Get heatmaps in batch
heatmaps = {}
for symbol in symbol_list:
heatmap_data = await redis_manager.get_heatmap(symbol, exchange)
if heatmap_data:
heatmaps[symbol] = {
'symbol': heatmap_data.symbol,
'timestamp': heatmap_data.timestamp.isoformat(),
'bucket_size': heatmap_data.bucket_size,
'points': [
{
'price': point.price,
'volume': point.volume,
'intensity': point.intensity,
'side': point.side
}
for point in heatmap_data.data
]
}
return response_formatter.success(
data=heatmaps,
message=f"Batch heatmaps for {len(symbol_list)} symbols",
metadata={
'requested_symbols': len(symbol_list),
'found_heatmaps': len(heatmaps),
'exchange': exchange or 'consolidated'
}
)
except Exception as e:
logger.error(f"Error getting batch heatmaps: {e}")
return JSONResponse(
status_code=500,
content=response_formatter.error("Internal server error", "BATCH_HEATMAPS_ERROR")
)
return app
# Create the FastAPI app instance
app = create_app()

View File

@ -0,0 +1,400 @@
"""
WebSocket server for real-time data streaming.
"""
import asyncio
import json
from typing import Dict, Set, Optional, Any
from fastapi import WebSocket, WebSocketDisconnect
from ..utils.logging import get_logger, set_correlation_id
from ..utils.validation import validate_symbol
from ..caching.redis_manager import redis_manager
from .response_formatter import ResponseFormatter
logger = get_logger(__name__)
class WebSocketManager:
"""
Manages WebSocket connections and real-time data streaming.
"""
def __init__(self):
"""Initialize WebSocket manager"""
# Active connections: connection_id -> WebSocket
self.connections: Dict[str, WebSocket] = {}
# Subscriptions: symbol -> set of connection_ids
self.subscriptions: Dict[str, Set[str]] = {}
# Connection metadata: connection_id -> metadata
self.connection_metadata: Dict[str, Dict[str, Any]] = {}
self.response_formatter = ResponseFormatter()
self.connection_counter = 0
logger.info("WebSocket manager initialized")
async def connect(self, websocket: WebSocket, client_ip: str) -> str:
"""
Accept new WebSocket connection.
Args:
websocket: WebSocket connection
client_ip: Client IP address
Returns:
str: Connection ID
"""
await websocket.accept()
# Generate connection ID
self.connection_counter += 1
connection_id = f"ws_{self.connection_counter}_{client_ip}"
# Store connection
self.connections[connection_id] = websocket
self.connection_metadata[connection_id] = {
'client_ip': client_ip,
'connected_at': asyncio.get_event_loop().time(),
'subscriptions': set(),
'messages_sent': 0
}
logger.info(f"WebSocket connected: {connection_id}")
# Send welcome message
welcome_msg = self.response_formatter.success(
data={'connection_id': connection_id},
message="WebSocket connected successfully"
)
await self._send_to_connection(connection_id, welcome_msg)
return connection_id
async def disconnect(self, connection_id: str) -> None:
"""
Handle WebSocket disconnection.
Args:
connection_id: Connection ID to disconnect
"""
if connection_id in self.connections:
# Remove from all subscriptions
metadata = self.connection_metadata.get(connection_id, {})
for symbol in metadata.get('subscriptions', set()):
await self._unsubscribe_connection(connection_id, symbol)
# Remove connection
del self.connections[connection_id]
del self.connection_metadata[connection_id]
logger.info(f"WebSocket disconnected: {connection_id}")
async def subscribe(self, connection_id: str, symbol: str,
data_type: str = "heatmap") -> bool:
"""
Subscribe connection to symbol updates.
Args:
connection_id: Connection ID
symbol: Trading symbol
data_type: Type of data to subscribe to
Returns:
bool: True if subscribed successfully
"""
try:
# Validate symbol
if not validate_symbol(symbol):
error_msg = self.response_formatter.validation_error("symbol", "Invalid symbol format")
await self._send_to_connection(connection_id, error_msg)
return False
symbol = symbol.upper()
subscription_key = f"{symbol}:{data_type}"
# Add to subscriptions
if subscription_key not in self.subscriptions:
self.subscriptions[subscription_key] = set()
self.subscriptions[subscription_key].add(connection_id)
# Update connection metadata
if connection_id in self.connection_metadata:
self.connection_metadata[connection_id]['subscriptions'].add(subscription_key)
logger.info(f"WebSocket {connection_id} subscribed to {subscription_key}")
# Send confirmation
confirm_msg = self.response_formatter.success(
data={'symbol': symbol, 'data_type': data_type},
message=f"Subscribed to {symbol} {data_type} updates"
)
await self._send_to_connection(connection_id, confirm_msg)
# Send initial data if available
await self._send_initial_data(connection_id, symbol, data_type)
return True
except Exception as e:
logger.error(f"Error subscribing {connection_id} to {symbol}: {e}")
error_msg = self.response_formatter.error("Subscription failed", "SUBSCRIBE_ERROR")
await self._send_to_connection(connection_id, error_msg)
return False
async def unsubscribe(self, connection_id: str, symbol: str,
data_type: str = "heatmap") -> bool:
"""
Unsubscribe connection from symbol updates.
Args:
connection_id: Connection ID
symbol: Trading symbol
data_type: Type of data to unsubscribe from
Returns:
bool: True if unsubscribed successfully
"""
try:
symbol = symbol.upper()
subscription_key = f"{symbol}:{data_type}"
await self._unsubscribe_connection(connection_id, subscription_key)
# Send confirmation
confirm_msg = self.response_formatter.success(
data={'symbol': symbol, 'data_type': data_type},
message=f"Unsubscribed from {symbol} {data_type} updates"
)
await self._send_to_connection(connection_id, confirm_msg)
return True
except Exception as e:
logger.error(f"Error unsubscribing {connection_id} from {symbol}: {e}")
return False
async def broadcast_update(self, symbol: str, data_type: str, data: Any) -> int:
"""
Broadcast data update to all subscribers.
Args:
symbol: Trading symbol
data_type: Type of data
data: Data to broadcast
Returns:
int: Number of connections notified
"""
try:
set_correlation_id()
subscription_key = f"{symbol.upper()}:{data_type}"
subscribers = self.subscriptions.get(subscription_key, set())
if not subscribers:
return 0
# Format message based on data type
if data_type == "heatmap":
message = self.response_formatter.heatmap_response(data, symbol)
elif data_type == "orderbook":
message = self.response_formatter.orderbook_response(data, symbol, "consolidated")
else:
message = self.response_formatter.success(data, f"{data_type} update for {symbol}")
# Add update type to message
message['update_type'] = data_type
message['symbol'] = symbol
# Send to all subscribers
sent_count = 0
for connection_id in subscribers.copy(): # Copy to avoid modification during iteration
if await self._send_to_connection(connection_id, message):
sent_count += 1
logger.debug(f"Broadcasted {data_type} update for {symbol} to {sent_count} connections")
return sent_count
except Exception as e:
logger.error(f"Error broadcasting update for {symbol}: {e}")
return 0
async def _send_to_connection(self, connection_id: str, message: Dict[str, Any]) -> bool:
"""
Send message to specific connection.
Args:
connection_id: Connection ID
message: Message to send
Returns:
bool: True if sent successfully
"""
try:
if connection_id not in self.connections:
return False
websocket = self.connections[connection_id]
message_json = json.dumps(message, default=str)
await websocket.send_text(message_json)
# Update statistics
if connection_id in self.connection_metadata:
self.connection_metadata[connection_id]['messages_sent'] += 1
return True
except Exception as e:
logger.warning(f"Error sending message to {connection_id}: {e}")
# Remove broken connection
await self.disconnect(connection_id)
return False
async def _unsubscribe_connection(self, connection_id: str, subscription_key: str) -> None:
"""Remove connection from subscription"""
if subscription_key in self.subscriptions:
self.subscriptions[subscription_key].discard(connection_id)
# Clean up empty subscriptions
if not self.subscriptions[subscription_key]:
del self.subscriptions[subscription_key]
# Update connection metadata
if connection_id in self.connection_metadata:
self.connection_metadata[connection_id]['subscriptions'].discard(subscription_key)
async def _send_initial_data(self, connection_id: str, symbol: str, data_type: str) -> None:
"""Send initial data to newly subscribed connection"""
try:
if data_type == "heatmap":
# Get latest heatmap from cache
heatmap_data = await redis_manager.get_heatmap(symbol)
if heatmap_data:
message = self.response_formatter.heatmap_response(heatmap_data, symbol)
message['update_type'] = 'initial_data'
await self._send_to_connection(connection_id, message)
elif data_type == "orderbook":
# Could get latest order book from cache
# This would require knowing which exchange to get data from
pass
except Exception as e:
logger.warning(f"Error sending initial data to {connection_id}: {e}")
def get_stats(self) -> Dict[str, Any]:
"""Get WebSocket manager statistics"""
total_subscriptions = sum(len(subs) for subs in self.subscriptions.values())
return {
'active_connections': len(self.connections),
'total_subscriptions': total_subscriptions,
'unique_symbols': len(set(key.split(':')[0] for key in self.subscriptions.keys())),
'connection_counter': self.connection_counter
}
# Global WebSocket manager instance
websocket_manager = WebSocketManager()
class WebSocketServer:
"""
WebSocket server for real-time data streaming.
"""
def __init__(self):
"""Initialize WebSocket server"""
self.manager = websocket_manager
logger.info("WebSocket server initialized")
async def handle_connection(self, websocket: WebSocket, client_ip: str) -> None:
"""
Handle WebSocket connection lifecycle.
Args:
websocket: WebSocket connection
client_ip: Client IP address
"""
connection_id = None
try:
# Accept connection
connection_id = await self.manager.connect(websocket, client_ip)
# Handle messages
while True:
try:
# Receive message
message = await websocket.receive_text()
await self._handle_message(connection_id, message)
except WebSocketDisconnect:
logger.info(f"WebSocket client disconnected: {connection_id}")
break
except Exception as e:
logger.error(f"WebSocket connection error: {e}")
finally:
# Clean up connection
if connection_id:
await self.manager.disconnect(connection_id)
async def _handle_message(self, connection_id: str, message: str) -> None:
"""
Handle incoming WebSocket message.
Args:
connection_id: Connection ID
message: Received message
"""
try:
# Parse message
data = json.loads(message)
action = data.get('action')
if action == 'subscribe':
symbol = data.get('symbol')
data_type = data.get('data_type', 'heatmap')
await self.manager.subscribe(connection_id, symbol, data_type)
elif action == 'unsubscribe':
symbol = data.get('symbol')
data_type = data.get('data_type', 'heatmap')
await self.manager.unsubscribe(connection_id, symbol, data_type)
elif action == 'ping':
# Send pong response
pong_msg = self.manager.response_formatter.success(
data={'action': 'pong'},
message="Pong"
)
await self.manager._send_to_connection(connection_id, pong_msg)
else:
# Unknown action
error_msg = self.manager.response_formatter.error(
f"Unknown action: {action}",
"UNKNOWN_ACTION"
)
await self.manager._send_to_connection(connection_id, error_msg)
except json.JSONDecodeError:
error_msg = self.manager.response_formatter.error(
"Invalid JSON message",
"INVALID_JSON"
)
await self.manager._send_to_connection(connection_id, error_msg)
except Exception as e:
logger.error(f"Error handling WebSocket message: {e}")
error_msg = self.manager.response_formatter.error(
"Message processing failed",
"MESSAGE_ERROR"
)
await self.manager._send_to_connection(connection_id, error_msg)

13
COBY/caching/__init__.py Normal file
View File

@ -0,0 +1,13 @@
"""
Caching layer for the COBY system.
"""
from .redis_manager import RedisManager
from .cache_keys import CacheKeys
from .data_serializer import DataSerializer
__all__ = [
'RedisManager',
'CacheKeys',
'DataSerializer'
]

278
COBY/caching/cache_keys.py Normal file
View File

@ -0,0 +1,278 @@
"""
Cache key management for Redis operations.
"""
from typing import Optional
from ..utils.logging import get_logger
logger = get_logger(__name__)
class CacheKeys:
"""
Centralized cache key management for consistent Redis operations.
Provides standardized key patterns for different data types.
"""
# Key prefixes
ORDERBOOK_PREFIX = "ob"
HEATMAP_PREFIX = "hm"
TRADE_PREFIX = "tr"
METRICS_PREFIX = "mt"
STATUS_PREFIX = "st"
STATS_PREFIX = "stats"
# TTL values (seconds)
ORDERBOOK_TTL = 60 # 1 minute
HEATMAP_TTL = 30 # 30 seconds
TRADE_TTL = 300 # 5 minutes
METRICS_TTL = 120 # 2 minutes
STATUS_TTL = 60 # 1 minute
STATS_TTL = 300 # 5 minutes
@classmethod
def orderbook_key(cls, symbol: str, exchange: str) -> str:
"""
Generate cache key for order book data.
Args:
symbol: Trading symbol
exchange: Exchange name
Returns:
str: Cache key
"""
return f"{cls.ORDERBOOK_PREFIX}:{exchange}:{symbol}"
@classmethod
def heatmap_key(cls, symbol: str, bucket_size: float = 1.0,
exchange: Optional[str] = None) -> str:
"""
Generate cache key for heatmap data.
Args:
symbol: Trading symbol
bucket_size: Price bucket size
exchange: Exchange name (None for consolidated)
Returns:
str: Cache key
"""
if exchange:
return f"{cls.HEATMAP_PREFIX}:{exchange}:{symbol}:{bucket_size}"
else:
return f"{cls.HEATMAP_PREFIX}:consolidated:{symbol}:{bucket_size}"
@classmethod
def trade_key(cls, symbol: str, exchange: str, trade_id: str) -> str:
"""
Generate cache key for trade data.
Args:
symbol: Trading symbol
exchange: Exchange name
trade_id: Trade identifier
Returns:
str: Cache key
"""
return f"{cls.TRADE_PREFIX}:{exchange}:{symbol}:{trade_id}"
@classmethod
def metrics_key(cls, symbol: str, exchange: str) -> str:
"""
Generate cache key for metrics data.
Args:
symbol: Trading symbol
exchange: Exchange name
Returns:
str: Cache key
"""
return f"{cls.METRICS_PREFIX}:{exchange}:{symbol}"
@classmethod
def status_key(cls, exchange: str) -> str:
"""
Generate cache key for exchange status.
Args:
exchange: Exchange name
Returns:
str: Cache key
"""
return f"{cls.STATUS_PREFIX}:{exchange}"
@classmethod
def stats_key(cls, component: str) -> str:
"""
Generate cache key for component statistics.
Args:
component: Component name
Returns:
str: Cache key
"""
return f"{cls.STATS_PREFIX}:{component}"
@classmethod
def latest_heatmaps_key(cls, symbol: str) -> str:
"""
Generate cache key for latest heatmaps list.
Args:
symbol: Trading symbol
Returns:
str: Cache key
"""
return f"{cls.HEATMAP_PREFIX}:latest:{symbol}"
@classmethod
def symbol_list_key(cls, exchange: str) -> str:
"""
Generate cache key for symbol list.
Args:
exchange: Exchange name
Returns:
str: Cache key
"""
return f"symbols:{exchange}"
@classmethod
def price_bucket_key(cls, symbol: str, exchange: str) -> str:
"""
Generate cache key for price buckets.
Args:
symbol: Trading symbol
exchange: Exchange name
Returns:
str: Cache key
"""
return f"buckets:{exchange}:{symbol}"
@classmethod
def arbitrage_key(cls, symbol: str) -> str:
"""
Generate cache key for arbitrage opportunities.
Args:
symbol: Trading symbol
Returns:
str: Cache key
"""
return f"arbitrage:{symbol}"
@classmethod
def get_ttl(cls, key: str) -> int:
"""
Get appropriate TTL for a cache key.
Args:
key: Cache key
Returns:
int: TTL in seconds
"""
if key.startswith(cls.ORDERBOOK_PREFIX):
return cls.ORDERBOOK_TTL
elif key.startswith(cls.HEATMAP_PREFIX):
return cls.HEATMAP_TTL
elif key.startswith(cls.TRADE_PREFIX):
return cls.TRADE_TTL
elif key.startswith(cls.METRICS_PREFIX):
return cls.METRICS_TTL
elif key.startswith(cls.STATUS_PREFIX):
return cls.STATUS_TTL
elif key.startswith(cls.STATS_PREFIX):
return cls.STATS_TTL
else:
return 300 # Default 5 minutes
@classmethod
def parse_key(cls, key: str) -> dict:
"""
Parse cache key to extract components.
Args:
key: Cache key to parse
Returns:
dict: Parsed key components
"""
parts = key.split(':')
if len(parts) < 2:
return {'type': 'unknown', 'key': key}
key_type = parts[0]
if key_type == cls.ORDERBOOK_PREFIX and len(parts) >= 3:
return {
'type': 'orderbook',
'exchange': parts[1],
'symbol': parts[2]
}
elif key_type == cls.HEATMAP_PREFIX and len(parts) >= 4:
return {
'type': 'heatmap',
'exchange': parts[1] if parts[1] != 'consolidated' else None,
'symbol': parts[2],
'bucket_size': float(parts[3]) if len(parts) > 3 else 1.0
}
elif key_type == cls.TRADE_PREFIX and len(parts) >= 4:
return {
'type': 'trade',
'exchange': parts[1],
'symbol': parts[2],
'trade_id': parts[3]
}
elif key_type == cls.METRICS_PREFIX and len(parts) >= 3:
return {
'type': 'metrics',
'exchange': parts[1],
'symbol': parts[2]
}
elif key_type == cls.STATUS_PREFIX and len(parts) >= 2:
return {
'type': 'status',
'exchange': parts[1]
}
elif key_type == cls.STATS_PREFIX and len(parts) >= 2:
return {
'type': 'stats',
'component': parts[1]
}
else:
return {'type': 'unknown', 'key': key}
@classmethod
def get_pattern(cls, key_type: str) -> str:
"""
Get Redis pattern for key type.
Args:
key_type: Type of key
Returns:
str: Redis pattern
"""
patterns = {
'orderbook': f"{cls.ORDERBOOK_PREFIX}:*",
'heatmap': f"{cls.HEATMAP_PREFIX}:*",
'trade': f"{cls.TRADE_PREFIX}:*",
'metrics': f"{cls.METRICS_PREFIX}:*",
'status': f"{cls.STATUS_PREFIX}:*",
'stats': f"{cls.STATS_PREFIX}:*"
}
return patterns.get(key_type, "*")

View File

@ -0,0 +1,355 @@
"""
Data serialization for Redis caching.
"""
import json
import pickle
import gzip
from typing import Any, Union, Dict, List
from datetime import datetime
from ..models.core import (
OrderBookSnapshot, TradeEvent, HeatmapData, PriceBuckets,
OrderBookMetrics, ImbalanceMetrics, ConsolidatedOrderBook
)
from ..utils.logging import get_logger
from ..utils.exceptions import ProcessingError
logger = get_logger(__name__)
class DataSerializer:
"""
Handles serialization and deserialization of data for Redis storage.
Supports multiple serialization formats:
- JSON for simple data
- Pickle for complex objects
- Compressed formats for large data
"""
def __init__(self, use_compression: bool = True):
"""
Initialize data serializer.
Args:
use_compression: Whether to use gzip compression
"""
self.use_compression = use_compression
self.serialization_stats = {
'serialized': 0,
'deserialized': 0,
'compression_ratio': 0.0,
'errors': 0
}
logger.info(f"Data serializer initialized (compression: {use_compression})")
def serialize(self, data: Any, format_type: str = 'auto') -> bytes:
"""
Serialize data for Redis storage.
Args:
data: Data to serialize
format_type: Serialization format ('json', 'pickle', 'auto')
Returns:
bytes: Serialized data
"""
try:
# Determine format
if format_type == 'auto':
format_type = self._determine_format(data)
# Serialize based on format
if format_type == 'json':
serialized = self._serialize_json(data)
elif format_type == 'pickle':
serialized = self._serialize_pickle(data)
else:
raise ValueError(f"Unsupported format: {format_type}")
# Apply compression if enabled
if self.use_compression:
original_size = len(serialized)
serialized = gzip.compress(serialized)
compressed_size = len(serialized)
# Update compression ratio
if original_size > 0:
ratio = compressed_size / original_size
self.serialization_stats['compression_ratio'] = (
(self.serialization_stats['compression_ratio'] *
self.serialization_stats['serialized'] + ratio) /
(self.serialization_stats['serialized'] + 1)
)
self.serialization_stats['serialized'] += 1
return serialized
except Exception as e:
self.serialization_stats['errors'] += 1
logger.error(f"Serialization error: {e}")
raise ProcessingError(f"Serialization failed: {e}", "SERIALIZE_ERROR")
def deserialize(self, data: bytes, format_type: str = 'auto') -> Any:
"""
Deserialize data from Redis storage.
Args:
data: Serialized data
format_type: Expected format ('json', 'pickle', 'auto')
Returns:
Any: Deserialized data
"""
try:
# Decompress if needed
if self.use_compression:
try:
data = gzip.decompress(data)
except gzip.BadGzipFile:
# Data might not be compressed
pass
# Determine format if auto
if format_type == 'auto':
format_type = self._detect_format(data)
# Deserialize based on format
if format_type == 'json':
result = self._deserialize_json(data)
elif format_type == 'pickle':
result = self._deserialize_pickle(data)
else:
raise ValueError(f"Unsupported format: {format_type}")
self.serialization_stats['deserialized'] += 1
return result
except Exception as e:
self.serialization_stats['errors'] += 1
logger.error(f"Deserialization error: {e}")
raise ProcessingError(f"Deserialization failed: {e}", "DESERIALIZE_ERROR")
def _determine_format(self, data: Any) -> str:
"""Determine best serialization format for data"""
# Use JSON for simple data types
if isinstance(data, (dict, list, str, int, float, bool)) or data is None:
return 'json'
# Use pickle for complex objects
return 'pickle'
def _detect_format(self, data: bytes) -> str:
"""Detect serialization format from data"""
try:
# Try JSON first
json.loads(data.decode('utf-8'))
return 'json'
except (json.JSONDecodeError, UnicodeDecodeError):
# Assume pickle
return 'pickle'
def _serialize_json(self, data: Any) -> bytes:
"""Serialize data as JSON"""
# Convert complex objects to dictionaries
if hasattr(data, '__dict__'):
data = self._object_to_dict(data)
elif isinstance(data, list):
data = [self._object_to_dict(item) if hasattr(item, '__dict__') else item
for item in data]
json_str = json.dumps(data, default=self._json_serializer, ensure_ascii=False)
return json_str.encode('utf-8')
def _deserialize_json(self, data: bytes) -> Any:
"""Deserialize JSON data"""
json_str = data.decode('utf-8')
return json.loads(json_str, object_hook=self._json_deserializer)
def _serialize_pickle(self, data: Any) -> bytes:
"""Serialize data as pickle"""
return pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
def _deserialize_pickle(self, data: bytes) -> Any:
"""Deserialize pickle data"""
return pickle.loads(data)
def _object_to_dict(self, obj: Any) -> Dict:
"""Convert object to dictionary for JSON serialization"""
if isinstance(obj, (OrderBookSnapshot, TradeEvent, HeatmapData,
PriceBuckets, OrderBookMetrics, ImbalanceMetrics,
ConsolidatedOrderBook)):
result = {
'__type__': obj.__class__.__name__,
'__data__': {}
}
# Convert object attributes
for key, value in obj.__dict__.items():
if isinstance(value, datetime):
result['__data__'][key] = {
'__datetime__': value.isoformat()
}
elif isinstance(value, list):
result['__data__'][key] = [
self._object_to_dict(item) if hasattr(item, '__dict__') else item
for item in value
]
elif hasattr(value, '__dict__'):
result['__data__'][key] = self._object_to_dict(value)
else:
result['__data__'][key] = value
return result
else:
return obj.__dict__ if hasattr(obj, '__dict__') else obj
def _json_serializer(self, obj: Any) -> Any:
"""Custom JSON serializer for special types"""
if isinstance(obj, datetime):
return {'__datetime__': obj.isoformat()}
elif hasattr(obj, '__dict__'):
return self._object_to_dict(obj)
else:
return str(obj)
def _json_deserializer(self, obj: Dict) -> Any:
"""Custom JSON deserializer for special types"""
if '__datetime__' in obj:
return datetime.fromisoformat(obj['__datetime__'])
elif '__type__' in obj and '__data__' in obj:
return self._reconstruct_object(obj['__type__'], obj['__data__'])
else:
return obj
def _reconstruct_object(self, type_name: str, data: Dict) -> Any:
"""Reconstruct object from serialized data"""
# Import required classes
from ..models.core import (
OrderBookSnapshot, TradeEvent, HeatmapData, PriceBuckets,
OrderBookMetrics, ImbalanceMetrics, ConsolidatedOrderBook,
PriceLevel, HeatmapPoint
)
# Map type names to classes
type_map = {
'OrderBookSnapshot': OrderBookSnapshot,
'TradeEvent': TradeEvent,
'HeatmapData': HeatmapData,
'PriceBuckets': PriceBuckets,
'OrderBookMetrics': OrderBookMetrics,
'ImbalanceMetrics': ImbalanceMetrics,
'ConsolidatedOrderBook': ConsolidatedOrderBook,
'PriceLevel': PriceLevel,
'HeatmapPoint': HeatmapPoint
}
if type_name in type_map:
cls = type_map[type_name]
# Recursively deserialize nested objects
processed_data = {}
for key, value in data.items():
if isinstance(value, dict) and '__datetime__' in value:
processed_data[key] = datetime.fromisoformat(value['__datetime__'])
elif isinstance(value, dict) and '__type__' in value:
processed_data[key] = self._reconstruct_object(
value['__type__'], value['__data__']
)
elif isinstance(value, list):
processed_data[key] = [
self._reconstruct_object(item['__type__'], item['__data__'])
if isinstance(item, dict) and '__type__' in item
else item
for item in value
]
else:
processed_data[key] = value
try:
return cls(**processed_data)
except Exception as e:
logger.warning(f"Failed to reconstruct {type_name}: {e}")
return processed_data
else:
logger.warning(f"Unknown type for reconstruction: {type_name}")
return data
def serialize_heatmap(self, heatmap: HeatmapData) -> bytes:
"""Specialized serialization for heatmap data"""
try:
# Create optimized representation
heatmap_dict = {
'symbol': heatmap.symbol,
'timestamp': heatmap.timestamp.isoformat(),
'bucket_size': heatmap.bucket_size,
'points': [
{
'p': point.price, # price
'v': point.volume, # volume
'i': point.intensity, # intensity
's': point.side # side
}
for point in heatmap.data
]
}
return self.serialize(heatmap_dict, 'json')
except Exception as e:
logger.error(f"Heatmap serialization error: {e}")
# Fallback to standard serialization
return self.serialize(heatmap, 'pickle')
def deserialize_heatmap(self, data: bytes) -> HeatmapData:
"""Specialized deserialization for heatmap data"""
try:
# Try optimized format first
heatmap_dict = self.deserialize(data, 'json')
if isinstance(heatmap_dict, dict) and 'points' in heatmap_dict:
from ..models.core import HeatmapData, HeatmapPoint
# Reconstruct heatmap points
points = []
for point_data in heatmap_dict['points']:
point = HeatmapPoint(
price=point_data['p'],
volume=point_data['v'],
intensity=point_data['i'],
side=point_data['s']
)
points.append(point)
# Create heatmap
heatmap = HeatmapData(
symbol=heatmap_dict['symbol'],
timestamp=datetime.fromisoformat(heatmap_dict['timestamp']),
bucket_size=heatmap_dict['bucket_size']
)
heatmap.data = points
return heatmap
else:
# Fallback to standard deserialization
return self.deserialize(data, 'pickle')
except Exception as e:
logger.error(f"Heatmap deserialization error: {e}")
# Final fallback
return self.deserialize(data, 'pickle')
def get_stats(self) -> Dict[str, Any]:
"""Get serialization statistics"""
return self.serialization_stats.copy()
def reset_stats(self) -> None:
"""Reset serialization statistics"""
self.serialization_stats = {
'serialized': 0,
'deserialized': 0,
'compression_ratio': 0.0,
'errors': 0
}
logger.info("Serialization statistics reset")

View File

@ -0,0 +1,691 @@
"""
Redis cache manager for high-performance data access.
"""
import asyncio
import redis.asyncio as redis
from typing import Any, Optional, List, Dict, Union
from datetime import datetime, timedelta
from ..config import config
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import StorageError
from ..utils.timing import get_current_timestamp
from .cache_keys import CacheKeys
from .data_serializer import DataSerializer
logger = get_logger(__name__)
class RedisManager:
"""
High-performance Redis cache manager for market data.
Provides:
- Connection pooling and management
- Data serialization and compression
- TTL management
- Batch operations
- Performance monitoring
"""
def __init__(self):
"""Initialize Redis manager"""
self.redis_pool: Optional[redis.ConnectionPool] = None
self.redis_client: Optional[redis.Redis] = None
self.serializer = DataSerializer(use_compression=True)
self.cache_keys = CacheKeys()
# Performance statistics
self.stats = {
'gets': 0,
'sets': 0,
'deletes': 0,
'hits': 0,
'misses': 0,
'errors': 0,
'total_data_size': 0,
'avg_response_time': 0.0
}
logger.info("Redis manager initialized")
async def initialize(self) -> None:
"""Initialize Redis connection pool"""
try:
# Create connection pool
self.redis_pool = redis.ConnectionPool(
host=config.redis.host,
port=config.redis.port,
password=config.redis.password,
db=config.redis.db,
max_connections=config.redis.max_connections,
socket_timeout=config.redis.socket_timeout,
socket_connect_timeout=config.redis.socket_connect_timeout,
decode_responses=False, # We handle bytes directly
retry_on_timeout=True,
health_check_interval=30
)
# Create Redis client
self.redis_client = redis.Redis(connection_pool=self.redis_pool)
# Test connection
await self.redis_client.ping()
logger.info(f"Redis connection established: {config.redis.host}:{config.redis.port}")
except Exception as e:
logger.error(f"Failed to initialize Redis connection: {e}")
raise StorageError(f"Redis initialization failed: {e}", "REDIS_INIT_ERROR")
async def close(self) -> None:
"""Close Redis connections"""
try:
if self.redis_client:
await self.redis_client.close()
if self.redis_pool:
await self.redis_pool.disconnect()
logger.info("Redis connections closed")
except Exception as e:
logger.warning(f"Error closing Redis connections: {e}")
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> bool:
"""
Set value in cache with optional TTL.
Args:
key: Cache key
value: Value to cache
ttl: Time to live in seconds (None = use default)
Returns:
bool: True if successful, False otherwise
"""
try:
set_correlation_id()
start_time = asyncio.get_event_loop().time()
# Serialize value
serialized_value = self.serializer.serialize(value)
# Determine TTL
if ttl is None:
ttl = self.cache_keys.get_ttl(key)
# Set in Redis
result = await self.redis_client.setex(key, ttl, serialized_value)
# Update statistics
self.stats['sets'] += 1
self.stats['total_data_size'] += len(serialized_value)
# Update response time
response_time = asyncio.get_event_loop().time() - start_time
self._update_avg_response_time(response_time)
logger.debug(f"Cached data: {key} (size: {len(serialized_value)} bytes, ttl: {ttl}s)")
return bool(result)
except Exception as e:
self.stats['errors'] += 1
logger.error(f"Error setting cache key {key}: {e}")
return False
async def get(self, key: str) -> Optional[Any]:
"""
Get value from cache.
Args:
key: Cache key
Returns:
Any: Cached value or None if not found
"""
try:
set_correlation_id()
start_time = asyncio.get_event_loop().time()
# Get from Redis
serialized_value = await self.redis_client.get(key)
# Update statistics
self.stats['gets'] += 1
if serialized_value is None:
self.stats['misses'] += 1
logger.debug(f"Cache miss: {key}")
return None
# Deserialize value
value = self.serializer.deserialize(serialized_value)
# Update statistics
self.stats['hits'] += 1
# Update response time
response_time = asyncio.get_event_loop().time() - start_time
self._update_avg_response_time(response_time)
logger.debug(f"Cache hit: {key} (size: {len(serialized_value)} bytes)")
return value
except Exception as e:
self.stats['errors'] += 1
logger.error(f"Error getting cache key {key}: {e}")
return None
async def delete(self, key: str) -> bool:
"""
Delete key from cache.
Args:
key: Cache key to delete
Returns:
bool: True if deleted, False otherwise
"""
try:
set_correlation_id()
result = await self.redis_client.delete(key)
self.stats['deletes'] += 1
logger.debug(f"Deleted cache key: {key}")
return bool(result)
except Exception as e:
self.stats['errors'] += 1
logger.error(f"Error deleting cache key {key}: {e}")
return False
async def exists(self, key: str) -> bool:
"""
Check if key exists in cache.
Args:
key: Cache key to check
Returns:
bool: True if exists, False otherwise
"""
try:
result = await self.redis_client.exists(key)
return bool(result)
except Exception as e:
logger.error(f"Error checking cache key existence {key}: {e}")
return False
async def expire(self, key: str, ttl: int) -> bool:
"""
Set expiration time for key.
Args:
key: Cache key
ttl: Time to live in seconds
Returns:
bool: True if successful, False otherwise
"""
try:
result = await self.redis_client.expire(key, ttl)
return bool(result)
except Exception as e:
logger.error(f"Error setting expiration for key {key}: {e}")
return False
async def mget(self, keys: List[str]) -> List[Optional[Any]]:
"""
Get multiple values from cache.
Args:
keys: List of cache keys
Returns:
List[Optional[Any]]: List of values (None for missing keys)
"""
try:
set_correlation_id()
start_time = asyncio.get_event_loop().time()
# Get from Redis
serialized_values = await self.redis_client.mget(keys)
# Deserialize values
values = []
for serialized_value in serialized_values:
if serialized_value is None:
values.append(None)
self.stats['misses'] += 1
else:
try:
value = self.serializer.deserialize(serialized_value)
values.append(value)
self.stats['hits'] += 1
except Exception as e:
logger.warning(f"Error deserializing value: {e}")
values.append(None)
self.stats['errors'] += 1
# Update statistics
self.stats['gets'] += len(keys)
# Update response time
response_time = asyncio.get_event_loop().time() - start_time
self._update_avg_response_time(response_time)
logger.debug(f"Multi-get: {len(keys)} keys, {sum(1 for v in values if v is not None)} hits")
return values
except Exception as e:
self.stats['errors'] += 1
logger.error(f"Error in multi-get: {e}")
return [None] * len(keys)
async def mset(self, key_value_pairs: Dict[str, Any], ttl: Optional[int] = None) -> bool:
"""
Set multiple key-value pairs.
Args:
key_value_pairs: Dictionary of key-value pairs
ttl: Time to live in seconds (None = use default per key)
Returns:
bool: True if successful, False otherwise
"""
try:
set_correlation_id()
# Serialize all values
serialized_pairs = {}
for key, value in key_value_pairs.items():
serialized_value = self.serializer.serialize(value)
serialized_pairs[key] = serialized_value
self.stats['total_data_size'] += len(serialized_value)
# Set in Redis
result = await self.redis_client.mset(serialized_pairs)
# Set TTL for each key if specified
if ttl is not None:
for key in key_value_pairs.keys():
await self.redis_client.expire(key, ttl)
else:
# Use individual TTLs
for key in key_value_pairs.keys():
key_ttl = self.cache_keys.get_ttl(key)
await self.redis_client.expire(key, key_ttl)
self.stats['sets'] += len(key_value_pairs)
logger.debug(f"Multi-set: {len(key_value_pairs)} keys")
return bool(result)
except Exception as e:
self.stats['errors'] += 1
logger.error(f"Error in multi-set: {e}")
return False
async def keys(self, pattern: str) -> List[str]:
"""
Get keys matching pattern.
Args:
pattern: Redis pattern (e.g., "hm:*")
Returns:
List[str]: List of matching keys
"""
try:
keys = await self.redis_client.keys(pattern)
return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys]
except Exception as e:
logger.error(f"Error getting keys with pattern {pattern}: {e}")
return []
async def flushdb(self) -> bool:
"""
Clear all keys in current database.
Returns:
bool: True if successful, False otherwise
"""
try:
result = await self.redis_client.flushdb()
logger.info("Redis database flushed")
return bool(result)
except Exception as e:
logger.error(f"Error flushing Redis database: {e}")
return False
async def info(self) -> Dict[str, Any]:
"""
Get Redis server information.
Returns:
Dict: Redis server info
"""
try:
info = await self.redis_client.info()
return info
except Exception as e:
logger.error(f"Error getting Redis info: {e}")
return {}
async def ping(self) -> bool:
"""
Ping Redis server.
Returns:
bool: True if server responds, False otherwise
"""
try:
result = await self.redis_client.ping()
return bool(result)
except Exception as e:
logger.error(f"Redis ping failed: {e}")
return False
async def set_heatmap(self, symbol: str, heatmap_data,
exchange: Optional[str] = None, ttl: Optional[int] = None) -> bool:
"""
Cache heatmap data with optimized serialization.
Args:
symbol: Trading symbol
heatmap_data: Heatmap data to cache
exchange: Exchange name (None for consolidated)
ttl: Time to live in seconds
Returns:
bool: True if successful, False otherwise
"""
try:
key = self.cache_keys.heatmap_key(symbol, 1.0, exchange)
# Use specialized heatmap serialization
serialized_value = self.serializer.serialize_heatmap(heatmap_data)
# Determine TTL
if ttl is None:
ttl = self.cache_keys.HEATMAP_TTL
# Set in Redis
result = await self.redis_client.setex(key, ttl, serialized_value)
# Update statistics
self.stats['sets'] += 1
self.stats['total_data_size'] += len(serialized_value)
logger.debug(f"Cached heatmap: {key} (size: {len(serialized_value)} bytes)")
return bool(result)
except Exception as e:
self.stats['errors'] += 1
logger.error(f"Error caching heatmap for {symbol}: {e}")
return False
async def get_heatmap(self, symbol: str, exchange: Optional[str] = None):
"""
Get cached heatmap data with optimized deserialization.
Args:
symbol: Trading symbol
exchange: Exchange name (None for consolidated)
Returns:
HeatmapData: Cached heatmap or None if not found
"""
try:
key = self.cache_keys.heatmap_key(symbol, 1.0, exchange)
# Get from Redis
serialized_value = await self.redis_client.get(key)
self.stats['gets'] += 1
if serialized_value is None:
self.stats['misses'] += 1
return None
# Use specialized heatmap deserialization
heatmap_data = self.serializer.deserialize_heatmap(serialized_value)
self.stats['hits'] += 1
logger.debug(f"Retrieved heatmap: {key}")
return heatmap_data
except Exception as e:
self.stats['errors'] += 1
logger.error(f"Error retrieving heatmap for {symbol}: {e}")
return None
async def cache_orderbook(self, orderbook) -> bool:
"""
Cache order book data.
Args:
orderbook: OrderBookSnapshot to cache
Returns:
bool: True if successful, False otherwise
"""
try:
key = self.cache_keys.orderbook_key(orderbook.symbol, orderbook.exchange)
return await self.set(key, orderbook)
except Exception as e:
logger.error(f"Error caching order book: {e}")
return False
async def get_orderbook(self, symbol: str, exchange: str):
"""
Get cached order book data.
Args:
symbol: Trading symbol
exchange: Exchange name
Returns:
OrderBookSnapshot: Cached order book or None if not found
"""
try:
key = self.cache_keys.orderbook_key(symbol, exchange)
return await self.get(key)
except Exception as e:
logger.error(f"Error retrieving order book: {e}")
return None
async def cache_metrics(self, metrics, symbol: str, exchange: str) -> bool:
"""
Cache metrics data.
Args:
metrics: Metrics data to cache
symbol: Trading symbol
exchange: Exchange name
Returns:
bool: True if successful, False otherwise
"""
try:
key = self.cache_keys.metrics_key(symbol, exchange)
return await self.set(key, metrics)
except Exception as e:
logger.error(f"Error caching metrics: {e}")
return False
async def get_metrics(self, symbol: str, exchange: str):
"""
Get cached metrics data.
Args:
symbol: Trading symbol
exchange: Exchange name
Returns:
Metrics data or None if not found
"""
try:
key = self.cache_keys.metrics_key(symbol, exchange)
return await self.get(key)
except Exception as e:
logger.error(f"Error retrieving metrics: {e}")
return None
async def cache_exchange_status(self, exchange: str, status_data) -> bool:
"""
Cache exchange status.
Args:
exchange: Exchange name
status_data: Status data to cache
Returns:
bool: True if successful, False otherwise
"""
try:
key = self.cache_keys.status_key(exchange)
return await self.set(key, status_data)
except Exception as e:
logger.error(f"Error caching exchange status: {e}")
return False
async def get_exchange_status(self, exchange: str):
"""
Get cached exchange status.
Args:
exchange: Exchange name
Returns:
Status data or None if not found
"""
try:
key = self.cache_keys.status_key(exchange)
return await self.get(key)
except Exception as e:
logger.error(f"Error retrieving exchange status: {e}")
return None
async def cleanup_expired_keys(self) -> int:
"""
Clean up expired keys (Redis handles this automatically, but we can force it).
Returns:
int: Number of keys cleaned up
"""
try:
# Get all keys
all_keys = await self.keys("*")
# Check which ones are expired
expired_count = 0
for key in all_keys:
ttl = await self.redis_client.ttl(key)
if ttl == -2: # Key doesn't exist (expired)
expired_count += 1
logger.debug(f"Found {expired_count} expired keys")
return expired_count
except Exception as e:
logger.error(f"Error cleaning up expired keys: {e}")
return 0
def _update_avg_response_time(self, response_time: float) -> None:
"""Update average response time"""
total_operations = self.stats['gets'] + self.stats['sets']
if total_operations > 0:
self.stats['avg_response_time'] = (
(self.stats['avg_response_time'] * (total_operations - 1) + response_time) /
total_operations
)
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
total_operations = self.stats['gets'] + self.stats['sets']
hit_rate = (self.stats['hits'] / max(self.stats['gets'], 1)) * 100
return {
**self.stats,
'total_operations': total_operations,
'hit_rate_percentage': hit_rate,
'serializer_stats': self.serializer.get_stats()
}
def reset_stats(self) -> None:
"""Reset cache statistics"""
self.stats = {
'gets': 0,
'sets': 0,
'deletes': 0,
'hits': 0,
'misses': 0,
'errors': 0,
'total_data_size': 0,
'avg_response_time': 0.0
}
self.serializer.reset_stats()
logger.info("Redis manager statistics reset")
async def health_check(self) -> Dict[str, Any]:
"""
Perform comprehensive health check.
Returns:
Dict: Health check results
"""
health = {
'redis_ping': False,
'connection_pool_size': 0,
'memory_usage': 0,
'connected_clients': 0,
'total_keys': 0,
'hit_rate': 0.0,
'avg_response_time': self.stats['avg_response_time']
}
try:
# Test ping
health['redis_ping'] = await self.ping()
# Get Redis info
info = await self.info()
if info:
health['memory_usage'] = info.get('used_memory', 0)
health['connected_clients'] = info.get('connected_clients', 0)
# Get key count
all_keys = await self.keys("*")
health['total_keys'] = len(all_keys)
# Calculate hit rate
if self.stats['gets'] > 0:
health['hit_rate'] = (self.stats['hits'] / self.stats['gets']) * 100
# Connection pool info
if self.redis_pool:
health['connection_pool_size'] = self.redis_pool.max_connections
except Exception as e:
logger.error(f"Health check error: {e}")
return health
# Global Redis manager instance
redis_manager = RedisManager()

167
COBY/config.py Normal file
View File

@ -0,0 +1,167 @@
"""
Configuration management for the multi-exchange data aggregation system.
"""
import os
from dataclasses import dataclass, field
from typing import List, Dict, Any
from pathlib import Path
@dataclass
class DatabaseConfig:
"""Database configuration settings"""
host: str = os.getenv('DB_HOST', '192.168.0.10')
port: int = int(os.getenv('DB_PORT', '5432'))
name: str = os.getenv('DB_NAME', 'market_data')
user: str = os.getenv('DB_USER', 'market_user')
password: str = os.getenv('DB_PASSWORD', 'market_data_secure_pass_2024')
schema: str = os.getenv('DB_SCHEMA', 'market_data')
pool_size: int = int(os.getenv('DB_POOL_SIZE', '10'))
max_overflow: int = int(os.getenv('DB_MAX_OVERFLOW', '20'))
pool_timeout: int = int(os.getenv('DB_POOL_TIMEOUT', '30'))
@dataclass
class RedisConfig:
"""Redis configuration settings"""
host: str = os.getenv('REDIS_HOST', '192.168.0.10')
port: int = int(os.getenv('REDIS_PORT', '6379'))
password: str = os.getenv('REDIS_PASSWORD', 'market_data_redis_2024')
db: int = int(os.getenv('REDIS_DB', '0'))
max_connections: int = int(os.getenv('REDIS_MAX_CONNECTIONS', '50'))
socket_timeout: int = int(os.getenv('REDIS_SOCKET_TIMEOUT', '5'))
socket_connect_timeout: int = int(os.getenv('REDIS_CONNECT_TIMEOUT', '5'))
@dataclass
class ExchangeConfig:
"""Exchange configuration settings"""
exchanges: List[str] = field(default_factory=lambda: [
'binance', 'coinbase', 'kraken', 'bybit', 'okx',
'huobi', 'kucoin', 'gateio', 'bitfinex', 'mexc'
])
symbols: List[str] = field(default_factory=lambda: ['BTCUSDT', 'ETHUSDT'])
max_connections_per_exchange: int = int(os.getenv('MAX_CONNECTIONS_PER_EXCHANGE', '5'))
reconnect_delay: int = int(os.getenv('RECONNECT_DELAY', '5'))
max_reconnect_attempts: int = int(os.getenv('MAX_RECONNECT_ATTEMPTS', '10'))
heartbeat_interval: int = int(os.getenv('HEARTBEAT_INTERVAL', '30'))
@dataclass
class AggregationConfig:
"""Data aggregation configuration"""
bucket_size: float = float(os.getenv('BUCKET_SIZE', '1.0')) # $1 USD buckets for all symbols
heatmap_depth: int = int(os.getenv('HEATMAP_DEPTH', '50')) # Number of price levels
update_frequency: float = float(os.getenv('UPDATE_FREQUENCY', '0.5')) # Seconds
volume_threshold: float = float(os.getenv('VOLUME_THRESHOLD', '0.01')) # Minimum volume
@dataclass
class PerformanceConfig:
"""Performance and optimization settings"""
data_buffer_size: int = int(os.getenv('DATA_BUFFER_SIZE', '10000'))
batch_write_size: int = int(os.getenv('BATCH_WRITE_SIZE', '1000'))
max_memory_usage: int = int(os.getenv('MAX_MEMORY_USAGE', '2048')) # MB
gc_threshold: float = float(os.getenv('GC_THRESHOLD', '0.8')) # 80% of max memory
processing_timeout: int = int(os.getenv('PROCESSING_TIMEOUT', '10')) # Seconds
max_queue_size: int = int(os.getenv('MAX_QUEUE_SIZE', '50000'))
@dataclass
class APIConfig:
"""API server configuration"""
host: str = os.getenv('API_HOST', '0.0.0.0')
port: int = int(os.getenv('API_PORT', '8080'))
websocket_port: int = int(os.getenv('WS_PORT', '8081'))
cors_origins: List[str] = field(default_factory=lambda: ['*'])
rate_limit: int = int(os.getenv('RATE_LIMIT', '100')) # Requests per minute
max_connections: int = int(os.getenv('MAX_WS_CONNECTIONS', '1000'))
@dataclass
class LoggingConfig:
"""Logging configuration"""
level: str = os.getenv('LOG_LEVEL', 'INFO')
format: str = os.getenv('LOG_FORMAT', '%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_path: str = os.getenv('LOG_FILE', 'logs/coby.log')
max_file_size: int = int(os.getenv('LOG_MAX_SIZE', '100')) # MB
backup_count: int = int(os.getenv('LOG_BACKUP_COUNT', '5'))
enable_correlation_id: bool = os.getenv('ENABLE_CORRELATION_ID', 'true').lower() == 'true'
@dataclass
class Config:
"""Main configuration class"""
database: DatabaseConfig = field(default_factory=DatabaseConfig)
redis: RedisConfig = field(default_factory=RedisConfig)
exchanges: ExchangeConfig = field(default_factory=ExchangeConfig)
aggregation: AggregationConfig = field(default_factory=AggregationConfig)
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
api: APIConfig = field(default_factory=APIConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
# Environment
environment: str = os.getenv('ENVIRONMENT', 'development')
debug: bool = os.getenv('DEBUG', 'false').lower() == 'true'
def __post_init__(self):
"""Post-initialization validation and setup"""
# Create logs directory if it doesn't exist
log_dir = Path(self.logging.file_path).parent
log_dir.mkdir(parents=True, exist_ok=True)
# Validate bucket sizes
if self.aggregation.btc_bucket_size <= 0:
raise ValueError("BTC bucket size must be positive")
if self.aggregation.eth_bucket_size <= 0:
raise ValueError("ETH bucket size must be positive")
def get_bucket_size(self, symbol: str = None) -> float:
"""Get bucket size (now universal $1 for all symbols)"""
return self.aggregation.bucket_size
def get_database_url(self) -> str:
"""Get database connection URL"""
return (f"postgresql://{self.database.user}:{self.database.password}"
f"@{self.database.host}:{self.database.port}/{self.database.name}")
def get_redis_url(self) -> str:
"""Get Redis connection URL"""
auth = f":{self.redis.password}@" if self.redis.password else ""
return f"redis://{auth}{self.redis.host}:{self.redis.port}/{self.redis.db}"
def to_dict(self) -> Dict[str, Any]:
"""Convert configuration to dictionary"""
return {
'database': {
'host': self.database.host,
'port': self.database.port,
'name': self.database.name,
'schema': self.database.schema,
},
'redis': {
'host': self.redis.host,
'port': self.redis.port,
'db': self.redis.db,
},
'exchanges': {
'count': len(self.exchanges.exchanges),
'symbols': self.exchanges.symbols,
},
'aggregation': {
'bucket_size': self.aggregation.bucket_size,
'heatmap_depth': self.aggregation.heatmap_depth,
},
'api': {
'host': self.api.host,
'port': self.api.port,
'websocket_port': self.api.websocket_port,
},
'environment': self.environment,
'debug': self.debug,
}
# Global configuration instance
config = Config()

View File

@ -0,0 +1,33 @@
"""
Exchange connector implementations for the COBY system.
"""
from .base_connector import BaseExchangeConnector
from .binance_connector import BinanceConnector
from .coinbase_connector import CoinbaseConnector
from .kraken_connector import KrakenConnector
from .bybit_connector import BybitConnector
from .okx_connector import OKXConnector
from .huobi_connector import HuobiConnector
from .kucoin_connector import KuCoinConnector
from .gateio_connector import GateIOConnector
from .bitfinex_connector import BitfinexConnector
from .mexc_connector import MEXCConnector
from .connection_manager import ConnectionManager
from .circuit_breaker import CircuitBreaker
__all__ = [
'BaseExchangeConnector',
'BinanceConnector',
'CoinbaseConnector',
'KrakenConnector',
'BybitConnector',
'OKXConnector',
'HuobiConnector',
'KuCoinConnector',
'GateIOConnector',
'BitfinexConnector',
'MEXCConnector',
'ConnectionManager',
'CircuitBreaker'
]

View File

@ -0,0 +1,383 @@
"""
Base exchange connector with WebSocket connection management, circuit breaker pattern,
and comprehensive error handling.
"""
import asyncio
import logging
import json
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Callable, Any
from datetime import datetime, timedelta
from enmodels.core import ConnectionStatus, OrderBookSnapshot, TradeEvent
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ConnectionError, ValidationError
from ..utils.timing import get_current_timestamp
from .connection_manager import ConnectionManager
from .circuit_breaker import CircuitBreaker, CircuitBreakerOpenError
logger = get_logger(__name__)
class BaseExchangeConnector(ExchangeConnector):
"""
Base implementation of exchange connector with common functionality.
Provides:
- WebSocket connection management
- Exponential backoff retry logic
- Circuit breaker pattern
- Health monitoring
- Message handling framework
- Subscription management
"""
def __init__(self, exchange_name: str, websocket_url: str):
"""
Initialize base exchange connector.
Args:
exchange_name: Name of the exchange
websocket_url: WebSocket URL for the exchange
"""
super().__init__(exchange_name)
self.websocket_url = websocket_url
self.websocket: Optional[websockets.WebSocketServerProtocol] = None
self.subscriptions: Dict[str, List[str]] = {} # symbol -> [subscription_types]
self.message_handlers: Dict[str, Callable] = {}
# Connection management
self.connection_manager = ConnectionManager(
name=f"{exchange_name}_connector",
max_retries=10,
initial_delay=1.0,
max_delay=300.0,
health_check_interval=30
)
# Circuit breaker
self.circuit_breaker = CircuitBreaker(
failure_threshold=5,
recovery_timeout=60,
expected_exception=Exception,
name=f"{exchange_name}_circuit"
)
# Statistics
self.message_count = 0
self.error_count = 0
self.last_message_time: Optional[datetime] = None
# Setup callbacks
self.connection_manager.on_connect = self._on_connect
self.connection_manager.on_disconnect = self._on_disconnect
self.connection_manager.on_error = self._on_error
self.connection_manager.on_health_check = self._health_check
# Message processing
self._message_queue = asyncio.Queue(maxsize=10000)
self._message_processor_task: Optional[asyncio.Task] = None
logger.info(f"Base connector initialized for {exchange_name}")
async def connect(self) -> bool:
"""Establish connection to the exchange WebSocket"""
try:
set_correlation_id()
logger.info(f"Connecting to {self.exchange_name} at {self.websocket_url}")
return await self.connection_manager.connect(self._establish_websocket_connection)
except Exception as e:
logger.error(f"Failed to connect to {self.exchange_name}: {e}")
self._notify_status_callbacks(ConnectionStatus.ERROR)
return False
async def disconnect(self) -> None:
"""Disconnect from the exchange WebSocket"""
try:
set_correlation_id()
logger.info(f"Disconnecting from {self.exchange_name}")
await self.connection_manager.disconnect(self._close_websocket_connection)
except Exception as e:
logger.error(f"Error during disconnect from {self.exchange_name}: {e}")
async def _establish_websocket_connection(self) -> None:
"""Establish WebSocket connection"""
try:
# Use circuit breaker for connection
self.websocket = await self.circuit_breaker.call_async(
websockets.connect,
self.websocket_url,
ping_interval=20,
ping_timeout=10,
close_timeout=10
)
logger.info(f"WebSocket connected to {self.exchange_name}")
# Start message processing
await self._start_message_processing()
except CircuitBreakerOpenError as e:
logger.error(f"Circuit breaker open for {self.exchange_name}: {e}")
raise ConnectionError(f"Circuit breaker open: {e}", "CIRCUIT_BREAKER_OPEN")
except Exception as e:
logger.error(f"WebSocket connection failed for {self.exchange_name}: {e}")
raise ConnectionError(f"WebSocket connection failed: {e}", "WEBSOCKET_CONNECT_FAILED")
async def _close_websocket_connection(self) -> None:
"""Close WebSocket connection"""
try:
# Stop message processing
await self._stop_message_processing()
# Close WebSocket
if self.websocket:
await self.websocket.close()
self.websocket = None
logger.info(f"WebSocket disconnected from {self.exchange_name}")
except Exception as e:
logger.warning(f"Error closing WebSocket for {self.exchange_name}: {e}")
async def _start_message_processing(self) -> None:
"""Start message processing tasks"""
if self._message_processor_task:
return
# Start message processor
self._message_processor_task = asyncio.create_task(self._message_processor())
# Start message receiver
asyncio.create_task(self._message_receiver())
logger.debug(f"Message processing started for {self.exchange_name}")
async def _stop_message_processing(self) -> None:
"""Stop message processing tasks"""
if self._message_processor_task:
self._message_processor_task.cancel()
try:
await self._message_processor_task
except asyncio.CancelledError:
pass
self._message_processor_task = None
logger.debug(f"Message processing stopped for {self.exchange_name}")
async def _message_receiver(self) -> None:
"""Receive messages from WebSocket"""
try:
while self.websocket and not self.websocket.closed:
try:
message = await asyncio.wait_for(self.websocket.recv(), timeout=30.0)
# Queue message for processing
try:
self._message_queue.put_nowait(message)
except asyncio.QueueFull:
logger.warning(f"Message queue full for {self.exchange_name}, dropping message")
except asyncio.TimeoutError:
# Send ping to keep connection alive
if self.websocket:
await self.websocket.ping()
except websockets.exceptions.ConnectionClosed:
logger.warning(f"WebSocket connection closed for {self.exchange_name}")
break
except Exception as e:
logger.error(f"Error receiving message from {self.exchange_name}: {e}")
self.error_count += 1
break
except Exception as e:
logger.error(f"Message receiver error for {self.exchange_name}: {e}")
finally:
# Mark as disconnected
self.connection_manager.is_connected = False
async def _message_processor(self) -> None:
"""Process messages from the queue"""
while True:
try:
# Get message from queue
message = await self._message_queue.get()
# Process message
await self._process_message(message)
# Update statistics
self.message_count += 1
self.last_message_time = get_current_timestamp()
# Mark task as done
self._message_queue.task_done()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error processing message for {self.exchange_name}: {e}")
self.error_count += 1
async def _process_message(self, message: str) -> None:
"""
Process incoming WebSocket message.
Args:
message: Raw message string
"""
try:
# Parse JSON message
data = json.loads(message)
# Determine message type and route to appropriate handler
message_type = self._get_message_type(data)
if message_type in self.message_handlers:
await self.message_handlers[message_type](data)
else:
logger.debug(f"Unhandled message type '{message_type}' from {self.exchange_name}")
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON message from {self.exchange_name}: {e}")
except Exception as e:
logger.error(f"Error processing message from {self.exchange_name}: {e}")
def _get_message_type(self, data: Dict) -> str:
"""
Determine message type from message data.
Override in subclasses for exchange-specific logic.
Args:
data: Parsed message data
Returns:
str: Message type identifier
"""
# Default implementation - override in subclasses
return data.get('type', 'unknown')
async def _send_message(self, message: Dict) -> bool:
"""
Send message to WebSocket.
Args:
message: Message to send
Returns:
bool: True if sent successfully, False otherwise
"""
try:
if not self.websocket or self.websocket.closed:
logger.warning(f"Cannot send message to {self.exchange_name}: not connected")
return False
message_str = json.dumps(message)
await self.websocket.send(message_str)
logger.debug(f"Sent message to {self.exchange_name}: {message_str[:100]}...")
return True
except Exception as e:
logger.error(f"Error sending message to {self.exchange_name}: {e}")
return False
# Callback handlers
async def _on_connect(self) -> None:
"""Handle successful connection"""
self._notify_status_callbacks(ConnectionStatus.CONNECTED)
# Resubscribe to all previous subscriptions
await self._resubscribe_all()
async def _on_disconnect(self) -> None:
"""Handle disconnection"""
self._notify_status_callbacks(ConnectionStatus.DISCONNECTED)
async def _on_error(self, error: Exception) -> None:
"""Handle connection error"""
logger.error(f"Connection error for {self.exchange_name}: {error}")
self._notify_status_callbacks(ConnectionStatus.ERROR)
async def _health_check(self) -> bool:
"""Perform health check"""
try:
if not self.websocket or self.websocket.closed:
return False
# Check if we've received messages recently
if self.last_message_time:
time_since_last_message = (get_current_timestamp() - self.last_message_time).total_seconds()
if time_since_last_message > 60: # No messages for 60 seconds
logger.warning(f"No messages received from {self.exchange_name} for {time_since_last_message}s")
return False
# Send ping
await self.websocket.ping()
return True
except Exception as e:
logger.error(f"Health check failed for {self.exchange_name}: {e}")
return False
async def _resubscribe_all(self) -> None:
"""Resubscribe to all previous subscriptions after reconnection"""
for symbol, subscription_types in self.subscriptions.items():
for sub_type in subscription_types:
try:
if sub_type == 'orderbook':
await self.subscribe_orderbook(symbol)
elif sub_type == 'trades':
await self.subscribe_trades(symbol)
except Exception as e:
logger.error(f"Failed to resubscribe to {sub_type} for {symbol}: {e}")
# Abstract methods that must be implemented by subclasses
async def subscribe_orderbook(self, symbol: str) -> None:
"""Subscribe to order book updates - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement subscribe_orderbook")
async def subscribe_trades(self, symbol: str) -> None:
"""Subscribe to trade updates - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement subscribe_trades")
async def unsubscribe_orderbook(self, symbol: str) -> None:
"""Unsubscribe from order book updates - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement unsubscribe_orderbook")
async def unsubscribe_trades(self, symbol: str) -> None:
"""Unsubscribe from trade updates - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement unsubscribe_trades")
async def get_symbols(self) -> List[str]:
"""Get available symbols - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement get_symbols")
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol format - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement normalize_symbol")
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
"""Get order book snapshot - must be implemented by subclasses"""
raise NotImplementedError("Subclasses must implement get_orderbook_snapshot")
# Utility methods
def get_stats(self) -> Dict[str, Any]:
"""Get connector statistics"""
return {
'exchange': self.exchange_name,
'connection_status': self.get_connection_status().value,
'is_connected': self.is_connected,
'message_count': self.message_count,
'error_count': self.error_count,
'last_message_time': self.last_message_time.isoformat() if self.last_message_time else None,
'subscriptions': dict(self.subscriptions),
'connection_manager': self.connection_manager.get_stats(),
'circuit_breaker': self.circuit_breaker.get_stats(),
'queue_size': self._message_queue.qsize()
}

View File

@ -0,0 +1,489 @@
"""
Binance exchange connector implementation.
"""
import json
from typing import Dict, List, Optional, Any
from datetime import datetime, timezone
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ValidationError
from ..utils.validation import validate_symbol, validate_price, validate_volume
from .base_connector import BaseExchangeConnector
logger = get_logger(__name__)
class BinanceConnector(BaseExchangeConnector):
"""
Binance WebSocket connector implementation.
Supports:
- Order book depth streams
- Trade streams
- Symbol normalization
- Real-time data processing
"""
# Binance WebSocket URLs
WEBSOCKET_URL = "wss://stream.binance.com:9443/ws"
API_URL = "https://api.binance.com/api/v3"
def __init__(self):
"""Initialize Binance connector"""
super().__init__("binance", self.WEBSOCKET_URL)
# Binance-specific message handlers
self.message_handlers.update({
'depthUpdate': self._handle_orderbook_update,
'trade': self._handle_trade_update,
'error': self._handle_error_message
})
# Stream management
self.active_streams: List[str] = []
self.stream_id = 1
logger.info("Binance connector initialized")
def _get_message_type(self, data: Dict) -> str:
"""
Determine message type from Binance message data.
Args:
data: Parsed message data
Returns:
str: Message type identifier
"""
# Binance uses 'e' field for event type
if 'e' in data:
return data['e']
# Handle error messages
if 'error' in data:
return 'error'
# Handle subscription confirmations
if 'result' in data and 'id' in data:
return 'subscription_response'
return 'unknown'
def normalize_symbol(self, symbol: str) -> str:
"""
Normalize symbol to Binance format.
Args:
symbol: Standard symbol format (e.g., 'BTCUSDT')
Returns:
str: Binance symbol format (e.g., 'BTCUSDT')
"""
# Binance uses uppercase symbols without separators
normalized = symbol.upper().replace('-', '').replace('/', '')
# Validate symbol format
if not validate_symbol(normalized):
raise ValidationError(f"Invalid symbol format: {symbol}", "INVALID_SYMBOL")
return normalized
async def subscribe_orderbook(self, symbol: str) -> None:
"""
Subscribe to order book depth updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
normalized_symbol = self.normalize_symbol(symbol)
stream_name = f"{normalized_symbol.lower()}@depth@100ms"
# Create subscription message
subscription_msg = {
"method": "SUBSCRIBE",
"params": [stream_name],
"id": self.stream_id
}
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'orderbook' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('orderbook')
self.active_streams.append(stream_name)
self.stream_id += 1
logger.info(f"Subscribed to order book for {symbol} on Binance")
else:
logger.error(f"Failed to subscribe to order book for {symbol} on Binance")
except Exception as e:
logger.error(f"Error subscribing to order book for {symbol}: {e}")
raise
async def subscribe_trades(self, symbol: str) -> None:
"""
Subscribe to trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
normalized_symbol = self.normalize_symbol(symbol)
stream_name = f"{normalized_symbol.lower()}@trade"
# Create subscription message
subscription_msg = {
"method": "SUBSCRIBE",
"params": [stream_name],
"id": self.stream_id
}
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'trades' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('trades')
self.active_streams.append(stream_name)
self.stream_id += 1
logger.info(f"Subscribed to trades for {symbol} on Binance")
else:
logger.error(f"Failed to subscribe to trades for {symbol} on Binance")
except Exception as e:
logger.error(f"Error subscribing to trades for {symbol}: {e}")
raise
async def unsubscribe_orderbook(self, symbol: str) -> None:
"""
Unsubscribe from order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
normalized_symbol = self.normalize_symbol(symbol)
stream_name = f"{normalized_symbol.lower()}@depth@100ms"
# Create unsubscription message
unsubscription_msg = {
"method": "UNSUBSCRIBE",
"params": [stream_name],
"id": self.stream_id
}
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'orderbook' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('orderbook')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
if stream_name in self.active_streams:
self.active_streams.remove(stream_name)
self.stream_id += 1
logger.info(f"Unsubscribed from order book for {symbol} on Binance")
else:
logger.error(f"Failed to unsubscribe from order book for {symbol} on Binance")
except Exception as e:
logger.error(f"Error unsubscribing from order book for {symbol}: {e}")
raise
async def unsubscribe_trades(self, symbol: str) -> None:
"""
Unsubscribe from trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
normalized_symbol = self.normalize_symbol(symbol)
stream_name = f"{normalized_symbol.lower()}@trade"
# Create unsubscription message
unsubscription_msg = {
"method": "UNSUBSCRIBE",
"params": [stream_name],
"id": self.stream_id
}
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'trades' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('trades')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
if stream_name in self.active_streams:
self.active_streams.remove(stream_name)
self.stream_id += 1
logger.info(f"Unsubscribed from trades for {symbol} on Binance")
else:
logger.error(f"Failed to unsubscribe from trades for {symbol} on Binance")
except Exception as e:
logger.error(f"Error unsubscribing from trades for {symbol}: {e}")
raise
async def get_symbols(self) -> List[str]:
"""
Get list of available trading symbols from Binance.
Returns:
List[str]: List of available symbols
"""
try:
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.API_URL}/exchangeInfo") as response:
if response.status == 200:
data = await response.json()
symbols = [
symbol_info['symbol']
for symbol_info in data.get('symbols', [])
if symbol_info.get('status') == 'TRADING'
]
logger.info(f"Retrieved {len(symbols)} symbols from Binance")
return symbols
else:
logger.error(f"Failed to get symbols from Binance: HTTP {response.status}")
return []
except Exception as e:
logger.error(f"Error getting symbols from Binance: {e}")
return []
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
"""
Get current order book snapshot from Binance REST API.
Args:
symbol: Trading symbol
depth: Number of price levels to retrieve
Returns:
OrderBookSnapshot: Current order book or None if unavailable
"""
try:
import aiohttp
normalized_symbol = self.normalize_symbol(symbol)
# Binance supports depths: 5, 10, 20, 50, 100, 500, 1000, 5000
valid_depths = [5, 10, 20, 50, 100, 500, 1000, 5000]
api_depth = min(valid_depths, key=lambda x: abs(x - depth))
url = f"{self.API_URL}/depth"
params = {
'symbol': normalized_symbol,
'limit': api_depth
}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
if response.status == 200:
data = await response.json()
return self._parse_orderbook_snapshot(data, symbol)
else:
logger.error(f"Failed to get order book for {symbol}: HTTP {response.status}")
return None
except Exception as e:
logger.error(f"Error getting order book snapshot for {symbol}: {e}")
return None
def _parse_orderbook_snapshot(self, data: Dict, symbol: str) -> OrderBookSnapshot:
"""
Parse Binance order book data into OrderBookSnapshot.
Args:
data: Raw Binance order book data
symbol: Trading symbol
Returns:
OrderBookSnapshot: Parsed order book
"""
try:
# Parse bids and asks
bids = []
for bid_data in data.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in data.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.now(timezone.utc),
bids=bids,
asks=asks,
sequence_id=data.get('lastUpdateId')
)
return orderbook
except Exception as e:
logger.error(f"Error parsing order book snapshot: {e}")
raise ValidationError(f"Invalid order book data: {e}", "PARSE_ERROR")
async def _handle_orderbook_update(self, data: Dict) -> None:
"""
Handle order book depth update from Binance.
Args:
data: Order book update data
"""
try:
set_correlation_id()
# Extract symbol from stream name
stream = data.get('s', '').upper()
if not stream:
logger.warning("Order book update missing symbol")
return
# Parse bids and asks
bids = []
for bid_data in data.get('b', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in data.get('a', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=stream,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(data.get('E', 0) / 1000, tz=timezone.utc),
bids=bids,
asks=asks,
sequence_id=data.get('u') # Final update ID
)
# Notify callbacks
self._notify_data_callbacks(orderbook)
logger.debug(f"Processed order book update for {stream}")
except Exception as e:
logger.error(f"Error handling order book update: {e}")
async def _handle_trade_update(self, data: Dict) -> None:
"""
Handle trade update from Binance.
Args:
data: Trade update data
"""
try:
set_correlation_id()
# Extract trade data
symbol = data.get('s', '').upper()
if not symbol:
logger.warning("Trade update missing symbol")
return
price = float(data.get('p', 0))
size = float(data.get('q', 0))
# Validate data
if not validate_price(price) or not validate_volume(size):
logger.warning(f"Invalid trade data: price={price}, size={size}")
return
# Determine side (Binance uses 'm' field - true if buyer is market maker)
is_buyer_maker = data.get('m', False)
side = 'sell' if is_buyer_maker else 'buy'
# Create trade event
trade = TradeEvent(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(data.get('T', 0) / 1000, tz=timezone.utc),
price=price,
size=size,
side=side,
trade_id=str(data.get('t', ''))
)
# Notify callbacks
self._notify_data_callbacks(trade)
logger.debug(f"Processed trade for {symbol}: {side} {size} @ {price}")
except Exception as e:
logger.error(f"Error handling trade update: {e}")
async def _handle_error_message(self, data: Dict) -> None:
"""
Handle error message from Binance.
Args:
data: Error message data
"""
error_code = data.get('code', 'unknown')
error_msg = data.get('msg', 'Unknown error')
logger.error(f"Binance error {error_code}: {error_msg}")
# Handle specific error codes
if error_code == -1121: # Invalid symbol
logger.error("Invalid symbol error - check symbol format")
elif error_code == -1130: # Invalid listen key
logger.error("Invalid listen key - may need to reconnect")
def get_binance_stats(self) -> Dict[str, Any]:
"""Get Binance-specific statistics"""
base_stats = self.get_stats()
binance_stats = {
'active_streams': len(self.active_streams),
'stream_list': self.active_streams.copy(),
'next_stream_id': self.stream_id
}
base_stats.update(binance_stats)
return base_stats

View File

@ -0,0 +1,270 @@
"""
Bitfinex exchange connector implementation.
Supports WebSocket connections to Bitfinex with proper channel subscription management.
"""
import json
from typing import Dict, List, Optional, Any
from datetime import datetime, timezone
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ValidationError, ConnectionError
from ..utils.validation import validate_symbol, validate_price, validate_volume
from .base_connector import BaseExchangeConnector
logger = get_logger(__name__)
class BitfinexConnector(BaseExchangeConnector):
"""
Bitfinex WebSocket connector implementation.
Supports:
- Channel subscription management
- Order book streams
- Trade streams
- Symbol normalization
"""
# Bitfinex WebSocket URLs
WEBSOCKET_URL = "wss://api-pub.bitfinex.com/ws/2"
API_URL = "https://api-pub.bitfinex.com"
def __init__(self, api_key: str = None, api_secret: str = None):
"""Initialize Bitfinex connector."""
super().__init__("bitfinex", self.WEBSOCKET_URL)
self.api_key = api_key
self.api_secret = api_secret
# Bitfinex-specific message handlers
self.message_handlers.update({
'subscribed': self._handle_subscription_response,
'unsubscribed': self._handle_unsubscription_response,
'error': self._handle_error_message,
'info': self._handle_info_message
})
# Channel management
self.channels = {} # channel_id -> channel_info
self.subscribed_symbols = set()
logger.info("Bitfinex connector initialized")
def _get_message_type(self, data) -> str:
"""Determine message type from Bitfinex message data."""
if isinstance(data, dict):
if 'event' in data:
return data['event']
elif 'error' in data:
return 'error'
elif isinstance(data, list) and len(data) >= 2:
# Data message format: [CHANNEL_ID, data]
return 'data'
return 'unknown'
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol to Bitfinex format."""
# Bitfinex uses 't' prefix for trading pairs
if symbol.upper() == 'BTCUSDT':
return 'tBTCUSD'
elif symbol.upper() == 'ETHUSDT':
return 'tETHUSD'
elif symbol.upper().endswith('USDT'):
base = symbol[:-4].upper()
return f"t{base}USD"
else:
# Generic conversion
normalized = symbol.upper().replace('-', '').replace('/', '')
return f"t{normalized}" if not normalized.startswith('t') else normalized
def _denormalize_symbol(self, bitfinex_symbol: str) -> str:
"""Convert Bitfinex symbol back to standard format."""
if bitfinex_symbol.startswith('t'):
symbol = bitfinex_symbol[1:] # Remove 't' prefix
if symbol.endswith('USD'):
return symbol[:-3] + 'USDT'
return symbol
return bitfinex_symbol
async def subscribe_orderbook(self, symbol: str) -> None:
"""Subscribe to order book updates for a symbol."""
try:
set_correlation_id()
bitfinex_symbol = self.normalize_symbol(symbol)
subscription_msg = {
"event": "subscribe",
"channel": "book",
"symbol": bitfinex_symbol,
"prec": "P0",
"freq": "F0",
"len": "25"
}
success = await self._send_message(subscription_msg)
if success:
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'orderbook' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('orderbook')
self.subscribed_symbols.add(bitfinex_symbol)
logger.info(f"Subscribed to order book for {symbol} ({bitfinex_symbol}) on Bitfinex")
else:
logger.error(f"Failed to subscribe to order book for {symbol} on Bitfinex")
except Exception as e:
logger.error(f"Error subscribing to order book for {symbol}: {e}")
raise
async def subscribe_trades(self, symbol: str) -> None:
"""Subscribe to trade updates for a symbol."""
try:
set_correlation_id()
bitfinex_symbol = self.normalize_symbol(symbol)
subscription_msg = {
"event": "subscribe",
"channel": "trades",
"symbol": bitfinex_symbol
}
success = await self._send_message(subscription_msg)
if success:
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'trades' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('trades')
self.subscribed_symbols.add(bitfinex_symbol)
logger.info(f"Subscribed to trades for {symbol} ({bitfinex_symbol}) on Bitfinex")
else:
logger.error(f"Failed to subscribe to trades for {symbol} on Bitfinex")
except Exception as e:
logger.error(f"Error subscribing to trades for {symbol}: {e}")
raise
async def unsubscribe_orderbook(self, symbol: str) -> None:
"""Unsubscribe from order book updates."""
# Implementation would find the channel ID and send unsubscribe message
pass
async def unsubscribe_trades(self, symbol: str) -> None:
"""Unsubscribe from trade updates."""
# Implementation would find the channel ID and send unsubscribe message
pass
async def get_symbols(self) -> List[str]:
"""Get available symbols from Bitfinex."""
try:
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.API_URL}/v1/symbols") as response:
if response.status == 200:
data = await response.json()
symbols = [self._denormalize_symbol(f"t{s.upper()}") for s in data]
logger.info(f"Retrieved {len(symbols)} symbols from Bitfinex")
return symbols
else:
logger.error(f"Failed to get symbols from Bitfinex: HTTP {response.status}")
return []
except Exception as e:
logger.error(f"Error getting symbols from Bitfinex: {e}")
return []
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
"""Get order book snapshot from Bitfinex REST API."""
try:
import aiohttp
bitfinex_symbol = self.normalize_symbol(symbol)
url = f"{self.API_URL}/v2/book/{bitfinex_symbol}/P0"
params = {'len': min(depth, 100)}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
if response.status == 200:
data = await response.json()
return self._parse_orderbook_snapshot(data, symbol)
else:
logger.error(f"Failed to get order book for {symbol}: HTTP {response.status}")
return None
except Exception as e:
logger.error(f"Error getting order book snapshot for {symbol}: {e}")
return None
def _parse_orderbook_snapshot(self, data: List, symbol: str) -> OrderBookSnapshot:
"""Parse Bitfinex order book data."""
try:
bids = []
asks = []
for level in data:
price = float(level[0])
count = int(level[1])
amount = float(level[2])
if validate_price(price) and validate_volume(abs(amount)):
if amount > 0:
bids.append(PriceLevel(price=price, size=amount))
else:
asks.append(PriceLevel(price=price, size=abs(amount)))
return OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.now(timezone.utc),
bids=bids,
asks=asks
)
except Exception as e:
logger.error(f"Error parsing order book snapshot: {e}")
raise ValidationError(f"Invalid order book data: {e}", "PARSE_ERROR")
async def _handle_subscription_response(self, data: Dict) -> None:
"""Handle subscription response."""
channel_id = data.get('chanId')
channel = data.get('channel')
symbol = data.get('symbol', '')
if channel_id:
self.channels[channel_id] = {
'channel': channel,
'symbol': symbol
}
logger.info(f"Bitfinex subscription confirmed: {channel} for {symbol} (ID: {channel_id})")
async def _handle_unsubscription_response(self, data: Dict) -> None:
"""Handle unsubscription response."""
channel_id = data.get('chanId')
if channel_id in self.channels:
del self.channels[channel_id]
logger.info(f"Bitfinex unsubscription confirmed for channel {channel_id}")
async def _handle_error_message(self, data: Dict) -> None:
"""Handle error message."""
error_msg = data.get('msg', 'Unknown error')
error_code = data.get('code', 'unknown')
logger.error(f"Bitfinex error {error_code}: {error_msg}")
async def _handle_info_message(self, data: Dict) -> None:
"""Handle info message."""
logger.info(f"Bitfinex info: {data}")
def get_bitfinex_stats(self) -> Dict[str, Any]:
"""Get Bitfinex-specific statistics."""
base_stats = self.get_stats()
bitfinex_stats = {
'active_channels': len(self.channels),
'subscribed_symbols': list(self.subscribed_symbols),
'authenticated': bool(self.api_key and self.api_secret)
}
base_stats.update(bitfinex_stats)
return base_stats

View File

@ -0,0 +1,605 @@
"""
Bybit exchange connector implementation.
Supports WebSocket connections to Bybit with unified trading account support.
"""
import json
import hmac
import hashlib
import time
from typing import Dict, List, Optional, Any
from datetime import datetime, timezone
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ValidationError, ConnectionError
from ..utils.validation import validate_symbol, validate_price, validate_volume
from .base_connector import BaseExchangeConnector
logger = get_logger(__name__)
class BybitConnector(BaseExchangeConnector):
"""
Bybit WebSocket connector implementation.
Supports:
- Unified Trading Account (UTA) WebSocket streams
- Order book streams
- Trade streams
- Symbol normalization
- Authentication for private channels
"""
# Bybit WebSocket URLs
WEBSOCKET_URL = "wss://stream.bybit.com/v5/public/spot"
WEBSOCKET_PRIVATE_URL = "wss://stream.bybit.com/v5/private"
TESTNET_URL = "wss://stream-testnet.bybit.com/v5/public/spot"
API_URL = "https://api.bybit.com"
def __init__(self, use_testnet: bool = False, api_key: str = None, api_secret: str = None):
"""
Initialize Bybit connector.
Args:
use_testnet: Whether to use testnet environment
api_key: API key for authentication (optional)
api_secret: API secret for authentication (optional)
"""
websocket_url = self.TESTNET_URL if use_testnet else self.WEBSOCKET_URL
super().__init__("bybit", websocket_url)
# Authentication credentials (optional)
self.api_key = api_key
self.api_secret = api_secret
self.use_testnet = use_testnet
# Bybit-specific message handlers
self.message_handlers.update({
'orderbook': self._handle_orderbook_update,
'publicTrade': self._handle_trade_update,
'pong': self._handle_pong,
'subscribe': self._handle_subscription_response
})
# Subscription tracking
self.subscribed_topics = set()
self.req_id = 1
logger.info(f"Bybit connector initialized ({'testnet' if use_testnet else 'mainnet'})")
def _get_message_type(self, data: Dict) -> str:
"""
Determine message type from Bybit message data.
Args:
data: Parsed message data
Returns:
str: Message type identifier
"""
# Bybit V5 API message format
if 'topic' in data:
topic = data['topic']
if 'orderbook' in topic:
return 'orderbook'
elif 'publicTrade' in topic:
return 'publicTrade'
else:
return topic
elif 'op' in data:
return data['op'] # 'subscribe', 'unsubscribe', 'ping', 'pong'
elif 'success' in data:
return 'response'
return 'unknown'
def normalize_symbol(self, symbol: str) -> str:
"""
Normalize symbol to Bybit format.
Args:
symbol: Standard symbol format (e.g., 'BTCUSDT')
Returns:
str: Bybit symbol format (e.g., 'BTCUSDT')
"""
# Bybit uses uppercase symbols without separators (same as Binance)
normalized = symbol.upper().replace('-', '').replace('/', '')
# Validate symbol format
if not validate_symbol(normalized):
raise ValidationError(f"Invalid symbol format: {symbol}", "INVALID_SYMBOL")
return normalized
async def subscribe_orderbook(self, symbol: str) -> None:
"""
Subscribe to order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
normalized_symbol = self.normalize_symbol(symbol)
topic = f"orderbook.50.{normalized_symbol}"
# Create subscription message
subscription_msg = {
"op": "subscribe",
"args": [topic],
"req_id": str(self.req_id)
}
self.req_id += 1
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'orderbook' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('orderbook')
self.subscribed_topics.add(topic)
logger.info(f"Subscribed to order book for {symbol} on Bybit")
else:
logger.error(f"Failed to subscribe to order book for {symbol} on Bybit")
except Exception as e:
logger.error(f"Error subscribing to order book for {symbol}: {e}")
raise
async def subscribe_trades(self, symbol: str) -> None:
"""
Subscribe to trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
normalized_symbol = self.normalize_symbol(symbol)
topic = f"publicTrade.{normalized_symbol}"
# Create subscription message
subscription_msg = {
"op": "subscribe",
"args": [topic],
"req_id": str(self.req_id)
}
self.req_id += 1
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'trades' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('trades')
self.subscribed_topics.add(topic)
logger.info(f"Subscribed to trades for {symbol} on Bybit")
else:
logger.error(f"Failed to subscribe to trades for {symbol} on Bybit")
except Exception as e:
logger.error(f"Error subscribing to trades for {symbol}: {e}")
raise
async def unsubscribe_orderbook(self, symbol: str) -> None:
"""
Unsubscribe from order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
normalized_symbol = self.normalize_symbol(symbol)
topic = f"orderbook.50.{normalized_symbol}"
# Create unsubscription message
unsubscription_msg = {
"op": "unsubscribe",
"args": [topic],
"req_id": str(self.req_id)
}
self.req_id += 1
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'orderbook' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('orderbook')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
self.subscribed_topics.discard(topic)
logger.info(f"Unsubscribed from order book for {symbol} on Bybit")
else:
logger.error(f"Failed to unsubscribe from order book for {symbol} on Bybit")
except Exception as e:
logger.error(f"Error unsubscribing from order book for {symbol}: {e}")
raise
async def unsubscribe_trades(self, symbol: str) -> None:
"""
Unsubscribe from trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
normalized_symbol = self.normalize_symbol(symbol)
topic = f"publicTrade.{normalized_symbol}"
# Create unsubscription message
unsubscription_msg = {
"op": "unsubscribe",
"args": [topic],
"req_id": str(self.req_id)
}
self.req_id += 1
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'trades' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('trades')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
self.subscribed_topics.discard(topic)
logger.info(f"Unsubscribed from trades for {symbol} on Bybit")
else:
logger.error(f"Failed to unsubscribe from trades for {symbol} on Bybit")
except Exception as e:
logger.error(f"Error unsubscribing from trades for {symbol}: {e}")
raise
async def get_symbols(self) -> List[str]:
"""
Get list of available trading symbols from Bybit.
Returns:
List[str]: List of available symbols
"""
try:
import aiohttp
api_url = "https://api-testnet.bybit.com" if self.use_testnet else self.API_URL
async with aiohttp.ClientSession() as session:
async with session.get(f"{api_url}/v5/market/instruments-info",
params={"category": "spot"}) as response:
if response.status == 200:
data = await response.json()
if data.get('retCode') != 0:
logger.error(f"Bybit API error: {data.get('retMsg')}")
return []
symbols = []
instruments = data.get('result', {}).get('list', [])
for instrument in instruments:
if instrument.get('status') == 'Trading':
symbol = instrument.get('symbol', '')
symbols.append(symbol)
logger.info(f"Retrieved {len(symbols)} symbols from Bybit")
return symbols
else:
logger.error(f"Failed to get symbols from Bybit: HTTP {response.status}")
return []
except Exception as e:
logger.error(f"Error getting symbols from Bybit: {e}")
return []
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
"""
Get current order book snapshot from Bybit REST API.
Args:
symbol: Trading symbol
depth: Number of price levels to retrieve
Returns:
OrderBookSnapshot: Current order book or None if unavailable
"""
try:
import aiohttp
normalized_symbol = self.normalize_symbol(symbol)
api_url = "https://api-testnet.bybit.com" if self.use_testnet else self.API_URL
# Bybit supports depths: 1, 25, 50, 100, 200
valid_depths = [1, 25, 50, 100, 200]
api_depth = min(valid_depths, key=lambda x: abs(x - depth))
url = f"{api_url}/v5/market/orderbook"
params = {
'category': 'spot',
'symbol': normalized_symbol,
'limit': api_depth
}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
if response.status == 200:
data = await response.json()
if data.get('retCode') != 0:
logger.error(f"Bybit API error: {data.get('retMsg')}")
return None
result = data.get('result', {})
return self._parse_orderbook_snapshot(result, symbol)
else:
logger.error(f"Failed to get order book for {symbol}: HTTP {response.status}")
return None
except Exception as e:
logger.error(f"Error getting order book snapshot for {symbol}: {e}")
return None
def _parse_orderbook_snapshot(self, data: Dict, symbol: str) -> OrderBookSnapshot:
"""
Parse Bybit order book data into OrderBookSnapshot.
Args:
data: Raw Bybit order book data
symbol: Trading symbol
Returns:
OrderBookSnapshot: Parsed order book
"""
try:
# Parse bids and asks
bids = []
for bid_data in data.get('b', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in data.get('a', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(int(data.get('ts', 0)) / 1000, tz=timezone.utc),
bids=bids,
asks=asks,
sequence_id=int(data.get('u', 0))
)
return orderbook
except Exception as e:
logger.error(f"Error parsing order book snapshot: {e}")
raise ValidationError(f"Invalid order book data: {e}", "PARSE_ERROR")
async def _handle_orderbook_update(self, data: Dict) -> None:
"""
Handle order book update from Bybit.
Args:
data: Order book update data
"""
try:
set_correlation_id()
# Extract symbol from topic
topic = data.get('topic', '')
if not topic.startswith('orderbook'):
logger.warning("Invalid orderbook topic")
return
# Extract symbol from topic: orderbook.50.BTCUSDT
parts = topic.split('.')
if len(parts) < 3:
logger.warning("Invalid orderbook topic format")
return
symbol = parts[2]
orderbook_data = data.get('data', {})
# Parse bids and asks
bids = []
for bid_data in orderbook_data.get('b', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in orderbook_data.get('a', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(int(data.get('ts', 0)) / 1000, tz=timezone.utc),
bids=bids,
asks=asks,
sequence_id=int(orderbook_data.get('u', 0))
)
# Notify callbacks
self._notify_data_callbacks(orderbook)
logger.debug(f"Processed order book update for {symbol}")
except Exception as e:
logger.error(f"Error handling order book update: {e}")
async def _handle_trade_update(self, data: Dict) -> None:
"""
Handle trade update from Bybit.
Args:
data: Trade update data
"""
try:
set_correlation_id()
# Extract symbol from topic
topic = data.get('topic', '')
if not topic.startswith('publicTrade'):
logger.warning("Invalid trade topic")
return
# Extract symbol from topic: publicTrade.BTCUSDT
parts = topic.split('.')
if len(parts) < 2:
logger.warning("Invalid trade topic format")
return
symbol = parts[1]
trades_data = data.get('data', [])
# Process each trade
for trade_data in trades_data:
price = float(trade_data.get('p', 0))
size = float(trade_data.get('v', 0))
# Validate data
if not validate_price(price) or not validate_volume(size):
logger.warning(f"Invalid trade data: price={price}, size={size}")
continue
# Determine side (Bybit uses 'S' field)
side_flag = trade_data.get('S', '')
side = 'buy' if side_flag == 'Buy' else 'sell'
# Create trade event
trade = TradeEvent(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(int(trade_data.get('T', 0)) / 1000, tz=timezone.utc),
price=price,
size=size,
side=side,
trade_id=str(trade_data.get('i', ''))
)
# Notify callbacks
self._notify_data_callbacks(trade)
logger.debug(f"Processed trade for {symbol}: {side} {size} @ {price}")
except Exception as e:
logger.error(f"Error handling trade update: {e}")
async def _handle_subscription_response(self, data: Dict) -> None:
"""
Handle subscription response from Bybit.
Args:
data: Subscription response data
"""
try:
success = data.get('success', False)
req_id = data.get('req_id', '')
op = data.get('op', '')
if success:
logger.info(f"Bybit {op} successful (req_id: {req_id})")
else:
ret_msg = data.get('ret_msg', 'Unknown error')
logger.error(f"Bybit {op} failed: {ret_msg} (req_id: {req_id})")
except Exception as e:
logger.error(f"Error handling subscription response: {e}")
async def _handle_pong(self, data: Dict) -> None:
"""
Handle pong response from Bybit.
Args:
data: Pong response data
"""
logger.debug("Received Bybit pong")
def _get_auth_signature(self, timestamp: str, recv_window: str = "5000") -> str:
"""
Generate authentication signature for Bybit.
Args:
timestamp: Current timestamp
recv_window: Receive window
Returns:
str: Authentication signature
"""
if not self.api_key or not self.api_secret:
return ""
try:
param_str = f"GET/realtime{timestamp}{self.api_key}{recv_window}"
signature = hmac.new(
self.api_secret.encode('utf-8'),
param_str.encode('utf-8'),
hashlib.sha256
).hexdigest()
return signature
except Exception as e:
logger.error(f"Error generating auth signature: {e}")
return ""
async def _send_ping(self) -> None:
"""Send ping to keep connection alive."""
try:
ping_msg = {
"op": "ping",
"req_id": str(self.req_id)
}
self.req_id += 1
await self._send_message(ping_msg)
logger.debug("Sent ping to Bybit")
except Exception as e:
logger.error(f"Error sending ping: {e}")
def get_bybit_stats(self) -> Dict[str, Any]:
"""Get Bybit-specific statistics."""
base_stats = self.get_stats()
bybit_stats = {
'subscribed_topics': list(self.subscribed_topics),
'use_testnet': self.use_testnet,
'authenticated': bool(self.api_key and self.api_secret),
'next_req_id': self.req_id
}
base_stats.update(bybit_stats)
return base_stats

View File

@ -0,0 +1,206 @@
"""
Circuit breaker pattern implementation for exchange connections.
"""
import time
from enum import Enum
from typing import Optional, Callable, Any
from ..utils.logging import get_logger
logger = get_logger(__name__)
class CircuitState(Enum):
"""Circuit breaker states"""
CLOSED = "closed" # Normal operation
OPEN = "open" # Circuit is open, calls fail fast
HALF_OPEN = "half_open" # Testing if service is back
class CircuitBreaker:
"""
Circuit breaker to prevent cascading failures in exchange connections.
States:
- CLOSED: Normal operation, requests pass through
- OPEN: Circuit is open, requests fail immediately
- HALF_OPEN: Testing if service is back, limited requests allowed
"""
def __init__(
self,
failure_threshold: int = 5,
recovery_timeout: int = 60,
expected_exception: type = Exception,
name: str = "CircuitBreaker"
):
"""
Initialize circuit breaker.
Args:
failure_threshold: Number of failures before opening circuit
recovery_timeout: Time in seconds before attempting recovery
expected_exception: Exception type that triggers circuit breaker
name: Name for logging purposes
"""
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout
self.expected_exception = expected_exception
self.name = name
# State tracking
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time: Optional[float] = None
self._next_attempt_time: Optional[float] = None
logger.info(f"Circuit breaker '{name}' initialized with threshold={failure_threshold}")
@property
def state(self) -> CircuitState:
"""Get current circuit state"""
return self._state
@property
def failure_count(self) -> int:
"""Get current failure count"""
return self._failure_count
def _should_attempt_reset(self) -> bool:
"""Check if we should attempt to reset the circuit"""
if self._state != CircuitState.OPEN:
return False
if self._next_attempt_time is None:
return False
return time.time() >= self._next_attempt_time
def _on_success(self) -> None:
"""Handle successful operation"""
if self._state == CircuitState.HALF_OPEN:
logger.info(f"Circuit breaker '{self.name}' reset to CLOSED after successful test")
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time = None
self._next_attempt_time = None
def _on_failure(self) -> None:
"""Handle failed operation"""
self._failure_count += 1
self._last_failure_time = time.time()
if self._state == CircuitState.HALF_OPEN:
# Failed during test, go back to OPEN
logger.warning(f"Circuit breaker '{self.name}' failed during test, returning to OPEN")
self._state = CircuitState.OPEN
self._next_attempt_time = time.time() + self.recovery_timeout
elif self._failure_count >= self.failure_threshold:
# Too many failures, open the circuit
logger.error(
f"Circuit breaker '{self.name}' OPENED after {self._failure_count} failures"
)
self._state = CircuitState.OPEN
self._next_attempt_time = time.time() + self.recovery_timeout
def call(self, func: Callable, *args, **kwargs) -> Any:
"""
Execute function with circuit breaker protection.
Args:
func: Function to execute
*args: Function arguments
**kwargs: Function keyword arguments
Returns:
Function result
Raises:
CircuitBreakerOpenError: When circuit is open
Original exception: When function fails
"""
# Check if we should attempt reset
if self._should_attempt_reset():
logger.info(f"Circuit breaker '{self.name}' attempting reset to HALF_OPEN")
self._state = CircuitState.HALF_OPEN
# Fail fast if circuit is open
if self._state == CircuitState.OPEN:
raise CircuitBreakerOpenError(
f"Circuit breaker '{self.name}' is OPEN. "
f"Next attempt in {self._next_attempt_time - time.time():.1f}s"
)
try:
# Execute the function
result = func(*args, **kwargs)
self._on_success()
return result
except self.expected_exception as e:
self._on_failure()
raise e
async def call_async(self, func: Callable, *args, **kwargs) -> Any:
"""
Execute async function with circuit breaker protection.
Args:
func: Async function to execute
*args: Function arguments
**kwargs: Function keyword arguments
Returns:
Function result
Raises:
CircuitBreakerOpenError: When circuit is open
Original exception: When function fails
"""
# Check if we should attempt reset
if self._should_attempt_reset():
logger.info(f"Circuit breaker '{self.name}' attempting reset to HALF_OPEN")
self._state = CircuitState.HALF_OPEN
# Fail fast if circuit is open
if self._state == CircuitState.OPEN:
raise CircuitBreakerOpenError(
f"Circuit breaker '{self.name}' is OPEN. "
f"Next attempt in {self._next_attempt_time - time.time():.1f}s"
)
try:
# Execute the async function
result = await func(*args, **kwargs)
self._on_success()
return result
except self.expected_exception as e:
self._on_failure()
raise e
def reset(self) -> None:
"""Manually reset the circuit breaker"""
logger.info(f"Circuit breaker '{self.name}' manually reset")
self._state = CircuitState.CLOSED
self._failure_count = 0
self._last_failure_time = None
self._next_attempt_time = None
def get_stats(self) -> dict:
"""Get circuit breaker statistics"""
return {
'name': self.name,
'state': self._state.value,
'failure_count': self._failure_count,
'failure_threshold': self.failure_threshold,
'last_failure_time': self._last_failure_time,
'next_attempt_time': self._next_attempt_time,
'recovery_timeout': self.recovery_timeout
}
class CircuitBreakerOpenError(Exception):
"""Exception raised when circuit breaker is open"""
pass

View File

@ -0,0 +1,650 @@
"""
Coinbase Pro exchange connector implementation.
Supports WebSocket connections to Coinbase Pro (now Coinbase Advanced Trade).
"""
import json
import hmac
import hashlib
import base64
import time
from typing import Dict, List, Optional, Any
from datetime import datetime, timezone
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ValidationError, ConnectionError
from ..utils.validation import validate_symbol, validate_price, validate_volume
from .base_connector import BaseExchangeConnector
logger = get_logger(__name__)
class CoinbaseConnector(BaseExchangeConnector):
"""
Coinbase Pro WebSocket connector implementation.
Supports:
- Order book level2 streams
- Trade streams (matches)
- Symbol normalization
- Authentication for private channels (if needed)
"""
# Coinbase Pro WebSocket URLs
WEBSOCKET_URL = "wss://ws-feed.exchange.coinbase.com"
SANDBOX_URL = "wss://ws-feed-public.sandbox.exchange.coinbase.com"
API_URL = "https://api.exchange.coinbase.com"
def __init__(self, use_sandbox: bool = False, api_key: str = None,
api_secret: str = None, passphrase: str = None):
"""
Initialize Coinbase connector.
Args:
use_sandbox: Whether to use sandbox environment
api_key: API key for authentication (optional)
api_secret: API secret for authentication (optional)
passphrase: API passphrase for authentication (optional)
"""
websocket_url = self.SANDBOX_URL if use_sandbox else self.WEBSOCKET_URL
super().__init__("coinbase", websocket_url)
# Authentication credentials (optional)
self.api_key = api_key
self.api_secret = api_secret
self.passphrase = passphrase
self.use_sandbox = use_sandbox
# Coinbase-specific message handlers
self.message_handlers.update({
'l2update': self._handle_orderbook_update,
'match': self._handle_trade_update,
'snapshot': self._handle_orderbook_snapshot,
'error': self._handle_error_message,
'subscriptions': self._handle_subscription_response
})
# Channel management
self.subscribed_channels = set()
self.product_ids = set()
logger.info(f"Coinbase connector initialized ({'sandbox' if use_sandbox else 'production'})")
def _get_message_type(self, data: Dict) -> str:
"""
Determine message type from Coinbase message data.
Args:
data: Parsed message data
Returns:
str: Message type identifier
"""
# Coinbase uses 'type' field for message type
return data.get('type', 'unknown')
def normalize_symbol(self, symbol: str) -> str:
"""
Normalize symbol to Coinbase format.
Args:
symbol: Standard symbol format (e.g., 'BTCUSDT')
Returns:
str: Coinbase product ID format (e.g., 'BTC-USD')
"""
# Convert standard format to Coinbase product ID
if symbol.upper() == 'BTCUSDT':
return 'BTC-USD'
elif symbol.upper() == 'ETHUSDT':
return 'ETH-USD'
elif symbol.upper() == 'ADAUSDT':
return 'ADA-USD'
elif symbol.upper() == 'DOTUSDT':
return 'DOT-USD'
elif symbol.upper() == 'LINKUSDT':
return 'LINK-USD'
else:
# Generic conversion: BTCUSDT -> BTC-USD
if symbol.endswith('USDT'):
base = symbol[:-4]
return f"{base}-USD"
elif symbol.endswith('USD'):
base = symbol[:-3]
return f"{base}-USD"
else:
# Assume it's already in correct format or try to parse
if '-' in symbol:
return symbol.upper()
else:
# Default fallback
return symbol.upper()
def _denormalize_symbol(self, product_id: str) -> str:
"""
Convert Coinbase product ID back to standard format.
Args:
product_id: Coinbase product ID (e.g., 'BTC-USD')
Returns:
str: Standard symbol format (e.g., 'BTCUSDT')
"""
if '-' in product_id:
base, quote = product_id.split('-', 1)
if quote == 'USD':
return f"{base}USDT"
else:
return f"{base}{quote}"
return product_id
async def subscribe_orderbook(self, symbol: str) -> None:
"""
Subscribe to order book level2 updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
product_id = self.normalize_symbol(symbol)
# Create subscription message
subscription_msg = {
"type": "subscribe",
"product_ids": [product_id],
"channels": ["level2"]
}
# Add authentication if credentials provided
if self.api_key and self.api_secret and self.passphrase:
subscription_msg.update(self._get_auth_headers(subscription_msg))
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'orderbook' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('orderbook')
self.subscribed_channels.add('level2')
self.product_ids.add(product_id)
logger.info(f"Subscribed to order book for {symbol} ({product_id}) on Coinbase")
else:
logger.error(f"Failed to subscribe to order book for {symbol} on Coinbase")
except Exception as e:
logger.error(f"Error subscribing to order book for {symbol}: {e}")
raise
async def subscribe_trades(self, symbol: str) -> None:
"""
Subscribe to trade updates (matches) for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
product_id = self.normalize_symbol(symbol)
# Create subscription message
subscription_msg = {
"type": "subscribe",
"product_ids": [product_id],
"channels": ["matches"]
}
# Add authentication if credentials provided
if self.api_key and self.api_secret and self.passphrase:
subscription_msg.update(self._get_auth_headers(subscription_msg))
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'trades' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('trades')
self.subscribed_channels.add('matches')
self.product_ids.add(product_id)
logger.info(f"Subscribed to trades for {symbol} ({product_id}) on Coinbase")
else:
logger.error(f"Failed to subscribe to trades for {symbol} on Coinbase")
except Exception as e:
logger.error(f"Error subscribing to trades for {symbol}: {e}")
raise
async def unsubscribe_orderbook(self, symbol: str) -> None:
"""
Unsubscribe from order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
product_id = self.normalize_symbol(symbol)
# Create unsubscription message
unsubscription_msg = {
"type": "unsubscribe",
"product_ids": [product_id],
"channels": ["level2"]
}
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'orderbook' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('orderbook')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
self.product_ids.discard(product_id)
logger.info(f"Unsubscribed from order book for {symbol} ({product_id}) on Coinbase")
else:
logger.error(f"Failed to unsubscribe from order book for {symbol} on Coinbase")
except Exception as e:
logger.error(f"Error unsubscribing from order book for {symbol}: {e}")
raise
async def unsubscribe_trades(self, symbol: str) -> None:
"""
Unsubscribe from trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
product_id = self.normalize_symbol(symbol)
# Create unsubscription message
unsubscription_msg = {
"type": "unsubscribe",
"product_ids": [product_id],
"channels": ["matches"]
}
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'trades' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('trades')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
self.product_ids.discard(product_id)
logger.info(f"Unsubscribed from trades for {symbol} ({product_id}) on Coinbase")
else:
logger.error(f"Failed to unsubscribe from trades for {symbol} on Coinbase")
except Exception as e:
logger.error(f"Error unsubscribing from trades for {symbol}: {e}")
raise
async def get_symbols(self) -> List[str]:
"""
Get list of available trading symbols from Coinbase.
Returns:
List[str]: List of available symbols in standard format
"""
try:
import aiohttp
api_url = "https://api-public.sandbox.exchange.coinbase.com" if self.use_sandbox else self.API_URL
async with aiohttp.ClientSession() as session:
async with session.get(f"{api_url}/products") as response:
if response.status == 200:
data = await response.json()
symbols = []
for product in data:
if product.get('status') == 'online' and product.get('trading_disabled') is False:
product_id = product.get('id', '')
# Convert to standard format
standard_symbol = self._denormalize_symbol(product_id)
symbols.append(standard_symbol)
logger.info(f"Retrieved {len(symbols)} symbols from Coinbase")
return symbols
else:
logger.error(f"Failed to get symbols from Coinbase: HTTP {response.status}")
return []
except Exception as e:
logger.error(f"Error getting symbols from Coinbase: {e}")
return []
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
"""
Get current order book snapshot from Coinbase REST API.
Args:
symbol: Trading symbol
depth: Number of price levels to retrieve (Coinbase supports up to 50)
Returns:
OrderBookSnapshot: Current order book or None if unavailable
"""
try:
import aiohttp
product_id = self.normalize_symbol(symbol)
api_url = "https://api-public.sandbox.exchange.coinbase.com" if self.use_sandbox else self.API_URL
# Coinbase supports level 1, 2, or 3
level = 2 # Level 2 gives us aggregated order book
url = f"{api_url}/products/{product_id}/book"
params = {'level': level}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
if response.status == 200:
data = await response.json()
return self._parse_orderbook_snapshot(data, symbol)
else:
logger.error(f"Failed to get order book for {symbol}: HTTP {response.status}")
return None
except Exception as e:
logger.error(f"Error getting order book snapshot for {symbol}: {e}")
return None
def _parse_orderbook_snapshot(self, data: Dict, symbol: str) -> OrderBookSnapshot:
"""
Parse Coinbase order book data into OrderBookSnapshot.
Args:
data: Raw Coinbase order book data
symbol: Trading symbol
Returns:
OrderBookSnapshot: Parsed order book
"""
try:
# Parse bids and asks
bids = []
for bid_data in data.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in data.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.now(timezone.utc),
bids=bids,
asks=asks,
sequence_id=data.get('sequence')
)
return orderbook
except Exception as e:
logger.error(f"Error parsing order book snapshot: {e}")
raise ValidationError(f"Invalid order book data: {e}", "PARSE_ERROR")
def _get_auth_headers(self, message: Dict) -> Dict[str, str]:
"""
Generate authentication headers for Coinbase Pro API.
Args:
message: Message to authenticate
Returns:
Dict: Authentication headers
"""
if not all([self.api_key, self.api_secret, self.passphrase]):
return {}
try:
timestamp = str(time.time())
message_str = json.dumps(message, separators=(',', ':'))
# Create signature
message_to_sign = timestamp + 'GET' + '/users/self/verify' + message_str
signature = base64.b64encode(
hmac.new(
base64.b64decode(self.api_secret),
message_to_sign.encode('utf-8'),
hashlib.sha256
).digest()
).decode('utf-8')
return {
'CB-ACCESS-KEY': self.api_key,
'CB-ACCESS-SIGN': signature,
'CB-ACCESS-TIMESTAMP': timestamp,
'CB-ACCESS-PASSPHRASE': self.passphrase
}
except Exception as e:
logger.error(f"Error generating auth headers: {e}")
return {}
async def _handle_orderbook_snapshot(self, data: Dict) -> None:
"""
Handle order book snapshot from Coinbase.
Args:
data: Order book snapshot data
"""
try:
set_correlation_id()
product_id = data.get('product_id', '')
if not product_id:
logger.warning("Order book snapshot missing product_id")
return
symbol = self._denormalize_symbol(product_id)
# Parse bids and asks
bids = []
for bid_data in data.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in data.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.now(timezone.utc),
bids=bids,
asks=asks,
sequence_id=data.get('sequence')
)
# Notify callbacks
self._notify_data_callbacks(orderbook)
logger.debug(f"Processed order book snapshot for {symbol}")
except Exception as e:
logger.error(f"Error handling order book snapshot: {e}")
async def _handle_orderbook_update(self, data: Dict) -> None:
"""
Handle order book level2 update from Coinbase.
Args:
data: Order book update data
"""
try:
set_correlation_id()
product_id = data.get('product_id', '')
if not product_id:
logger.warning("Order book update missing product_id")
return
symbol = self._denormalize_symbol(product_id)
# Coinbase l2update format: changes array with [side, price, size]
changes = data.get('changes', [])
bids = []
asks = []
for change in changes:
if len(change) >= 3:
side = change[0] # 'buy' or 'sell'
price = float(change[1])
size = float(change[2])
if validate_price(price) and validate_volume(size):
if side == 'buy':
bids.append(PriceLevel(price=price, size=size))
elif side == 'sell':
asks.append(PriceLevel(price=price, size=size))
# Create order book update (partial snapshot)
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromisoformat(data.get('time', '').replace('Z', '+00:00')),
bids=bids,
asks=asks,
sequence_id=data.get('sequence')
)
# Notify callbacks
self._notify_data_callbacks(orderbook)
logger.debug(f"Processed order book update for {symbol}")
except Exception as e:
logger.error(f"Error handling order book update: {e}")
async def _handle_trade_update(self, data: Dict) -> None:
"""
Handle trade (match) update from Coinbase.
Args:
data: Trade update data
"""
try:
set_correlation_id()
product_id = data.get('product_id', '')
if not product_id:
logger.warning("Trade update missing product_id")
return
symbol = self._denormalize_symbol(product_id)
price = float(data.get('price', 0))
size = float(data.get('size', 0))
# Validate data
if not validate_price(price) or not validate_volume(size):
logger.warning(f"Invalid trade data: price={price}, size={size}")
return
# Determine side (Coinbase uses 'side' field for taker side)
side = data.get('side', 'unknown') # 'buy' or 'sell'
# Create trade event
trade = TradeEvent(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromisoformat(data.get('time', '').replace('Z', '+00:00')),
price=price,
size=size,
side=side,
trade_id=str(data.get('trade_id', ''))
)
# Notify callbacks
self._notify_data_callbacks(trade)
logger.debug(f"Processed trade for {symbol}: {side} {size} @ {price}")
except Exception as e:
logger.error(f"Error handling trade update: {e}")
async def _handle_subscription_response(self, data: Dict) -> None:
"""
Handle subscription confirmation from Coinbase.
Args:
data: Subscription response data
"""
try:
channels = data.get('channels', [])
logger.info(f"Coinbase subscription confirmed for channels: {channels}")
except Exception as e:
logger.error(f"Error handling subscription response: {e}")
async def _handle_error_message(self, data: Dict) -> None:
"""
Handle error message from Coinbase.
Args:
data: Error message data
"""
message = data.get('message', 'Unknown error')
reason = data.get('reason', '')
logger.error(f"Coinbase error: {message}")
if reason:
logger.error(f"Coinbase error reason: {reason}")
# Handle specific error types
if 'Invalid signature' in message:
logger.error("Authentication failed - check API credentials")
elif 'Product not found' in message:
logger.error("Invalid product ID - check symbol mapping")
def get_coinbase_stats(self) -> Dict[str, Any]:
"""Get Coinbase-specific statistics."""
base_stats = self.get_stats()
coinbase_stats = {
'subscribed_channels': list(self.subscribed_channels),
'product_ids': list(self.product_ids),
'use_sandbox': self.use_sandbox,
'authenticated': bool(self.api_key and self.api_secret and self.passphrase)
}
base_stats.update(coinbase_stats)
return base_stats

View File

@ -0,0 +1,271 @@
"""
Connection management with exponential backoff and retry logic.
"""
import asyncio
import random
from typing import Optional, Callable, Any
from ..utils.logging import get_logger
from ..utils.exceptions import ConnectionError
logger = get_logger(__name__)
class ExponentialBackoff:
"""Exponential backoff strategy for connection retries"""
def __init__(
self,
initial_delay: float = 1.0,
max_delay: float = 300.0,
multiplier: float = 2.0,
jitter: bool = True
):
"""
Initialize exponential backoff.
Args:
initial_delay: Initial delay in seconds
max_delay: Maximum delay in seconds
multiplier: Backoff multiplier
jitter: Whether to add random jitter
"""
self.initial_delay = initial_delay
self.max_delay = max_delay
self.multiplier = multiplier
self.jitter = jitter
self.current_delay = initial_delay
self.attempt_count = 0
def get_delay(self) -> float:
"""Get next delay value"""
delay = min(self.current_delay, self.max_delay)
# Add jitter to prevent thundering herd
if self.jitter:
delay = delay * (0.5 + random.random() * 0.5)
# Update for next attempt
self.current_delay *= self.multiplier
self.attempt_count += 1
return delay
def reset(self) -> None:
"""Reset backoff to initial state"""
self.current_delay = self.initial_delay
self.attempt_count = 0
class ConnectionManager:
"""
Manages connection lifecycle with retry logic and health monitoring.
"""
def __init__(
self,
name: str,
max_retries: int = 10,
initial_delay: float = 1.0,
max_delay: float = 300.0,
health_check_interval: int = 30
):
"""
Initialize connection manager.
Args:
name: Connection name for logging
max_retries: Maximum number of retry attempts
initial_delay: Initial retry delay in seconds
max_delay: Maximum retry delay in seconds
health_check_interval: Health check interval in seconds
"""
self.name = name
self.max_retries = max_retries
self.health_check_interval = health_check_interval
self.backoff = ExponentialBackoff(initial_delay, max_delay)
self.is_connected = False
self.connection_attempts = 0
self.last_error: Optional[Exception] = None
self.health_check_task: Optional[asyncio.Task] = None
# Callbacks
self.on_connect: Optional[Callable] = None
self.on_disconnect: Optional[Callable] = None
self.on_error: Optional[Callable] = None
self.on_health_check: Optional[Callable] = None
logger.info(f"Connection manager '{name}' initialized")
async def connect(self, connect_func: Callable) -> bool:
"""
Attempt to establish connection with retry logic.
Args:
connect_func: Async function that establishes the connection
Returns:
bool: True if connection successful, False otherwise
"""
self.connection_attempts = 0
self.backoff.reset()
while self.connection_attempts < self.max_retries:
try:
logger.info(f"Attempting to connect '{self.name}' (attempt {self.connection_attempts + 1})")
# Attempt connection
await connect_func()
# Connection successful
self.is_connected = True
self.connection_attempts = 0
self.last_error = None
self.backoff.reset()
logger.info(f"Connection '{self.name}' established successfully")
# Start health check
await self._start_health_check()
# Notify success
if self.on_connect:
try:
await self.on_connect()
except Exception as e:
logger.warning(f"Error in connect callback: {e}")
return True
except Exception as e:
self.connection_attempts += 1
self.last_error = e
logger.warning(
f"Connection '{self.name}' failed (attempt {self.connection_attempts}): {e}"
)
# Notify error
if self.on_error:
try:
await self.on_error(e)
except Exception as callback_error:
logger.warning(f"Error in error callback: {callback_error}")
# Check if we should retry
if self.connection_attempts >= self.max_retries:
logger.error(f"Connection '{self.name}' failed after {self.max_retries} attempts")
break
# Wait before retry
delay = self.backoff.get_delay()
logger.info(f"Retrying connection '{self.name}' in {delay:.1f} seconds")
await asyncio.sleep(delay)
self.is_connected = False
return False
async def disconnect(self, disconnect_func: Optional[Callable] = None) -> None:
"""
Disconnect and cleanup.
Args:
disconnect_func: Optional async function to handle disconnection
"""
logger.info(f"Disconnecting '{self.name}'")
# Stop health check
await self._stop_health_check()
# Execute disconnect function
if disconnect_func:
try:
await disconnect_func()
except Exception as e:
logger.warning(f"Error during disconnect: {e}")
self.is_connected = False
# Notify disconnect
if self.on_disconnect:
try:
await self.on_disconnect()
except Exception as e:
logger.warning(f"Error in disconnect callback: {e}")
logger.info(f"Connection '{self.name}' disconnected")
async def reconnect(self, connect_func: Callable, disconnect_func: Optional[Callable] = None) -> bool:
"""
Reconnect by disconnecting first then connecting.
Args:
connect_func: Async function that establishes the connection
disconnect_func: Optional async function to handle disconnection
Returns:
bool: True if reconnection successful, False otherwise
"""
logger.info(f"Reconnecting '{self.name}'")
# Disconnect first
await self.disconnect(disconnect_func)
# Wait a bit before reconnecting
await asyncio.sleep(1.0)
# Attempt to connect
return await self.connect(connect_func)
async def _start_health_check(self) -> None:
"""Start periodic health check"""
if self.health_check_task:
return
self.health_check_task = asyncio.create_task(self._health_check_loop())
logger.debug(f"Health check started for '{self.name}'")
async def _stop_health_check(self) -> None:
"""Stop health check"""
if self.health_check_task:
self.health_check_task.cancel()
try:
await self.health_check_task
except asyncio.CancelledError:
pass
self.health_check_task = None
logger.debug(f"Health check stopped for '{self.name}'")
async def _health_check_loop(self) -> None:
"""Health check loop"""
while self.is_connected:
try:
await asyncio.sleep(self.health_check_interval)
if self.on_health_check:
is_healthy = await self.on_health_check()
if not is_healthy:
logger.warning(f"Health check failed for '{self.name}'")
self.is_connected = False
break
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Health check error for '{self.name}': {e}")
self.is_connected = False
break
def get_stats(self) -> dict:
"""Get connection statistics"""
return {
'name': self.name,
'is_connected': self.is_connected,
'connection_attempts': self.connection_attempts,
'max_retries': self.max_retries,
'current_delay': self.backoff.current_delay,
'backoff_attempts': self.backoff.attempt_count,
'last_error': str(self.last_error) if self.last_error else None,
'health_check_active': self.health_check_task is not None
}

View File

@ -0,0 +1,601 @@
"""
Gate.io exchange connector implementation.
Supports WebSocket connections to Gate.io with their WebSocket v4 API.
"""
import json
import hmac
import hashlib
import time
from typing import Dict, List, Optional, Any
from datetime import datetime, timezone
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ValidationError, ConnectionError
from ..utils.validation import validate_symbol, validate_price, validate_volume
from .base_connector import BaseExchangeConnector
logger = get_logger(__name__)
class GateIOConnector(BaseExchangeConnector):
"""
Gate.io WebSocket connector implementation.
Supports:
- WebSocket v4 API
- Order book streams
- Trade streams
- Symbol normalization
- Authentication for private channels
"""
# Gate.io WebSocket URLs
WEBSOCKET_URL = "wss://api.gateio.ws/ws/v4/"
TESTNET_URL = "wss://fx-api-testnet.gateio.ws/ws/v4/"
API_URL = "https://api.gateio.ws"
def __init__(self, use_testnet: bool = False, api_key: str = None, api_secret: str = None):
"""
Initialize Gate.io connector.
Args:
use_testnet: Whether to use testnet environment
api_key: API key for authentication (optional)
api_secret: API secret for authentication (optional)
"""
websocket_url = self.TESTNET_URL if use_testnet else self.WEBSOCKET_URL
super().__init__("gateio", websocket_url)
# Authentication credentials (optional)
self.api_key = api_key
self.api_secret = api_secret
self.use_testnet = use_testnet
# Gate.io-specific message handlers
self.message_handlers.update({
'spot.order_book_update': self._handle_orderbook_update,
'spot.trades': self._handle_trade_update,
'spot.pong': self._handle_pong,
'error': self._handle_error_message
})
# Subscription tracking
self.subscribed_channels = set()
self.request_id = 1
logger.info(f"Gate.io connector initialized ({'testnet' if use_testnet else 'mainnet'})")
def _get_message_type(self, data: Dict) -> str:
"""
Determine message type from Gate.io message data.
Args:
data: Parsed message data
Returns:
str: Message type identifier
"""
# Gate.io v4 API message format
if 'method' in data:
return data['method'] # 'spot.order_book_update', 'spot.trades', etc.
elif 'error' in data:
return 'error'
elif 'result' in data:
return 'result'
return 'unknown'
def normalize_symbol(self, symbol: str) -> str:
"""
Normalize symbol to Gate.io format.
Args:
symbol: Standard symbol format (e.g., 'BTCUSDT')
Returns:
str: Gate.io symbol format (e.g., 'BTC_USDT')
"""
# Gate.io uses underscore-separated format
if symbol.upper() == 'BTCUSDT':
return 'BTC_USDT'
elif symbol.upper() == 'ETHUSDT':
return 'ETH_USDT'
elif symbol.upper().endswith('USDT'):
base = symbol[:-4].upper()
return f"{base}_USDT"
elif symbol.upper().endswith('USD'):
base = symbol[:-3].upper()
return f"{base}_USD"
else:
# Assume it's already in correct format or add underscore
if '_' not in symbol:
# Try to split common patterns
if len(symbol) >= 6:
# Assume last 4 chars are quote currency
base = symbol[:-4].upper()
quote = symbol[-4:].upper()
return f"{base}_{quote}"
else:
return symbol.upper()
else:
return symbol.upper()
def _denormalize_symbol(self, gateio_symbol: str) -> str:
"""
Convert Gate.io symbol back to standard format.
Args:
gateio_symbol: Gate.io symbol format (e.g., 'BTC_USDT')
Returns:
str: Standard symbol format (e.g., 'BTCUSDT')
"""
if '_' in gateio_symbol:
return gateio_symbol.replace('_', '')
return gateio_symbol
async def subscribe_orderbook(self, symbol: str) -> None:
"""
Subscribe to order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
gateio_symbol = self.normalize_symbol(symbol)
# Create subscription message
subscription_msg = {
"method": "spot.order_book",
"params": [gateio_symbol, 20, "0"], # symbol, limit, interval
"id": self.request_id
}
self.request_id += 1
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'orderbook' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('orderbook')
self.subscribed_channels.add(f"spot.order_book:{gateio_symbol}")
logger.info(f"Subscribed to order book for {symbol} ({gateio_symbol}) on Gate.io")
else:
logger.error(f"Failed to subscribe to order book for {symbol} on Gate.io")
except Exception as e:
logger.error(f"Error subscribing to order book for {symbol}: {e}")
raise
async def subscribe_trades(self, symbol: str) -> None:
"""
Subscribe to trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
gateio_symbol = self.normalize_symbol(symbol)
# Create subscription message
subscription_msg = {
"method": "spot.trades",
"params": [gateio_symbol],
"id": self.request_id
}
self.request_id += 1
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'trades' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('trades')
self.subscribed_channels.add(f"spot.trades:{gateio_symbol}")
logger.info(f"Subscribed to trades for {symbol} ({gateio_symbol}) on Gate.io")
else:
logger.error(f"Failed to subscribe to trades for {symbol} on Gate.io")
except Exception as e:
logger.error(f"Error subscribing to trades for {symbol}: {e}")
raise
async def unsubscribe_orderbook(self, symbol: str) -> None:
"""
Unsubscribe from order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
gateio_symbol = self.normalize_symbol(symbol)
# Create unsubscription message
unsubscription_msg = {
"method": "spot.unsubscribe",
"params": [f"spot.order_book", gateio_symbol],
"id": self.request_id
}
self.request_id += 1
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'orderbook' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('orderbook')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
self.subscribed_channels.discard(f"spot.order_book:{gateio_symbol}")
logger.info(f"Unsubscribed from order book for {symbol} ({gateio_symbol}) on Gate.io")
else:
logger.error(f"Failed to unsubscribe from order book for {symbol} on Gate.io")
except Exception as e:
logger.error(f"Error unsubscribing from order book for {symbol}: {e}")
raise
async def unsubscribe_trades(self, symbol: str) -> None:
"""
Unsubscribe from trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
gateio_symbol = self.normalize_symbol(symbol)
# Create unsubscription message
unsubscription_msg = {
"method": "spot.unsubscribe",
"params": ["spot.trades", gateio_symbol],
"id": self.request_id
}
self.request_id += 1
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'trades' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('trades')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
self.subscribed_channels.discard(f"spot.trades:{gateio_symbol}")
logger.info(f"Unsubscribed from trades for {symbol} ({gateio_symbol}) on Gate.io")
else:
logger.error(f"Failed to unsubscribe from trades for {symbol} on Gate.io")
except Exception as e:
logger.error(f"Error unsubscribing from trades for {symbol}: {e}")
raise
async def get_symbols(self) -> List[str]:
"""
Get list of available trading symbols from Gate.io.
Returns:
List[str]: List of available symbols in standard format
"""
try:
import aiohttp
api_url = "https://fx-api-testnet.gateio.ws" if self.use_testnet else self.API_URL
async with aiohttp.ClientSession() as session:
async with session.get(f"{api_url}/api/v4/spot/currency_pairs") as response:
if response.status == 200:
data = await response.json()
symbols = []
for pair_info in data:
if pair_info.get('trade_status') == 'tradable':
pair_id = pair_info.get('id', '')
# Convert to standard format
standard_symbol = self._denormalize_symbol(pair_id)
symbols.append(standard_symbol)
logger.info(f"Retrieved {len(symbols)} symbols from Gate.io")
return symbols
else:
logger.error(f"Failed to get symbols from Gate.io: HTTP {response.status}")
return []
except Exception as e:
logger.error(f"Error getting symbols from Gate.io: {e}")
return []
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
"""
Get current order book snapshot from Gate.io REST API.
Args:
symbol: Trading symbol
depth: Number of price levels to retrieve
Returns:
OrderBookSnapshot: Current order book or None if unavailable
"""
try:
import aiohttp
gateio_symbol = self.normalize_symbol(symbol)
api_url = "https://fx-api-testnet.gateio.ws" if self.use_testnet else self.API_URL
# Gate.io supports various depths
api_depth = min(depth, 100)
url = f"{api_url}/api/v4/spot/order_book"
params = {
'currency_pair': gateio_symbol,
'limit': api_depth
}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
if response.status == 200:
data = await response.json()
return self._parse_orderbook_snapshot(data, symbol)
else:
logger.error(f"Failed to get order book for {symbol}: HTTP {response.status}")
return None
except Exception as e:
logger.error(f"Error getting order book snapshot for {symbol}: {e}")
return None
def _parse_orderbook_snapshot(self, data: Dict, symbol: str) -> OrderBookSnapshot:
"""
Parse Gate.io order book data into OrderBookSnapshot.
Args:
data: Raw Gate.io order book data
symbol: Trading symbol
Returns:
OrderBookSnapshot: Parsed order book
"""
try:
# Parse bids and asks
bids = []
for bid_data in data.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in data.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.now(timezone.utc), # Gate.io doesn't provide timestamp in snapshot
bids=bids,
asks=asks,
sequence_id=data.get('id')
)
return orderbook
except Exception as e:
logger.error(f"Error parsing order book snapshot: {e}")
raise ValidationError(f"Invalid order book data: {e}", "PARSE_ERROR")
async def _handle_orderbook_update(self, data: Dict) -> None:
"""
Handle order book update from Gate.io.
Args:
data: Order book update data
"""
try:
set_correlation_id()
params = data.get('params', [])
if len(params) < 2:
logger.warning("Invalid order book update format")
return
# Gate.io format: [symbol, order_book_data]
gateio_symbol = params[0]
symbol = self._denormalize_symbol(gateio_symbol)
book_data = params[1]
# Parse bids and asks
bids = []
for bid_data in book_data.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in book_data.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(int(book_data.get('t', 0)) / 1000, tz=timezone.utc),
bids=bids,
asks=asks,
sequence_id=book_data.get('id')
)
# Notify callbacks
self._notify_data_callbacks(orderbook)
logger.debug(f"Processed order book update for {symbol}")
except Exception as e:
logger.error(f"Error handling order book update: {e}")
async def _handle_trade_update(self, data: Dict) -> None:
"""
Handle trade update from Gate.io.
Args:
data: Trade update data
"""
try:
set_correlation_id()
params = data.get('params', [])
if len(params) < 2:
logger.warning("Invalid trade update format")
return
# Gate.io format: [symbol, [trade_data]]
gateio_symbol = params[0]
symbol = self._denormalize_symbol(gateio_symbol)
trades_data = params[1]
# Process each trade
for trade_data in trades_data:
price = float(trade_data.get('price', 0))
amount = float(trade_data.get('amount', 0))
# Validate data
if not validate_price(price) or not validate_volume(amount):
logger.warning(f"Invalid trade data: price={price}, amount={amount}")
continue
# Determine side (Gate.io uses 'side' field)
side = trade_data.get('side', 'unknown').lower()
# Create trade event
trade = TradeEvent(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(int(trade_data.get('time', 0)), tz=timezone.utc),
price=price,
size=amount,
side=side,
trade_id=str(trade_data.get('id', ''))
)
# Notify callbacks
self._notify_data_callbacks(trade)
logger.debug(f"Processed trade for {symbol}: {side} {amount} @ {price}")
except Exception as e:
logger.error(f"Error handling trade update: {e}")
async def _handle_pong(self, data: Dict) -> None:
"""
Handle pong response from Gate.io.
Args:
data: Pong response data
"""
logger.debug("Received Gate.io pong")
async def _handle_error_message(self, data: Dict) -> None:
"""
Handle error message from Gate.io.
Args:
data: Error message data
"""
error_info = data.get('error', {})
code = error_info.get('code', 'unknown')
message = error_info.get('message', 'Unknown error')
logger.error(f"Gate.io error {code}: {message}")
def _get_auth_signature(self, method: str, url: str, query_string: str,
payload: str, timestamp: str) -> str:
"""
Generate authentication signature for Gate.io.
Args:
method: HTTP method
url: Request URL
query_string: Query string
payload: Request payload
timestamp: Request timestamp
Returns:
str: Authentication signature
"""
if not self.api_key or not self.api_secret:
return ""
try:
# Create signature string
message = f"{method}\n{url}\n{query_string}\n{hashlib.sha512(payload.encode()).hexdigest()}\n{timestamp}"
# Generate signature
signature = hmac.new(
self.api_secret.encode('utf-8'),
message.encode('utf-8'),
hashlib.sha512
).hexdigest()
return signature
except Exception as e:
logger.error(f"Error generating auth signature: {e}")
return ""
async def _send_ping(self) -> None:
"""Send ping to keep connection alive."""
try:
ping_msg = {
"method": "spot.ping",
"params": [],
"id": self.request_id
}
self.request_id += 1
await self._send_message(ping_msg)
logger.debug("Sent ping to Gate.io")
except Exception as e:
logger.error(f"Error sending ping: {e}")
def get_gateio_stats(self) -> Dict[str, Any]:
"""Get Gate.io-specific statistics."""
base_stats = self.get_stats()
gateio_stats = {
'subscribed_channels': list(self.subscribed_channels),
'use_testnet': self.use_testnet,
'authenticated': bool(self.api_key and self.api_secret),
'next_request_id': self.request_id
}
base_stats.update(gateio_stats)
return base_stats

View File

@ -0,0 +1,660 @@
"""
Huobi Global exchange connector implementation.
Supports WebSocket connections to Huobi with proper symbol mapping.
"""
import json
import gzip
import hmac
import hashlib
import base64
import time
from typing import Dict, List, Optional, Any
from datetime import datetime, timezone
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ValidationError, ConnectionError
from ..utils.validation import validate_symbol, validate_price, validate_volume
from .base_connector import BaseExchangeConnector
logger = get_logger(__name__)
class HuobiConnector(BaseExchangeConnector):
"""
Huobi Global WebSocket connector implementation.
Supports:
- Order book streams
- Trade streams
- Symbol normalization
- GZIP message decompression
- Authentication for private channels
"""
# Huobi WebSocket URLs
WEBSOCKET_URL = "wss://api.huobi.pro/ws"
WEBSOCKET_PRIVATE_URL = "wss://api.huobi.pro/ws/v2"
API_URL = "https://api.huobi.pro"
def __init__(self, api_key: str = None, api_secret: str = None):
"""
Initialize Huobi connector.
Args:
api_key: API key for authentication (optional)
api_secret: API secret for authentication (optional)
"""
super().__init__("huobi", self.WEBSOCKET_URL)
# Authentication credentials (optional)
self.api_key = api_key
self.api_secret = api_secret
# Huobi-specific message handlers
self.message_handlers.update({
'market.*.depth.step0': self._handle_orderbook_update,
'market.*.trade.detail': self._handle_trade_update,
'ping': self._handle_ping,
'pong': self._handle_pong
})
# Subscription tracking
self.subscribed_topics = set()
logger.info("Huobi connector initialized")
def _get_message_type(self, data: Dict) -> str:
"""
Determine message type from Huobi message data.
Args:
data: Parsed message data
Returns:
str: Message type identifier
"""
# Huobi message format
if 'ping' in data:
return 'ping'
elif 'pong' in data:
return 'pong'
elif 'ch' in data:
# Data channel message
channel = data['ch']
if 'depth' in channel:
return 'market.*.depth.step0'
elif 'trade' in channel:
return 'market.*.trade.detail'
else:
return channel
elif 'subbed' in data:
return 'subscription_response'
elif 'unsubbed' in data:
return 'unsubscription_response'
elif 'status' in data and data.get('status') == 'error':
return 'error'
return 'unknown'
def normalize_symbol(self, symbol: str) -> str:
"""
Normalize symbol to Huobi format.
Args:
symbol: Standard symbol format (e.g., 'BTCUSDT')
Returns:
str: Huobi symbol format (e.g., 'btcusdt')
"""
# Huobi uses lowercase symbols
normalized = symbol.lower().replace('-', '').replace('/', '')
# Validate symbol format
if not validate_symbol(normalized.upper()):
raise ValidationError(f"Invalid symbol format: {symbol}", "INVALID_SYMBOL")
return normalized
def _denormalize_symbol(self, huobi_symbol: str) -> str:
"""
Convert Huobi symbol back to standard format.
Args:
huobi_symbol: Huobi symbol format (e.g., 'btcusdt')
Returns:
str: Standard symbol format (e.g., 'BTCUSDT')
"""
return huobi_symbol.upper()
async def _decompress_message(self, message: bytes) -> str:
"""
Decompress GZIP message from Huobi.
Args:
message: Compressed message bytes
Returns:
str: Decompressed message string
"""
try:
return gzip.decompress(message).decode('utf-8')
except Exception as e:
logger.error(f"Error decompressing message: {e}")
return ""
async def _process_message(self, message: str) -> None:
"""
Override message processing to handle GZIP compression.
Args:
message: Raw message (could be compressed)
"""
try:
# Check if message is compressed (binary)
if isinstance(message, bytes):
message = await self._decompress_message(message)
if not message:
return
# Parse JSON message
data = json.loads(message)
# Handle ping/pong first
if 'ping' in data:
await self._handle_ping(data)
return
# Determine message type and route to appropriate handler
message_type = self._get_message_type(data)
if message_type in self.message_handlers:
await self.message_handlers[message_type](data)
else:
logger.debug(f"Unhandled message type '{message_type}' from {self.exchange_name}")
except json.JSONDecodeError as e:
logger.warning(f"Invalid JSON message from {self.exchange_name}: {e}")
except Exception as e:
logger.error(f"Error processing message from {self.exchange_name}: {e}")
async def subscribe_orderbook(self, symbol: str) -> None:
"""
Subscribe to order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
huobi_symbol = self.normalize_symbol(symbol)
topic = f"market.{huobi_symbol}.depth.step0"
# Create subscription message
subscription_msg = {
"sub": topic,
"id": str(int(time.time()))
}
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'orderbook' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('orderbook')
self.subscribed_topics.add(topic)
logger.info(f"Subscribed to order book for {symbol} ({huobi_symbol}) on Huobi")
else:
logger.error(f"Failed to subscribe to order book for {symbol} on Huobi")
except Exception as e:
logger.error(f"Error subscribing to order book for {symbol}: {e}")
raise
async def subscribe_trades(self, symbol: str) -> None:
"""
Subscribe to trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
huobi_symbol = self.normalize_symbol(symbol)
topic = f"market.{huobi_symbol}.trade.detail"
# Create subscription message
subscription_msg = {
"sub": topic,
"id": str(int(time.time()))
}
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'trades' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('trades')
self.subscribed_topics.add(topic)
logger.info(f"Subscribed to trades for {symbol} ({huobi_symbol}) on Huobi")
else:
logger.error(f"Failed to subscribe to trades for {symbol} on Huobi")
except Exception as e:
logger.error(f"Error subscribing to trades for {symbol}: {e}")
raise
async def unsubscribe_orderbook(self, symbol: str) -> None:
"""
Unsubscribe from order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
huobi_symbol = self.normalize_symbol(symbol)
topic = f"market.{huobi_symbol}.depth.step0"
# Create unsubscription message
unsubscription_msg = {
"unsub": topic,
"id": str(int(time.time()))
}
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'orderbook' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('orderbook')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
self.subscribed_topics.discard(topic)
logger.info(f"Unsubscribed from order book for {symbol} ({huobi_symbol}) on Huobi")
else:
logger.error(f"Failed to unsubscribe from order book for {symbol} on Huobi")
except Exception as e:
logger.error(f"Error unsubscribing from order book for {symbol}: {e}")
raise
async def unsubscribe_trades(self, symbol: str) -> None:
"""
Unsubscribe from trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
huobi_symbol = self.normalize_symbol(symbol)
topic = f"market.{huobi_symbol}.trade.detail"
# Create unsubscription message
unsubscription_msg = {
"unsub": topic,
"id": str(int(time.time()))
}
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'trades' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('trades')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
self.subscribed_topics.discard(topic)
logger.info(f"Unsubscribed from trades for {symbol} ({huobi_symbol}) on Huobi")
else:
logger.error(f"Failed to unsubscribe from trades for {symbol} on Huobi")
except Exception as e:
logger.error(f"Error unsubscribing from trades for {symbol}: {e}")
raise
async def get_symbols(self) -> List[str]:
"""
Get list of available trading symbols from Huobi.
Returns:
List[str]: List of available symbols in standard format
"""
try:
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.API_URL}/v1/common/symbols") as response:
if response.status == 200:
data = await response.json()
if data.get('status') != 'ok':
logger.error(f"Huobi API error: {data}")
return []
symbols = []
symbol_data = data.get('data', [])
for symbol_info in symbol_data:
if symbol_info.get('state') == 'online':
symbol = symbol_info.get('symbol', '')
# Convert to standard format
standard_symbol = self._denormalize_symbol(symbol)
symbols.append(standard_symbol)
logger.info(f"Retrieved {len(symbols)} symbols from Huobi")
return symbols
else:
logger.error(f"Failed to get symbols from Huobi: HTTP {response.status}")
return []
except Exception as e:
logger.error(f"Error getting symbols from Huobi: {e}")
return []
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
"""
Get current order book snapshot from Huobi REST API.
Args:
symbol: Trading symbol
depth: Number of price levels to retrieve
Returns:
OrderBookSnapshot: Current order book or None if unavailable
"""
try:
import aiohttp
huobi_symbol = self.normalize_symbol(symbol)
# Huobi supports depths: 5, 10, 20
valid_depths = [5, 10, 20]
api_depth = min(valid_depths, key=lambda x: abs(x - depth))
url = f"{self.API_URL}/market/depth"
params = {
'symbol': huobi_symbol,
'depth': api_depth,
'type': 'step0'
}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
if response.status == 200:
data = await response.json()
if data.get('status') != 'ok':
logger.error(f"Huobi API error: {data}")
return None
tick_data = data.get('tick', {})
return self._parse_orderbook_snapshot(tick_data, symbol)
else:
logger.error(f"Failed to get order book for {symbol}: HTTP {response.status}")
return None
except Exception as e:
logger.error(f"Error getting order book snapshot for {symbol}: {e}")
return None
def _parse_orderbook_snapshot(self, data: Dict, symbol: str) -> OrderBookSnapshot:
"""
Parse Huobi order book data into OrderBookSnapshot.
Args:
data: Raw Huobi order book data
symbol: Trading symbol
Returns:
OrderBookSnapshot: Parsed order book
"""
try:
# Parse bids and asks
bids = []
for bid_data in data.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in data.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(int(data.get('ts', 0)) / 1000, tz=timezone.utc),
bids=bids,
asks=asks,
sequence_id=data.get('version')
)
return orderbook
except Exception as e:
logger.error(f"Error parsing order book snapshot: {e}")
raise ValidationError(f"Invalid order book data: {e}", "PARSE_ERROR")
async def _handle_orderbook_update(self, data: Dict) -> None:
"""
Handle order book update from Huobi.
Args:
data: Order book update data
"""
try:
set_correlation_id()
# Extract symbol from channel
channel = data.get('ch', '')
if not channel:
logger.warning("Order book update missing channel")
return
# Parse channel: market.btcusdt.depth.step0
parts = channel.split('.')
if len(parts) < 2:
logger.warning("Invalid order book channel format")
return
huobi_symbol = parts[1]
symbol = self._denormalize_symbol(huobi_symbol)
tick_data = data.get('tick', {})
# Parse bids and asks
bids = []
for bid_data in tick_data.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in tick_data.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(int(tick_data.get('ts', 0)) / 1000, tz=timezone.utc),
bids=bids,
asks=asks,
sequence_id=tick_data.get('version')
)
# Notify callbacks
self._notify_data_callbacks(orderbook)
logger.debug(f"Processed order book update for {symbol}")
except Exception as e:
logger.error(f"Error handling order book update: {e}")
async def _handle_trade_update(self, data: Dict) -> None:
"""
Handle trade update from Huobi.
Args:
data: Trade update data
"""
try:
set_correlation_id()
# Extract symbol from channel
channel = data.get('ch', '')
if not channel:
logger.warning("Trade update missing channel")
return
# Parse channel: market.btcusdt.trade.detail
parts = channel.split('.')
if len(parts) < 2:
logger.warning("Invalid trade channel format")
return
huobi_symbol = parts[1]
symbol = self._denormalize_symbol(huobi_symbol)
tick_data = data.get('tick', {})
trades_data = tick_data.get('data', [])
# Process each trade
for trade_data in trades_data:
price = float(trade_data.get('price', 0))
amount = float(trade_data.get('amount', 0))
# Validate data
if not validate_price(price) or not validate_volume(amount):
logger.warning(f"Invalid trade data: price={price}, amount={amount}")
continue
# Determine side (Huobi uses 'direction' field)
direction = trade_data.get('direction', 'unknown')
side = 'buy' if direction == 'buy' else 'sell'
# Create trade event
trade = TradeEvent(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(int(trade_data.get('ts', 0)) / 1000, tz=timezone.utc),
price=price,
size=amount,
side=side,
trade_id=str(trade_data.get('tradeId', trade_data.get('id', '')))
)
# Notify callbacks
self._notify_data_callbacks(trade)
logger.debug(f"Processed trade for {symbol}: {side} {amount} @ {price}")
except Exception as e:
logger.error(f"Error handling trade update: {e}")
async def _handle_ping(self, data: Dict) -> None:
"""
Handle ping message from Huobi and respond with pong.
Args:
data: Ping message data
"""
try:
ping_value = data.get('ping')
if ping_value:
# Respond with pong
pong_msg = {"pong": ping_value}
await self._send_message(pong_msg)
logger.debug(f"Responded to Huobi ping with pong: {ping_value}")
except Exception as e:
logger.error(f"Error handling ping: {e}")
async def _handle_pong(self, data: Dict) -> None:
"""
Handle pong response from Huobi.
Args:
data: Pong response data
"""
logger.debug("Received Huobi pong")
def _get_auth_signature(self, method: str, host: str, path: str,
params: Dict[str, str]) -> str:
"""
Generate authentication signature for Huobi.
Args:
method: HTTP method
host: API host
path: Request path
params: Request parameters
Returns:
str: Authentication signature
"""
if not self.api_key or not self.api_secret:
return ""
try:
# Sort parameters
sorted_params = sorted(params.items())
query_string = '&'.join([f"{k}={v}" for k, v in sorted_params])
# Create signature string
signature_string = f"{method}\n{host}\n{path}\n{query_string}"
# Generate signature
signature = base64.b64encode(
hmac.new(
self.api_secret.encode('utf-8'),
signature_string.encode('utf-8'),
hashlib.sha256
).digest()
).decode('utf-8')
return signature
except Exception as e:
logger.error(f"Error generating auth signature: {e}")
return ""
def get_huobi_stats(self) -> Dict[str, Any]:
"""Get Huobi-specific statistics."""
base_stats = self.get_stats()
huobi_stats = {
'subscribed_topics': list(self.subscribed_topics),
'authenticated': bool(self.api_key and self.api_secret)
}
base_stats.update(huobi_stats)
return base_stats

View File

@ -0,0 +1,708 @@
"""
Kraken exchange connector implementation.
Supports WebSocket connections to Kraken exchange with their specific message format.
"""
import json
import hashlib
import hmac
import base64
import time
from typing import Dict, List, Optional, Any
from datetime import datetime, timezone
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ValidationError, ConnectionError
from ..utils.validation import validate_symbol, validate_price, validate_volume
from .base_connector import BaseExchangeConnector
logger = get_logger(__name__)
class KrakenConnector(BaseExchangeConnector):
"""
Kraken WebSocket connector implementation.
Supports:
- Order book streams
- Trade streams
- Symbol normalization for Kraken format
- Authentication for private channels (if needed)
"""
# Kraken WebSocket URLs
WEBSOCKET_URL = "wss://ws.kraken.com"
WEBSOCKET_AUTH_URL = "wss://ws-auth.kraken.com"
API_URL = "https://api.kraken.com"
def __init__(self, api_key: str = None, api_secret: str = None):
"""
Initialize Kraken connector.
Args:
api_key: API key for authentication (optional)
api_secret: API secret for authentication (optional)
"""
super().__init__("kraken", self.WEBSOCKET_URL)
# Authentication credentials (optional)
self.api_key = api_key
self.api_secret = api_secret
# Kraken-specific message handlers
self.message_handlers.update({
'book-10': self._handle_orderbook_update,
'book-25': self._handle_orderbook_update,
'book-100': self._handle_orderbook_update,
'book-500': self._handle_orderbook_update,
'book-1000': self._handle_orderbook_update,
'trade': self._handle_trade_update,
'systemStatus': self._handle_system_status,
'subscriptionStatus': self._handle_subscription_status,
'heartbeat': self._handle_heartbeat
})
# Kraken-specific tracking
self.channel_map = {} # channel_id -> (channel_name, symbol)
self.subscription_ids = {} # symbol -> subscription_id
self.system_status = 'unknown'
logger.info("Kraken connector initialized")
def _get_message_type(self, data: Dict) -> str:
"""
Determine message type from Kraken message data.
Args:
data: Parsed message data
Returns:
str: Message type identifier
"""
# Kraken messages can be arrays or objects
if isinstance(data, list) and len(data) >= 2:
# Data message format: [channelID, data, channelName, pair]
if len(data) >= 4:
channel_name = data[2]
return channel_name
else:
return 'unknown'
elif isinstance(data, dict):
# Status/control messages
if 'event' in data:
return data['event']
elif 'errorMessage' in data:
return 'error'
return 'unknown'
def normalize_symbol(self, symbol: str) -> str:
"""
Normalize symbol to Kraken format.
Args:
symbol: Standard symbol format (e.g., 'BTCUSDT')
Returns:
str: Kraken pair format (e.g., 'XBT/USD')
"""
# Kraken uses different symbol names
symbol_map = {
'BTCUSDT': 'XBT/USD',
'ETHUSDT': 'ETH/USD',
'ADAUSDT': 'ADA/USD',
'DOTUSDT': 'DOT/USD',
'LINKUSDT': 'LINK/USD',
'LTCUSDT': 'LTC/USD',
'XRPUSDT': 'XRP/USD',
'BCHUSDT': 'BCH/USD',
'EOSUSDT': 'EOS/USD',
'XLMUSDT': 'XLM/USD'
}
if symbol.upper() in symbol_map:
return symbol_map[symbol.upper()]
else:
# Generic conversion: BTCUSDT -> BTC/USD
if symbol.endswith('USDT'):
base = symbol[:-4]
return f"{base}/USD"
elif symbol.endswith('USD'):
base = symbol[:-3]
return f"{base}/USD"
else:
# Assume it's already in correct format
return symbol.upper()
def _denormalize_symbol(self, kraken_pair: str) -> str:
"""
Convert Kraken pair back to standard format.
Args:
kraken_pair: Kraken pair format (e.g., 'XBT/USD')
Returns:
str: Standard symbol format (e.g., 'BTCUSDT')
"""
# Reverse mapping
reverse_map = {
'XBT/USD': 'BTCUSDT',
'ETH/USD': 'ETHUSDT',
'ADA/USD': 'ADAUSDT',
'DOT/USD': 'DOTUSDT',
'LINK/USD': 'LINKUSDT',
'LTC/USD': 'LTCUSDT',
'XRP/USD': 'XRPUSDT',
'BCH/USD': 'BCHUSDT',
'EOS/USD': 'EOSUSDT',
'XLM/USD': 'XLMUSDT'
}
if kraken_pair in reverse_map:
return reverse_map[kraken_pair]
else:
# Generic conversion: BTC/USD -> BTCUSDT
if '/' in kraken_pair:
base, quote = kraken_pair.split('/', 1)
if quote == 'USD':
return f"{base}USDT"
else:
return f"{base}{quote}"
return kraken_pair
async def subscribe_orderbook(self, symbol: str) -> None:
"""
Subscribe to order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
kraken_pair = self.normalize_symbol(symbol)
# Create subscription message
subscription_msg = {
"event": "subscribe",
"pair": [kraken_pair],
"subscription": {
"name": "book",
"depth": 25 # 25 levels
}
}
# Add authentication if credentials provided
if self.api_key and self.api_secret:
subscription_msg["subscription"]["token"] = self._get_auth_token()
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'orderbook' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('orderbook')
logger.info(f"Subscribed to order book for {symbol} ({kraken_pair}) on Kraken")
else:
logger.error(f"Failed to subscribe to order book for {symbol} on Kraken")
except Exception as e:
logger.error(f"Error subscribing to order book for {symbol}: {e}")
raise
async def subscribe_trades(self, symbol: str) -> None:
"""
Subscribe to trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
kraken_pair = self.normalize_symbol(symbol)
# Create subscription message
subscription_msg = {
"event": "subscribe",
"pair": [kraken_pair],
"subscription": {
"name": "trade"
}
}
# Add authentication if credentials provided
if self.api_key and self.api_secret:
subscription_msg["subscription"]["token"] = self._get_auth_token()
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'trades' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('trades')
logger.info(f"Subscribed to trades for {symbol} ({kraken_pair}) on Kraken")
else:
logger.error(f"Failed to subscribe to trades for {symbol} on Kraken")
except Exception as e:
logger.error(f"Error subscribing to trades for {symbol}: {e}")
raise
async def unsubscribe_orderbook(self, symbol: str) -> None:
"""
Unsubscribe from order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
kraken_pair = self.normalize_symbol(symbol)
# Create unsubscription message
unsubscription_msg = {
"event": "unsubscribe",
"pair": [kraken_pair],
"subscription": {
"name": "book"
}
}
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'orderbook' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('orderbook')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
logger.info(f"Unsubscribed from order book for {symbol} ({kraken_pair}) on Kraken")
else:
logger.error(f"Failed to unsubscribe from order book for {symbol} on Kraken")
except Exception as e:
logger.error(f"Error unsubscribing from order book for {symbol}: {e}")
raise
async def unsubscribe_trades(self, symbol: str) -> None:
"""
Unsubscribe from trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
kraken_pair = self.normalize_symbol(symbol)
# Create unsubscription message
unsubscription_msg = {
"event": "unsubscribe",
"pair": [kraken_pair],
"subscription": {
"name": "trade"
}
}
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'trades' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('trades')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
logger.info(f"Unsubscribed from trades for {symbol} ({kraken_pair}) on Kraken")
else:
logger.error(f"Failed to unsubscribe from trades for {symbol} on Kraken")
except Exception as e:
logger.error(f"Error unsubscribing from trades for {symbol}: {e}")
raise
async def get_symbols(self) -> List[str]:
"""
Get list of available trading symbols from Kraken.
Returns:
List[str]: List of available symbols in standard format
"""
try:
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.API_URL}/0/public/AssetPairs") as response:
if response.status == 200:
data = await response.json()
if data.get('error'):
logger.error(f"Kraken API error: {data['error']}")
return []
symbols = []
pairs = data.get('result', {})
for pair_name, pair_info in pairs.items():
# Skip dark pool pairs
if '.d' in pair_name:
continue
# Get the WebSocket pair name
ws_name = pair_info.get('wsname')
if ws_name:
# Convert to standard format
standard_symbol = self._denormalize_symbol(ws_name)
symbols.append(standard_symbol)
logger.info(f"Retrieved {len(symbols)} symbols from Kraken")
return symbols
else:
logger.error(f"Failed to get symbols from Kraken: HTTP {response.status}")
return []
except Exception as e:
logger.error(f"Error getting symbols from Kraken: {e}")
return []
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
"""
Get current order book snapshot from Kraken REST API.
Args:
symbol: Trading symbol
depth: Number of price levels to retrieve
Returns:
OrderBookSnapshot: Current order book or None if unavailable
"""
try:
import aiohttp
kraken_pair = self.normalize_symbol(symbol)
url = f"{self.API_URL}/0/public/Depth"
params = {
'pair': kraken_pair,
'count': min(depth, 500) # Kraken max is 500
}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
if response.status == 200:
data = await response.json()
if data.get('error'):
logger.error(f"Kraken API error: {data['error']}")
return None
result = data.get('result', {})
# Kraken returns data with the actual pair name as key
pair_data = None
for key, value in result.items():
if isinstance(value, dict) and 'bids' in value and 'asks' in value:
pair_data = value
break
if pair_data:
return self._parse_orderbook_snapshot(pair_data, symbol)
else:
logger.error(f"No order book data found for {symbol}")
return None
else:
logger.error(f"Failed to get order book for {symbol}: HTTP {response.status}")
return None
except Exception as e:
logger.error(f"Error getting order book snapshot for {symbol}: {e}")
return None
def _parse_orderbook_snapshot(self, data: Dict, symbol: str) -> OrderBookSnapshot:
"""
Parse Kraken order book data into OrderBookSnapshot.
Args:
data: Raw Kraken order book data
symbol: Trading symbol
Returns:
OrderBookSnapshot: Parsed order book
"""
try:
# Parse bids and asks
bids = []
for bid_data in data.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in data.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.now(timezone.utc),
bids=bids,
asks=asks
)
return orderbook
except Exception as e:
logger.error(f"Error parsing order book snapshot: {e}")
raise ValidationError(f"Invalid order book data: {e}", "PARSE_ERROR")
async def _handle_orderbook_update(self, data: List) -> None:
"""
Handle order book update from Kraken.
Args:
data: Order book update data (Kraken array format)
"""
try:
set_correlation_id()
# Kraken format: [channelID, data, channelName, pair]
if len(data) < 4:
logger.warning("Invalid Kraken order book update format")
return
channel_id = data[0]
book_data = data[1]
channel_name = data[2]
kraken_pair = data[3]
symbol = self._denormalize_symbol(kraken_pair)
# Track channel mapping
self.channel_map[channel_id] = (channel_name, symbol)
# Parse order book data
bids = []
asks = []
# Kraken book data can have 'b' (bids), 'a' (asks), 'bs' (bid snapshot), 'as' (ask snapshot)
if 'b' in book_data:
for bid_data in book_data['b']:
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
if 'bs' in book_data: # Bid snapshot
for bid_data in book_data['bs']:
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
if 'a' in book_data:
for ask_data in book_data['a']:
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
if 'as' in book_data: # Ask snapshot
for ask_data in book_data['as']:
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.now(timezone.utc),
bids=bids,
asks=asks
)
# Notify callbacks
self._notify_data_callbacks(orderbook)
logger.debug(f"Processed order book update for {symbol}")
except Exception as e:
logger.error(f"Error handling order book update: {e}")
async def _handle_trade_update(self, data: List) -> None:
"""
Handle trade update from Kraken.
Args:
data: Trade update data (Kraken array format)
"""
try:
set_correlation_id()
# Kraken format: [channelID, data, channelName, pair]
if len(data) < 4:
logger.warning("Invalid Kraken trade update format")
return
channel_id = data[0]
trade_data = data[1]
channel_name = data[2]
kraken_pair = data[3]
symbol = self._denormalize_symbol(kraken_pair)
# Track channel mapping
self.channel_map[channel_id] = (channel_name, symbol)
# Process trade data (array of trades)
for trade_info in trade_data:
if len(trade_info) >= 6:
price = float(trade_info[0])
size = float(trade_info[1])
timestamp = float(trade_info[2])
side = trade_info[3] # 'b' for buy, 's' for sell
order_type = trade_info[4] # 'm' for market, 'l' for limit
misc = trade_info[5] if len(trade_info) > 5 else ''
# Validate data
if not validate_price(price) or not validate_volume(size):
logger.warning(f"Invalid trade data: price={price}, size={size}")
continue
# Convert side
trade_side = 'buy' if side == 'b' else 'sell'
# Create trade event
trade = TradeEvent(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(timestamp, tz=timezone.utc),
price=price,
size=size,
side=trade_side,
trade_id=f"{timestamp}_{price}_{size}" # Generate ID
)
# Notify callbacks
self._notify_data_callbacks(trade)
logger.debug(f"Processed trade for {symbol}: {trade_side} {size} @ {price}")
except Exception as e:
logger.error(f"Error handling trade update: {e}")
async def _handle_system_status(self, data: Dict) -> None:
"""
Handle system status message from Kraken.
Args:
data: System status data
"""
try:
status = data.get('status', 'unknown')
version = data.get('version', 'unknown')
self.system_status = status
logger.info(f"Kraken system status: {status} (version: {version})")
if status != 'online':
logger.warning(f"Kraken system not online: {status}")
except Exception as e:
logger.error(f"Error handling system status: {e}")
async def _handle_subscription_status(self, data: Dict) -> None:
"""
Handle subscription status message from Kraken.
Args:
data: Subscription status data
"""
try:
status = data.get('status', 'unknown')
channel_name = data.get('channelName', 'unknown')
pair = data.get('pair', 'unknown')
subscription = data.get('subscription', {})
if status == 'subscribed':
logger.info(f"Kraken subscription confirmed: {channel_name} for {pair}")
# Store subscription ID if provided
if 'channelID' in data:
channel_id = data['channelID']
symbol = self._denormalize_symbol(pair)
self.channel_map[channel_id] = (channel_name, symbol)
elif status == 'unsubscribed':
logger.info(f"Kraken unsubscription confirmed: {channel_name} for {pair}")
elif status == 'error':
error_message = data.get('errorMessage', 'Unknown error')
logger.error(f"Kraken subscription error: {error_message}")
except Exception as e:
logger.error(f"Error handling subscription status: {e}")
async def _handle_heartbeat(self, data: Dict) -> None:
"""
Handle heartbeat message from Kraken.
Args:
data: Heartbeat data
"""
logger.debug("Received Kraken heartbeat")
def _get_auth_token(self) -> str:
"""
Generate authentication token for Kraken WebSocket.
Returns:
str: Authentication token
"""
if not self.api_key or not self.api_secret:
return ""
try:
# This is a simplified version - actual Kraken auth is more complex
# and requires getting a token from the REST API first
nonce = str(int(time.time() * 1000))
message = nonce + self.api_key
signature = hmac.new(
base64.b64decode(self.api_secret),
message.encode('utf-8'),
hashlib.sha512
).hexdigest()
return f"{self.api_key}:{signature}:{nonce}"
except Exception as e:
logger.error(f"Error generating auth token: {e}")
return ""
def get_kraken_stats(self) -> Dict[str, Any]:
"""Get Kraken-specific statistics."""
base_stats = self.get_stats()
kraken_stats = {
'system_status': self.system_status,
'channel_mappings': len(self.channel_map),
'authenticated': bool(self.api_key and self.api_secret)
}
base_stats.update(kraken_stats)
return base_stats

View File

@ -0,0 +1,776 @@
"""
KuCoin exchange connector implementation.
Supports WebSocket connections to KuCoin with proper token-based authentication.
"""
import json
import hmac
import hashlib
import base64
import time
from typing import Dict, List, Optional, Any
from datetime import datetime, timezone
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ValidationError, ConnectionError
from ..utils.validation import validate_symbol, validate_price, validate_volume
from .base_connector import BaseExchangeConnector
logger = get_logger(__name__)
class KuCoinConnector(BaseExchangeConnector):
"""
KuCoin WebSocket connector implementation.
Supports:
- Token-based authentication
- Order book streams
- Trade streams
- Symbol normalization
- Bullet connection protocol
"""
# KuCoin API URLs
API_URL = "https://api.kucoin.com"
SANDBOX_API_URL = "https://openapi-sandbox.kucoin.com"
def __init__(self, use_sandbox: bool = False, api_key: str = None,
api_secret: str = None, passphrase: str = None):
"""
Initialize KuCoin connector.
Args:
use_sandbox: Whether to use sandbox environment
api_key: API key for authentication (optional)
api_secret: API secret for authentication (optional)
passphrase: API passphrase for authentication (optional)
"""
# KuCoin requires getting WebSocket URL from REST API
super().__init__("kucoin", "") # URL will be set after token retrieval
# Authentication credentials (optional)
self.api_key = api_key
self.api_secret = api_secret
self.passphrase = passphrase
self.use_sandbox = use_sandbox
# KuCoin-specific attributes
self.token = None
self.connect_id = None
self.ping_interval = 18000 # 18 seconds (KuCoin requirement)
self.ping_timeout = 10000 # 10 seconds
# KuCoin-specific message handlers
self.message_handlers.update({
'message': self._handle_data_message,
'welcome': self._handle_welcome_message,
'ack': self._handle_ack_message,
'error': self._handle_error_message,
'pong': self._handle_pong_message
})
# Subscription tracking
self.subscribed_topics = set()
self.subscription_id = 1
logger.info(f"KuCoin connector initialized ({'sandbox' if use_sandbox else 'live'})")
def _get_message_type(self, data: Dict) -> str:
"""
Determine message type from KuCoin message data.
Args:
data: Parsed message data
Returns:
str: Message type identifier
"""
# KuCoin message format
if 'type' in data:
return data['type'] # 'message', 'welcome', 'ack', 'error', 'pong'
elif 'subject' in data:
# Data message with subject
return 'message'
return 'unknown'
def normalize_symbol(self, symbol: str) -> str:
"""
Normalize symbol to KuCoin format.
Args:
symbol: Standard symbol format (e.g., 'BTCUSDT')
Returns:
str: KuCoin symbol format (e.g., 'BTC-USDT')
"""
# KuCoin uses dash-separated format
if symbol.upper() == 'BTCUSDT':
return 'BTC-USDT'
elif symbol.upper() == 'ETHUSDT':
return 'ETH-USDT'
elif symbol.upper().endswith('USDT'):
base = symbol[:-4].upper()
return f"{base}-USDT"
elif symbol.upper().endswith('USD'):
base = symbol[:-3].upper()
return f"{base}-USD"
else:
# Assume it's already in correct format or add dash
if '-' not in symbol:
# Try to split common patterns
if len(symbol) >= 6:
# Assume last 4 chars are quote currency
base = symbol[:-4].upper()
quote = symbol[-4:].upper()
return f"{base}-{quote}"
else:
return symbol.upper()
else:
return symbol.upper()
def _denormalize_symbol(self, kucoin_symbol: str) -> str:
"""
Convert KuCoin symbol back to standard format.
Args:
kucoin_symbol: KuCoin symbol format (e.g., 'BTC-USDT')
Returns:
str: Standard symbol format (e.g., 'BTCUSDT')
"""
if '-' in kucoin_symbol:
return kucoin_symbol.replace('-', '')
return kucoin_symbol
async def _get_websocket_token(self) -> Optional[Dict[str, Any]]:
"""
Get WebSocket connection token from KuCoin REST API.
Returns:
Dict: Token information including WebSocket URL
"""
try:
import aiohttp
api_url = self.SANDBOX_API_URL if self.use_sandbox else self.API_URL
endpoint = "/api/v1/bullet-public"
# Use private endpoint if authenticated
if self.api_key and self.api_secret and self.passphrase:
endpoint = "/api/v1/bullet-private"
headers = self._get_auth_headers("POST", endpoint, "")
else:
headers = {}
async with aiohttp.ClientSession() as session:
async with session.post(f"{api_url}{endpoint}", headers=headers) as response:
if response.status == 200:
data = await response.json()
if data.get('code') != '200000':
logger.error(f"KuCoin token error: {data.get('msg')}")
return None
return data.get('data')
else:
logger.error(f"Failed to get KuCoin token: HTTP {response.status}")
return None
except Exception as e:
logger.error(f"Error getting KuCoin WebSocket token: {e}")
return None
async def connect(self) -> bool:
"""Override connect to get token first."""
try:
# Get WebSocket token and URL
token_data = await self._get_websocket_token()
if not token_data:
logger.error("Failed to get KuCoin WebSocket token")
return False
self.token = token_data.get('token')
servers = token_data.get('instanceServers', [])
if not servers:
logger.error("No KuCoin WebSocket servers available")
return False
# Use first available server
server = servers[0]
self.websocket_url = f"{server['endpoint']}?token={self.token}&connectId={int(time.time() * 1000)}"
self.ping_interval = server.get('pingInterval', 18000)
self.ping_timeout = server.get('pingTimeout', 10000)
logger.info(f"KuCoin WebSocket URL: {server['endpoint']}")
# Now connect using the base connector method
return await super().connect()
except Exception as e:
logger.error(f"Error connecting to KuCoin: {e}")
return False
async def subscribe_orderbook(self, symbol: str) -> None:
"""
Subscribe to order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
kucoin_symbol = self.normalize_symbol(symbol)
topic = f"/market/level2:{kucoin_symbol}"
# Create subscription message
subscription_msg = {
"id": str(self.subscription_id),
"type": "subscribe",
"topic": topic,
"privateChannel": False,
"response": True
}
self.subscription_id += 1
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'orderbook' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('orderbook')
self.subscribed_topics.add(topic)
logger.info(f"Subscribed to order book for {symbol} ({kucoin_symbol}) on KuCoin")
else:
logger.error(f"Failed to subscribe to order book for {symbol} on KuCoin")
except Exception as e:
logger.error(f"Error subscribing to order book for {symbol}: {e}")
raise
async def subscribe_trades(self, symbol: str) -> None:
"""
Subscribe to trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
kucoin_symbol = self.normalize_symbol(symbol)
topic = f"/market/match:{kucoin_symbol}"
# Create subscription message
subscription_msg = {
"id": str(self.subscription_id),
"type": "subscribe",
"topic": topic,
"privateChannel": False,
"response": True
}
self.subscription_id += 1
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'trades' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('trades')
self.subscribed_topics.add(topic)
logger.info(f"Subscribed to trades for {symbol} ({kucoin_symbol}) on KuCoin")
else:
logger.error(f"Failed to subscribe to trades for {symbol} on KuCoin")
except Exception as e:
logger.error(f"Error subscribing to trades for {symbol}: {e}")
raise
async def unsubscribe_orderbook(self, symbol: str) -> None:
"""
Unsubscribe from order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
kucoin_symbol = self.normalize_symbol(symbol)
topic = f"/market/level2:{kucoin_symbol}"
# Create unsubscription message
unsubscription_msg = {
"id": str(self.subscription_id),
"type": "unsubscribe",
"topic": topic,
"privateChannel": False,
"response": True
}
self.subscription_id += 1
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'orderbook' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('orderbook')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
self.subscribed_topics.discard(topic)
logger.info(f"Unsubscribed from order book for {symbol} ({kucoin_symbol}) on KuCoin")
else:
logger.error(f"Failed to unsubscribe from order book for {symbol} on KuCoin")
except Exception as e:
logger.error(f"Error unsubscribing from order book for {symbol}: {e}")
raise
async def unsubscribe_trades(self, symbol: str) -> None:
"""
Unsubscribe from trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
kucoin_symbol = self.normalize_symbol(symbol)
topic = f"/market/match:{kucoin_symbol}"
# Create unsubscription message
unsubscription_msg = {
"id": str(self.subscription_id),
"type": "unsubscribe",
"topic": topic,
"privateChannel": False,
"response": True
}
self.subscription_id += 1
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'trades' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('trades')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
self.subscribed_topics.discard(topic)
logger.info(f"Unsubscribed from trades for {symbol} ({kucoin_symbol}) on KuCoin")
else:
logger.error(f"Failed to unsubscribe from trades for {symbol} on KuCoin")
except Exception as e:
logger.error(f"Error unsubscribing from trades for {symbol}: {e}")
raise
async def get_symbols(self) -> List[str]:
"""
Get list of available trading symbols from KuCoin.
Returns:
List[str]: List of available symbols in standard format
"""
try:
import aiohttp
api_url = self.SANDBOX_API_URL if self.use_sandbox else self.API_URL
async with aiohttp.ClientSession() as session:
async with session.get(f"{api_url}/api/v1/symbols") as response:
if response.status == 200:
data = await response.json()
if data.get('code') != '200000':
logger.error(f"KuCoin API error: {data.get('msg')}")
return []
symbols = []
symbol_data = data.get('data', [])
for symbol_info in symbol_data:
if symbol_info.get('enableTrading'):
symbol = symbol_info.get('symbol', '')
# Convert to standard format
standard_symbol = self._denormalize_symbol(symbol)
symbols.append(standard_symbol)
logger.info(f"Retrieved {len(symbols)} symbols from KuCoin")
return symbols
else:
logger.error(f"Failed to get symbols from KuCoin: HTTP {response.status}")
return []
except Exception as e:
logger.error(f"Error getting symbols from KuCoin: {e}")
return []
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
"""
Get current order book snapshot from KuCoin REST API.
Args:
symbol: Trading symbol
depth: Number of price levels to retrieve
Returns:
OrderBookSnapshot: Current order book or None if unavailable
"""
try:
import aiohttp
kucoin_symbol = self.normalize_symbol(symbol)
api_url = self.SANDBOX_API_URL if self.use_sandbox else self.API_URL
url = f"{api_url}/api/v1/market/orderbook/level2_20"
params = {'symbol': kucoin_symbol}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
if response.status == 200:
data = await response.json()
if data.get('code') != '200000':
logger.error(f"KuCoin API error: {data.get('msg')}")
return None
result = data.get('data', {})
return self._parse_orderbook_snapshot(result, symbol)
else:
logger.error(f"Failed to get order book for {symbol}: HTTP {response.status}")
return None
except Exception as e:
logger.error(f"Error getting order book snapshot for {symbol}: {e}")
return None
def _parse_orderbook_snapshot(self, data: Dict, symbol: str) -> OrderBookSnapshot:
"""
Parse KuCoin order book data into OrderBookSnapshot.
Args:
data: Raw KuCoin order book data
symbol: Trading symbol
Returns:
OrderBookSnapshot: Parsed order book
"""
try:
# Parse bids and asks
bids = []
for bid_data in data.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in data.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(int(data.get('time', 0)) / 1000, tz=timezone.utc),
bids=bids,
asks=asks,
sequence_id=int(data.get('sequence', 0))
)
return orderbook
except Exception as e:
logger.error(f"Error parsing order book snapshot: {e}")
raise ValidationError(f"Invalid order book data: {e}", "PARSE_ERROR")
async def _handle_data_message(self, data: Dict) -> None:
"""
Handle data message from KuCoin.
Args:
data: Data message
"""
try:
set_correlation_id()
subject = data.get('subject', '')
topic = data.get('topic', '')
message_data = data.get('data', {})
if 'level2' in subject:
await self._handle_orderbook_update(data)
elif 'match' in subject:
await self._handle_trade_update(data)
else:
logger.debug(f"Unhandled KuCoin subject: {subject}")
except Exception as e:
logger.error(f"Error handling data message: {e}")
async def _handle_orderbook_update(self, data: Dict) -> None:
"""
Handle order book update from KuCoin.
Args:
data: Order book update data
"""
try:
topic = data.get('topic', '')
if not topic:
logger.warning("Order book update missing topic")
return
# Extract symbol from topic: /market/level2:BTC-USDT
parts = topic.split(':')
if len(parts) < 2:
logger.warning("Invalid order book topic format")
return
kucoin_symbol = parts[1]
symbol = self._denormalize_symbol(kucoin_symbol)
message_data = data.get('data', {})
changes = message_data.get('changes', {})
# Parse bids and asks changes
bids = []
for bid_data in changes.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in changes.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(int(message_data.get('time', 0)) / 1000, tz=timezone.utc),
bids=bids,
asks=asks,
sequence_id=int(message_data.get('sequenceEnd', 0))
)
# Notify callbacks
self._notify_data_callbacks(orderbook)
logger.debug(f"Processed order book update for {symbol}")
except Exception as e:
logger.error(f"Error handling order book update: {e}")
async def _handle_trade_update(self, data: Dict) -> None:
"""
Handle trade update from KuCoin.
Args:
data: Trade update data
"""
try:
topic = data.get('topic', '')
if not topic:
logger.warning("Trade update missing topic")
return
# Extract symbol from topic: /market/match:BTC-USDT
parts = topic.split(':')
if len(parts) < 2:
logger.warning("Invalid trade topic format")
return
kucoin_symbol = parts[1]
symbol = self._denormalize_symbol(kucoin_symbol)
message_data = data.get('data', {})
price = float(message_data.get('price', 0))
size = float(message_data.get('size', 0))
# Validate data
if not validate_price(price) or not validate_volume(size):
logger.warning(f"Invalid trade data: price={price}, size={size}")
return
# Determine side (KuCoin uses 'side' field)
side = message_data.get('side', 'unknown').lower()
# Create trade event
trade = TradeEvent(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(int(message_data.get('time', 0)) / 1000, tz=timezone.utc),
price=price,
size=size,
side=side,
trade_id=str(message_data.get('tradeId', ''))
)
# Notify callbacks
self._notify_data_callbacks(trade)
logger.debug(f"Processed trade for {symbol}: {side} {size} @ {price}")
except Exception as e:
logger.error(f"Error handling trade update: {e}")
async def _handle_welcome_message(self, data: Dict) -> None:
"""
Handle welcome message from KuCoin.
Args:
data: Welcome message data
"""
try:
connect_id = data.get('id')
if connect_id:
self.connect_id = connect_id
logger.info(f"KuCoin connection established with ID: {connect_id}")
except Exception as e:
logger.error(f"Error handling welcome message: {e}")
async def _handle_ack_message(self, data: Dict) -> None:
"""
Handle acknowledgment message from KuCoin.
Args:
data: Ack message data
"""
try:
msg_id = data.get('id', '')
logger.debug(f"KuCoin ACK received for message ID: {msg_id}")
except Exception as e:
logger.error(f"Error handling ack message: {e}")
async def _handle_error_message(self, data: Dict) -> None:
"""
Handle error message from KuCoin.
Args:
data: Error message data
"""
try:
code = data.get('code', 'unknown')
message = data.get('data', 'Unknown error')
logger.error(f"KuCoin error {code}: {message}")
except Exception as e:
logger.error(f"Error handling error message: {e}")
async def _handle_pong_message(self, data: Dict) -> None:
"""
Handle pong message from KuCoin.
Args:
data: Pong message data
"""
logger.debug("Received KuCoin pong")
def _get_auth_headers(self, method: str, endpoint: str, body: str) -> Dict[str, str]:
"""
Generate authentication headers for KuCoin API.
Args:
method: HTTP method
endpoint: API endpoint
body: Request body
Returns:
Dict: Authentication headers
"""
if not all([self.api_key, self.api_secret, self.passphrase]):
return {}
try:
timestamp = str(int(time.time() * 1000))
# Create signature string
str_to_sign = timestamp + method + endpoint + body
signature = base64.b64encode(
hmac.new(
self.api_secret.encode('utf-8'),
str_to_sign.encode('utf-8'),
hashlib.sha256
).digest()
).decode('utf-8')
# Create passphrase signature
passphrase_signature = base64.b64encode(
hmac.new(
self.api_secret.encode('utf-8'),
self.passphrase.encode('utf-8'),
hashlib.sha256
).digest()
).decode('utf-8')
return {
'KC-API-SIGN': signature,
'KC-API-TIMESTAMP': timestamp,
'KC-API-KEY': self.api_key,
'KC-API-PASSPHRASE': passphrase_signature,
'KC-API-KEY-VERSION': '2',
'Content-Type': 'application/json'
}
except Exception as e:
logger.error(f"Error generating auth headers: {e}")
return {}
async def _send_ping(self) -> None:
"""Send ping to keep connection alive."""
try:
ping_msg = {
"id": str(self.subscription_id),
"type": "ping"
}
self.subscription_id += 1
await self._send_message(ping_msg)
logger.debug("Sent ping to KuCoin")
except Exception as e:
logger.error(f"Error sending ping: {e}")
def get_kucoin_stats(self) -> Dict[str, Any]:
"""Get KuCoin-specific statistics."""
base_stats = self.get_stats()
kucoin_stats = {
'subscribed_topics': list(self.subscribed_topics),
'use_sandbox': self.use_sandbox,
'authenticated': bool(self.api_key and self.api_secret and self.passphrase),
'connect_id': self.connect_id,
'token_available': bool(self.token),
'next_subscription_id': self.subscription_id
}
base_stats.update(kucoin_stats)
return base_stats

View File

@ -0,0 +1,282 @@
"""
MEXC exchange connector implementation.
Supports WebSocket connections to MEXC with their WebSocket streams.
"""
import json
from typing import Dict, List, Optional, Any
from datetime import datetime, timezone
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ValidationError, ConnectionError
from ..utils.validation import validate_symbol, validate_price, validate_volume
from .base_connector import BaseExchangeConnector
logger = get_logger(__name__)
class MEXCConnector(BaseExchangeConnector):
"""
MEXC WebSocket connector implementation.
Supports:
- Order book streams
- Trade streams
- Symbol normalization
"""
# MEXC WebSocket URLs
WEBSOCKET_URL = "wss://wbs.mexc.com/ws"
API_URL = "https://api.mexc.com"
def __init__(self, api_key: str = None, api_secret: str = None):
"""Initialize MEXC connector."""
super().__init__("mexc", self.WEBSOCKET_URL)
self.api_key = api_key
self.api_secret = api_secret
# MEXC-specific message handlers
self.message_handlers.update({
'spot@public.deals.v3.api': self._handle_trade_update,
'spot@public.increase.depth.v3.api': self._handle_orderbook_update,
'spot@public.limit.depth.v3.api': self._handle_orderbook_snapshot,
'pong': self._handle_pong
})
# Subscription tracking
self.subscribed_streams = set()
self.request_id = 1
logger.info("MEXC connector initialized")
def _get_message_type(self, data: Dict) -> str:
"""Determine message type from MEXC message data."""
if 'c' in data: # Channel
return data['c']
elif 'msg' in data:
return 'message'
elif 'pong' in data:
return 'pong'
return 'unknown'
def normalize_symbol(self, symbol: str) -> str:
"""Normalize symbol to MEXC format."""
# MEXC uses uppercase without separators (same as Binance)
normalized = symbol.upper().replace('-', '').replace('/', '')
if not validate_symbol(normalized):
raise ValidationError(f"Invalid symbol format: {symbol}", "INVALID_SYMBOL")
return normalized
async def subscribe_orderbook(self, symbol: str) -> None:
"""Subscribe to order book updates for a symbol."""
try:
set_correlation_id()
mexc_symbol = self.normalize_symbol(symbol)
subscription_msg = {
"method": "SUBSCRIPTION",
"params": [f"spot@public.limit.depth.v3.api@{mexc_symbol}@20"]
}
success = await self._send_message(subscription_msg)
if success:
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'orderbook' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('orderbook')
self.subscribed_streams.add(f"spot@public.limit.depth.v3.api@{mexc_symbol}@20")
logger.info(f"Subscribed to order book for {symbol} ({mexc_symbol}) on MEXC")
else:
logger.error(f"Failed to subscribe to order book for {symbol} on MEXC")
except Exception as e:
logger.error(f"Error subscribing to order book for {symbol}: {e}")
raise
async def subscribe_trades(self, symbol: str) -> None:
"""Subscribe to trade updates for a symbol."""
try:
set_correlation_id()
mexc_symbol = self.normalize_symbol(symbol)
subscription_msg = {
"method": "SUBSCRIPTION",
"params": [f"spot@public.deals.v3.api@{mexc_symbol}"]
}
success = await self._send_message(subscription_msg)
if success:
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'trades' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('trades')
self.subscribed_streams.add(f"spot@public.deals.v3.api@{mexc_symbol}")
logger.info(f"Subscribed to trades for {symbol} ({mexc_symbol}) on MEXC")
else:
logger.error(f"Failed to subscribe to trades for {symbol} on MEXC")
except Exception as e:
logger.error(f"Error subscribing to trades for {symbol}: {e}")
raise
async def unsubscribe_orderbook(self, symbol: str) -> None:
"""Unsubscribe from order book updates."""
try:
mexc_symbol = self.normalize_symbol(symbol)
unsubscription_msg = {
"method": "UNSUBSCRIPTION",
"params": [f"spot@public.limit.depth.v3.api@{mexc_symbol}@20"]
}
success = await self._send_message(unsubscription_msg)
if success:
if symbol in self.subscriptions and 'orderbook' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('orderbook')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
self.subscribed_streams.discard(f"spot@public.limit.depth.v3.api@{mexc_symbol}@20")
logger.info(f"Unsubscribed from order book for {symbol} on MEXC")
except Exception as e:
logger.error(f"Error unsubscribing from order book for {symbol}: {e}")
raise
async def unsubscribe_trades(self, symbol: str) -> None:
"""Unsubscribe from trade updates."""
try:
mexc_symbol = self.normalize_symbol(symbol)
unsubscription_msg = {
"method": "UNSUBSCRIPTION",
"params": [f"spot@public.deals.v3.api@{mexc_symbol}"]
}
success = await self._send_message(unsubscription_msg)
if success:
if symbol in self.subscriptions and 'trades' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('trades')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
self.subscribed_streams.discard(f"spot@public.deals.v3.api@{mexc_symbol}")
logger.info(f"Unsubscribed from trades for {symbol} on MEXC")
except Exception as e:
logger.error(f"Error unsubscribing from trades for {symbol}: {e}")
raise
async def get_symbols(self) -> List[str]:
"""Get available symbols from MEXC."""
try:
import aiohttp
async with aiohttp.ClientSession() as session:
async with session.get(f"{self.API_URL}/api/v3/exchangeInfo") as response:
if response.status == 200:
data = await response.json()
symbols = [
symbol_info['symbol']
for symbol_info in data.get('symbols', [])
if symbol_info.get('status') == 'TRADING'
]
logger.info(f"Retrieved {len(symbols)} symbols from MEXC")
return symbols
else:
logger.error(f"Failed to get symbols from MEXC: HTTP {response.status}")
return []
except Exception as e:
logger.error(f"Error getting symbols from MEXC: {e}")
return []
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
"""Get order book snapshot from MEXC REST API."""
try:
import aiohttp
mexc_symbol = self.normalize_symbol(symbol)
url = f"{self.API_URL}/api/v3/depth"
params = {'symbol': mexc_symbol, 'limit': min(depth, 5000)}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
if response.status == 200:
data = await response.json()
return self._parse_orderbook_snapshot(data, symbol)
else:
logger.error(f"Failed to get order book for {symbol}: HTTP {response.status}")
return None
except Exception as e:
logger.error(f"Error getting order book snapshot for {symbol}: {e}")
return None
def _parse_orderbook_snapshot(self, data: Dict, symbol: str) -> OrderBookSnapshot:
"""Parse MEXC order book data."""
try:
bids = []
for bid_data in data.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in data.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
return OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.now(timezone.utc),
bids=bids,
asks=asks,
sequence_id=data.get('lastUpdateId')
)
except Exception as e:
logger.error(f"Error parsing order book snapshot: {e}")
raise ValidationError(f"Invalid order book data: {e}", "PARSE_ERROR")
async def _handle_orderbook_update(self, data: Dict) -> None:
"""Handle order book update from MEXC."""
# Implementation would parse MEXC-specific order book update format
logger.debug("Received MEXC order book update")
async def _handle_orderbook_snapshot(self, data: Dict) -> None:
"""Handle order book snapshot from MEXC."""
# Implementation would parse MEXC-specific order book snapshot format
logger.debug("Received MEXC order book snapshot")
async def _handle_trade_update(self, data: Dict) -> None:
"""Handle trade update from MEXC."""
# Implementation would parse MEXC-specific trade format
logger.debug("Received MEXC trade update")
async def _handle_pong(self, data: Dict) -> None:
"""Handle pong response from MEXC."""
logger.debug("Received MEXC pong")
def get_mexc_stats(self) -> Dict[str, Any]:
"""Get MEXC-specific statistics."""
base_stats = self.get_stats()
mexc_stats = {
'subscribed_streams': list(self.subscribed_streams),
'authenticated': bool(self.api_key and self.api_secret),
'next_request_id': self.request_id
}
base_stats.update(mexc_stats)
return base_stats

View File

@ -0,0 +1,660 @@
"""
OKX exchange connector implementation.
Supports WebSocket connections to OKX with their V5 API WebSocket streams.
"""
import json
import hmac
import hashlib
import base64
import time
from typing import Dict, List, Optional, Any
from datetime import datetime, timezone
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ValidationError, ConnectionError
from ..utils.validation import validate_symbol, validate_price, validate_volume
from .base_connector import BaseExchangeConnector
logger = get_logger(__name__)
class OKXConnector(BaseExchangeConnector):
"""
OKX WebSocket connector implementation.
Supports:
- V5 API WebSocket streams
- Order book streams
- Trade streams
- Symbol normalization
- Authentication for private channels
"""
# OKX WebSocket URLs
WEBSOCKET_URL = "wss://ws.okx.com:8443/ws/v5/public"
WEBSOCKET_PRIVATE_URL = "wss://ws.okx.com:8443/ws/v5/private"
DEMO_WEBSOCKET_URL = "wss://wspap.okx.com:8443/ws/v5/public?brokerId=9999"
API_URL = "https://www.okx.com"
def __init__(self, use_demo: bool = False, api_key: str = None,
api_secret: str = None, passphrase: str = None):
"""
Initialize OKX connector.
Args:
use_demo: Whether to use demo environment
api_key: API key for authentication (optional)
api_secret: API secret for authentication (optional)
passphrase: API passphrase for authentication (optional)
"""
websocket_url = self.DEMO_WEBSOCKET_URL if use_demo else self.WEBSOCKET_URL
super().__init__("okx", websocket_url)
# Authentication credentials (optional)
self.api_key = api_key
self.api_secret = api_secret
self.passphrase = passphrase
self.use_demo = use_demo
# OKX-specific message handlers
self.message_handlers.update({
'books': self._handle_orderbook_update,
'trades': self._handle_trade_update,
'error': self._handle_error_message,
'subscribe': self._handle_subscription_response,
'unsubscribe': self._handle_subscription_response
})
# Subscription tracking
self.subscribed_channels = set()
logger.info(f"OKX connector initialized ({'demo' if use_demo else 'live'})")
def _get_message_type(self, data: Dict) -> str:
"""
Determine message type from OKX message data.
Args:
data: Parsed message data
Returns:
str: Message type identifier
"""
# OKX V5 API message format
if 'event' in data:
return data['event'] # 'subscribe', 'unsubscribe', 'error'
elif 'arg' in data and 'data' in data:
# Data message
channel = data['arg'].get('channel', '')
return channel
elif 'op' in data:
return data['op'] # 'ping', 'pong'
return 'unknown'
def normalize_symbol(self, symbol: str) -> str:
"""
Normalize symbol to OKX format.
Args:
symbol: Standard symbol format (e.g., 'BTCUSDT')
Returns:
str: OKX symbol format (e.g., 'BTC-USDT')
"""
# OKX uses dash-separated format
if symbol.upper() == 'BTCUSDT':
return 'BTC-USDT'
elif symbol.upper() == 'ETHUSDT':
return 'ETH-USDT'
elif symbol.upper().endswith('USDT'):
base = symbol[:-4].upper()
return f"{base}-USDT"
elif symbol.upper().endswith('USD'):
base = symbol[:-3].upper()
return f"{base}-USD"
else:
# Assume it's already in correct format or add dash
if '-' not in symbol:
# Try to split common patterns
if len(symbol) >= 6:
# Assume last 4 chars are quote currency
base = symbol[:-4].upper()
quote = symbol[-4:].upper()
return f"{base}-{quote}"
else:
return symbol.upper()
else:
return symbol.upper()
def _denormalize_symbol(self, okx_symbol: str) -> str:
"""
Convert OKX symbol back to standard format.
Args:
okx_symbol: OKX symbol format (e.g., 'BTC-USDT')
Returns:
str: Standard symbol format (e.g., 'BTCUSDT')
"""
if '-' in okx_symbol:
return okx_symbol.replace('-', '')
return okx_symbol
async def subscribe_orderbook(self, symbol: str) -> None:
"""
Subscribe to order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
okx_symbol = self.normalize_symbol(symbol)
# Create subscription message
subscription_msg = {
"op": "subscribe",
"args": [
{
"channel": "books",
"instId": okx_symbol
}
]
}
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'orderbook' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('orderbook')
self.subscribed_channels.add(f"books:{okx_symbol}")
logger.info(f"Subscribed to order book for {symbol} ({okx_symbol}) on OKX")
else:
logger.error(f"Failed to subscribe to order book for {symbol} on OKX")
except Exception as e:
logger.error(f"Error subscribing to order book for {symbol}: {e}")
raise
async def subscribe_trades(self, symbol: str) -> None:
"""
Subscribe to trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
set_correlation_id()
okx_symbol = self.normalize_symbol(symbol)
# Create subscription message
subscription_msg = {
"op": "subscribe",
"args": [
{
"channel": "trades",
"instId": okx_symbol
}
]
}
# Send subscription
success = await self._send_message(subscription_msg)
if success:
# Track subscription
if symbol not in self.subscriptions:
self.subscriptions[symbol] = []
if 'trades' not in self.subscriptions[symbol]:
self.subscriptions[symbol].append('trades')
self.subscribed_channels.add(f"trades:{okx_symbol}")
logger.info(f"Subscribed to trades for {symbol} ({okx_symbol}) on OKX")
else:
logger.error(f"Failed to subscribe to trades for {symbol} on OKX")
except Exception as e:
logger.error(f"Error subscribing to trades for {symbol}: {e}")
raise
async def unsubscribe_orderbook(self, symbol: str) -> None:
"""
Unsubscribe from order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
okx_symbol = self.normalize_symbol(symbol)
# Create unsubscription message
unsubscription_msg = {
"op": "unsubscribe",
"args": [
{
"channel": "books",
"instId": okx_symbol
}
]
}
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'orderbook' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('orderbook')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
self.subscribed_channels.discard(f"books:{okx_symbol}")
logger.info(f"Unsubscribed from order book for {symbol} ({okx_symbol}) on OKX")
else:
logger.error(f"Failed to unsubscribe from order book for {symbol} on OKX")
except Exception as e:
logger.error(f"Error unsubscribing from order book for {symbol}: {e}")
raise
async def unsubscribe_trades(self, symbol: str) -> None:
"""
Unsubscribe from trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
try:
okx_symbol = self.normalize_symbol(symbol)
# Create unsubscription message
unsubscription_msg = {
"op": "unsubscribe",
"args": [
{
"channel": "trades",
"instId": okx_symbol
}
]
}
# Send unsubscription
success = await self._send_message(unsubscription_msg)
if success:
# Remove from tracking
if symbol in self.subscriptions and 'trades' in self.subscriptions[symbol]:
self.subscriptions[symbol].remove('trades')
if not self.subscriptions[symbol]:
del self.subscriptions[symbol]
self.subscribed_channels.discard(f"trades:{okx_symbol}")
logger.info(f"Unsubscribed from trades for {symbol} ({okx_symbol}) on OKX")
else:
logger.error(f"Failed to unsubscribe from trades for {symbol} on OKX")
except Exception as e:
logger.error(f"Error unsubscribing from trades for {symbol}: {e}")
raise
async def get_symbols(self) -> List[str]:
"""
Get list of available trading symbols from OKX.
Returns:
List[str]: List of available symbols in standard format
"""
try:
import aiohttp
api_url = "https://www.okx.com"
async with aiohttp.ClientSession() as session:
async with session.get(f"{api_url}/api/v5/public/instruments",
params={"instType": "SPOT"}) as response:
if response.status == 200:
data = await response.json()
if data.get('code') != '0':
logger.error(f"OKX API error: {data.get('msg')}")
return []
symbols = []
instruments = data.get('data', [])
for instrument in instruments:
if instrument.get('state') == 'live':
inst_id = instrument.get('instId', '')
# Convert to standard format
standard_symbol = self._denormalize_symbol(inst_id)
symbols.append(standard_symbol)
logger.info(f"Retrieved {len(symbols)} symbols from OKX")
return symbols
else:
logger.error(f"Failed to get symbols from OKX: HTTP {response.status}")
return []
except Exception as e:
logger.error(f"Error getting symbols from OKX: {e}")
return []
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
"""
Get current order book snapshot from OKX REST API.
Args:
symbol: Trading symbol
depth: Number of price levels to retrieve
Returns:
OrderBookSnapshot: Current order book or None if unavailable
"""
try:
import aiohttp
okx_symbol = self.normalize_symbol(symbol)
api_url = "https://www.okx.com"
# OKX supports depths up to 400
api_depth = min(depth, 400)
url = f"{api_url}/api/v5/market/books"
params = {
'instId': okx_symbol,
'sz': api_depth
}
async with aiohttp.ClientSession() as session:
async with session.get(url, params=params) as response:
if response.status == 200:
data = await response.json()
if data.get('code') != '0':
logger.error(f"OKX API error: {data.get('msg')}")
return None
result_data = data.get('data', [])
if result_data:
return self._parse_orderbook_snapshot(result_data[0], symbol)
else:
return None
else:
logger.error(f"Failed to get order book for {symbol}: HTTP {response.status}")
return None
except Exception as e:
logger.error(f"Error getting order book snapshot for {symbol}: {e}")
return None
def _parse_orderbook_snapshot(self, data: Dict, symbol: str) -> OrderBookSnapshot:
"""
Parse OKX order book data into OrderBookSnapshot.
Args:
data: Raw OKX order book data
symbol: Trading symbol
Returns:
OrderBookSnapshot: Parsed order book
"""
try:
# Parse bids and asks
bids = []
for bid_data in data.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in data.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(int(data.get('ts', 0)) / 1000, tz=timezone.utc),
bids=bids,
asks=asks,
sequence_id=int(data.get('seqId', 0))
)
return orderbook
except Exception as e:
logger.error(f"Error parsing order book snapshot: {e}")
raise ValidationError(f"Invalid order book data: {e}", "PARSE_ERROR")
async def _handle_orderbook_update(self, data: Dict) -> None:
"""
Handle order book update from OKX.
Args:
data: Order book update data
"""
try:
set_correlation_id()
# Extract symbol from arg
arg = data.get('arg', {})
okx_symbol = arg.get('instId', '')
if not okx_symbol:
logger.warning("Order book update missing instId")
return
symbol = self._denormalize_symbol(okx_symbol)
# Process each data item
for book_data in data.get('data', []):
# Parse bids and asks
bids = []
for bid_data in book_data.get('bids', []):
price = float(bid_data[0])
size = float(bid_data[1])
if validate_price(price) and validate_volume(size):
bids.append(PriceLevel(price=price, size=size))
asks = []
for ask_data in book_data.get('asks', []):
price = float(ask_data[0])
size = float(ask_data[1])
if validate_price(price) and validate_volume(size):
asks.append(PriceLevel(price=price, size=size))
# Create order book snapshot
orderbook = OrderBookSnapshot(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(int(book_data.get('ts', 0)) / 1000, tz=timezone.utc),
bids=bids,
asks=asks,
sequence_id=int(book_data.get('seqId', 0))
)
# Notify callbacks
self._notify_data_callbacks(orderbook)
logger.debug(f"Processed order book update for {symbol}")
except Exception as e:
logger.error(f"Error handling order book update: {e}")
async def _handle_trade_update(self, data: Dict) -> None:
"""
Handle trade update from OKX.
Args:
data: Trade update data
"""
try:
set_correlation_id()
# Extract symbol from arg
arg = data.get('arg', {})
okx_symbol = arg.get('instId', '')
if not okx_symbol:
logger.warning("Trade update missing instId")
return
symbol = self._denormalize_symbol(okx_symbol)
# Process each trade
for trade_data in data.get('data', []):
price = float(trade_data.get('px', 0))
size = float(trade_data.get('sz', 0))
# Validate data
if not validate_price(price) or not validate_volume(size):
logger.warning(f"Invalid trade data: price={price}, size={size}")
continue
# Determine side (OKX uses 'side' field)
side = trade_data.get('side', 'unknown').lower()
# Create trade event
trade = TradeEvent(
symbol=symbol,
exchange=self.exchange_name,
timestamp=datetime.fromtimestamp(int(trade_data.get('ts', 0)) / 1000, tz=timezone.utc),
price=price,
size=size,
side=side,
trade_id=str(trade_data.get('tradeId', ''))
)
# Notify callbacks
self._notify_data_callbacks(trade)
logger.debug(f"Processed trade for {symbol}: {side} {size} @ {price}")
except Exception as e:
logger.error(f"Error handling trade update: {e}")
async def _handle_subscription_response(self, data: Dict) -> None:
"""
Handle subscription response from OKX.
Args:
data: Subscription response data
"""
try:
event = data.get('event', '')
arg = data.get('arg', {})
channel = arg.get('channel', '')
inst_id = arg.get('instId', '')
if event == 'subscribe':
logger.info(f"OKX subscription confirmed: {channel} for {inst_id}")
elif event == 'unsubscribe':
logger.info(f"OKX unsubscription confirmed: {channel} for {inst_id}")
elif event == 'error':
error_msg = data.get('msg', 'Unknown error')
logger.error(f"OKX subscription error: {error_msg}")
except Exception as e:
logger.error(f"Error handling subscription response: {e}")
async def _handle_error_message(self, data: Dict) -> None:
"""
Handle error message from OKX.
Args:
data: Error message data
"""
error_code = data.get('code', 'unknown')
error_msg = data.get('msg', 'Unknown error')
logger.error(f"OKX error {error_code}: {error_msg}")
# Handle specific error codes
if error_code == '60012':
logger.error("Invalid request - check parameters")
elif error_code == '60013':
logger.error("Invalid channel - check channel name")
def _get_auth_headers(self, timestamp: str, method: str = "GET",
request_path: str = "/users/self/verify") -> Dict[str, str]:
"""
Generate authentication headers for OKX API.
Args:
timestamp: Current timestamp
method: HTTP method
request_path: Request path
Returns:
Dict: Authentication headers
"""
if not all([self.api_key, self.api_secret, self.passphrase]):
return {}
try:
# Create signature
message = timestamp + method + request_path
signature = base64.b64encode(
hmac.new(
self.api_secret.encode('utf-8'),
message.encode('utf-8'),
hashlib.sha256
).digest()
).decode('utf-8')
# Create passphrase signature
passphrase_signature = base64.b64encode(
hmac.new(
self.api_secret.encode('utf-8'),
self.passphrase.encode('utf-8'),
hashlib.sha256
).digest()
).decode('utf-8')
return {
'OK-ACCESS-KEY': self.api_key,
'OK-ACCESS-SIGN': signature,
'OK-ACCESS-TIMESTAMP': timestamp,
'OK-ACCESS-PASSPHRASE': passphrase_signature
}
except Exception as e:
logger.error(f"Error generating auth headers: {e}")
return {}
async def _send_ping(self) -> None:
"""Send ping to keep connection alive."""
try:
ping_msg = {"op": "ping"}
await self._send_message(ping_msg)
logger.debug("Sent ping to OKX")
except Exception as e:
logger.error(f"Error sending ping: {e}")
def get_okx_stats(self) -> Dict[str, Any]:
"""Get OKX-specific statistics."""
base_stats = self.get_stats()
okx_stats = {
'subscribed_channels': list(self.subscribed_channels),
'use_demo': self.use_demo,
'authenticated': bool(self.api_key and self.api_secret and self.passphrase)
}
base_stats.update(okx_stats)
return base_stats

273
COBY/docker/README.md Normal file
View File

@ -0,0 +1,273 @@
# Market Data Infrastructure Docker Setup
This directory contains Docker Compose configurations and scripts for deploying TimescaleDB and Redis infrastructure for the multi-exchange data aggregation system.
## 🏗️ Architecture
- **TimescaleDB**: Time-series database optimized for high-frequency market data
- **Redis**: High-performance caching layer for real-time data
- **Network**: Isolated Docker network for secure communication
## 📋 Prerequisites
- Docker Engine 20.10+
- Docker Compose 2.0+
- At least 4GB RAM available for containers
- 50GB+ disk space for data storage
## 🚀 Quick Start
1. **Copy environment file**:
```bash
cp .env.example .env
```
2. **Edit configuration** (update passwords and settings):
```bash
nano .env
```
3. **Deploy infrastructure**:
```bash
chmod +x deploy.sh
./deploy.sh
```
4. **Verify deployment**:
```bash
docker-compose -f timescaledb-compose.yml ps
```
## 📁 File Structure
```
docker/
├── timescaledb-compose.yml # Main Docker Compose configuration
├── init-scripts/ # Database initialization scripts
│ └── 01-init-timescaledb.sql
├── redis.conf # Redis configuration
├── .env # Environment variables
├── deploy.sh # Deployment script
├── backup.sh # Backup script
├── restore.sh # Restore script
└── README.md # This file
```
## ⚙️ Configuration
### Environment Variables
Key variables in `.env`:
```bash
# Database credentials
POSTGRES_PASSWORD=your_secure_password
POSTGRES_USER=market_user
POSTGRES_DB=market_data
# Redis settings
REDIS_PASSWORD=your_redis_password
# Performance tuning
POSTGRES_SHARED_BUFFERS=256MB
POSTGRES_EFFECTIVE_CACHE_SIZE=1GB
REDIS_MAXMEMORY=2gb
```
### TimescaleDB Configuration
The database is pre-configured with:
- Optimized PostgreSQL settings for time-series data
- TimescaleDB extension enabled
- Hypertables for automatic partitioning
- Retention policies (90 days for raw data)
- Continuous aggregates for common queries
- Proper indexes for query performance
### Redis Configuration
Redis is configured for:
- High-frequency data caching
- Memory optimization (2GB limit)
- Persistence with AOF and RDB
- Optimized for order book data structures
## 🔌 Connection Details
After deployment, connect using:
### TimescaleDB
```
Host: 192.168.0.10
Port: 5432
Database: market_data
Username: market_user
Password: (from .env file)
```
### Redis
```
Host: 192.168.0.10
Port: 6379
Password: (from .env file)
```
## 🗄️ Database Schema
The system creates the following tables:
- `order_book_snapshots`: Real-time order book data
- `trade_events`: Individual trade events
- `heatmap_data`: Aggregated price bucket data
- `ohlcv_data`: OHLCV candlestick data
- `exchange_status`: Exchange connection monitoring
- `system_metrics`: System performance metrics
## 💾 Backup & Restore
### Create Backup
```bash
chmod +x backup.sh
./backup.sh
```
Backups are stored in `./backups/` with timestamp.
### Restore from Backup
```bash
chmod +x restore.sh
./restore.sh market_data_backup_YYYYMMDD_HHMMSS.tar.gz
```
### Automated Backups
Set up a cron job for regular backups:
```bash
# Daily backup at 2 AM
0 2 * * * /path/to/docker/backup.sh
```
## 📊 Monitoring
### Health Checks
Check service health:
```bash
# TimescaleDB
docker exec market_data_timescaledb pg_isready -U market_user -d market_data
# Redis
docker exec market_data_redis redis-cli -a your_password ping
```
### View Logs
```bash
# All services
docker-compose -f timescaledb-compose.yml logs -f
# Specific service
docker-compose -f timescaledb-compose.yml logs -f timescaledb
```
### Database Queries
Connect to TimescaleDB:
```bash
docker exec -it market_data_timescaledb psql -U market_user -d market_data
```
Example queries:
```sql
-- Check table sizes
SELECT
schemaname,
tablename,
pg_size_pretty(pg_total_relation_size(schemaname||'.'||tablename)) as size
FROM pg_tables
WHERE schemaname = 'market_data';
-- Recent order book data
SELECT * FROM market_data.order_book_snapshots
ORDER BY timestamp DESC LIMIT 10;
-- Exchange status
SELECT * FROM market_data.exchange_status
ORDER BY timestamp DESC LIMIT 10;
```
## 🔧 Maintenance
### Update Images
```bash
docker-compose -f timescaledb-compose.yml pull
docker-compose -f timescaledb-compose.yml up -d
```
### Clean Up Old Data
```bash
# TimescaleDB has automatic retention policies
# Manual cleanup if needed:
docker exec market_data_timescaledb psql -U market_user -d market_data -c "
SELECT drop_chunks('market_data.order_book_snapshots', INTERVAL '30 days');
"
```
### Scale Resources
Edit `timescaledb-compose.yml` to adjust:
- Memory limits
- CPU limits
- Shared buffers
- Connection limits
## 🚨 Troubleshooting
### Common Issues
1. **Port conflicts**: Change ports in compose file if 5432/6379 are in use
2. **Memory issues**: Reduce shared_buffers and Redis maxmemory
3. **Disk space**: Monitor `/var/lib/docker/volumes/` usage
4. **Connection refused**: Check firewall settings and container status
### Performance Tuning
1. **TimescaleDB**:
- Adjust `shared_buffers` based on available RAM
- Tune `effective_cache_size` to 75% of system RAM
- Monitor query performance with `pg_stat_statements`
2. **Redis**:
- Adjust `maxmemory` based on data volume
- Monitor memory usage with `INFO memory`
- Use appropriate eviction policy
### Recovery Procedures
1. **Container failure**: `docker-compose restart <service>`
2. **Data corruption**: Restore from latest backup
3. **Network issues**: Check Docker network configuration
4. **Performance degradation**: Review logs and system metrics
## 🔐 Security
- Change default passwords in `.env`
- Use strong passwords (20+ characters)
- Restrict network access to trusted IPs
- Regular security updates
- Monitor access logs
- Enable SSL/TLS for production
## 📞 Support
For issues related to:
- TimescaleDB: Check [TimescaleDB docs](https://docs.timescale.com/)
- Redis: Check [Redis docs](https://redis.io/documentation)
- Docker: Check [Docker docs](https://docs.docker.com/)
## 🔄 Updates
This infrastructure supports:
- Rolling updates with zero downtime
- Blue-green deployments
- Automated failover
- Data migration scripts

108
COBY/docker/backup.sh Normal file
View File

@ -0,0 +1,108 @@
#!/bin/bash
# Backup script for market data infrastructure
# Run this script regularly to backup your data
set -e
# Configuration
BACKUP_DIR="./backups"
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
RETENTION_DAYS=30
# Load environment variables
if [ -f .env ]; then
source .env
fi
echo "🗄️ Starting backup process..."
# Create backup directory if it doesn't exist
mkdir -p "$BACKUP_DIR"
# Backup TimescaleDB
echo "📊 Backing up TimescaleDB..."
docker exec market_data_timescaledb pg_dump \
-U market_user \
-d market_data \
--verbose \
--no-password \
--format=custom \
--compress=9 \
> "$BACKUP_DIR/timescaledb_backup_$TIMESTAMP.dump"
if [ $? -eq 0 ]; then
echo "✅ TimescaleDB backup completed: timescaledb_backup_$TIMESTAMP.dump"
else
echo "❌ TimescaleDB backup failed"
exit 1
fi
# Backup Redis
echo "📦 Backing up Redis..."
docker exec market_data_redis redis-cli \
-a "$REDIS_PASSWORD" \
--rdb /data/redis_backup_$TIMESTAMP.rdb \
BGSAVE
# Wait for Redis backup to complete
sleep 5
# Copy Redis backup from container
docker cp market_data_redis:/data/redis_backup_$TIMESTAMP.rdb "$BACKUP_DIR/"
if [ $? -eq 0 ]; then
echo "✅ Redis backup completed: redis_backup_$TIMESTAMP.rdb"
else
echo "❌ Redis backup failed"
exit 1
fi
# Create backup metadata
cat > "$BACKUP_DIR/backup_$TIMESTAMP.info" << EOF
Backup Information
==================
Timestamp: $TIMESTAMP
Date: $(date)
TimescaleDB Backup: timescaledb_backup_$TIMESTAMP.dump
Redis Backup: redis_backup_$TIMESTAMP.rdb
Container Versions:
TimescaleDB: $(docker exec market_data_timescaledb psql -U market_user -d market_data -t -c "SELECT version();")
Redis: $(docker exec market_data_redis redis-cli -a "$REDIS_PASSWORD" INFO server | grep redis_version)
Database Size:
$(docker exec market_data_timescaledb psql -U market_user -d market_data -c "\l+")
EOF
# Compress backups
echo "🗜️ Compressing backups..."
tar -czf "$BACKUP_DIR/market_data_backup_$TIMESTAMP.tar.gz" \
-C "$BACKUP_DIR" \
"timescaledb_backup_$TIMESTAMP.dump" \
"redis_backup_$TIMESTAMP.rdb" \
"backup_$TIMESTAMP.info"
# Remove individual files after compression
rm "$BACKUP_DIR/timescaledb_backup_$TIMESTAMP.dump"
rm "$BACKUP_DIR/redis_backup_$TIMESTAMP.rdb"
rm "$BACKUP_DIR/backup_$TIMESTAMP.info"
echo "✅ Compressed backup created: market_data_backup_$TIMESTAMP.tar.gz"
# Clean up old backups
echo "🧹 Cleaning up old backups (older than $RETENTION_DAYS days)..."
find "$BACKUP_DIR" -name "market_data_backup_*.tar.gz" -mtime +$RETENTION_DAYS -delete
# Display backup information
BACKUP_SIZE=$(du -h "$BACKUP_DIR/market_data_backup_$TIMESTAMP.tar.gz" | cut -f1)
echo ""
echo "📋 Backup Summary:"
echo " File: market_data_backup_$TIMESTAMP.tar.gz"
echo " Size: $BACKUP_SIZE"
echo " Location: $BACKUP_DIR"
echo ""
echo "🔄 To restore from this backup:"
echo " ./restore.sh market_data_backup_$TIMESTAMP.tar.gz"
echo ""
echo "✅ Backup process completed successfully!"

112
COBY/docker/deploy.sh Normal file
View File

@ -0,0 +1,112 @@
#!/bin/bash
# Deployment script for market data infrastructure
# Run this on your Docker host at 192.168.0.10
set -e
echo "🚀 Deploying Market Data Infrastructure..."
# Check if Docker and Docker Compose are available
if ! command -v docker &> /dev/null; then
echo "❌ Docker is not installed or not in PATH"
exit 1
fi
if ! command -v docker-compose &> /dev/null && ! docker compose version &> /dev/null; then
echo "❌ Docker Compose is not installed or not in PATH"
exit 1
fi
# Set Docker Compose command
if docker compose version &> /dev/null; then
DOCKER_COMPOSE="docker compose"
else
DOCKER_COMPOSE="docker-compose"
fi
# Create necessary directories
echo "📁 Creating directories..."
mkdir -p ./data/timescale
mkdir -p ./data/redis
mkdir -p ./logs
mkdir -p ./backups
# Set proper permissions
echo "🔐 Setting permissions..."
chmod 755 ./data/timescale
chmod 755 ./data/redis
chmod 755 ./logs
chmod 755 ./backups
# Copy environment file if it doesn't exist
if [ ! -f .env ]; then
echo "📋 Creating .env file..."
cp .env.example .env
echo "⚠️ Please edit .env file with your specific configuration"
echo "⚠️ Default passwords are set - change them for production!"
fi
# Pull latest images
echo "📥 Pulling Docker images..."
$DOCKER_COMPOSE -f timescaledb-compose.yml pull
# Stop existing containers if running
echo "🛑 Stopping existing containers..."
$DOCKER_COMPOSE -f timescaledb-compose.yml down
# Start the services
echo "🏃 Starting services..."
$DOCKER_COMPOSE -f timescaledb-compose.yml up -d
# Wait for services to be ready
echo "⏳ Waiting for services to be ready..."
sleep 30
# Check service health
echo "🏥 Checking service health..."
# Check TimescaleDB
if docker exec market_data_timescaledb pg_isready -U market_user -d market_data; then
echo "✅ TimescaleDB is ready"
else
echo "❌ TimescaleDB is not ready"
exit 1
fi
# Check Redis
if docker exec market_data_redis redis-cli -a market_data_redis_2024 ping | grep -q PONG; then
echo "✅ Redis is ready"
else
echo "❌ Redis is not ready"
exit 1
fi
# Display connection information
echo ""
echo "🎉 Deployment completed successfully!"
echo ""
echo "📊 Connection Information:"
echo " TimescaleDB:"
echo " Host: 192.168.0.10"
echo " Port: 5432"
echo " Database: market_data"
echo " Username: market_user"
echo " Password: (check .env file)"
echo ""
echo " Redis:"
echo " Host: 192.168.0.10"
echo " Port: 6379"
echo " Password: (check .env file)"
echo ""
echo "📝 Next steps:"
echo " 1. Update your application configuration to use these connection details"
echo " 2. Test the connection from your application"
echo " 3. Set up monitoring and alerting"
echo " 4. Configure backup schedules"
echo ""
echo "🔍 To view logs:"
echo " docker-compose -f timescaledb-compose.yml logs -f"
echo ""
echo "🛑 To stop services:"
echo " docker-compose -f timescaledb-compose.yml down"

View File

@ -0,0 +1,214 @@
-- Initialize TimescaleDB extension and create market data schema
CREATE EXTENSION IF NOT EXISTS timescaledb;
-- Create database schema for market data
CREATE SCHEMA IF NOT EXISTS market_data;
-- Set search path
SET search_path TO market_data, public;
-- Order book snapshots table
CREATE TABLE IF NOT EXISTS order_book_snapshots (
id BIGSERIAL,
symbol VARCHAR(20) NOT NULL,
exchange VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
bids JSONB NOT NULL,
asks JSONB NOT NULL,
sequence_id BIGINT,
mid_price DECIMAL(20,8),
spread DECIMAL(20,8),
bid_volume DECIMAL(30,8),
ask_volume DECIMAL(30,8),
created_at TIMESTAMPTZ DEFAULT NOW(),
PRIMARY KEY (timestamp, symbol, exchange)
);
-- Convert to hypertable
SELECT create_hypertable('order_book_snapshots', 'timestamp', if_not_exists => TRUE);
-- Create indexes for better query performance
CREATE INDEX IF NOT EXISTS idx_order_book_symbol_exchange ON order_book_snapshots (symbol, exchange, timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_order_book_timestamp ON order_book_snapshots (timestamp DESC);
-- Trade events table
CREATE TABLE IF NOT EXISTS trade_events (
id BIGSERIAL,
symbol VARCHAR(20) NOT NULL,
exchange VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
price DECIMAL(20,8) NOT NULL,
size DECIMAL(30,8) NOT NULL,
side VARCHAR(4) NOT NULL,
trade_id VARCHAR(100) NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW(),
PRIMARY KEY (timestamp, symbol, exchange, trade_id)
);
-- Convert to hypertable
SELECT create_hypertable('trade_events', 'timestamp', if_not_exists => TRUE);
-- Create indexes for trade events
CREATE INDEX IF NOT EXISTS idx_trade_events_symbol_exchange ON trade_events (symbol, exchange, timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_trade_events_timestamp ON trade_events (timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_trade_events_price ON trade_events (symbol, price, timestamp DESC);
-- Aggregated heatmap data table
CREATE TABLE IF NOT EXISTS heatmap_data (
symbol VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
bucket_size DECIMAL(10,2) NOT NULL,
price_bucket DECIMAL(20,8) NOT NULL,
volume DECIMAL(30,8) NOT NULL,
side VARCHAR(3) NOT NULL,
exchange_count INTEGER NOT NULL,
exchanges JSONB,
created_at TIMESTAMPTZ DEFAULT NOW(),
PRIMARY KEY (timestamp, symbol, bucket_size, price_bucket, side)
);
-- Convert to hypertable
SELECT create_hypertable('heatmap_data', 'timestamp', if_not_exists => TRUE);
-- Create indexes for heatmap data
CREATE INDEX IF NOT EXISTS idx_heatmap_symbol_bucket ON heatmap_data (symbol, bucket_size, timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_heatmap_timestamp ON heatmap_data (timestamp DESC);
-- OHLCV data table
CREATE TABLE IF NOT EXISTS ohlcv_data (
symbol VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
timeframe VARCHAR(10) NOT NULL,
open_price DECIMAL(20,8) NOT NULL,
high_price DECIMAL(20,8) NOT NULL,
low_price DECIMAL(20,8) NOT NULL,
close_price DECIMAL(20,8) NOT NULL,
volume DECIMAL(30,8) NOT NULL,
trade_count INTEGER,
vwap DECIMAL(20,8),
created_at TIMESTAMPTZ DEFAULT NOW(),
PRIMARY KEY (timestamp, symbol, timeframe)
);
-- Convert to hypertable
SELECT create_hypertable('ohlcv_data', 'timestamp', if_not_exists => TRUE);
-- Create indexes for OHLCV data
CREATE INDEX IF NOT EXISTS idx_ohlcv_symbol_timeframe ON ohlcv_data (symbol, timeframe, timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_ohlcv_timestamp ON ohlcv_data (timestamp DESC);
-- Exchange status tracking table
CREATE TABLE IF NOT EXISTS exchange_status (
exchange VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
status VARCHAR(20) NOT NULL, -- 'connected', 'disconnected', 'error'
last_message_time TIMESTAMPTZ,
error_message TEXT,
connection_count INTEGER DEFAULT 0,
created_at TIMESTAMPTZ DEFAULT NOW(),
PRIMARY KEY (timestamp, exchange)
);
-- Convert to hypertable
SELECT create_hypertable('exchange_status', 'timestamp', if_not_exists => TRUE);
-- Create indexes for exchange status
CREATE INDEX IF NOT EXISTS idx_exchange_status_exchange ON exchange_status (exchange, timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_exchange_status_timestamp ON exchange_status (timestamp DESC);
-- System metrics table for monitoring
CREATE TABLE IF NOT EXISTS system_metrics (
metric_name VARCHAR(50) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
value DECIMAL(20,8) NOT NULL,
labels JSONB,
created_at TIMESTAMPTZ DEFAULT NOW(),
PRIMARY KEY (timestamp, metric_name)
);
-- Convert to hypertable
SELECT create_hypertable('system_metrics', 'timestamp', if_not_exists => TRUE);
-- Create indexes for system metrics
CREATE INDEX IF NOT EXISTS idx_system_metrics_name ON system_metrics (metric_name, timestamp DESC);
CREATE INDEX IF NOT EXISTS idx_system_metrics_timestamp ON system_metrics (timestamp DESC);
-- Create retention policies (keep data for 90 days by default)
SELECT add_retention_policy('order_book_snapshots', INTERVAL '90 days', if_not_exists => TRUE);
SELECT add_retention_policy('trade_events', INTERVAL '90 days', if_not_exists => TRUE);
SELECT add_retention_policy('heatmap_data', INTERVAL '90 days', if_not_exists => TRUE);
SELECT add_retention_policy('ohlcv_data', INTERVAL '365 days', if_not_exists => TRUE);
SELECT add_retention_policy('exchange_status', INTERVAL '30 days', if_not_exists => TRUE);
SELECT add_retention_policy('system_metrics', INTERVAL '30 days', if_not_exists => TRUE);
-- Create continuous aggregates for common queries
CREATE MATERIALIZED VIEW IF NOT EXISTS hourly_ohlcv
WITH (timescaledb.continuous) AS
SELECT
symbol,
exchange,
time_bucket('1 hour', timestamp) AS hour,
first(price, timestamp) AS open_price,
max(price) AS high_price,
min(price) AS low_price,
last(price, timestamp) AS close_price,
sum(size) AS volume,
count(*) AS trade_count,
avg(price) AS vwap
FROM trade_events
GROUP BY symbol, exchange, hour
WITH NO DATA;
-- Add refresh policy for continuous aggregate
SELECT add_continuous_aggregate_policy('hourly_ohlcv',
start_offset => INTERVAL '3 hours',
end_offset => INTERVAL '1 hour',
schedule_interval => INTERVAL '1 hour',
if_not_exists => TRUE);
-- Create view for latest order book data
CREATE OR REPLACE VIEW latest_order_books AS
SELECT DISTINCT ON (symbol, exchange)
symbol,
exchange,
timestamp,
bids,
asks,
mid_price,
spread,
bid_volume,
ask_volume
FROM order_book_snapshots
ORDER BY symbol, exchange, timestamp DESC;
-- Create view for latest heatmap data
CREATE OR REPLACE VIEW latest_heatmaps AS
SELECT DISTINCT ON (symbol, bucket_size, price_bucket, side)
symbol,
bucket_size,
price_bucket,
side,
timestamp,
volume,
exchange_count,
exchanges
FROM heatmap_data
ORDER BY symbol, bucket_size, price_bucket, side, timestamp DESC;
-- Grant permissions to market_user
GRANT ALL PRIVILEGES ON SCHEMA market_data TO market_user;
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA market_data TO market_user;
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA market_data TO market_user;
GRANT ALL PRIVILEGES ON ALL FUNCTIONS IN SCHEMA market_data TO market_user;
-- Set default privileges for future objects
ALTER DEFAULT PRIVILEGES IN SCHEMA market_data GRANT ALL ON TABLES TO market_user;
ALTER DEFAULT PRIVILEGES IN SCHEMA market_data GRANT ALL ON SEQUENCES TO market_user;
ALTER DEFAULT PRIVILEGES IN SCHEMA market_data GRANT ALL ON FUNCTIONS TO market_user;
-- Create database user for read-only access (for dashboards)
CREATE USER IF NOT EXISTS dashboard_user WITH PASSWORD 'dashboard_read_2024';
GRANT CONNECT ON DATABASE market_data TO dashboard_user;
GRANT USAGE ON SCHEMA market_data TO dashboard_user;
GRANT SELECT ON ALL TABLES IN SCHEMA market_data TO dashboard_user;
ALTER DEFAULT PRIVILEGES IN SCHEMA market_data GRANT SELECT ON TABLES TO dashboard_user;

View File

@ -0,0 +1,37 @@
#!/bin/bash
# Manual database initialization script
# Run this to initialize the TimescaleDB schema
echo "🔧 Initializing TimescaleDB schema..."
# Check if we can connect to the database
echo "📡 Testing connection to TimescaleDB..."
# You can run this command on your Docker host (192.168.0.10)
# Replace with your actual password from the .env file
PGPASSWORD="market_data_secure_pass_2024" psql -h 192.168.0.10 -p 5432 -U market_user -d market_data -c "SELECT version();"
if [ $? -eq 0 ]; then
echo "✅ Connection successful!"
echo "🏗️ Creating database schema..."
# Execute the initialization script
PGPASSWORD="market_data_secure_pass_2024" psql -h 192.168.0.10 -p 5432 -U market_user -d market_data -f ../docker/init-scripts/01-init-timescaledb.sql
if [ $? -eq 0 ]; then
echo "✅ Database schema initialized successfully!"
echo "📊 Verifying tables..."
PGPASSWORD="market_data_secure_pass_2024" psql -h 192.168.0.10 -p 5432 -U market_user -d market_data -c "\dt market_data.*"
else
echo "❌ Schema initialization failed"
exit 1
fi
else
echo "❌ Cannot connect to database"
exit 1
fi

131
COBY/docker/redis.conf Normal file
View File

@ -0,0 +1,131 @@
# Redis configuration for market data caching
# Optimized for high-frequency trading data
# Network settings
bind 0.0.0.0
port 6379
tcp-backlog 511
timeout 0
tcp-keepalive 300
# General settings
daemonize no
supervised no
pidfile /var/run/redis_6379.pid
loglevel notice
logfile ""
databases 16
# Snapshotting (persistence)
save 900 1
save 300 10
save 60 10000
stop-writes-on-bgsave-error yes
rdbcompression yes
rdbchecksum yes
dbfilename dump.rdb
dir /data
# Replication
replica-serve-stale-data yes
replica-read-only yes
repl-diskless-sync no
repl-diskless-sync-delay 5
repl-ping-replica-period 10
repl-timeout 60
repl-disable-tcp-nodelay no
repl-backlog-size 1mb
repl-backlog-ttl 3600
# Security
requirepass market_data_redis_2024
# Memory management
maxmemory 2gb
maxmemory-policy allkeys-lru
maxmemory-samples 5
# Lazy freeing
lazyfree-lazy-eviction no
lazyfree-lazy-expire no
lazyfree-lazy-server-del no
replica-lazy-flush no
# Threaded I/O
io-threads 4
io-threads-do-reads yes
# Append only file (AOF)
appendonly yes
appendfilename "appendonly.aof"
appendfsync everysec
no-appendfsync-on-rewrite no
auto-aof-rewrite-percentage 100
auto-aof-rewrite-min-size 64mb
aof-load-truncated yes
aof-use-rdb-preamble yes
# Lua scripting
lua-time-limit 5000
# Slow log
slowlog-log-slower-than 10000
slowlog-max-len 128
# Latency monitor
latency-monitor-threshold 100
# Event notification
notify-keyspace-events ""
# Hash settings (optimized for order book data)
hash-max-ziplist-entries 512
hash-max-ziplist-value 64
# List settings
list-max-ziplist-size -2
list-compress-depth 0
# Set settings
set-max-intset-entries 512
# Sorted set settings
zset-max-ziplist-entries 128
zset-max-ziplist-value 64
# HyperLogLog settings
hll-sparse-max-bytes 3000
# Streams settings
stream-node-max-bytes 4096
stream-node-max-entries 100
# Active rehashing
activerehashing yes
# Client settings
client-output-buffer-limit normal 0 0 0
client-output-buffer-limit replica 256mb 64mb 60
client-output-buffer-limit pubsub 32mb 8mb 60
client-query-buffer-limit 1gb
# Protocol settings
proto-max-bulk-len 512mb
# Frequency settings
hz 10
# Dynamic HZ
dynamic-hz yes
# AOF rewrite settings
aof-rewrite-incremental-fsync yes
# RDB settings
rdb-save-incremental-fsync yes
# Jemalloc settings
jemalloc-bg-thread yes
# TLS settings (disabled for internal network)
tls-port 0

188
COBY/docker/restore.sh Normal file
View File

@ -0,0 +1,188 @@
#!/bin/bash
# Restore script for market data infrastructure
# Usage: ./restore.sh <backup_file.tar.gz>
set -e
# Check if backup file is provided
if [ $# -eq 0 ]; then
echo "❌ Usage: $0 <backup_file.tar.gz>"
echo "Available backups:"
ls -la ./backups/market_data_backup_*.tar.gz 2>/dev/null || echo "No backups found"
exit 1
fi
BACKUP_FILE="$1"
RESTORE_DIR="./restore_temp"
TIMESTAMP=$(date +"%Y%m%d_%H%M%S")
# Load environment variables
if [ -f .env ]; then
source .env
fi
echo "🔄 Starting restore process..."
echo "📁 Backup file: $BACKUP_FILE"
# Check if backup file exists
if [ ! -f "$BACKUP_FILE" ]; then
echo "❌ Backup file not found: $BACKUP_FILE"
exit 1
fi
# Create temporary restore directory
mkdir -p "$RESTORE_DIR"
# Extract backup
echo "📦 Extracting backup..."
tar -xzf "$BACKUP_FILE" -C "$RESTORE_DIR"
# Find extracted files
TIMESCALE_BACKUP=$(find "$RESTORE_DIR" -name "timescaledb_backup_*.dump" | head -1)
REDIS_BACKUP=$(find "$RESTORE_DIR" -name "redis_backup_*.rdb" | head -1)
BACKUP_INFO=$(find "$RESTORE_DIR" -name "backup_*.info" | head -1)
if [ -z "$TIMESCALE_BACKUP" ] || [ -z "$REDIS_BACKUP" ]; then
echo "❌ Invalid backup file structure"
rm -rf "$RESTORE_DIR"
exit 1
fi
# Display backup information
if [ -f "$BACKUP_INFO" ]; then
echo "📋 Backup Information:"
cat "$BACKUP_INFO"
echo ""
fi
# Confirm restore
read -p "⚠️ This will replace all existing data. Continue? (y/N): " -n 1 -r
echo
if [[ ! $REPLY =~ ^[Yy]$ ]]; then
echo "❌ Restore cancelled"
rm -rf "$RESTORE_DIR"
exit 1
fi
# Stop services
echo "🛑 Stopping services..."
docker-compose -f timescaledb-compose.yml down
# Backup current data (just in case)
echo "💾 Creating safety backup of current data..."
mkdir -p "./backups/pre_restore_$TIMESTAMP"
docker run --rm -v market_data_timescale_data:/data -v "$(pwd)/backups/pre_restore_$TIMESTAMP":/backup alpine tar czf /backup/current_timescale.tar.gz -C /data .
docker run --rm -v market_data_redis_data:/data -v "$(pwd)/backups/pre_restore_$TIMESTAMP":/backup alpine tar czf /backup/current_redis.tar.gz -C /data .
# Start only TimescaleDB for restore
echo "🏃 Starting TimescaleDB for restore..."
docker-compose -f timescaledb-compose.yml up -d timescaledb
# Wait for TimescaleDB to be ready
echo "⏳ Waiting for TimescaleDB to be ready..."
sleep 30
# Check if TimescaleDB is ready
if ! docker exec market_data_timescaledb pg_isready -U market_user -d market_data; then
echo "❌ TimescaleDB is not ready"
exit 1
fi
# Drop existing database and recreate
echo "🗑️ Dropping existing database..."
docker exec market_data_timescaledb psql -U postgres -c "DROP DATABASE IF EXISTS market_data;"
docker exec market_data_timescaledb psql -U postgres -c "CREATE DATABASE market_data OWNER market_user;"
# Restore TimescaleDB
echo "📊 Restoring TimescaleDB..."
docker cp "$TIMESCALE_BACKUP" market_data_timescaledb:/tmp/restore.dump
docker exec market_data_timescaledb pg_restore \
-U market_user \
-d market_data \
--verbose \
--no-password \
/tmp/restore.dump
if [ $? -eq 0 ]; then
echo "✅ TimescaleDB restore completed"
else
echo "❌ TimescaleDB restore failed"
exit 1
fi
# Stop TimescaleDB
docker-compose -f timescaledb-compose.yml stop timescaledb
# Restore Redis data
echo "📦 Restoring Redis data..."
# Remove existing Redis data
docker volume rm market_data_redis_data 2>/dev/null || true
docker volume create market_data_redis_data
# Copy Redis backup to volume
docker run --rm -v market_data_redis_data:/data -v "$(pwd)/$RESTORE_DIR":/backup alpine cp "/backup/$(basename "$REDIS_BACKUP")" /data/dump.rdb
# Start all services
echo "🏃 Starting all services..."
docker-compose -f timescaledb-compose.yml up -d
# Wait for services to be ready
echo "⏳ Waiting for services to be ready..."
sleep 30
# Verify restore
echo "🔍 Verifying restore..."
# Check TimescaleDB
if docker exec market_data_timescaledb pg_isready -U market_user -d market_data; then
echo "✅ TimescaleDB is ready"
# Show table counts
echo "📊 Database table counts:"
docker exec market_data_timescaledb psql -U market_user -d market_data -c "
SELECT
schemaname,
tablename,
n_tup_ins as row_count
FROM pg_stat_user_tables
WHERE schemaname = 'market_data'
ORDER BY tablename;
"
else
echo "❌ TimescaleDB verification failed"
exit 1
fi
# Check Redis
if docker exec market_data_redis redis-cli -a "$REDIS_PASSWORD" ping | grep -q PONG; then
echo "✅ Redis is ready"
# Show Redis info
echo "📦 Redis database info:"
docker exec market_data_redis redis-cli -a "$REDIS_PASSWORD" INFO keyspace
else
echo "❌ Redis verification failed"
exit 1
fi
# Clean up
echo "🧹 Cleaning up temporary files..."
rm -rf "$RESTORE_DIR"
echo ""
echo "🎉 Restore completed successfully!"
echo ""
echo "📋 Restore Summary:"
echo " Source: $BACKUP_FILE"
echo " Timestamp: $TIMESTAMP"
echo " Safety backup: ./backups/pre_restore_$TIMESTAMP/"
echo ""
echo "⚠️ If you encounter any issues, you can restore the safety backup:"
echo " docker-compose -f timescaledb-compose.yml down"
echo " docker volume rm market_data_timescale_data market_data_redis_data"
echo " docker volume create market_data_timescale_data"
echo " docker volume create market_data_redis_data"
echo " docker run --rm -v market_data_timescale_data:/data -v $(pwd)/backups/pre_restore_$TIMESTAMP:/backup alpine tar xzf /backup/current_timescale.tar.gz -C /data"
echo " docker run --rm -v market_data_redis_data:/data -v $(pwd)/backups/pre_restore_$TIMESTAMP:/backup alpine tar xzf /backup/current_redis.tar.gz -C /data"
echo " docker-compose -f timescaledb-compose.yml up -d"

View File

@ -0,0 +1,78 @@
version: '3.8'
services:
timescaledb:
image: timescale/timescaledb:latest-pg15
container_name: market_data_timescaledb
restart: unless-stopped
environment:
POSTGRES_DB: market_data
POSTGRES_USER: market_user
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-market_data_secure_pass_2024}
POSTGRES_INITDB_ARGS: "--encoding=UTF-8 --lc-collate=C --lc-ctype=C"
# TimescaleDB specific settings
TIMESCALEDB_TELEMETRY: 'off'
ports:
- "5432:5432"
volumes:
- timescale_data:/var/lib/postgresql/data
- ./init-scripts:/docker-entrypoint-initdb.d
command: >
postgres
-c shared_preload_libraries=timescaledb
-c max_connections=200
-c shared_buffers=256MB
-c effective_cache_size=1GB
-c maintenance_work_mem=64MB
-c checkpoint_completion_target=0.9
-c wal_buffers=16MB
-c default_statistics_target=100
-c random_page_cost=1.1
-c effective_io_concurrency=200
-c work_mem=4MB
-c min_wal_size=1GB
-c max_wal_size=4GB
-c max_worker_processes=8
-c max_parallel_workers_per_gather=4
-c max_parallel_workers=8
-c max_parallel_maintenance_workers=4
healthcheck:
test: ["CMD-SHELL", "pg_isready -U market_user -d market_data"]
interval: 30s
timeout: 10s
retries: 3
start_period: 60s
networks:
- market_data_network
redis:
image: redis:7-alpine
container_name: market_data_redis
restart: unless-stopped
ports:
- "6379:6379"
volumes:
- redis_data:/data
- ./redis.conf:/usr/local/etc/redis/redis.conf
command: redis-server /usr/local/etc/redis/redis.conf
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 30s
timeout: 10s
retries: 3
start_period: 30s
networks:
- market_data_network
volumes:
timescale_data:
driver: local
redis_data:
driver: local
networks:
market_data_network:
driver: bridge
ipam:
config:
- subnet: 172.20.0.0/16

View File

@ -0,0 +1,168 @@
#!/usr/bin/env python3
"""
Example usage of Binance connector.
"""
import asyncio
import sys
from pathlib import Path
# Add COBY to path
sys.path.insert(0, str(Path(__file__).parent.parent))
from connectors.binance_connector import BinanceConnector
from utils.logging import setup_logging, get_logger
from models.core import OrderBookSnapshot, TradeEvent
# Setup logging
setup_logging(level='INFO', console_output=True)
logger = get_logger(__name__)
class BinanceExample:
"""Example Binance connector usage"""
def __init__(self):
self.connector = BinanceConnector()
self.orderbook_count = 0
self.trade_count = 0
# Add data callbacks
self.connector.add_data_callback(self.on_data_received)
self.connector.add_status_callback(self.on_status_changed)
def on_data_received(self, data):
"""Handle received data"""
if isinstance(data, OrderBookSnapshot):
self.orderbook_count += 1
logger.info(
f"📊 Order Book {self.orderbook_count}: {data.symbol} - "
f"Mid: ${data.mid_price:.2f}, Spread: ${data.spread:.2f}, "
f"Bids: {len(data.bids)}, Asks: {len(data.asks)}"
)
elif isinstance(data, TradeEvent):
self.trade_count += 1
logger.info(
f"💰 Trade {self.trade_count}: {data.symbol} - "
f"{data.side.upper()} {data.size} @ ${data.price:.2f}"
)
def on_status_changed(self, exchange, status):
"""Handle status changes"""
logger.info(f"🔄 {exchange} status changed to: {status.value}")
async def run_example(self):
"""Run the example"""
try:
logger.info("🚀 Starting Binance connector example")
# Connect to Binance
logger.info("🔌 Connecting to Binance...")
connected = await self.connector.connect()
if not connected:
logger.error("❌ Failed to connect to Binance")
return
logger.info("✅ Connected to Binance successfully")
# Get available symbols
logger.info("📋 Getting available symbols...")
symbols = await self.connector.get_symbols()
logger.info(f"📋 Found {len(symbols)} trading symbols")
# Show some popular symbols
popular_symbols = ['BTCUSDT', 'ETHUSDT', 'ADAUSDT', 'BNBUSDT']
available_popular = [s for s in popular_symbols if s in symbols]
logger.info(f"📋 Popular symbols available: {available_popular}")
# Get order book snapshot
if 'BTCUSDT' in symbols:
logger.info("📊 Getting BTC order book snapshot...")
orderbook = await self.connector.get_orderbook_snapshot('BTCUSDT', depth=10)
if orderbook:
logger.info(
f"📊 BTC Order Book: Mid=${orderbook.mid_price:.2f}, "
f"Spread=${orderbook.spread:.2f}"
)
# Subscribe to real-time data
logger.info("🔔 Subscribing to real-time data...")
# Subscribe to BTC order book and trades
if 'BTCUSDT' in symbols:
await self.connector.subscribe_orderbook('BTCUSDT')
await self.connector.subscribe_trades('BTCUSDT')
logger.info("✅ Subscribed to BTCUSDT order book and trades")
# Subscribe to ETH order book
if 'ETHUSDT' in symbols:
await self.connector.subscribe_orderbook('ETHUSDT')
logger.info("✅ Subscribed to ETHUSDT order book")
# Let it run for a while
logger.info("⏳ Collecting data for 30 seconds...")
await asyncio.sleep(30)
# Show statistics
stats = self.connector.get_binance_stats()
logger.info("📈 Final Statistics:")
logger.info(f" 📊 Order books received: {self.orderbook_count}")
logger.info(f" 💰 Trades received: {self.trade_count}")
logger.info(f" 📡 Total messages: {stats['message_count']}")
logger.info(f" ❌ Errors: {stats['error_count']}")
logger.info(f" 🔗 Active streams: {stats['active_streams']}")
logger.info(f" 📋 Subscriptions: {list(stats['subscriptions'].keys())}")
# Unsubscribe and disconnect
logger.info("🔌 Cleaning up...")
if 'BTCUSDT' in self.connector.subscriptions:
await self.connector.unsubscribe_orderbook('BTCUSDT')
await self.connector.unsubscribe_trades('BTCUSDT')
if 'ETHUSDT' in self.connector.subscriptions:
await self.connector.unsubscribe_orderbook('ETHUSDT')
await self.connector.disconnect()
logger.info("✅ Disconnected successfully")
except KeyboardInterrupt:
logger.info("⏹️ Interrupted by user")
except Exception as e:
logger.error(f"❌ Example failed: {e}")
finally:
# Ensure cleanup
try:
await self.connector.disconnect()
except:
pass
async def main():
"""Main function"""
example = BinanceExample()
await example.run_example()
if __name__ == "__main__":
print("Binance Connector Example")
print("=" * 25)
print("This example will:")
print("1. Connect to Binance WebSocket")
print("2. Get available trading symbols")
print("3. Subscribe to real-time order book and trade data")
print("4. Display received data for 30 seconds")
print("5. Show statistics and disconnect")
print()
print("Press Ctrl+C to stop early")
print("=" * 25)
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\n👋 Example stopped by user")
except Exception as e:
print(f"\n❌ Example failed: {e}")
sys.exit(1)

View File

@ -0,0 +1,284 @@
"""
Example demonstrating multi-exchange connectivity with Binance, Coinbase, and Kraken.
Shows how to connect to multiple exchanges simultaneously and handle their data.
"""
import asyncio
import logging
from datetime import datetime
from ..connectors.binance_connector import BinanceConnector
from ..connectors.coinbase_connector import CoinbaseConnector
from ..connectors.kraken_connector import KrakenConnector
from ..models.core import OrderBookSnapshot, TradeEvent
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class MultiExchangeManager:
"""Manages connections to multiple exchanges."""
def __init__(self):
"""Initialize multi-exchange manager."""
# Initialize connectors
self.connectors = {
'binance': BinanceConnector(),
'coinbase': CoinbaseConnector(use_sandbox=True), # Use sandbox for testing
'kraken': KrakenConnector()
}
# Data tracking
self.data_received = {
'binance': {'orderbooks': 0, 'trades': 0},
'coinbase': {'orderbooks': 0, 'trades': 0},
'kraken': {'orderbooks': 0, 'trades': 0}
}
# Set up data callbacks
for name, connector in self.connectors.items():
connector.add_data_callback(lambda data, exchange=name: self._handle_data(exchange, data))
def _handle_data(self, exchange: str, data):
"""Handle data from any exchange."""
try:
if isinstance(data, OrderBookSnapshot):
self.data_received[exchange]['orderbooks'] += 1
logger.info(f"📊 {exchange.upper()}: Order book for {data.symbol} - "
f"Bids: {len(data.bids)}, Asks: {len(data.asks)}")
# Show best bid/ask if available
if data.bids and data.asks:
best_bid = max(data.bids, key=lambda x: x.price)
best_ask = min(data.asks, key=lambda x: x.price)
spread = best_ask.price - best_bid.price
logger.info(f" Best: {best_bid.price} / {best_ask.price} (spread: {spread:.2f})")
elif isinstance(data, TradeEvent):
self.data_received[exchange]['trades'] += 1
logger.info(f"💰 {exchange.upper()}: Trade {data.symbol} - "
f"{data.side} {data.size} @ {data.price}")
except Exception as e:
logger.error(f"Error handling data from {exchange}: {e}")
async def connect_all(self):
"""Connect to all exchanges."""
logger.info("Connecting to all exchanges...")
connection_tasks = []
for name, connector in self.connectors.items():
task = asyncio.create_task(self._connect_exchange(name, connector))
connection_tasks.append(task)
# Wait for all connections
results = await asyncio.gather(*connection_tasks, return_exceptions=True)
# Report results
for i, (name, result) in enumerate(zip(self.connectors.keys(), results)):
if isinstance(result, Exception):
logger.error(f"❌ Failed to connect to {name}: {result}")
elif result:
logger.info(f"✅ Connected to {name}")
else:
logger.warning(f"⚠️ Connection to {name} returned False")
async def _connect_exchange(self, name: str, connector) -> bool:
"""Connect to a single exchange."""
try:
return await connector.connect()
except Exception as e:
logger.error(f"Error connecting to {name}: {e}")
return False
async def subscribe_to_symbols(self, symbols: list):
"""Subscribe to order book and trade data for given symbols."""
logger.info(f"Subscribing to symbols: {symbols}")
for symbol in symbols:
for name, connector in self.connectors.items():
try:
if connector.is_connected:
# Subscribe to order book
await connector.subscribe_orderbook(symbol)
logger.info(f"📈 Subscribed to {symbol} order book on {name}")
# Subscribe to trades
await connector.subscribe_trades(symbol)
logger.info(f"💱 Subscribed to {symbol} trades on {name}")
# Small delay between subscriptions
await asyncio.sleep(0.5)
else:
logger.warning(f"⚠️ {name} not connected, skipping {symbol}")
except Exception as e:
logger.error(f"Error subscribing to {symbol} on {name}: {e}")
async def run_for_duration(self, duration_seconds: int):
"""Run data collection for specified duration."""
logger.info(f"Running data collection for {duration_seconds} seconds...")
start_time = datetime.now()
# Print statistics periodically
while (datetime.now() - start_time).seconds < duration_seconds:
await asyncio.sleep(10) # Print stats every 10 seconds
self._print_statistics()
logger.info("Data collection period completed")
def _print_statistics(self):
"""Print current data statistics."""
logger.info("📊 Current Statistics:")
total_orderbooks = 0
total_trades = 0
for exchange, stats in self.data_received.items():
orderbooks = stats['orderbooks']
trades = stats['trades']
total_orderbooks += orderbooks
total_trades += trades
logger.info(f" {exchange.upper()}: {orderbooks} order books, {trades} trades")
logger.info(f" TOTAL: {total_orderbooks} order books, {total_trades} trades")
async def disconnect_all(self):
"""Disconnect from all exchanges."""
logger.info("Disconnecting from all exchanges...")
for name, connector in self.connectors.items():
try:
await connector.disconnect()
logger.info(f"✅ Disconnected from {name}")
except Exception as e:
logger.error(f"Error disconnecting from {name}: {e}")
def get_connector_stats(self):
"""Get statistics from all connectors."""
stats = {}
for name, connector in self.connectors.items():
try:
if hasattr(connector, 'get_stats'):
stats[name] = connector.get_stats()
else:
stats[name] = {
'connected': connector.is_connected,
'exchange': connector.exchange_name
}
except Exception as e:
stats[name] = {'error': str(e)}
return stats
async def demonstrate_multi_exchange():
"""Demonstrate multi-exchange connectivity."""
logger.info("=== Multi-Exchange Connectivity Demo ===")
# Create manager
manager = MultiExchangeManager()
try:
# Connect to all exchanges
await manager.connect_all()
# Wait a moment for connections to stabilize
await asyncio.sleep(2)
# Subscribe to some popular symbols
symbols = ['BTCUSDT', 'ETHUSDT']
await manager.subscribe_to_symbols(symbols)
# Run data collection for 30 seconds
await manager.run_for_duration(30)
# Print final statistics
logger.info("=== Final Statistics ===")
manager._print_statistics()
# Print connector statistics
logger.info("=== Connector Statistics ===")
connector_stats = manager.get_connector_stats()
for exchange, stats in connector_stats.items():
logger.info(f"{exchange.upper()}: {stats}")
except Exception as e:
logger.error(f"Error in multi-exchange demo: {e}")
finally:
# Clean up
await manager.disconnect_all()
async def test_individual_connectors():
"""Test each connector individually."""
logger.info("=== Individual Connector Tests ===")
# Test Binance
logger.info("Testing Binance connector...")
binance = BinanceConnector()
try:
symbols = await binance.get_symbols()
logger.info(f"Binance symbols available: {len(symbols)}")
# Test order book snapshot
orderbook = await binance.get_orderbook_snapshot('BTCUSDT')
if orderbook:
logger.info(f"Binance order book: {len(orderbook.bids)} bids, {len(orderbook.asks)} asks")
except Exception as e:
logger.error(f"Binance test error: {e}")
# Test Coinbase
logger.info("Testing Coinbase connector...")
coinbase = CoinbaseConnector(use_sandbox=True)
try:
symbols = await coinbase.get_symbols()
logger.info(f"Coinbase symbols available: {len(symbols)}")
# Test order book snapshot
orderbook = await coinbase.get_orderbook_snapshot('BTCUSDT')
if orderbook:
logger.info(f"Coinbase order book: {len(orderbook.bids)} bids, {len(orderbook.asks)} asks")
except Exception as e:
logger.error(f"Coinbase test error: {e}")
# Test Kraken
logger.info("Testing Kraken connector...")
kraken = KrakenConnector()
try:
symbols = await kraken.get_symbols()
logger.info(f"Kraken symbols available: {len(symbols)}")
# Test order book snapshot
orderbook = await kraken.get_orderbook_snapshot('BTCUSDT')
if orderbook:
logger.info(f"Kraken order book: {len(orderbook.bids)} bids, {len(orderbook.asks)} asks")
except Exception as e:
logger.error(f"Kraken test error: {e}")
async def main():
"""Run all demonstrations."""
logger.info("Starting Multi-Exchange Examples...")
try:
# Test individual connectors first
await test_individual_connectors()
await asyncio.sleep(2)
# Then test multi-exchange connectivity
await demonstrate_multi_exchange()
logger.info("All multi-exchange examples completed successfully!")
except Exception as e:
logger.error(f"Error running examples: {e}")
if __name__ == "__main__":
# Run the examples
asyncio.run(main())

View File

@ -0,0 +1,276 @@
"""
Example showing how to integrate COBY system with existing orchestrator.
Demonstrates drop-in replacement and mode switching capabilities.
"""
import asyncio
import logging
from datetime import datetime, timedelta
# Import the COBY data provider replacement
from ..integration.data_provider_replacement import COBYDataProvider
from ..integration.orchestrator_adapter import MarketTick
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def demonstrate_basic_usage():
"""Demonstrate basic COBY data provider usage."""
logger.info("=== Basic COBY Data Provider Usage ===")
# Initialize COBY data provider (drop-in replacement)
data_provider = COBYDataProvider()
try:
# Test basic data access methods
logger.info("Testing basic data access...")
# Get current price
current_price = data_provider.get_current_price('BTCUSDT')
logger.info(f"Current BTC price: ${current_price}")
# Get historical data
historical_data = data_provider.get_historical_data('BTCUSDT', '1m', limit=10)
if historical_data is not None:
logger.info(f"Historical data shape: {historical_data.shape}")
logger.info(f"Latest close price: ${historical_data['close'].iloc[-1]}")
# Get COB data
cob_data = data_provider.get_latest_cob_data('BTCUSDT')
if cob_data:
logger.info(f"Latest COB data: {cob_data}")
# Get data quality indicators
quality = data_provider.adapter.get_data_quality_indicators('BTCUSDT')
logger.info(f"Data quality score: {quality.get('quality_score', 0)}")
except Exception as e:
logger.error(f"Error in basic usage: {e}")
finally:
await data_provider.close()
async def demonstrate_subscription_system():
"""Demonstrate the subscription system."""
logger.info("=== COBY Subscription System ===")
data_provider = COBYDataProvider()
try:
# Set up tick subscription
tick_count = 0
def tick_callback(tick: MarketTick):
nonlocal tick_count
tick_count += 1
logger.info(f"Received tick #{tick_count}: {tick.symbol} @ ${tick.price}")
# Subscribe to ticks
subscriber_id = data_provider.subscribe_to_ticks(
tick_callback,
symbols=['BTCUSDT', 'ETHUSDT'],
subscriber_name='example_subscriber'
)
logger.info(f"Subscribed to ticks with ID: {subscriber_id}")
# Set up COB data subscription
cob_count = 0
def cob_callback(symbol: str, data: dict):
nonlocal cob_count
cob_count += 1
logger.info(f"Received COB data #{cob_count} for {symbol}")
cob_subscriber_id = data_provider.subscribe_to_cob_raw_ticks(cob_callback)
logger.info(f"Subscribed to COB data with ID: {cob_subscriber_id}")
# Wait for some data
logger.info("Waiting for data updates...")
await asyncio.sleep(10)
# Unsubscribe
data_provider.unsubscribe(subscriber_id)
data_provider.unsubscribe(cob_subscriber_id)
logger.info("Unsubscribed from all feeds")
except Exception as e:
logger.error(f"Error in subscription demo: {e}")
finally:
await data_provider.close()
async def demonstrate_mode_switching():
"""Demonstrate switching between live and replay modes."""
logger.info("=== COBY Mode Switching ===")
data_provider = COBYDataProvider()
try:
# Start in live mode
current_mode = data_provider.get_current_mode()
logger.info(f"Current mode: {current_mode}")
# Get some live data
live_price = data_provider.get_current_price('BTCUSDT')
logger.info(f"Live price: ${live_price}")
# Switch to replay mode
logger.info("Switching to replay mode...")
start_time = datetime.utcnow() - timedelta(hours=1)
end_time = datetime.utcnow() - timedelta(minutes=30)
success = await data_provider.switch_to_replay_mode(
start_time=start_time,
end_time=end_time,
speed=10.0, # 10x speed
symbols=['BTCUSDT']
)
if success:
logger.info("Successfully switched to replay mode")
# Get replay status
replay_status = data_provider.get_replay_status()
if replay_status:
logger.info(f"Replay progress: {replay_status['progress']:.2%}")
logger.info(f"Replay speed: {replay_status['speed']}x")
# Wait for some replay data
await asyncio.sleep(5)
# Get data during replay
replay_price = data_provider.get_current_price('BTCUSDT')
logger.info(f"Replay price: ${replay_price}")
# Switch back to live mode
logger.info("Switching back to live mode...")
success = await data_provider.switch_to_live_mode()
if success:
logger.info("Successfully switched back to live mode")
current_mode = data_provider.get_current_mode()
logger.info(f"Current mode: {current_mode}")
except Exception as e:
logger.error(f"Error in mode switching demo: {e}")
finally:
await data_provider.close()
async def demonstrate_orchestrator_compatibility():
"""Demonstrate compatibility with orchestrator interface."""
logger.info("=== Orchestrator Compatibility ===")
data_provider = COBYDataProvider()
try:
# Test methods that orchestrator uses
logger.info("Testing orchestrator-compatible methods...")
# Build base data input (used by ML models)
base_data = data_provider.build_base_data_input('BTCUSDT')
if base_data:
features = base_data.get_feature_vector()
logger.info(f"Feature vector shape: {features.shape}")
# Get feature matrix (used by ML models)
feature_matrix = data_provider.get_feature_matrix(
'BTCUSDT',
timeframes=['1m', '5m'],
window_size=20
)
if feature_matrix is not None:
logger.info(f"Feature matrix shape: {feature_matrix.shape}")
# Get pivot bounds (used for normalization)
pivot_bounds = data_provider.get_pivot_bounds('BTCUSDT')
if pivot_bounds:
logger.info(f"Price range: ${pivot_bounds.price_min:.2f} - ${pivot_bounds.price_max:.2f}")
# Get COB imbalance (used for market microstructure analysis)
imbalance = data_provider.get_current_cob_imbalance('BTCUSDT')
logger.info(f"Order book imbalance: {imbalance['imbalance']:.3f}")
# Get system status
status = data_provider.get_cached_data_summary()
logger.info(f"System status: {status}")
# Test compatibility methods
data_provider.start_centralized_data_collection()
data_provider.invalidate_ohlcv_cache('BTCUSDT')
logger.info("All orchestrator compatibility tests passed!")
except Exception as e:
logger.error(f"Error in compatibility demo: {e}")
finally:
await data_provider.close()
async def demonstrate_performance_monitoring():
"""Demonstrate performance monitoring capabilities."""
logger.info("=== Performance Monitoring ===")
data_provider = COBYDataProvider()
try:
# Get initial statistics
initial_stats = data_provider.get_subscriber_stats()
logger.info(f"Initial stats: {initial_stats}")
# Get data quality information
quality_info = data_provider.get_cob_data_quality()
logger.info(f"Data quality info: {quality_info}")
# Get WebSocket status
ws_status = data_provider.get_cob_websocket_status()
logger.info(f"WebSocket status: {ws_status}")
# Monitor system metadata
system_metadata = data_provider.adapter.get_system_metadata()
logger.info(f"System components health: {system_metadata['components']}")
logger.info(f"Active subscribers: {system_metadata['active_subscribers']}")
except Exception as e:
logger.error(f"Error in performance monitoring: {e}")
finally:
await data_provider.close()
async def main():
"""Run all demonstration examples."""
logger.info("Starting COBY Integration Examples...")
try:
# Run all demonstrations
await demonstrate_basic_usage()
await asyncio.sleep(1)
await demonstrate_subscription_system()
await asyncio.sleep(1)
await demonstrate_mode_switching()
await asyncio.sleep(1)
await demonstrate_orchestrator_compatibility()
await asyncio.sleep(1)
await demonstrate_performance_monitoring()
logger.info("All COBY integration examples completed successfully!")
except Exception as e:
logger.error(f"Error running examples: {e}")
if __name__ == "__main__":
# Run the examples
asyncio.run(main())

View File

@ -0,0 +1,8 @@
"""
Integration layer for the COBY multi-exchange data aggregation system.
Provides compatibility interfaces for seamless integration with existing systems.
"""
from .orchestrator_adapter import COBYOrchestratorAdapter, MarketTick, PivotBounds
__all__ = ['COBYOrchestratorAdapter', 'MarketTick', 'PivotBounds']

View File

@ -0,0 +1,390 @@
"""
Drop-in replacement for the existing DataProvider class using COBY system.
Provides full compatibility with the orchestrator interface.
"""
import asyncio
import logging
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Callable, Union
from pathlib import Path
from .orchestrator_adapter import COBYOrchestratorAdapter, MarketTick, PivotBounds
from ..config import Config
from ..utils.logging import get_logger
logger = get_logger(__name__)
class COBYDataProvider:
"""
Drop-in replacement for DataProvider using COBY system.
Provides full compatibility with existing orchestrator interface while
leveraging COBY's multi-exchange data aggregation capabilities.
"""
def __init__(self, config_path: Optional[str] = None):
"""
Initialize COBY data provider.
Args:
config_path: Optional path to configuration file
"""
# Initialize COBY configuration
self.config = Config()
# Initialize COBY adapter
self.adapter = COBYOrchestratorAdapter(self.config)
# Initialize adapter components
asyncio.run(self.adapter._initialize_components())
# Compatibility attributes
self.symbols = self.config.exchanges.symbols
self.exchanges = self.config.exchanges.exchanges
logger.info("COBY data provider initialized")
# === CORE DATA METHODS ===
def get_historical_data(self, symbol: str, timeframe: str, limit: int = 1000,
refresh: bool = False) -> Optional[pd.DataFrame]:
"""Get historical OHLCV data."""
return self.adapter.get_historical_data(symbol, timeframe, limit, refresh)
def get_current_price(self, symbol: str) -> Optional[float]:
"""Get current price for a symbol."""
return self.adapter.get_current_price(symbol)
def get_live_price_from_api(self, symbol: str) -> Optional[float]:
"""Get live price from API (low-latency method)."""
return self.adapter.get_live_price_from_api(symbol)
def build_base_data_input(self, symbol: str) -> Optional[Any]:
"""Build base data input for ML models."""
return self.adapter.build_base_data_input(symbol)
# === COB DATA METHODS ===
def get_cob_raw_ticks(self, symbol: str, count: int = 1000) -> List[Dict]:
"""Get raw COB ticks for a symbol."""
return self.adapter.get_cob_raw_ticks(symbol, count)
def get_cob_1s_aggregated(self, symbol: str, count: int = 300) -> List[Dict]:
"""Get 1s aggregated COB data with $1 price buckets."""
return self.adapter.get_cob_1s_aggregated(symbol, count)
def get_latest_cob_data(self, symbol: str) -> Optional[Dict]:
"""Get latest COB raw tick for a symbol."""
return self.adapter.get_latest_cob_data(symbol)
def get_latest_cob_aggregated(self, symbol: str) -> Optional[Dict]:
"""Get latest 1s aggregated COB data for a symbol."""
return self.adapter.get_latest_cob_aggregated(symbol)
def get_current_cob_imbalance(self, symbol: str) -> Dict[str, float]:
"""Get current COB imbalance metrics for a symbol."""
try:
latest_data = self.get_latest_cob_data(symbol)
if not latest_data:
return {'bid_volume': 0.0, 'ask_volume': 0.0, 'imbalance': 0.0}
bid_volume = latest_data.get('bid_volume', 0.0)
ask_volume = latest_data.get('ask_volume', 0.0)
total_volume = bid_volume + ask_volume
imbalance = 0.0
if total_volume > 0:
imbalance = (bid_volume - ask_volume) / total_volume
return {
'bid_volume': bid_volume,
'ask_volume': ask_volume,
'imbalance': imbalance
}
except Exception as e:
logger.error(f"Error getting COB imbalance for {symbol}: {e}")
return {'bid_volume': 0.0, 'ask_volume': 0.0, 'imbalance': 0.0}
def get_cob_price_buckets(self, symbol: str, timeframe_seconds: int = 60) -> Dict:
"""Get price bucket analysis for a timeframe."""
try:
# Get aggregated data for the timeframe
count = timeframe_seconds # 1 second per data point
aggregated_data = self.get_cob_1s_aggregated(symbol, count)
if not aggregated_data:
return {}
# Combine all buckets
combined_bid_buckets = {}
combined_ask_buckets = {}
for data_point in aggregated_data:
for price, volume in data_point.get('bid_buckets', {}).items():
combined_bid_buckets[price] = combined_bid_buckets.get(price, 0) + volume
for price, volume in data_point.get('ask_buckets', {}).items():
combined_ask_buckets[price] = combined_ask_buckets.get(price, 0) + volume
return {
'symbol': symbol,
'timeframe_seconds': timeframe_seconds,
'bid_buckets': combined_bid_buckets,
'ask_buckets': combined_ask_buckets,
'timestamp': datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"Error getting price buckets for {symbol}: {e}")
return {}
def get_cob_websocket_status(self) -> Dict[str, Any]:
"""Get COB WebSocket status."""
try:
system_metadata = self.adapter.get_system_metadata()
connectors = system_metadata.get('components', {}).get('connectors', {})
return {
'connected': any(connectors.values()),
'exchanges': connectors,
'last_update': datetime.utcnow().isoformat(),
'mode': system_metadata.get('mode', 'unknown')
}
except Exception as e:
logger.error(f"Error getting WebSocket status: {e}")
return {'connected': False, 'error': str(e)}
# === SUBSCRIPTION METHODS ===
def subscribe_to_ticks(self, callback: Callable[[MarketTick], None],
symbols: List[str] = None,
subscriber_name: str = None) -> str:
"""Subscribe to tick data updates."""
return self.adapter.subscribe_to_ticks(callback, symbols, subscriber_name)
def subscribe_to_cob_raw_ticks(self, callback: Callable[[str, Dict], None]) -> str:
"""Subscribe to raw COB tick updates."""
return self.adapter.subscribe_to_cob_raw_ticks(callback)
def subscribe_to_cob_aggregated(self, callback: Callable[[str, Dict], None]) -> str:
"""Subscribe to 1s aggregated COB updates."""
return self.adapter.subscribe_to_cob_aggregated(callback)
def subscribe_to_training_data(self, callback: Callable[[str, dict], None]) -> str:
"""Subscribe to training data updates."""
return self.adapter.subscribe_to_training_data(callback)
def subscribe_to_model_predictions(self, callback: Callable[[str, dict], None]) -> str:
"""Subscribe to model prediction updates."""
return self.adapter.subscribe_to_model_predictions(callback)
def unsubscribe(self, subscriber_id: str) -> bool:
"""Unsubscribe from data feeds."""
return self.adapter.unsubscribe(subscriber_id)
# === MODE SWITCHING ===
async def switch_to_live_mode(self) -> bool:
"""Switch to live data mode."""
return await self.adapter.switch_to_live_mode()
async def switch_to_replay_mode(self, start_time: datetime, end_time: datetime,
speed: float = 1.0, symbols: List[str] = None) -> bool:
"""Switch to replay data mode."""
return await self.adapter.switch_to_replay_mode(start_time, end_time, speed, symbols)
def get_current_mode(self) -> str:
"""Get current data mode."""
return self.adapter.get_current_mode()
def get_replay_status(self) -> Optional[Dict[str, Any]]:
"""Get replay session status."""
return self.adapter.get_replay_status()
# === COMPATIBILITY METHODS ===
def start_centralized_data_collection(self) -> None:
"""Start centralized data collection."""
self.adapter.start_centralized_data_collection()
def start_training_data_collection(self) -> None:
"""Start training data collection."""
self.adapter.start_training_data_collection()
def invalidate_ohlcv_cache(self, symbol: str) -> None:
"""Invalidate OHLCV cache for a symbol."""
self.adapter.invalidate_ohlcv_cache(symbol)
def get_latest_candles(self, symbol: str, timeframe: str, limit: int = 100) -> pd.DataFrame:
"""Get the latest candles from cached data."""
return self.get_historical_data(symbol, timeframe, limit) or pd.DataFrame()
def get_price_at_index(self, symbol: str, index: int, timeframe: str = '1m') -> Optional[float]:
"""Get price at specific index for backtesting."""
try:
df = self.get_historical_data(symbol, timeframe, limit=index + 10)
if df is not None and len(df) > index:
return float(df.iloc[-(index + 1)]['close'])
return None
except Exception as e:
logger.error(f"Error getting price at index {index} for {symbol}: {e}")
return None
# === PIVOT AND MARKET STRUCTURE (MOCK IMPLEMENTATIONS) ===
def get_pivot_bounds(self, symbol: str) -> Optional[PivotBounds]:
"""Get pivot bounds for a symbol (mock implementation)."""
try:
# Get recent price data
df = self.get_historical_data(symbol, '1m', limit=1000)
if df is None or df.empty:
return None
# Calculate basic pivot levels
high_prices = df['high'].values
low_prices = df['low'].values
volumes = df['volume'].values
price_max = float(np.max(high_prices))
price_min = float(np.min(low_prices))
volume_max = float(np.max(volumes))
volume_min = float(np.min(volumes))
# Simple support/resistance calculation
price_range = price_max - price_min
support_levels = [price_min + i * price_range / 10 for i in range(1, 5)]
resistance_levels = [price_max - i * price_range / 10 for i in range(1, 5)]
return PivotBounds(
symbol=symbol,
price_max=price_max,
price_min=price_min,
volume_max=volume_max,
volume_min=volume_min,
pivot_support_levels=support_levels,
pivot_resistance_levels=resistance_levels,
pivot_context={'method': 'simple'},
created_timestamp=datetime.utcnow(),
data_period_start=df.index[0].to_pydatetime(),
data_period_end=df.index[-1].to_pydatetime(),
total_candles_analyzed=len(df)
)
except Exception as e:
logger.error(f"Error getting pivot bounds for {symbol}: {e}")
return None
def get_pivot_normalized_features(self, symbol: str, df: pd.DataFrame) -> Optional[pd.DataFrame]:
"""Get dataframe with pivot-normalized features."""
try:
pivot_bounds = self.get_pivot_bounds(symbol)
if not pivot_bounds:
return df
# Add normalized features
df_copy = df.copy()
price_range = pivot_bounds.get_price_range()
if price_range > 0:
df_copy['normalized_close'] = (df_copy['close'] - pivot_bounds.price_min) / price_range
df_copy['normalized_high'] = (df_copy['high'] - pivot_bounds.price_min) / price_range
df_copy['normalized_low'] = (df_copy['low'] - pivot_bounds.price_min) / price_range
return df_copy
except Exception as e:
logger.error(f"Error getting pivot normalized features for {symbol}: {e}")
return df
# === FEATURE EXTRACTION METHODS ===
def get_feature_matrix(self, symbol: str, timeframes: List[str] = None,
window_size: int = 20) -> Optional[np.ndarray]:
"""Get feature matrix for ML models."""
try:
if not timeframes:
timeframes = ['1m', '5m', '15m']
features = []
for timeframe in timeframes:
df = self.get_historical_data(symbol, timeframe, limit=window_size + 10)
if df is not None and len(df) >= window_size:
# Extract basic features
closes = df['close'].values[-window_size:]
volumes = df['volume'].values[-window_size:]
# Normalize features
close_mean = np.mean(closes)
close_std = np.std(closes) + 1e-8
normalized_closes = (closes - close_mean) / close_std
volume_mean = np.mean(volumes)
volume_std = np.std(volumes) + 1e-8
normalized_volumes = (volumes - volume_mean) / volume_std
features.extend(normalized_closes)
features.extend(normalized_volumes)
if features:
return np.array(features, dtype=np.float32)
return None
except Exception as e:
logger.error(f"Error getting feature matrix for {symbol}: {e}")
return None
# === SYSTEM STATUS AND STATISTICS ===
def get_cached_data_summary(self) -> Dict[str, Any]:
"""Get summary of cached data."""
try:
system_metadata = self.adapter.get_system_metadata()
return {
'system': 'COBY',
'mode': system_metadata.get('mode'),
'statistics': system_metadata.get('statistics', {}),
'components_healthy': system_metadata.get('components', {}),
'active_subscribers': system_metadata.get('active_subscribers', 0)
}
except Exception as e:
logger.error(f"Error getting cached data summary: {e}")
return {'error': str(e)}
def get_cob_data_quality(self) -> Dict[str, Any]:
"""Get COB data quality information."""
try:
quality_info = {}
for symbol in self.symbols:
quality_info[symbol] = self.adapter.get_data_quality_indicators(symbol)
return quality_info
except Exception as e:
logger.error(f"Error getting COB data quality: {e}")
return {'error': str(e)}
def get_subscriber_stats(self) -> Dict[str, Any]:
"""Get subscriber statistics."""
return self.adapter.get_stats()
# === CLEANUP ===
async def close(self) -> None:
"""Close all connections and cleanup."""
await self.adapter.close()
def __del__(self):
"""Cleanup on deletion."""
try:
asyncio.run(self.close())
except:
pass

View File

@ -0,0 +1,888 @@
"""
Orchestrator integration adapter for COBY system.
Provides compatibility layer for seamless integration with existing orchestrator.
"""
import asyncio
import logging
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Callable, Union
from dataclasses import dataclass, field
import uuid
from collections import deque
import threading
from ..storage.storage_manager import StorageManager
from ..replay.replay_manager import HistoricalReplayManager
from ..caching.redis_manager import RedisManager
from ..aggregation.aggregation_engine import StandardAggregationEngine
from ..processing.data_processor import StandardDataProcessor
from ..connectors.binance_connector import BinanceConnector
from ..models.core import OrderBookSnapshot, TradeEvent, HeatmapData, ReplayStatus
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import IntegrationError, ValidationError
from ..config import Config
logger = get_logger(__name__)
@dataclass
class MarketTick:
"""Market tick data structure compatible with orchestrator"""
symbol: str
price: float
volume: float
timestamp: datetime
side: str = "unknown"
exchange: str = "binance"
subscriber_name: str = "unknown"
@dataclass
class PivotBounds:
"""Pivot bounds structure compatible with orchestrator"""
symbol: str
price_max: float
price_min: float
volume_max: float
volume_min: float
pivot_support_levels: List[float]
pivot_resistance_levels: List[float]
pivot_context: Dict[str, Any]
created_timestamp: datetime
data_period_start: datetime
data_period_end: datetime
total_candles_analyzed: int
def get_price_range(self) -> float:
return self.price_max - self.price_min
def normalize_price(self, price: float) -> float:
return (price - self.price_min) / self.get_price_range()
class COBYOrchestratorAdapter:
"""
Adapter that makes COBY system compatible with existing orchestrator interface.
Provides:
- Data provider interface compatibility
- Live/replay mode switching
- Data quality indicators
- Subscription management
- Caching and performance optimization
"""
def __init__(self, config: Config):
"""
Initialize orchestrator adapter.
Args:
config: COBY system configuration
"""
self.config = config
# Core components
self.storage_manager = StorageManager(config)
self.replay_manager = HistoricalReplayManager(self.storage_manager, config)
self.redis_manager = RedisManager()
self.aggregation_engine = StandardAggregationEngine()
self.data_processor = StandardDataProcessor()
# Exchange connectors
self.connectors = {
'binance': BinanceConnector()
}
# Mode management
self.mode = 'live' # 'live' or 'replay'
self.current_replay_session = None
# Subscription management
self.subscribers = {
'ticks': {},
'cob_raw': {},
'cob_aggregated': {},
'training_data': {},
'model_predictions': {}
}
self.subscriber_lock = threading.Lock()
# Data caching
self.tick_cache = {}
self.orderbook_cache = {}
self.price_cache = {}
# Statistics
self.stats = {
'ticks_processed': 0,
'orderbooks_processed': 0,
'subscribers_active': 0,
'cache_hits': 0,
'cache_misses': 0
}
# Initialize components
self._initialize_components()
logger.info("COBY orchestrator adapter initialized")
async def _initialize_components(self):
"""Initialize all COBY components."""
try:
# Initialize storage
await self.storage_manager.initialize()
# Initialize Redis cache
await self.redis_manager.initialize()
# Initialize connectors
for name, connector in self.connectors.items():
await connector.connect()
connector.add_data_callback(self._handle_connector_data)
logger.info("COBY components initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize COBY components: {e}")
raise IntegrationError(f"Component initialization failed: {e}")
# === ORCHESTRATOR COMPATIBILITY METHODS ===
def get_historical_data(self, symbol: str, timeframe: str, limit: int = 1000,
refresh: bool = False) -> Optional[pd.DataFrame]:
"""Get historical OHLCV data compatible with orchestrator interface."""
try:
set_correlation_id()
# Convert timeframe to minutes
timeframe_minutes = self._parse_timeframe(timeframe)
if not timeframe_minutes:
logger.warning(f"Unsupported timeframe: {timeframe}")
return None
# Calculate time range
end_time = datetime.utcnow()
start_time = end_time - timedelta(minutes=timeframe_minutes * limit)
# Get data from storage
if self.mode == 'replay' and self.current_replay_session:
# Use replay data
data = asyncio.run(self.storage_manager.get_historical_data(
symbol, start_time, end_time, 'ohlcv'
))
else:
# Use live data from cache or storage
cache_key = f"ohlcv:{symbol}:{timeframe}:{limit}"
cached_data = asyncio.run(self.redis_manager.get(cache_key))
if cached_data and not refresh:
self.stats['cache_hits'] += 1
return pd.DataFrame(cached_data)
self.stats['cache_misses'] += 1
data = asyncio.run(self.storage_manager.get_historical_data(
symbol, start_time, end_time, 'ohlcv'
))
# Cache the result
if data:
asyncio.run(self.redis_manager.set(cache_key, data, ttl=60))
if not data:
return None
# Convert to DataFrame compatible with orchestrator
df = pd.DataFrame(data)
if not df.empty:
df['timestamp'] = pd.to_datetime(df['timestamp'])
df.set_index('timestamp', inplace=True)
df = df.sort_index()
return df
except Exception as e:
logger.error(f"Error getting historical data for {symbol}: {e}")
return None
def get_current_price(self, symbol: str) -> Optional[float]:
"""Get current price for a symbol."""
try:
# Check cache first
if symbol in self.price_cache:
cached_price, timestamp = self.price_cache[symbol]
if (datetime.utcnow() - timestamp).seconds < 5: # 5 second cache
return cached_price
# Get latest orderbook
latest_orderbook = asyncio.run(
self.storage_manager.get_latest_orderbook(symbol)
)
if latest_orderbook and latest_orderbook.get('mid_price'):
price = float(latest_orderbook['mid_price'])
self.price_cache[symbol] = (price, datetime.utcnow())
return price
return None
except Exception as e:
logger.error(f"Error getting current price for {symbol}: {e}")
return None
def get_live_price_from_api(self, symbol: str) -> Optional[float]:
"""Get live price from API (low-latency method)."""
return self.get_current_price(symbol)
def build_base_data_input(self, symbol: str) -> Optional[Any]:
"""Build base data input compatible with orchestrator models."""
try:
# This would need to be implemented based on the specific
# BaseDataInput class used by the orchestrator
# For now, return a mock object that provides the interface
class MockBaseDataInput:
def __init__(self, symbol: str, adapter):
self.symbol = symbol
self.adapter = adapter
def get_feature_vector(self) -> np.ndarray:
# Return feature vector from COBY data
return self.adapter._get_feature_vector(self.symbol)
return MockBaseDataInput(symbol, self)
except Exception as e:
logger.error(f"Error building base data input for {symbol}: {e}")
return None
def _get_feature_vector(self, symbol: str) -> np.ndarray:
"""Get feature vector for ML models."""
try:
# Get latest market data
latest_orderbook = asyncio.run(
self.storage_manager.get_latest_orderbook(symbol)
)
if not latest_orderbook:
return np.zeros(100, dtype=np.float32) # Default size
# Extract features from orderbook
features = []
# Price features
if latest_orderbook.get('mid_price'):
features.append(float(latest_orderbook['mid_price']))
if latest_orderbook.get('spread'):
features.append(float(latest_orderbook['spread']))
# Volume features
if latest_orderbook.get('bid_volume'):
features.append(float(latest_orderbook['bid_volume']))
if latest_orderbook.get('ask_volume'):
features.append(float(latest_orderbook['ask_volume']))
# Pad or truncate to expected size
target_size = 100
if len(features) < target_size:
features.extend([0.0] * (target_size - len(features)))
elif len(features) > target_size:
features = features[:target_size]
return np.array(features, dtype=np.float32)
except Exception as e:
logger.error(f"Error getting feature vector for {symbol}: {e}")
return np.zeros(100, dtype=np.float32)
# === COB DATA METHODS ===
def get_cob_raw_ticks(self, symbol: str, count: int = 1000) -> List[Dict]:
"""Get raw COB ticks for a symbol."""
try:
# Get recent orderbook snapshots
end_time = datetime.utcnow()
start_time = end_time - timedelta(minutes=15) # 15 minutes of data
data = asyncio.run(self.storage_manager.get_historical_data(
symbol, start_time, end_time, 'orderbook'
))
if not data:
return []
# Convert to COB tick format
ticks = []
for item in data[-count:]:
tick = {
'symbol': item['symbol'],
'timestamp': item['timestamp'].isoformat(),
'mid_price': item.get('mid_price'),
'spread': item.get('spread'),
'bid_volume': item.get('bid_volume'),
'ask_volume': item.get('ask_volume'),
'exchange': item['exchange']
}
ticks.append(tick)
return ticks
except Exception as e:
logger.error(f"Error getting COB raw ticks for {symbol}: {e}")
return []
def get_cob_1s_aggregated(self, symbol: str, count: int = 300) -> List[Dict]:
"""Get 1s aggregated COB data with $1 price buckets."""
try:
# Get heatmap data
bucket_size = self.config.aggregation.bucket_size
start_time = datetime.utcnow() - timedelta(seconds=count)
heatmap_data = asyncio.run(
self.storage_manager.get_heatmap_data(symbol, bucket_size, start_time)
)
if not heatmap_data:
return []
# Group by timestamp and aggregate
aggregated = {}
for item in heatmap_data:
timestamp = item['timestamp']
if timestamp not in aggregated:
aggregated[timestamp] = {
'timestamp': timestamp.isoformat(),
'symbol': symbol,
'bid_buckets': {},
'ask_buckets': {},
'total_bid_volume': 0,
'total_ask_volume': 0
}
price_bucket = float(item['price_bucket'])
volume = float(item['volume'])
side = item['side']
if side == 'bid':
aggregated[timestamp]['bid_buckets'][price_bucket] = volume
aggregated[timestamp]['total_bid_volume'] += volume
else:
aggregated[timestamp]['ask_buckets'][price_bucket] = volume
aggregated[timestamp]['total_ask_volume'] += volume
# Return sorted by timestamp
result = list(aggregated.values())
result.sort(key=lambda x: x['timestamp'])
return result[-count:]
except Exception as e:
logger.error(f"Error getting COB 1s aggregated data for {symbol}: {e}")
return []
def get_latest_cob_data(self, symbol: str) -> Optional[Dict]:
"""Get latest COB raw tick for a symbol."""
try:
latest_orderbook = asyncio.run(
self.storage_manager.get_latest_orderbook(symbol)
)
if not latest_orderbook:
return None
return {
'symbol': symbol,
'timestamp': latest_orderbook['timestamp'].isoformat(),
'mid_price': latest_orderbook.get('mid_price'),
'spread': latest_orderbook.get('spread'),
'bid_volume': latest_orderbook.get('bid_volume'),
'ask_volume': latest_orderbook.get('ask_volume'),
'exchange': latest_orderbook['exchange']
}
except Exception as e:
logger.error(f"Error getting latest COB data for {symbol}: {e}")
return None
def get_latest_cob_aggregated(self, symbol: str) -> Optional[Dict]:
"""Get latest 1s aggregated COB data for a symbol."""
try:
aggregated_data = self.get_cob_1s_aggregated(symbol, count=1)
return aggregated_data[0] if aggregated_data else None
except Exception as e:
logger.error(f"Error getting latest COB aggregated data for {symbol}: {e}")
return None
# === SUBSCRIPTION METHODS ===
def subscribe_to_ticks(self, callback: Callable[[MarketTick], None],
symbols: List[str] = None,
subscriber_name: str = None) -> str:
"""Subscribe to tick data updates."""
try:
subscriber_id = str(uuid.uuid4())
with self.subscriber_lock:
self.subscribers['ticks'][subscriber_id] = {
'callback': callback,
'symbols': symbols or [],
'subscriber_name': subscriber_name or 'unknown',
'created_at': datetime.utcnow()
}
self.stats['subscribers_active'] += 1
logger.info(f"Added tick subscriber {subscriber_id} for {subscriber_name}")
return subscriber_id
except Exception as e:
logger.error(f"Error adding tick subscriber: {e}")
return ""
def subscribe_to_cob_raw_ticks(self, callback: Callable[[str, Dict], None]) -> str:
"""Subscribe to raw COB tick updates."""
try:
subscriber_id = str(uuid.uuid4())
with self.subscriber_lock:
self.subscribers['cob_raw'][subscriber_id] = {
'callback': callback,
'created_at': datetime.utcnow()
}
self.stats['subscribers_active'] += 1
logger.info(f"Added COB raw tick subscriber {subscriber_id}")
return subscriber_id
except Exception as e:
logger.error(f"Error adding COB raw tick subscriber: {e}")
return ""
def subscribe_to_cob_aggregated(self, callback: Callable[[str, Dict], None]) -> str:
"""Subscribe to 1s aggregated COB updates."""
try:
subscriber_id = str(uuid.uuid4())
with self.subscriber_lock:
self.subscribers['cob_aggregated'][subscriber_id] = {
'callback': callback,
'created_at': datetime.utcnow()
}
self.stats['subscribers_active'] += 1
logger.info(f"Added COB aggregated subscriber {subscriber_id}")
return subscriber_id
except Exception as e:
logger.error(f"Error adding COB aggregated subscriber: {e}")
return ""
def subscribe_to_training_data(self, callback: Callable[[str, dict], None]) -> str:
"""Subscribe to training data updates."""
try:
subscriber_id = str(uuid.uuid4())
with self.subscriber_lock:
self.subscribers['training_data'][subscriber_id] = {
'callback': callback,
'created_at': datetime.utcnow()
}
self.stats['subscribers_active'] += 1
logger.info(f"Added training data subscriber {subscriber_id}")
return subscriber_id
except Exception as e:
logger.error(f"Error adding training data subscriber: {e}")
return ""
def subscribe_to_model_predictions(self, callback: Callable[[str, dict], None]) -> str:
"""Subscribe to model prediction updates."""
try:
subscriber_id = str(uuid.uuid4())
with self.subscriber_lock:
self.subscribers['model_predictions'][subscriber_id] = {
'callback': callback,
'created_at': datetime.utcnow()
}
self.stats['subscribers_active'] += 1
logger.info(f"Added model prediction subscriber {subscriber_id}")
return subscriber_id
except Exception as e:
logger.error(f"Error adding model prediction subscriber: {e}")
return ""
def unsubscribe(self, subscriber_id: str) -> bool:
"""Unsubscribe from all data feeds."""
try:
with self.subscriber_lock:
removed = False
for category in self.subscribers:
if subscriber_id in self.subscribers[category]:
del self.subscribers[category][subscriber_id]
self.stats['subscribers_active'] -= 1
removed = True
break
if removed:
logger.info(f"Removed subscriber {subscriber_id}")
return removed
except Exception as e:
logger.error(f"Error removing subscriber {subscriber_id}: {e}")
return False
# === MODE SWITCHING ===
async def switch_to_live_mode(self) -> bool:
"""Switch to live data mode."""
try:
if self.mode == 'live':
logger.info("Already in live mode")
return True
# Stop replay session if active
if self.current_replay_session:
await self.replay_manager.stop_replay(self.current_replay_session)
self.current_replay_session = None
# Start live connectors
for name, connector in self.connectors.items():
if not connector.is_connected:
await connector.connect()
self.mode = 'live'
logger.info("Switched to live data mode")
return True
except Exception as e:
logger.error(f"Error switching to live mode: {e}")
return False
async def switch_to_replay_mode(self, start_time: datetime, end_time: datetime,
speed: float = 1.0, symbols: List[str] = None) -> bool:
"""Switch to replay data mode."""
try:
if self.mode == 'replay' and self.current_replay_session:
await self.replay_manager.stop_replay(self.current_replay_session)
# Create replay session
session_id = self.replay_manager.create_replay_session(
start_time=start_time,
end_time=end_time,
speed=speed,
symbols=symbols or self.config.exchanges.symbols
)
# Add data callback for replay
self.replay_manager.add_data_callback(session_id, self._handle_replay_data)
# Start replay
await self.replay_manager.start_replay(session_id)
self.current_replay_session = session_id
self.mode = 'replay'
logger.info(f"Switched to replay mode: {start_time} to {end_time}")
return True
except Exception as e:
logger.error(f"Error switching to replay mode: {e}")
return False
def get_current_mode(self) -> str:
"""Get current data mode (live or replay)."""
return self.mode
def get_replay_status(self) -> Optional[Dict[str, Any]]:
"""Get current replay session status."""
if not self.current_replay_session:
return None
session = self.replay_manager.get_replay_status(self.current_replay_session)
if not session:
return None
return {
'session_id': session.session_id,
'status': session.status.value,
'progress': session.progress,
'current_time': session.current_time.isoformat(),
'speed': session.speed,
'events_replayed': session.events_replayed,
'total_events': session.total_events
}
# === DATA QUALITY AND METADATA ===
def get_data_quality_indicators(self, symbol: str) -> Dict[str, Any]:
"""Get data quality indicators for a symbol."""
try:
# Get recent data statistics
end_time = datetime.utcnow()
start_time = end_time - timedelta(minutes=5)
orderbook_data = asyncio.run(self.storage_manager.get_historical_data(
symbol, start_time, end_time, 'orderbook'
))
trade_data = asyncio.run(self.storage_manager.get_historical_data(
symbol, start_time, end_time, 'trades'
))
# Calculate quality metrics
quality = {
'symbol': symbol,
'timestamp': datetime.utcnow().isoformat(),
'orderbook_updates': len(orderbook_data) if orderbook_data else 0,
'trade_events': len(trade_data) if trade_data else 0,
'data_freshness_seconds': 0,
'exchange_coverage': [],
'quality_score': 0.0
}
# Calculate data freshness
if orderbook_data:
latest_timestamp = max(item['timestamp'] for item in orderbook_data)
quality['data_freshness_seconds'] = (
datetime.utcnow() - latest_timestamp
).total_seconds()
# Get exchange coverage
if orderbook_data:
exchanges = set(item['exchange'] for item in orderbook_data)
quality['exchange_coverage'] = list(exchanges)
# Calculate quality score (0-1)
score = 0.0
if quality['orderbook_updates'] > 0:
score += 0.4
if quality['trade_events'] > 0:
score += 0.3
if quality['data_freshness_seconds'] < 10:
score += 0.3
quality['quality_score'] = score
return quality
except Exception as e:
logger.error(f"Error getting data quality for {symbol}: {e}")
return {
'symbol': symbol,
'timestamp': datetime.utcnow().isoformat(),
'quality_score': 0.0,
'error': str(e)
}
def get_system_metadata(self) -> Dict[str, Any]:
"""Get system metadata and status."""
try:
return {
'system': 'COBY',
'version': '1.0.0',
'mode': self.mode,
'timestamp': datetime.utcnow().isoformat(),
'components': {
'storage': self.storage_manager.is_healthy(),
'redis': True, # Simplified check
'connectors': {
name: connector.is_connected
for name, connector in self.connectors.items()
}
},
'statistics': self.stats,
'replay_session': self.get_replay_status(),
'active_subscribers': sum(
len(subs) for subs in self.subscribers.values()
)
}
except Exception as e:
logger.error(f"Error getting system metadata: {e}")
return {'error': str(e)}
# === DATA HANDLERS ===
async def _handle_connector_data(self, data: Union[OrderBookSnapshot, TradeEvent]) -> None:
"""Handle data from exchange connectors."""
try:
# Store data
if isinstance(data, OrderBookSnapshot):
await self.storage_manager.store_orderbook(data)
self.stats['orderbooks_processed'] += 1
# Create market tick for subscribers
if data.bids and data.asks:
best_bid = max(data.bids, key=lambda x: x.price)
best_ask = min(data.asks, key=lambda x: x.price)
mid_price = (best_bid.price + best_ask.price) / 2
tick = MarketTick(
symbol=data.symbol,
price=mid_price,
volume=best_bid.size + best_ask.size,
timestamp=data.timestamp,
exchange=data.exchange
)
await self._notify_tick_subscribers(tick)
# Create COB data for subscribers
cob_data = {
'symbol': data.symbol,
'timestamp': data.timestamp.isoformat(),
'bids': [{'price': b.price, 'size': b.size} for b in data.bids[:10]],
'asks': [{'price': a.price, 'size': a.size} for a in data.asks[:10]],
'exchange': data.exchange
}
await self._notify_cob_raw_subscribers(data.symbol, cob_data)
elif isinstance(data, TradeEvent):
await self.storage_manager.store_trade(data)
self.stats['ticks_processed'] += 1
# Create market tick
tick = MarketTick(
symbol=data.symbol,
price=data.price,
volume=data.size,
timestamp=data.timestamp,
side=data.side,
exchange=data.exchange
)
await self._notify_tick_subscribers(tick)
except Exception as e:
logger.error(f"Error handling connector data: {e}")
async def _handle_replay_data(self, data: Union[OrderBookSnapshot, TradeEvent]) -> None:
"""Handle data from replay system."""
try:
# Process replay data same as live data
await self._handle_connector_data(data)
except Exception as e:
logger.error(f"Error handling replay data: {e}")
async def _notify_tick_subscribers(self, tick: MarketTick) -> None:
"""Notify tick subscribers."""
try:
with self.subscriber_lock:
subscribers = self.subscribers['ticks'].copy()
for subscriber_id, sub_info in subscribers.items():
try:
callback = sub_info['callback']
symbols = sub_info['symbols']
# Check if subscriber wants this symbol
if not symbols or tick.symbol in symbols:
if asyncio.iscoroutinefunction(callback):
await callback(tick)
else:
callback(tick)
except Exception as e:
logger.error(f"Error notifying tick subscriber {subscriber_id}: {e}")
except Exception as e:
logger.error(f"Error notifying tick subscribers: {e}")
async def _notify_cob_raw_subscribers(self, symbol: str, data: Dict) -> None:
"""Notify COB raw tick subscribers."""
try:
with self.subscriber_lock:
subscribers = self.subscribers['cob_raw'].copy()
for subscriber_id, sub_info in subscribers.items():
try:
callback = sub_info['callback']
if asyncio.iscoroutinefunction(callback):
await callback(symbol, data)
else:
callback(symbol, data)
except Exception as e:
logger.error(f"Error notifying COB raw subscriber {subscriber_id}: {e}")
except Exception as e:
logger.error(f"Error notifying COB raw subscribers: {e}")
# === UTILITY METHODS ===
def _parse_timeframe(self, timeframe: str) -> Optional[int]:
"""Parse timeframe string to minutes."""
try:
if timeframe.endswith('m'):
return int(timeframe[:-1])
elif timeframe.endswith('h'):
return int(timeframe[:-1]) * 60
elif timeframe.endswith('d'):
return int(timeframe[:-1]) * 24 * 60
else:
return None
except:
return None
def start_centralized_data_collection(self) -> None:
"""Start centralized data collection (compatibility method)."""
logger.info("Centralized data collection started (COBY mode)")
def start_training_data_collection(self) -> None:
"""Start training data collection (compatibility method)."""
logger.info("Training data collection started (COBY mode)")
def invalidate_ohlcv_cache(self, symbol: str) -> None:
"""Invalidate OHLCV cache for a symbol."""
try:
# Clear Redis cache for this symbol
cache_pattern = f"ohlcv:{symbol}:*"
asyncio.run(self.redis_manager.delete_pattern(cache_pattern))
# Clear local price cache
if symbol in self.price_cache:
del self.price_cache[symbol]
except Exception as e:
logger.error(f"Error invalidating cache for {symbol}: {e}")
async def close(self) -> None:
"""Close all connections and cleanup."""
try:
# Stop replay session
if self.current_replay_session:
await self.replay_manager.stop_replay(self.current_replay_session)
# Close connectors
for connector in self.connectors.values():
await connector.disconnect()
# Close storage
await self.storage_manager.close()
# Close Redis
await self.redis_manager.close()
logger.info("COBY orchestrator adapter closed")
except Exception as e:
logger.error(f"Error closing adapter: {e}")
def get_stats(self) -> Dict[str, Any]:
"""Get adapter statistics."""
return {
**self.stats,
'mode': self.mode,
'active_subscribers': sum(len(subs) for subs in self.subscribers.values()),
'cache_size': len(self.price_cache),
'replay_session': self.current_replay_session
}

View File

@ -0,0 +1,17 @@
"""
Interface definitions for the multi-exchange data aggregation system.
"""
from .exchange_connector import ExchangeConnector
from .data_processor import DataProcessor
from .aggregation_engine import AggregationEngine
from .storage_manager import StorageManager
from .replay_manager import ReplayManager
__all__ = [
'ExchangeConnector',
'DataProcessor',
'AggregationEngine',
'StorageManager',
'ReplayManager'
]

View File

@ -0,0 +1,139 @@
"""
Interface for data aggregation and heatmap generation.
"""
from abc import ABC, abstractmethod
from typing import Dict, List
from ..models.core import (
OrderBookSnapshot, PriceBuckets, HeatmapData,
ImbalanceMetrics, ConsolidatedOrderBook
)
class AggregationEngine(ABC):
"""Aggregates data into price buckets and heatmaps"""
@abstractmethod
def create_price_buckets(self, orderbook: OrderBookSnapshot,
bucket_size: float) -> PriceBuckets:
"""
Convert order book data to price buckets.
Args:
orderbook: Order book snapshot
bucket_size: Size of each price bucket
Returns:
PriceBuckets: Aggregated price bucket data
"""
pass
@abstractmethod
def update_heatmap(self, symbol: str, buckets: PriceBuckets) -> HeatmapData:
"""
Update heatmap data with new price buckets.
Args:
symbol: Trading symbol
buckets: Price bucket data
Returns:
HeatmapData: Updated heatmap visualization data
"""
pass
@abstractmethod
def calculate_imbalances(self, orderbook: OrderBookSnapshot) -> ImbalanceMetrics:
"""
Calculate order book imbalance metrics.
Args:
orderbook: Order book snapshot
Returns:
ImbalanceMetrics: Calculated imbalance metrics
"""
pass
@abstractmethod
def aggregate_across_exchanges(self, symbol: str,
orderbooks: List[OrderBookSnapshot]) -> ConsolidatedOrderBook:
"""
Aggregate order book data from multiple exchanges.
Args:
symbol: Trading symbol
orderbooks: List of order book snapshots from different exchanges
Returns:
ConsolidatedOrderBook: Consolidated order book data
"""
pass
@abstractmethod
def calculate_volume_weighted_price(self, orderbooks: List[OrderBookSnapshot]) -> float:
"""
Calculate volume-weighted average price across exchanges.
Args:
orderbooks: List of order book snapshots
Returns:
float: Volume-weighted average price
"""
pass
@abstractmethod
def get_market_depth(self, orderbook: OrderBookSnapshot,
depth_levels: List[float]) -> Dict[float, Dict[str, float]]:
"""
Calculate market depth at different price levels.
Args:
orderbook: Order book snapshot
depth_levels: List of depth percentages (e.g., [0.1, 0.5, 1.0])
Returns:
Dict: Market depth data {level: {'bid_volume': x, 'ask_volume': y}}
"""
pass
@abstractmethod
def smooth_heatmap(self, heatmap: HeatmapData, smoothing_factor: float) -> HeatmapData:
"""
Apply smoothing to heatmap data to reduce noise.
Args:
heatmap: Raw heatmap data
smoothing_factor: Smoothing factor (0.0 to 1.0)
Returns:
HeatmapData: Smoothed heatmap data
"""
pass
@abstractmethod
def calculate_liquidity_score(self, orderbook: OrderBookSnapshot) -> float:
"""
Calculate liquidity score for an order book.
Args:
orderbook: Order book snapshot
Returns:
float: Liquidity score (0.0 to 1.0)
"""
pass
@abstractmethod
def detect_support_resistance(self, heatmap: HeatmapData) -> Dict[str, List[float]]:
"""
Detect support and resistance levels from heatmap data.
Args:
heatmap: Heatmap data
Returns:
Dict: {'support': [prices], 'resistance': [prices]}
"""
pass

View File

@ -0,0 +1,119 @@
"""
Interface for data processing and normalization.
"""
from abc import ABC, abstractmethod
from typing import Dict, Union, List, Optional
from ..models.core import OrderBookSnapshot, TradeEvent, OrderBookMetrics
class DataProcessor(ABC):
"""Processes and normalizes raw exchange data"""
@abstractmethod
def normalize_orderbook(self, raw_data: Dict, exchange: str) -> OrderBookSnapshot:
"""
Normalize raw order book data to standard format.
Args:
raw_data: Raw order book data from exchange
exchange: Exchange name
Returns:
OrderBookSnapshot: Normalized order book data
"""
pass
@abstractmethod
def normalize_trade(self, raw_data: Dict, exchange: str) -> TradeEvent:
"""
Normalize raw trade data to standard format.
Args:
raw_data: Raw trade data from exchange
exchange: Exchange name
Returns:
TradeEvent: Normalized trade data
"""
pass
@abstractmethod
def validate_data(self, data: Union[OrderBookSnapshot, TradeEvent]) -> bool:
"""
Validate normalized data for quality and consistency.
Args:
data: Normalized data to validate
Returns:
bool: True if data is valid, False otherwise
"""
pass
@abstractmethod
def calculate_metrics(self, orderbook: OrderBookSnapshot) -> OrderBookMetrics:
"""
Calculate metrics from order book data.
Args:
orderbook: Order book snapshot
Returns:
OrderBookMetrics: Calculated metrics
"""
pass
@abstractmethod
def detect_anomalies(self, data: Union[OrderBookSnapshot, TradeEvent]) -> List[str]:
"""
Detect anomalies in the data.
Args:
data: Data to analyze for anomalies
Returns:
List[str]: List of detected anomaly descriptions
"""
pass
@abstractmethod
def filter_data(self, data: Union[OrderBookSnapshot, TradeEvent],
criteria: Dict) -> bool:
"""
Filter data based on criteria.
Args:
data: Data to filter
criteria: Filtering criteria
Returns:
bool: True if data passes filter, False otherwise
"""
pass
@abstractmethod
def enrich_data(self, data: Union[OrderBookSnapshot, TradeEvent]) -> Dict:
"""
Enrich data with additional metadata.
Args:
data: Data to enrich
Returns:
Dict: Enriched data with metadata
"""
pass
@abstractmethod
def get_data_quality_score(self, data: Union[OrderBookSnapshot, TradeEvent]) -> float:
"""
Calculate data quality score.
Args:
data: Data to score
Returns:
float: Quality score between 0.0 and 1.0
"""
pass

View File

@ -0,0 +1,189 @@
"""
Base interface for exchange WebSocket connectors.
"""
from abc import ABC, abstractmethod
from typing import Callable, List, Optional
from ..models.core import ConnectionStatus, OrderBookSnapshot, TradeEvent
class ExchangeConnector(ABC):
"""Base interface for exchange WebSocket connectors"""
def __init__(self, exchange_name: str):
self.exchange_name = exchange_name
self._data_callbacks: List[Callable] = []
self._status_callbacks: List[Callable] = []
self._connection_status = ConnectionStatus.DISCONNECTED
@abstractmethod
async def connect(self) -> bool:
"""
Establish connection to the exchange WebSocket.
Returns:
bool: True if connection successful, False otherwise
"""
pass
@abstractmethod
async def disconnect(self) -> None:
"""Disconnect from the exchange WebSocket."""
pass
@abstractmethod
async def subscribe_orderbook(self, symbol: str) -> None:
"""
Subscribe to order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
pass
@abstractmethod
async def subscribe_trades(self, symbol: str) -> None:
"""
Subscribe to trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
pass
@abstractmethod
async def unsubscribe_orderbook(self, symbol: str) -> None:
"""
Unsubscribe from order book updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
pass
@abstractmethod
async def unsubscribe_trades(self, symbol: str) -> None:
"""
Unsubscribe from trade updates for a symbol.
Args:
symbol: Trading symbol (e.g., 'BTCUSDT')
"""
pass
def get_connection_status(self) -> ConnectionStatus:
"""
Get current connection status.
Returns:
ConnectionStatus: Current connection status
"""
return self._connection_status
def add_data_callback(self, callback: Callable) -> None:
"""
Add callback for data updates.
Args:
callback: Function to call when data is received
Signature: callback(data: Union[OrderBookSnapshot, TradeEvent])
"""
if callback not in self._data_callbacks:
self._data_callbacks.append(callback)
def remove_data_callback(self, callback: Callable) -> None:
"""
Remove data callback.
Args:
callback: Callback function to remove
"""
if callback in self._data_callbacks:
self._data_callbacks.remove(callback)
def add_status_callback(self, callback: Callable) -> None:
"""
Add callback for status updates.
Args:
callback: Function to call when status changes
Signature: callback(exchange: str, status: ConnectionStatus)
"""
if callback not in self._status_callbacks:
self._status_callbacks.append(callback)
def remove_status_callback(self, callback: Callable) -> None:
"""
Remove status callback.
Args:
callback: Callback function to remove
"""
if callback in self._status_callbacks:
self._status_callbacks.remove(callback)
def _notify_data_callbacks(self, data):
"""Notify all data callbacks of new data."""
for callback in self._data_callbacks:
try:
callback(data)
except Exception as e:
# Log error but don't stop other callbacks
print(f"Error in data callback: {e}")
def _notify_status_callbacks(self, status: ConnectionStatus):
"""Notify all status callbacks of status change."""
self._connection_status = status
for callback in self._status_callbacks:
try:
callback(self.exchange_name, status)
except Exception as e:
# Log error but don't stop other callbacks
print(f"Error in status callback: {e}")
@abstractmethod
async def get_symbols(self) -> List[str]:
"""
Get list of available trading symbols.
Returns:
List[str]: List of available symbols
"""
pass
@abstractmethod
def normalize_symbol(self, symbol: str) -> str:
"""
Normalize symbol to exchange format.
Args:
symbol: Standard symbol format (e.g., 'BTCUSDT')
Returns:
str: Exchange-specific symbol format
"""
pass
@abstractmethod
async def get_orderbook_snapshot(self, symbol: str, depth: int = 20) -> Optional[OrderBookSnapshot]:
"""
Get current order book snapshot.
Args:
symbol: Trading symbol
depth: Number of price levels to retrieve
Returns:
OrderBookSnapshot: Current order book or None if unavailable
"""
pass
@property
def name(self) -> str:
"""Get exchange name."""
return self.exchange_name
@property
def is_connected(self) -> bool:
"""Check if connector is connected."""
return self._connection_status == ConnectionStatus.CONNECTED

View File

@ -0,0 +1,212 @@
"""
Interface for historical data replay functionality.
"""
from abc import ABC, abstractmethod
from datetime import datetime
from typing import List, Optional, Callable, Dict, Any
from ..models.core import ReplaySession, ReplayStatus
class ReplayManager(ABC):
"""Provides historical data replay functionality"""
@abstractmethod
def create_replay_session(self, start_time: datetime, end_time: datetime,
speed: float = 1.0, symbols: Optional[List[str]] = None,
exchanges: Optional[List[str]] = None) -> str:
"""
Create a new replay session.
Args:
start_time: Replay start time
end_time: Replay end time
speed: Playback speed multiplier (1.0 = real-time)
symbols: List of symbols to replay (None = all)
exchanges: List of exchanges to replay (None = all)
Returns:
str: Session ID
"""
pass
@abstractmethod
async def start_replay(self, session_id: str) -> None:
"""
Start replay session.
Args:
session_id: Session ID to start
"""
pass
@abstractmethod
async def pause_replay(self, session_id: str) -> None:
"""
Pause replay session.
Args:
session_id: Session ID to pause
"""
pass
@abstractmethod
async def resume_replay(self, session_id: str) -> None:
"""
Resume paused replay session.
Args:
session_id: Session ID to resume
"""
pass
@abstractmethod
async def stop_replay(self, session_id: str) -> None:
"""
Stop replay session.
Args:
session_id: Session ID to stop
"""
pass
@abstractmethod
def get_replay_status(self, session_id: str) -> Optional[ReplaySession]:
"""
Get replay session status.
Args:
session_id: Session ID
Returns:
ReplaySession: Session status or None if not found
"""
pass
@abstractmethod
def list_replay_sessions(self) -> List[ReplaySession]:
"""
List all replay sessions.
Returns:
List[ReplaySession]: List of all sessions
"""
pass
@abstractmethod
def delete_replay_session(self, session_id: str) -> bool:
"""
Delete replay session.
Args:
session_id: Session ID to delete
Returns:
bool: True if deleted successfully, False otherwise
"""
pass
@abstractmethod
def set_replay_speed(self, session_id: str, speed: float) -> bool:
"""
Change replay speed for active session.
Args:
session_id: Session ID
speed: New playback speed multiplier
Returns:
bool: True if speed changed successfully, False otherwise
"""
pass
@abstractmethod
def seek_replay(self, session_id: str, timestamp: datetime) -> bool:
"""
Seek to specific timestamp in replay.
Args:
session_id: Session ID
timestamp: Target timestamp
Returns:
bool: True if seek successful, False otherwise
"""
pass
@abstractmethod
def add_data_callback(self, session_id: str, callback: Callable) -> bool:
"""
Add callback for replay data.
Args:
session_id: Session ID
callback: Function to call with replay data
Signature: callback(data: Union[OrderBookSnapshot, TradeEvent])
Returns:
bool: True if callback added successfully, False otherwise
"""
pass
@abstractmethod
def remove_data_callback(self, session_id: str, callback: Callable) -> bool:
"""
Remove data callback from replay session.
Args:
session_id: Session ID
callback: Callback function to remove
Returns:
bool: True if callback removed successfully, False otherwise
"""
pass
@abstractmethod
def add_status_callback(self, session_id: str, callback: Callable) -> bool:
"""
Add callback for replay status changes.
Args:
session_id: Session ID
callback: Function to call on status change
Signature: callback(session_id: str, status: ReplayStatus)
Returns:
bool: True if callback added successfully, False otherwise
"""
pass
@abstractmethod
async def get_available_data_range(self, symbol: str,
exchange: Optional[str] = None) -> Optional[Dict[str, datetime]]:
"""
Get available data time range for replay.
Args:
symbol: Trading symbol
exchange: Exchange name (None = all exchanges)
Returns:
Dict: {'start': datetime, 'end': datetime} or None if no data
"""
pass
@abstractmethod
def validate_replay_request(self, start_time: datetime, end_time: datetime,
symbols: Optional[List[str]] = None,
exchanges: Optional[List[str]] = None) -> List[str]:
"""
Validate replay request parameters.
Args:
start_time: Requested start time
end_time: Requested end time
symbols: Requested symbols
exchanges: Requested exchanges
Returns:
List[str]: List of validation errors (empty if valid)
"""
pass

View File

@ -0,0 +1,215 @@
"""
Interface for data storage and retrieval.
"""
from abc import ABC, abstractmethod
from datetime import datetime
from typing import List, Dict, Optional, Any
from ..models.core import OrderBookSnapshot, TradeEvent, HeatmapData, SystemMetrics
class StorageManager(ABC):
"""Manages data persistence and retrieval"""
@abstractmethod
async def store_orderbook(self, data: OrderBookSnapshot) -> bool:
"""
Store order book snapshot to database.
Args:
data: Order book snapshot to store
Returns:
bool: True if stored successfully, False otherwise
"""
pass
@abstractmethod
async def store_trade(self, data: TradeEvent) -> bool:
"""
Store trade event to database.
Args:
data: Trade event to store
Returns:
bool: True if stored successfully, False otherwise
"""
pass
@abstractmethod
async def store_heatmap(self, data: HeatmapData) -> bool:
"""
Store heatmap data to database.
Args:
data: Heatmap data to store
Returns:
bool: True if stored successfully, False otherwise
"""
pass
@abstractmethod
async def store_metrics(self, data: SystemMetrics) -> bool:
"""
Store system metrics to database.
Args:
data: System metrics to store
Returns:
bool: True if stored successfully, False otherwise
"""
pass
@abstractmethod
async def get_historical_orderbooks(self, symbol: str, exchange: str,
start: datetime, end: datetime,
limit: Optional[int] = None) -> List[OrderBookSnapshot]:
"""
Retrieve historical order book data.
Args:
symbol: Trading symbol
exchange: Exchange name
start: Start timestamp
end: End timestamp
limit: Maximum number of records to return
Returns:
List[OrderBookSnapshot]: Historical order book data
"""
pass
@abstractmethod
async def get_historical_trades(self, symbol: str, exchange: str,
start: datetime, end: datetime,
limit: Optional[int] = None) -> List[TradeEvent]:
"""
Retrieve historical trade data.
Args:
symbol: Trading symbol
exchange: Exchange name
start: Start timestamp
end: End timestamp
limit: Maximum number of records to return
Returns:
List[TradeEvent]: Historical trade data
"""
pass
@abstractmethod
async def get_latest_orderbook(self, symbol: str, exchange: str) -> Optional[OrderBookSnapshot]:
"""
Get latest order book snapshot.
Args:
symbol: Trading symbol
exchange: Exchange name
Returns:
OrderBookSnapshot: Latest order book or None if not found
"""
pass
@abstractmethod
async def get_latest_heatmap(self, symbol: str, bucket_size: float) -> Optional[HeatmapData]:
"""
Get latest heatmap data.
Args:
symbol: Trading symbol
bucket_size: Price bucket size
Returns:
HeatmapData: Latest heatmap or None if not found
"""
pass
@abstractmethod
async def get_ohlcv_data(self, symbol: str, exchange: str, timeframe: str,
start: datetime, end: datetime) -> List[Dict[str, Any]]:
"""
Get OHLCV candlestick data.
Args:
symbol: Trading symbol
exchange: Exchange name
timeframe: Timeframe (e.g., '1m', '5m', '1h')
start: Start timestamp
end: End timestamp
Returns:
List[Dict]: OHLCV data
"""
pass
@abstractmethod
async def batch_store_orderbooks(self, data: List[OrderBookSnapshot]) -> int:
"""
Store multiple order book snapshots in batch.
Args:
data: List of order book snapshots
Returns:
int: Number of records stored successfully
"""
pass
@abstractmethod
async def batch_store_trades(self, data: List[TradeEvent]) -> int:
"""
Store multiple trade events in batch.
Args:
data: List of trade events
Returns:
int: Number of records stored successfully
"""
pass
@abstractmethod
def setup_database_schema(self) -> None:
"""
Set up database schema and tables.
Should be idempotent - safe to call multiple times.
"""
pass
@abstractmethod
async def cleanup_old_data(self, retention_days: int) -> int:
"""
Clean up old data based on retention policy.
Args:
retention_days: Number of days to retain data
Returns:
int: Number of records deleted
"""
pass
@abstractmethod
async def get_storage_stats(self) -> Dict[str, Any]:
"""
Get storage statistics.
Returns:
Dict: Storage statistics (table sizes, record counts, etc.)
"""
pass
@abstractmethod
async def health_check(self) -> bool:
"""
Check storage system health.
Returns:
bool: True if healthy, False otherwise
"""
pass

31
COBY/models/__init__.py Normal file
View File

@ -0,0 +1,31 @@
"""
Data models for the multi-exchange data aggregation system.
"""
from .core import (
OrderBookSnapshot,
PriceLevel,
TradeEvent,
PriceBuckets,
HeatmapData,
HeatmapPoint,
ConnectionStatus,
OrderBookMetrics,
ImbalanceMetrics,
ConsolidatedOrderBook,
ReplayStatus
)
__all__ = [
'OrderBookSnapshot',
'PriceLevel',
'TradeEvent',
'PriceBuckets',
'HeatmapData',
'HeatmapPoint',
'ConnectionStatus',
'OrderBookMetrics',
'ImbalanceMetrics',
'ConsolidatedOrderBook',
'ReplayStatus'
]

324
COBY/models/core.py Normal file
View File

@ -0,0 +1,324 @@
"""
Core data models for the multi-exchange data aggregation system.
"""
from dataclasses import dataclass, field
from datetime import datetime
from typing import List, Dict, Optional, Any
from enum import Enum
class ConnectionStatus(Enum):
"""Exchange connection status"""
DISCONNECTED = "disconnected"
CONNECTING = "connecting"
CONNECTED = "connected"
RECONNECTING = "reconnecting"
ERROR = "error"
class ReplayStatus(Enum):
"""Replay session status"""
CREATED = "created"
RUNNING = "running"
PAUSED = "paused"
STOPPED = "stopped"
COMPLETED = "completed"
ERROR = "error"
@dataclass
class PriceLevel:
"""Individual price level in order book"""
price: float
size: float
count: Optional[int] = None
def __post_init__(self):
"""Validate price level data"""
if self.price <= 0:
raise ValueError("Price must be positive")
if self.size < 0:
raise ValueError("Size cannot be negative")
@dataclass
class OrderBookSnapshot:
"""Standardized order book snapshot"""
symbol: str
exchange: str
timestamp: datetime
bids: List[PriceLevel]
asks: List[PriceLevel]
sequence_id: Optional[int] = None
def __post_init__(self):
"""Validate and sort order book data"""
if not self.symbol:
raise ValueError("Symbol cannot be empty")
if not self.exchange:
raise ValueError("Exchange cannot be empty")
# Sort bids descending (highest price first)
self.bids.sort(key=lambda x: x.price, reverse=True)
# Sort asks ascending (lowest price first)
self.asks.sort(key=lambda x: x.price)
@property
def mid_price(self) -> Optional[float]:
"""Calculate mid price"""
if self.bids and self.asks:
return (self.bids[0].price + self.asks[0].price) / 2
return None
@property
def spread(self) -> Optional[float]:
"""Calculate bid-ask spread"""
if self.bids and self.asks:
return self.asks[0].price - self.bids[0].price
return None
@property
def bid_volume(self) -> float:
"""Total bid volume"""
return sum(level.size for level in self.bids)
@property
def ask_volume(self) -> float:
"""Total ask volume"""
return sum(level.size for level in self.asks)
@dataclass
class TradeEvent:
"""Standardized trade event"""
symbol: str
exchange: str
timestamp: datetime
price: float
size: float
side: str # 'buy' or 'sell'
trade_id: str
def __post_init__(self):
"""Validate trade event data"""
if not self.symbol:
raise ValueError("Symbol cannot be empty")
if not self.exchange:
raise ValueError("Exchange cannot be empty")
if self.price <= 0:
raise ValueError("Price must be positive")
if self.size <= 0:
raise ValueError("Size must be positive")
if self.side not in ['buy', 'sell']:
raise ValueError("Side must be 'buy' or 'sell'")
if not self.trade_id:
raise ValueError("Trade ID cannot be empty")
@dataclass
class PriceBuckets:
"""Aggregated price buckets for heatmap"""
symbol: str
timestamp: datetime
bucket_size: float
bid_buckets: Dict[float, float] = field(default_factory=dict) # price -> volume
ask_buckets: Dict[float, float] = field(default_factory=dict) # price -> volume
def __post_init__(self):
"""Validate price buckets"""
if self.bucket_size <= 0:
raise ValueError("Bucket size must be positive")
def get_bucket_price(self, price: float) -> float:
"""Get bucket price for a given price"""
return round(price / self.bucket_size) * self.bucket_size
def add_bid(self, price: float, volume: float):
"""Add bid volume to appropriate bucket"""
bucket_price = self.get_bucket_price(price)
self.bid_buckets[bucket_price] = self.bid_buckets.get(bucket_price, 0) + volume
def add_ask(self, price: float, volume: float):
"""Add ask volume to appropriate bucket"""
bucket_price = self.get_bucket_price(price)
self.ask_buckets[bucket_price] = self.ask_buckets.get(bucket_price, 0) + volume
@dataclass
class HeatmapPoint:
"""Individual heatmap data point"""
price: float
volume: float
intensity: float # 0.0 to 1.0
side: str # 'bid' or 'ask'
def __post_init__(self):
"""Validate heatmap point"""
if self.price <= 0:
raise ValueError("Price must be positive")
if self.volume < 0:
raise ValueError("Volume cannot be negative")
if not 0 <= self.intensity <= 1:
raise ValueError("Intensity must be between 0 and 1")
if self.side not in ['bid', 'ask']:
raise ValueError("Side must be 'bid' or 'ask'")
@dataclass
class HeatmapData:
"""Heatmap visualization data"""
symbol: str
timestamp: datetime
bucket_size: float
data: List[HeatmapPoint] = field(default_factory=list)
def __post_init__(self):
"""Validate heatmap data"""
if self.bucket_size <= 0:
raise ValueError("Bucket size must be positive")
def add_point(self, price: float, volume: float, side: str, max_volume: float = None):
"""Add a heatmap point with calculated intensity"""
if max_volume is None:
max_volume = max((point.volume for point in self.data), default=volume)
intensity = min(volume / max_volume, 1.0) if max_volume > 0 else 0.0
point = HeatmapPoint(price=price, volume=volume, intensity=intensity, side=side)
self.data.append(point)
def get_bids(self) -> List[HeatmapPoint]:
"""Get bid points sorted by price descending"""
bids = [point for point in self.data if point.side == 'bid']
return sorted(bids, key=lambda x: x.price, reverse=True)
def get_asks(self) -> List[HeatmapPoint]:
"""Get ask points sorted by price ascending"""
asks = [point for point in self.data if point.side == 'ask']
return sorted(asks, key=lambda x: x.price)
@dataclass
class OrderBookMetrics:
"""Order book analysis metrics"""
symbol: str
exchange: str
timestamp: datetime
mid_price: float
spread: float
spread_percentage: float
bid_volume: float
ask_volume: float
volume_imbalance: float # (bid_volume - ask_volume) / (bid_volume + ask_volume)
depth_10: float # Volume within 10 price levels
depth_50: float # Volume within 50 price levels
def __post_init__(self):
"""Validate metrics"""
if self.mid_price <= 0:
raise ValueError("Mid price must be positive")
if self.spread < 0:
raise ValueError("Spread cannot be negative")
@dataclass
class ImbalanceMetrics:
"""Order book imbalance metrics"""
symbol: str
timestamp: datetime
volume_imbalance: float
price_imbalance: float
depth_imbalance: float
momentum_score: float # Derived from recent imbalance changes
def __post_init__(self):
"""Validate imbalance metrics"""
if not -1 <= self.volume_imbalance <= 1:
raise ValueError("Volume imbalance must be between -1 and 1")
@dataclass
class ConsolidatedOrderBook:
"""Consolidated order book from multiple exchanges"""
symbol: str
timestamp: datetime
exchanges: List[str]
bids: List[PriceLevel]
asks: List[PriceLevel]
weighted_mid_price: float
total_bid_volume: float
total_ask_volume: float
exchange_weights: Dict[str, float] = field(default_factory=dict)
def __post_init__(self):
"""Validate consolidated order book"""
if not self.exchanges:
raise ValueError("At least one exchange must be specified")
if self.weighted_mid_price <= 0:
raise ValueError("Weighted mid price must be positive")
@dataclass
class ExchangeStatus:
"""Exchange connection and health status"""
exchange: str
status: ConnectionStatus
last_message_time: Optional[datetime] = None
error_message: Optional[str] = None
connection_count: int = 0
uptime_percentage: float = 0.0
message_rate: float = 0.0 # Messages per second
def __post_init__(self):
"""Validate exchange status"""
if not self.exchange:
raise ValueError("Exchange name cannot be empty")
if not 0 <= self.uptime_percentage <= 100:
raise ValueError("Uptime percentage must be between 0 and 100")
@dataclass
class SystemMetrics:
"""System performance metrics"""
timestamp: datetime
cpu_usage: float
memory_usage: float
disk_usage: float
network_io: Dict[str, float] = field(default_factory=dict)
database_connections: int = 0
redis_connections: int = 0
active_websockets: int = 0
messages_per_second: float = 0.0
processing_latency: float = 0.0 # Milliseconds
def __post_init__(self):
"""Validate system metrics"""
if not 0 <= self.cpu_usage <= 100:
raise ValueError("CPU usage must be between 0 and 100")
if not 0 <= self.memory_usage <= 100:
raise ValueError("Memory usage must be between 0 and 100")
@dataclass
class ReplaySession:
"""Historical data replay session"""
session_id: str
start_time: datetime
end_time: datetime
speed: float # Playback speed multiplier
status: ReplayStatus
current_time: Optional[datetime] = None
progress: float = 0.0 # 0.0 to 1.0
symbols: List[str] = field(default_factory=list)
exchanges: List[str] = field(default_factory=list)
def __post_init__(self):
"""Validate replay session"""
if not self.session_id:
raise ValueError("Session ID cannot be empty")
if self.start_time >= self.end_time:
raise ValueError("Start time must be before end time")
if self.speed <= 0:
raise ValueError("Speed must be positive")
if not 0 <= self.progress <= 1:
raise ValueError("Progress must be between 0 and 1")

View File

@ -0,0 +1,15 @@
"""
Data processing and normalization components for the COBY system.
"""
from .data_processor import StandardDataProcessor
from .quality_checker import DataQualityChecker
from .anomaly_detector import AnomalyDetector
from .metrics_calculator import MetricsCalculator
__all__ = [
'StandardDataProcessor',
'DataQualityChecker',
'AnomalyDetector',
'MetricsCalculator'
]

View File

@ -0,0 +1,329 @@
"""
Anomaly detection for market data.
"""
import statistics
from typing import Dict, List, Union, Optional, Deque
from collections import deque
from datetime import datetime, timedelta
from ..models.core import OrderBookSnapshot, TradeEvent
from ..utils.logging import get_logger
from ..utils.timing import get_current_timestamp
logger = get_logger(__name__)
class AnomalyDetector:
"""
Detects anomalies in market data using statistical methods.
Detects:
- Price spikes and drops
- Volume anomalies
- Spread anomalies
- Frequency anomalies
"""
def __init__(self, window_size: int = 100, z_score_threshold: float = 3.0):
"""
Initialize anomaly detector.
Args:
window_size: Size of rolling window for statistics
z_score_threshold: Z-score threshold for anomaly detection
"""
self.window_size = window_size
self.z_score_threshold = z_score_threshold
# Rolling windows for statistics
self.price_windows: Dict[str, Deque[float]] = {}
self.volume_windows: Dict[str, Deque[float]] = {}
self.spread_windows: Dict[str, Deque[float]] = {}
self.timestamp_windows: Dict[str, Deque[datetime]] = {}
logger.info(f"Anomaly detector initialized with window_size={window_size}, threshold={z_score_threshold}")
def detect_orderbook_anomalies(self, orderbook: OrderBookSnapshot) -> List[str]:
"""
Detect anomalies in order book data.
Args:
orderbook: Order book snapshot to analyze
Returns:
List[str]: List of detected anomalies
"""
anomalies = []
key = f"{orderbook.symbol}_{orderbook.exchange}"
try:
# Price anomalies
if orderbook.mid_price:
price_anomalies = self._detect_price_anomalies(key, orderbook.mid_price)
anomalies.extend(price_anomalies)
# Volume anomalies
total_volume = orderbook.bid_volume + orderbook.ask_volume
volume_anomalies = self._detect_volume_anomalies(key, total_volume)
anomalies.extend(volume_anomalies)
# Spread anomalies
if orderbook.spread and orderbook.mid_price:
spread_pct = (orderbook.spread / orderbook.mid_price) * 100
spread_anomalies = self._detect_spread_anomalies(key, spread_pct)
anomalies.extend(spread_anomalies)
# Frequency anomalies
frequency_anomalies = self._detect_frequency_anomalies(key, orderbook.timestamp)
anomalies.extend(frequency_anomalies)
# Update windows
self._update_windows(key, orderbook)
except Exception as e:
logger.error(f"Error detecting order book anomalies: {e}")
anomalies.append(f"Anomaly detection error: {e}")
if anomalies:
logger.warning(f"Anomalies detected in {orderbook.symbol}@{orderbook.exchange}: {anomalies}")
return anomalies
def detect_trade_anomalies(self, trade: TradeEvent) -> List[str]:
"""
Detect anomalies in trade data.
Args:
trade: Trade event to analyze
Returns:
List[str]: List of detected anomalies
"""
anomalies = []
key = f"{trade.symbol}_{trade.exchange}_trade"
try:
# Price anomalies
price_anomalies = self._detect_price_anomalies(key, trade.price)
anomalies.extend(price_anomalies)
# Volume anomalies
volume_anomalies = self._detect_volume_anomalies(key, trade.size)
anomalies.extend(volume_anomalies)
# Update windows
self._update_trade_windows(key, trade)
except Exception as e:
logger.error(f"Error detecting trade anomalies: {e}")
anomalies.append(f"Anomaly detection error: {e}")
if anomalies:
logger.warning(f"Trade anomalies detected in {trade.symbol}@{trade.exchange}: {anomalies}")
return anomalies
def _detect_price_anomalies(self, key: str, price: float) -> List[str]:
"""Detect price anomalies using z-score"""
anomalies = []
if key not in self.price_windows:
self.price_windows[key] = deque(maxlen=self.window_size)
return anomalies
window = self.price_windows[key]
if len(window) < 10: # Need minimum data points
return anomalies
try:
mean_price = statistics.mean(window)
std_price = statistics.stdev(window)
if std_price > 0:
z_score = abs(price - mean_price) / std_price
if z_score > self.z_score_threshold:
direction = "spike" if price > mean_price else "drop"
anomalies.append(f"Price {direction}: {price:.6f} (z-score: {z_score:.2f})")
except statistics.StatisticsError:
pass # Not enough data or all values are the same
return anomalies
def _detect_volume_anomalies(self, key: str, volume: float) -> List[str]:
"""Detect volume anomalies using z-score"""
anomalies = []
volume_key = f"{key}_volume"
if volume_key not in self.volume_windows:
self.volume_windows[volume_key] = deque(maxlen=self.window_size)
return anomalies
window = self.volume_windows[volume_key]
if len(window) < 10:
return anomalies
try:
mean_volume = statistics.mean(window)
std_volume = statistics.stdev(window)
if std_volume > 0:
z_score = abs(volume - mean_volume) / std_volume
if z_score > self.z_score_threshold:
direction = "spike" if volume > mean_volume else "drop"
anomalies.append(f"Volume {direction}: {volume:.6f} (z-score: {z_score:.2f})")
except statistics.StatisticsError:
pass
return anomalies
def _detect_spread_anomalies(self, key: str, spread_pct: float) -> List[str]:
"""Detect spread anomalies using z-score"""
anomalies = []
spread_key = f"{key}_spread"
if spread_key not in self.spread_windows:
self.spread_windows[spread_key] = deque(maxlen=self.window_size)
return anomalies
window = self.spread_windows[spread_key]
if len(window) < 10:
return anomalies
try:
mean_spread = statistics.mean(window)
std_spread = statistics.stdev(window)
if std_spread > 0:
z_score = abs(spread_pct - mean_spread) / std_spread
if z_score > self.z_score_threshold:
direction = "widening" if spread_pct > mean_spread else "tightening"
anomalies.append(f"Spread {direction}: {spread_pct:.4f}% (z-score: {z_score:.2f})")
except statistics.StatisticsError:
pass
return anomalies
def _detect_frequency_anomalies(self, key: str, timestamp: datetime) -> List[str]:
"""Detect frequency anomalies in data updates"""
anomalies = []
timestamp_key = f"{key}_timestamp"
if timestamp_key not in self.timestamp_windows:
self.timestamp_windows[timestamp_key] = deque(maxlen=self.window_size)
return anomalies
window = self.timestamp_windows[timestamp_key]
if len(window) < 5:
return anomalies
try:
# Calculate intervals between updates
intervals = []
for i in range(1, len(window)):
interval = (window[i] - window[i-1]).total_seconds()
intervals.append(interval)
if len(intervals) >= 5:
mean_interval = statistics.mean(intervals)
std_interval = statistics.stdev(intervals)
# Check current interval
current_interval = (timestamp - window[-1]).total_seconds()
if std_interval > 0:
z_score = abs(current_interval - mean_interval) / std_interval
if z_score > self.z_score_threshold:
if current_interval > mean_interval:
anomalies.append(f"Update delay: {current_interval:.1f}s (expected: {mean_interval:.1f}s)")
else:
anomalies.append(f"Update burst: {current_interval:.1f}s (expected: {mean_interval:.1f}s)")
except (statistics.StatisticsError, IndexError):
pass
return anomalies
def _update_windows(self, key: str, orderbook: OrderBookSnapshot) -> None:
"""Update rolling windows with new data"""
# Update price window
if orderbook.mid_price:
if key not in self.price_windows:
self.price_windows[key] = deque(maxlen=self.window_size)
self.price_windows[key].append(orderbook.mid_price)
# Update volume window
total_volume = orderbook.bid_volume + orderbook.ask_volume
volume_key = f"{key}_volume"
if volume_key not in self.volume_windows:
self.volume_windows[volume_key] = deque(maxlen=self.window_size)
self.volume_windows[volume_key].append(total_volume)
# Update spread window
if orderbook.spread and orderbook.mid_price:
spread_pct = (orderbook.spread / orderbook.mid_price) * 100
spread_key = f"{key}_spread"
if spread_key not in self.spread_windows:
self.spread_windows[spread_key] = deque(maxlen=self.window_size)
self.spread_windows[spread_key].append(spread_pct)
# Update timestamp window
timestamp_key = f"{key}_timestamp"
if timestamp_key not in self.timestamp_windows:
self.timestamp_windows[timestamp_key] = deque(maxlen=self.window_size)
self.timestamp_windows[timestamp_key].append(orderbook.timestamp)
def _update_trade_windows(self, key: str, trade: TradeEvent) -> None:
"""Update rolling windows with trade data"""
# Update price window
if key not in self.price_windows:
self.price_windows[key] = deque(maxlen=self.window_size)
self.price_windows[key].append(trade.price)
# Update volume window
volume_key = f"{key}_volume"
if volume_key not in self.volume_windows:
self.volume_windows[volume_key] = deque(maxlen=self.window_size)
self.volume_windows[volume_key].append(trade.size)
def get_statistics(self) -> Dict[str, Dict[str, float]]:
"""Get current statistics for all tracked symbols"""
stats = {}
for key, window in self.price_windows.items():
if len(window) >= 2:
try:
stats[key] = {
'price_mean': statistics.mean(window),
'price_std': statistics.stdev(window),
'price_min': min(window),
'price_max': max(window),
'data_points': len(window)
}
except statistics.StatisticsError:
stats[key] = {'error': 'insufficient_data'}
return stats
def reset_windows(self, key: Optional[str] = None) -> None:
"""Reset rolling windows for a specific key or all keys"""
if key:
# Reset specific key
self.price_windows.pop(key, None)
self.volume_windows.pop(f"{key}_volume", None)
self.spread_windows.pop(f"{key}_spread", None)
self.timestamp_windows.pop(f"{key}_timestamp", None)
else:
# Reset all windows
self.price_windows.clear()
self.volume_windows.clear()
self.spread_windows.clear()
self.timestamp_windows.clear()
logger.info(f"Reset anomaly detection windows for {key or 'all keys'}")

View File

@ -0,0 +1,378 @@
"""
Main data processor implementation.
"""
from typing import Dict, Union, List, Optional, Any
from ..interfaces.data_processor import DataProcessor
from ..models.core import OrderBookSnapshot, TradeEvent, OrderBookMetrics
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ValidationError, ProcessingError
from ..utils.timing import get_current_timestamp
from .quality_checker import DataQualityChecker
from .anomaly_detector import AnomalyDetector
from .metrics_calculator import MetricsCalculator
logger = get_logger(__name__)
class StandardDataProcessor(DataProcessor):
"""
Standard implementation of data processor interface.
Provides:
- Data normalization and validation
- Quality checking
- Anomaly detection
- Metrics calculation
- Data enrichment
"""
def __init__(self):
"""Initialize data processor with components"""
self.quality_checker = DataQualityChecker()
self.anomaly_detector = AnomalyDetector()
self.metrics_calculator = MetricsCalculator()
# Processing statistics
self.processed_orderbooks = 0
self.processed_trades = 0
self.quality_failures = 0
self.anomalies_detected = 0
logger.info("Standard data processor initialized")
def normalize_orderbook(self, raw_data: Dict, exchange: str) -> OrderBookSnapshot:
"""
Normalize raw order book data to standard format.
Args:
raw_data: Raw order book data from exchange
exchange: Exchange name
Returns:
OrderBookSnapshot: Normalized order book data
"""
try:
set_correlation_id()
# This is a generic implementation - specific exchanges would override
# For now, assume data is already in correct format
if isinstance(raw_data, OrderBookSnapshot):
return raw_data
# If raw_data is a dict, try to construct OrderBookSnapshot
# This would be customized per exchange
raise NotImplementedError(
"normalize_orderbook should be implemented by exchange-specific processors"
)
except Exception as e:
logger.error(f"Error normalizing order book data: {e}")
raise ProcessingError(f"Normalization failed: {e}", "NORMALIZE_ERROR")
def normalize_trade(self, raw_data: Dict, exchange: str) -> TradeEvent:
"""
Normalize raw trade data to standard format.
Args:
raw_data: Raw trade data from exchange
exchange: Exchange name
Returns:
TradeEvent: Normalized trade data
"""
try:
set_correlation_id()
# This is a generic implementation - specific exchanges would override
if isinstance(raw_data, TradeEvent):
return raw_data
# If raw_data is a dict, try to construct TradeEvent
# This would be customized per exchange
raise NotImplementedError(
"normalize_trade should be implemented by exchange-specific processors"
)
except Exception as e:
logger.error(f"Error normalizing trade data: {e}")
raise ProcessingError(f"Normalization failed: {e}", "NORMALIZE_ERROR")
def validate_data(self, data: Union[OrderBookSnapshot, TradeEvent]) -> bool:
"""
Validate normalized data for quality and consistency.
Args:
data: Normalized data to validate
Returns:
bool: True if data is valid, False otherwise
"""
try:
set_correlation_id()
if isinstance(data, OrderBookSnapshot):
quality_score, issues = self.quality_checker.check_orderbook_quality(data)
self.processed_orderbooks += 1
if quality_score < 0.5: # Threshold for acceptable quality
self.quality_failures += 1
logger.warning(f"Low quality order book data: score={quality_score:.2f}, issues={issues}")
return False
return True
elif isinstance(data, TradeEvent):
quality_score, issues = self.quality_checker.check_trade_quality(data)
self.processed_trades += 1
if quality_score < 0.5:
self.quality_failures += 1
logger.warning(f"Low quality trade data: score={quality_score:.2f}, issues={issues}")
return False
return True
else:
logger.error(f"Unknown data type for validation: {type(data)}")
return False
except Exception as e:
logger.error(f"Error validating data: {e}")
return False
def calculate_metrics(self, orderbook: OrderBookSnapshot) -> OrderBookMetrics:
"""
Calculate metrics from order book data.
Args:
orderbook: Order book snapshot
Returns:
OrderBookMetrics: Calculated metrics
"""
try:
set_correlation_id()
return self.metrics_calculator.calculate_orderbook_metrics(orderbook)
except Exception as e:
logger.error(f"Error calculating metrics: {e}")
raise ProcessingError(f"Metrics calculation failed: {e}", "METRICS_ERROR")
def detect_anomalies(self, data: Union[OrderBookSnapshot, TradeEvent]) -> List[str]:
"""
Detect anomalies in the data.
Args:
data: Data to analyze for anomalies
Returns:
List[str]: List of detected anomaly descriptions
"""
try:
set_correlation_id()
if isinstance(data, OrderBookSnapshot):
anomalies = self.anomaly_detector.detect_orderbook_anomalies(data)
elif isinstance(data, TradeEvent):
anomalies = self.anomaly_detector.detect_trade_anomalies(data)
else:
logger.error(f"Unknown data type for anomaly detection: {type(data)}")
return ["Unknown data type"]
if anomalies:
self.anomalies_detected += len(anomalies)
return anomalies
except Exception as e:
logger.error(f"Error detecting anomalies: {e}")
return [f"Anomaly detection error: {e}"]
def filter_data(self, data: Union[OrderBookSnapshot, TradeEvent], criteria: Dict) -> bool:
"""
Filter data based on criteria.
Args:
data: Data to filter
criteria: Filtering criteria
Returns:
bool: True if data passes filter, False otherwise
"""
try:
set_correlation_id()
# Symbol filter
if 'symbols' in criteria:
allowed_symbols = criteria['symbols']
if data.symbol not in allowed_symbols:
return False
# Exchange filter
if 'exchanges' in criteria:
allowed_exchanges = criteria['exchanges']
if data.exchange not in allowed_exchanges:
return False
# Quality filter
if 'min_quality' in criteria:
min_quality = criteria['min_quality']
if isinstance(data, OrderBookSnapshot):
quality_score, _ = self.quality_checker.check_orderbook_quality(data)
elif isinstance(data, TradeEvent):
quality_score, _ = self.quality_checker.check_trade_quality(data)
else:
quality_score = 0.0
if quality_score < min_quality:
return False
# Price range filter
if 'price_range' in criteria:
price_range = criteria['price_range']
min_price, max_price = price_range
if isinstance(data, OrderBookSnapshot):
price = data.mid_price
elif isinstance(data, TradeEvent):
price = data.price
else:
return False
if price and (price < min_price or price > max_price):
return False
# Volume filter for trades
if 'min_volume' in criteria and isinstance(data, TradeEvent):
min_volume = criteria['min_volume']
if data.size < min_volume:
return False
return True
except Exception as e:
logger.error(f"Error filtering data: {e}")
return False
def enrich_data(self, data: Union[OrderBookSnapshot, TradeEvent]) -> Dict:
"""
Enrich data with additional metadata.
Args:
data: Data to enrich
Returns:
Dict: Enriched data with metadata
"""
try:
set_correlation_id()
enriched = {
'original_data': data,
'processing_timestamp': get_current_timestamp(),
'processor_version': '1.0.0'
}
# Add quality metrics
if isinstance(data, OrderBookSnapshot):
quality_score, quality_issues = self.quality_checker.check_orderbook_quality(data)
enriched['quality_score'] = quality_score
enriched['quality_issues'] = quality_issues
# Add calculated metrics
try:
metrics = self.calculate_metrics(data)
enriched['metrics'] = {
'mid_price': metrics.mid_price,
'spread': metrics.spread,
'spread_percentage': metrics.spread_percentage,
'volume_imbalance': metrics.volume_imbalance,
'depth_10': metrics.depth_10,
'depth_50': metrics.depth_50
}
except Exception as e:
enriched['metrics_error'] = str(e)
# Add liquidity score
try:
liquidity_score = self.metrics_calculator.calculate_liquidity_score(data)
enriched['liquidity_score'] = liquidity_score
except Exception as e:
enriched['liquidity_error'] = str(e)
elif isinstance(data, TradeEvent):
quality_score, quality_issues = self.quality_checker.check_trade_quality(data)
enriched['quality_score'] = quality_score
enriched['quality_issues'] = quality_issues
# Add trade-specific enrichments
enriched['trade_value'] = data.price * data.size
enriched['side_numeric'] = 1 if data.side == 'buy' else -1
# Add anomaly detection results
anomalies = self.detect_anomalies(data)
enriched['anomalies'] = anomalies
enriched['anomaly_count'] = len(anomalies)
return enriched
except Exception as e:
logger.error(f"Error enriching data: {e}")
return {
'original_data': data,
'enrichment_error': str(e)
}
def get_data_quality_score(self, data: Union[OrderBookSnapshot, TradeEvent]) -> float:
"""
Calculate data quality score.
Args:
data: Data to score
Returns:
float: Quality score between 0.0 and 1.0
"""
try:
set_correlation_id()
if isinstance(data, OrderBookSnapshot):
quality_score, _ = self.quality_checker.check_orderbook_quality(data)
elif isinstance(data, TradeEvent):
quality_score, _ = self.quality_checker.check_trade_quality(data)
else:
logger.error(f"Unknown data type for quality scoring: {type(data)}")
return 0.0
return quality_score
except Exception as e:
logger.error(f"Error calculating quality score: {e}")
return 0.0
def get_processing_stats(self) -> Dict[str, Any]:
"""Get processing statistics"""
return {
'processed_orderbooks': self.processed_orderbooks,
'processed_trades': self.processed_trades,
'quality_failures': self.quality_failures,
'anomalies_detected': self.anomalies_detected,
'quality_failure_rate': (
self.quality_failures / max(1, self.processed_orderbooks + self.processed_trades)
),
'anomaly_rate': (
self.anomalies_detected / max(1, self.processed_orderbooks + self.processed_trades)
),
'quality_checker_summary': self.quality_checker.get_quality_summary(),
'anomaly_detector_stats': self.anomaly_detector.get_statistics()
}
def reset_stats(self) -> None:
"""Reset processing statistics"""
self.processed_orderbooks = 0
self.processed_trades = 0
self.quality_failures = 0
self.anomalies_detected = 0
logger.info("Processing statistics reset")

View File

@ -0,0 +1,275 @@
"""
Metrics calculation for order book analysis.
"""
from typing import Dict, List, Optional
from ..models.core import OrderBookSnapshot, OrderBookMetrics, ImbalanceMetrics
from ..utils.logging import get_logger
logger = get_logger(__name__)
class MetricsCalculator:
"""
Calculates various metrics from order book data.
Metrics include:
- Basic metrics (mid price, spread, volumes)
- Imbalance metrics
- Depth metrics
- Liquidity metrics
"""
def __init__(self):
"""Initialize metrics calculator"""
logger.info("Metrics calculator initialized")
def calculate_orderbook_metrics(self, orderbook: OrderBookSnapshot) -> OrderBookMetrics:
"""
Calculate comprehensive order book metrics.
Args:
orderbook: Order book snapshot
Returns:
OrderBookMetrics: Calculated metrics
"""
try:
# Basic calculations
mid_price = self._calculate_mid_price(orderbook)
spread = self._calculate_spread(orderbook)
spread_percentage = (spread / mid_price * 100) if mid_price > 0 else 0.0
# Volume calculations
bid_volume = sum(level.size for level in orderbook.bids)
ask_volume = sum(level.size for level in orderbook.asks)
# Imbalance calculation
total_volume = bid_volume + ask_volume
volume_imbalance = ((bid_volume - ask_volume) / total_volume) if total_volume > 0 else 0.0
# Depth calculations
depth_10 = self._calculate_depth(orderbook, 10)
depth_50 = self._calculate_depth(orderbook, 50)
return OrderBookMetrics(
symbol=orderbook.symbol,
exchange=orderbook.exchange,
timestamp=orderbook.timestamp,
mid_price=mid_price,
spread=spread,
spread_percentage=spread_percentage,
bid_volume=bid_volume,
ask_volume=ask_volume,
volume_imbalance=volume_imbalance,
depth_10=depth_10,
depth_50=depth_50
)
except Exception as e:
logger.error(f"Error calculating order book metrics: {e}")
raise
def calculate_imbalance_metrics(self, orderbook: OrderBookSnapshot) -> ImbalanceMetrics:
"""
Calculate order book imbalance metrics.
Args:
orderbook: Order book snapshot
Returns:
ImbalanceMetrics: Calculated imbalance metrics
"""
try:
# Volume imbalance
bid_volume = sum(level.size for level in orderbook.bids)
ask_volume = sum(level.size for level in orderbook.asks)
total_volume = bid_volume + ask_volume
volume_imbalance = ((bid_volume - ask_volume) / total_volume) if total_volume > 0 else 0.0
# Price imbalance (weighted by volume)
price_imbalance = self._calculate_price_imbalance(orderbook)
# Depth imbalance
depth_imbalance = self._calculate_depth_imbalance(orderbook)
# Momentum score (simplified - would need historical data for full implementation)
momentum_score = volume_imbalance * 0.5 + price_imbalance * 0.3 + depth_imbalance * 0.2
return ImbalanceMetrics(
symbol=orderbook.symbol,
timestamp=orderbook.timestamp,
volume_imbalance=volume_imbalance,
price_imbalance=price_imbalance,
depth_imbalance=depth_imbalance,
momentum_score=momentum_score
)
except Exception as e:
logger.error(f"Error calculating imbalance metrics: {e}")
raise
def _calculate_mid_price(self, orderbook: OrderBookSnapshot) -> float:
"""Calculate mid price"""
if not orderbook.bids or not orderbook.asks:
return 0.0
best_bid = orderbook.bids[0].price
best_ask = orderbook.asks[0].price
return (best_bid + best_ask) / 2.0
def _calculate_spread(self, orderbook: OrderBookSnapshot) -> float:
"""Calculate bid-ask spread"""
if not orderbook.bids or not orderbook.asks:
return 0.0
best_bid = orderbook.bids[0].price
best_ask = orderbook.asks[0].price
return best_ask - best_bid
def _calculate_depth(self, orderbook: OrderBookSnapshot, levels: int) -> float:
"""Calculate market depth for specified number of levels"""
bid_depth = sum(
level.size for level in orderbook.bids[:levels]
)
ask_depth = sum(
level.size for level in orderbook.asks[:levels]
)
return bid_depth + ask_depth
def _calculate_price_imbalance(self, orderbook: OrderBookSnapshot) -> float:
"""Calculate price-weighted imbalance"""
if not orderbook.bids or not orderbook.asks:
return 0.0
# Calculate volume-weighted average prices for top levels
bid_vwap = self._calculate_vwap(orderbook.bids[:5])
ask_vwap = self._calculate_vwap(orderbook.asks[:5])
if bid_vwap == 0 or ask_vwap == 0:
return 0.0
mid_price = (bid_vwap + ask_vwap) / 2.0
# Normalize imbalance
price_imbalance = (bid_vwap - ask_vwap) / mid_price if mid_price > 0 else 0.0
return max(-1.0, min(1.0, price_imbalance))
def _calculate_depth_imbalance(self, orderbook: OrderBookSnapshot) -> float:
"""Calculate depth imbalance across multiple levels"""
levels_to_check = [5, 10, 20]
imbalances = []
for levels in levels_to_check:
bid_depth = sum(level.size for level in orderbook.bids[:levels])
ask_depth = sum(level.size for level in orderbook.asks[:levels])
total_depth = bid_depth + ask_depth
if total_depth > 0:
imbalance = (bid_depth - ask_depth) / total_depth
imbalances.append(imbalance)
# Return weighted average of imbalances
if imbalances:
return sum(imbalances) / len(imbalances)
return 0.0
def _calculate_vwap(self, levels: List) -> float:
"""Calculate volume-weighted average price for price levels"""
if not levels:
return 0.0
total_volume = sum(level.size for level in levels)
if total_volume == 0:
return 0.0
weighted_sum = sum(level.price * level.size for level in levels)
return weighted_sum / total_volume
def calculate_liquidity_score(self, orderbook: OrderBookSnapshot) -> float:
"""
Calculate liquidity score based on depth and spread.
Args:
orderbook: Order book snapshot
Returns:
float: Liquidity score (0.0 to 1.0)
"""
try:
if not orderbook.bids or not orderbook.asks:
return 0.0
# Spread component (lower spread = higher liquidity)
spread = self._calculate_spread(orderbook)
mid_price = self._calculate_mid_price(orderbook)
if mid_price == 0:
return 0.0
spread_pct = (spread / mid_price) * 100
spread_score = max(0.0, 1.0 - (spread_pct / 5.0)) # Normalize to 5% max spread
# Depth component (higher depth = higher liquidity)
total_depth = self._calculate_depth(orderbook, 10)
depth_score = min(1.0, total_depth / 100.0) # Normalize to 100 units max depth
# Volume balance component (more balanced = higher liquidity)
bid_volume = sum(level.size for level in orderbook.bids[:10])
ask_volume = sum(level.size for level in orderbook.asks[:10])
total_volume = bid_volume + ask_volume
if total_volume > 0:
imbalance = abs(bid_volume - ask_volume) / total_volume
balance_score = 1.0 - imbalance
else:
balance_score = 0.0
# Weighted combination
liquidity_score = (spread_score * 0.4 + depth_score * 0.4 + balance_score * 0.2)
return max(0.0, min(1.0, liquidity_score))
except Exception as e:
logger.error(f"Error calculating liquidity score: {e}")
return 0.0
def get_market_summary(self, orderbook: OrderBookSnapshot) -> Dict[str, float]:
"""
Get comprehensive market summary.
Args:
orderbook: Order book snapshot
Returns:
Dict[str, float]: Market summary metrics
"""
try:
metrics = self.calculate_orderbook_metrics(orderbook)
imbalance = self.calculate_imbalance_metrics(orderbook)
liquidity = self.calculate_liquidity_score(orderbook)
return {
'mid_price': metrics.mid_price,
'spread': metrics.spread,
'spread_percentage': metrics.spread_percentage,
'bid_volume': metrics.bid_volume,
'ask_volume': metrics.ask_volume,
'volume_imbalance': metrics.volume_imbalance,
'depth_10': metrics.depth_10,
'depth_50': metrics.depth_50,
'price_imbalance': imbalance.price_imbalance,
'depth_imbalance': imbalance.depth_imbalance,
'momentum_score': imbalance.momentum_score,
'liquidity_score': liquidity
}
except Exception as e:
logger.error(f"Error generating market summary: {e}")
return {}

View File

@ -0,0 +1,288 @@
"""
Data quality checking and validation for market data.
"""
from typing import Dict, List, Union, Optional, Tuple
from datetime import datetime, timezone
from ..models.core import OrderBookSnapshot, TradeEvent
from ..utils.logging import get_logger
from ..utils.validation import validate_price, validate_volume, validate_symbol
from ..utils.timing import get_current_timestamp
logger = get_logger(__name__)
class DataQualityChecker:
"""
Comprehensive data quality checker for market data.
Validates:
- Data structure integrity
- Price and volume ranges
- Timestamp consistency
- Cross-validation between related data points
"""
def __init__(self):
"""Initialize quality checker with default thresholds"""
# Quality thresholds
self.max_spread_percentage = 10.0 # Maximum spread as % of mid price
self.max_price_change_percentage = 50.0 # Maximum price change between updates
self.min_volume_threshold = 0.000001 # Minimum meaningful volume
self.max_timestamp_drift = 300 # Maximum seconds drift from current time
# Price history for validation
self.price_history: Dict[str, Dict[str, float]] = {} # symbol -> exchange -> last_price
logger.info("Data quality checker initialized")
def check_orderbook_quality(self, orderbook: OrderBookSnapshot) -> Tuple[float, List[str]]:
"""
Check order book data quality.
Args:
orderbook: Order book snapshot to validate
Returns:
Tuple[float, List[str]]: Quality score (0.0-1.0) and list of issues
"""
issues = []
quality_score = 1.0
try:
# Basic structure validation
structure_issues = self._check_orderbook_structure(orderbook)
issues.extend(structure_issues)
quality_score -= len(structure_issues) * 0.1
# Price validation
price_issues = self._check_orderbook_prices(orderbook)
issues.extend(price_issues)
quality_score -= len(price_issues) * 0.15
# Volume validation
volume_issues = self._check_orderbook_volumes(orderbook)
issues.extend(volume_issues)
quality_score -= len(volume_issues) * 0.1
# Spread validation
spread_issues = self._check_orderbook_spread(orderbook)
issues.extend(spread_issues)
quality_score -= len(spread_issues) * 0.2
# Timestamp validation
timestamp_issues = self._check_timestamp(orderbook.timestamp)
issues.extend(timestamp_issues)
quality_score -= len(timestamp_issues) * 0.1
# Cross-validation with history
history_issues = self._check_price_history(orderbook)
issues.extend(history_issues)
quality_score -= len(history_issues) * 0.15
# Update price history
self._update_price_history(orderbook)
except Exception as e:
logger.error(f"Error checking order book quality: {e}")
issues.append(f"Quality check error: {e}")
quality_score = 0.0
# Ensure score is within bounds
quality_score = max(0.0, min(1.0, quality_score))
if issues:
logger.debug(f"Order book quality issues for {orderbook.symbol}@{orderbook.exchange}: {issues}")
return quality_score, issues de
f check_trade_quality(self, trade: TradeEvent) -> Tuple[float, List[str]]:
"""
Check trade data quality.
Args:
trade: Trade event to validate
Returns:
Tuple[float, List[str]]: Quality score (0.0-1.0) and list of issues
"""
issues = []
quality_score = 1.0
try:
# Basic structure validation
if not validate_symbol(trade.symbol):
issues.append("Invalid symbol format")
if not trade.exchange:
issues.append("Missing exchange")
if not trade.trade_id:
issues.append("Missing trade ID")
# Price validation
if not validate_price(trade.price):
issues.append(f"Invalid price: {trade.price}")
# Volume validation
if not validate_volume(trade.size):
issues.append(f"Invalid size: {trade.size}")
if trade.size < self.min_volume_threshold:
issues.append(f"Size below threshold: {trade.size}")
# Side validation
if trade.side not in ['buy', 'sell']:
issues.append(f"Invalid side: {trade.side}")
# Timestamp validation
timestamp_issues = self._check_timestamp(trade.timestamp)
issues.extend(timestamp_issues)
# Calculate quality score
quality_score -= len(issues) * 0.2
except Exception as e:
logger.error(f"Error checking trade quality: {e}")
issues.append(f"Quality check error: {e}")
quality_score = 0.0
# Ensure score is within bounds
quality_score = max(0.0, min(1.0, quality_score))
if issues:
logger.debug(f"Trade quality issues for {trade.symbol}@{trade.exchange}: {issues}")
return quality_score, issues
def _check_orderbook_structure(self, orderbook: OrderBookSnapshot) -> List[str]:
"""Check basic order book structure"""
issues = []
if not validate_symbol(orderbook.symbol):
issues.append("Invalid symbol format")
if not orderbook.exchange:
issues.append("Missing exchange")
if not orderbook.bids:
issues.append("No bid levels")
if not orderbook.asks:
issues.append("No ask levels")
return issues
def _check_orderbook_prices(self, orderbook: OrderBookSnapshot) -> List[str]:
"""Check order book price validity"""
issues = []
# Check bid prices (should be descending)
for i, bid in enumerate(orderbook.bids):
if not validate_price(bid.price):
issues.append(f"Invalid bid price at level {i}: {bid.price}")
if i > 0 and bid.price >= orderbook.bids[i-1].price:
issues.append(f"Bid prices not descending at level {i}")
# Check ask prices (should be ascending)
for i, ask in enumerate(orderbook.asks):
if not validate_price(ask.price):
issues.append(f"Invalid ask price at level {i}: {ask.price}")
if i > 0 and ask.price <= orderbook.asks[i-1].price:
issues.append(f"Ask prices not ascending at level {i}")
# Check bid-ask ordering
if orderbook.bids and orderbook.asks:
if orderbook.bids[0].price >= orderbook.asks[0].price:
issues.append("Best bid >= best ask (crossed book)")
return issues def
_check_orderbook_volumes(self, orderbook: OrderBookSnapshot) -> List[str]:
"""Check order book volume validity"""
issues = []
# Check bid volumes
for i, bid in enumerate(orderbook.bids):
if not validate_volume(bid.size):
issues.append(f"Invalid bid volume at level {i}: {bid.size}")
if bid.size < self.min_volume_threshold:
issues.append(f"Bid volume below threshold at level {i}: {bid.size}")
# Check ask volumes
for i, ask in enumerate(orderbook.asks):
if not validate_volume(ask.size):
issues.append(f"Invalid ask volume at level {i}: {ask.size}")
if ask.size < self.min_volume_threshold:
issues.append(f"Ask volume below threshold at level {i}: {ask.size}")
return issues
def _check_orderbook_spread(self, orderbook: OrderBookSnapshot) -> List[str]:
"""Check order book spread validity"""
issues = []
if orderbook.mid_price and orderbook.spread:
spread_percentage = (orderbook.spread / orderbook.mid_price) * 100
if spread_percentage > self.max_spread_percentage:
issues.append(f"Spread too wide: {spread_percentage:.2f}%")
if spread_percentage < 0:
issues.append(f"Negative spread: {spread_percentage:.2f}%")
return issues
def _check_timestamp(self, timestamp: datetime) -> List[str]:
"""Check timestamp validity"""
issues = []
if not timestamp:
issues.append("Missing timestamp")
return issues
# Check if timestamp is timezone-aware
if timestamp.tzinfo is None:
issues.append("Timestamp missing timezone info")
# Check timestamp drift
current_time = get_current_timestamp()
time_diff = abs((timestamp - current_time).total_seconds())
if time_diff > self.max_timestamp_drift:
issues.append(f"Timestamp drift too large: {time_diff:.1f}s")
return issues
def _check_price_history(self, orderbook: OrderBookSnapshot) -> List[str]:
"""Check price consistency with history"""
issues = []
key = f"{orderbook.symbol}_{orderbook.exchange}"
if key in self.price_history and orderbook.mid_price:
last_price = self.price_history[key]
price_change = abs(orderbook.mid_price - last_price) / last_price * 100
if price_change > self.max_price_change_percentage:
issues.append(f"Large price change: {price_change:.2f}%")
return issues
def _update_price_history(self, orderbook: OrderBookSnapshot) -> None:
"""Update price history for future validation"""
if orderbook.mid_price:
key = f"{orderbook.symbol}_{orderbook.exchange}"
self.price_history[key] = orderbook.mid_price
def get_quality_summary(self) -> Dict[str, int]:
"""Get summary of quality checks performed"""
return {
'symbols_tracked': len(self.price_history),
'max_spread_percentage': self.max_spread_percentage,
'max_price_change_percentage': self.max_price_change_percentage,
'min_volume_threshold': self.min_volume_threshold,
'max_timestamp_drift': self.max_timestamp_drift
}

8
COBY/replay/__init__.py Normal file
View File

@ -0,0 +1,8 @@
"""
Historical data replay system for the COBY multi-exchange data aggregation system.
Provides configurable playback of historical market data with session management.
"""
from .replay_manager import HistoricalReplayManager
__all__ = ['HistoricalReplayManager']

View File

@ -0,0 +1,665 @@
"""
Historical data replay manager implementation.
Provides configurable playback of historical market data with session management.
"""
import asyncio
import uuid
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Optional, Callable, Any, Union
from dataclasses import replace
from ..interfaces.replay_manager import ReplayManager
from ..models.core import ReplaySession, ReplayStatus, OrderBookSnapshot, TradeEvent
from ..storage.storage_manager import StorageManager
from ..utils.logging import get_logger, set_correlation_id
from ..utils.exceptions import ReplayError, ValidationError
from ..utils.timing import get_current_timestamp
from ..config import Config
logger = get_logger(__name__)
class HistoricalReplayManager(ReplayManager):
"""
Implementation of historical data replay functionality.
Provides:
- Session-based replay management
- Configurable playback speeds
- Real-time data streaming
- Session controls (start/pause/stop/seek)
- Data filtering by symbol and exchange
"""
def __init__(self, storage_manager: StorageManager, config: Config):
"""
Initialize replay manager.
Args:
storage_manager: Storage manager for data access
config: System configuration
"""
self.storage_manager = storage_manager
self.config = config
# Session management
self.sessions: Dict[str, ReplaySession] = {}
self.session_tasks: Dict[str, asyncio.Task] = {}
self.session_callbacks: Dict[str, Dict[str, List[Callable]]] = {}
# Performance tracking
self.stats = {
'sessions_created': 0,
'sessions_completed': 0,
'sessions_failed': 0,
'total_events_replayed': 0,
'avg_replay_speed': 0.0
}
logger.info("Historical replay manager initialized")
def create_replay_session(self, start_time: datetime, end_time: datetime,
speed: float = 1.0, symbols: Optional[List[str]] = None,
exchanges: Optional[List[str]] = None) -> str:
"""Create a new replay session."""
try:
set_correlation_id()
# Validate parameters
validation_errors = self.validate_replay_request(start_time, end_time, symbols, exchanges)
if validation_errors:
raise ValidationError(f"Invalid replay request: {', '.join(validation_errors)}")
# Generate session ID
session_id = str(uuid.uuid4())
# Create session
session = ReplaySession(
session_id=session_id,
start_time=start_time,
end_time=end_time,
current_time=start_time,
speed=speed,
status=ReplayStatus.CREATED,
symbols=symbols or [],
exchanges=exchanges or [],
created_at=get_current_timestamp(),
events_replayed=0,
total_events=0,
progress=0.0
)
# Store session
self.sessions[session_id] = session
self.session_callbacks[session_id] = {
'data': [],
'status': []
}
self.stats['sessions_created'] += 1
logger.info(f"Created replay session {session_id} for {start_time} to {end_time}")
return session_id
except Exception as e:
logger.error(f"Failed to create replay session: {e}")
raise ReplayError(f"Session creation failed: {e}")
async def start_replay(self, session_id: str) -> None:
"""Start replay session."""
try:
set_correlation_id()
if session_id not in self.sessions:
raise ReplayError(f"Session {session_id} not found")
session = self.sessions[session_id]
if session.status == ReplayStatus.RUNNING:
logger.warning(f"Session {session_id} is already running")
return
# Update session status
session.status = ReplayStatus.RUNNING
session.started_at = get_current_timestamp()
# Notify status callbacks
await self._notify_status_callbacks(session_id, ReplayStatus.RUNNING)
# Start replay task
task = asyncio.create_task(self._replay_task(session_id))
self.session_tasks[session_id] = task
logger.info(f"Started replay session {session_id}")
except Exception as e:
logger.error(f"Failed to start replay session {session_id}: {e}")
await self._set_session_error(session_id, str(e))
raise ReplayError(f"Failed to start replay: {e}")
async def pause_replay(self, session_id: str) -> None:
"""Pause replay session."""
try:
if session_id not in self.sessions:
raise ReplayError(f"Session {session_id} not found")
session = self.sessions[session_id]
if session.status != ReplayStatus.RUNNING:
logger.warning(f"Session {session_id} is not running")
return
# Update session status
session.status = ReplayStatus.PAUSED
session.paused_at = get_current_timestamp()
# Cancel replay task
if session_id in self.session_tasks:
self.session_tasks[session_id].cancel()
del self.session_tasks[session_id]
# Notify status callbacks
await self._notify_status_callbacks(session_id, ReplayStatus.PAUSED)
logger.info(f"Paused replay session {session_id}")
except Exception as e:
logger.error(f"Failed to pause replay session {session_id}: {e}")
raise ReplayError(f"Failed to pause replay: {e}")
async def resume_replay(self, session_id: str) -> None:
"""Resume paused replay session."""
try:
if session_id not in self.sessions:
raise ReplayError(f"Session {session_id} not found")
session = self.sessions[session_id]
if session.status != ReplayStatus.PAUSED:
logger.warning(f"Session {session_id} is not paused")
return
# Resume from current position
await self.start_replay(session_id)
logger.info(f"Resumed replay session {session_id}")
except Exception as e:
logger.error(f"Failed to resume replay session {session_id}: {e}")
raise ReplayError(f"Failed to resume replay: {e}")
async def stop_replay(self, session_id: str) -> None:
"""Stop replay session."""
try:
if session_id not in self.sessions:
raise ReplayError(f"Session {session_id} not found")
session = self.sessions[session_id]
# Update session status
session.status = ReplayStatus.STOPPED
session.stopped_at = get_current_timestamp()
# Cancel replay task
if session_id in self.session_tasks:
self.session_tasks[session_id].cancel()
try:
await self.session_tasks[session_id]
except asyncio.CancelledError:
pass
del self.session_tasks[session_id]
# Notify status callbacks
await self._notify_status_callbacks(session_id, ReplayStatus.STOPPED)
logger.info(f"Stopped replay session {session_id}")
except Exception as e:
logger.error(f"Failed to stop replay session {session_id}: {e}")
raise ReplayError(f"Failed to stop replay: {e}")
def get_replay_status(self, session_id: str) -> Optional[ReplaySession]:
"""Get replay session status."""
return self.sessions.get(session_id)
def list_replay_sessions(self) -> List[ReplaySession]:
"""List all replay sessions."""
return list(self.sessions.values())
def delete_replay_session(self, session_id: str) -> bool:
"""Delete replay session."""
try:
if session_id not in self.sessions:
return False
# Stop session if running
if self.sessions[session_id].status == ReplayStatus.RUNNING:
asyncio.create_task(self.stop_replay(session_id))
# Clean up
del self.sessions[session_id]
if session_id in self.session_callbacks:
del self.session_callbacks[session_id]
logger.info(f"Deleted replay session {session_id}")
return True
except Exception as e:
logger.error(f"Failed to delete replay session {session_id}: {e}")
return False
def set_replay_speed(self, session_id: str, speed: float) -> bool:
"""Change replay speed for active session."""
try:
if session_id not in self.sessions:
return False
if speed <= 0:
raise ValueError("Speed must be positive")
session = self.sessions[session_id]
session.speed = speed
logger.info(f"Set replay speed to {speed}x for session {session_id}")
return True
except Exception as e:
logger.error(f"Failed to set replay speed for session {session_id}: {e}")
return False
def seek_replay(self, session_id: str, timestamp: datetime) -> bool:
"""Seek to specific timestamp in replay."""
try:
if session_id not in self.sessions:
return False
session = self.sessions[session_id]
# Validate timestamp is within session range
if timestamp < session.start_time or timestamp > session.end_time:
logger.warning(f"Seek timestamp {timestamp} outside session range")
return False
# Update current time
session.current_time = timestamp
# Recalculate progress
total_duration = (session.end_time - session.start_time).total_seconds()
elapsed_duration = (timestamp - session.start_time).total_seconds()
session.progress = elapsed_duration / total_duration if total_duration > 0 else 0.0
logger.info(f"Seeked to {timestamp} in session {session_id}")
return True
except Exception as e:
logger.error(f"Failed to seek in session {session_id}: {e}")
return False
def add_data_callback(self, session_id: str, callback: Callable) -> bool:
"""Add callback for replay data."""
try:
if session_id not in self.session_callbacks:
return False
self.session_callbacks[session_id]['data'].append(callback)
logger.debug(f"Added data callback for session {session_id}")
return True
except Exception as e:
logger.error(f"Failed to add data callback for session {session_id}: {e}")
return False
def remove_data_callback(self, session_id: str, callback: Callable) -> bool:
"""Remove data callback from replay session."""
try:
if session_id not in self.session_callbacks:
return False
callbacks = self.session_callbacks[session_id]['data']
if callback in callbacks:
callbacks.remove(callback)
logger.debug(f"Removed data callback for session {session_id}")
return True
return False
except Exception as e:
logger.error(f"Failed to remove data callback for session {session_id}: {e}")
return False
def add_status_callback(self, session_id: str, callback: Callable) -> bool:
"""Add callback for replay status changes."""
try:
if session_id not in self.session_callbacks:
return False
self.session_callbacks[session_id]['status'].append(callback)
logger.debug(f"Added status callback for session {session_id}")
return True
except Exception as e:
logger.error(f"Failed to add status callback for session {session_id}: {e}")
return False
async def get_available_data_range(self, symbol: str,
exchange: Optional[str] = None) -> Optional[Dict[str, datetime]]:
"""Get available data time range for replay."""
try:
# Query database for data range
if exchange:
query = """
SELECT
MIN(timestamp) as start_time,
MAX(timestamp) as end_time
FROM order_book_snapshots
WHERE symbol = $1 AND exchange = $2
"""
result = await self.storage_manager.connection_pool.fetchrow(query, symbol, exchange)
else:
query = """
SELECT
MIN(timestamp) as start_time,
MAX(timestamp) as end_time
FROM order_book_snapshots
WHERE symbol = $1
"""
result = await self.storage_manager.connection_pool.fetchrow(query, symbol)
if result and result['start_time'] and result['end_time']:
return {
'start': result['start_time'],
'end': result['end_time']
}
return None
except Exception as e:
logger.error(f"Failed to get data range for {symbol}: {e}")
return None
def validate_replay_request(self, start_time: datetime, end_time: datetime,
symbols: Optional[List[str]] = None,
exchanges: Optional[List[str]] = None) -> List[str]:
"""Validate replay request parameters."""
errors = []
# Validate time range
if start_time >= end_time:
errors.append("Start time must be before end time")
# Check if time range is too large (more than 30 days)
if (end_time - start_time).days > 30:
errors.append("Time range cannot exceed 30 days")
# Check if start time is too far in the past (more than 1 year)
if (get_current_timestamp() - start_time).days > 365:
errors.append("Start time cannot be more than 1 year ago")
# Validate symbols
if symbols:
for symbol in symbols:
if not symbol or len(symbol) < 3:
errors.append(f"Invalid symbol: {symbol}")
# Validate exchanges
if exchanges:
valid_exchanges = self.config.exchanges.exchanges
for exchange in exchanges:
if exchange not in valid_exchanges:
errors.append(f"Unsupported exchange: {exchange}")
return errors
async def _replay_task(self, session_id: str) -> None:
"""Main replay task that streams historical data."""
try:
session = self.sessions[session_id]
# Calculate total events for progress tracking
await self._calculate_total_events(session_id)
# Stream data
await self._stream_historical_data(session_id)
# Mark as completed
session.status = ReplayStatus.COMPLETED
session.completed_at = get_current_timestamp()
session.progress = 1.0
await self._notify_status_callbacks(session_id, ReplayStatus.COMPLETED)
self.stats['sessions_completed'] += 1
logger.info(f"Completed replay session {session_id}")
except asyncio.CancelledError:
logger.info(f"Replay session {session_id} was cancelled")
except Exception as e:
logger.error(f"Replay task failed for session {session_id}: {e}")
await self._set_session_error(session_id, str(e))
self.stats['sessions_failed'] += 1
async def _calculate_total_events(self, session_id: str) -> None:
"""Calculate total number of events for progress tracking."""
try:
session = self.sessions[session_id]
# Build query conditions
conditions = ["timestamp >= $1", "timestamp <= $2"]
params = [session.start_time, session.end_time]
param_count = 2
if session.symbols:
param_count += 1
conditions.append(f"symbol = ANY(${param_count})")
params.append(session.symbols)
if session.exchanges:
param_count += 1
conditions.append(f"exchange = ANY(${param_count})")
params.append(session.exchanges)
where_clause = " AND ".join(conditions)
# Count order book events
orderbook_query = f"""
SELECT COUNT(*) FROM order_book_snapshots
WHERE {where_clause}
"""
orderbook_count = await self.storage_manager.connection_pool.fetchval(
orderbook_query, *params
)
# Count trade events
trade_query = f"""
SELECT COUNT(*) FROM trade_events
WHERE {where_clause}
"""
trade_count = await self.storage_manager.connection_pool.fetchval(
trade_query, *params
)
session.total_events = (orderbook_count or 0) + (trade_count or 0)
logger.debug(f"Session {session_id} has {session.total_events} total events")
except Exception as e:
logger.error(f"Failed to calculate total events for session {session_id}: {e}")
session.total_events = 0
async def _stream_historical_data(self, session_id: str) -> None:
"""Stream historical data for replay session."""
session = self.sessions[session_id]
# Build query conditions
conditions = ["timestamp >= $1", "timestamp <= $2"]
params = [session.current_time, session.end_time]
param_count = 2
if session.symbols:
param_count += 1
conditions.append(f"symbol = ANY(${param_count})")
params.append(session.symbols)
if session.exchanges:
param_count += 1
conditions.append(f"exchange = ANY(${param_count})")
params.append(session.exchanges)
where_clause = " AND ".join(conditions)
# Query both order book and trade data, ordered by timestamp
query = f"""
(
SELECT 'orderbook' as type, timestamp, symbol, exchange,
bids, asks, sequence_id, mid_price, spread, bid_volume, ask_volume,
NULL as price, NULL as size, NULL as side, NULL as trade_id
FROM order_book_snapshots
WHERE {where_clause}
)
UNION ALL
(
SELECT 'trade' as type, timestamp, symbol, exchange,
NULL as bids, NULL as asks, NULL as sequence_id,
NULL as mid_price, NULL as spread, NULL as bid_volume, NULL as ask_volume,
price, size, side, trade_id
FROM trade_events
WHERE {where_clause}
)
ORDER BY timestamp ASC
"""
# Stream data in chunks
chunk_size = 1000
offset = 0
last_timestamp = session.current_time
while session.status == ReplayStatus.RUNNING:
# Fetch chunk
chunk_query = f"{query} LIMIT {chunk_size} OFFSET {offset}"
rows = await self.storage_manager.connection_pool.fetch(chunk_query, *params)
if not rows:
break
# Process each row
for row in rows:
if session.status != ReplayStatus.RUNNING:
break
# Calculate delay based on replay speed
if last_timestamp < row['timestamp']:
time_diff = (row['timestamp'] - last_timestamp).total_seconds()
delay = time_diff / session.speed
if delay > 0:
await asyncio.sleep(delay)
# Create data object
if row['type'] == 'orderbook':
data = await self._create_orderbook_from_row(row)
else:
data = await self._create_trade_from_row(row)
# Notify data callbacks
await self._notify_data_callbacks(session_id, data)
# Update session progress
session.events_replayed += 1
session.current_time = row['timestamp']
if session.total_events > 0:
session.progress = session.events_replayed / session.total_events
last_timestamp = row['timestamp']
self.stats['total_events_replayed'] += 1
offset += chunk_size
async def _create_orderbook_from_row(self, row: Dict) -> OrderBookSnapshot:
"""Create OrderBookSnapshot from database row."""
import json
from ..models.core import PriceLevel
# Parse bids and asks from JSON
bids_data = json.loads(row['bids']) if row['bids'] else []
asks_data = json.loads(row['asks']) if row['asks'] else []
bids = [PriceLevel(price=b['price'], size=b['size'], count=b.get('count'))
for b in bids_data]
asks = [PriceLevel(price=a['price'], size=a['size'], count=a.get('count'))
for a in asks_data]
return OrderBookSnapshot(
symbol=row['symbol'],
exchange=row['exchange'],
timestamp=row['timestamp'],
bids=bids,
asks=asks,
sequence_id=row['sequence_id']
)
async def _create_trade_from_row(self, row: Dict) -> TradeEvent:
"""Create TradeEvent from database row."""
return TradeEvent(
symbol=row['symbol'],
exchange=row['exchange'],
timestamp=row['timestamp'],
price=float(row['price']),
size=float(row['size']),
side=row['side'],
trade_id=row['trade_id']
)
async def _notify_data_callbacks(self, session_id: str,
data: Union[OrderBookSnapshot, TradeEvent]) -> None:
"""Notify all data callbacks for a session."""
if session_id in self.session_callbacks:
callbacks = self.session_callbacks[session_id]['data']
for callback in callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(data)
else:
callback(data)
except Exception as e:
logger.error(f"Data callback error for session {session_id}: {e}")
async def _notify_status_callbacks(self, session_id: str, status: ReplayStatus) -> None:
"""Notify all status callbacks for a session."""
if session_id in self.session_callbacks:
callbacks = self.session_callbacks[session_id]['status']
for callback in callbacks:
try:
if asyncio.iscoroutinefunction(callback):
await callback(session_id, status)
else:
callback(session_id, status)
except Exception as e:
logger.error(f"Status callback error for session {session_id}: {e}")
async def _set_session_error(self, session_id: str, error_message: str) -> None:
"""Set session to error state."""
if session_id in self.sessions:
session = self.sessions[session_id]
session.status = ReplayStatus.ERROR
session.error_message = error_message
session.stopped_at = get_current_timestamp()
await self._notify_status_callbacks(session_id, ReplayStatus.ERROR)
def get_stats(self) -> Dict[str, Any]:
"""Get replay manager statistics."""
active_sessions = sum(1 for s in self.sessions.values()
if s.status == ReplayStatus.RUNNING)
return {
**self.stats,
'active_sessions': active_sessions,
'total_sessions': len(self.sessions),
'session_statuses': {
status.value: sum(1 for s in self.sessions.values() if s.status == status)
for status in ReplayStatus
}
}

34
COBY/requirements.txt Normal file
View File

@ -0,0 +1,34 @@
# Core dependencies for COBY system
asyncpg>=0.29.0 # PostgreSQL/TimescaleDB async driver
redis>=5.0.0 # Redis client
websockets>=12.0 # WebSocket client library
aiohttp>=3.9.0 # Async HTTP client/server
fastapi>=0.104.0 # API framework
uvicorn>=0.24.0 # ASGI server
pydantic>=2.5.0 # Data validation
python-multipart>=0.0.6 # Form data parsing
# Data processing
pandas>=2.1.0 # Data manipulation
numpy>=1.24.0 # Numerical computing
scipy>=1.11.0 # Scientific computing
# Utilities
python-dotenv>=1.0.0 # Environment variable loading
structlog>=23.2.0 # Structured logging
click>=8.1.0 # CLI framework
rich>=13.7.0 # Rich text and beautiful formatting
# Development dependencies
pytest>=7.4.0 # Testing framework
pytest-asyncio>=0.21.0 # Async testing
pytest-cov>=4.1.0 # Coverage reporting
black>=23.11.0 # Code formatting
isort>=5.12.0 # Import sorting
flake8>=6.1.0 # Linting
mypy>=1.7.0 # Type checking
# Optional dependencies for enhanced features
prometheus-client>=0.19.0 # Metrics collection
grafana-api>=1.0.3 # Grafana integration
psutil>=5.9.0 # System monitoring

11
COBY/storage/__init__.py Normal file
View File

@ -0,0 +1,11 @@
"""
Storage layer for the multi-exchange data aggregation system.
Provides TimescaleDB integration, connection pooling, and schema management.
"""
from .timescale_manager import TimescaleManager
from .connection_pool import ConnectionPoolManager
from .schema import SchemaManager
from .storage_manager import StorageManager
__all__ = ['TimescaleManager', 'ConnectionPoolManager', 'SchemaManager', 'StorageManager']

View File

@ -0,0 +1,219 @@
"""
Database connection pool management with health monitoring and automatic recovery.
"""
import asyncio
import logging
from typing import Optional, Dict, Any
from datetime import datetime, timedelta
import asyncpg
from asyncpg import Pool
from ..config import Config
from ..utils.exceptions import ConnectionError
logger = logging.getLogger(__name__)
class ConnectionPoolManager:
"""Manages database connection pools with health monitoring and recovery."""
def __init__(self, config: Config):
self.config = config
self.pool: Optional[Pool] = None
self._connection_string = self._build_connection_string()
self._health_check_interval = 30 # seconds
self._health_check_task: Optional[asyncio.Task] = None
self._last_health_check = datetime.utcnow()
self._connection_failures = 0
self._max_failures = 5
def _build_connection_string(self) -> str:
"""Build PostgreSQL connection string from config."""
return (
f"postgresql://{self.config.database.user}:{self.config.database.password}"
f"@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}"
)
async def initialize(self) -> None:
"""Initialize connection pool with health monitoring."""
try:
logger.info("Creating database connection pool...")
self.pool = await asyncpg.create_pool(
self._connection_string,
min_size=5,
max_size=self.config.database.pool_size,
command_timeout=60,
server_settings={
'jit': 'off',
'timezone': 'UTC',
'statement_timeout': '30s',
'idle_in_transaction_session_timeout': '60s'
},
init=self._init_connection
)
# Test initial connection
await self._test_connection()
# Start health monitoring
self._health_check_task = asyncio.create_task(self._health_monitor())
logger.info(f"Database connection pool initialized with {self.config.db_min_connections}-{self.config.db_max_connections} connections")
except Exception as e:
logger.error(f"Failed to initialize connection pool: {e}")
raise ConnectionError(f"Connection pool initialization failed: {e}")
async def _init_connection(self, conn: asyncpg.Connection) -> None:
"""Initialize individual database connections."""
try:
# Set connection-specific settings
await conn.execute("SET timezone = 'UTC'")
await conn.execute("SET statement_timeout = '30s'")
# Test TimescaleDB extension
result = await conn.fetchval("SELECT extname FROM pg_extension WHERE extname = 'timescaledb'")
if not result:
logger.warning("TimescaleDB extension not found in database")
except Exception as e:
logger.error(f"Failed to initialize connection: {e}")
raise
async def _test_connection(self) -> bool:
"""Test database connection health."""
try:
async with self.pool.acquire() as conn:
await conn.execute('SELECT 1')
self._connection_failures = 0
return True
except Exception as e:
self._connection_failures += 1
logger.error(f"Connection test failed (attempt {self._connection_failures}): {e}")
if self._connection_failures >= self._max_failures:
logger.critical("Maximum connection failures reached, attempting pool recreation")
await self._recreate_pool()
return False
async def _recreate_pool(self) -> None:
"""Recreate connection pool after failures."""
try:
if self.pool:
await self.pool.close()
self.pool = None
# Wait before recreating
await asyncio.sleep(5)
self.pool = await asyncpg.create_pool(
self._connection_string,
min_size=5,
max_size=self.config.database.pool_size,
command_timeout=60,
server_settings={
'jit': 'off',
'timezone': 'UTC'
},
init=self._init_connection
)
self._connection_failures = 0
logger.info("Connection pool recreated successfully")
except Exception as e:
logger.error(f"Failed to recreate connection pool: {e}")
# Will retry on next health check
async def _health_monitor(self) -> None:
"""Background task to monitor connection pool health."""
while True:
try:
await asyncio.sleep(self._health_check_interval)
if self.pool:
await self._test_connection()
self._last_health_check = datetime.utcnow()
# Log pool statistics periodically
if datetime.utcnow().minute % 5 == 0: # Every 5 minutes
stats = self.get_pool_stats()
logger.debug(f"Connection pool stats: {stats}")
except asyncio.CancelledError:
logger.info("Health monitor task cancelled")
break
except Exception as e:
logger.error(f"Health monitor error: {e}")
async def close(self) -> None:
"""Close connection pool and stop monitoring."""
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
if self.pool:
await self.pool.close()
logger.info("Database connection pool closed")
def get_pool_stats(self) -> Dict[str, Any]:
"""Get connection pool statistics."""
if not self.pool:
return {"status": "not_initialized"}
return {
"status": "active",
"size": self.pool.get_size(),
"max_size": self.pool.get_max_size(),
"min_size": self.pool.get_min_size(),
"connection_failures": self._connection_failures,
"last_health_check": self._last_health_check.isoformat(),
"health_check_interval": self._health_check_interval
}
def is_healthy(self) -> bool:
"""Check if connection pool is healthy."""
if not self.pool:
return False
# Check if health check is recent
time_since_check = datetime.utcnow() - self._last_health_check
if time_since_check > timedelta(seconds=self._health_check_interval * 2):
return False
# Check failure count
return self._connection_failures < self._max_failures
async def acquire(self):
"""Acquire a connection from the pool."""
if not self.pool:
raise ConnectionError("Connection pool not initialized")
return self.pool.acquire()
async def execute(self, query: str, *args) -> None:
"""Execute a query using a pooled connection."""
async with self.acquire() as conn:
return await conn.execute(query, *args)
async def fetch(self, query: str, *args) -> list:
"""Fetch multiple rows using a pooled connection."""
async with self.acquire() as conn:
return await conn.fetch(query, *args)
async def fetchrow(self, query: str, *args):
"""Fetch a single row using a pooled connection."""
async with self.acquire() as conn:
return await conn.fetchrow(query, *args)
async def fetchval(self, query: str, *args):
"""Fetch a single value using a pooled connection."""
async with self.acquire() as conn:
return await conn.fetchval(query, *args)

271
COBY/storage/migrations.py Normal file
View File

@ -0,0 +1,271 @@
"""
Database migration system for schema updates.
"""
from typing import List, Dict, Any
from datetime import datetime
from ..utils.logging import get_logger
from ..utils.exceptions import StorageError
from .connection_pool import db_pool
logger = get_logger(__name__)
class Migration:
"""Base class for database migrations"""
def __init__(self, version: str, description: str):
self.version = version
self.description = description
async def up(self) -> None:
"""Apply the migration"""
raise NotImplementedError
async def down(self) -> None:
"""Rollback the migration"""
raise NotImplementedError
class MigrationManager:
"""Manages database schema migrations"""
def __init__(self):
self.migrations: List[Migration] = []
def register_migration(self, migration: Migration) -> None:
"""Register a migration"""
self.migrations.append(migration)
# Sort by version
self.migrations.sort(key=lambda m: m.version)
async def initialize_migration_table(self) -> None:
"""Create migration tracking table"""
query = """
CREATE TABLE IF NOT EXISTS market_data.schema_migrations (
version VARCHAR(50) PRIMARY KEY,
description TEXT NOT NULL,
applied_at TIMESTAMPTZ DEFAULT NOW()
);
"""
await db_pool.execute_command(query)
logger.debug("Migration table initialized")
async def get_applied_migrations(self) -> List[str]:
"""Get list of applied migration versions"""
try:
query = "SELECT version FROM market_data.schema_migrations ORDER BY version"
rows = await db_pool.execute_query(query)
return [row['version'] for row in rows]
except Exception:
# Table might not exist yet
return []
async def apply_migration(self, migration: Migration) -> bool:
"""Apply a single migration"""
try:
logger.info(f"Applying migration {migration.version}: {migration.description}")
async with db_pool.get_transaction() as conn:
# Apply the migration
await migration.up()
# Record the migration
await conn.execute(
"INSERT INTO market_data.schema_migrations (version, description) VALUES ($1, $2)",
migration.version,
migration.description
)
logger.info(f"Migration {migration.version} applied successfully")
return True
except Exception as e:
logger.error(f"Failed to apply migration {migration.version}: {e}")
return False
async def rollback_migration(self, migration: Migration) -> bool:
"""Rollback a single migration"""
try:
logger.info(f"Rolling back migration {migration.version}: {migration.description}")
async with db_pool.get_transaction() as conn:
# Rollback the migration
await migration.down()
# Remove the migration record
await conn.execute(
"DELETE FROM market_data.schema_migrations WHERE version = $1",
migration.version
)
logger.info(f"Migration {migration.version} rolled back successfully")
return True
except Exception as e:
logger.error(f"Failed to rollback migration {migration.version}: {e}")
return False
async def migrate_up(self, target_version: str = None) -> bool:
"""Apply all pending migrations up to target version"""
try:
await self.initialize_migration_table()
applied_migrations = await self.get_applied_migrations()
pending_migrations = [
m for m in self.migrations
if m.version not in applied_migrations
]
if target_version:
pending_migrations = [
m for m in pending_migrations
if m.version <= target_version
]
if not pending_migrations:
logger.info("No pending migrations to apply")
return True
logger.info(f"Applying {len(pending_migrations)} pending migrations")
for migration in pending_migrations:
if not await self.apply_migration(migration):
return False
logger.info("All migrations applied successfully")
return True
except Exception as e:
logger.error(f"Migration failed: {e}")
return False
async def migrate_down(self, target_version: str) -> bool:
"""Rollback migrations down to target version"""
try:
applied_migrations = await self.get_applied_migrations()
migrations_to_rollback = [
m for m in reversed(self.migrations)
if m.version in applied_migrations and m.version > target_version
]
if not migrations_to_rollback:
logger.info("No migrations to rollback")
return True
logger.info(f"Rolling back {len(migrations_to_rollback)} migrations")
for migration in migrations_to_rollback:
if not await self.rollback_migration(migration):
return False
logger.info("All migrations rolled back successfully")
return True
except Exception as e:
logger.error(f"Migration rollback failed: {e}")
return False
async def get_migration_status(self) -> Dict[str, Any]:
"""Get current migration status"""
try:
applied_migrations = await self.get_applied_migrations()
status = {
'total_migrations': len(self.migrations),
'applied_migrations': len(applied_migrations),
'pending_migrations': len(self.migrations) - len(applied_migrations),
'current_version': applied_migrations[-1] if applied_migrations else None,
'latest_version': self.migrations[-1].version if self.migrations else None,
'migrations': []
}
for migration in self.migrations:
status['migrations'].append({
'version': migration.version,
'description': migration.description,
'applied': migration.version in applied_migrations
})
return status
except Exception as e:
logger.error(f"Failed to get migration status: {e}")
return {}
# Example migrations
class InitialSchemaMigration(Migration):
"""Initial schema creation migration"""
def __init__(self):
super().__init__("001", "Create initial schema and tables")
async def up(self) -> None:
"""Create initial schema"""
from .schema import DatabaseSchema
queries = DatabaseSchema.get_all_creation_queries()
for query in queries:
await db_pool.execute_command(query)
async def down(self) -> None:
"""Drop initial schema"""
# Drop tables in reverse order
tables = [
'system_metrics',
'exchange_status',
'ohlcv_data',
'heatmap_data',
'trade_events',
'order_book_snapshots'
]
for table in tables:
await db_pool.execute_command(f"DROP TABLE IF EXISTS market_data.{table} CASCADE")
class AddIndexesMigration(Migration):
"""Add performance indexes migration"""
def __init__(self):
super().__init__("002", "Add performance indexes")
async def up(self) -> None:
"""Add indexes"""
from .schema import DatabaseSchema
queries = DatabaseSchema.get_index_creation_queries()
for query in queries:
await db_pool.execute_command(query)
async def down(self) -> None:
"""Drop indexes"""
indexes = [
'idx_order_book_symbol_exchange',
'idx_order_book_timestamp',
'idx_trade_events_symbol_exchange',
'idx_trade_events_timestamp',
'idx_trade_events_price',
'idx_heatmap_symbol_bucket',
'idx_heatmap_timestamp',
'idx_ohlcv_symbol_timeframe',
'idx_ohlcv_timestamp',
'idx_exchange_status_exchange',
'idx_exchange_status_timestamp',
'idx_system_metrics_name',
'idx_system_metrics_timestamp'
]
for index in indexes:
await db_pool.execute_command(f"DROP INDEX IF EXISTS market_data.{index}")
# Global migration manager
migration_manager = MigrationManager()
# Register default migrations
migration_manager.register_migration(InitialSchemaMigration())
migration_manager.register_migration(AddIndexesMigration())

338
COBY/storage/schema.py Normal file
View File

@ -0,0 +1,338 @@
"""
Database schema management and migration system.
Handles schema versioning, migrations, and database structure updates.
"""
import logging
from typing import Dict, List, Optional
from datetime import datetime
import asyncpg
logger = logging.getLogger(__name__)
class SchemaManager:
"""Manages database schema versions and migrations."""
def __init__(self, connection_pool):
self.pool = connection_pool
self.current_version = "1.0.0"
async def initialize_schema_tracking(self) -> None:
"""Initialize schema version tracking table."""
try:
async with self.pool.acquire() as conn:
await conn.execute("""
CREATE TABLE IF NOT EXISTS schema_migrations (
version VARCHAR(20) PRIMARY KEY,
applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
description TEXT,
checksum VARCHAR(64)
);
""")
# Record initial schema version
await conn.execute("""
INSERT INTO schema_migrations (version, description)
VALUES ($1, $2)
ON CONFLICT (version) DO NOTHING
""", self.current_version, "Initial schema setup")
logger.info("Schema tracking initialized")
except Exception as e:
logger.error(f"Failed to initialize schema tracking: {e}")
raise
async def get_current_schema_version(self) -> Optional[str]:
"""Get the current schema version from database."""
try:
async with self.pool.acquire() as conn:
version = await conn.fetchval("""
SELECT version FROM schema_migrations
ORDER BY applied_at DESC LIMIT 1
""")
return version
except Exception as e:
logger.error(f"Failed to get schema version: {e}")
return None
async def apply_migration(self, version: str, description: str, sql_commands: List[str]) -> bool:
"""Apply a database migration."""
try:
async with self.pool.acquire() as conn:
async with conn.transaction():
# Check if migration already applied
existing = await conn.fetchval("""
SELECT version FROM schema_migrations WHERE version = $1
""", version)
if existing:
logger.info(f"Migration {version} already applied")
return True
# Apply migration commands
for sql_command in sql_commands:
await conn.execute(sql_command)
# Record migration
await conn.execute("""
INSERT INTO schema_migrations (version, description)
VALUES ($1, $2)
""", version, description)
logger.info(f"Applied migration {version}: {description}")
return True
except Exception as e:
logger.error(f"Failed to apply migration {version}: {e}")
return False
async def create_base_schema(self) -> bool:
"""Create the base database schema with all tables and indexes."""
migration_commands = [
# Enable TimescaleDB extension
"CREATE EXTENSION IF NOT EXISTS timescaledb CASCADE;",
# Order book snapshots table
"""
CREATE TABLE IF NOT EXISTS order_book_snapshots (
id BIGSERIAL,
symbol VARCHAR(20) NOT NULL,
exchange VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
bids JSONB NOT NULL,
asks JSONB NOT NULL,
sequence_id BIGINT,
mid_price DECIMAL(20,8),
spread DECIMAL(20,8),
bid_volume DECIMAL(30,8),
ask_volume DECIMAL(30,8),
PRIMARY KEY (timestamp, symbol, exchange)
);
""",
# Trade events table
"""
CREATE TABLE IF NOT EXISTS trade_events (
id BIGSERIAL,
symbol VARCHAR(20) NOT NULL,
exchange VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
price DECIMAL(20,8) NOT NULL,
size DECIMAL(30,8) NOT NULL,
side VARCHAR(4) NOT NULL,
trade_id VARCHAR(100) NOT NULL,
PRIMARY KEY (timestamp, symbol, exchange, trade_id)
);
""",
# Heatmap data table
"""
CREATE TABLE IF NOT EXISTS heatmap_data (
symbol VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
bucket_size DECIMAL(10,2) NOT NULL,
price_bucket DECIMAL(20,8) NOT NULL,
volume DECIMAL(30,8) NOT NULL,
side VARCHAR(3) NOT NULL,
exchange_count INTEGER NOT NULL,
PRIMARY KEY (timestamp, symbol, bucket_size, price_bucket, side)
);
""",
# OHLCV data table
"""
CREATE TABLE IF NOT EXISTS ohlcv_data (
symbol VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
timeframe VARCHAR(10) NOT NULL,
open_price DECIMAL(20,8) NOT NULL,
high_price DECIMAL(20,8) NOT NULL,
low_price DECIMAL(20,8) NOT NULL,
close_price DECIMAL(20,8) NOT NULL,
volume DECIMAL(30,8) NOT NULL,
trade_count INTEGER,
PRIMARY KEY (timestamp, symbol, timeframe)
);
"""
]
return await self.apply_migration(
"1.0.0",
"Create base schema with core tables",
migration_commands
)
async def create_hypertables(self) -> bool:
"""Convert tables to TimescaleDB hypertables."""
hypertable_commands = [
"SELECT create_hypertable('order_book_snapshots', 'timestamp', if_not_exists => TRUE);",
"SELECT create_hypertable('trade_events', 'timestamp', if_not_exists => TRUE);",
"SELECT create_hypertable('heatmap_data', 'timestamp', if_not_exists => TRUE);",
"SELECT create_hypertable('ohlcv_data', 'timestamp', if_not_exists => TRUE);"
]
return await self.apply_migration(
"1.0.1",
"Convert tables to hypertables",
hypertable_commands
)
async def create_indexes(self) -> bool:
"""Create performance indexes."""
index_commands = [
# Order book snapshots indexes
"CREATE INDEX IF NOT EXISTS idx_obs_symbol_time ON order_book_snapshots (symbol, timestamp DESC);",
"CREATE INDEX IF NOT EXISTS idx_obs_exchange_time ON order_book_snapshots (exchange, timestamp DESC);",
"CREATE INDEX IF NOT EXISTS idx_obs_symbol_exchange ON order_book_snapshots (symbol, exchange, timestamp DESC);",
# Trade events indexes
"CREATE INDEX IF NOT EXISTS idx_trades_symbol_time ON trade_events (symbol, timestamp DESC);",
"CREATE INDEX IF NOT EXISTS idx_trades_exchange_time ON trade_events (exchange, timestamp DESC);",
"CREATE INDEX IF NOT EXISTS idx_trades_price ON trade_events (symbol, price, timestamp DESC);",
"CREATE INDEX IF NOT EXISTS idx_trades_side ON trade_events (symbol, side, timestamp DESC);",
# Heatmap data indexes
"CREATE INDEX IF NOT EXISTS idx_heatmap_symbol_time ON heatmap_data (symbol, timestamp DESC);",
"CREATE INDEX IF NOT EXISTS idx_heatmap_bucket ON heatmap_data (symbol, bucket_size, timestamp DESC);",
"CREATE INDEX IF NOT EXISTS idx_heatmap_side ON heatmap_data (symbol, side, timestamp DESC);",
# OHLCV indexes
"CREATE INDEX IF NOT EXISTS idx_ohlcv_symbol_timeframe ON ohlcv_data (symbol, timeframe, timestamp DESC);"
]
return await self.apply_migration(
"1.0.2",
"Create performance indexes",
index_commands
)
async def setup_retention_policies(self) -> bool:
"""Set up data retention policies."""
retention_commands = [
"SELECT add_retention_policy('order_book_snapshots', INTERVAL '30 days', if_not_exists => TRUE);",
"SELECT add_retention_policy('trade_events', INTERVAL '90 days', if_not_exists => TRUE);",
"SELECT add_retention_policy('heatmap_data', INTERVAL '60 days', if_not_exists => TRUE);",
"SELECT add_retention_policy('ohlcv_data', INTERVAL '1 year', if_not_exists => TRUE);"
]
return await self.apply_migration(
"1.0.3",
"Setup data retention policies",
retention_commands
)
async def create_continuous_aggregates(self) -> bool:
"""Create continuous aggregates for better query performance."""
aggregate_commands = [
# 1-minute OHLCV aggregates from trades
"""
CREATE MATERIALIZED VIEW IF NOT EXISTS trades_1m
WITH (timescaledb.continuous) AS
SELECT
time_bucket('1 minute', timestamp) AS bucket,
symbol,
exchange,
first(price, timestamp) AS open_price,
max(price) AS high_price,
min(price) AS low_price,
last(price, timestamp) AS close_price,
sum(size) AS volume,
count(*) AS trade_count
FROM trade_events
GROUP BY bucket, symbol, exchange;
""",
# 5-minute order book statistics
"""
CREATE MATERIALIZED VIEW IF NOT EXISTS orderbook_stats_5m
WITH (timescaledb.continuous) AS
SELECT
time_bucket('5 minutes', timestamp) AS bucket,
symbol,
exchange,
avg(mid_price) AS avg_mid_price,
avg(spread) AS avg_spread,
avg(bid_volume) AS avg_bid_volume,
avg(ask_volume) AS avg_ask_volume,
count(*) AS snapshot_count
FROM order_book_snapshots
WHERE mid_price IS NOT NULL
GROUP BY bucket, symbol, exchange;
"""
]
return await self.apply_migration(
"1.0.4",
"Create continuous aggregates",
aggregate_commands
)
async def setup_complete_schema(self) -> bool:
"""Set up the complete database schema with all components."""
try:
# Initialize schema tracking
await self.initialize_schema_tracking()
# Apply all migrations in order
migrations = [
self.create_base_schema,
self.create_hypertables,
self.create_indexes,
self.setup_retention_policies,
self.create_continuous_aggregates
]
for migration in migrations:
success = await migration()
if not success:
logger.error(f"Failed to apply migration: {migration.__name__}")
return False
logger.info("Complete database schema setup successful")
return True
except Exception as e:
logger.error(f"Failed to setup complete schema: {e}")
return False
async def get_schema_info(self) -> Dict:
"""Get information about the current schema state."""
try:
async with self.pool.acquire() as conn:
# Get applied migrations
migrations = await conn.fetch("""
SELECT version, applied_at, description
FROM schema_migrations
ORDER BY applied_at
""")
# Get table information
tables = await conn.fetch("""
SELECT
schemaname,
tablename,
pg_size_pretty(pg_total_relation_size(schemaname||'.'||tablename)) as size
FROM pg_tables
WHERE schemaname = 'public'
AND tablename IN ('order_book_snapshots', 'trade_events', 'heatmap_data', 'ohlcv_data')
""")
# Get hypertable information
hypertables = await conn.fetch("""
SELECT hypertable_name, num_chunks, compression_enabled
FROM timescaledb_information.hypertables
WHERE hypertable_schema = 'public'
""")
return {
"migrations": [dict(m) for m in migrations],
"tables": [dict(t) for t in tables],
"hypertables": [dict(h) for h in hypertables]
}
except Exception as e:
logger.error(f"Failed to get schema info: {e}")
return {}

View File

@ -0,0 +1,270 @@
"""
Comprehensive storage manager that integrates TimescaleDB, connection pooling, and schema management.
"""
import asyncio
import logging
from typing import Dict, List, Optional, Any
from datetime import datetime, timedelta
from .timescale_manager import TimescaleManager
from .connection_pool import ConnectionPoolManager
from .schema import SchemaManager
from ..models.core import OrderBookSnapshot, TradeEvent, HeatmapData
from ..config import Config
from ..utils.exceptions import DatabaseError, ConnectionError
logger = logging.getLogger(__name__)
class StorageManager:
"""Unified storage manager for all database operations."""
def __init__(self, config: Config):
self.config = config
self.connection_pool = ConnectionPoolManager(config)
self.schema_manager = SchemaManager(self.connection_pool)
self.timescale_manager = TimescaleManager(config)
self._initialized = False
async def initialize(self) -> None:
"""Initialize all storage components."""
try:
logger.info("Initializing storage manager...")
# Initialize connection pool
await self.connection_pool.initialize()
# Set up database schema
await self.schema_manager.setup_complete_schema()
# Initialize TimescaleDB manager with existing pool
self.timescale_manager.pool = self.connection_pool.pool
self._initialized = True
logger.info("Storage manager initialized successfully")
except Exception as e:
logger.error(f"Failed to initialize storage manager: {e}")
raise ConnectionError(f"Storage initialization failed: {e}")
async def close(self) -> None:
"""Close all storage connections."""
if self.timescale_manager:
await self.timescale_manager.close()
if self.connection_pool:
await self.connection_pool.close()
logger.info("Storage manager closed")
def is_healthy(self) -> bool:
"""Check if storage system is healthy."""
return self._initialized and self.connection_pool.is_healthy()
# Order book operations
async def store_orderbook(self, data: OrderBookSnapshot) -> bool:
"""Store order book snapshot."""
if not self._initialized:
raise DatabaseError("Storage manager not initialized")
return await self.timescale_manager.store_orderbook(data)
async def batch_store_orderbooks(self, data_list: List[OrderBookSnapshot]) -> bool:
"""Store multiple order book snapshots."""
if not self._initialized:
raise DatabaseError("Storage manager not initialized")
return await self.timescale_manager.batch_store_orderbooks(data_list)
async def get_latest_orderbook(self, symbol: str, exchange: Optional[str] = None) -> Optional[Dict]:
"""Get latest order book snapshot."""
if not self._initialized:
raise DatabaseError("Storage manager not initialized")
return await self.timescale_manager.get_latest_orderbook(symbol, exchange)
# Trade operations
async def store_trade(self, data: TradeEvent) -> bool:
"""Store trade event."""
if not self._initialized:
raise DatabaseError("Storage manager not initialized")
return await self.timescale_manager.store_trade(data)
async def batch_store_trades(self, data_list: List[TradeEvent]) -> bool:
"""Store multiple trade events."""
if not self._initialized:
raise DatabaseError("Storage manager not initialized")
return await self.timescale_manager.batch_store_trades(data_list)
# Heatmap operations
async def store_heatmap_data(self, data: HeatmapData) -> bool:
"""Store heatmap data."""
if not self._initialized:
raise DatabaseError("Storage manager not initialized")
return await self.timescale_manager.store_heatmap_data(data)
async def get_heatmap_data(self, symbol: str, bucket_size: float,
start: Optional[datetime] = None) -> List[Dict]:
"""Get heatmap data for visualization."""
if not self._initialized:
raise DatabaseError("Storage manager not initialized")
return await self.timescale_manager.get_heatmap_data(symbol, bucket_size, start)
# Historical data operations
async def get_historical_data(self, symbol: str, start: datetime, end: datetime,
data_type: str = 'orderbook') -> List[Dict]:
"""Get historical data for a symbol within time range."""
if not self._initialized:
raise DatabaseError("Storage manager not initialized")
return await self.timescale_manager.get_historical_data(symbol, start, end, data_type)
# System operations
async def get_system_stats(self) -> Dict[str, Any]:
"""Get comprehensive system statistics."""
if not self._initialized:
return {"status": "not_initialized"}
try:
# Get database stats
db_stats = await self.timescale_manager.get_database_stats()
# Get connection pool stats
pool_stats = self.connection_pool.get_pool_stats()
# Get schema info
schema_info = await self.schema_manager.get_schema_info()
return {
"status": "healthy" if self.is_healthy() else "unhealthy",
"database": db_stats,
"connection_pool": pool_stats,
"schema": schema_info,
"initialized": self._initialized
}
except Exception as e:
logger.error(f"Failed to get system stats: {e}")
return {"status": "error", "error": str(e)}
async def health_check(self) -> Dict[str, Any]:
"""Perform comprehensive health check."""
health_status = {
"healthy": True,
"components": {},
"timestamp": datetime.utcnow().isoformat()
}
try:
# Check connection pool
pool_healthy = self.connection_pool.is_healthy()
health_status["components"]["connection_pool"] = {
"healthy": pool_healthy,
"stats": self.connection_pool.get_pool_stats()
}
# Test database connection
try:
async with self.connection_pool.acquire() as conn:
await conn.execute('SELECT 1')
health_status["components"]["database"] = {"healthy": True}
except Exception as e:
health_status["components"]["database"] = {
"healthy": False,
"error": str(e)
}
health_status["healthy"] = False
# Check schema version
try:
current_version = await self.schema_manager.get_current_schema_version()
health_status["components"]["schema"] = {
"healthy": True,
"version": current_version
}
except Exception as e:
health_status["components"]["schema"] = {
"healthy": False,
"error": str(e)
}
health_status["healthy"] = False
# Overall health
health_status["healthy"] = all(
comp.get("healthy", False)
for comp in health_status["components"].values()
)
except Exception as e:
logger.error(f"Health check failed: {e}")
health_status["healthy"] = False
health_status["error"] = str(e)
return health_status
# Maintenance operations
async def cleanup_old_data(self, days: int = 30) -> Dict[str, int]:
"""Clean up old data beyond retention period."""
if not self._initialized:
raise DatabaseError("Storage manager not initialized")
try:
cutoff_date = datetime.utcnow() - timedelta(days=days)
cleanup_stats = {}
async with self.connection_pool.acquire() as conn:
# Clean up old order book snapshots
result = await conn.execute("""
DELETE FROM order_book_snapshots
WHERE timestamp < $1
""", cutoff_date)
cleanup_stats["order_book_snapshots"] = int(result.split()[-1])
# Clean up old trade events
result = await conn.execute("""
DELETE FROM trade_events
WHERE timestamp < $1
""", cutoff_date)
cleanup_stats["trade_events"] = int(result.split()[-1])
# Clean up old heatmap data
result = await conn.execute("""
DELETE FROM heatmap_data
WHERE timestamp < $1
""", cutoff_date)
cleanup_stats["heatmap_data"] = int(result.split()[-1])
logger.info(f"Cleaned up old data: {cleanup_stats}")
return cleanup_stats
except Exception as e:
logger.error(f"Failed to cleanup old data: {e}")
raise DatabaseError(f"Data cleanup failed: {e}")
async def optimize_database(self) -> bool:
"""Run database optimization tasks."""
if not self._initialized:
raise DatabaseError("Storage manager not initialized")
try:
async with self.connection_pool.acquire() as conn:
# Analyze tables for better query planning
tables = ['order_book_snapshots', 'trade_events', 'heatmap_data', 'ohlcv_data']
for table in tables:
await conn.execute(f"ANALYZE {table}")
# Vacuum tables to reclaim space
for table in tables:
await conn.execute(f"VACUUM {table}")
logger.info("Database optimization completed")
return True
except Exception as e:
logger.error(f"Database optimization failed: {e}")
return False

View File

@ -0,0 +1,540 @@
"""
TimescaleDB connection manager and database operations.
Provides connection pooling, schema management, and optimized time-series operations.
"""
import asyncio
import logging
from datetime import datetime
from typing import Dict, List, Optional, Any
from dataclasses import asdict
import asyncpg
from asyncpg import Pool, Connection
import json
from ..models.core import OrderBookSnapshot, TradeEvent, HeatmapData, PriceBuckets
from ..utils.exceptions import DatabaseError, ConnectionError
from ..config import Config
logger = logging.getLogger(__name__)
class TimescaleManager:
"""Manages TimescaleDB connections and operations for time-series data."""
def __init__(self, config: Config):
self.config = config
self.pool: Optional[Pool] = None
self._connection_string = self._build_connection_string()
def _build_connection_string(self) -> str:
"""Build PostgreSQL connection string from config."""
return (
f"postgresql://{self.config.database.user}:{self.config.database.password}"
f"@{self.config.database.host}:{self.config.database.port}/{self.config.database.name}"
)
async def initialize(self) -> None:
"""Initialize connection pool and database schema."""
try:
logger.info("Initializing TimescaleDB connection pool...")
self.pool = await asyncpg.create_pool(
self._connection_string,
min_size=5,
max_size=self.config.database.pool_size,
command_timeout=60,
server_settings={
'jit': 'off', # Disable JIT for better performance with time-series
'timezone': 'UTC'
}
)
# Test connection
async with self.pool.acquire() as conn:
await conn.execute('SELECT 1')
logger.info("TimescaleDB connection pool initialized successfully")
# Initialize database schema
await self.setup_database_schema()
except Exception as e:
logger.error(f"Failed to initialize TimescaleDB: {e}")
raise ConnectionError(f"TimescaleDB initialization failed: {e}")
async def close(self) -> None:
"""Close connection pool."""
if self.pool:
await self.pool.close()
logger.info("TimescaleDB connection pool closed")
async def setup_database_schema(self) -> None:
"""Create database schema with hypertables and indexes."""
try:
async with self.pool.acquire() as conn:
# Enable TimescaleDB extension
await conn.execute("CREATE EXTENSION IF NOT EXISTS timescaledb CASCADE;")
# Create order book snapshots table
await conn.execute("""
CREATE TABLE IF NOT EXISTS order_book_snapshots (
id BIGSERIAL,
symbol VARCHAR(20) NOT NULL,
exchange VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
bids JSONB NOT NULL,
asks JSONB NOT NULL,
sequence_id BIGINT,
mid_price DECIMAL(20,8),
spread DECIMAL(20,8),
bid_volume DECIMAL(30,8),
ask_volume DECIMAL(30,8),
PRIMARY KEY (timestamp, symbol, exchange)
);
""")
# Convert to hypertable if not already
try:
await conn.execute("""
SELECT create_hypertable('order_book_snapshots', 'timestamp',
if_not_exists => TRUE);
""")
except Exception as e:
if "already a hypertable" not in str(e):
logger.warning(f"Could not create hypertable for order_book_snapshots: {e}")
# Create trade events table
await conn.execute("""
CREATE TABLE IF NOT EXISTS trade_events (
id BIGSERIAL,
symbol VARCHAR(20) NOT NULL,
exchange VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
price DECIMAL(20,8) NOT NULL,
size DECIMAL(30,8) NOT NULL,
side VARCHAR(4) NOT NULL,
trade_id VARCHAR(100) NOT NULL,
PRIMARY KEY (timestamp, symbol, exchange, trade_id)
);
""")
# Convert to hypertable if not already
try:
await conn.execute("""
SELECT create_hypertable('trade_events', 'timestamp',
if_not_exists => TRUE);
""")
except Exception as e:
if "already a hypertable" not in str(e):
logger.warning(f"Could not create hypertable for trade_events: {e}")
# Create heatmap data table
await conn.execute("""
CREATE TABLE IF NOT EXISTS heatmap_data (
symbol VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
bucket_size DECIMAL(10,2) NOT NULL,
price_bucket DECIMAL(20,8) NOT NULL,
volume DECIMAL(30,8) NOT NULL,
side VARCHAR(3) NOT NULL,
exchange_count INTEGER NOT NULL,
PRIMARY KEY (timestamp, symbol, bucket_size, price_bucket, side)
);
""")
# Convert to hypertable if not already
try:
await conn.execute("""
SELECT create_hypertable('heatmap_data', 'timestamp',
if_not_exists => TRUE);
""")
except Exception as e:
if "already a hypertable" not in str(e):
logger.warning(f"Could not create hypertable for heatmap_data: {e}")
# Create OHLCV data table
await conn.execute("""
CREATE TABLE IF NOT EXISTS ohlcv_data (
symbol VARCHAR(20) NOT NULL,
timestamp TIMESTAMPTZ NOT NULL,
timeframe VARCHAR(10) NOT NULL,
open_price DECIMAL(20,8) NOT NULL,
high_price DECIMAL(20,8) NOT NULL,
low_price DECIMAL(20,8) NOT NULL,
close_price DECIMAL(20,8) NOT NULL,
volume DECIMAL(30,8) NOT NULL,
trade_count INTEGER,
PRIMARY KEY (timestamp, symbol, timeframe)
);
""")
# Convert to hypertable if not already
try:
await conn.execute("""
SELECT create_hypertable('ohlcv_data', 'timestamp',
if_not_exists => TRUE);
""")
except Exception as e:
if "already a hypertable" not in str(e):
logger.warning(f"Could not create hypertable for ohlcv_data: {e}")
# Create indexes for better query performance
await self._create_indexes(conn)
# Set up data retention policies
await self._setup_retention_policies(conn)
logger.info("Database schema setup completed successfully")
except Exception as e:
logger.error(f"Failed to setup database schema: {e}")
raise DatabaseError(f"Schema setup failed: {e}")
async def _create_indexes(self, conn: Connection) -> None:
"""Create indexes for optimized queries."""
indexes = [
# Order book snapshots indexes
"CREATE INDEX IF NOT EXISTS idx_obs_symbol_time ON order_book_snapshots (symbol, timestamp DESC);",
"CREATE INDEX IF NOT EXISTS idx_obs_exchange_time ON order_book_snapshots (exchange, timestamp DESC);",
"CREATE INDEX IF NOT EXISTS idx_obs_symbol_exchange ON order_book_snapshots (symbol, exchange, timestamp DESC);",
# Trade events indexes
"CREATE INDEX IF NOT EXISTS idx_trades_symbol_time ON trade_events (symbol, timestamp DESC);",
"CREATE INDEX IF NOT EXISTS idx_trades_exchange_time ON trade_events (exchange, timestamp DESC);",
"CREATE INDEX IF NOT EXISTS idx_trades_price ON trade_events (symbol, price, timestamp DESC);",
# Heatmap data indexes
"CREATE INDEX IF NOT EXISTS idx_heatmap_symbol_time ON heatmap_data (symbol, timestamp DESC);",
"CREATE INDEX IF NOT EXISTS idx_heatmap_bucket ON heatmap_data (symbol, bucket_size, timestamp DESC);",
# OHLCV indexes
"CREATE INDEX IF NOT EXISTS idx_ohlcv_symbol_timeframe ON ohlcv_data (symbol, timeframe, timestamp DESC);"
]
for index_sql in indexes:
try:
await conn.execute(index_sql)
except Exception as e:
logger.warning(f"Could not create index: {e}")
async def _setup_retention_policies(self, conn: Connection) -> None:
"""Set up data retention policies for automatic cleanup."""
try:
# Keep raw order book data for 30 days
await conn.execute("""
SELECT add_retention_policy('order_book_snapshots', INTERVAL '30 days',
if_not_exists => TRUE);
""")
# Keep trade events for 90 days
await conn.execute("""
SELECT add_retention_policy('trade_events', INTERVAL '90 days',
if_not_exists => TRUE);
""")
# Keep heatmap data for 60 days
await conn.execute("""
SELECT add_retention_policy('heatmap_data', INTERVAL '60 days',
if_not_exists => TRUE);
""")
# Keep OHLCV data for 1 year
await conn.execute("""
SELECT add_retention_policy('ohlcv_data', INTERVAL '1 year',
if_not_exists => TRUE);
""")
logger.info("Data retention policies configured")
except Exception as e:
logger.warning(f"Could not set up retention policies: {e}")
async def store_orderbook(self, data: OrderBookSnapshot) -> bool:
"""Store order book snapshot in database."""
try:
async with self.pool.acquire() as conn:
# Calculate derived metrics
mid_price = None
spread = None
bid_volume = sum(level.size for level in data.bids)
ask_volume = sum(level.size for level in data.asks)
if data.bids and data.asks:
best_bid = max(data.bids, key=lambda x: x.price).price
best_ask = min(data.asks, key=lambda x: x.price).price
mid_price = (best_bid + best_ask) / 2
spread = best_ask - best_bid
# Convert price levels to JSON
bids_json = json.dumps([asdict(level) for level in data.bids])
asks_json = json.dumps([asdict(level) for level in data.asks])
await conn.execute("""
INSERT INTO order_book_snapshots
(symbol, exchange, timestamp, bids, asks, sequence_id,
mid_price, spread, bid_volume, ask_volume)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
ON CONFLICT (timestamp, symbol, exchange) DO UPDATE SET
bids = EXCLUDED.bids,
asks = EXCLUDED.asks,
sequence_id = EXCLUDED.sequence_id,
mid_price = EXCLUDED.mid_price,
spread = EXCLUDED.spread,
bid_volume = EXCLUDED.bid_volume,
ask_volume = EXCLUDED.ask_volume
""", data.symbol, data.exchange, data.timestamp, bids_json, asks_json,
data.sequence_id, mid_price, spread, bid_volume, ask_volume)
return True
except Exception as e:
logger.error(f"Failed to store order book data: {e}")
raise DatabaseError(f"Order book storage failed: {e}")
async def store_trade(self, data: TradeEvent) -> bool:
"""Store trade event in database."""
try:
async with self.pool.acquire() as conn:
await conn.execute("""
INSERT INTO trade_events
(symbol, exchange, timestamp, price, size, side, trade_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (timestamp, symbol, exchange, trade_id) DO NOTHING
""", data.symbol, data.exchange, data.timestamp, data.price,
data.size, data.side, data.trade_id)
return True
except Exception as e:
logger.error(f"Failed to store trade data: {e}")
raise DatabaseError(f"Trade storage failed: {e}")
async def store_heatmap_data(self, data: HeatmapData) -> bool:
"""Store heatmap data in database."""
try:
async with self.pool.acquire() as conn:
# Prepare batch insert data
insert_data = []
for point in data.data:
insert_data.append((
data.symbol, data.timestamp, data.bucket_size,
point.price, point.volume, point.side, 1 # exchange_count
))
if insert_data:
await conn.executemany("""
INSERT INTO heatmap_data
(symbol, timestamp, bucket_size, price_bucket, volume, side, exchange_count)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (timestamp, symbol, bucket_size, price_bucket, side)
DO UPDATE SET
volume = heatmap_data.volume + EXCLUDED.volume,
exchange_count = heatmap_data.exchange_count + EXCLUDED.exchange_count
""", insert_data)
return True
except Exception as e:
logger.error(f"Failed to store heatmap data: {e}")
raise DatabaseError(f"Heatmap storage failed: {e}")
async def get_latest_orderbook(self, symbol: str, exchange: Optional[str] = None) -> Optional[Dict]:
"""Get latest order book snapshot for a symbol."""
try:
async with self.pool.acquire() as conn:
if exchange:
query = """
SELECT * FROM order_book_snapshots
WHERE symbol = $1 AND exchange = $2
ORDER BY timestamp DESC LIMIT 1
"""
row = await conn.fetchrow(query, symbol, exchange)
else:
query = """
SELECT * FROM order_book_snapshots
WHERE symbol = $1
ORDER BY timestamp DESC LIMIT 1
"""
row = await conn.fetchrow(query, symbol)
if row:
return dict(row)
return None
except Exception as e:
logger.error(f"Failed to get latest order book: {e}")
raise DatabaseError(f"Order book retrieval failed: {e}")
async def get_historical_data(self, symbol: str, start: datetime, end: datetime,
data_type: str = 'orderbook') -> List[Dict]:
"""Get historical data for a symbol within time range."""
try:
async with self.pool.acquire() as conn:
if data_type == 'orderbook':
query = """
SELECT * FROM order_book_snapshots
WHERE symbol = $1 AND timestamp >= $2 AND timestamp <= $3
ORDER BY timestamp ASC
"""
elif data_type == 'trades':
query = """
SELECT * FROM trade_events
WHERE symbol = $1 AND timestamp >= $2 AND timestamp <= $3
ORDER BY timestamp ASC
"""
elif data_type == 'heatmap':
query = """
SELECT * FROM heatmap_data
WHERE symbol = $1 AND timestamp >= $2 AND timestamp <= $3
ORDER BY timestamp ASC
"""
else:
raise ValueError(f"Unknown data type: {data_type}")
rows = await conn.fetch(query, symbol, start, end)
return [dict(row) for row in rows]
except Exception as e:
logger.error(f"Failed to get historical data: {e}")
raise DatabaseError(f"Historical data retrieval failed: {e}")
async def get_heatmap_data(self, symbol: str, bucket_size: float,
start: Optional[datetime] = None) -> List[Dict]:
"""Get heatmap data for visualization."""
try:
async with self.pool.acquire() as conn:
if start:
query = """
SELECT price_bucket, volume, side, timestamp
FROM heatmap_data
WHERE symbol = $1 AND bucket_size = $2 AND timestamp >= $3
ORDER BY timestamp DESC, price_bucket ASC
"""
rows = await conn.fetch(query, symbol, bucket_size, start)
else:
# Get latest heatmap data
query = """
SELECT price_bucket, volume, side, timestamp
FROM heatmap_data
WHERE symbol = $1 AND bucket_size = $2
AND timestamp = (
SELECT MAX(timestamp) FROM heatmap_data
WHERE symbol = $1 AND bucket_size = $2
)
ORDER BY price_bucket ASC
"""
rows = await conn.fetch(query, symbol, bucket_size)
return [dict(row) for row in rows]
except Exception as e:
logger.error(f"Failed to get heatmap data: {e}")
raise DatabaseError(f"Heatmap data retrieval failed: {e}")
async def batch_store_orderbooks(self, data_list: List[OrderBookSnapshot]) -> bool:
"""Store multiple order book snapshots in a single transaction."""
if not data_list:
return True
try:
async with self.pool.acquire() as conn:
async with conn.transaction():
insert_data = []
for data in data_list:
# Calculate derived metrics
mid_price = None
spread = None
bid_volume = sum(level.size for level in data.bids)
ask_volume = sum(level.size for level in data.asks)
if data.bids and data.asks:
best_bid = max(data.bids, key=lambda x: x.price).price
best_ask = min(data.asks, key=lambda x: x.price).price
mid_price = (best_bid + best_ask) / 2
spread = best_ask - best_bid
# Convert price levels to JSON
bids_json = json.dumps([asdict(level) for level in data.bids])
asks_json = json.dumps([asdict(level) for level in data.asks])
insert_data.append((
data.symbol, data.exchange, data.timestamp, bids_json, asks_json,
data.sequence_id, mid_price, spread, bid_volume, ask_volume
))
await conn.executemany("""
INSERT INTO order_book_snapshots
(symbol, exchange, timestamp, bids, asks, sequence_id,
mid_price, spread, bid_volume, ask_volume)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
ON CONFLICT (timestamp, symbol, exchange) DO UPDATE SET
bids = EXCLUDED.bids,
asks = EXCLUDED.asks,
sequence_id = EXCLUDED.sequence_id,
mid_price = EXCLUDED.mid_price,
spread = EXCLUDED.spread,
bid_volume = EXCLUDED.bid_volume,
ask_volume = EXCLUDED.ask_volume
""", insert_data)
return True
except Exception as e:
logger.error(f"Failed to batch store order books: {e}")
raise DatabaseError(f"Batch order book storage failed: {e}")
async def batch_store_trades(self, data_list: List[TradeEvent]) -> bool:
"""Store multiple trade events in a single transaction."""
if not data_list:
return True
try:
async with self.pool.acquire() as conn:
async with conn.transaction():
insert_data = [(
data.symbol, data.exchange, data.timestamp, data.price,
data.size, data.side, data.trade_id
) for data in data_list]
await conn.executemany("""
INSERT INTO trade_events
(symbol, exchange, timestamp, price, size, side, trade_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (timestamp, symbol, exchange, trade_id) DO NOTHING
""", insert_data)
return True
except Exception as e:
logger.error(f"Failed to batch store trades: {e}")
raise DatabaseError(f"Batch trade storage failed: {e}")
async def get_database_stats(self) -> Dict[str, Any]:
"""Get database statistics and health information."""
try:
async with self.pool.acquire() as conn:
stats = {}
# Get table sizes
tables = ['order_book_snapshots', 'trade_events', 'heatmap_data', 'ohlcv_data']
for table in tables:
row = await conn.fetchrow(f"""
SELECT
COUNT(*) as row_count,
pg_size_pretty(pg_total_relation_size('{table}')) as size
FROM {table}
""")
stats[table] = dict(row)
# Get connection pool stats
stats['connection_pool'] = {
'size': self.pool.get_size(),
'max_size': self.pool.get_max_size(),
'min_size': self.pool.get_min_size()
}
return stats
except Exception as e:
logger.error(f"Failed to get database stats: {e}")
return {}

274
COBY/test_integration.py Normal file
View File

@ -0,0 +1,274 @@
#!/usr/bin/env python3
"""
Integration test script for COBY system components.
Run this to test the TimescaleDB integration and basic functionality.
"""
import asyncio
import sys
from datetime import datetime, timezone
from pathlib import Path
# Add COBY to path
sys.path.insert(0, str(Path(__file__).parent))
from config import config
from storage.timescale_manager import TimescaleManager
from models.core import OrderBookSnapshot, TradeEvent, PriceLevel
from utils.logging import setup_logging, get_logger
# Setup logging
setup_logging(level='INFO', console_output=True)
logger = get_logger(__name__)
async def test_database_connection():
"""Test basic database connectivity"""
logger.info("🔌 Testing database connection...")
try:
manager = TimescaleManager()
await manager.initialize()
# Test health check
is_healthy = await manager.health_check()
if is_healthy:
logger.info("✅ Database connection: HEALTHY")
else:
logger.error("❌ Database connection: UNHEALTHY")
return False
# Test storage stats
stats = await manager.get_storage_stats()
logger.info(f"📊 Found {len(stats.get('table_sizes', []))} tables")
for table_info in stats.get('table_sizes', []):
logger.info(f" 📋 {table_info['table']}: {table_info['size']}")
await manager.close()
return True
except Exception as e:
logger.error(f"❌ Database test failed: {e}")
return False
async def test_data_storage():
"""Test storing and retrieving data"""
logger.info("💾 Testing data storage operations...")
try:
manager = TimescaleManager()
await manager.initialize()
# Create test order book
test_orderbook = OrderBookSnapshot(
symbol="BTCUSDT",
exchange="test_exchange",
timestamp=datetime.now(timezone.utc),
bids=[
PriceLevel(price=50000.0, size=1.5, count=3),
PriceLevel(price=49999.0, size=2.0, count=5)
],
asks=[
PriceLevel(price=50001.0, size=1.0, count=2),
PriceLevel(price=50002.0, size=1.5, count=4)
],
sequence_id=12345
)
# Test storing order book
result = await manager.store_orderbook(test_orderbook)
if result:
logger.info("✅ Order book storage: SUCCESS")
else:
logger.error("❌ Order book storage: FAILED")
return False
# Test retrieving order book
retrieved = await manager.get_latest_orderbook("BTCUSDT", "test_exchange")
if retrieved:
logger.info(f"✅ Order book retrieval: SUCCESS (mid_price: {retrieved.mid_price})")
else:
logger.error("❌ Order book retrieval: FAILED")
return False
# Create test trade
test_trade = TradeEvent(
symbol="BTCUSDT",
exchange="test_exchange",
timestamp=datetime.now(timezone.utc),
price=50000.5,
size=0.1,
side="buy",
trade_id="test_trade_123"
)
# Test storing trade
result = await manager.store_trade(test_trade)
if result:
logger.info("✅ Trade storage: SUCCESS")
else:
logger.error("❌ Trade storage: FAILED")
return False
await manager.close()
return True
except Exception as e:
logger.error(f"❌ Data storage test failed: {e}")
return False
async def test_batch_operations():
"""Test batch storage operations"""
logger.info("📦 Testing batch operations...")
try:
manager = TimescaleManager()
await manager.initialize()
# Create batch of order books
orderbooks = []
for i in range(5):
orderbook = OrderBookSnapshot(
symbol="ETHUSDT",
exchange="test_exchange",
timestamp=datetime.now(timezone.utc),
bids=[PriceLevel(price=3000.0 + i, size=1.0)],
asks=[PriceLevel(price=3001.0 + i, size=1.0)],
sequence_id=i
)
orderbooks.append(orderbook)
# Test batch storage
result = await manager.batch_store_orderbooks(orderbooks)
if result == 5:
logger.info(f"✅ Batch order book storage: SUCCESS ({result} records)")
else:
logger.error(f"❌ Batch order book storage: PARTIAL ({result}/5 records)")
return False
# Create batch of trades
trades = []
for i in range(10):
trade = TradeEvent(
symbol="ETHUSDT",
exchange="test_exchange",
timestamp=datetime.now(timezone.utc),
price=3000.0 + (i * 0.1),
size=0.05,
side="buy" if i % 2 == 0 else "sell",
trade_id=f"batch_trade_{i}"
)
trades.append(trade)
# Test batch trade storage
result = await manager.batch_store_trades(trades)
if result == 10:
logger.info(f"✅ Batch trade storage: SUCCESS ({result} records)")
else:
logger.error(f"❌ Batch trade storage: PARTIAL ({result}/10 records)")
return False
await manager.close()
return True
except Exception as e:
logger.error(f"❌ Batch operations test failed: {e}")
return False
async def test_configuration():
"""Test configuration system"""
logger.info("⚙️ Testing configuration system...")
try:
# Test database configuration
db_url = config.get_database_url()
logger.info(f"✅ Database URL: {db_url.replace(config.database.password, '***')}")
# Test Redis configuration
redis_url = config.get_redis_url()
logger.info(f"✅ Redis URL: {redis_url.replace(config.redis.password, '***')}")
# Test bucket sizes
btc_bucket = config.get_bucket_size('BTCUSDT')
eth_bucket = config.get_bucket_size('ETHUSDT')
logger.info(f"✅ Bucket sizes: BTC=${btc_bucket}, ETH=${eth_bucket}")
# Test configuration dict
config_dict = config.to_dict()
logger.info(f"✅ Configuration loaded: {len(config_dict)} sections")
return True
except Exception as e:
logger.error(f"❌ Configuration test failed: {e}")
return False
async def run_all_tests():
"""Run all integration tests"""
logger.info("🚀 Starting COBY Integration Tests")
logger.info("=" * 50)
tests = [
("Configuration", test_configuration),
("Database Connection", test_database_connection),
("Data Storage", test_data_storage),
("Batch Operations", test_batch_operations)
]
results = []
for test_name, test_func in tests:
logger.info(f"\n🧪 Running {test_name} test...")
try:
result = await test_func()
results.append((test_name, result))
if result:
logger.info(f"{test_name}: PASSED")
else:
logger.error(f"{test_name}: FAILED")
except Exception as e:
logger.error(f"{test_name}: ERROR - {e}")
results.append((test_name, False))
# Summary
logger.info("\n" + "=" * 50)
logger.info("📋 TEST SUMMARY")
logger.info("=" * 50)
passed = sum(1 for _, result in results if result)
total = len(results)
for test_name, result in results:
status = "✅ PASSED" if result else "❌ FAILED"
logger.info(f"{test_name:20} {status}")
logger.info(f"\nOverall: {passed}/{total} tests passed")
if passed == total:
logger.info("🎉 All tests passed! System is ready.")
return True
else:
logger.error("⚠️ Some tests failed. Check configuration and database connection.")
return False
if __name__ == "__main__":
print("COBY Integration Test Suite")
print("=" * 30)
# Run tests
success = asyncio.run(run_all_tests())
if success:
print("\n🎉 Integration tests completed successfully!")
print("The system is ready for the next development phase.")
sys.exit(0)
else:
print("\n❌ Integration tests failed!")
print("Please check the logs and fix any issues before proceeding.")
sys.exit(1)

3
COBY/tests/__init__.py Normal file
View File

@ -0,0 +1,3 @@
"""
Test modules for the COBY multi-exchange data aggregation system.
"""

View File

@ -0,0 +1,271 @@
"""
Comprehensive tests for all exchange connectors.
Tests the consistency and compatibility across all implemented connectors.
"""
import asyncio
import pytest
from unittest.mock import Mock, AsyncMock
from ..connectors.binance_connector import BinanceConnector
from ..connectors.coinbase_connector import CoinbaseConnector
from ..connectors.kraken_connector import KrakenConnector
from ..connectors.bybit_connector import BybitConnector
from ..connectors.okx_connector import OKXConnector
from ..connectors.huobi_connector import HuobiConnector
class TestAllConnectors:
"""Test suite for all exchange connectors."""
@pytest.fixture
def all_connectors(self):
"""Create instances of all connectors for testing."""
return {
'binance': BinanceConnector(),
'coinbase': CoinbaseConnector(use_sandbox=True),
'kraken': KrakenConnector(),
'bybit': BybitConnector(use_testnet=True),
'okx': OKXConnector(use_demo=True),
'huobi': HuobiConnector()
}
def test_all_connectors_initialization(self, all_connectors):
"""Test that all connectors initialize correctly."""
expected_names = ['binance', 'coinbase', 'kraken', 'bybit', 'okx', 'huobi']
for name, connector in all_connectors.items():
assert connector.exchange_name == name
assert hasattr(connector, 'websocket_url')
assert hasattr(connector, 'message_handlers')
assert hasattr(connector, 'subscriptions')
def test_interface_consistency(self, all_connectors):
"""Test that all connectors implement the required interface methods."""
required_methods = [
'connect',
'disconnect',
'subscribe_orderbook',
'subscribe_trades',
'unsubscribe_orderbook',
'unsubscribe_trades',
'get_symbols',
'get_orderbook_snapshot',
'normalize_symbol',
'get_connection_status',
'add_data_callback',
'remove_data_callback'
]
for name, connector in all_connectors.items():
for method in required_methods:
assert hasattr(connector, method), f"{name} missing method {method}"
assert callable(getattr(connector, method)), f"{name}.{method} not callable"
def test_symbol_normalization_consistency(self, all_connectors):
"""Test symbol normalization across all connectors."""
test_symbols = ['BTCUSDT', 'ETHUSDT', 'btcusdt', 'BTC-USDT', 'BTC/USDT']
for name, connector in all_connectors.items():
for symbol in test_symbols:
try:
normalized = connector.normalize_symbol(symbol)
assert isinstance(normalized, str)
assert len(normalized) > 0
print(f"{name}: {symbol} -> {normalized}")
except Exception as e:
print(f"{name} failed to normalize {symbol}: {e}")
@pytest.mark.asyncio
async def test_subscription_interface(self, all_connectors):
"""Test subscription interface consistency."""
for name, connector in all_connectors.items():
# Mock the _send_message method
connector._send_message = AsyncMock(return_value=True)
try:
# Test order book subscription
await connector.subscribe_orderbook('BTCUSDT')
assert 'BTCUSDT' in connector.subscriptions
# Test trade subscription
await connector.subscribe_trades('ETHUSDT')
assert 'ETHUSDT' in connector.subscriptions
# Test unsubscription
await connector.unsubscribe_orderbook('BTCUSDT')
await connector.unsubscribe_trades('ETHUSDT')
print(f"{name} subscription interface works")
except Exception as e:
print(f"{name} subscription interface failed: {e}")
def test_message_type_detection(self, all_connectors):
"""Test message type detection across connectors."""
# Test with generic message structures
test_messages = [
{'type': 'test'},
{'event': 'test'},
{'op': 'test'},
{'ch': 'test'},
{'topic': 'test'},
[1, {}, 'test', 'symbol'], # Kraken format
{'unknown': 'data'}
]
for name, connector in all_connectors.items():
for msg in test_messages:
try:
msg_type = connector._get_message_type(msg)
assert isinstance(msg_type, str)
print(f"{name}: {msg} -> {msg_type}")
except Exception as e:
print(f"{name} failed to detect message type for {msg}: {e}")
def test_statistics_interface(self, all_connectors):
"""Test statistics interface consistency."""
for name, connector in all_connectors.items():
try:
stats = connector.get_stats()
assert isinstance(stats, dict)
assert 'exchange' in stats
assert stats['exchange'] == name
assert 'connection_status' in stats
print(f"{name} statistics interface works")
except Exception as e:
print(f"{name} statistics interface failed: {e}")
def test_callback_system(self, all_connectors):
"""Test callback system consistency."""
for name, connector in all_connectors.items():
try:
# Test adding callback
def test_callback(data):
pass
connector.add_data_callback(test_callback)
assert test_callback in connector.data_callbacks
# Test removing callback
connector.remove_data_callback(test_callback)
assert test_callback not in connector.data_callbacks
print(f"{name} callback system works")
except Exception as e:
print(f"{name} callback system failed: {e}")
def test_connection_status(self, all_connectors):
"""Test connection status interface."""
for name, connector in all_connectors.items():
try:
status = connector.get_connection_status()
assert hasattr(status, 'value') # Should be an enum
# Test is_connected property
is_connected = connector.is_connected
assert isinstance(is_connected, bool)
print(f"{name} connection status interface works")
except Exception as e:
print(f"{name} connection status interface failed: {e}")
async def test_connector_compatibility():
"""Test compatibility across all connectors."""
print("=== Testing All Exchange Connectors ===")
connectors = {
'binance': BinanceConnector(),
'coinbase': CoinbaseConnector(use_sandbox=True),
'kraken': KrakenConnector(),
'bybit': BybitConnector(use_testnet=True),
'okx': OKXConnector(use_demo=True),
'huobi': HuobiConnector()
}
# Test basic functionality
for name, connector in connectors.items():
try:
print(f"\nTesting {name.upper()} connector:")
# Test initialization
assert connector.exchange_name == name
print(f" ✓ Initialization: {connector.exchange_name}")
# Test symbol normalization
btc_symbol = connector.normalize_symbol('BTCUSDT')
eth_symbol = connector.normalize_symbol('ETHUSDT')
print(f" ✓ Symbol normalization: BTCUSDT -> {btc_symbol}, ETHUSDT -> {eth_symbol}")
# Test message type detection
test_msg = {'type': 'test'} if name != 'kraken' else [1, {}, 'test', 'symbol']
msg_type = connector._get_message_type(test_msg)
print(f" ✓ Message type detection: {msg_type}")
# Test statistics
stats = connector.get_stats()
print(f" ✓ Statistics: {len(stats)} fields")
# Test connection status
status = connector.get_connection_status()
print(f" ✓ Connection status: {status.value}")
print(f"{name.upper()} connector passed all tests")
except Exception as e:
print(f"{name.upper()} connector failed: {e}")
print("\n=== All Connector Tests Completed ===")
return True
async def test_multi_connector_data_flow():
"""Test data flow across multiple connectors simultaneously."""
print("=== Testing Multi-Connector Data Flow ===")
connectors = {
'binance': BinanceConnector(),
'coinbase': CoinbaseConnector(use_sandbox=True),
'kraken': KrakenConnector()
}
# Set up data collection
received_data = {name: [] for name in connectors.keys()}
def create_callback(exchange_name):
def callback(data):
received_data[exchange_name].append(data)
print(f"Received data from {exchange_name}: {type(data).__name__}")
return callback
# Add callbacks to all connectors
for name, connector in connectors.items():
connector.add_data_callback(create_callback(name))
connector._send_message = AsyncMock(return_value=True)
# Test subscription to same symbol across exchanges
symbol = 'BTCUSDT'
for name, connector in connectors.items():
try:
await connector.subscribe_orderbook(symbol)
await connector.subscribe_trades(symbol)
print(f"✓ Subscribed to {symbol} on {name}")
except Exception as e:
print(f"✗ Failed to subscribe to {symbol} on {name}: {e}")
print("Multi-connector data flow test completed")
return True
if __name__ == "__main__":
# Run all tests
async def run_all_tests():
await test_connector_compatibility()
await test_multi_connector_data_flow()
print("✅ All connector tests completed successfully")
asyncio.run(run_all_tests())

View File

@ -0,0 +1,341 @@
"""
Tests for Binance exchange connector.
"""
import pytest
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime, timezone
from ..connectors.binance_connector import BinanceConnector
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
@pytest.fixture
def binance_connector():
"""Create Binance connector for testing"""
return BinanceConnector()
@pytest.fixture
def sample_binance_orderbook_data():
"""Sample Binance order book data"""
return {
"lastUpdateId": 1027024,
"bids": [
["4.00000000", "431.00000000"],
["3.99000000", "9.00000000"]
],
"asks": [
["4.00000200", "12.00000000"],
["4.01000000", "18.00000000"]
]
}
@pytest.fixture
def sample_binance_depth_update():
"""Sample Binance depth update message"""
return {
"e": "depthUpdate",
"E": 1672515782136,
"s": "BTCUSDT",
"U": 157,
"u": 160,
"b": [
["50000.00", "0.25"],
["49999.00", "0.50"]
],
"a": [
["50001.00", "0.30"],
["50002.00", "0.40"]
]
}
@pytest.fixture
def sample_binance_trade_update():
"""Sample Binance trade update message"""
return {
"e": "trade",
"E": 1672515782136,
"s": "BTCUSDT",
"t": 12345,
"p": "50000.50",
"q": "0.10",
"b": 88,
"a": 50,
"T": 1672515782134,
"m": False,
"M": True
}
class TestBinanceConnector:
"""Test cases for BinanceConnector"""
def test_initialization(self, binance_connector):
"""Test connector initialization"""
assert binance_connector.exchange_name == "binance"
assert binance_connector.websocket_url == BinanceConnector.WEBSOCKET_URL
assert len(binance_connector.message_handlers) >= 3
assert binance_connector.stream_id == 1
assert binance_connector.active_streams == []
def test_normalize_symbol(self, binance_connector):
"""Test symbol normalization"""
# Test standard format
assert binance_connector.normalize_symbol("BTCUSDT") == "BTCUSDT"
# Test with separators
assert binance_connector.normalize_symbol("BTC-USDT") == "BTCUSDT"
assert binance_connector.normalize_symbol("BTC/USDT") == "BTCUSDT"
# Test lowercase
assert binance_connector.normalize_symbol("btcusdt") == "BTCUSDT"
# Test invalid symbol
with pytest.raises(Exception):
binance_connector.normalize_symbol("")
def test_get_message_type(self, binance_connector):
"""Test message type detection"""
# Test depth update
depth_msg = {"e": "depthUpdate", "s": "BTCUSDT"}
assert binance_connector._get_message_type(depth_msg) == "depthUpdate"
# Test trade update
trade_msg = {"e": "trade", "s": "BTCUSDT"}
assert binance_connector._get_message_type(trade_msg) == "trade"
# Test error message
error_msg = {"error": {"code": -1121, "msg": "Invalid symbol"}}
assert binance_connector._get_message_type(error_msg) == "error"
# Test unknown message
unknown_msg = {"data": "something"}
assert binance_connector._get_message_type(unknown_msg) == "unknown"
def test_parse_orderbook_snapshot(self, binance_connector, sample_binance_orderbook_data):
"""Test order book snapshot parsing"""
orderbook = binance_connector._parse_orderbook_snapshot(
sample_binance_orderbook_data,
"BTCUSDT"
)
assert isinstance(orderbook, OrderBookSnapshot)
assert orderbook.symbol == "BTCUSDT"
assert orderbook.exchange == "binance"
assert len(orderbook.bids) == 2
assert len(orderbook.asks) == 2
assert orderbook.sequence_id == 1027024
# Check bid data
assert orderbook.bids[0].price == 4.0
assert orderbook.bids[0].size == 431.0
# Check ask data
assert orderbook.asks[0].price == 4.000002
assert orderbook.asks[0].size == 12.0
@pytest.mark.asyncio
async def test_handle_orderbook_update(self, binance_connector, sample_binance_depth_update):
"""Test order book update handling"""
# Mock callback
callback_called = False
received_data = None
def mock_callback(data):
nonlocal callback_called, received_data
callback_called = True
received_data = data
binance_connector.add_data_callback(mock_callback)
# Handle update
await binance_connector._handle_orderbook_update(sample_binance_depth_update)
# Verify callback was called
assert callback_called
assert isinstance(received_data, OrderBookSnapshot)
assert received_data.symbol == "BTCUSDT"
assert received_data.exchange == "binance"
assert len(received_data.bids) == 2
assert len(received_data.asks) == 2
@pytest.mark.asyncio
async def test_handle_trade_update(self, binance_connector, sample_binance_trade_update):
"""Test trade update handling"""
# Mock callback
callback_called = False
received_data = None
def mock_callback(data):
nonlocal callback_called, received_data
callback_called = True
received_data = data
binance_connector.add_data_callback(mock_callback)
# Handle update
await binance_connector._handle_trade_update(sample_binance_trade_update)
# Verify callback was called
assert callback_called
assert isinstance(received_data, TradeEvent)
assert received_data.symbol == "BTCUSDT"
assert received_data.exchange == "binance"
assert received_data.price == 50000.50
assert received_data.size == 0.10
assert received_data.side == "buy" # m=False means buyer is not maker
assert received_data.trade_id == "12345"
@pytest.mark.asyncio
async def test_subscribe_orderbook(self, binance_connector):
"""Test order book subscription"""
# Mock WebSocket send
binance_connector._send_message = AsyncMock(return_value=True)
# Subscribe
await binance_connector.subscribe_orderbook("BTCUSDT")
# Verify subscription was sent
binance_connector._send_message.assert_called_once()
call_args = binance_connector._send_message.call_args[0][0]
assert call_args["method"] == "SUBSCRIBE"
assert "btcusdt@depth@100ms" in call_args["params"]
assert call_args["id"] == 1
# Verify tracking
assert "BTCUSDT" in binance_connector.subscriptions
assert "orderbook" in binance_connector.subscriptions["BTCUSDT"]
assert "btcusdt@depth@100ms" in binance_connector.active_streams
assert binance_connector.stream_id == 2
@pytest.mark.asyncio
async def test_subscribe_trades(self, binance_connector):
"""Test trade subscription"""
# Mock WebSocket send
binance_connector._send_message = AsyncMock(return_value=True)
# Subscribe
await binance_connector.subscribe_trades("ETHUSDT")
# Verify subscription was sent
binance_connector._send_message.assert_called_once()
call_args = binance_connector._send_message.call_args[0][0]
assert call_args["method"] == "SUBSCRIBE"
assert "ethusdt@trade" in call_args["params"]
assert call_args["id"] == 1
# Verify tracking
assert "ETHUSDT" in binance_connector.subscriptions
assert "trades" in binance_connector.subscriptions["ETHUSDT"]
assert "ethusdt@trade" in binance_connector.active_streams
@pytest.mark.asyncio
async def test_unsubscribe_orderbook(self, binance_connector):
"""Test order book unsubscription"""
# Setup initial subscription
binance_connector.subscriptions["BTCUSDT"] = ["orderbook"]
binance_connector.active_streams.append("btcusdt@depth@100ms")
# Mock WebSocket send
binance_connector._send_message = AsyncMock(return_value=True)
# Unsubscribe
await binance_connector.unsubscribe_orderbook("BTCUSDT")
# Verify unsubscription was sent
binance_connector._send_message.assert_called_once()
call_args = binance_connector._send_message.call_args[0][0]
assert call_args["method"] == "UNSUBSCRIBE"
assert "btcusdt@depth@100ms" in call_args["params"]
# Verify tracking removal
assert "BTCUSDT" not in binance_connector.subscriptions
assert "btcusdt@depth@100ms" not in binance_connector.active_streams
@pytest.mark.asyncio
@patch('aiohttp.ClientSession.get')
async def test_get_symbols(self, mock_get, binance_connector):
"""Test getting available symbols"""
# Mock API response
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value={
"symbols": [
{"symbol": "BTCUSDT", "status": "TRADING"},
{"symbol": "ETHUSDT", "status": "TRADING"},
{"symbol": "ADAUSDT", "status": "BREAK"} # Should be filtered out
]
})
mock_get.return_value.__aenter__.return_value = mock_response
# Get symbols
symbols = await binance_connector.get_symbols()
# Verify results
assert len(symbols) == 2
assert "BTCUSDT" in symbols
assert "ETHUSDT" in symbols
assert "ADAUSDT" not in symbols # Filtered out due to status
@pytest.mark.asyncio
@patch('aiohttp.ClientSession.get')
async def test_get_orderbook_snapshot(self, mock_get, binance_connector, sample_binance_orderbook_data):
"""Test getting order book snapshot"""
# Mock API response
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=sample_binance_orderbook_data)
mock_get.return_value.__aenter__.return_value = mock_response
# Get order book snapshot
orderbook = await binance_connector.get_orderbook_snapshot("BTCUSDT", depth=20)
# Verify results
assert isinstance(orderbook, OrderBookSnapshot)
assert orderbook.symbol == "BTCUSDT"
assert orderbook.exchange == "binance"
assert len(orderbook.bids) == 2
assert len(orderbook.asks) == 2
def test_get_binance_stats(self, binance_connector):
"""Test getting Binance-specific statistics"""
# Add some test data
binance_connector.active_streams = ["btcusdt@depth@100ms", "ethusdt@trade"]
binance_connector.stream_id = 5
stats = binance_connector.get_binance_stats()
# Verify Binance-specific stats
assert stats['active_streams'] == 2
assert len(stats['stream_list']) == 2
assert stats['next_stream_id'] == 5
# Verify base stats are included
assert 'exchange' in stats
assert 'connection_status' in stats
assert 'message_count' in stats
if __name__ == "__main__":
# Run a simple test
async def simple_test():
connector = BinanceConnector()
# Test symbol normalization
normalized = connector.normalize_symbol("BTC-USDT")
print(f"Symbol normalization: BTC-USDT -> {normalized}")
# Test message type detection
msg_type = connector._get_message_type({"e": "depthUpdate"})
print(f"Message type detection: {msg_type}")
print("Simple Binance connector test completed")
asyncio.run(simple_test())

View File

@ -0,0 +1,321 @@
"""
Unit tests for Bybit exchange connector.
"""
import asyncio
import pytest
from unittest.mock import Mock, AsyncMock, patch
from datetime import datetime, timezone
from ..connectors.bybit_connector import BybitConnector
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
class TestBybitConnector:
"""Test suite for Bybit connector."""
@pytest.fixture
def connector(self):
"""Create connector instance for testing."""
return BybitConnector(use_testnet=True)
def test_initialization(self, connector):
"""Test connector initializes correctly."""
assert connector.exchange_name == "bybit"
assert connector.use_testnet is True
assert connector.TESTNET_URL in connector.websocket_url
assert 'orderbook' in connector.message_handlers
assert 'publicTrade' in connector.message_handlers
def test_symbol_normalization(self, connector):
"""Test symbol normalization to Bybit format."""
# Test standard conversions (Bybit uses same format as Binance)
assert connector.normalize_symbol('BTCUSDT') == 'BTCUSDT'
assert connector.normalize_symbol('ETHUSDT') == 'ETHUSDT'
assert connector.normalize_symbol('btcusdt') == 'BTCUSDT'
# Test with separators
assert connector.normalize_symbol('BTC-USDT') == 'BTCUSDT'
assert connector.normalize_symbol('BTC/USDT') == 'BTCUSDT'
def test_message_type_detection(self, connector):
"""Test message type detection."""
# Test orderbook message
orderbook_message = {
'topic': 'orderbook.50.BTCUSDT',
'data': {'b': [], 'a': []}
}
assert connector._get_message_type(orderbook_message) == 'orderbook'
# Test trade message
trade_message = {
'topic': 'publicTrade.BTCUSDT',
'data': []
}
assert connector._get_message_type(trade_message) == 'publicTrade'
# Test operation message
op_message = {'op': 'subscribe', 'success': True}
assert connector._get_message_type(op_message) == 'subscribe'
# Test response message
response_message = {'success': True, 'ret_msg': 'OK'}
assert connector._get_message_type(response_message) == 'response'
@pytest.mark.asyncio
async def test_subscription_methods(self, connector):
"""Test subscription and unsubscription methods."""
# Mock the _send_message method
connector._send_message = AsyncMock(return_value=True)
# Test order book subscription
await connector.subscribe_orderbook('BTCUSDT')
# Verify subscription was tracked
assert 'BTCUSDT' in connector.subscriptions
assert 'orderbook' in connector.subscriptions['BTCUSDT']
assert 'orderbook.50.BTCUSDT' in connector.subscribed_topics
# Verify correct message was sent
connector._send_message.assert_called()
call_args = connector._send_message.call_args[0][0]
assert call_args['op'] == 'subscribe'
assert 'orderbook.50.BTCUSDT' in call_args['args']
# Test trade subscription
await connector.subscribe_trades('ETHUSDT')
assert 'ETHUSDT' in connector.subscriptions
assert 'trades' in connector.subscriptions['ETHUSDT']
assert 'publicTrade.ETHUSDT' in connector.subscribed_topics
# Test unsubscription
await connector.unsubscribe_orderbook('BTCUSDT')
# Verify unsubscription
if 'BTCUSDT' in connector.subscriptions:
assert 'orderbook' not in connector.subscriptions['BTCUSDT']
@pytest.mark.asyncio
async def test_orderbook_snapshot_parsing(self, connector):
"""Test parsing order book snapshot data."""
# Mock order book data from Bybit
mock_data = {
'u': 12345,
'ts': 1609459200000,
'b': [
['50000.00', '1.5'],
['49999.00', '2.0']
],
'a': [
['50001.00', '1.2'],
['50002.00', '0.8']
]
}
# Parse the data
orderbook = connector._parse_orderbook_snapshot(mock_data, 'BTCUSDT')
# Verify parsing
assert isinstance(orderbook, OrderBookSnapshot)
assert orderbook.symbol == 'BTCUSDT'
assert orderbook.exchange == 'bybit'
assert orderbook.sequence_id == 12345
# Verify bids
assert len(orderbook.bids) == 2
assert orderbook.bids[0].price == 50000.00
assert orderbook.bids[0].size == 1.5
# Verify asks
assert len(orderbook.asks) == 2
assert orderbook.asks[0].price == 50001.00
assert orderbook.asks[0].size == 1.2
@pytest.mark.asyncio
async def test_orderbook_update_handling(self, connector):
"""Test handling order book update messages."""
# Mock callback
callback_called = False
received_data = None
def mock_callback(data):
nonlocal callback_called, received_data
callback_called = True
received_data = data
connector.add_data_callback(mock_callback)
# Mock Bybit orderbook update message
update_message = {
'topic': 'orderbook.50.BTCUSDT',
'ts': 1609459200000,
'data': {
'u': 12345,
'b': [['50000.00', '1.5']],
'a': [['50001.00', '1.2']]
}
}
# Handle the message
await connector._handle_orderbook_update(update_message)
# Verify callback was called
assert callback_called
assert isinstance(received_data, OrderBookSnapshot)
assert received_data.symbol == 'BTCUSDT'
assert received_data.exchange == 'bybit'
assert received_data.sequence_id == 12345
@pytest.mark.asyncio
async def test_trade_handling(self, connector):
"""Test handling trade messages."""
# Mock callback
callback_called = False
received_trades = []
def mock_callback(data):
nonlocal callback_called
callback_called = True
received_trades.append(data)
connector.add_data_callback(mock_callback)
# Mock Bybit trade message
trade_message = {
'topic': 'publicTrade.BTCUSDT',
'ts': 1609459200000,
'data': [
{
'T': 1609459200000,
'p': '50000.50',
'v': '0.1',
'S': 'Buy',
'i': '12345'
}
]
}
# Handle the message
await connector._handle_trade_update(trade_message)
# Verify callback was called
assert callback_called
assert len(received_trades) == 1
trade = received_trades[0]
assert isinstance(trade, TradeEvent)
assert trade.symbol == 'BTCUSDT'
assert trade.exchange == 'bybit'
assert trade.price == 50000.50
assert trade.size == 0.1
assert trade.side == 'buy'
assert trade.trade_id == '12345'
@pytest.mark.asyncio
async def test_get_symbols(self, connector):
"""Test getting available symbols."""
# Mock HTTP response
mock_response_data = {
'retCode': 0,
'result': {
'list': [
{
'symbol': 'BTCUSDT',
'status': 'Trading'
},
{
'symbol': 'ETHUSDT',
'status': 'Trading'
},
{
'symbol': 'DISABLEDUSDT',
'status': 'Closed'
}
]
}
}
with patch('aiohttp.ClientSession.get') as mock_get:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=mock_response_data)
mock_get.return_value.__aenter__.return_value = mock_response
symbols = await connector.get_symbols()
# Should only return trading symbols
assert 'BTCUSDT' in symbols
assert 'ETHUSDT' in symbols
assert 'DISABLEDUSDT' not in symbols
@pytest.mark.asyncio
async def test_get_orderbook_snapshot(self, connector):
"""Test getting order book snapshot from REST API."""
# Mock HTTP response
mock_orderbook = {
'retCode': 0,
'result': {
'ts': 1609459200000,
'u': 12345,
'b': [['50000.00', '1.5']],
'a': [['50001.00', '1.2']]
}
}
with patch('aiohttp.ClientSession.get') as mock_get:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=mock_orderbook)
mock_get.return_value.__aenter__.return_value = mock_response
orderbook = await connector.get_orderbook_snapshot('BTCUSDT')
assert isinstance(orderbook, OrderBookSnapshot)
assert orderbook.symbol == 'BTCUSDT'
assert orderbook.exchange == 'bybit'
assert len(orderbook.bids) == 1
assert len(orderbook.asks) == 1
def test_statistics(self, connector):
"""Test getting connector statistics."""
# Add some test data
connector.subscribed_topics.add('orderbook.50.BTCUSDT')
stats = connector.get_bybit_stats()
assert stats['exchange'] == 'bybit'
assert 'orderbook.50.BTCUSDT' in stats['subscribed_topics']
assert stats['use_testnet'] is True
assert 'authenticated' in stats
async def test_bybit_integration():
"""Integration test for Bybit connector."""
connector = BybitConnector(use_testnet=True)
try:
# Test basic functionality
assert connector.exchange_name == "bybit"
# Test symbol normalization
assert connector.normalize_symbol('BTCUSDT') == 'BTCUSDT'
assert connector.normalize_symbol('btc-usdt') == 'BTCUSDT'
# Test message type detection
test_message = {'topic': 'orderbook.50.BTCUSDT', 'data': {}}
assert connector._get_message_type(test_message) == 'orderbook'
print("✓ Bybit connector integration test passed")
return True
except Exception as e:
print(f"✗ Bybit connector integration test failed: {e}")
return False
if __name__ == "__main__":
# Run integration test
success = asyncio.run(test_bybit_integration())
if not success:
exit(1)

View File

@ -0,0 +1,364 @@
"""
Unit tests for Coinbase exchange connector.
"""
import asyncio
import pytest
from unittest.mock import Mock, AsyncMock, patch
from datetime import datetime, timezone
from ..connectors.coinbase_connector import CoinbaseConnector
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
class TestCoinbaseConnector:
"""Test suite for Coinbase connector."""
@pytest.fixture
def connector(self):
"""Create connector instance for testing."""
return CoinbaseConnector(use_sandbox=True)
def test_initialization(self, connector):
"""Test connector initializes correctly."""
assert connector.exchange_name == "coinbase"
assert connector.use_sandbox is True
assert connector.SANDBOX_URL in connector.websocket_url
assert 'l2update' in connector.message_handlers
assert 'match' in connector.message_handlers
def test_symbol_normalization(self, connector):
"""Test symbol normalization to Coinbase format."""
# Test standard conversions
assert connector.normalize_symbol('BTCUSDT') == 'BTC-USD'
assert connector.normalize_symbol('ETHUSDT') == 'ETH-USD'
assert connector.normalize_symbol('ADAUSDT') == 'ADA-USD'
# Test generic conversion
assert connector.normalize_symbol('LINKUSDT') == 'LINK-USD'
# Test already correct format
assert connector.normalize_symbol('BTC-USD') == 'BTC-USD'
def test_symbol_denormalization(self, connector):
"""Test converting Coinbase format back to standard."""
assert connector._denormalize_symbol('BTC-USD') == 'BTCUSDT'
assert connector._denormalize_symbol('ETH-USD') == 'ETHUSDT'
assert connector._denormalize_symbol('ADA-USD') == 'ADAUSDT'
# Test other quote currencies
assert connector._denormalize_symbol('BTC-EUR') == 'BTCEUR'
def test_message_type_detection(self, connector):
"""Test message type detection."""
# Test l2update message
l2_message = {'type': 'l2update', 'product_id': 'BTC-USD'}
assert connector._get_message_type(l2_message) == 'l2update'
# Test match message
match_message = {'type': 'match', 'product_id': 'BTC-USD'}
assert connector._get_message_type(match_message) == 'match'
# Test error message
error_message = {'type': 'error', 'message': 'Invalid signature'}
assert connector._get_message_type(error_message) == 'error'
# Test unknown message
unknown_message = {'data': 'something'}
assert connector._get_message_type(unknown_message) == 'unknown'
@pytest.mark.asyncio
async def test_subscription_methods(self, connector):
"""Test subscription and unsubscription methods."""
# Mock the _send_message method
connector._send_message = AsyncMock(return_value=True)
# Test order book subscription
await connector.subscribe_orderbook('BTCUSDT')
# Verify subscription was tracked
assert 'BTCUSDT' in connector.subscriptions
assert 'orderbook' in connector.subscriptions['BTCUSDT']
assert 'level2' in connector.subscribed_channels
assert 'BTC-USD' in connector.product_ids
# Verify correct message was sent
connector._send_message.assert_called()
call_args = connector._send_message.call_args[0][0]
assert call_args['type'] == 'subscribe'
assert 'BTC-USD' in call_args['product_ids']
assert 'level2' in call_args['channels']
# Test trade subscription
await connector.subscribe_trades('ETHUSDT')
assert 'ETHUSDT' in connector.subscriptions
assert 'trades' in connector.subscriptions['ETHUSDT']
assert 'matches' in connector.subscribed_channels
assert 'ETH-USD' in connector.product_ids
# Test unsubscription
await connector.unsubscribe_orderbook('BTCUSDT')
# Verify unsubscription
if 'BTCUSDT' in connector.subscriptions:
assert 'orderbook' not in connector.subscriptions['BTCUSDT']
@pytest.mark.asyncio
async def test_orderbook_snapshot_parsing(self, connector):
"""Test parsing order book snapshot data."""
# Mock order book data from Coinbase
mock_data = {
'sequence': 12345,
'bids': [
['50000.00', '1.5', 1],
['49999.00', '2.0', 2]
],
'asks': [
['50001.00', '1.2', 1],
['50002.00', '0.8', 1]
]
}
# Parse the data
orderbook = connector._parse_orderbook_snapshot(mock_data, 'BTCUSDT')
# Verify parsing
assert isinstance(orderbook, OrderBookSnapshot)
assert orderbook.symbol == 'BTCUSDT'
assert orderbook.exchange == 'coinbase'
assert orderbook.sequence_id == 12345
# Verify bids
assert len(orderbook.bids) == 2
assert orderbook.bids[0].price == 50000.00
assert orderbook.bids[0].size == 1.5
assert orderbook.bids[1].price == 49999.00
assert orderbook.bids[1].size == 2.0
# Verify asks
assert len(orderbook.asks) == 2
assert orderbook.asks[0].price == 50001.00
assert orderbook.asks[0].size == 1.2
assert orderbook.asks[1].price == 50002.00
assert orderbook.asks[1].size == 0.8
@pytest.mark.asyncio
async def test_orderbook_update_handling(self, connector):
"""Test handling order book l2update messages."""
# Mock callback
callback_called = False
received_data = None
def mock_callback(data):
nonlocal callback_called, received_data
callback_called = True
received_data = data
connector.add_data_callback(mock_callback)
# Mock l2update message
update_message = {
'type': 'l2update',
'product_id': 'BTC-USD',
'time': '2023-01-01T12:00:00.000000Z',
'changes': [
['buy', '50000.00', '1.5'],
['sell', '50001.00', '1.2']
]
}
# Handle the message
await connector._handle_orderbook_update(update_message)
# Verify callback was called
assert callback_called
assert isinstance(received_data, OrderBookSnapshot)
assert received_data.symbol == 'BTCUSDT'
assert received_data.exchange == 'coinbase'
# Verify bids and asks
assert len(received_data.bids) == 1
assert received_data.bids[0].price == 50000.00
assert received_data.bids[0].size == 1.5
assert len(received_data.asks) == 1
assert received_data.asks[0].price == 50001.00
assert received_data.asks[0].size == 1.2
@pytest.mark.asyncio
async def test_trade_handling(self, connector):
"""Test handling trade (match) messages."""
# Mock callback
callback_called = False
received_data = None
def mock_callback(data):
nonlocal callback_called, received_data
callback_called = True
received_data = data
connector.add_data_callback(mock_callback)
# Mock match message
trade_message = {
'type': 'match',
'product_id': 'BTC-USD',
'time': '2023-01-01T12:00:00.000000Z',
'price': '50000.50',
'size': '0.1',
'side': 'buy',
'trade_id': 12345
}
# Handle the message
await connector._handle_trade_update(trade_message)
# Verify callback was called
assert callback_called
assert isinstance(received_data, TradeEvent)
assert received_data.symbol == 'BTCUSDT'
assert received_data.exchange == 'coinbase'
assert received_data.price == 50000.50
assert received_data.size == 0.1
assert received_data.side == 'buy'
assert received_data.trade_id == '12345'
@pytest.mark.asyncio
async def test_error_handling(self, connector):
"""Test error message handling."""
# Test error message
error_message = {
'type': 'error',
'message': 'Invalid signature',
'reason': 'Authentication failed'
}
# Should not raise exception
await connector._handle_error_message(error_message)
@pytest.mark.asyncio
async def test_get_symbols(self, connector):
"""Test getting available symbols."""
# Mock HTTP response
mock_products = [
{
'id': 'BTC-USD',
'status': 'online',
'trading_disabled': False
},
{
'id': 'ETH-USD',
'status': 'online',
'trading_disabled': False
},
{
'id': 'DISABLED-USD',
'status': 'offline',
'trading_disabled': True
}
]
with patch('aiohttp.ClientSession.get') as mock_get:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=mock_products)
mock_get.return_value.__aenter__.return_value = mock_response
symbols = await connector.get_symbols()
# Should only return online, enabled symbols
assert 'BTCUSDT' in symbols
assert 'ETHUSDT' in symbols
assert 'DISABLEDUSDT' not in symbols
@pytest.mark.asyncio
async def test_get_orderbook_snapshot(self, connector):
"""Test getting order book snapshot from REST API."""
# Mock HTTP response
mock_orderbook = {
'sequence': 12345,
'bids': [['50000.00', '1.5']],
'asks': [['50001.00', '1.2']]
}
with patch('aiohttp.ClientSession.get') as mock_get:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=mock_orderbook)
mock_get.return_value.__aenter__.return_value = mock_response
orderbook = await connector.get_orderbook_snapshot('BTCUSDT')
assert isinstance(orderbook, OrderBookSnapshot)
assert orderbook.symbol == 'BTCUSDT'
assert orderbook.exchange == 'coinbase'
assert len(orderbook.bids) == 1
assert len(orderbook.asks) == 1
def test_authentication_headers(self, connector):
"""Test authentication header generation."""
# Set up credentials
connector.api_key = 'test_key'
connector.api_secret = 'dGVzdF9zZWNyZXQ=' # base64 encoded 'test_secret'
connector.passphrase = 'test_passphrase'
# Test message
test_message = {'type': 'subscribe', 'channels': ['level2']}
# Generate headers
headers = connector._get_auth_headers(test_message)
# Verify headers are present
assert 'CB-ACCESS-KEY' in headers
assert 'CB-ACCESS-SIGN' in headers
assert 'CB-ACCESS-TIMESTAMP' in headers
assert 'CB-ACCESS-PASSPHRASE' in headers
assert headers['CB-ACCESS-KEY'] == 'test_key'
assert headers['CB-ACCESS-PASSPHRASE'] == 'test_passphrase'
def test_statistics(self, connector):
"""Test getting connector statistics."""
# Add some test data
connector.subscribed_channels.add('level2')
connector.product_ids.add('BTC-USD')
stats = connector.get_coinbase_stats()
assert stats['exchange'] == 'coinbase'
assert 'level2' in stats['subscribed_channels']
assert 'BTC-USD' in stats['product_ids']
assert stats['use_sandbox'] is True
assert 'authenticated' in stats
async def test_coinbase_integration():
"""Integration test for Coinbase connector."""
connector = CoinbaseConnector(use_sandbox=True)
try:
# Test basic functionality
assert connector.exchange_name == "coinbase"
# Test symbol normalization
assert connector.normalize_symbol('BTCUSDT') == 'BTC-USD'
assert connector._denormalize_symbol('BTC-USD') == 'BTCUSDT'
# Test message type detection
test_message = {'type': 'l2update', 'product_id': 'BTC-USD'}
assert connector._get_message_type(test_message) == 'l2update'
print("✓ Coinbase connector integration test passed")
return True
except Exception as e:
print(f"✗ Coinbase connector integration test failed: {e}")
return False
if __name__ == "__main__":
# Run integration test
success = asyncio.run(test_coinbase_integration())
if not success:
exit(1)

View File

@ -0,0 +1,304 @@
"""
Tests for data processing components.
"""
import pytest
from datetime import datetime, timezone
from ..processing.data_processor import StandardDataProcessor
from ..processing.quality_checker import DataQualityChecker
from ..processing.anomaly_detector import AnomalyDetector
from ..processing.metrics_calculator import MetricsCalculator
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
@pytest.fixture
def data_processor():
"""Create data processor for testing"""
return StandardDataProcessor()
@pytest.fixture
def quality_checker():
"""Create quality checker for testing"""
return DataQualityChecker()
@pytest.fixture
def anomaly_detector():
"""Create anomaly detector for testing"""
return AnomalyDetector()
@pytest.fixture
def metrics_calculator():
"""Create metrics calculator for testing"""
return MetricsCalculator()
@pytest.fixture
def sample_orderbook():
"""Create sample order book for testing"""
return OrderBookSnapshot(
symbol="BTCUSDT",
exchange="binance",
timestamp=datetime.now(timezone.utc),
bids=[
PriceLevel(price=50000.0, size=1.5),
PriceLevel(price=49999.0, size=2.0),
PriceLevel(price=49998.0, size=1.0)
],
asks=[
PriceLevel(price=50001.0, size=1.0),
PriceLevel(price=50002.0, size=1.5),
PriceLevel(price=50003.0, size=2.0)
]
)
@pytest.fixture
def sample_trade():
"""Create sample trade for testing"""
return TradeEvent(
symbol="BTCUSDT",
exchange="binance",
timestamp=datetime.now(timezone.utc),
price=50000.5,
size=0.1,
side="buy",
trade_id="test_trade_123"
)
class TestDataQualityChecker:
"""Test cases for DataQualityChecker"""
def test_orderbook_quality_check(self, quality_checker, sample_orderbook):
"""Test order book quality checking"""
quality_score, issues = quality_checker.check_orderbook_quality(sample_orderbook)
assert 0.0 <= quality_score <= 1.0
assert isinstance(issues, list)
# Good order book should have high quality score
assert quality_score > 0.8
def test_trade_quality_check(self, quality_checker, sample_trade):
"""Test trade quality checking"""
quality_score, issues = quality_checker.check_trade_quality(sample_trade)
assert 0.0 <= quality_score <= 1.0
assert isinstance(issues, list)
# Good trade should have high quality score
assert quality_score > 0.8
def test_invalid_orderbook_detection(self, quality_checker):
"""Test detection of invalid order book"""
# Create invalid order book with crossed spread
invalid_orderbook = OrderBookSnapshot(
symbol="BTCUSDT",
exchange="binance",
timestamp=datetime.now(timezone.utc),
bids=[PriceLevel(price=50002.0, size=1.0)], # Bid higher than ask
asks=[PriceLevel(price=50001.0, size=1.0)] # Ask lower than bid
)
quality_score, issues = quality_checker.check_orderbook_quality(invalid_orderbook)
assert quality_score < 0.8
assert any("crossed book" in issue.lower() for issue in issues)
class TestAnomalyDetector:
"""Test cases for AnomalyDetector"""
def test_orderbook_anomaly_detection(self, anomaly_detector, sample_orderbook):
"""Test order book anomaly detection"""
# First few order books should not trigger anomalies
for _ in range(5):
anomalies = anomaly_detector.detect_orderbook_anomalies(sample_orderbook)
assert isinstance(anomalies, list)
def test_trade_anomaly_detection(self, anomaly_detector, sample_trade):
"""Test trade anomaly detection"""
# First few trades should not trigger anomalies
for _ in range(5):
anomalies = anomaly_detector.detect_trade_anomalies(sample_trade)
assert isinstance(anomalies, list)
def test_price_spike_detection(self, anomaly_detector):
"""Test price spike detection"""
# Create normal order books
for i in range(20):
normal_orderbook = OrderBookSnapshot(
symbol="BTCUSDT",
exchange="binance",
timestamp=datetime.now(timezone.utc),
bids=[PriceLevel(price=50000.0 + i, size=1.0)],
asks=[PriceLevel(price=50001.0 + i, size=1.0)]
)
anomaly_detector.detect_orderbook_anomalies(normal_orderbook)
# Create order book with price spike
spike_orderbook = OrderBookSnapshot(
symbol="BTCUSDT",
exchange="binance",
timestamp=datetime.now(timezone.utc),
bids=[PriceLevel(price=60000.0, size=1.0)], # 20% spike
asks=[PriceLevel(price=60001.0, size=1.0)]
)
anomalies = anomaly_detector.detect_orderbook_anomalies(spike_orderbook)
assert len(anomalies) > 0
assert any("spike" in anomaly.lower() for anomaly in anomalies)
class TestMetricsCalculator:
"""Test cases for MetricsCalculator"""
def test_orderbook_metrics_calculation(self, metrics_calculator, sample_orderbook):
"""Test order book metrics calculation"""
metrics = metrics_calculator.calculate_orderbook_metrics(sample_orderbook)
assert metrics.symbol == "BTCUSDT"
assert metrics.exchange == "binance"
assert metrics.mid_price == 50000.5 # (50000 + 50001) / 2
assert metrics.spread == 1.0 # 50001 - 50000
assert metrics.spread_percentage > 0
assert metrics.bid_volume == 4.5 # 1.5 + 2.0 + 1.0
assert metrics.ask_volume == 4.5 # 1.0 + 1.5 + 2.0
assert metrics.volume_imbalance == 0.0 # Equal volumes
def test_imbalance_metrics_calculation(self, metrics_calculator, sample_orderbook):
"""Test imbalance metrics calculation"""
imbalance = metrics_calculator.calculate_imbalance_metrics(sample_orderbook)
assert imbalance.symbol == "BTCUSDT"
assert -1.0 <= imbalance.volume_imbalance <= 1.0
assert -1.0 <= imbalance.price_imbalance <= 1.0
assert -1.0 <= imbalance.depth_imbalance <= 1.0
assert -1.0 <= imbalance.momentum_score <= 1.0
def test_liquidity_score_calculation(self, metrics_calculator, sample_orderbook):
"""Test liquidity score calculation"""
liquidity_score = metrics_calculator.calculate_liquidity_score(sample_orderbook)
assert 0.0 <= liquidity_score <= 1.0
assert liquidity_score > 0.5 # Good order book should have decent liquidity
class TestStandardDataProcessor:
"""Test cases for StandardDataProcessor"""
def test_data_validation(self, data_processor, sample_orderbook, sample_trade):
"""Test data validation"""
# Valid data should pass validation
assert data_processor.validate_data(sample_orderbook) is True
assert data_processor.validate_data(sample_trade) is True
def test_metrics_calculation(self, data_processor, sample_orderbook):
"""Test metrics calculation through processor"""
metrics = data_processor.calculate_metrics(sample_orderbook)
assert metrics.symbol == "BTCUSDT"
assert metrics.mid_price > 0
assert metrics.spread > 0
def test_anomaly_detection(self, data_processor, sample_orderbook, sample_trade):
"""Test anomaly detection through processor"""
orderbook_anomalies = data_processor.detect_anomalies(sample_orderbook)
trade_anomalies = data_processor.detect_anomalies(sample_trade)
assert isinstance(orderbook_anomalies, list)
assert isinstance(trade_anomalies, list)
def test_data_filtering(self, data_processor, sample_orderbook, sample_trade):
"""Test data filtering"""
# Test symbol filter
criteria = {'symbols': ['BTCUSDT']}
assert data_processor.filter_data(sample_orderbook, criteria) is True
assert data_processor.filter_data(sample_trade, criteria) is True
criteria = {'symbols': ['ETHUSDT']}
assert data_processor.filter_data(sample_orderbook, criteria) is False
assert data_processor.filter_data(sample_trade, criteria) is False
# Test price range filter
criteria = {'price_range': (40000, 60000)}
assert data_processor.filter_data(sample_orderbook, criteria) is True
assert data_processor.filter_data(sample_trade, criteria) is True
criteria = {'price_range': (60000, 70000)}
assert data_processor.filter_data(sample_orderbook, criteria) is False
assert data_processor.filter_data(sample_trade, criteria) is False
def test_data_enrichment(self, data_processor, sample_orderbook, sample_trade):
"""Test data enrichment"""
orderbook_enriched = data_processor.enrich_data(sample_orderbook)
trade_enriched = data_processor.enrich_data(sample_trade)
# Check enriched data structure
assert 'original_data' in orderbook_enriched
assert 'quality_score' in orderbook_enriched
assert 'anomalies' in orderbook_enriched
assert 'processing_timestamp' in orderbook_enriched
assert 'original_data' in trade_enriched
assert 'quality_score' in trade_enriched
assert 'anomalies' in trade_enriched
assert 'trade_value' in trade_enriched
def test_quality_score_calculation(self, data_processor, sample_orderbook, sample_trade):
"""Test quality score calculation"""
orderbook_score = data_processor.get_data_quality_score(sample_orderbook)
trade_score = data_processor.get_data_quality_score(sample_trade)
assert 0.0 <= orderbook_score <= 1.0
assert 0.0 <= trade_score <= 1.0
# Good data should have high quality scores
assert orderbook_score > 0.8
assert trade_score > 0.8
def test_processing_stats(self, data_processor, sample_orderbook, sample_trade):
"""Test processing statistics"""
# Process some data
data_processor.validate_data(sample_orderbook)
data_processor.validate_data(sample_trade)
stats = data_processor.get_processing_stats()
assert 'processed_orderbooks' in stats
assert 'processed_trades' in stats
assert 'quality_failures' in stats
assert 'anomalies_detected' in stats
assert stats['processed_orderbooks'] >= 1
assert stats['processed_trades'] >= 1
if __name__ == "__main__":
# Run simple tests
processor = StandardDataProcessor()
# Test with sample data
orderbook = OrderBookSnapshot(
symbol="BTCUSDT",
exchange="test",
timestamp=datetime.now(timezone.utc),
bids=[PriceLevel(price=50000.0, size=1.0)],
asks=[PriceLevel(price=50001.0, size=1.0)]
)
# Test validation
is_valid = processor.validate_data(orderbook)
print(f"Order book validation: {'PASSED' if is_valid else 'FAILED'}")
# Test metrics
metrics = processor.calculate_metrics(orderbook)
print(f"Metrics calculation: mid_price={metrics.mid_price}, spread={metrics.spread}")
# Test quality score
quality_score = processor.get_data_quality_score(orderbook)
print(f"Quality score: {quality_score:.2f}")
print("Simple data processor test completed")

View File

@ -0,0 +1,398 @@
"""
Unit tests for Kraken exchange connector.
"""
import asyncio
import pytest
from unittest.mock import Mock, AsyncMock, patch
from datetime import datetime, timezone
from ..connectors.kraken_connector import KrakenConnector
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
class TestKrakenConnector:
"""Test suite for Kraken connector."""
@pytest.fixture
def connector(self):
"""Create connector instance for testing."""
return KrakenConnector()
def test_initialization(self, connector):
"""Test connector initializes correctly."""
assert connector.exchange_name == "kraken"
assert connector.WEBSOCKET_URL in connector.websocket_url
assert 'book-25' in connector.message_handlers
assert 'trade' in connector.message_handlers
assert connector.system_status == 'unknown'
def test_symbol_normalization(self, connector):
"""Test symbol normalization to Kraken format."""
# Test standard conversions
assert connector.normalize_symbol('BTCUSDT') == 'XBT/USD'
assert connector.normalize_symbol('ETHUSDT') == 'ETH/USD'
assert connector.normalize_symbol('ADAUSDT') == 'ADA/USD'
# Test generic conversion
assert connector.normalize_symbol('LINKUSDT') == 'LINK/USD'
# Test already correct format
assert connector.normalize_symbol('XBT/USD') == 'XBT/USD'
def test_symbol_denormalization(self, connector):
"""Test converting Kraken format back to standard."""
assert connector._denormalize_symbol('XBT/USD') == 'BTCUSDT'
assert connector._denormalize_symbol('ETH/USD') == 'ETHUSDT'
assert connector._denormalize_symbol('ADA/USD') == 'ADAUSDT'
# Test other quote currencies
assert connector._denormalize_symbol('BTC/EUR') == 'BTCEUR'
def test_message_type_detection(self, connector):
"""Test message type detection."""
# Test order book message (array format)
book_message = [123, {'b': [['50000', '1.5']]}, 'book-25', 'XBT/USD']
assert connector._get_message_type(book_message) == 'book-25'
# Test trade message (array format)
trade_message = [456, [['50000', '0.1', 1609459200, 'b', 'm', '']], 'trade', 'XBT/USD']
assert connector._get_message_type(trade_message) == 'trade'
# Test status message (object format)
status_message = {'event': 'systemStatus', 'status': 'online'}
assert connector._get_message_type(status_message) == 'systemStatus'
# Test subscription message
sub_message = {'event': 'subscriptionStatus', 'status': 'subscribed'}
assert connector._get_message_type(sub_message) == 'subscriptionStatus'
# Test unknown message
unknown_message = {'data': 'something'}
assert connector._get_message_type(unknown_message) == 'unknown'
@pytest.mark.asyncio
async def test_subscription_methods(self, connector):
"""Test subscription and unsubscription methods."""
# Mock the _send_message method
connector._send_message = AsyncMock(return_value=True)
# Test order book subscription
await connector.subscribe_orderbook('BTCUSDT')
# Verify subscription was tracked
assert 'BTCUSDT' in connector.subscriptions
assert 'orderbook' in connector.subscriptions['BTCUSDT']
# Verify correct message was sent
connector._send_message.assert_called()
call_args = connector._send_message.call_args[0][0]
assert call_args['event'] == 'subscribe'
assert 'XBT/USD' in call_args['pair']
assert call_args['subscription']['name'] == 'book'
# Test trade subscription
await connector.subscribe_trades('ETHUSDT')
assert 'ETHUSDT' in connector.subscriptions
assert 'trades' in connector.subscriptions['ETHUSDT']
# Test unsubscription
await connector.unsubscribe_orderbook('BTCUSDT')
# Verify unsubscription message
call_args = connector._send_message.call_args[0][0]
assert call_args['event'] == 'unsubscribe'
@pytest.mark.asyncio
async def test_orderbook_snapshot_parsing(self, connector):
"""Test parsing order book snapshot data."""
# Mock order book data from Kraken
mock_data = {
'bids': [
['50000.00', '1.5', 1609459200],
['49999.00', '2.0', 1609459201]
],
'asks': [
['50001.00', '1.2', 1609459200],
['50002.00', '0.8', 1609459201]
]
}
# Parse the data
orderbook = connector._parse_orderbook_snapshot(mock_data, 'BTCUSDT')
# Verify parsing
assert isinstance(orderbook, OrderBookSnapshot)
assert orderbook.symbol == 'BTCUSDT'
assert orderbook.exchange == 'kraken'
# Verify bids
assert len(orderbook.bids) == 2
assert orderbook.bids[0].price == 50000.00
assert orderbook.bids[0].size == 1.5
assert orderbook.bids[1].price == 49999.00
assert orderbook.bids[1].size == 2.0
# Verify asks
assert len(orderbook.asks) == 2
assert orderbook.asks[0].price == 50001.00
assert orderbook.asks[0].size == 1.2
assert orderbook.asks[1].price == 50002.00
assert orderbook.asks[1].size == 0.8
@pytest.mark.asyncio
async def test_orderbook_update_handling(self, connector):
"""Test handling order book update messages."""
# Mock callback
callback_called = False
received_data = None
def mock_callback(data):
nonlocal callback_called, received_data
callback_called = True
received_data = data
connector.add_data_callback(mock_callback)
# Mock Kraken order book update message
update_message = [
123, # channel ID
{
'b': [['50000.00', '1.5', '1609459200.123456']],
'a': [['50001.00', '1.2', '1609459200.123456']]
},
'book-25',
'XBT/USD'
]
# Handle the message
await connector._handle_orderbook_update(update_message)
# Verify callback was called
assert callback_called
assert isinstance(received_data, OrderBookSnapshot)
assert received_data.symbol == 'BTCUSDT'
assert received_data.exchange == 'kraken'
# Verify channel mapping was stored
assert 123 in connector.channel_map
assert connector.channel_map[123] == ('book-25', 'BTCUSDT')
# Verify bids and asks
assert len(received_data.bids) == 1
assert received_data.bids[0].price == 50000.00
assert received_data.bids[0].size == 1.5
assert len(received_data.asks) == 1
assert received_data.asks[0].price == 50001.00
assert received_data.asks[0].size == 1.2
@pytest.mark.asyncio
async def test_trade_handling(self, connector):
"""Test handling trade messages."""
# Mock callback
callback_called = False
received_trades = []
def mock_callback(data):
nonlocal callback_called
callback_called = True
received_trades.append(data)
connector.add_data_callback(mock_callback)
# Mock Kraken trade message (array of trades)
trade_message = [
456, # channel ID
[
['50000.50', '0.1', 1609459200.123456, 'b', 'm', ''],
['50001.00', '0.05', 1609459201.123456, 's', 'l', '']
],
'trade',
'XBT/USD'
]
# Handle the message
await connector._handle_trade_update(trade_message)
# Verify callback was called
assert callback_called
assert len(received_trades) == 2
# Verify first trade (buy)
trade1 = received_trades[0]
assert isinstance(trade1, TradeEvent)
assert trade1.symbol == 'BTCUSDT'
assert trade1.exchange == 'kraken'
assert trade1.price == 50000.50
assert trade1.size == 0.1
assert trade1.side == 'buy'
# Verify second trade (sell)
trade2 = received_trades[1]
assert trade2.price == 50001.00
assert trade2.size == 0.05
assert trade2.side == 'sell'
# Verify channel mapping was stored
assert 456 in connector.channel_map
assert connector.channel_map[456] == ('trade', 'BTCUSDT')
@pytest.mark.asyncio
async def test_system_status_handling(self, connector):
"""Test handling system status messages."""
# Mock system status message
status_message = {
'event': 'systemStatus',
'status': 'online',
'version': '1.0.0'
}
# Handle the message
await connector._handle_system_status(status_message)
# Verify status was updated
assert connector.system_status == 'online'
@pytest.mark.asyncio
async def test_subscription_status_handling(self, connector):
"""Test handling subscription status messages."""
# Mock subscription status message
sub_message = {
'event': 'subscriptionStatus',
'status': 'subscribed',
'channelName': 'book-25',
'channelID': 123,
'pair': 'XBT/USD',
'subscription': {'name': 'book', 'depth': 25}
}
# Handle the message
await connector._handle_subscription_status(sub_message)
# Verify channel mapping was stored
assert 123 in connector.channel_map
assert connector.channel_map[123] == ('book-25', 'BTCUSDT')
@pytest.mark.asyncio
async def test_get_symbols(self, connector):
"""Test getting available symbols."""
# Mock HTTP response
mock_response_data = {
'error': [],
'result': {
'XBTUSD': {
'wsname': 'XBT/USD'
},
'ETHUSD': {
'wsname': 'ETH/USD'
},
'XBTUSD.d': { # Dark pool - should be filtered out
'wsname': 'XBT/USD.d'
}
}
}
with patch('aiohttp.ClientSession.get') as mock_get:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=mock_response_data)
mock_get.return_value.__aenter__.return_value = mock_response
symbols = await connector.get_symbols()
# Should only return non-dark pool symbols
assert 'BTCUSDT' in symbols
assert 'ETHUSDT' in symbols
# Dark pool should be filtered out
assert len([s for s in symbols if '.d' in s]) == 0
@pytest.mark.asyncio
async def test_get_orderbook_snapshot(self, connector):
"""Test getting order book snapshot from REST API."""
# Mock HTTP response
mock_orderbook = {
'error': [],
'result': {
'XBTUSD': {
'bids': [['50000.00', '1.5', 1609459200]],
'asks': [['50001.00', '1.2', 1609459200]]
}
}
}
with patch('aiohttp.ClientSession.get') as mock_get:
mock_response = AsyncMock()
mock_response.status = 200
mock_response.json = AsyncMock(return_value=mock_orderbook)
mock_get.return_value.__aenter__.return_value = mock_response
orderbook = await connector.get_orderbook_snapshot('BTCUSDT')
assert isinstance(orderbook, OrderBookSnapshot)
assert orderbook.symbol == 'BTCUSDT'
assert orderbook.exchange == 'kraken'
assert len(orderbook.bids) == 1
assert len(orderbook.asks) == 1
def test_authentication_token(self, connector):
"""Test authentication token generation."""
# Set up credentials
connector.api_key = 'test_key'
connector.api_secret = 'dGVzdF9zZWNyZXQ=' # base64 encoded
# Generate token
token = connector._get_auth_token()
# Should return a token (simplified implementation)
assert isinstance(token, str)
assert len(token) > 0
def test_statistics(self, connector):
"""Test getting connector statistics."""
# Add some test data
connector.system_status = 'online'
connector.channel_map[123] = ('book-25', 'BTCUSDT')
stats = connector.get_kraken_stats()
assert stats['exchange'] == 'kraken'
assert stats['system_status'] == 'online'
assert stats['channel_mappings'] == 1
assert 'authenticated' in stats
async def test_kraken_integration():
"""Integration test for Kraken connector."""
connector = KrakenConnector()
try:
# Test basic functionality
assert connector.exchange_name == "kraken"
# Test symbol normalization
assert connector.normalize_symbol('BTCUSDT') == 'XBT/USD'
assert connector._denormalize_symbol('XBT/USD') == 'BTCUSDT'
# Test message type detection
test_message = [123, {}, 'book-25', 'XBT/USD']
assert connector._get_message_type(test_message) == 'book-25'
# Test status message
status_message = {'event': 'systemStatus', 'status': 'online'}
assert connector._get_message_type(status_message) == 'systemStatus'
print("✓ Kraken connector integration test passed")
return True
except Exception as e:
print(f"✗ Kraken connector integration test failed: {e}")
return False
if __name__ == "__main__":
# Run integration test
success = asyncio.run(test_kraken_integration())
if not success:
exit(1)

View File

@ -0,0 +1,385 @@
"""
Integration tests for COBY orchestrator compatibility.
Tests the adapter's compatibility with the existing orchestrator interface.
"""
import asyncio
import logging
import pytest
from datetime import datetime, timedelta
from unittest.mock import Mock, AsyncMock
from ..integration.orchestrator_adapter import COBYOrchestratorAdapter, MarketTick
from ..integration.data_provider_replacement import COBYDataProvider
from ..config import Config
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class TestOrchestratorIntegration:
"""Test suite for orchestrator integration."""
@pytest.fixture
async def adapter(self):
"""Create adapter instance for testing."""
config = Config()
adapter = COBYOrchestratorAdapter(config)
# Mock the storage manager for testing
adapter.storage_manager = Mock()
adapter.storage_manager.initialize = AsyncMock()
adapter.storage_manager.is_healthy = Mock(return_value=True)
adapter.storage_manager.get_latest_orderbook = AsyncMock(return_value={
'symbol': 'BTCUSDT',
'timestamp': datetime.utcnow(),
'mid_price': 50000.0,
'spread': 0.01,
'bid_volume': 10.5,
'ask_volume': 8.3,
'exchange': 'binance'
})
adapter.storage_manager.get_historical_data = AsyncMock(return_value=[
{
'timestamp': datetime.utcnow() - timedelta(minutes=i),
'open': 50000 + i,
'high': 50010 + i,
'low': 49990 + i,
'close': 50005 + i,
'volume': 100 + i,
'symbol': 'BTCUSDT',
'exchange': 'binance'
}
for i in range(100)
])
# Mock Redis manager
adapter.redis_manager = Mock()
adapter.redis_manager.initialize = AsyncMock()
adapter.redis_manager.get = AsyncMock(return_value=None)
adapter.redis_manager.set = AsyncMock()
# Mock connectors
adapter.connectors = {'binance': Mock()}
adapter.connectors['binance'].connect = AsyncMock()
adapter.connectors['binance'].is_connected = True
await adapter._initialize_components()
return adapter
@pytest.fixture
async def data_provider(self):
"""Create data provider replacement for testing."""
# Mock the adapter initialization
provider = COBYDataProvider()
# Use the same mocks as adapter
provider.adapter.storage_manager = Mock()
provider.adapter.storage_manager.get_latest_orderbook = AsyncMock(return_value={
'symbol': 'BTCUSDT',
'timestamp': datetime.utcnow(),
'mid_price': 50000.0,
'spread': 0.01,
'bid_volume': 10.5,
'ask_volume': 8.3,
'exchange': 'binance'
})
return provider
async def test_adapter_initialization(self, adapter):
"""Test adapter initializes correctly."""
assert adapter is not None
assert adapter.mode == 'live'
assert adapter.config is not None
assert 'binance' in adapter.connectors
async def test_get_current_price(self, adapter):
"""Test getting current price."""
price = adapter.get_current_price('BTCUSDT')
assert price == 50000.0
async def test_get_historical_data(self, adapter):
"""Test getting historical data."""
df = adapter.get_historical_data('BTCUSDT', '1m', limit=50)
assert df is not None
assert len(df) == 100 # Mock returns 100 records
assert 'open' in df.columns
assert 'high' in df.columns
assert 'low' in df.columns
assert 'close' in df.columns
assert 'volume' in df.columns
async def test_build_base_data_input(self, adapter):
"""Test building base data input."""
base_data = adapter.build_base_data_input('BTCUSDT')
assert base_data is not None
assert hasattr(base_data, 'get_feature_vector')
features = base_data.get_feature_vector()
assert isinstance(features, type(features)) # numpy array
assert len(features) == 100 # Expected feature size
async def test_cob_data_methods(self, adapter):
"""Test COB data access methods."""
# Mock COB data
adapter.storage_manager.get_historical_data = AsyncMock(return_value=[
{
'symbol': 'BTCUSDT',
'timestamp': datetime.utcnow(),
'mid_price': 50000.0,
'spread': 0.01,
'bid_volume': 10.5,
'ask_volume': 8.3,
'exchange': 'binance'
}
])
# Test raw ticks
raw_ticks = adapter.get_cob_raw_ticks('BTCUSDT', count=10)
assert isinstance(raw_ticks, list)
# Test latest COB data
latest_cob = adapter.get_latest_cob_data('BTCUSDT')
assert latest_cob is not None
assert latest_cob['symbol'] == 'BTCUSDT'
assert 'mid_price' in latest_cob
async def test_subscription_management(self, adapter):
"""Test subscription methods."""
callback_called = False
received_tick = None
def tick_callback(tick):
nonlocal callback_called, received_tick
callback_called = True
received_tick = tick
# Subscribe to ticks
subscriber_id = adapter.subscribe_to_ticks(
tick_callback,
symbols=['BTCUSDT'],
subscriber_name='test_subscriber'
)
assert subscriber_id != ""
assert len(adapter.subscribers['ticks']) == 1
# Simulate tick notification
test_tick = MarketTick(
symbol='BTCUSDT',
price=50000.0,
volume=1.5,
timestamp=datetime.utcnow(),
exchange='binance'
)
await adapter._notify_tick_subscribers(test_tick)
assert callback_called
assert received_tick is not None
assert received_tick.symbol == 'BTCUSDT'
# Unsubscribe
success = adapter.unsubscribe(subscriber_id)
assert success
assert len(adapter.subscribers['ticks']) == 0
async def test_mode_switching(self, adapter):
"""Test switching between live and replay modes."""
# Initially in live mode
assert adapter.get_current_mode() == 'live'
# Mock replay manager
adapter.replay_manager = Mock()
adapter.replay_manager.create_replay_session = Mock(return_value='test_session_123')
adapter.replay_manager.add_data_callback = Mock()
adapter.replay_manager.start_replay = AsyncMock()
adapter.replay_manager.stop_replay = AsyncMock()
# Switch to replay mode
start_time = datetime.utcnow() - timedelta(hours=1)
end_time = datetime.utcnow()
success = await adapter.switch_to_replay_mode(
start_time=start_time,
end_time=end_time,
speed=2.0,
symbols=['BTCUSDT']
)
assert success
assert adapter.get_current_mode() == 'replay'
assert adapter.current_replay_session == 'test_session_123'
# Switch back to live mode
success = await adapter.switch_to_live_mode()
assert success
assert adapter.get_current_mode() == 'live'
assert adapter.current_replay_session is None
async def test_data_quality_indicators(self, adapter):
"""Test data quality indicators."""
quality = adapter.get_data_quality_indicators('BTCUSDT')
assert quality is not None
assert quality['symbol'] == 'BTCUSDT'
assert 'quality_score' in quality
assert 'timestamp' in quality
assert isinstance(quality['quality_score'], float)
assert 0.0 <= quality['quality_score'] <= 1.0
async def test_system_metadata(self, adapter):
"""Test system metadata retrieval."""
metadata = adapter.get_system_metadata()
assert metadata is not None
assert metadata['system'] == 'COBY'
assert metadata['version'] == '1.0.0'
assert 'mode' in metadata
assert 'components' in metadata
assert 'statistics' in metadata
async def test_data_provider_compatibility(self, data_provider):
"""Test data provider replacement compatibility."""
# Test core methods exist and work
assert hasattr(data_provider, 'get_historical_data')
assert hasattr(data_provider, 'get_current_price')
assert hasattr(data_provider, 'build_base_data_input')
assert hasattr(data_provider, 'subscribe_to_ticks')
assert hasattr(data_provider, 'get_cob_raw_ticks')
# Test current price
price = data_provider.get_current_price('BTCUSDT')
assert price == 50000.0
# Test COB imbalance
imbalance = data_provider.get_current_cob_imbalance('BTCUSDT')
assert 'bid_volume' in imbalance
assert 'ask_volume' in imbalance
assert 'imbalance' in imbalance
# Test WebSocket status
status = data_provider.get_cob_websocket_status()
assert 'connected' in status
assert 'exchanges' in status
async def test_error_handling(self, adapter):
"""Test error handling in various scenarios."""
# Test with invalid symbol
price = adapter.get_current_price('INVALID_SYMBOL')
# Should not raise exception, may return None
# Test with storage error
adapter.storage_manager.get_latest_orderbook = AsyncMock(side_effect=Exception("Storage error"))
price = adapter.get_current_price('BTCUSDT')
# Should handle error gracefully
# Test subscription with invalid callback
subscriber_id = adapter.subscribe_to_ticks(None, ['BTCUSDT'])
# Should handle gracefully
async def test_performance_metrics(self, adapter):
"""Test performance metrics and statistics."""
# Get initial stats
initial_stats = adapter.get_stats()
assert 'ticks_processed' in initial_stats
assert 'orderbooks_processed' in initial_stats
# Simulate some data processing
from ..models.core import OrderBookSnapshot, PriceLevel
test_orderbook = OrderBookSnapshot(
symbol='BTCUSDT',
exchange='binance',
timestamp=datetime.utcnow(),
bids=[PriceLevel(price=49999.0, size=1.5)],
asks=[PriceLevel(price=50001.0, size=1.2)]
)
await adapter._handle_connector_data(test_orderbook)
# Check stats updated
updated_stats = adapter.get_stats()
assert updated_stats['orderbooks_processed'] >= initial_stats['orderbooks_processed']
async def test_integration_suite():
"""Run the complete integration test suite."""
logger.info("Starting COBY orchestrator integration tests...")
try:
# Create test instances
config = Config()
adapter = COBYOrchestratorAdapter(config)
# Mock components for testing
adapter.storage_manager = Mock()
adapter.storage_manager.initialize = AsyncMock()
adapter.storage_manager.is_healthy = Mock(return_value=True)
adapter.redis_manager = Mock()
adapter.redis_manager.initialize = AsyncMock()
adapter.connectors = {'binance': Mock()}
adapter.connectors['binance'].connect = AsyncMock()
await adapter._initialize_components()
# Run basic functionality tests
logger.info("Testing basic functionality...")
# Test price retrieval
adapter.storage_manager.get_latest_orderbook = AsyncMock(return_value={
'symbol': 'BTCUSDT',
'timestamp': datetime.utcnow(),
'mid_price': 50000.0,
'spread': 0.01,
'bid_volume': 10.5,
'ask_volume': 8.3,
'exchange': 'binance'
})
price = adapter.get_current_price('BTCUSDT')
assert price == 50000.0
logger.info(f"✓ Current price retrieval: {price}")
# Test subscription
callback_called = False
def test_callback(tick):
nonlocal callback_called
callback_called = True
subscriber_id = adapter.subscribe_to_ticks(test_callback, ['BTCUSDT'])
assert subscriber_id != ""
logger.info(f"✓ Subscription created: {subscriber_id}")
# Test data quality
quality = adapter.get_data_quality_indicators('BTCUSDT')
assert quality['symbol'] == 'BTCUSDT'
logger.info(f"✓ Data quality check: {quality['quality_score']}")
# Test system metadata
metadata = adapter.get_system_metadata()
assert metadata['system'] == 'COBY'
logger.info(f"✓ System metadata: {metadata['mode']}")
logger.info("All integration tests passed successfully!")
return True
except Exception as e:
logger.error(f"Integration test failed: {e}")
return False
if __name__ == "__main__":
# Run the integration tests
success = asyncio.run(test_integration_suite())
if success:
print("✓ COBY orchestrator integration tests completed successfully")
else:
print("✗ COBY orchestrator integration tests failed")
exit(1)

View File

@ -0,0 +1,347 @@
"""
Tests for Redis caching system.
"""
import pytest
import asyncio
from datetime import datetime, timezone
from ..caching.redis_manager import RedisManager
from ..caching.cache_keys import CacheKeys
from ..caching.data_serializer import DataSerializer
from ..models.core import OrderBookSnapshot, HeatmapData, PriceLevel, HeatmapPoint
@pytest.fixture
async def redis_manager():
"""Create and initialize Redis manager for testing"""
manager = RedisManager()
await manager.initialize()
yield manager
await manager.close()
@pytest.fixture
def cache_keys():
"""Create cache keys helper"""
return CacheKeys()
@pytest.fixture
def data_serializer():
"""Create data serializer"""
return DataSerializer()
@pytest.fixture
def sample_orderbook():
"""Create sample order book for testing"""
return OrderBookSnapshot(
symbol="BTCUSDT",
exchange="binance",
timestamp=datetime.now(timezone.utc),
bids=[
PriceLevel(price=50000.0, size=1.5),
PriceLevel(price=49999.0, size=2.0)
],
asks=[
PriceLevel(price=50001.0, size=1.0),
PriceLevel(price=50002.0, size=1.5)
]
)
@pytest.fixture
def sample_heatmap():
"""Create sample heatmap for testing"""
heatmap = HeatmapData(
symbol="BTCUSDT",
timestamp=datetime.now(timezone.utc),
bucket_size=1.0
)
# Add some sample points
heatmap.data = [
HeatmapPoint(price=50000.0, volume=1.5, intensity=0.8, side='bid'),
HeatmapPoint(price=50001.0, volume=1.0, intensity=0.6, side='ask'),
HeatmapPoint(price=49999.0, volume=2.0, intensity=1.0, side='bid'),
HeatmapPoint(price=50002.0, volume=1.5, intensity=0.7, side='ask')
]
return heatmap
class TestCacheKeys:
"""Test cases for CacheKeys"""
def test_orderbook_key_generation(self, cache_keys):
"""Test order book key generation"""
key = cache_keys.orderbook_key("BTCUSDT", "binance")
assert key == "ob:binance:BTCUSDT"
def test_heatmap_key_generation(self, cache_keys):
"""Test heatmap key generation"""
# Exchange-specific heatmap
key1 = cache_keys.heatmap_key("BTCUSDT", 1.0, "binance")
assert key1 == "hm:binance:BTCUSDT:1.0"
# Consolidated heatmap
key2 = cache_keys.heatmap_key("BTCUSDT", 1.0)
assert key2 == "hm:consolidated:BTCUSDT:1.0"
def test_ttl_determination(self, cache_keys):
"""Test TTL determination for different key types"""
ob_key = cache_keys.orderbook_key("BTCUSDT", "binance")
hm_key = cache_keys.heatmap_key("BTCUSDT", 1.0)
assert cache_keys.get_ttl(ob_key) == cache_keys.ORDERBOOK_TTL
assert cache_keys.get_ttl(hm_key) == cache_keys.HEATMAP_TTL
def test_key_parsing(self, cache_keys):
"""Test cache key parsing"""
ob_key = cache_keys.orderbook_key("BTCUSDT", "binance")
parsed = cache_keys.parse_key(ob_key)
assert parsed['type'] == 'orderbook'
assert parsed['exchange'] == 'binance'
assert parsed['symbol'] == 'BTCUSDT'
class TestDataSerializer:
"""Test cases for DataSerializer"""
def test_simple_data_serialization(self, data_serializer):
"""Test serialization of simple data types"""
test_data = {
'string': 'test',
'number': 42,
'float': 3.14,
'boolean': True,
'list': [1, 2, 3],
'nested': {'key': 'value'}
}
# Serialize and deserialize
serialized = data_serializer.serialize(test_data)
deserialized = data_serializer.deserialize(serialized)
assert deserialized == test_data
def test_orderbook_serialization(self, data_serializer, sample_orderbook):
"""Test order book serialization"""
# Serialize and deserialize
serialized = data_serializer.serialize(sample_orderbook)
deserialized = data_serializer.deserialize(serialized)
assert isinstance(deserialized, OrderBookSnapshot)
assert deserialized.symbol == sample_orderbook.symbol
assert deserialized.exchange == sample_orderbook.exchange
assert len(deserialized.bids) == len(sample_orderbook.bids)
assert len(deserialized.asks) == len(sample_orderbook.asks)
def test_heatmap_serialization(self, data_serializer, sample_heatmap):
"""Test heatmap serialization"""
# Test specialized heatmap serialization
serialized = data_serializer.serialize_heatmap(sample_heatmap)
deserialized = data_serializer.deserialize_heatmap(serialized)
assert isinstance(deserialized, HeatmapData)
assert deserialized.symbol == sample_heatmap.symbol
assert deserialized.bucket_size == sample_heatmap.bucket_size
assert len(deserialized.data) == len(sample_heatmap.data)
# Check first point
original_point = sample_heatmap.data[0]
deserialized_point = deserialized.data[0]
assert deserialized_point.price == original_point.price
assert deserialized_point.volume == original_point.volume
assert deserialized_point.side == original_point.side
class TestRedisManager:
"""Test cases for RedisManager"""
@pytest.mark.asyncio
async def test_basic_set_get(self, redis_manager):
"""Test basic set and get operations"""
# Set a simple value
key = "test:basic"
value = {"test": "data", "number": 42}
success = await redis_manager.set(key, value, ttl=60)
assert success is True
# Get the value back
retrieved = await redis_manager.get(key)
assert retrieved == value
# Clean up
await redis_manager.delete(key)
@pytest.mark.asyncio
async def test_orderbook_caching(self, redis_manager, sample_orderbook):
"""Test order book caching"""
# Cache order book
success = await redis_manager.cache_orderbook(sample_orderbook)
assert success is True
# Retrieve order book
retrieved = await redis_manager.get_orderbook(
sample_orderbook.symbol,
sample_orderbook.exchange
)
assert retrieved is not None
assert isinstance(retrieved, OrderBookSnapshot)
assert retrieved.symbol == sample_orderbook.symbol
assert retrieved.exchange == sample_orderbook.exchange
@pytest.mark.asyncio
async def test_heatmap_caching(self, redis_manager, sample_heatmap):
"""Test heatmap caching"""
# Cache heatmap
success = await redis_manager.set_heatmap(
sample_heatmap.symbol,
sample_heatmap,
exchange="binance"
)
assert success is True
# Retrieve heatmap
retrieved = await redis_manager.get_heatmap(
sample_heatmap.symbol,
exchange="binance"
)
assert retrieved is not None
assert isinstance(retrieved, HeatmapData)
assert retrieved.symbol == sample_heatmap.symbol
assert len(retrieved.data) == len(sample_heatmap.data)
@pytest.mark.asyncio
async def test_multi_operations(self, redis_manager):
"""Test multi-get and multi-set operations"""
# Prepare test data
test_data = {
"test:multi1": {"value": 1},
"test:multi2": {"value": 2},
"test:multi3": {"value": 3}
}
# Multi-set
success = await redis_manager.mset(test_data, ttl=60)
assert success is True
# Multi-get
keys = list(test_data.keys())
values = await redis_manager.mget(keys)
assert len(values) == 3
assert all(v is not None for v in values)
# Verify values
for i, key in enumerate(keys):
assert values[i] == test_data[key]
# Clean up
for key in keys:
await redis_manager.delete(key)
@pytest.mark.asyncio
async def test_key_expiration(self, redis_manager):
"""Test key expiration"""
key = "test:expiration"
value = {"expires": "soon"}
# Set with short TTL
success = await redis_manager.set(key, value, ttl=1)
assert success is True
# Should exist immediately
exists = await redis_manager.exists(key)
assert exists is True
# Wait for expiration
await asyncio.sleep(2)
# Should not exist after expiration
exists = await redis_manager.exists(key)
assert exists is False
@pytest.mark.asyncio
async def test_cache_miss(self, redis_manager):
"""Test cache miss behavior"""
# Try to get non-existent key
value = await redis_manager.get("test:nonexistent")
assert value is None
# Check statistics
stats = redis_manager.get_stats()
assert stats['misses'] > 0
@pytest.mark.asyncio
async def test_health_check(self, redis_manager):
"""Test Redis health check"""
health = await redis_manager.health_check()
assert isinstance(health, dict)
assert 'redis_ping' in health
assert 'total_keys' in health
assert 'hit_rate' in health
# Should be able to ping
assert health['redis_ping'] is True
@pytest.mark.asyncio
async def test_statistics_tracking(self, redis_manager):
"""Test statistics tracking"""
# Reset stats
redis_manager.reset_stats()
# Perform some operations
await redis_manager.set("test:stats1", {"data": 1})
await redis_manager.set("test:stats2", {"data": 2})
await redis_manager.get("test:stats1")
await redis_manager.get("test:nonexistent")
# Check statistics
stats = redis_manager.get_stats()
assert stats['sets'] >= 2
assert stats['gets'] >= 2
assert stats['hits'] >= 1
assert stats['misses'] >= 1
assert stats['total_operations'] >= 4
# Clean up
await redis_manager.delete("test:stats1")
await redis_manager.delete("test:stats2")
if __name__ == "__main__":
# Run simple tests
async def simple_test():
manager = RedisManager()
await manager.initialize()
# Test basic operations
success = await manager.set("test", {"simple": "test"}, ttl=60)
print(f"Set operation: {'SUCCESS' if success else 'FAILED'}")
value = await manager.get("test")
print(f"Get operation: {'SUCCESS' if value else 'FAILED'}")
# Test ping
ping_result = await manager.ping()
print(f"Ping test: {'SUCCESS' if ping_result else 'FAILED'}")
# Get statistics
stats = manager.get_stats()
print(f"Statistics: {stats}")
# Clean up
await manager.delete("test")
await manager.close()
print("Simple Redis test completed")
asyncio.run(simple_test())

View File

@ -0,0 +1,153 @@
"""
Test script for the historical data replay system.
"""
import asyncio
import logging
from datetime import datetime, timedelta
from ..config import Config
from ..storage.storage_manager import StorageManager
from ..replay.replay_manager import HistoricalReplayManager
from ..models.core import ReplayStatus
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def test_replay_system():
"""Test the replay system functionality."""
# Initialize components
config = Config()
storage_manager = StorageManager(config)
try:
# Initialize storage
logger.info("Initializing storage manager...")
await storage_manager.initialize()
# Initialize replay manager
logger.info("Initializing replay manager...")
replay_manager = HistoricalReplayManager(storage_manager, config)
# Test data range query
logger.info("Testing data range query...")
data_range = await replay_manager.get_available_data_range("BTCUSDT")
if data_range:
logger.info(f"Available data range: {data_range['start']} to {data_range['end']}")
else:
logger.warning("No data available for BTCUSDT")
return
# Create test replay session
logger.info("Creating replay session...")
start_time = data_range['start']
end_time = start_time + timedelta(minutes=5) # 5 minute replay
session_id = replay_manager.create_replay_session(
start_time=start_time,
end_time=end_time,
speed=10.0, # 10x speed
symbols=["BTCUSDT"],
exchanges=["binance"]
)
logger.info(f"Created session: {session_id}")
# Add data callback
data_count = 0
def data_callback(data):
nonlocal data_count
data_count += 1
if data_count % 100 == 0:
logger.info(f"Received {data_count} data points")
replay_manager.add_data_callback(session_id, data_callback)
# Add status callback
def status_callback(session_id, status):
logger.info(f"Session {session_id} status changed to: {status.value}")
replay_manager.add_status_callback(session_id, status_callback)
# Start replay
logger.info("Starting replay...")
await replay_manager.start_replay(session_id)
# Monitor progress
while True:
session = replay_manager.get_replay_status(session_id)
if not session:
break
if session.status in [ReplayStatus.COMPLETED, ReplayStatus.ERROR, ReplayStatus.STOPPED]:
break
logger.info(f"Progress: {session.progress:.2%}, Events: {session.events_replayed}")
await asyncio.sleep(2)
# Final status
final_session = replay_manager.get_replay_status(session_id)
if final_session:
logger.info(f"Final status: {final_session.status.value}")
logger.info(f"Total events replayed: {final_session.events_replayed}")
logger.info(f"Total data callbacks: {data_count}")
# Test session controls
logger.info("Testing session controls...")
# Create another session for control testing
control_session_id = replay_manager.create_replay_session(
start_time=start_time,
end_time=end_time,
speed=1.0,
symbols=["BTCUSDT"]
)
# Start and immediately pause
await replay_manager.start_replay(control_session_id)
await asyncio.sleep(1)
await replay_manager.pause_replay(control_session_id)
# Test seek
seek_time = start_time + timedelta(minutes=2)
success = replay_manager.seek_replay(control_session_id, seek_time)
logger.info(f"Seek to {seek_time}: {'success' if success else 'failed'}")
# Test speed change
success = replay_manager.set_replay_speed(control_session_id, 5.0)
logger.info(f"Speed change to 5x: {'success' if success else 'failed'}")
# Resume and stop
await replay_manager.resume_replay(control_session_id)
await asyncio.sleep(2)
await replay_manager.stop_replay(control_session_id)
# Get statistics
stats = replay_manager.get_stats()
logger.info(f"Replay manager stats: {stats}")
# List all sessions
sessions = replay_manager.list_replay_sessions()
logger.info(f"Total sessions: {len(sessions)}")
# Clean up
for session in sessions:
replay_manager.delete_replay_session(session.session_id)
logger.info("Replay system test completed successfully!")
except Exception as e:
logger.error(f"Test failed: {e}")
raise
finally:
# Clean up
await storage_manager.close()
if __name__ == "__main__":
asyncio.run(test_replay_system())

View File

@ -0,0 +1,101 @@
"""
Test script for TimescaleDB integration.
Tests database connection, schema creation, and basic operations.
"""
import asyncio
import logging
from datetime import datetime
from decimal import Decimal
from ..config import Config
from ..storage.storage_manager import StorageManager
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
async def test_timescale_integration():
"""Test TimescaleDB integration with basic operations."""
# Initialize configuration
config = Config()
# Create storage manager
storage = StorageManager(config)
try:
# Initialize storage
logger.info("Initializing storage manager...")
await storage.initialize()
# Test health check
logger.info("Running health check...")
health = await storage.health_check()
logger.info(f"Health status: {health}")
# Test schema info
logger.info("Getting schema information...")
schema_info = await storage.get_system_stats()
logger.info(f"Schema info: {schema_info}")
# Test order book storage
logger.info("Testing order book storage...")
test_orderbook = OrderBookSnapshot(
symbol="BTCUSDT",
exchange="binance",
timestamp=datetime.utcnow(),
bids=[
PriceLevel(price=50000.0, size=1.5),
PriceLevel(price=49999.0, size=2.0),
],
asks=[
PriceLevel(price=50001.0, size=1.2),
PriceLevel(price=50002.0, size=0.8),
],
sequence_id=12345
)
success = await storage.store_orderbook(test_orderbook)
logger.info(f"Order book storage result: {success}")
# Test trade storage
logger.info("Testing trade storage...")
test_trade = TradeEvent(
symbol="BTCUSDT",
exchange="binance",
timestamp=datetime.utcnow(),
price=50000.5,
size=0.1,
side="buy",
trade_id="test_trade_123"
)
success = await storage.store_trade(test_trade)
logger.info(f"Trade storage result: {success}")
# Test data retrieval
logger.info("Testing data retrieval...")
latest_orderbook = await storage.get_latest_orderbook("BTCUSDT", "binance")
logger.info(f"Latest order book: {latest_orderbook is not None}")
# Test system stats
logger.info("Getting system statistics...")
stats = await storage.get_system_stats()
logger.info(f"System stats: {stats}")
logger.info("All tests completed successfully!")
except Exception as e:
logger.error(f"Test failed: {e}")
raise
finally:
# Clean up
await storage.close()
if __name__ == "__main__":
asyncio.run(test_timescale_integration())

View File

@ -0,0 +1,192 @@
"""
Tests for TimescaleDB storage manager.
"""
import pytest
import asyncio
from datetime import datetime, timezone
from ..storage.timescale_manager import TimescaleManager
from ..models.core import OrderBookSnapshot, TradeEvent, PriceLevel
from ..config import config
@pytest.fixture
async def storage_manager():
"""Create and initialize storage manager for testing"""
manager = TimescaleManager()
await manager.initialize()
yield manager
await manager.close()
@pytest.fixture
def sample_orderbook():
"""Create sample order book for testing"""
return OrderBookSnapshot(
symbol="BTCUSDT",
exchange="binance",
timestamp=datetime.now(timezone.utc),
bids=[
PriceLevel(price=50000.0, size=1.5, count=3),
PriceLevel(price=49999.0, size=2.0, count=5)
],
asks=[
PriceLevel(price=50001.0, size=1.0, count=2),
PriceLevel(price=50002.0, size=1.5, count=4)
],
sequence_id=12345
)
@pytest.fixture
def sample_trade():
"""Create sample trade for testing"""
return TradeEvent(
symbol="BTCUSDT",
exchange="binance",
timestamp=datetime.now(timezone.utc),
price=50000.5,
size=0.1,
side="buy",
trade_id="test_trade_123"
)
class TestTimescaleManager:
"""Test cases for TimescaleManager"""
@pytest.mark.asyncio
async def test_health_check(self, storage_manager):
"""Test storage health check"""
is_healthy = await storage_manager.health_check()
assert is_healthy is True
@pytest.mark.asyncio
async def test_store_orderbook(self, storage_manager, sample_orderbook):
"""Test storing order book snapshot"""
result = await storage_manager.store_orderbook(sample_orderbook)
assert result is True
@pytest.mark.asyncio
async def test_store_trade(self, storage_manager, sample_trade):
"""Test storing trade event"""
result = await storage_manager.store_trade(sample_trade)
assert result is True
@pytest.mark.asyncio
async def test_get_latest_orderbook(self, storage_manager, sample_orderbook):
"""Test retrieving latest order book"""
# Store the order book first
await storage_manager.store_orderbook(sample_orderbook)
# Retrieve it
retrieved = await storage_manager.get_latest_orderbook(
sample_orderbook.symbol,
sample_orderbook.exchange
)
assert retrieved is not None
assert retrieved.symbol == sample_orderbook.symbol
assert retrieved.exchange == sample_orderbook.exchange
assert len(retrieved.bids) == len(sample_orderbook.bids)
assert len(retrieved.asks) == len(sample_orderbook.asks)
@pytest.mark.asyncio
async def test_batch_store_orderbooks(self, storage_manager):
"""Test batch storing order books"""
orderbooks = []
for i in range(5):
orderbook = OrderBookSnapshot(
symbol="ETHUSDT",
exchange="binance",
timestamp=datetime.now(timezone.utc),
bids=[PriceLevel(price=3000.0 + i, size=1.0)],
asks=[PriceLevel(price=3001.0 + i, size=1.0)],
sequence_id=i
)
orderbooks.append(orderbook)
result = await storage_manager.batch_store_orderbooks(orderbooks)
assert result == 5
@pytest.mark.asyncio
async def test_batch_store_trades(self, storage_manager):
"""Test batch storing trades"""
trades = []
for i in range(5):
trade = TradeEvent(
symbol="ETHUSDT",
exchange="binance",
timestamp=datetime.now(timezone.utc),
price=3000.0 + i,
size=0.1,
side="buy" if i % 2 == 0 else "sell",
trade_id=f"test_trade_{i}"
)
trades.append(trade)
result = await storage_manager.batch_store_trades(trades)
assert result == 5
@pytest.mark.asyncio
async def test_get_storage_stats(self, storage_manager):
"""Test getting storage statistics"""
stats = await storage_manager.get_storage_stats()
assert isinstance(stats, dict)
assert 'table_sizes' in stats
assert 'record_counts' in stats
assert 'connection_pool' in stats
@pytest.mark.asyncio
async def test_historical_data_retrieval(self, storage_manager, sample_orderbook, sample_trade):
"""Test retrieving historical data"""
# Store some data first
await storage_manager.store_orderbook(sample_orderbook)
await storage_manager.store_trade(sample_trade)
# Define time range
start_time = datetime.now(timezone.utc).replace(hour=0, minute=0, second=0, microsecond=0)
end_time = datetime.now(timezone.utc).replace(hour=23, minute=59, second=59, microsecond=999999)
# Retrieve historical order books
orderbooks = await storage_manager.get_historical_orderbooks(
sample_orderbook.symbol,
sample_orderbook.exchange,
start_time,
end_time,
limit=10
)
assert isinstance(orderbooks, list)
# Retrieve historical trades
trades = await storage_manager.get_historical_trades(
sample_trade.symbol,
sample_trade.exchange,
start_time,
end_time,
limit=10
)
assert isinstance(trades, list)
if __name__ == "__main__":
# Run a simple test
async def simple_test():
manager = TimescaleManager()
await manager.initialize()
# Test health check
is_healthy = await manager.health_check()
print(f"Health check: {'PASSED' if is_healthy else 'FAILED'}")
# Test storage stats
stats = await manager.get_storage_stats()
print(f"Storage stats: {len(stats)} categories")
await manager.close()
print("Simple test completed")
asyncio.run(simple_test())

Some files were not shown because too many files have changed in this diff Show More