@@ -360,6 +360,58 @@ SwitchOp::getMutableSuccessorOperands(unsigned index) {
360
360
// Code for LLVM::GEPOp.
361
361
// ===----------------------------------------------------------------------===//
362
362
363
+ // / Populates `indices` with positions of GEP indices that would correspond to
364
+ // / LLVMStructTypes potentially nested in the given type. The type currently
365
+ // / visited gets `currentIndex` and LLVM container types are visited
366
+ // / recursively. The recursion is bounded and takes care of recursive types by
367
+ // / means of the `visited` set.
368
+ static void recordStructIndices (Type type, unsigned currentIndex,
369
+ SmallVectorImpl<unsigned > &indices,
370
+ SmallVectorImpl<unsigned > *structSizes,
371
+ SmallPtrSet<Type, 4 > &visited) {
372
+ if (visited.contains (type))
373
+ return ;
374
+
375
+ visited.insert (type);
376
+
377
+ llvm::TypeSwitch<Type>(type)
378
+ .Case <LLVMStructType>([&](LLVMStructType structType) {
379
+ indices.push_back (currentIndex);
380
+ if (structSizes)
381
+ structSizes->push_back (structType.getBody ().size ());
382
+ for (Type elementType : structType.getBody ())
383
+ recordStructIndices (elementType, currentIndex + 1 , indices,
384
+ structSizes, visited);
385
+ })
386
+ .Case <VectorType, LLVMScalableVectorType, LLVMFixedVectorType,
387
+ LLVMArrayType>([&](auto containerType) {
388
+ recordStructIndices (containerType.getElementType (), currentIndex + 1 ,
389
+ indices, structSizes, visited);
390
+ });
391
+ }
392
+
393
+ // / Populates `indices` with positions of GEP indices that correspond to
394
+ // / LLVMStructTypes potentially nested in the given `baseGEPType`, which must
395
+ // / be either an LLVMPointer type or a vector thereof. If `structSizes` is
396
+ // / provided, it is populated with sizes of the indexed structs for bounds
397
+ // / verification purposes.
398
+ static void
399
+ findKnownStructIndices (Type baseGEPType, SmallVectorImpl<unsigned > &indices,
400
+ SmallVectorImpl<unsigned > *structSizes = nullptr ) {
401
+ Type type = baseGEPType;
402
+ if (auto vectorType = type.dyn_cast <VectorType>())
403
+ type = vectorType.getElementType ();
404
+ if (auto scalableVectorType = type.dyn_cast <LLVMScalableVectorType>())
405
+ type = scalableVectorType.getElementType ();
406
+ if (auto fixedVectorType = type.dyn_cast <LLVMFixedVectorType>())
407
+ type = fixedVectorType.getElementType ();
408
+
409
+ Type pointeeType = type.cast <LLVMPointerType>().getElementType ();
410
+ SmallPtrSet<Type, 4 > visited;
411
+ recordStructIndices (pointeeType, /* currentIndex=*/ 1 , indices, structSizes,
412
+ visited);
413
+ }
414
+
363
415
void GEPOp::build (OpBuilder &builder, OperationState &result, Type resultType,
364
416
Value basePtr, ValueRange operands,
365
417
ArrayRef<NamedAttribute> attributes) {
@@ -372,11 +424,58 @@ void GEPOp::build(OpBuilder &builder, OperationState &result, Type resultType,
372
424
Value basePtr, ValueRange indices,
373
425
ArrayRef<int32_t > structIndices,
374
426
ArrayRef<NamedAttribute> attributes) {
427
+ SmallVector<Value> remainingIndices;
428
+ SmallVector<int32_t > updatedStructIndices (structIndices.begin (),
429
+ structIndices.end ());
430
+ SmallVector<unsigned > structRelatedPositions;
431
+ findKnownStructIndices (basePtr.getType (), structRelatedPositions);
432
+
433
+ SmallVector<unsigned > operandsToErase;
434
+ for (unsigned pos : structRelatedPositions) {
435
+ // GEP may not be indexing as deep as some structs are located.
436
+ if (pos >= structIndices.size ())
437
+ continue ;
438
+
439
+ // If the index is already static, it's fine.
440
+ if (structIndices[pos] != kDynamicIndex )
441
+ continue ;
442
+
443
+ // Find the corresponding operand.
444
+ unsigned operandPos =
445
+ std::count (structIndices.begin (), std::next (structIndices.begin (), pos),
446
+ kDynamicIndex );
447
+
448
+ // Extract the constant value from the operand and put it into the attribute
449
+ // instead.
450
+ APInt staticIndexValue;
451
+ bool matched =
452
+ matchPattern (indices[operandPos], m_ConstantInt (&staticIndexValue));
453
+ (void )matched;
454
+ assert (matched && " index into a struct must be a constant" );
455
+ assert (staticIndexValue.sge (APInt::getSignedMinValue (/* numBits=*/ 32 )) &&
456
+ " struct index underflows 32-bit integer" );
457
+ assert (staticIndexValue.sle (APInt::getSignedMaxValue (/* numBits=*/ 32 )) &&
458
+ " struct index overflows 32-bit integer" );
459
+ auto staticIndex = static_cast <int32_t >(staticIndexValue.getSExtValue ());
460
+ updatedStructIndices[pos] = staticIndex;
461
+ operandsToErase.push_back (operandPos);
462
+ }
463
+
464
+ for (unsigned i = 0 , e = indices.size (); i < e; ++i) {
465
+ if (llvm::find (operandsToErase, i) == operandsToErase.end ())
466
+ remainingIndices.push_back (indices[i]);
467
+ }
468
+
469
+ assert (remainingIndices.size () == static_cast <size_t >(llvm::count (
470
+ updatedStructIndices, kDynamicIndex )) &&
471
+ " exected as many index operands as dynamic index attr elements" );
472
+
375
473
result.addTypes (resultType);
376
474
result.addAttributes (attributes);
377
- result.addAttribute (" structIndices" , builder.getI32TensorAttr (structIndices));
475
+ result.addAttribute (" structIndices" ,
476
+ builder.getI32TensorAttr (updatedStructIndices));
378
477
result.addOperands (basePtr);
379
- result.addOperands (indices );
478
+ result.addOperands (remainingIndices );
380
479
}
381
480
382
481
static ParseResult
@@ -417,6 +516,27 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
417
516
});
418
517
}
419
518
519
+ LogicalResult verify (LLVM::GEPOp gepOp) {
520
+ SmallVector<unsigned > indices;
521
+ SmallVector<unsigned > structSizes;
522
+ findKnownStructIndices (gepOp.getBase ().getType (), indices, &structSizes);
523
+ for (unsigned i = 0 , e = indices.size (); i < e; ++i) {
524
+ unsigned index = indices[i];
525
+ // GEP may not be indexing as deep as some structs nested in the type.
526
+ if (index >= gepOp.getStructIndices ().getNumElements ())
527
+ continue ;
528
+
529
+ int32_t staticIndex = gepOp.getStructIndices ().getValues <int32_t >()[index];
530
+ if (staticIndex == LLVM::GEPOp::kDynamicIndex )
531
+ return gepOp.emitOpError () << " expected index " << index
532
+ << " indexing a struct to be constant" ;
533
+ if (staticIndex < 0 || static_cast <unsigned >(staticIndex) >= structSizes[i])
534
+ return gepOp.emitOpError ()
535
+ << " index " << index << " indexing a struct is out of bounds" ;
536
+ }
537
+ return success ();
538
+ }
539
+
420
540
// ===----------------------------------------------------------------------===//
421
541
// Builder, printer and parser for for LLVM::LoadOp.
422
542
// ===----------------------------------------------------------------------===//
0 commit comments