|
9 | 9 | from pytensor.gradient import DisconnectedType
|
10 | 10 | from pytensor.graph.basic import Apply, Variable
|
11 | 11 | from pytensor.graph.op import Op
|
| 12 | +from pytensor.graph.replace import _vectorize_node |
12 | 13 | from pytensor.link.c.op import COp
|
13 | 14 | from pytensor.link.c.params_type import ParamsType
|
14 | 15 | from pytensor.link.c.type import Generic
|
|
25 | 26 | stack,
|
26 | 27 | switch,
|
27 | 28 | )
|
28 |
| -from pytensor.tensor.blockwise import Blockwise |
| 29 | +from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback |
29 | 30 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise
|
30 | 31 | from pytensor.tensor.shape import shape, specify_broadcastable
|
31 | 32 | from pytensor.tensor.type import (
|
@@ -2873,7 +2874,11 @@ def logsumexp(x, axis=None, keepdims=False):
|
2873 | 2874 | return log(sum(exp(x), axis=axis, keepdims=keepdims))
|
2874 | 2875 |
|
2875 | 2876 |
|
2876 |
| -_matrix_matrix_matmul = Blockwise(_dot, signature="(n,k),(k,m)->(n,m)") |
| 2877 | +_matrix_matrix_matmul = Blockwise( |
| 2878 | + _dot, |
| 2879 | + signature="(m,k),(k,n)->(m,n)", |
| 2880 | + gufunc_spec=("numpy.matmul", 2, 1), |
| 2881 | +) |
2877 | 2882 |
|
2878 | 2883 |
|
2879 | 2884 | def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
|
@@ -2937,6 +2942,15 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
|
2937 | 2942 | return out
|
2938 | 2943 |
|
2939 | 2944 |
|
| 2945 | +@_vectorize_node.register(Dot) |
| 2946 | +def vectorize_node_to_matmul(op, node, batched_x, batched_y): |
| 2947 | + old_x, old_y = node.inputs |
| 2948 | + if old_x.type.ndim == 2 and old_y.type.ndim == 2: |
| 2949 | + return matmul(batched_x, batched_y).owner |
| 2950 | + else: |
| 2951 | + return vectorize_node_fallback(op, node, batched_x, batched_y) |
| 2952 | + |
| 2953 | + |
2940 | 2954 | __all__ = [
|
2941 | 2955 | "max_and_argmax",
|
2942 | 2956 | "max",
|
|
0 commit comments