Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 52 additions & 7 deletions py/torch_tensorrt/runtime/_runtime_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,17 @@

from __future__ import annotations

import atexit
import logging
import os
import shutil
import threading
import weakref
from functools import partial
from typing import (
IO,
Any,
Callable,
Optional,
Protocol,
Sequence,
Expand Down Expand Up @@ -140,6 +144,14 @@ def ensure_materialized(self, runtime_config: Any) -> Any:
return self._cache


def _autosave_at_exit(ref: "weakref.ref[RuntimeCache]") -> None:
"""Module-level so the atexit closure only holds a weakref, not a bound
method (see :meth:`RuntimeCache.__init__` for full rationale)."""
rc = ref()
if rc is not None:
rc._autosave_if_enabled()


class RuntimeCache:
"""User-facing handle for the TensorRT-RTX runtime kernel cache.

Expand Down Expand Up @@ -181,6 +193,11 @@ def __init__(
path: str = "",
autosave_on_del: bool = False,
) -> None:
# Set the atexit-token slot first so ``__del__`` can safely read it
# even if a later step in ``__init__`` raises and leaves the object
# partially constructed.
self._atexit_token: Optional[Callable[..., None]] = None

# Pick the backing that matches the active runtime. The torchbind
# class ``torch.classes.tensorrt.RuntimeCacheHandle`` is registered by
# the C++ shared library; if the .so isn't loaded
Expand All @@ -197,6 +214,17 @@ def __init__(
self._handle = _RuntimeCacheHandle(path=path)
self.autosave_on_del = autosave_on_del

# Engine-implicit handles must save before ``sys.meta_path`` is torn
# down at interpreter exit -- ``__del__`` then hits ``ImportError``
# from the torchbind property and the lazy ``filelock`` import.
# ``atexit`` fires before that teardown. ``partial`` over a
# ``weakref`` keeps the registration non-owning, so mid-program GC
# still runs ``__del__`` normally.
if autosave_on_del:
self._atexit_token = atexit.register(
partial(_autosave_at_exit, weakref.ref(self))
)

@property
def path(self) -> str:
"""The disk path the handle is anchored to. Single source of truth
Expand Down Expand Up @@ -300,15 +328,32 @@ def save(self, path: Optional[str] = None) -> None:
except OSError:
pass

def _autosave_if_enabled(self) -> None:
"""Idempotent autosave shared by ``__del__`` and the atexit hook.
Flips ``autosave_on_del`` off; swallows any exception so a
shutdown-time ``ImportError`` never leaks as
``Exception ignored in __del__``.
"""
try:
if self.autosave_on_del and self.path:
self.autosave_on_del = False
self.save()
except Exception:
pass

def __del__(self) -> None:
# Best-effort autosave for engine-implicit handles. The CM disables
# this (``autosave_on_del=False``) since it saves on ``__exit__``;
# user-constructed handles default to disabled so save timing stays
# under the user's control. ``__del__`` can fire during interpreter
# shutdown when imports/filesystem ops fail unpredictably -- swallow.
if self.autosave_on_del and self.path:
# Mid-program GC path. The companion ``_autosave_at_exit`` hook
# covers the case where the handle survives until interpreter exit;
# whichever path runs first flips ``autosave_on_del`` off so the
# other no-ops.
self._autosave_if_enabled()

# Drop our atexit hook so the registry does not accumulate dead
# entries across many engine-implicit handles in long-running
# processes (each entry holds a now-dead weakref).
if self._atexit_token is not None:
try:
self.save()
atexit.unregister(self._atexit_token)
except Exception:
pass

Expand Down
93 changes: 93 additions & 0 deletions tests/py/dynamo/runtime/test_000_runtime_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
import tempfile
import unittest
from unittest.mock import patch

import torch
import torch_tensorrt as torchtrt
Expand Down Expand Up @@ -334,6 +335,98 @@ def test_user_built_handle_no_autosave_by_default(self):
"User-built handle with autosave_on_del=False should not save on GC",
)

def test_del_swallows_shutdown_import_error_on_path(self):
"""During interpreter shutdown ``self.path`` (a property that forwards
to ``self._handle.path``) can raise ``ImportError`` from a lazy import
triggered on a torn-down ``sys.meta_path``. ``__del__`` must wrap the
entire body in try/except so this does not surface as a noisy
``Exception ignored in __del__``.
"""
import sys

from torch_tensorrt.runtime._runtime_cache import RuntimeCache

handle = RuntimeCache(path="/nonexistent/path", autosave_on_del=True)

class _Boom:
@property
def path(self) -> str:
raise ImportError(
"sys.meta_path is None, Python is likely shutting down"
)

handle._handle = _Boom()

# An exception escaping ``__del__`` reaches the interpreter via
# ``sys.unraisablehook`` rather than ordinary stderr. Swap the hook
# for a Mock so the call (if any) is recorded and the contract --
# "nothing leaks" -- maps to ``assert_not_called``.
with patch.object(sys, "unraisablehook") as mock_hook:
del handle
gc.collect()
mock_hook.assert_not_called()

def test_atexit_hook_saves_via_weakref(self):
"""``_autosave_at_exit`` resolves the weakref and invokes ``save()``,
and flips ``autosave_on_del`` off so a subsequent ``__del__`` no-ops.
"""
import weakref

from torch_tensorrt.runtime._runtime_cache import (
RuntimeCache,
_autosave_at_exit,
)

with tempfile.TemporaryDirectory() as tmp:
path = os.path.join(tmp, "rc.bin")
handle = RuntimeCache(path=path, autosave_on_del=True)

with patch.object(handle, "save") as mock_save:
_autosave_at_exit(weakref.ref(handle))
mock_save.assert_called_once()
self.assertFalse(
handle.autosave_on_del,
"atexit hook must flip autosave_on_del off so __del__ skips",
)

def test_atexit_hook_no_op_on_dead_weakref(self):
"""If the handle was already collected mid-program, the atexit hook
sees a dead weakref and does nothing -- no exceptions, no save."""
import weakref

from torch_tensorrt.runtime._runtime_cache import _autosave_at_exit

class _WeakrefableDummy:
pass

ref: weakref.ref = weakref.ref(_WeakrefableDummy())
gc.collect()
self.assertIsNone(ref(), "sentinel must be collected by gc")

# Must not raise even though ref() is dead.
_autosave_at_exit(ref)

def test_atexit_token_unregistered_after_del(self):
"""``__del__`` removes the handle's atexit hook so the registry does
not accumulate dead entries across many engine-implicit handles in
long-running processes."""
import atexit

from torch_tensorrt.runtime._runtime_cache import RuntimeCache

handle = RuntimeCache(path="/nonexistent/path", autosave_on_del=True)
token = handle._atexit_token
self.assertIsNotNone(token)

# Spy on ``atexit.unregister`` to verify ``__del__`` cleaned up. Using
# a mock avoids depending on private CPython implementation details
# of the atexit registry (no ``atexit._exithandlers`` in modern
# Python).
with patch.object(atexit, "unregister") as mock_unregister:
del handle
gc.collect()
mock_unregister.assert_called_once_with(token)


@unittest.skipIf(
not ENABLED_FEATURES.tensorrt_rtx,
Expand Down
Loading