Skip to content

Commit 0c1b1b2

Browse files
committed
[mlir][spirv] Implement SPIR-V lowering for vector.deinterleave
1. Added a conversion for vector.deinterleave to the VectorToSPIRV pass. 2. Added LIT tests for the new conversion.
1 parent bf7c505 commit 0c1b1b2

File tree

2 files changed

+121
-3
lines changed

2 files changed

+121
-3
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,74 @@ struct VectorInterleaveOpConvert final
618618
}
619619
};
620620

621+
struct VectorDeinterleaveOpConvert final
622+
: public OpConversionPattern<vector::DeinterleaveOp> {
623+
using OpConversionPattern::OpConversionPattern;
624+
625+
LogicalResult
626+
matchAndRewrite(vector::DeinterleaveOp deinterleaveOp, OpAdaptor adaptor,
627+
ConversionPatternRewriter &rewriter) const override {
628+
629+
// Check the result vector type.
630+
VectorType oldResultType = deinterleaveOp.getResultVectorType();
631+
Type newResultType = getTypeConverter()->convertType(oldResultType);
632+
if (!newResultType)
633+
return rewriter.notifyMatchFailure(deinterleaveOp,
634+
"unsupported result vector type");
635+
636+
// Get location.
637+
Location loc = deinterleaveOp->getLoc();
638+
639+
// Deinterleave the indices.
640+
VectorType sourceType = deinterleaveOp.getSourceVectorType();
641+
int n = sourceType.getNumElements();
642+
643+
// Output vectors of size 1 are converted to scalars by the type converter.
644+
// We cannot use `spirv::VectorShuffleOp` directly in this case, and need to
645+
// use `spirv::CompositeExtractOp`.
646+
if (n == 2) {
647+
spirv::CompositeExtractOp compositeExtractZero =
648+
rewriter.create<spirv::CompositeExtractOp>(
649+
loc, newResultType, adaptor.getSource(),
650+
rewriter.getI32ArrayAttr({0}));
651+
652+
spirv::CompositeExtractOp compositeExtractOne =
653+
rewriter.create<spirv::CompositeExtractOp>(
654+
loc, newResultType, adaptor.getSource(),
655+
rewriter.getI32ArrayAttr({1}));
656+
657+
rewriter.replaceOp(deinterleaveOp,
658+
{compositeExtractZero, compositeExtractOne});
659+
return success();
660+
}
661+
662+
// Indices for `res1`.
663+
auto seqEven = llvm::seq<int64_t>(n / 2);
664+
auto indicesEven =
665+
llvm::map_to_vector(seqEven, [](int i) { return i * 2; });
666+
667+
// Indices for `res2`.
668+
auto seqOdd = llvm::seq<int64_t>(n / 2);
669+
auto indicesOdd =
670+
llvm::map_to_vector(seqOdd, [](int i) { return i * 2 + 1; });
671+
672+
// Create two SPIR-V shuffles.
673+
spirv::VectorShuffleOp shuffleEven =
674+
rewriter.create<spirv::VectorShuffleOp>(
675+
loc, newResultType, adaptor.getSource(), adaptor.getSource(),
676+
rewriter.getI32ArrayAttr(indicesEven));
677+
678+
spirv::VectorShuffleOp shuffleOdd = rewriter.create<spirv::VectorShuffleOp>(
679+
loc, newResultType, adaptor.getSource(), adaptor.getSource(),
680+
rewriter.getI32ArrayAttr(indicesOdd));
681+
682+
// Replace deinterleaveOp with SPIR-V shuffles.
683+
rewriter.replaceOp(deinterleaveOp, {shuffleEven, shuffleOdd});
684+
685+
return success();
686+
}
687+
};
688+
621689
struct VectorLoadOpConverter final
622690
: public OpConversionPattern<vector::LoadOp> {
623691
using OpConversionPattern::OpConversionPattern;
@@ -862,9 +930,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
862930
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
863931
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
864932
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
865-
VectorInterleaveOpConvert, VectorSplatPattern, VectorLoadOpConverter,
866-
VectorStoreOpConverter>(typeConverter, patterns.getContext(),
867-
PatternBenefit(1));
933+
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
934+
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
935+
typeConverter, patterns.getContext(), PatternBenefit(1));
868936

869937
// Make sure that the more specialized dot product pattern has higher benefit
870938
// than the generic one that extracts all elements.

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,56 @@ func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf3
507507

508508
// -----
509509

510+
// CHECK-LABEL: func @deinterleave_return0
511+
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
512+
// CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
513+
// CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
514+
// CHECK: return %[[SHUFFLE0]]
515+
func.func @deinterleave_return0(%a: vector<4xf32>) -> vector<2xf32> {
516+
%0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
517+
return %0 : vector<2xf32>
518+
}
519+
520+
// -----
521+
522+
// CHECK-LABEL: func @deinterleave_return1
523+
// CHECK-SAME: (%[[ARG0:.+]]: vector<4xf32>)
524+
// CHECK: %[[SHUFFLE0:.*]] = spirv.VectorShuffle [0 : i32, 2 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
525+
// CHECK: %[[SHUFFLE1:.*]] = spirv.VectorShuffle [1 : i32, 3 : i32] %[[ARG0]], %[[ARG0]] : vector<4xf32>, vector<4xf32> -> vector<2xf32>
526+
// CHECK: return %[[SHUFFLE1]]
527+
func.func @deinterleave_return1(%a: vector<4xf32>) -> vector<2xf32> {
528+
%0, %1 = vector.deinterleave %a : vector<4xf32> -> vector<2xf32>
529+
return %1 : vector<2xf32>
530+
}
531+
532+
// -----
533+
534+
// CHECK-LABEL: func @deinterleave_scalar_return0
535+
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
536+
// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
537+
// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
538+
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT0]] : f32 to vector<1xf32>
539+
// CHECK: return %[[RES]]
540+
func.func @deinterleave_scalar_return0(%a: vector<2xf32>) -> vector<1xf32> {
541+
%0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
542+
return %0 : vector<1xf32>
543+
}
544+
545+
// -----
546+
547+
// CHECK-LABEL: func @deinterleave_scalar_return1
548+
// CHECK-SAME: (%[[ARG0:.+]]: vector<2xf32>)
549+
// CHECK: %[[EXTRACT0:.*]] = spirv.CompositeExtract %[[ARG0]][0 : i32] : vector<2xf32>
550+
// CHECK: %[[EXTRACT1:.*]] = spirv.CompositeExtract %[[ARG0]][1 : i32] : vector<2xf32>
551+
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[EXTRACT1]] : f32 to vector<1xf32>
552+
// CHECK: return %[[RES]]
553+
func.func @deinterleave_scalar_return1(%a: vector<2xf32>) -> vector<1xf32> {
554+
%0, %1 = vector.deinterleave %a: vector<2xf32> -> vector<1xf32>
555+
return %1 : vector<1xf32>
556+
}
557+
558+
// -----
559+
510560
// CHECK-LABEL: func @reduction_add
511561
// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
512562
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>

0 commit comments

Comments
 (0)