"""Live GPU memory watchdog for spike-sorting runs.
Symmetric to :class:`HostMemoryWatchdog` but watches GPU VRAM via
``pynvml`` (or ``nvidia-smi`` as a fallback). Trips when the
device-in-use crosses the configured percentage thresholds; on trip
it terminates registered subprocesses, runs registered kill
callbacks, and raises a
:class:`spikelab.spike_sorting._exceptions.GpuMemoryWatchdogError`
into the main thread via ``_thread.interrupt_main``.
The watchdog narrows its measurement to the device the sort is using
(KS4 ``torch_device``, RT-Sort ``device``, KS2-Docker default
``cuda:0``) so unrelated GPUs running other workloads are ignored.
Detection priority:
1. ``pynvml`` (already an optional spikelab dep) — fastest, exact
API for free/used/total memory per device.
2. ``nvidia-smi`` parse — fallback when ``pynvml`` is missing.
3. No-op when neither is available — the watchdog reports as
disabled rather than raising.
"""
from __future__ import annotations
import _thread
import contextvars
import logging
import subprocess
import threading
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from .._exceptions import GpuMemoryWatchdogError, GpuThermalWatchdogError
from ._audit import append_audit_event
_logger = logging.getLogger(__name__)
# Throttle reason bits we surface as warnings (kernel docs:
# https://docs.nvidia.com/deploy/nvml-api/group__nvmlClocksThrottleReasons.html).
# We deliberately ignore Idle/AppClocks because those reflect benign
# OS scheduling rather than hardware distress.
_THROTTLE_REASON_LABELS = (
(0x4, "SW power cap"),
(0x8, "HW slowdown"),
(0x20, "SW thermal slowdown"),
(0x40, "HW thermal slowdown"),
(0x80, "HW power brake"),
)
_active_gpu_watchdog: contextvars.ContextVar[Optional["GpuMemoryWatchdog"]] = (
contextvars.ContextVar("active_gpu_memory_watchdog", default=None)
)
def get_active_gpu_watchdog() -> Optional["GpuMemoryWatchdog"]:
"""Return the GPU watchdog active for the current context, or None.
Mirror of :func:`._watchdog.get_active_watchdog` for the GPU
watchdog. Lets the per-recording :class:`KeyboardInterrupt`
catch site discover a tripped GPU watchdog and convert the
interrupt into the appropriate classified error.
Returns:
watchdog (GpuMemoryWatchdog or None): The active instance,
or ``None`` when no GPU watchdog is currently running.
"""
return _active_gpu_watchdog.get()
class _PynvmlSession:
"""Long-lived pynvml handle for one device.
Initialises pynvml once, caches the per-device handle, and reads
memory / temperature / throttle reasons via the cached handle.
Replaces the per-call ``nvmlInit/nvmlShutdown`` pattern, which
serialised every poll and added measurable overhead at the
default 2-second cadence.
Best-effort: per-method failures return ``None`` rather than
raising. ``shutdown()`` is idempotent and safe to call on a
session that never initialised.
"""
def __init__(self, device_index: int) -> None:
self.device_index = int(device_index)
self._pynvml = None
self._handle = None
def start(self) -> bool:
"""Initialise pynvml and resolve the device handle.
Returns:
ok (bool): ``True`` when pynvml is importable and the
handle resolves; ``False`` otherwise.
"""
try:
import pynvml
except ImportError:
return False
try:
pynvml.nvmlInit()
except Exception:
return False
try:
handle = pynvml.nvmlDeviceGetHandleByIndex(self.device_index)
except Exception:
try:
pynvml.nvmlShutdown()
except Exception:
pass
return False
self._pynvml = pynvml
self._handle = handle
return True
def read_memory(self) -> Optional[Tuple[float, float]]:
"""Return ``(used_pct, total_gb)`` or ``None`` on failure."""
if self._handle is None or self._pynvml is None:
return None
try:
info = self._pynvml.nvmlDeviceGetMemoryInfo(self._handle)
except Exception:
return None
total = float(info.total)
if total <= 0:
return None
return float(info.used) / total * 100.0, total / (1024**3)
def read_temperature_c(self) -> Optional[float]:
"""Return device temperature in degrees Celsius, or ``None``."""
if self._handle is None or self._pynvml is None:
return None
try:
# Sensor 0 is NVML_TEMPERATURE_GPU on every supported device.
return float(self._pynvml.nvmlDeviceGetTemperature(self._handle, 0))
except Exception:
return None
def read_throttle_reasons(self) -> Optional[int]:
"""Return the active throttle-reasons bitmask, or ``None``."""
if self._handle is None or self._pynvml is None:
return None
try:
return int(
self._pynvml.nvmlDeviceGetCurrentClocksThrottleReasons(self._handle)
)
except Exception:
return None
def shutdown(self) -> None:
if self._pynvml is not None:
try:
self._pynvml.nvmlShutdown()
except Exception:
pass
self._pynvml = None
self._handle = None
def _format_throttle_reasons(mask: int) -> str:
"""Render an active throttle-reasons mask as a comma-separated string."""
parts = [label for bit, label in _THROTTLE_REASON_LABELS if mask & bit]
return ", ".join(parts)
def _resolve_device_index(device: Optional[str]) -> int:
"""Return the integer device index for a torch-style device string.
Accepts ``"cuda"``, ``"cuda:0"``, ``"cuda:1"``, integer-like
strings, and ``None`` (interpreted as device 0). Falls back to 0
on parse failure rather than raising — the watchdog is
best-effort.
Parameters:
device (str or None): Torch-style device identifier.
Returns:
index (int): Device index (>= 0).
"""
if device is None:
return 0
s = str(device).strip().lower()
if s in ("", "cuda"):
return 0
if ":" in s:
try:
return max(0, int(s.split(":", 1)[1]))
except ValueError:
return 0
if s.isdigit():
return int(s)
return 0
def _read_gpu_memory_pynvml(device_index: int) -> Optional[Tuple[float, float]]:
"""Return ``(used_pct, total_gb)`` for *device_index* via pynvml.
Reuses :class:`_PynvmlSession` so the init / shutdown lifecycle
is owned by a single class rather than duplicated across the
per-call helper. Returns ``None`` when pynvml is missing or the
read fails.
"""
session = _PynvmlSession(device_index)
if not session.start():
return None
try:
return session.read_memory()
finally:
session.shutdown()
def _read_gpu_memory_nvidia_smi(
device_index: int,
) -> Optional[Tuple[float, float]]:
"""Return ``(used_pct, total_gb)`` via parsing ``nvidia-smi``.
Returns ``None`` when nvidia-smi is unavailable or the device
index is out of range.
"""
try:
out = subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=index,memory.used,memory.total",
"--format=csv,noheader,nounits",
],
text=True,
timeout=5,
)
except (subprocess.SubprocessError, FileNotFoundError):
return None
for line in out.strip().splitlines():
parts = [p.strip() for p in line.split(",")]
if len(parts) < 3:
continue
try:
idx = int(parts[0])
used_mib = float(parts[1])
total_mib = float(parts[2])
except ValueError:
continue
if idx != device_index or total_mib <= 0:
continue
return used_mib / total_mib * 100.0, total_mib / 1024.0
return None
def capture_gpu_snapshot(output_path, *, header: str = "") -> Optional[str]:
"""Write a GPU diagnostic snapshot to disk for postmortem analysis.
Captures the current ``nvidia-smi`` output and (if PyTorch is
available with CUDA) ``torch.cuda.memory_summary`` for every
visible device. The result is a plain-text file the operator can
inspect to determine which process owned the GPU memory or what
PyTorch's allocator thought it had reserved.
Best-effort: failures during capture are recorded in the file
rather than raising.
Parameters:
output_path (path-like): Destination file path. Parent
directories are created if missing.
header (str): Optional banner prepended to the file (e.g.
"Host memory watchdog trip at 93.2%").
Returns:
path (str or None): The string path on success, ``None`` on
failure.
"""
import datetime as _dt
from pathlib import Path as _Path
target = _Path(output_path)
lines: List[str] = []
if header:
lines.append(header)
lines.append("=" * len(header))
lines.append("")
lines.append(f"Captured: {_dt.datetime.now().isoformat(timespec='seconds')}")
lines.append("")
# nvidia-smi. Force the C locale so column labels and units in
# the human-readable output are predictable English text — the
# snapshot is grepped through after the fact, and operators on
# non-English hosts (de_DE, ja_JP, etc.) would otherwise have to
# decode localised labels.
import os as _os
smi_env = {**_os.environ, "LANG": "C", "LC_ALL": "C"}
lines.append("-- nvidia-smi --")
try:
out = subprocess.check_output(
["nvidia-smi"],
text=True,
timeout=10,
env=smi_env,
)
lines.append(out.rstrip())
except (subprocess.SubprocessError, FileNotFoundError) as exc:
lines.append(f"(nvidia-smi unavailable: {exc!r})")
lines.append("")
# torch memory summary
lines.append("-- torch.cuda.memory_summary --")
try:
import torch
if torch.cuda.is_available():
for i in range(torch.cuda.device_count()):
lines.append(f"\nDevice {i}:")
try:
lines.append(torch.cuda.memory_summary(device=i, abbreviated=True))
except Exception as exc:
lines.append(f"(memory_summary failed: {exc!r})")
else:
lines.append("(torch.cuda.is_available() = False)")
except ImportError:
lines.append("(torch not installed)")
except Exception as exc:
lines.append(f"(torch.cuda probe failed: {exc!r})")
try:
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text("\n".join(lines), encoding="utf-8")
return str(target)
except Exception as exc:
_logger.error("snapshot failed to write %s: %r", target, exc)
return None
def read_gpu_memory(
device_index: int,
) -> Optional[Tuple[float, float]]:
"""Return ``(used_pct, total_gb)`` for *device_index*, or ``None``.
Tries ``pynvml`` first, then ``nvidia-smi``. Returns ``None``
when neither source can produce a reading (e.g. no NVIDIA driver,
or the index is out of range).
Parameters:
device_index (int): Zero-based GPU index.
Returns:
info (tuple[float, float] or None): ``(used_pct, total_gb)``
on success.
"""
info = _read_gpu_memory_pynvml(device_index)
if info is not None:
return info
return _read_gpu_memory_nvidia_smi(device_index)
def _try_capture_snapshot_to_results(log_path, header: str) -> None:
"""Write a GPU snapshot to the per-recording results folder.
Used by watchdog abort paths to leave a postmortem artefact at
``<results_folder>/gpu_snapshot_at_trip.txt``. The watchdog must
pass the log path captured at ``__enter__`` time on the main
thread — the watchdog's polling thread cannot reliably look up
the ``get_active_log_path`` ContextVar because Python does not
propagate ContextVars across thread boundaries.
Best-effort: failures (None log_path, write failure, etc.) are
silent so a snapshot bug never breaks the surrounding watchdog.
Parameters:
log_path (Path or None): Per-recording log path; the
results folder is its parent. ``None`` short-circuits.
header (str): Banner to prepend to the snapshot file.
"""
if log_path is None:
return
try:
from pathlib import Path as _Path
results_folder = _Path(log_path).parent
target = results_folder / "gpu_snapshot_at_trip.txt"
capture_gpu_snapshot(target, header=header)
except Exception as exc:
_logger.error("snapshot failed to capture on trip: %r", exc)
def _resolve_rt_sort_device(config) -> int:
"""Resolver for the RT-Sort sorter."""
return _resolve_device_index(getattr(config.rt_sort, "device", None))
def _resolve_kilosort4_device(config) -> int:
"""Resolver for the Kilosort4 host sorter."""
params = getattr(config.sorter, "sorter_params", None) or {}
return _resolve_device_index(params.get("torch_device"))
# Registry mapping sorter name (lowercase) to a callable that takes
# the config and returns the integer device index to monitor.
# Sorters absent from this registry default to device 0 — the
# convention KS2-Docker and other CUDA-using sorters that don't
# expose a configurable device follow.
#
# Adding a new GPU-using sorter: write a ``_resolve_<sorter>_device``
# function that pulls the device index from wherever the sorter
# config holds it, then add an entry here. No other call site needs
# updating.
_DEVICE_RESOLVERS: Dict[str, Callable[[Any], int]] = {
"rt_sort": _resolve_rt_sort_device,
"kilosort4": _resolve_kilosort4_device,
}
def resolve_active_device(config) -> int:
"""Pick the GPU device index implied by the sorter config.
The watchdog measures only this device so unrelated GPUs running
other workloads are ignored.
Parameters:
config (SortingPipelineConfig): Pipeline configuration.
Returns:
index (int): Device index to monitor (defaults to 0 for any
sorter not registered in :data:`_DEVICE_RESOLVERS`).
"""
sorter_name = getattr(config.sorter, "sorter_name", "").lower()
resolver = _DEVICE_RESOLVERS.get(sorter_name)
if resolver is None:
return 0
return resolver(config)
[docs]
class GpuMemoryWatchdog:
"""Daemon-thread watchdog that aborts on GPU VRAM or thermal pressure.
Use as a context manager around the per-recording sort. Each
poll inspects three signals:
* **VRAM usage** — crossing ``warn_pct`` prints a rate-limited
warning; crossing ``abort_pct`` builds a
:class:`GpuMemoryWatchdogError`, terminates registered
subprocesses, runs kill callbacks, and raises into the main
thread.
* **Device temperature** — crossing ``warn_temp_c`` prints a
rate-limited warning; crossing ``abort_temp_c`` aborts with a
:class:`GpuThermalWatchdogError`. Sustained operation above
the GPU's thermal junction limit risks driver-level throttling
that silently degrades sort output.
* **Active throttle reasons** — when the device reports SW/HW
power-cap or thermal slowdown, prints a rate-limited warning
(no abort: the device is already protecting itself).
Parameters:
device_index (int): GPU index to monitor. Use
:func:`resolve_active_device` to pick from the config.
warn_pct (float): Used-memory percentage at which to warn.
Defaults to ``85.0``.
abort_pct (float): Used-memory percentage at which to abort.
Defaults to ``95.0``.
poll_interval_s (float): Seconds between polls. Defaults to
``2.0``.
warn_repeat_s (float): Minimum seconds between repeated
warnings. Defaults to ``30.0``.
kill_grace_s (float): Seconds between ``terminate()`` and
``kill()`` on registered subprocesses.
warn_temp_c (float or None): Temperature in degrees Celsius
at which to warn. ``None`` disables the warn-stage temp
check. Defaults to ``85.0``.
abort_temp_c (float or None): Temperature at which to abort.
``None`` disables thermal aborts. Defaults to ``92.0``.
monitor_throttle_reasons (bool): When True, surface NVML
throttle reasons (SW power cap, HW thermal slowdown,
HW power brake) as rate-limited warnings. Defaults to
``True``.
Notes:
- Thermal monitoring requires ``pynvml``; the
``nvidia-smi``-only fallback path used by
:func:`read_gpu_memory` does not surface temperature.
When pynvml is missing, thermal/throttle checks silently
degrade while VRAM monitoring continues via nvidia-smi.
- Disabled (no-op context manager) when no usable GPU info
source is available.
"""
[docs]
def __init__(
self,
device_index: int = 0,
*,
warn_pct: float = 85.0,
abort_pct: float = 95.0,
poll_interval_s: float = 2.0,
warn_repeat_s: float = 30.0,
kill_grace_s: float = 5.0,
warn_temp_c: Optional[float] = 85.0,
abort_temp_c: Optional[float] = 92.0,
monitor_throttle_reasons: bool = True,
) -> None:
if not 0.0 < warn_pct < abort_pct <= 100.0:
raise ValueError(
f"warn_pct ({warn_pct}) and abort_pct ({abort_pct}) must "
f"satisfy 0 < warn_pct < abort_pct <= 100."
)
if np.isnan(poll_interval_s) or poll_interval_s <= 0.0:
raise ValueError(
f"poll_interval_s must be positive, got {poll_interval_s}."
)
if np.isnan(kill_grace_s) or kill_grace_s < 0.0:
raise ValueError(f"kill_grace_s must be non-negative, got {kill_grace_s}.")
if (
warn_temp_c is not None
and abort_temp_c is not None
and not 0.0 < warn_temp_c < abort_temp_c
):
raise ValueError(
f"warn_temp_c ({warn_temp_c}) and abort_temp_c "
f"({abort_temp_c}) must satisfy 0 < warn_temp_c < "
"abort_temp_c."
)
self.device_index = int(device_index)
self.warn_pct = float(warn_pct)
self.abort_pct = float(abort_pct)
self.poll_interval_s = float(poll_interval_s)
self.warn_repeat_s = float(warn_repeat_s)
self.kill_grace_s = float(kill_grace_s)
self.warn_temp_c = float(warn_temp_c) if warn_temp_c is not None else None
self.abort_temp_c = float(abort_temp_c) if abort_temp_c is not None else None
self.monitor_throttle_reasons = bool(monitor_throttle_reasons)
self._subprocesses: List[Tuple[subprocess.Popen, float]] = []
self._kill_callbacks: List[Callable[[], None]] = []
self._lock = threading.Lock()
self._stop_event = threading.Event()
self._thread: Optional[threading.Thread] = None
self._tripped = False
self._tripped_kind: Optional[str] = None
self._used_pct_at_trip: Optional[float] = None
self._temp_c_at_trip: Optional[float] = None
self._last_warn_t = 0.0
self._last_temp_warn_t = 0.0
self._last_throttle_warn_t = 0.0
self._enabled = False
self._session: Optional[_PynvmlSession] = None
self._token: Optional[contextvars.Token] = None
# Captured at ``__enter__`` time on the main thread because
# ContextVars do not propagate to the polling thread.
self._snapshot_log_path = None
# Set True when the trip cascade ran but
# ``_thread.interrupt_main`` raised — the main thread did not
# receive the KeyboardInterrupt and will surface a downstream
# error instead. Catch sites read this via
# :meth:`interrupt_delivery_failed`.
self._interrupt_main_failed = False
# ------------------------------------------------------------------
# Trip-state queries
# ------------------------------------------------------------------
[docs]
def tripped(self) -> bool:
"""Return True once the watchdog has fired its abort path."""
return self._tripped
[docs]
def interrupt_delivery_failed(self) -> bool:
"""Return True if the trip fired but ``_thread.interrupt_main`` raised.
When True, GPU protection ran successfully (subprocesses
terminated, kill callbacks invoked) but the main thread did
not receive a ``KeyboardInterrupt``. The pipeline's catch
site checks this to reclassify a downstream exception caused
by the now-dead subprocess.
Returns:
failed (bool): True only when the watchdog tripped and
the interrupt delivery raised.
"""
return self._interrupt_main_failed
[docs]
def used_pct_at_trip(self) -> Optional[float]:
"""Return the used-memory percent at the trip moment, or None."""
return self._used_pct_at_trip
[docs]
def temperature_c_at_trip(self) -> Optional[float]:
"""Return the device temperature at the trip moment, or None."""
return self._temp_c_at_trip
[docs]
def trip_kind(self) -> Optional[str]:
"""Return ``"memory"``, ``"thermal"``, or ``None`` if not tripped."""
return self._tripped_kind
[docs]
def make_error(
self, message: Optional[str] = None
) -> Union[GpuMemoryWatchdogError, GpuThermalWatchdogError]:
"""Build the trip-kind-appropriate watchdog error.
Parameters:
message (str or None): Override the default message.
Returns:
err: :class:`GpuMemoryWatchdogError` for VRAM trips,
:class:`GpuThermalWatchdogError` for temperature
trips. Falls back to a memory-shaped error when the
trip kind is unset.
"""
if self._tripped_kind == "thermal":
if message is None:
temp = (
f"{self._temp_c_at_trip:.1f}"
if self._temp_c_at_trip is not None
else "?"
)
abort_temp = (
f"{self.abort_temp_c:.1f}" if self.abort_temp_c is not None else "?"
)
message = (
f"GPU thermal watchdog tripped: device "
f"{self.device_index} at {temp} C "
f"(abort threshold {abort_temp} C)."
)
return GpuThermalWatchdogError(
message,
device_index=self.device_index,
temperature_c_at_trip=self._temp_c_at_trip,
abort_temp_c=self.abort_temp_c,
)
if message is None:
pct = (
f"{self._used_pct_at_trip:.1f}"
if self._used_pct_at_trip is not None
else "?"
)
message = (
f"GPU watchdog tripped: device {self.device_index} used "
f"{pct}% (abort threshold {self.abort_pct:.1f}%)."
)
return GpuMemoryWatchdogError(
message,
device_index=self.device_index,
used_pct_at_trip=self._used_pct_at_trip,
abort_pct=self.abort_pct,
)
# ------------------------------------------------------------------
# Registration (subprocesses + kill callbacks)
# ------------------------------------------------------------------
[docs]
def register_subprocess(
self,
popen: subprocess.Popen,
*,
kill_grace_s: Optional[float] = None,
) -> None:
"""Track a subprocess for termination on watchdog abort."""
grace = self.kill_grace_s if kill_grace_s is None else float(kill_grace_s)
with self._lock:
self._subprocesses.append((popen, grace))
[docs]
def unregister_subprocess(self, popen: subprocess.Popen) -> None:
"""Stop tracking a previously registered subprocess."""
with self._lock:
self._subprocesses = [
(p, g) for (p, g) in self._subprocesses if p is not popen
]
[docs]
def register_kill_callback(self, callback: Callable[[], None]) -> None:
"""Track a zero-arg callable to invoke on watchdog abort."""
with self._lock:
self._kill_callbacks.append(callback)
[docs]
def unregister_kill_callback(self, callback: Callable[[], None]) -> None:
"""Stop tracking a previously registered kill callback."""
with self._lock:
self._kill_callbacks = [
c for c in self._kill_callbacks if c is not callback
]
# ------------------------------------------------------------------
# Context manager
# ------------------------------------------------------------------
def __enter__(self) -> "GpuMemoryWatchdog":
# Capture the active per-recording log path on the main
# thread; the daemon polling thread cannot read the
# ContextVar reliably.
try:
from ._inactivity import get_active_log_path
self._snapshot_log_path = get_active_log_path()
except Exception:
self._snapshot_log_path = None
# Probe once before starting the thread so we can disable
# cleanly when no GPU info source is available.
info = read_gpu_memory(self.device_index)
if info is None:
_logger.warning(
"no GPU info available for device %d (no pynvml, no "
"nvidia-smi). Disabled.",
self.device_index,
)
self._enabled = False
return self
self._enabled = True
# Publish the active watchdog so the per-recording
# ``KeyboardInterrupt`` catch site can convert a
# ``_thread.interrupt_main`` from this watchdog into a
# classified error rather than letting it bubble up raw.
self._token = _active_gpu_watchdog.set(self)
used_pct, total_gb = info
# Try to set up a long-lived pynvml session for the polling
# thread. Falls back to per-poll ``read_gpu_memory`` (which
# uses nvidia-smi) when pynvml is unavailable; in that case
# thermal / throttle monitoring degrades silently because
# nvidia-smi-only does not expose those signals here.
session = _PynvmlSession(self.device_index)
if session.start():
self._session = session
else:
self._session = None
thermal_str = ""
if self._session is not None and (
self.warn_temp_c is not None or self.abort_temp_c is not None
):
initial_temp = self._session.read_temperature_c()
if initial_temp is not None:
thermal_str = (
f" temp_warn>={self.warn_temp_c} "
f"abort>={self.abort_temp_c} (now {initial_temp:.1f}C)"
)
_logger.info(
"active: device=%d (%.1f GB) start=%.1f%% warn>=%.1f%% "
"abort>=%.1f%% poll=%.1fs%s",
self.device_index,
total_gb,
used_pct,
self.warn_pct,
self.abort_pct,
self.poll_interval_s,
thermal_str,
)
self._stop_event.clear()
self._thread = threading.Thread(
target=self._poll_loop,
name=f"GpuMemoryWatchdog[{self.device_index}]",
daemon=True,
)
self._thread.start()
return self
def __exit__(self, exc_type, exc, tb) -> None:
self._stop_event.set()
if self._thread is not None:
self._thread.join(timeout=self.poll_interval_s + 1.0)
self._thread = None
if self._token is not None:
try:
_active_gpu_watchdog.reset(self._token)
except (LookupError, ValueError, RuntimeError):
# Another context modified the var between set/reset,
# or the token was already consumed (Python 3.10+
# raises RuntimeError on re-used tokens).
pass
self._token = None
if self._session is not None:
self._session.shutdown()
self._session = None
# ------------------------------------------------------------------
# Internals
# ------------------------------------------------------------------
def _poll_loop(self) -> None:
"""Polling loop: warn, then trip, then exit."""
# Defer the first poll so __enter__ has time to return.
if self._stop_event.wait(self.poll_interval_s):
return
blind_threshold_s = 5.0 * self.warn_repeat_s
# VRAM and thermal/throttle have independent failure modes
# (the same _session may succeed for one and fail for the
# other on driver corner cases) so they get independent
# blindness trackers.
vram_blind_started_t: Optional[float] = None
vram_blind_warned = False
thermal_blind_started_t: Optional[float] = None
thermal_blind_warned = False
while not self._stop_event.is_set():
now = time.time()
# Memory: prefer the cached pynvml session, fall back
# to the free-function reader (which uses nvidia-smi).
if self._session is not None:
info = self._session.read_memory()
else:
info = read_gpu_memory(self.device_index)
if info is None:
if vram_blind_started_t is None:
vram_blind_started_t = now
elif (
not vram_blind_warned
and now - vram_blind_started_t >= blind_threshold_s
):
self._warn_blind_vram(now - vram_blind_started_t)
vram_blind_warned = True
else:
vram_blind_started_t = None
vram_blind_warned = False
used_pct, _total_gb = info
if used_pct >= self.abort_pct:
self._on_abort(used_pct)
return
if used_pct >= self.warn_pct:
self._maybe_warn(used_pct)
# Thermal + throttle reasons require pynvml; skip when
# the session is unavailable. When _session is present
# but a configured sub-read returns None, that counts
# as thermal blindness for the warning tracker.
if self._session is not None:
thermal_configured = (
self.warn_temp_c is not None
or self.abort_temp_c is not None
or self.monitor_throttle_reasons
)
thermal_unreadable = False
if self.warn_temp_c is not None or self.abort_temp_c is not None:
temp_c = self._session.read_temperature_c()
if temp_c is None:
thermal_unreadable = True
else:
if (
self.abort_temp_c is not None
and temp_c >= self.abort_temp_c
):
self._on_thermal_abort(temp_c)
return
if self.warn_temp_c is not None and temp_c >= self.warn_temp_c:
self._maybe_warn_temp(temp_c)
if self.monitor_throttle_reasons:
mask = self._session.read_throttle_reasons()
if mask is None:
thermal_unreadable = True
elif mask & sum(bit for bit, _ in _THROTTLE_REASON_LABELS):
self._maybe_warn_throttle(mask)
if thermal_configured and thermal_unreadable:
if thermal_blind_started_t is None:
thermal_blind_started_t = now
elif (
not thermal_blind_warned
and now - thermal_blind_started_t >= blind_threshold_s
):
self._warn_blind_thermal(now - thermal_blind_started_t)
thermal_blind_warned = True
else:
thermal_blind_started_t = None
thermal_blind_warned = False
self._stop_event.wait(self.poll_interval_s)
def _maybe_warn(self, used_pct: float) -> None:
"""Print a warning if enough time has passed since the last one."""
now = time.time()
if now - self._last_warn_t < self.warn_repeat_s:
return
self._last_warn_t = now
_logger.warning(
"device %d VRAM at %.1f%% (warn=%.1f%% / abort=%.1f%%).",
self.device_index,
used_pct,
self.warn_pct,
self.abort_pct,
)
append_audit_event(
watchdog="gpu_memory",
event="warn",
log_path=self._snapshot_log_path,
device_index=self.device_index,
used_pct=used_pct,
warn_pct=self.warn_pct,
abort_pct=self.abort_pct,
)
def _maybe_warn_temp(self, temp_c: float) -> None:
"""Print a thermal warning if rate-limit allows."""
now = time.time()
if now - self._last_temp_warn_t < self.warn_repeat_s:
return
self._last_temp_warn_t = now
abort = f"{self.abort_temp_c:.1f}" if self.abort_temp_c is not None else "off"
_logger.warning(
"device %d temperature %.1f C (warn>=%.1f / abort>=%s).",
self.device_index,
temp_c,
self.warn_temp_c,
abort,
)
append_audit_event(
watchdog="gpu_thermal",
event="warn",
log_path=self._snapshot_log_path,
device_index=self.device_index,
temperature_c=temp_c,
warn_temp_c=self.warn_temp_c,
abort_temp_c=self.abort_temp_c,
)
def _maybe_warn_throttle(self, mask: int) -> None:
"""Print a throttle-reason warning if rate-limit allows."""
now = time.time()
if now - self._last_throttle_warn_t < self.warn_repeat_s:
return
self._last_throttle_warn_t = now
reasons = _format_throttle_reasons(mask) or f"mask=0x{mask:x}"
_logger.warning(
"device %d throttling — %s.",
self.device_index,
reasons,
)
append_audit_event(
watchdog="gpu_throttle",
event="warn",
log_path=self._snapshot_log_path,
device_index=self.device_index,
throttle_mask=int(mask),
throttle_reasons=reasons,
)
def _warn_blind_vram(self, blind_for: float) -> None:
_logger.warning(
"VRAM reading for device %d unreadable for %.1fs — "
"watchdog is blind to VRAM-pressure aborts until readings "
"recover.",
self.device_index,
blind_for,
)
append_audit_event(
watchdog="gpu_memory",
event="blind_warn",
source="vram",
log_path=self._snapshot_log_path,
device_index=self.device_index,
blind_for_s=blind_for,
)
def _warn_blind_thermal(self, blind_for: float) -> None:
_logger.warning(
"thermal/throttle reading for device %d unreadable for "
"%.1fs — watchdog is blind to thermal aborts until "
"readings recover.",
self.device_index,
blind_for,
)
append_audit_event(
watchdog="gpu_memory",
event="blind_warn",
source="thermal",
log_path=self._snapshot_log_path,
device_index=self.device_index,
blind_for_s=blind_for,
)
def _on_thermal_abort(self, temp_c: float) -> None:
"""Trip on thermal threshold; terminate, run callbacks, raise."""
self._tripped = True
self._tripped_kind = "thermal"
self._temp_c_at_trip = temp_c
abort = f"{self.abort_temp_c:.1f}" if self.abort_temp_c is not None else "?"
_logger.error(
"THERMAL ABORT: device %d at %.1f C (>= %s C). "
"Terminating subprocesses and raising into main thread.",
self.device_index,
temp_c,
abort,
)
append_audit_event(
watchdog="gpu_thermal",
event="abort",
log_path=self._snapshot_log_path,
device_index=self.device_index,
temperature_c=temp_c,
abort_temp_c=self.abort_temp_c,
)
_try_capture_snapshot_to_results(
self._snapshot_log_path,
f"GPU thermal watchdog trip — device {self.device_index} at "
f"{temp_c:.1f} C",
)
self._kill_targets_and_interrupt()
def _on_abort(self, used_pct: float) -> None:
"""Record trip, terminate subprocesses, run callbacks, interrupt main."""
self._tripped = True
self._tripped_kind = "memory"
self._used_pct_at_trip = used_pct
_logger.error(
"ABORT: device %d VRAM at %.1f%% (>= %.1f%%). Terminating "
"subprocesses and raising into main thread.",
self.device_index,
used_pct,
self.abort_pct,
)
append_audit_event(
watchdog="gpu_memory",
event="abort",
log_path=self._snapshot_log_path,
device_index=self.device_index,
used_pct=used_pct,
abort_pct=self.abort_pct,
)
_try_capture_snapshot_to_results(
self._snapshot_log_path,
f"GPU memory watchdog trip — device {self.device_index} at "
f"{used_pct:.1f}%",
)
self._kill_targets_and_interrupt()
def _kill_targets(self) -> None:
"""Terminate registered subprocesses and run kill callbacks.
Side-effect-only: returns nothing, raises only the explicit
``SystemExit`` / ``KeyboardInterrupt`` propagation case.
Split from :meth:`_kill_targets_and_interrupt` so a future
telemetry-only abort path can run the destructive cascade
without injecting a ``KeyboardInterrupt`` into the main
thread.
Notes:
- When ``entries`` is empty (no subprocesses registered),
the ``if entries:`` gate skips the grace sleep — there
is nothing to terminate, so callbacks fire immediately.
The ``default=self.kill_grace_s`` on the ``max(...)``
call is unreachable in that case but kept for clarity.
"""
with self._lock:
entries = list(self._subprocesses)
callbacks = list(self._kill_callbacks)
for popen, _grace in entries:
try:
if popen.poll() is None:
popen.terminate()
except Exception as exc:
_logger.error(
"terminate() failed for pid=%s: %s",
getattr(popen, "pid", "?"),
exc,
)
if entries:
time.sleep(max((g for _, g in entries), default=self.kill_grace_s))
for popen, _grace in entries:
try:
if popen.poll() is None:
popen.kill()
except Exception as exc:
_logger.error(
"kill() failed for pid=%s: %s",
getattr(popen, "pid", "?"),
exc,
)
for cb in callbacks:
try:
cb()
except (SystemExit, KeyboardInterrupt):
# An in-process kill callback delivers KeyboardInterrupt
# via _thread.interrupt_main(); SystemExit signals
# operator-requested abort. Both must propagate.
raise
except Exception as exc:
_logger.error("kill_callback raised: %r; continuing.", exc)
def _kill_targets_and_interrupt(self) -> None:
"""Common subprocess + callback termination + interrupt main.
Used by the abort paths (``_on_abort``, ``_on_thermal_abort``)
which both want destructive cleanup followed by
``_thread.interrupt_main`` to convert the daemon-thread trip
into a classified main-thread error.
"""
self._kill_targets()
# If __exit__ ran while we were mid-cascade (terminate +
# grace + callbacks can take several seconds), the with-block
# has already torn down. Sending interrupt_main() now would
# land a phantom KeyboardInterrupt in whatever code is running
# next — the next sort, an exception handler, or the
# interactive prompt. Skip it.
if self._stop_event.is_set():
_logger.info("suppressing interrupt_main: watchdog is already exiting.")
return
try:
_thread.interrupt_main()
except Exception as exc:
self._interrupt_main_failed = True
_logger.error("failed to interrupt main: %s", exc)
append_audit_event(
watchdog="gpu_memory",
event="interrupt_delivery_failed",
log_path=self._snapshot_log_path,
device_index=self.device_index,
error=repr(exc),
)