33
33
class Einsum (OpFromGraph ):
34
34
"""
35
35
Wrapper Op for Einsum graphs
36
+
37
+ Notes
38
+ -----
39
+ The `optimized` prop indicates whether the inner graph was optimized, which can only be done when all shapes are
40
+ statically known. This is now determined at graph creation time only. We could introduce a rewrite that tries to
41
+ optimize the graph if static shapes become known later (e.g., after use of `clone_replace` or shape inference during
42
+ rewrites).
43
+
44
+ Also, once the graph is optimized, it could be inlined for potential further optimization that consider the rest of
45
+ the graph.
46
+
47
+ This prop is different from the `optimize` kwarg in numpy that determines what kind (if any) of optimization is
48
+ desired. We haven't decided whether we want to provide this functionality.
36
49
"""
37
50
38
- __props__ = ("subscripts" , "path" , "optimize " )
51
+ __props__ = ("subscripts" , "path" , "optimized " )
39
52
40
- def __init__ (self , * args , subscripts : str , path : str , optimize : bool , ** kwargs ):
53
+ def __init__ (self , * args , subscripts : str , path : str , optimized : bool , ** kwargs ):
41
54
self .subscripts = subscripts
42
55
self .path = path
43
- self .optimize = optimize
56
+ self .optimized = optimized
44
57
super ().__init__ (* args , ** kwargs , strict = True )
45
58
46
59
@@ -368,7 +381,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
368
381
contraction_list = _contraction_list_from_path (subscripts , operands , path )
369
382
370
383
# If there are only 1 or 2 operands, there is no optimization to be done?
371
- optimize = len (operands ) <= 2
384
+ optimized = len (operands ) <= 2
372
385
else :
373
386
# Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal
374
387
# contraction order.
@@ -382,7 +395,7 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
382
395
optimize = "optimal" ,
383
396
)
384
397
path = [contraction [0 ] for contraction in contraction_list ]
385
- optimize = True
398
+ optimized = True
386
399
387
400
def sum_uniques (
388
401
operand : TensorVariable , names : str , uniques : list [str ]
@@ -550,6 +563,6 @@ def sum_repeats(
550
563
inputs = list (operands ),
551
564
outputs = [einsum_result ],
552
565
path = tuple (path ),
553
- optimize = optimize ,
566
+ optimized = optimized ,
554
567
)(* operands )
555
568
return cast (TensorVariable , out )
0 commit comments