Skip to content

[mlir][LLVM] Add nsw and nuw flags to trunc #115509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 11, 2024

Conversation

lfrenot
Copy link
Contributor

@lfrenot lfrenot commented Nov 8, 2024

This implementation is based on the one already existing for the binary operations.

If the nuw keyword is present, and any of the truncated bits are non-zero, the result is a poison value. If the nsw keyword is present, and any of the truncated bits are not the same as the top bit of the truncation result, the result is a poison value.

@zero9178, @gysit and @Dinistro, could you take a look?

@llvmbot
Copy link
Member

llvmbot commented Nov 8, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: None (lfrenot)

Changes

This implementation is based on the one already existing for the binary operations.

If the nuw keyword is present, and any of the truncated bits are non-zero, the result is a poison value. If the nsw keyword is present, and any of the truncated bits are not the same as the top bit of the truncation result, the result is a poison value.

@zero9178, @gysit and @Dinistro, could you take a look?


Full diff: https://github.com/llvm/llvm-project/pull/115509.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+18-1)
  • (modified) mlir/test/Dialect/LLVMIR/roundtrip.mlir (+17)
  • (modified) mlir/test/Target/LLVMIR/Import/nsw_nuw.ll (+2)
  • (modified) mlir/test/Target/LLVMIR/nsw_nuw.mlir (+2)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index 315af2594047a5..ef81a068a36055 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -508,6 +508,23 @@ class LLVM_CastOp<string mnemonic, string instName, Type type,
       $_location, $_resultType, $arg);
   }];
 }
+class LLVM_CastOpWithOverflowFlag<string mnemonic, string instName, Type type,
+                  Type resultType, list<Trait> traits = []> :
+    LLVM_Op<mnemonic, !listconcat([Pure], [DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)>,
+    LLVM_Builder<"$res = builder.Create" # instName # "($arg, $_resultType, /*Name=*/\"\", op.hasNoUnsignedWrap(), op.hasNoSignedWrap());"> {
+  let arguments = (ins type:$arg, EnumProperty<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags);
+  let results = (outs resultType:$res);
+  let builders = [LLVM_OneResultOpBuilder];
+  let assemblyFormat = "$arg attr-dict `` custom<OverflowFlags>($overflowFlags) `:` type($arg) `to` type($res)";
+  string llvmInstName = instName;
+  string mlirBuilder = [{
+    auto op = $_builder.create<$_qualCppClassName>(
+      $_location, $_resultType, $arg);
+    moduleImport.setIntegerOverflowFlags(inst, op);
+    $res = op;
+  }];
+}
+
 def LLVM_BitcastOp : LLVM_CastOp<"bitcast", "BitCast", LLVM_AnyNonAggregate,
     LLVM_AnyNonAggregate, [DeclareOpInterfaceMethods<PromotableOpInterface>]> {
   let hasFolder = 1;
@@ -537,7 +554,7 @@ def LLVM_ZExtOp : LLVM_CastOp<"zext", "ZExt",
   let hasFolder = 1;
   let hasVerifier = 1;
 }
-def LLVM_TruncOp : LLVM_CastOp<"trunc", "Trunc",
+def LLVM_TruncOp : LLVM_CastOpWithOverflowFlag<"trunc", "Trunc",
                                LLVM_ScalarOrVectorOf<AnySignlessInteger>,
                                LLVM_ScalarOrVectorOf<AnySignlessInteger>>;
 def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "SIToFP",
diff --git a/mlir/test/Dialect/LLVMIR/roundtrip.mlir b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
index 682780c5f0a7df..73776df3484273 100644
--- a/mlir/test/Dialect/LLVMIR/roundtrip.mlir
+++ b/mlir/test/Dialect/LLVMIR/roundtrip.mlir
@@ -325,6 +325,23 @@ func.func @casts(%arg0: i32, %arg1: i64, %arg2: vector<4xi32>,
   llvm.return
 }
 
+// CHECK-LABEL: @casts_overflow
+// CHECK-SAME: (%[[I32:.*]]: i32, %[[I64:.*]]: i64, %[[V4I32:.*]]: vector<4xi32>, %[[V4I64:.*]]: vector<4xi64>, %[[PTR:.*]]: !llvm.ptr)
+func.func @casts_overflow(%arg0: i32, %arg1: i64, %arg2: vector<4xi32>,
+            %arg3: vector<4xi64>, %arg4: !llvm.ptr) {
+// CHECK:  = llvm.trunc %[[I64]] overflow<nsw> : i64 to i56
+  %0 = llvm.trunc %arg1 overflow<nsw> : i64 to i56
+// CHECK:  = llvm.trunc %[[I64]] overflow<nuw> : i64 to i56
+  %1 = llvm.trunc %arg1 overflow<nuw> : i64 to i56
+// CHECK:  = llvm.trunc %[[I64]] overflow<nsw, nuw> : i64 to i56
+  %2 = llvm.trunc %arg1 overflow<nsw, nuw> : i64 to i56
+// CHECK:  = llvm.trunc %[[I64]] overflow<nsw, nuw> : i64 to i56
+  %3 = llvm.trunc %arg1 overflow<nuw, nsw> : i64 to i56
+// CHECK:  = llvm.trunc %[[V4I64]] overflow<nsw> : vector<4xi64> to vector<4xi56>
+  %4 = llvm.trunc %arg3 overflow<nsw> : vector<4xi64> to vector<4xi56>
+  llvm.return
+}
+
 // CHECK-LABEL: @vect
 func.func @vect(%arg0: vector<4xf32>, %arg1: i32, %arg2: f32, %arg3: !llvm.vec<2 x ptr>) {
 // CHECK:  = llvm.extractelement {{.*}} : vector<4xf32>
diff --git a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
index d08098a5e5dfe0..4af799da36dc08 100644
--- a/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
+++ b/mlir/test/Target/LLVMIR/Import/nsw_nuw.ll
@@ -10,5 +10,7 @@ define void @intflag_inst(i64 %arg1, i64 %arg2) {
   %3 = mul nsw nuw i64 %arg1, %arg2
   ; CHECK: llvm.shl %{{.*}}, %{{.*}} overflow<nsw, nuw> : i64
   %4 = shl nuw nsw i64 %arg1, %arg2
+  ; CHECK: llvm.trunc %{{.*}} overflow<nsw> : i64 to i32
+  %5 = trunc nsw i64 %arg1 to i32
   ret void
 }
diff --git a/mlir/test/Target/LLVMIR/nsw_nuw.mlir b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
index 6843c2ef0299c7..584aa05a04f7cf 100644
--- a/mlir/test/Target/LLVMIR/nsw_nuw.mlir
+++ b/mlir/test/Target/LLVMIR/nsw_nuw.mlir
@@ -10,5 +10,7 @@ llvm.func @intflags_func(%arg0: i64, %arg1: i64) {
   %2 = llvm.mul %arg0, %arg1 overflow <nsw, nuw> : i64
   // CHECK: %{{.*}} = shl nuw nsw i64 %{{.*}}, %{{.*}}
   %3 = llvm.shl %arg0, %arg1 overflow <nsw, nuw> : i64
+  // CHECK: %{{.*}} = trunc nuw i64 %{{.*}} to i32
+  %4 = llvm.trunc %arg1 overflow<nuw> : i64 to i32
   llvm.return
 }

Copy link
Member

@zero9178 zero9178 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@zero9178 zero9178 requested review from gysit and Dinistro November 8, 2024 17:05
@tobiasgrosser
Copy link
Contributor

Thank you. @lfrenot does not have merge rights, so please merge when you are happy with this.

Copy link
Contributor

@gysit gysit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo nit.

custom<LLVMOpAttrs>(attr-dict) is apparently not necessary when looking at your tests.

let arguments = (ins type:$arg, EnumProperty<"IntegerOverflowFlags", "", "IntegerOverflowFlags::none">:$overflowFlags);
let results = (outs resultType:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "$arg attr-dict `` custom<OverflowFlags>($overflowFlags) `:` type($arg) `to` type($res)";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let assemblyFormat = "$arg attr-dict `` custom<OverflowFlags>($overflowFlags) `:` type($arg) `to` type($res)";
let assemblyFormat = "$arg custom<OverflowFlags>($overflowFlags) `` attr-dict `:` type($arg) `to` type($res)";

nit: Can you move the attribute dictionary after the overflow flags. I think this is how it is done for the other operations that take overflow flags.

I am somewhat surprised you don't need to elide overflow flags from the attribute dictionary when printing using custom<LLVMOpAttrs>(attr-dict). Maybe tablegen does elide the flag since it is passed to a custom print function.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the integer overflow flags are properties they should not be printed as part of attr-dict and custom<LLVMOpAttrs>(attr-dict) is not needed. In fact, it seems like the print/parseLLVMOpAttrs only makes sense for the fast math flags these days.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what you mean by "custom<LLVMOpAttrs>(attr-dict) is not needed". Does anything need to be changed besides the nit above?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the nit should be sufficient (especially since the tests pass). The purpose of custom<LLVMOpAttrs>(attr-dict) is to filter out specific attributes from the attribute dictionary. This is not needed anymore now that the flag is a property.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose the double tick in the assembly format is not necessary since there seem to be missing spaces in the assembly format (see the failing tests in the ci). I would try the following:

custom<OverflowFlags>($overflowFlags) attr-dict

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That worked, thank you

@Dinistro
Copy link
Contributor

Side note: @lfrenot I suggest to directly request review from us instead of pinging, this ensures that stuff shows up in the correct places in GitHub 🙂

@lfrenot
Copy link
Contributor Author

lfrenot commented Nov 11, 2024

Side note: @lfrenot I suggest to directly request review from us instead of pinging, this ensures that stuff shows up in the correct places in GitHub 🙂

I would, but I do not have the perms to request reviews

@Dinistro
Copy link
Contributor

Side note: @lfrenot I suggest to directly request review from us instead of pinging, this ensures that stuff shows up in the correct places in GitHub 🙂

I would, but I do not have the perms to request reviews

That seems weird, considering that you can ping people. I'll ask in the discord to figure out if this is intentional.

@lfrenot lfrenot force-pushed the mlir-llvm-trunc-overflow branch 2 times, most recently from d7f5df7 to b8834c9 Compare November 11, 2024 13:14
@tobiasgrosser
Copy link
Contributor

@lfrenot may need commit access to get these rights. Lets give it a bit, before we ask chris for commit access.

This implementation is based on the one already existing for the binary operations.

If the nuw keyword is present, and any of the truncated bits are non-zero, the result is a poison value. If the nsw keyword is present, and any of the truncated bits are not the same as the top bit of the truncation result, the result is a poison value.
@lfrenot lfrenot force-pushed the mlir-llvm-trunc-overflow branch from b8834c9 to d432191 Compare November 11, 2024 13:23
Co-authored-by: Tobias Gysi <[email protected]>
@lfrenot lfrenot force-pushed the mlir-llvm-trunc-overflow branch from d432191 to 5efc314 Compare November 11, 2024 14:59
@gysit gysit merged commit bf601ba into llvm:main Nov 11, 2024
8 checks passed
@lfrenot lfrenot deleted the mlir-llvm-trunc-overflow branch November 11, 2024 16:18
@Dinistro
Copy link
Contributor

@lfrenot may need commit access to get these rights. Lets give it a bit, before we ask chris for commit access.

This seems to be indeed be a GitHub restriction. Not sure why this is the case but that is not relevant for this PR.

@Dinistro
Copy link
Contributor

I'm wondering if we should now also add these flags to arith.trunci?

Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
This implementation is based on the one already existing for the binary
operations.

If the nuw keyword is present, and any of the truncated bits are
non-zero, the result is a poison value. If the nsw keyword is present,
and any of the truncated bits are not the same as the top bit of the
truncation result, the result is a poison value.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants