Skip to content

Commit 3d896d9

Browse files
committed
change broadcasting behaviour
1 parent f4de2fd commit 3d896d9

File tree

16 files changed

+532
-263
lines changed

16 files changed

+532
-263
lines changed

pytensor/sparse/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ class DenseFromSparse(Op):
954954
955955
"""
956956

957-
__props__ = ()
957+
__props__ = ("sparse_grad",)
958958

959959
def __init__(self, structured=True):
960960
self.sparse_grad = structured

pytensor/sparse/rewriting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,13 +1099,13 @@ def c_code_cache_version(self):
10991099
csm_grad_c = CSMGradC()
11001100

11011101

1102-
@node_rewriter([csm_grad(None)])
1102+
@node_rewriter([csm_grad()])
11031103
def local_csm_grad_c(fgraph, node):
11041104
"""
11051105
csm_grad(None) -> csm_grad_c
11061106
11071107
"""
1108-
if node.op == csm_grad(None):
1108+
if node.op == csm_grad():
11091109
return [csm_grad_c(*node.inputs)]
11101110
return False
11111111

pytensor/sparse/type.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,10 @@ def __init__(
7474
):
7575
if shape is None and broadcastable is None:
7676
shape = (None, None)
77-
77+
if broadcastable is None:
78+
broadcastable = (False, False)
79+
if broadcastable != (False, False):
80+
raise ValueError("Broadcasting sparse types is not yet implemented")
7881
if format not in self.format_cls:
7982
raise ValueError(
8083
f'unsupported format "{format}" not in list',
@@ -96,7 +99,9 @@ def clone(
9699
dtype = self.dtype
97100
if shape is None:
98101
shape = self.shape
99-
return type(self)(format, dtype, shape=shape, **kwargs)
102+
return type(self)(
103+
format, dtype, shape=shape, broadcastable=broadcastable, **kwargs
104+
)
100105

101106
def filter(self, value, strict=False, allow_downcast=None):
102107
if isinstance(value, Variable):

pytensor/tensor/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def get_scalar_constant_value(
396396
for i in v.owner.inputs
397397
]
398398
ret = [[None]]
399-
v.owner.op.perform(v.owner, const, ret)
399+
v.owner.op.scalar_op.perform(v.owner, const, ret)
400400
return np.asarray(ret[0][0].copy())
401401
elif (
402402
isinstance(v.owner.op, pytensor.tensor.subtensor.Subtensor)

0 commit comments

Comments
 (0)