Skip to content

[WIP][X86][AMX] Support AMX constant #92280

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/include/llvm-c/Core.h
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,7 @@ typedef enum {
LLVMInstructionValueKind,
LLVMPoisonValueValueKind,
LLVMConstantTargetNoneValueKind,
LLVMConstantAMXNoneValueKind,
} LLVMValueKind;

typedef enum {
Expand Down
21 changes: 21 additions & 0 deletions llvm/include/llvm/IR/Constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,27 @@ class ConstantTargetNone final : public ConstantData {
}
};

/// A constant AMX type default initializer
class ConstantAMXNone final : public ConstantData {
friend class Constant;

explicit ConstantAMXNone(Type *T)
: ConstantData(T, Value::ConstantAMXNoneVal) {}

void destroyConstantImpl();

public:
ConstantAMXNone(const ConstantAMXNone &) = delete;

/// Static factory methods - Return objects of the specified value.
static ConstantAMXNone *get(Type *T);

/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(const Value *V) {
return V->getValueID() == ConstantAMXNoneVal;
}
};

/// The address of a basic block.
///
class BlockAddress final : public Constant {
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/Value.def
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ HANDLE_CONSTANT(ConstantInt)
HANDLE_CONSTANT(ConstantFP)
HANDLE_CONSTANT(ConstantTargetNone)
HANDLE_CONSTANT(ConstantPointerNull)
HANDLE_CONSTANT(ConstantAMXNone)
HANDLE_CONSTANT(ConstantTokenNone)

HANDLE_CONSTANT_MARKER(ConstantFirstVal, Function)
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1885,6 +1885,12 @@ SDValue SelectionDAGBuilder::getValueImpl(const Value *V) {
DAG.getConstant(0, getCurSDLoc(), MVT::nxv16i1));
}

if (VT == MVT::x86amx) {
assert(C->isNullValue() && "Can only zero this target type!");
return DAG.getNode(ISD::BITCAST, getCurSDLoc(), VT,
DAG.getConstant(0, getCurSDLoc(), MVT::v256i32));
}

VectorType *VecTy = cast<VectorType>(V->getType());

// Now that we know the number and type of the elements, get that number of
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/IR/AsmWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1564,7 +1564,8 @@ static void WriteConstantInternal(raw_ostream &Out, const Constant *CV,
return;
}

if (isa<ConstantAggregateZero>(CV) || isa<ConstantTargetNone>(CV)) {
if (isa<ConstantAggregateZero>(CV) || isa<ConstantTargetNone>(CV) ||
isa<ConstantAMXNone>(CV)) {
Out << "zeroinitializer";
return;
}
Expand Down
22 changes: 21 additions & 1 deletion llvm/lib/IR/Constants.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ bool Constant::isNullValue() const {
// constant zero is zero for aggregates, cpnull is null for pointers, none for
// tokens.
return isa<ConstantAggregateZero>(this) || isa<ConstantPointerNull>(this) ||
isa<ConstantTokenNone>(this) || isa<ConstantTargetNone>(this);
isa<ConstantTokenNone>(this) || isa<ConstantTargetNone>(this) ||
isa<ConstantAMXNone>(this);
}

bool Constant::isAllOnesValue() const {
Expand Down Expand Up @@ -391,6 +392,8 @@ Constant *Constant::getNullValue(Type *Ty) {
return ConstantTokenNone::get(Ty->getContext());
case Type::TargetExtTyID:
return ConstantTargetNone::get(cast<TargetExtType>(Ty));
case Type::X86_AMXTyID:
return ConstantAMXNone::get(Ty);
default:
// Function, Label, or Opaque type?
llvm_unreachable("Cannot create a null constant of that type!");
Expand Down Expand Up @@ -1805,6 +1808,23 @@ void ConstantTargetNone::destroyConstantImpl() {
getContext().pImpl->CTNConstants.erase(getType());
}

//---- ConstantAMXNone::get() implementation.
//

ConstantAMXNone *ConstantAMXNone::get(Type *Ty) {
std::unique_ptr<ConstantAMXNone> &Entry =
Ty->getContext().pImpl->CAMXConstants[Ty];
if (!Entry)
Entry.reset(new ConstantAMXNone(Ty));

return Entry.get();
}

/// Remove the constant from the constant table.
void ConstantAMXNone::destroyConstantImpl() {
getContext().pImpl->CAMXConstants.erase(getType());
}

UndefValue *UndefValue::get(Type *Ty) {
std::unique_ptr<UndefValue> &Entry = Ty->getContext().pImpl->UVConstants[Ty];
if (!Entry)
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/IR/LLVMContextImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1549,6 +1549,8 @@ class LLVMContextImpl {

DenseMap<TargetExtType *, std::unique_ptr<ConstantTargetNone>> CTNConstants;

DenseMap<Type *, std::unique_ptr<ConstantAMXNone>> CAMXConstants;

DenseMap<Type *, std::unique_ptr<UndefValue>> UVConstants;

DenseMap<Type *, std::unique_ptr<PoisonValue>> PVConstants;
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5265,9 +5265,9 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
for (Value *V : Call.args()) {
if (auto *MD = dyn_cast<MetadataAsValue>(V))
visitMetadataAsValue(*MD, Call.getCaller());
if (auto *Const = dyn_cast<Constant>(V))
/*if (auto *Const = dyn_cast<Constant>(V))
Check(!Const->getType()->isX86_AMXTy(),
"const x86_amx is not allowed in argument!");
"const x86_amx is not allowed in argument!");*/
}

switch (ID) {
Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43482,6 +43482,18 @@ static SDValue combineBitcast(SDNode *N, SelectionDAG &DAG,
// vxi1 types.
if (DCI.isBeforeLegalize()) {
SDLoc dl(N);

if (VT == MVT::x86amx) {
SDValue Intrin =
DAG.getTargetConstant(Intrinsic::x86_tilezero_internal, dl,
TLI.getPointerTy(DAG.getDataLayout()));
// FIXME: We need to rebuild the Row and Col from its user.
SDValue Row = DAG.getConstant(8, dl, MVT::i16);
SDValue Col = DAG.getConstant(8, dl, MVT::i16);
return DAG.getNode(ISD::INTRINSIC_W_CHAIN, dl, {MVT::x86amx, MVT::Other},
{DAG.getEntryNode(), Intrin, Row, Col});
}

if (SDValue V = combineBitcastvxi1(DAG, VT, N0, dl, Subtarget))
return V;

Expand Down
Loading