Skip to content

[mlir] ArithToLLVM: fix memref bitcast lowering #125148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 12, 2025

Conversation

Hardcode84
Copy link
Contributor

arith.bitcast is allowed on memrefs and such code can actually be generated by IREE ConvertBf16ArithToF32Pass. LLVM::detail::vectorOneToOneRewrite doesn't properly check its types and will generate bitcast between structs which is illegal.

With the opaque pointers this is a no-op operation for memref so we can just add type check in LLVM::detail::vectorOneToOneRewrite and add a separate pattern which removes op if converted types are the same.

@llvmbot
Copy link
Member

llvmbot commented Jan 31, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Ivan Butygin (Hardcode84)

Changes

arith.bitcast is allowed on memrefs and such code can actually be generated by IREE ConvertBf16ArithToF32Pass. LLVM::detail::vectorOneToOneRewrite doesn't properly check its types and will generate bitcast between structs which is illegal.

With the opaque pointers this is a no-op operation for memref so we can just add type check in LLVM::detail::vectorOneToOneRewrite and add a separate pattern which removes op if converted types are the same.


Full diff: https://github.com/llvm/llvm-project/pull/125148.diff

3 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp (+20)
  • (modified) mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp (+6-1)
  • (modified) mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir (+12)
diff --git a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
index 754ed898142936..b726faa92a03a0 100644
--- a/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
+++ b/mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp
@@ -54,6 +54,23 @@ struct ConstrainedVectorConvertToLLVMPattern
   }
 };
 
+/// No-op bitcast.
+struct IdentityBitcastLowering final
+    : public OpConversionPattern<arith::BitcastOp> {
+  using OpConversionPattern::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const final {
+    Value src = adaptor.getIn();
+    if (src.getType() != getTypeConverter()->convertType(op.getType()))
+      return rewriter.notifyMatchFailure(op, "Types are different");
+
+    rewriter.replaceOp(op, src);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Straightforward Op Lowerings
 //===----------------------------------------------------------------------===//
@@ -524,6 +541,9 @@ void mlir::arith::registerConvertArithToLLVMInterface(
 
 void mlir::arith::populateArithToLLVMConversionPatterns(
     const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
+
+  patterns.add<IdentityBitcastLowering>(converter, patterns.getContext());
+
   // clang-format off
   patterns.add<
     AddFOpLowering,
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index 626135c10a3e96..c9d3b57b0d596e 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -103,6 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
   return success();
 }
 
+static bool isVectorCompatibleType(Type type) {
+  return isa<LLVM::LLVMArrayType, VectorType, IntegerType, FloatType>(type) &&
+         LLVM::isCompatibleType(type);
+}
+
 LogicalResult LLVM::detail::vectorOneToOneRewrite(
     Operation *op, StringRef targetOp, ValueRange operands,
     ArrayRef<NamedAttribute> targetAttrs,
@@ -111,7 +116,7 @@ LogicalResult LLVM::detail::vectorOneToOneRewrite(
   assert(!operands.empty());
 
   // Cannot convert ops if their operands are not of LLVM type.
-  if (!llvm::all_of(operands.getTypes(), isCompatibleType))
+  if (!llvm::all_of(operands.getTypes(), isVectorCompatibleType))
     return failure();
 
   auto llvmNDVectorTy = operands[0].getType();
diff --git a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
index 1dabacfd8a47cc..9a6c4bca88f3bf 100644
--- a/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
+++ b/mlir/test/Conversion/ArithToLLVM/arith-to-llvm.mlir
@@ -727,3 +727,15 @@ func.func @ops_supporting_overflow(%arg0: i64, %arg1: i64) {
   %3 = arith.shli %arg0, %arg1 overflow<nsw, nuw> : i64
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @memref_bitcast
+//  CHECK-SAME:   (%[[ARG:.*]]: memref<?xi16>)
+//       CHECK:   %[[V1:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<?xi16> to !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
+//       CHECK:   %[[V2:.*]] = builtin.unrealized_conversion_cast %[[V1]] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xbf16>
+//       CHECK:   return %[[V2]]
+func.func @memref_bitcast(%1: memref<?xi16>) -> memref<?xbf16> {
+  %2 = arith.bitcast %1 : memref<?xi16> to memref<?xbf16>
+  func.return %2 : memref<?xbf16>
+}

matchAndRewrite(arith::BitcastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
Value src = adaptor.getIn();
if (src.getType() != getTypeConverter()->convertType(op.getType()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find this a bit difficult to follow. Would it make sense to change this to: if (!isa<MemRefType>(src.getType())? And rename the pattern to MemRefBitcastLowering?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if we need any memref-specific logic here, just handling same input/output converted types should be enough.

@@ -103,6 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
return success();
}

static bool isVectorCompatibleType(Type type) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can just LLVM::isCompatibleType be used here? It already checks for LLVMArrayType, VectorType, etc. Alternatively, there is also LLVM::isCompatibleVectorType, which may be useful here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was using LLVM::isCompatibleType before, but it's too broad, I specifically want to limit this transform to scalar and vector types.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of fiddling, can you just set PatternBenefit on the bitcast pattern?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While PatternBenefit will probably work in this specific case, I still think there is a potential problem in vectorOneToOneRewrite as it can generate llvm bitcasts for unsupported types like structs.

arith.bitcast is allowed on memrefs and such code can actually be generated by IREE `ConvertBf16ArithToF32Pass`.
`LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types and will generate bitcast between structs which is illegal.
With the opaque pointers this is a no-op operation for memref so we can just add type check in `LLVM::detail::vectorOneToOneRewrite` and add a separate pattern which removes op if converted types are the same.
@Hardcode84
Copy link
Contributor Author

ping

@gysit
Copy link
Contributor

gysit commented Feb 11, 2025

Looks generally ok to me but is outside of my area of expertise (especially with regards to what types the vector lowering supports etc). Is there an owner that could be assigned for a review?

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there're ways to deal with this that don't require messing with VectorPattern.cpp

@@ -103,6 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
return success();
}

static bool isVectorCompatibleType(Type type) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of fiddling, can you just set PatternBenefit on the bitcast pattern?

@@ -54,6 +54,25 @@ struct ConstrainedVectorConvertToLLVMPattern
}
};

/// No-op bitcast. Propagate type input arg if converted source and dest types
/// are the same.
struct IdentityBitcastLowering final
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, to get the behavior you want here where identity bitcast on memref gets folded away before the general pattern for arith.bitcast kicks in, you want to set a PatternBenefit

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see your point about tightening up the bounds on a utility that shouldn't be used for stuff that doesn't work.

What I was worried about is that the one-to-one pattern might be used outside of the arith lowerings where being able to deal with a struct could come up

At the very least, we should allow pointers.

But, in the interests of not blocking things and of it not being an unreasonable tightening, approved.

@Hardcode84 Hardcode84 force-pushed the fix-memref-bitcast-llvm branch from 88d92bf to 01d05df Compare February 12, 2025 02:42
@Hardcode84
Copy link
Contributor Author

What I was worried about is that the one-to-one pattern might be used outside of the arith lowerings where being able to deal with a struct could come up
At the very least, we should allow pointers.

Added pointer support, but it would most likely will never be triggered as the only upstream usages are arith/math lowering.

@Hardcode84 Hardcode84 merged commit 79010e2 into llvm:main Feb 12, 2025
8 checks passed
@Hardcode84 Hardcode84 deleted the fix-memref-bitcast-llvm branch February 12, 2025 11:19
@llvm-ci
Copy link
Collaborator

llvm-ci commented Feb 12, 2025

LLVM Buildbot has detected a new failure on builder mlir-rocm-mi200 running on mi200-buildbot while building mlir at step 7 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/177/builds/12872

Here is the relevant piece of the build log for the reference
Step 7 (test-build-check-mlir-build-only-check-mlir) failure: test (failure)
******************** TEST 'MLIR :: Integration/Dialect/Complex/CPU/correctness.mlir' FAILED ********************
Exit Code: 2

Command Output (stdout):
--
# RUN: at line 1
/vol/worker/mi200-buildbot/mlir-rocm-mi200/build/bin/mlir-opt /vol/worker/mi200-buildbot/mlir-rocm-mi200/llvm-project/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir    -one-shot-bufferize="bufferize-function-boundaries" --canonicalize    -convert-scf-to-cf --convert-complex-to-standard    -finalize-memref-to-llvm -convert-math-to-llvm -convert-math-to-libm    -convert-vector-to-llvm -convert-complex-to-llvm    -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm    -reconcile-unrealized-casts | /vol/worker/mi200-buildbot/mlir-rocm-mi200/build/bin/mlir-runner   -e entry -entry-point-result=void    -shared-libs=/vol/worker/mi200-buildbot/mlir-rocm-mi200/build/lib/libmlir_c_runner_utils.so | /vol/worker/mi200-buildbot/mlir-rocm-mi200/build/bin/FileCheck /vol/worker/mi200-buildbot/mlir-rocm-mi200/llvm-project/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
# executed command: /vol/worker/mi200-buildbot/mlir-rocm-mi200/build/bin/mlir-opt /vol/worker/mi200-buildbot/mlir-rocm-mi200/llvm-project/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir -one-shot-bufferize=bufferize-function-boundaries --canonicalize -convert-scf-to-cf --convert-complex-to-standard -finalize-memref-to-llvm -convert-math-to-llvm -convert-math-to-libm -convert-vector-to-llvm -convert-complex-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts
# executed command: /vol/worker/mi200-buildbot/mlir-rocm-mi200/build/bin/mlir-runner -e entry -entry-point-result=void -shared-libs=/vol/worker/mi200-buildbot/mlir-rocm-mi200/build/lib/libmlir_c_runner_utils.so
# .---command stderr------------
# | loc("<stdin>":486:11): error: Dialect `arith' not found for custom op 'arith.select' 
# | could not parse the input IR
# `-----------------------------
# error: command failed with exit status: 1
# executed command: /vol/worker/mi200-buildbot/mlir-rocm-mi200/build/bin/FileCheck /vol/worker/mi200-buildbot/mlir-rocm-mi200/llvm-project/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
# .---command stderr------------
# | FileCheck error: '<stdin>' is empty.
# | FileCheck command line:  /vol/worker/mi200-buildbot/mlir-rocm-mi200/build/bin/FileCheck /vol/worker/mi200-buildbot/mlir-rocm-mi200/llvm-project/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
# `-----------------------------
# error: command failed with exit status: 2

--

********************


@llvm-ci
Copy link
Collaborator

llvm-ci commented Feb 12, 2025

LLVM Buildbot has detected a new failure on builder mlir-nvidia-gcc7 running on mlir-nvidia while building mlir at step 7 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/10142

Here is the relevant piece of the build log for the reference
Step 7 (test-build-check-mlir-build-only-check-mlir) failure: test (failure)
******************** TEST 'MLIR :: Integration/Dialect/Complex/CPU/correctness.mlir' FAILED ********************
Exit Code: 2

Command Output (stdout):
--
# RUN: at line 1
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir    -one-shot-bufferize="bufferize-function-boundaries" --canonicalize    -convert-scf-to-cf --convert-complex-to-standard    -finalize-memref-to-llvm -convert-math-to-llvm -convert-math-to-libm    -convert-vector-to-llvm -convert-complex-to-llvm    -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm    -reconcile-unrealized-casts | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-runner   -e entry -entry-point-result=void    -shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_c_runner_utils.so | /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-opt /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir -one-shot-bufferize=bufferize-function-boundaries --canonicalize -convert-scf-to-cf --convert-complex-to-standard -finalize-memref-to-llvm -convert-math-to-llvm -convert-math-to-libm -convert-vector-to-llvm -convert-complex-to-llvm -convert-func-to-llvm -convert-arith-to-llvm -convert-cf-to-llvm -reconcile-unrealized-casts
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/mlir-runner -e entry -entry-point-result=void -shared-libs=/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/lib/libmlir_c_runner_utils.so
# .---command stderr------------
# | loc("<stdin>":486:11): error: Dialect `arith' not found for custom op 'arith.select' 
# | could not parse the input IR
# `-----------------------------
# error: command failed with exit status: 1
# executed command: /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
# .---command stderr------------
# | FileCheck error: '<stdin>' is empty.
# | FileCheck command line:  /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/bin/FileCheck /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/test/Integration/Dialect/Complex/CPU/correctness.mlir
# `-----------------------------
# error: command failed with exit status: 2

--

********************


@llvm-ci
Copy link
Collaborator

llvm-ci commented Feb 12, 2025

LLVM Buildbot has detected a new failure on builder mlir-nvidia running on mlir-nvidia while building mlir at step 7 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/10206

Here is the relevant piece of the build log for the reference
Step 7 (test-build-check-mlir-build-only-check-mlir) failure: test (failure)

Hardcode84 added a commit that referenced this pull request Feb 12, 2025
Hardcode84 added a commit that referenced this pull request Feb 12, 2025
github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Feb 12, 2025
@Hardcode84
Copy link
Contributor Author

@krzysz00 You were actually right, it caused issues with select of complex lowering, and we didn't have lit tests for that, only integration tests. I will go with pattern benefit for now I guess.

Hardcode84 added a commit to Hardcode84/llvm-project that referenced this pull request Feb 12, 2025
`arith.bitcast` is allowed on memrefs and such code can actually be
generated by IREE `ConvertBf16ArithToF32Pass`.
`LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types
and will generate bitcast between structs which is illegal.

With the opaque pointers this is a no-op operation for memref so we can
just add a separate pattern which removes op if converted types are the same.
@Hardcode84
Copy link
Contributor Author

Updated PR #126939

Hardcode84 added a commit that referenced this pull request Feb 12, 2025
…6939)

Reland #125148

Limiting vector pattern caused issues with `select` of complex lowering,
which wasn't caught as it was missing lit tests. Keep the pattern as is
for now and instead set a higher benefit to `IdentityBitcastLowering` so
it will always run before the vector pattern.
github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Feb 12, 2025
…25148) (#126939)

Reland llvm/llvm-project#125148

Limiting vector pattern caused issues with `select` of complex lowering,
which wasn't caught as it was missing lit tests. Keep the pattern as is
for now and instead set a higher benefit to `IdentityBitcastLowering` so
it will always run before the vector pattern.
flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
`arith.bitcast` is allowed on memrefs and such code can actually be
generated by IREE `ConvertBf16ArithToF32Pass`.
`LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types
and will generate bitcast between structs which is illegal.

With the opaque pointers this is a no-op operation for memref so we can
just add type check in `LLVM::detail::vectorOneToOneRewrite` and add a
separate pattern which removes op if converted types are the same.
flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
flovent pushed a commit to flovent/llvm-project that referenced this pull request Feb 13, 2025
…lvm#126939)

Reland llvm#125148

Limiting vector pattern caused issues with `select` of complex lowering,
which wasn't caught as it was missing lit tests. Keep the pattern as is
for now and instead set a higher benefit to `IdentityBitcastLowering` so
it will always run before the vector pattern.
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
`arith.bitcast` is allowed on memrefs and such code can actually be
generated by IREE `ConvertBf16ArithToF32Pass`.
`LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types
and will generate bitcast between structs which is illegal.

With the opaque pointers this is a no-op operation for memref so we can
just add type check in `LLVM::detail::vectorOneToOneRewrite` and add a
separate pattern which removes op if converted types are the same.
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
joaosaffran pushed a commit to joaosaffran/llvm-project that referenced this pull request Feb 14, 2025
…lvm#126939)

Reland llvm#125148

Limiting vector pattern caused issues with `select` of complex lowering,
which wasn't caught as it was missing lit tests. Keep the pattern as is
for now and instead set a higher benefit to `IdentityBitcastLowering` so
it will always run before the vector pattern.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
`arith.bitcast` is allowed on memrefs and such code can actually be
generated by IREE `ConvertBf16ArithToF32Pass`.
`LLVM::detail::vectorOneToOneRewrite` doesn't properly check its types
and will generate bitcast between structs which is illegal.

With the opaque pointers this is a no-op operation for memref so we can
just add type check in `LLVM::detail::vectorOneToOneRewrite` and add a
separate pattern which removes op if converted types are the same.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Feb 24, 2025
…lvm#126939)

Reland llvm#125148

Limiting vector pattern caused issues with `select` of complex lowering,
which wasn't caught as it was missing lit tests. Keep the pattern as is
for now and instead set a higher benefit to `IdentityBitcastLowering` so
it will always run before the vector pattern.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants