Skip to content

Commit 78dc1e4

Browse files
author
Tobias Gysi
committed
[mlir][linalg][python] Add shape-only tensor support to OpDSL.
Add an index_dim annotation to specify the shape to loop mapping of shape-only tensors. A shape-only tensor serves is not accessed withing the body of the operation but is required to span the iteration space of certain operations such as pooling. Differential Revision: https://reviews.llvm.org/D104767
1 parent e0f2744 commit 78dc1e4

File tree

8 files changed

+345
-40
lines changed

8 files changed

+345
-40
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 78 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,21 +309,25 @@ structured_op: !LinalgStructuredOpConfig
309309
metadata: !LinalgOpMetadata
310310
name: depthwise_conv_2d_input_nhwc_filter_hwc_poly
311311
cpp_class_name: DepthwiseConv2DInputNhwcFilterHwcPolyOp
312-
doc: A depth-wise 2-D convolution operation.
312+
doc: |-
313+
Performs depth-wise 2-D convolution.
314+
315+
Numeric casting is performed on the operands to the inner multiply, promoting
316+
them to the same data type as the accumulator/output.
313317
structured_op: !LinalgStructuredOpConfig
314318
args:
315319
- !LinalgOperandDefConfig
316320
name: I
317321
usage: InputOperand
318322
type_var: T1
319323
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
320-
(s0, s6, s7, s3)>
324+
(s0, s4, s5, s3)>
321325
- !LinalgOperandDefConfig
322326
name: K
323327
usage: InputOperand
324328
type_var: T2
325329
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
326-
(s4, s5, s3)>
330+
(s6, s7, s3)>
327331
- !LinalgOperandDefConfig
328332
name: O
329333
usage: OutputOperand
@@ -383,6 +387,77 @@ structured_op: !LinalgStructuredOpConfig
383387
- !ScalarExpression
384388
scalar_arg: K
385389
--- !LinalgOpConfig
390+
metadata: !LinalgOpMetadata
391+
name: pooling_nhwc_sum_poly
392+
cpp_class_name: PoolingNhwcSumPolyOp
393+
doc: |-
394+
Performs sum pooling.
395+
396+
Numeric casting is performed on the input operand, promoting it to the same
397+
data type as the accumulator/output.
398+
structured_op: !LinalgStructuredOpConfig
399+
args:
400+
- !LinalgOperandDefConfig
401+
name: I
402+
usage: InputOperand
403+
type_var: T1
404+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
405+
(s0, s4, s5, s3)>
406+
- !LinalgOperandDefConfig
407+
name: K
408+
usage: InputOperand
409+
type_var: T2
410+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
411+
(s10, s11)>
412+
- !LinalgOperandDefConfig
413+
name: O
414+
usage: OutputOperand
415+
type_var: U
416+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
417+
(s0, s1, s2, s3)>
418+
- !LinalgOperandDefConfig
419+
name: strides
420+
usage: IndexAttribute
421+
type_var: I64
422+
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
423+
-> (s6, s7)>
424+
- !LinalgOperandDefConfig
425+
name: dilations
426+
usage: IndexAttribute
427+
type_var: I64
428+
attribute_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
429+
-> (s8, s9)>
430+
indexing_maps: !LinalgIndexingMapsConfig
431+
static_indexing_maps:
432+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
433+
s10, s11] -> (d2, d3 * s6 + d0 * s8, d4 * s7 + d1 * s9, d5)>
434+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
435+
s10, s11] -> (d0, d1)>
436+
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9,
437+
s10, s11] -> (d2, d3, d4, d5)>
438+
iterator_types:
439+
- reduction
440+
- reduction
441+
- parallel
442+
- parallel
443+
- parallel
444+
- parallel
445+
assignments:
446+
- !ScalarAssign
447+
arg: O
448+
value: !ScalarExpression
449+
scalar_apply:
450+
fn_name: add
451+
operands:
452+
- !ScalarExpression
453+
scalar_arg: O
454+
- !ScalarExpression
455+
symbolic_cast:
456+
type_var: U
457+
operands:
458+
- !ScalarExpression
459+
scalar_arg: I
460+
--- !LinalgOpConfig
386461
metadata: !LinalgOpMetadata
387462
name: fill_rng_2d
388463
cpp_class_name: FillRng2DOp

mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,13 +151,15 @@ class OperandDef:
151151
def __init__(self,
152152
kind: OperandKind,
153153
type_var: TypeVar,
154-
size_exprs: Optional[Sequence[AffineExprDef]] = None):
154+
size_exprs: Optional[Sequence[AffineExprDef]] = None,
155+
index_dims: Optional[Sequence[DimDef]] = None):
155156
if not isinstance(type_var, TypeVar):
156157
raise ValueError(
157158
f"OperandDef requires a TypeVar but got {repr(type_var)}")
158159
self.owner = None # type: Optional["LinalgOpDef"]
159160
self.type_var = type_var
160161
self.size_exprs = size_exprs
162+
self.index_dims = index_dims
161163
self.kind = kind
162164
self.name = None # type: Optional[str]
163165
self.registered_index = -1 # type: int
@@ -174,7 +176,8 @@ def __hash__(self):
174176

175177
def __repr__(self):
176178
return (f"{self.name}:OperandDef(kind={self.kind.name}, "
177-
f"type={repr(self.type_var)}, size_exprs={self.size_exprs})")
179+
f"type={repr(self.type_var)}, size_exprs={self.size_exprs}), "
180+
f"index_dims={self.index_dims})")
178181

179182

180183
class TensorDef:
@@ -184,15 +187,25 @@ class TensorDef:
184187
to the body of the structured op. A unique name identifies the tensor operands
185188
and an index determines their position in the operation's parameter list. A
186189
tensor definition takes type, a shape, and an optional flag to mark output
187-
tensors.
190+
tensors. Additionally, a tuple of index dimensions may be used to map the
191+
tensor to the loop dimensions of the operation. This mapping is needed to
192+
compute the indexing map of shape-only tensors that have no uses.
188193
"""
189194

190195
def __init__(self,
191196
type_var: TypeVar,
192197
*shape: AffineExprDef,
198+
index_dims: Optional[Sequence[DimDef]] = None,
193199
output: bool = False):
200+
if index_dims and len(shape) != len(index_dims):
201+
raise ValueError(f"Expected the shape rank {len(shape)} to match the "
202+
f"number of index_dims {len(index_dims)}")
203+
if index_dims and any(not isinstance(dim, DimDef) for dim in index_dims):
204+
raise ValueError(f"TensorDef requires index dims of type DimDef but "
205+
f"got {type(index_dims)}")
194206
kind = OperandKind.OutputTensor if output else OperandKind.InputTensor
195-
self.operand_def = OperandDef(kind, type_var, size_exprs=shape)
207+
self.operand_def = OperandDef(
208+
kind, type_var, size_exprs=shape, index_dims=index_dims)
196209

197210
def __getitem__(self, dims) -> TensorUse:
198211
assert self.operand_def.owner, "TensorDef is not attached to an op"

mlir/python/mlir/dialects/linalg/opdsl/lang/config.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,26 +138,39 @@ def __init__(self,
138138
read_use.collect_scalar_uses(collected_scalar_uses)
139139
read_use.collect_indices(collected_indices)
140140

141-
# Collect all attribute definitions
141+
# Collect all attribute definitions.
142142
collected_attr_defs = list()
143143
for operand in registered_operands:
144144
if operand.kind == OperandKind.Attribute:
145145
collected_attr_defs.append(operand)
146146

147+
# Collect all tensors with manual indexing annotation.
148+
collected_index_defs = list()
149+
for operand in registered_operands:
150+
if operand.index_dims:
151+
collected_index_defs.append(operand)
152+
147153
# Add all definitions before uses, so process twice.
148154
for use in collected_tensor_uses:
149155
self.add_operand(use.operand_def)
150156
for use in collected_scalar_uses:
151157
self.add_operand(use.operand_def)
152158
for definition in collected_attr_defs:
153159
self.add_operand(definition)
160+
for definition in collected_index_defs:
161+
if definition not in self.operands:
162+
self.add_operand(definition)
163+
self.add_indexed_operand(definition)
154164
for use in collected_tensor_uses:
155165
self.add_tensor_use(use)
156166

157167
# Normalize all shape and indexing maps now that full count of dims and
158168
# symbols are known.
159169
for cuse in self.uses.values():
160170
cuse.indexing_map = self._normalize_affine_map(cuse.indexing_map)
171+
for definition in collected_index_defs:
172+
self.operands[definition].indexing_map = self._normalize_affine_map(
173+
self.operands[definition].indexing_map)
161174
for operand_config in self.operands.values():
162175
if operand_config.shape_map:
163176
operand_config.shape_map = self._normalize_affine_map(
@@ -278,6 +291,18 @@ def add_operand(self, operand_def: OperandDef):
278291
self.operands[operand_def] = OperandDefConfig(
279292
operand_def, shape_map=affine_map)
280293

294+
def add_indexed_operand(self, operand_def: OperandDef):
295+
with self.context:
296+
local_state = AffineBuildState(
297+
global_state=self.affine_state, allow_new_symbols=False)
298+
exprs = []
299+
for expr in operand_def.index_dims:
300+
exprs.append(expr.build(state=local_state))
301+
self.operands[operand_def].indexing_map = _ir.AffineMap.get(
302+
dim_count=local_state.dim_count,
303+
symbol_count=local_state.symbol_count,
304+
exprs=exprs)
305+
281306
def add_tensor_use(self, tensor_use: TensorUse):
282307
if tensor_use in self.uses:
283308
return

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,32 @@ def depthwise_conv_2d_input_nhwc_filter_hwc_poly(
8181
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
8282
strides=AttributeDef(S.SH, S.SW),
8383
dilations=AttributeDef(S.DH, S.DW)):
84-
"""A depth-wise 2-D convolution operation."""
84+
"""Performs depth-wise 2-D convolution.
85+
86+
Numeric casting is performed on the operands to the inner multiply, promoting
87+
them to the same data type as the accumulator/output.
88+
"""
8589
O[D.n, D.oh, D.ow, D.c] += cast(
8690
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
8791
D.c]) * cast(U, K[D.kh, D.kw, D.c])
8892

8993

94+
@linalg_structured_op
95+
def pooling_nhwc_sum_poly(
96+
I=TensorDef(T1, S.N, S.H, S.W, S.C),
97+
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
98+
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
99+
strides=AttributeDef(S.SH, S.SW),
100+
dilations=AttributeDef(S.DH, S.DW)):
101+
"""Performs sum pooling.
102+
103+
Numeric casting is performed on the input operand, promoting it to the same
104+
data type as the accumulator/output.
105+
"""
106+
O[D.n, D.oh, D.ow, D.c] += cast(
107+
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c])
108+
109+
90110
@linalg_structured_op
91111
def fill_rng_2d(
92112
min=ScalarDef(F64),

mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,34 @@ func @generalize_depthwise_conv_2d_input_nhwc_filter_hwc_poly_i32(%input : tenso
6060

6161
// -----
6262

63+
func @generalize_pooling_nhwc_sum_poly_f32(%input : tensor<1x4x16x1xf32>, %shape: tensor<2x2xf32>, %output: tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32> {
64+
%0 = linalg.pooling_nhwc_sum_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
65+
ins(%input, %shape : tensor<1x4x16x1xf32>, tensor<2x2xf32>) outs(%output : tensor<1x2x4x1xf32>) -> tensor<1x2x4x1xf32>
66+
return %0: tensor<1x2x4x1xf32>
67+
}
68+
69+
// CHECK-LABEL: @generalize_pooling_nhwc_sum_poly_f32
70+
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: f32, %[[SHAPE_ARG:.+]]: f32, %[[OUT_ARG:.+]]: f32)
71+
// CHECK-NEXT: %[[ADD:.+]] = addf %[[OUT_ARG]], %[[IN_ARG]] : f32
72+
// CHECK-NEXT: linalg.yield %[[ADD]] : f32
73+
// CHECK-NEXT: -> tensor<1x2x4x1xf32>
74+
75+
// -----
76+
77+
func @generalize_pooling_nhwc_sum_poly_i32(%input : tensor<1x4x16x1xi32>, %shape: tensor<2x2xi32>, %output: tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32> {
78+
%0 = linalg.pooling_nhwc_sum_poly {dilations = dense<[1, 2]> : tensor<2xi64>, strides = dense<[2, 4]> : tensor<2xi64>}
79+
ins(%input, %shape : tensor<1x4x16x1xi32>, tensor<2x2xi32>) outs(%output : tensor<1x2x4x1xi32>) -> tensor<1x2x4x1xi32>
80+
return %0: tensor<1x2x4x1xi32>
81+
}
82+
83+
// CHECK-LABEL: @generalize_pooling_nhwc_sum_poly_i32
84+
// CHECK: ^{{.*}}(%[[IN_ARG:.+]]: i32, %[[SHAPE_ARG:.+]]: i32, %[[OUT_ARG:.+]]: i32)
85+
// CHECK-NEXT: %[[ADD:.+]] = addi %[[OUT_ARG]], %[[IN_ARG]] : i32
86+
// CHECK-NEXT: linalg.yield %[[ADD]] : i32
87+
// CHECK-NEXT: -> tensor<1x2x4x1xi32>
88+
89+
// -----
90+
6391
func @generalize_fill_rng_2d_f32(%min: f64, %max: f64, %seed: i32, %O: tensor<16x32xf32>) -> tensor<16x32xf32> {
6492
%0 = linalg.fill_rng_2d ins(%min, %max, %seed: f64, f64, i32) outs(%O : tensor<16x32xf32>) -> tensor<16x32xf32>
6593
return %0: tensor<16x32xf32>

0 commit comments

Comments
 (0)