@@ -535,6 +535,18 @@ inline bool tensor_has_rank_greater_or_equal_to(
535
535
return true ;
536
536
}
537
537
538
+ inline bool tensor_has_rank_smaller_or_equal_to (
539
+ exec_aten::Tensor t,
540
+ size_t rank) {
541
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
542
+ t.dim () <= rank,
543
+ " Expected tensor.dim() to be <= %zu, but got %zu" ,
544
+ static_cast <size_t >(rank),
545
+ static_cast <size_t >(t.dim ()));
546
+
547
+ return true ;
548
+ }
549
+
538
550
inline bool tensor_has_dim (exec_aten::Tensor t, int64_t d) {
539
551
if (t.dim () == 0 ) {
540
552
ET_LOG_MSG_AND_RETURN_IF_FALSE (
@@ -551,6 +563,25 @@ inline bool tensor_has_dim(exec_aten::Tensor t, int64_t d) {
551
563
return true ;
552
564
}
553
565
566
+ inline bool tensor_dim_has_index (exec_aten::Tensor t, int64_t d, int64_t ix) {
567
+ // Indexing ops don't support zero-dim tensors
568
+ ET_CHECK (t.dim () != 0 );
569
+ if (d < 0 ) {
570
+ d += t.dim ();
571
+ }
572
+ // Dimension must have been already checked by tensor_has_dim
573
+ ET_CHECK (d >= 0 && d < t.dim ());
574
+
575
+ ET_LOG_MSG_AND_RETURN_IF_FALSE (
576
+ ix >= -t.size (d) && ix < t.size (d),
577
+ " index %" PRId64 " out of range [-%zu,%zu) at dimension %" PRId64 " )" ,
578
+ ix,
579
+ static_cast <size_t >(t.size (d)),
580
+ static_cast <size_t >(t.size (d)),
581
+ d);
582
+ return true ;
583
+ }
584
+
554
585
inline bool tensors_have_same_size_at_dims (
555
586
exec_aten::Tensor a,
556
587
size_t dim_a,
0 commit comments