Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 40 additions & 29 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def __init__(self, config: Config) -> None:
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506)
self._writing_pyc = False
# flag to guard against recursive find_spec calls, e.g. triggered by PYTHON_LAZY_IMPORTS=all
# where accessing a lazily-imported name inside find_spec triggers another find_spec call (#14632)
self._in_find_spec = False
self._basenames_to_check_rewrite = {"conftest"}
self._marked_for_rewrite_cache: dict[str, bool] = {}
self._session_paths_checked = False
Expand All @@ -107,38 +110,46 @@ def find_spec(
) -> importlib.machinery.ModuleSpec | None:
if self._writing_pyc:
return None
state = self.config.stash[assertstate_key]
if self._early_rewrite_bailout(name, state):
return None
state.trace(f"find_module called for: {name}")

# Type ignored because mypy is confused about the `self` binding here.
spec = self._find_spec(name, path) # type: ignore

if (
# the import machinery could not find a file to import
spec is None
# this is a namespace package (without `__init__.py`)
# there's nothing to rewrite there
or spec.origin is None
# we can only rewrite source files
or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
# if the file doesn't exist, we can't rewrite it
or not os.path.exists(spec.origin)
):
if self._in_find_spec:
# Guard against recursive find_spec calls, e.g. triggered by PYTHON_LAZY_IMPORTS=all
# where accessing a lazily-imported name inside find_spec triggers another find_spec call.
return None
else:
fn = spec.origin
self._in_find_spec = True
try:
state = self.config.stash[assertstate_key]
if self._early_rewrite_bailout(name, state):
return None
state.trace(f"find_module called for: {name}")

# Type ignored because mypy is confused about the `self` binding here.
spec = self._find_spec(name, path) # type: ignore

if (
# the import machinery could not find a file to import
spec is None
# this is a namespace package (without `__init__.py`)
# there's nothing to rewrite there
or spec.origin is None
# we can only rewrite source files
or not isinstance(spec.loader, importlib.machinery.SourceFileLoader)
# if the file doesn't exist, we can't rewrite it
or not os.path.exists(spec.origin)
):
return None
else:
fn = spec.origin

if not self._should_rewrite(name, fn, state):
return None
if not self._should_rewrite(name, fn, state):
return None

return importlib.util.spec_from_file_location(
name,
fn,
loader=self,
submodule_search_locations=spec.submodule_search_locations,
)
return importlib.util.spec_from_file_location(
name,
fn,
loader=self,
submodule_search_locations=spec.submodule_search_locations,
)
finally:
self._in_find_spec = False

def create_module(
self, spec: importlib.machinery.ModuleSpec
Expand Down
45 changes: 45 additions & 0 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2424,3 +2424,48 @@ def test():
)
reprec = pytester.inline_run("-p", "no:terminalreporter")
reprec.assertoutcome(passed=1)


def test_rewrite_hook_reentrancy_guard(pytestconfig, pytester: Pytester) -> None:
"""AssertionRewritingHook.find_spec returns None when called recursively,
e.g. when PYTHON_LAZY_IMPORTS=all triggers an import inside find_spec (#14632)."""
hook = AssertionRewritingHook(pytestconfig)

# Simulate a recursive find_spec call from inside the hook body, as would happen
# when PYTHON_LAZY_IMPORTS=all causes a lazy-import resolution mid-find_spec.
original_early_bailout = hook._early_rewrite_bailout
recursive_call_returned_none = []

def spy_early_bailout(name, state):
# Attempt a nested find_spec call; the re-entrancy guard must return None.
result = hook.find_spec("fnmatch")
recursive_call_returned_none.append(result is None)
return original_early_bailout(name, state)

hook._early_rewrite_bailout = spy_early_bailout # type: ignore[method-assign]
pytester.syspathinsert()
pytester.makepyfile(test_foo="def test_foo(): pass")
hook.find_spec("test_foo")

assert recursive_call_returned_none == [True], (
"Recursive find_spec call should return None (re-entrancy guard)"
)


@pytest.mark.skipif(
sys.version_info < (3, 15),

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not technically required. on < 3.15, it does just does nothing. it can still be checked.

reason="PYTHON_LAZY_IMPORTS requires Python 3.15+",
)
def test_lazy_imports_all_does_not_crash_pytest(
pytester: Pytester, monkeypatch: pytest.MonkeyPatch
) -> None:
"""pytest does not crash with PYTHON_LAZY_IMPORTS=all (#14632)."""
pytester.makepyfile(
"""
def test_foo():
assert 1 == 1
"""
)
monkeypatch.setenv("PYTHON_LAZY_IMPORTS", "all")
result = pytester.runpytest_subprocess()
assert result.ret == 0
Loading