Skip to content

Commit 67c1615

Browse files
committed
[MLIR] Add vector support for fpexp and fptrunc.
Differential Revision: https://reviews.llvm.org/D75150
1 parent 0d65000 commit 67c1615

File tree

4 files changed

+118
-0
lines changed

4 files changed

+118
-0
lines changed

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2601,6 +2601,10 @@ bool FPExtOp::areCastCompatible(Type a, Type b) {
26012601
if (auto fa = a.dyn_cast<FloatType>())
26022602
if (auto fb = b.dyn_cast<FloatType>())
26032603
return fa.getWidth() < fb.getWidth();
2604+
if (auto va = a.dyn_cast<VectorType>())
2605+
if (auto vb = b.dyn_cast<VectorType>())
2606+
return va.getShape().equals(vb.getShape()) &&
2607+
areCastCompatible(va.getElementType(), vb.getElementType());
26042608
return false;
26052609
}
26062610

@@ -2612,6 +2616,10 @@ bool FPTruncOp::areCastCompatible(Type a, Type b) {
26122616
if (auto fa = a.dyn_cast<FloatType>())
26132617
if (auto fb = b.dyn_cast<FloatType>())
26142618
return fa.getWidth() > fb.getWidth();
2619+
if (auto va = a.dyn_cast<VectorType>())
2620+
if (auto vb = b.dyn_cast<VectorType>())
2621+
return va.getShape().equals(vb.getShape()) &&
2622+
areCastCompatible(va.getElementType(), vb.getElementType());
26152623
return false;
26162624
}
26172625

mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,18 @@ func @fpext(%arg0 : f16, %arg1 : f32) {
485485
return
486486
}
487487

488+
// Checking conversion of integer types to floating point.
489+
// CHECK-LABEL: @fpext
490+
func @fpext_vector(%arg0 : vector<2xf16>, %arg1 : vector<2xf32>) {
491+
// CHECK-NEXT: = llvm.fpext {{.*}} : !llvm<"<2 x half>"> to !llvm<"<2 x float>">
492+
%0 = fpext %arg0: vector<2xf16> to vector<2xf32>
493+
// CHECK-NEXT: = llvm.fpext {{.*}} : !llvm<"<2 x half>"> to !llvm<"<2 x double>">
494+
%1 = fpext %arg0: vector<2xf16> to vector<2xf64>
495+
// CHECK-NEXT: = llvm.fpext {{.*}} : !llvm<"<2 x float>"> to !llvm<"<2 x double>">
496+
%2 = fpext %arg1: vector<2xf32> to vector<2xf64>
497+
return
498+
}
499+
488500
// Checking conversion of integer types to floating point.
489501
// CHECK-LABEL: @fptrunc
490502
func @fptrunc(%arg0 : f32, %arg1 : f64) {
@@ -497,6 +509,18 @@ func @fptrunc(%arg0 : f32, %arg1 : f64) {
497509
return
498510
}
499511

512+
// Checking conversion of integer types to floating point.
513+
// CHECK-LABEL: @fptrunc
514+
func @fptrunc_vector(%arg0 : vector<2xf32>, %arg1 : vector<2xf64>) {
515+
// CHECK-NEXT: = llvm.fptrunc {{.*}} : !llvm<"<2 x float>"> to !llvm<"<2 x half>">
516+
%0 = fptrunc %arg0: vector<2xf32> to vector<2xf16>
517+
// CHECK-NEXT: = llvm.fptrunc {{.*}} : !llvm<"<2 x double>"> to !llvm<"<2 x half>">
518+
%1 = fptrunc %arg1: vector<2xf64> to vector<2xf16>
519+
// CHECK-NEXT: = llvm.fptrunc {{.*}} : !llvm<"<2 x double>"> to !llvm<"<2 x float>">
520+
%2 = fptrunc %arg1: vector<2xf64> to vector<2xf32>
521+
return
522+
}
523+
500524
// Check sign and zero extension and truncation of integers.
501525
// CHECK-LABEL: @integer_extension_and_truncation
502526
func @integer_extension_and_truncation() {

mlir/test/IR/core-ops.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,12 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
506506
// CHECK: %{{[0-9]+}} = sqrt %arg0 : tensor<4x4x?xf32>
507507
%142 = sqrt %t : tensor<4x4x?xf32>
508508

509+
// CHECK: = fpext {{.*}} : vector<4xf32> to vector<4xf64>
510+
%143 = fpext %vcf32 : vector<4xf32> to vector<4xf64>
511+
512+
// CHECK: = fptrunc {{.*}} : vector<4xf32> to vector<4xf16>
513+
%144 = fptrunc %vcf32 : vector<4xf32> to vector<4xf16>
514+
509515
return
510516
}
511517

mlir/test/IR/invalid-ops.mlir

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,46 @@ func @fpext_f32_to_i32(%arg0 : f32) {
563563

564564
// -----
565565

566+
func @fpext_vec(%arg0 : vector<2xf16>) {
567+
// expected-error@+1 {{are cast incompatible}}
568+
%0 = fpext %arg0 : vector<2xf16> to vector<3xf32>
569+
return
570+
}
571+
572+
// -----
573+
574+
func @fpext_vec_f32_to_f16(%arg0 : vector<2xf32>) {
575+
// expected-error@+1 {{are cast incompatible}}
576+
%0 = fpext %arg0 : vector<2xf32> to vector<2xf16>
577+
return
578+
}
579+
580+
// -----
581+
582+
func @fpext_vec_f16_to_f16(%arg0 : vector<2xf16>) {
583+
// expected-error@+1 {{are cast incompatible}}
584+
%0 = fpext %arg0 : vector<2xf16> to vector<2xf16>
585+
return
586+
}
587+
588+
// -----
589+
590+
func @fpext_vec_i32_to_f32(%arg0 : vector<2xi32>) {
591+
// expected-error@+1 {{are cast incompatible}}
592+
%0 = fpext %arg0 : vector<2xi32> to vector<2xf32>
593+
return
594+
}
595+
596+
// -----
597+
598+
func @fpext_vec_f32_to_i32(%arg0 : vector<2xf32>) {
599+
// expected-error@+1 {{are cast incompatible}}
600+
%0 = fpext %arg0 : vector<2xf32> to vector<2xi32>
601+
return
602+
}
603+
604+
// -----
605+
566606
func @fptrunc_f16_to_f32(%arg0 : f16) {
567607
// expected-error@+1 {{are cast incompatible}}
568608
%0 = fptrunc %arg0 : f16 to f32
@@ -595,6 +635,46 @@ func @fptrunc_f32_to_i32(%arg0 : f32) {
595635

596636
// -----
597637

638+
func @fptrunc_vec(%arg0 : vector<2xf16>) {
639+
// expected-error@+1 {{are cast incompatible}}
640+
%0 = fptrunc %arg0 : vector<2xf16> to vector<3xf32>
641+
return
642+
}
643+
644+
// -----
645+
646+
func @fptrunc_vec_f16_to_f32(%arg0 : vector<2xf16>) {
647+
// expected-error@+1 {{are cast incompatible}}
648+
%0 = fptrunc %arg0 : vector<2xf16> to vector<2xf32>
649+
return
650+
}
651+
652+
// -----
653+
654+
func @fptrunc_vec_f32_to_f32(%arg0 : vector<2xf32>) {
655+
// expected-error@+1 {{are cast incompatible}}
656+
%0 = fptrunc %arg0 : vector<2xf32> to vector<2xf32>
657+
return
658+
}
659+
660+
// -----
661+
662+
func @fptrunc_vec_i32_to_f32(%arg0 : vector<2xi32>) {
663+
// expected-error@+1 {{are cast incompatible}}
664+
%0 = fptrunc %arg0 : vector<2xi32> to vector<2xf32>
665+
return
666+
}
667+
668+
// -----
669+
670+
func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) {
671+
// expected-error@+1 {{are cast incompatible}}
672+
%0 = fptrunc %arg0 : vector<2xf32> to vector<2xi32>
673+
return
674+
}
675+
676+
// -----
677+
598678
func @sexti_index_as_operand(%arg0 : index) {
599679
// expected-error@+1 {{'index' is not a valid operand type}}
600680
%0 = sexti %arg0 : index to i128

0 commit comments

Comments
 (0)