@@ -36,12 +36,12 @@ class Einsum(OpFromGraph):
36
36
Wrapper Op for Einsum graphs
37
37
"""
38
38
39
- __props__ = ("subscripts" , "path" , "optimized " )
39
+ __props__ = ("subscripts" , "path" , "optimize " )
40
40
41
- def __init__ (self , * args , subscripts : str , path : str , optimized : bool , ** kwargs ):
41
+ def __init__ (self , * args , subscripts : str , path : str , optimize : bool , ** kwargs ):
42
42
self .subscripts = subscripts
43
43
self .path = path
44
- self .optimized = optimized
44
+ self .optimize = optimize
45
45
super ().__init__ (* args , ** kwargs , strict = True )
46
46
47
47
@@ -224,7 +224,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
224
224
shapes = [operand .type .shape for operand in operands ]
225
225
226
226
if None in itertools .chain .from_iterable (shapes ):
227
- # We mark optimized = False, even in cases where there is no ordering optimization to be done
227
+ # We mark optimize = False, even in cases where there is no ordering optimization to be done
228
228
# because the inner graph may have to accommodate dynamic shapes.
229
229
# If those shapes become known later we will likely want to rebuild the Op (unless we inline it)
230
230
if len (operands ) == 1 :
@@ -235,7 +235,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
235
235
# We use (1,0) and not (0,1) because that's what opt_einsum tends to prefer, and so the Op signatures will match more often
236
236
path = [(1 , 0 ) for i in range (len (operands ) - 1 )]
237
237
contraction_list = contraction_list_from_path (subscripts , operands , path )
238
- optimized = (
238
+ optimize = (
239
239
len (operands ) <= 2
240
240
) # If there are only 1 or 2 operands, there is no optimization to be done?
241
241
else :
@@ -248,7 +248,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
248
248
shapes = True ,
249
249
)
250
250
path = [contraction [0 ] for contraction in contraction_list ]
251
- optimized = True
251
+ optimize = True
252
252
253
253
def sum_uniques (
254
254
operand : TensorVariable , names : str , uniques : list [str ]
@@ -413,6 +413,6 @@ def sum_repeats(
413
413
inputs = list (operands ),
414
414
outputs = [einsum_result ],
415
415
path = tuple (path ),
416
- optimized = optimized ,
416
+ optimize = optimize ,
417
417
)(* operands )
418
418
return cast (TensorVariable , out )
0 commit comments