Skip to content

Commit 3b2f26a

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] NFC : Fix check for scalar case handling in LinalgToLoops
The invertPermutation method does not return a nullptr anymore, but rather returns an empty map for the scalar case. Update the check in LinalgToLoops to reflect this. Also add test case for generating scalar code.
1 parent 03391df commit 3b2f26a

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,8 +652,8 @@ LinalgOpToLoopsImpl<LoopTy, ConcreteOpTy>::doit(Operation *op,
652652
linalgOp.indexing_maps().template getAsRange<AffineMapAttr>();
653653
auto maps =
654654
functional::map([](AffineMapAttr a) { return a.getValue(); }, mapsRange);
655-
auto invertedMap = inversePermutation(concatAffineMaps(maps));
656-
if (!invertedMap) {
655+
AffineMap invertedMap = inversePermutation(concatAffineMaps(maps));
656+
if (invertedMap.isEmpty()) {
657657
LinalgScopedEmitter<IndexedValueTy, ConcreteOpTy>::emitScalarImplementation(
658658
{}, linalgOp);
659659
return LinalgLoops();

mlir/test/Dialect/Linalg/loops.mlir

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -913,3 +913,46 @@ func @generic_const_init(%arg0: memref<?xf32>) {
913913
// CHECKPARALLEL: %[[CONST:.*]] = constant 1.000000e+00 : f32
914914
// CHECKPARALLEL: loop.parallel (%[[i:.*]])
915915
// CHECKPARALLEL: store %[[CONST]], %[[ARG0]]
916+
917+
#scalar_access = [
918+
affine_map<() -> ()>,
919+
affine_map<() -> ()>,
920+
affine_map<() -> ()>
921+
]
922+
#scalar_trait = {
923+
args_in = 2,
924+
args_out = 1,
925+
iterator_types = [],
926+
indexing_maps = #scalar_access,
927+
library_call = "some_external_fn"
928+
}
929+
func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>)
930+
{
931+
linalg.generic #scalar_trait %arg0, %arg1, %arg2 {
932+
^bb(%a : f32, %b : f32, %c : f32) :
933+
%0 = addf %a, %b : f32
934+
linalg.yield %0 : f32
935+
} : memref<f32>, memref<f32>, memref<f32>
936+
return
937+
}
938+
// CHECKLOOP-LABEL: @scalar_code
939+
// CHECKLOOP-SAME: %[[ARG0]]: memref<f32>
940+
// CHECKLOOP-SAME: %[[ARG1]]: memref<f32>
941+
// CHECKLOOP-SAME: %[[ARG2]]: memref<f32>
942+
// CHECKLOOP-NOT: loop.for
943+
// CHECKLOOP-DAG: load %[[ARG0]][]
944+
// CHECKLOOP-DAG: load %[[ARG1]][]
945+
// CHECKLOOP-DAG: load %[[ARG2]][]
946+
// CHECKLOOP: addf
947+
// CHECKLOOP: store %{{.*}}, %[[ARG2]][]
948+
949+
// CHECKPARALLEL-LABEL: @scalar_code
950+
// CHECKPARALLEL-SAME: %[[ARG0]]: memref<f32>
951+
// CHECKPARALLEL-SAME: %[[ARG1]]: memref<f32>
952+
// CHECKPARALLEL-SAME: %[[ARG2]]: memref<f32>
953+
// CHECKPARALLEL-NOT: loop.for
954+
// CHECKPARALLEL-DAG: load %[[ARG0]][]
955+
// CHECKPARALLEL-DAG: load %[[ARG1]][]
956+
// CHECKPARALLEL-DAG: load %[[ARG2]][]
957+
// CHECKPARALLEL: addf
958+
// CHECKPARALLEL: store %{{.*}}, %[[ARG2]][]

0 commit comments

Comments
 (0)