Skip to content

Commit 2ad297d

Browse files
committed
[mlir][spirv] Handle zero-element tensors in spirv type conversion
Return gracefully instead of crashing. Add missing type conversion tests. Fixes: #61044 Reviewed By: qedawkins Differential Revision: https://reviews.llvm.org/D156942
1 parent 7ef1718 commit 2ad297d

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -390,8 +390,14 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
390390
return nullptr;
391391
}
392392

393-
auto arrayElemCount = *tensorSize / *scalarSize;
394-
auto arrayElemType = convertScalarType(targetEnv, options, scalarType);
393+
int64_t arrayElemCount = *tensorSize / *scalarSize;
394+
if (arrayElemCount == 0) {
395+
LLVM_DEBUG(llvm::dbgs()
396+
<< type << " illegal: cannot handle zero-element tensors\n");
397+
return nullptr;
398+
}
399+
400+
Type arrayElemType = convertScalarType(targetEnv, options, scalarType);
395401
if (!arrayElemType)
396402
return nullptr;
397403
std::optional<int64_t> arrayElemSize =

mlir/test/Conversion/TensorToSPIRV/tensor-ops-to-spirv.mlir

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
// RUN: mlir-opt -split-input-file -convert-tensor-to-spirv -verify-diagnostics %s | FileCheck %s
1+
// RUN: mlir-opt --split-input-file --convert-tensor-to-spirv \
2+
// RUN: --verify-diagnostics %s | FileCheck %s
23

34
//===----------------------------------------------------------------------===//
45
// tensor.extract
@@ -27,3 +28,38 @@ func.func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 {
2728
// CHECK: spirv.ReturnValue %[[VAL]]
2829
return %extract : i32
2930
}
31+
32+
// -----
33+
34+
//===----------------------------------------------------------------------===//
35+
// Type conversion
36+
//===----------------------------------------------------------------------===//
37+
38+
// CHECK-LABEL: func @tensor_0d
39+
// CHECK-NEXT: spirv.Constant 1 : i32
40+
func.func @tensor_0d() -> () {
41+
%x = arith.constant dense<1> : tensor<i32>
42+
return
43+
}
44+
45+
// CHECK-LABEL: func @tensor_1d
46+
// CHECK-NEXT: spirv.Constant dense<[1, 2, 3]> : tensor<3xi32> : !spirv.array<3 x i32>
47+
func.func @tensor_1d() -> () {
48+
%x = arith.constant dense<[1, 2, 3]> : tensor<3xi32>
49+
return
50+
}
51+
52+
// CHECK-LABEL: func @tensor_2d
53+
// CHECK-NEXT: spirv.Constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spirv.array<6 x i32>
54+
func.func @tensor_2d() -> () {
55+
%x = arith.constant dense<[[1, 2, 3], [4, 5, 6]]> : tensor<2x3xi32>
56+
return
57+
}
58+
59+
// We do not handle zero-element tensors yet. Just make we do not crash on them.
60+
// CHECK-LABEL: func @tensor_2d_empty
61+
// CHECK-NEXT: arith.constant dense<>
62+
func.func @tensor_2d_empty() -> () {
63+
%x = arith.constant dense<> : tensor<2x0xi32>
64+
return
65+
}

0 commit comments

Comments
 (0)