Skip to content

Commit 2dbaac2

Browse files
committed
Update MLIR to support nusw and nuw in GEP.
nusw and nuw were introduced in getelementptr, this patch plumbs them in MLIR. Since inbounds implies nusw, this patch also adds an inboundsFlag to represent the concept of raw inbounds with no nusw implication, and have the inbounds literal captured as the combination of inboundsFlag and nusw. Fixes: iree#20482 Signed-off-by: Lin, Peiyong <[email protected]>
1 parent b6746b0 commit 2dbaac2

File tree

6 files changed

+63
-19
lines changed

6 files changed

+63
-19
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMEnums.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -876,4 +876,32 @@ def UWTableKindEnum : LLVM_EnumAttr<
876876
let cppNamespace = "::mlir::LLVM::uwtable";
877877
}
878878

879+
//===----------------------------------------------------------------------===//
880+
// GEPNoWrapFlags
881+
//===----------------------------------------------------------------------===//
882+
883+
// These values must match llvm::GEPNoWrapFlags ones.
884+
// See llvm/include/llvm/IR/GEPNoWrapFlags.h.
885+
// Since inbounds implies nusw, create an inboundsFlag that represents the
886+
// concept of raw inbounds with no nusw implication and the actual inbounds
887+
// literal will be captured as the combination of inboundsFlag and nusw.
888+
889+
def GEPNone : I32BitEnumCaseNone<"none">;
890+
def GEPInboundsFlag : I32BitEnumCaseBit<"inboundsFlag", 0, "inbounds_flag">;
891+
def GEPNusw : I32BitEnumCaseBit<"nusw", 1>;
892+
def GEPNuw : I32BitEnumCaseBit<"nuw", 2>;
893+
def GEPInbounds : BitEnumCaseGroup<"inbounds", [GEPInboundsFlag, GEPNusw]>;
894+
895+
def GEPNoWrapFlags : I32BitEnum<
896+
"GEPNoWrapFlags",
897+
"::mlir::LLVM::GEPNoWrapFlags",
898+
[GEPNone, GEPInboundsFlag, GEPNusw, GEPNuw, GEPInbounds]> {
899+
let cppNamespace = "::mlir::LLVM";
900+
let printBitEnumPrimaryGroups = 1;
901+
}
902+
903+
def GEPNoWrapFlagsProp : EnumProp<GEPNoWrapFlags> {
904+
let defaultValue = interfaceType # "::none";
905+
}
906+
879907
#endif // LLVMIR_ENUMS

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
291291
Variadic<LLVM_ScalarOrVectorOf<AnySignlessInteger>>:$dynamicIndices,
292292
DenseI32ArrayAttr:$rawConstantIndices,
293293
TypeAttr:$elem_type,
294-
UnitAttr:$inbounds);
294+
GEPNoWrapFlagsProp:$gepNoWrapFlags);
295295
let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$res);
296296
let skipDefaultBuilders = 1;
297297

@@ -303,8 +303,10 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
303303
as indices. In the case of indexing within a structure, it is required to
304304
either use constant indices directly, or supply a constant SSA value.
305305

306-
An optional 'inbounds' attribute specifies the low-level pointer arithmetic
306+
Optional attributes can be used to specify the low-level pointer arithmetic
307307
overflow behavior that LLVM uses after lowering the operation to LLVM IR.
308+
The acceptable attributes could be one of or the combination of 'inbounds',
309+
'nusw' or 'nuw'.
308310

309311
Examples:
310312

@@ -323,10 +325,12 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
323325

324326
let builders = [
325327
OpBuilder<(ins "Type":$resultType, "Type":$elementType, "Value":$basePtr,
326-
"ValueRange":$indices, CArg<"bool", "false">:$inbounds,
328+
"ValueRange":$indices,
329+
CArg<"GEPNoWrapFlags", "GEPNoWrapFlags::none">:$gepNoWrapFlags,
327330
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
328331
OpBuilder<(ins "Type":$resultType, "Type":$elementType, "Value":$basePtr,
329-
"ArrayRef<GEPArg>":$indices, CArg<"bool", "false">:$inbounds,
332+
"ArrayRef<GEPArg>":$indices,
333+
CArg<"GEPNoWrapFlags", "GEPNoWrapFlags::none">:$gepNoWrapFlags,
330334
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes)>,
331335
];
332336
let llvmBuilder = [{
@@ -343,10 +347,13 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [Pure,
343347
}
344348
Type baseElementType = op.getElemType();
345349
llvm::Type *elementType = moduleTranslation.convertType(baseElementType);
346-
$res = builder.CreateGEP(elementType, $base, indices, "", $inbounds);
350+
$res = builder.CreateGEP(elementType, $base, indices, "",
351+
llvm::GEPNoWrapFlags::fromRaw(
352+
static_cast<unsigned>(
353+
op.getGepNoWrapFlags())));
347354
}];
348355
let assemblyFormat = [{
349-
(`inbounds` $inbounds^)?
356+
($gepNoWrapFlags^)?
350357
$base `[` custom<GEPIndices>($dynamicIndices, $rawConstantIndices) `]` attr-dict
351358
`:` functional-type(operands, results) `,` $elem_type
352359
}];

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -673,29 +673,29 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices,
673673

674674
void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
675675
Type elementType, Value basePtr, ArrayRef<GEPArg> indices,
676-
bool inbounds, ArrayRef<NamedAttribute> attributes) {
676+
GEPNoWrapFlags noWrapFlags,
677+
ArrayRef<NamedAttribute> attributes) {
677678
SmallVector<int32_t> rawConstantIndices;
678679
SmallVector<Value> dynamicIndices;
679680
destructureIndices(elementType, indices, rawConstantIndices, dynamicIndices);
680681

681682
result.addTypes(resultType);
682683
result.addAttributes(attributes);
683-
result.addAttribute(getRawConstantIndicesAttrName(result.name),
684-
builder.getDenseI32ArrayAttr(rawConstantIndices));
685-
if (inbounds) {
686-
result.addAttribute(getInboundsAttrName(result.name),
687-
builder.getUnitAttr());
688-
}
689-
result.addAttribute(kElemTypeAttrName, TypeAttr::get(elementType));
684+
result.getOrAddProperties<Properties>().rawConstantIndices =
685+
builder.getDenseI32ArrayAttr(rawConstantIndices);
686+
result.getOrAddProperties<Properties>().gepNoWrapFlags = noWrapFlags;
687+
result.getOrAddProperties<Properties>().elem_type =
688+
TypeAttr::get(elementType);
690689
result.addOperands(basePtr);
691690
result.addOperands(dynamicIndices);
692691
}
693692

694693
void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
695694
Type elementType, Value basePtr, ValueRange indices,
696-
bool inbounds, ArrayRef<NamedAttribute> attributes) {
695+
GEPNoWrapFlags noWrapFlags,
696+
ArrayRef<NamedAttribute> attributes) {
697697
build(builder, result, resultType, elementType, basePtr,
698-
SmallVector<GEPArg>(indices), inbounds, attributes);
698+
SmallVector<GEPArg>(indices), noWrapFlags, attributes);
699699
}
700700

701701
static ParseResult

mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ DeletionKind LLVM::GEPOp::rewire(const DestructurableMemorySlot &slot,
891891
auto byteType = IntegerType::get(builder.getContext(), 8);
892892
auto newPtr = builder.createOrFold<LLVM::GEPOp>(
893893
getLoc(), getResult().getType(), byteType, newSlot.ptr,
894-
ArrayRef<GEPArg>(accessInfo->subslotOffset), getInbounds());
894+
ArrayRef<GEPArg>(accessInfo->subslotOffset), getGepNoWrapFlags());
895895
getResult().replaceAllUsesWith(newPtr);
896896
return DeletionKind::Delete;
897897
}

mlir/lib/Target/LLVMIR/ModuleImport.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1984,8 +1984,9 @@ LogicalResult ModuleImport::convertInstruction(llvm::Instruction *inst) {
19841984
}
19851985

19861986
Type type = convertType(inst->getType());
1987-
auto gepOp = builder.create<GEPOp>(loc, type, sourceElementType, *basePtr,
1988-
indices, gepInst->isInBounds());
1987+
auto gepOp = builder.create<GEPOp>(
1988+
loc, type, sourceElementType, *basePtr, indices,
1989+
gepInst->isInBounds() ? GEPNoWrapFlags::inbounds : GEPNoWrapFlags());
19891990
mapValue(inst, gepOp);
19901991
return success();
19911992
}

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,14 @@ llvm.func @gep(%ptr: !llvm.ptr, %idx: i64,
10571057
llvm.getelementptr %ptr[%idx, 1, 0] : (!llvm.ptr, i64) -> !llvm.ptr, !llvm.struct<(i32, struct<(i32, f32)>)>
10581058
// CHECK: = getelementptr inbounds { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
10591059
llvm.getelementptr inbounds %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
1060+
// CHECK: = getelementptr inbounds nuw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
1061+
llvm.getelementptr inbounds | nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
1062+
// CHECK: = getelementptr nusw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
1063+
llvm.getelementptr nusw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
1064+
// CHECK: = getelementptr nusw nuw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
1065+
llvm.getelementptr nusw | nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
1066+
// CHECK: = getelementptr nuw { [10 x float] }, ptr %{{.*}}, i64 %{{.*}}, i32 0, i64 %{{.*}}
1067+
llvm.getelementptr nuw %ptr2[%idx, 0, %idx] : (!llvm.ptr, i64, i64) -> !llvm.ptr, !llvm.struct<(array<10 x f32>)>
10601068
llvm.return
10611069
}
10621070

0 commit comments

Comments
 (0)