Skip to content

Commit c140783

Browse files
authored
[tosa] Add verifier checks for Scatter (#142661)
This adds verifier checks for the scatter op to make sure the shapes of inputs and output are consistent with respect to spec. Signed-off-by: Tai Ly <[email protected]>
1 parent 2718a47 commit c140783

File tree

9 files changed

+168
-29
lines changed

9 files changed

+168
-29
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2673,6 +2673,73 @@ LogicalResult tosa::ScatterOp::verify() {
26732673
.failed()) {
26742674
return failure();
26752675
}
2676+
2677+
const ShapeAdaptor valuesInShape(getValuesIn().getType());
2678+
const ShapeAdaptor indicesShape(getIndices().getType());
2679+
const ShapeAdaptor inputShape(getInput().getType());
2680+
const ShapeAdaptor outputShape(getValuesOut().getType());
2681+
2682+
int64_t N = ShapedType::kDynamic;
2683+
int64_t K = ShapedType::kDynamic;
2684+
int64_t W = ShapedType::kDynamic;
2685+
int64_t C = ShapedType::kDynamic;
2686+
if (valuesInShape.hasRank()) {
2687+
N = valuesInShape.getDimSize(0);
2688+
K = valuesInShape.getDimSize(1);
2689+
C = valuesInShape.getDimSize(2);
2690+
}
2691+
if (indicesShape.hasRank()) {
2692+
const int64_t indicesN = indicesShape.getDimSize(0);
2693+
W = indicesShape.getDimSize(1);
2694+
if (N == ShapedType::kDynamic)
2695+
N = indicesN;
2696+
else if (indicesN != ShapedType::kDynamic && N != indicesN)
2697+
return emitOpError() << "requires indices dimension 0 to have size " << N
2698+
<< ", got " << indicesN;
2699+
}
2700+
if (inputShape.hasRank()) {
2701+
const int64_t inputN = inputShape.getDimSize(0);
2702+
const int64_t inputW = inputShape.getDimSize(1);
2703+
const int64_t inputC = inputShape.getDimSize(2);
2704+
if (N == ShapedType::kDynamic)
2705+
N = inputN;
2706+
else if (inputN != ShapedType::kDynamic && N != inputN)
2707+
return emitOpError() << "requires input dimension 0 to have size " << N
2708+
<< ", got " << inputN;
2709+
if (W == ShapedType::kDynamic)
2710+
W = inputW;
2711+
else if (inputW != ShapedType::kDynamic && W != inputW)
2712+
return emitOpError() << "requires input dimension 1 to have size " << W
2713+
<< ", got " << inputW;
2714+
2715+
if (C == ShapedType::kDynamic)
2716+
C = inputC;
2717+
else if (inputC != ShapedType::kDynamic && C != inputC)
2718+
return emitOpError() << "requires input dimension 2 to have size " << C
2719+
<< ", got " << inputC;
2720+
}
2721+
if (outputShape.hasRank()) {
2722+
const int64_t outputN = outputShape.getDimSize(0);
2723+
const int64_t outputK = outputShape.getDimSize(1);
2724+
const int64_t outputC = outputShape.getDimSize(2);
2725+
if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2726+
N != outputN)
2727+
return emitOpError() << "requires values_out dimension 0 to have size "
2728+
<< N << ", got " << outputN;
2729+
if (K == ShapedType::kDynamic)
2730+
K = outputK;
2731+
else if (outputK != ShapedType::kDynamic && K != outputK)
2732+
return emitOpError() << "requires values_out dimension 1 to have size "
2733+
<< K << ", got " << outputK;
2734+
if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2735+
C != outputC)
2736+
return emitOpError() << "requires values_out dimension 2 to have size "
2737+
<< C << ", got " << outputC;
2738+
}
2739+
if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2740+
return emitOpError() << "requires dimensions K >= W, got K=" << K
2741+
<< " and W=" << W;
2742+
26762743
return success();
26772744
}
26782745

mlir/test/Dialect/Tosa/availability.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,11 +583,11 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
583583

584584
// -----
585585
// CHECK-LABEL: scatter
586-
func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
586+
func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
587587
// CHECK: profiles: [ [pro_int, pro_fp] ]
588588
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
589-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
590-
return %0 : tensor<13x21x3xf32>
589+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
590+
return %0 : tensor<13x28x3xf32>
591591
}
592592

593593
// -----

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>) ->
243243
}
244244

245245
// -----
246-
func.func @test_scatter(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16> {
246+
func.func @test_scatter(%arg0: tensor<13x26x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16> {
247247
// expected-error@+1 {{'tosa.scatter' op illegal: requires [bf16] but not enabled in target}}
248-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16>
249-
return %0 : tensor<13x21x3xbf16>
248+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x26x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16>
249+
return %0 : tensor<13x26x3xbf16>
250250
}
251251

252252
// -----

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,10 +1080,10 @@ func.func @test_gather_tensor_size_invalid(%arg0: tensor<268435456x21x3xf32>, %a
10801080

10811081
// -----
10821082

1083-
func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32> {
1083+
func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32> {
10841084
// expected-error@+1 {{'tosa.scatter' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1085-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x210000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32>
1086-
return %0 : tensor<13x210000000x3xf32>
1085+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x260000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32>
1086+
return %0 : tensor<13x260000000x3xf32>
10871087
}
10881088

10891089
// -----

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -730,9 +730,9 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
730730

731731
// -----
732732
// CHECK-LABEL: scatter
733-
func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
734-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
735-
return %0 : tensor<13x21x3xf32>
733+
func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
734+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
735+
return %0 : tensor<13x52x3xf32>
736736
}
737737

738738
// -----
@@ -744,8 +744,8 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
744744

745745
// -----
746746
// CHECK-LABEL: scatter_unranked_indices
747-
func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
748-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
747+
func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
748+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
749749
return %0 : tensor<13x21x3xf32>
750750
}
751751

@@ -1026,9 +1026,9 @@ func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26
10261026

10271027
// -----
10281028
// CHECK-LABEL: scatter_f8E5M2
1029-
func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
1030-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
1031-
return %0 : tensor<13x21x3xf8E5M2>
1029+
func.func @test_scatter_f8E5M2(%arg0: tensor<13x52x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2> {
1030+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2>
1031+
return %0 : tensor<13x52x3xf8E5M2>
10321032
}
10331033

10341034
// -----
@@ -1171,7 +1171,7 @@ func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<1
11711171

11721172
// -----
11731173
// CHECK-LABEL: scatter_f8E4M3FN
1174-
func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
1175-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
1176-
return %0 : tensor<13x21x3xf8E4M3FN>
1174+
func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN> {
1175+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN>
1176+
return %0 : tensor<13x29x3xf8E4M3FN>
11771177
}

mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
310310
}
311311

312312
// -----
313-
func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
313+
func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
314314
// expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_fp] but not enabled in target}}
315-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
316-
return %0 : tensor<13x21x3xf32>
315+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
316+
return %0 : tensor<13x28x3xf32>
317317
}
318318

319319
// -----

mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>) ->
242242
}
243243

244244
// -----
245-
func.func @test_scatter(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x21x3xi32> {
245+
func.func @test_scatter(%arg0: tensor<13x27x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x27x3xi32> {
246246
// expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_int] but not enabled in target}}
247-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x21x3xi32>
248-
return %0 : tensor<13x21x3xi32>
247+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x27x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x27x3xi32>
248+
return %0 : tensor<13x27x3xi32>
249249
}
250250

251251
// -----

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -656,9 +656,9 @@ func.func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor<?x6xi32
656656
// -----
657657

658658
// CHECK-LABEL: @scatter_static
659-
func.func @scatter_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
660-
// CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x4x5xi32>
661-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
659+
func.func @scatter_static(%arg0 : tensor<3x8x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
660+
// CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x8x5xi32>
661+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
662662
return
663663
}
664664

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,3 +864,75 @@ func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
864864
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
865865
return
866866
}
867+
868+
// -----
869+
870+
// CHECK-LABEL: @scatter_invalid_indices_N
871+
func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) {
872+
// expected-error@+1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
873+
%1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32>
874+
return
875+
}
876+
877+
// -----
878+
879+
// CHECK-LABEL: @scatter_invalid_input_N
880+
func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) {
881+
// expected-error@+1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
882+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32>
883+
return
884+
}
885+
886+
// -----
887+
888+
// CHECK-LABEL: @scatter_invalid_out_N
889+
func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
890+
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
891+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<3x4x5xi32>
892+
return
893+
}
894+
895+
// -----
896+
897+
// CHECK-LABEL: @scatter_invalid_out_K
898+
func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
899+
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
900+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x3x5xi32>
901+
return
902+
}
903+
904+
// -----
905+
906+
// CHECK-LABEL: @scatter_invalid_input_W
907+
func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x3x5xi32>) {
908+
// expected-error@+1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
909+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x3x5xi32>) -> tensor<2x4x5xi32>
910+
return
911+
}
912+
913+
// -----
914+
915+
// CHECK-LABEL: @scatter_invalid_input_C
916+
func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x6xi32>) {
917+
// expected-error@+1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
918+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x6xi32>) -> tensor<2x4x5xi32>
919+
return
920+
}
921+
922+
// -----
923+
924+
// CHECK-LABEL: @scatter_invalid_out_C
925+
func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
926+
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
927+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x6xi32>
928+
return
929+
}
930+
931+
// -----
932+
933+
// CHECK-LABEL: @scatter_invalid_K_W
934+
func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) {
935+
// expected-error@+1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
936+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
937+
return
938+
}

0 commit comments

Comments
 (0)