storage module

This commit is contained in:
Dobromir Popov 2024-11-14 15:39:35 +02:00
parent eebba5d6b4
commit b623c4cb15
6 changed files with 427 additions and 418 deletions

1
.gitignore vendored
View File

@ -28,3 +28,4 @@ crypto/sol/logs/token_info.json
crypto/sol/logs/transation_details.json crypto/sol/logs/transation_details.json
.env .env
app_data.db app_data.db
crypto/sol/.vs/*

View File

@ -1,26 +1,31 @@
import concurrent.futures import asyncio
import threading
import queue
import uvicorn
import os
from dotenv import load_dotenv
import datetime import datetime
import json import json
import websockets
import logging import logging
from modules.webui import init_app import os
from modules.utils import telegram_utils, logging, get_pk from typing import Dict, Any
from modules.log_processor import watch_for_new_logs
from modules.SolanaAPI import SAPI
from config import DO_WATCH_WALLET
from asgiref.wsgi import WsgiToAsgi
from multiprocessing import Process
import time
import uvicorn
from asgiref.wsgi import WsgiToAsgi
from dotenv import load_dotenv
from config import DO_WATCH_WALLET
from modules.SolanaAPI import SAPI
from modules.log_processor import watch_for_new_logs
from modules.utils import telegram_utils
from modules.webui import init_app, teardown_app
# Load environment variables
load_dotenv() load_dotenv()
load_dotenv('.env.secret') load_dotenv('.env.secret')
def save_log(log): # Configure logging
logger = logging.getLogger(__name__)
class LogProcessor:
@staticmethod
def save_log(log: Dict[str, Any]) -> None:
"""Save log to JSON file with timestamp."""
try: try:
os.makedirs('./logs', exist_ok=True) os.makedirs('./logs', exist_ok=True)
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f")
@ -28,13 +33,11 @@ def save_log(log):
with open(filename, 'w') as f: with open(filename, 'w') as f:
json.dump(log, f, indent=2) json.dump(log, f, indent=2)
except Exception as e: except Exception as e:
logging.error(f"Error saving RPC log: {e}") logger.error(f"Error saving RPC log: {e}")
PROCESSING_LOG = False @staticmethod
log_queue = queue.Queue() def extract_transaction_details(logs: list) -> Dict[str, Any]:
"""Extract transaction details from logs."""
def process_log(log_result):
global PROCESSING_LOG
tr_details = { tr_details = {
"order_id": None, "order_id": None,
"token_in": None, "token_in": None,
@ -46,26 +49,10 @@ def process_log(log_result):
"percentage_swapped": 0 "percentage_swapped": 0
} }
if log_result['value']['err']:
return
logs = log_result['value']['logs']
try:
PROCESSING_LOG = True
swap_operations = ['Program log: Instruction: Swap', 'Program log: Instruction: Swap2',
'Program log: Instruction: SwapExactAmountIn', 'Program log: Instruction: SwapV2']
if any(op in logs for op in swap_operations):
save_log(log_result)
tx_signature_str = log_result['value']['signature']
before_source_balance = 0 before_source_balance = 0
source_token_change = 0 source_token_change = 0
i = 0 for i, log_entry in enumerate(logs):
while i < len(logs):
log_entry = logs[i]
if tr_details["order_id"] is None and "order_id" in log_entry: if tr_details["order_id"] is None and "order_id" in log_entry:
tr_details["order_id"] = log_entry.split(":")[-1].strip() tr_details["order_id"] = log_entry.split(":")[-1].strip()
tr_details["token_in"] = logs[i + 1].split(":")[-1].strip() tr_details["token_in"] = logs[i + 1].split(":")[-1].strip()
@ -80,154 +67,114 @@ def process_log(log_result):
tr_details["amount_out"] = float(part.split(":")[-1].strip()) / 10 ** 6 tr_details["amount_out"] = float(part.split(":")[-1].strip()) / 10 ** 6
if "before_source_balance" in log_entry: if "before_source_balance" in log_entry:
parts = log_entry.split(", ") before_source_balance = float(log_entry.split(":")[-1].strip()) / 10 ** 6
for part in parts:
if "before_source_balance" in part:
before_source_balance = float(part.split(":")[-1].strip()) / 10 ** 6
if "source_token_change" in log_entry: if "source_token_change" in log_entry:
parts = log_entry.split(", ") source_token_change = float(log_entry.split(":")[-1].strip()) / 10 ** 6
for part in parts:
if "source_token_change" in part:
source_token_change = float(part.split(":")[-1].strip()) / 10 ** 6
i += 1
try:
if tr_details["token_in"] is None or tr_details["token_out"] is None or \
tr_details["amount_in"] == 0 or tr_details["amount_out"] == 0:
logging.warning("Incomplete swap details found in logs. Getting details from transaction")
tr_details = SAPI.get_transaction_details_info(tx_signature_str, logs)
if before_source_balance > 0 and source_token_change > 0: if before_source_balance > 0 and source_token_change > 0:
tr_details["percentage_swapped"] = (source_token_change / before_source_balance) * 100 tr_details["percentage_swapped"] = (source_token_change / before_source_balance) * 100
if tr_details["percentage_swapped"] > 100: if tr_details["percentage_swapped"] > 100:
tr_details["percentage_swapped"] = tr_details["percentage_swapped"] / 1000 tr_details["percentage_swapped"] /= 1000
return tr_details
@staticmethod
async def process_log(log_result: Dict[str, Any]) -> Dict[str, Any]:
"""Process a single log entry."""
if log_result['value']['err']:
return
logs = log_result['value']['logs']
swap_operations = [
'Program log: Instruction: Swap',
'Program log: Instruction: Swap2',
'Program log: Instruction: SwapExactAmountIn',
'Program log: Instruction: SwapV2'
]
try: try:
if not any(op in logs for op in swap_operations):
return
LogProcessor.save_log(log_result)
tx_signature = log_result['value']['signature']
tr_details = LogProcessor.extract_transaction_details(logs)
if not all([tr_details["token_in"], tr_details["token_out"],
tr_details["amount_in"], tr_details["amount_out"]]):
tr_details = await SAPI.get_transaction_details_info(tx_signature, logs)
# Update token information
token_in = SAPI.dex.TOKENS_INFO[tr_details["token_in"]] token_in = SAPI.dex.TOKENS_INFO[tr_details["token_in"]]
token_out = SAPI.dex.TOKENS_INFO[tr_details["token_out"]] token_out = SAPI.dex.TOKENS_INFO[tr_details["token_out"]]
tr_details["symbol_in"] = token_in.get('symbol') tr_details.update({
tr_details["symbol_out"] = token_out.get('symbol') "symbol_in": token_in.get('symbol'),
tr_details['amount_in_USD'] = tr_details['amount_in'] * token_in.get('price', 0) "symbol_out": token_out.get('symbol'),
tr_details['amount_out_USD'] = tr_details['amount_out'] * token_out.get('price', 0) "amount_in_USD": tr_details['amount_in'] * token_in.get('price', 0),
"amount_out_USD": tr_details['amount_out'] * token_out.get('price', 0)
})
except Exception as e: # Send notification
logging.error(f"Error fetching token prices: {e}") message = (
message_text = (
f"<b>Swap detected: </b>\n" f"<b>Swap detected: </b>\n"
f"{tr_details['amount_in_USD']:.2f} worth of {tr_details['symbol_in']} " f"{tr_details['amount_in_USD']:.2f} worth of {tr_details['symbol_in']} "
f"({tr_details['percentage_swapped']:.2f}% ) swapped for {tr_details['symbol_out']} \n" f"({tr_details['percentage_swapped']:.2f}% ) swapped for {tr_details['symbol_out']}"
) )
telegram_utils.send_telegram_message(message_text) await telegram_utils.send_telegram_message(message)
SAPI.follow_move(tr_details)
SAPI.save_token_info() # Follow up actions
await SAPI.follow_move(tr_details)
await SAPI.save_token_info()
except Exception as e: except Exception as e:
logging.error(f"Error acquiring log details and following: {e}") logger.error(f"Error processing log: {e}")
telegram_utils.send_telegram_message(f"Not followed! Error following move.") await telegram_utils.send_telegram_message("Not followed! Error following move.")
except Exception as e:
logging.error(f"Error processing log: {e}")
PROCESSING_LOG = False
return tr_details return tr_details
# def process_messages(websocket): class Bot:
# try: @staticmethod
# while True: async def initialize():
# response = websocket.recv() """Initialize bot and start monitoring."""
# response_data = json.loads(response) await telegram_utils.initialize()
# logger.debug(f"Received response: {response_data}") await telegram_utils.send_telegram_message("Solana Agent Started. Connecting to mainnet...")
# if 'result' in response_data:
# new_sub_id = response_data['result']
# if int(new_sub_id) > 1:
# subscription_id = new_sub_id
# logger.info(f"Subscription successful. New id: {subscription_id}")
# elif new_sub_id:
# logger.info(f"Existing subscription confirmed: {subscription_id}")
# else:
# return None
# return subscription_id
# elif 'params' in response_data:
# log = response_data['params']['result']
# logger.debug(f"Received transaction log: {log}")
# log_queue.put(log)
# else:
# logger.warning(f"Unexpected response: {response_data}")
# except Exception as e:
# logger.error(f"An error occurred: {e}")
def log_processor_worker():
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
while True:
try:
log = log_queue.get()
executor.submit(process_log, log)
except Exception as e:
logger.error(f"Error in log processor worker: {e}")
finally:
log_queue.task_done()
logger = logging.getLogger(__name__)
app = init_app()
asgi_app = WsgiToAsgi(app)
def run_with_retry(task_func, *args, **kwargs):
while True:
try:
task_func(*args, **kwargs)
except Exception as e:
error_msg = f"Error in task {task_func.__name__}: {e}"
logger.error(error_msg)
telegram_utils.send_telegram_message(error_msg)
time.sleep(5)
def init_bot():
# Initialize bot components
# pk = get_pk()
telegram_utils.initialize()
telegram_utils.send_telegram_message("Solana Agent Started. Connecting to mainnet...")
# Start monitoring tasks
threading.Thread(target=run_with_retry, args=(watch_for_new_logs,), daemon=True).start()
asyncio.create_task(watch_for_new_logs())
if DO_WATCH_WALLET: if DO_WATCH_WALLET:
threading.Thread(target=run_with_retry, args=(SAPI.wallet_watch_loop,), daemon=True).start() asyncio.create_task(SAPI.wallet_watch_loop())
def run_server(): async def start_server():
uvicorn.run( """Run the ASGI server."""
config = uvicorn.Config(
"app:asgi_app", "app:asgi_app",
host="0.0.0.0", host="0.0.0.0",
port=3001, port=3001,
log_level="info", log_level="info",
reload=True reload=True
) )
server = uvicorn.Server(config)
await server.serve()
def main(): async def main():
# Start log processor worker """Main application entry point."""
log_processor_thread = threading.Thread(target=log_processor_worker, daemon=True) # Initialize app and create ASGI wrapper
log_processor_thread.start() app = await init_app()
global asgi_app
asgi_app = WsgiToAsgi(app)
# Initialize bot # Initialize bot
init_bot() await Bot.initialize()
# Start server in a separate process
server_process = Process(target=run_server)
server_process.start()
# Start server
try: try:
# Keep main process running await start_server()
while True:
time.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
logger.info("Shutting down...") logger.info("Shutting down...")
finally: await teardown_app()
server_process.terminate()
server_process.join()
if __name__ == '__main__': if __name__ == '__main__':
main() try:
asyncio.run(main())
except KeyboardInterrupt:
logger.info("Application terminated by user")

View File

@ -41,6 +41,7 @@ from solders import message
from jupiter_python_sdk.jupiter import Jupiter from jupiter_python_sdk.jupiter import Jupiter
import asyncio import asyncio
import contextlib
import json import json
import logging import logging
import random import random
@ -99,6 +100,8 @@ class SolanaWS:
async def connect(self): async def connect(self):
while True: while True:
if self.websocket is None or self.websocket.closed:
await self.connect()
try: try:
current_url = random.choice(SOLANA_ENDPOINTS) current_url = random.choice(SOLANA_ENDPOINTS)
self.websocket = await websockets.connect(current_url, ping_interval=30, ping_timeout=10) self.websocket = await websockets.connect(current_url, ping_interval=30, ping_timeout=10)
@ -262,8 +265,13 @@ class SolanaAPI:
async def process_messages(self, solana_ws): async def process_messages(self, solana_ws):
while True: while True:
try:
message = await solana_ws.message_queue.get() message = await solana_ws.message_queue.get()
await self.process_transaction(message) await self.process_transaction(message)
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Error processing message: {e}")
_first_subscription = True _first_subscription = True
@ -1045,17 +1053,21 @@ class SolanaDEX:
base_url = "https://api.coingecko.com/api/v3/simple/token_price/solana" base_url = "https://api.coingecko.com/api/v3/simple/token_price/solana"
prices = {} prices = {}
async def fetch_single_price(session, address): async def fetch_single_price(session, address, retries=3, backoff_factor=0.5):
params = { params = {
"contract_addresses": address, "contract_addresses": address,
"vs_currencies": self.DISPLAY_CURRENCY.lower() "vs_currencies": self.DISPLAY_CURRENCY.lower()
} }
for attempt in range(retries):
try: try:
async with session.get(base_url, params=params) as response: async with session.get(base_url, params=params) as response:
if response.status == 200: if response.status == 200:
data = await response.json() data = await response.json()
if address in data and self.DISPLAY_CURRENCY.lower() in data[address]: if address in data and self.DISPLAY_CURRENCY.lower() in data[address]:
return address, data[address][self.DISPLAY_CURRENCY.lower()] return address, data[address][self.DISPLAY_CURRENCY.lower()]
elif response.status == 429:
logging.warning(f"Rate limit exceeded for {address}. Retrying...")
await asyncio.sleep(backoff_factor * (2 ** attempt))
else: else:
logging.warning(f"Failed to get price for {address} from CoinGecko. Status: {response.status}") logging.warning(f"Failed to get price for {address} from CoinGecko. Status: {response.status}")
except Exception as e: except Exception as e:
@ -1169,7 +1181,11 @@ class SolanaDEX:
async def get_wallet_balances(self, wallet_address, doGetTokenName=True): async def get_wallet_balances(self, wallet_address, doGetTokenName=True):
balances = {} balances = {}
if not asyncio.get_event_loop().is_running():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
logging.info(f"Getting balances for wallet: {wallet_address}") logging.info(f"Getting balances for wallet: {wallet_address}")
response = None
try: try:
response = await self.solana_client.get_token_accounts_by_owner_json_parsed( response = await self.solana_client.get_token_accounts_by_owner_json_parsed(
Pubkey.from_string(wallet_address), Pubkey.from_string(wallet_address),
@ -1189,8 +1205,9 @@ class SolanaDEX:
mint = info['mint'] mint = info['mint']
decimals = int(info['tokenAmount']['decimals']) decimals = int(info['tokenAmount']['decimals'])
amount = int(info['tokenAmount']['amount']) amount = int(info['tokenAmount']['amount'])
amount = float(amount /10**decimals) amount = int(amount)
if amount > 1: if amount > 1:
amount = float(amount / 10**decimals)
if mint in self.TOKENS_INFO: if mint in self.TOKENS_INFO:
token_name = self.TOKENS_INFO[mint].get('symbol') token_name = self.TOKENS_INFO[mint].get('symbol')
elif doGetTokenName: elif doGetTokenName:
@ -1198,7 +1215,7 @@ class SolanaDEX:
self.TOKENS_INFO[mint] = {'symbol': token_name} self.TOKENS_INFO[mint] = {'symbol': token_name}
await asyncio.sleep(2) await asyncio.sleep(2)
self.TOKENS_INFO[mint]['holdedAmount'] = round(amount,decimals) self.TOKENS_INFO[mint]['holdedAmount'] = round(amount, decimals)
self.TOKENS_INFO[mint]['decimals'] = decimals self.TOKENS_INFO[mint]['decimals'] = decimals
balances[mint] = { balances[mint] = {
'name': token_name or 'N/A', 'name': token_name or 'N/A',
@ -1227,7 +1244,10 @@ class SolanaDEX:
except Exception as e: except Exception as e:
logging.error(f"Error getting wallet balances: {str(e)}") logging.error(f"Error getting wallet balances: {str(e)}")
if response and response.value:
logging.info(f"Found {len(response.value)} ({len(balances)} non zero) token accounts for wallet: {wallet_address}") logging.info(f"Found {len(response.value)} ({len(balances)} non zero) token accounts for wallet: {wallet_address}")
else:
logging.warning(f"No token accounts found for wallet: {wallet_address}")
return balances return balances
async def convert_balances_to_currency(self, balances, sol_price): async def convert_balances_to_currency(self, balances, sol_price):

View File

@ -1,7 +1,7 @@
import os import os
import asyncio import asyncio
from pathlib import Path from pathlib import Path
from .storage import store_transaction, prisma_client from .storage import Storage
from .SolanaAPI import SolanaAPI from .SolanaAPI import SolanaAPI
LOG_DIRECTORY = "./logs" LOG_DIRECTORY = "./logs"
@ -28,7 +28,7 @@ async def process_log_file(file_path):
transaction_data = await solana_api.process_wh(data) transaction_data = await solana_api.process_wh(data)
# Check if the transaction already exists # Check if the transaction already exists
existing_transaction = await prisma_client.transaction.find_first( existing_transaction = await Storage.get_prisma().transaction.find_first(
where={'solana_signature': solana_signature} where={'solana_signature': solana_signature}
) )
@ -46,7 +46,8 @@ async def process_log_file(file_path):
'solana_signature': solana_signature, 'solana_signature': solana_signature,
'details': details 'details': details
} }
await store_transaction(transaction_data) storage = Storage()
await storage.store_transaction(transaction_data)
# Rename the file to append '_saved' # Rename the file to append '_saved'
new_file_path = file_path.with_name(file_path.stem + "_saved" + file_path.suffix) new_file_path = file_path.with_name(file_path.stem + "_saved" + file_path.suffix)

View File

@ -1,8 +1,10 @@
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from typing import NamedTuple from typing import NamedTuple
from enum import Enum
from datetime import datetime
import json
from prisma import Prisma
from typing import List, Optional, Dict
import asyncio
class Transaction(NamedTuple): class Transaction(NamedTuple):
wallet: str wallet: str
@ -15,28 +17,69 @@ class Transaction(NamedTuple):
value_out_usd: float value_out_usd: float
tx_signature: str tx_signature: str
from enum import Enum class TransactionType(Enum):
from datetime import datetime BUY = "BUY"
from enum import Enum SELL = "SELL"
import json
from prisma import Prisma
class TransactionStatus(Enum): class TransactionStatus(Enum):
PENDING = "PENDING" PENDING = "PENDING"
SENT = "SENT" SENT = "SENT"
CONFIRMED = "CONFIRMED" CONFIRMED = "CONFIRMED"
# Initialize the Prisma client
prisma_client = Prisma()
async def init_db(): class Storage:
await prisma_client.connect() _instance: Optional['Storage'] = None
_lock = asyncio.Lock()
_initialized = False
prisma = Prisma() # Class-level Prisma instance
async def store_transaction(transaction: Transaction): STABLECOINS = ['USDC', 'USDT', 'SOL']
"""
Store a transaction record in the database. def __new__(cls):
""" if cls._instance is None:
await prisma_client.transaction.create( cls._instance = super(Storage, cls).__new__(cls)
return cls._instance
async def __ainit__(self):
if self._initialized:
return
await self.prisma.connect()
self._initialized = True
@classmethod
async def get_instance(cls) -> 'Storage':
if not cls._instance:
async with cls._lock:
if not cls._instance:
cls._instance = cls()
await cls._instance.__ainit__()
return cls._instance
@classmethod
def get_prisma(cls):
return cls.prisma
async def disconnect(self):
if self._initialized:
await self.prisma.disconnect()
self._initialized = False
def __init__(self):
self.prisma = Prisma()
self.users = {
"db": {"id": 1, "username": "db", "email": "user1@example.com", "password": "db"},
"popov": {"id": 2, "username": "popov", "email": "user2@example.com", "password": "popov"}
}
def is_connected(self):
return self.prisma.is_connected()
async def connect(self):
await self.prisma.connect()
async def disconnect(self):
await self.prisma.disconnect()
async def store_transaction(self, transaction: Transaction):
return await self.prisma.transaction.create(
data={ data={
'wallet_id': transaction.wallet, 'wallet_id': transaction.wallet,
'timestamp': datetime.now().isoformat(), 'timestamp': datetime.now().isoformat(),
@ -53,8 +96,8 @@ async def store_transaction(transaction: Transaction):
} }
) )
async def update_holdings(wallet_id, currency, amount_change): async def update_holdings(self, wallet_id: str, currency: str, amount_change: float):
holding = await prisma_client.holding.find_first( holding = await self.prisma.holding.find_first(
where={ where={
'wallet_id': wallet_id, 'wallet_id': wallet_id,
'currency': currency 'currency': currency
@ -62,7 +105,7 @@ async def update_holdings(wallet_id, currency, amount_change):
) )
if holding: if holding:
new_amount = holding.amount + amount_change new_amount = holding.amount + amount_change
await prisma_client.holding.update( return await self.prisma.holding.update(
where={'id': holding.id}, where={'id': holding.id},
data={ data={
'amount': new_amount, 'amount': new_amount,
@ -70,7 +113,7 @@ async def update_holdings(wallet_id, currency, amount_change):
} }
) )
else: else:
await prisma_client.holding.create( return await self.prisma.holding.create(
data={ data={
'wallet_id': wallet_id, 'wallet_id': wallet_id,
'currency': currency, 'currency': currency,
@ -79,13 +122,14 @@ async def update_holdings(wallet_id, currency, amount_change):
} }
) )
async def get_wallet_holdings(wallet_id): async def get_wallet_holdings(self, wallet_id: str):
return await prisma_client.holding.find_many( return await self.prisma.holding.find_many(
where={'wallet_id': wallet_id}, where={'wallet_id': wallet_id},
select={'currency': True, 'amount': True} select={'currency': True, 'amount': True}
) )
async def get_transaction_history(wallet_id, start_date=None, end_date=None, include_closed=False): async def get_transaction_history(self, wallet_id: str, start_date: Optional[str] = None,
end_date: Optional[str] = None, include_closed: bool = False):
filters = {'wallet_id': wallet_id} filters = {'wallet_id': wallet_id}
if not include_closed: if not include_closed:
filters['closed'] = False filters['closed'] = False
@ -94,19 +138,19 @@ async def get_transaction_history(wallet_id, start_date=None, end_date=None, inc
if end_date: if end_date:
filters['timestamp'] = {'lte': end_date} filters['timestamp'] = {'lte': end_date}
return await prisma_client.transaction.find_many( return await self.prisma.transaction.find_many(
where=filters, where=filters,
order={'timestamp': 'desc'} order={'timestamp': 'desc'}
) )
async def close_transaction(transaction_id): async def close_transaction(self, transaction_id: int):
await prisma_client.transaction.update( return await self.prisma.transaction.update(
where={'id': transaction_id}, where={'id': transaction_id},
data={'closed': True} data={'closed': True}
) )
async def get_open_transactions(wallet_id, currency): async def get_open_transactions(self, wallet_id: str, currency: str):
return await prisma_client.transaction.find_many( return await self.prisma.transaction.find_many(
where={ where={
'wallet_id': wallet_id, 'wallet_id': wallet_id,
'buy_currency': currency, 'buy_currency': currency,
@ -115,8 +159,8 @@ async def get_open_transactions(wallet_id, currency):
order={'timestamp': 'asc'} order={'timestamp': 'asc'}
) )
async def calculate_current_holdings(wallet_id): async def calculate_current_holdings(self, wallet_id: str):
transactions = await prisma_client.transaction.group_by( transactions = await self.prisma.transaction.group_by(
by=['buy_currency'], by=['buy_currency'],
where={'wallet_id': wallet_id, 'closed': False}, where={'wallet_id': wallet_id, 'closed': False},
_sum={'buy_amount': True, 'sell_amount': True} _sum={'buy_amount': True, 'sell_amount': True}
@ -129,14 +173,12 @@ async def calculate_current_holdings(wallet_id):
for t in transactions if t._sum.buy_amount > (t._sum.sell_amount or 0) for t in transactions if t._sum.buy_amount > (t._sum.sell_amount or 0)
] ]
STABLECOINS = ['USDC', 'USDT', 'SOL'] async def is_transaction_closed(self, wallet_id: str, transaction_id: int) -> bool:
transaction = await self.prisma.transaction.find_unique(
async def is_transaction_closed(wallet_id, transaction_id):
transaction = await prisma_client.transaction.find_unique(
where={'id': transaction_id} where={'id': transaction_id}
) )
if transaction: if transaction:
sold_amount = await prisma_client.transaction.aggregate( sold_amount = await self.prisma.transaction.aggregate(
_sum={'sell_amount': True}, _sum={'sell_amount': True},
where={ where={
'wallet_id': wallet_id, 'wallet_id': wallet_id,
@ -147,19 +189,20 @@ async def is_transaction_closed(wallet_id, transaction_id):
return sold_amount._sum.sell_amount >= transaction.buy_amount return sold_amount._sum.sell_amount >= transaction.buy_amount
return False return False
async def close_completed_transactions(wallet_id): async def close_completed_transactions(self, wallet_id: str):
transactions = await prisma_client.transaction.find_many( transactions = await self.prisma.transaction.find_many(
where={ where={
'wallet_id': wallet_id, 'wallet_id': wallet_id,
'closed': False, 'closed': False,
'buy_currency': {'notIn': STABLECOINS} 'buy_currency': {'notIn': self.STABLECOINS}
} }
) )
for transaction in transactions: for transaction in transactions:
if await is_transaction_closed(wallet_id, transaction.id): if await self.is_transaction_closed(wallet_id, transaction.id):
await close_transaction(transaction.id) await self.close_transaction(transaction.id)
async def get_profit_loss(wallet_id, currency, start_date=None, end_date=None): async def get_profit_loss(self, wallet_id: str, currency: str,
start_date: Optional[str] = None, end_date: Optional[str] = None):
filters = { filters = {
'wallet_id': wallet_id, 'wallet_id': wallet_id,
'OR': [ 'OR': [
@ -172,7 +215,7 @@ async def get_profit_loss(wallet_id, currency, start_date=None, end_date=None):
if end_date: if end_date:
filters['timestamp'] = {'lte': end_date} filters['timestamp'] = {'lte': end_date}
result = await prisma_client.transaction.aggregate( result = await self.prisma.transaction.aggregate(
_sum={ _sum={
'sell_value': True, 'sell_value': True,
'buy_value': True 'buy_value': True
@ -181,51 +224,35 @@ async def get_profit_loss(wallet_id, currency, start_date=None, end_date=None):
) )
return (result._sum.sell_value or 0) - (result._sum.buy_value or 0) return (result._sum.sell_value or 0) - (result._sum.buy_value or 0)
# # # # # # USERS # User management methods
def get_or_create_user(self, email: str, google_id: str):
# For this example, we'll use a simple dictionary to store users user = next((u for u in self.users.values() if u['email'] == email), None)
users = {
"db": {"id": 1, "username": "db", "email": "user1@example.com", "password": "db"},
"popov": {"id": 2, "username": "popov", "email": "user2@example.com", "password": "popov"}
}
def get_or_create_user(email, google_id):
user = next((u for u in users.values() if u['email'] == email), None)
if not user: if not user:
user_id = max(u['id'] for u in users.values()) + 1 user_id = max(u['id'] for u in self.users.values()) + 1
username = email.split('@')[0] # Use the part before @ as username username = email.split('@')[0]
user = { user = {
'id': user_id, 'id': user_id,
'username': username, 'username': username,
'email': email, 'email': email,
'google_id': google_id 'google_id': google_id
} }
users[username] = user self.users[username] = user
return user return user
def authenticate_user(username, password): def authenticate_user(self, username: str, password: str):
""" user = self.users.get(username)
Authenticate a user based on username and password.
Returns user data if authentication is successful, None otherwise.
"""
user = users.get(username)
if user and user['password'] == password: if user and user['password'] == password:
return {"id": user['id'], "username": user['username'], "email": user['email']} return {"id": user['id'], "username": user['username'], "email": user['email']}
return None return None
def get_user_by_id(user_id): def get_user_by_id(self, user_id: int):
""" for user in self.users.values():
Retrieve a user by their ID.
"""
for user in users.values():
if user['id'] == int(user_id): if user['id'] == int(user_id):
return {"id": user['id'], "username": user['username'], "email": user['email']} return {"id": user['id'], "username": user['username'], "email": user['email']}
return None return None
def store_api_key(user_id, api_key): def store_api_key(self, user_id: int, api_key: str):
"""
Store the generated API key for a user.
"""
# In a real application, you would store this in a database
# For this example, we'll just print it
print(f"Storing API key {api_key} for user {user_id}") print(f"Storing API key {api_key} for user {user_id}")
storage = Storage()

View File

@ -14,13 +14,14 @@ from config import LIQUIDITY_TOKENS, YOUR_WALLET
from modules import storage, utils, SolanaAPI from modules import storage, utils, SolanaAPI
from modules.utils import async_safe_call, decode_instruction_data from modules.utils import async_safe_call, decode_instruction_data
from modules.storage import Storage
import os import os
import logging import logging
from datetime import datetime from datetime import datetime
on_transaction = None on_transaction = None
def init_app(tr_handler=None): async def init_app(tr_handler=None):
global on_transaction global on_transaction
on_transaction = tr_handler on_transaction = tr_handler
app = Flask(__name__, template_folder='../templates', static_folder='../static') app = Flask(__name__, template_folder='../templates', static_folder='../static')
@ -30,6 +31,15 @@ def init_app(tr_handler=None):
login_manager = LoginManager(app) login_manager = LoginManager(app)
login_manager.login_view = 'login' login_manager.login_view = 'login'
storage = Storage()
# Ensure database connection
async def ensure_storage_connection():
if not storage.is_connected():
await storage.connect()
asyncio.run(ensure_storage_connection())
# oauth = OAuth(app) # oauth = OAuth(app)
# google = oauth.remote_app( # google = oauth.remote_app(
# 'google', # 'google',
@ -481,6 +491,9 @@ def init_app(tr_handler=None):
return app return app
def teardown_app():
# Close the database connection
storage.disconnect()
# Function to find the latest log file # Function to find the latest log file
def get_latest_log_file(wh:bool): def get_latest_log_file(wh:bool):
@ -499,4 +512,4 @@ def get_latest_log_file(wh:bool):
utils.log.error(f"Error fetching latest log file: {e}") utils.log.error(f"Error fetching latest log file: {e}")
return None return None
export = init_app export = init_app, teardown_app