File tree Expand file tree Collapse file tree 2 files changed +20
-4
lines changed
lib/Dialect/Polynomial/IR Expand file tree Collapse file tree 2 files changed +20
-4
lines changed Original file line number Diff line number Diff line change 17
17
#include " mlir/IR/PatternMatch.h"
18
18
#include " mlir/Support/LogicalResult.h"
19
19
#include " llvm/ADT/APInt.h"
20
+ #include < iostream>
20
21
21
22
using namespace mlir ;
22
23
using namespace mlir ::polynomial;
@@ -107,16 +108,21 @@ LogicalResult MulScalarOp::verify() {
107
108
// / Test if a value is a primitive nth root of unity modulo cmod.
108
109
bool isPrimitiveNthRootOfUnity (const APInt &root, const APInt &n,
109
110
const APInt &cmod) {
111
+ // The first or subsequent multiplications, may overflow the input bit width,
112
+ // so scale them up to ensure they do not overflow.
113
+ unsigned requiredBitWidth =
114
+ std::max (root.getActiveBits () * 2 , cmod.getActiveBits () * 2 );
110
115
// Root bitwidth may be 1 less then cmod.
111
- APInt r = APInt (root).zext (cmod.getBitWidth ());
112
- assert (r.ule (cmod) && " root must be less than cmod" );
113
- unsigned upperBound = n.getZExtValue ();
116
+ APInt r = APInt (root).zextOrTrunc (requiredBitWidth);
117
+ APInt cmodExt = APInt (cmod).zextOrTrunc (requiredBitWidth);
118
+ assert (r.ule (cmodExt) && " root must be less than cmod" );
119
+ uint64_t upperBound = n.getZExtValue ();
114
120
115
121
APInt a = r;
116
122
for (size_t k = 1 ; k < upperBound; k++) {
117
123
if (a.isOne ())
118
124
return false ;
119
- a = (a * r).urem (cmod );
125
+ a = (a * r).urem (cmodExt );
120
126
}
121
127
return a.isOne ();
122
128
}
Original file line number Diff line number Diff line change 18
18
#ntt_ring = #polynomial.ring <coefficientType =i32 , coefficientModulus =256 , polynomialModulus =#ntt_poly >
19
19
!ntt_poly_ty = !polynomial.polynomial <ring =#ntt_ring >
20
20
21
+ #ntt_poly_2 = #polynomial.int_polynomial <1 + x **65536 >
22
+ #ntt_ring_2 = #polynomial.ring <coefficientType = i32 , coefficientModulus = 786433 : i32 , polynomialModulus =#ntt_poly_2 >
23
+ #ntt_ring_2_root = #polynomial.primitive_root <value =283965 :i32 , degree =131072 :i32 >
24
+ !ntt_poly_ty_2 = !polynomial.polynomial <ring =#ntt_ring_2 >
25
+
21
26
module {
22
27
func.func @test_multiply () -> !polynomial.polynomial <ring =#ring1 > {
23
28
%c0 = arith.constant 0 : index
@@ -95,6 +100,11 @@ module {
95
100
return
96
101
}
97
102
103
+ func.func @test_ntt_with_overflowing_root (%0 : !ntt_poly_ty_2 ) {
104
+ %1 = polynomial.ntt %0 {root =#ntt_ring_2_root } : !ntt_poly_ty_2 -> tensor <65536 xi32 , #ntt_ring_2 >
105
+ return
106
+ }
107
+
98
108
func.func @test_intt (%0 : tensor <8 xi32 , #ntt_ring >) {
99
109
%1 = polynomial.intt %0 {root =#polynomial.primitive_root <value =31 :i32 , degree =8 :index >} : tensor <8 xi32 , #ntt_ring > -> !ntt_poly_ty
100
110
return
You can’t perform that action at this time.
0 commit comments