Skip to content

Commit e1c0db0

Browse files
committed
Add verifier checks for Scatter
This adds verifier checks for the scatter op to make sure the shapes of inputs and output are consistent with respect to spec. Signed-off-by: Tai Ly <[email protected]> Change-Id: I59531fa63e2d1dbd2865e0ef9b08b76991915c9a
1 parent 009228a commit e1c0db0

File tree

9 files changed

+166
-27
lines changed

9 files changed

+166
-27
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2386,6 +2386,73 @@ LogicalResult tosa::ScatterOp::verify() {
23862386
.failed()) {
23872387
return failure();
23882388
}
2389+
2390+
const ShapeAdaptor valuesInShape(getValuesIn().getType());
2391+
const ShapeAdaptor indicesShape(getIndices().getType());
2392+
const ShapeAdaptor inputShape(getInput().getType());
2393+
const ShapeAdaptor outputShape(getValuesOut().getType());
2394+
2395+
int64_t N = ShapedType::kDynamic;
2396+
int64_t K = ShapedType::kDynamic;
2397+
int64_t W = ShapedType::kDynamic;
2398+
int64_t C = ShapedType::kDynamic;
2399+
if (valuesInShape.hasRank()) {
2400+
N = valuesInShape.getDimSize(0);
2401+
K = valuesInShape.getDimSize(1);
2402+
C = valuesInShape.getDimSize(2);
2403+
}
2404+
if (indicesShape.hasRank()) {
2405+
const int64_t indicesN = indicesShape.getDimSize(0);
2406+
W = indicesShape.getDimSize(1);
2407+
if (N == ShapedType::kDynamic)
2408+
N = indicesN;
2409+
else if (indicesN != ShapedType::kDynamic && N != indicesN)
2410+
return emitOpError() << "requires indices dimension 0 to have size " << N
2411+
<< ", got " << indicesN;
2412+
}
2413+
if (inputShape.hasRank()) {
2414+
const int64_t inputN = inputShape.getDimSize(0);
2415+
const int64_t inputW = inputShape.getDimSize(1);
2416+
const int64_t inputC = inputShape.getDimSize(2);
2417+
if (N == ShapedType::kDynamic)
2418+
N = inputN;
2419+
else if (inputN != ShapedType::kDynamic && N != inputN)
2420+
return emitOpError() << "requires input dimension 0 to have size " << N
2421+
<< ", got " << inputN;
2422+
if (W == ShapedType::kDynamic)
2423+
W = inputW;
2424+
else if (inputW != ShapedType::kDynamic && W != inputW)
2425+
return emitOpError() << "requires input dimension 1 to have size " << W
2426+
<< ", got " << inputW;
2427+
2428+
if (C == ShapedType::kDynamic)
2429+
C = inputC;
2430+
else if (inputC != ShapedType::kDynamic && C != inputC)
2431+
return emitOpError() << "requires input dimension 2 to have size " << C
2432+
<< ", got " << inputC;
2433+
}
2434+
if (outputShape.hasRank()) {
2435+
const int64_t outputN = outputShape.getDimSize(0);
2436+
const int64_t outputK = outputShape.getDimSize(1);
2437+
const int64_t outputC = outputShape.getDimSize(2);
2438+
if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2439+
N != outputN)
2440+
return emitOpError() << "requires values_out dimension 0 to have size "
2441+
<< N << ", got " << outputN;
2442+
if (K == ShapedType::kDynamic)
2443+
K = outputK;
2444+
else if (outputK != ShapedType::kDynamic && K != outputK)
2445+
return emitOpError() << "requires values_out dimension 1 to have size "
2446+
<< K << ", got " << outputK;
2447+
if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2448+
C != outputC)
2449+
return emitOpError() << "requires values_out dimension 2 to have size "
2450+
<< C << ", got " << outputC;
2451+
}
2452+
if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2453+
return emitOpError() << "requires dimensions K >= W, got K=" << K
2454+
<< " and W=" << W;
2455+
23892456
return success();
23902457
}
23912458

mlir/test/Dialect/Tosa/availability.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,11 +583,11 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
583583

584584
// -----
585585
// CHECK-LABEL: scatter
586-
func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
586+
func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
587587
// CHECK: profiles: [ [pro_int, pro_fp] ]
588588
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
589-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
590-
return %0 : tensor<13x21x3xf32>
589+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
590+
return %0 : tensor<13x28x3xf32>
591591
}
592592

593593
// -----

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>) ->
243243
}
244244

245245
// -----
246-
func.func @test_scatter(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16> {
246+
func.func @test_scatter(%arg0: tensor<13x26x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16> {
247247
// expected-error@+1 {{'tosa.scatter' op illegal: requires [bf16] but not enabled in target}}
248-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16>
249-
return %0 : tensor<13x21x3xbf16>
248+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x26x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16>
249+
return %0 : tensor<13x26x3xbf16>
250250
}
251251

252252
// -----

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,10 +1087,10 @@ func.func @test_gather_tensor_size_invalid(%arg0: tensor<268435456x21x3xf32>, %a
10871087

10881088
// -----
10891089

1090-
func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32> {
1090+
func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32> {
10911091
// expected-error@+1 {{'tosa.scatter' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1092-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x210000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32>
1093-
return %0 : tensor<13x210000000x3xf32>
1092+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x260000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32>
1093+
return %0 : tensor<13x260000000x3xf32>
10941094
}
10951095

10961096
// -----

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -671,9 +671,9 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
671671

672672
// -----
673673
// CHECK-LABEL: scatter
674-
func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
675-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
676-
return %0 : tensor<13x21x3xf32>
674+
func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
675+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
676+
return %0 : tensor<13x52x3xf32>
677677
}
678678

679679
// -----
@@ -951,9 +951,9 @@ func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26
951951

952952
// -----
953953
// CHECK-LABEL: scatter_f8E5M2
954-
func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
955-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
956-
return %0 : tensor<13x21x3xf8E5M2>
954+
func.func @test_scatter_f8E5M2(%arg0: tensor<13x52x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2> {
955+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2>
956+
return %0 : tensor<13x52x3xf8E5M2>
957957
}
958958

959959
// -----
@@ -1103,7 +1103,7 @@ func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<1
11031103

11041104
// -----
11051105
// CHECK-LABEL: scatter_f8E4M3FN
1106-
func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
1107-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
1108-
return %0 : tensor<13x21x3xf8E4M3FN>
1106+
func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN> {
1107+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN>
1108+
return %0 : tensor<13x29x3xf8E4M3FN>
11091109
}

mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
310310
}
311311

312312
// -----
313-
func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
313+
func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
314314
// expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_fp] but not enabled in target}}
315-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
316-
return %0 : tensor<13x21x3xf32>
315+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
316+
return %0 : tensor<13x28x3xf32>
317317
}
318318

319319
// -----

mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>) ->
242242
}
243243

244244
// -----
245-
func.func @test_scatter(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x21x3xi32> {
245+
func.func @test_scatter(%arg0: tensor<13x27x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x27x3xi32> {
246246
// expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_int] but not enabled in target}}
247-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x21x3xi32>
248-
return %0 : tensor<13x21x3xi32>
247+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x27x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x27x3xi32>
248+
return %0 : tensor<13x27x3xi32>
249249
}
250250

251251
// -----

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,9 +687,9 @@ func.func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor<?x6xi32
687687
// -----
688688

689689
// CHECK-LABEL: @scatter_static
690-
func.func @scatter_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
691-
// CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x4x5xi32>
692-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
690+
func.func @scatter_static(%arg0 : tensor<3x8x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
691+
// CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x8x5xi32>
692+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
693693
return
694694
}
695695

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,75 @@ func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1
238238
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x8xf32>
239239
return %0 : tensor<13x26x8xf32>
240240
}
241+
242+
// -----
243+
244+
// CHECK-LABEL: @scatter_invalid_indices_N
245+
func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) {
246+
// expected-error@+1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
247+
%1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32>
248+
return
249+
}
250+
251+
// -----
252+
253+
// CHECK-LABEL: @scatter_invalid_input_N
254+
func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) {
255+
// expected-error@+1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
256+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32>
257+
return
258+
}
259+
260+
// -----
261+
262+
// CHECK-LABEL: @scatter_invalid_out_N
263+
func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
264+
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
265+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<3x4x5xi32>
266+
return
267+
}
268+
269+
// -----
270+
271+
// CHECK-LABEL: @scatter_invalid_out_K
272+
func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
273+
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
274+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x3x5xi32>
275+
return
276+
}
277+
278+
// -----
279+
280+
// CHECK-LABEL: @scatter_invalid_input_W
281+
func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x3x5xi32>) {
282+
// expected-error@+1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
283+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x3x5xi32>) -> tensor<2x4x5xi32>
284+
return
285+
}
286+
287+
// -----
288+
289+
// CHECK-LABEL: @scatter_invalid_input_C
290+
func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x6xi32>) {
291+
// expected-error@+1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
292+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x6xi32>) -> tensor<2x4x5xi32>
293+
return
294+
}
295+
296+
// -----
297+
298+
// CHECK-LABEL: @scatter_invalid_out_C
299+
func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
300+
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
301+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x6xi32>
302+
return
303+
}
304+
305+
// -----
306+
307+
// CHECK-LABEL: @scatter_invalid_K_W
308+
func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) {
309+
// expected-error@+1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
310+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
311+
return
312+
}

0 commit comments

Comments
 (0)