Skip to content

Commit 6fa671f

Browse files
authored
[MLIR][Affine] Fix fusion crash from memory space int assumption (#127032)
Fix fusion crash from memory space int assumption from assumption on int attr-based memory spaces. Fixes: #118759
1 parent e8dba3a commit 6fa671f

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,14 +339,14 @@ static Value createPrivateMemRef(AffineForOp forOp,
339339
auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
340340
assert(eltSize && "memrefs with size elt types expected");
341341
uint64_t bufSize = *eltSize * *numElements;
342-
unsigned newMemSpace;
342+
Attribute newMemSpace;
343343
if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
344-
newMemSpace = *fastMemorySpace;
344+
newMemSpace = b.getI64IntegerAttr(*fastMemorySpace);
345345
} else {
346-
newMemSpace = oldMemRefType.getMemorySpaceAsInt();
346+
newMemSpace = oldMemRefType.getMemorySpace();
347347
}
348348
auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
349-
{}, newMemSpace);
349+
/*map=*/AffineMap(), newMemSpace);
350350

351351
// Create new private memref for fused loop 'forOp'. 'newShape' is always
352352
// a constant shape.

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,8 @@ func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %produce
358358
return
359359
}
360360

361+
// -----
362+
361363
#map = affine_map<()[s0] -> (s0 + 5)>
362364
#map1 = affine_map<()[s0] -> (s0 + 17)>
363365

@@ -391,3 +393,35 @@ func.func @memref_index_type() {
391393
// PRODUCER-CONSUMER-MAXIMAL: return
392394
return
393395
}
396+
397+
// -----
398+
399+
#map = affine_map<(d0) -> (d0)>
400+
#map1 =affine_map<(d0) -> (d0 + 1)>
401+
402+
// Test non-integer memory spaces.
403+
404+
// PRODUCER-CONSUMER-LABEL: func @non_int_memory_space
405+
func.func @non_int_memory_space() {
406+
%alloc = memref.alloc() : memref<256x8xf32, #spirv.storage_class<StorageBuffer>>
407+
affine.for %arg0 = 0 to 64 {
408+
affine.for %arg1 = 0 to 8 {
409+
%0 = affine.apply #map(%arg1)
410+
%1 = affine.load %alloc[%arg0, %0] : memref<256x8xf32, #spirv.storage_class<StorageBuffer>>
411+
affine.store %1, %alloc[%arg0, %arg1] : memref<256x8xf32, #spirv.storage_class<StorageBuffer>>
412+
}
413+
}
414+
affine.for %arg0 = 16 to 32 {
415+
affine.for %arg1 = 0 to 8 {
416+
%0 = affine.apply #map(%arg1)
417+
%1 = affine.load %alloc[%arg0, %0] : memref<256x8xf32, #spirv.storage_class<StorageBuffer>>
418+
affine.store %1, %alloc[%arg0, %arg1] : memref<256x8xf32, #spirv.storage_class<StorageBuffer>>
419+
}
420+
}
421+
// Fused nest.
422+
// PRODUCER-CONSUMER-NEXT: memref.alloc()
423+
// PRODUCER-CONSUMER-NEXT: memref.alloc()
424+
// PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 16 to 32
425+
// PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 0 to 8
426+
return
427+
}

0 commit comments

Comments
 (0)