diff --git a/backend/btrixcloud/storages.py b/backend/btrixcloud/storages.py index 3a182113..a1540694 100644 --- a/backend/btrixcloud/storages.py +++ b/backend/btrixcloud/storages.py @@ -11,7 +11,7 @@ from typing import ( TYPE_CHECKING, ) from urllib.parse import urlsplit -from contextlib import asynccontextmanager +from contextlib import asynccontextmanager, contextmanager import asyncio import heapq @@ -275,10 +275,11 @@ class StorageOps: ) as client: yield client, bucket, key - def get_sync_s3_client( - self, storage: S3Storage, use_access=False - ) -> tuple[S3Client, str, str, str]: + @contextmanager + def get_sync_client(self, org: Organization) -> Iterator[tuple[S3Client, str, str]]: """context manager for s3 client""" + storage = self.get_org_primary_storage(org) + endpoint_url = storage.endpoint_url if not endpoint_url.endswith("/"): @@ -289,19 +290,17 @@ class StorageOps: endpoint_url = parts.scheme + "://" + parts.netloc - client = boto3.client( - "s3", - region_name=storage.region, - endpoint_url=endpoint_url, - aws_access_key_id=storage.access_key, - aws_secret_access_key=storage.secret_key, - ) - - public_endpoint_url = ( - storage.endpoint_url if not use_access else storage.access_endpoint_url - ) - - return client, bucket, key, public_endpoint_url + try: + client = boto3.client( + "s3", + region_name=storage.region, + endpoint_url=endpoint_url, + aws_access_key_id=storage.access_key, + aws_secret_access_key=storage.secret_key, + ) + yield client, bucket, key + finally: + client.close() async def verify_storage_upload(self, storage: S3Storage, filename: str) -> None: """Test credentials and storage endpoint by uploading an empty test file""" @@ -367,14 +366,6 @@ class StorageOps: await client.put_object(Bucket=bucket, Key=key, Body=data) - def get_sync_client( - self, org: Organization, use_access=False - ) -> tuple[S3Client, str, str, str]: - """get sync client""" - s3storage = self.get_org_primary_storage(org) - - return self.get_sync_s3_client(s3storage, use_access=use_access) - # pylint: disable=too-many-arguments,too-many-locals async def do_upload_multipart( self, @@ -525,22 +516,21 @@ class StorageOps: contexts: List[str], ) -> Iterator[bytes]: """Return filtered stream of logs from specified WACZs sorted by timestamp""" - client, bucket, key, _ = self.get_sync_client(org) + with self.get_sync_client(org) as (client, bucket, key): + loop = asyncio.get_event_loop() - loop = asyncio.get_event_loop() + resp = await loop.run_in_executor( + None, + self._sync_get_logs, + wacz_files, + log_levels, + contexts, + client, + bucket, + key, + ) - resp = await loop.run_in_executor( - None, - self._sync_get_logs, - wacz_files, - log_levels, - contexts, - client, - bucket, - key, - ) - - return resp + return resp def _sync_get_logs( self, @@ -674,15 +664,14 @@ class StorageOps: ) -> Iterator[bytes]: """return an iter for downloading a stream nested wacz file from list of files""" - client, bucket, key, _ = self.get_sync_client(org) + with self.get_sync_client(org) as (client, bucket, key): + loop = asyncio.get_event_loop() - loop = asyncio.get_event_loop() + resp = await loop.run_in_executor( + None, self._sync_dl, files, client, bucket, key + ) - resp = await loop.run_in_executor( - None, self._sync_dl, files, client, bucket, key - ) - - return resp + return resp # ============================================================================