Remove backend-specific strides + dilation_rate restriction from DepthwiseConv and SeparableConv#22598
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces validation in depthwise_conv and separable_conv to prevent the simultaneous use of strides and dilation rates greater than one, which is not supported. The implementation raises a ValueError when these conditions are met, and the tests have been updated to assert this behavior. Feedback focuses on improving maintainability by extracting the duplicated max-value calculation logic into a helper function and ensuring that the type-checking in tests is consistent with the implementation by supporting both lists and tuples.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #22598 +/- ##
==========================================
- Coverage 82.95% 82.94% -0.01%
==========================================
Files 596 596
Lines 69200 69196 -4
Branches 10806 10804 -2
==========================================
- Hits 57402 57398 -4
Misses 8969 8969
Partials 2829 2829
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
hertschuh
left a comment
There was a problem hiding this comment.
What it looks like to me is that it was working on all backends except Tensorflow. So I would argue that we should do the reverse, i.e. remove the check on BaseDepthwiseConv and BaseSeparableConv and accept the fact that it doesn't work on Tensorflow. I think this check was inherited from Keras 2, which was TensorFlow only.
|
@hertschuh Good point! I have reversed the approach and removed the global validation from depthwise_conv and separable_conv. The TF-specific skip is restored in the correctness tests, and the shape tests are back to asserting the correct output shapes for backends that support the combination. |
|
This PR is now empty. |
|
Forgot to push the changes. Will do that soon |
|
I am seeing failure from OpenVino that is not related to the change in this PR. |
hertschuh
left a comment
There was a problem hiding this comment.
A couple very small tweaks:
Also, can you rebase?
…le_conv ops The low-level keras.ops.depthwise_conv and keras.ops.separable_conv functions lacked the same strides > 1 + dilation_rate > 1 validation that already exists in the high-level DepthwiseConv2D and SeparableConv2D layers. This caused the symbolic (static) shape inference to return incorrect shapes when both strides > 1 and dilation_rate > 1 are used together, since the formula-based shape computation disagrees with what TF's depthwise/separable conv ops actually produce. Also added the validation to both ops and updated the tests to assert a ValueError instead of silently skipping the unsupported combination on the TF backend.
…s+dilation The static and dynamic shape tests were asserting on output shapes for strides=2 + dilation_rate=2 combinations, which are now correctly rejected with a ValueError. Also extract _get_seq_max helper to avoid duplicated logic in depthwise_conv and separable_conv, and use (list, tuple) consistently in the test isinstance checks.
As per reviewer feedback, strides+dilation works on all backends except TensorFlow. Remove the ValueError validation added to depthwise_conv() and separable_conv() in nn.py, restore the original shape assertions in static/dynamic shape tests, and restore the TF-specific skip in the correctness tests.
…parableConv as it only applies to TensorFlow
83893d1 to
e39ebe8
Compare
hertschuh
left a comment
There was a problem hiding this comment.
Can you update the PR description and title? It's no longer what this is.
A couple of things don't conform to the code format:
It is done. Please check |
The strides > 1 + dilation_rate > 1 combination is not universally unsupported. It only fails on the TensorFlow backend. The previous validation in BaseDepthwiseConv.init and BaseSeparableConv.init was raising a ValueError for all backends, which was too strict. This PR removes that blanket restriction and its associated docstring notes and tests, allowing JAX, PyTorch, and other backends to use the combination freely. Also adds compute_output_shape validation in BaseSeparableConv.build() so invalid output dimensions are still caught at build time regardless of backend.
Contributor Agreement