Skip to content

Commit 3049c76

Browse files
authored
[mlir][vector][spirv] Lower vector.load and vector.store to SPIR-V (#71674)
Add patterns to lower vector.load to spirv.load and vector.store to spirv.store.
1 parent 28233b1 commit 3049c76

File tree

2 files changed

+182
-9
lines changed

2 files changed

+182
-9
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,76 @@ struct VectorShuffleOpConvert final
509509
}
510510
};
511511

512+
struct VectorLoadOpConverter final
513+
: public OpConversionPattern<vector::LoadOp> {
514+
using OpConversionPattern::OpConversionPattern;
515+
516+
LogicalResult
517+
matchAndRewrite(vector::LoadOp loadOp, OpAdaptor adaptor,
518+
ConversionPatternRewriter &rewriter) const override {
519+
auto memrefType = loadOp.getMemRefType();
520+
auto attr =
521+
dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
522+
if (!attr)
523+
return rewriter.notifyMatchFailure(
524+
loadOp, "expected spirv.storage_class memory space");
525+
526+
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
527+
auto loc = loadOp.getLoc();
528+
Value accessChain =
529+
spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
530+
adaptor.getIndices(), loc, rewriter);
531+
if (!accessChain)
532+
return rewriter.notifyMatchFailure(
533+
loadOp, "failed to get memref element pointer");
534+
535+
spirv::StorageClass storageClass = attr.getValue();
536+
auto vectorType = loadOp.getVectorType();
537+
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
538+
Value castedAccessChain =
539+
rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
540+
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
541+
castedAccessChain);
542+
543+
return success();
544+
}
545+
};
546+
547+
struct VectorStoreOpConverter final
548+
: public OpConversionPattern<vector::StoreOp> {
549+
using OpConversionPattern::OpConversionPattern;
550+
551+
LogicalResult
552+
matchAndRewrite(vector::StoreOp storeOp, OpAdaptor adaptor,
553+
ConversionPatternRewriter &rewriter) const override {
554+
auto memrefType = storeOp.getMemRefType();
555+
auto attr =
556+
dyn_cast_or_null<spirv::StorageClassAttr>(memrefType.getMemorySpace());
557+
if (!attr)
558+
return rewriter.notifyMatchFailure(
559+
storeOp, "expected spirv.storage_class memory space");
560+
561+
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
562+
auto loc = storeOp.getLoc();
563+
Value accessChain =
564+
spirv::getElementPtr(typeConverter, memrefType, adaptor.getBase(),
565+
adaptor.getIndices(), loc, rewriter);
566+
if (!accessChain)
567+
return rewriter.notifyMatchFailure(
568+
storeOp, "failed to get memref element pointer");
569+
570+
spirv::StorageClass storageClass = attr.getValue();
571+
auto vectorType = storeOp.getVectorType();
572+
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
573+
Value castedAccessChain =
574+
rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
575+
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
576+
adaptor.getValueToStore());
577+
578+
return success();
579+
}
580+
};
581+
512582
struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
513583
using OpRewritePattern::OpRewritePattern;
514584

@@ -614,15 +684,16 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
614684

615685
void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
616686
RewritePatternSet &patterns) {
617-
patterns.add<VectorBitcastConvert, VectorBroadcastConvert,
618-
VectorExtractElementOpConvert, VectorExtractOpConvert,
619-
VectorExtractStridedSliceOpConvert,
620-
VectorFmaOpConvert<spirv::GLFmaOp>,
621-
VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
622-
VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
623-
VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
624-
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
625-
VectorSplatPattern>(typeConverter, patterns.getContext());
687+
patterns.add<
688+
VectorBitcastConvert, VectorBroadcastConvert,
689+
VectorExtractElementOpConvert, VectorExtractOpConvert,
690+
VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
691+
VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
692+
VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
693+
VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
694+
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
695+
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
696+
typeConverter, patterns.getContext());
626697
}
627698

628699
void mlir::populateVectorReductionToSPIRVDotProductPatterns(

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,3 +631,105 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
631631
%1 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
632632
return %1 : vector<1xf32>
633633
}
634+
635+
// -----
636+
637+
module attributes {
638+
spirv.target_env = #spirv.target_env<
639+
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
640+
} {
641+
642+
// CHECK-LABEL: @vector_load
643+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
644+
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
645+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
646+
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
647+
// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
648+
// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
649+
// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
650+
// CHECK: %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
651+
// CHECK: %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
652+
// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
653+
// CHECK: %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
654+
// CHECK: %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S5]] : vector<4xf32>
655+
// CHECK: return %[[R0]] : vector<4xf32>
656+
func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
657+
%idx = arith.constant 0 : index
658+
%cst_0 = arith.constant 0.000000e+00 : f32
659+
%0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
660+
return %0: vector<4xf32>
661+
}
662+
663+
// CHECK-LABEL: @vector_load_2d
664+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
665+
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
666+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
667+
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
668+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
669+
// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C1]] : index to i32
670+
// CHECK: %[[CST0_1:.+]] = spirv.Constant 0 : i32
671+
// CHECK: %[[CST0_2:.+]] = spirv.Constant 0 : i32
672+
// CHECK: %[[CST4:.+]] = spirv.Constant 4 : i32
673+
// CHECK: %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
674+
// CHECK: %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
675+
// CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
676+
// CHECK: %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
677+
// CHECK: %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
678+
// CHECK: %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
679+
// CHECK: %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
680+
// CHECK: %[[R0:.+]] = spirv.Load "StorageBuffer" %[[S8]] : vector<4xf32>
681+
// CHECK: return %[[R0]] : vector<4xf32>
682+
func.func @vector_load_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
683+
%idx_0 = arith.constant 0 : index
684+
%idx_1 = arith.constant 1 : index
685+
%0 = vector.load %arg0[%idx_0, %idx_1] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
686+
return %0: vector<4xf32>
687+
}
688+
689+
// CHECK-LABEL: @vector_store
690+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
691+
// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
692+
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
693+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
694+
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
695+
// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
696+
// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
697+
// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
698+
// CHECK: %[[S2:.+]] = spirv.IMul %[[CST3]], %[[S1]] : i32
699+
// CHECK: %[[S3:.+]] = spirv.IAdd %[[CST2]], %[[S2]] : i32
700+
// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S3]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
701+
// CHECK: %[[S5:.+]] = spirv.Bitcast %[[S4]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
702+
// CHECK: spirv.Store "StorageBuffer" %[[S5]], %[[ARG1]] : vector<4xf32>
703+
func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
704+
%idx = arith.constant 0 : index
705+
vector.store %arg1, %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
706+
return
707+
}
708+
709+
// CHECK-LABEL: @vector_store_2d
710+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
711+
// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
712+
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
713+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
714+
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
715+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
716+
// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C1]] : index to i32
717+
// CHECK: %[[CST0_1:.+]] = spirv.Constant 0 : i32
718+
// CHECK: %[[CST0_2:.+]] = spirv.Constant 0 : i32
719+
// CHECK: %[[CST4:.+]] = spirv.Constant 4 : i32
720+
// CHECK: %[[S3:.+]] = spirv.IMul %[[CST4]], %[[S1]] : i32
721+
// CHECK: %[[S4:.+]] = spirv.IAdd %[[CST0_2]], %[[S3]] : i32
722+
// CHECK: %[[CST1:.+]] = spirv.Constant 1 : i32
723+
// CHECK: %[[S5:.+]] = spirv.IMul %[[CST1]], %[[S2]] : i32
724+
// CHECK: %[[S6:.+]] = spirv.IAdd %[[S4]], %[[S5]] : i32
725+
// CHECK: %[[S7:.+]] = spirv.AccessChain %[[S0]][%[[CST0_1]], %[[S6]]] : !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
726+
// CHECK: %[[S8:.+]] = spirv.Bitcast %[[S7]] : !spirv.ptr<f32, StorageBuffer> to !spirv.ptr<vector<4xf32>, StorageBuffer>
727+
// CHECK: spirv.Store "StorageBuffer" %[[S8]], %[[ARG1]] : vector<4xf32>
728+
func.func @vector_store_2d(%arg0 : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<4xf32>) {
729+
%idx_0 = arith.constant 0 : index
730+
%idx_1 = arith.constant 1 : index
731+
vector.store %arg1, %arg0[%idx_0, %idx_1] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>>, vector<4xf32>
732+
return
733+
}
734+
735+
} // end module

0 commit comments

Comments
 (0)