Skip to content

Commit cb6797e

Browse files
Closes gh-1250 by fixing stride simplification logic
When simplifying iteration space, first consider strides of output array (last stride argument) since output arrays should never have zero strides.
1 parent 616c21e commit cb6797e

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

dpctl/tensor/libtensor/include/utils/strided_iters.hpp

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -444,10 +444,10 @@ int simplify_iteration_two_strides(const int nd,
444444
(strides2[i1] < 0) ? -strides2[i1] : strides2[i1];
445445
auto abs_str2_i2 =
446446
(strides2[i2] < 0) ? -strides2[i2] : strides2[i2];
447-
return (abs_str1_i1 > abs_str1_i2) ||
448-
(abs_str1_i1 == abs_str1_i2 &&
449-
(abs_str2_i1 > abs_str2_i2 ||
450-
(abs_str2_i1 == abs_str2_i2 && shape[i1] > shape[i2])));
447+
return (abs_str2_i1 > abs_str2_i2) ||
448+
(abs_str2_i1 == abs_str2_i2 &&
449+
(abs_str1_i1 > abs_str1_i2 ||
450+
(abs_str1_i1 == abs_str1_i2 && shape[i1] > shape[i2])));
451451
});
452452

453453
std::vector<ShapeTy> shape_w;
@@ -600,12 +600,12 @@ int simplify_iteration_three_strides(const int nd,
600600
(strides3[i1] < 0) ? -strides3[i1] : strides3[i1];
601601
auto abs_str3_i2 =
602602
(strides3[i2] < 0) ? -strides3[i2] : strides3[i2];
603-
return (abs_str1_i1 > abs_str1_i2) ||
604-
((abs_str1_i1 == abs_str1_i2) &&
603+
return (abs_str3_i1 > abs_str3_i2) ||
604+
((abs_str3_i1 == abs_str3_i2) &&
605605
((abs_str2_i1 > abs_str2_i2) ||
606606
((abs_str2_i1 == abs_str2_i2) &&
607-
((abs_str3_i1 > abs_str3_i2) ||
608-
((abs_str3_i1 == abs_str3_i2) &&
607+
((abs_str1_i1 > abs_str1_i2) ||
608+
((abs_str1_i1 == abs_str1_i2) &&
609609
(shape[i1] > shape[i2]))))));
610610
});
611611

@@ -768,14 +768,14 @@ int simplify_iteration_four_strides(const int nd,
768768
(strides4[i1] < 0) ? -strides4[i1] : strides4[i1];
769769
auto abs_str4_i2 =
770770
(strides4[i2] < 0) ? -strides4[i2] : strides4[i2];
771-
return (abs_str1_i1 > abs_str1_i2) ||
772-
((abs_str1_i1 == abs_str1_i2) &&
773-
((abs_str2_i1 > abs_str2_i2) ||
774-
((abs_str2_i1 == abs_str2_i2) &&
775-
((abs_str3_i1 > abs_str3_i2) ||
776-
((abs_str3_i1 == abs_str3_i2) &&
777-
((abs_str4_i1 > abs_str4_i2) ||
778-
((abs_str4_i1 == abs_str4_i2) &&
771+
return (abs_str4_i1 > abs_str4_i2) ||
772+
((abs_str4_i1 == abs_str4_i2) &&
773+
((abs_str3_i1 > abs_str3_i2) ||
774+
((abs_str3_i1 == abs_str3_i2) &&
775+
((abs_str2_i1 > abs_str2_i2) ||
776+
((abs_str2_i1 == abs_str2_i2) &&
777+
((abs_str1_i1 > abs_str1_i2) ||
778+
((abs_str1_i1 == abs_str1_i2) &&
779779
(shape[i1] > shape[i2]))))))));
780780
});
781781

0 commit comments

Comments
 (0)