@@ -62,49 +62,8 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
62
62
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
63
63
Value memRefDesc, ValueRange indices,
64
64
LLVM::GEPNoWrapFlags noWrapFlags) const {
65
-
66
- auto [strides, offset] = type.getStridesAndOffset ();
67
-
68
- MemRefDescriptor memRefDescriptor (memRefDesc);
69
- // Use a canonical representation of the start address so that later
70
- // optimizations have a longer sequence of instructions to CSE.
71
- // If we don't do that we would sprinkle the memref.offset in various
72
- // position of the different address computations.
73
- Value base =
74
- memRefDescriptor.bufferPtr (rewriter, loc, *getTypeConverter (), type);
75
-
76
- LLVM::IntegerOverflowFlags intOverflowFlags =
77
- LLVM::IntegerOverflowFlags::none;
78
- if (LLVM::bitEnumContainsAny (noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
79
- intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
80
- }
81
- if (LLVM::bitEnumContainsAny (noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
82
- intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
83
- }
84
-
85
- Type indexType = getIndexType ();
86
- Value index;
87
- for (int i = 0 , e = indices.size (); i < e; ++i) {
88
- Value increment = indices[i];
89
- if (strides[i] != 1 ) { // Skip if stride is 1.
90
- Value stride =
91
- ShapedType::isDynamic (strides[i])
92
- ? memRefDescriptor.stride (rewriter, loc, i)
93
- : createIndexAttrConstant (rewriter, loc, indexType, strides[i]);
94
- increment = rewriter.create <LLVM::MulOp>(loc, increment, stride,
95
- intOverflowFlags);
96
- }
97
- index = index ? rewriter.create <LLVM::AddOp>(loc, index, increment,
98
- intOverflowFlags)
99
- : increment;
100
- }
101
-
102
- Type elementPtrType = memRefDescriptor.getElementPtrType ();
103
- return index ? rewriter.create <LLVM::GEPOp>(
104
- loc, elementPtrType,
105
- getTypeConverter ()->convertType (type.getElementType ()),
106
- base, index, noWrapFlags)
107
- : base;
65
+ return LLVM::getStridedElementPtr (rewriter, loc, *getTypeConverter (), type,
66
+ memRefDesc, indices, noWrapFlags);
108
67
}
109
68
110
69
// Check if the MemRefType `type` is supported by the lowering. We currently
@@ -524,3 +483,52 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
524
483
525
484
return res;
526
485
}
486
+
487
+ Value mlir::LLVM::getStridedElementPtr (OpBuilder &builder, Location loc,
488
+ const LLVMTypeConverter &converter,
489
+ MemRefType type, Value memRefDesc,
490
+ ValueRange indices,
491
+ LLVM::GEPNoWrapFlags noWrapFlags) {
492
+ auto [strides, offset] = type.getStridesAndOffset ();
493
+
494
+ MemRefDescriptor memRefDescriptor (memRefDesc);
495
+ // Use a canonical representation of the start address so that later
496
+ // optimizations have a longer sequence of instructions to CSE.
497
+ // If we don't do that we would sprinkle the memref.offset in various
498
+ // position of the different address computations.
499
+ Value base = memRefDescriptor.bufferPtr (builder, loc, converter, type);
500
+
501
+ LLVM::IntegerOverflowFlags intOverflowFlags =
502
+ LLVM::IntegerOverflowFlags::none;
503
+ if (LLVM::bitEnumContainsAny (noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
504
+ intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
505
+ }
506
+ if (LLVM::bitEnumContainsAny (noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
507
+ intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
508
+ }
509
+
510
+ Type indexType = converter.getIndexType ();
511
+ Value index;
512
+ for (int i = 0 , e = indices.size (); i < e; ++i) {
513
+ Value increment = indices[i];
514
+ if (strides[i] != 1 ) { // Skip if stride is 1.
515
+ Value stride =
516
+ ShapedType::isDynamic (strides[i])
517
+ ? memRefDescriptor.stride (builder, loc, i)
518
+ : builder.create <LLVM::ConstantOp>(
519
+ loc, indexType, builder.getIndexAttr (strides[i]));
520
+ increment =
521
+ builder.create <LLVM::MulOp>(loc, increment, stride, intOverflowFlags);
522
+ }
523
+ index = index ? builder.create <LLVM::AddOp>(loc, index, increment,
524
+ intOverflowFlags)
525
+ : increment;
526
+ }
527
+
528
+ Type elementPtrType = memRefDescriptor.getElementPtrType ();
529
+ return index ? builder.create <LLVM::GEPOp>(
530
+ loc, elementPtrType,
531
+ converter.convertType (type.getElementType ()), base, index,
532
+ noWrapFlags)
533
+ : base;
534
+ }
0 commit comments