Skip to content

Commit df58e70

Browse files
committed
jax eig now has GPU support
1 parent c6187c1 commit df58e70

File tree

1 file changed

+0
-8
lines changed

1 file changed

+0
-8
lines changed

keras/src/ops/linalg_test.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -350,14 +350,6 @@ def test_det(self):
350350
def test_eig(self):
351351
x = np.random.rand(2, 3, 3)
352352
x = x @ x.transpose((0, 2, 1))
353-
if backend.backend() == "jax":
354-
import jax
355-
356-
if jax.default_backend() == "gpu":
357-
# eig not implemented for jax on gpu backend
358-
with self.assertRaises(NotImplementedError):
359-
linalg.eig(x)
360-
return
361353
w, v = map(ops.convert_to_numpy, linalg.eig(x))
362354
x_reconstructed = (v * w[..., None, :]) @ v.transpose((0, 2, 1))
363355
self.assertAllClose(x_reconstructed, x, atol=1e-4)

0 commit comments

Comments
 (0)