Skip to content

Commit 4bbc58f

Browse files
committed
change broadcasting behaviour
1 parent d7ffee8 commit 4bbc58f

File tree

16 files changed

+532
-252
lines changed

16 files changed

+532
-252
lines changed

pytensor/sparse/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -954,7 +954,7 @@ class DenseFromSparse(Op):
954954
955955
"""
956956

957-
__props__ = ()
957+
__props__ = ("sparse_grad",)
958958

959959
def __init__(self, structured=True):
960960
self.sparse_grad = structured

pytensor/sparse/rewriting.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,13 +1099,13 @@ def c_code_cache_version(self):
10991099
csm_grad_c = CSMGradC()
11001100

11011101

1102-
@node_rewriter([csm_grad(None)])
1102+
@node_rewriter([csm_grad()])
11031103
def local_csm_grad_c(fgraph, node):
11041104
"""
11051105
csm_grad(None) -> csm_grad_c
11061106
11071107
"""
1108-
if node.op == csm_grad(None):
1108+
if node.op == csm_grad():
11091109
return [csm_grad_c(*node.inputs)]
11101110
return False
11111111

pytensor/sparse/type.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@ def __init__(
7373
):
7474
if shape is None and broadcastable is None:
7575
shape = (None, None)
76-
76+
if broadcastable is None:
77+
broadcastable = (False, False)
78+
if broadcastable != (False, False):
79+
raise ValueError("Broadcasting sparse types is not yet implemented")
7780
if format not in self.format_cls:
7881
raise ValueError(
7982
f'unsupported format "{format}" not in list',
@@ -95,7 +98,9 @@ def clone(
9598
dtype = self.dtype
9699
if shape is None:
97100
shape = self.shape
98-
return type(self)(format, dtype, shape=shape, **kwargs)
101+
return type(self)(
102+
format, dtype, shape=shape, broadcastable=broadcastable, **kwargs
103+
)
99104

100105
def filter(self, value, strict=False, allow_downcast=None):
101106
if isinstance(value, Variable):

pytensor/tensor/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def get_scalar_constant_value(
396396
for i in v.owner.inputs
397397
]
398398
ret = [[None]]
399-
v.owner.op.perform(v.owner, const, ret)
399+
v.owner.op.scalar_op.perform(v.owner, const, ret)
400400
return np.asarray(ret[0][0].copy())
401401
elif (
402402
isinstance(v.owner.op, pytensor.tensor.subtensor.Subtensor)

0 commit comments

Comments
 (0)