Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions pandasai/core/response/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ def _validate_response(self, result: dict):
raise InvalidOutputValueMismatch(
"Invalid output: Expected a numeric value for result type 'number', but received a non-numeric value."
)
# NaN / inf are floats, so they pass the isinstance check above and would be
# returned as a valid number. They almost always come from an aggregation over
# empty data (e.g. df["x"].mean() on a zero-row result) - reject instead of
# silently returning NaN as the answer.
if isinstance(result["value"], float) and not np.isfinite(result["value"]):
raise InvalidOutputValueMismatch(
"Invalid output: Numeric result is NaN or infinite (likely an aggregation over empty data)."
)
elif result["type"] == "string":
if not isinstance(result["value"], str):
raise InvalidOutputValueMismatch(
Expand Down
33 changes: 33 additions & 0 deletions tests/test_nan_number_rejected.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Regression test: a NaN/inf numeric result must be rejected, not returned as valid.

NaN and inf are floats, so they pass the isinstance(value, (int, float, np.int64))
check in _validate_response and would be wrapped in a NumberResponse and returned as
the answer. They almost always come from an aggregation over an empty result
(e.g. df["sales"].mean() when a WHERE clause matched zero rows).

with the fix -> PASS (raises InvalidOutputValueMismatch)
without it -> FAIL (returns NumberResponse(nan) silently)
"""
import numpy as np
import pytest

from pandasai.core.response.parser import ResponseParser
from pandasai.exceptions import InvalidOutputValueMismatch


def test_nan_number_rejected():
parser = ResponseParser()
with pytest.raises(InvalidOutputValueMismatch, match="NaN"):
parser.parse({"type": "number", "value": float("nan")})


def test_inf_number_rejected():
parser = ResponseParser()
with pytest.raises(InvalidOutputValueMismatch, match="NaN"):
parser.parse({"type": "number", "value": float("inf")})


def test_normal_number_still_ok():
parser = ResponseParser()
resp = parser.parse({"type": "number", "value": 42})
assert resp.value == 42