Skip to content

Commit 0a7431a

Browse files
committed
move primitive root attr to ntt/intt ops
1 parent 73eb9b3 commit 0a7431a

File tree

4 files changed

+80
-52
lines changed

4 files changed

+80
-52
lines changed

mlir/include/mlir/Dialect/Polynomial/IR/Polynomial.td

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,7 @@ def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
277277
Polynomial_TypedIntPolynomialAttr
278278
]>;
279279

280-
// Not deriving from Polynomial_Op due to need for custom assembly format
281-
def Polynomial_ConstantOp : Op<Polynomial_Dialect, "constant",
280+
def Polynomial_ConstantOp : Polynomial_Op<"constant",
282281
[Pure, InferTypeOpAdaptor]> {
283282
let summary = "Define a constant polynomial via an attribute.";
284283
let description = [{
@@ -312,9 +311,12 @@ def Polynomial_NTTOp : Polynomial_Op<"ntt", [Pure]> {
312311

313312
`f[k] = F(omega[n]^k) ; k = {0, ..., n-1}`
314313

315-
The choice of primitive root is determined by subsequent lowerings.
314+
The choice of primitive root may be optionally specified.
316315
}];
317-
let arguments = (ins Polynomial_PolynomialType:$input);
316+
let arguments = (ins
317+
Polynomial_PolynomialType:$input,
318+
OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
319+
);
318320
let results = (outs RankedTensorOf<[AnyInteger]>:$output);
319321
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
320322
let hasCanonicalizer = 1;
@@ -332,8 +334,13 @@ def Polynomial_INTTOp : Polynomial_Op<"intt", [Pure]> {
332334
output polynomial at powers of a primitive `n`-th root of unity (see
333335
`polynomial.ntt`). The ring of the polynomial is taken from the required
334336
encoding attribute of the tensor.
337+
338+
The choice of primitive root may be optionally specified.
335339
}];
336-
let arguments = (ins RankedTensorOf<[AnyInteger]>:$input);
340+
let arguments = (
341+
ins RankedTensorOf<[AnyInteger]>:$input,
342+
OptionalAttr<Polynomial_PrimitiveRootAttr>:$root
343+
);
337344
let results = (outs Polynomial_PolynomialType:$output);
338345
let assemblyFormat = "$input attr-dict `:` qualified(type($input)) `->` type($output)";
339346
let hasCanonicalizer = 1;

mlir/include/mlir/Dialect/Polynomial/IR/PolynomialAttributes.td

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,24 +166,45 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring"> {
166166
let parameters = (ins
167167
"Type": $coefficientType,
168168
OptionalParameter<"::mlir::IntegerAttr">: $coefficientModulus,
169-
OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus,
170-
OptionalParameter<"::mlir::IntegerAttr">: $primitiveRoot
169+
OptionalParameter<"::mlir::polynomial::IntPolynomialAttr">: $polynomialModulus
171170
);
172171
let assemblyFormat = "`<` struct(params) `>`";
173172
let builders = [
174173
AttrBuilderWithInferredContext<
175174
(ins "::mlir::Type":$coefficientTy,
176175
CArg<"::mlir::IntegerAttr", "nullptr"> :$coefficientModulusAttr,
177-
CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr,
178-
CArg<"::mlir::IntegerAttr", "nullptr"> :$primitiveRootAttr), [{
176+
CArg<"::mlir::polynomial::IntPolynomialAttr", "nullptr"> :$polynomialModulusAttr), [{
179177
return $_get(
180178
coefficientTy.getContext(),
181179
coefficientTy,
182180
coefficientModulusAttr,
183-
polynomialModulusAttr,
184-
primitiveRootAttr);
181+
polynomialModulusAttr);
185182
}]>,
186183
];
187184
}
188185

186+
def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
187+
let summary = "an attribute containing an integer and its degree as a root of unity";
188+
let description = [{
189+
A primitive root attribute stores an integer root `value` and an integer
190+
`degree`, corresponding to a primitive root of unity of the given degree in
191+
an unspecified ring.
192+
193+
This is used as an attribute on `polynomial.ntt` and `polynomial.intt` ops
194+
to specify the root of unity used in lowering the transform.
195+
196+
Example:
197+
198+
```mlir
199+
#poly = #polynomial.primitive_root<value=123 : i32, degree : 7 index>
200+
```
201+
}];
202+
let parameters = (ins
203+
"::mlir::IntegerAttr":$value,
204+
"::mlir::IntegerAttr":$degree
205+
);
206+
let assemblyFormat = "`<` struct(params) `>`";
207+
}
208+
209+
189210
#endif // POLYNOMIAL_ATTRIBUTES

mlir/lib/Dialect/Polynomial/IR/PolynomialCanonicalization.td

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ include "mlir/IR/PatternBase.td"
1717

1818
defvar DefOverflow = ConstantEnumCase<Arith_IntegerOverflowAttr, "none">;
1919

20+
def Equal : Constraint<CPred<"$0 == $1">>;
21+
2022
// Get a -1 integer attribute of the same type as the polynomial SSA value's
2123
// ring coefficient type.
2224
def getMinusOne
@@ -31,51 +33,51 @@ def SubAsAdd : Pat<
3133
(Arith_ConstantOp (getMinusOne $g))))>;
3234

3335
def INTTAfterNTT : Pat<
34-
(Polynomial_INTTOp (Polynomial_NTTOp $poly)),
36+
(Polynomial_INTTOp (Polynomial_NTTOp $poly, $r1), $r2),
3537
(replaceWithValue $poly),
36-
[]
38+
[(Equal $r1, $r2)]
3739
>;
3840

3941
def NTTAfterINTT : Pat<
40-
(Polynomial_NTTOp (Polynomial_INTTOp $tensor)),
42+
(Polynomial_NTTOp (Polynomial_INTTOp $tensor, $r1), $r2),
4143
(replaceWithValue $tensor),
42-
[]
44+
[(Equal $r1, $r2)]
4345
>;
4446

4547
// NTTs are expensive, and addition in coefficient or NTT domain should be
4648
// equivalently expensive, so reducing the number of NTTs is optimal.
4749
// ntt(a) + ntt(b) -> ntt(a + b)
4850
def NTTOfAdd : Pat<
4951
(Arith_AddIOp
50-
(Polynomial_NTTOp $p1),
51-
(Polynomial_NTTOp $p2),
52+
(Polynomial_NTTOp $p1, $r1),
53+
(Polynomial_NTTOp $p2, $r2),
5254
$overflow),
53-
(Polynomial_NTTOp (Polynomial_AddOp $p1, $p2)),
54-
[]
55+
(Polynomial_NTTOp (Polynomial_AddOp $p1, $p2), $r1),
56+
[(Equal $r1, $r2)]
5557
>;
5658
// intt(a) + intt(b) -> intt(a + b)
5759
def INTTOfAdd : Pat<
5860
(Polynomial_AddOp
59-
(Polynomial_INTTOp $t1),
60-
(Polynomial_INTTOp $t2)),
61-
(Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow)),
62-
[]
61+
(Polynomial_INTTOp $t1, $r1),
62+
(Polynomial_INTTOp $t2, $r2)),
63+
(Polynomial_INTTOp (Arith_AddIOp $t1, $t2, DefOverflow), $r1),
64+
[(Equal $r1, $r2)]
6365
>;
6466
// repeated for sub
6567
def NTTOfSub : Pat<
6668
(Arith_SubIOp
67-
(Polynomial_NTTOp $p1),
68-
(Polynomial_NTTOp $p2),
69+
(Polynomial_NTTOp $p1, $r1),
70+
(Polynomial_NTTOp $p2, $r2),
6971
$overflow),
70-
(Polynomial_NTTOp (Polynomial_SubOp $p1, $p2)),
71-
[]
72+
(Polynomial_NTTOp (Polynomial_SubOp $p1, $p2), $r1),
73+
[(Equal $r1, $r2)]
7274
>;
7375
def INTTOfSub : Pat<
7476
(Polynomial_SubOp
75-
(Polynomial_INTTOp $t1),
76-
(Polynomial_INTTOp $t2)),
77-
(Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow)),
78-
[]
77+
(Polynomial_INTTOp $t1, $r1),
78+
(Polynomial_INTTOp $t2, $r2)),
79+
(Polynomial_INTTOp (Arith_SubIOp $t1, $t2, DefOverflow), $r1),
80+
[(Equal $r1, $r2)]
7981
>;
8082

8183
#endif // POLYNOMIAL_CANONICALIZATION

mlir/lib/Dialect/Polynomial/IR/PolynomialOps.cpp

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,15 @@ LogicalResult MulScalarOp::verify() {
108108
}
109109

110110
/// Test if a value is a primitive nth root of unity modulo cmod.
111-
bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
111+
bool isPrimitiveNthRootOfUnity(const APInt &root, const APInt &n,
112112
const APInt &cmod) {
113113
// Root bitwidth may be 1 less then cmod.
114114
APInt r = APInt(root).zext(cmod.getBitWidth());
115115
assert(r.ule(cmod) && "root must be less than cmod");
116+
unsigned upperBound = n.getZExtValue();
116117

117118
APInt a = r;
118-
for (size_t k = 1; k < n; k++) {
119+
for (size_t k = 1; k < upperBound; k++) {
119120
if (a.isOne())
120121
return false;
121122
a = (a * r).urem(cmod);
@@ -126,7 +127,8 @@ bool isPrimitiveNthRootOfUnity(const APInt &root, const unsigned n,
126127
/// Verify that the types involved in an NTT or INTT operation are
127128
/// compatible.
128129
static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
129-
RankedTensorType tensorType) {
130+
RankedTensorType tensorType,
131+
std::optional<PrimitiveRootAttr> root) {
130132
Attribute encoding = tensorType.getEncoding();
131133
if (!encoding) {
132134
return op->emitOpError()
@@ -157,33 +159,29 @@ static LogicalResult verifyNTTOp(Operation *op, RingAttr ring,
157159
return diag;
158160
}
159161

160-
if (!ring.getPrimitiveRoot()) {
161-
return op->emitOpError()
162-
<< "ring type " << ring << " does not provide a primitive root "
163-
<< "of unity, which is required to express an NTT";
164-
}
165-
166-
if (!isPrimitiveNthRootOfUnity(ring.getPrimitiveRoot().getValue(), polyDegree,
167-
ring.getCoefficientModulus().getValue())) {
168-
return op->emitOpError()
169-
<< "ring type " << ring << " has a primitiveRoot attribute '"
170-
<< ring.getPrimitiveRoot()
171-
<< "' that is not a primitive root of the coefficient ring";
162+
if (root.has_value()) {
163+
APInt rootValue = root.value().getValue().getValue();
164+
APInt rootDegree = root.value().getDegree().getValue();
165+
APInt cmod = ring.getCoefficientModulus().getValue();
166+
if (!isPrimitiveNthRootOfUnity(rootValue, rootDegree, cmod)) {
167+
return op->emitOpError()
168+
<< "provided root " << rootValue.getZExtValue() << " is not a primitive root "
169+
<< "of unity mod " << cmod.getZExtValue() << ", with the specified degree "
170+
<< rootDegree.getZExtValue();
171+
}
172172
}
173173

174174
return success();
175175
}
176176

177177
LogicalResult NTTOp::verify() {
178-
auto ring = getInput().getType().getRing();
179-
auto tensorType = getOutput().getType();
180-
return verifyNTTOp(this->getOperation(), ring, tensorType);
178+
return verifyNTTOp(this->getOperation(), getInput().getType().getRing(),
179+
getOutput().getType(), getRoot());
181180
}
182181

183182
LogicalResult INTTOp::verify() {
184-
auto tensorType = getInput().getType();
185-
auto ring = getOutput().getType().getRing();
186-
return verifyNTTOp(this->getOperation(), ring, tensorType);
183+
return verifyNTTOp(this->getOperation(), getOutput().getType().getRing(),
184+
getInput().getType(), getRoot());
187185
}
188186

189187
ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {

0 commit comments

Comments
 (0)