Skip to content

[mlir][vector][spirv] Lower vector.transfer_read and vector.transfer_write to SPIR-V #69708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 96 additions & 1 deletion mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,99 @@ struct VectorShuffleOpConvert final
}
};

struct VectorTransferReadOpConverter final
: public OpConversionPattern<vector::TransferReadOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::TransferReadOp transferReadOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (transferReadOp.getMask())
return rewriter.notifyMatchFailure(transferReadOp,
"unsupported transfer_read with mask");

if (transferReadOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(
transferReadOp,
"unsupported transfer_read with out-of-bound dimensions");

auto sourceType = transferReadOp.getSource().getType();
auto memrefType = dyn_cast<MemRefType>(sourceType);
if (!memrefType)
return rewriter.notifyMatchFailure(transferReadOp, "not a memref source");

auto attr =
dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
if (!attr)
return failure();

const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
auto loc = transferReadOp.getLoc();
Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getSource(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works when the permutation_map is an identity map? I don't see checks to exclude other cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I revoked the fix. I am still thinking why we need identity map here. Can you give more information about it?

adaptor.getIndices(), loc, rewriter);
if (!accessChain)
return failure();

spirv::StorageClass storageClass = attr.getValue();
auto vectorType = transferReadOp.getVectorType();
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
Value castedAccessChain =
rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
rewriter.replaceOpWithNewOp<spirv::LoadOp>(transferReadOp, vectorType,
castedAccessChain);

return success();
}
};

struct VectorTransferWriteOpConverter final
: public OpConversionPattern<vector::TransferWriteOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::TransferWriteOp transferWriteOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (transferWriteOp.getMask())
return rewriter.notifyMatchFailure(
transferWriteOp, "unsupported transfer_write with mask");

if (transferWriteOp.hasOutOfBoundsDim())
return rewriter.notifyMatchFailure(
transferWriteOp,
"unsupported transfer_write with out-of-bound dimensions");

auto sourceType = transferWriteOp.getSource().getType();
auto memrefType = dyn_cast<MemRefType>(sourceType);
if (!memrefType)
return rewriter.notifyMatchFailure(transferWriteOp,
"not a memref source");

auto attr =
dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
if (!attr)
return failure();

const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
auto loc = transferWriteOp.getLoc();
Value accessChain =
spirv::getElementPtr(typeConverter, memrefType, adaptor.getSource(),
adaptor.getIndices(), loc, rewriter);
if (!accessChain)
return failure();

spirv::StorageClass storageClass = attr.getValue();
auto vectorType = transferWriteOp.getVectorType();
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
Value castedAccessChain =
rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
rewriter.replaceOpWithNewOp<spirv::StoreOp>(
transferWriteOp, castedAccessChain, adaptor.getVector());

return success();
}
};

struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -622,7 +715,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorSplatPattern>(typeConverter, patterns.getContext());
VectorSplatPattern, VectorTransferReadOpConverter,
VectorTransferWriteOpConverter>(typeConverter,
patterns.getContext());
}

void mlir::populateVectorReductionToSPIRVDotProductPatterns(
Expand Down
103 changes: 103 additions & 0 deletions mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,106 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
%1 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
return %1 : vector<1xf32>
}

// -----

module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
} {

// CHECK-LABEL: @transfer_read
// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
// CHECK: %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
// CHECK: %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
// CHECK: %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
// CHECK: %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S5]] : vector<4xf32>
// CHECK: return %[[R0]]
func.func @transfer_read(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
%idx = arith.constant 0 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%idx], %cst_0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
return %0: vector<4xf32>
}

// CHECK-LABEL: @transfer_read_2d
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C1]] : index to i32
// CHECK: %[[CST0_1:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST0_2:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST4:.+]] = spirv.Constant 4 : i32
// CHECK: %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
// CHECK: %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
// CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
// CHECK: %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
// CHECK: %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
// CHECK: %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
// CHECK: %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
// CHECK: %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S8]] : vector<4xf32>
// CHECK: return %[[R0]]
func.func @transfer_read_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
%idx_0 = arith.constant 0 : index
%idx_1 = arith.constant 1 : index
%cst_0 = arith.constant 0.000000e+00 : f32
%0 = vector.transfer_read %arg0[%idx_0, %idx_1], %cst_0 {in_bounds = [true]} : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
return %0: vector<4xf32>
}

// CHECK-LABEL: @transfer_write
// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
// CHECK: %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
// CHECK: %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
// CHECK: %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
// CHECK: spirv.Store "StorageBuffer" %[[S5]], %[[ARG1]] : vector<4xf32>
func.func @transfer_write(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
%idx = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%idx] : vector<4xf32>, memref<4xf32, #spirv.storage_class<StorageBuffer>>
return
}

// CHECK-LABEL: @transfer_write_2d
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32
// CHECK: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
// CHECK: %[[C1:.+]] = arith.constant 1 : index
// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C1]] : index to i32
// CHECK: %[[CST0_1:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST0_2:.+]] = spirv.Constant 0 : i32
// CHECK: %[[CST4:.+]] = spirv.Constant 4 : i32
// CHECK: %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
// CHECK: %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
// CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
// CHECK: %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
// CHECK: %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
// CHECK: %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
// CHECK: %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
// CHECK: spirv.Store "StorageBuffer" %[[S8]], %[[ARG1]] : vector<4xf32>
func.func @transfer_write_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
%idx_0 = arith.constant 0 : index
%idx_1 = arith.constant 1 : index
vector.transfer_write %arg1, %arg0[%idx_0, %idx_1] {in_bounds = [true]} : vector<4xf32>, memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
return
}

} // end module