Skip to content

Commit b6577d3

Browse files
committed
Add AffineScope trait to scf.index_switch
Adding the `AffineScope` trait to `scf.index_switch` solves #64287. This defines a new top level scope for symbols, which proves highly practical as it enables a broader range of things to be represented as sequences of affine loop nests. I suggest that we add this trait to other SCF operations like `scf.for`, `scf.while` as well, as that will be very useful, and should not have any major issues that we cannot solve.
1 parent 19d1da5 commit b6577d3

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1101,7 +1101,7 @@ def WhileOp : SCF_Op<"while",
11011101
// IndexSwitchOp
11021102
//===----------------------------------------------------------------------===//
11031103

1104-
def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects,
1104+
def IndexSwitchOp : SCF_Op<"index_switch", [RecursiveMemoryEffects, AffineScope,
11051105
SingleBlockImplicitTerminator<"scf::YieldOp">,
11061106
DeclareOpInterfaceMethods<RegionBranchOpInterface,
11071107
["getRegionInvocationBounds",

mlir/test/Dialect/Affine/parallelize.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,22 @@ func.func @iter_arg_memrefs(%in: memref<10xf32>) {
323323
}
324324
return
325325
}
326+
327+
// CHECK-LABEL: @scf_affine_scope
328+
func.func @scf_affine_scope() {
329+
%c0 = arith.constant 0 : index
330+
%0 = tensor.empty(%c0) : tensor<?xi1>
331+
%1 = bufferization.to_memref %0 : memref<?xi1>
332+
%alloc = memref.alloc(%c0) : memref<?xi1>
333+
%2 = scf.index_switch %c0 -> tensor<?x31x6xf16>
334+
default {
335+
%dim = memref.dim %1, %c0 : memref<?xi1>
336+
affine.for %arg0 = 0 to %dim {
337+
%3 = affine.load %1[%arg0] : memref<?xi1>
338+
affine.store %3, %alloc[%arg0] : memref<?xi1>
339+
}
340+
%3 = tensor.empty(%c0) : tensor<?x31x6xf16>
341+
scf.yield %3 : tensor<?x31x6xf16>
342+
}
343+
return
344+
}

0 commit comments

Comments
 (0)