We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 21bc5fa commit 530e856Copy full SHA for 530e856
1 file changed
tests/test_torch.py
@@ -161,3 +161,18 @@ def test_round():
161
r = xp.round(x, decimals=1, out=o)
162
assert xp.all(r == o)
163
assert r is o
164
+
165
166
+def test_dynamo_array_namespace():
167
+ """Check that torch.compiling array_namespace does not incur graph breaks."""
168
+ from array_api_compat import array_namespace
169
170
+ def foo(x):
171
+ xp = array_namespace(x)
172
+ return xp.multiply(x, x)
173
174
+ bar = torch.compile(fullgraph=True)(foo)
175
176
+ x = torch.arange(3)
177
+ y = bar(x)
178
+ assert xp.all(y == x**2)
0 commit comments