Skip to content

Commit c27a958

Browse files
[MLIR][LLVMIR] Extend llrint,lrint,lround for vectors of float (#136225)
Matching langref. Note that `llround` is different than the rest.
1 parent 89f930a commit c27a958

File tree

3 files changed

+66
-21
lines changed

3 files changed

+66
-21
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,16 +153,18 @@ def LLVM_PowOp : LLVM_BinarySameArgsIntrOpF<"pow">;
153153
def LLVM_PowIOp : LLVM_PowFI<"powi">;
154154
def LLVM_RintOp : LLVM_UnaryIntrOpF<"rint">;
155155
def LLVM_NearbyintOp : LLVM_UnaryIntrOpF<"nearbyint">;
156-
class LLVM_IntRoundIntrOpBase<string func> :
156+
class LLVM_IntRoundIntrOpBase<string func, Type element = LLVM_AnyFloat> :
157157
LLVM_OneResultIntrOp<func, [0], [0], [Pure]> {
158-
let arguments = (ins LLVM_AnyFloat:$val);
158+
let arguments = (ins element:$val);
159159
let assemblyFormat = "`(` operands `)` attr-dict `:` "
160160
"functional-type(operands, results)";
161161
}
162-
def LLVM_LroundOp : LLVM_IntRoundIntrOpBase<"lround">;
162+
class LLVM_ScalarOrVectorIntRoundIntrOpBase<string func> :
163+
LLVM_IntRoundIntrOpBase<func, LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
164+
def LLVM_LroundOp : LLVM_ScalarOrVectorIntRoundIntrOpBase<"lround">;
163165
def LLVM_LlroundOp : LLVM_IntRoundIntrOpBase<"llround">;
164-
def LLVM_LrintOp : LLVM_IntRoundIntrOpBase<"lrint">;
165-
def LLVM_LlrintOp : LLVM_IntRoundIntrOpBase<"llrint">;
166+
def LLVM_LrintOp : LLVM_ScalarOrVectorIntRoundIntrOpBase<"lrint">;
167+
def LLVM_LlrintOp : LLVM_ScalarOrVectorIntRoundIntrOpBase<"llrint">;
166168
def LLVM_BitReverseOp : LLVM_UnaryIntrOpI<"bitreverse">;
167169
def LLVM_ByteSwapOp : LLVM_UnaryIntrOpI<"bswap">;
168170
def LLVM_CountLeadingZerosOp : LLVM_CountZerosIntrOp<"ctlz">;

mlir/test/Target/LLVMIR/Import/intrinsic.ll

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -235,15 +235,23 @@ define void @nearbyint_test(float %0, double %1, <8 x float> %2, <8 x double> %3
235235
ret void
236236
}
237237
; CHECK-LABEL: llvm.func @lround_test
238-
define void @lround_test(float %0, double %1) {
238+
define void @lround_test(float %0, double %1, <2 x float> %2, <2 x double> %3) {
239239
; CHECK: llvm.intr.lround(%{{.*}}) : (f32) -> i32
240-
%3 = call i32 @llvm.lround.i32.f32(float %0)
240+
%5 = call i32 @llvm.lround.i32.f32(float %0)
241241
; CHECK: llvm.intr.lround(%{{.*}}) : (f32) -> i64
242-
%4 = call i64 @llvm.lround.i64.f32(float %0)
242+
%6 = call i64 @llvm.lround.i64.f32(float %0)
243243
; CHECK: llvm.intr.lround(%{{.*}}) : (f64) -> i32
244-
%5 = call i32 @llvm.lround.i32.f64(double %1)
244+
%7 = call i32 @llvm.lround.i32.f64(double %1)
245245
; CHECK: llvm.intr.lround(%{{.*}}) : (f64) -> i64
246-
%6 = call i64 @llvm.lround.i64.f64(double %1)
246+
%8 = call i64 @llvm.lround.i64.f64(double %1)
247+
; CHECK: llvm.intr.lround(%{{.*}}) : (vector<2xf32>) -> vector<2xi32>
248+
%9 = call <2 x i32> @llvm.lround.v2i32.v2f32(<2 x float> %2)
249+
; CHECK: llvm.intr.lround(%{{.*}}) : (vector<2xf64>) -> vector<2xi32>
250+
%10 = call <2 x i32> @llvm.lround.v2i32.v2f64(<2 x double> %3)
251+
; CHECK: llvm.intr.lround(%{{.*}}) : (vector<2xf32>) -> vector<2xi64>
252+
%11 = call <2 x i64> @llvm.lround.v2i64.v2f32(<2 x float> %2)
253+
; CHECK: llvm.intr.lround(%{{.*}}) : (vector<2xf64>) -> vector<2xi64>
254+
%12 = call <2 x i64> @llvm.lround.v2i64.v2f64(<2 x double> %3)
247255
ret void
248256
}
249257
; CHECK-LABEL: llvm.func @llround_test
@@ -255,23 +263,35 @@ define void @llround_test(float %0, double %1) {
255263
ret void
256264
}
257265
; CHECK-LABEL: llvm.func @lrint_test
258-
define void @lrint_test(float %0, double %1) {
266+
define void @lrint_test(float %0, double %1, <2 x float> %2, <2 x double> %3) {
259267
; CHECK: llvm.intr.lrint(%{{.*}}) : (f32) -> i32
260-
%3 = call i32 @llvm.lrint.i32.f32(float %0)
268+
%5 = call i32 @llvm.lrint.i32.f32(float %0)
261269
; CHECK: llvm.intr.lrint(%{{.*}}) : (f32) -> i64
262-
%4 = call i64 @llvm.lrint.i64.f32(float %0)
270+
%6 = call i64 @llvm.lrint.i64.f32(float %0)
263271
; CHECK: llvm.intr.lrint(%{{.*}}) : (f64) -> i32
264-
%5 = call i32 @llvm.lrint.i32.f64(double %1)
272+
%7 = call i32 @llvm.lrint.i32.f64(double %1)
265273
; CHECK: llvm.intr.lrint(%{{.*}}) : (f64) -> i64
266-
%6 = call i64 @llvm.lrint.i64.f64(double %1)
274+
%8 = call i64 @llvm.lrint.i64.f64(double %1)
275+
; CHECK: llvm.intr.lrint(%{{.*}}) : (vector<2xf32>) -> vector<2xi32>
276+
%9 = call <2 x i32> @llvm.lrint.v2i32.v2f32(<2 x float> %2)
277+
; CHECK: llvm.intr.lrint(%{{.*}}) : (vector<2xf64>) -> vector<2xi32>
278+
%10 = call <2 x i32> @llvm.lrint.v2i32.v2f64(<2 x double> %3)
279+
; CHECK: llvm.intr.lrint(%{{.*}}) : (vector<2xf32>) -> vector<2xi64>
280+
%11 = call <2 x i64> @llvm.lrint.v2i64.v2f32(<2 x float> %2)
281+
; CHECK: llvm.intr.lrint(%{{.*}}) : (vector<2xf64>) -> vector<2xi64>
282+
%12 = call <2 x i64> @llvm.lrint.v2i64.v2f64(<2 x double> %3)
267283
ret void
268284
}
269285
; CHECK-LABEL: llvm.func @llrint_test
270-
define void @llrint_test(float %0, double %1) {
286+
define void @llrint_test(float %0, double %1, <2 x float> %2, <2 x double> %3) {
271287
; CHECK: llvm.intr.llrint(%{{.*}}) : (f32) -> i64
272-
%3 = call i64 @llvm.llrint.i64.f32(float %0)
288+
%5 = call i64 @llvm.llrint.i64.f32(float %0)
273289
; CHECK: llvm.intr.llrint(%{{.*}}) : (f64) -> i64
274-
%4 = call i64 @llvm.llrint.i64.f64(double %1)
290+
%6 = call i64 @llvm.llrint.i64.f64(double %1)
291+
; CHECK: llvm.intr.llrint(%{{.*}}) : (vector<2xf32>) -> vector<2xi64>
292+
%7 = call <2 x i64> @llvm.llrint.v2i64.v2f32(<2 x float> %2)
293+
; CHECK: llvm.intr.llrint(%{{.*}}) : (vector<2xf64>) -> vector<2xi64>
294+
%8 = call <2 x i64> @llvm.llrint.v2i64.v2f64(<2 x double> %3)
275295
ret void
276296
}
277297

mlir/test/Target/LLVMIR/llvmir-intrinsics.mlir

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,8 @@ llvm.func @nearbyint_test(%arg0 : f32, %arg1 : f64, %arg2 : vector<8xf32>, %arg3
242242
}
243243

244244
// CHECK-LABEL: @lround_test
245-
llvm.func @lround_test(%arg0 : f32, %arg1 : f64) {
245+
llvm.func @lround_test(%arg0 : f32, %arg1 : f64,
246+
%arg2 : vector<2xf32>, %arg3 : vector<2xf64>) {
246247
// CHECK: call i32 @llvm.lround.i32.f32
247248
"llvm.intr.lround"(%arg0) : (f32) -> i32
248249
// CHECK: call i64 @llvm.lround.i64.f32
@@ -251,6 +252,14 @@ llvm.func @lround_test(%arg0 : f32, %arg1 : f64) {
251252
"llvm.intr.lround"(%arg1) : (f64) -> i32
252253
// CHECK: call i64 @llvm.lround.i64.f64
253254
"llvm.intr.lround"(%arg1) : (f64) -> i64
255+
// CHECK: call <2 x i32> @llvm.lround.v2i32.v2f32
256+
"llvm.intr.lround"(%arg2) : (vector<2xf32>) -> vector<2xi32>
257+
// CHECK: call <2 x i32> @llvm.lround.v2i32.v2f64
258+
"llvm.intr.lround"(%arg3) : (vector<2xf64>) -> vector<2xi32>
259+
// CHECK: call <2 x i64> @llvm.lround.v2i64.v2f32
260+
"llvm.intr.lround"(%arg2) : (vector<2xf32>) -> vector<2xi64>
261+
// CHECK: call <2 x i64> @llvm.lround.v2i64.v2f64
262+
"llvm.intr.lround"(%arg3) : (vector<2xf64>) -> vector<2xi64>
254263
llvm.return
255264
}
256265

@@ -264,7 +273,8 @@ llvm.func @llround_test(%arg0 : f32, %arg1 : f64) {
264273
}
265274

266275
// CHECK-LABEL: @lrint_test
267-
llvm.func @lrint_test(%arg0 : f32, %arg1 : f64) {
276+
llvm.func @lrint_test(%arg0 : f32, %arg1 : f64,
277+
%arg2 : vector<2xf32>, %arg3 : vector<2xf64>) {
268278
// CHECK: call i32 @llvm.lrint.i32.f32
269279
"llvm.intr.lrint"(%arg0) : (f32) -> i32
270280
// CHECK: call i64 @llvm.lrint.i64.f32
@@ -273,15 +283,28 @@ llvm.func @lrint_test(%arg0 : f32, %arg1 : f64) {
273283
"llvm.intr.lrint"(%arg1) : (f64) -> i32
274284
// CHECK: call i64 @llvm.lrint.i64.f64
275285
"llvm.intr.lrint"(%arg1) : (f64) -> i64
286+
// CHECK: call <2 x i32> @llvm.lrint.v2i32.v2f32
287+
"llvm.intr.lrint"(%arg2) : (vector<2xf32>) -> vector<2xi32>
288+
// CHECK: call <2 x i32> @llvm.lrint.v2i32.v2f64
289+
"llvm.intr.lrint"(%arg3) : (vector<2xf64>) -> vector<2xi32>
290+
// CHECK: call <2 x i64> @llvm.lrint.v2i64.v2f32
291+
"llvm.intr.lrint"(%arg2) : (vector<2xf32>) -> vector<2xi64>
292+
// CHECK: call <2 x i64> @llvm.lrint.v2i64.v2f64
293+
"llvm.intr.lrint"(%arg3) : (vector<2xf64>) -> vector<2xi64>
276294
llvm.return
277295
}
278296

279297
// CHECK-LABEL: @llrint_test
280-
llvm.func @llrint_test(%arg0 : f32, %arg1 : f64) {
298+
llvm.func @llrint_test(%arg0 : f32, %arg1 : f64,
299+
%arg2 : vector<2xf32>, %arg3 : vector<2xf64>) {
281300
// CHECK: call i64 @llvm.llrint.i64.f32
282301
"llvm.intr.llrint"(%arg0) : (f32) -> i64
283302
// CHECK: call i64 @llvm.llrint.i64.f64
284303
"llvm.intr.llrint"(%arg1) : (f64) -> i64
304+
// CHECK: call <2 x i64> @llvm.llrint.v2i64.v2f32
305+
"llvm.intr.llrint"(%arg2) : (vector<2xf32>) -> vector<2xi64>
306+
// CHECK: call <2 x i64> @llvm.llrint.v2i64.v2f64
307+
"llvm.intr.llrint"(%arg3) : (vector<2xf64>) -> vector<2xi64>
285308
llvm.return
286309
}
287310

0 commit comments

Comments
 (0)