Skip to content

Commit 1b610e6

Browse files
authored
[MLIR][Math] Add floating point value folders (#127947)
1 parent 1b78ff6 commit 1b610e6

File tree

3 files changed

+140
-0
lines changed

3 files changed

+140
-0
lines changed

mlir/include/mlir/Dialect/Math/IR/MathOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,7 @@ def Math_IsFiniteOp : Math_FloatClassificationOp<"isfinite"> {
736736
%f = math.isfinite %a : f32
737737
```
738738
}];
739+
let hasFolder = 1;
739740
}
740741

741742
//===----------------------------------------------------------------------===//
@@ -754,6 +755,7 @@ def Math_IsInfOp : Math_FloatClassificationOp<"isinf"> {
754755
%f = math.isinf %a : f32
755756
```
756757
}];
758+
let hasFolder = 1;
757759
}
758760

759761
//===----------------------------------------------------------------------===//
@@ -772,6 +774,7 @@ def Math_IsNaNOp : Math_FloatClassificationOp<"isnan"> {
772774
%f = math.isnan %a : f32
773775
```
774776
}];
777+
let hasFolder = 1;
775778
}
776779

777780

@@ -791,6 +794,7 @@ def Math_IsNormalOp : Math_FloatClassificationOp<"isnormal"> {
791794
%f = math.isnormal %a : f32
792795
```
793796
}];
797+
let hasFolder = 1;
794798
}
795799

796800
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Math/IR/MathOps.cpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -579,6 +579,70 @@ OpFoldResult math::ExpM1Op::fold(FoldAdaptor adaptor) {
579579
});
580580
}
581581

582+
//===----------------------------------------------------------------------===//
583+
// IsFiniteOp folder
584+
//===----------------------------------------------------------------------===//
585+
586+
OpFoldResult math::IsFiniteOp::fold(FoldAdaptor adaptor) {
587+
if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
588+
return BoolAttr::get(val.getContext(), val.getValue().isFinite());
589+
}
590+
if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
591+
return DenseElementsAttr::get(
592+
cast<ShapedType>(getType()),
593+
APInt(1, splat.getSplatValue<APFloat>().isFinite()));
594+
}
595+
return {};
596+
}
597+
598+
//===----------------------------------------------------------------------===//
599+
// IsInfOp folder
600+
//===----------------------------------------------------------------------===//
601+
602+
OpFoldResult math::IsInfOp::fold(FoldAdaptor adaptor) {
603+
if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
604+
return BoolAttr::get(val.getContext(), val.getValue().isInfinity());
605+
}
606+
if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
607+
return DenseElementsAttr::get(
608+
cast<ShapedType>(getType()),
609+
APInt(1, splat.getSplatValue<APFloat>().isInfinity()));
610+
}
611+
return {};
612+
}
613+
614+
//===----------------------------------------------------------------------===//
615+
// IsNaNOp folder
616+
//===----------------------------------------------------------------------===//
617+
618+
OpFoldResult math::IsNaNOp::fold(FoldAdaptor adaptor) {
619+
if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
620+
return BoolAttr::get(val.getContext(), val.getValue().isNaN());
621+
}
622+
if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
623+
return DenseElementsAttr::get(
624+
cast<ShapedType>(getType()),
625+
APInt(1, splat.getSplatValue<APFloat>().isNaN()));
626+
}
627+
return {};
628+
}
629+
630+
//===----------------------------------------------------------------------===//
631+
// IsNormalOp folder
632+
//===----------------------------------------------------------------------===//
633+
634+
OpFoldResult math::IsNormalOp::fold(FoldAdaptor adaptor) {
635+
if (auto val = dyn_cast_or_null<FloatAttr>(adaptor.getOperand())) {
636+
return BoolAttr::get(val.getContext(), val.getValue().isNormal());
637+
}
638+
if (auto splat = dyn_cast_or_null<SplatElementsAttr>(adaptor.getOperand())) {
639+
return DenseElementsAttr::get(
640+
cast<ShapedType>(getType()),
641+
APInt(1, splat.getSplatValue<APFloat>().isNormal()));
642+
}
643+
return {};
644+
}
645+
582646
//===----------------------------------------------------------------------===//
583647
// TanOp folder
584648
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Math/canonicalize.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,3 +492,75 @@ func.func @abs_poison() -> f32 {
492492
%1 = math.absf %0 : f32
493493
return %1 : f32
494494
}
495+
496+
// CHECK-LABEL: @isfinite_fold
497+
// CHECK: %[[cst:.+]] = arith.constant true
498+
// CHECK: return %[[cst]]
499+
func.func @isfinite_fold() -> i1 {
500+
%c = arith.constant 2.0 : f32
501+
%r = math.isfinite %c : f32
502+
return %r : i1
503+
}
504+
505+
// CHECK-LABEL: @isfinite_fold_vec
506+
// CHECK: %[[cst:.+]] = arith.constant dense<true> : vector<4xi1>
507+
// CHECK: return %[[cst]]
508+
func.func @isfinite_fold_vec() -> (vector<4xi1>) {
509+
%v1 = arith.constant dense<2.0> : vector<4xf32>
510+
%0 = math.isfinite %v1 : vector<4xf32>
511+
return %0 : vector<4xi1>
512+
}
513+
514+
// CHECK-LABEL: @isinf_fold
515+
// CHECK: %[[cst:.+]] = arith.constant false
516+
// CHECK: return %[[cst]]
517+
func.func @isinf_fold() -> i1 {
518+
%c = arith.constant 2.0 : f32
519+
%r = math.isinf %c : f32
520+
return %r : i1
521+
}
522+
523+
// CHECK-LABEL: @isinf_fold_vec
524+
// CHECK: %[[cst:.+]] = arith.constant dense<false> : vector<4xi1>
525+
// CHECK: return %[[cst]]
526+
func.func @isinf_fold_vec() -> (vector<4xi1>) {
527+
%v1 = arith.constant dense<2.0> : vector<4xf32>
528+
%0 = math.isinf %v1 : vector<4xf32>
529+
return %0 : vector<4xi1>
530+
}
531+
532+
// CHECK-LABEL: @isnan_fold
533+
// CHECK: %[[cst:.+]] = arith.constant false
534+
// CHECK: return %[[cst]]
535+
func.func @isnan_fold() -> i1 {
536+
%c = arith.constant 2.0 : f32
537+
%r = math.isnan %c : f32
538+
return %r : i1
539+
}
540+
541+
// CHECK-LABEL: @isnan_fold_vec
542+
// CHECK: %[[cst:.+]] = arith.constant dense<false> : vector<4xi1>
543+
// CHECK: return %[[cst]]
544+
func.func @isnan_fold_vec() -> (vector<4xi1>) {
545+
%v1 = arith.constant dense<2.0> : vector<4xf32>
546+
%0 = math.isnan %v1 : vector<4xf32>
547+
return %0 : vector<4xi1>
548+
}
549+
550+
// CHECK-LABEL: @isnormal_fold
551+
// CHECK: %[[cst:.+]] = arith.constant true
552+
// CHECK: return %[[cst]]
553+
func.func @isnormal_fold() -> i1 {
554+
%c = arith.constant 2.0 : f32
555+
%r = math.isnormal %c : f32
556+
return %r : i1
557+
}
558+
559+
// CHECK-LABEL: @isnormal_fold_vec
560+
// CHECK: %[[cst:.+]] = arith.constant dense<true> : vector<4xi1>
561+
// CHECK: return %[[cst]]
562+
func.func @isnormal_fold_vec() -> (vector<4xi1>) {
563+
%v1 = arith.constant dense<2.0> : vector<4xf32>
564+
%0 = math.isnormal %v1 : vector<4xf32>
565+
return %0 : vector<4xi1>
566+
}

0 commit comments

Comments
 (0)