Skip to content

[MLIR][ROCDL] Add Scale Convert Packed FP4 <-> F32/BF16/F16 Support for GFX950 #140676

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 118 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,10 @@ def ROCDL_V2I16Type : FixedVectorOfLengthAndType<[2], [I16]>,
BuildableType<"::mlir::VectorType::get("
"{2},$_builder.getI16Type())">;

def ROCDL_V2F32Type : FixedVectorOfLengthAndType<[2], [F32]>,
BuildableType<"::mlir::VectorType::get("
"{2},$_builder.getF32Type())">;

def ROCDL_V2F16Type : FixedVectorOfLengthAndType<[2], [F16]>,
BuildableType<"::mlir::VectorType::get("
"{2},$_builder.getF16Type())">;
Expand Down Expand Up @@ -1005,6 +1009,120 @@ def ROCDL_CvtScaleF32SrBf8F32Op :
}];
}

//===---------------------------------------------------------------------===//
// 4-bit float scale intrinsics
//===---------------------------------------------------------------------===//
def ROCDL_CvtScaleF32PkFp4F32Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp4.f32", [], [], [Pure], 1>,
Arguments<(ins I32:$old, F32:$src0, F32:$src1, F32: $scale, I32:$byteSel)> {
let summary = "Convert f32 to packed fp4 and scale";
let description = [{ Convert `src` based on $byteSe to packed fp4, then scale
the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src0 `,` $src1 `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}

def ROCDL_CvtScaleF32PkFp4F16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp4.f16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, ROCDL_V2F16Type:$src, F32: $scale, I32:$byteSel)> {
let summary = "Convert f16 to packed fp4 and scale";
let description = [{ Convert `src` based on $byteSel to packed fp4, then scale
the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}

def ROCDL_CvtScaleF32PkFp4Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.pk.fp4.bf16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, ROCDL_V2BF16Type:$src, F32: $scale, I32:$byteSel)> {
let summary = "Convert bf16 to packed fp4 and scale";
let description = [{ Convert `src` based on $byteSel to packed fp4, then scale
the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}

def ROCDL_CvtScaleF32SrPkFp4F32Op :
ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.f32", [], [], [Pure], 1>,
Arguments<(ins I32:$old, ROCDL_V2F32Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f32 to packed fp4 using stochastic rounding";
let description = [{
Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}

def ROCDL_CvtScaleF32SrPkFp4F16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.f16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, ROCDL_V2F16Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert f16 to packed fp4 using stochastic rounding";
let description = [{
Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}

def ROCDL_CvtScaleF32SrPkFp4Bf16Op :
ROCDL_IntrOp<"cvt.scalef32.sr.pk.fp4.bf16", [], [], [Pure], 1>,
Arguments<(ins I32:$old, ROCDL_V2BF16Type:$src, I32:$seed, F32: $scale, I32:$byteSel)> {
let summary = "Scale and convert bf16 to packed fp4 using stochastic rounding";
let description = [{
Scale `src` by the exponent in `scale` then convert to packed fp4 with stochastic rounding
using seed data in `seed`. store into the `byteSel`th byte of `old`, preserving the others.
}];
let assemblyFormat = [{
attr-dict $src `,` $seed `,` $scale `->` $old `[` $byteSel `]` `:` type($res)
}];
}

def ROCDL_CvtScaleF32PkF32Fp4Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f32.fp4", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
let summary = "Convert fp4 to packed f32 and scale";
let description = [{ Convert `src` based on $byteSel to packed f32, then scale
the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
}];
}


def ROCDL_CvtScaleF32PkF16Fp4Op :
ROCDL_IntrOp<"cvt.scalef32.pk.f16.fp4", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
let summary = "Convert fp4 to packed f16 and scale";
let description = [{ Convert `src` based on $byteSel to packed f16, then scale
the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
}];
}

def ROCDL_CvtScaleF32PkBf16Fp4Op :
ROCDL_IntrOp<"cvt.scalef32.pk.bf16.fp4", [], [], [Pure], 1>,
Arguments<(ins I32:$src, F32:$scale, I32:$byteSel)> {
let summary = "Convert fp4 to packed bf16 and scale";
let description = [{ Convert `src` based on $byteSel to packed bf16, then scale
the packed values by the exponent in `scale`.
}];
let assemblyFormat = [{
attr-dict $src `[` $byteSel `]` `,` $scale `:` type($res)
}];
}
//===---------------------------------------------------------------------===//
// 8-bit float intrinsics
//===---------------------------------------------------------------------===//
Expand Down
25 changes: 25 additions & 0 deletions mlir/test/Dialect/LLVMIR/rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,31 @@ llvm.func @rocdl_v2f16_v2i16(%source: vector<2xf16>, %source2: vector<2xbf16>, %
llvm.return %source_scaled : vector<2xi16>
}

llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %source: vector<2xf32>, %source_half: vector<2xf16>, %source_bfloat: vector<2xbf16>, %stoch: i32) -> i32 {
// CHECK-LABEL: @rocdl_4bit_packed_floats
// CHECK: rocdl.cvt.scalef32.pk.fp4.f32
// CHECK: rocdl.cvt.scalef32.pk.fp4.f16
// CHECK: rocdl.cvt.scalef32.pk.fp4.bf16
// CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f32
// CHECK: rocdl.cvt.scalef32.sr.pk.fp4.f16
// CHECK: rocdl.cvt.scalef32.sr.pk.fp4.bf16
// CHECK: rocdl.cvt.scalef32.pk.f32.fp4
// CHECK: rocdl.cvt.scalef32.pk.f16.fp4
// CHECK: rocdl.cvt.scalef32.pk.bf16.fp4
%c0 = llvm.mlir.constant(0 : i32) : i32
%scale = llvm.mlir.constant(1.0 : f32) : f32
%pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
%pk2 = rocdl.cvt.scalef32.pk.fp4.f16 %source_half, %scale -> %pk1[%c0] : i32
%pk3 = rocdl.cvt.scalef32.pk.fp4.bf16 %source_bfloat, %scale -> %pk2[%c0] : i32
%sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0] : i32
%sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
%sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
%pk4 = rocdl.cvt.scalef32.pk.f32.fp4 %old[%c0], %scale : vector<2xf32>
%pk5 = rocdl.cvt.scalef32.pk.f16.fp4 %old[%c0], %scale : vector<2xf16>
%pk6 = rocdl.cvt.scalef32.pk.bf16.fp4 %old[%c0], %scale : vector<2xbf16>
llvm.return %sr3 : i32
}

llvm.func @rocdl.s.waitcnt() {
// CHECK-LABEL: rocdl.s.waitcnt
// CHECK: rocdl.s.waitcnt 0
Expand Down
25 changes: 25 additions & 0 deletions mlir/test/Target/LLVMIR/rocdl.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1145,6 +1145,31 @@ llvm.func @rocdl_16bit_packed_floats(%sourceA: f32, %sourceB: f32) -> vector<2xf
llvm.return %source : vector<2xf16>
}

llvm.func @rocdl_4bit_packed_floats(%old: i32, %source0: f32, %source1: f32, %source: vector<2xf32>, %source_half: vector<2xf16>, %source_bfloat: vector<2xbf16>, %stoch: i32) -> i32 {
// CHECK-LABEL: @rocdl_4bit_packed_floats
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f32(i32 %0, float %1, float %2, float 1.000000e+00, i32 0)
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.f16(i32 %8, <2 x half> %4, float 1.000000e+00, i32 0)
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.pk.fp4.bf16(i32 %9, <2 x bfloat> %5, float 1.000000e+00, i32 0)
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f32(i32 %10, <2 x float> %3, i32 %6, float 1.000000e+00, i32 0)
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.f16(i32 %11, <2 x half> %4, i32 %6, float 1.000000e+00, i32 0)
// CHECK: call i32 @llvm.amdgcn.cvt.scalef32.sr.pk.fp4.bf16(i32 %12, <2 x bfloat> %5, i32 %6, float 1.000000e+00, i32 0)
// CHECK: call <2 x float> @llvm.amdgcn.cvt.scalef32.pk.f32.fp4(i32 %0, float 1.000000e+00, i32 0)
// CHECK: call <2 x half> @llvm.amdgcn.cvt.scalef32.pk.f16.fp4(i32 %0, float 1.000000e+00, i32 0)
// CHECK: call <2 x bfloat> @llvm.amdgcn.cvt.scalef32.pk.bf16.fp4(i32 %0, float 1.000000e+00, i32 0)
%c0 = llvm.mlir.constant(0 : i32) : i32
%scale = llvm.mlir.constant(1.0 : f32) : f32
%pk1 = rocdl.cvt.scalef32.pk.fp4.f32 %source0, %source1, %scale -> %old[%c0] : i32
%pk2 = rocdl.cvt.scalef32.pk.fp4.f16 %source_half, %scale -> %pk1[%c0] : i32
%pk3 = rocdl.cvt.scalef32.pk.fp4.bf16 %source_bfloat, %scale -> %pk2[%c0] : i32
%sr1 = rocdl.cvt.scalef32.sr.pk.fp4.f32 %source, %stoch, %scale -> %pk3[%c0] : i32
%sr2 = rocdl.cvt.scalef32.sr.pk.fp4.f16 %source_half, %stoch, %scale -> %sr1[%c0] : i32
%sr3 = rocdl.cvt.scalef32.sr.pk.fp4.bf16 %source_bfloat, %stoch, %scale -> %sr2[%c0] : i32
%pk4 = rocdl.cvt.scalef32.pk.f32.fp4 %old[%c0], %scale : vector<2xf32>
%pk5 = rocdl.cvt.scalef32.pk.f16.fp4 %old[%c0], %scale : vector<2xf16>
%pk6 = rocdl.cvt.scalef32.pk.bf16.fp4 %old[%c0], %scale : vector<2xbf16>
llvm.return %sr3 : i32
}

llvm.func @rocdl_atomic_attrs(%ptr: !llvm.ptr<1>, %data: f32) {
// CHECK-LABEL: @rocdl_atomic_attrs
// CHECK: atomicrmw
Expand Down