1
1
import collections
2
+ import itertools
2
3
from collections .abc import Sequence
3
4
from functools import partial , reduce
4
5
from itertools import pairwise
9
10
normalize_axis_index ,
10
11
normalize_axis_tuple ,
11
12
)
13
+ from opt_einsum .helpers import find_contraction
14
+ from opt_einsum .parser import parse_einsum_input
12
15
13
16
from pytensor .compile .builders import OpFromGraph
14
17
from pytensor .tensor import TensorLike , vectorize
@@ -32,14 +35,13 @@ class Einsum(OpFromGraph):
32
35
Wrapper Op for Einsum graphs
33
36
"""
34
37
35
- __props__ = ("subscripts" , "optimize " )
38
+ __props__ = ("subscripts" , "path" , "optimized " )
36
39
37
- def __init__ (
38
- self , * args , subscripts : str , optimize : str | None = "optimal" , ** kwargs
39
- ):
40
+ def __init__ (self , * args , subscripts : str , path : str , optimized : bool , ** kwargs ):
40
41
self .subscripts = subscripts
41
- self .optimize = optimize
42
- super ().__init__ (* args , ** kwargs )
42
+ self .path = path
43
+ self .optimized = optimized
44
+ super ().__init__ (* args , ** kwargs , strict = True )
43
45
44
46
45
47
def _iota (shape : TensorVariable , axis : int ) -> TensorVariable :
@@ -140,6 +142,57 @@ def _general_dot(
140
142
return cast (TensorVariable , out )
141
143
142
144
145
+ PATH = tuple [tuple [int ] | tuple [int , int ]]
146
+
147
+
148
+ def contraction_list_from_path (
149
+ subscripts : str , operands : Sequence [TensorLike ], path : PATH
150
+ ):
151
+ """TODO Docstrings
152
+
153
+ Code adapted from einsum_opt
154
+ """
155
+ fake_operands = [
156
+ np .zeros ([1 if dim == 1 else 0 for dim in x .type .shape ]) for x in operands
157
+ ]
158
+ input_subscripts , output_subscript , operands = parse_einsum_input (
159
+ (subscripts , * fake_operands )
160
+ )
161
+
162
+ # Build a few useful list and sets
163
+ input_list = input_subscripts .split ("," )
164
+ input_sets = [set (x ) for x in input_list ]
165
+ output_set = set (output_subscript )
166
+
167
+ # Build contraction tuple (positions, gemm, einsum_str, remaining)
168
+ contraction_list = []
169
+ for cnum , contract_inds in enumerate (path ):
170
+ # Make sure we remove inds from right to left
171
+ contract_inds = tuple (sorted (contract_inds , reverse = True ))
172
+
173
+ contract_tuple = find_contraction (contract_inds , input_sets , output_set )
174
+ out_inds , input_sets , idx_removed , idx_contract = contract_tuple
175
+
176
+ tmp_inputs = [input_list .pop (x ) for x in contract_inds ]
177
+
178
+ # Last contraction
179
+ if (cnum - len (path )) == - 1 :
180
+ idx_result = output_subscript
181
+ else :
182
+ # use tensordot order to minimize transpositions
183
+ all_input_inds = "" .join (tmp_inputs )
184
+ idx_result = "" .join (sorted (out_inds , key = all_input_inds .find ))
185
+
186
+ input_list .append (idx_result )
187
+ einsum_str = "," .join (tmp_inputs ) + "->" + idx_result
188
+
189
+ # We only need the first three inputs to build the forward graph
190
+ contraction = (contract_inds , idx_removed , einsum_str , None , None )
191
+ contraction_list .append (contraction )
192
+
193
+ return contraction_list
194
+
195
+
143
196
def einsum (subscripts : str , * operands : "TensorLike" ) -> TensorVariable :
144
197
"""
145
198
Multiplication and summation of tensors using the Einstein summation convention.
@@ -166,18 +219,35 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
166
219
# TODO: Do we need this as dependency?
167
220
from opt_einsum import contract_path
168
221
169
- operands = cast ( tuple [ TensorVariable ], tuple ( map ( as_tensor , operands )))
222
+ operands = [ as_tensor ( operand ) for operand in operands ]
170
223
shapes = [operand .type .shape for operand in operands ]
171
224
172
- # TODE: Do fast path at creation time, and optimize only in fast_run
173
- _ , contraction_list = contract_path (
174
- subscripts ,
175
- * shapes ,
176
- einsum_call = True ,
177
- use_blas = True ,
178
- optimize = "optimal" ,
179
- shapes = True ,
180
- )
225
+ if None in itertools .chain .from_iterable (shapes ):
226
+ # We mark optimized = False, even in cases where there is no ordering optimization to be done
227
+ # because the inner graph may have to accommodate dynamic shapes.
228
+ # If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
229
+ if len (operands ) == 1 :
230
+ path = [(0 ,)]
231
+ else :
232
+ # Create default path of repeating (1,0) that executes left to right cyclically
233
+ # with intermediate outputs being pushed to the end of the stack
234
+ # We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
235
+ path = [(1 , 0 ) for i in range (len (operands ) - 1 )]
236
+ contraction_list = contraction_list_from_path (subscripts , operands , path )
237
+ optimized = (
238
+ len (operands ) <= 2
239
+ ) # If there are only 1 or 2 operands, there is no optimization to be done?
240
+ else :
241
+ _ , contraction_list = contract_path (
242
+ subscripts ,
243
+ * shapes ,
244
+ einsum_call = True ,
245
+ use_blas = True ,
246
+ optimize = "optimal" ,
247
+ shapes = True ,
248
+ )
249
+ path = [contraction [0 ] for contraction in contraction_list ]
250
+ optimized = True
181
251
182
252
def sum_uniques (
183
253
operand : TensorVariable , names : str , uniques : list [str ]
@@ -244,6 +314,7 @@ def sum_repeats(
244
314
lhs , rhs = map (einsum_operands .pop , operand_indices )
245
315
lhs_names , rhs_names = input_names
246
316
317
+ # TODO: Do this as well?
247
318
# handle cases where one side of a contracting or batch dimension is 1
248
319
# but its counterpart is not.
249
320
# lhs, lhs_names = filter_singleton_dims(lhs, lhs_names, shape(rhs),
@@ -321,6 +392,10 @@ def sum_repeats(
321
392
axes = (lhs_cont , rhs_cont ),
322
393
batch_axes = (lhs_batch , rhs_batch ),
323
394
)
395
+ else :
396
+ raise ValueError (
397
+ f"Each step of einsum must have 1 or 2 operands, got { len (operand_indices )} "
398
+ )
324
399
325
400
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
326
401
assert len (names ) == len (result_names ) == len (set (names ))
@@ -336,5 +411,7 @@ def sum_repeats(
336
411
subscripts = subscripts ,
337
412
inputs = list (operands ),
338
413
outputs = [einsum_result ],
414
+ path = tuple (path ),
415
+ optimized = optimized ,
339
416
)(* operands )
340
417
return cast (TensorVariable , out )
0 commit comments