Source code for spikelab.batch_jobs.storage_s3

"""S3-compatible storage helpers for batch job artifacts."""

from __future__ import annotations

from pathlib import Path
from typing import Optional

try:
    import boto3
except ImportError:  # pragma: no cover
    boto3 = None  # type: ignore[assignment]

from ..data_loaders.s3_utils import parse_s3_url
from .models import StoragePathTemplates


[docs] class S3StorageClient: """Small wrapper around boto3 for upload/download URI handling. Path layout is controlled by *path_templates* (a :class:`StoragePathTemplates` instance loaded from the active profile). """
[docs] def __init__( self, *, prefix: Optional[str] = None, path_templates: Optional[StoragePathTemplates] = None, endpoint_url: Optional[str] = None, region_name: Optional[str] = None, aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, ) -> None: # ``prefix`` normalisation: ``None`` or empty string stays # ``None`` (no bucket-level base configured); a non-empty # string gets a trailing ``/`` appended if missing so # downstream ``prefix + filename`` concatenation produces # a valid S3 URI. Spelt out as three branches instead of a # nested ternary for readability. The ``not prefix`` check # (rather than ``is None``) is intentional — empty string # is a documented synonym for "no prefix". if not prefix: self.prefix = None elif prefix.endswith("/"): self.prefix = prefix else: self.prefix = f"{prefix}/" self.endpoint_url = endpoint_url self.region_name = region_name self._templates = path_templates or StoragePathTemplates() # When boto3 is available, eagerly construct the client so # tests that patch ``spikelab.batch_jobs.storage_s3.boto3`` # for the duration of the constructor get the patched client # (the original behaviour). When boto3 is None, defer the # ImportError until a method that actually needs the client # is called — this lets pure-string operations # (``build_uri``, ``output_prefix_for_run``, # ``logs_prefix_for_run``) succeed on hosts without the # optional dependency installed, e.g. ``cli._cmd_render`` → # ``_build_session`` → here. self._boto3_kwargs = { "endpoint_url": endpoint_url, "region_name": region_name, "aws_access_key_id": aws_access_key_id, "aws_secret_access_key": aws_secret_access_key, "aws_session_token": aws_session_token, } if boto3 is not None: self._client_instance = boto3.client("s3", **self._boto3_kwargs) else: self._client_instance = None
@property def _client(self): """Return the boto3 S3 client, deferring the ImportError to first use when boto3 was not available at construction time. """ if self._client_instance is None: if boto3 is None: raise ImportError("boto3 is required for S3 storage: pip install boto3") self._client_instance = boto3.client("s3", **self._boto3_kwargs) return self._client_instance
[docs] def build_uri(self, *, run_id: str, filename: str, category: str = "inputs") -> str: """Build an S3 URI for a file using the active path templates. ``category`` should be one of the keys defined on ``StoragePathTemplates`` (``"inputs"``, ``"outputs"``, ``"logs"``). An unknown category silently falls back to the ``inputs`` template and emits a ``UserWarning`` so typos ("input", "logs/", etc.) don't quietly land in the wrong S3 prefix. """ if not self.prefix: raise ValueError( "S3 prefix is not configured. Set it in the profile or command." ) if not hasattr(self._templates, category): import warnings known = sorted( k for k in vars(self._templates).keys() if not k.startswith("_") ) warnings.warn( f"build_uri: unknown category={category!r}; falling back " f"to the ``inputs`` template. Known categories: {known}. " "Check for typos.", UserWarning, stacklevel=2, ) template = getattr(self._templates, category, self._templates.inputs) return template.format(prefix=self.prefix, run_id=run_id, filename=filename)
[docs] def upload_file(self, *, local_path: str, s3_uri: str) -> str: """Upload a local file to S3 and return the URI. Raises ``FileNotFoundError`` if ``local_path`` does not exist rather than deferring to boto3's less informative error. """ if not Path(local_path).is_file(): raise FileNotFoundError( f"upload_file: local_path={local_path!r} does not exist " "or is not a regular file." ) bucket, key = parse_s3_url(s3_uri) self._client.upload_file(local_path, bucket, key) return s3_uri
[docs] def upload_bundle(self, *, local_zip: str, run_id: str) -> str: """Upload a zip bundle to S3 under the inputs path template.""" filename = Path(local_zip).name uri = self.build_uri(run_id=run_id, filename=filename, category="inputs") return self.upload_file(local_path=local_zip, s3_uri=uri)
[docs] def output_prefix_for_run(self, run_id: str) -> str: """Return the S3 prefix for a run's output files.""" if not self.prefix: return "" return self._templates.outputs.format( prefix=self.prefix, run_id=run_id, filename="" )
[docs] def logs_prefix_for_run(self, run_id: str) -> str: """Return the S3 prefix for a run's log files.""" if not self.prefix: return "" return self._templates.logs.format( prefix=self.prefix, run_id=run_id, filename="" )
[docs] def download_file(self, *, s3_uri: str, local_path: str) -> str: """Download a single file from S3. Parameters: s3_uri (str): Full ``s3://bucket/key`` URI. local_path (str): Destination path on disk. Returns: local_path (str): The same *local_path* for convenience. """ bucket, key = parse_s3_url(s3_uri) Path(local_path).parent.mkdir(parents=True, exist_ok=True) self._client.download_file(bucket, key, local_path) return local_path
[docs] def download_output(self, *, run_id: str, filename: str, local_dir: str) -> str: """Download a file from the output prefix of a run. Parameters: run_id (str): Run identifier. filename (str): Name of the file within the output prefix. ``..`` segments are rejected to prevent path traversal outside ``local_dir``. local_dir (str): Local directory to save the file into. Returns: local_path (str): Absolute path of the downloaded file. """ # Path-traversal guard: ``filename`` flows directly into the # local filesystem destination. A malicious or buggy upstream # (e.g. an S3 listing entry with ``..`` segments) could escape # ``local_dir`` and clobber arbitrary files. Resolve both paths # and assert the destination stays under the dir. local_dir_resolved = Path(local_dir).resolve() target = (local_dir_resolved / filename).resolve() try: target.relative_to(local_dir_resolved) except ValueError: raise ValueError( f"filename={filename!r} resolves outside local_dir={local_dir!r}; " "path-traversal segments (e.g. '..') are not allowed." ) prefix = self.output_prefix_for_run(run_id) s3_uri = prefix + filename return self.download_file(s3_uri=s3_uri, local_path=str(target))
DEFAULT_LIST_OUTPUT_LIMIT = 10_000
[docs] def list_output_files(self, run_id: str, *, max_keys: Optional[int] = None) -> list: """List object keys under the output prefix of a run. Parameters: run_id (str): Run identifier. max_keys (int | None): Cap on the number of keys returned. Defaults to ``DEFAULT_LIST_OUTPUT_LIMIT`` (10000) to guard against unbounded memory use on long-running jobs that produced thousands of intermediate files (QC figures, per-recording reports, etc.). Pass an explicit larger value if the caller really needs the full list; exceeding the cap raises ``ValueError`` rather than silently truncating. Returns: keys (list[str]): S3 object keys found under the output prefix. Raises: ValueError: When more than ``max_keys`` objects exist under the prefix. """ prefix = self.output_prefix_for_run(run_id) if not prefix: return [] cap = self.DEFAULT_LIST_OUTPUT_LIMIT if max_keys is None else max_keys bucket, key_prefix = parse_s3_url(prefix) paginator = self._client.get_paginator("list_objects_v2") keys: list = [] for page in paginator.paginate(Bucket=bucket, Prefix=key_prefix): for obj in page.get("Contents", []): keys.append(obj["Key"]) if len(keys) > cap: raise ValueError( f"list_output_files: more than max_keys={cap} objects " f"under prefix={prefix!r}. Pass a larger ``max_keys`` " "if this is expected; otherwise narrow the run_id." ) return keys