storage module
This commit is contained in:
@ -1,8 +1,10 @@
|
||||
import sys
|
||||
import os
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from typing import NamedTuple
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
import json
|
||||
from prisma import Prisma
|
||||
from typing import List, Optional, Dict
|
||||
import asyncio
|
||||
|
||||
class Transaction(NamedTuple):
|
||||
wallet: str
|
||||
@ -14,218 +16,243 @@ class Transaction(NamedTuple):
|
||||
amount_out: float
|
||||
value_out_usd: float
|
||||
tx_signature: str
|
||||
|
||||
from enum import Enum
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
import json
|
||||
from prisma import Prisma
|
||||
|
||||
class TransactionType(Enum):
|
||||
BUY = "BUY"
|
||||
SELL = "SELL"
|
||||
class TransactionStatus(Enum):
|
||||
PENDING = "PENDING"
|
||||
SENT = "SENT"
|
||||
CONFIRMED = "CONFIRMED"
|
||||
|
||||
# Initialize the Prisma client
|
||||
prisma_client = Prisma()
|
||||
|
||||
async def init_db():
|
||||
await prisma_client.connect()
|
||||
|
||||
async def store_transaction(transaction: Transaction):
|
||||
"""
|
||||
Store a transaction record in the database.
|
||||
"""
|
||||
await prisma_client.transaction.create(
|
||||
data={
|
||||
'wallet_id': transaction.wallet,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'type': transaction.transaction_type,
|
||||
'sell_currency': transaction.symbol_in,
|
||||
'sell_amount': transaction.amount_in,
|
||||
'sell_value': transaction.value_in_usd,
|
||||
'buy_currency': transaction.symbol_out,
|
||||
'buy_amount': transaction.amount_out,
|
||||
'buy_value': transaction.value_out_usd,
|
||||
'solana_signature': transaction.tx_signature,
|
||||
'details': json.dumps({}),
|
||||
'status': TransactionStatus.PENDING.value
|
||||
}
|
||||
)
|
||||
|
||||
async def update_holdings(wallet_id, currency, amount_change):
|
||||
holding = await prisma_client.holding.find_first(
|
||||
where={
|
||||
'wallet_id': wallet_id,
|
||||
'currency': currency
|
||||
}
|
||||
)
|
||||
if holding:
|
||||
new_amount = holding.amount + amount_change
|
||||
await prisma_client.holding.update(
|
||||
where={'id': holding.id},
|
||||
data={
|
||||
'amount': new_amount,
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
)
|
||||
else:
|
||||
await prisma_client.holding.create(
|
||||
data={
|
||||
'wallet_id': wallet_id,
|
||||
'currency': currency,
|
||||
'amount': amount_change,
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
)
|
||||
|
||||
async def get_wallet_holdings(wallet_id):
|
||||
return await prisma_client.holding.find_many(
|
||||
where={'wallet_id': wallet_id},
|
||||
select={'currency': True, 'amount': True}
|
||||
)
|
||||
|
||||
async def get_transaction_history(wallet_id, start_date=None, end_date=None, include_closed=False):
|
||||
filters = {'wallet_id': wallet_id}
|
||||
if not include_closed:
|
||||
filters['closed'] = False
|
||||
if start_date:
|
||||
filters['timestamp'] = {'gte': start_date}
|
||||
if end_date:
|
||||
filters['timestamp'] = {'lte': end_date}
|
||||
class Storage:
|
||||
_instance: Optional['Storage'] = None
|
||||
_lock = asyncio.Lock()
|
||||
_initialized = False
|
||||
prisma = Prisma() # Class-level Prisma instance
|
||||
|
||||
return await prisma_client.transaction.find_many(
|
||||
where=filters,
|
||||
order={'timestamp': 'desc'}
|
||||
)
|
||||
STABLECOINS = ['USDC', 'USDT', 'SOL']
|
||||
|
||||
async def close_transaction(transaction_id):
|
||||
await prisma_client.transaction.update(
|
||||
where={'id': transaction_id},
|
||||
data={'closed': True}
|
||||
)
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
cls._instance = super(Storage, cls).__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
async def get_open_transactions(wallet_id, currency):
|
||||
return await prisma_client.transaction.find_many(
|
||||
where={
|
||||
'wallet_id': wallet_id,
|
||||
'buy_currency': currency,
|
||||
'closed': False
|
||||
},
|
||||
order={'timestamp': 'asc'}
|
||||
)
|
||||
async def __ainit__(self):
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
await self.prisma.connect()
|
||||
self._initialized = True
|
||||
|
||||
async def calculate_current_holdings(wallet_id):
|
||||
transactions = await prisma_client.transaction.group_by(
|
||||
by=['buy_currency'],
|
||||
where={'wallet_id': wallet_id, 'closed': False},
|
||||
_sum={'buy_amount': True, 'sell_amount': True}
|
||||
)
|
||||
return [
|
||||
{
|
||||
'currency': t.buy_currency,
|
||||
'amount': t._sum.buy_amount - (t._sum.sell_amount or 0)
|
||||
@classmethod
|
||||
async def get_instance(cls) -> 'Storage':
|
||||
if not cls._instance:
|
||||
async with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = cls()
|
||||
await cls._instance.__ainit__()
|
||||
return cls._instance
|
||||
@classmethod
|
||||
def get_prisma(cls):
|
||||
return cls.prisma
|
||||
async def disconnect(self):
|
||||
if self._initialized:
|
||||
await self.prisma.disconnect()
|
||||
self._initialized = False
|
||||
|
||||
def __init__(self):
|
||||
self.prisma = Prisma()
|
||||
self.users = {
|
||||
"db": {"id": 1, "username": "db", "email": "user1@example.com", "password": "db"},
|
||||
"popov": {"id": 2, "username": "popov", "email": "user2@example.com", "password": "popov"}
|
||||
}
|
||||
for t in transactions if t._sum.buy_amount > (t._sum.sell_amount or 0)
|
||||
]
|
||||
|
||||
STABLECOINS = ['USDC', 'USDT', 'SOL']
|
||||
def is_connected(self):
|
||||
return self.prisma.is_connected()
|
||||
|
||||
async def is_transaction_closed(wallet_id, transaction_id):
|
||||
transaction = await prisma_client.transaction.find_unique(
|
||||
where={'id': transaction_id}
|
||||
)
|
||||
if transaction:
|
||||
sold_amount = await prisma_client.transaction.aggregate(
|
||||
_sum={'sell_amount': True},
|
||||
async def connect(self):
|
||||
await self.prisma.connect()
|
||||
|
||||
async def disconnect(self):
|
||||
await self.prisma.disconnect()
|
||||
|
||||
async def store_transaction(self, transaction: Transaction):
|
||||
return await self.prisma.transaction.create(
|
||||
data={
|
||||
'wallet_id': transaction.wallet,
|
||||
'timestamp': datetime.now().isoformat(),
|
||||
'type': transaction.transaction_type,
|
||||
'sell_currency': transaction.symbol_in,
|
||||
'sell_amount': transaction.amount_in,
|
||||
'sell_value': transaction.value_in_usd,
|
||||
'buy_currency': transaction.symbol_out,
|
||||
'buy_amount': transaction.amount_out,
|
||||
'buy_value': transaction.value_out_usd,
|
||||
'solana_signature': transaction.tx_signature,
|
||||
'details': json.dumps({}),
|
||||
'status': TransactionStatus.PENDING.value
|
||||
}
|
||||
)
|
||||
|
||||
async def update_holdings(self, wallet_id: str, currency: str, amount_change: float):
|
||||
holding = await self.prisma.holding.find_first(
|
||||
where={
|
||||
'wallet_id': wallet_id,
|
||||
'sell_currency': transaction.buy_currency,
|
||||
'timestamp': {'gt': transaction.timestamp}
|
||||
'currency': currency
|
||||
}
|
||||
)
|
||||
return sold_amount._sum.sell_amount >= transaction.buy_amount
|
||||
return False
|
||||
if holding:
|
||||
new_amount = holding.amount + amount_change
|
||||
return await self.prisma.holding.update(
|
||||
where={'id': holding.id},
|
||||
data={
|
||||
'amount': new_amount,
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
)
|
||||
else:
|
||||
return await self.prisma.holding.create(
|
||||
data={
|
||||
'wallet_id': wallet_id,
|
||||
'currency': currency,
|
||||
'amount': amount_change,
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
)
|
||||
|
||||
async def close_completed_transactions(wallet_id):
|
||||
transactions = await prisma_client.transaction.find_many(
|
||||
where={
|
||||
'wallet_id': wallet_id,
|
||||
'closed': False,
|
||||
'buy_currency': {'notIn': STABLECOINS}
|
||||
}
|
||||
)
|
||||
for transaction in transactions:
|
||||
if await is_transaction_closed(wallet_id, transaction.id):
|
||||
await close_transaction(transaction.id)
|
||||
async def get_wallet_holdings(self, wallet_id: str):
|
||||
return await self.prisma.holding.find_many(
|
||||
where={'wallet_id': wallet_id},
|
||||
select={'currency': True, 'amount': True}
|
||||
)
|
||||
|
||||
async def get_profit_loss(wallet_id, currency, start_date=None, end_date=None):
|
||||
filters = {
|
||||
'wallet_id': wallet_id,
|
||||
'OR': [
|
||||
{'sell_currency': currency},
|
||||
{'buy_currency': currency}
|
||||
async def get_transaction_history(self, wallet_id: str, start_date: Optional[str] = None,
|
||||
end_date: Optional[str] = None, include_closed: bool = False):
|
||||
filters = {'wallet_id': wallet_id}
|
||||
if not include_closed:
|
||||
filters['closed'] = False
|
||||
if start_date:
|
||||
filters['timestamp'] = {'gte': start_date}
|
||||
if end_date:
|
||||
filters['timestamp'] = {'lte': end_date}
|
||||
|
||||
return await self.prisma.transaction.find_many(
|
||||
where=filters,
|
||||
order={'timestamp': 'desc'}
|
||||
)
|
||||
|
||||
async def close_transaction(self, transaction_id: int):
|
||||
return await self.prisma.transaction.update(
|
||||
where={'id': transaction_id},
|
||||
data={'closed': True}
|
||||
)
|
||||
|
||||
async def get_open_transactions(self, wallet_id: str, currency: str):
|
||||
return await self.prisma.transaction.find_many(
|
||||
where={
|
||||
'wallet_id': wallet_id,
|
||||
'buy_currency': currency,
|
||||
'closed': False
|
||||
},
|
||||
order={'timestamp': 'asc'}
|
||||
)
|
||||
|
||||
async def calculate_current_holdings(self, wallet_id: str):
|
||||
transactions = await self.prisma.transaction.group_by(
|
||||
by=['buy_currency'],
|
||||
where={'wallet_id': wallet_id, 'closed': False},
|
||||
_sum={'buy_amount': True, 'sell_amount': True}
|
||||
)
|
||||
return [
|
||||
{
|
||||
'currency': t.buy_currency,
|
||||
'amount': t._sum.buy_amount - (t._sum.sell_amount or 0)
|
||||
}
|
||||
for t in transactions if t._sum.buy_amount > (t._sum.sell_amount or 0)
|
||||
]
|
||||
}
|
||||
if start_date:
|
||||
filters['timestamp'] = {'gte': start_date}
|
||||
if end_date:
|
||||
filters['timestamp'] = {'lte': end_date}
|
||||
|
||||
result = await prisma_client.transaction.aggregate(
|
||||
_sum={
|
||||
'sell_value': True,
|
||||
'buy_value': True
|
||||
},
|
||||
where=filters
|
||||
)
|
||||
return (result._sum.sell_value or 0) - (result._sum.buy_value or 0)
|
||||
|
||||
# # # # # # USERS
|
||||
async def is_transaction_closed(self, wallet_id: str, transaction_id: int) -> bool:
|
||||
transaction = await self.prisma.transaction.find_unique(
|
||||
where={'id': transaction_id}
|
||||
)
|
||||
if transaction:
|
||||
sold_amount = await self.prisma.transaction.aggregate(
|
||||
_sum={'sell_amount': True},
|
||||
where={
|
||||
'wallet_id': wallet_id,
|
||||
'sell_currency': transaction.buy_currency,
|
||||
'timestamp': {'gt': transaction.timestamp}
|
||||
}
|
||||
)
|
||||
return sold_amount._sum.sell_amount >= transaction.buy_amount
|
||||
return False
|
||||
|
||||
# For this example, we'll use a simple dictionary to store users
|
||||
users = {
|
||||
"db": {"id": 1, "username": "db", "email": "user1@example.com", "password": "db"},
|
||||
"popov": {"id": 2, "username": "popov", "email": "user2@example.com", "password": "popov"}
|
||||
}
|
||||
async def close_completed_transactions(self, wallet_id: str):
|
||||
transactions = await self.prisma.transaction.find_many(
|
||||
where={
|
||||
'wallet_id': wallet_id,
|
||||
'closed': False,
|
||||
'buy_currency': {'notIn': self.STABLECOINS}
|
||||
}
|
||||
)
|
||||
for transaction in transactions:
|
||||
if await self.is_transaction_closed(wallet_id, transaction.id):
|
||||
await self.close_transaction(transaction.id)
|
||||
|
||||
def get_or_create_user(email, google_id):
|
||||
user = next((u for u in users.values() if u['email'] == email), None)
|
||||
if not user:
|
||||
user_id = max(u['id'] for u in users.values()) + 1
|
||||
username = email.split('@')[0] # Use the part before @ as username
|
||||
user = {
|
||||
'id': user_id,
|
||||
'username': username,
|
||||
'email': email,
|
||||
'google_id': google_id
|
||||
async def get_profit_loss(self, wallet_id: str, currency: str,
|
||||
start_date: Optional[str] = None, end_date: Optional[str] = None):
|
||||
filters = {
|
||||
'wallet_id': wallet_id,
|
||||
'OR': [
|
||||
{'sell_currency': currency},
|
||||
{'buy_currency': currency}
|
||||
]
|
||||
}
|
||||
users[username] = user
|
||||
return user
|
||||
if start_date:
|
||||
filters['timestamp'] = {'gte': start_date}
|
||||
if end_date:
|
||||
filters['timestamp'] = {'lte': end_date}
|
||||
|
||||
result = await self.prisma.transaction.aggregate(
|
||||
_sum={
|
||||
'sell_value': True,
|
||||
'buy_value': True
|
||||
},
|
||||
where=filters
|
||||
)
|
||||
return (result._sum.sell_value or 0) - (result._sum.buy_value or 0)
|
||||
|
||||
def authenticate_user(username, password):
|
||||
"""
|
||||
Authenticate a user based on username and password.
|
||||
Returns user data if authentication is successful, None otherwise.
|
||||
"""
|
||||
user = users.get(username)
|
||||
if user and user['password'] == password:
|
||||
return {"id": user['id'], "username": user['username'], "email": user['email']}
|
||||
return None
|
||||
# User management methods
|
||||
def get_or_create_user(self, email: str, google_id: str):
|
||||
user = next((u for u in self.users.values() if u['email'] == email), None)
|
||||
if not user:
|
||||
user_id = max(u['id'] for u in self.users.values()) + 1
|
||||
username = email.split('@')[0]
|
||||
user = {
|
||||
'id': user_id,
|
||||
'username': username,
|
||||
'email': email,
|
||||
'google_id': google_id
|
||||
}
|
||||
self.users[username] = user
|
||||
return user
|
||||
|
||||
def get_user_by_id(user_id):
|
||||
"""
|
||||
Retrieve a user by their ID.
|
||||
"""
|
||||
for user in users.values():
|
||||
if user['id'] == int(user_id):
|
||||
def authenticate_user(self, username: str, password: str):
|
||||
user = self.users.get(username)
|
||||
if user and user['password'] == password:
|
||||
return {"id": user['id'], "username": user['username'], "email": user['email']}
|
||||
return None
|
||||
return None
|
||||
|
||||
def store_api_key(user_id, api_key):
|
||||
"""
|
||||
Store the generated API key for a user.
|
||||
"""
|
||||
# In a real application, you would store this in a database
|
||||
# For this example, we'll just print it
|
||||
print(f"Storing API key {api_key} for user {user_id}")
|
||||
def get_user_by_id(self, user_id: int):
|
||||
for user in self.users.values():
|
||||
if user['id'] == int(user_id):
|
||||
return {"id": user['id'], "username": user['username'], "email": user['email']}
|
||||
return None
|
||||
|
||||
def store_api_key(self, user_id: int, api_key: str):
|
||||
print(f"Storing API key {api_key} for user {user_id}")
|
||||
|
||||
|
||||
storage = Storage()
|
||||
|
Reference in New Issue
Block a user