Skip to content

Commit 419c6da

Browse files
authored
[mlir][LLVM] Verify too many indices in GEP verifier (#70174)
The current verifier stopped verification with a success value as soon as a type was encountered that cannot be indexed into. The correct behaviour in this case is to error out as there are too many indices for the element type. Not doing so leads to bad user-experience as an invalid GEP is likely to fail only later during LLVM IR translation. This PR implements the correct verification behaviour. Some tests upstream had to also be fixed as they were creating invalid GEPs. Fixes #70168
1 parent 2399c77 commit 419c6da

File tree

5 files changed

+44
-79
lines changed

5 files changed

+44
-79
lines changed

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 28 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -664,90 +664,51 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
664664
});
665665
}
666666

667-
namespace {
668-
/// Base class for llvm::Error related to GEP index.
669-
class GEPIndexError : public llvm::ErrorInfo<GEPIndexError> {
670-
protected:
671-
unsigned indexPos;
672-
673-
public:
674-
static char ID;
675-
676-
std::error_code convertToErrorCode() const override {
677-
return llvm::inconvertibleErrorCode();
678-
}
679-
680-
explicit GEPIndexError(unsigned pos) : indexPos(pos) {}
681-
};
682-
683-
/// llvm::Error for out-of-bound GEP index.
684-
struct GEPIndexOutOfBoundError
685-
: public llvm::ErrorInfo<GEPIndexOutOfBoundError, GEPIndexError> {
686-
static char ID;
687-
688-
using ErrorInfo::ErrorInfo;
689-
690-
void log(llvm::raw_ostream &os) const override {
691-
os << "index " << indexPos << " indexing a struct is out of bounds";
692-
}
693-
};
694-
695-
/// llvm::Error for non-static GEP index indexing a struct.
696-
struct GEPStaticIndexError
697-
: public llvm::ErrorInfo<GEPStaticIndexError, GEPIndexError> {
698-
static char ID;
699-
700-
using ErrorInfo::ErrorInfo;
701-
702-
void log(llvm::raw_ostream &os) const override {
703-
os << "expected index " << indexPos << " indexing a struct "
704-
<< "to be constant";
705-
}
706-
};
707-
} // end anonymous namespace
708-
709-
char GEPIndexError::ID = 0;
710-
char GEPIndexOutOfBoundError::ID = 0;
711-
char GEPStaticIndexError::ID = 0;
712-
713-
/// For the given `structIndices` and `indices`, check if they're complied
714-
/// with `baseGEPType`, especially check against LLVMStructTypes nested within.
715-
static llvm::Error verifyStructIndices(Type baseGEPType, unsigned indexPos,
716-
GEPIndicesAdaptor<ValueRange> indices) {
667+
/// For the given `indices`, check if they comply with `baseGEPType`,
668+
/// especially check against LLVMStructTypes nested within.
669+
static LogicalResult
670+
verifyStructIndices(Type baseGEPType, unsigned indexPos,
671+
GEPIndicesAdaptor<ValueRange> indices,
672+
function_ref<InFlightDiagnostic()> emitOpError) {
717673
if (indexPos >= indices.size())
718674
// Stop searching
719-
return llvm::Error::success();
675+
return success();
720676

721-
return llvm::TypeSwitch<Type, llvm::Error>(baseGEPType)
722-
.Case<LLVMStructType>([&](LLVMStructType structType) -> llvm::Error {
677+
return TypeSwitch<Type, LogicalResult>(baseGEPType)
678+
.Case<LLVMStructType>([&](LLVMStructType structType) -> LogicalResult {
723679
if (!indices[indexPos].is<IntegerAttr>())
724-
return llvm::make_error<GEPStaticIndexError>(indexPos);
680+
return emitOpError() << "expected index " << indexPos
681+
<< " indexing a struct to be constant";
725682

726683
int32_t gepIndex = indices[indexPos].get<IntegerAttr>().getInt();
727684
ArrayRef<Type> elementTypes = structType.getBody();
728685
if (gepIndex < 0 ||
729686
static_cast<size_t>(gepIndex) >= elementTypes.size())
730-
return llvm::make_error<GEPIndexOutOfBoundError>(indexPos);
687+
return emitOpError() << "index " << indexPos
688+
<< " indexing a struct is out of bounds";
731689

732690
// Instead of recursively going into every children types, we only
733691
// dive into the one indexed by gepIndex.
734692
return verifyStructIndices(elementTypes[gepIndex], indexPos + 1,
735-
indices);
693+
indices, emitOpError);
736694
})
737695
.Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
738-
LLVMArrayType>([&](auto containerType) -> llvm::Error {
696+
LLVMArrayType>([&](auto containerType) -> LogicalResult {
739697
return verifyStructIndices(containerType.getElementType(), indexPos + 1,
740-
indices);
698+
indices, emitOpError);
741699
})
742-
.Default(
743-
[](auto otherType) -> llvm::Error { return llvm::Error::success(); });
700+
.Default([&](auto otherType) -> LogicalResult {
701+
return emitOpError()
702+
<< "type " << otherType << " cannot be indexed (index #"
703+
<< indexPos << ")";
704+
});
744705
}
745706

746-
/// Driver function around `recordStructIndices`. Note that we always check
747-
/// from the second GEP index since the first one is always dynamic.
748-
static llvm::Error verifyStructIndices(Type baseGEPType,
749-
GEPIndicesAdaptor<ValueRange> indices) {
750-
return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices);
707+
/// Driver function around `verifyStructIndices`.
708+
static LogicalResult
709+
verifyStructIndices(Type baseGEPType, GEPIndicesAdaptor<ValueRange> indices,
710+
function_ref<InFlightDiagnostic()> emitOpError) {
711+
return verifyStructIndices(baseGEPType, /*indexPos=*/1, indices, emitOpError);
751712
}
752713

753714
LogicalResult LLVM::GEPOp::verify() {
@@ -763,11 +724,8 @@ LogicalResult LLVM::GEPOp::verify() {
763724
return emitOpError("expected as many dynamic indices as specified in '")
764725
<< getRawConstantIndicesAttrName().getValue() << "'";
765726

766-
if (llvm::Error err =
767-
verifyStructIndices(getSourceElementType(), getIndices()))
768-
return emitOpError() << llvm::toString(std::move(err));
769-
770-
return success();
727+
return verifyStructIndices(getSourceElementType(), getIndices(),
728+
[&] { return emitOpError(); });
771729
}
772730

773731
Type LLVM::GEPOp::getSourceElementType() {

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1431,3 +1431,11 @@ llvm.func @invalid_variadic_call(%arg: i32) {
14311431
"llvm.call"(%arg) <{callee = @variadic}> : (i32) -> ()
14321432
llvm.return
14331433
}
1434+
1435+
// -----
1436+
1437+
llvm.func @foo(%arg: !llvm.ptr) {
1438+
// expected-error@+1 {{type '!llvm.ptr' cannot be indexed (index #1)}}
1439+
%0 = llvm.getelementptr %arg[0, 4] : (!llvm.ptr) -> !llvm.ptr, !llvm.ptr
1440+
llvm.return
1441+
}

mlir/test/Dialect/LLVMIR/mem2reg.mlir

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,7 @@ llvm.func @trivial_get_element_ptr() {
549549
%1 = llvm.mlir.constant(2 : i64) : i64
550550
%2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
551551
%3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
552-
%4 = llvm.getelementptr %3[0, 0, 0] : (!llvm.ptr) -> !llvm.ptr, i8
552+
%4 = llvm.getelementptr %3[0] : (!llvm.ptr) -> !llvm.ptr, i8
553553
llvm.intr.lifetime.start 2, %3 : !llvm.ptr
554554
llvm.intr.lifetime.start 2, %4 : !llvm.ptr
555555
llvm.return
@@ -563,9 +563,8 @@ llvm.func @nontrivial_get_element_ptr() {
563563
%1 = llvm.mlir.constant(2 : i64) : i64
564564
// CHECK: = llvm.alloca
565565
%2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
566-
%3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
567-
%4 = llvm.getelementptr %3[0, 1, 0] : (!llvm.ptr) -> !llvm.ptr, i8
568-
llvm.intr.lifetime.start 2, %3 : !llvm.ptr
566+
%4 = llvm.getelementptr %2[1] : (!llvm.ptr) -> !llvm.ptr, i8
567+
llvm.intr.lifetime.start 2, %2 : !llvm.ptr
569568
llvm.intr.lifetime.start 2, %4 : !llvm.ptr
570569
llvm.return
571570
}
@@ -579,7 +578,7 @@ llvm.func @dynamic_get_element_ptr() {
579578
// CHECK: = llvm.alloca
580579
%2 = llvm.alloca %0 x i8 {alignment = 8 : i64} : (i32) -> !llvm.ptr
581580
%3 = llvm.bitcast %2 : !llvm.ptr to !llvm.ptr
582-
%4 = llvm.getelementptr %3[0, %0] : (!llvm.ptr, i32) -> !llvm.ptr, i8
581+
%4 = llvm.getelementptr %3[%0] : (!llvm.ptr, i32) -> !llvm.ptr, i8
583582
llvm.intr.lifetime.start 2, %3 : !llvm.ptr
584583
llvm.intr.lifetime.start 2, %4 : !llvm.ptr
585584
llvm.return

mlir/test/Dialect/LLVMIR/roundtrip-typed-pointers.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ func.func @ops(%arg0: i32) {
66
// Memory-related operations.
77
//
88
// CHECK-NEXT: %[[ALLOCA:.*]] = llvm.alloca %[[I32]] x f64 : (i32) -> !llvm.ptr<f64>
9-
// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]], %[[I32]]] : (!llvm.ptr<f64>, i32, i32) -> !llvm.ptr<f64>
9+
// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]]] : (!llvm.ptr<f64>, i32) -> !llvm.ptr<f64>
1010
// CHECK-NEXT: %[[VALUE:.*]] = llvm.load %[[GEP]] : !llvm.ptr<f64>
1111
// CHECK-NEXT: llvm.store %[[VALUE]], %[[ALLOCA]] : !llvm.ptr<f64>
1212
// CHECK-NEXT: %{{.*}} = llvm.bitcast %[[ALLOCA]] : !llvm.ptr<f64> to !llvm.ptr<i64>
1313
%13 = llvm.alloca %arg0 x f64 : (i32) -> !llvm.ptr<f64>
14-
%14 = llvm.getelementptr %13[%arg0, %arg0] : (!llvm.ptr<f64>, i32, i32) -> !llvm.ptr<f64>
14+
%14 = llvm.getelementptr %13[%arg0] : (!llvm.ptr<f64>, i32) -> !llvm.ptr<f64>
1515
%15 = llvm.load %14 : !llvm.ptr<f64>
1616
llvm.store %15, %13 : !llvm.ptr<f64>
1717
%16 = llvm.bitcast %13 : !llvm.ptr<f64> to !llvm.ptr<i64>

mlir/test/Dialect/LLVMIR/roundtrip.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,11 @@ func.func @ops(%arg0: i32, %arg1: f32,
5050
// Memory-related operations.
5151
//
5252
// CHECK-NEXT: %[[ALLOCA:.*]] = llvm.alloca %[[I32]] x f64 : (i32) -> !llvm.ptr
53-
// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]], %[[I32]]] : (!llvm.ptr, i32, i32) -> !llvm.ptr, f64
53+
// CHECK-NEXT: %[[GEP:.*]] = llvm.getelementptr %[[ALLOCA]][%[[I32]]] : (!llvm.ptr, i32) -> !llvm.ptr, f64
5454
// CHECK-NEXT: %[[VALUE:.*]] = llvm.load %[[GEP]] : !llvm.ptr -> f64
5555
// CHECK-NEXT: llvm.store %[[VALUE]], %[[ALLOCA]] : f64, !llvm.ptr
5656
%13 = llvm.alloca %arg0 x f64 : (i32) -> !llvm.ptr
57-
%14 = llvm.getelementptr %13[%arg0, %arg0] : (!llvm.ptr, i32, i32) -> !llvm.ptr, f64
57+
%14 = llvm.getelementptr %13[%arg0] : (!llvm.ptr, i32) -> !llvm.ptr, f64
5858
%15 = llvm.load %14 : !llvm.ptr -> f64
5959
llvm.store %15, %13 : f64, !llvm.ptr
6060

0 commit comments

Comments
 (0)