Skip to content

Commit d517117

Browse files
committed
[mlir] python bindings for vector transform ops
Provide Python bindings for transform ops defined in the vector dialect. All of these ops are sufficiently simple that no mixins are necessary for them to be nicely usable. Reviewed By: ingomueller-net Differential Revision: https://reviews.llvm.org/D156554
1 parent 1f8618f commit d517117

File tree

6 files changed

+237
-4
lines changed

6 files changed

+237
-4
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ def ApplyLowerBroadcastPatternsOp : Op<Transform_Dialect,
8181
let assemblyFormat = "attr-dict";
8282
}
8383

84-
// TODO: evolve lowering_strategy to proper enums.
8584
def ApplyLowerContractionPatternsOp : Op<Transform_Dialect,
8685
"apply_patterns.vector.lower_contraction",
8786
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
@@ -143,7 +142,6 @@ def ApplyMaterializeMasksPatternsOp : Op<Transform_Dialect,
143142
let assemblyFormat = "attr-dict";
144143
}
145144

146-
// TODO: evolve lowering_strategy to proper enums.
147145
def ApplyLowerMultiReductionPatternsOp : Op<Transform_Dialect,
148146
"apply_patterns.vector.lower_multi_reduction",
149147
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
@@ -232,7 +230,6 @@ def ApplyLowerTransferPatternsOp : Op<Transform_Dialect,
232230
}];
233231
}
234232

235-
// TODO: evolve lowering_strategy to proper enums.
236233
def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
237234
"apply_patterns.vector.lower_transpose",
238235
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
@@ -259,7 +256,6 @@ def ApplyLowerTransposePatternsOp : Op<Transform_Dialect,
259256
}];
260257
}
261258

262-
// TODO: evolve split_transfer_strategy to proper enums.
263259
def ApplySplitTransferFullPartialPatternsOp : Op<Transform_Dialect,
264260
"apply_patterns.vector.split_transfer_full_partial",
265261
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {

mlir/python/CMakeLists.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,24 @@ declare_mlir_dialect_extension_python_bindings(
192192
DIALECT_NAME transform
193193
EXTENSION_NAME structured_transform)
194194

195+
declare_mlir_dialect_extension_python_bindings(
196+
ADD_TO_PARENT MLIRPythonSources.Dialects
197+
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
198+
TD_FILE dialects/VectorTransformOps.td
199+
SOURCES
200+
dialects/transform/vector.py
201+
DIALECT_NAME transform
202+
EXTENSION_NAME vector_transform)
203+
204+
set(LLVM_TARGET_DEFINITIONS "${CMAKE_CURRENT_SOURCE_DIR}/mlir/dialects/VectorTransformOps.td")
205+
mlir_tablegen("dialects/_vector_transform_enum_gen.py" -gen-python-enum-bindings)
206+
add_public_tablegen_target(MLIRVectorTransformPyEnumGen)
207+
declare_mlir_python_sources(
208+
MLIRPythonSources.Dialects.vector_transform.enum_gen
209+
ROOT_DIR "${CMAKE_CURRENT_BINARY_DIR}"
210+
ADD_TO_PARENT MLIRPythonSources.Dialects.vector_transform
211+
SOURCES "dialects/_vector_transform_enum_gen.py" )
212+
195213
declare_mlir_dialect_python_bindings(
196214
ADD_TO_PARENT MLIRPythonSources.Dialects
197215
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
//===-- VectorTransformOps.td ------------------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Entry point of the Python bindings generator for the vector transform ops.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
14+
#ifndef PYTHON_BINDINGS_VECTORTRANSFORMOPS
15+
#define PYTHON_BINDINGS_VECTORTRANSFORMOPS
16+
17+
include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.td"
18+
19+
#endif // PYTHON_BINDINGS_VECTORTRANSFORMOPS
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
5+
from .._vector_transform_enum_gen import *
6+
from .._vector_transform_ops_gen import *
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# RUN: %PYTHON %s | FileCheck %s
2+
3+
from mlir.ir import *
4+
from mlir.dialects import transform
5+
from mlir.dialects.transform import vector
6+
7+
8+
def run_apply_patterns(f):
9+
with Context(), Location.unknown():
10+
module = Module.create()
11+
with InsertionPoint(module.body):
12+
sequence = transform.SequenceOp(
13+
transform.FailurePropagationMode.PROPAGATE,
14+
[],
15+
transform.AnyOpType.get(),
16+
)
17+
with InsertionPoint(sequence.body):
18+
apply = transform.ApplyPatternsOp(sequence.bodyTarget)
19+
with InsertionPoint(apply.patterns):
20+
f()
21+
transform.YieldOp()
22+
print("\nTEST:", f.__name__)
23+
print(module)
24+
return f
25+
26+
27+
@run_apply_patterns
28+
def non_configurable_patterns():
29+
# CHECK-LABEL: TEST: non_configurable_patterns
30+
# CHECK: apply_patterns
31+
# CHECK: transform.apply_patterns.vector.cast_away_vector_leading_one_dim
32+
vector.ApplyCastAwayVectorLeadingOneDimPatternsOp()
33+
# CHECK: transform.apply_patterns.vector.rank_reducing_subview_patterns
34+
vector.ApplyRankReducingSubviewPatternsOp()
35+
# CHECK: transform.apply_patterns.vector.transfer_permutation_patterns
36+
vector.ApplyTransferPermutationPatternsOp()
37+
# CHECK: transform.apply_patterns.vector.lower_broadcast
38+
vector.ApplyLowerBroadcastPatternsOp()
39+
# CHECK: transform.apply_patterns.vector.lower_masks
40+
vector.ApplyLowerMasksPatternsOp()
41+
# CHECK: transform.apply_patterns.vector.lower_masked_transfers
42+
vector.ApplyLowerMaskedTransfersPatternsOp()
43+
# CHECK: transform.apply_patterns.vector.materialize_masks
44+
vector.ApplyMaterializeMasksPatternsOp()
45+
# CHECK: transform.apply_patterns.vector.lower_outerproduct
46+
vector.ApplyLowerOuterProductPatternsOp()
47+
# CHECK: transform.apply_patterns.vector.lower_gather
48+
vector.ApplyLowerGatherPatternsOp()
49+
# CHECK: transform.apply_patterns.vector.lower_scan
50+
vector.ApplyLowerScanPatternsOp()
51+
# CHECK: transform.apply_patterns.vector.lower_shape_cast
52+
vector.ApplyLowerShapeCastPatternsOp()
53+
54+
55+
@run_apply_patterns
56+
def configurable_patterns():
57+
# CHECK-LABEL: TEST: configurable_patterns
58+
# CHECK: apply_patterns
59+
# CHECK: transform.apply_patterns.vector.lower_transfer
60+
# CHECK-SAME: max_transfer_rank = 4
61+
vector.ApplyLowerTransferPatternsOp(max_transfer_rank=4)
62+
# CHECK: transform.apply_patterns.vector.transfer_to_scf
63+
# CHECK-SAME: max_transfer_rank = 3
64+
# CHECK-SAME: full_unroll = true
65+
vector.ApplyTransferToScfPatternsOp(max_transfer_rank=3, full_unroll=True)
66+
67+
68+
@run_apply_patterns
69+
def enum_configurable_patterns():
70+
# CHECK: transform.apply_patterns.vector.lower_contraction
71+
vector.ApplyLowerContractionPatternsOp()
72+
# CHECK: transform.apply_patterns.vector.lower_contraction
73+
# CHECK-SAME: lowering_strategy = matmulintrinsics
74+
vector.ApplyLowerContractionPatternsOp(
75+
lowering_strategy=vector.VectorContractLowering.MATMUL
76+
)
77+
# CHECK: transform.apply_patterns.vector.lower_contraction
78+
# CHECK-SAME: lowering_strategy = parallelarith
79+
vector.ApplyLowerContractionPatternsOp(
80+
lowering_strategy=vector.VectorContractLowering.PARALLEL_ARITH
81+
)
82+
83+
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
84+
vector.ApplyLowerMultiReductionPatternsOp()
85+
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
86+
# This is the default mode, not printed.
87+
vector.ApplyLowerMultiReductionPatternsOp(
88+
lowering_strategy=vector.VectorMultiReductionLowering.INNER_PARALLEL
89+
)
90+
# CHECK: transform.apply_patterns.vector.lower_multi_reduction
91+
# CHECK-SAME: lowering_strategy = innerreduction
92+
vector.ApplyLowerMultiReductionPatternsOp(
93+
lowering_strategy=vector.VectorMultiReductionLowering.INNER_REDUCTION
94+
)
95+
96+
# CHECK: transform.apply_patterns.vector.lower_transpose
97+
# CHECK-SAME: lowering_strategy = eltwise
98+
# CHECK-SAME: avx2_lowering_strategy = false
99+
vector.ApplyLowerTransposePatternsOp()
100+
# CHECK: transform.apply_patterns.vector.lower_transpose
101+
# CHECK-SAME: lowering_strategy = eltwise
102+
# CHECK-SAME: avx2_lowering_strategy = false
103+
vector.ApplyLowerTransposePatternsOp(
104+
lowering_strategy=vector.VectorTransposeLowering.ELT_WISE
105+
)
106+
# CHECK: transform.apply_patterns.vector.lower_transpose
107+
# CHECK-SAME: lowering_strategy = flat_transpose
108+
# CHECK-SAME: avx2_lowering_strategy = false
109+
vector.ApplyLowerTransposePatternsOp(
110+
lowering_strategy=vector.VectorTransposeLowering.FLAT
111+
)
112+
# CHECK: transform.apply_patterns.vector.lower_transpose
113+
# CHECK-SAME: lowering_strategy = shuffle_1d
114+
# CHECK-SAME: avx2_lowering_strategy = false
115+
vector.ApplyLowerTransposePatternsOp(
116+
lowering_strategy=vector.VectorTransposeLowering.SHUFFLE1_D
117+
)
118+
# CHECK: transform.apply_patterns.vector.lower_transpose
119+
# CHECK-SAME: lowering_strategy = shuffle_16x16
120+
# CHECK-SAME: avx2_lowering_strategy = false
121+
vector.ApplyLowerTransposePatternsOp(
122+
lowering_strategy=vector.VectorTransposeLowering.SHUFFLE16X16
123+
)
124+
# CHECK: transform.apply_patterns.vector.lower_transpose
125+
# CHECK-SAME: lowering_strategy = flat_transpose
126+
# CHECK-SAME: avx2_lowering_strategy = true
127+
vector.ApplyLowerTransposePatternsOp(
128+
lowering_strategy=vector.VectorTransposeLowering.FLAT,
129+
avx2_lowering_strategy=True,
130+
)
131+
132+
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
133+
vector.ApplySplitTransferFullPartialPatternsOp()
134+
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
135+
# CHECK-SAME: split_transfer_strategy = none
136+
vector.ApplySplitTransferFullPartialPatternsOp(
137+
split_transfer_strategy=vector.VectorTransferSplit.NONE
138+
)
139+
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
140+
# CHECK-SAME: split_transfer_strategy = "vector-transfer"
141+
vector.ApplySplitTransferFullPartialPatternsOp(
142+
split_transfer_strategy=vector.VectorTransferSplit.VECTOR_TRANSFER
143+
)
144+
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
145+
# This is the default mode, not printed.
146+
vector.ApplySplitTransferFullPartialPatternsOp(
147+
split_transfer_strategy=vector.VectorTransferSplit.LINALG_COPY
148+
)
149+
# CHECK: transform.apply_patterns.vector.split_transfer_full_partial
150+
# CHECK-SAME: split_transfer_strategy = "force-in-bounds"
151+
vector.ApplySplitTransferFullPartialPatternsOp(
152+
split_transfer_strategy=vector.VectorTransferSplit.FORCE_IN_BOUNDS
153+
)

utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,45 @@ gentbl_filegroup(
902902
],
903903
)
904904

905+
gentbl_filegroup(
906+
name = "VectorTransformEnumPyGen",
907+
tbl_outs = [
908+
(
909+
["-gen-python-enum-bindings"],
910+
"mlir/dialects/_vector_transform_enum_gen.py",
911+
),
912+
],
913+
tblgen = "//mlir:mlir-tblgen",
914+
td_file = "mlir/dialects/VectorTransformOps.td",
915+
deps = [
916+
"//mlir:OpBaseTdFiles",
917+
"//mlir:TransformDialectTdFiles",
918+
"//mlir:VectorTransformOpsTdFiles",
919+
],
920+
)
921+
922+
gentbl_filegroup(
923+
name = "VectorTransformOpsPyGen",
924+
tbl_outs = [
925+
(
926+
[
927+
"-gen-python-op-bindings",
928+
"-bind-dialect=transform",
929+
"-dialect-extension=vector_transform",
930+
],
931+
"mlir/dialects/_vector_transform_ops_gen.py",
932+
),
933+
],
934+
tblgen = "//mlir:mlir-tblgen",
935+
td_file = "mlir/dialects/VectorTransformOps.td",
936+
deps = [
937+
"//mlir:OpBaseTdFiles",
938+
"//mlir:TransformDialectTdFiles",
939+
"//mlir:VectorTransformOpsTdFiles",
940+
],
941+
)
942+
943+
905944
filegroup(
906945
name = "TransformOpsPyFiles",
907946
srcs = [
@@ -919,6 +958,8 @@ filegroup(
919958
":StructuredTransformOpsPyGen",
920959
":TransformEnumPyGen",
921960
":TransformOpsPyGen",
961+
":VectorTransformEnumPyGen",
962+
":VectorTransformOpsPyGen",
922963
],
923964
)
924965

0 commit comments

Comments
 (0)