Skip to content

Commit eb6a336

Browse files
arsenmpravinjagtap
authored andcommitted
AMDGPU: Add basic verification for mfma scale intrinsics (llvm#117048)
Verify the format is valid and the type is one of the expected i32 vectors. Verify the used vector types at least cover the requirements of the corresponding format operand.
1 parent 8b5d4f3 commit eb6a336

File tree

3 files changed

+283
-6
lines changed

3 files changed

+283
-6
lines changed

llvm/include/llvm/IR/IntrinsicsAMDGPU.td

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2992,12 +2992,10 @@ class AMDGPUMfmaIntrinsic<LLVMType DestTy, LLVMType SrcABTy> :
29922992
// blgp.
29932993
//
29942994
// These should be <8 x i32> for f8 formats, <6 x i32> for f6 formats,
2995-
// and <4 x i32> for f4 formats. If the format control bits imply a
2996-
// smaller type than used, the high elements will be truncated.
2997-
//
2998-
// If the format control bits imply a larger type than used, the high
2999-
// elements are padded with undef.
3000-
2995+
// and <4 x i32> for f4 formats. It is invalid to use a format that
2996+
// requires more registers than the corresponding vector type (e.g. it
2997+
// is illegal to use <6 x i32> in operand 0 if cbsz specifies an f8
2998+
// format that requires 8 registers).
30012999
class AMDGPUMfmaScaleIntrinsic<LLVMType DestTy> :
30023000
DefaultAttrsIntrinsic<[DestTy],
30033001
[llvm_anyvector_ty, llvm_anyvector_ty, DestTy,

llvm/lib/IR/Verifier.cpp

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6369,6 +6369,55 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
63696369
"Value for inactive lanes must be a VGPR function argument", &Call);
63706370
break;
63716371
}
6372+
case Intrinsic::amdgcn_mfma_scale_f32_16x16x128_f8f6f4:
6373+
case Intrinsic::amdgcn_mfma_scale_f32_32x32x64_f8f6f4: {
6374+
Value *Src0 = Call.getArgOperand(0);
6375+
Value *Src1 = Call.getArgOperand(1);
6376+
6377+
uint64_t CBSZ = cast<ConstantInt>(Call.getArgOperand(3))->getZExtValue();
6378+
uint64_t BLGP = cast<ConstantInt>(Call.getArgOperand(4))->getZExtValue();
6379+
Check(CBSZ <= 4, "invalid value for cbsz format", Call,
6380+
Call.getArgOperand(3));
6381+
Check(BLGP <= 4, "invalid value for blgp format", Call,
6382+
Call.getArgOperand(4));
6383+
6384+
// AMDGPU::MFMAScaleFormats values
6385+
auto getFormatNumRegs = [](unsigned FormatVal) {
6386+
switch (FormatVal) {
6387+
case 0:
6388+
case 1:
6389+
return 8u;
6390+
case 2:
6391+
case 3:
6392+
return 6u;
6393+
case 4:
6394+
return 4u;
6395+
default:
6396+
llvm_unreachable("invalid format value");
6397+
}
6398+
};
6399+
6400+
auto isValidSrcASrcBVector = [](FixedVectorType *Ty) {
6401+
if (!Ty || !Ty->getElementType()->isIntegerTy(32))
6402+
return false;
6403+
unsigned NumElts = Ty->getNumElements();
6404+
return NumElts == 4 || NumElts == 6 || NumElts == 8;
6405+
};
6406+
6407+
auto *Src0Ty = dyn_cast<FixedVectorType>(Src0->getType());
6408+
auto *Src1Ty = dyn_cast<FixedVectorType>(Src1->getType());
6409+
Check(isValidSrcASrcBVector(Src0Ty),
6410+
"operand 0 must be 4, 6 or 8 element i32 vector", &Call, Src0);
6411+
Check(isValidSrcASrcBVector(Src1Ty),
6412+
"operand 1 must be 4, 6 or 8 element i32 vector", &Call, Src1);
6413+
6414+
// Permit excess registers for the format.
6415+
Check(Src0Ty->getNumElements() >= getFormatNumRegs(CBSZ),
6416+
"invalid vector type for format", &Call, Src0, Call.getArgOperand(3));
6417+
Check(Src1Ty->getNumElements() >= getFormatNumRegs(BLGP),
6418+
"invalid vector type for format", &Call, Src1, Call.getArgOperand(5));
6419+
break;
6420+
}
63726421
case Intrinsic::nvvm_setmaxnreg_inc_sync_aligned_u32:
63736422
case Intrinsic::nvvm_setmaxnreg_dec_sync_aligned_u32: {
63746423
Value *V = Call.getArgOperand(0);

0 commit comments

Comments
 (0)