Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import isclose, nan_to_num, one_hot, pad
from ._delegation import (
argpartition,
isclose,
nan_to_num,
one_hot,
pad,
partition,
)
from ._lib._at import at
from ._lib._funcs import (
apply_where,
Expand All @@ -23,6 +30,7 @@
__all__ = [
"__version__",
"apply_where",
"argpartition",
"at",
"atleast_nd",
"broadcast_shapes",
Expand All @@ -37,6 +45,7 @@
"nunique",
"one_hot",
"pad",
"partition",
"setdiff1d",
"sinc",
]
107 changes: 107 additions & 0 deletions src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,3 +326,110 @@ def pad(
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]

return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp)


def partition(
a: Array,
kth: int,
*,
xp: ModuleType | None = None,
Comment thread
lucascolley marked this conversation as resolved.
) -> Array:
"""
Return a partitioned copy of an array.
Comment thread
cakedev0 marked this conversation as resolved.

Parameters
----------
a : 1-dimensional array
Comment thread
cakedev0 marked this conversation as resolved.
Outdated
Input array.
kth : int
Element index to partition by.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
partitioned_array
Array of the same type and shape as a.
Comment thread
cakedev0 marked this conversation as resolved.
Outdated
"""
# Validate inputs.
if xp is None:
xp = array_namespace(a)
if a.ndim != 1:
msg = "only 1-dimensional arrays are currently supported"
raise NotImplementedError(msg)

# Delegate where possible.
if is_numpy_namespace(xp) or is_cupy_namespace(xp):
return xp.partition(a, kth)
if is_jax_namespace(xp):
from jax import numpy

return numpy.partition(a, kth)
Comment thread
cakedev0 marked this conversation as resolved.
Outdated

# Use top-k when possible:
if is_torch_namespace(xp):
from torch import topk

a_left, indices_left = topk(a, kth, largest=False, sorted=False)
Comment thread
cakedev0 marked this conversation as resolved.
Outdated
mask_right = xp.ones(a.shape, dtype=bool)
mask_right[indices_left] = False
return xp.concat((a_left, a[mask_right]))
# Note: dask topk/argtopk sort the return values, so it's
# not much more efficient than sorting everything when
# kth is not small compared to x.size

return _funcs.partition(a, kth, xp=xp)
Comment thread
cakedev0 marked this conversation as resolved.
Outdated


def argpartition(
a: Array,
kth: int,
*,
xp: ModuleType | None = None,
) -> Array:
"""
Perform an indirect partition along the given axis.

Parameters
----------
a : 1-dimensional array
Input array.
kth : int
Element index to partition by.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
index_array
Array of indices that partition `a` along the specified axis.
"""
# Validate inputs.
if xp is None:
xp = array_namespace(a)
if a.ndim != 1:
msg = "only 1-dimensional arrays are currently supported"
raise NotImplementedError(msg)
Comment thread
cakedev0 marked this conversation as resolved.

# Delegate where possible.
if is_numpy_namespace(xp) or is_cupy_namespace(xp):
return xp.argpartition(a, kth)
if is_jax_namespace(xp):
from jax import numpy

return numpy.argpartition(a, kth)

# Use top-k when possible:
if is_torch_namespace(xp):
from torch import topk
Comment thread
cakedev0 marked this conversation as resolved.
Outdated

_, indices = topk(a, kth, largest=False, sorted=False)
mask = xp.ones(a.shape, dtype=bool)
mask[indices] = False
indices_above = xp.arange(a.shape[0])[mask]
return xp.concat((indices, indices_above))
# Note: dask topk/argtopk sort the return values, so it's
# not much more efficient than sorting everything when
# kth is not small compared to x.size

return _funcs.argpartition(a, kth, xp=xp)
Comment thread
cakedev0 marked this conversation as resolved.
Outdated
20 changes: 20 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,3 +1029,23 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
)
return xp.sin(y) / y


def partition( # numpydoc ignore=PR01,RT01
x: Array,
kth: int, # noqa: ARG001
*,
xp: ModuleType,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
return xp.sort(x, stable=False)


def argpartition( # numpydoc ignore=PR01,RT01
x: Array,
kth: int, # noqa: ARG001
*,
xp: ModuleType,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""
return xp.argsort(x, stable=False)
17 changes: 17 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from array_api_extra import (
apply_where,
argpartition,
at,
atleast_nd,
broadcast_shapes,
Expand All @@ -25,6 +26,7 @@
nunique,
one_hot,
pad,
partition,
setdiff1d,
sinc,
)
Expand Down Expand Up @@ -1298,3 +1300,18 @@ def test_device(self, xp: ModuleType, device: Device):

def test_xp(self, xp: ModuleType):
xp_assert_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))


class TestPartition:
def test_basic(self, xp: ModuleType):
# Using 0-dimensional array
rng = np.random.default_rng(2847)

for _ in range(100):
n = rng.integers(1, 1000)
x = xp.asarray(rng.random(size=n))
k = int(rng.integers(1, n - 1))
y = partition(x, k)
assert xp.max(y[:k]) <= xp.min(y[k:])
y = x[argpartition(x, k)]
assert xp.max(y[:k]) <= xp.min(y[k:])
Loading