File tree Expand file tree Collapse file tree 2 files changed +19
-5
lines changed
lib/Dialect/Polynomial/IR Expand file tree Collapse file tree 2 files changed +19
-5
lines changed Original file line number Diff line number Diff line change @@ -114,16 +114,20 @@ LogicalResult MulScalarOp::verify() {
114
114
// / Test if a value is a primitive nth root of unity modulo cmod.
115
115
bool isPrimitiveNthRootOfUnity (const APInt &root, const APInt &n,
116
116
const APInt &cmod) {
117
- // Root bitwidth may be 1 less then cmod.
118
- APInt r = APInt (root).zext (cmod.getBitWidth ());
119
- assert (r.ule (cmod) && " root must be less than cmod" );
120
- unsigned upperBound = n.getZExtValue ();
117
+ // The first or subsequent multiplications, may overflow the input bit width,
118
+ // so scale them up to ensure they do not overflow.
119
+ unsigned requiredBitWidth =
120
+ std::max (root.getActiveBits () * 2 , cmod.getActiveBits () * 2 );
121
+ APInt r = APInt (root).zextOrTrunc (requiredBitWidth);
122
+ APInt cmodExt = APInt (cmod).zextOrTrunc (requiredBitWidth);
123
+ assert (r.ule (cmodExt) && " root must be less than cmod" );
124
+ uint64_t upperBound = n.getZExtValue ();
121
125
122
126
APInt a = r;
123
127
for (size_t k = 1 ; k < upperBound; k++) {
124
128
if (a.isOne ())
125
129
return false ;
126
- a = (a * r).urem (cmod );
130
+ a = (a * r).urem (cmodExt );
127
131
}
128
132
return a.isOne ();
129
133
}
Original file line number Diff line number Diff line change 18
18
#ntt_ring = #polynomial.ring <coefficientType =i32 , coefficientModulus =256 , polynomialModulus =#ntt_poly >
19
19
!ntt_poly_ty = !polynomial.polynomial <ring =#ntt_ring >
20
20
21
+ #ntt_poly_2 = #polynomial.int_polynomial <1 + x **65536 >
22
+ #ntt_ring_2 = #polynomial.ring <coefficientType = i32 , coefficientModulus = 786433 : i32 , polynomialModulus =#ntt_poly_2 >
23
+ #ntt_ring_2_root = #polynomial.primitive_root <value =283965 :i32 , degree =131072 :i32 >
24
+ !ntt_poly_ty_2 = !polynomial.polynomial <ring =#ntt_ring_2 >
25
+
21
26
module {
22
27
func.func @test_multiply () -> !polynomial.polynomial <ring =#ring1 > {
23
28
%c0 = arith.constant 0 : index
@@ -95,6 +100,11 @@ module {
95
100
return
96
101
}
97
102
103
+ func.func @test_ntt_with_overflowing_root (%0 : !ntt_poly_ty_2 ) {
104
+ %1 = polynomial.ntt %0 {root =#ntt_ring_2_root } : !ntt_poly_ty_2 -> tensor <65536 xi32 , #ntt_ring_2 >
105
+ return
106
+ }
107
+
98
108
func.func @test_intt (%0 : tensor <8 xi32 , #ntt_ring >) {
99
109
%1 = polynomial.intt %0 {root =#polynomial.primitive_root <value =31 :i32 , degree =8 :index >} : tensor <8 xi32 , #ntt_ring > -> !ntt_poly_ty
100
110
return
You can’t perform that action at this time.
0 commit comments