Skip to content

Commit 44765a3

Browse files
llvm-beanzpuja2196
authored andcommitted
[HLSL] Vector Usual Arithmetic Conversions (#110195)
HLSL has a different set of usual arithmetic conversions for vector types to resolve a common type for binary operator expressions. This PR implements the current spec proposal from: microsoft/hlsl-specs#311 There is one case that may need additional handling for implicitly truncating vector<T,1> to T early to allow other transformations. Fixes #106253 Re-lands #108659
1 parent 7dbaf49 commit 44765a3

File tree

7 files changed

+598
-4
lines changed

7 files changed

+598
-4
lines changed

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12395,6 +12395,9 @@ def err_hlsl_operator_unsupported : Error<
1239512395

1239612396
def err_hlsl_param_qualifier_mismatch :
1239712397
Error<"conflicting parameter qualifier %0 on parameter %1">;
12398+
def err_hlsl_vector_compound_assignment_truncation : Error<
12399+
"left hand operand of type %0 to compound assignment cannot be truncated "
12400+
"when used with right hand operand of type %1">;
1239812401

1239912402
def warn_hlsl_impcast_vector_truncation : Warning<
1240012403
"implicit conversion truncates vector: %0 to %1">, InGroup<Conversion>;

clang/include/clang/Driver/Options.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2979,7 +2979,7 @@ def flax_vector_conversions_EQ : Joined<["-"], "flax-vector-conversions=">, Grou
29792979
"LangOptions::LaxVectorConversionKind::Integer",
29802980
"LangOptions::LaxVectorConversionKind::All"]>,
29812981
MarshallingInfoEnum<LangOpts<"LaxVectorConversions">,
2982-
open_cl.KeyPath #
2982+
!strconcat("(", open_cl.KeyPath, " || ", hlsl.KeyPath, ")") #
29832983
" ? LangOptions::LaxVectorConversionKind::None" #
29842984
" : LangOptions::LaxVectorConversionKind::All">;
29852985
def flax_vector_conversions : Flag<["-"], "flax-vector-conversions">, Group<f_Group>,

clang/include/clang/Sema/Sema.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7423,7 +7423,8 @@ class Sema final : public SemaBase {
74237423
SourceLocation Loc,
74247424
BinaryOperatorKind Opc);
74257425
QualType CheckVectorLogicalOperands(ExprResult &LHS, ExprResult &RHS,
7426-
SourceLocation Loc);
7426+
SourceLocation Loc,
7427+
BinaryOperatorKind Opc);
74277428

74287429
/// Context in which we're performing a usual arithmetic conversion.
74297430
enum ArithConvKind {

clang/include/clang/Sema/SemaHLSL.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,11 @@ class SemaHLSL : public SemaBase {
6363
std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages);
6464
void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU);
6565

66+
QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
67+
QualType LHSType, QualType RHSType,
68+
bool IsCompAssign);
69+
void emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS, BinaryOperatorKind Opc);
70+
6671
void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL);
6772
void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL);
6873
void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL);

clang/lib/Sema/SemaExpr.cpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10133,6 +10133,10 @@ QualType Sema::CheckVectorOperands(ExprResult &LHS, ExprResult &RHS,
1013310133
const VectorType *RHSVecType = RHSType->getAs<VectorType>();
1013410134
assert(LHSVecType || RHSVecType);
1013510135

10136+
if (getLangOpts().HLSL)
10137+
return HLSL().handleVectorBinOpConversion(LHS, RHS, LHSType, RHSType,
10138+
IsCompAssign);
10139+
1013610140
// AltiVec-style "vector bool op vector bool" combinations are allowed
1013710141
// for some operators but not others.
1013810142
if (!AllowBothBool && LHSVecType &&
@@ -12863,7 +12867,8 @@ static void diagnoseXorMisusedAsPow(Sema &S, const ExprResult &XorLHS,
1286312867
}
1286412868

1286512869
QualType Sema::CheckVectorLogicalOperands(ExprResult &LHS, ExprResult &RHS,
12866-
SourceLocation Loc) {
12870+
SourceLocation Loc,
12871+
BinaryOperatorKind Opc) {
1286712872
// Ensure that either both operands are of the same vector type, or
1286812873
// one operand is of a vector type and the other is of its element type.
1286912874
QualType vType = CheckVectorOperands(LHS, RHS, Loc, false,
@@ -12883,6 +12888,15 @@ QualType Sema::CheckVectorLogicalOperands(ExprResult &LHS, ExprResult &RHS,
1288312888
if (!getLangOpts().CPlusPlus &&
1288412889
!(isa<ExtVectorType>(vType->getAs<VectorType>())))
1288512890
return InvalidLogicalVectorOperands(Loc, LHS, RHS);
12891+
// Beginning with HLSL 2021, HLSL disallows logical operators on vector
12892+
// operands and instead requires the use of the `and`, `or`, `any`, `all`, and
12893+
// `select` functions.
12894+
if (getLangOpts().HLSL &&
12895+
getLangOpts().getHLSLVersion() >= LangOptionsBase::HLSL_2021) {
12896+
(void)InvalidOperands(Loc, LHS, RHS);
12897+
HLSL().emitLogicalOperatorFixIt(LHS.get(), RHS.get(), Opc);
12898+
return QualType();
12899+
}
1288612900

1288712901
return GetSignedVectorType(LHS.get()->getType());
1288812902
}
@@ -13054,7 +13068,7 @@ inline QualType Sema::CheckLogicalOperands(ExprResult &LHS, ExprResult &RHS,
1305413068
// Check vector operands differently.
1305513069
if (LHS.get()->getType()->isVectorType() ||
1305613070
RHS.get()->getType()->isVectorType())
13057-
return CheckVectorLogicalOperands(LHS, RHS, Loc);
13071+
return CheckVectorLogicalOperands(LHS, RHS, Loc, Opc);
1305813072

1305913073
bool EnumConstantInBoolContext = false;
1306013074
for (const ExprResult &HS : {LHS, RHS}) {

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,194 @@ void SemaHLSL::DiagnoseAttrStageMismatch(
401401
<< (AllowedStages.size() != 1) << join(StageStrings, ", ");
402402
}
403403

404+
template <CastKind Kind>
405+
static void castVector(Sema &S, ExprResult &E, QualType &Ty, unsigned Sz) {
406+
if (const auto *VTy = Ty->getAs<VectorType>())
407+
Ty = VTy->getElementType();
408+
Ty = S.getASTContext().getExtVectorType(Ty, Sz);
409+
E = S.ImpCastExprToType(E.get(), Ty, Kind);
410+
}
411+
412+
template <CastKind Kind>
413+
static QualType castElement(Sema &S, ExprResult &E, QualType Ty) {
414+
E = S.ImpCastExprToType(E.get(), Ty, Kind);
415+
return Ty;
416+
}
417+
418+
static QualType handleFloatVectorBinOpConversion(
419+
Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
420+
QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
421+
bool LHSFloat = LElTy->isRealFloatingType();
422+
bool RHSFloat = RElTy->isRealFloatingType();
423+
424+
if (LHSFloat && RHSFloat) {
425+
if (IsCompAssign ||
426+
SemaRef.getASTContext().getFloatingTypeOrder(LElTy, RElTy) > 0)
427+
return castElement<CK_FloatingCast>(SemaRef, RHS, LHSType);
428+
429+
return castElement<CK_FloatingCast>(SemaRef, LHS, RHSType);
430+
}
431+
432+
if (LHSFloat)
433+
return castElement<CK_IntegralToFloating>(SemaRef, RHS, LHSType);
434+
435+
assert(RHSFloat);
436+
if (IsCompAssign)
437+
return castElement<clang::CK_FloatingToIntegral>(SemaRef, RHS, LHSType);
438+
439+
return castElement<CK_IntegralToFloating>(SemaRef, LHS, RHSType);
440+
}
441+
442+
static QualType handleIntegerVectorBinOpConversion(
443+
Sema &SemaRef, ExprResult &LHS, ExprResult &RHS, QualType LHSType,
444+
QualType RHSType, QualType LElTy, QualType RElTy, bool IsCompAssign) {
445+
446+
int IntOrder = SemaRef.Context.getIntegerTypeOrder(LElTy, RElTy);
447+
bool LHSSigned = LElTy->hasSignedIntegerRepresentation();
448+
bool RHSSigned = RElTy->hasSignedIntegerRepresentation();
449+
auto &Ctx = SemaRef.getASTContext();
450+
451+
// If both types have the same signedness, use the higher ranked type.
452+
if (LHSSigned == RHSSigned) {
453+
if (IsCompAssign || IntOrder >= 0)
454+
return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
455+
456+
return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
457+
}
458+
459+
// If the unsigned type has greater than or equal rank of the signed type, use
460+
// the unsigned type.
461+
if (IntOrder != (LHSSigned ? 1 : -1)) {
462+
if (IsCompAssign || RHSSigned)
463+
return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
464+
return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
465+
}
466+
467+
// At this point the signed type has higher rank than the unsigned type, which
468+
// means it will be the same size or bigger. If the signed type is bigger, it
469+
// can represent all the values of the unsigned type, so select it.
470+
if (Ctx.getIntWidth(LElTy) != Ctx.getIntWidth(RElTy)) {
471+
if (IsCompAssign || LHSSigned)
472+
return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
473+
return castElement<CK_IntegralCast>(SemaRef, LHS, RHSType);
474+
}
475+
476+
// This is a bit of an odd duck case in HLSL. It shouldn't happen, but can due
477+
// to C/C++ leaking through. The place this happens today is long vs long
478+
// long. When arguments are vector<unsigned long, N> and vector<long long, N>,
479+
// the long long has higher rank than long even though they are the same size.
480+
481+
// If this is a compound assignment cast the right hand side to the left hand
482+
// side's type.
483+
if (IsCompAssign)
484+
return castElement<CK_IntegralCast>(SemaRef, RHS, LHSType);
485+
486+
// If this isn't a compound assignment we convert to unsigned long long.
487+
QualType ElTy = Ctx.getCorrespondingUnsignedType(LHSSigned ? LElTy : RElTy);
488+
QualType NewTy = Ctx.getExtVectorType(
489+
ElTy, RHSType->castAs<VectorType>()->getNumElements());
490+
(void)castElement<CK_IntegralCast>(SemaRef, RHS, NewTy);
491+
492+
return castElement<CK_IntegralCast>(SemaRef, LHS, NewTy);
493+
}
494+
495+
static CastKind getScalarCastKind(ASTContext &Ctx, QualType DestTy,
496+
QualType SrcTy) {
497+
if (DestTy->isRealFloatingType() && SrcTy->isRealFloatingType())
498+
return CK_FloatingCast;
499+
if (DestTy->isIntegralType(Ctx) && SrcTy->isIntegralType(Ctx))
500+
return CK_IntegralCast;
501+
if (DestTy->isRealFloatingType())
502+
return CK_IntegralToFloating;
503+
assert(SrcTy->isRealFloatingType() && DestTy->isIntegralType(Ctx));
504+
return CK_FloatingToIntegral;
505+
}
506+
507+
QualType SemaHLSL::handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS,
508+
QualType LHSType,
509+
QualType RHSType,
510+
bool IsCompAssign) {
511+
const auto *LVecTy = LHSType->getAs<VectorType>();
512+
const auto *RVecTy = RHSType->getAs<VectorType>();
513+
auto &Ctx = getASTContext();
514+
515+
// If the LHS is not a vector and this is a compound assignment, we truncate
516+
// the argument to a scalar then convert it to the LHS's type.
517+
if (!LVecTy && IsCompAssign) {
518+
QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
519+
RHS = SemaRef.ImpCastExprToType(RHS.get(), RElTy, CK_HLSLVectorTruncation);
520+
RHSType = RHS.get()->getType();
521+
if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
522+
return LHSType;
523+
RHS = SemaRef.ImpCastExprToType(RHS.get(), LHSType,
524+
getScalarCastKind(Ctx, LHSType, RHSType));
525+
return LHSType;
526+
}
527+
528+
unsigned EndSz = std::numeric_limits<unsigned>::max();
529+
unsigned LSz = 0;
530+
if (LVecTy)
531+
LSz = EndSz = LVecTy->getNumElements();
532+
if (RVecTy)
533+
EndSz = std::min(RVecTy->getNumElements(), EndSz);
534+
assert(EndSz != std::numeric_limits<unsigned>::max() &&
535+
"one of the above should have had a value");
536+
537+
// In a compound assignment, the left operand does not change type, the right
538+
// operand is converted to the type of the left operand.
539+
if (IsCompAssign && LSz != EndSz) {
540+
Diag(LHS.get()->getBeginLoc(),
541+
diag::err_hlsl_vector_compound_assignment_truncation)
542+
<< LHSType << RHSType;
543+
return QualType();
544+
}
545+
546+
if (RVecTy && RVecTy->getNumElements() > EndSz)
547+
castVector<CK_HLSLVectorTruncation>(SemaRef, RHS, RHSType, EndSz);
548+
if (!IsCompAssign && LVecTy && LVecTy->getNumElements() > EndSz)
549+
castVector<CK_HLSLVectorTruncation>(SemaRef, LHS, LHSType, EndSz);
550+
551+
if (!RVecTy)
552+
castVector<CK_VectorSplat>(SemaRef, RHS, RHSType, EndSz);
553+
if (!IsCompAssign && !LVecTy)
554+
castVector<CK_VectorSplat>(SemaRef, LHS, LHSType, EndSz);
555+
556+
// If we're at the same type after resizing we can stop here.
557+
if (Ctx.hasSameUnqualifiedType(LHSType, RHSType))
558+
return Ctx.getCommonSugaredType(LHSType, RHSType);
559+
560+
QualType LElTy = LHSType->castAs<VectorType>()->getElementType();
561+
QualType RElTy = RHSType->castAs<VectorType>()->getElementType();
562+
563+
// Handle conversion for floating point vectors.
564+
if (LElTy->isRealFloatingType() || RElTy->isRealFloatingType())
565+
return handleFloatVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
566+
LElTy, RElTy, IsCompAssign);
567+
568+
assert(LElTy->isIntegralType(Ctx) && RElTy->isIntegralType(Ctx) &&
569+
"HLSL Vectors can only contain integer or floating point types");
570+
return handleIntegerVectorBinOpConversion(SemaRef, LHS, RHS, LHSType, RHSType,
571+
LElTy, RElTy, IsCompAssign);
572+
}
573+
574+
void SemaHLSL::emitLogicalOperatorFixIt(Expr *LHS, Expr *RHS,
575+
BinaryOperatorKind Opc) {
576+
assert((Opc == BO_LOr || Opc == BO_LAnd) &&
577+
"Called with non-logical operator");
578+
llvm::SmallVector<char, 256> Buff;
579+
llvm::raw_svector_ostream OS(Buff);
580+
PrintingPolicy PP(SemaRef.getLangOpts());
581+
StringRef NewFnName = Opc == BO_LOr ? "or" : "and";
582+
OS << NewFnName << "(";
583+
LHS->printPretty(OS, nullptr, PP);
584+
OS << ", ";
585+
RHS->printPretty(OS, nullptr, PP);
586+
OS << ")";
587+
SourceRange FullRange = SourceRange(LHS->getBeginLoc(), RHS->getEndLoc());
588+
SemaRef.Diag(LHS->getBeginLoc(), diag::note_function_suggestion)
589+
<< NewFnName << FixItHint::CreateReplacement(FullRange, OS.str());
590+
}
591+
404592
void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) {
405593
llvm::VersionTuple SMVersion =
406594
getASTContext().getTargetInfo().getTriple().getOSVersion();

0 commit comments

Comments
 (0)