Skip to content

Commit 1065869

Browse files
committed
[Matrix] Add matrix type to Clang.
This patch adds a matrix type to Clang as described in the draft specification in clang/docs/MatrixSupport.rst. It introduces a new option -fenable-matrix, which can be used to enable the matrix support. The patch adds new MatrixType and DependentSizedMatrixType types along with the plumbing required. Loads of and stores to pointers to matrix values are lowered to memory operations on 1-D IR arrays. After loading, the loaded values are cast to a vector. This ensures matrix values use the alignment of the element type, instead of LLVM's large vector alignment. The operators and builtins described in the draft spec will will be added in follow-up patches. Reviewers: martong, rsmith, Bigcheese, anemet, dexonsmith, rjmccall, aaron.ballman Reviewed By: rjmccall Differential Revision: https://reviews.llvm.org/D72281
1 parent b51df26 commit 1065869

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1852
-9
lines changed

clang/include/clang/AST/ASTContext.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ class ASTContext : public RefCountedBase<ASTContext> {
194194
DependentAddressSpaceTypes;
195195
mutable llvm::FoldingSet<VectorType> VectorTypes;
196196
mutable llvm::FoldingSet<DependentVectorType> DependentVectorTypes;
197+
mutable llvm::FoldingSet<ConstantMatrixType> MatrixTypes;
198+
mutable llvm::FoldingSet<DependentSizedMatrixType> DependentSizedMatrixTypes;
197199
mutable llvm::FoldingSet<FunctionNoProtoType> FunctionNoProtoTypes;
198200
mutable llvm::ContextualFoldingSet<FunctionProtoType, ASTContext&>
199201
FunctionProtoTypes;
@@ -1326,6 +1328,20 @@ class ASTContext : public RefCountedBase<ASTContext> {
13261328
Expr *SizeExpr,
13271329
SourceLocation AttrLoc) const;
13281330

1331+
/// Return the unique reference to the matrix type of the specified element
1332+
/// type and size
1333+
///
1334+
/// \pre \p ElementType must be a valid matrix element type (see
1335+
/// MatrixType::isValidElementType).
1336+
QualType getConstantMatrixType(QualType ElementType, unsigned NumRows,
1337+
unsigned NumColumns) const;
1338+
1339+
/// Return the unique reference to the matrix type of the specified element
1340+
/// type and size
1341+
QualType getDependentSizedMatrixType(QualType ElementType, Expr *RowExpr,
1342+
Expr *ColumnExpr,
1343+
SourceLocation AttrLoc) const;
1344+
13291345
QualType getDependentAddressSpaceType(QualType PointeeType,
13301346
Expr *AddrSpaceExpr,
13311347
SourceLocation AttrLoc) const;

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,17 @@ DEF_TRAVERSE_TYPE(VectorType, { TRY_TO(TraverseType(T->getElementType())); })
10061006

10071007
DEF_TRAVERSE_TYPE(ExtVectorType, { TRY_TO(TraverseType(T->getElementType())); })
10081008

1009+
DEF_TRAVERSE_TYPE(ConstantMatrixType,
1010+
{ TRY_TO(TraverseType(T->getElementType())); })
1011+
1012+
DEF_TRAVERSE_TYPE(DependentSizedMatrixType, {
1013+
if (T->getRowExpr())
1014+
TRY_TO(TraverseStmt(T->getRowExpr()));
1015+
if (T->getColumnExpr())
1016+
TRY_TO(TraverseStmt(T->getColumnExpr()));
1017+
TRY_TO(TraverseType(T->getElementType()));
1018+
})
1019+
10091020
DEF_TRAVERSE_TYPE(FunctionNoProtoType,
10101021
{ TRY_TO(TraverseType(T->getReturnType())); })
10111022

@@ -1258,6 +1269,18 @@ DEF_TRAVERSE_TYPELOC(ExtVectorType, {
12581269
TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
12591270
})
12601271

1272+
DEF_TRAVERSE_TYPELOC(ConstantMatrixType, {
1273+
TRY_TO(TraverseStmt(TL.getAttrRowOperand()));
1274+
TRY_TO(TraverseStmt(TL.getAttrColumnOperand()));
1275+
TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
1276+
})
1277+
1278+
DEF_TRAVERSE_TYPELOC(DependentSizedMatrixType, {
1279+
TRY_TO(TraverseStmt(TL.getAttrRowOperand()));
1280+
TRY_TO(TraverseStmt(TL.getAttrColumnOperand()));
1281+
TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
1282+
})
1283+
12611284
DEF_TRAVERSE_TYPELOC(FunctionNoProtoType,
12621285
{ TRY_TO(TraverseTypeLoc(TL.getReturnLoc())); })
12631286

clang/include/clang/AST/Type.h

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,19 @@ class alignas(8) Type : public ExtQualsTypeCommonBase {
16541654
uint32_t NumElements;
16551655
};
16561656

1657+
class ConstantMatrixTypeBitfields {
1658+
friend class ConstantMatrixType;
1659+
1660+
unsigned : NumTypeBits;
1661+
1662+
/// Number of rows and columns. Using 20 bits allows supporting very large
1663+
/// matrixes, while keeping 24 bits to accommodate NumTypeBits.
1664+
unsigned NumRows : 20;
1665+
unsigned NumColumns : 20;
1666+
1667+
static constexpr uint32_t MaxElementsPerDimension = (1 << 20) - 1;
1668+
};
1669+
16571670
class AttributedTypeBitfields {
16581671
friend class AttributedType;
16591672

@@ -1763,6 +1776,7 @@ class alignas(8) Type : public ExtQualsTypeCommonBase {
17631776
TypeWithKeywordBitfields TypeWithKeywordBits;
17641777
ElaboratedTypeBitfields ElaboratedTypeBits;
17651778
VectorTypeBitfields VectorTypeBits;
1779+
ConstantMatrixTypeBitfields ConstantMatrixTypeBits;
17661780
SubstTemplateTypeParmPackTypeBitfields SubstTemplateTypeParmPackTypeBits;
17671781
TemplateSpecializationTypeBitfields TemplateSpecializationTypeBits;
17681782
DependentTemplateSpecializationTypeBitfields
@@ -2021,6 +2035,7 @@ class alignas(8) Type : public ExtQualsTypeCommonBase {
20212035
bool isComplexIntegerType() const; // GCC _Complex integer type.
20222036
bool isVectorType() const; // GCC vector type.
20232037
bool isExtVectorType() const; // Extended vector type.
2038+
bool isConstantMatrixType() const; // Matrix type.
20242039
bool isDependentAddressSpaceType() const; // value-dependent address space qualifier
20252040
bool isObjCObjectPointerType() const; // pointer to ObjC object
20262041
bool isObjCRetainableType() const; // ObjC object or block pointer
@@ -3390,6 +3405,131 @@ class ExtVectorType : public VectorType {
33903405
}
33913406
};
33923407

3408+
/// Represents a matrix type, as defined in the Matrix Types clang extensions.
3409+
/// __attribute__((matrix_type(rows, columns))), where "rows" specifies
3410+
/// number of rows and "columns" specifies the number of columns.
3411+
class MatrixType : public Type, public llvm::FoldingSetNode {
3412+
protected:
3413+
friend class ASTContext;
3414+
3415+
/// The element type of the matrix.
3416+
QualType ElementType;
3417+
3418+
MatrixType(QualType ElementTy, QualType CanonElementTy);
3419+
3420+
MatrixType(TypeClass TypeClass, QualType ElementTy, QualType CanonElementTy,
3421+
const Expr *RowExpr = nullptr, const Expr *ColumnExpr = nullptr);
3422+
3423+
public:
3424+
/// Returns type of the elements being stored in the matrix
3425+
QualType getElementType() const { return ElementType; }
3426+
3427+
/// Valid elements types are the following:
3428+
/// * an integer type (as in C2x 6.2.5p19), but excluding enumerated types
3429+
/// and _Bool
3430+
/// * the standard floating types float or double
3431+
/// * a half-precision floating point type, if one is supported on the target
3432+
static bool isValidElementType(QualType T) {
3433+
return T->isDependentType() ||
3434+
(T->isRealType() && !T->isBooleanType() && !T->isEnumeralType());
3435+
}
3436+
3437+
bool isSugared() const { return false; }
3438+
QualType desugar() const { return QualType(this, 0); }
3439+
3440+
static bool classof(const Type *T) {
3441+
return T->getTypeClass() == ConstantMatrix ||
3442+
T->getTypeClass() == DependentSizedMatrix;
3443+
}
3444+
};
3445+
3446+
/// Represents a concrete matrix type with constant number of rows and columns
3447+
class ConstantMatrixType final : public MatrixType {
3448+
protected:
3449+
friend class ASTContext;
3450+
3451+
/// The element type of the matrix.
3452+
QualType ElementType;
3453+
3454+
ConstantMatrixType(QualType MatrixElementType, unsigned NRows,
3455+
unsigned NColumns, QualType CanonElementType);
3456+
3457+
ConstantMatrixType(TypeClass typeClass, QualType MatrixType, unsigned NRows,
3458+
unsigned NColumns, QualType CanonElementType);
3459+
3460+
public:
3461+
/// Returns the number of rows in the matrix.
3462+
unsigned getNumRows() const { return ConstantMatrixTypeBits.NumRows; }
3463+
3464+
/// Returns the number of columns in the matrix.
3465+
unsigned getNumColumns() const { return ConstantMatrixTypeBits.NumColumns; }
3466+
3467+
/// Returns the number of elements required to embed the matrix into a vector.
3468+
unsigned getNumElementsFlattened() const {
3469+
return ConstantMatrixTypeBits.NumRows * ConstantMatrixTypeBits.NumColumns;
3470+
}
3471+
3472+
/// Returns true if \p NumElements is a valid matrix dimension.
3473+
static bool isDimensionValid(uint64_t NumElements) {
3474+
return NumElements > 0 &&
3475+
NumElements <= ConstantMatrixTypeBitfields::MaxElementsPerDimension;
3476+
}
3477+
3478+
void Profile(llvm::FoldingSetNodeID &ID) {
3479+
Profile(ID, getElementType(), getNumRows(), getNumColumns(),
3480+
getTypeClass());
3481+
}
3482+
3483+
static void Profile(llvm::FoldingSetNodeID &ID, QualType ElementType,
3484+
unsigned NumRows, unsigned NumColumns,
3485+
TypeClass TypeClass) {
3486+
ID.AddPointer(ElementType.getAsOpaquePtr());
3487+
ID.AddInteger(NumRows);
3488+
ID.AddInteger(NumColumns);
3489+
ID.AddInteger(TypeClass);
3490+
}
3491+
3492+
static bool classof(const Type *T) {
3493+
return T->getTypeClass() == ConstantMatrix;
3494+
}
3495+
};
3496+
3497+
/// Represents a matrix type where the type and the number of rows and columns
3498+
/// is dependent on a template.
3499+
class DependentSizedMatrixType final : public MatrixType {
3500+
friend class ASTContext;
3501+
3502+
const ASTContext &Context;
3503+
Expr *RowExpr;
3504+
Expr *ColumnExpr;
3505+
3506+
SourceLocation loc;
3507+
3508+
DependentSizedMatrixType(const ASTContext &Context, QualType ElementType,
3509+
QualType CanonicalType, Expr *RowExpr,
3510+
Expr *ColumnExpr, SourceLocation loc);
3511+
3512+
public:
3513+
QualType getElementType() const { return ElementType; }
3514+
Expr *getRowExpr() const { return RowExpr; }
3515+
Expr *getColumnExpr() const { return ColumnExpr; }
3516+
SourceLocation getAttributeLoc() const { return loc; }
3517+
3518+
bool isSugared() const { return false; }
3519+
QualType desugar() const { return QualType(this, 0); }
3520+
3521+
static bool classof(const Type *T) {
3522+
return T->getTypeClass() == DependentSizedMatrix;
3523+
}
3524+
3525+
void Profile(llvm::FoldingSetNodeID &ID) {
3526+
Profile(ID, Context, getElementType(), getRowExpr(), getColumnExpr());
3527+
}
3528+
3529+
static void Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
3530+
QualType ElementType, Expr *RowExpr, Expr *ColumnExpr);
3531+
};
3532+
33933533
/// FunctionType - C99 6.7.5.3 - Function Declarators. This is the common base
33943534
/// class of FunctionNoProtoType and FunctionProtoType.
33953535
class FunctionType : public Type {
@@ -6605,6 +6745,10 @@ inline bool Type::isExtVectorType() const {
66056745
return isa<ExtVectorType>(CanonicalType);
66066746
}
66076747

6748+
inline bool Type::isConstantMatrixType() const {
6749+
return isa<ConstantMatrixType>(CanonicalType);
6750+
}
6751+
66086752
inline bool Type::isDependentAddressSpaceType() const {
66096753
return isa<DependentAddressSpaceType>(CanonicalType);
66106754
}

clang/include/clang/AST/TypeLoc.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1735,6 +1735,7 @@ class DependentAddressSpaceTypeLoc
17351735

17361736
void initializeLocal(ASTContext &Context, SourceLocation loc) {
17371737
setAttrNameLoc(loc);
1738+
setAttrOperandParensRange(loc);
17381739
setAttrOperandParensRange(SourceRange(loc));
17391740
setAttrExprOperand(getTypePtr()->getAddrSpaceExpr());
17401741
}
@@ -1774,6 +1775,68 @@ class DependentSizedExtVectorTypeLoc :
17741775
DependentSizedExtVectorType> {
17751776
};
17761777

1778+
struct MatrixTypeLocInfo {
1779+
SourceLocation AttrLoc;
1780+
SourceRange OperandParens;
1781+
Expr *RowOperand;
1782+
Expr *ColumnOperand;
1783+
};
1784+
1785+
class MatrixTypeLoc : public ConcreteTypeLoc<UnqualTypeLoc, MatrixTypeLoc,
1786+
MatrixType, MatrixTypeLocInfo> {
1787+
public:
1788+
/// The location of the attribute name, i.e.
1789+
/// float __attribute__((matrix_type(4, 2)))
1790+
/// ^~~~~~~~~~~~~~~~~
1791+
SourceLocation getAttrNameLoc() const { return getLocalData()->AttrLoc; }
1792+
void setAttrNameLoc(SourceLocation loc) { getLocalData()->AttrLoc = loc; }
1793+
1794+
/// The attribute's row operand, if it has one.
1795+
/// float __attribute__((matrix_type(4, 2)))
1796+
/// ^
1797+
Expr *getAttrRowOperand() const { return getLocalData()->RowOperand; }
1798+
void setAttrRowOperand(Expr *e) { getLocalData()->RowOperand = e; }
1799+
1800+
/// The attribute's column operand, if it has one.
1801+
/// float __attribute__((matrix_type(4, 2)))
1802+
/// ^
1803+
Expr *getAttrColumnOperand() const { return getLocalData()->ColumnOperand; }
1804+
void setAttrColumnOperand(Expr *e) { getLocalData()->ColumnOperand = e; }
1805+
1806+
/// The location of the parentheses around the operand, if there is
1807+
/// an operand.
1808+
/// float __attribute__((matrix_type(4, 2)))
1809+
/// ^ ^
1810+
SourceRange getAttrOperandParensRange() const {
1811+
return getLocalData()->OperandParens;
1812+
}
1813+
void setAttrOperandParensRange(SourceRange range) {
1814+
getLocalData()->OperandParens = range;
1815+
}
1816+
1817+
SourceRange getLocalSourceRange() const {
1818+
SourceRange range(getAttrNameLoc());
1819+
range.setEnd(getAttrOperandParensRange().getEnd());
1820+
return range;
1821+
}
1822+
1823+
void initializeLocal(ASTContext &Context, SourceLocation loc) {
1824+
setAttrNameLoc(loc);
1825+
setAttrOperandParensRange(loc);
1826+
setAttrRowOperand(nullptr);
1827+
setAttrColumnOperand(nullptr);
1828+
}
1829+
};
1830+
1831+
class ConstantMatrixTypeLoc
1832+
: public InheritingConcreteTypeLoc<MatrixTypeLoc, ConstantMatrixTypeLoc,
1833+
ConstantMatrixType> {};
1834+
1835+
class DependentSizedMatrixTypeLoc
1836+
: public InheritingConcreteTypeLoc<MatrixTypeLoc,
1837+
DependentSizedMatrixTypeLoc,
1838+
DependentSizedMatrixType> {};
1839+
17771840
// FIXME: location of the '_Complex' keyword.
17781841
class ComplexTypeLoc : public InheritingConcreteTypeLoc<TypeSpecTypeLoc,
17791842
ComplexTypeLoc,

clang/include/clang/AST/TypeProperties.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,41 @@ let Class = DependentSizedExtVectorType in {
224224
}]>;
225225
}
226226

227+
let Class = MatrixType in {
228+
def : Property<"elementType", QualType> {
229+
let Read = [{ node->getElementType() }];
230+
}
231+
}
232+
233+
let Class = ConstantMatrixType in {
234+
def : Property<"numRows", UInt32> {
235+
let Read = [{ node->getNumRows() }];
236+
}
237+
def : Property<"numColumns", UInt32> {
238+
let Read = [{ node->getNumColumns() }];
239+
}
240+
241+
def : Creator<[{
242+
return ctx.getConstantMatrixType(elementType, numRows, numColumns);
243+
}]>;
244+
}
245+
246+
let Class = DependentSizedMatrixType in {
247+
def : Property<"rows", ExprRef> {
248+
let Read = [{ node->getRowExpr() }];
249+
}
250+
def : Property<"columns", ExprRef> {
251+
let Read = [{ node->getColumnExpr() }];
252+
}
253+
def : Property<"attributeLoc", SourceLocation> {
254+
let Read = [{ node->getAttributeLoc() }];
255+
}
256+
257+
def : Creator<[{
258+
return ctx.getDependentSizedMatrixType(elementType, rows, columns, attributeLoc);
259+
}]>;
260+
}
261+
227262
let Class = FunctionType in {
228263
def : Property<"returnType", QualType> {
229264
let Read = [{ node->getReturnType() }];

clang/include/clang/Basic/Attr.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2460,6 +2460,15 @@ def VecTypeHint : InheritableAttr {
24602460
let Documentation = [Undocumented];
24612461
}
24622462

2463+
def MatrixType : TypeAttr {
2464+
let Spellings = [Clang<"matrix_type">];
2465+
let Subjects = SubjectList<[TypedefName], ErrorDiag>;
2466+
let Args = [ExprArgument<"NumRows">, ExprArgument<"NumColumns">];
2467+
let Documentation = [Undocumented];
2468+
let ASTNode = 0;
2469+
let PragmaAttributeSupport = 0;
2470+
}
2471+
24632472
def Visibility : InheritableAttr {
24642473
let Clone = 0;
24652474
let Spellings = [GCC<"visibility">];

0 commit comments

Comments
 (0)