Skip to content

Commit 009228a

Browse files
committed
Add verifier checks for Gather
This adds verifier checks for the gather op to make sure the shapes of inputs and output are consistent with respect to spec. Signed-off-by: Tai Ly <[email protected]> Change-Id: I16685bceef25f428669c5412d897b6918a424119
1 parent 8b38238 commit 009228a

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2167,8 +2167,52 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
21672167
}
21682168

21692169
LogicalResult tosa::GatherOp::verify() {
2170-
return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2171-
/* outType = */ getOutput().getType());
2170+
if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2171+
/* outType = */ getOutput().getType())
2172+
.failed()) {
2173+
return failure();
2174+
}
2175+
2176+
const ShapeAdaptor valuesShape(getValues().getType());
2177+
const ShapeAdaptor indicesShape(getIndices().getType());
2178+
const ShapeAdaptor outputShape(getOutput().getType());
2179+
2180+
int64_t N = ShapedType::kDynamic;
2181+
int64_t W = ShapedType::kDynamic;
2182+
int64_t C = ShapedType::kDynamic;
2183+
2184+
if (valuesShape.hasRank()) {
2185+
N = valuesShape.getDimSize(0);
2186+
C = valuesShape.getDimSize(2);
2187+
}
2188+
if (indicesShape.hasRank()) {
2189+
const int64_t indicesN = indicesShape.getDimSize(0);
2190+
W = indicesShape.getDimSize(1);
2191+
if (N == ShapedType::kDynamic)
2192+
N = indicesN;
2193+
else if (indicesN != ShapedType::kDynamic && N != indicesN)
2194+
return emitOpError() << "requires indices dimension 0 to have size " << N
2195+
<< ", got " << indicesN;
2196+
}
2197+
if (outputShape.hasRank()) {
2198+
const int64_t outputN = outputShape.getDimSize(0);
2199+
const int64_t outputW = outputShape.getDimSize(1);
2200+
const int64_t outputC = outputShape.getDimSize(2);
2201+
if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2202+
N != outputN)
2203+
return emitOpError() << "requires output dimension 0 to have size " << N
2204+
<< ", got " << outputN;
2205+
2206+
if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2207+
W != outputW)
2208+
return emitOpError() << "requires output dimension 1 to have size " << W
2209+
<< ", got " << outputW;
2210+
if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2211+
C != outputC)
2212+
return emitOpError() << "requires output dimension 2 to have size " << C
2213+
<< ", got " << outputC;
2214+
}
2215+
return success();
21722216
}
21732217

21742218
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,35 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
206206
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
207207
return %0 : tensor<2x?xf32>
208208
}
209+
210+
// -----
211+
// CHECK-LABEL: @test_gather_invalid_indices_N
212+
func.func @test_gather_invalid_indices_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<12x26xi32>) -> tensor<13x26x3xf32> {
213+
// expected-error@+1 {{'tosa.gather' op requires indices dimension 0 to have size 13, got 12}}
214+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<12x26xi32>) -> tensor<13x26x3xf32>
215+
return %0 : tensor<13x26x3xf32>
216+
}
217+
218+
// -----
219+
// CHECK-LABEL: test_gather_invalid_out_N
220+
func.func @test_gather_invalid_out_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<12x26x3xf32> {
221+
// expected-error@+1 {{'tosa.gather' op requires output dimension 0 to have size 13, got 12}}
222+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<12x26x3xf32>
223+
return %0 : tensor<12x26x3xf32>
224+
}
225+
226+
// -----
227+
// CHECK-LABEL: test_gather_invalid_out_W
228+
func.func @test_gather_invalid_out_W(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x28x3xf32> {
229+
// expected-error@+1 {{'tosa.gather' op requires output dimension 1 to have size 26, got 28}}
230+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x28x3xf32>
231+
return %0 : tensor<13x28x3xf32>
232+
}
233+
234+
// -----
235+
// CHECK-LABEL: test_gather_invalid_out_C
236+
func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x8xf32> {
237+
// expected-error@+1 {{'tosa.gather' op requires output dimension 2 to have size 3, got 8}}
238+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x8xf32>
239+
return %0 : tensor<13x26x8xf32>
240+
}

0 commit comments

Comments
 (0)