Skip to content

Commit 97eb257

Browse files
committed
[mlir][polynomial] ensure primitive root calculation doesn't overflow
1 parent 30f51bf commit 97eb257

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/IR/PatternMatch.h"
1818
#include "mlir/Support/LogicalResult.h"
1919
#include "llvm/ADT/APInt.h"
20+
#include <iostream>
2021

2122
using namespace mlir;
2223
using namespace mlir::polynomial;
@@ -107,16 +108,21 @@ LogicalResult MulScalarOp::verify() {
107108
/// Test if a value is a primitive nth root of unity modulo cmod.
108109
bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
109110
const APInt &cmod) {
111+
// The first or subsequent multiplications, may overflow the input bit width,
112+
// so scale them up to ensure they do not overflow.
113+
unsigned requiredBitWidth =
114+
std::max(root.getActiveBits() * 2, cmod.getActiveBits() * 2);
110115
// Root bitwidth may be 1 less then cmod.
111-
APInt r = APInt(root).zext(cmod.getBitWidth());
112-
assert(r.ule(cmod) && "root must be less than cmod");
113-
unsigned upperBound = n.getZExtValue();
116+
APInt r = APInt(root).zextOrTrunc(requiredBitWidth);
117+
APInt cmodExt = APInt(cmod).zextOrTrunc(requiredBitWidth);
118+
assert(r.ule(cmodExt) && "root must be less than cmod");
119+
uint64_t upperBound = n.getZExtValue();
114120

115121
APInt a = r;
116122
for (size_t k = 1; k < upperBound; k++) {
117123
if (a.isOne())
118124
return false;
119-
a = (a * r).urem(cmod);
125+
a = (a * r).urem(cmodExt);
120126
}
121127
return a.isOne();
122128
}

mlir/test/Dialect/Polynomial/ops.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
1919
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
2020

21+
#ntt_poly_2 = #polynomial.int_polynomial<1 + x**65536>
22+
#ntt_ring_2 = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#ntt_poly_2>
23+
#ntt_ring_2_root = #polynomial.primitive_root<value=283965:i32, degree=131072:i32>
24+
!ntt_poly_ty_2 = !polynomial.polynomial<ring=#ntt_ring_2>
25+
2126
module {
2227
func.func @test_multiply() -> !polynomial.polynomial<ring=#ring1> {
2328
%c0 = arith.constant 0 : index
@@ -95,6 +100,11 @@ module {
95100
return
96101
}
97102

103+
func.func @test_ntt_with_overflowing_root(%0 : !ntt_poly_ty_2) {
104+
%1 = polynomial.ntt %0 {root=#ntt_ring_2_root} : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2>
105+
return
106+
}
107+
98108
func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
99109
%1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
100110
return

0 commit comments

Comments
 (0)