Skip to content

Commit f0ed46e

Browse files
fix AddrSpaceCast
1 parent 06fd8df commit f0ed46e

File tree

6 files changed

+161
-28
lines changed

6 files changed

+161
-28
lines changed

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,11 @@ SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
11281128
SPIRVType *Type = getSPIRVTypeForVReg(VReg);
11291129
assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
11301130
Type->getOperand(1).isImm() && "Pointer type is expected");
1131+
return getPointerStorageClass(Type);
1132+
}
1133+
1134+
SPIRV::StorageClass::StorageClass
1135+
SPIRVGlobalRegistry::getPointerStorageClass(const SPIRVType *Type) const {
11311136
return static_cast<SPIRV::StorageClass::StorageClass>(
11321137
Type->getOperand(1).getImm());
11331138
}

llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,8 @@ class SPIRVGlobalRegistry {
388388

389389
// Gets the storage class of the pointer type assigned to this vreg.
390390
SPIRV::StorageClass::StorageClass getPointerStorageClass(Register VReg) const;
391+
SPIRV::StorageClass::StorageClass
392+
getPointerStorageClass(const SPIRVType *Type) const;
391393

392394
// Return the number of bits SPIR-V pointers and size_t variables require.
393395
unsigned getPointerSize() const { return PointerSize; }

llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp

Lines changed: 76 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,36 +1156,87 @@ static bool isUSMStorageClass(SPIRV::StorageClass::StorageClass SC) {
11561156
bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
11571157
const SPIRVType *ResType,
11581158
MachineInstr &I) const {
1159-
// If the AddrSpaceCast user is single and in OpConstantComposite or
1160-
// OpVariable, we should select OpSpecConstantOp.
1161-
auto UIs = MRI->use_instructions(ResVReg);
1162-
if (!UIs.empty() && ++UIs.begin() == UIs.end() &&
1163-
(UIs.begin()->getOpcode() == SPIRV::OpConstantComposite ||
1164-
UIs.begin()->getOpcode() == SPIRV::OpVariable ||
1165-
isSpvIntrinsic(*UIs.begin(), Intrinsic::spv_init_global))) {
1166-
Register NewReg = I.getOperand(1).getReg();
1167-
MachineBasicBlock &BB = *I.getParent();
1168-
SPIRVType *SpvBaseTy = GR.getOrCreateSPIRVIntegerType(8, I, TII);
1169-
ResType = GR.getOrCreateSPIRVPointerType(SpvBaseTy, I, TII,
1170-
SPIRV::StorageClass::Generic);
1171-
bool Result =
1172-
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpSpecConstantOp))
1173-
.addDef(ResVReg)
1174-
.addUse(GR.getSPIRVTypeID(ResType))
1175-
.addImm(static_cast<uint32_t>(SPIRV::Opcode::PtrCastToGeneric))
1176-
.addUse(NewReg)
1177-
.constrainAllUses(TII, TRI, RBI);
1178-
return Result;
1179-
}
1159+
MachineBasicBlock &BB = *I.getParent();
1160+
const DebugLoc &DL = I.getDebugLoc();
1161+
11801162
Register SrcPtr = I.getOperand(1).getReg();
11811163
SPIRVType *SrcPtrTy = GR.getSPIRVTypeForVReg(SrcPtr);
1182-
SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtr);
1183-
SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass(ResVReg);
1164+
// don't generate a cast for a null that is represented by OpTypeInt
1165+
if (SrcPtrTy->getOpcode() != SPIRV::OpTypePointer ||
1166+
ResType->getOpcode() != SPIRV::OpTypePointer)
1167+
return BuildMI(BB, I, DL, TII.get(TargetOpcode::COPY))
1168+
.addDef(ResVReg)
1169+
.addUse(SrcPtr)
1170+
.constrainAllUses(TII, TRI, RBI);
1171+
1172+
SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass(SrcPtrTy);
1173+
SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass(ResType);
1174+
1175+
// AddrSpaceCast uses within OpVariable and OpConstantComposite instructions
1176+
// are expressed by OpSpecConstantOp with an Opcode.
1177+
bool IsGRef = false;
1178+
bool IsAllowedRefs =
1179+
std::all_of(MRI->use_instr_begin(ResVReg), MRI->use_instr_end(),
1180+
[&IsGRef](auto const &It) {
1181+
unsigned Opcode = It.getOpcode();
1182+
if (Opcode == SPIRV::OpConstantComposite ||
1183+
Opcode == SPIRV::OpVariable ||
1184+
isSpvIntrinsic(It, Intrinsic::spv_init_global))
1185+
return IsGRef = true;
1186+
return Opcode == SPIRV::OpName;
1187+
});
1188+
if (IsAllowedRefs && IsGRef) {
1189+
// TODO: insert a check whether the Kernel capability was declared.
1190+
unsigned SpecOpcode =
1191+
DstSC == SPIRV::StorageClass::Generic && isGenericCastablePtr(SrcSC)
1192+
? static_cast<uint32_t>(SPIRV::Opcode::PtrCastToGeneric)
1193+
: (SrcSC == SPIRV::StorageClass::Generic &&
1194+
isGenericCastablePtr(DstSC)
1195+
? static_cast<uint32_t>(SPIRV::Opcode::GenericCastToPtr)
1196+
: 0);
1197+
if (SpecOpcode) {
1198+
// TODO: OpConstantComposite expects i8*, so we are forced to forget a
1199+
// correct value of ResType and use general i8* instead. Maybe this should
1200+
// be addressed in the emit-intrinsic step to infer a correct
1201+
// OpConstantComposite type.
1202+
SPIRVType *NewResType = GR.getOrCreateSPIRVPointerType(
1203+
GR.getOrCreateSPIRVIntegerType(8, I, TII), I, TII, DstSC);
1204+
bool Result = BuildMI(BB, I, DL, TII.get(SPIRV::OpSpecConstantOp))
1205+
.addDef(ResVReg)
1206+
.addUse(GR.getSPIRVTypeID(NewResType))
1207+
.addImm(SpecOpcode)
1208+
.addUse(SrcPtr)
1209+
.constrainAllUses(TII, TRI, RBI);
1210+
return Result;
1211+
} else if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
1212+
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
1213+
GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
1214+
Register Tmp = MRI->createVirtualRegister(&SPIRV::pIDRegClass);
1215+
MRI->setType(Tmp, LLT::pointer(0, 64));
1216+
GR.assignSPIRVTypeToVReg(GenericPtrTy, Tmp, *BB.getParent());
1217+
MachineInstrBuilder MIB =
1218+
BuildMI(BB, I, DL, TII.get(SPIRV::OpSpecConstantOp))
1219+
.addDef(Tmp)
1220+
.addUse(GR.getSPIRVTypeID(GenericPtrTy))
1221+
.addImm(static_cast<uint32_t>(SPIRV::Opcode::PtrCastToGeneric))
1222+
.addUse(SrcPtr);
1223+
GR.add(MIB.getInstr(), BB.getParent(), Tmp);
1224+
bool Result = MIB.constrainAllUses(TII, TRI, RBI);
1225+
SPIRVType *NewResType = GR.getOrCreateSPIRVPointerType(
1226+
GR.getOrCreateSPIRVIntegerType(8, I, TII), I, TII, DstSC);
1227+
return Result &&
1228+
BuildMI(BB, I, DL, TII.get(SPIRV::OpSpecConstantOp))
1229+
.addDef(ResVReg)
1230+
.addUse(GR.getSPIRVTypeID(NewResType))
1231+
.addImm(static_cast<uint32_t>(SPIRV::Opcode::GenericCastToPtr))
1232+
.addUse(Tmp)
1233+
.constrainAllUses(TII, TRI, RBI);
1234+
}
1235+
}
11841236

11851237
// don't generate a cast between identical storage classes
11861238
if (SrcSC == DstSC)
1187-
return BuildMI(*I.getParent(), I, I.getDebugLoc(),
1188-
TII.get(TargetOpcode::COPY))
1239+
return BuildMI(BB, I, DL, TII.get(TargetOpcode::COPY))
11891240
.addDef(ResVReg)
11901241
.addUse(SrcPtr)
11911242
.constrainAllUses(TII, TRI, RBI);
@@ -1201,8 +1252,6 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
12011252
Register Tmp = MRI->createVirtualRegister(&SPIRV::iIDRegClass);
12021253
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
12031254
GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
1204-
MachineBasicBlock &BB = *I.getParent();
1205-
const DebugLoc &DL = I.getDebugLoc();
12061255
bool Success = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric))
12071256
.addDef(Tmp)
12081257
.addUse(GR.getSPIRVTypeID(GenericPtrTy))

llvm/lib/Target/SPIRV/SPIRVPreLegalizer.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,21 @@ static SPIRVType *propagateSPIRVType(MachineInstr *MI, SPIRVGlobalRegistry *GR,
294294
default:
295295
break;
296296
}
297-
if (SpvType)
297+
if (SpvType) {
298+
// check if the address space needs correction
299+
LLT RegType = MRI.getType(Reg);
300+
if (SpvType->getOpcode() == SPIRV::OpTypePointer &&
301+
RegType.isPointer() &&
302+
storageClassToAddressSpace(GR->getPointerStorageClass(SpvType)) !=
303+
RegType.getAddressSpace()) {
304+
const SPIRVSubtarget &ST =
305+
MI->getParent()->getParent()->getSubtarget<SPIRVSubtarget>();
306+
SpvType = GR->getOrCreateSPIRVPointerType(
307+
GR->getPointeeType(SpvType), *MI, *ST.getInstrInfo(),
308+
addressSpaceToStorageClass(RegType.getAddressSpace(), ST));
309+
}
298310
GR->assignSPIRVTypeToVReg(SpvType, Reg, MIB.getMF());
311+
}
299312
if (!MRI.getRegClassOrNull(Reg))
300313
MRI.setRegClass(Reg, SpvType ? GR->getRegClass(SpvType)
301314
: &SPIRV::iIDRegClass);

llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1631,6 +1631,7 @@ multiclass OpcodeOperand<bits<32> value> {
16311631
defm InBoundsAccessChain : OpcodeOperand<66>;
16321632
defm InBoundsPtrAccessChain : OpcodeOperand<70>;
16331633
defm PtrCastToGeneric : OpcodeOperand<121>;
1634+
defm GenericCastToPtr : OpcodeOperand<122>;
16341635
defm Bitcast : OpcodeOperand<124>;
16351636
defm ConvertPtrToU : OpcodeOperand<117>;
16361637
defm ConvertUToPtr : OpcodeOperand<120>;
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv64-unknown-unknown %s -o - -filetype=obj | spirv-val %}
3+
4+
; CHECK-DAG: OpName %[[F:.*]] "F"
5+
; CHECK-DAG: OpName %[[B:.*]] "B"
6+
; CHECK-DAG: OpName %[[G1:.*]] "G1"
7+
; CHECK-DAG: OpName %[[G2:.*]] "G2"
8+
; CHECK-DAG: OpName %[[X:.*]] "X"
9+
; CHECK-DAG: OpName %[[Y:.*]] "Y"
10+
; CHECK-DAG: OpName %[[G3:.*]] "G3"
11+
; CHECK-DAG: OpName %[[G4:.*]] "G4"
12+
13+
; CHECK-DAG: %[[Int:.*]] = OpTypeInt 32 0
14+
; CHECK-DAG: %[[Char:.*]] = OpTypeInt 8 0
15+
; CHECK-DAG: %[[GenPtrChar:.*]] = OpTypePointer Generic %[[Char]]
16+
; CHECK-DAG: %[[CWPtrChar:.*]] = OpTypePointer CrossWorkgroup %[[Char]]
17+
; CHECK-DAG: %[[Arr1:.*]] = OpTypeArray %[[CWPtrChar]] %[[#]]
18+
; CHECK-DAG: %[[Struct1:.*]] = OpTypeStruct %8
19+
; CHECK-DAG: %[[Arr2:.*]] = OpTypeArray %[[GenPtrChar]] %[[#]]
20+
; CHECK-DAG: %[[Struct2:.*]] = OpTypeStruct %[[Arr2]]
21+
; CHECK-DAG: %[[GenPtr:.*]] = OpTypePointer Generic %[[Int]]
22+
; CHECK-DAG: %[[CWPtr:.*]] = OpTypePointer CrossWorkgroup %[[Int]]
23+
; CHECK-DAG: %[[WPtr:.*]] = OpTypePointer Workgroup %[[Int]]
24+
25+
; CHECK-DAG: %[[F]] = OpVariable %[[CWPtr]] CrossWorkgroup %[[#]]
26+
; CHECK-DAG: %[[GenF:.*]] = OpSpecConstantOp %[[GenPtrChar]] 121 %[[F]]
27+
; CHECK-DAG: %[[B]] = OpVariable %[[CWPtr]] CrossWorkgroup %[[#]]
28+
; CHECK-DAG: %[[GenB:.*]] = OpSpecConstantOp %[[GenPtrChar]] 121 %[[B]]
29+
; CHECK-DAG: %[[GenFB:.*]] = OpConstantComposite %[[Arr2]] %[[GenF]] %[[GenB]]
30+
; CHECK-DAG: %[[GenBF:.*]] = OpConstantComposite %[[Arr2]] %[[GenB]] %[[GenF]]
31+
; CHECK-DAG: %[[CG1:.*]] = OpConstantComposite %[[Struct2]] %[[GenFB]]
32+
; CHECK-DAG: %[[CG2:.*]] = OpConstantComposite %[[Struct2]] %[[GenBF]]
33+
34+
; CHECK-DAG: %[[X]] = OpVariable %[[WPtr]] Workgroup %[[#]]
35+
; CHECK-DAG: %[[GenX:.*]] = OpSpecConstantOp %[[GenPtr]] 121 %[[X]]
36+
; CHECK-DAG: %[[CWX:.*]] = OpSpecConstantOp %[[CWPtrChar]] 122 %[[GenX]]
37+
; CHECK-DAG: %[[Y]] = OpVariable %[[WPtr]] Workgroup %[[#]]
38+
; CHECK-DAG: %[[GenY:.*]] = OpSpecConstantOp %[[GenPtr]] 121 %[[Y]]
39+
; CHECK-DAG: %[[CWY:.*]] = OpSpecConstantOp %[[CWPtrChar]] 122 %[[GenY]]
40+
; CHECK-DAG: %[[CWXY:.*]] = OpConstantComposite %[[Arr1]] %[[CWX]] %[[CWY]]
41+
; CHECK-DAG: %[[CWYX:.*]] = OpConstantComposite %[[Arr1]] %[[CWY]] %[[CWX]]
42+
; CHECK-DAG: %[[CG3:.*]] = OpConstantComposite %[[Struct1]] %[[CWXY]]
43+
; CHECK-DAG: %[[CG4:.*]] = OpConstantComposite %[[Struct1]] %[[CWYX]]
44+
45+
; CHECK-DAG: %[[G4]] = OpVariable %[[#]] CrossWorkgroup %[[CG4]]
46+
; CHECK-DAG: %[[G3]] = OpVariable %[[#]] CrossWorkgroup %[[CG3]]
47+
; CHECK-DAG: %[[G2]] = OpVariable %[[#]] CrossWorkgroup %[[CG2]]
48+
; CHECK-DAG: %[[G1]] = OpVariable %[[#]] CrossWorkgroup %[[CG1]]
49+
50+
@F = addrspace(1) constant i32 0
51+
@B = addrspace(1) constant i32 1
52+
@G1 = addrspace(1) constant { [2 x ptr addrspace(4)] } { [2 x ptr addrspace(4)] [ptr addrspace(4) addrspacecast (ptr addrspace(1) @F to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(1) @B to ptr addrspace(4))] }
53+
@G2 = addrspace(1) constant { [2 x ptr addrspace(4)] } { [2 x ptr addrspace(4)] [ptr addrspace(4) addrspacecast (ptr addrspace(1) @B to ptr addrspace(4)), ptr addrspace(4) addrspacecast (ptr addrspace(1) @F to ptr addrspace(4))] }
54+
55+
@X = addrspace(3) constant i32 0
56+
@Y = addrspace(3) constant i32 1
57+
@G3 = addrspace(1) constant { [2 x ptr addrspace(1)] } { [2 x ptr addrspace(1)] [ptr addrspace(1) addrspacecast (ptr addrspace(3) @X to ptr addrspace(1)), ptr addrspace(1) addrspacecast (ptr addrspace(3) @Y to ptr addrspace(1))] }
58+
@G4 = addrspace(1) constant { [2 x ptr addrspace(1)] } { [2 x ptr addrspace(1)] [ptr addrspace(1) addrspacecast (ptr addrspace(3) @Y to ptr addrspace(1)), ptr addrspace(1) addrspacecast (ptr addrspace(3) @X to ptr addrspace(1))] }
59+
60+
define void @foo() {
61+
entry:
62+
ret void
63+
}

0 commit comments

Comments
 (0)