@@ -455,41 +455,59 @@ func.func @gather_coordinate_rank_overflow(
455
455
456
456
// -----
457
457
458
+ func.func @gather_coordinate_rank_mismatch0 (
459
+ %source: tensor <4 x5 x6 xf32 >, %indices: tensor <index >) {
460
+ // expected-error@+1 {{gather_dims length must match the size of last dimension of indices}}
461
+ %out = tensor.gather %source [%indices ] gather_dims ([0 , 1 , 2 ]):
462
+ (tensor <4 x5 x6 xf32 >, tensor <index >) -> tensor <1 x2 xf32 >
463
+ }
464
+
465
+ // -----
466
+
467
+ func.func @gather_coordinate_rank_mismatch1 (
468
+ %source: tensor <4 x5 x6 xf32 >, %indices: tensor <1 x2 x2 xindex >) {
469
+ // expected-error@+1 {{gather_dims length must match the size of last dimension of indices}}
470
+ %out = tensor.gather %source [%indices ] gather_dims ([0 , 1 , 2 ]):
471
+ (tensor <4 x5 x6 xf32 >, tensor <1 x2 x2 xindex >) -> tensor <1 x2 xf32 >
472
+ }
473
+
474
+ // -----
475
+
458
476
func.func @gather_coordinate_negative (
459
- %source : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 3 x index >) {
477
+ %source : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 1 x index >) {
460
478
// expected-error@+1 {{gather_dims value must be non-negative}}
461
479
%out = tensor.gather %source [%indices ] gather_dims ([-1 ]):
462
- (tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 3 x index >) -> tensor <1 x 2 x 1 x 1 x 1 x f32 >
480
+ (tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 1 x index >) -> tensor <1 x 2 x 1 x f32 >
463
481
return
464
482
}
465
483
466
484
// -----
467
485
468
486
func.func @gather_coordinate_overflow (
469
- %source : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 3 x index >) {
487
+ %source : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 1 x index >) {
470
488
// expected-error@+1 {{gather_dims value must be smaller than source rank}}
471
489
%out = tensor.gather %source [%indices ] gather_dims ([42 ]):
472
- (tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 3 x index >) -> tensor <1 x 2 x 1 x 1 x 1 x f32 >
490
+ (tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 1 x index >) -> tensor <1 x 2 x 1 x f32 >
473
491
return
474
492
}
475
493
476
494
// -----
477
495
478
- func.func @gather_coordinate_overflow (
479
- %source : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 3 x index >) {
496
+ func.func @gather_coordinate_increase (
497
+ %source : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 2 x index >) {
480
498
// expected-error@+1 {{gather_dims values must be strictly increasing}}
481
499
%out = tensor.gather %source [%indices ] gather_dims ([1 , 0 ]):
482
- (tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 3 x index >) -> tensor <1 x 2 x 1 x 1 x 1 x f32 >
500
+ (tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 2 x index >) -> tensor <1 x 2 x 1 x 1 x f32 >
483
501
return
484
502
}
485
503
486
504
// -----
487
505
488
506
func.func @gather_wrong_result_type (
489
- %source : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 3 x index >) {
507
+ %source : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 2 x index >) {
490
508
// expected-error@+1 {{result type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<1x2x1xf32>')}}
491
509
%out = tensor.gather %source [%indices ] gather_dims ([0 , 2 ]):
492
- (tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 3 x index >) -> tensor <1 x2 x1 xf32 >
510
+ (tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 2 x index >) -> tensor <1 x2 x1 xf32 >
493
511
return
494
512
}
495
513
@@ -517,56 +535,78 @@ func.func @scatter_coordinate_rank_overflow(
517
535
518
536
// -----
519
537
538
+ func.func @scatter_coordinate_rank_mismatch0 (
539
+ %source : tensor <f32 >,
540
+ %dest : tensor <4 x5 x6 xf32 >, %indices: tensor <index >) {
541
+ // expected-error@+1 {{scatter_dims length must match the size of last dimension of indices}}
542
+ %out = tensor.scatter %source into %dest [%indices ] scatter_dims ([0 , 1 , 2 ]) unique :
543
+ (tensor <f32 >, tensor <4 x5 x6 xf32 >, tensor <index >) -> tensor <1 x2 xf32 >
544
+ return
545
+ }
546
+
547
+ // -----
548
+
549
+ func.func @scatter_coordinate_rank_mismatch1 (
550
+ %source : tensor <f32 >,
551
+ %dest : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x2 x2 xindex >) {
552
+ // expected-error@+1 {{scatter_dims length must match the size of last dimension of indices}}
553
+ %out = tensor.scatter %source into %dest [%indices ] scatter_dims ([0 , 1 , 2 ]) unique :
554
+ (tensor <f32 >, tensor <4 x5 x6 xf32 >, tensor <1 x2 x2 xindex >) -> tensor <1 x2 xf32 >
555
+ return
556
+ }
557
+
558
+ // -----
559
+
520
560
func.func @scatter_coordinate_negative (
521
561
%source : tensor <f32 >,
522
- %dest : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 3 x index >) {
562
+ %dest : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 1 x index >) {
523
563
// expected-error@+1 {{scatter_dims value must be non-negative}}
524
564
%out = tensor.scatter %source into %dest [%indices ] scatter_dims ([-1 ]) unique :
525
- (tensor <f32 >, tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 3 x index >) -> tensor <1 x 2 x 1 x 1 x 1 x f32 >
565
+ (tensor <f32 >, tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 1 x index >) -> tensor <1 x 2 x 1 x f32 >
526
566
return
527
567
}
528
568
529
569
// -----
530
570
531
571
func.func @scatter_coordinate_overflow (
532
572
%source : tensor <f32 >,
533
- %dest : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 3 x index >) {
573
+ %dest : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 1 x index >) {
534
574
// expected-error@+1 {{scatter_dims value must be smaller than dest rank}}
535
575
%out = tensor.scatter %source into %dest [%indices ] scatter_dims ([42 ]) unique :
536
- (tensor <f32 >, tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 3 x index >) -> tensor <1 x 2 x 1 x 1 x 1 x f32 >
576
+ (tensor <f32 >, tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 1 x index >) -> tensor <1 x 2 x 1 x f32 >
537
577
return
538
578
}
539
579
540
580
// -----
541
581
542
- func.func @scatter_coordinate_overflow (
582
+ func.func @scatter_coordinate_increase (
543
583
%source : tensor <f32 >,
544
- %dest : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 3 x index >) {
584
+ %dest : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 2 x index >) {
545
585
// expected-error@+1 {{scatter_dims values must be strictly increasing}}
546
586
%out = tensor.scatter %source into %dest [%indices ] scatter_dims ([1 , 0 ]) unique :
547
- (tensor <f32 >, tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 3 x index >) -> tensor <1 x 2 x 1 x 1 x 1 x f32 >
587
+ (tensor <f32 >, tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 2 x index >) -> tensor <1 x 2 x 1 x 1 x f32 >
548
588
return
549
589
}
550
590
551
591
// -----
552
592
553
593
func.func @scatter_missing_unique (
554
594
%source : tensor <f32 >,
555
- %dest : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 3 x index >) {
595
+ %dest : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 2 x index >) {
556
596
// expected-error@+1 {{requires 'unique' attribute to be set}}
557
597
%out = tensor.scatter %source into %dest [%indices ] scatter_dims ([0 , 2 ]):
558
- (tensor <f32 >, tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 3 x index >) -> tensor <1 x2 x1 xf32 >
598
+ (tensor <f32 >, tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 2 x index >) -> tensor <1 x2 x1 xf32 >
559
599
return
560
600
}
561
601
562
602
// -----
563
603
564
604
func.func @scatter_wrong_result_type (
565
605
%source : tensor <f32 >,
566
- %dest : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 3 x index >) {
606
+ %dest : tensor <4 x5 x6 xf32 >, %indices: tensor <1 x 2 x 2 x index >) {
567
607
// expected-error@+1 {{source type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<f32>')}}
568
608
%out = tensor.scatter %source into %dest [%indices ] scatter_dims ([0 , 2 ]) unique :
569
- (tensor <f32 >, tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 3 x index >) -> tensor <1 x2 x1 xf32 >
609
+ (tensor <f32 >, tensor <4 x5 x6 xf32 >, tensor <1 x 2 x 2 x index >) -> tensor <1 x2 x1 xf32 >
570
610
return
571
611
}
572
612
0 commit comments