Skip to content

Commit ff49e34

Browse files
MrSidimsvmaksimo
authored andcommitted
Fix translation of APInt constants
Previously APInt constants were being stored into uint64_t value with following encoding/decoding. Now they are being packed into SPIRVWords array directly. Signed-off-by: Dmitry Sidorov <[email protected]>
1 parent ee7aded commit ff49e34

File tree

7 files changed

+75
-13
lines changed

7 files changed

+75
-13
lines changed

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1497,10 +1497,27 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
14971497
}
14981498
switch (BT->getOpCode()) {
14991499
case OpTypeBool:
1500-
case OpTypeInt:
1500+
case OpTypeInt: {
1501+
const unsigned NumBits = BT->getBitWidth();
1502+
if (NumBits > 64) {
1503+
// Translate arbitrary precision integer constants
1504+
const unsigned RawDataNumWords = BConst->getNumWords();
1505+
const unsigned BigValNumWords = (RawDataNumWords + 1) / 2;
1506+
std::vector<uint64_t> BigValVec(BigValNumWords);
1507+
const SPIRVWord *RawData = BConst->getSPIRVWords();
1508+
// SPIRV words are integers of 32-bit width, meanwhile llvm::APInt
1509+
// is storing data using an array of 64-bit words. Here we pack SPIRV
1510+
// words into 64-bit integer array.
1511+
for (size_t I = 0; I != RawDataNumWords; ++I)
1512+
BigValVec[I / 2] =
1513+
(I % 2) ? BigValVec[I / 2] | ((uint64_t)RawData[I] << 32)
1514+
: BigValVec[I / 2] | ((uint64_t)RawData[I]);
1515+
return mapValue(BV, ConstantInt::get(LT, APInt(NumBits, BigValVec)));
1516+
}
15011517
return mapValue(
15021518
BV, ConstantInt::get(LT, ConstValue,
15031519
static_cast<SPIRVTypeInt *>(BT)->isSigned()));
1520+
}
15041521
case OpTypeFloat: {
15051522
const llvm::fltSemantics *FS = nullptr;
15061523
switch (BT->getFloatBitWidth()) {

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -684,8 +684,17 @@ SPIRVValue *LLVMToSPIRV::transConstant(Value *V) {
684684
return BM->addNullConstant(transType(AggType));
685685
}
686686

687-
if (auto ConstI = dyn_cast<ConstantInt>(V))
687+
if (auto ConstI = dyn_cast<ConstantInt>(V)) {
688+
unsigned BitWidth = ConstI->getType()->getBitWidth();
689+
if (BitWidth > 64) {
690+
BM->getErrorLog().checkError(
691+
BM->isAllowedToUseExtension(
692+
ExtensionID::SPV_INTEL_arbitrary_precision_integers),
693+
SPIRVEC_InvalidBitWidth, std::to_string(BitWidth));
694+
return BM->addConstant(transType(V->getType()), ConstI->getValue());
695+
}
688696
return BM->addConstant(transType(V->getType()), ConstI->getZExtValue());
697+
}
689698

690699
if (auto ConstFP = dyn_cast<ConstantFP>(V)) {
691700
auto BT = static_cast<SPIRVType *>(transType(V->getType()));

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
#include "SPIRVType.h"
4949
#include "SPIRVValue.h"
5050

51+
#include "llvm/ADT/APInt.h"
52+
5153
#include <set>
5254
#include <unordered_map>
5355
#include <unordered_set>
@@ -261,6 +263,7 @@ class SPIRVModuleImpl : public SPIRVModule {
261263
SPIRVFunction *F) override;
262264
SPIRVValue *addConstant(SPIRVValue *) override;
263265
SPIRVValue *addConstant(SPIRVType *, uint64_t) override;
266+
SPIRVValue *addConstant(SPIRVType *, llvm::APInt) override;
264267
SPIRVValue *addSpecConstant(SPIRVType *, uint64_t) override;
265268
SPIRVValue *addDoubleConstant(SPIRVTypeFloat *, double) override;
266269
SPIRVValue *addFloatConstant(SPIRVTypeFloat *, float) override;
@@ -1021,6 +1024,25 @@ SPIRVValue *SPIRVModuleImpl::addConstant(SPIRVType *Ty, uint64_t V) {
10211024
return addConstant(new SPIRVConstant(this, Ty, getId(), V));
10221025
}
10231026

1027+
// Complete constructor for AP integer constant
1028+
template <spv::Op OC>
1029+
SPIRVConstantBase<OC>::SPIRVConstantBase(SPIRVModule *M, SPIRVType *TheType,
1030+
SPIRVId TheId, llvm::APInt &TheValue)
1031+
: SPIRVValue(M, 0, OC, TheType, TheId) {
1032+
const uint64_t *BigValArr = TheValue.getRawData();
1033+
for (size_t I = 0; I != TheValue.getNumWords(); ++I) {
1034+
Union.Words[I * 2 + 1] =
1035+
(uint32_t)((BigValArr[I] & 0xFFFFFFFF00000000LL) >> 32);
1036+
Union.Words[I * 2] = (uint32_t)(BigValArr[I] & 0xFFFFFFFFLL);
1037+
}
1038+
recalculateWordCount();
1039+
validate();
1040+
}
1041+
1042+
SPIRVValue *SPIRVModuleImpl::addConstant(SPIRVType *Ty, llvm::APInt V) {
1043+
return addConstant(new SPIRVConstant(this, Ty, getId(), V));
1044+
}
1045+
10241046
SPIRVValue *SPIRVModuleImpl::addIntegerConstant(SPIRVTypeInt *Ty, uint64_t V) {
10251047
if (Ty->getBitWidth() == 32) {
10261048
unsigned I32 = static_cast<unsigned>(V);

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,10 @@
4848
#include <string>
4949
#include <vector>
5050

51+
namespace llvm {
52+
class APInt;
53+
} // namespace llvm
54+
5155
namespace SPIRV {
5256

5357
template <Op> class SPIRVConstantBase;
@@ -252,6 +256,7 @@ class SPIRVModule {
252256
SPIRVFunction *F) = 0;
253257
virtual SPIRVValue *addConstant(SPIRVValue *) = 0;
254258
virtual SPIRVValue *addConstant(SPIRVType *, uint64_t) = 0;
259+
virtual SPIRVValue *addConstant(SPIRVType *, llvm::APInt) = 0;
255260
virtual SPIRVValue *addSpecConstant(SPIRVType *, uint64_t) = 0;
256261
virtual SPIRVValue *addDoubleConstant(SPIRVTypeFloat *, double) = 0;
257262
virtual SPIRVValue *addFloatConstant(SPIRVTypeFloat *, float) = 0;

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ class SPIRVTypeInt : public SPIRVType {
185185
(BitWidth <= 64 ||
186186
(Module->isAllowedToUseExtension(
187187
ExtensionID::SPV_INTEL_arbitrary_precision_integers) &&
188-
BitWidth <= 1024)) &&
188+
BitWidth <= 2048)) &&
189189
"Invalid bit width");
190190
}
191191

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVValue.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@
4747
#include "SPIRVEntry.h"
4848
#include "SPIRVType.h"
4949

50+
namespace llvm {
51+
class APInt;
52+
} // namespace llvm
53+
5054
#include <iostream>
5155

5256
namespace SPIRV {
@@ -146,6 +150,9 @@ template <spv::Op OC> class SPIRVConstantBase : public SPIRVValue {
146150
recalculateWordCount();
147151
validate();
148152
}
153+
// Incomplete constructor for AP integer constant
154+
SPIRVConstantBase(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
155+
llvm::APInt &TheValue);
149156
// Complete constructor for float constant
150157
SPIRVConstantBase(SPIRVModule *M, SPIRVType *TheType, SPIRVId TheId,
151158
float TheValue)
@@ -167,6 +174,8 @@ template <spv::Op OC> class SPIRVConstantBase : public SPIRVValue {
167174
uint64_t getZExtIntValue() const { return Union.UInt64Val; }
168175
float getFloatValue() const { return Union.FloatVal; }
169176
double getDoubleValue() const { return Union.DoubleVal; }
177+
unsigned getNumWords() const { return NumWords; }
178+
SPIRVWord *getSPIRVWords() { return Union.Words; }
170179

171180
protected:
172181
void recalculateWordCount() {
@@ -175,7 +184,7 @@ template <spv::Op OC> class SPIRVConstantBase : public SPIRVValue {
175184
}
176185
void validate() const override {
177186
SPIRVValue::validate();
178-
assert(NumWords >= 1 && NumWords <= 32 && "Invalid constant size");
187+
assert(NumWords >= 1 && NumWords <= 64 && "Invalid constant size");
179188
}
180189
void encode(spv_ostream &O) const override {
181190
getEncoder(O) << Type << Id;
@@ -197,7 +206,7 @@ template <spv::Op OC> class SPIRVConstantBase : public SPIRVValue {
197206
uint64_t UInt64Val;
198207
float FloatVal;
199208
double DoubleVal;
200-
SPIRVWord Words[32];
209+
SPIRVWord Words[64];
201210
UnionType() { UInt64Val = 0; }
202211
} Union;
203212
};

llvm-spirv/test/capability-arbitrary-precision-integers.ll

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818
; CHECK-SPIRV-DAG: TypeInt [[#I96:]] 96 0
1919
; CHECK-SPIRV-DAG: TypeInt [[#I128:]] 128 0
2020
; CHECK-SPIRV-DAG: TypeInt [[#I256:]] 256 0
21-
; CHECK-SPIRV-DAG: TypeInt [[#I1024:]] 1024 0
22-
; CHECK-SPIRV-DAG: Constant [[#I96]] [[#]] 1 0 0
21+
; CHECK-SPIRV-DAG: TypeInt [[#I2048:]] 2048 0
22+
; CHECK-SPIRV-DAG: Constant [[#I96]] [[#]] 4 0 1
2323
; CHECK-SPIRV-DAG: Constant [[#I128]] [[#]] 1 0 0 0
2424
; CHECK-SPIRV-DAG: Constant [[#I256]] [[#]] 1 0 0 0 0 0 0 0
25-
; CHECK-SPIRV-DAG: Constant [[#I1024]] [[#]] 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
25+
; CHECK-SPIRV-DAG: Constant [[#I2048]] [[#]] 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
2626

2727
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
2828
target triple = "spir64-unknown-unknown"
@@ -36,7 +36,7 @@ target triple = "spir64-unknown-unknown"
3636
@d = addrspace(1) global i96 0, align 8
3737
@e = addrspace(1) global i128 0, align 8
3838
@f = addrspace(1) global i256 0, align 8
39-
@g = addrspace(1) global i1024 0, align 8
39+
@g = addrspace(1) global i2048 0, align 8
4040

4141
; Function Attrs: noinline nounwind optnone
4242
; CHECK-LLVM: void @_Z4funci(i30 %a)
@@ -50,13 +50,13 @@ entry:
5050
store i30 1, i30* %a.addr, align 4
5151
; CHECK-LLVM: store i48 -4294901761, i48 addrspace(1)* @c
5252
store i48 -4294901761, i48 addrspace(1)* @c, align 8
53-
store i96 1, i96 addrspace(1)* @d, align 8
54-
; CHECK-LLVM: store i96 1, i96 addrspace(1)* @d
53+
store i96 18446744073709551620, i96 addrspace(1)* @d, align 8
54+
; CHECK-LLVM: store i96 18446744073709551620, i96 addrspace(1)* @d
5555
store i128 1, i128 addrspace(1)* @e, align 8
5656
; CHECK-LLVM: store i128 1, i128 addrspace(1)* @e
5757
store i256 1, i256 addrspace(1)* @f, align 8
5858
; CHECK-LLVM: store i256 1, i256 addrspace(1)* @f
59-
store i1024 1, i1024 addrspace(1)* @g, align 8
60-
; CHECK-LLVM: store i1024 1, i1024 addrspace(1)* @g
59+
store i2048 1, i2048 addrspace(1)* @g, align 8
60+
; CHECK-LLVM: store i2048 1, i2048 addrspace(1)* @g
6161
ret void
6262
}

0 commit comments

Comments
 (0)