Skip to content

Commit b9f492c

Browse files
committed
[MLIR][Affine] Fix fusion crash from memory space int assumption
Fix fusion crash from memory space int assumption from assumption on int attr-based memory spaces. Fixes: #118759
1 parent 298caeb commit b9f492c

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -339,14 +339,14 @@ static Value createPrivateMemRef(AffineForOp forOp,
339339
auto eltSize = getMemRefIntOrFloatEltSizeInBytes(oldMemRefType);
340340
assert(eltSize && "memrefs with size elt types expected");
341341
uint64_t bufSize = *eltSize * *numElements;
342-
unsigned newMemSpace;
342+
Attribute newMemSpace;
343343
if (bufSize <= localBufSizeThreshold && fastMemorySpace.has_value()) {
344-
newMemSpace = *fastMemorySpace;
344+
newMemSpace = b.getI64IntegerAttr(*fastMemorySpace);
345345
} else {
346-
newMemSpace = oldMemRefType.getMemorySpaceAsInt();
346+
newMemSpace = oldMemRefType.getMemorySpace();
347347
}
348348
auto newMemRefType = MemRefType::get(newShape, oldMemRefType.getElementType(),
349-
{}, newMemSpace);
349+
/*map=*/AffineMap(), newMemSpace);
350350

351351
// Create new private memref for fused loop 'forOp'. 'newShape' is always
352352
// a constant shape.

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,33 @@ func.func @memref_index_type() {
391391
// PRODUCER-CONSUMER-MAXIMAL: return
392392
return
393393
}
394+
395+
#map = affine_map<(d0) -> (d0)>
396+
#map1 =affine_map<(d0) -> (d0 + 1)>
397+
398+
// Test non-integer memory spaces.
399+
400+
// PRODUCER-CONSUMER-LABEL: func @non_int_memory_space
401+
func.func @non_int_memory_space() {
402+
%alloc = memref.alloc() : memref<256x8xf32, #spirv.storage_class<StorageBuffer>>
403+
affine.for %arg0 = 0 to 64 {
404+
affine.for %arg1 = 0 to 8 {
405+
%0 = affine.apply #map(%arg1)
406+
%1 = affine.load %alloc[%arg0, %0] : memref<256x8xf32, #spirv.storage_class<StorageBuffer>>
407+
affine.store %1, %alloc[%arg0, %arg1] : memref<256x8xf32, #spirv.storage_class<StorageBuffer>>
408+
}
409+
}
410+
affine.for %arg0 = 16 to 32 {
411+
affine.for %arg1 = 0 to 8 {
412+
%0 = affine.apply #map(%arg1)
413+
%1 = affine.load %alloc[%arg0, %0] : memref<256x8xf32, #spirv.storage_class<StorageBuffer>>
414+
affine.store %1, %alloc[%arg0, %arg1] : memref<256x8xf32, #spirv.storage_class<StorageBuffer>>
415+
}
416+
}
417+
// Fused nest.
418+
// PRODUCER-CONSUMER-NEXT: memref.alloc()
419+
// PRODUCER-CONSUMER-NEXT: memref.alloc()
420+
// PRODUCER-CONSUMER: affine.for %{{.*}} = 16 to 32
421+
// PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 0 to 8
422+
return
423+
}

0 commit comments

Comments
 (0)