-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[AMDGPU][wmma] - Add tied wmma intrinsic #69903
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -865,22 +865,26 @@ def WMMAOpcode3AddrMappingTable : WMMAMappingTable { | |
// it converts the default pseudo to the pseudo where src2 is not the same as vdst. | ||
// 3) @earlyclobber on the destination satisfies the constraint during RA. | ||
|
||
multiclass WMMAInst<string Suffix, string Instr, VOPProfile P, SDPatternOperator node = null_frag, RegisterOperand _Src01RC64 = VRegSrc_256, WMMAType Type> { | ||
multiclass WMMAInst<string Suffix, string Instr, VOPProfile P, SDPatternOperator node = null_frag, RegisterOperand _Src01RC64 = VRegSrc_256, WMMAType Type, bit convertibleTo3Addr> { | ||
|
||
defvar WMMAConstraints2Addr = "@earlyclobber $vdst,$vdst = $src2"; | ||
defvar WMMAConstraints3Addr = "@earlyclobber $vdst"; | ||
|
||
defvar WMMAProfile = VOPProfileWMMA<P, Suffix, _Src01RC64, Type.hasClamp, Type.hasOpsel>; | ||
let Mnemonic = Instr, mayRaiseFPException = 0, ReadsModeReg = 0 in { | ||
let Constraints = WMMAConstraints2Addr, isConvertibleToThreeAddress = 1 in { | ||
let Constraints = WMMAConstraints2Addr, isConvertibleToThreeAddress = convertibleTo3Addr in { | ||
def _twoaddr # Suffix : VOP3P_Pseudo<Instr # Suffix, WMMAProfile>; | ||
} | ||
let Constraints = WMMAConstraints3Addr, SchedRW = [Write32Bit, Write32Bit] in { | ||
def _threeaddr # Suffix : VOP3P_Pseudo<Instr # Suffix, WMMAProfile>; | ||
} | ||
} | ||
def : WMMAOpcodeMapping<!cast<Instruction>(NAME # _twoaddr # Suffix), | ||
if convertibleTo3Addr then { | ||
let Mnemonic = Instr, mayRaiseFPException = 0, ReadsModeReg = 0 in { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: instead of repeating this "let" clause here, I think it would be neater to put this new code inside the existing "let" clause, i.e. move it up to just before the closing brace on line 878. Yes this will mean that the "let" clause now applies to the WMMAOpcodeMapping part too, but that should be harmless. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the suggestion! I think I did that now, if I understood your comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I attempted this, but this caused an Error. Moving the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks, and sorry for my bad idea. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea itself was not bad! It would look better. I am also surprised that it does not work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's explained by the implementation of |
||
let Constraints = WMMAConstraints3Addr, SchedRW = [Write32Bit, Write32Bit] in { | ||
def _threeaddr # Suffix : VOP3P_Pseudo<Instr # Suffix, WMMAProfile>; | ||
} | ||
} | ||
def : WMMAOpcodeMapping<!cast<Instruction>(NAME # _twoaddr # Suffix), | ||
!cast<Instruction>(NAME # _threeaddr # Suffix)>; | ||
} | ||
|
||
if !eq(Type, WMMAOpSel) then { | ||
def : WMMAOpSelPat<!cast<Instruction>(NAME # _twoaddr # Suffix), node, P>; | ||
|
@@ -893,21 +897,25 @@ multiclass WMMAInst<string Suffix, string Instr, VOPProfile P, SDPatternOperator | |
|
||
|
||
let WaveSizePredicate = isWave32 in { | ||
defm V_WMMA_F32_16X16X16_F16 : WMMAInst<"_w32", "v_wmma_f32_16x16x16_f16", VOP_V8F32_V16F16_V16F16_V8F32, int_amdgcn_wmma_f32_16x16x16_f16, VRegSrc_256, WMMARegular>; | ||
defm V_WMMA_F32_16X16X16_BF16 : WMMAInst<"_w32", "v_wmma_f32_16x16x16_bf16", VOP_V8F32_V16I16_V16I16_V8F32, int_amdgcn_wmma_f32_16x16x16_bf16, VRegSrc_256, WMMARegular>; | ||
defm V_WMMA_F16_16X16X16_F16 : WMMAInst<"_w32", "v_wmma_f16_16x16x16_f16", VOP_V16F16_V16F16_V16F16_V16F16, int_amdgcn_wmma_f16_16x16x16_f16, VRegSrc_256, WMMAOpSel>; | ||
defm V_WMMA_BF16_16X16X16_BF16 : WMMAInst<"_w32", "v_wmma_bf16_16x16x16_bf16", VOP_V16I16_V16I16_V16I16_V16I16, int_amdgcn_wmma_bf16_16x16x16_bf16, VRegSrc_256, WMMAOpSel>; | ||
defm V_WMMA_I32_16X16X16_IU8 : WMMAInst<"_w32", "v_wmma_i32_16x16x16_iu8", VOP_V8I32_V4I32_V4I32_V8I32, int_amdgcn_wmma_i32_16x16x16_iu8, VRegSrc_128, WMMAUIClamp>; | ||
defm V_WMMA_I32_16X16X16_IU4 : WMMAInst<"_w32", "v_wmma_i32_16x16x16_iu4", VOP_V8I32_V2I32_V2I32_V8I32, int_amdgcn_wmma_i32_16x16x16_iu4, VRegSrc_64, WMMAUIClamp>; | ||
defm V_WMMA_F32_16X16X16_F16 : WMMAInst<"_w32", "v_wmma_f32_16x16x16_f16", VOP_V8F32_V16F16_V16F16_V8F32, int_amdgcn_wmma_f32_16x16x16_f16, VRegSrc_256, WMMARegular, 1>; | ||
defm V_WMMA_F32_16X16X16_BF16 : WMMAInst<"_w32", "v_wmma_f32_16x16x16_bf16", VOP_V8F32_V16I16_V16I16_V8F32, int_amdgcn_wmma_f32_16x16x16_bf16, VRegSrc_256, WMMARegular, 1>; | ||
defm V_WMMA_F16_16X16X16_F16 : WMMAInst<"_w32", "v_wmma_f16_16x16x16_f16", VOP_V16F16_V16F16_V16F16_V16F16, int_amdgcn_wmma_f16_16x16x16_f16, VRegSrc_256, WMMAOpSel, 1>; | ||
defm V_WMMA_BF16_16X16X16_BF16 : WMMAInst<"_w32", "v_wmma_bf16_16x16x16_bf16", VOP_V16I16_V16I16_V16I16_V16I16, int_amdgcn_wmma_bf16_16x16x16_bf16, VRegSrc_256, WMMAOpSel, 1>; | ||
defm V_WMMA_F16_16X16X16_F16_TIED : WMMAInst<"_w32", "v_wmma_f16_16x16x16_f16", VOP_V16F16_V16F16_V16F16_V16F16, int_amdgcn_wmma_f16_16x16x16_f16_tied, VRegSrc_256, WMMAOpSel, 0>; | ||
defm V_WMMA_BF16_16X16X16_BF16_TIED : WMMAInst<"_w32", "v_wmma_bf16_16x16x16_bf16", VOP_V16I16_V16I16_V16I16_V16I16, int_amdgcn_wmma_bf16_16x16x16_bf16_tied, VRegSrc_256, WMMAOpSel, 0>; | ||
defm V_WMMA_I32_16X16X16_IU8 : WMMAInst<"_w32", "v_wmma_i32_16x16x16_iu8", VOP_V8I32_V4I32_V4I32_V8I32, int_amdgcn_wmma_i32_16x16x16_iu8, VRegSrc_128, WMMAUIClamp, 1>; | ||
defm V_WMMA_I32_16X16X16_IU4 : WMMAInst<"_w32", "v_wmma_i32_16x16x16_iu4", VOP_V8I32_V2I32_V2I32_V8I32, int_amdgcn_wmma_i32_16x16x16_iu4, VRegSrc_64, WMMAUIClamp, 1>; | ||
} | ||
|
||
let WaveSizePredicate = isWave64 in { | ||
defm V_WMMA_F32_16X16X16_F16 : WMMAInst<"_w64", "v_wmma_f32_16x16x16_f16", VOP_V4F32_V16F16_V16F16_V4F32, int_amdgcn_wmma_f32_16x16x16_f16, VRegSrc_256, WMMARegular>; | ||
defm V_WMMA_F32_16X16X16_BF16 : WMMAInst<"_w64", "v_wmma_f32_16x16x16_bf16", VOP_V4F32_V16I16_V16I16_V4F32, int_amdgcn_wmma_f32_16x16x16_bf16, VRegSrc_256, WMMARegular>; | ||
defm V_WMMA_F16_16X16X16_F16 : WMMAInst<"_w64", "v_wmma_f16_16x16x16_f16", VOP_V8F16_V16F16_V16F16_V8F16, int_amdgcn_wmma_f16_16x16x16_f16, VRegSrc_256, WMMAOpSel>; | ||
defm V_WMMA_BF16_16X16X16_BF16 : WMMAInst<"_w64", "v_wmma_bf16_16x16x16_bf16", VOP_V8I16_V16I16_V16I16_V8I16, int_amdgcn_wmma_bf16_16x16x16_bf16, VRegSrc_256, WMMAOpSel>; | ||
defm V_WMMA_I32_16X16X16_IU8 : WMMAInst<"_w64", "v_wmma_i32_16x16x16_iu8", VOP_V4I32_V4I32_V4I32_V4I32, int_amdgcn_wmma_i32_16x16x16_iu8, VRegSrc_128, WMMAUIClamp>; | ||
defm V_WMMA_I32_16X16X16_IU4 : WMMAInst<"_w64", "v_wmma_i32_16x16x16_iu4", VOP_V4I32_V2I32_V2I32_V4I32, int_amdgcn_wmma_i32_16x16x16_iu4, VRegSrc_64, WMMAUIClamp>; | ||
defm V_WMMA_F32_16X16X16_F16 : WMMAInst<"_w64", "v_wmma_f32_16x16x16_f16", VOP_V4F32_V16F16_V16F16_V4F32, int_amdgcn_wmma_f32_16x16x16_f16, VRegSrc_256, WMMARegular, 1>; | ||
defm V_WMMA_F32_16X16X16_BF16 : WMMAInst<"_w64", "v_wmma_f32_16x16x16_bf16", VOP_V4F32_V16I16_V16I16_V4F32, int_amdgcn_wmma_f32_16x16x16_bf16, VRegSrc_256, WMMARegular, 1>; | ||
defm V_WMMA_F16_16X16X16_F16 : WMMAInst<"_w64", "v_wmma_f16_16x16x16_f16", VOP_V8F16_V16F16_V16F16_V8F16, int_amdgcn_wmma_f16_16x16x16_f16, VRegSrc_256, WMMAOpSel, 1>; | ||
defm V_WMMA_BF16_16X16X16_BF16 : WMMAInst<"_w64", "v_wmma_bf16_16x16x16_bf16", VOP_V8I16_V16I16_V16I16_V8I16, int_amdgcn_wmma_bf16_16x16x16_bf16, VRegSrc_256, WMMAOpSel, 1>; | ||
defm V_WMMA_F16_16X16X16_F16_TIED : WMMAInst<"_w64", "v_wmma_f16_16x16x16_f16", VOP_V8F16_V16F16_V16F16_V8F16, int_amdgcn_wmma_f16_16x16x16_f16_tied, VRegSrc_256, WMMAOpSel, 0>; | ||
defm V_WMMA_BF16_16X16X16_BF16_TIED : WMMAInst<"_w64", "v_wmma_bf16_16x16x16_bf16", VOP_V8I16_V16I16_V16I16_V8I16, int_amdgcn_wmma_bf16_16x16x16_bf16_tied, VRegSrc_256, WMMAOpSel, 0>; | ||
defm V_WMMA_I32_16X16X16_IU8 : WMMAInst<"_w64", "v_wmma_i32_16x16x16_iu8", VOP_V4I32_V4I32_V4I32_V4I32, int_amdgcn_wmma_i32_16x16x16_iu8, VRegSrc_128, WMMAUIClamp, 1>; | ||
defm V_WMMA_I32_16X16X16_IU4 : WMMAInst<"_w64", "v_wmma_i32_16x16x16_iu4", VOP_V4I32_V2I32_V2I32_V4I32, int_amdgcn_wmma_i32_16x16x16_iu4, VRegSrc_64, WMMAUIClamp, 1>; | ||
|
||
} | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.