Skip to content

Commit 4078b11

Browse files
authored
[MLIR][Affine] Fix fusion crash for non-int/fp memref elt types (llvm#126829)
Fix assumption on memref elt types being int or float during private memref creation in affine fusion. Fixes: llvm#121020
1 parent 6936fad commit 4078b11

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,9 @@ struct GreedyFusion {
759759
const DenseSet<Value> &srcEscapingMemRefs,
760760
unsigned producerId, unsigned consumerId,
761761
bool removeSrcNode) {
762+
// We can't generate private memrefs if their size can't be computed.
763+
if (!getMemRefIntOrFloatEltSizeInBytes(cast<MemRefType>(memref.getType())))
764+
return false;
762765
const Node *consumerNode = mdg->getNode(consumerId);
763766
// If `memref` is an escaping one, do not create a private memref
764767
// for the below scenarios, since doing so will leave the escaping

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
2+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer fusion-maximal}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER-MAXIMAL
23
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{fusion-maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
34
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(spirv.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=SPIRV
45

@@ -345,3 +346,37 @@ func.func @same_memref_load_multiple_stores(%producer : memref<32xf32>, %produce
345346
// PRODUCER-CONSUMER-NEXT: }
346347
return
347348
}
349+
350+
#map = affine_map<()[s0] -> (s0 + 5)>
351+
#map1 = affine_map<()[s0] -> (s0 + 17)>
352+
353+
// Test with non-int/float memref types.
354+
355+
// PRODUCER-CONSUMER-MAXIMAL-LABEL: func @memref_index_type
356+
func.func @memref_index_type() {
357+
%0 = llvm.mlir.constant(2 : index) : i64
358+
%2 = llvm.mlir.constant(0 : index) : i64
359+
%3 = builtin.unrealized_conversion_cast %2 : i64 to index
360+
%alloc = memref.alloc() {alignment = 64 : i64} : memref<8x18xf32>
361+
%alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<3xf32>
362+
%alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<3xindex>
363+
affine.for %arg3 = 0 to 3 {
364+
%4 = affine.load %alloc_2[%arg3] : memref<3xindex>
365+
%5 = builtin.unrealized_conversion_cast %4 : index to i64
366+
%6 = llvm.sub %0, %5 : i64
367+
%7 = builtin.unrealized_conversion_cast %6 : i64 to index
368+
affine.store %7, %alloc_2[%arg3] : memref<3xindex>
369+
}
370+
affine.for %arg3 = 0 to 3 {
371+
%4 = affine.load %alloc_2[%arg3] : memref<3xindex>
372+
%5 = affine.apply #map()[%4]
373+
%6 = affine.apply #map1()[%3]
374+
%7 = memref.load %alloc[%5, %6] : memref<8x18xf32>
375+
affine.store %7, %alloc_1[%arg3] : memref<3xf32>
376+
}
377+
// Expect fusion.
378+
// PRODUCER-CONSUMER-MAXIMAL: affine.for
379+
// PRODUCER-CONSUMER-MAXIMAL-NOT: affine.for
380+
// PRODUCER-CONSUMER-MAXIMAL: return
381+
return
382+
}

0 commit comments

Comments
 (0)