Skip to content

Commit 15ac5b9

Browse files
committed
Address comments
1 parent bba9af2 commit 15ac5b9

File tree

4 files changed

+7
-9
lines changed

4 files changed

+7
-9
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/LoweringPatterns.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,8 @@ void populateVectorStepLoweringPatterns(RewritePatternSet &patterns,
241241

242242
/// Populate the pattern set with the following patterns:
243243
///
244-
/// [FlattenGather]
245-
/// Flattens 2 or more dimensional `vector.gather` ops by unrolling the
244+
/// [UnrollGather]
245+
/// Unrolls 2 or more dimensional `vector.gather` ops by unrolling the
246246
/// outermost dimension.
247247
void populateVectorGatherLoweringPatterns(RewritePatternSet &patterns,
248248
PatternBenefit benefit = 1);

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ class VectorGatherOpConversion
273273
if (vType.getRank() > 1)
274274
return failure();
275275

276-
auto loc = gather->getLoc();
276+
Location loc = gather->getLoc();
277277

278278
// Resolve alignment.
279279
unsigned align;
@@ -284,12 +284,10 @@ class VectorGatherOpConversion
284284
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.getBase(),
285285
adaptor.getIndices(), rewriter);
286286
Value base = adaptor.getBase();
287-
288-
// Handle the simple case of 1-D vector.
289-
// Resolve address.
290287
Value ptrs =
291288
getIndexedPtrs(rewriter, loc, *this->getTypeConverter(), memRefType,
292289
base, ptr, adaptor.getIndexVec(), vType);
290+
293291
// Replace with the gather intrinsic.
294292
rewriter.replaceOpWithNewOp<LLVM::masked_gather>(
295293
gather, typeConverter->convertType(vType), ptrs, adaptor.getMask(),

mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ struct UnrollGather : OpRewritePattern<vector::GatherOp> {
107107
/// ```mlir
108108
/// %subview = memref.subview %M (...)
109109
/// : memref<100x3xf32> to memref<100xf32, strided<[3]>>
110-
/// %gather = vector.gather %subview[%idxs] (...) : memref<100xf32,
111-
/// strided<[3]>>
110+
/// %gather = vector.gather %subview[%idxs] (...)
111+
/// : memref<100xf32, strided<[3]>>
112112
/// ```
113113
/// ==>
114114
/// ```mlir

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1677,7 +1677,7 @@ func.func @gather_with_mask(%arg0: memref<?xf32>, %arg1: vector<2x3xi32>, %arg2:
16771677
func.func @gather_with_mask_scalable(%arg0: memref<?xf32>, %arg1: vector<2x[3]xi32>, %arg2: vector<2x[3]xf32>) -> vector<2x[3]xf32> {
16781678
%0 = arith.constant 0: index
16791679
// vector.constant_mask only supports 'none set' or 'all set' scalable
1680-
// dimensions, hence [1, 3] rather than [1, 2] as in the example for fixed
1680+
// dimensions, hence [2, 3] rather than [2, 2] as in the example for fixed
16811681
// width vectors above.
16821682
%1 = vector.constant_mask [2, 3] : vector<2x[3]xi1>
16831683
%2 = vector.gather %arg0[%0][%arg1], %1, %arg2 : memref<?xf32>, vector<2x[3]xi32>, vector<2x[3]xi1>, vector<2x[3]xf32> into vector<2x[3]xf32>

0 commit comments

Comments
 (0)