1
1
import collections
2
+ import itertools
2
3
from collections .abc import Sequence
3
4
from functools import partial , reduce
4
5
from itertools import pairwise
9
10
normalize_axis_index ,
10
11
normalize_axis_tuple ,
11
12
)
13
+ from opt_einsum .helpers import find_contraction
14
+ from opt_einsum .parser import parse_einsum_input
12
15
13
16
from pytensor .compile .builders import OpFromGraph
14
17
from pytensor .tensor import TensorLike , vectorize
@@ -32,14 +35,15 @@ class Einsum(OpFromGraph):
32
35
Wrapper Op for Einsum graphs
33
36
"""
34
37
35
- __props__ = ("subscripts" , "optimize " )
38
+ __props__ = ("subscripts" , "path" , "optimized " )
36
39
37
40
def __init__ (
38
- self , * args , subscripts : str , optimize : str | None = "optimal" , ** kwargs
41
+ self , * args , subscripts : str , path : str , optimized : bool , ** kwargs
39
42
):
40
43
self .subscripts = subscripts
41
- self .optimize = optimize
42
- super ().__init__ (* args , ** kwargs )
44
+ self .path = path
45
+ self .optimized = optimized
46
+ super ().__init__ (* args , ** kwargs , strict = True )
43
47
44
48
45
49
def _iota (shape : TensorVariable , axis : int ) -> TensorVariable :
@@ -142,6 +146,50 @@ def _general_dot(
142
146
return cast (TensorVariable , out )
143
147
144
148
149
+ PATH = tuple [tuple [int ] | tuple [int , int ]]
150
+
151
+ def contraction_list_from_path (subscripts : str , operands : Sequence [TensorLike ], path : PATH ):
152
+ """TODO Docstrings
153
+
154
+ Code adapted from einsum_opt
155
+ """
156
+ fake_operands = [np .zeros ([1 if dim == 1 else 0 for dim in x .type .shape ]) for x in operands ]
157
+ input_subscripts , output_subscript , operands = parse_einsum_input ((subscripts , * fake_operands ))
158
+
159
+ # Build a few useful list and sets
160
+ input_list = input_subscripts .split (',' )
161
+ input_sets = [set (x ) for x in input_list ]
162
+ output_set = set (output_subscript )
163
+
164
+ # Build contraction tuple (positions, gemm, einsum_str, remaining)
165
+ contraction_list = []
166
+ for cnum , contract_inds in enumerate (path ):
167
+ # Make sure we remove inds from right to left
168
+ contract_inds = tuple (sorted (list (contract_inds ), reverse = True ))
169
+
170
+ contract_tuple = find_contraction (contract_inds , input_sets , output_set )
171
+ out_inds , input_sets , idx_removed , idx_contract = contract_tuple
172
+
173
+ tmp_inputs = [input_list .pop (x ) for x in contract_inds ]
174
+
175
+ # Last contraction
176
+ if (cnum - len (path )) == - 1 :
177
+ idx_result = output_subscript
178
+ else :
179
+ # use tensordot order to minimize transpositions
180
+ all_input_inds = "" .join (tmp_inputs )
181
+ idx_result = "" .join (sorted (out_inds , key = all_input_inds .find ))
182
+
183
+ input_list .append (idx_result )
184
+ einsum_str = "," .join (tmp_inputs ) + "->" + idx_result
185
+
186
+ # We only need the first three inputs to build the forward graph
187
+ contraction = (contract_inds , idx_removed , einsum_str , None , None )
188
+ contraction_list .append (contraction )
189
+
190
+ return contraction_list
191
+
192
+
145
193
def einsum (subscripts : str , * operands : "TensorLike" ) -> TensorVariable :
146
194
"""
147
195
Multiplication and summation of tensors using the Einstein summation convention.
@@ -168,18 +216,33 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
168
216
# TODO: Do we need this as dependency?
169
217
from opt_einsum import contract_path
170
218
171
- operands = cast ( tuple [ TensorVariable ], tuple ( map ( as_tensor , operands )))
219
+ operands = [ as_tensor ( operand ) for operand in operands ]
172
220
shapes = [operand .type .shape for operand in operands ]
173
221
174
- # TODE: Do fast path at creation time, and optimize only in fast_run
175
- _ , contraction_list = contract_path (
176
- subscripts ,
177
- * shapes ,
178
- einsum_call = True ,
179
- use_blas = True ,
180
- optimize = "optimal" ,
181
- shapes = True ,
182
- )
222
+ if None in itertools .chain .from_iterable (shapes ):
223
+ # We mark optimized = False, even in cases where there is no ordering optimization to be done
224
+ # because the inner graph may have to accommodate dynamic shapes.
225
+ # If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
226
+ if len (operands ) == 1 :
227
+ path = [(0 ,)]
228
+ else :
229
+ # Create default path of repeating (1,0) that executes left to right cyclically
230
+ # with intermediate outputs being pushed to the end of the stack
231
+ # We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
232
+ path = [(1 ,0 ) for i in range (len (operands ) - 1 )]
233
+ contraction_list = contraction_list_from_path (subscripts , operands , path )
234
+ optimized = False
235
+ else :
236
+ _ , contraction_list = contract_path (
237
+ subscripts ,
238
+ * shapes ,
239
+ einsum_call = True ,
240
+ use_blas = True ,
241
+ optimize = "optimal" ,
242
+ shapes = True ,
243
+ )
244
+ path = [contraction [0 ] for contraction in contraction_list ]
245
+ optimized = True
183
246
184
247
def sum_uniques (
185
248
operand : TensorVariable , names : str , uniques : list [str ]
@@ -246,6 +309,7 @@ def sum_repeats(
246
309
lhs , rhs = map (einsum_operands .pop , operand_indices )
247
310
lhs_names , rhs_names = input_names
248
311
312
+ # TODO: Do this as well?
249
313
# handle cases where one side of a contracting or batch dimension is 1
250
314
# but its counterpart is not.
251
315
# lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
@@ -323,6 +387,8 @@ def sum_repeats(
323
387
axes = (lhs_cont , rhs_cont ),
324
388
batch_axes = (lhs_batch , rhs_batch ),
325
389
)
390
+ else :
391
+ raise ValueError (f"Each step of einsum must have 1 or 2 operands, got { len (operand_indices )} " )
326
392
327
393
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
328
394
assert len (names ) == len (result_names ) == len (set (names ))
@@ -338,5 +404,7 @@ def sum_repeats(
338
404
subscripts = subscripts ,
339
405
inputs = list (operands ),
340
406
outputs = [einsum_result ],
407
+ path = tuple (path ),
408
+ optimized = optimized ,
341
409
)(* operands )
342
410
return cast (TensorVariable , out )
0 commit comments