Skip to content

Allow fixed vector operand for LLVM_AtomicRMWOp #110553

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1737,7 +1737,8 @@ def LLVM_ConstantOp
// Atomic operations.
//

def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger]>;
def LLVM_AtomicRMWType
: AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyPointer, AnySignlessInteger, LLVM_AnyFixedVector]>;

def LLVM_AtomicRMWOp : LLVM_MemAccessOpBase<"atomicrmw", [
TypesMatchWith<"result #0 and operand #1 have the same type",
Expand Down
10 changes: 9 additions & 1 deletion mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3010,8 +3010,16 @@ LogicalResult AtomicRMWOp::verify() {
auto valType = getVal().getType();
if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub ||
getBinOp() == AtomicBinOp::fmin || getBinOp() == AtomicBinOp::fmax) {
if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
if (isCompatibleVectorType(valType)) {
if (isScalableVectorType(valType))
return emitOpError("expected LLVM IR fixed vector type");
Type elemType = getVectorElementType(valType);
if (!isCompatibleFloatingPointType(elemType))
return emitOpError(
"expected LLVM IR floating point type for vector element");
} else if (!isCompatibleFloatingPointType(valType)) {
return emitOpError("expected LLVM IR floating point type");
}
} else if (getBinOp() == AtomicBinOp::xchg) {
DataLayout dataLayout = DataLayout::closest(*this);
if (!isTypeCompatibleWithAtomicOp(valType, dataLayout))
Expand Down
15 changes: 15 additions & 0 deletions mlir/test/Dialect/LLVMIR/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,21 @@ func.func @atomicrmw_expected_float(%i32_ptr : !llvm.ptr, %i32 : i32) {

// -----

func.func @atomicrmw_scalable_vector(%ptr : !llvm.ptr, %f32_vec : vector<[2]xf32>) {
// expected-error@+1 {{'val' must be floating point LLVM type or LLVM pointer type or signless integer or LLVM dialect-compatible fixed-length vector type}}
%0 = llvm.atomicrmw fadd %ptr, %f32_vec unordered : !llvm.ptr, vector<[2]xf32>
llvm.return
}
// -----

func.func @atomicrmw_vector_expected_float(%ptr : !llvm.ptr, %i32_vec : vector<3xi32>) {
// expected-error@+1 {{expected LLVM IR floating point type for vector element}}
%0 = llvm.atomicrmw fadd %ptr, %i32_vec unordered : !llvm.ptr, vector<3xi32>
llvm.return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test with a scalable vector as well?

Copy link
Contributor Author

@joviliast joviliast Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please check https://github.com/llvm/llvm-project/pull/110553/files#diff-ff9a14cb96ea30dc57bad4dc2c44b34d54d57a25777288c26f305279e387f1a1R646
But it verifies tablegen introduced rule, not verifier.
Verifier checks scalability second time. In my opinion it is more consistant to have such check in verifier then don't have. WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see there is also LLVM_AnyFixedVector. I think it is ok to keep the redundant check yes.

}

// -----

func.func @atomicrmw_unexpected_xchg_type(%i1_ptr : !llvm.ptr, %i1 : i1) {
// expected-error@+1 {{unexpected LLVM IR type for 'xchg' bin_op}}
%0 = llvm.atomicrmw xchg %i1_ptr, %i1 unordered : !llvm.ptr, i1
Expand Down
8 changes: 5 additions & 3 deletions mlir/test/Dialect/LLVMIR/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -420,11 +420,13 @@ func.func @atomic_store(%val : f32, %large_val : i256, %ptr : !llvm.ptr) {
}

// CHECK-LABEL: @atomicrmw
func.func @atomicrmw(%ptr : !llvm.ptr, %val : f32) {
func.func @atomicrmw(%ptr : !llvm.ptr, %f32 : f32, %f16_vec : vector<2xf16>) {
// CHECK: llvm.atomicrmw fadd %{{.*}}, %{{.*}} monotonic : !llvm.ptr, f32
%0 = llvm.atomicrmw fadd %ptr, %val monotonic : !llvm.ptr, f32
%0 = llvm.atomicrmw fadd %ptr, %f32 monotonic : !llvm.ptr, f32
// CHECK: llvm.atomicrmw volatile fsub %{{.*}}, %{{.*}} syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32
%1 = llvm.atomicrmw volatile fsub %ptr, %val syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32
%1 = llvm.atomicrmw volatile fsub %ptr, %f32 syncscope("singlethread") monotonic {alignment = 16 : i64} : !llvm.ptr, f32
// CHECK: llvm.atomicrmw fmin %{{.*}}, %{{.*}} monotonic : !llvm.ptr, vector<2xf16>
%2 = llvm.atomicrmw fmin %ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
llvm.return
}

Expand Down
13 changes: 11 additions & 2 deletions mlir/test/Target/LLVMIR/llvmir.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,8 @@ llvm.func @elements_constant_3d_array() -> !llvm.array<2 x array<2 x array<2 x i
// CHECK-LABEL: @atomicrmw
llvm.func @atomicrmw(
%f32_ptr : !llvm.ptr, %f32 : f32,
%i32_ptr : !llvm.ptr, %i32 : i32) {
%i32_ptr : !llvm.ptr, %i32 : i32,
%f16_vec_ptr : !llvm.ptr, %f16_vec : vector<2xf16>) {
// CHECK: atomicrmw fadd ptr %{{.*}}, float %{{.*}} monotonic
%0 = llvm.atomicrmw fadd %f32_ptr, %f32 monotonic : !llvm.ptr, f32
// CHECK: atomicrmw fsub ptr %{{.*}}, float %{{.*}} monotonic
Expand Down Expand Up @@ -1535,11 +1536,19 @@ llvm.func @atomicrmw(
%17 = llvm.atomicrmw usub_cond %i32_ptr, %i32 monotonic : !llvm.ptr, i32
// CHECK: atomicrmw usub_sat ptr %{{.*}}, i32 %{{.*}} monotonic
%18 = llvm.atomicrmw usub_sat %i32_ptr, %i32 monotonic : !llvm.ptr, i32
// CHECK: atomicrmw fadd ptr %{{.*}}, <2 x half> %{{.*}} monotonic
%19 = llvm.atomicrmw fadd %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also test fmin/fmax/fsub with vector.

Also the scalar cases are supported, as well as bfloat

Copy link
Contributor Author

@joviliast joviliast Oct 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also test fmin/fmax/fsub with vector.

I'm not sure about fmin/fmax/fsub, so currently I filtered other operations here.
I suggest just to fix a title and description for now. WDYT?

Also the scalar cases are supported, as well as bfloat

Bfloat also should be tested (done), agree, but scalar cases is out of scope of this PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're applying a bunch of restrictions to these that simply do not exist in the underlying IR. FP vectors are supported for all the FP operations (except xchg, for now, which doesn't really count)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally the LLVM dialect verifiers should follow the LLVM IR verifiers as closely as possible:
https://github.com/llvm/llvm-project/blob/6c331e50e4bfb4158d16ec3fe17ad7bb5c739e9f/llvm/lib/IR/Verifier.cpp#L4332C1-L4333C1
-> seems to contain the relevant code for this PR.

At the moment, the LLVM dialect verifiers are still very incomplete since it is mostly a lowering dialect in MLIR. However, we should try to avoid having verifiers that fail on correct LLVM IR (as it obviously was the case before your PR).

// CHECK: atomicrmw fsub ptr %{{.*}}, <2 x half> %{{.*}} monotonic
%20 = llvm.atomicrmw fsub %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
// CHECK: atomicrmw fmax ptr %{{.*}}, <2 x half> %{{.*}} monotonic
%21 = llvm.atomicrmw fmax %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>
// CHECK: atomicrmw fmin ptr %{{.*}}, <2 x half> %{{.*}} monotonic
%22 = llvm.atomicrmw fmin %f16_vec_ptr, %f16_vec monotonic : !llvm.ptr, vector<2xf16>

// CHECK: atomicrmw volatile
// CHECK-SAME: syncscope("singlethread")
// CHECK-SAME: align 8
%19 = llvm.atomicrmw volatile udec_wrap %i32_ptr, %i32 syncscope("singlethread") monotonic {alignment = 8 : i64} : !llvm.ptr, i32
%23 = llvm.atomicrmw volatile udec_wrap %i32_ptr, %i32 syncscope("singlethread") monotonic {alignment = 8 : i64} : !llvm.ptr, i32
llvm.return
}

Expand Down
Loading