Skip to content

Commit 6d2fd3d

Browse files
[mlir][linalg] Replace monomorphic contration ops with polymorphic variants.
* Moves `batch_matmul`, `matmul`, `matvec`, `vectmat`, `dot` to the new mechanism. * This is not just an NFC change, in addition to using a new code generation mechanism, it also activates symbolic casting, allowing mixed precision operands and results. * These definitions were generated from DSL by the tool: https://github.com/stellaraccident/mlir-linalgpy/blob/main/mlir_linalg/oplib/core.py (will be upstreamed in a subsequent set of changes). Reviewed By: nicolasvasilache, ThomasRaoux Differential Revision: https://reviews.llvm.org/D97719
1 parent d36a15d commit 6d2fd3d

File tree

3 files changed

+259
-44
lines changed

3 files changed

+259
-44
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 250 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
--- !LinalgOpConfig
22
metadata: !LinalgOpMetadata
3-
name: polymorphic_matmul
4-
cpp_op_name: PolymorphicMatmulOp
3+
name: matmul
4+
cpp_op_name: MatmulOp
55
doc: |-
6-
Type polymorphic matrix multiplication.
6+
Performs a matrix multiplacation of two 2D inputs.
77
8-
This op is presently here to test a new path for generation and will replace
9-
the existing 'matmul' op when ready. Do not use.
8+
Numeric casting is performed on the operands to the inner multiply, promoting
9+
them to the same data type as the accumulator/output.
1010
implements:
1111
- LinalgContractionOpInterface
1212
structured_op: !LinalgStructuredOpConfig
@@ -60,4 +60,249 @@ structured_op: !LinalgStructuredOpConfig
6060
operands:
6161
- !ScalarExpression
6262
scalar_arg: B
63+
--- !LinalgOpConfig
64+
metadata: !LinalgOpMetadata
65+
name: batch_matmul
66+
cpp_op_name: BatchMatmulOp
67+
doc: |-
68+
Performs a batched matrix multiplacation of two 3D inputs.
69+
70+
Numeric casting is performed on the operands to the inner multiply, promoting
71+
them to the same data type as the accumulator/output.
72+
implements:
73+
- LinalgContractionOpInterface
74+
structured_op: !LinalgStructuredOpConfig
75+
args:
76+
- !<LinalgTensorDef>
77+
name: A
78+
usage: input
79+
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s3)>
80+
element_type_var: T1
81+
- !<LinalgTensorDef>
82+
name: B
83+
usage: input
84+
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s3, s2)>
85+
element_type_var: T2
86+
- !<LinalgTensorDef>
87+
name: C
88+
usage: output
89+
shape: affine_map<()[s0, s1, s2, s3] -> (s0, s1, s2)>
90+
element_type_var: U
91+
indexing_maps: !LinalgIndexingMapsConfig
92+
static_indexing_maps:
93+
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d3)>
94+
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d3, d2)>
95+
- affine_map<(d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0, d1, d2)>
96+
iterator_types:
97+
- parallel
98+
- parallel
99+
- parallel
100+
- reduction
101+
assignments:
102+
- !ScalarAssign
103+
arg: C
104+
value: !ScalarExpression
105+
scalar_apply:
106+
fn_name: add
107+
operands:
108+
- !ScalarExpression
109+
scalar_arg: C
110+
- !ScalarExpression
111+
scalar_apply:
112+
fn_name: mul
113+
operands:
114+
- !ScalarExpression
115+
symbolic_cast:
116+
type_var: U
117+
operands:
118+
- !ScalarExpression
119+
scalar_arg: A
120+
- !ScalarExpression
121+
symbolic_cast:
122+
type_var: U
123+
operands:
124+
- !ScalarExpression
125+
scalar_arg: B
126+
--- !LinalgOpConfig
127+
metadata: !LinalgOpMetadata
128+
name: matvec
129+
cpp_op_name: MatvecOp
130+
doc: |-
131+
Performs a matrix-vector multiplication.
132+
133+
Numeric casting is performed on the operands to the inner multiply, promoting
134+
them to the same data type as the accumulator/output.
135+
implements:
136+
- LinalgContractionOpInterface
137+
structured_op: !LinalgStructuredOpConfig
138+
args:
139+
- !<LinalgTensorDef>
140+
name: A
141+
usage: input
142+
shape: affine_map<()[s0, s1] -> (s0, s1)>
143+
element_type_var: T1
144+
- !<LinalgTensorDef>
145+
name: y
146+
usage: input
147+
shape: affine_map<()[s0, s1] -> (s1)>
148+
element_type_var: T2
149+
- !<LinalgTensorDef>
150+
name: x
151+
usage: output
152+
shape: affine_map<()[s0, s1] -> (s0)>
153+
element_type_var: U
154+
indexing_maps: !LinalgIndexingMapsConfig
155+
static_indexing_maps:
156+
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
157+
- affine_map<(d0, d1)[s0, s1] -> (d1)>
158+
- affine_map<(d0, d1)[s0, s1] -> (d0)>
159+
iterator_types:
160+
- parallel
161+
- reduction
162+
assignments:
163+
- !ScalarAssign
164+
arg: x
165+
value: !ScalarExpression
166+
scalar_apply:
167+
fn_name: add
168+
operands:
169+
- !ScalarExpression
170+
scalar_arg: x
171+
- !ScalarExpression
172+
scalar_apply:
173+
fn_name: mul
174+
operands:
175+
- !ScalarExpression
176+
symbolic_cast:
177+
type_var: U
178+
operands:
179+
- !ScalarExpression
180+
scalar_arg: A
181+
- !ScalarExpression
182+
symbolic_cast:
183+
type_var: U
184+
operands:
185+
- !ScalarExpression
186+
scalar_arg: y
187+
--- !LinalgOpConfig
188+
metadata: !LinalgOpMetadata
189+
name: vecmat
190+
cpp_op_name: VecmatOp
191+
doc: |-
192+
Performs a vector-matrix multiplacation.
193+
194+
Numeric casting is performed on the operands to the inner multiply, promoting
195+
them to the same data type as the accumulator/output.
196+
implements:
197+
- LinalgContractionOpInterface
198+
structured_op: !LinalgStructuredOpConfig
199+
args:
200+
- !<LinalgTensorDef>
201+
name: y
202+
usage: input
203+
shape: affine_map<()[s0, s1] -> (s1)>
204+
element_type_var: T1
205+
- !<LinalgTensorDef>
206+
name: A
207+
usage: input
208+
shape: affine_map<()[s0, s1] -> (s1, s0)>
209+
element_type_var: T2
210+
- !<LinalgTensorDef>
211+
name: x
212+
usage: output
213+
shape: affine_map<()[s0, s1] -> (s0)>
214+
element_type_var: U
215+
indexing_maps: !LinalgIndexingMapsConfig
216+
static_indexing_maps:
217+
- affine_map<(d0, d1)[s0, s1] -> (d1)>
218+
- affine_map<(d0, d1)[s0, s1] -> (d1, d0)>
219+
- affine_map<(d0, d1)[s0, s1] -> (d0)>
220+
iterator_types:
221+
- parallel
222+
- reduction
223+
assignments:
224+
- !ScalarAssign
225+
arg: x
226+
value: !ScalarExpression
227+
scalar_apply:
228+
fn_name: add
229+
operands:
230+
- !ScalarExpression
231+
scalar_arg: x
232+
- !ScalarExpression
233+
scalar_apply:
234+
fn_name: mul
235+
operands:
236+
- !ScalarExpression
237+
symbolic_cast:
238+
type_var: U
239+
operands:
240+
- !ScalarExpression
241+
scalar_arg: y
242+
- !ScalarExpression
243+
symbolic_cast:
244+
type_var: U
245+
operands:
246+
- !ScalarExpression
247+
scalar_arg: A
248+
--- !LinalgOpConfig
249+
metadata: !LinalgOpMetadata
250+
name: dot
251+
cpp_op_name: DotOp
252+
doc: |-
253+
Performs a dot product of two vectors to a scalar result.
254+
255+
Numeric casting is performed on the operands to the inner multiply, promoting
256+
them to the same data type as the accumulator/output.
257+
implements:
258+
- LinalgContractionOpInterface
259+
structured_op: !LinalgStructuredOpConfig
260+
args:
261+
- !<LinalgTensorDef>
262+
name: A
263+
usage: input
264+
shape: affine_map<()[s0] -> (s0)>
265+
element_type_var: T1
266+
- !<LinalgTensorDef>
267+
name: B
268+
usage: input
269+
shape: affine_map<()[s0] -> (s0)>
270+
element_type_var: T2
271+
- !<LinalgTensorDef>
272+
name: C
273+
usage: output
274+
shape: affine_map<()[s0] -> ()>
275+
element_type_var: U
276+
indexing_maps: !LinalgIndexingMapsConfig
277+
static_indexing_maps:
278+
- affine_map<(d0)[s0] -> (d0)>
279+
- affine_map<(d0)[s0] -> (d0)>
280+
- affine_map<(d0)[s0] -> ()>
281+
iterator_types:
282+
- reduction
283+
assignments:
284+
- !ScalarAssign
285+
arg: C
286+
value: !ScalarExpression
287+
scalar_apply:
288+
fn_name: add
289+
operands:
290+
- !ScalarExpression
291+
scalar_arg: C
292+
- !ScalarExpression
293+
scalar_apply:
294+
fn_name: mul
295+
operands:
296+
- !ScalarExpression
297+
symbolic_cast:
298+
type_var: U
299+
operands:
300+
- !ScalarExpression
301+
scalar_arg: A
302+
- !ScalarExpression
303+
symbolic_cast:
304+
type_var: U
305+
operands:
306+
- !ScalarExpression
307+
scalar_arg: B
63308

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOpsSpec.tc

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
ods_def<MatmulOp>
2-
implements_interface<LinalgContractionOpInterface> :
3-
def matmul(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
4-
C(m, n) = std_addf<k>(C(m, n), std_mulf(A(m, k), B(k, n)));
5-
}
6-
71
ods_def<MatmulColumnMajorOp>
82
implements_interface<LinalgContractionOpInterface> :
93
def matmul_column_major(A: f32(K, M), B: f32(N, K)) -> (C: f32(N, M)) {
@@ -30,12 +24,6 @@ def matmul_i32_i32_i32(A: i32(M, K), B: i32(K, N)) -> (C: i32(M, N)) {
3024
C(m, n) = std_addi<k>(C(m, n), std_muli(A(m, k), B(k, n)));
3125
}
3226

33-
ods_def<MatvecOp>
34-
implements_interface<LinalgContractionOpInterface> :
35-
def matvec(A: f32(M, N), y: f32(N)) -> (x: f32(M)) {
36-
x(m) = std_addf<n>(x(m), std_mulf(A(m, n), y(n)));
37-
}
38-
3927
ods_def<MatvecI8I8I32Op>
4028
implements_interface<LinalgContractionOpInterface> :
4129
def matvec_i8_i8_i32(A: i8(M, N), y: i8(N)) -> (x: i32(M)) {
@@ -54,12 +42,6 @@ def matvec_i32_i32_i32(A: i32(M, N), y: i32(N)) -> (x: i32(M)) {
5442
x(m) = std_addi<n>(x(m), std_muli(A(m, n), y(n)));
5543
}
5644

57-
ods_def<VecmatOp>
58-
implements_interface<LinalgContractionOpInterface> :
59-
def vecmat(y: f32(M), A: f32(M, N)) -> (x: f32(N)) {
60-
x(n) = std_addf<m>(x(n), std_mulf(y(m), A(m, n)));
61-
}
62-
6345
ods_def<VecmatI8I8I32Op>
6446
implements_interface<LinalgContractionOpInterface> :
6547
def vecmat_i8_i8_i32(y: i8(M), A: i8(M, N)) -> (x: i32(N)) {
@@ -78,12 +60,6 @@ def vecmat_i32_i32_i32(y: i32(M), A: i32(M, N)) -> (x: i32(N)) {
7860
x(n) = std_addi<m>(x(n), std_muli(y(m), A(m, n)));
7961
}
8062

81-
ods_def<DotOp>
82-
implements_interface<LinalgContractionOpInterface> :
83-
def dot(A: f32(M), B: f32(M)) -> (C: f32()) {
84-
C() = std_addf<m>(C(), std_mulf(A(m), B(m)));
85-
}
86-
8763
ods_def<DotI8I8I32Op>
8864
implements_interface<LinalgContractionOpInterface> :
8965
def dot_i8_i8_i32(A: i8(M), B: i8(M)) -> (C: i32()) {
@@ -102,12 +78,6 @@ def dot_i32_i32_i32(A: i32(M), B: i32(M)) -> (C: i32()) {
10278
C() = std_addi<m>(C(), std_muli(A(m), B(m)));
10379
}
10480

105-
ods_def<BatchMatmulOp>
106-
implements_interface<LinalgContractionOpInterface> :
107-
def batch_matmul(A: f32(Batch, M, K), B: f32(Batch, K, N)) -> (C: f32(Batch, M, N)) {
108-
C(b, m, n) = std_addf<k>(C(b, m, n), std_mulf(A(b, m, k), B(b, k, n)));
109-
}
110-
11181
ods_def<BatchMatmulI8I8I32Op>
11282
implements_interface<LinalgContractionOpInterface> :
11383
def batch_matmul_i8_i8_i32(A: i8(Batch, M, K), B: i8(Batch, K, N)) -> (C: i32(Batch, M, N)) {

0 commit comments

Comments
 (0)