Skip to content

[Fix] issues mentioned in comments of #22472#22595

Closed
ChiragSW wants to merge 11 commits intokeras-team:masterfrom
ChiragSW:issue#22472#extra
Closed

[Fix] issues mentioned in comments of #22472#22595
ChiragSW wants to merge 11 commits intokeras-team:masterfrom
ChiragSW:issue#22472#extra

Conversation

@ChiragSW
Copy link
Copy Markdown
Contributor

@ChiragSW ChiragSW commented Mar 31, 2026

Fixes issues mentioned in comments of #22472

Actually, we have discovered that many APIs in the keras.ops module exhibit the issue described above—they lack validity checks when accepting Keras.Input as an input.
These include:

keras.ops.flip
keras.ops.log_softmax
keras.ops.sparse_categorical_crossentropy
keras.ops.roll
keras.ops.sparsemax
keras.ops.trace
keras.ops.image.pad_images
keras.ops.image.crop_images

Root cause

keras.ops functions branch like:
if any_symbolic_tensors(inputs): call Operation(...).symbolic_call() which relies on compute_output_spec() for validation
else: run backend/eager code paths where the real argument checks happen
So any checks that only existed in the eager helper/backend path were skipped for keras.Input.

Fix

keras.ops.image.pad_images / crop_images (keras/src/ops/image.py)

  1. pad_images(): added the missing “must specify exactly two of …” argument validation in the public wrapper so it runs for both eager and symbolic inputs.
  2. crop_images(): added the same “exactly two of …” validation in the wrapper.
  3. CropImages.compute_output_spec(): added validation by inferring missing crop amounts when input_height/input_width are known, and raising if the inferred values would be negative.

keras.ops.flip, keras.ops.roll, keras.ops.trace (keras/src/ops/numpy.py)

  1. flip(): validate axis is None, int, or a sequence of ints.
  2. roll(): validate axis/shift compatibility.
  3. trace(): validate axis1 != axis2 during symbolic shape inference.

keras.ops.log_softmax, keras.ops.sparsemax, keras.ops.sparse_categorical_crossentropy (keras/src/ops/nn.py)

  1. log_softmax() and sparsemax(): validate axis bounds against rank when rank is known (so axis=-3 for a rank-2 input raises for symbolic too)
  2. sparse_categorical_crossentropy(): enforce the existing constraint axis == -1 in the wrapper so symbolic inputs don’t bypass it.
  • I am a human, and not a bot.
  • I will be responsible for responding to review comments in a timely manner.
  • I will work with the maintainers to push this PR forward until submission.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces input validation and axis canonicalization for several image, neural network, and numpy operations to improve API reliability. However, the logic used to determine tensor rank is fragile, as it fails for Python lists and causes crashes for KerasTensor instances with unknown rank. Additionally, the roll operation incorrectly validates the shift argument by not supporting broadcasting, and some manual checks are redundant with existing utility functions. The feedback highlights the need for more robust tensor conversion and consistent use of backend utilities to handle all valid input types and edge cases without crashing.

Comment thread keras/src/ops/nn.py Outdated
Comment on lines +1017 to +1038
ndim = len(getattr(x, "shape", []))
if isinstance(axis, int):
if axis < -ndim or axis >= ndim:
raise ValueError(
f"axis {axis} is out of bounds for array of dimension {ndim}"
)
axis = axis if axis >= 0 else axis + ndim
elif isinstance(axis, tuple):
canonical_axis = []
for a in axis:
if not isinstance(a, int):
raise TypeError(
"Argument `axis` must be an integer or tuple of integers. "
f"Received: axis={axis}"
)
if a < -ndim or a >= ndim:
raise ValueError(
f"axis {a} is out of bounds for array of dimension {ndim}"
)
a = a if a >= 0 else a + ndim
canonical_axis.append(a)
axis = tuple(canonical_axis)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using len(getattr(x, "shape", [])) to determine rank is problematic for two reasons:

  1. It returns 0 for Python lists (which lack a .shape attribute), causing valid calls like log_softmax([1.0, 2.0]) to fail with an out-of-bounds error.
  2. It will crash with a TypeError if x.shape is None, which is the case for KerasTensor instances with unknown rank.

It is recommended to use backend.convert_to_tensor(x) first and then check if x.shape is not None before validating the axis.

References
  1. APIs should be consistent with established conventions (NumPy) and handle common input types like lists gracefully. (link)

Comment thread keras/src/ops/nn.py Outdated
Comment on lines +1099 to +1108
ndim = len(getattr(x, "shape", []))
if not isinstance(axis, int):
raise TypeError(
f"Argument `axis` must be an integer. Received: axis={axis}"
)
if axis < -ndim or axis >= ndim:
raise ValueError(
f"axis {axis} is out of bounds for array of dimension {ndim}"
)
axis = axis if axis >= 0 else axis + ndim
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block has the same ndim detection issue as log_softmax. It will fail for Python lists and crash for tensors with unknown rank. Please convert x to a tensor and verify x.shape is not None before proceeding with validation.

Comment thread keras/src/ops/numpy.py Outdated
Comment on lines +3514 to +3529
ndim = len(getattr(x, "shape", []))
if isinstance(axis, int):
canonicalize_axis(axis, ndim)
elif isinstance(axis, (tuple, list)):
for a in axis:
if not isinstance(a, int):
raise TypeError(
"Argument `axis` must be an integer or a sequence of "
f"integers. Received: axis={axis}"
)
canonicalize_axis(a, ndim)
else:
raise TypeError(
"Argument `axis` must be an integer, a sequence of integers, "
f"or `None`. Received: axis={axis}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The ndim detection here fails for Python lists and crashes for tensors with unknown rank. Additionally, this manual validation logic is redundant with backend_utils.canonicalize_axis. Consider converting x to a tensor first and using the utility function safely.

Comment thread keras/src/ops/numpy.py Outdated
Comment on lines +7032 to +7069
ndim = len(getattr(x, "shape", []))
if axis is None:
if isinstance(shift, (tuple, list)):
raise ValueError(
"When `axis` is `None`, `shift` must be an integer. "
f"Received: shift={shift}"
)
elif isinstance(axis, int):
canonicalize_axis(axis, ndim)
if isinstance(shift, (tuple, list)):
raise ValueError(
"When `axis` is an integer, `shift` must be an integer. "
f"Received: shift={shift}"
)
elif isinstance(axis, (tuple, list)):
for a in axis:
if not isinstance(a, int):
raise TypeError(
"Argument `axis` must be an integer or a sequence of "
f"integers. Received: axis={axis}"
)
canonicalize_axis(a, ndim)
if not isinstance(shift, (tuple, list)) or len(shift) != len(axis):
raise ValueError(
"`shift` and `axis` must have the same size. "
f"Received: shift={shift}, axis={axis}"
)
for s in shift:
if not isinstance(s, int):
raise TypeError(
"Argument `shift` must be an integer or a sequence of "
f"integers. Received: shift={shift}"
)
else:
raise TypeError(
"Argument `axis` must be an integer, a sequence of integers, "
f"or `None`. Received: axis={axis}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are two issues in this block:

  1. The ndim detection fails for lists and crashes on unknown rank tensors.
  2. The logic for shift validation when axis is a sequence (lines 7054-7058) is incorrect. It currently raises a ValueError if shift is a single integer but axis is a tuple/list. NumPy allows a single integer shift to be broadcast across all specified axes.

Suggested fix for the shift logic:

        if isinstance(shift, (tuple, list)):
            if len(shift) != len(axis):
                raise ValueError(
                    "`shift` and `axis` must have the same size. "
                    f"Received: shift={shift}, axis={axis}"
                )
            for s in shift:
                if not isinstance(s, int):
                    raise TypeError(
                        "Argument `shift` must be an integer or a sequence of "
                        f"integers. Received: shift={shift}"
                    )
        elif not isinstance(shift, int):
            raise TypeError(
                "Argument `shift` must be an integer or a sequence of "
                f"integers. Received: shift={shift}"
            )

Comment thread keras/src/ops/numpy.py Outdated
Comment on lines +7897 to +7904
ndim = len(getattr(x, "shape", []))
axis1 = canonicalize_axis(axis1, ndim)
axis2 = canonicalize_axis(axis2, ndim)
if axis1 == axis2:
raise ValueError(
f"axis1 and axis2 must be different. Received: axis1={axis1}, "
f"axis2={axis2}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Same ndim detection issue as other ops: this will fail for Python lists and crash for tensors with unknown rank. Please ensure x is converted to a tensor and its shape is verified before calling canonicalize_axis.

Comment thread keras/src/ops/numpy.py Outdated
Comment on lines +7846 to +7853
ndim = len(getattr(x, "shape", []))
axis1 = canonicalize_axis(self.axis1, ndim)
axis2 = canonicalize_axis(self.axis2, ndim)
if axis1 == axis2:
raise ValueError(
f"axis1 and axis2 must be different. Received: "
f"axis1={self.axis1}, axis2={self.axis2}"
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In compute_output_spec, x is a KerasTensor. If the rank is unknown (x.shape is None), len(getattr(x, "shape", [])) will crash. Since trace requires at least 2 dimensions, you should handle the None shape case by returning a KerasTensor with unknown shape or raising a more descriptive error if rank must be known.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Mar 31, 2026

Codecov Report

❌ Patch coverage is 59.43396% with 86 lines in your changes missing coverage. Please review.
✅ Project coverage is 83.01%. Comparing base (a2e97e1) to head (94b9793).
⚠️ Report is 88 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/ops/image.py 31.91% 16 Missing and 16 partials ⚠️
keras/src/ops/numpy.py 66.27% 16 Missing and 13 partials ⚠️
keras/src/ops/nn.py 64.81% 8 Missing and 11 partials ⚠️
keras/src/utils/rng_utils.py 60.00% 2 Missing and 2 partials ⚠️
keras/src/ops/operation_utils.py 81.81% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22595      +/-   ##
==========================================
- Coverage   83.26%   83.01%   -0.26%     
==========================================
  Files         596      596              
  Lines       67828    69050    +1222     
  Branches    10562    10855     +293     
==========================================
+ Hits        56480    57322     +842     
- Misses       8605     8861     +256     
- Partials     2743     2867     +124     
Flag Coverage Δ
keras 82.83% <59.43%> (-0.26%) ⬇️
keras-jax 59.01% <58.49%> (-0.81%) ⬇️
keras-numpy 54.85% <58.49%> (+0.41%) ⬆️
keras-openvino 59.32% <58.49%> (+7.62%) ⬆️
keras-tensorflow 60.57% <59.43%> (-0.57%) ⬇️
keras-torch 59.34% <58.49%> (-0.66%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ChiragSW
Copy link
Copy Markdown
Contributor Author

ChiragSW commented Apr 1, 2026

There are some crashes and bugs that need to be fixed. I will look into it soon

@ChiragSW
Copy link
Copy Markdown
Contributor Author

ChiragSW commented Apr 2, 2026

Summary of new changes

  • Updated many ops to stop using len(getattr(x, "shape", [])) for rank detection, since it breaks on Python lists (no .shape) and can crash when x.shape is None for unknown-rank symbolic tensors.
  • For log_softmax, sparsemax, flip, roll, and trace, now converted x to a tensor and only does canonicalization when the rank is actually known, so those failures donot happen.
  • Also fixed roll’s validation so a single integer shift is allowed when axis is a tuple or list.
  • Trace.compute_output_spec now handles unknown-rank symbolic inputs by raising an error and not crashing.

@ChiragSW
Copy link
Copy Markdown
Contributor Author

ChiragSW commented Apr 2, 2026

The changes resolve the issues. Please review @keerthanakadiri @hertschuh

Copy link
Copy Markdown
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not following the context of this. It seems unrelated to the linked bug.

Are we missing some validation for the axis argument in some places?

If so, can you add tests with self.assertRaises to demonstrate the problem.

Also, there should never be any code in-between if any_symbolic_tensor(...) and the call to symbolic_call. Some validation can happen before the if. But some validation has to happen in the backend specific implementation after convert_to_tensor.

Comment thread keras/src/ops/nn.py Outdated
Comment on lines +1018 to +1023
if isinstance(axis, int):
if axis < -ndim or axis >= ndim:
raise ValueError(
f"axis {axis} is out of bounds for array of dimension {ndim}"
)
axis = axis if axis >= 0 else axis + ndim
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use canonicalize_axis for this.

@ChiragSW
Copy link
Copy Markdown
Contributor Author

ChiragSW commented Apr 4, 2026

Problem is that Keras ops have two execution paths, symbolic and eager. Validation checks only lived in the eager path, so invalid inputs via keras.Input would silently pass through without any error.

Fix:

  • Run shared validation before the any_symbolic_tensors branch when rank is statically knowable
  • Keep validation after convert_to_tensor on the eager path
  • For purely symbolic cases, put validation inside compute_output_spec

Tests added
Across nn_test.py, numpy_test.py, image_test.py, and operation_utils_test.py, all asserting that bad axis or rank inputs on keras.Input now correctly raise ValueError.

@ChiragSW
Copy link
Copy Markdown
Contributor Author

ChiragSW commented Apr 9, 2026

The issue has been fixed. Please review @hertschuh

Copy link
Copy Markdown
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First, this PR is combining a few unrelated things into one. Please split this PR in at least 3 separate PRs:

  • the image padding / cropping validation
  • the axis verification / canonicalization
  • the RNG seed generator changes (and I don't understand the context of these changes)

Then, I see a lot of code duplication like this:

    if axis is not None:
        ndim = get_static_tensor_ndim(x)
        if isinstance(axis, int):
            if ndim is not None:
                canonicalize_axis(axis, ndim)
        elif isinstance(axis, (tuple, list)):
            for a in axis:
                if not isinstance(a, int):
                    raise TypeError(
                        "Argument `axis` must be an integer or a sequence "
                        f"of integers. Received: axis={axis}"
                    )
                if ndim is not None:
                    canonicalize_axis(a, ndim)
        else:
            raise TypeError(
                "Argument `axis` must be an integer, a sequence of "
                f"integers, or `None`. Received: axis={axis}"
            )

Please create a helper function called canonicalize_axes after canonicalize_axis. It will take an int or a list of ints and always return a tuple of ints after doing the validation.

Next, see this pattern a lot:

# Some validation
if any_symbolic_tensors(...)
   ...
x = convert_to_tensor(x)
# The same validation

This is not a pattern that we should use:

  • It creates a ton of code duplication
  • We should let the backend implementation do the x = convert_to_tensor(x) part, sometimes the backend implementation needs to look at the type or value of x before converting it to a tensor

If you can't do the validation because you don't fully know the shape of the input, it means it shouldn't be here. It's ok to move the validation in the backend specific implementations (as long as it's factored as one-liner) and in symbolic_call.

Comment thread keras/src/ops/image.py
Comment on lines +13 to +14
and hasattr(images, "_keras_history")
and images._keras_history.operation.__class__.__name__ == "InputLayer"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why you are looking at the keras history, you should never have to do that. Also, you shouldn't even care if it's a keras tensor or a normal tensor.

Comment on lines +10 to 11
import keras
from keras.src import backend
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not import keras in unit test, please import the feature from keras.src

@ChiragSW
Copy link
Copy Markdown
Contributor Author

ChiragSW commented Apr 16, 2026

For now I will close this PR and put up 3 separate PRs. I will add the details in what each PR fixes. Should I go ahead with this process @hertschuh ?
The 3 PRs will be:

  1. image padding / cropping validation
  2. axis verification / canonicalization
  3. RNG seed generator changes

@hertschuh
Copy link
Copy Markdown
Collaborator

For now I will close this PR and put up 3 separate PRs. I will add the details in what each PR fixes. Should I go ahead with this process @hertschuh ? The 3 PRs will be:

  1. image padding / cropping validation
  2. axis verification / canonicalization
  3. RNG seed generator changes

Yes, thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants