-
Notifications
You must be signed in to change notification settings - Fork 14.3k
Upstream polynomial.ntt and polynomial.intt #90992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a couple of stylistic nits. I didn't try getting deep into the math, LMK if I should.
}]; | ||
let arguments = (ins Polynomial_PolynomialType:$input); | ||
let results = (outs RankedTensorOf<[AnyInteger]>:$output); | ||
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: would functional-type
work here instead of explicit arrows?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC this one avoids wrapping the sole input in parentheses in the assembly format.
FWIW, everything I'm upstreaming has been reviewed for mathematical content in https://github.com/google/heir. I don't know if that precludes the need for a secondary "math" review upstream, but at least you can be assured that we have gone a few steps ahead. In particular, we have a lowering of the NTT op via standard algorithms and functional tests for correctness. |
@llvm/pr-subscribers-mlir Author: Jeremy Kun (j2kun) ChangesThese two ops represent a number-theoretic transform of a polynomial to a tensor of evaluations of the polynomial at a list of powers of primitive roots of the polynomial. To support this, a new optional attribute is added to the ring attribute to specify the primitive root of unity used for the NTT. A verifier for the op is added to ensure the chosen root is a primitive nth root of unity. Full diff: https://github.com/llvm/llvm-project/pull/90992.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
index d3e3ac55677f86..ed1f4ce8b7e599 100644
--- a/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
+++ b/mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td
@@ -79,7 +79,7 @@ def Polynomial_PolynomialAttr : Polynomial_Attr<"Polynomial", "polynomial"> {
#poly = #polynomial.polynomial<x**1024 + 1>
```
}];
- let parameters = (ins "Polynomial":$polynomial);
+ let parameters = (ins "::mlir::polynomial::Polynomial":$polynomial);
let hasCustomAssemblyFormat = 1;
}
@@ -122,10 +122,19 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
let parameters = (ins
"Type": $coefficientType,
- OptionalParameter<"IntegerAttr">: $coefficientModulus,
- OptionalParameter<"PolynomialAttr">: $polynomialModulus
+ OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
+ OptionalParameter<"::mlir::polynomial::PolynomialAttr">: $polynomialModulus,
+ OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
);
+ let builders = [
+ AttrBuilder<
+ (ins "::mlir::Type":$coefficientTy,
+ "::mlir::IntegerAttr":$coefficientModulusAttr,
+ "::mlir::polynomial::PolynomialAttr":$polynomialModulusAttr), [{
+ return $_get($_ctxt, coefficientTy, coefficientModulusAttr, polynomialModulusAttr, nullptr);
+ }]>
+ ];
let hasCustomAssemblyFormat = 1;
}
@@ -416,4 +425,45 @@ def Polynomial_ConstantOp : Polynomial_Op<"constant", [Pure]> {
let assemblyFormat = "$input attr-dict `:` type($output)";
}
+def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
+ let summary = "Computes point-value tensor representation of a polynomial.";
+ let description = [{
+ `polynomial.ntt` computes the forward integer Number Theoretic Transform
+ (NTT) on the input polynomial. It returns a tensor containing a point-value
+ representation of the input polynomial. The output tensor has shape equal
+ to the degree of the ring's `polynomialModulus`. The polynomial's RingAttr
+ is embedded as the encoding attribute of the output tensor.
+
+ Given an input polynomial `F(x)` over a ring whose `polynomialModulus` has
+ degree `n`, and a primitive `n`-th root of unity `omega_n`, the output is
+ the list of $n$ evaluations
+
+ `f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
+
+ The choice of primitive root is determined by subsequent lowerings.
+ }];
+ let arguments = (ins Polynomial_PolynomialType:$input);
+ let results = (outs RankedTensorOf<[AnyInteger]>:$output);
+ let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+ let hasVerifier = 1;
+}
+
+def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
+ let summary = "Computes the reverse integer Number Theoretic Transform (NTT).";
+ let description = [{
+ `polynomial.intt` computes the reverse integer Number Theoretic Transform
+ (INTT) on the input tensor. This is the inverse operation of the
+ `polynomial.ntt` operation.
+
+ The input tensor is interpreted as a point-value representation of the
+ output polynomial at powers of a primitive `n`-th root of unity (see
+ `polynomial.ntt`). The ring of the polynomial is taken from the required
+ encoding attribute of the tensor.
+ }];
+ let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
+ let results = (outs Polynomial_PolynomialType:$output);
+ let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
+ let hasVerifier = 1;
+}
+
#endif // POLYNOMIAL_OPS
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
index f1ec2be72a33ab..45263a5e97e72d 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp
@@ -202,11 +202,29 @@ Attribute RingAttr::parse(AsmParser &parser, Type type) {
polyAttr = attr;
}
+ Polynomial poly = polyAttr.getPolynomial();
+ APInt root(coefficientModulusAttr.getValue().getBitWidth(), 0);
+ IntegerAttr rootAttr = nullptr;
+ if (succeeded(parser.parseOptionalComma())) {
+ if (failed(parser.parseKeyword("primitiveRoot")))
+ return {};
+
+ if (failed(parser.parseEqual()))
+ return {};
+
+ ParseResult result = parser.parseInteger(root);
+ if (failed(result)) {
+ parser.emitError(parser.getCurrentLocation(), "invalid primitiveRoot");
+ return {};
+ }
+ rootAttr = IntegerAttr::get(coefficientModulusAttr.getType(), root);
+ }
+
if (failed(parser.parseGreater()))
return {};
return RingAttr::get(parser.getContext(), ty, coefficientModulusAttr,
- polyAttr);
+ polyAttr, rootAttr);
}
} // namespace polynomial
diff --git a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
index 8e2bb5f27dc6cc..1d5c7be4b6752a 100644
--- a/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
+++ b/mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp
@@ -104,3 +104,80 @@ LogicalResult MulScalarOp::verify() {
return success();
}
+
+// Test if a value is a primitive nth root of unity modulo cmod
+bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
+ const APInt &cmod) {
+ // root bitwidth may be 1 less then cmod
+ APInt r = APInt(root).zext(cmod.getBitWidth());
+ assert(r.ule(cmod) && "root must be less than cmod");
+
+ APInt a = r;
+ for (size_t k = 1; k < n; k++) {
+ if (a.isOne())
+ return false;
+ a = (a * r).urem(cmod);
+ }
+ return a.isOne();
+}
+
+static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
+ RankedTensorType tensorType) {
+ auto encoding = tensorType.getEncoding();
+ if (!encoding) {
+ return op->emitOpError()
+ << "a ring encoding was not provided to the tensor";
+ }
+ auto encodedRing = dyn_cast<RingAttr>(encoding);
+ if (!encodedRing) {
+ return op->emitOpError()
+ << "the provided tensor encoding is not a ring attribute";
+ }
+
+ if (encodedRing != ring) {
+ return op->emitOpError()
+ << "encoded ring type " << encodedRing
+ << " is not equivalent to the polynomial ring " << ring;
+ }
+
+ auto polyDegree = ring.getPolynomialModulus().getPolynomial().getDegree();
+ auto tensorShape = tensorType.getShape();
+ bool compatible = tensorShape.size() == 1 && tensorShape[0] == polyDegree;
+ if (!compatible) {
+ InFlightDiagnostic diag = op->emitOpError()
+ << "tensor type " << tensorType
+ << " does not match output type " << ring;
+ diag.attachNote() << "the tensor must have shape [d] where d "
+ "is exactly the degree of the polynomialModulus of "
+ "the polynomial type's ring attribute";
+ return diag;
+ }
+
+ if (!ring.getPrimitiveRoot()) {
+ return op->emitOpError()
+ << "ring type " << ring << " does not provide a primitive root "
+ << "of unity, which is required to express an NTT";
+ }
+
+ if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
+ ring.getCoefficientModulus().getValue())) {
+ return op->emitOpError()
+ << "ring type " << ring << " has a primitiveRoot attribute '"
+ << ring.getPrimitiveRoot()
+ << "' that is not a primitive root of the coefficient ring";
+ }
+
+ return success();
+}
+
+LogicalResult NTTOp::verify() {
+ auto ring = getInput().getType().getRing();
+ auto tensorType = getOutput().getType();
+ return verifyNTTOp(this->getOperation(), ring, tensorType);
+}
+
+LogicalResult INTTOp::verify() {
+ auto tensorType = getInput().getType();
+ auto ring = getOutput().getType().getRing();
+ return verifyNTTOp(this->getOperation(), ring, tensorType);
+}
diff --git a/mlir/test/Dialect/Polynomial/ops.mlir b/mlir/test/Dialect/Polynomial/ops.mlir
index ea1b279fa1ff96..a29cfc2e9cc549 100644
--- a/mlir/test/Dialect/Polynomial/ops.mlir
+++ b/mlir/test/Dialect/Polynomial/ops.mlir
@@ -10,9 +10,13 @@
#one_plus_x_squared = #polynomial.polynomial<1 + x**2>
#ideal = #polynomial.polynomial<-1 + x**1024>
-#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=18, polynomialModulus=#ideal>
+#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
!poly_ty = !polynomial.polynomial<#ring>
+#ntt_poly = #polynomial.polynomial<-1 + x**8>
+#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
+!ntt_poly_ty = !polynomial.polynomial<#ntt_ring>
+
module {
func.func @test_multiply() -> !polynomial.polynomial<#ring1> {
%c0 = arith.constant 0 : index
@@ -79,4 +83,14 @@ module {
%1 = polynomial.constant <1 + x**2> : !polynomial.polynomial<#ring1>
return
}
+
+ func.func @test_ntt(%0 : !ntt_poly_ty) {
+ %1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
+ return
+ }
+
+ func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
+ %1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
+ return
+ }
}
diff --git a/mlir/test/Dialect/Polynomial/ops_errors.mlir b/mlir/test/Dialect/Polynomial/ops_errors.mlir
index c34a7de30e5fe5..9029017256be3a 100644
--- a/mlir/test/Dialect/Polynomial/ops_errors.mlir
+++ b/mlir/test/Dialect/Polynomial/ops_errors.mlir
@@ -51,3 +51,90 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
%poly = polynomial.mul_scalar %arg0, %scalar : !ty, i32
return %poly : !ty
}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+ // expected-error@+1 {{a ring encoding was not provided to the tensor}}
+ %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+ // expected-error@+1 {{tensor encoding is not a ring attribute}}
+ %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
+ // expected-error@+1 {{not equivalent to the polynomial ring}}
+ %1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=31>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
+ // expected-error@below {{does not match output type}}
+ // expected-note@below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
+ %1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**1024>
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_ntt
+// CHECK-NOT: polynomial.ntt
+func.func @test_invalid_ntt(%0 : !poly_ty) {
+ // expected-error@+1 {{does not provide a primitive root of unity, which is required to express an NTT}}
+ %1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
+ return
+}
+
+// -----
+
+#my_poly = #polynomial.polynomial<-1 + x**8>
+// A valid root is 31
+#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256, polynomialModulus=#my_poly, primitiveRoot=32>
+!poly_ty = !polynomial.polynomial<#ring>
+
+// CHECK-NOT: @test_invalid_intt
+// CHECK-NOT: polynomial.intt
+func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
+ // expected-error@below {{has a primitiveRoot attribute '32 : i16' that is not a primitive root of the coefficient ring}}
+ %1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
+ return
+}
|
Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
Co-authored-by: Oleksandr "Alex" Zinenko <[email protected]>
These two ops represent a number-theoretic transform of a polynomial to a tensor of evaluations of the polynomial at a list of powers of primitive roots of the polynomial.
To support this, a new optional attribute is added to the ring attribute to specify the primitive root of unity used for the NTT. A verifier for the op is added to ensure the chosen root is a primitive nth root of unity.