From 2e73148bea6886d716cacefc6e0359064d2c1fdf Mon Sep 17 00:00:00 2001 From: Ilya Kreymer Date: Mon, 14 Aug 2023 18:29:28 -0700 Subject: [PATCH] fix redis connection leaks + exclusions error: (fixes #1065) (#1066) * fix redis connection leaks + exclusions error: (fixes #1065) - use contextmanager for accessing redis to ensure redis.close() is always called - add get_redis_client() to k8sapi to ensure unified place to get redis client - use connectionpool.from_url() until redis 5.0.0 is released to ensure auto close and single client settings are applied - also: catch invalid regex passed to re.compile() in queue regex check, return 400 instead of 500 for invalid regex - redis requirements: bump to 5.0.0rc2 --- backend/btrixcloud/basecrawls.py | 17 ++-- backend/btrixcloud/crawls.py | 153 ++++++++++++++++--------------- backend/btrixcloud/k8sapi.py | 14 +++ backend/btrixcloud/operator.py | 34 ++++--- backend/requirements.txt | 2 +- 5 files changed, 128 insertions(+), 92 deletions(-) diff --git a/backend/btrixcloud/basecrawls.py b/backend/btrixcloud/basecrawls.py index cd1c344f..73e1cc34 100644 --- a/backend/btrixcloud/basecrawls.py +++ b/backend/btrixcloud/basecrawls.py @@ -6,10 +6,11 @@ import os from datetime import timedelta from typing import Optional, List, Union import urllib.parse +import contextlib from pydantic import UUID4 from fastapi import HTTPException, Depends -from redis import asyncio as aioredis, exceptions +from redis import exceptions from .models import ( CrawlFile, @@ -216,8 +217,8 @@ class BaseCrawlOps: # more responsive, saves db update in operator if crawl.state in RUNNING_STATES: try: - redis = await self.get_redis(crawl.id) - crawl.stats = await get_redis_crawl_stats(redis, crawl.id) + async with self.get_redis(crawl.id) as redis: + crawl.stats = await get_redis_crawl_stats(redis, crawl.id) # redis not available, ignore except exceptions.ConnectionError: pass @@ -281,13 +282,17 @@ class BaseCrawlOps: for update in updates: await self.crawls.find_one_and_update(*update) + @contextlib.asynccontextmanager async def get_redis(self, crawl_id): """get redis url for crawl id""" redis_url = self.crawl_manager.get_redis_url(crawl_id) - return await aioredis.from_url( - redis_url, encoding="utf-8", decode_responses=True - ) + redis = await self.crawl_manager.get_redis_client(redis_url) + + try: + yield redis + finally: + await redis.close() async def add_to_collection( self, crawl_ids: List[uuid.UUID], collection_id: uuid.UUID, org: Organization diff --git a/backend/btrixcloud/crawls.py b/backend/btrixcloud/crawls.py index 2ed4c5a7..80db4f3f 100644 --- a/backend/btrixcloud/crawls.py +++ b/backend/btrixcloud/crawls.py @@ -363,23 +363,26 @@ class CrawlOps(BaseCrawlOps): total = 0 results = [] - redis = None try: - redis = await self.get_redis(crawl_id) + async with self.get_redis(crawl_id) as redis: + total = await self._crawl_queue_len(redis, f"{crawl_id}:q") + results = await self._crawl_queue_range( + redis, f"{crawl_id}:q", offset, count + ) + results = [json.loads(result)["url"] for result in results] - total = await self._crawl_queue_len(redis, f"{crawl_id}:q") - results = await self._crawl_queue_range( - redis, f"{crawl_id}:q", offset, count - ) - results = [json.loads(result)["url"] for result in results] except exceptions.ConnectionError: # can't connect to redis, likely not initialized yet pass matched = [] if regex: - regex = re.compile(regex) + try: + regex = re.compile(regex) + except re.error as exc: + raise HTTPException(status_code=400, detail="invalid_regex") from exc + matched = [result for result in results if regex.search(result)] return {"total": total, "results": results, "matched": matched} @@ -387,25 +390,29 @@ class CrawlOps(BaseCrawlOps): async def match_crawl_queue(self, crawl_id, regex): """get list of urls that match regex""" total = 0 - redis = None - - try: - redis = await self.get_redis(crawl_id) - total = await self._crawl_queue_len(redis, f"{crawl_id}:q") - except exceptions.ConnectionError: - # can't connect to redis, likely not initialized yet - pass - - regex = re.compile(regex) matched = [] step = 50 - for count in range(0, total, step): - results = await self._crawl_queue_range(redis, f"{crawl_id}:q", count, step) - for result in results: - url = json.loads(result)["url"] - if regex.search(url): - matched.append(url) + async with self.get_redis(crawl_id) as redis: + try: + total = await self._crawl_queue_len(redis, f"{crawl_id}:q") + except exceptions.ConnectionError: + # can't connect to redis, likely not initialized yet + pass + + try: + regex = re.compile(regex) + except re.error as exc: + raise HTTPException(status_code=400, detail="invalid_regex") from exc + + for count in range(0, total, step): + results = await self._crawl_queue_range( + redis, f"{crawl_id}:q", count, step + ) + for result in results: + url = json.loads(result)["url"] + if regex.search(url): + matched.append(url) return {"total": total, "matched": matched} @@ -413,56 +420,58 @@ class CrawlOps(BaseCrawlOps): """filter out urls that match regex""" # pylint: disable=too-many-locals total = 0 - redis = None - q_key = f"{crawl_id}:q" s_key = f"{crawl_id}:s" - - try: - redis = await self.get_redis(crawl_id) - total = await self._crawl_queue_len(redis, f"{crawl_id}:q") - except exceptions.ConnectionError: - # can't connect to redis, likely not initialized yet - pass - - dircount = -1 - regex = re.compile(regex) step = 50 - - count = 0 num_removed = 0 - # pylint: disable=fixme - # todo: do this in a more efficient way? - # currently quite inefficient as redis does not have a way - # to atomically check and remove value from list - # so removing each jsob block by value - while count < total: - if dircount == -1 and count > total / 2: - dircount = 1 - results = await self._crawl_queue_range(redis, q_key, count, step) - count += step + async with self.get_redis(crawl_id) as redis: + try: + total = await self._crawl_queue_len(redis, f"{crawl_id}:q") + except exceptions.ConnectionError: + # can't connect to redis, likely not initialized yet + pass - qrems = [] - srems = [] + dircount = -1 - for result in results: - url = json.loads(result)["url"] - if regex.search(url): - srems.append(url) - # await redis.srem(s_key, url) - # res = await self._crawl_queue_rem(redis, q_key, result, dircount) - qrems.append(result) + try: + regex = re.compile(regex) + except re.error as exc: + raise HTTPException(status_code=400, detail="invalid_regex") from exc - if not srems: - continue + count = 0 - await redis.srem(s_key, *srems) - res = await self._crawl_queue_rem(redis, q_key, qrems, dircount) - if res: - count -= res - num_removed += res - print(f"Removed {res} from queue", flush=True) + # pylint: disable=fixme + # todo: do this in a more efficient way? + # currently quite inefficient as redis does not have a way + # to atomically check and remove value from list + # so removing each jsob block by value + while count < total: + if dircount == -1 and count > total / 2: + dircount = 1 + results = await self._crawl_queue_range(redis, q_key, count, step) + count += step + + qrems = [] + srems = [] + + for result in results: + url = json.loads(result)["url"] + if regex.search(url): + srems.append(url) + # await redis.srem(s_key, url) + # res = await self._crawl_queue_rem(redis, q_key, result, dircount) + qrems.append(result) + + if not srems: + continue + + await redis.srem(s_key, *srems) + res = await self._crawl_queue_rem(redis, q_key, qrems, dircount) + if res: + count -= res + num_removed += res + print(f"Removed {res} from queue", flush=True) return num_removed @@ -475,13 +484,13 @@ class CrawlOps(BaseCrawlOps): skip = page * page_size upper_bound = skip + page_size - 1 - try: - redis = await self.get_redis(crawl_id) - errors = await redis.lrange(f"{crawl_id}:e", skip, upper_bound) - total = await redis.llen(f"{crawl_id}:e") - except exceptions.ConnectionError: - # pylint: disable=raise-missing-from - raise HTTPException(status_code=503, detail="redis_connection_error") + async with self.get_redis(crawl_id) as redis: + try: + errors = await redis.lrange(f"{crawl_id}:e", skip, upper_bound) + total = await redis.llen(f"{crawl_id}:e") + except exceptions.ConnectionError: + # pylint: disable=raise-missing-from + raise HTTPException(status_code=503, detail="redis_connection_error") parsed_errors = parse_jsonl_error_messages(errors) return parsed_errors, total diff --git a/backend/btrixcloud/k8sapi.py b/backend/btrixcloud/k8sapi.py index 414ee62a..705c9b11 100644 --- a/backend/btrixcloud/k8sapi.py +++ b/backend/btrixcloud/k8sapi.py @@ -13,6 +13,9 @@ from kubernetes_asyncio.client.api import custom_objects_api from kubernetes_asyncio.utils import create_from_dict from kubernetes_asyncio.client.exceptions import ApiException +from redis.asyncio import Redis +from redis.asyncio.connection import ConnectionPool + from fastapi.templating import Jinja2Templates from .utils import get_templates_dir, dt_now, to_k8s_date @@ -62,6 +65,17 @@ class K8sAPI: ) return redis_url + async def get_redis_client(self, redis_url): + """return redis client with correct params for one-time use""" + # manual settings until redis 5.0.0 is released + pool = ConnectionPool.from_url(redis_url, decode_responses=True) + redis = Redis( + connection_pool=pool, + decode_responses=True, + ) + redis.auto_close_connection_pool = True + return redis + # pylint: disable=too-many-arguments async def new_crawl_job( self, cid, userid, oid, scale=1, crawl_timeout=0, manual=True diff --git a/backend/btrixcloud/operator.py b/backend/btrixcloud/operator.py index 71b0da13..03dccd32 100644 --- a/backend/btrixcloud/operator.py +++ b/backend/btrixcloud/operator.py @@ -13,7 +13,6 @@ import yaml import humanize from pydantic import BaseModel -from redis import asyncio as aioredis from .utils import ( from_k8s_date, @@ -430,6 +429,7 @@ class BtrixOperator(K8sAPI): async def cancel_crawl(self, redis_url, crawl_id, cid, status, state): """immediately cancel crawl with specified state return true if db mark_finished update succeeds""" + redis = None try: redis = await self._get_redis(redis_url) await self.mark_finished(redis, crawl_id, uuid.UUID(cid), status, state) @@ -438,6 +438,10 @@ class BtrixOperator(K8sAPI): except: return False + finally: + if redis: + await redis.close() + def _done_response(self, status, finalized=False): """done response for removing crawl""" return { @@ -462,15 +466,16 @@ class BtrixOperator(K8sAPI): """init redis, ensure connectivity""" redis = None try: - redis = await aioredis.from_url( - redis_url, encoding="utf-8", decode_responses=True - ) + redis = await self.get_redis_client(redis_url) # test connection await redis.ping() return redis # pylint: disable=bare-except except: + if redis: + await redis.close() + return None async def check_if_finished(self, crawl, status): @@ -512,16 +517,16 @@ class BtrixOperator(K8sAPI): status.resync_after = self.fast_retry_secs return status - # set state to running (if not already) - if status.state not in RUNNING_STATES: - await self.set_state( - "running", - status, - crawl.id, - allowed_from=["starting", "waiting_capacity"], - ) - try: + # set state to running (if not already) + if status.state not in RUNNING_STATES: + await self.set_state( + "running", + status, + crawl.id, + allowed_from=["starting", "waiting_capacity"], + ) + file_done = await redis.lpop(self.done_key) while file_done: @@ -547,6 +552,9 @@ class BtrixOperator(K8sAPI): print(f"Crawl get failed: {exc}, will try again") return status + finally: + await redis.close() + def check_if_pods_running(self, pods): """check if at least one crawler pod has started""" try: diff --git a/backend/requirements.txt b/backend/requirements.txt index 38a1ef60..46f718ef 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -5,7 +5,7 @@ loguru aiofiles kubernetes-asyncio==22.6.5 aiobotocore -redis>=4.2.0rc1 +redis>=5.0.0rc2 pyyaml jinja2 humanize