Skip to content

Commit cd7eedb

Browse files
committed
update tests
1 parent 0a7431a commit cd7eedb

File tree

3 files changed

+33
-46
lines changed

3 files changed

+33
-46
lines changed

mlir/test/Dialect/Polynomial/canonicalization.mlir

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// RUN: mlir-opt -canonicalize %s | FileCheck %s
22
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
3-
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
3+
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
4+
#root = #polynomial.primitive_root<value=31:i32, degree=8:index>
45
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
56
!tensor_ty = tensor<8xi32, #ntt_ring>
67

@@ -10,8 +11,8 @@ func.func @test_canonicalize_intt_after_ntt(%p0 : !ntt_poly_ty) -> !ntt_poly_ty
1011
// CHECK-NOT: polynomial.ntt
1112
// CHECK-NOT: polynomial.intt
1213
// CHECK: %[[RESULT:.+]] = polynomial.add %[[P]], %[[P]] : [[T]]
13-
%t0 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
14-
%p1 = polynomial.intt %t0: !tensor_ty -> !ntt_poly_ty
14+
%t0 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
15+
%p1 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
1516
%p2 = polynomial.add %p1, %p1 : !ntt_poly_ty
1617
// CHECK: return %[[RESULT]] : [[T]]
1718
return %p2 : !ntt_poly_ty
@@ -23,8 +24,8 @@ func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
2324
// CHECK-NOT: polynomial.intt
2425
// CHECK-NOT: polynomial.ntt
2526
// CHECK: %[[RESULT:.+]] = arith.addi %[[X]], %[[X]] : [[T]]
26-
%p0 = polynomial.intt %t0 : !tensor_ty -> !ntt_poly_ty
27-
%t1 = polynomial.ntt %p0 : !ntt_poly_ty -> !tensor_ty
27+
%p0 = polynomial.intt %t0 {root=#root} : !tensor_ty -> !ntt_poly_ty
28+
%t1 = polynomial.ntt %p0 {root=#root} : !ntt_poly_ty -> !tensor_ty
2829
%t2 = arith.addi %t1, %t1 : !tensor_ty
2930
// CHECK: return %[[RESULT]] : [[T]]
3031
return %t2 : !tensor_ty
@@ -51,10 +52,10 @@ func.func @test_canonicalize_sub(%poly0 : !sub_ty, %poly1 : !sub_ty) -> !sub_ty
5152
func.func @test_canonicalize_fold_add_through_ntt(
5253
%poly0 : !ntt_poly_ty,
5354
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
54-
%0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
55-
%1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
55+
%0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
56+
%1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
5657
%a_plus_b = arith.addi %0, %1 : !tensor_ty
57-
%out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
58+
%out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
5859
return %out : !ntt_poly_ty
5960
}
6061

@@ -65,10 +66,10 @@ func.func @test_canonicalize_fold_add_through_ntt(
6566
func.func @test_canonicalize_fold_add_through_intt(
6667
%tensor0 : !tensor_ty,
6768
%tensor1 : !tensor_ty) -> !tensor_ty {
68-
%0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
69-
%1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
69+
%0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
70+
%1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
7071
%a_plus_b = polynomial.add %0, %1 : !ntt_poly_ty
71-
%out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
72+
%out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
7273
return %out : !tensor_ty
7374
}
7475

@@ -80,10 +81,10 @@ func.func @test_canonicalize_fold_add_through_intt(
8081
func.func @test_canonicalize_fold_sub_through_ntt(
8182
%poly0 : !ntt_poly_ty,
8283
%poly1 : !ntt_poly_ty) -> !ntt_poly_ty {
83-
%0 = polynomial.ntt %poly0 : !ntt_poly_ty -> !tensor_ty
84-
%1 = polynomial.ntt %poly1 : !ntt_poly_ty -> !tensor_ty
84+
%0 = polynomial.ntt %poly0 {root=#root} : !ntt_poly_ty -> !tensor_ty
85+
%1 = polynomial.ntt %poly1 {root=#root} : !ntt_poly_ty -> !tensor_ty
8586
%a_plus_b = arith.subi %0, %1 : !tensor_ty
86-
%out = polynomial.intt %a_plus_b : !tensor_ty -> !ntt_poly_ty
87+
%out = polynomial.intt %a_plus_b {root=#root} : !tensor_ty -> !ntt_poly_ty
8788
return %out : !ntt_poly_ty
8889
}
8990

@@ -94,9 +95,9 @@ func.func @test_canonicalize_fold_sub_through_ntt(
9495
func.func @test_canonicalize_fold_sub_through_intt(
9596
%tensor0 : !tensor_ty,
9697
%tensor1 : !tensor_ty) -> !tensor_ty {
97-
%0 = polynomial.intt %tensor0 : !tensor_ty -> !ntt_poly_ty
98-
%1 = polynomial.intt %tensor1 : !tensor_ty -> !ntt_poly_ty
98+
%0 = polynomial.intt %tensor0 {root=#root} : !tensor_ty -> !ntt_poly_ty
99+
%1 = polynomial.intt %tensor1 {root=#root} : !tensor_ty -> !ntt_poly_ty
99100
%a_plus_b = polynomial.sub %0, %1 : !ntt_poly_ty
100-
%out = polynomial.ntt %a_plus_b : !ntt_poly_ty -> !tensor_ty
101+
%out = polynomial.ntt %a_plus_b {root=#root} : !ntt_poly_ty -> !tensor_ty
101102
return %out : !tensor_ty
102103
}

mlir/test/Dialect/Polynomial/ops.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
#one_plus_x_squared = #polynomial.int_polynomial<1 + x**2>
1212

1313
#ideal = #polynomial.int_polynomial<-1 + x**1024>
14-
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal, primitiveRoot=193>
14+
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ideal>
1515
!poly_ty = !polynomial.polynomial<ring=#ring>
1616

1717
#ntt_poly = #polynomial.int_polynomial<-1 + x**8>
18-
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly, primitiveRoot=31>
18+
#ntt_ring = #polynomial.ring<coefficientType=i32, coefficientModulus=256, polynomialModulus=#ntt_poly>
1919
!ntt_poly_ty = !polynomial.polynomial<ring=#ntt_ring>
2020

2121
module {
@@ -91,12 +91,12 @@ module {
9191
}
9292

9393
func.func @test_ntt(%0 : !ntt_poly_ty) {
94-
%1 = polynomial.ntt %0 : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
94+
%1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !ntt_poly_ty -> tensor<8xi32, #ntt_ring>
9595
return
9696
}
9797

9898
func.func @test_intt(%0 : tensor<8xi32, #ntt_ring>) {
99-
%1 = polynomial.intt %0 : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
99+
%1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<8xi32, #ntt_ring> -> !ntt_poly_ty
100100
return
101101
}
102102
}

mlir/test/Dialect/Polynomial/ops_errors.mlir

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -55,86 +55,72 @@ func.func @test_mul_scalar_wrong_type(%arg0: !ty) -> !ty {
5555
// -----
5656

5757
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
58-
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
58+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
5959
!poly_ty = !polynomial.polynomial<ring=#ring>
6060

6161
// CHECK-NOT: @test_invalid_ntt
6262
// CHECK-NOT: polynomial.ntt
6363
func.func @test_invalid_ntt(%0 : !poly_ty) {
6464
// expected-error@below {{expects a ring encoding to be provided to the tensor}}
65-
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32>
65+
%1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32>
6666
return
6767
}
6868

6969
// -----
7070

7171
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
72-
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
72+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
7373
!poly_ty = !polynomial.polynomial<ring=#ring>
7474

7575
// CHECK-NOT: @test_invalid_ntt
7676
// CHECK-NOT: polynomial.ntt
7777
func.func @test_invalid_ntt(%0 : !poly_ty) {
7878
// expected-error@below {{tensor encoding is not a ring attribute}}
79-
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #my_poly>
79+
%1 = polynomial.ntt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : !poly_ty -> tensor<1024xi32, #my_poly>
8080
return
8181
}
8282

8383
// -----
8484

8585
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
8686
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
87-
#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
87+
#ring1 = #polynomial.ring<coefficientType=i16, coefficientModulus=257:i16, polynomialModulus=#my_poly>
8888
!poly_ty = !polynomial.polynomial<ring=#ring>
8989

9090
// CHECK-NOT: @test_invalid_intt
9191
// CHECK-NOT: polynomial.intt
9292
func.func @test_invalid_intt(%0 : tensor<1024xi32, #ring1>) {
9393
// expected-error@below {{not equivalent to the polynomial ring}}
94-
%1 = polynomial.intt %0 : tensor<1024xi32, #ring1> -> !poly_ty
94+
%1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1024xi32, #ring1> -> !poly_ty
9595
return
9696
}
9797

9898
// -----
9999

100100
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
101-
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=31:i16>
101+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
102102
!poly_ty = !polynomial.polynomial<ring=#ring>
103103

104104
// CHECK-NOT: @test_invalid_intt
105105
// CHECK-NOT: polynomial.intt
106106
func.func @test_invalid_intt(%0 : tensor<1025xi32, #ring>) {
107107
// expected-error@below {{does not match output type}}
108108
// expected-note@below {{exactly the degree of the polynomialModulus of the polynomial type's ring attribute}}
109-
%1 = polynomial.intt %0 : tensor<1025xi32, #ring> -> !poly_ty
110-
return
111-
}
112-
113-
// -----
114-
115-
#my_poly = #polynomial.int_polynomial<-1 + x**1024>
116-
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
117-
!poly_ty = !polynomial.polynomial<ring=#ring>
118-
119-
// CHECK-NOT: @test_invalid_ntt
120-
// CHECK-NOT: polynomial.ntt
121-
func.func @test_invalid_ntt(%0 : !poly_ty) {
122-
// expected-error@below {{does not provide a primitive root of unity, which is required to express an NTT}}
123-
%1 = polynomial.ntt %0 : !poly_ty -> tensor<1024xi32, #ring>
109+
%1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=31:i32, degree=8:index>} : tensor<1025xi32, #ring> -> !poly_ty
124110
return
125111
}
126112

127113
// -----
128114

129115
#my_poly = #polynomial.int_polynomial<-1 + x**8>
130116
// A valid root is 31
131-
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly, primitiveRoot=32:i16>
117+
#ring = #polynomial.ring<coefficientType=i16, coefficientModulus=256:i16, polynomialModulus=#my_poly>
132118
!poly_ty = !polynomial.polynomial<ring=#ring>
133119

134120
// CHECK-NOT: @test_invalid_intt
135121
// CHECK-NOT: polynomial.intt
136122
func.func @test_invalid_intt(%0 : tensor<8xi32, #ring>) {
137-
// expected-error@below {{has a primitiveRoot attribute '32 : i16' that is not a primitive root of the coefficient ring}}
138-
%1 = polynomial.intt %0 : tensor<8xi32, #ring> -> !poly_ty
123+
// expected-error@below {{provided root 32 is not a primitive root of unity mod 256, with the specified degree 8}}
124+
%1 = polynomial.intt %0 {root=#polynomial.primitive_root<value=32:i16, degree=8:index>} : tensor<8xi32, #ring> -> !poly_ty
139125
return
140126
}

0 commit comments

Comments
 (0)