Skip to content

Commit 656674a

Browse files
committed
[mlir][Vector] Align gather/scatter/expand/compress API
Align the vector gather/scatter/expand/compress API with the vector load/store/maskedload/maskedstore API. Reviewed By: aartbik Differential Revision: https://reviews.llvm.org/D96396
1 parent ee66e43 commit 656674a

File tree

4 files changed

+35
-33
lines changed

4 files changed

+35
-33
lines changed

mlir/include/mlir/Dialect/Vector/VectorOps.td

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1619,7 +1619,7 @@ def Vector_GatherOp :
16191619
VectorType getPassThruVectorType() {
16201620
return pass_thru().getType().cast<VectorType>();
16211621
}
1622-
VectorType getResultVectorType() {
1622+
VectorType getVectorType() {
16231623
return result().getType().cast<VectorType>();
16241624
}
16251625
}];
@@ -1633,7 +1633,7 @@ def Vector_ScatterOp :
16331633
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
16341634
VectorOfRankAndType<[1], [AnyInteger]>:$indices,
16351635
VectorOfRankAndType<[1], [I1]>:$mask,
1636-
VectorOfRank<[1]>:$value)> {
1636+
VectorOfRank<[1]>:$valueToStore)> {
16371637

16381638
let summary = "scatters elements from a vector into memory as defined by an index vector and mask";
16391639

@@ -1675,12 +1675,13 @@ def Vector_ScatterOp :
16751675
VectorType getMaskVectorType() {
16761676
return mask().getType().cast<VectorType>();
16771677
}
1678-
VectorType getValueVectorType() {
1679-
return value().getType().cast<VectorType>();
1678+
VectorType getVectorType() {
1679+
return valueToStore().getType().cast<VectorType>();
16801680
}
16811681
}];
1682-
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
1683-
"type($base) `,` type($indices) `,` type($mask) `,` type($value)";
1682+
let assemblyFormat =
1683+
"$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` "
1684+
"type($base) `,` type($indices) `,` type($mask) `,` type($valueToStore)";
16841685
let hasCanonicalizer = 1;
16851686
}
16861687

@@ -1730,7 +1731,7 @@ def Vector_ExpandLoadOp :
17301731
VectorType getPassThruVectorType() {
17311732
return pass_thru().getType().cast<VectorType>();
17321733
}
1733-
VectorType getResultVectorType() {
1734+
VectorType getVectorType() {
17341735
return result().getType().cast<VectorType>();
17351736
}
17361737
}];
@@ -1744,7 +1745,7 @@ def Vector_CompressStoreOp :
17441745
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
17451746
Variadic<Index>:$indices,
17461747
VectorOfRankAndType<[1], [I1]>:$mask,
1747-
VectorOfRank<[1]>:$value)> {
1748+
VectorOfRank<[1]>:$valueToStore)> {
17481749

17491750
let summary = "writes elements selectively from a vector as defined by a mask";
17501751

@@ -1781,12 +1782,13 @@ def Vector_CompressStoreOp :
17811782
VectorType getMaskVectorType() {
17821783
return mask().getType().cast<VectorType>();
17831784
}
1784-
VectorType getValueVectorType() {
1785-
return value().getType().cast<VectorType>();
1785+
VectorType getVectorType() {
1786+
return valueToStore().getType().cast<VectorType>();
17861787
}
17871788
}];
1788-
let assemblyFormat = "$base `[` $indices `]` `,` $mask `,` $value attr-dict `:` "
1789-
"type($base) `,` type($mask) `,` type($value)";
1789+
let assemblyFormat =
1790+
"$base `[` $indices `]` `,` $mask `,` $valueToStore attr-dict `:` "
1791+
"type($base) `,` type($mask) `,` type($valueToStore)";
17901792
let hasCanonicalizer = 1;
17911793
}
17921794

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ class VectorGatherOpConversion
446446
return failure();
447447

448448
// Get index ptrs.
449-
VectorType vType = gather.getResultVectorType();
449+
VectorType vType = gather.getVectorType();
450450
Type iType = gather.getIndicesVectorType().getElementType();
451451
Value ptrs;
452452
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
@@ -480,7 +480,7 @@ class VectorScatterOpConversion
480480
return failure();
481481

482482
// Get index ptrs.
483-
VectorType vType = scatter.getValueVectorType();
483+
VectorType vType = scatter.getVectorType();
484484
Type iType = scatter.getIndicesVectorType().getElementType();
485485
Value ptrs;
486486
if (failed(getIndexedPtrs(rewriter, loc, adaptor.base(), adaptor.indices(),
@@ -489,7 +489,7 @@ class VectorScatterOpConversion
489489

490490
// Replace with the scatter intrinsic.
491491
rewriter.replaceOpWithNewOp<LLVM::masked_scatter>(
492-
scatter, adaptor.value(), ptrs, adaptor.mask(),
492+
scatter, adaptor.valueToStore(), ptrs, adaptor.mask(),
493493
rewriter.getI32IntegerAttr(align));
494494
return success();
495495
}
@@ -509,7 +509,7 @@ class VectorExpandLoadOpConversion
509509
MemRefType memRefType = expand.getMemRefType();
510510

511511
// Resolve address.
512-
auto vtype = typeConverter->convertType(expand.getResultVectorType());
512+
auto vtype = typeConverter->convertType(expand.getVectorType());
513513
Value ptr = this->getStridedElementPtr(loc, memRefType, adaptor.base(),
514514
adaptor.indices(), rewriter);
515515

@@ -537,7 +537,7 @@ class VectorCompressStoreOpConversion
537537
adaptor.indices(), rewriter);
538538

539539
rewriter.replaceOpWithNewOp<LLVM::masked_compressstore>(
540-
compress, adaptor.value(), ptr, adaptor.mask());
540+
compress, adaptor.valueToStore(), ptr, adaptor.mask());
541541
return success();
542542
}
543543
};

mlir/lib/Dialect/Vector/VectorOps.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2534,7 +2534,7 @@ void MaskedStoreOp::getCanonicalizationPatterns(
25342534
static LogicalResult verify(GatherOp op) {
25352535
VectorType indicesVType = op.getIndicesVectorType();
25362536
VectorType maskVType = op.getMaskVectorType();
2537-
VectorType resVType = op.getResultVectorType();
2537+
VectorType resVType = op.getVectorType();
25382538
MemRefType memType = op.getMemRefType();
25392539

25402540
if (resVType.getElementType() != memType.getElementType())
@@ -2580,15 +2580,15 @@ void GatherOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
25802580
static LogicalResult verify(ScatterOp op) {
25812581
VectorType indicesVType = op.getIndicesVectorType();
25822582
VectorType maskVType = op.getMaskVectorType();
2583-
VectorType valueVType = op.getValueVectorType();
2583+
VectorType valueVType = op.getVectorType();
25842584
MemRefType memType = op.getMemRefType();
25852585

25862586
if (valueVType.getElementType() != memType.getElementType())
2587-
return op.emitOpError("base and value element type should match");
2587+
return op.emitOpError("base and valueToStore element type should match");
25882588
if (valueVType.getDimSize(0) != indicesVType.getDimSize(0))
2589-
return op.emitOpError("expected value dim to match indices dim");
2589+
return op.emitOpError("expected valueToStore dim to match indices dim");
25902590
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
2591-
return op.emitOpError("expected value dim to match mask dim");
2591+
return op.emitOpError("expected valueToStore dim to match mask dim");
25922592
return success();
25932593
}
25942594

@@ -2624,7 +2624,7 @@ void ScatterOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
26242624
static LogicalResult verify(ExpandLoadOp op) {
26252625
VectorType maskVType = op.getMaskVectorType();
26262626
VectorType passVType = op.getPassThruVectorType();
2627-
VectorType resVType = op.getResultVectorType();
2627+
VectorType resVType = op.getVectorType();
26282628
MemRefType memType = op.getMemRefType();
26292629

26302630
if (resVType.getElementType() != memType.getElementType())
@@ -2671,15 +2671,15 @@ void ExpandLoadOp::getCanonicalizationPatterns(
26712671

26722672
static LogicalResult verify(CompressStoreOp op) {
26732673
VectorType maskVType = op.getMaskVectorType();
2674-
VectorType valueVType = op.getValueVectorType();
2674+
VectorType valueVType = op.getVectorType();
26752675
MemRefType memType = op.getMemRefType();
26762676

26772677
if (valueVType.getElementType() != memType.getElementType())
2678-
return op.emitOpError("base and value element type should match");
2678+
return op.emitOpError("base and valueToStore element type should match");
26792679
if (llvm::size(op.indices()) != memType.getRank())
26802680
return op.emitOpError("requires ") << memType.getRank() << " indices";
26812681
if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
2682-
return op.emitOpError("expected value dim to match mask dim");
2682+
return op.emitOpError("expected valueToStore dim to match mask dim");
26832683
return success();
26842684
}
26852685

@@ -2692,8 +2692,8 @@ class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
26922692
switch (get1DMaskFormat(compress.mask())) {
26932693
case MaskFormat::AllTrue:
26942694
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
2695-
compress, compress.value(), compress.base(), compress.indices(),
2696-
false);
2695+
compress, compress.valueToStore(), compress.base(),
2696+
compress.indices(), false);
26972697
return success();
26982698
case MaskFormat::AllFalse:
26992699
rewriter.eraseOp(compress);

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,7 +1332,7 @@ func @gather_pass_thru_type_mismatch(%base: memref<?xf32>, %indices: vector<16xi
13321332

13331333
func @scatter_base_type_mismatch(%base: memref<?xf64>, %indices: vector<16xi32>,
13341334
%mask: vector<16xi1>, %value: vector<16xf32>) {
1335-
// expected-error@+1 {{'vector.scatter' op base and value element type should match}}
1335+
// expected-error@+1 {{'vector.scatter' op base and valueToStore element type should match}}
13361336
vector.scatter %base[%indices], %mask, %value
13371337
: memref<?xf64>, vector<16xi32>, vector<16xi1>, vector<16xf32>
13381338
}
@@ -1350,7 +1350,7 @@ func @scatter_rank_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
13501350

13511351
func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32>,
13521352
%mask: vector<16xi1>, %value: vector<16xf32>) {
1353-
// expected-error@+1 {{'vector.scatter' op expected value dim to match indices dim}}
1353+
// expected-error@+1 {{'vector.scatter' op expected valueToStore dim to match indices dim}}
13541354
vector.scatter %base[%indices], %mask, %value
13551355
: memref<?xf32>, vector<17xi32>, vector<16xi1>, vector<16xf32>
13561356
}
@@ -1359,7 +1359,7 @@ func @scatter_dim_indices_mismatch(%base: memref<?xf32>, %indices: vector<17xi32
13591359

13601360
func @scatter_dim_mask_mismatch(%base: memref<?xf32>, %indices: vector<16xi32>,
13611361
%mask: vector<17xi1>, %value: vector<16xf32>) {
1362-
// expected-error@+1 {{'vector.scatter' op expected value dim to match mask dim}}
1362+
// expected-error@+1 {{'vector.scatter' op expected valueToStore dim to match mask dim}}
13631363
vector.scatter %base[%indices], %mask, %value
13641364
: memref<?xf32>, vector<16xi32>, vector<17xi1>, vector<16xf32>
13651365
}
@@ -1400,15 +1400,15 @@ func @expand_memref_mismatch(%base: memref<?x?xf32>, %mask: vector<16xi1>, %pass
14001400

14011401
func @compress_base_type_mismatch(%base: memref<?xf64>, %mask: vector<16xi1>, %value: vector<16xf32>) {
14021402
%c0 = constant 0 : index
1403-
// expected-error@+1 {{'vector.compressstore' op base and value element type should match}}
1403+
// expected-error@+1 {{'vector.compressstore' op base and valueToStore element type should match}}
14041404
vector.compressstore %base[%c0], %mask, %value : memref<?xf64>, vector<16xi1>, vector<16xf32>
14051405
}
14061406

14071407
// -----
14081408

14091409
func @compress_dim_mask_mismatch(%base: memref<?xf32>, %mask: vector<17xi1>, %value: vector<16xf32>) {
14101410
%c0 = constant 0 : index
1411-
// expected-error@+1 {{'vector.compressstore' op expected value dim to match mask dim}}
1411+
// expected-error@+1 {{'vector.compressstore' op expected valueToStore dim to match mask dim}}
14121412
vector.compressstore %base[%c0], %mask, %value : memref<?xf32>, vector<17xi1>, vector<16xf32>
14131413
}
14141414

0 commit comments

Comments
 (0)