1
1
import collections
2
+ import itertools
2
3
from collections .abc import Sequence
3
4
from functools import partial , reduce
4
5
from itertools import pairwise
9
10
normalize_axis_index ,
10
11
normalize_axis_tuple ,
11
12
)
13
+ from opt_einsum .helpers import find_contraction
14
+ from opt_einsum .parser import parse_einsum_input
12
15
13
16
from pytensor .compile .builders import OpFromGraph
14
17
from pytensor .tensor import TensorLike , vectorize
@@ -32,14 +35,13 @@ class Einsum(OpFromGraph):
32
35
Wrapper Op for Einsum graphs
33
36
"""
34
37
35
- __props__ = ("subscripts" , "optimize " )
38
+ __props__ = ("subscripts" , "path" , "optimized " )
36
39
37
- def __init__ (
38
- self , * args , subscripts : str , optimize : str | None = "optimal" , ** kwargs
39
- ):
40
+ def __init__ (self , * args , subscripts : str , path : str , optimized : bool , ** kwargs ):
40
41
self .subscripts = subscripts
41
- self .optimize = optimize
42
- super ().__init__ (* args , ** kwargs )
42
+ self .path = path
43
+ self .optimized = optimized
44
+ super ().__init__ (* args , ** kwargs , strict = True )
43
45
44
46
45
47
def _iota (shape : TensorVariable , axis : int ) -> TensorVariable :
@@ -142,6 +144,57 @@ def _general_dot(
142
144
return cast (TensorVariable , out )
143
145
144
146
147
+ PATH = tuple [tuple [int ] | tuple [int , int ]]
148
+
149
+
150
+ def contraction_list_from_path (
151
+ subscripts : str , operands : Sequence [TensorLike ], path : PATH
152
+ ):
153
+ """TODO Docstrings
154
+
155
+ Code adapted from einsum_opt
156
+ """
157
+ fake_operands = [
158
+ np .zeros ([1 if dim == 1 else 0 for dim in x .type .shape ]) for x in operands
159
+ ]
160
+ input_subscripts , output_subscript , operands = parse_einsum_input (
161
+ (subscripts , * fake_operands )
162
+ )
163
+
164
+ # Build a few useful list and sets
165
+ input_list = input_subscripts .split ("," )
166
+ input_sets = [set (x ) for x in input_list ]
167
+ output_set = set (output_subscript )
168
+
169
+ # Build contraction tuple (positions, gemm, einsum_str, remaining)
170
+ contraction_list = []
171
+ for cnum , contract_inds in enumerate (path ):
172
+ # Make sure we remove inds from right to left
173
+ contract_inds = tuple (sorted (contract_inds , reverse = True ))
174
+
175
+ contract_tuple = find_contraction (contract_inds , input_sets , output_set )
176
+ out_inds , input_sets , idx_removed , idx_contract = contract_tuple
177
+
178
+ tmp_inputs = [input_list .pop (x ) for x in contract_inds ]
179
+
180
+ # Last contraction
181
+ if (cnum - len (path )) == - 1 :
182
+ idx_result = output_subscript
183
+ else :
184
+ # use tensordot order to minimize transpositions
185
+ all_input_inds = "" .join (tmp_inputs )
186
+ idx_result = "" .join (sorted (out_inds , key = all_input_inds .find ))
187
+
188
+ input_list .append (idx_result )
189
+ einsum_str = "," .join (tmp_inputs ) + "->" + idx_result
190
+
191
+ # We only need the first three inputs to build the forward graph
192
+ contraction = (contract_inds , idx_removed , einsum_str , None , None )
193
+ contraction_list .append (contraction )
194
+
195
+ return contraction_list
196
+
197
+
145
198
def einsum (subscripts : str , * operands : "TensorLike" ) -> TensorVariable :
146
199
"""
147
200
Multiplication and summation of tensors using the Einstein summation convention.
@@ -168,18 +221,33 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
168
221
# TODO: Do we need this as dependency?
169
222
from opt_einsum import contract_path
170
223
171
- operands = cast ( tuple [ TensorVariable ], tuple ( map ( as_tensor , operands )))
224
+ operands = [ as_tensor ( operand ) for operand in operands ]
172
225
shapes = [operand .type .shape for operand in operands ]
173
226
174
- # TODE: Do fast path at creation time, and optimize only in fast_run
175
- _ , contraction_list = contract_path (
176
- subscripts ,
177
- * shapes ,
178
- einsum_call = True ,
179
- use_blas = True ,
180
- optimize = "optimal" ,
181
- shapes = True ,
182
- )
227
+ if None in itertools .chain .from_iterable (shapes ):
228
+ # We mark optimized = False, even in cases where there is no ordering optimization to be done
229
+ # because the inner graph may have to accommodate dynamic shapes.
230
+ # If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
231
+ if len (operands ) == 1 :
232
+ path = [(0 ,)]
233
+ else :
234
+ # Create default path of repeating (1,0) that executes left to right cyclically
235
+ # with intermediate outputs being pushed to the end of the stack
236
+ # We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
237
+ path = [(1 , 0 ) for i in range (len (operands ) - 1 )]
238
+ contraction_list = contraction_list_from_path (subscripts , operands , path )
239
+ optimized = False
240
+ else :
241
+ _ , contraction_list = contract_path (
242
+ subscripts ,
243
+ * shapes ,
244
+ einsum_call = True ,
245
+ use_blas = True ,
246
+ optimize = "optimal" ,
247
+ shapes = True ,
248
+ )
249
+ path = [contraction [0 ] for contraction in contraction_list ]
250
+ optimized = True
183
251
184
252
def sum_uniques (
185
253
operand : TensorVariable , names : str , uniques : list [str ]
@@ -246,6 +314,7 @@ def sum_repeats(
246
314
lhs , rhs = map (einsum_operands .pop , operand_indices )
247
315
lhs_names , rhs_names = input_names
248
316
317
+ # TODO: Do this as well?
249
318
# handle cases where one side of a contracting or batch dimension is 1
250
319
# but its counterpart is not.
251
320
# lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
@@ -323,6 +392,10 @@ def sum_repeats(
323
392
axes = (lhs_cont , rhs_cont ),
324
393
batch_axes = (lhs_batch , rhs_batch ),
325
394
)
395
+ else :
396
+ raise ValueError (
397
+ f"Each step of einsum must have 1 or 2 operands, got { len (operand_indices )} "
398
+ )
326
399
327
400
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
328
401
assert len (names ) == len (result_names ) == len (set (names ))
@@ -338,5 +411,7 @@ def sum_repeats(
338
411
subscripts = subscripts ,
339
412
inputs = list (operands ),
340
413
outputs = [einsum_result ],
414
+ path = tuple (path ),
415
+ optimized = optimized ,
341
416
)(* operands )
342
417
return cast (TensorVariable , out )
0 commit comments