Skip to content

Commit 5198205

Browse files
authored
[mlir][affine] Make [de]linearize_index a valid source of dims (#138929)
There's a sense in which affine.linearize_index and affine.delinearize_index are special-cases of affine.apply (which get their own ops to enable better code generation and more accurate canonicalization). Therefore, allow these operations to be dimension operands for operations like affine.load just like affine.apply can be.
1 parent af6261b commit 5198205

File tree

3 files changed

+53
-1
lines changed

3 files changed

+53
-1
lines changed

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,8 @@ bool mlir::affine::isValidDim(Value value) {
305305
// *) It is valid as a symbol.
306306
// *) It is an induction variable.
307307
// *) It is the result of an affine apply operation with dimension id operands.
308+
// *) It is the result of a more specialized index transformation (ex.
309+
// delinearize_index or linearize_index) with dimension id operands.
308310
bool mlir::affine::isValidDim(Value value, Region *region) {
309311
// The value must be an index type.
310312
if (!value.getType().isIndex())
@@ -325,6 +327,11 @@ bool mlir::affine::isValidDim(Value value, Region *region) {
325327
// Affine apply operation is ok if all of its operands are ok.
326328
if (auto applyOp = dyn_cast<AffineApplyOp>(op))
327329
return applyOp.isValidDim(region);
330+
// delinearize_index and linearize_index are special forms of apply
331+
// and so are valid dimensions if all their arguments are valid dimensions.
332+
if (isa<AffineDelinearizeIndexOp, AffineLinearizeIndexOp>(op))
333+
return llvm::all_of(op->getOperands(),
334+
[&](Value arg) { return ::isValidDim(arg, region); });
328335
// The dim op is okay if its operand memref/tensor is defined at the top
329336
// level.
330337
if (auto dimOp = dyn_cast<ShapedDimOpInterface>(op))

mlir/test/Dialect/Affine/invalid.mlir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,34 @@ func.func @dynamic_dimension_index() {
544544

545545
// -----
546546

547+
func.func @dynamic_linearized_index() {
548+
"unknown.region"() ({
549+
%idx = "unknown.test"() : () -> (index)
550+
%memref = "unknown.test"() : () -> memref<?xf32>
551+
%pos = affine.linearize_index [%idx, %idx] by (8) : index
552+
// expected-error@below {{op operand cannot be used as a dimension id}}
553+
affine.load %memref[%pos] : memref<?xf32>
554+
"unknown.terminator"() : () -> ()
555+
}) : () -> ()
556+
return
557+
}
558+
559+
// -----
560+
561+
func.func @dynamic_delinearized_index() {
562+
"unknown.region"() ({
563+
%idx = "unknown.test"() : () -> (index)
564+
%memref = "unknown.test"() : () -> memref<?x?xf32>
565+
%pos0, %pos1 = affine.delinearize_index %idx into (8) : index, index
566+
// expected-error@below {{op operand cannot be used as a dimension id}}
567+
affine.load %memref[%pos0, %pos1] : memref<?x?xf32>
568+
"unknown.terminator"() : () -> ()
569+
}) : () -> ()
570+
return
571+
}
572+
573+
// -----
574+
547575
#map = affine_map<() -> ()>
548576
#map1 = affine_map<() -> (1)>
549577
func.func @no_lower_bound() {

mlir/test/Dialect/Affine/ops.mlir

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,23 @@ func.func @valid_symbol_affine_scope(%n : index, %A : memref<?xf32>) {
148148

149149
// -----
150150

151+
// Test dimension constraints for linearize_index and delinearize_index
152+
153+
// CHECK-LABEL: func @valid_dim_linearize_delinearize
154+
func.func @valid_dim_linearize_delinearize(%m : index, %n : index, %A : memref<?xf32>, %B: memref<?x32x?xf32>) {
155+
affine.for %0 = 0 to %m {
156+
affine.for %1 = 0 to %n {
157+
%load_idx = affine.linearize_index disjoint [%0, %1] by (%m, %n) : index
158+
%store_idx0, %store_idx1 = affine.delinearize_index %n into (32) : index, index
159+
%v = affine.load %A[%load_idx] : memref<?xf32>
160+
affine.store %v, %B[%0, %store_idx1, %store_idx0] : memref<?x32x?xf32>
161+
}
162+
}
163+
return
164+
}
165+
166+
// -----
167+
151168
// Test the fact that module op always provides an affine scope.
152169

153170
%idx = "test.foo"() : () -> (index)
@@ -309,7 +326,7 @@ func.func @linearize_mixed(%index0: index, %index1: index, %index2: index, %basi
309326
module {
310327
func.func @gpu_launch_affine() {
311328
%c1 = arith.constant 1 : index
312-
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %c1, %arg7 = %c1, %arg8 = %c1)
329+
gpu.launch blocks(%arg0, %arg1, %arg2) in (%arg6 = %c1, %arg7 = %c1, %arg8 = %c1)
313330
threads(%arg3, %arg4, %arg5) in (%arg9 = %c1, %arg10 = %c1, %arg11 = %c1) {
314331
%thread_id_x = gpu.thread_id x
315332
%c128 = arith.constant 128 : index

0 commit comments

Comments
 (0)