Skip to content

Commit f50cfc4

Browse files
committed
[mlir] Require struct indices in LLVM::GEPOp to be constant
Recent commits added a possibility for indices in LLVM dialect GEP operations to be supplied directly as constant attributes to ensure they remain such until translation to LLVM IR happens. Make this required for indexing into LLVM struct types to match LLVM IR requirements, otherwise the translation would assert on constructing such IR. For better compatibility with MLIR-style operation construction interface, allow GEP operations to be constructed programmatically using Values pointing to known constant operations as struct indices. Depends On D116758 Reviewed By: wsmoses Differential Revision: https://reviews.llvm.org/D116759
1 parent 43ff4a6 commit f50cfc4

File tree

7 files changed

+146
-9
lines changed

7 files changed

+146
-9
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,9 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> {
350350
constexpr static int kDynamicIndex = std::numeric_limits<int32_t>::min();
351351
}];
352352
let hasFolder = 1;
353+
let verifier = [{
354+
return ::verify(*this);
355+
}];
353356
}
354357

355358
def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 122 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,58 @@ SwitchOp::getMutableSuccessorOperands(unsigned index) {
360360
// Code for LLVM::GEPOp.
361361
//===----------------------------------------------------------------------===//
362362

363+
/// Populates `indices` with positions of GEP indices that would correspond to
364+
/// LLVMStructTypes potentially nested in the given type. The type currently
365+
/// visited gets `currentIndex` and LLVM container types are visited
366+
/// recursively. The recursion is bounded and takes care of recursive types by
367+
/// means of the `visited` set.
368+
static void recordStructIndices(Type type, unsigned currentIndex,
369+
SmallVectorImpl<unsigned> &indices,
370+
SmallVectorImpl<unsigned> *structSizes,
371+
SmallPtrSet<Type, 4> &visited) {
372+
if (visited.contains(type))
373+
return;
374+
375+
visited.insert(type);
376+
377+
llvm::TypeSwitch<Type>(type)
378+
.Case<LLVMStructType>([&](LLVMStructType structType) {
379+
indices.push_back(currentIndex);
380+
if (structSizes)
381+
structSizes->push_back(structType.getBody().size());
382+
for (Type elementType : structType.getBody())
383+
recordStructIndices(elementType, currentIndex + 1, indices,
384+
structSizes, visited);
385+
})
386+
.Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
387+
LLVMArrayType>([&](auto containerType) {
388+
recordStructIndices(containerType.getElementType(), currentIndex + 1,
389+
indices, structSizes, visited);
390+
});
391+
}
392+
393+
/// Populates `indices` with positions of GEP indices that correspond to
394+
/// LLVMStructTypes potentially nested in the given `baseGEPType`, which must
395+
/// be either an LLVMPointer type or a vector thereof. If `structSizes` is
396+
/// provided, it is populated with sizes of the indexed structs for bounds
397+
/// verification purposes.
398+
static void
399+
findKnownStructIndices(Type baseGEPType, SmallVectorImpl<unsigned> &indices,
400+
SmallVectorImpl<unsigned> *structSizes = nullptr) {
401+
Type type = baseGEPType;
402+
if (auto vectorType = type.dyn_cast<VectorType>())
403+
type = vectorType.getElementType();
404+
if (auto scalableVectorType = type.dyn_cast<LLVMScalableVectorType>())
405+
type = scalableVectorType.getElementType();
406+
if (auto fixedVectorType = type.dyn_cast<LLVMFixedVectorType>())
407+
type = fixedVectorType.getElementType();
408+
409+
Type pointeeType = type.cast<LLVMPointerType>().getElementType();
410+
SmallPtrSet<Type, 4> visited;
411+
recordStructIndices(pointeeType, /*currentIndex=*/1, indices, structSizes,
412+
visited);
413+
}
414+
363415
void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
364416
Value basePtr, ValueRange operands,
365417
ArrayRef<NamedAttribute> attributes) {
@@ -372,11 +424,58 @@ void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
372424
Value basePtr, ValueRange indices,
373425
ArrayRef<int32_t> structIndices,
374426
ArrayRef<NamedAttribute> attributes) {
427+
SmallVector<Value> remainingIndices;
428+
SmallVector<int32_t> updatedStructIndices(structIndices.begin(),
429+
structIndices.end());
430+
SmallVector<unsigned> structRelatedPositions;
431+
findKnownStructIndices(basePtr.getType(), structRelatedPositions);
432+
433+
SmallVector<unsigned> operandsToErase;
434+
for (unsigned pos : structRelatedPositions) {
435+
// GEP may not be indexing as deep as some structs are located.
436+
if (pos >= structIndices.size())
437+
continue;
438+
439+
// If the index is already static, it's fine.
440+
if (structIndices[pos] != kDynamicIndex)
441+
continue;
442+
443+
// Find the corresponding operand.
444+
unsigned operandPos =
445+
std::count(structIndices.begin(), std::next(structIndices.begin(), pos),
446+
kDynamicIndex);
447+
448+
// Extract the constant value from the operand and put it into the attribute
449+
// instead.
450+
APInt staticIndexValue;
451+
bool matched =
452+
matchPattern(indices[operandPos], m_ConstantInt(&staticIndexValue));
453+
(void)matched;
454+
assert(matched && "index into a struct must be a constant");
455+
assert(staticIndexValue.sge(APInt::getSignedMinValue(/*numBits=*/32)) &&
456+
"struct index underflows 32-bit integer");
457+
assert(staticIndexValue.sle(APInt::getSignedMaxValue(/*numBits=*/32)) &&
458+
"struct index overflows 32-bit integer");
459+
auto staticIndex = static_cast<int32_t>(staticIndexValue.getSExtValue());
460+
updatedStructIndices[pos] = staticIndex;
461+
operandsToErase.push_back(operandPos);
462+
}
463+
464+
for (unsigned i = 0, e = indices.size(); i < e; ++i) {
465+
if (llvm::find(operandsToErase, i) == operandsToErase.end())
466+
remainingIndices.push_back(indices[i]);
467+
}
468+
469+
assert(remainingIndices.size() == static_cast<size_t>(llvm::count(
470+
updatedStructIndices, kDynamicIndex)) &&
471+
"exected as many index operands as dynamic index attr elements");
472+
375473
result.addTypes(resultType);
376474
result.addAttributes(attributes);
377-
result.addAttribute("structIndices", builder.getI32TensorAttr(structIndices));
475+
result.addAttribute("structIndices",
476+
builder.getI32TensorAttr(updatedStructIndices));
378477
result.addOperands(basePtr);
379-
result.addOperands(indices);
478+
result.addOperands(remainingIndices);
380479
}
381480

382481
static ParseResult
@@ -417,6 +516,27 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
417516
});
418517
}
419518

519+
LogicalResult verify(LLVM::GEPOp gepOp) {
520+
SmallVector<unsigned> indices;
521+
SmallVector<unsigned> structSizes;
522+
findKnownStructIndices(gepOp.getBase().getType(), indices, &structSizes);
523+
for (unsigned i = 0, e = indices.size(); i < e; ++i) {
524+
unsigned index = indices[i];
525+
// GEP may not be indexing as deep as some structs nested in the type.
526+
if (index >= gepOp.getStructIndices().getNumElements())
527+
continue;
528+
529+
int32_t staticIndex = gepOp.getStructIndices().getValues<int32_t>()[index];
530+
if (staticIndex == LLVM::GEPOp::kDynamicIndex)
531+
return gepOp.emitOpError() << "expected index " << index
532+
<< " indexing a struct to be constant";
533+
if (staticIndex < 0 || static_cast<unsigned>(staticIndex) >= structSizes[i])
534+
return gepOp.emitOpError()
535+
<< "index " << index << " indexing a struct is out of bounds";
536+
}
537+
return success();
538+
}
539+
420540
//===----------------------------------------------------------------------===//
421541
// Builder, printer and parser for for LLVM::LoadOp.
422542
//===----------------------------------------------------------------------===//

mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -501,8 +501,7 @@ func @memref_reshape(%input : memref<2x3xf32>, %shape : memref<?xindex>) {
501501
// CHECK: [[STRUCT_PTR:%.*]] = llvm.bitcast [[UNDERLYING_DESC]]
502502
// CHECK-SAME: !llvm.ptr<i8> to !llvm.ptr<struct<(ptr<f32>, ptr<f32>, i64, i64)>>
503503
// CHECK: [[C0:%.*]] = llvm.mlir.constant(0 : index) : i64
504-
// CHECK: [[C3_I32:%.*]] = llvm.mlir.constant(3 : i32) : i32
505-
// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], [[C3_I32]]]
504+
// CHECK: [[SIZES_PTR:%.*]] = llvm.getelementptr [[STRUCT_PTR]]{{\[}}[[C0]], 3]
506505
// CHECK: [[STRIDES_PTR:%.*]] = llvm.getelementptr [[SIZES_PTR]]{{\[}}[[RANK]]]
507506
// CHECK: [[SHAPE_IN_PTR:%.*]] = llvm.extractvalue [[SHAPE]][1] : [[SHAPE_TY]]
508507
// CHECK: [[C1_:%.*]] = llvm.mlir.constant(1 : index) : i64

mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -547,12 +547,11 @@ func @dim_of_unranked(%unranked: memref<*xi32>) -> index {
547547
// CHECK: %[[ZERO_D_DESC:.*]] = llvm.bitcast %[[RANKED_DESC]]
548548
// CHECK-SAME: : !llvm.ptr<i8> to !llvm.ptr<struct<(ptr<i32>, ptr<i32>, i64)>>
549549

550-
// CHECK: %[[C2_i32:.*]] = llvm.mlir.constant(2 : i32) : i32
551550
// CHECK: %[[C0_:.*]] = llvm.mlir.constant(0 : index) : i64
552551

553552
// CHECK: %[[OFFSET_PTR:.*]] = llvm.getelementptr %[[ZERO_D_DESC]]{{\[}}
554-
// CHECK-SAME: %[[C0_]], %[[C2_i32]]] : (!llvm.ptr<struct<(ptr<i32>, ptr<i32>,
555-
// CHECK-SAME: i64)>>, i64, i32) -> !llvm.ptr<i64>
553+
// CHECK-SAME: %[[C0_]], 2] : (!llvm.ptr<struct<(ptr<i32>, ptr<i32>,
554+
// CHECK-SAME: i64)>>, i64) -> !llvm.ptr<i64>
556555

557556
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64
558557
// CHECK: %[[INDEX_INC:.*]] = llvm.add %[[C1]], %{{.*}} : i64

mlir/test/Conversion/SPIRVToLLVM/memory-ops-to-llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ spv.func @access_chain() "None" {
1010
%0 = spv.Constant 1: i32
1111
%1 = spv.Variable : !spv.ptr<!spv.struct<(f32, !spv.array<4xf32>)>, Function>
1212
// CHECK: %[[ZERO:.*]] = llvm.mlir.constant(0 : i32) : i32
13-
// CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], %[[ONE]], %[[ONE]]] : (!llvm.ptr<struct<packed (f32, array<4 x f32>)>>, i32, i32, i32) -> !llvm.ptr<f32>
13+
// CHECK: llvm.getelementptr %{{.*}}[%[[ZERO]], 1, %[[ONE]]] : (!llvm.ptr<struct<packed (f32, array<4 x f32>)>>, i32, i32) -> !llvm.ptr<f32>
1414
%2 = spv.AccessChain %1[%0, %0] : !spv.ptr<!spv.struct<(f32, !spv.array<4xf32>)>, Function>, i32, i32
1515
spv.Return
1616
}

mlir/test/Dialect/LLVMIR/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,3 +1234,19 @@ func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
12341234
nvvm.cp.async.shared.global %arg0, %arg1, 32
12351235
return
12361236
}
1237+
1238+
// -----
1239+
1240+
func @gep_struct_variable(%arg0: !llvm.ptr<struct<(i32)>>, %arg1: i32, %arg2: i32) {
1241+
// expected-error @below {{op expected index 1 indexing a struct to be constant}}
1242+
llvm.getelementptr %arg0[%arg1, %arg1] : (!llvm.ptr<struct<(i32)>>, i32, i32) -> !llvm.ptr<i32>
1243+
return
1244+
}
1245+
1246+
// -----
1247+
1248+
func @gep_out_of_bounds(%ptr: !llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, %idx: i64) {
1249+
// expected-error @below {{index 2 indexing a struct is out of bounds}}
1250+
llvm.getelementptr %ptr[%idx, 1, 3] : (!llvm.ptr<struct<(i32, struct<(i32, f32)>)>>, i64) -> !llvm.ptr<i32>
1251+
return
1252+
}

mlir/test/Target/LLVMIR/llvmir.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1444,7 +1444,7 @@ llvm.mlir.global linkonce @take_self_address() : !llvm.struct<(i32, !llvm.ptr<i3
14441444
%z32 = llvm.mlir.constant(0 : i32) : i32
14451445
%0 = llvm.mlir.undef : !llvm.struct<(i32, !llvm.ptr<i32>)>
14461446
%1 = llvm.mlir.addressof @take_self_address : !llvm.ptr<!llvm.struct<(i32, !llvm.ptr<i32>)>>
1447-
%2 = llvm.getelementptr %1[%z32, %z32] : (!llvm.ptr<!llvm.struct<(i32, !llvm.ptr<i32>)>>, i32, i32) -> !llvm.ptr<i32>
1447+
%2 = llvm.getelementptr %1[%z32, 0] : (!llvm.ptr<!llvm.struct<(i32, !llvm.ptr<i32>)>>, i32) -> !llvm.ptr<i32>
14481448
%3 = llvm.insertvalue %z32, %0[0 : i32] : !llvm.struct<(i32, !llvm.ptr<i32>)>
14491449
%4 = llvm.insertvalue %2, %3[1 : i32] : !llvm.struct<(i32, !llvm.ptr<i32>)>
14501450
llvm.return %4 : !llvm.struct<(i32, !llvm.ptr<i32>)>

0 commit comments

Comments
 (0)