@@ -408,9 +408,8 @@ int simplify_iteration_stride(const int nd,
408
408
409
409
The new shape and new strides, as well as the offset
410
410
`(new_shape, new_strides1, disp1, new_stride2, disp2)` are such that
411
- iterating over them will traverse the same pairs of elements, possibly in
412
- different order.
413
-
411
+ iterating over them will traverse the same set of pairs of elements,
412
+ possibly in a different order.
414
413
*/
415
414
template <class ShapeTy , class StridesTy >
416
415
int simplify_iteration_two_strides (const int nd,
@@ -447,7 +446,7 @@ int simplify_iteration_two_strides(const int nd,
447
446
auto str1_p = strides1[p];
448
447
auto str2_p = strides2[p];
449
448
shape_w.push_back (sh_p);
450
- if (str1_p < 0 && str2_p < 0 ) {
449
+ if (str1_p <= 0 && str2_p <= 0 && std::min (str1_p, str2_p) < 0 ) {
451
450
disp1 += str1_p * (sh_p - 1 );
452
451
str1_p = -str1_p;
453
452
disp2 += str2_p * (sh_p - 1 );
@@ -468,7 +467,7 @@ int simplify_iteration_two_strides(const int nd,
468
467
StridesTy jump1 = strides1_w[i] - (shape_w[i + 1 ] - 1 ) * str1;
469
468
StridesTy jump2 = strides2_w[i] - (shape_w[i + 1 ] - 1 ) * str2;
470
469
471
- if (jump1 == str1 and jump2 == str2) {
470
+ if (jump1 == str1 && jump2 == str2) {
472
471
changed = true ;
473
472
shape_w[i] *= shape_w[i + 1 ];
474
473
for (int j = i; j < nd_; ++j) {
@@ -540,3 +539,148 @@ contract_iter2(vecT shape, vecT strides1, vecT strides2)
540
539
out_strides2.resize (nd);
541
540
return std::make_tuple (out_shape, out_strides1, disp1, out_strides2, disp2);
542
541
}
542
+
543
+ /*
544
+ For purposes of iterating over pairs of elements of three arrays
545
+ with `shape` and strides `strides1`, `strides2`, `strides3` given as
546
+ pointers `simplify_iteration_three_strides(nd, shape_ptr, strides1_ptr,
547
+ strides2_ptr, strides3_ptr, disp1, disp2, disp3)`
548
+ may modify memory and returns new length of these arrays.
549
+
550
+ The new shape and new strides, as well as the offset
551
+ `(new_shape, new_strides1, disp1, new_stride2, disp2, new_stride3, disp3)`
552
+ are such that iterating over them will traverse the same set of tuples of
553
+ elements, possibly in a different order.
554
+ */
555
+ template <class ShapeTy , class StridesTy >
556
+ int simplify_iteration_three_strides (const int nd,
557
+ ShapeTy *shape,
558
+ StridesTy *strides1,
559
+ StridesTy *strides2,
560
+ StridesTy *strides3,
561
+ StridesTy &disp1,
562
+ StridesTy &disp2,
563
+ StridesTy &disp3)
564
+ {
565
+ disp1 = std::ptrdiff_t (0 );
566
+ disp2 = std::ptrdiff_t (0 );
567
+ if (nd < 2 )
568
+ return nd;
569
+
570
+ std::vector<int > pos (nd);
571
+ std::iota (pos.begin (), pos.end (), 0 );
572
+
573
+ std::stable_sort (
574
+ pos.begin (), pos.end (), [&strides1, &shape](int i1, int i2) {
575
+ auto abs_str1 = (strides1[i1] < 0 ) ? -strides1[i1] : strides1[i1];
576
+ auto abs_str2 = (strides1[i2] < 0 ) ? -strides1[i2] : strides1[i2];
577
+ return (abs_str1 > abs_str2) ||
578
+ (abs_str1 == abs_str2 && shape[i1] > shape[i2]);
579
+ });
580
+
581
+ std::vector<ShapeTy> shape_w;
582
+ std::vector<StridesTy> strides1_w;
583
+ std::vector<StridesTy> strides2_w;
584
+ std::vector<StridesTy> strides3_w;
585
+
586
+ bool contractable = true ;
587
+ for (int i = 0 ; i < nd; ++i) {
588
+ auto p = pos[i];
589
+ auto sh_p = shape[p];
590
+ auto str1_p = strides1[p];
591
+ auto str2_p = strides2[p];
592
+ auto str3_p = strides3[p];
593
+ shape_w.push_back (sh_p);
594
+ if (str1_p <= 0 && str2_p <= 0 && str3_p <= 0 &&
595
+ std::min (std::min (str1_p, str2_p), str3_p) < 0 )
596
+ {
597
+ disp1 += str1_p * (sh_p - 1 );
598
+ str1_p = -str1_p;
599
+ disp2 += str2_p * (sh_p - 1 );
600
+ str2_p = -str2_p;
601
+ disp3 += str3_p * (sh_p - 1 );
602
+ str3_p = -str3_p;
603
+ }
604
+ if (str1_p < 0 || str2_p < 0 || str3_p < 0 ) {
605
+ contractable = false ;
606
+ }
607
+ strides1_w.push_back (str1_p);
608
+ strides2_w.push_back (str2_p);
609
+ strides3_w.push_back (str3_p);
610
+ }
611
+ int nd_ = nd;
612
+ while (contractable) {
613
+ bool changed = false ;
614
+ for (int i = 0 ; i + 1 < nd_; ++i) {
615
+ StridesTy str1 = strides1_w[i + 1 ];
616
+ StridesTy str2 = strides2_w[i + 1 ];
617
+ StridesTy str3 = strides3_w[i + 1 ];
618
+ StridesTy jump1 = strides1_w[i] - (shape_w[i + 1 ] - 1 ) * str1;
619
+ StridesTy jump2 = strides2_w[i] - (shape_w[i + 1 ] - 1 ) * str2;
620
+ StridesTy jump3 = strides3_w[i] - (shape_w[i + 1 ] - 1 ) * str3;
621
+
622
+ if (jump1 == str1 && jump2 == str2 && jump3 == str3) {
623
+ changed = true ;
624
+ shape_w[i] *= shape_w[i + 1 ];
625
+ for (int j = i; j < nd_; ++j) {
626
+ strides1_w[j] = strides1_w[j + 1 ];
627
+ }
628
+ for (int j = i; j < nd_; ++j) {
629
+ strides2_w[j] = strides2_w[j + 1 ];
630
+ }
631
+ for (int j = i; j < nd_; ++j) {
632
+ strides3_w[j] = strides3_w[j + 1 ];
633
+ }
634
+ for (int j = i + 1 ; j + 1 < nd_; ++j) {
635
+ shape_w[j] = shape_w[j + 1 ];
636
+ }
637
+ --nd_;
638
+ break ;
639
+ }
640
+ }
641
+ if (!changed)
642
+ break ;
643
+ }
644
+ for (int i = 0 ; i < nd_; ++i) {
645
+ shape[i] = shape_w[i];
646
+ }
647
+ for (int i = 0 ; i < nd_; ++i) {
648
+ strides1[i] = strides1_w[i];
649
+ }
650
+ for (int i = 0 ; i < nd_; ++i) {
651
+ strides2[i] = strides2_w[i];
652
+ }
653
+ for (int i = 0 ; i < nd_; ++i) {
654
+ strides3[i] = strides3_w[i];
655
+ }
656
+
657
+ return nd_;
658
+ }
659
+
660
+ template <typename T, class Error , typename vecT = std::vector<T>>
661
+ std::tuple<vecT, vecT, T, vecT, T, vecT, T>
662
+ contract_iter3 (vecT shape, vecT strides1, vecT strides2, vecT strides3)
663
+ {
664
+ const size_t dim = shape.size ();
665
+ if (dim != strides1.size () || dim != strides2.size () ||
666
+ dim != strides3.size ()) {
667
+ throw Error (" Shape and strides must be of equal size." );
668
+ }
669
+ vecT out_shape = shape;
670
+ vecT out_strides1 = strides1;
671
+ vecT out_strides2 = strides2;
672
+ vecT out_strides3 = strides3;
673
+ T disp1 (0 );
674
+ T disp2 (0 );
675
+ T disp3 (0 );
676
+
677
+ int nd = simplify_iteration_three_strides (
678
+ dim, out_shape.data (), out_strides1.data (), out_strides2.data (),
679
+ out_strides3.data (), disp1, disp2, disp3);
680
+ out_shape.resize (nd);
681
+ out_strides1.resize (nd);
682
+ out_strides2.resize (nd);
683
+ out_strides3.resize (nd);
684
+ return std::make_tuple (out_shape, out_strides1, disp1, out_strides2, disp2,
685
+ out_strides3, disp3);
686
+ }
0 commit comments