File tree Expand file tree Collapse file tree 2 files changed +19
-4
lines changed
lib/Dialect/Polynomial/IR Expand file tree Collapse file tree 2 files changed +19
-4
lines changed Original file line number Diff line number Diff line change @@ -107,16 +107,21 @@ LogicalResult MulScalarOp::verify() {
107
107
// / Test if a value is a primitive nth root of unity modulo cmod.
108
108
bool isPrimitiveNthRootOfUnity (const APInt &root, const APInt &n,
109
109
const APInt &cmod) {
110
+ // The first or subsequent multiplications, may overflow the input bit width,
111
+ // so scale them up to ensure they do not overflow.
112
+ unsigned requiredBitWidth =
113
+ std::max (root.getActiveBits () * 2 , cmod.getActiveBits () * 2 );
110
114
// Root bitwidth may be 1 less then cmod.
111
- APInt r = APInt (root).zext (cmod.getBitWidth ());
112
- assert (r.ule (cmod) && " root must be less than cmod" );
113
- unsigned upperBound = n.getZExtValue ();
115
+ APInt r = APInt (root).zextOrTrunc (requiredBitWidth);
116
+ APInt cmodExt = APInt (cmod).zextOrTrunc (requiredBitWidth);
117
+ assert (r.ule (cmodExt) && " root must be less than cmod" );
118
+ uint64_t upperBound = n.getZExtValue ();
114
119
115
120
APInt a = r;
116
121
for (size_t k = 1 ; k < upperBound; k++) {
117
122
if (a.isOne ())
118
123
return false ;
119
- a = (a * r).urem (cmod );
124
+ a = (a * r).urem (cmodExt );
120
125
}
121
126
return a.isOne ();
122
127
}
Original file line number Diff line number Diff line change 18
18
#ntt_ring = #polynomial.ring <coefficientType =i32 , coefficientModulus =256 , polynomialModulus =#ntt_poly >
19
19
!ntt_poly_ty = !polynomial.polynomial <ring =#ntt_ring >
20
20
21
+ #ntt_poly_2 = #polynomial.int_polynomial <1 + x **65536 >
22
+ #ntt_ring_2 = #polynomial.ring <coefficientType = i32 , coefficientModulus = 786433 : i32 , polynomialModulus =#ntt_poly_2 >
23
+ #ntt_ring_2_root = #polynomial.primitive_root <value =283965 :i32 , degree =131072 :i32 >
24
+ !ntt_poly_ty_2 = !polynomial.polynomial <ring =#ntt_ring_2 >
25
+
21
26
module {
22
27
func.func @test_multiply () -> !polynomial.polynomial <ring =#ring1 > {
23
28
%c0 = arith.constant 0 : index
@@ -95,6 +100,11 @@ module {
95
100
return
96
101
}
97
102
103
+ func.func @test_ntt_with_overflowing_root (%0 : !ntt_poly_ty_2 ) {
104
+ %1 = polynomial.ntt %0 {root =#ntt_ring_2_root } : !ntt_poly_ty_2 -> tensor <65536 xi32 , #ntt_ring_2 >
105
+ return
106
+ }
107
+
98
108
func.func @test_intt (%0 : tensor <8 xi32 , #ntt_ring >) {
99
109
%1 = polynomial.intt %0 {root =#polynomial.primitive_root <value =31 :i32 , degree =8 :index >} : tensor <8 xi32 , #ntt_ring > -> !ntt_poly_ty
100
110
return
You can’t perform that action at this time.
0 commit comments