Skip to content

Commit 7e6e545

Browse files
[mlir][polynomial] Move primitive root attr to ring attr
Related to llvm#93227 and google/heir#993 When ntt/intt ops are emitted as a result of pattern rewrite, the primitive root attr must be provided in some way, and it is convenient for it to be provided in ring attr. As for using different primitive root for the same polynomial, to_tensor/tensor.cast/from_tensor should be enough for changing primitiveRoot attribute in RingAttr.
1 parent 72fb379 commit 7e6e545

File tree

8 files changed

+94
-72
lines changed

8 files changed

+94
-72
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -311,12 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
311311

312312
`f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
313313

314-
The choice of primitive root may be optionally specified.
314+
The choice of primitive root is specified in the primitiveRootAttr of RingAttr.
315+
Its degree affects the behavior of ntt performed, with n-th primitive root
316+
performing cyclic convolution and 2n-th primitive root performing negacyclic
317+
convolution.
315318
}];
316-
let arguments = (ins
317-
Polynomial_PolynomialType:$input,
318-
OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
319-
);
319+
let arguments = (ins Polynomial_PolynomialType:$input);
320320
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
321321
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
322322
let hasCanonicalizer = 1;
@@ -335,12 +335,12 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
335335
`polynomial.ntt`). The ring of the polynomial is taken from the required
336336
encoding attribute of the tensor.
337337

338-
The choice of primitive root may be optionally specified.
338+
The choice of primitive root is specified in the primitiveRootAttr of RingAttr.
339+
Its degree affects the behavior of ntt performed, with n-th primitive root
340+
performing cyclic convolution and 2n-th primitive root performing negacyclic
341+
convolution.
339342
}];
340-
let arguments = (
341-
ins RankedTensorOf<[AnyInteger]>:$input,
342-
OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
343-
);
343+
let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
344344
let results = (outs Polynomial_PolynomialType:$output);
345345
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
346346
let hasCanonicalizer = 1;

mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,26 @@ def Polynomial_TypedFloatPolynomialAttr : Polynomial_Attr<
126126
}];
127127
}
128128

129+
def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
130+
let summary = "an attribute containing an integer and its degree as a root of unity";
131+
let description = [{
132+
A primitive root attribute stores an integer root `value` and an integer
133+
`degree`, corresponding to a primitive root of unity of the given degree in
134+
an unspecified ring.
135+
136+
Example:
137+
138+
```mlir
139+
#poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
140+
```
141+
}];
142+
let parameters = (ins
143+
"::mlir::IntegerAttr":$value,
144+
"::mlir::IntegerAttr":$degree
145+
);
146+
let assemblyFormat = "`<` struct(params) `>`";
147+
}
148+
129149
def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
130150
let summary = "an attribute specifying a polynomial ring";
131151
let description = [{
@@ -142,6 +162,9 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
142162
modulus. For single-variable polynomials, an "polynomialModulus" is always specificed
143163
via a single polynomial, which we call `polynomialModulus`.
144164

165+
For ntt/intt and mul to ntt/intt optimization to work, an n-th or 2n-th
166+
_primitiveRoot_ should be specified.
167+
145168
An expressive example is polynomials with i32 coefficients, whose
146169
coefficients are taken modulo `2**32 - 5`, with a polynomial modulus of
147170
`x**1024 - 1`.
@@ -177,46 +200,25 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
177200
let parameters = (ins
178201
"Type": $coefficientType,
179202
OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
180-
OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus
203+
OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
204+
OptionalParameter<"::mlir::polynomial::PrimitiveRootAttr">: $primitiveRoot
181205
);
182206
let genVerifyDecl = 1;
183207
let assemblyFormat = "`<` struct(params) `>`";
184208
let builders = [
185209
AttrBuilderWithInferredContext<
186210
(ins "::mlir::Type":$coefficientTy,
187211
CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
188-
CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{
212+
CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
213+
CArg<"::mlir::polynomial::PrimitiveRootAttr", "nullptr"> :$primitiveRootAttr), [{
189214
return $_get(
190215
coefficientTy.getContext(),
191216
coefficientTy,
192217
coefficientModulusAttr,
193-
polynomialModulusAttr);
218+
polynomialModulusAttr,
219+
primitiveRootAttr);
194220
}]>,
195221
];
196222
}
197223

198-
def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
199-
let summary = "an attribute containing an integer and its degree as a root of unity";
200-
let description = [{
201-
A primitive root attribute stores an integer root `value` and an integer
202-
`degree`, corresponding to a primitive root of unity of the given degree in
203-
an unspecified ring.
204-
205-
This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops
206-
to specify the root of unity used in lowering the transform.
207-
208-
Example:
209-
210-
```mlir
211-
#poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
212-
```
213-
}];
214-
let parameters = (ins
215-
"::mlir::IntegerAttr":$value,
216-
"::mlir::IntegerAttr":$degree
217-
);
218-
let assemblyFormat = "`<` struct(params) `>`";
219-
}
220-
221-
222224
#endif // POLYNOMIAL_ATTRIBUTES

mlir/lib/Dialect/Polynomial/IR/PolynomialAttributes.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ Attribute FloatPolynomialAttr::parse(AsmParser &parser, Type type) {
206206
LogicalResult
207207
RingAttr::verify(function_ref<mlir::InFlightDiagnostic()> emitError,
208208
Type coefficientType, IntegerAttr coefficientModulus,
209-
IntPolynomialAttr polynomialModulus) {
209+
IntPolynomialAttr polynomialModulus,
210+
PrimitiveRootAttr primitiveRoot) {
210211
if (coefficientModulus) {
211212
auto coeffIntType = llvm::dyn_cast<IntegerType>(coefficientType);
212213
if (!coeffIntType) {

mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@ include "mlir/Dialect/Polynomial/IR/Polynomial.td"
1414
include "mlir/IR/OpBase.td"
1515
include "mlir/IR/PatternBase.td"
1616

17-
def Equal : Constraint<CPred<"$0 == $1">>;
18-
1917
// Get a -1 integer attribute of the same type as the polynomial SSA value's
2018
// ring coefficient type.
2119
def getMinusOne
@@ -30,15 +28,13 @@ def SubAsAdd : Pat<
3028
(Arith_ConstantOp (getMinusOne $g))))>;
3129

3230
def INTTAfterNTT : Pat<
33-
(Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
34-
(replaceWithValue $poly),
35-
[(Equal $r1, $r2)]
31+
(Polynomial_INTTOp (Polynomial_NTTOp $poly)),
32+
(replaceWithValue $poly)
3633
>;
3734

3835
def NTTAfterINTT : Pat<
39-
(Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
40-
(replaceWithValue $tensor),
41-
[(Equal $r1, $r2)]
36+
(Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
37+
(replaceWithValue $tensor)
4238
>;
4339

4440
#endif // POLYNOMIAL_CANONICALIZATION

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,7 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
134134
/// Verify that the types involved in an NTT or INTT operation are
135135
/// compatible.
136136
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
137-
RankedTensorType tensorType,
138-
std::optional<PrimitiveRootAttr> root) {
137+
RankedTensorType tensorType) {
139138
Attribute encoding = tensorType.getEncoding();
140139
if (!encoding) {
141140
return op->emitOpError()
@@ -166,9 +165,10 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
166165
return diag;
167166
}
168167

169-
if (root.has_value()) {
170-
APInt rootValue = root.value().getValue().getValue();
171-
APInt rootDegree = root.value().getDegree().getValue();
168+
auto root = ring.getPrimitiveRoot();
169+
if (root) {
170+
APInt rootValue = root.getValue().getValue();
171+
APInt rootDegree = root.getDegree().getValue();
172172
APInt cmod = ring.getCoefficientModulus().getValue();
173173
if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
174174
return op->emitOpError()
@@ -177,19 +177,22 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
177177
<< "of unity mod " << cmod.getZExtValue()
178178
<< ", with the specified degree " << rootDegree.getZExtValue();
179179
}
180+
} else {
181+
return op->emitOpError()
182+
<< "primitive root not provided but ntt/intt op called";
180183
}
181184

182185
return success();
183186
}
184187

185188
LogicalResult NTTOp::verify() {
186189
return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
187-
getOutput().getType(), getRoot());
190+
getOutput().getType());
188191
}
189192

190193
LogicalResult INTTOp::verify() {
191194
return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
192-
getInput().getType(), getRoot());
195+
getInput().getType());
193196
}
194197

195198
ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {

mlir/test/Dialect/Polynomial/canonicalization.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: mlir-opt -canonicalize %s | FileCheck %s
22
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
3-
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
43
#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
4+
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=#root>
55
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
66
!tensor_ty = tensor<8xi32, #ntt_ring>
77

@@ -11,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
1111
// CHECK-NOT: polynomial.ntt
1212
// CHECK-NOT: polynomial.intt
1313
// CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]]
14-
%t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
15-
%p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
14+
%t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
15+
%p1 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
1616
%p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
1717
// CHECK: return %[[RESULT]] : [[T]]
1818
return %p2 : !ntt_poly_ty
@@ -24,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
2424
// CHECK-NOT: polynomial.intt
2525
// CHECK-NOT: polynomial.ntt
2626
// CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
27-
%p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
28-
%t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
27+
%p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
28+
%t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
2929
%t2 = arith.addi %t1, %t1 : !tensor_ty
3030
// CHECK: return %[[RESULT]] : [[T]]
3131
return %t2 : !tensor_ty

mlir/test/Dialect/Polynomial/ops.mlir

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
!poly_ty = !polynomial.polynomial<ring=#ring>
1616

1717
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
18-
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
18+
#ntt_ring_root = #polynomial.primitive_root<value=31:i32, degree=8:index>
19+
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=#ntt_ring_root>
1920
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
2021

2122
#ntt_poly_2 = #polynomial.int_polynomial<1 + x**65536>
22-
#ntt_ring_2 = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#ntt_poly_2>
2323
#ntt_ring_2_root = #polynomial.primitive_root<value=283965:i32, degree=131072:i32>
24+
#ntt_ring_2 = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#ntt_poly_2, primitiveRoot=#ntt_ring_2_root>
2425
!ntt_poly_ty_2 = !polynomial.polynomial<ring=#ntt_ring_2>
2526

2627
module {
@@ -96,17 +97,17 @@ module {
9697
}
9798

9899
func.func @test_ntt(%0 : !ntt_poly_ty) {
99-
%1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
100+
%1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
100101
return
101102
}
102103

103104
func.func @test_ntt_with_overflowing_root(%0 : !ntt_poly_ty_2) {
104-
%1 = polynomial.ntt %0 {root=#ntt_ring_2_root} : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2>
105+
%1 = polynomial.ntt %0 : !ntt_poly_ty_2 -> tensor<65536xi32, #ntt_ring_2>
105106
return
106107
}
107108

108109
func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
109-
%1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
110+
%1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
110111
return
111112
}
112113
}

mlir/test/Dialect/Polynomial/ops_errors.mlir

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,36 +55,39 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
5555
// -----
5656

5757
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
58-
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
58+
#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
59+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
5960
!poly_ty = !polynomial.polynomial<ring=#ring>
6061

6162
// CHECK-NOT: @test_invalid_ntt
6263
// CHECK-NOT: polynomial.ntt
6364
func.func @test_invalid_ntt(%0 : !poly_ty) {
6465
// expected-error@below {{expects a ring encoding to be provided to the tensor}}
65-
%1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32>
66+
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
6667
return
6768
}
6869

6970
// -----
7071

7172
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
72-
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
73+
#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
74+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
7375
!poly_ty = !polynomial.polynomial<ring=#ring>
7476

7577
// CHECK-NOT: @test_invalid_ntt
7678
// CHECK-NOT: polynomial.ntt
7779
func.func @test_invalid_ntt(%0 : !poly_ty) {
7880
// expected-error@below {{tensor encoding is not a ring attribute}}
79-
%1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32, #my_poly>
81+
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
8082
return
8183
}
8284

8385
// -----
8486

8587
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
88+
#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
8689
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
87-
#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly>
90+
#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
8891
!poly_ty = !polynomial.polynomial<ring=#ring>
8992

9093
// CHECK-NOT: @test_invalid_intt
@@ -98,29 +101,45 @@ func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
98101
// -----
99102

100103
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
101-
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
104+
#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
105+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
102106
!poly_ty = !polynomial.polynomial<ring=#ring>
103107

104108
// CHECK-NOT: @test_invalid_intt
105109
// CHECK-NOT: polynomial.intt
106110
func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
107111
// expected-error@below {{does not match output type}}
108112
// expected-note@below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
109-
%1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1025xi32, #ring> -> !poly_ty
113+
%1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
110114
return
111115
}
112116

113117
// -----
114118

115119
#my_poly = #polynomial.int_polynomial<-1 + x**8>
116120
// A valid root is 31
117-
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
121+
#root = #polynomial.primitive_root<value=32:i32, degree=8:index>
122+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=#root>
118123
!poly_ty = !polynomial.polynomial<ring=#ring>
119124

120125
// CHECK-NOT: @test_invalid_intt
121126
// CHECK-NOT: polynomial.intt
122127
func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
123128
// expected-error@below {{provided root 32 is not a primitive root of unity mod 256, with the specified degree 8}}
124-
%1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=32:i16, degree=8:index>} : tensor<8xi32, #ring> -> !poly_ty
129+
%1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
130+
return
131+
}
132+
133+
// -----
134+
135+
#my_poly = #polynomial.int_polynomial<-1 + x**8>
136+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
137+
!poly_ty = !polynomial.polynomial<ring=#ring>
138+
139+
// CHECK-NOT: @test_invalid_intt
140+
// CHECK-NOT: polynomial.intt
141+
func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
142+
// expected-error@below {{primitive root not provided but ntt/intt op called}}
143+
%1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
125144
return
126145
}

0 commit comments

Comments
 (0)