Skip to content

Commit 8345b86

Browse files
author
Nicolas Vasilache
committed
[mlir][Vector] Add lowering of 1-D vector transfer_read/write to masked load/store
Summary: This revision adds support to lower 1-D vector transfers to LLVM. A mask of the vector length is created that compares the base offset + linear index to the dim of the vector. In each position where this does not overflow (i.e. offset + vector index < dim), the mask is set to 1. A notable fact is that the lowering uses llvm.dialect_cast to allow writing code in the simplest form by targeting the simplest mix of vector and LLVM dialects and letting other conversions kick in. Differential Revision: https://reviews.llvm.org/D77703
1 parent 413467f commit 8345b86

File tree

8 files changed

+318
-70
lines changed

8 files changed

+318
-70
lines changed

mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -398,6 +398,29 @@ class ConvertToLLVMPattern : public ConversionPattern {
398398
Value createIndexConstant(ConversionPatternRewriter &builder, Location loc,
399399
uint64_t value) const;
400400

401+
// Given subscript indices and array sizes in row-major order,
402+
// i_n, i_{n-1}, ..., i_1
403+
// s_n, s_{n-1}, ..., s_1
404+
// obtain a value that corresponds to the linearized subscript
405+
// \sum_k i_k * \prod_{j=1}^{k-1} s_j
406+
// by accumulating the running linearized value.
407+
// Note that `indices` and `allocSizes` are passed in the same order as they
408+
// appear in load/store operations and memref type declarations.
409+
Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
410+
ArrayRef<Value> indices,
411+
ArrayRef<Value> allocSizes) const;
412+
413+
// This is a strided getElementPtr variant that linearizes subscripts as:
414+
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
415+
Value getStridedElementPtr(Location loc, Type elementTypePtr,
416+
Value descriptor, ArrayRef<Value> indices,
417+
ArrayRef<int64_t> strides, int64_t offset,
418+
ConversionPatternRewriter &rewriter) const;
419+
420+
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
421+
ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
422+
llvm::Module &module) const;
423+
401424
protected:
402425
/// Reference to the type converter, with potential extensions.
403426
LLVMTypeConverter &typeConverter;

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class LLVMType : public mlir::Type::TypeBase<LLVMType, mlir::Type,
7373

7474
/// Vector type utilities.
7575
LLVMType getVectorElementType();
76+
unsigned getVectorNumElements();
7677
bool isVectorTy();
7778

7879
/// Function type utilities.

mlir/include/mlir/IR/Builders.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class Builder {
111111
IntegerAttr getI16IntegerAttr(int16_t value);
112112
IntegerAttr getI32IntegerAttr(int32_t value);
113113
IntegerAttr getI64IntegerAttr(int64_t value);
114+
IntegerAttr getIndexAttr(int64_t value);
114115

115116
/// Signed and unsigned integer attribute getters.
116117
IntegerAttr getSI32IntegerAttr(int32_t value);

mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp

Lines changed: 55 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,61 @@ Value ConvertToLLVMPattern::createIndexConstant(
735735
return createIndexAttrConstant(builder, loc, getIndexType(), value);
736736
}
737737

738+
Value ConvertToLLVMPattern::linearizeSubscripts(
739+
ConversionPatternRewriter &builder, Location loc, ArrayRef<Value> indices,
740+
ArrayRef<Value> allocSizes) const {
741+
assert(indices.size() == allocSizes.size() &&
742+
"mismatching number of indices and allocation sizes");
743+
assert(!indices.empty() && "cannot linearize a 0-dimensional access");
744+
745+
Value linearized = indices.front();
746+
for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
747+
linearized = builder.create<LLVM::MulOp>(
748+
loc, this->getIndexType(), ArrayRef<Value>{linearized, allocSizes[i]});
749+
linearized = builder.create<LLVM::AddOp>(
750+
loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]});
751+
}
752+
return linearized;
753+
}
754+
755+
Value ConvertToLLVMPattern::getStridedElementPtr(
756+
Location loc, Type elementTypePtr, Value descriptor,
757+
ArrayRef<Value> indices, ArrayRef<int64_t> strides, int64_t offset,
758+
ConversionPatternRewriter &rewriter) const {
759+
MemRefDescriptor memRefDescriptor(descriptor);
760+
761+
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
762+
Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
763+
? memRefDescriptor.offset(rewriter, loc)
764+
: this->createIndexConstant(rewriter, loc, offset);
765+
766+
for (int i = 0, e = indices.size(); i < e; ++i) {
767+
Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
768+
? memRefDescriptor.stride(rewriter, loc, i)
769+
: this->createIndexConstant(rewriter, loc, strides[i]);
770+
Value additionalOffset =
771+
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
772+
offsetValue =
773+
rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
774+
}
775+
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
776+
}
777+
778+
Value ConvertToLLVMPattern::getDataPtr(Location loc, MemRefType type,
779+
Value memRefDesc,
780+
ArrayRef<Value> indices,
781+
ConversionPatternRewriter &rewriter,
782+
llvm::Module &module) const {
783+
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
784+
int64_t offset;
785+
SmallVector<int64_t, 4> strides;
786+
auto successStrides = getStridesAndOffset(type, strides, offset);
787+
assert(succeeded(successStrides) && "unexpected non-strided memref");
788+
(void)successStrides;
789+
return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
790+
offset, rewriter);
791+
}
792+
738793
/// Only retain those attributes that are not constructed by
739794
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
740795
/// attributes.
@@ -1913,70 +1968,6 @@ struct LoadStoreOpLowering : public ConvertOpToLLVMPattern<Derived> {
19131968
MemRefType type = cast<Derived>(op).getMemRefType();
19141969
return isSupportedMemRefType(type) ? success() : failure();
19151970
}
1916-
1917-
// Given subscript indices and array sizes in row-major order,
1918-
// i_n, i_{n-1}, ..., i_1
1919-
// s_n, s_{n-1}, ..., s_1
1920-
// obtain a value that corresponds to the linearized subscript
1921-
// \sum_k i_k * \prod_{j=1}^{k-1} s_j
1922-
// by accumulating the running linearized value.
1923-
// Note that `indices` and `allocSizes` are passed in the same order as they
1924-
// appear in load/store operations and memref type declarations.
1925-
Value linearizeSubscripts(ConversionPatternRewriter &builder, Location loc,
1926-
ArrayRef<Value> indices,
1927-
ArrayRef<Value> allocSizes) const {
1928-
assert(indices.size() == allocSizes.size() &&
1929-
"mismatching number of indices and allocation sizes");
1930-
assert(!indices.empty() && "cannot linearize a 0-dimensional access");
1931-
1932-
Value linearized = indices.front();
1933-
for (int i = 1, nSizes = allocSizes.size(); i < nSizes; ++i) {
1934-
linearized = builder.create<LLVM::MulOp>(
1935-
loc, this->getIndexType(),
1936-
ArrayRef<Value>{linearized, allocSizes[i]});
1937-
linearized = builder.create<LLVM::AddOp>(
1938-
loc, this->getIndexType(), ArrayRef<Value>{linearized, indices[i]});
1939-
}
1940-
return linearized;
1941-
}
1942-
1943-
// This is a strided getElementPtr variant that linearizes subscripts as:
1944-
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
1945-
Value getStridedElementPtr(Location loc, Type elementTypePtr,
1946-
Value descriptor, ArrayRef<Value> indices,
1947-
ArrayRef<int64_t> strides, int64_t offset,
1948-
ConversionPatternRewriter &rewriter) const {
1949-
MemRefDescriptor memRefDescriptor(descriptor);
1950-
1951-
Value base = memRefDescriptor.alignedPtr(rewriter, loc);
1952-
Value offsetValue = offset == MemRefType::getDynamicStrideOrOffset()
1953-
? memRefDescriptor.offset(rewriter, loc)
1954-
: this->createIndexConstant(rewriter, loc, offset);
1955-
1956-
for (int i = 0, e = indices.size(); i < e; ++i) {
1957-
Value stride = strides[i] == MemRefType::getDynamicStrideOrOffset()
1958-
? memRefDescriptor.stride(rewriter, loc, i)
1959-
: this->createIndexConstant(rewriter, loc, strides[i]);
1960-
Value additionalOffset =
1961-
rewriter.create<LLVM::MulOp>(loc, indices[i], stride);
1962-
offsetValue =
1963-
rewriter.create<LLVM::AddOp>(loc, offsetValue, additionalOffset);
1964-
}
1965-
return rewriter.create<LLVM::GEPOp>(loc, elementTypePtr, base, offsetValue);
1966-
}
1967-
1968-
Value getDataPtr(Location loc, MemRefType type, Value memRefDesc,
1969-
ArrayRef<Value> indices, ConversionPatternRewriter &rewriter,
1970-
llvm::Module &module) const {
1971-
LLVM::LLVMType ptrType = MemRefDescriptor(memRefDesc).getElementType();
1972-
int64_t offset;
1973-
SmallVector<int64_t, 4> strides;
1974-
auto successStrides = getStridesAndOffset(type, strides, offset);
1975-
assert(succeeded(successStrides) && "unexpected non-strided memref");
1976-
(void)successStrides;
1977-
return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides,
1978-
offset, rewriter);
1979-
}
19801971
};
19811972

19821973
// Load operation is lowered to obtaining a pointer to the indexed element

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 139 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1515
#include "mlir/Dialect/StandardOps/IR/Ops.h"
1616
#include "mlir/Dialect/Vector/VectorOps.h"
17+
#include "mlir/IR/AffineMap.h"
1718
#include "mlir/IR/Attributes.h"
1819
#include "mlir/IR/Builders.h"
1920
#include "mlir/IR/MLIRContext.h"
@@ -894,6 +895,129 @@ class VectorTypeCastOpConversion : public ConvertToLLVMPattern {
894895
}
895896
};
896897

898+
template <typename ConcreteOp>
899+
void replaceTransferOp(ConversionPatternRewriter &rewriter,
900+
LLVMTypeConverter &typeConverter, Location loc,
901+
Operation *op, ArrayRef<Value> operands, Value dataPtr,
902+
Value mask);
903+
904+
template <>
905+
void replaceTransferOp<TransferReadOp>(ConversionPatternRewriter &rewriter,
906+
LLVMTypeConverter &typeConverter,
907+
Location loc, Operation *op,
908+
ArrayRef<Value> operands, Value dataPtr,
909+
Value mask) {
910+
auto xferOp = cast<TransferReadOp>(op);
911+
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
912+
VectorType fillType = xferOp.getVectorType();
913+
Value fill = rewriter.create<SplatOp>(loc, fillType, xferOp.padding());
914+
fill = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(fillType), fill);
915+
916+
auto vecTy = toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
917+
rewriter.replaceOpWithNewOp<LLVM::MaskedLoadOp>(
918+
op, vecTy, dataPtr, mask, ValueRange{fill},
919+
rewriter.getI32IntegerAttr(1));
920+
}
921+
922+
template <>
923+
void replaceTransferOp<TransferWriteOp>(ConversionPatternRewriter &rewriter,
924+
LLVMTypeConverter &typeConverter,
925+
Location loc, Operation *op,
926+
ArrayRef<Value> operands, Value dataPtr,
927+
Value mask) {
928+
auto adaptor = TransferWriteOpOperandAdaptor(operands);
929+
rewriter.replaceOpWithNewOp<LLVM::MaskedStoreOp>(
930+
op, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(1));
931+
}
932+
933+
static TransferReadOpOperandAdaptor
934+
getTransferOpAdapter(TransferReadOp xferOp, ArrayRef<Value> operands) {
935+
return TransferReadOpOperandAdaptor(operands);
936+
}
937+
938+
static TransferWriteOpOperandAdaptor
939+
getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef<Value> operands) {
940+
return TransferWriteOpOperandAdaptor(operands);
941+
}
942+
943+
/// Conversion pattern that converts a 1-D vector transfer read/write op in a
944+
/// sequence of:
945+
/// 1. Bitcast to vector form.
946+
/// 2. Create an offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
947+
/// 3. Create a mask where offsetVector is compared against memref upper bound.
948+
/// 4. Rewrite op as a masked read or write.
949+
template <typename ConcreteOp>
950+
class VectorTransferConversion : public ConvertToLLVMPattern {
951+
public:
952+
explicit VectorTransferConversion(MLIRContext *context,
953+
LLVMTypeConverter &typeConv)
954+
: ConvertToLLVMPattern(ConcreteOp::getOperationName(), context,
955+
typeConv) {}
956+
957+
LogicalResult
958+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
959+
ConversionPatternRewriter &rewriter) const override {
960+
auto xferOp = cast<ConcreteOp>(op);
961+
auto adaptor = getTransferOpAdapter(xferOp, operands);
962+
if (xferOp.getMemRefType().getRank() != 1)
963+
return failure();
964+
if (!xferOp.permutation_map().isIdentity())
965+
return failure();
966+
967+
auto toLLVMTy = [&](Type t) { return typeConverter.convertType(t); };
968+
969+
Location loc = op->getLoc();
970+
Type i64Type = rewriter.getIntegerType(64);
971+
MemRefType memRefType = xferOp.getMemRefType();
972+
973+
// 1. Get the source/dst address as an LLVM vector pointer.
974+
// TODO: support alignment when possible.
975+
Value dataPtr = getDataPtr(loc, memRefType, adaptor.memref(),
976+
adaptor.indices(), rewriter, getModule());
977+
auto vecTy =
978+
toLLVMTy(xferOp.getVectorType()).template cast<LLVM::LLVMType>();
979+
auto vectorDataPtr =
980+
rewriter.create<LLVM::BitcastOp>(loc, vecTy.getPointerTo(), dataPtr);
981+
982+
// 2. Create a vector with linear indices [ 0 .. vector_length - 1 ].
983+
unsigned vecWidth = vecTy.getVectorNumElements();
984+
VectorType vectorCmpType = VectorType::get(vecWidth, i64Type);
985+
SmallVector<int64_t, 8> indices;
986+
indices.reserve(vecWidth);
987+
for (unsigned i = 0; i < vecWidth; ++i)
988+
indices.push_back(i);
989+
Value linearIndices = rewriter.create<ConstantOp>(
990+
loc, vectorCmpType,
991+
DenseElementsAttr::get(vectorCmpType, ArrayRef<int64_t>(indices)));
992+
linearIndices = rewriter.create<LLVM::DialectCastOp>(
993+
loc, toLLVMTy(vectorCmpType), linearIndices);
994+
995+
// 3. Create offsetVector = [ offset + 0 .. offset + vector_length - 1 ].
996+
Value offsetIndex = *(xferOp.indices().begin());
997+
offsetIndex = rewriter.create<IndexCastOp>(
998+
loc, vectorCmpType.getElementType(), offsetIndex);
999+
Value base = rewriter.create<SplatOp>(loc, vectorCmpType, offsetIndex);
1000+
Value offsetVector = rewriter.create<AddIOp>(loc, base, linearIndices);
1001+
1002+
// 4. Let dim the memref dimension, compute the vector comparison mask:
1003+
// [ offset + 0 .. offset + vector_length - 1 ] < [ dim .. dim ]
1004+
Value dim = rewriter.create<DimOp>(loc, xferOp.memref(), 0);
1005+
dim =
1006+
rewriter.create<IndexCastOp>(loc, vectorCmpType.getElementType(), dim);
1007+
dim = rewriter.create<SplatOp>(loc, vectorCmpType, dim);
1008+
Value mask =
1009+
rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, offsetVector, dim);
1010+
mask = rewriter.create<LLVM::DialectCastOp>(loc, toLLVMTy(mask.getType()),
1011+
mask);
1012+
1013+
// 5. Rewrite as a masked read / write.
1014+
replaceTransferOp<ConcreteOp>(rewriter, typeConverter, loc, op, operands,
1015+
vectorDataPtr, mask);
1016+
1017+
return success();
1018+
}
1019+
};
1020+
8971021
class VectorPrintOpConversion : public ConvertToLLVMPattern {
8981022
public:
8991023
explicit VectorPrintOpConversion(MLIRContext *context,
@@ -1079,16 +1203,25 @@ class VectorStridedSliceOpConversion : public OpRewritePattern<StridedSliceOp> {
10791203
void mlir::populateVectorToLLVMConversionPatterns(
10801204
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
10811205
MLIRContext *ctx = converter.getDialect()->getContext();
1206+
// clang-format off
10821207
patterns.insert<VectorFMAOpNDRewritePattern,
10831208
VectorInsertStridedSliceOpDifferentRankRewritePattern,
10841209
VectorInsertStridedSliceOpSameRankRewritePattern,
10851210
VectorStridedSliceOpConversion>(ctx);
1086-
patterns.insert<VectorBroadcastOpConversion, VectorReductionOpConversion,
1087-
VectorShuffleOpConversion, VectorExtractElementOpConversion,
1088-
VectorExtractOpConversion, VectorFMAOp1DConversion,
1089-
VectorInsertElementOpConversion, VectorInsertOpConversion,
1090-
VectorTypeCastOpConversion, VectorPrintOpConversion>(
1091-
ctx, converter);
1211+
patterns
1212+
.insert<VectorBroadcastOpConversion,
1213+
VectorReductionOpConversion,
1214+
VectorShuffleOpConversion,
1215+
VectorExtractElementOpConversion,
1216+
VectorExtractOpConversion,
1217+
VectorFMAOp1DConversion,
1218+
VectorInsertElementOpConversion,
1219+
VectorInsertOpConversion,
1220+
VectorPrintOpConversion,
1221+
VectorTransferConversion<TransferReadOp>,
1222+
VectorTransferConversion<TransferWriteOp>,
1223+
VectorTypeCastOpConversion>(ctx, converter);
1224+
// clang-format on
10921225
}
10931226

10941227
void mlir::populateVectorToLLVMMatrixConversionPatterns(

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,6 +1774,9 @@ bool LLVMType::isArrayTy() { return getUnderlyingType()->isArrayTy(); }
17741774
LLVMType LLVMType::getVectorElementType() {
17751775
return get(getContext(), getUnderlyingType()->getVectorElementType());
17761776
}
1777+
unsigned LLVMType::getVectorNumElements() {
1778+
return getUnderlyingType()->getVectorNumElements();
1779+
}
17771780
bool LLVMType::isVectorTy() { return getUnderlyingType()->isVectorTy(); }
17781781

17791782
/// Function type utilities.

mlir/lib/IR/Builders.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ DictionaryAttr Builder::getDictionaryAttr(ArrayRef<NamedAttribute> value) {
9393
return DictionaryAttr::get(value, context);
9494
}
9595

96+
IntegerAttr Builder::getIndexAttr(int64_t value) {
97+
return IntegerAttr::get(getIndexType(), APInt(64, value));
98+
}
99+
96100
IntegerAttr Builder::getI64IntegerAttr(int64_t value) {
97101
return IntegerAttr::get(getIntegerType(64), APInt(64, value));
98102
}

0 commit comments

Comments
 (0)