Skip to content

Commit f597218

Browse files
committed
Handle one-element input vector case and remove cmake/bazel dependencies
1 parent 2c4d103 commit f597218

File tree

4 files changed

+27
-3
lines changed

4 files changed

+27
-3
lines changed

mlir/lib/Conversion/VectorToSPIRV/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,5 @@ add_mlir_conversion_library(MLIRVectorToSPIRV
1414
MLIRSPIRVDialect
1515
MLIRSPIRVConversion
1616
MLIRVectorDialect
17-
MLIRVectorTransforms
1817
MLIRTransforms
1918
)

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,19 @@ struct VectorInterleaveOpConvert final
601601

602602
// Interleave the indices
603603
int n = sourceType.getNumElements();
604+
605+
// Input vectors of size 1 are converted to scalars by the type converter.
606+
// We cannot use spirv::VectorShuffleOp directly in this case, and need to
607+
// use spirv::CompositeConstructOp.
608+
if (n == 1) {
609+
SmallVector<Value> newOperands(2);
610+
newOperands[0] = adaptor.getLhs();
611+
newOperands[1] = adaptor.getRhs();
612+
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(
613+
interleaveOp, newResultType, newOperands);
614+
return success();
615+
}
616+
604617
auto seq = llvm::seq<int64_t>(2 * n);
605618
auto indices = llvm::to_vector(
606619
llvm::map_range(seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
@@ -609,7 +622,7 @@ struct VectorInterleaveOpConvert final
609622
rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
610623
interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
611624
rewriter.getI32ArrayAttr(indices));
612-
625+
613626
return success();
614627
}
615628
};

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,19 @@ func.func @interleave(%a: vector<2xf32>, %b: vector<2xf32>) -> vector<4xf32> {
494494

495495
// -----
496496

497+
// CHECK-LABEL: func @interleave_size1
498+
// CHECK-SAME: (%[[ARG0:.+]]: vector<1xf32>, %[[ARG1:.+]]: vector<1xf32>)
499+
// CHECK: %[[V0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : vector<1xf32> to f32
500+
// CHECK: %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32
501+
// CHECK: %[[RES:.*]] = spirv.CompositeConstruct %[[V0]], %[[V1]] : (f32, f32) -> vector<2xf32>
502+
// CHECK: return %[[RES]]
503+
func.func @interleave_size1(%a: vector<1xf32>, %b: vector<1xf32>) -> vector<2xf32> {
504+
%0 = vector.interleave %a, %b : vector<1xf32>
505+
return %0 : vector<2xf32>
506+
}
507+
508+
// -----
509+
497510
// CHECK-LABEL: func @reduction_add
498511
// CHECK-SAME: (%[[V:.+]]: vector<4xi32>)
499512
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<4xi32>

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4976,7 +4976,6 @@ cc_library(
49764976
":VectorToLLVM",
49774977
":VectorToSCF",
49784978
":VectorTransformOpsIncGen",
4979-
":VectorTransforms",
49804979
":X86VectorTransforms",
49814980
],
49824981
)

0 commit comments

Comments
 (0)