6
6
from typing import cast
7
7
8
8
import numpy as np
9
+ from numpy .core .einsumfunc import _find_contraction , _parse_einsum_input # type: ignore
9
10
from numpy .core .numeric import ( # type: ignore
10
11
normalize_axis_index ,
11
12
normalize_axis_tuple ,
12
13
)
13
- from opt_einsum .helpers import find_contraction
14
- from opt_einsum .parser import parse_einsum_input
15
14
16
15
from pytensor .compile .builders import OpFromGraph
17
16
from pytensor .tensor import TensorLike
@@ -129,9 +128,6 @@ def _general_dot(
129
128
core_lhs_axes = tuple (np .array (lhs_axes ) - lhs_n_batch_axes )
130
129
core_rhs_axes = tuple (np .array (rhs_axes ) - rhs_n_batch_axes )
131
130
132
- # TODO: tensordot produces very complicated graphs unnecessarily
133
- # In some cases we are just doing elemwise addition after some transpositions
134
- # We also have some Blockwise(Reshape) that will slow down things!
135
131
if signature == "(),()->()" :
136
132
# Just a multiplication
137
133
out = lhs * rhs
@@ -146,7 +142,7 @@ def _general_dot(
146
142
PATH = tuple [tuple [int ] | tuple [int , int ]]
147
143
148
144
149
- def contraction_list_from_path (
145
+ def _contraction_list_from_path (
150
146
subscripts : str , operands : Sequence [TensorLike ], path : PATH
151
147
):
152
148
"""
@@ -189,7 +185,7 @@ def contraction_list_from_path(
189
185
fake_operands = [
190
186
np .zeros ([1 if dim == 1 else 0 for dim in x .type .shape ]) for x in operands
191
187
]
192
- input_subscripts , output_subscript , operands = parse_einsum_input (
188
+ input_subscripts , output_subscript , operands = _parse_einsum_input (
193
189
(subscripts , * fake_operands )
194
190
)
195
191
@@ -204,7 +200,7 @@ def contraction_list_from_path(
204
200
# Make sure we remove inds from right to left
205
201
contract_inds = tuple (sorted (contract_inds , reverse = True ))
206
202
207
- contract_tuple = find_contraction (contract_inds , input_sets , output_set )
203
+ contract_tuple = _find_contraction (contract_inds , input_sets , output_set )
208
204
out_inds , input_sets , idx_removed , idx_contract = contract_tuple
209
205
210
206
tmp_inputs = [input_list .pop (x ) for x in contract_inds ]
@@ -354,12 +350,6 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
354
350
355
351
# TODO: Is this doing something clever about unknown shapes?
356
352
# contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
357
- # using einsum_call=True here is an internal api for opt_einsum... sorry
358
-
359
- # TODO: Handle None static shapes
360
- # TODO: Do we need this as dependency?
361
- from opt_einsum import contract_path
362
-
363
353
operands = [as_tensor (operand ) for operand in operands ]
364
354
shapes = [operand .type .shape for operand in operands ]
365
355
@@ -375,20 +365,21 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
375
365
path = [(0 ,)]
376
366
else :
377
367
path = [(1 , 0 ) for i in range (len (operands ) - 1 )]
378
- contraction_list = contraction_list_from_path (subscripts , operands , path )
368
+ contraction_list = _contraction_list_from_path (subscripts , operands , path )
379
369
380
370
# If there are only 1 or 2 operands, there is no optimization to be done?
381
371
optimize = len (operands ) <= 2
382
372
else :
383
373
# Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal
384
374
# contraction order.
385
- _ , contraction_list = contract_path (
375
+ # Call _implementation to bypass dispatch
376
+ _ , contraction_list = np .einsum_path (
386
377
subscripts ,
387
- * shapes ,
378
+ # Numpy einsum_path requires arrays even though only the shapes matter
379
+ # It's not trivial to duck-type our way around because of internal call to `asanyarray`
380
+ * [np .empty (shape ) for shape in shapes ],
388
381
einsum_call = True ,
389
- use_blas = True ,
390
382
optimize = "optimal" ,
391
- shapes = True ,
392
383
)
393
384
path = [contraction [0 ] for contraction in contraction_list ]
394
385
optimize = True
0 commit comments