Skip to content

Commit 36c8e1a

Browse files
committed
Fix ellipsis and drop dependency on opt_einsum
1 parent f94e02c commit 36c8e1a

File tree

4 files changed

+12
-23
lines changed

4 files changed

+12
-23
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ jobs:
154154
shell: micromamba-shell {0}
155155
run: |
156156
157-
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy opt_einsum pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock
157+
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock
158158
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
159159
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
160160
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi
@@ -215,7 +215,7 @@ jobs:
215215
- name: Install dependencies
216216
shell: micromamba-shell {0}
217217
run: |
218-
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy opt_einsum pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
218+
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
219219
pip install -e ./
220220
micromamba list && pip freeze
221221
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'

environment.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ dependencies:
1111
- compilers
1212
- numpy>=1.17.0,<2
1313
- scipy>=0.14,<1.14.0
14-
- opt_einsum
1514
- filelock>=3.15
1615
- etuples
1716
- logical-unification

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ dependencies = [
4949
"setuptools>=59.0.0",
5050
"scipy>=0.14,<1.14",
5151
"numpy>=1.17.0,<2",
52-
"opt_einsum",
5352
"filelock>=3.15",
5453
"etuples",
5554
"logical-unification",

pytensor/tensor/einsum.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
from typing import cast
77

88
import numpy as np
9+
from numpy.core.einsumfunc import _find_contraction, _parse_einsum_input # type: ignore
910
from numpy.core.numeric import ( # type: ignore
1011
normalize_axis_index,
1112
normalize_axis_tuple,
1213
)
13-
from opt_einsum.helpers import find_contraction
14-
from opt_einsum.parser import parse_einsum_input
1514

1615
from pytensor.compile.builders import OpFromGraph
1716
from pytensor.tensor import TensorLike
@@ -129,9 +128,6 @@ def _general_dot(
129128
core_lhs_axes = tuple(np.array(lhs_axes) - lhs_n_batch_axes)
130129
core_rhs_axes = tuple(np.array(rhs_axes) - rhs_n_batch_axes)
131130

132-
# TODO: tensordot produces very complicated graphs unnecessarily
133-
# In some cases we are just doing elemwise addition after some transpositions
134-
# We also have some Blockwise(Reshape) that will slow down things!
135131
if signature == "(),()->()":
136132
# Just a multiplication
137133
out = lhs * rhs
@@ -146,7 +142,7 @@ def _general_dot(
146142
PATH = tuple[tuple[int] | tuple[int, int]]
147143

148144

149-
def contraction_list_from_path(
145+
def _contraction_list_from_path(
150146
subscripts: str, operands: Sequence[TensorLike], path: PATH
151147
):
152148
"""
@@ -189,7 +185,7 @@ def contraction_list_from_path(
189185
fake_operands = [
190186
np.zeros([1 if dim == 1 else 0 for dim in x.type.shape]) for x in operands
191187
]
192-
input_subscripts, output_subscript, operands = parse_einsum_input(
188+
input_subscripts, output_subscript, operands = _parse_einsum_input(
193189
(subscripts, *fake_operands)
194190
)
195191

@@ -204,7 +200,7 @@ def contraction_list_from_path(
204200
# Make sure we remove inds from right to left
205201
contract_inds = tuple(sorted(contract_inds, reverse=True))
206202

207-
contract_tuple = find_contraction(contract_inds, input_sets, output_set)
203+
contract_tuple = _find_contraction(contract_inds, input_sets, output_set)
208204
out_inds, input_sets, idx_removed, idx_contract = contract_tuple
209205

210206
tmp_inputs = [input_list.pop(x) for x in contract_inds]
@@ -354,12 +350,6 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
354350

355351
# TODO: Is this doing something clever about unknown shapes?
356352
# contract_path = _poly_einsum_handlers.get(ty, _default_poly_einsum_handler)
357-
# using einsum_call=True here is an internal api for opt_einsum... sorry
358-
359-
# TODO: Handle None static shapes
360-
# TODO: Do we need this as dependency?
361-
from opt_einsum import contract_path
362-
363353
operands = [as_tensor(operand) for operand in operands]
364354
shapes = [operand.type.shape for operand in operands]
365355

@@ -375,20 +365,21 @@ def einsum(subscripts: str, *operands: "TensorLike") -> TensorVariable:
375365
path = [(0,)]
376366
else:
377367
path = [(1, 0) for i in range(len(operands) - 1)]
378-
contraction_list = contraction_list_from_path(subscripts, operands, path)
368+
contraction_list = _contraction_list_from_path(subscripts, operands, path)
379369

380370
# If there are only 1 or 2 operands, there is no optimization to be done?
381371
optimize = len(operands) <= 2
382372
else:
383373
# Case 2: All operands have known shapes. In this case, we can use opt_einsum to compute the optimal
384374
# contraction order.
385-
_, contraction_list = contract_path(
375+
# Call _implementation to bypass dispatch
376+
_, contraction_list = np.einsum_path(
386377
subscripts,
387-
*shapes,
378+
# Numpy einsum_path requires arrays even though only the shapes matter
379+
# It's not trivial to duck-type our way around because of internal call to `asanyarray`
380+
*[np.empty(shape) for shape in shapes],
388381
einsum_call=True,
389-
use_blas=True,
390382
optimize="optimal",
391-
shapes=True,
392383
)
393384
path = [contraction[0] for contraction in contraction_list]
394385
optimize = True

0 commit comments

Comments
 (0)