Skip to content

Commit ede40da

Browse files
authored
[mlir][tensor] Add check for indices of tensor.gather (#106894)
This patch add a check for indices of `tensor.gather` and `tensor.scatter`. For that the length of gather_dims/scatter_dims should match the size of last dimension of the indices. Fix #94901.
1 parent 5acd9d1 commit ede40da

File tree

2 files changed

+69
-23
lines changed

2 files changed

+69
-23
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,7 +1288,8 @@ RankedTensorType GatherOp::inferResultType(RankedTensorType sourceType,
12881288
}
12891289

12901290
static LogicalResult
1291-
verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims, int64_t rank,
1291+
verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims,
1292+
ArrayRef<int64_t> indices, int64_t rank,
12921293
StringRef gatherOrScatter, StringRef sourceOrDest) {
12931294
if (dims.empty())
12941295
return op->emitOpError(gatherOrScatter) << "_dims must be non-empty";
@@ -1297,6 +1298,9 @@ verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims, int64_t rank,
12971298
if (numGatherDims > rank)
12981299
return op->emitOpError(gatherOrScatter)
12991300
<< "_dims overflow " << sourceOrDest << " rank";
1301+
if (indices.empty() || indices.back() != numGatherDims)
1302+
return op->emitOpError(gatherOrScatter)
1303+
<< "_dims length must match the size of last dimension of indices";
13001304
for (int64_t val : dims) {
13011305
if (val < 0)
13021306
return op->emitOpError(gatherOrScatter)
@@ -1316,7 +1320,8 @@ verifyGatherOrScatterDims(Operation *op, ArrayRef<int64_t> dims, int64_t rank,
13161320
LogicalResult GatherOp::verify() {
13171321
int64_t sourceRank = getSourceType().getRank();
13181322
ArrayRef<int64_t> gatherDims = getGatherDims();
1319-
if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims, sourceRank,
1323+
if (failed(verifyGatherOrScatterDims(getOperation(), gatherDims,
1324+
getIndicesType().getShape(), sourceRank,
13201325
"gather", "source")))
13211326
return failure();
13221327

@@ -3530,7 +3535,8 @@ void ScatterOp::getAsmResultNames(
35303535
LogicalResult ScatterOp::verify() {
35313536
int64_t destRank = getDestType().getRank();
35323537
ArrayRef<int64_t> scatterDims = getScatterDims();
3533-
if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims, destRank,
3538+
if (failed(verifyGatherOrScatterDims(getOperation(), scatterDims,
3539+
getIndicesType().getShape(), destRank,
35343540
"scatter", "dest")))
35353541
return failure();
35363542

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -455,41 +455,59 @@ func.func @gather_coordinate_rank_overflow(
455455

456456
// -----
457457

458+
func.func @gather_coordinate_rank_mismatch0(
459+
%source: tensor<4x5x6xf32>, %indices: tensor<index>) {
460+
// expected-error@+1 {{gather_dims length must match the size of last dimension of indices}}
461+
%out = tensor.gather %source[%indices] gather_dims([0, 1, 2]):
462+
(tensor<4x5x6xf32>, tensor<index>) -> tensor<1x2xf32>
463+
}
464+
465+
// -----
466+
467+
func.func @gather_coordinate_rank_mismatch1(
468+
%source: tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
469+
// expected-error@+1 {{gather_dims length must match the size of last dimension of indices}}
470+
%out = tensor.gather %source[%indices] gather_dims([0, 1, 2]):
471+
(tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2xf32>
472+
}
473+
474+
// -----
475+
458476
func.func @gather_coordinate_negative(
459-
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
477+
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) {
460478
// expected-error@+1 {{gather_dims value must be non-negative}}
461479
%out = tensor.gather %source[%indices] gather_dims([-1]):
462-
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
480+
(tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32>
463481
return
464482
}
465483

466484
// -----
467485

468486
func.func @gather_coordinate_overflow(
469-
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
487+
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) {
470488
// expected-error@+1 {{gather_dims value must be smaller than source rank}}
471489
%out = tensor.gather %source[%indices] gather_dims([42]):
472-
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
490+
(tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32>
473491
return
474492
}
475493

476494
// -----
477495

478-
func.func @gather_coordinate_overflow(
479-
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
496+
func.func @gather_coordinate_increase(
497+
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
480498
// expected-error@+1 {{gather_dims values must be strictly increasing}}
481499
%out = tensor.gather %source[%indices] gather_dims([1, 0]):
482-
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
500+
(tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1x1xf32>
483501
return
484502
}
485503

486504
// -----
487505

488506
func.func @gather_wrong_result_type(
489-
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
507+
%source : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
490508
// expected-error@+1 {{result type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<1x2x1xf32>')}}
491509
%out = tensor.gather %source[%indices] gather_dims([0, 2]):
492-
(tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
510+
(tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1xf32>
493511
return
494512
}
495513

@@ -517,56 +535,78 @@ func.func @scatter_coordinate_rank_overflow(
517535

518536
// -----
519537

538+
func.func @scatter_coordinate_rank_mismatch0(
539+
%source : tensor<f32>,
540+
%dest : tensor<4x5x6xf32>, %indices: tensor<index>) {
541+
// expected-error@+1 {{scatter_dims length must match the size of last dimension of indices}}
542+
%out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 1, 2]) unique:
543+
(tensor<f32>, tensor<4x5x6xf32>, tensor<index>) -> tensor<1x2xf32>
544+
return
545+
}
546+
547+
// -----
548+
549+
func.func @scatter_coordinate_rank_mismatch1(
550+
%source : tensor<f32>,
551+
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
552+
// expected-error@+1 {{scatter_dims length must match the size of last dimension of indices}}
553+
%out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 1, 2]) unique:
554+
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2xf32>
555+
return
556+
}
557+
558+
// -----
559+
520560
func.func @scatter_coordinate_negative(
521561
%source : tensor<f32>,
522-
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
562+
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) {
523563
// expected-error@+1 {{scatter_dims value must be non-negative}}
524564
%out = tensor.scatter %source into %dest[%indices] scatter_dims([-1]) unique:
525-
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
565+
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32>
526566
return
527567
}
528568

529569
// -----
530570

531571
func.func @scatter_coordinate_overflow(
532572
%source : tensor<f32>,
533-
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
573+
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x1xindex>) {
534574
// expected-error@+1 {{scatter_dims value must be smaller than dest rank}}
535575
%out = tensor.scatter %source into %dest[%indices] scatter_dims([42]) unique:
536-
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
576+
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x1xindex>) -> tensor<1x2x1xf32>
537577
return
538578
}
539579

540580
// -----
541581

542-
func.func @scatter_coordinate_overflow(
582+
func.func @scatter_coordinate_increase(
543583
%source : tensor<f32>,
544-
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
584+
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
545585
// expected-error@+1 {{scatter_dims values must be strictly increasing}}
546586
%out = tensor.scatter %source into %dest[%indices] scatter_dims([1, 0]) unique:
547-
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1x1x1xf32>
587+
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1x1xf32>
548588
return
549589
}
550590

551591
// -----
552592

553593
func.func @scatter_missing_unique(
554594
%source : tensor<f32>,
555-
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
595+
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
556596
// expected-error@+1 {{requires 'unique' attribute to be set}}
557597
%out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]):
558-
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
598+
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1xf32>
559599
return
560600
}
561601

562602
// -----
563603

564604
func.func @scatter_wrong_result_type(
565605
%source : tensor<f32>,
566-
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x3xindex>) {
606+
%dest : tensor<4x5x6xf32>, %indices: tensor<1x2x2xindex>) {
567607
// expected-error@+1 {{source type mismatch: expected 'tensor<1x2x1x5x1xf32>' or its rank-reduced variant 'tensor<1x2x5xf32>' (got: 'tensor<f32>')}}
568608
%out = tensor.scatter %source into %dest[%indices] scatter_dims([0, 2]) unique:
569-
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x3xindex>) -> tensor<1x2x1xf32>
609+
(tensor<f32>, tensor<4x5x6xf32>, tensor<1x2x2xindex>) -> tensor<1x2x1xf32>
570610
return
571611
}
572612

0 commit comments

Comments
 (0)