1
1
import collections
2
2
from collections .abc import Sequence
3
- from functools import reduce
3
+ from functools import partial , reduce
4
4
from itertools import pairwise
5
+ from typing import cast
5
6
6
- from numpy .core .numeric import normalize_axis_index # type: ignore
7
+ import numpy as np
8
+ from numpy .core .numeric import ( # type: ignore
9
+ normalize_axis_index ,
10
+ normalize_axis_tuple ,
11
+ )
7
12
8
13
from pytensor .compile .builders import OpFromGraph
14
+ from pytensor .tensor import vectorize
9
15
from pytensor .tensor .basic import (
10
16
arange ,
11
17
get_vector_length ,
@@ -54,6 +60,62 @@ def _removechars(s, chars):
54
60
return s .translate (str .maketrans (dict .fromkeys (chars )))
55
61
56
62
63
+ def _batched_tensordot (
64
+ vars : tuple [TensorVariable , TensorVariable ],
65
+ axes : Sequence [Sequence [int ]], # Should be length 2,
66
+ batch_axes : Sequence [Sequence [int ]], # Should be length 2,
67
+ ) -> TensorVariable :
68
+ # Shortcut for non batched case
69
+ if not batch_axes [0 ] and not batch_axes [1 ]:
70
+ return tensordot (* vars , axes = axes )
71
+
72
+ # Normalize axes, thankfully numpy helper does not sort axis!
73
+ axes = [
74
+ normalize_axis_tuple (var_axes , var .ndim ) for var , var_axes in zip (vars , axes )
75
+ ]
76
+ batch_axes = [
77
+ normalize_axis_tuple (var_axes , var .ndim )
78
+ for var , var_axes in zip (vars , batch_axes )
79
+ ]
80
+ n_batch_axes = [len (var_batch_axes ) for var_batch_axes in batch_axes ]
81
+ if any (
82
+ var_batch_axes != tuple (range (var_n_batch_axes ))
83
+ for var_batch_axes , var_n_batch_axes in zip (batch_axes , n_batch_axes )
84
+ ):
85
+ # Will need to transpose /expand_dims to align batch dims on the left and then transpose back
86
+ raise NotImplementedError (
87
+ f"Arbitrary batch dims location not yet supported, got: { batch_axes } "
88
+ )
89
+
90
+ lhs , rhs = vars
91
+ lhs_axes , rhs_axes = axes
92
+ lhs_n_batch_axes , rhs_n_batch_axes = n_batch_axes
93
+
94
+ # Create signature of tensordot
95
+ lhs_signature = [f"l{ i } " for i in range (lhs .type .ndim )]
96
+ rhs_signature = [f"r{ i } " for i in range (rhs .type .ndim )]
97
+ # Aligned axes get the same dimension name
98
+ for i , (lhs_axis , rhs_axis ) in enumerate (zip (lhs_axes , rhs_axes )):
99
+ lhs_signature [lhs_axis ] = rhs_signature [rhs_axis ] = f"a{ i } "
100
+ # Trim away the batch ndims
101
+ lhs_signature = lhs_signature [lhs_n_batch_axes :]
102
+ rhs_signature = rhs_signature [rhs_n_batch_axes :]
103
+ out_signature = [
104
+ lhs_dim for lhs_dim in lhs_signature if not lhs_dim .startswith ("a" )
105
+ ] + [rhs_dim for rhs_dim in rhs_signature if not rhs_dim .startswith ("a" )]
106
+ signature = f"({ ',' .join (lhs_signature )} ),({ ',' .join (rhs_signature )} )->({ ',' .join (out_signature )} )"
107
+ # Adjust axes for core case
108
+ core_lhs_axes = tuple (np .array (lhs_axes ) - lhs_n_batch_axes )
109
+ core_rhs_axes = tuple (np .array (rhs_axes ) - rhs_n_batch_axes )
110
+
111
+ # TODO: Make sure this looks reasonable after optimizations
112
+ # Right now we have some Blockwise(Reshape) that will slow down things!
113
+ out = vectorize (
114
+ partial (tensordot , axes = [core_lhs_axes , core_rhs_axes ]), signature = signature
115
+ )(lhs , rhs )
116
+ return cast (TensorVariable , out )
117
+
118
+
57
119
def einsum (subscripts : str , * operands ):
58
120
"""
59
121
Multiplication and summation of tensors using the Einstein summation convention.
@@ -199,8 +261,6 @@ def sum_repeats(
199
261
lhs_batch , rhs_batch = tuple (
200
262
zip (* [(lhs_names .find (n ), rhs_names .find (n )) for n in batch_names ])
201
263
)
202
- if lhs_batch or rhs_batch :
203
- raise NotImplementedError ("Batch dimensions are not yet supported" )
204
264
else :
205
265
lhs_batch = rhs_batch = ()
206
266
@@ -226,10 +286,16 @@ def sum_repeats(
226
286
# needing a transpose.
227
287
names = batch_names_str + remaining_rhs_names + remaining_lhs_names
228
288
if names == result_names :
229
- operand = tensordot (rhs , lhs , (rhs_cont , lhs_cont ))
289
+ operand = _batched_tensordot (
290
+ (rhs , lhs ), (rhs_cont , lhs_cont ), (rhs_batch , lhs_batch )
291
+ )
230
292
else :
231
293
names = batch_names_str + remaining_lhs_names + remaining_rhs_names
232
- operand = tensordot (lhs , rhs , axes = (lhs_cont , rhs_cont ))
294
+ operand = _batched_tensordot (
295
+ (lhs , rhs ),
296
+ axes = (lhs_cont , rhs_cont ),
297
+ batch_axes = (lhs_batch , rhs_batch ),
298
+ )
233
299
234
300
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
235
301
assert len (names ) == len (result_names ) == len (set (names ))
0 commit comments