Skip to content

Commit df5ccf5

Browse files
committed
[mlir][vector] add higher dimensional support to gather/scatter
Similar to mask-load/store and compress/expand, the gather and scatter operation now allow for higher dimension uses. Note that to support the mixed-type index, the new syntax is: vector.gather %base [%i,%j] [%kvector] .... The first client of this generalization is the sparse compiler, which needs to define scatter and gathers on dense operands of higher dimensions too. Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D97422
1 parent c62dabc commit df5ccf5

File tree

14 files changed

+244
-115
lines changed

14 files changed

+244
-115
lines changed

mlir/include/mlir/Dialect/Vector/VectorOps.td

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1574,11 +1574,14 @@ def Vector_MaskedLoadOp :
15741574
closely correspond to those of the `llvm.masked.load`
15751575
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-load-intrinsics).
15761576

1577-
Example:
1577+
Examples:
15781578

15791579
```mlir
15801580
%0 = vector.maskedload %base[%i], %mask, %pass_thru
15811581
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
1582+
1583+
%1 = vector.maskedload %base[%i, %j], %mask, %pass_thru
1584+
: memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
15821585
```
15831586
}];
15841587
let extraClassDeclaration = [{
@@ -1625,11 +1628,14 @@ def Vector_MaskedStoreOp :
16251628
closely correspond to those of the `llvm.masked.store`
16261629
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-store-intrinsics).
16271630

1628-
Example:
1631+
Examples:
16291632

16301633
```mlir
16311634
vector.maskedstore %base[%i], %mask, %value
16321635
: memref<?xf32>, vector<8xi1>, vector<8xf32>
1636+
1637+
vector.maskedstore %base[%i, %j], %mask, %value
1638+
: memref<?x?xf32>, vector<16xi1>, vector<16xf32>
16331639
```
16341640
}];
16351641
let extraClassDeclaration = [{
@@ -1652,7 +1658,8 @@ def Vector_MaskedStoreOp :
16521658
def Vector_GatherOp :
16531659
Vector_Op<"gather">,
16541660
Arguments<(ins Arg<AnyMemRef, "", [MemRead]>:$base,
1655-
VectorOfRankAndType<[1], [AnyInteger]>:$indices,
1661+
Variadic<Index>:$indices,
1662+
VectorOfRankAndType<[1], [AnyInteger]>:$index_vec,
16561663
VectorOfRankAndType<[1], [I1]>:$mask,
16571664
VectorOfRank<[1]>:$pass_thru)>,
16581665
Results<(outs VectorOfRank<[1]>:$result)> {
@@ -1661,9 +1668,10 @@ def Vector_GatherOp :
16611668

16621669
let description = [{
16631670
The gather operation gathers elements from memory into a 1-D vector as
1664-
defined by a base and a 1-D index vector, but only if the corresponding
1665-
bit is set in a 1-D mask vector. Otherwise, the element is taken from a
1666-
1-D pass-through vector. Informally the semantics are:
1671+
defined by a base with indices and an additional 1-D index vector, but
1672+
only if the corresponding bit is set in a 1-D mask vector. Otherwise, the
1673+
element is taken from a 1-D pass-through vector. Informally the semantics
1674+
are:
16671675
```
16681676
result[0] := mask[0] ? base[index[0]] : pass_thru[0]
16691677
result[1] := mask[1] ? base[index[1]] : pass_thru[1]
@@ -1677,19 +1685,22 @@ def Vector_GatherOp :
16771685
correspond to those of the `llvm.masked.gather`
16781686
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-gather-intrinsics).
16791687

1680-
Example:
1688+
Examples:
16811689

16821690
```mlir
1683-
%g = vector.gather %base[%indices], %mask, %pass_thru
1691+
%0 = vector.gather %base[%c0][%v], %mask, %pass_thru
16841692
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1693+
1694+
%1 = vector.gather %base[%i, %j][%v], %mask, %pass_thru
1695+
: memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
16851696
```
16861697
}];
16871698
let extraClassDeclaration = [{
16881699
MemRefType getMemRefType() {
16891700
return base().getType().cast<MemRefType>();
16901701
}
1691-
VectorType getIndicesVectorType() {
1692-
return indices().getType().cast<VectorType>();
1702+
VectorType getIndexVectorType() {
1703+
return index_vec().getType().cast<VectorType>();
16931704
}
16941705
VectorType getMaskVectorType() {
16951706
return mask().getType().cast<VectorType>();
@@ -1701,25 +1712,29 @@ def Vector_GatherOp :
17011712
return result().getType().cast<VectorType>();
17021713
}
17031714
}];
1704-
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $pass_thru attr-dict `:` "
1705-
"type($base) `,` type($indices) `,` type($mask) `,` type($pass_thru) `into` type($result)";
1715+
let assemblyFormat =
1716+
"$base `[` $indices `]` `[` $index_vec `]` `,` "
1717+
"$mask `,` $pass_thru attr-dict `:` type($base) `,` "
1718+
"type($index_vec) `,` type($mask) `,` type($pass_thru) "
1719+
"`into` type($result)";
17061720
let hasCanonicalizer = 1;
17071721
}
17081722

17091723
def Vector_ScatterOp :
17101724
Vector_Op<"scatter">,
17111725
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
1712-
VectorOfRankAndType<[1], [AnyInteger]>:$indices,
1726+
Variadic<Index>:$indices,
1727+
VectorOfRankAndType<[1], [AnyInteger]>:$index_vec,
17131728
VectorOfRankAndType<[1], [I1]>:$mask,
17141729
VectorOfRank<[1]>:$valueToStore)> {
17151730

17161731
let summary = "scatters elements from a vector into memory as defined by an index vector and mask";
17171732

17181733
let description = [{
17191734
The scatter operation scatters elements from a 1-D vector into memory as
1720-
defined by a base and a 1-D index vector, but only if the corresponding
1721-
bit in a 1-D mask vector is set. Otherwise, no action is taken for that
1722-
element. Informally the semantics are:
1735+
defined by a base with indices and an additional 1-D index vector, but
1736+
only if the corresponding bit in a 1-D mask vector is set. Otherwise, no
1737+
action is taken for that element. Informally the semantics are:
17231738
```
17241739
if (mask[0]) base[index[0]] = value[0]
17251740
if (mask[1]) base[index[1]] = value[1]
@@ -1736,19 +1751,22 @@ def Vector_ScatterOp :
17361751
correspond to those of the `llvm.masked.scatter`
17371752
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-scatter-intrinsics).
17381753

1739-
Example:
1754+
Examples:
17401755

17411756
```mlir
1742-
vector.scatter %base[%indices], %mask, %value
1757+
vector.scatter %base[%c0][%v], %mask, %value
17431758
: memref<?xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
1759+
1760+
vector.scatter %base[%i, %j][%v], %mask, %value
1761+
: memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
17441762
```
17451763
}];
17461764
let extraClassDeclaration = [{
17471765
MemRefType getMemRefType() {
17481766
return base().getType().cast<MemRefType>();
17491767
}
1750-
VectorType getIndicesVectorType() {
1751-
return indices().getType().cast<VectorType>();
1768+
VectorType getIndexVectorType() {
1769+
return index_vec().getType().cast<VectorType>();
17521770
}
17531771
VectorType getMaskVectorType() {
17541772
return mask().getType().cast<VectorType>();
@@ -1758,8 +1776,9 @@ def Vector_ScatterOp :
17581776
}
17591777
}];
17601778
let assemblyFormat =
1761-
"$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` "
1762-
"type($base) `,` type($indices) `,` type($mask) `,` type($valueToStore)";
1779+
"$base `[` $indices `]` `[` $index_vec `]` `,` "
1780+
"$mask `,` $valueToStore attr-dict `:` type($base) `,` "
1781+
"type($index_vec) `,` type($mask) `,` type($valueToStore)";
17631782
let hasCanonicalizer = 1;
17641783
}
17651784

@@ -1792,11 +1811,14 @@ def Vector_ExpandLoadOp :
17921811
correspond to those of the `llvm.masked.expandload`
17931812
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-expandload-intrinsics).
17941813

1795-
Example:
1814+
Examples:
17961815

17971816
```mlir
17981817
%0 = vector.expandload %base[%i], %mask, %pass_thru
17991818
: memref<?xf32>, vector<8xi1>, vector<8xf32> into vector<8xf32>
1819+
1820+
%1 = vector.expandload %base[%i, %j], %mask, %pass_thru
1821+
: memref<?x?xf32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
18001822
```
18011823
}];
18021824
let extraClassDeclaration = [{
@@ -1846,11 +1868,14 @@ def Vector_CompressStoreOp :
18461868
correspond to those of the `llvm.masked.compressstore`
18471869
[intrinsic](https://llvm.org/docs/LangRef.html#llvm-masked-compressstore-intrinsics).
18481870

1849-
Example:
1871+
Examples:
18501872

18511873
```mlir
18521874
vector.compressstore %base[%i], %mask, %value
18531875
: memref<?xf32>, vector<8xi1>, vector<8xf32>
1876+
1877+
vector.compressstore %base[%i, %j], %mask, %value
1878+
: memref<?x?xf32>, vector<16xi1>, vector<16xf32>
18541879
```
18551880
}];
18561881
let extraClassDeclaration = [{

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -178,34 +178,21 @@ LogicalResult getMemRefAlignment(LLVMTypeConverter &typeConverter,
178178
return success();
179179
}
180180

181-
// Helper that returns the base address of a memref.
182-
static LogicalResult getBase(ConversionPatternRewriter &rewriter, Location loc,
183-
Value memref, MemRefType memRefType, Value &base) {
184-
// Inspect stride and offset structure.
185-
//
186-
// TODO: flat memory only for now, generalize
187-
//
181+
// Add an index vector component to a base pointer. This almost always succeeds
182+
// unless the last stride is non-unit or the memory space is not zero.
183+
static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
184+
Location loc, Value memref, Value base,
185+
Value index, MemRefType memRefType,
186+
VectorType vType, Value &ptrs) {
188187
int64_t offset;
189188
SmallVector<int64_t, 4> strides;
190189
auto successStrides = getStridesAndOffset(memRefType, strides, offset);
191-
if (failed(successStrides) || strides.size() != 1 || strides[0] != 1 ||
192-
offset != 0 || memRefType.getMemorySpace() != 0)
193-
return failure();
194-
base = MemRefDescriptor(memref).alignedPtr(rewriter, loc);
195-
return success();
196-
}
197-
198-
// Helper that returns vector of pointers given a memref base with index vector.
199-
static LogicalResult getIndexedPtrs(ConversionPatternRewriter &rewriter,
200-
Location loc, Value memref, Value indices,
201-
MemRefType memRefType, VectorType vType,
202-
Type iType, Value &ptrs) {
203-
Value base;
204-
if (failed(getBase(rewriter, loc, memref, memRefType, base)))
190+
if (failed(successStrides) || strides.back() != 1 ||
191+
memRefType.getMemorySpace() != 0)
205192
return failure();
206193
auto pType = MemRefDescriptor(memref).getElementPtrType();
207194
auto ptrsType = LLVM::getFixedVectorType(pType, vType.getDimSize(0));
208-
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, indices);
195+
ptrs = rewriter.create<LLVM::GEPOp>(loc, ptrsType, base, index);
209196
return success();
210197
}
211198

@@ -435,19 +422,20 @@ class VectorGatherOpConversion
435422
ConversionPatternRewriter &rewriter) const override {
436423
auto loc = gather->getLoc();
437424
auto adaptor = vector::GatherOpAdaptor(operands);
425+
MemRefType memRefType = gather.getMemRefType();
438426

439427
// Resolve alignment.
440428
unsigned align;
441-
if (failed(getMemRefAlignment(*getTypeConverter(), gather.getMemRefType(),
442-
align)))
429+
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
443430
return failure();
444431

445-
// Get index ptrs.
446-
VectorType vType = gather.getVectorType();
447-
Type iType = gather.getIndicesVectorType().getElementType();
432+
// Resolve address.
448433
Value ptrs;
449-
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
450-
gather.getMemRefType(), vType, iType, ptrs)))
434+
VectorType vType = gather.getVectorType();
435+
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
436+
adaptor.indices(), rewriter);
437+
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
438+
adaptor.index_vec(), memRefType, vType, ptrs)))
451439
return failure();
452440

453441
// Replace with the gather intrinsic.
@@ -469,19 +457,20 @@ class VectorScatterOpConversion
469457
ConversionPatternRewriter &rewriter) const override {
470458
auto loc = scatter->getLoc();
471459
auto adaptor = vector::ScatterOpAdaptor(operands);
460+
MemRefType memRefType = scatter.getMemRefType();
472461

473462
// Resolve alignment.
474463
unsigned align;
475-
if (failed(getMemRefAlignment(*getTypeConverter(), scatter.getMemRefType(),
476-
align)))
464+
if (failed(getMemRefAlignment(*getTypeConverter(), memRefType, align)))
477465
return failure();
478466

479-
// Get index ptrs.
480-
VectorType vType = scatter.getVectorType();
481-
Type iType = scatter.getIndicesVectorType().getElementType();
467+
// Resolve address.
482468
Value ptrs;
483-
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
484-
scatter.getMemRefType(), vType, iType, ptrs)))
469+
VectorType vType = scatter.getVectorType();
470+
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
471+
adaptor.indices(), rewriter);
472+
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), ptr,
473+
adaptor.index_vec(), memRefType, vType, ptrs)))
485474
return failure();
486475

487476
// Replace with the scatter intrinsic.
@@ -507,8 +496,8 @@ class VectorExpandLoadOpConversion
507496

508497
// Resolve address.
509498
auto vtype = typeConverter->convertType(expand.getVectorType());
510-
Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
511-
adaptor.indices(), rewriter);
499+
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
500+
adaptor.indices(), rewriter);
512501

513502
rewriter.replaceOpWithNewOp<LLVM::masked_expandload>(
514503
expand, vtype, ptr, adaptor.mask(), adaptor.pass_thru());
@@ -530,8 +519,8 @@ class VectorCompressStoreOpConversion
530519
MemRefType memRefType = compress.getMemRefType();
531520

532521
// Resolve address.
533-
Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
534-
adaptor.indices(), rewriter);
522+
Value ptr = getStridedElementPtr(loc, memRefType, adaptor.base(),
523+
adaptor.indices(), rewriter);
535524

536525
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
537526
compress, adaptor.valueToStore(), ptr, adaptor.mask());

mlir/lib/Dialect/Linalg/Transforms/Sparsification.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -652,9 +652,13 @@ static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
652652
Location loc = ptr.getLoc();
653653
VectorType vtp = vectorType(codegen, ptr);
654654
Value pass = rewriter.create<ConstantOp>(loc, vtp, rewriter.getZeroAttr(vtp));
655-
if (args.back().getType().isa<VectorType>())
656-
return rewriter.create<vector::GatherOp>(loc, vtp, ptr, args.back(),
657-
codegen.curVecMask, pass);
655+
if (args.back().getType().isa<VectorType>()) {
656+
SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
657+
Value indexVec = args.back();
658+
scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
659+
return rewriter.create<vector::GatherOp>(
660+
loc, vtp, ptr, scalarArgs, indexVec, codegen.curVecMask, pass);
661+
}
658662
return rewriter.create<vector::MaskedLoadOp>(loc, vtp, ptr, args,
659663
codegen.curVecMask, pass);
660664
}
@@ -663,12 +667,16 @@ static Value genVectorLoad(CodeGen &codegen, PatternRewriter &rewriter,
663667
static void genVectorStore(CodeGen &codegen, PatternRewriter &rewriter,
664668
Value rhs, Value ptr, ArrayRef<Value> args) {
665669
Location loc = ptr.getLoc();
666-
if (args.back().getType().isa<VectorType>())
667-
rewriter.create<vector::ScatterOp>(loc, ptr, args.back(),
670+
if (args.back().getType().isa<VectorType>()) {
671+
SmallVector<Value, 4> scalarArgs(args.begin(), args.end());
672+
Value indexVec = args.back();
673+
scalarArgs.back() = rewriter.create<ConstantIndexOp>(loc, 0);
674+
rewriter.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec,
668675
codegen.curVecMask, rhs);
669-
else
670-
rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
671-
rhs);
676+
return;
677+
}
678+
rewriter.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask,
679+
rhs);
672680
}
673681

674682
/// Generates a vectorized invariant. Here we rely on subsequent loop
@@ -985,11 +993,15 @@ static Operation *genWhile(Merger &merger, CodeGen &codegen,
985993
unsigned tensor = merger.tensor(b);
986994
assert(idx == merger.index(b));
987995
types.push_back(indexType);
996+
assert(codegen.pidxs[tensor][idx].getType().isa<IndexType>() &&
997+
"type mismatch for sparse index");
988998
operands.push_back(codegen.pidxs[tensor][idx]);
989999
}
9901000
}
9911001
if (needsUniv) {
9921002
types.push_back(indexType);
1003+
assert(codegen.loops[idx].getType().isa<IndexType>() &&
1004+
"type_mismatch for universal index");
9931005
operands.push_back(codegen.loops[idx]);
9941006
}
9951007
Location loc = op.getLoc();
@@ -1160,6 +1172,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
11601172
genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
11611173
return;
11621174
}
1175+
assert(codegen.curVecLength == 1);
11631176

11641177
// Construct iteration lattices for current loop index, with L0 at top.
11651178
// Then emit initialization code for the loop sequence at this level.
@@ -1239,6 +1252,7 @@ static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
12391252
}
12401253
genInvariants(merger, codegen, rewriter, op, exp, ldx, /*hoist=*/false);
12411254
codegen.loops[idx] = Value();
1255+
codegen.curVecLength = 1;
12421256
}
12431257

12441258
namespace {

0 commit comments

Comments
 (0)