From eacb80e5ff3f75e18c672e8345a37fae334e02f5 Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Tue, 23 Jun 2026 15:57:56 -0700 Subject: [PATCH 1/2] fix(runtime): autosave engine-implicit RuntimeCache via atexit + weakref The previous `__del__`-only autosave path silently lost cache updates when the engine survived until interpreter exit (typical for inference servers). Python tears down `sys.meta_path` early in shutdown; the torchbind attribute access in `self._handle.path` and the lazy `filelock` import inside `save()` then raise `ImportError: sys.meta_path is None`, which escaped `__del__` and surfaced as a noisy `Exception ignored in __del__` once per surviving handle. `atexit` callbacks run *before* module teardown, so the torchbind path and lazy imports still resolve there. Register an `atexit` hook from `__init__` whenever `autosave_on_del=True`. The hook closes over a `weakref.ref(self)` so it doesn't pin the handle alive: a handle that dies mid-program still goes through `__del__`, and the atexit hook later sees a dead weakref and no-ops. Other design points worth calling out: * `_autosave_at_exit` is a module-level helper, not a bound method. A bound method captures `self` via `__self__`, which would defeat the weakref. The free function lets the closure carry only the weakref. * Both `__del__` and the atexit hook flip `autosave_on_del` off before saving so whichever path runs first wins and the other no-ops -- no double-save, no double-leak risk. * `__del__` unregisters its atexit token. Without this, a long-running process that churns engine-implicit handles (model swaps, A/B rollouts) accumulates dead atexit entries -- small per entry but unbounded. * The `try` in `__del__` still wraps the whole body so any residual attribute-access failure during late-shutdown corner cases is swallowed rather than leaking to `sys.unraisablehook`. Tests added in `TestRuntimeCacheAutosave`: - `test_del_swallows_shutdown_import_error_on_path`: monkey-patches `_handle.path` to raise the shutdown `ImportError`; asserts via `sys.unraisablehook` that nothing leaks. - `test_atexit_hook_saves_via_weakref`: exercises the helper directly, verifies it saves and flips `autosave_on_del`. - `test_atexit_hook_no_op_on_dead_weakref`: dead weakref => no-op, no exception. - `test_atexit_token_unregistered_after_del`: `atexit.unregister` is spied to confirm `__del__` cleaned up. Refs #4359 --- py/torch_tensorrt/runtime/_runtime_cache.py | 59 ++++++++++-- .../dynamo/runtime/test_000_runtime_cache.py | 93 +++++++++++++++++++ 2 files changed, 145 insertions(+), 7 deletions(-) diff --git a/py/torch_tensorrt/runtime/_runtime_cache.py b/py/torch_tensorrt/runtime/_runtime_cache.py index 19f244d797..9e82fdf443 100644 --- a/py/torch_tensorrt/runtime/_runtime_cache.py +++ b/py/torch_tensorrt/runtime/_runtime_cache.py @@ -22,13 +22,17 @@ from __future__ import annotations +import atexit import logging import os import shutil import threading +import weakref +from functools import partial from typing import ( IO, Any, + Callable, Optional, Protocol, Sequence, @@ -140,6 +144,14 @@ def ensure_materialized(self, runtime_config: Any) -> Any: return self._cache +def _autosave_at_exit(ref: "weakref.ref[RuntimeCache]") -> None: + """Module-level so the atexit closure only holds a weakref, not a bound + method (see :meth:`RuntimeCache.__init__` for full rationale).""" + rc = ref() + if rc is not None: + rc._autosave_if_enabled() + + class RuntimeCache: """User-facing handle for the TensorRT-RTX runtime kernel cache. @@ -181,6 +193,11 @@ def __init__( path: str = "", autosave_on_del: bool = False, ) -> None: + # Set the atexit-token slot first so ``__del__`` can safely read it + # even if a later step in ``__init__`` raises and leaves the object + # partially constructed. + self._atexit_token: Optional[Callable[..., None]] = None + # Pick the backing that matches the active runtime. The torchbind # class ``torch.classes.tensorrt.RuntimeCacheHandle`` is registered by # the C++ shared library; if the .so isn't loaded @@ -197,6 +214,17 @@ def __init__( self._handle = _RuntimeCacheHandle(path=path) self.autosave_on_del = autosave_on_del + # Engine-implicit handles must save before ``sys.meta_path`` is torn + # down at interpreter exit -- ``__del__`` then hits ``ImportError`` + # from the torchbind property and the lazy ``filelock`` import. + # ``atexit`` fires before that teardown. ``partial`` over a + # ``weakref`` keeps the registration non-owning, so mid-program GC + # still runs ``__del__`` normally. + if autosave_on_del: + self._atexit_token = atexit.register( + partial(_autosave_at_exit, weakref.ref(self)) + ) + @property def path(self) -> str: """The disk path the handle is anchored to. Single source of truth @@ -300,15 +328,32 @@ def save(self, path: Optional[str] = None) -> None: except OSError: pass + def _autosave_if_enabled(self) -> None: + """Idempotent autosave shared by ``__del__`` and the atexit hook. + Flips ``autosave_on_del`` off; swallows any exception so a + shutdown-time ``ImportError`` never leaks as + ``Exception ignored in __del__``. + """ + try: + if self.autosave_on_del and self.path: + self.autosave_on_del = False + self.save() + except Exception: + pass + def __del__(self) -> None: - # Best-effort autosave for engine-implicit handles. The CM disables - # this (``autosave_on_del=False``) since it saves on ``__exit__``; - # user-constructed handles default to disabled so save timing stays - # under the user's control. ``__del__`` can fire during interpreter - # shutdown when imports/filesystem ops fail unpredictably -- swallow. - if self.autosave_on_del and self.path: + # Mid-program GC path. The companion ``_autosave_at_exit`` hook + # covers the case where the handle survives until interpreter exit; + # whichever path runs first flips ``autosave_on_del`` off so the + # other no-ops. + self._autosave_if_enabled() + + # Drop our atexit hook so the registry does not accumulate dead + # entries across many engine-implicit handles in long-running + # processes (each entry holds a now-dead weakref). + if self._atexit_token is not None: try: - self.save() + atexit.unregister(self._atexit_token) except Exception: pass diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index b076a29647..ab8188aada 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -4,6 +4,7 @@ import shutil import tempfile import unittest +from unittest.mock import patch import torch import torch_tensorrt as torchtrt @@ -334,6 +335,98 @@ def test_user_built_handle_no_autosave_by_default(self): "User-built handle with autosave_on_del=False should not save on GC", ) + def test_del_swallows_shutdown_import_error_on_path(self): + """During interpreter shutdown ``self.path`` (a property that forwards + to ``self._handle.path``) can raise ``ImportError`` from a lazy import + triggered on a torn-down ``sys.meta_path``. ``__del__`` must wrap the + entire body in try/except so this does not surface as a noisy + ``Exception ignored in __del__``. + """ + import sys + + from torch_tensorrt.runtime._runtime_cache import RuntimeCache + + handle = RuntimeCache(path="/nonexistent/path", autosave_on_del=True) + + class _Boom: + @property + def path(self) -> str: + raise ImportError( + "sys.meta_path is None, Python is likely shutting down" + ) + + handle._handle = _Boom() + + # An exception escaping ``__del__`` reaches the interpreter via + # ``sys.unraisablehook`` rather than ordinary stderr. Swap the hook + # for a Mock so the call (if any) is recorded and the contract -- + # "nothing leaks" -- maps to ``assert_not_called``. + with patch.object(sys, "unraisablehook") as mock_hook: + del handle + gc.collect() + mock_hook.assert_not_called() + + def test_atexit_hook_saves_via_weakref(self): + """``_autosave_at_exit`` resolves the weakref and invokes ``save()``, + and flips ``autosave_on_del`` off so a subsequent ``__del__`` no-ops. + """ + import weakref + + from torch_tensorrt.runtime._runtime_cache import ( + RuntimeCache, + _autosave_at_exit, + ) + + with tempfile.TemporaryDirectory() as tmp: + path = os.path.join(tmp, "rc.bin") + handle = RuntimeCache(path=path, autosave_on_del=True) + + with patch.object(handle, "save") as mock_save: + _autosave_at_exit(weakref.ref(handle)) + mock_save.assert_called_once() + self.assertFalse( + handle.autosave_on_del, + "atexit hook must flip autosave_on_del off so __del__ skips", + ) + + def test_atexit_hook_no_op_on_dead_weakref(self): + """If the handle was already collected mid-program, the atexit hook + sees a dead weakref and does nothing -- no exceptions, no save.""" + import weakref + + from torch_tensorrt.runtime._runtime_cache import _autosave_at_exit + + class _WeakrefableDummy: + pass + + ref: weakref.ref = weakref.ref(_WeakrefableDummy()) + gc.collect() + self.assertIsNone(ref(), "sentinel must be collected by gc") + + # Must not raise even though ref() is dead. + _autosave_at_exit(ref) + + def test_atexit_token_unregistered_after_del(self): + """``__del__`` removes the handle's atexit hook so the registry does + not accumulate dead entries across many engine-implicit handles in + long-running processes.""" + import atexit + + from torch_tensorrt.runtime._runtime_cache import RuntimeCache + + handle = RuntimeCache(path="/nonexistent/path", autosave_on_del=True) + token = handle._atexit_token + self.assertIsNotNone(token) + + # Spy on ``atexit.unregister`` to verify ``__del__`` cleaned up. Using + # a mock avoids depending on private CPython implementation details + # of the atexit registry (no ``atexit._exithandlers`` in modern + # Python). + with patch.object(atexit, "unregister") as mock_unregister: + del handle + gc.collect() + mock_unregister.assert_called_once_with(token) + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx, From 78635f9f78118e40c68be46b8a6921e5d6c3133f Mon Sep 17 00:00:00 2001 From: tejaswinp Date: Fri, 26 Jun 2026 11:52:31 -0700 Subject: [PATCH 2/2] fix(runtime): make engine wrapper + RuntimeCache survive torch.save `torch.save(module, ...)` walks the wrapper's `__dict__` and serializes `_runtime_settings`, which in turn pickles the engine-implicit ``RuntimeCache``. That object now holds an `atexit` token -- ``partial(_autosave_at_exit, weakref.ref(self))`` -- whose ``weakref`` is not picklable. Result on CI: ``test_mutable_torchtrt_module::test_save`` crashes with ``TypeError: cannot pickle 'weakref.ReferenceType'``. Layered fix: 1. ``TorchTensorRTModule.__getstate__`` / ``__setstate__`` exclude ``_runtime_settings`` and ``_implicit_cache_handle`` from the pickle stream and reset both to defaults on load. This mirrors ``set_extra_state`` (line 515) and the documented intent at line 236: "RuntimeSettings are intentionally NOT serialized: they're per-engine, in-memory init values, not part of the engine's identity (see #4310)." The pickle path now behaves the same as state_dict + load_state_dict. 2. ``RuntimeCache.__getstate__`` / ``__setstate__`` (defense in depth): strip ``_atexit_token`` on pickle, re-register a fresh atexit hook on unpickle if ``autosave_on_del`` was on. Makes standalone-pickle of ``RuntimeCache`` safe for any caller, not just the wrapper. 3. ``RuntimeCache.__del__`` uses ``getattr(self, "_atexit_token", None)`` for the unregister step. ``__init__`` sets the slot first, but protocols like ``copy.deepcopy`` bypass ``__init__`` and can leave an instance partially constructed mid-state-copy; ``getattr`` keeps ``__del__`` quiet in that edge case instead of raising ``AttributeError`` to ``sys.unraisablehook``. New test ``test_pickle_round_trip_strips_atexit_token`` exercises the standalone-pickle path with a stub ``_handle`` to isolate from the orthogonal python-runtime ``threading.Lock`` pickle limitation. Existing ``TestRuntimeCacheAutosave`` + full ``test_000_runtime_cache.py`` continue to pass. Refs #4359 --- .../dynamo/runtime/_TorchTensorRTModule.py | 31 +++++++++++++++ py/torch_tensorrt/runtime/_runtime_cache.py | 39 ++++++++++++++++++- .../dynamo/runtime/test_000_runtime_cache.py | 39 +++++++++++++++++++ 3 files changed, 107 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py index edc4297ed4..22f297eaf7 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py @@ -536,6 +536,37 @@ def set_extra_state(self, state: SerializedTorchTensorRTModuleFmt) -> None: self.output_binding_names = state[3] self.target_device = self._resolve_target_device() + def __getstate__(self) -> dict[str, Any]: + """Exclude per-engine, in-memory state from the pickle stream. + + Mirrors the ``set_extra_state`` reset (line 515) so that + ``torch.save(module)`` / ``torch.load`` behaves the same way as + ``state_dict`` / ``load_state_dict`` w.r.t. ``RuntimeSettings``: + the caller must reapply any ``runtime_cache`` / strategy / cuda-graph + configuration after load. See ``_pack_engine_info`` for the matching + cpp-side exclusion (engine bytes never carry these fields). + + ``_implicit_cache_handle`` is dropped alongside ``_runtime_settings`` + because it aliases the same ``RuntimeCache`` instance and would + otherwise drag a ``weakref`` (via the handle's ``atexit`` closure) + and a Python-only ``threading.Lock`` (when the python-runtime path + is active) into pickle -- neither is picklable. + """ + get_state = getattr(super(), "__getstate__", None) + state = (get_state() if get_state else self.__dict__).copy() + state.pop("_runtime_settings", None) + state.pop("_implicit_cache_handle", None) + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + state.setdefault("_runtime_settings", RuntimeSettings()) + state.setdefault("_implicit_cache_handle", None) + set_state = getattr(super(), "__setstate__", None) + if set_state is not None: + set_state(state) + else: + self.__dict__.update(state) + def set_pre_allocated_outputs(self, enable: bool) -> None: self.get_engine().use_pre_allocated_outputs = enable diff --git a/py/torch_tensorrt/runtime/_runtime_cache.py b/py/torch_tensorrt/runtime/_runtime_cache.py index 9e82fdf443..6059719e09 100644 --- a/py/torch_tensorrt/runtime/_runtime_cache.py +++ b/py/torch_tensorrt/runtime/_runtime_cache.py @@ -341,6 +341,31 @@ def _autosave_if_enabled(self) -> None: except Exception: pass + def __getstate__(self) -> dict[str, Any]: + """Strip the unpicklable atexit token from the pickle stream. + + The token is a ``partial`` over a ``weakref`` -- both of which are + per-process artifacts and ``weakref`` is unpicklable. The pickled + state carries only ``_handle`` (cpp torchbind persists path-only; + see ``register_jit_hooks.cpp``) and ``autosave_on_del``; + ``__setstate__`` reconstructs a fresh atexit hook in the loading + process if autosave was enabled. + """ + state = self.__dict__.copy() + state["_atexit_token"] = None + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + self.__dict__.update(state) + # Re-register atexit autosave in the loading process if it was + # active in the saving one. The fresh ``weakref.ref(self)`` is + # bound to the *new* instance, so the loading-process GC behavior + # mirrors what ``__init__`` would have set up directly. + if self.autosave_on_del and self._atexit_token is None: + self._atexit_token = atexit.register( + partial(_autosave_at_exit, weakref.ref(self)) + ) + def __del__(self) -> None: # Mid-program GC path. The companion ``_autosave_at_exit`` hook # covers the case where the handle survives until interpreter exit; @@ -351,9 +376,19 @@ def __del__(self) -> None: # Drop our atexit hook so the registry does not accumulate dead # entries across many engine-implicit handles in long-running # processes (each entry holds a now-dead weakref). - if self._atexit_token is not None: + # + # Use ``getattr`` rather than direct attribute access: protocols + # like ``copy.deepcopy`` can crash mid-state-copy on an unrelated + # field (e.g. a pre-existing ``threading.Lock`` somewhere else in + # the object graph) and leave the new instance with only some of + # its attributes set. ``__init__`` never ran on that object, so + # ``self._atexit_token`` may simply not exist when ``__del__`` + # fires -- ``getattr`` with a default makes that case a no-op + # instead of raising ``AttributeError`` to ``sys.unraisablehook``. + token = getattr(self, "_atexit_token", None) + if token is not None: try: - atexit.unregister(self._atexit_token) + atexit.unregister(token) except Exception: pass diff --git a/tests/py/dynamo/runtime/test_000_runtime_cache.py b/tests/py/dynamo/runtime/test_000_runtime_cache.py index ab8188aada..135995ce85 100644 --- a/tests/py/dynamo/runtime/test_000_runtime_cache.py +++ b/tests/py/dynamo/runtime/test_000_runtime_cache.py @@ -427,6 +427,45 @@ def test_atexit_token_unregistered_after_del(self): gc.collect() mock_unregister.assert_called_once_with(token) + def test_pickle_round_trip_strips_atexit_token(self): + """Standalone ``RuntimeCache`` pickle: the unpicklable ``partial`` + over ``weakref`` is stripped on ``__getstate__`` and a fresh atexit + hook is wired up by ``__setstate__`` when ``autosave_on_del`` was on. + + ``_handle`` is stubbed with a picklable placeholder so that the test + isolates ``RuntimeCache.__getstate__/__setstate__`` from an + orthogonal pre-existing limitation: the python-runtime + ``_RuntimeCacheHandle`` carries a ``threading.Lock`` that pickle + can't serialize. The cpp-rt torchbind handle pickles to path-only + (see ``register_jit_hooks.cpp``). + """ + import pickle + from types import SimpleNamespace + + from torch_tensorrt.runtime._runtime_cache import RuntimeCache + + original = RuntimeCache(path="/nonexistent/path", autosave_on_del=True) + self.assertIsNotNone(original._atexit_token) + + # Sidestep the python-rt ``threading.Lock`` so we only exercise the + # RuntimeCache state-transition logic. + original._handle = SimpleNamespace(path="/nonexistent/path") + + blob = pickle.dumps(original) + loaded = pickle.loads(blob) + + self.assertTrue(loaded.autosave_on_del) + self.assertEqual(loaded.path, "/nonexistent/path") + self.assertIsNotNone( + loaded._atexit_token, + "autosave_on_del=True must re-wire atexit on unpickle", + ) + self.assertIsNot( + loaded._atexit_token, + original._atexit_token, + "loaded handle must own its own atexit token (fresh weakref)", + ) + @unittest.skipIf( not ENABLED_FEATURES.tensorrt_rtx,