diff --git a/.gitignore b/.gitignore index e4a1ce7..8045718 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ crypto/sol/logs/token_info.json crypto/sol/logs/transation_details.json .env app_data.db +crypto/sol/.vs/* diff --git a/crypto/sol/app.py b/crypto/sol/app.py index 8867c9f..3d6ecb6 100644 --- a/crypto/sol/app.py +++ b/crypto/sol/app.py @@ -1,233 +1,180 @@ -import concurrent.futures -import threading -import queue -import uvicorn -import os -from dotenv import load_dotenv +import asyncio import datetime import json -import websockets import logging -from modules.webui import init_app -from modules.utils import telegram_utils, logging, get_pk -from modules.log_processor import watch_for_new_logs -from modules.SolanaAPI import SAPI -from config import DO_WATCH_WALLET -from asgiref.wsgi import WsgiToAsgi -from multiprocessing import Process -import time +import os +from typing import Dict, Any +import uvicorn +from asgiref.wsgi import WsgiToAsgi +from dotenv import load_dotenv + +from config import DO_WATCH_WALLET +from modules.SolanaAPI import SAPI +from modules.log_processor import watch_for_new_logs +from modules.utils import telegram_utils +from modules.webui import init_app, teardown_app + +# Load environment variables load_dotenv() load_dotenv('.env.secret') -def save_log(log): - try: - os.makedirs('./logs', exist_ok=True) - timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") - filename = f"./logs/log_{timestamp}.json" - with open(filename, 'w') as f: - json.dump(log, f, indent=2) - except Exception as e: - logging.error(f"Error saving RPC log: {e}") - -PROCESSING_LOG = False -log_queue = queue.Queue() - -def process_log(log_result): - global PROCESSING_LOG - tr_details = { - "order_id": None, - "token_in": None, - "token_out": None, - "amount_in": 0, - "amount_out": 0, - "amount_in_USD": 0, - "amount_out_USD": 0, - "percentage_swapped": 0 - } - - if log_result['value']['err']: - return - - logs = log_result['value']['logs'] - try: - PROCESSING_LOG = True - swap_operations = ['Program log: Instruction: Swap', 'Program log: Instruction: Swap2', - 'Program log: Instruction: SwapExactAmountIn', 'Program log: Instruction: SwapV2'] - - if any(op in logs for op in swap_operations): - save_log(log_result) - tx_signature_str = log_result['value']['signature'] - - before_source_balance = 0 - source_token_change = 0 - - i = 0 - while i < len(logs): - log_entry = logs[i] - - if tr_details["order_id"] is None and "order_id" in log_entry: - tr_details["order_id"] = log_entry.split(":")[-1].strip() - tr_details["token_in"] = logs[i + 1].split(":")[-1].strip() - tr_details["token_out"] = logs[i + 2].split(":")[-1].strip() - - if "source_token_change" in log_entry: - parts = log_entry.split(", ") - for part in parts: - if "source_token_change" in part: - tr_details["amount_in"] = float(part.split(":")[-1].strip()) / 10 ** 6 - elif "destination_token_change" in part: - tr_details["amount_out"] = float(part.split(":")[-1].strip()) / 10 ** 6 - - if "before_source_balance" in log_entry: - parts = log_entry.split(", ") - for part in parts: - if "before_source_balance" in part: - before_source_balance = float(part.split(":")[-1].strip()) / 10 ** 6 - if "source_token_change" in log_entry: - parts = log_entry.split(", ") - for part in parts: - if "source_token_change" in part: - source_token_change = float(part.split(":")[-1].strip()) / 10 ** 6 - i += 1 - - try: - if tr_details["token_in"] is None or tr_details["token_out"] is None or \ - tr_details["amount_in"] == 0 or tr_details["amount_out"] == 0: - logging.warning("Incomplete swap details found in logs. Getting details from transaction") - tr_details = SAPI.get_transaction_details_info(tx_signature_str, logs) - - - if before_source_balance > 0 and source_token_change > 0: - tr_details["percentage_swapped"] = (source_token_change / before_source_balance) * 100 - if tr_details["percentage_swapped"] > 100: - tr_details["percentage_swapped"] = tr_details["percentage_swapped"] / 1000 - - try: - token_in = SAPI.dex.TOKENS_INFO[tr_details["token_in"]] - token_out = SAPI.dex.TOKENS_INFO[tr_details["token_out"]] - - tr_details["symbol_in"] = token_in.get('symbol') - tr_details["symbol_out"] = token_out.get('symbol') - tr_details['amount_in_USD'] = tr_details['amount_in'] * token_in.get('price', 0) - tr_details['amount_out_USD'] = tr_details['amount_out'] * token_out.get('price', 0) - - except Exception as e: - logging.error(f"Error fetching token prices: {e}") - - message_text = ( - f"Swap detected: \n" - f"{tr_details['amount_in_USD']:.2f} worth of {tr_details['symbol_in']} " - f"({tr_details['percentage_swapped']:.2f}% ) swapped for {tr_details['symbol_out']} \n" - ) - telegram_utils.send_telegram_message(message_text) - SAPI.follow_move(tr_details) - SAPI.save_token_info() - - except Exception as e: - logging.error(f"Error acquiring log details and following: {e}") - telegram_utils.send_telegram_message(f"Not followed! Error following move.") - - except Exception as e: - logging.error(f"Error processing log: {e}") - - PROCESSING_LOG = False - return tr_details - -# def process_messages(websocket): -# try: -# while True: -# response = websocket.recv() -# response_data = json.loads(response) -# logger.debug(f"Received response: {response_data}") - -# if 'result' in response_data: -# new_sub_id = response_data['result'] -# if int(new_sub_id) > 1: -# subscription_id = new_sub_id -# logger.info(f"Subscription successful. New id: {subscription_id}") -# elif new_sub_id: -# logger.info(f"Existing subscription confirmed: {subscription_id}") -# else: -# return None -# return subscription_id -# elif 'params' in response_data: -# log = response_data['params']['result'] -# logger.debug(f"Received transaction log: {log}") -# log_queue.put(log) -# else: -# logger.warning(f"Unexpected response: {response_data}") - -# except Exception as e: -# logger.error(f"An error occurred: {e}") - -def log_processor_worker(): - with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: - while True: - try: - log = log_queue.get() - executor.submit(process_log, log) - except Exception as e: - logger.error(f"Error in log processor worker: {e}") - finally: - log_queue.task_done() - +# Configure logging logger = logging.getLogger(__name__) -app = init_app() -asgi_app = WsgiToAsgi(app) -def run_with_retry(task_func, *args, **kwargs): - while True: +class LogProcessor: + @staticmethod + def save_log(log: Dict[str, Any]) -> None: + """Save log to JSON file with timestamp.""" try: - task_func(*args, **kwargs) + os.makedirs('./logs', exist_ok=True) + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") + filename = f"./logs/log_{timestamp}.json" + with open(filename, 'w') as f: + json.dump(log, f, indent=2) except Exception as e: - error_msg = f"Error in task {task_func.__name__}: {e}" - logger.error(error_msg) - telegram_utils.send_telegram_message(error_msg) - time.sleep(5) + logger.error(f"Error saving RPC log: {e}") -def init_bot(): - # Initialize bot components - # pk = get_pk() - telegram_utils.initialize() - telegram_utils.send_telegram_message("Solana Agent Started. Connecting to mainnet...") - - # Start monitoring tasks - threading.Thread(target=run_with_retry, args=(watch_for_new_logs,), daemon=True).start() - - if DO_WATCH_WALLET: - threading.Thread(target=run_with_retry, args=(SAPI.wallet_watch_loop,), daemon=True).start() + @staticmethod + def extract_transaction_details(logs: list) -> Dict[str, Any]: + """Extract transaction details from logs.""" + tr_details = { + "order_id": None, + "token_in": None, + "token_out": None, + "amount_in": 0, + "amount_out": 0, + "amount_in_USD": 0, + "amount_out_USD": 0, + "percentage_swapped": 0 + } -def run_server(): - uvicorn.run( + before_source_balance = 0 + source_token_change = 0 + + for i, log_entry in enumerate(logs): + if tr_details["order_id"] is None and "order_id" in log_entry: + tr_details["order_id"] = log_entry.split(":")[-1].strip() + tr_details["token_in"] = logs[i + 1].split(":")[-1].strip() + tr_details["token_out"] = logs[i + 2].split(":")[-1].strip() + + if "source_token_change" in log_entry: + parts = log_entry.split(", ") + for part in parts: + if "source_token_change" in part: + tr_details["amount_in"] = float(part.split(":")[-1].strip()) / 10 ** 6 + elif "destination_token_change" in part: + tr_details["amount_out"] = float(part.split(":")[-1].strip()) / 10 ** 6 + + if "before_source_balance" in log_entry: + before_source_balance = float(log_entry.split(":")[-1].strip()) / 10 ** 6 + if "source_token_change" in log_entry: + source_token_change = float(log_entry.split(":")[-1].strip()) / 10 ** 6 + + if before_source_balance > 0 and source_token_change > 0: + tr_details["percentage_swapped"] = (source_token_change / before_source_balance) * 100 + if tr_details["percentage_swapped"] > 100: + tr_details["percentage_swapped"] /= 1000 + + return tr_details + + @staticmethod + async def process_log(log_result: Dict[str, Any]) -> Dict[str, Any]: + """Process a single log entry.""" + if log_result['value']['err']: + return + + logs = log_result['value']['logs'] + swap_operations = [ + 'Program log: Instruction: Swap', + 'Program log: Instruction: Swap2', + 'Program log: Instruction: SwapExactAmountIn', + 'Program log: Instruction: SwapV2' + ] + + try: + if not any(op in logs for op in swap_operations): + return + + LogProcessor.save_log(log_result) + tx_signature = log_result['value']['signature'] + tr_details = LogProcessor.extract_transaction_details(logs) + + if not all([tr_details["token_in"], tr_details["token_out"], + tr_details["amount_in"], tr_details["amount_out"]]): + tr_details = await SAPI.get_transaction_details_info(tx_signature, logs) + + # Update token information + token_in = SAPI.dex.TOKENS_INFO[tr_details["token_in"]] + token_out = SAPI.dex.TOKENS_INFO[tr_details["token_out"]] + + tr_details.update({ + "symbol_in": token_in.get('symbol'), + "symbol_out": token_out.get('symbol'), + "amount_in_USD": tr_details['amount_in'] * token_in.get('price', 0), + "amount_out_USD": tr_details['amount_out'] * token_out.get('price', 0) + }) + + # Send notification + message = ( + f"Swap detected: \n" + f"{tr_details['amount_in_USD']:.2f} worth of {tr_details['symbol_in']} " + f"({tr_details['percentage_swapped']:.2f}% ) swapped for {tr_details['symbol_out']}" + ) + await telegram_utils.send_telegram_message(message) + + # Follow up actions + await SAPI.follow_move(tr_details) + await SAPI.save_token_info() + + except Exception as e: + logger.error(f"Error processing log: {e}") + await telegram_utils.send_telegram_message("Not followed! Error following move.") + + return tr_details + +class Bot: + @staticmethod + async def initialize(): + """Initialize bot and start monitoring.""" + await telegram_utils.initialize() + await telegram_utils.send_telegram_message("Solana Agent Started. Connecting to mainnet...") + + asyncio.create_task(watch_for_new_logs()) + if DO_WATCH_WALLET: + asyncio.create_task(SAPI.wallet_watch_loop()) + +async def start_server(): + """Run the ASGI server.""" + config = uvicorn.Config( "app:asgi_app", host="0.0.0.0", port=3001, log_level="info", reload=True ) + server = uvicorn.Server(config) + await server.serve() -def main(): - # Start log processor worker - log_processor_thread = threading.Thread(target=log_processor_worker, daemon=True) - log_processor_thread.start() +async def main(): + """Main application entry point.""" + # Initialize app and create ASGI wrapper + app = await init_app() + global asgi_app + asgi_app = WsgiToAsgi(app) # Initialize bot - init_bot() - - # Start server in a separate process - server_process = Process(target=run_server) - server_process.start() + await Bot.initialize() + # Start server try: - # Keep main process running - while True: - time.sleep(1) + await start_server() except KeyboardInterrupt: logger.info("Shutting down...") - finally: - server_process.terminate() - server_process.join() + await teardown_app() if __name__ == '__main__': - main() \ No newline at end of file + try: + asyncio.run(main()) + except KeyboardInterrupt: + logger.info("Application terminated by user") \ No newline at end of file diff --git a/crypto/sol/modules/SolanaAPI.py b/crypto/sol/modules/SolanaAPI.py index 59e57f0..cbce708 100644 --- a/crypto/sol/modules/SolanaAPI.py +++ b/crypto/sol/modules/SolanaAPI.py @@ -41,6 +41,7 @@ from solders import message from jupiter_python_sdk.jupiter import Jupiter import asyncio +import contextlib import json import logging import random @@ -99,6 +100,8 @@ class SolanaWS: async def connect(self): while True: + if self.websocket is None or self.websocket.closed: + await self.connect() try: current_url = random.choice(SOLANA_ENDPOINTS) self.websocket = await websockets.connect(current_url, ping_interval=30, ping_timeout=10) @@ -262,8 +265,13 @@ class SolanaAPI: async def process_messages(self, solana_ws): while True: - message = await solana_ws.message_queue.get() - await self.process_transaction(message) + try: + message = await solana_ws.message_queue.get() + await self.process_transaction(message) + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error processing message: {e}") _first_subscription = True @@ -1045,27 +1053,31 @@ class SolanaDEX: base_url = "https://api.coingecko.com/api/v3/simple/token_price/solana" prices = {} - async def fetch_single_price(session, address): + async def fetch_single_price(session, address, retries=3, backoff_factor=0.5): params = { "contract_addresses": address, "vs_currencies": self.DISPLAY_CURRENCY.lower() } - try: - async with session.get(base_url, params=params) as response: - if response.status == 200: - data = await response.json() - if address in data and self.DISPLAY_CURRENCY.lower() in data[address]: - return address, data[address][self.DISPLAY_CURRENCY.lower()] - else: - logging.warning(f"Failed to get price for {address} from CoinGecko. Status: {response.status}") - except Exception as e: - logging.error(f"Error fetching price for {address} from CoinGecko: {str(e)}") + for attempt in range(retries): + try: + async with session.get(base_url, params=params) as response: + if response.status == 200: + data = await response.json() + if address in data and self.DISPLAY_CURRENCY.lower() in data[address]: + return address, data[address][self.DISPLAY_CURRENCY.lower()] + elif response.status == 429: + logging.warning(f"Rate limit exceeded for {address}. Retrying...") + await asyncio.sleep(backoff_factor * (2 ** attempt)) + else: + logging.warning(f"Failed to get price for {address} from CoinGecko. Status: {response.status}") + except Exception as e: + logging.error(f"Error fetching price for {address} from CoinGecko: {str(e)}") return address, None async with aiohttp.ClientSession() as session: tasks = [fetch_single_price(session, address) for address in token_addresses] results = await asyncio.gather(*tasks) - + for address, price in results: if price is not None: prices[address] = price @@ -1169,7 +1181,11 @@ class SolanaDEX: async def get_wallet_balances(self, wallet_address, doGetTokenName=True): balances = {} + if not asyncio.get_event_loop().is_running(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) logging.info(f"Getting balances for wallet: {wallet_address}") + response = None try: response = await self.solana_client.get_token_accounts_by_owner_json_parsed( Pubkey.from_string(wallet_address), @@ -1189,16 +1205,17 @@ class SolanaDEX: mint = info['mint'] decimals = int(info['tokenAmount']['decimals']) amount = int(info['tokenAmount']['amount']) - amount = float(amount /10**decimals) + amount = int(amount) if amount > 1: + amount = float(amount / 10**decimals) if mint in self.TOKENS_INFO: token_name = self.TOKENS_INFO[mint].get('symbol') elif doGetTokenName: token_name = await self.get_token_metadata_symbol(mint) or 'N/A' self.TOKENS_INFO[mint] = {'symbol': token_name} await asyncio.sleep(2) - - self.TOKENS_INFO[mint]['holdedAmount'] = round(amount,decimals) + + self.TOKENS_INFO[mint]['holdedAmount'] = round(amount, decimals) self.TOKENS_INFO[mint]['decimals'] = decimals balances[mint] = { 'name': token_name or 'N/A', @@ -1227,7 +1244,10 @@ class SolanaDEX: except Exception as e: logging.error(f"Error getting wallet balances: {str(e)}") - logging.info(f"Found {len(response.value)} ({len(balances)} non zero) token accounts for wallet: {wallet_address}") + if response and response.value: + logging.info(f"Found {len(response.value)} ({len(balances)} non zero) token accounts for wallet: {wallet_address}") + else: + logging.warning(f"No token accounts found for wallet: {wallet_address}") return balances async def convert_balances_to_currency(self, balances, sol_price): diff --git a/crypto/sol/modules/log_processor.py b/crypto/sol/modules/log_processor.py index 63af635..b23a90a 100644 --- a/crypto/sol/modules/log_processor.py +++ b/crypto/sol/modules/log_processor.py @@ -1,7 +1,7 @@ import os import asyncio from pathlib import Path -from .storage import store_transaction, prisma_client +from .storage import Storage from .SolanaAPI import SolanaAPI LOG_DIRECTORY = "./logs" @@ -28,7 +28,7 @@ async def process_log_file(file_path): transaction_data = await solana_api.process_wh(data) # Check if the transaction already exists - existing_transaction = await prisma_client.transaction.find_first( + existing_transaction = await Storage.get_prisma().transaction.find_first( where={'solana_signature': solana_signature} ) @@ -46,7 +46,8 @@ async def process_log_file(file_path): 'solana_signature': solana_signature, 'details': details } - await store_transaction(transaction_data) + storage = Storage() + await storage.store_transaction(transaction_data) # Rename the file to append '_saved' new_file_path = file_path.with_name(file_path.stem + "_saved" + file_path.suffix) diff --git a/crypto/sol/modules/storage.py b/crypto/sol/modules/storage.py index 0c6d351..09a2af3 100644 --- a/crypto/sol/modules/storage.py +++ b/crypto/sol/modules/storage.py @@ -1,8 +1,10 @@ -import sys -import os -sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - from typing import NamedTuple +from enum import Enum +from datetime import datetime +import json +from prisma import Prisma +from typing import List, Optional, Dict +import asyncio class Transaction(NamedTuple): wallet: str @@ -14,218 +16,243 @@ class Transaction(NamedTuple): amount_out: float value_out_usd: float tx_signature: str - -from enum import Enum -from datetime import datetime -from enum import Enum -import json -from prisma import Prisma +class TransactionType(Enum): + BUY = "BUY" + SELL = "SELL" class TransactionStatus(Enum): PENDING = "PENDING" SENT = "SENT" CONFIRMED = "CONFIRMED" -# Initialize the Prisma client -prisma_client = Prisma() -async def init_db(): - await prisma_client.connect() - -async def store_transaction(transaction: Transaction): - """ - Store a transaction record in the database. - """ - await prisma_client.transaction.create( - data={ - 'wallet_id': transaction.wallet, - 'timestamp': datetime.now().isoformat(), - 'type': transaction.transaction_type, - 'sell_currency': transaction.symbol_in, - 'sell_amount': transaction.amount_in, - 'sell_value': transaction.value_in_usd, - 'buy_currency': transaction.symbol_out, - 'buy_amount': transaction.amount_out, - 'buy_value': transaction.value_out_usd, - 'solana_signature': transaction.tx_signature, - 'details': json.dumps({}), - 'status': TransactionStatus.PENDING.value - } - ) - -async def update_holdings(wallet_id, currency, amount_change): - holding = await prisma_client.holding.find_first( - where={ - 'wallet_id': wallet_id, - 'currency': currency - } - ) - if holding: - new_amount = holding.amount + amount_change - await prisma_client.holding.update( - where={'id': holding.id}, - data={ - 'amount': new_amount, - 'last_updated': datetime.now().isoformat() - } - ) - else: - await prisma_client.holding.create( - data={ - 'wallet_id': wallet_id, - 'currency': currency, - 'amount': amount_change, - 'last_updated': datetime.now().isoformat() - } - ) - -async def get_wallet_holdings(wallet_id): - return await prisma_client.holding.find_many( - where={'wallet_id': wallet_id}, - select={'currency': True, 'amount': True} - ) - -async def get_transaction_history(wallet_id, start_date=None, end_date=None, include_closed=False): - filters = {'wallet_id': wallet_id} - if not include_closed: - filters['closed'] = False - if start_date: - filters['timestamp'] = {'gte': start_date} - if end_date: - filters['timestamp'] = {'lte': end_date} +class Storage: + _instance: Optional['Storage'] = None + _lock = asyncio.Lock() + _initialized = False + prisma = Prisma() # Class-level Prisma instance - return await prisma_client.transaction.find_many( - where=filters, - order={'timestamp': 'desc'} - ) + STABLECOINS = ['USDC', 'USDT', 'SOL'] -async def close_transaction(transaction_id): - await prisma_client.transaction.update( - where={'id': transaction_id}, - data={'closed': True} - ) + def __new__(cls): + if cls._instance is None: + cls._instance = super(Storage, cls).__new__(cls) + return cls._instance -async def get_open_transactions(wallet_id, currency): - return await prisma_client.transaction.find_many( - where={ - 'wallet_id': wallet_id, - 'buy_currency': currency, - 'closed': False - }, - order={'timestamp': 'asc'} - ) + async def __ainit__(self): + if self._initialized: + return + + await self.prisma.connect() + self._initialized = True -async def calculate_current_holdings(wallet_id): - transactions = await prisma_client.transaction.group_by( - by=['buy_currency'], - where={'wallet_id': wallet_id, 'closed': False}, - _sum={'buy_amount': True, 'sell_amount': True} - ) - return [ - { - 'currency': t.buy_currency, - 'amount': t._sum.buy_amount - (t._sum.sell_amount or 0) + @classmethod + async def get_instance(cls) -> 'Storage': + if not cls._instance: + async with cls._lock: + if not cls._instance: + cls._instance = cls() + await cls._instance.__ainit__() + return cls._instance + @classmethod + def get_prisma(cls): + return cls.prisma + async def disconnect(self): + if self._initialized: + await self.prisma.disconnect() + self._initialized = False + + def __init__(self): + self.prisma = Prisma() + self.users = { + "db": {"id": 1, "username": "db", "email": "user1@example.com", "password": "db"}, + "popov": {"id": 2, "username": "popov", "email": "user2@example.com", "password": "popov"} } - for t in transactions if t._sum.buy_amount > (t._sum.sell_amount or 0) - ] -STABLECOINS = ['USDC', 'USDT', 'SOL'] + def is_connected(self): + return self.prisma.is_connected() -async def is_transaction_closed(wallet_id, transaction_id): - transaction = await prisma_client.transaction.find_unique( - where={'id': transaction_id} - ) - if transaction: - sold_amount = await prisma_client.transaction.aggregate( - _sum={'sell_amount': True}, + async def connect(self): + await self.prisma.connect() + + async def disconnect(self): + await self.prisma.disconnect() + + async def store_transaction(self, transaction: Transaction): + return await self.prisma.transaction.create( + data={ + 'wallet_id': transaction.wallet, + 'timestamp': datetime.now().isoformat(), + 'type': transaction.transaction_type, + 'sell_currency': transaction.symbol_in, + 'sell_amount': transaction.amount_in, + 'sell_value': transaction.value_in_usd, + 'buy_currency': transaction.symbol_out, + 'buy_amount': transaction.amount_out, + 'buy_value': transaction.value_out_usd, + 'solana_signature': transaction.tx_signature, + 'details': json.dumps({}), + 'status': TransactionStatus.PENDING.value + } + ) + + async def update_holdings(self, wallet_id: str, currency: str, amount_change: float): + holding = await self.prisma.holding.find_first( where={ 'wallet_id': wallet_id, - 'sell_currency': transaction.buy_currency, - 'timestamp': {'gt': transaction.timestamp} + 'currency': currency } ) - return sold_amount._sum.sell_amount >= transaction.buy_amount - return False + if holding: + new_amount = holding.amount + amount_change + return await self.prisma.holding.update( + where={'id': holding.id}, + data={ + 'amount': new_amount, + 'last_updated': datetime.now().isoformat() + } + ) + else: + return await self.prisma.holding.create( + data={ + 'wallet_id': wallet_id, + 'currency': currency, + 'amount': amount_change, + 'last_updated': datetime.now().isoformat() + } + ) -async def close_completed_transactions(wallet_id): - transactions = await prisma_client.transaction.find_many( - where={ - 'wallet_id': wallet_id, - 'closed': False, - 'buy_currency': {'notIn': STABLECOINS} - } - ) - for transaction in transactions: - if await is_transaction_closed(wallet_id, transaction.id): - await close_transaction(transaction.id) + async def get_wallet_holdings(self, wallet_id: str): + return await self.prisma.holding.find_many( + where={'wallet_id': wallet_id}, + select={'currency': True, 'amount': True} + ) -async def get_profit_loss(wallet_id, currency, start_date=None, end_date=None): - filters = { - 'wallet_id': wallet_id, - 'OR': [ - {'sell_currency': currency}, - {'buy_currency': currency} + async def get_transaction_history(self, wallet_id: str, start_date: Optional[str] = None, + end_date: Optional[str] = None, include_closed: bool = False): + filters = {'wallet_id': wallet_id} + if not include_closed: + filters['closed'] = False + if start_date: + filters['timestamp'] = {'gte': start_date} + if end_date: + filters['timestamp'] = {'lte': end_date} + + return await self.prisma.transaction.find_many( + where=filters, + order={'timestamp': 'desc'} + ) + + async def close_transaction(self, transaction_id: int): + return await self.prisma.transaction.update( + where={'id': transaction_id}, + data={'closed': True} + ) + + async def get_open_transactions(self, wallet_id: str, currency: str): + return await self.prisma.transaction.find_many( + where={ + 'wallet_id': wallet_id, + 'buy_currency': currency, + 'closed': False + }, + order={'timestamp': 'asc'} + ) + + async def calculate_current_holdings(self, wallet_id: str): + transactions = await self.prisma.transaction.group_by( + by=['buy_currency'], + where={'wallet_id': wallet_id, 'closed': False}, + _sum={'buy_amount': True, 'sell_amount': True} + ) + return [ + { + 'currency': t.buy_currency, + 'amount': t._sum.buy_amount - (t._sum.sell_amount or 0) + } + for t in transactions if t._sum.buy_amount > (t._sum.sell_amount or 0) ] - } - if start_date: - filters['timestamp'] = {'gte': start_date} - if end_date: - filters['timestamp'] = {'lte': end_date} - - result = await prisma_client.transaction.aggregate( - _sum={ - 'sell_value': True, - 'buy_value': True - }, - where=filters - ) - return (result._sum.sell_value or 0) - (result._sum.buy_value or 0) -# # # # # # USERS + async def is_transaction_closed(self, wallet_id: str, transaction_id: int) -> bool: + transaction = await self.prisma.transaction.find_unique( + where={'id': transaction_id} + ) + if transaction: + sold_amount = await self.prisma.transaction.aggregate( + _sum={'sell_amount': True}, + where={ + 'wallet_id': wallet_id, + 'sell_currency': transaction.buy_currency, + 'timestamp': {'gt': transaction.timestamp} + } + ) + return sold_amount._sum.sell_amount >= transaction.buy_amount + return False -# For this example, we'll use a simple dictionary to store users -users = { - "db": {"id": 1, "username": "db", "email": "user1@example.com", "password": "db"}, - "popov": {"id": 2, "username": "popov", "email": "user2@example.com", "password": "popov"} -} + async def close_completed_transactions(self, wallet_id: str): + transactions = await self.prisma.transaction.find_many( + where={ + 'wallet_id': wallet_id, + 'closed': False, + 'buy_currency': {'notIn': self.STABLECOINS} + } + ) + for transaction in transactions: + if await self.is_transaction_closed(wallet_id, transaction.id): + await self.close_transaction(transaction.id) -def get_or_create_user(email, google_id): - user = next((u for u in users.values() if u['email'] == email), None) - if not user: - user_id = max(u['id'] for u in users.values()) + 1 - username = email.split('@')[0] # Use the part before @ as username - user = { - 'id': user_id, - 'username': username, - 'email': email, - 'google_id': google_id + async def get_profit_loss(self, wallet_id: str, currency: str, + start_date: Optional[str] = None, end_date: Optional[str] = None): + filters = { + 'wallet_id': wallet_id, + 'OR': [ + {'sell_currency': currency}, + {'buy_currency': currency} + ] } - users[username] = user - return user + if start_date: + filters['timestamp'] = {'gte': start_date} + if end_date: + filters['timestamp'] = {'lte': end_date} + + result = await self.prisma.transaction.aggregate( + _sum={ + 'sell_value': True, + 'buy_value': True + }, + where=filters + ) + return (result._sum.sell_value or 0) - (result._sum.buy_value or 0) -def authenticate_user(username, password): - """ - Authenticate a user based on username and password. - Returns user data if authentication is successful, None otherwise. - """ - user = users.get(username) - if user and user['password'] == password: - return {"id": user['id'], "username": user['username'], "email": user['email']} - return None + # User management methods + def get_or_create_user(self, email: str, google_id: str): + user = next((u for u in self.users.values() if u['email'] == email), None) + if not user: + user_id = max(u['id'] for u in self.users.values()) + 1 + username = email.split('@')[0] + user = { + 'id': user_id, + 'username': username, + 'email': email, + 'google_id': google_id + } + self.users[username] = user + return user -def get_user_by_id(user_id): - """ - Retrieve a user by their ID. - """ - for user in users.values(): - if user['id'] == int(user_id): + def authenticate_user(self, username: str, password: str): + user = self.users.get(username) + if user and user['password'] == password: return {"id": user['id'], "username": user['username'], "email": user['email']} - return None + return None -def store_api_key(user_id, api_key): - """ - Store the generated API key for a user. - """ - # In a real application, you would store this in a database - # For this example, we'll just print it - print(f"Storing API key {api_key} for user {user_id}") + def get_user_by_id(self, user_id: int): + for user in self.users.values(): + if user['id'] == int(user_id): + return {"id": user['id'], "username": user['username'], "email": user['email']} + return None + + def store_api_key(self, user_id: int, api_key: str): + print(f"Storing API key {api_key} for user {user_id}") + + +storage = Storage() diff --git a/crypto/sol/modules/webui.py b/crypto/sol/modules/webui.py index 01b1c05..53a8c60 100644 --- a/crypto/sol/modules/webui.py +++ b/crypto/sol/modules/webui.py @@ -14,13 +14,14 @@ from config import LIQUIDITY_TOKENS, YOUR_WALLET from modules import storage, utils, SolanaAPI from modules.utils import async_safe_call, decode_instruction_data +from modules.storage import Storage import os import logging from datetime import datetime on_transaction = None -def init_app(tr_handler=None): +async def init_app(tr_handler=None): global on_transaction on_transaction = tr_handler app = Flask(__name__, template_folder='../templates', static_folder='../static') @@ -29,6 +30,15 @@ def init_app(tr_handler=None): executor = ThreadPoolExecutor(max_workers=10) # Adjust the number of workers as needed login_manager = LoginManager(app) login_manager.login_view = 'login' + + storage = Storage() + + # Ensure database connection + async def ensure_storage_connection(): + if not storage.is_connected(): + await storage.connect() + + asyncio.run(ensure_storage_connection()) # oauth = OAuth(app) # google = oauth.remote_app( @@ -481,6 +491,9 @@ def init_app(tr_handler=None): return app +def teardown_app(): + # Close the database connection + storage.disconnect() # Function to find the latest log file def get_latest_log_file(wh:bool): @@ -499,4 +512,4 @@ def get_latest_log_file(wh:bool): utils.log.error(f"Error fetching latest log file: {e}") return None -export = init_app +export = init_app, teardown_app