We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
2 parents 9f97689 + 5ef2a2f commit c303adcCopy full SHA for c303adc
1 file changed
array_api_strict/_elementwise_functions.py
@@ -352,5 +352,7 @@ def sign(x: Array, /) -> Array:
352
raise TypeError("Only numeric dtypes are allowed in sign")
353
# Special treatment to work around non-compliant NumPy 1.x behaviour
354
if x.dtype in _complex_floating_dtypes:
355
- return x/abs(x)
+ _x = x._array
356
+ _result = _x / np.abs(np.where(_x != 0, _x, np.asarray(1.0, dtype=_x.dtype)))
357
+ return Array._new(_result, device=x.device)
358
return Array._new(np.sign(x._array), device=x.device)
0 commit comments