Skip to content

[SPIR-V] Translate complex nested vector expressions instead of lowering them #5183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 1 addition & 38 deletions llvm-spirv/lib/SPIRV/SPIRVLowerConstExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,53 +167,16 @@ void SPIRVLowerConstExprBase::visit(Module *M) {
};

WorkList.pop_front();
auto LowerConstantVec = [&II, &LowerOp, &WorkList,
&M](ConstantVector *Vec,
unsigned NumOfOp) -> Value * {
if (std::all_of(Vec->op_begin(), Vec->op_end(), [](Value *V) {
return isa<ConstantExpr>(V) || isa<Function>(V);
})) {
// Expand a vector of constexprs and construct it back with
// series of insertelement instructions
std::list<Value *> OpList;
std::transform(Vec->op_begin(), Vec->op_end(),
std::back_inserter(OpList),
[LowerOp](Value *V) { return LowerOp(V); });
Value *Repl = nullptr;
unsigned Idx = 0;
auto *PhiII = dyn_cast<PHINode>(II);
auto *InsPoint =
PhiII ? &PhiII->getIncomingBlock(NumOfOp)->back() : II;
std::list<Instruction *> ReplList;
for (auto V : OpList) {
if (auto *Inst = dyn_cast<Instruction>(V))
ReplList.push_back(Inst);
Repl = InsertElementInst::Create(
(Repl ? Repl : UndefValue::get(Vec->getType())), V,
ConstantInt::get(Type::getInt32Ty(M->getContext()), Idx++), "",
InsPoint);
}
WorkList.splice(WorkList.begin(), ReplList);
return Repl;
}
return nullptr;
};

for (unsigned OI = 0, OE = II->getNumOperands(); OI != OE; ++OI) {
auto *Op = II->getOperand(OI);
if (auto *Vec = dyn_cast<ConstantVector>(Op)) {
Value *ReplInst = LowerConstantVec(Vec, OI);
if (ReplInst)
II->replaceUsesOfWith(Op, ReplInst);
} else if (auto CE = dyn_cast<ConstantExpr>(Op)) {
if (auto *CE = dyn_cast<ConstantExpr>(Op)) {
WorkList.push_front(cast<Instruction>(LowerOp(CE)));
} else if (auto MDAsVal = dyn_cast<MetadataAsValue>(Op)) {
Metadata *MD = MDAsVal->getMetadata();
if (auto ConstMD = dyn_cast<ConstantAsMetadata>(MD)) {
Constant *C = ConstMD->getValue();
Value *ReplInst = nullptr;
if (auto *Vec = dyn_cast<ConstantVector>(C))
ReplInst = LowerConstantVec(Vec, OI);
if (auto *CE = dyn_cast<ConstantExpr>(C))
ReplInst = LowerOp(CE);
if (ReplInst) {
Expand Down
7 changes: 5 additions & 2 deletions llvm-spirv/lib/SPIRV/SPIRVWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,8 @@ SPIRVValue *LLVMToSPIRVBase::transValue(Value *V, SPIRVBasicBlock *BB,

SPIRVDBG(dbgs() << "[transValue] " << *V << '\n');
assert((!isa<Instruction>(V) || isa<GetElementPtrInst>(V) ||
isa<CastInst>(V) || BB) &&
isa<CastInst>(V) || isa<ExtractElementInst>(V) ||
isa<BinaryOperator>(V) || BB) &&
"Invalid SPIRV BB");

auto BV = transValueWithoutDecoration(V, BB, CreateForward, FuncTrans);
Expand All @@ -995,7 +996,9 @@ SPIRVInstruction *LLVMToSPIRVBase::transBinaryInst(BinaryOperator *B,
transBoolOpCode(Op0, OpCodeMap::map(LLVMOC)), transType(B->getType()),
Op0, transValue(B->getOperand(1), BB), BB);

if (isUnfusedMulAdd(B)) {
// BinaryOperator can have no parent if it is handled as an expression inside
// another instruction.
if (B->getParent() && isUnfusedMulAdd(B)) {
Function *F = B->getFunction();
SPIRVDBG(dbgs() << "[fp-contract] disabled for " << F->getName()
<< ": possible fma candidate " << *B << '\n');
Expand Down
13 changes: 12 additions & 1 deletion llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,18 @@ SPIRVSpecConstantOp *createSpecConstantOpInst(SPIRVInstruction *Inst) {
auto OC = Inst->getOpCode();
assert(isSpecConstantOpAllowedOp(OC) &&
"Op code not allowed for OpSpecConstantOp");
auto Ops = Inst->getIds(Inst->getOperands());
std::vector<SPIRVWord> Ops;

// CompositeExtract/Insert operations use zero-based numbering for their
// indexes (containted in instruction operands). All their operands are
// Literals, so we can pass them as is for further handling.
if (OC == OpCompositeExtract || OC == OpCompositeInsert) {
auto *SPIRVInst = static_cast<SPIRVInstTemplateBase *>(Inst);
Ops = SPIRVInst->getOpWords();
} else {
Ops = Inst->getIds(Inst->getOperands());
}

Ops.insert(Ops.begin(), OC);
return static_cast<SPIRVSpecConstantOp *>(SPIRVSpecConstantOp::create(
OpSpecConstantOp, Inst->getType(), Inst->getId(), Ops, nullptr,
Expand Down
20 changes: 19 additions & 1 deletion llvm-spirv/lib/SPIRV/libSPIRV/SPIRVInstruction.h
Original file line number Diff line number Diff line change
Expand Up @@ -2757,11 +2757,29 @@ _SPIRV_OP(ImageQuerySamples, true, 4)
#define _SPIRV_OP(x, ...) \
typedef SPIRVInstTemplate<SPIRVInstTemplateBase, Op##x, __VA_ARGS__> SPIRV##x;
// Other instructions
_SPIRV_OP(SpecConstantOp, true, 4, true, 0)
_SPIRV_OP(GenericPtrMemSemantics, true, 4, false)
_SPIRV_OP(GenericCastToPtrExplicit, true, 5, false, 1)
#undef _SPIRV_OP

class SPIRVSpecConstantOpBase : public SPIRVInstTemplateBase {
public:
bool isOperandLiteral(unsigned I) const override {
// If SpecConstant results from CompositeExtract/Insert operation, then all
// operands are expected to be literals.
switch (Ops[0]) { // Opcode of underlying SpecConstant operation
case OpCompositeExtract:
case OpCompositeInsert:
return true;
default:
return SPIRVInstTemplateBase::isOperandLiteral(I);
}
}
};

typedef SPIRVInstTemplate<SPIRVSpecConstantOpBase, OpSpecConstantOp, true, 4,
true, 0>
SPIRVSpecConstantOp;

class SPIRVAssumeTrueKHR : public SPIRVInstruction {
public:
static const Op OC = OpAssumeTrueKHR;
Expand Down
94 changes: 94 additions & 0 deletions llvm-spirv/test/complex-constexpr-vector.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
; RUN: llvm-as %s -o %t.bc
; RUN: llvm-spirv %t.bc -o %t.spv
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
; RUN: llvm-dis %t.rev.bc
; RUN: FileCheck < %t.rev.ll %s --check-prefix=CHECK-LLVM

target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64"
target triple = "spir64"

define linkonce_odr hidden spir_func void @foo() {
entry:
; CHECK-SPIRV-DAG: Constant [[#]] [[#CONSTANT1:]] 65793
; CHECK-SPIRV-DAG: Constant [[#]] [[#CONSTANT2:]] 131586

; CHECK-SPIRV: ConstantComposite [[#]] [[#COMPOS0:]] [[#CONSTANT1]]
; 124 is OpBitcast opcode
; CHECK-SPIRV: SpecConstantOp [[#]] [[#BITCAST_RES0:]] 124 [[#COMPOS0]]

; 81 is OpCompositeExtract opcode
; CHECK-SPIRV: SpecConstantOp [[#]] [[#EXTRACT_RES0:]] 81 [[#BITCAST_RES0]] 0
; CHECK-SPIRV: ConstantComposite [[#]] [[#COMPOS1:]] [[#CONSTANT2]]

; CHECK-SPIRV: SpecConstantOp [[#]] [[#BITCAST_RES1:]] 124 [[#COMPOS1]]
; CHECK-SPIRV: SpecConstantOp [[#]] [[#EXTRACT_RES1:]] 81 [[#BITCAST_RES1]] 0
; 129 is OpFAdd opcode
; CHECK-SPIRV: SpecConstantOp [[#]] [[#MEMBER_1:]] 129 [[#EXTRACT_RES0:]] [[#EXTRACT_RES1]]

; CHECK-SPIRV: SpecConstantOp [[#]] [[#EXTRACT_RES2:]] 81 [[#BITCAST_RES0]] 1
; CHECK-SPIRV: SpecConstantOp [[#]] [[#EXTRACT_RES3:]] 81 [[#BITCAST_RES1]] 1
; CHECK-SPIRV: SpecConstantOp [[#]] [[#MEMBER_2:]] 129 [[#EXTRACT_RES2]] [[#EXTRACT_RES3]]

; CHECK-SPIRV: SpecConstantOp [[#]] [[#BITCAST_RES2:]] 81 [[#BITCAST_RES0]] 2
; CHECK-SPIRV: SpecConstantOp [[#]] [[#BITCAST_RES2:]] 81 [[#BITCAST_RES1]] 2
; CHECK-SPIRV: SpecConstantOp [[#]] [[#MEMBER_3:]] 129 [[#]] [[#BITCAST_RES2]]

; CHECK-SPIRV: Undef [[#]] [[#MEMBER_4:]]
; CHECK-SPIRV: ConstantComposite [[#]] [[#FINAL_COMPOS:]] [[#MEMBER_1]] [[#MEMBER_2]] [[#MEMBER_3]] [[#MEMBER_4]]
; CHECK-SPIRV: DebugValue [[#]] [[#FINAL_COMPOS]]

; CHECK-LLVM: call void @llvm.dbg.value(
; CHECK-LLVM-SAME: metadata <4 x half> <
; CHECK-LLVM-SAME: half fadd (
; CHECK-LLVM-SAME: half extractelement (<4 x half> bitcast (<2 x i32> <i32 65793, i32 65793> to <4 x half>), i32 0),
; CHECK-LLVM-SAME: half extractelement (<4 x half> bitcast (<2 x i32> <i32 131586, i32 131586> to <4 x half>), i32 0)),
; CHECK-LLVM-SAME: half fadd (
; CHECK-LLVM-SAME: half extractelement (<4 x half> bitcast (<2 x i32> <i32 65793, i32 65793> to <4 x half>), i32 1),
; CHECK-LLVM-SAME: half extractelement (<4 x half> bitcast (<2 x i32> <i32 131586, i32 131586> to <4 x half>), i32 1)),
; CHECK-LLVM-SAME: half fadd (
; CHECK-LLVM-SAME: half extractelement (<4 x half> bitcast (<2 x i32> <i32 65793, i32 65793> to <4 x half>), i32 2),
; CHECK-LLVM-SAME: half extractelement (<4 x half> bitcast (<2 x i32> <i32 131586, i32 131586> to <4 x half>), i32 2)),
; CHECK-LLVM-SAME: half undef>,
; CHECK-LLVM-SAME: metadata ![[#]], metadata !DIExpression()), !dbg ![[#]]
call void @llvm.dbg.value(
metadata <4 x half> <
half fadd (
half extractelement (<4 x half> bitcast (<2 x i32> <i32 65793, i32 65793> to <4 x half>), i32 0),
half extractelement (<4 x half> bitcast (<2 x i32> <i32 131586, i32 131586> to <4 x half>), i32 0)),
half fadd (
half extractelement (<4 x half> bitcast (<2 x i32> <i32 65793, i32 65793> to <4 x half>), i32 1),
half extractelement (<4 x half> bitcast (<2 x i32> <i32 131586, i32 131586> to <4 x half>), i32 1)),
half fadd (
half extractelement (<4 x half> bitcast (<2 x i32> <i32 65793, i32 65793> to <4 x half>), i32 2),
half extractelement (<4 x half> bitcast (<2 x i32> <i32 131586, i32 131586> to <4 x half>), i32 2)),
half undef>,
metadata !12, metadata !DIExpression()), !dbg !7
ret void
}

; Function Attrs: nofree nosync nounwind readnone speculatable willreturn
declare void @llvm.dbg.value(metadata, metadata, metadata)

!llvm.dbg.cu = !{!0}
!llvm.module.flags = !{!3, !4}
!opencl.used.extensions = !{!2}
!opencl.used.optional.core.features = !{!2}
!opencl.compiler.options = !{!2}
!llvm.ident = !{!5}

!0 = distinct !DICompileUnit(language: DW_LANG_C_plus_plus_14, file: !1, producer: "clang version 13.0.0 (https://github.com/intel/llvm.git)", isOptimized: false, runtimeVersion: 0, emissionKind: FullDebug, enums: !2, nameTableKind: None)
!1 = !DIFile(filename: "main.cpp", directory: "/export/users")
!2 = !{}
!3 = !{i32 2, !"Debug Info Version", i32 3}
!4 = !{i32 1, !"wchar_size", i32 4}
!5 = !{!"clang version 13.0.0"}
!6 = distinct !DISubprogram(name: "main", scope: !1, file: !1, line: 1, type: !8, scopeLine: 4, flags: DIFlagPrototyped, spFlags: DISPFlagDefinition, unit: !0, retainedNodes: !2)
!7 = !DILocation(line: 1, scope: !6, inlinedAt: !11)
!8 = !DISubroutineType(types: !9)
!9 = !{!10}
!10 = !DIBasicType(name: "int", size: 32, encoding: DW_ATE_signed)
!11 = !DILocation(line: 1, column: 0, scope: !6)
!12 = !DILocalVariable(name: "resVec", scope: !6, file: !1, line: 1, type: !13)
!13 = distinct !DICompositeType(tag: DW_TAG_class_type, name: "vec<cl::sycl::detail::half_impl::half, 3>", scope: !6, file: !1, line: 1, size: 64, flags: DIFlagTypePassByValue, elements: !2)
15 changes: 9 additions & 6 deletions llvm-spirv/test/constexpr_phi.ll
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,22 @@
; RUN: FileCheck < %t.r.ll %s --check-prefix=CHECK-LLVM

; CHECK-SPIRV: Name [[#F:]] "_Z3runiiPi"

; 117 is OpConvertPtrToU opcode
; CHECK-SPIRV: SpecConstantOp [[#]] [[#SpecConst0:]] 117 [[#F1Ptr:]]
; CHECK-SPIRV: SpecConstantOp [[#]] [[#SpecConst1:]] 117 [[#F2Ptr:]]
; CHECK-SPIRV: ConstantComposite [[#]] [[#Compos0:]] [[#SpecConst0]] [[#SpecConst0]]
; CHECK-SPIRV: ConstantComposite [[#]] [[#Compos1:]] [[#SpecConst0]] [[#SpecConst1]]

; CHECK-SPIRV: Function [[#]] [[#F]] [[#]] [[#]]
; CHECK-SPIRV: Label [[#L1:]]
; CHECK-SPIRV: CompositeInsert [[#]] [[#Ins1:]] [[#]] [[#]] 0
; CHECK-SPIRV: CompositeInsert [[#]] [[#Ins2:]] [[#]] [[#Ins1]] 1
; CHECK-SPIRV: BranchConditional [[#]] [[#L2:]] [[#L3:]]
; CHECK-SPIRV: Label [[#L2]]
; CHECK-SPIRV: CompositeInsert [[#]] [[#Ins3:]] [[#]] [[#]] 0
; CHECK-SPIRV: CompositeInsert [[#]] [[#Ins4:]] [[#]] [[#Ins3]] 1
; CHECK-SPIRV: Branch [[#L3]]
; CHECK-SPIRV: Label [[#L3]]
; CHECK-NEXT-SPIRV: Phi [[#]] [[#]]
; CHECK-SAME-SPIRV: [[#Ins2]] [[#L1]]
; CHECK-SAME-SPIRV: [[#Ins4]] [[#L2]]
; CHECK-SAME-SPIRV: [[#Compos0]] [[#L1]]
; CHECK-SAME-SPIRV: [[#Compos1]] [[#L2]]

; CHECK-LLVM: br label %[[#L:]]
; CHECK-LLVM: [[#L]]:
Expand Down
Loading