Skip to content

Commit 45d8759

Browse files
authored
Emit nuw and nsw for mul and add when lowering to llvm.getelementptr (#140966)
Now that the GEP no wrap flags are known when lowering to llvm.getelementptr, we can also emit nuw and nsw for the generated llvm.mul and llvm.add when no unsigned wrap and no signed wrap are used respectively. fixes: iree-org/iree#20483 Signed-off-by: Lin, Peiyong <[email protected]>
1 parent 656d9ba commit 45d8759

File tree

3 files changed

+26
-15
lines changed

3 files changed

+26
-15
lines changed

mlir/lib/Conversion/LLVMCommon/Pattern.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,15 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
7373
Value base =
7474
memRefDescriptor.bufferPtr(rewriter, loc, *getTypeConverter(), type);
7575

76+
LLVM::IntegerOverflowFlags intOverflowFlags =
77+
LLVM::IntegerOverflowFlags::none;
78+
if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
79+
intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
80+
}
81+
if (LLVM::bitEnumContainsAny(noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
82+
intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
83+
}
84+
7685
Type indexType = getIndexType();
7786
Value index;
7887
for (int i = 0, e = indices.size(); i < e; ++i) {
@@ -82,10 +91,12 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
8291
ShapedType::isDynamic(strides[i])
8392
? memRefDescriptor.stride(rewriter, loc, i)
8493
: createIndexAttrConstant(rewriter, loc, indexType, strides[i]);
85-
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride);
94+
increment = rewriter.create<LLVM::MulOp>(loc, increment, stride,
95+
intOverflowFlags);
8696
}
87-
index =
88-
index ? rewriter.create<LLVM::AddOp>(loc, index, increment) : increment;
97+
index = index ? rewriter.create<LLVM::AddOp>(loc, index, increment,
98+
intOverflowFlags)
99+
: increment;
89100
}
90101

91102
Type elementPtrType = memRefDescriptor.getElementPtrType();

mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,8 @@ func.func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) {
175175
// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
176176
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
177177
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
178-
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
179-
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
178+
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow<nsw, nuw> : i64
179+
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow<nsw, nuw> : i64
180180
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
181181
// CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr -> f32
182182
%0 = memref.load %mixed[%i, %j] : memref<42x?xf32>
@@ -192,8 +192,8 @@ func.func @dynamic_load(%dynamic : memref<?x?xf32>, %i : index, %j : index) {
192192
// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
193193
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
194194
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
195-
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
196-
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
195+
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow<nsw, nuw> : i64
196+
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow<nsw, nuw> : i64
197197
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
198198
// CHECK-NEXT: llvm.load %[[addr]] : !llvm.ptr -> f32
199199
%0 = memref.load %dynamic[%i, %j] : memref<?x?xf32>
@@ -230,8 +230,8 @@ func.func @dynamic_store(%dynamic : memref<?x?xf32>, %i : index, %j : index, %va
230230
// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
231231
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
232232
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
233-
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
234-
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
233+
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow<nsw, nuw> : i64
234+
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow<nsw, nuw> : i64
235235
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
236236
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr
237237
memref.store %val, %dynamic[%i, %j] : memref<?x?xf32>
@@ -247,8 +247,8 @@ func.func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val :
247247
// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]]
248248
// CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
249249
// CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
250-
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64
251-
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] : i64
250+
// CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] overflow<nsw, nuw> : i64
251+
// CHECK-NEXT: %[[off1:.*]] = llvm.add %[[offI]], %[[J]] overflow<nsw, nuw> : i64
252252
// CHECK-NEXT: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
253253
// CHECK-NEXT: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr
254254
memref.store %val, %mixed[%i, %j] : memref<42x?xf32>

mlir/test/Conversion/MemRefToLLVM/convert-static-memref-ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,8 @@ func.func @static_load(%static : memref<10x42xf32>, %i : index, %j : index) {
138138
// CHECK-DAG: %[[JJ:.*]] = builtin.unrealized_conversion_cast %[[J]]
139139
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
140140
// CHECK: %[[st0:.*]] = llvm.mlir.constant(42 : index) : i64
141-
// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] : i64
142-
// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] : i64
141+
// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] overflow<nsw, nuw> : i64
142+
// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] overflow<nsw, nuw> : i64
143143
// CHECK: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
144144
// CHECK: llvm.load %[[addr]] : !llvm.ptr -> f32
145145
%0 = memref.load %static[%i, %j] : memref<10x42xf32>
@@ -166,8 +166,8 @@ func.func @static_store(%static : memref<10x42xf32>, %i : index, %j : index, %va
166166
// CHECK-DAG: %[[JJ:.*]] = builtin.unrealized_conversion_cast %[[J]]
167167
// CHECK: %[[ptr:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
168168
// CHECK: %[[st0:.*]] = llvm.mlir.constant(42 : index) : i64
169-
// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] : i64
170-
// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] : i64
169+
// CHECK: %[[offI:.*]] = llvm.mul %[[II]], %[[st0]] overflow<nsw, nuw> : i64
170+
// CHECK: %[[off1:.*]] = llvm.add %[[offI]], %[[JJ]] overflow<nsw, nuw> : i64
171171
// CHECK: %[[addr:.*]] = llvm.getelementptr inbounds|nuw %[[ptr]][%[[off1]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
172172
// CHECK: llvm.store %{{.*}}, %[[addr]] : f32, !llvm.ptr
173173

0 commit comments

Comments
 (0)