@@ -1156,36 +1156,87 @@ static bool isUSMStorageClass(SPIRV::StorageClass::StorageClass SC) {
1156
1156
bool SPIRVInstructionSelector::selectAddrSpaceCast (Register ResVReg,
1157
1157
const SPIRVType *ResType,
1158
1158
MachineInstr &I) const {
1159
- // If the AddrSpaceCast user is single and in OpConstantComposite or
1160
- // OpVariable, we should select OpSpecConstantOp.
1161
- auto UIs = MRI->use_instructions (ResVReg);
1162
- if (!UIs.empty () && ++UIs.begin () == UIs.end () &&
1163
- (UIs.begin ()->getOpcode () == SPIRV::OpConstantComposite ||
1164
- UIs.begin ()->getOpcode () == SPIRV::OpVariable ||
1165
- isSpvIntrinsic (*UIs.begin (), Intrinsic::spv_init_global))) {
1166
- Register NewReg = I.getOperand (1 ).getReg ();
1167
- MachineBasicBlock &BB = *I.getParent ();
1168
- SPIRVType *SpvBaseTy = GR.getOrCreateSPIRVIntegerType (8 , I, TII);
1169
- ResType = GR.getOrCreateSPIRVPointerType (SpvBaseTy, I, TII,
1170
- SPIRV::StorageClass::Generic);
1171
- bool Result =
1172
- BuildMI (BB, I, I.getDebugLoc (), TII.get (SPIRV::OpSpecConstantOp))
1173
- .addDef (ResVReg)
1174
- .addUse (GR.getSPIRVTypeID (ResType))
1175
- .addImm (static_cast <uint32_t >(SPIRV::Opcode::PtrCastToGeneric))
1176
- .addUse (NewReg)
1177
- .constrainAllUses (TII, TRI, RBI);
1178
- return Result;
1179
- }
1159
+ MachineBasicBlock &BB = *I.getParent ();
1160
+ const DebugLoc &DL = I.getDebugLoc ();
1161
+
1180
1162
Register SrcPtr = I.getOperand (1 ).getReg ();
1181
1163
SPIRVType *SrcPtrTy = GR.getSPIRVTypeForVReg (SrcPtr);
1182
- SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass (SrcPtr);
1183
- SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass (ResVReg);
1164
+ // don't generate a cast for a null that is represented by OpTypeInt
1165
+ if (SrcPtrTy->getOpcode () != SPIRV::OpTypePointer ||
1166
+ ResType->getOpcode () != SPIRV::OpTypePointer)
1167
+ return BuildMI (BB, I, DL, TII.get (TargetOpcode::COPY))
1168
+ .addDef (ResVReg)
1169
+ .addUse (SrcPtr)
1170
+ .constrainAllUses (TII, TRI, RBI);
1171
+
1172
+ SPIRV::StorageClass::StorageClass SrcSC = GR.getPointerStorageClass (SrcPtrTy);
1173
+ SPIRV::StorageClass::StorageClass DstSC = GR.getPointerStorageClass (ResType);
1174
+
1175
+ // AddrSpaceCast uses within OpVariable and OpConstantComposite instructions
1176
+ // are expressed by OpSpecConstantOp with an Opcode.
1177
+ bool IsGRef = false ;
1178
+ bool IsAllowedRefs =
1179
+ std::all_of (MRI->use_instr_begin (ResVReg), MRI->use_instr_end (),
1180
+ [&IsGRef](auto const &It) {
1181
+ unsigned Opcode = It.getOpcode ();
1182
+ if (Opcode == SPIRV::OpConstantComposite ||
1183
+ Opcode == SPIRV::OpVariable ||
1184
+ isSpvIntrinsic (It, Intrinsic::spv_init_global))
1185
+ return IsGRef = true ;
1186
+ return Opcode == SPIRV::OpName;
1187
+ });
1188
+ if (IsAllowedRefs && IsGRef) {
1189
+ // TODO: insert a check whether the Kernel capability was declared.
1190
+ unsigned SpecOpcode =
1191
+ DstSC == SPIRV::StorageClass::Generic && isGenericCastablePtr (SrcSC)
1192
+ ? static_cast <uint32_t >(SPIRV::Opcode::PtrCastToGeneric)
1193
+ : (SrcSC == SPIRV::StorageClass::Generic &&
1194
+ isGenericCastablePtr (DstSC)
1195
+ ? static_cast <uint32_t >(SPIRV::Opcode::GenericCastToPtr)
1196
+ : 0 );
1197
+ if (SpecOpcode) {
1198
+ // TODO: OpConstantComposite expects i8*, so we are forced to forget a
1199
+ // correct value of ResType and use general i8* instead. Maybe this should
1200
+ // be addressed in the emit-intrinsic step to infer a correct
1201
+ // OpConstantComposite type.
1202
+ SPIRVType *NewResType = GR.getOrCreateSPIRVPointerType (
1203
+ GR.getOrCreateSPIRVIntegerType (8 , I, TII), I, TII, DstSC);
1204
+ bool Result = BuildMI (BB, I, DL, TII.get (SPIRV::OpSpecConstantOp))
1205
+ .addDef (ResVReg)
1206
+ .addUse (GR.getSPIRVTypeID (NewResType))
1207
+ .addImm (SpecOpcode)
1208
+ .addUse (SrcPtr)
1209
+ .constrainAllUses (TII, TRI, RBI);
1210
+ return Result;
1211
+ } else if (isGenericCastablePtr (SrcSC) && isGenericCastablePtr (DstSC)) {
1212
+ SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType (
1213
+ GR.getPointeeType (SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
1214
+ Register Tmp = MRI->createVirtualRegister (&SPIRV::pIDRegClass);
1215
+ MRI->setType (Tmp, LLT::pointer (0 , 64 ));
1216
+ GR.assignSPIRVTypeToVReg (GenericPtrTy, Tmp, *BB.getParent ());
1217
+ MachineInstrBuilder MIB =
1218
+ BuildMI (BB, I, DL, TII.get (SPIRV::OpSpecConstantOp))
1219
+ .addDef (Tmp)
1220
+ .addUse (GR.getSPIRVTypeID (GenericPtrTy))
1221
+ .addImm (static_cast <uint32_t >(SPIRV::Opcode::PtrCastToGeneric))
1222
+ .addUse (SrcPtr);
1223
+ GR.add (MIB.getInstr (), BB.getParent (), Tmp);
1224
+ bool Result = MIB.constrainAllUses (TII, TRI, RBI);
1225
+ SPIRVType *NewResType = GR.getOrCreateSPIRVPointerType (
1226
+ GR.getOrCreateSPIRVIntegerType (8 , I, TII), I, TII, DstSC);
1227
+ return Result &&
1228
+ BuildMI (BB, I, DL, TII.get (SPIRV::OpSpecConstantOp))
1229
+ .addDef (ResVReg)
1230
+ .addUse (GR.getSPIRVTypeID (NewResType))
1231
+ .addImm (static_cast <uint32_t >(SPIRV::Opcode::GenericCastToPtr))
1232
+ .addUse (Tmp)
1233
+ .constrainAllUses (TII, TRI, RBI);
1234
+ }
1235
+ }
1184
1236
1185
1237
// don't generate a cast between identical storage classes
1186
1238
if (SrcSC == DstSC)
1187
- return BuildMI (*I.getParent (), I, I.getDebugLoc (),
1188
- TII.get (TargetOpcode::COPY))
1239
+ return BuildMI (BB, I, DL, TII.get (TargetOpcode::COPY))
1189
1240
.addDef (ResVReg)
1190
1241
.addUse (SrcPtr)
1191
1242
.constrainAllUses (TII, TRI, RBI);
@@ -1201,8 +1252,6 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
1201
1252
Register Tmp = MRI->createVirtualRegister (&SPIRV::iIDRegClass);
1202
1253
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType (
1203
1254
GR.getPointeeType (SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
1204
- MachineBasicBlock &BB = *I.getParent ();
1205
- const DebugLoc &DL = I.getDebugLoc ();
1206
1255
bool Success = BuildMI (BB, I, DL, TII.get (SPIRV::OpPtrCastToGeneric))
1207
1256
.addDef (Tmp)
1208
1257
.addUse (GR.getSPIRVTypeID (GenericPtrTy))
0 commit comments