Skip to content

Commit f345f7e

Browse files
author
gysit
committed
[mlir][OpDSL] Support pointwise ops with rank zero inputs.
Allow pointwise operations to take rank zero input tensors similarly to scalar inputs. Use an empty indexing map to broadcast rank zero tensors to the iteration domain of the operation. Depends On D120734 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D120807
1 parent 52f7578 commit f345f7e

File tree

5 files changed

+42
-10
lines changed

5 files changed

+42
-10
lines changed

mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,11 @@ def emit_generic_structured_op(op_config: LinalgStructuredOpConfig, *ins: Value,
187187
if arg_def.operand_def.kind == OperandKind.SCALAR:
188188
indexing_maps.append(scalar_map)
189189
if arg_def.operand_def.is_tensor():
190-
indexing_maps.append(tensor_map)
190+
idx = arg_def.operand_def.registered_index
191+
if idx < len(ins) and ShapedType(ins[idx].type).rank == 0:
192+
indexing_maps.append(scalar_map)
193+
else:
194+
indexing_maps.append(tensor_map)
191195
indexing_maps_attr = ArrayAttr.get(
192196
[AffineMapAttr.get(am) for am in indexing_maps])
193197

mlir/test/Dialect/Linalg/generalize-named-polymorphic-ops.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,3 +320,18 @@ func @generalize_elemwise_mul(%lhs : tensor<4x8xf32>, %rhs : tensor<4x8xf32>, %o
320320

321321
// CHECK-LABEL: @generalize_elemwise_mul
322322
// CHECK: = arith.mulf
323+
324+
// -----
325+
326+
// Verifies pointwise ops support rank zero input tensors
327+
func @generalize_elemwise_rank_zero(%lhs : tensor<f32>, %rhs : tensor<f32>, %output : tensor<4x8xf32>) -> tensor<4x8xf32> {
328+
%0 = linalg.elemwise_binary {fun = #linalg.binary_fn<sub>}
329+
ins(%lhs, %rhs: tensor<f32>, tensor<f32>)
330+
outs(%output: tensor<4x8xf32>) -> tensor<4x8xf32>
331+
return %0: tensor<4x8xf32>
332+
}
333+
334+
// CHECK-LABEL: @generalize_elemwise_rank_zero
335+
// CHECK: linalg.generic
336+
// CHECK-SAME: iterator_types = ["parallel", "parallel"]
337+
// CHECK: = arith.subf

mlir/test/python/dialects/linalg/opdsl/emit_fill.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)):
1616
O[None] = TypeFn.cast_signed(U, value)
1717

18+
@linalg_structured_op
19+
def fill_rank_zero_poly(I=TensorDef(T1), O=TensorDef(U, output=True)):
20+
O[None] = TypeFn.cast_signed(U, I[None])
1821

1922
with Context() as ctx, Location.unknown():
2023
module = Module.create()
@@ -25,6 +28,8 @@ def fill_poly(value=ScalarDef(T1), O=TensorDef(U, output=True)):
2528
# CHECK-DAG: #[[$MAP0:.+]] = affine_map<() -> ()>
2629
# CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
2730
# CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
31+
# CHECK-DAG: #[[$MAP3:.+]] = affine_map<(d0, d1, d2) -> ()>
32+
# CHECK-DAG: #[[$MAP4:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
2833

2934
# CHECK-LABEL: @test_fill_0d
3035
# CHECK: linalg.generic
@@ -42,5 +47,13 @@ def test_fill_0d(value, init_result):
4247
def test_fill_2d(value, init_result):
4348
return fill_poly(value, outs=[init_result])
4449

50+
# CHECK-LABEL: @test_fill_rank_zero_3d
51+
# CHECK: linalg.generic
52+
# CHECK-SAME: indexing_maps = [#[[$MAP3]], #[[$MAP4]]]
53+
# CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
54+
@builtin.FuncOp.from_py_func(
55+
RankedTensorType.get([], f32), RankedTensorType.get([4, 8, 16], f32))
56+
def test_fill_rank_zero_3d(input, init_result):
57+
return fill_rank_zero_poly(input, outs=[init_result])
4558

4659
print(module)

mlir/test/python/integration/dialects/linalg/opsrun.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,19 @@ def log(*args):
2525
%v1 = arith.constant 1.0 : f32
2626
%v2 = arith.constant 2.0 : f32
2727
28-
%lhs = memref.alloc() : memref<4x8xf32>
28+
%lhs = memref.alloc() : memref<f32>
2929
%rhs = memref.alloc() : memref<4x8xf32>
3030
%O0 = memref.alloc() : memref<4x8xf32>
3131
%O1 = memref.alloc() : memref<4x8xf32>
32-
linalg.fill(%v1, %lhs) : f32, memref<4x8xf32>
32+
linalg.fill(%v1, %lhs) : f32, memref<f32>
3333
linalg.fill(%v2, %rhs) : f32, memref<4x8xf32>
3434
linalg.fill(%v0, %O0) : f32, memref<4x8xf32>
3535
linalg.fill(%v0, %O1) : f32, memref<4x8xf32>
3636
3737
call @elemwise_exp_add_on_buffers(%lhs, %rhs, %O0) :
38-
(memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
38+
(memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
3939
call @elemwise_log_mul_on_buffers(%lhs, %rhs, %O1) :
40-
(memref<4x8xf32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
40+
(memref<f32>, memref<4x8xf32>, memref<4x8xf32>) -> ()
4141
4242
%c0 = arith.constant 0 : index
4343
%res0 = memref.load %O0[%c0, %c0] : memref<4x8xf32>
@@ -212,14 +212,14 @@ def test_elemwise_builtin():
212212
with InsertionPoint(module.body):
213213

214214
@builtin.FuncOp.from_py_func(
215-
MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
215+
MemRefType.get((), f32), MemRefType.get((4, 8), f32),
216216
MemRefType.get((4, 8), f32))
217217
def elemwise_exp_add_on_buffers(lhs, rhs, out):
218218
linalg.elemwise_unary(lhs, outs=[out])
219219
linalg.elemwise_binary(out, rhs, outs=[out])
220220

221221
@builtin.FuncOp.from_py_func(
222-
MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
222+
MemRefType.get((), f32), MemRefType.get((4, 8), f32),
223223
MemRefType.get((4, 8), f32))
224224
def elemwise_log_mul_on_buffers(lhs, rhs, out):
225225
linalg.elemwise_unary(lhs, outs=[out], fun=UnaryFn.log)
@@ -251,14 +251,14 @@ def test_elemwise_generic():
251251
with InsertionPoint(module.body):
252252

253253
@builtin.FuncOp.from_py_func(
254-
MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
254+
MemRefType.get((), f32), MemRefType.get((4, 8), f32),
255255
MemRefType.get((4, 8), f32))
256256
def elemwise_exp_add_on_buffers(lhs, rhs, out):
257257
linalg.elemwise_unary(lhs, outs=[out], emit_generic=True)
258258
linalg.elemwise_binary(out, rhs, outs=[out], emit_generic=True)
259259

260260
@builtin.FuncOp.from_py_func(
261-
MemRefType.get((4, 8), f32), MemRefType.get((4, 8), f32),
261+
MemRefType.get((), f32), MemRefType.get((4, 8), f32),
262262
MemRefType.get((4, 8), f32))
263263
def elemwise_log_mul_on_buffers(lhs, rhs, out):
264264
linalg.elemwise_unary(

mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ ArrayAttr {0}::indexing_maps() {{
678678
getNumParallelLoops(), context);
679679
SmallVector<AffineMap> indexingMaps;
680680
for (OpOperand *opOperand : getInputAndOutputOperands())
681-
indexingMaps.push_back(isScalar(opOperand) ? scalarMap : tensorMap);
681+
indexingMaps.push_back(getRank(opOperand) == 0 ? scalarMap : tensorMap);
682682
return Builder(getContext()).getAffineMapArrayAttr(indexingMaps);
683683
}
684684
)FMT";

0 commit comments

Comments
 (0)