Skip to content

Commit cf08b92

Browse files
authored
Revert "Add sparse support to ops.ones_like and ops.zeros_like." (#21302)
Arguably, the result of `ones_like` and `zeros_like` should not be sparse.
1 parent dba71da commit cf08b92

File tree

4 files changed

+6
-16
lines changed

4 files changed

+6
-16
lines changed

keras/src/backend/jax/numpy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -920,12 +920,10 @@ def not_equal(x1, x2):
920920
return jnp.not_equal(x1, x2)
921921

922922

923-
@sparse.elementwise_unary(linear=False)
924923
def ones_like(x, dtype=None):
925924
return jnp.ones_like(x, dtype=dtype)
926925

927926

928-
@sparse.elementwise_unary(linear=True)
929927
def zeros_like(x, dtype=None):
930928
return jnp.zeros_like(x, dtype=dtype)
931929

keras/src/backend/tensorflow/numpy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1840,12 +1840,10 @@ def not_equal(x1, x2):
18401840
return tf.not_equal(x1, x2)
18411841

18421842

1843-
@sparse.elementwise_unary
18441843
def ones_like(x, dtype=None):
18451844
return tf.ones_like(x, dtype=dtype)
18461845

18471846

1848-
@sparse.elementwise_unary
18491847
def zeros_like(x, dtype=None):
18501848
return tf.zeros_like(x, dtype=dtype)
18511849

keras/src/ops/numpy.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4516,8 +4516,7 @@ def call(self, x, dtype=None):
45164516
def compute_output_spec(self, x, dtype=None):
45174517
if dtype is None:
45184518
dtype = x.dtype
4519-
sparse = getattr(x, "sparse", False)
4520-
return KerasTensor(x.shape, dtype=dtype, sparse=sparse)
4519+
return KerasTensor(x.shape, dtype=dtype)
45214520

45224521

45234522
@keras_export(["keras.ops.ones_like", "keras.ops.numpy.ones_like"])
@@ -4543,8 +4542,7 @@ def call(self, x, dtype=None):
45434542
def compute_output_spec(self, x, dtype=None):
45444543
if dtype is None:
45454544
dtype = x.dtype
4546-
sparse = getattr(x, "sparse", False)
4547-
return KerasTensor(x.shape, dtype=dtype, sparse=sparse)
4545+
return KerasTensor(x.shape, dtype=dtype)
45484546

45494547

45504548
@keras_export(

keras/src/ops/numpy_test.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5141,7 +5141,6 @@ class SparseTest(testing.TestCase):
51415141
"imag",
51425142
"log1p",
51435143
"negative",
5144-
"ones_like",
51455144
"real",
51465145
"round",
51475146
"sign",
@@ -5151,7 +5150,6 @@ class SparseTest(testing.TestCase):
51515150
"square",
51525151
"tan",
51535152
"tanh",
5154-
"zeros_like",
51555153
]
51565154
ELEMENTWISE_UNARY_OPS_TESTS = [
51575155
{
@@ -5333,11 +5331,10 @@ def test_elementwise_unary_sparse_correctness(
53335331
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
53345332
x = create_sparse_tensor(x)
53355333
x_np = backend.convert_to_numpy(x)
5336-
expected = np_op(x_np) * backend.convert_to_numpy(knp.ones_like(x))
53375334

5338-
self.assertAllClose(op_function(x), expected)
5335+
self.assertAllClose(op_function(x), np_op(x_np))
53395336
self.assertSameSparseness(op_function(x), x)
5340-
self.assertAllClose(op_class()(x), expected)
5337+
self.assertAllClose(op_class()(x), np_op(x_np))
53415338
self.assertSameSparseness(op_class()(x), x)
53425339

53435340
@parameterized.named_parameters(ELEMENTWISE_UNARY_OPS_TESTS)
@@ -5350,11 +5347,10 @@ def test_elementwise_unary_indexed_slices_correctness(
53505347
x = np.array([[1, 0.5, -0.7], [0.9, 0.2, -1]])
53515348
x = create_indexed_slices(x)
53525349
x_np = backend.convert_to_numpy(x)
5353-
expected = np_op(x_np) * backend.convert_to_numpy(knp.ones_like(x))
53545350

5355-
self.assertAllClose(op_function(x), expected)
5351+
self.assertAllClose(op_function(x), np_op(x_np))
53565352
self.assertSameSparseness(op_function(x), x)
5357-
self.assertAllClose(op_class()(x), expected)
5353+
self.assertAllClose(op_class()(x), np_op(x_np))
53585354
self.assertSameSparseness(op_class()(x), x)
53595355

53605356
@parameterized.named_parameters(OTHER_UNARY_OPS_TESTS)

0 commit comments

Comments
 (0)