Skip to content

Commit 300b94e

Browse files
Tai78641lhutton1
authored andcommitted
[mlir][tosa] Add verifier checks for Gather (llvm#137204)
This adds verifier checks for the gather op to make sure the shapes of inputs and output are consistent with respect to spec. --------- Signed-off-by: Tai Ly <[email protected]> Co-authored-by: Luke Hutton <[email protected]>
1 parent 9c4ee7f commit 300b94e

File tree

2 files changed

+79
-2
lines changed

2 files changed

+79
-2
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,8 +2262,52 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
22622262
}
22632263

22642264
LogicalResult tosa::GatherOp::verify() {
2265-
return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2266-
/* outType = */ getOutput().getType());
2265+
if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2266+
/* outType = */ getOutput().getType())
2267+
.failed()) {
2268+
return failure();
2269+
}
2270+
2271+
const ShapeAdaptor valuesShape(getValues().getType());
2272+
const ShapeAdaptor indicesShape(getIndices().getType());
2273+
const ShapeAdaptor outputShape(getOutput().getType());
2274+
2275+
int64_t N = ShapedType::kDynamic;
2276+
int64_t W = ShapedType::kDynamic;
2277+
int64_t C = ShapedType::kDynamic;
2278+
2279+
if (valuesShape.hasRank()) {
2280+
N = valuesShape.getDimSize(0);
2281+
C = valuesShape.getDimSize(2);
2282+
}
2283+
if (indicesShape.hasRank()) {
2284+
const int64_t indicesN = indicesShape.getDimSize(0);
2285+
W = indicesShape.getDimSize(1);
2286+
if (N == ShapedType::kDynamic)
2287+
N = indicesN;
2288+
else if (indicesN != ShapedType::kDynamic && N != indicesN)
2289+
return emitOpError() << "requires indices dimension 0 to have size " << N
2290+
<< ", got " << indicesN;
2291+
}
2292+
if (outputShape.hasRank()) {
2293+
const int64_t outputN = outputShape.getDimSize(0);
2294+
const int64_t outputW = outputShape.getDimSize(1);
2295+
const int64_t outputC = outputShape.getDimSize(2);
2296+
if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2297+
N != outputN)
2298+
return emitOpError() << "requires output dimension 0 to have size " << N
2299+
<< ", got " << outputN;
2300+
2301+
if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2302+
W != outputW)
2303+
return emitOpError() << "requires output dimension 1 to have size " << W
2304+
<< ", got " << outputW;
2305+
if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2306+
C != outputC)
2307+
return emitOpError() << "requires output dimension 2 to have size " << C
2308+
<< ", got " << outputC;
2309+
}
2310+
return success();
22672311
}
22682312

22692313
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,3 +370,36 @@ func.func @test_error_scalar_input_with_per_channel(%arg0: tensor<i8>) -> tensor
370370
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = true, rounding_mode = "SINGLE_ROUND", per_channel = true, input_unsigned = false, output_unsigned = false} : (tensor<i8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi16>) -> tensor<i16>
371371
return %0 : tensor<i16>
372372
}
373+
374+
// -----
375+
376+
// CHECK-LABEL: @test_gather_invalid_indices_N
377+
func.func @test_gather_invalid_indices_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<12x26xi32>) -> tensor<13x26x3xf32> {
378+
// expected-error@+1 {{'tosa.gather' op requires indices dimension 0 to have size 13, got 12}}
379+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<12x26xi32>) -> tensor<13x26x3xf32>
380+
return %0 : tensor<13x26x3xf32>
381+
}
382+
383+
// -----
384+
// CHECK-LABEL: test_gather_invalid_out_N
385+
func.func @test_gather_invalid_out_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<12x26x3xf32> {
386+
// expected-error@+1 {{'tosa.gather' op requires output dimension 0 to have size 13, got 12}}
387+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<12x26x3xf32>
388+
return %0 : tensor<12x26x3xf32>
389+
}
390+
391+
// -----
392+
// CHECK-LABEL: test_gather_invalid_out_W
393+
func.func @test_gather_invalid_out_W(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x28x3xf32> {
394+
// expected-error@+1 {{'tosa.gather' op requires output dimension 1 to have size 26, got 28}}
395+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x28x3xf32>
396+
return %0 : tensor<13x28x3xf32>
397+
}
398+
399+
// -----
400+
// CHECK-LABEL: test_gather_invalid_out_C
401+
func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x8xf32> {
402+
// expected-error@+1 {{'tosa.gather' op requires output dimension 2 to have size 3, got 8}}
403+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x8xf32>
404+
return %0 : tensor<13x26x8xf32>
405+
}

0 commit comments

Comments
 (0)