|
1 |
| -from collections.abc import Callable |
2 | 1 | from functools import singledispatch
|
3 | 2 | from textwrap import dedent, indent
|
4 |
| -from typing import Any |
5 | 3 |
|
6 | 4 | import numba
|
7 | 5 | import numpy as np
|
8 | 6 | from numba.core.extending import overload
|
9 | 7 | from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
|
10 | 8 |
|
11 | 9 | from pytensor import config
|
12 |
| -from pytensor.graph.basic import Apply |
13 | 10 | from pytensor.graph.op import Op
|
14 | 11 | from pytensor.link.numba.dispatch import basic as numba_basic
|
15 | 12 | from pytensor.link.numba.dispatch.basic import (
|
@@ -124,42 +121,6 @@ def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr):
|
124 | 121 | """
|
125 | 122 |
|
126 | 123 |
|
127 |
| -def create_vectorize_func( |
128 |
| - scalar_op_fn: Callable, |
129 |
| - node: Apply, |
130 |
| - use_signature: bool = False, |
131 |
| - identity: Any | None = None, |
132 |
| - **kwargs, |
133 |
| -) -> Callable: |
134 |
| - r"""Create a vectorized Numba function from a `Apply`\s Python function.""" |
135 |
| - |
136 |
| - if len(node.outputs) > 1: |
137 |
| - raise NotImplementedError( |
138 |
| - "Multi-output Elemwise Ops are not supported by the Numba backend" |
139 |
| - ) |
140 |
| - |
141 |
| - if use_signature: |
142 |
| - signature = [create_numba_signature(node, force_scalar=True)] |
143 |
| - else: |
144 |
| - signature = [] |
145 |
| - |
146 |
| - target = ( |
147 |
| - getattr(node.tag, "numba__vectorize_target", None) |
148 |
| - or config.numba__vectorize_target |
149 |
| - ) |
150 |
| - |
151 |
| - numba_vectorized_fn = numba_basic.numba_vectorize( |
152 |
| - signature, identity=identity, target=target, fastmath=config.numba__fastmath |
153 |
| - ) |
154 |
| - |
155 |
| - py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn) |
156 |
| - |
157 |
| - elemwise_fn = numba_vectorized_fn(scalar_op_fn) |
158 |
| - elemwise_fn.py_scalar_func = py_scalar_func |
159 |
| - |
160 |
| - return elemwise_fn |
161 |
| - |
162 |
| - |
163 | 124 | def create_multiaxis_reducer(
|
164 | 125 | scalar_op,
|
165 | 126 | identity,
|
|
0 commit comments