Skip to content

[SYCL][Matrix] Switch to SPV_KHR_cooperative_matrix extension #13316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions clang/lib/CodeGen/CodeGenTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,23 @@ llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
"spirv.JointMatrixINTEL", {CompTy}, Params);
}

llvm::Type *
getCooperativeMatrixKHRExtType(llvm::Type *CompTy,
ArrayRef<TemplateArgument> TemplateArgs,
const unsigned Val = 0) {
assert(TemplateArgs.size() == 5 &&
"Wrong CooperativeMatrixKHR template parameters number");
std::vector<unsigned> Params;
for (size_t I = 1; I != TemplateArgs.size(); ++I) {
assert(TemplateArgs[I].getKind() == TemplateArgument::Integral &&
"Wrong CooperativeMatrixKHR template parameter");
Params.push_back(TemplateArgs[I].getAsIntegral().getExtValue());
}

return llvm::TargetExtType::get(
CompTy->getContext(), "spirv.CooperativeMatrixKHR", {CompTy}, Params);
}

/// ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V JointMatrixINTEL type.
Expand Down Expand Up @@ -363,6 +380,39 @@ llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
return getJointMatrixINTELExtType(CompTy, TemplateArgs);
}

/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
/// The expected representation is:
/// target("spirv.CooperativeMatrixKHR", %element_type, %scope%, %rows%, %cols%,
/// %use%)
llvm::Type *CodeGenTypes::ConvertSPVCooperativeMatrixType(RecordDecl *RD) {
auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
ArrayRef<TemplateArgument> TemplateArgs =
TemplateDecl->getTemplateArgs().asArray();
assert(TemplateArgs[0].getKind() == TemplateArgument::Type &&
"1st CooperativeMatrixKHR template parameter must be type");
llvm::Type *CompTy = ConvertType(TemplateArgs[0].getAsType());

if (CompTy->isStructTy()) {
StringRef LlvmTyName = CompTy->getStructName();
// Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
if (LlvmTyName.starts_with("class.sycl::") ||
LlvmTyName.starts_with("class.__sycl_internal::"))
LlvmTyName = LlvmTyName.rsplit("::").second;
if (LlvmTyName == "half") {
CompTy = llvm::Type::getHalfTy(getLLVMContext());
} else if (LlvmTyName == "tf32") {
CompTy = llvm::Type::getFloatTy(getLLVMContext());
} else if (LlvmTyName == "bfloat16") {
CompTy = llvm::Type::getInt16Ty(getLLVMContext());
} else {
llvm_unreachable("Wrong matrix base type!");
}
}
return getCooperativeMatrixKHRExtType(CompTy, TemplateArgs);
}

/// ConvertType - Convert the specified type to its LLVM form.
llvm::Type *CodeGenTypes::ConvertType(QualType T) {
T = Context.getCanonicalType(T);
Expand Down Expand Up @@ -654,6 +704,10 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
"__spv::__spirv_JointMatrixINTEL") {
ResultType = ConvertSYCLJointMatrixINTELType(RD);
break;
} else if (RD && RD->getQualifiedNameAsString() ==
"__spv::__spirv_CooperativeMatrixKHR") {
ResultType = ConvertSPVCooperativeMatrixType(RD);
break;
} else if (RD && RD->getQualifiedNameAsString() ==
"__spv::__spirv_TaskSequenceINTEL") {
ResultType = llvm::TargetExtType::get(getLLVMContext(),
Expand Down
9 changes: 9 additions & 0 deletions clang/lib/CodeGen/CodeGenTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,15 @@ class CodeGenTypes {
/// %use%, (optional) %element_type_interpretation%)
llvm::Type *ConvertSYCLJointMatrixINTELType(RecordDecl *RD);

/// ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
/// which is represented as a pointer to a structure to LLVM extension type
/// with the parameters that follow SPIR-V CooperativeMatrixKHR type.
/// The expected representation is:
/// target("spirv.CooperativeMatrixKHR", %element_type, %scope%, %rows%,
/// %cols%, %use%)
///
llvm::Type *ConvertSPVCooperativeMatrixType(RecordDecl *RD);

/// GetFunctionType - Get the LLVM function type for \arg Info.
llvm::FunctionType *GetFunctionType(const CGFunctionInfo &Info);

Expand Down
3 changes: 2 additions & 1 deletion clang/lib/Driver/ToolChains/Clang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10395,7 +10395,8 @@ void SPIRVTranslator::ConstructJob(Compilation &C, const JobAction &JA,
",+SPV_KHR_uniform_group_instructions"
",+SPV_INTEL_masked_gather_scatter"
",+SPV_INTEL_tensor_float32_conversion"
",+SPV_INTEL_optnone";
",+SPV_INTEL_optnone"
",+SPV_KHR_cooperative_matrix";
if (ShouldPreserveMetadata)
ExtArg += ",+SPV_KHR_non_semantic_info";
if (IsCPU)
Expand Down
41 changes: 41 additions & 0 deletions clang/test/CodeGenSYCL/cooperative_matrix.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// RUN: %clang_cc1 -triple spir64-unknown-unknown -disable-llvm-passes -emit-llvm %s -o - | FileCheck %s
// Test that SPIR-V codegen generates the expected LLVM struct name for the
// CooperativeMatrixKHR type.
#include <stddef.h>
#include <stdint.h>

namespace __spv {
template <typename T, uint32_t S, size_t R, size_t C, uint32_t U>
struct __spirv_CooperativeMatrixKHR;
}

// CHECK: @_Z2f1{{.*}}(target("spirv.CooperativeMatrixKHR", float, 3, 5, 10, 0)
void f1(__spv::__spirv_CooperativeMatrixKHR<float, 3, 5, 10, 0> *matrix) {}

// CHECK: @_Z2f2{{.*}}(target("spirv.CooperativeMatrixKHR", i64, 3, 10, 2, 1)
void f2(__spv::__spirv_CooperativeMatrixKHR<uint64_t, 3, 10, 2, 1> *matrix) {}

// CHECK: @_Z2f3{{.*}}(target("spirv.CooperativeMatrixKHR", i8, 3, 10, 2, 2)
void f3(__spv::__spirv_CooperativeMatrixKHR<char, 3, 10, 2, 2> *matrix) {}

namespace sycl {
class half {};
class bfloat16 {};
class tf32 {};
}
typedef sycl::half my_half;

// CHECK: @_Z2f4{{.*}}(target("spirv.CooperativeMatrixKHR", half, 3, 10, 2, 0)
void f4(__spv::__spirv_CooperativeMatrixKHR<my_half, 3, 10, 2, 0> *matrix) {}

// CHECK: @_Z2f5{{.*}}(target("spirv.CooperativeMatrixKHR", i16, 3, 10, 2, 0)
void f5(__spv::__spirv_CooperativeMatrixKHR<sycl::bfloat16, 3, 10, 2, 0> *matrix) {}

// CHECK: @_Z2f6{{.*}}(target("spirv.CooperativeMatrixKHR", i128, 3, 10, 2, 0)
void f6(__spv::__spirv_CooperativeMatrixKHR<_BitInt(128), 3, 10, 2, 0> *matrix) {}

// CHECK: @_Z2f7{{.*}}(target("spirv.CooperativeMatrixKHR", float, 3, 10, 2, 0)
void f7(__spv::__spirv_CooperativeMatrixKHR<sycl::tf32, 3, 10, 2, 0> *matrix) {}

// CHECK: @_Z2f8{{.*}}(target("spirv.CooperativeMatrixKHR", double, 3, 5, 10, 0)
void f8(__spv::__spirv_CooperativeMatrixKHR<double, 3, 5, 10, 0> *matrix) {}
4 changes: 3 additions & 1 deletion clang/test/Driver/sycl-spirv-ext.c
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@
// CHECK-DEFAULT-SAME:,+SPV_KHR_uniform_group_instructions
// CHECK-DEFAULT-SAME:,+SPV_INTEL_masked_gather_scatter
// CHECK-DEFAULT-SAME:,+SPV_INTEL_tensor_float32_conversion
// CHECK-DEFAULT-SAME:,+SPV_INTEL_optnone"
// CHECK-DEFAULT-SAME:,+SPV_INTEL_optnone
// CHECK-DEFAULT-SAME:,+SPV_KHR_cooperative_matrix"
// CHECK-FPGA-HW: llvm-spirv{{.*}}"-spirv-ext=-all
// CHECK-FPGA-HW-SAME:,+SPV_EXT_shader_atomic_float_add
// CHECK-FPGA-HW-SAME:,+SPV_EXT_shader_atomic_float_min_max
Expand Down Expand Up @@ -119,5 +120,6 @@
// CHECK-CPU-SAME:,+SPV_INTEL_masked_gather_scatter
// CHECK-CPU-SAME:,+SPV_INTEL_tensor_float32_conversion
// CHECK-CPU-SAME:,+SPV_INTEL_optnone
// CHECK-CPU-SAME:,+SPV_KHR_cooperative_matrix
// CHECK-CPU-SAME:,+SPV_INTEL_fp_max_error"

29 changes: 29 additions & 0 deletions llvm-spirv/lib/SPIRV/SPIRVRegularizeLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,34 @@ void SPIRVRegularizeLLVMBase::expandSYCLTypeUsing(Module *M) {
expandVIDWithSYCLTypeByValComp(F);
}

void SPIRVRegularizeLLVMBase::finishSROACooperativeMatrix(Module *M) {
for (auto &F : *M) {
if (!F.getName().starts_with("_Z19__spirv_AccessChain"))
continue;
for (auto I = F.user_begin(), E = F.user_end(); I != E;) {
if (auto *CI = dyn_cast<CallInst>(*I++)) {
Instruction *Ptr =
dyn_cast<Instruction>(CI->getArgOperand(0)->stripPointerCasts());
StructType *WrapperMatrixTy =
dyn_cast<StructType>(cast<AllocaInst>(Ptr)->getAllocatedType());
if (!WrapperMatrixTy)
continue;
Type *MatrixTy = WrapperMatrixTy->getElementType(0);
AllocaInst *Alloca = nullptr;
{
IRBuilder Builder(CI);
IRBuilderBase::InsertPointGuard IG(Builder);
Builder.SetInsertPointPastAllocas(CI->getParent()->getParent());
Alloca = Builder.CreateAlloca(MatrixTy);
}
Ptr->replaceAllUsesWith(Alloca);
Ptr->dropAllReferences();
Ptr->eraseFromParent();
}
}
}
}

bool SPIRVRegularizeLLVMBase::runRegularizeLLVM(Module &Module) {
M = &Module;
Ctx = &M->getContext();
Expand Down Expand Up @@ -464,6 +492,7 @@ void regularizeWithOverflowInstrinsics(StringRef MangledName, CallInst *Call,
bool SPIRVRegularizeLLVMBase::regularize() {
eraseUselessFunctions(M);
expandSYCLTypeUsing(M);
finishSROACooperativeMatrix(M);

for (auto &GV : M->globals()) {
SPIRVBuiltinVariableKind Kind;
Expand Down
1 change: 1 addition & 0 deletions llvm-spirv/lib/SPIRV/SPIRVRegularizeLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class SPIRVRegularizeLLVMBase {
void expandSYCLTypeUsing(llvm::Module *M);
void expandVEDWithSYCLTypeSRetArg(llvm::Function *F);
void expandVIDWithSYCLTypeByValComp(llvm::Function *F);
void finishSROACooperativeMatrix(llvm::Module *M);

// According to the specification, the operands of a shift instruction must be
// a scalar/vector of integer. When LLVM-IR contains a shift instruction with
Expand Down
42 changes: 42 additions & 0 deletions llvm-spirv/lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6506,6 +6506,48 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI,
transValue(CI->getArgOperand(1), BB), MemoryAccess,
BB);
}
/* case OpAccessChain: {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to remove

Instruction *Ptr = dyn_cast<Instruction>(CI->getArgOperand(0)->stripPointerCasts());
// auto Class = transType(
// CI->getArgOperand(0)->getType())->getPointerStorageClass();
auto Class = StorageClassFunction;
SPIRVValue *Base = ValueMap[Ptr];
SPIRVType *SPVMatTy = Base->getType()->getPointerElementType();
if (!SPVMatTy->isTypeCooperativeMatrixKHR()) {
Type *MatrixTy = nullptr;
for (User *PtrU : Ptr->users()) {
auto *Inst = dyn_cast<Instruction>(PtrU);
Inst = dyn_cast<Instruction>(Inst->stripPointerCasts());
if (isa<StoreInst>(Inst)) {
MatrixTy = cast<StoreInst>(Inst)->getValueOperand()->getType();
break;
}
if (isa<LoadInst>(Inst)) {
MatrixTy = cast<LoadInst>(Inst)->getType();
break;
}
}
Type *MatrixTy = cast<StructType>(cast<AllocaInst>(Ptr)->getAllocatedType())->getElementType(0);
assert(MatrixTy && "not a matrix type");
SPVMatTy = transType(MatrixTy);
SPIRVType *PtrToMatrixType = BM->addPointerType(Class, SPVMatTy);
auto *EntryBB = BB->getParent()->getBasicBlock(0);
Base = BM->addVariable(
PtrToMatrixType, false, spv::internal::LinkageTypeInternal, nullptr,
Ptr->getName().str() + "_matrix", Class, EntryBB);
ValueMap[Ptr] = Base;
// auto *Base = BM->addUnaryInst(OpBitcast, PtrToMatrixType,
// transValue(CI->getArgOperand(0), BB), BB);
}
std::vector<SPIRVValue *> Indices;
for (size_t I = 1; I < CI->arg_size(); ++I) {
Indices.emplace_back(transValue(CI->getArgOperand(I), BB));
}
SPIRVType *ElemTy =
static_cast<SPIRVTypeCooperativeMatrixKHR *>(SPVMatTy)->getCompType();
SPIRVType *ElemPtrTy = BM->addPointerType(Class, ElemTy);
return BM->addAccessChainInst(ElemPtrTy, Base, Indices, BB, false);
}*/
case OpCompositeConstruct: {
std::vector<SPIRVId> Operands = {
transValue(CI->getArgOperand(0), BB)->getId()};
Expand Down
14 changes: 14 additions & 0 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ class SPIRVModuleImpl : public SPIRVModule {
SPIRVInstruction *addPtrAccessChainInst(SPIRVType *, SPIRVValue *,
std::vector<SPIRVValue *>,
SPIRVBasicBlock *, bool) override;
SPIRVInstruction *addAccessChainInst(SPIRVType *, SPIRVValue *,
std::vector<SPIRVValue *>,
SPIRVBasicBlock *, bool) override;
SPIRVInstruction *addAsyncGroupCopy(SPIRVValue *Scope, SPIRVValue *Dest,
SPIRVValue *Src, SPIRVValue *NumElems,
SPIRVValue *Stride, SPIRVValue *Event,
Expand Down Expand Up @@ -1703,6 +1706,17 @@ SPIRVModuleImpl::addPtrAccessChainInst(SPIRVType *Type, SPIRVValue *Base,
BB);
}

SPIRVInstruction *
SPIRVModuleImpl::addAccessChainInst(SPIRVType *Type, SPIRVValue *Base,
std::vector<SPIRVValue *> Indices,
SPIRVBasicBlock *BB, bool IsInBounds) {
return addInstruction(
SPIRVInstTemplateBase::create(
IsInBounds ? OpInBoundsAccessChain : OpAccessChain, Type,
getId(), getVec(Base->getId(), Base->getIds(Indices)), BB, this),
BB);
}

SPIRVInstruction *SPIRVModuleImpl::addAsyncGroupCopy(
SPIRVValue *Scope, SPIRVValue *Dest, SPIRVValue *Src, SPIRVValue *NumElems,
SPIRVValue *Stride, SPIRVValue *Event, SPIRVBasicBlock *BB) {
Expand Down
3 changes: 3 additions & 0 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ class SPIRVModule {
virtual SPIRVInstruction *addPtrAccessChainInst(SPIRVType *, SPIRVValue *,
std::vector<SPIRVValue *>,
SPIRVBasicBlock *, bool) = 0;
virtual SPIRVInstruction *addAccessChainInst(SPIRVType *, SPIRVValue *,
std::vector<SPIRVValue *>,
SPIRVBasicBlock *, bool) = 0;
virtual SPIRVInstruction *
addAsyncGroupCopy(SPIRVValue *Scope, SPIRVValue *Dest, SPIRVValue *Src,
SPIRVValue *NumElems, SPIRVValue *Stride, SPIRVValue *Event,
Expand Down
Loading