Skip to content

Commit 624c9fc

Browse files
j2kunftynse
andauthored
Upstream polynomial.ntt and polynomial.intt (llvm#90992)
These two ops represent a number-theoretic transform of a polynomial to a tensor of evaluations of the polynomial at a list of powers of primitive roots of the polynomial. To support this, a new optional attribute is added to the ring attribute to specify the primitive root of unity used for the NTT. A verifier for the op is added to ensure the chosen root is a primitive nth root of unity. --------- Co-authored-by: Jeremy Kun <[email protected]> Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
1 parent 716eab7 commit 624c9fc

File tree

5 files changed

+251
-5
lines changed

5 files changed

+251
-5
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
7979
#poly = #polynomial.polynomial<x**1024 + 1>
8080
```
8181
}];
82-
let parameters = (ins "Polynomial":$polynomial);
82+
let parameters = (ins "::mlir::polynomial::Polynomial":$polynomial);
8383
let hasCustomAssemblyFormat = 1;
8484
}
8585

@@ -122,10 +122,19 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
122122

123123
let parameters = (ins
124124
"Type": $coefficientType,
125-
OptionalParameter<"IntegerAttr">: $coefficientModulus,
126-
OptionalParameter<"PolynomialAttr">: $polynomialModulus
125+
OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
126+
OptionalParameter<"::mlir::polynomial::PolynomialAttr">: $polynomialModulus,
127+
OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
127128
);
128129

130+
let builders = [
131+
AttrBuilder<
132+
(ins "::mlir::Type":$coefficientTy,
133+
"::mlir::IntegerAttr":$coefficientModulusAttr,
134+
"::mlir::polynomial::PolynomialAttr":$polynomialModulusAttr), [{
135+
return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
136+
}]>
137+
];
129138
let hasCustomAssemblyFormat = 1;
130139
}
131140

@@ -416,4 +425,45 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
416425
let assemblyFormat = "$input attr-dict `:` type($output)";
417426
}
418427

428+
def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
429+
let summary = "Computes point-value tensor representation of a polynomial.";
430+
let description = [{
431+
`polynomial.ntt` computes the forward integer Number Theoretic Transform
432+
(NTT) on the input polynomial. It returns a tensor containing a point-value
433+
representation of the input polynomial. The output tensor has shape equal
434+
to the degree of the ring's `polynomialModulus`. The polynomial's RingAttr
435+
is embedded as the encoding attribute of the output tensor.
436+
437+
Given an input polynomial `F(x)` over a ring whose `polynomialModulus` has
438+
degree `n`, and a primitive `n`-th root of unity `omega_n`, the output is
439+
the list of $n$ evaluations
440+
441+
`f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
442+
443+
The choice of primitive root is determined by subsequent lowerings.
444+
}];
445+
let arguments = (ins Polynomial_PolynomialType:$input);
446+
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
447+
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
448+
let hasVerifier = 1;
449+
}
450+
451+
def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
452+
let summary = "Computes the reverse integer Number Theoretic Transform (NTT).";
453+
let description = [{
454+
`polynomial.intt` computes the reverse integer Number Theoretic Transform
455+
(INTT) on the input tensor. This is the inverse operation of the
456+
`polynomial.ntt` operation.
457+
458+
The input tensor is interpreted as a point-value representation of the
459+
output polynomial at powers of a primitive `n`-th root of unity (see
460+
`polynomial.ntt`). The ring of the polynomial is taken from the required
461+
encoding attribute of the tensor.
462+
}];
463+
let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
464+
let results = (outs Polynomial_PolynomialType:$output);
465+
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
466+
let hasVerifier = 1;
467+
}
468+
419469
#endif // POLYNOMIAL_OPS

mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,27 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
202202
polyAttr = attr;
203203
}
204204

205+
Polynomial poly = polyAttr.getPolynomial();
206+
APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
207+
IntegerAttr rootAttr = nullptr;
208+
if (succeeded(parser.parseOptionalComma())) {
209+
if (failed(parser.parseKeyword("primitiveRoot")) ||
210+
failed(parser.parseEqual()))
211+
return {};
212+
213+
ParseResult result = parser.parseInteger(root);
214+
if (failed(result)) {
215+
parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
216+
return {};
217+
}
218+
rootAttr = IntegerAttr::get(coefficientModulusAttr.getType(), root);
219+
}
220+
205221
if (failed(parser.parseGreater()))
206222
return {};
207223

208224
return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
209-
polyAttr);
225+
polyAttr, rootAttr);
210226
}
211227

212228
} // namespace polynomial

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,82 @@ LogicalResult MulScalarOp::verify() {
104104

105105
return success();
106106
}
107+
108+
/// Test if a value is a primitive nth root of unity modulo cmod.
109+
bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
110+
const APInt &cmod) {
111+
// Root bitwidth may be 1 less then cmod.
112+
APInt r = APInt(root).zext(cmod.getBitWidth());
113+
assert(r.ule(cmod) && "root must be less than cmod");
114+
115+
APInt a = r;
116+
for (size_t k = 1; k < n; k++) {
117+
if (a.isOne())
118+
return false;
119+
a = (a * r).urem(cmod);
120+
}
121+
return a.isOne();
122+
}
123+
124+
/// Verify that the types involved in an NTT or INTT operation are
125+
/// compatible.
126+
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
127+
RankedTensorType tensorType) {
128+
Attribute encoding = tensorType.getEncoding();
129+
if (!encoding) {
130+
return op->emitOpError()
131+
<< "expects a ring encoding to be provided to the tensor";
132+
}
133+
auto encodedRing = dyn_cast<RingAttr>(encoding);
134+
if (!encodedRing) {
135+
return op->emitOpError()
136+
<< "the provided tensor encoding is not a ring attribute";
137+
}
138+
139+
if (encodedRing != ring) {
140+
return op->emitOpError()
141+
<< "encoded ring type " << encodedRing
142+
<< " is not equivalent to the polynomial ring " << ring;
143+
}
144+
145+
unsigned polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
146+
ArrayRef<int64_t> tensorShape = tensorType.getShape();
147+
bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
148+
if (!compatible) {
149+
InFlightDiagnostic diag = op->emitOpError()
150+
<< "tensor type " << tensorType
151+
<< " does not match output type " << ring;
152+
diag.attachNote() << "the tensor must have shape [d] where d "
153+
"is exactly the degree of the polynomialModulus of "
154+
"the polynomial type's ring attribute";
155+
return diag;
156+
}
157+
158+
if (!ring.getPrimitiveRoot()) {
159+
return op->emitOpError()
160+
<< "ring type " << ring << " does not provide a primitive root "
161+
<< "of unity, which is required to express an NTT";
162+
}
163+
164+
if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
165+
ring.getCoefficientModulus().getValue())) {
166+
return op->emitOpError()
167+
<< "ring type " << ring << " has a primitiveRoot attribute '"
168+
<< ring.getPrimitiveRoot()
169+
<< "' that is not a primitive root of the coefficient ring";
170+
}
171+
172+
return success();
173+
}
174+
175+
LogicalResult NTTOp::verify() {
176+
auto ring = getInput().getType().getRing();
177+
auto tensorType = getOutput().getType();
178+
return verifyNTTOp(this->getOperation(), ring, tensorType);
179+
}
180+
181+
LogicalResult INTTOp::verify() {
182+
auto tensorType = getInput().getType();
183+
auto ring = getOutput().getType().getRing();
184+
return verifyNTTOp(this->getOperation(), ring, tensorType);
185+
}

mlir/test/Dialect/Polynomial/ops.mlir

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,13 @@
1010
#one_plus_x_squared = #polynomial.polynomial<1 + x**2>
1111

1212
#ideal = #polynomial.polynomial<-1 + x**1024>
13-
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=18, polynomialModulus=#ideal>
13+
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
1414
!poly_ty = !polynomial.polynomial<#ring>
1515

16+
#ntt_poly = #polynomial.polynomial<-1 + x**8>
17+
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
18+
!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
19+
1620
module {
1721
func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
1822
%c0 = arith.constant 0 : index
@@ -79,4 +83,14 @@ module {
7983
%1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
8084
return
8185
}
86+
87+
func.func @test_ntt(%0 : !ntt_poly_ty) {
88+
%1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
89+
return
90+
}
91+
92+
func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
93+
%1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
94+
return
95+
}
8296
}

mlir/test/Dialect/Polynomial/ops_errors.mlir

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,90 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
5151
%poly = polynomial.mul_scalar %arg0, %scalar : !ty, i32
5252
return %poly : !ty
5353
}
54+
55+
// -----
56+
57+
#my_poly = #polynomial.polynomial<-1 + x**1024>
58+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
59+
!poly_ty = !polynomial.polynomial<#ring>
60+
61+
// CHECK-NOT: @test_invalid_ntt
62+
// CHECK-NOT: polynomial.ntt
63+
func.func @test_invalid_ntt(%0 : !poly_ty) {
64+
// expected-error@below {{expects a ring encoding to be provided to the tensor}}
65+
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
66+
return
67+
}
68+
69+
// -----
70+
71+
#my_poly = #polynomial.polynomial<-1 + x**1024>
72+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
73+
!poly_ty = !polynomial.polynomial<#ring>
74+
75+
// CHECK-NOT: @test_invalid_ntt
76+
// CHECK-NOT: polynomial.ntt
77+
func.func @test_invalid_ntt(%0 : !poly_ty) {
78+
// expected-error@below {{tensor encoding is not a ring attribute}}
79+
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
80+
return
81+
}
82+
83+
// -----
84+
85+
#my_poly = #polynomial.polynomial<-1 + x**1024>
86+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
87+
#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257, polynomialModulus=#my_poly, primitiveRoot=31>
88+
!poly_ty = !polynomial.polynomial<#ring>
89+
90+
// CHECK-NOT: @test_invalid_intt
91+
// CHECK-NOT: polynomial.intt
92+
func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
93+
// expected-error@below {{not equivalent to the polynomial ring}}
94+
%1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
95+
return
96+
}
97+
98+
// -----
99+
100+
#my_poly = #polynomial.polynomial<-1 + x**1024>
101+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
102+
!poly_ty = !polynomial.polynomial<#ring>
103+
104+
// CHECK-NOT: @test_invalid_intt
105+
// CHECK-NOT: polynomial.intt
106+
func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
107+
// expected-error@below {{does not match output type}}
108+
// expected-note@below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
109+
%1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
110+
return
111+
}
112+
113+
// -----
114+
115+
#my_poly = #polynomial.polynomial<-1 + x**1024>
116+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
117+
!poly_ty = !polynomial.polynomial<#ring>
118+
119+
// CHECK-NOT: @test_invalid_ntt
120+
// CHECK-NOT: polynomial.ntt
121+
func.func @test_invalid_ntt(%0 : !poly_ty) {
122+
// expected-error@below {{does not provide a primitive root of unity, which is required to express an NTT}}
123+
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
124+
return
125+
}
126+
127+
// -----
128+
129+
#my_poly = #polynomial.polynomial<-1 + x**8>
130+
// A valid root is 31
131+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=32>
132+
!poly_ty = !polynomial.polynomial<#ring>
133+
134+
// CHECK-NOT: @test_invalid_intt
135+
// CHECK-NOT: polynomial.intt
136+
func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
137+
// expected-error@below {{has a primitiveRoot attribute '32 : i16' that is not a primitive root of the coefficient ring}}
138+
%1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
139+
return
140+
}

0 commit comments

Comments
 (0)