Skip to content

[DO NOT MERGE][DO NOT REVIEW] Specialization constant #1084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm-spirv/include/LLVMSPIRVLib.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ bool writeSpirv(Module *M, const SPIRV::TranslatorOpts &Opts, std::ostream &OS,
bool readSpirv(LLVMContext &C, const SPIRV::TranslatorOpts &Opts,
std::istream &IS, Module *&M, std::string &ErrMsg);

using SpecConstInfoTy = std::pair<uint32_t, uint32_t>;
void getSpecConstInfo(std::istream &IS,
std::vector<SpecConstInfoTy> &SpecConstInfo);

/// \brief Convert a SPIRVModule into LLVM IR.
/// \returns null on failure.
std::unique_ptr<Module>
Expand Down
14 changes: 14 additions & 0 deletions llvm-spirv/include/LLVMSPIRVOpts.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <cassert>
#include <cstdint>
#include <map>
#include <unordered_map>

namespace SPIRV {

Expand Down Expand Up @@ -99,12 +100,25 @@ class TranslatorOpts {

void enableGenArgNameMD() { GenKernelArgNameMD = true; }

void setSpecConst(uint32_t SpecId, uint64_t SpecValue) {
ExternalSpecialization[SpecId] = SpecValue;
}

bool getSpecializationConstant(uint32_t SpecId, uint64_t &Value) const {
auto It = ExternalSpecialization.find(SpecId);
if (It == ExternalSpecialization.end())
return false;
Value = It->second;
return true;
}

private:
// Common translation options
VersionNumber MaxVersion = VersionNumber::MaximumVersion;
ExtensionsStatusMap ExtStatusMap;
// SPIR-V to LLVM translation options
bool GenKernelArgNameMD;
std::unordered_map<uint32_t, uint64_t> ExternalSpecialization;
};

} // namespace SPIRV
Expand Down
87 changes: 79 additions & 8 deletions llvm-spirv/lib/SPIRV/SPIRVReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1327,15 +1327,28 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,

// Translation of non-instruction values
switch (OC) {
case OpConstant: {
case OpConstant:
case OpSpecConstant: {
SPIRVConstant *BConst = static_cast<SPIRVConstant *>(BV);
SPIRVType *BT = BV->getType();
Type *LT = transType(BT);
uint64_t ConstValue = BConst->getZExtIntValue();
SPIRVWord SpecId = 0;
if (OC == OpSpecConstant && BV->hasDecorate(DecorationSpecId, 0, &SpecId)) {
// Update the value with possibly provided external specialization.
if (BM->getSpecializationConstant(SpecId, ConstValue)) {
assert(
(BT->getBitWidth() == 64 ||
(ConstValue >> BT->getBitWidth()) == 0) &&
"Size of externally provided specialization constant value doesn't"
"fit into the specialization constant type");
}
}
switch (BT->getOpCode()) {
case OpTypeBool:
case OpTypeInt:
return mapValue(
BV, ConstantInt::get(LT, BConst->getZExtIntValue(),
BV, ConstantInt::get(LT, ConstValue,
static_cast<SPIRVTypeInt *>(BT)->isSigned()));
case OpTypeFloat: {
const llvm::fltSemantics *FS = nullptr;
Expand All @@ -1350,12 +1363,10 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
FS = &APFloat::IEEEdouble();
break;
default:
llvm_unreachable("invalid float type");
llvm_unreachable("invalid floating-point type");
}
return mapValue(
BV, ConstantFP::get(*Context,
APFloat(*FS, APInt(BT->getFloatBitWidth(),
BConst->getZExtIntValue()))));
APFloat FPConstValue(*FS, APInt(BT->getFloatBitWidth(), ConstValue));
return mapValue(BV, ConstantFP::get(*Context, FPConstValue));
}
default:
llvm_unreachable("Not implemented");
Expand All @@ -1369,12 +1380,27 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
case OpConstantFalse:
return mapValue(BV, ConstantInt::getFalse(*Context));

case OpSpecConstantTrue:
case OpSpecConstantFalse: {
bool IsTrue = OC == OpSpecConstantTrue;
SPIRVWord SpecId = 0;
if (BV->hasDecorate(DecorationSpecId, 0, &SpecId)) {
uint64_t ConstValue = 0;
if (BM->getSpecializationConstant(SpecId, ConstValue)) {
IsTrue = ConstValue;
}
}
return mapValue(BV, IsTrue ? ConstantInt::getTrue(*Context)
: ConstantInt::getFalse(*Context));
}

case OpConstantNull: {
auto LT = transType(BV->getType());
return mapValue(BV, Constant::getNullValue(LT));
}

case OpConstantComposite: {
case OpConstantComposite:
case OpSpecConstantComposite: {
auto BCC = static_cast<SPIRVConstantComposite *>(BV);
std::vector<Constant *> CV;
for (auto &I : BCC->getElements())
Expand Down Expand Up @@ -3622,3 +3648,48 @@ bool llvm::readSpirv(LLVMContext &C, const SPIRV::TranslatorOpts &Opts,

return true;
}

void llvm::getSpecConstInfo(std::istream &IS,
std::vector<SpecConstInfoTy> &SpecConstInfo) {
std::unique_ptr<SPIRVModule> BM(SPIRVModule::createSPIRVModule());
BM->setAutoAddExtensions(false);
SPIRVDecoder D(IS, *BM);
SPIRVWord Magic;
D >> Magic;
if (!BM->getErrorLog().checkError(Magic == MagicNumber, SPIRVEC_InvalidModule,
"invalid magic number")) {
return;
}
// Skip the rest of the header
D.ignore(4);

// According to the logical layout of SPIRV module (p2.4 of the spec),
// all constant instructions must appear before function declarations.
while (D.OpCode != OpFunction && D.getWordCountAndOpCode()) {
switch (D.OpCode) {
case OpDecorate:
// The decoration is added to the module in scope of SPIRVDecorate::decode
D.getEntry();
break;
case OpTypeBool:
case OpTypeInt:
case OpTypeFloat:
BM->addEntry(D.getEntry());
break;
case OpSpecConstant:
case OpSpecConstantTrue:
case OpSpecConstantFalse: {
auto *C = BM->addConstant(static_cast<SPIRVValue *>(D.getEntry()));
SPIRVWord SpecConstIdLiteral = 0;
if (C->hasDecorate(DecorationSpecId, 0, &SpecConstIdLiteral)) {
SPIRVType *Ty = C->getType();
uint32_t SpecConstSize = Ty->isTypeBool() ? 1 : Ty->getBitWidth() / 8;
SpecConstInfo.emplace_back(SpecConstIdLiteral, SpecConstSize);
}
break;
}
default:
D.ignoreInstruction();
}
}
}
27 changes: 26 additions & 1 deletion llvm-spirv/lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1749,9 +1749,12 @@ SPIRVValue *LLVMToSPIRV::transDirectCallInst(CallInst *CI,
return oclTransSpvcCastSampler(CI, BB);

if (oclIsBuiltin(MangledName, &DemangledName) ||
isDecoratedSPIRVFunc(F, &DemangledName))
isDecoratedSPIRVFunc(F, &DemangledName)) {
if (auto BV = transBuiltinToConstant(DemangledName, CI))
return BV;
if (auto BV = transBuiltinToInst(DemangledName, MangledName, CI, BB))
return BV;
}

SmallVector<std::string, 2> Dec;
if (isBuiltinTransToExtInst(CI->getCalledFunction(), &ExtSetKind, &ExtOp,
Expand Down Expand Up @@ -2042,6 +2045,28 @@ void LLVMToSPIRV::oclGetMutatedArgumentTypesByBuiltin(
ChangedType[1] = getSamplerType(F->getParent());
}

SPIRVValue *
LLVMToSPIRV::transBuiltinToConstant(const std::string &DemangledName,
CallInst *CI) {
Op OC = getSPIRVFuncOC(DemangledName);
if(!isSpecConstantOpCode(OC))
return nullptr;
Type *Ty = CI->getArgOperand(1)->getType();
assert(Ty == CI->getType() && "Type mismatch!");
Value *V = CI->getArgOperand(1);
uint64_t Val = 0;
if(Ty->isIntegerTy())
Val = cast<ConstantInt>(V)->getZExtValue();
else if(Ty->isFloatingPointTy())
Val = cast<ConstantFP>(V)->getValueAPF().bitcastToAPInt().getZExtValue();
else
return nullptr;
SPIRVValue *SC = BM->addSpecConstant(transType(Ty), Val);
uint64_t SpecId = cast<ConstantInt>(CI->getArgOperand(0))->getZExtValue();
SC->addDecorate(DecorationSpecId, SpecId);
return SC;
}

SPIRVInstruction *
LLVMToSPIRV::transBuiltinToInst(const std::string &DemangledName,
const std::string &MangledName, CallInst *CI,
Expand Down
2 changes: 2 additions & 0 deletions llvm-spirv/lib/SPIRV/SPIRVWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class LLVMToSPIRV : public ModulePass {
SmallVectorImpl<std::string> *Dec = nullptr);
bool oclIsKernel(Function *F);
bool transOCLKernelMetadata();
SPIRVValue *transBuiltinToConstant(const std::string &DemangledName,
CallInst *CI);
SPIRVInstruction *transBuiltinToInst(const std::string &DemangledName,
const std::string &MangledName,
CallInst *CI, SPIRVBasicBlock *BB);
Expand Down
4 changes: 0 additions & 4 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVEntry.h
Original file line number Diff line number Diff line change
Expand Up @@ -771,10 +771,6 @@ template <spv::Op OC> bool isa(SPIRVEntry *E) {
_SPIRV_OP(Nop)
_SPIRV_OP(SourceContinued)
_SPIRV_OP(TypeRuntimeArray)
_SPIRV_OP(SpecConstantTrue)
_SPIRV_OP(SpecConstantFalse)
_SPIRV_OP(SpecConstant)
_SPIRV_OP(SpecConstantComposite)
_SPIRV_OP(Image)
_SPIRV_OP(ImageTexelPointer)
_SPIRV_OP(ImageSampleDrefImplicitLod)
Expand Down
11 changes: 11 additions & 0 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ class SPIRVModuleImpl : public SPIRVModule {
const std::vector<SPIRVValue *> &) override;
SPIRVValue *addConstant(SPIRVValue *) override;
SPIRVValue *addConstant(SPIRVType *, uint64_t) override;
SPIRVValue *addSpecConstant(SPIRVType *, uint64_t) override;
SPIRVValue *addDoubleConstant(SPIRVTypeFloat *, double) override;
SPIRVValue *addFloatConstant(SPIRVTypeFloat *, float) override;
SPIRVValue *addIntegerConstant(SPIRVTypeInt *, uint64_t) override;
Expand Down Expand Up @@ -1047,6 +1048,16 @@ SPIRVValue *SPIRVModuleImpl::addUndef(SPIRVType *TheType) {
return addConstant(new SPIRVUndef(this, TheType, getId()));
}

SPIRVValue *SPIRVModuleImpl::addSpecConstant(SPIRVType *Ty, uint64_t V) {
if (Ty->isTypeBool()) {
if (V)
return add(new SPIRVSpecConstantTrue(this, Ty, getId()));
else
return add(new SPIRVSpecConstantFalse(this, Ty, getId()));
}
return add(new SPIRVSpecConstant(this, Ty, getId(), V));
}

// Instruction creation functions

SPIRVInstruction *
Expand Down
9 changes: 8 additions & 1 deletion llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@

namespace SPIRV {

template <Op> class SPIRVConstantBase;
using SPIRVConstant = SPIRVConstantBase<OpConstant>;

class SPIRVBasicBlock;
class SPIRVConstant;
class SPIRVEntry;
class SPIRVFunction;
class SPIRVInstruction;
Expand Down Expand Up @@ -245,6 +247,7 @@ class SPIRVModule {
addCompositeConstant(SPIRVType *, const std::vector<SPIRVValue *> &) = 0;
virtual SPIRVValue *addConstant(SPIRVValue *) = 0;
virtual SPIRVValue *addConstant(SPIRVType *, uint64_t) = 0;
virtual SPIRVValue *addSpecConstant(SPIRVType *, uint64_t) = 0;
virtual SPIRVValue *addDoubleConstant(SPIRVTypeFloat *, double) = 0;
virtual SPIRVValue *addFloatConstant(SPIRVTypeFloat *, float) = 0;
virtual SPIRVValue *addIntegerConstant(SPIRVTypeInt *, uint64_t) = 0;
Expand Down Expand Up @@ -449,6 +452,10 @@ class SPIRVModule {
return TranslationOpts.isGenArgNameMDEnabled();
}

bool getSpecializationConstant(SPIRVWord SpecId, uint64_t &ConstValue) {
return TranslationOpts.getSpecializationConstant(SpecId, ConstValue);
}

// I/O functions
friend spv_ostream &operator<<(spv_ostream &O, SPIRVModule &M);
friend std::istream &operator>>(std::istream &I, SPIRVModule &M);
Expand Down
5 changes: 5 additions & 0 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVOpCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,11 @@ inline bool isTypeOpCode(Op OpCode) {
isSubgroupAvcINTELTypeOpCode(OpCode) || OC == OpTypeVmeImageINTEL;
}

inline bool isSpecConstantOpCode(Op OpCode) {
unsigned OC = OpCode;
return OpSpecConstantTrue <= OC && OC <= OpSpecConstantOp;
}

inline bool isConstantOpCode(Op OpCode) {
unsigned OC = OpCode;
return (OpConstantTrue <= OC && OC <= OpSpecConstantOp) || OC == OpUndef ||
Expand Down
22 changes: 22 additions & 0 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVStream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
#include "SPIRVNameMapEnum.h"
#include "SPIRVOpCode.h"

#include <limits> // std::numeric_limits

namespace SPIRV {

/// Write string with quote. Replace " with \".
Expand Down Expand Up @@ -256,6 +258,12 @@ SPIRVEntry *SPIRVDecoder::getEntry() {
}
}

if (!M.getErrorLog().checkError(Entry->isImplemented(),
SPIRVEC_UnimplementedOpCode,
std::to_string(Entry->getOpCode()))) {
M.setInvalid();
}

assert(!IS.bad() && !IS.fail() && "SPIRV stream fails");
return Entry;
}
Expand All @@ -266,6 +274,20 @@ void SPIRVDecoder::validate() const {
assert(!IS.bad() && "Bad iInput stream");
}

// Skip \param n words in SPIR-V binary stream.
// In case of SPIR-V text format always skip until the end of the line.
void SPIRVDecoder::ignore(size_t N) {
#ifdef _SPIRV_SUPPORT_TEXT_FMT
if (SPIRVUseTextFormat) {
IS.ignore(std::numeric_limits<std::streamsize>::max(), '\n');
return;
}
#endif
IS.ignore(N * sizeof(SPIRVWord));
}

void SPIRVDecoder::ignoreInstruction() { ignore(WordCount - 1); }

spv_ostream &operator<<(spv_ostream &O, const SPIRVNL &E) {
#ifdef _SPIRV_SUPPORT_TEXT_FMT
if (SPIRVUseTextFormat)
Expand Down
2 changes: 2 additions & 0 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ class SPIRVDecoder {
bool getWordCountAndOpCode();
SPIRVEntry *getEntry();
void validate() const;
void ignore(size_t N);
void ignoreInstruction();

std::istream &IS;
SPIRVModule &M;
Expand Down
3 changes: 1 addition & 2 deletions llvm-spirv/lib/SPIRV/libSPIRV/SPIRVType.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class SPIRVTypeFloat : public SPIRVType {
SPIRVCapVec CV;
if (isTypeFloat(16)) {
CV.push_back(CapabilityFloat16Buffer);
auto Extensions = getModule()->getExtension();
auto Extensions = getModule()->getSourceExtension();
if (std::any_of(Extensions.begin(), Extensions.end(),
[](const std::string &I) { return I == "cl_khr_fp16"; }))
CV.push_back(CapabilityFloat16);
Expand Down Expand Up @@ -355,7 +355,6 @@ class SPIRVTypeMatrix : public SPIRVType {
SPIRVWord ColCount; // Column Count
};

class SPIRVConstant;
class SPIRVTypeArray : public SPIRVType {
public:
// Complete constructor
Expand Down
Loading