Skip to content

Commit 647d4ba

Browse files
committed
[Matrix] Add support for matrix-by-scalar division.
This patch extends the matrix spec to allow matrix-by-scalar division. Originally support for `/` was left out to avoid ambiguity for the matrix-matrix version of `/`, which could either be elementwise or specified as matrix multiplication M1 * (1/M2). For the matrix-scalar version, no ambiguity exists; `*` is also an elementwise operation in that case. Matrix-by-scalar division is commonly supported by systems including Matlab, Mathematica or NumPy. Reviewed By: rjmccall Differential Revision: https://reviews.llvm.org/D97857
1 parent 4eee13f commit 647d4ba

File tree

9 files changed

+192
-15
lines changed

9 files changed

+192
-15
lines changed

clang/docs/MatrixTypes.rst

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -118,15 +118,21 @@ more explicit.
118118
Matrix Type Binary Operators
119119
----------------------------
120120

121-
Each matrix type supports the following binary operators: ``+``, ``-`` and ``*``. The ``*``
122-
operator provides matrix multiplication, while ``+`` and ``-`` are performed
123-
element-wise. There are also scalar versions of the operators, which take a
124-
matrix type and the matrix element type. The operation is applied to all
125-
elements of the matrix using the scalar value.
126-
127-
For ``BIN_OP`` in ``+``, ``-``, ``*`` given the expression ``M1 BIN_OP M2`` where
128-
at least one of ``M1`` or ``M2`` is of matrix type and, for `*`, the other is of
129-
a real type:
121+
Given two matrixes, the ``+`` and ``-`` operators perform element-wise addition
122+
and subtraction, while the ``*`` operator performs matrix multiplication.
123+
``+``, ``-``, ``*``, and ``/`` can also be used with a matrix and a scalar
124+
value, applying the operation to each element of the matrix.
125+
126+
Earlier versions of this extension did not support division by a scalar.
127+
You can test for the availability of this feature with
128+
``__has_extension(matrix_types_scalar_division)``.
129+
130+
For the expression ``M1 BIN_OP M2`` where
131+
* ``BIN_OP`` is one of ``+`` or ``-``, one of ``M1`` and ``M2`` is of matrix
132+
type, and the other is of matrix type or real type; or
133+
* ``BIN_OP`` is ``*``, one of ``M1`` and ``M2`` is of matrix type, and the
134+
other is of a real type; or
135+
* ``BIN_OP`` is ``/``, ``M1`` is of matrix type, and ``M2`` is of a real type:
130136

131137
* The usual arithmetic conversions are applied to ``M1`` and ``M2``. [ Note: if ``M1`` or
132138
``M2`` are of a real type, they are broadcast to matrices here. — end note ]

clang/include/clang/Basic/Features.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ EXTENSION(pragma_clang_attribute_external_declaration, true)
266266
EXTENSION(gnu_asm, LangOpts.GNUAsm)
267267
EXTENSION(gnu_asm_goto_with_outputs, LangOpts.GNUAsm)
268268
EXTENSION(matrix_types, LangOpts.MatrixTypes)
269+
EXTENSION(matrix_types_scalar_division, true)
269270

270271
FEATURE(cxx_abi_relative_vtable, LangOpts.CPlusPlus && LangOpts.RelativeCXXABIVTables)
271272

clang/lib/AST/Type.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2081,8 +2081,9 @@ bool Type::isUnsignedIntegerOrEnumerationType() const {
20812081
bool Type::hasUnsignedIntegerRepresentation() const {
20822082
if (const auto *VT = dyn_cast<VectorType>(CanonicalType))
20832083
return VT->getElementType()->isUnsignedIntegerOrEnumerationType();
2084-
else
2085-
return isUnsignedIntegerOrEnumerationType();
2084+
if (const auto *VT = dyn_cast<MatrixType>(CanonicalType))
2085+
return VT->getElementType()->isUnsignedIntegerOrEnumerationType();
2086+
return isUnsignedIntegerOrEnumerationType();
20862087
}
20872088

20882089
bool Type::isFloatingType() const {

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3198,6 +3198,20 @@ Value *ScalarExprEmitter::EmitDiv(const BinOpInfo &Ops) {
31983198
}
31993199
}
32003200

3201+
if (Ops.Ty->isConstantMatrixType()) {
3202+
llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
3203+
// We need to check the types of the operands of the operator to get the
3204+
// correct matrix dimensions.
3205+
auto *BO = cast<BinaryOperator>(Ops.E);
3206+
assert(
3207+
isa<ConstantMatrixType>(BO->getLHS()->getType().getCanonicalType()) &&
3208+
"first operand must be a matrix");
3209+
assert(BO->getRHS()->getType().getCanonicalType()->isArithmeticType() &&
3210+
"second operand must be an arithmetic type");
3211+
return MB.CreateScalarDiv(Ops.LHS, Ops.RHS,
3212+
Ops.Ty->hasUnsignedIntegerRepresentation());
3213+
}
3214+
32013215
if (Ops.LHS->getType()->isFPOrFPVectorTy()) {
32023216
llvm::Value *Val;
32033217
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);

clang/lib/Sema/SemaExpr.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10222,14 +10222,19 @@ QualType Sema::CheckMultiplyDivideOperands(ExprResult &LHS, ExprResult &RHS,
1022210222
bool IsCompAssign, bool IsDiv) {
1022310223
checkArithmeticNull(*this, LHS, RHS, Loc, /*IsCompare=*/false);
1022410224

10225-
if (LHS.get()->getType()->isVectorType() ||
10226-
RHS.get()->getType()->isVectorType())
10225+
QualType LHSTy = LHS.get()->getType();
10226+
QualType RHSTy = RHS.get()->getType();
10227+
if (LHSTy->isVectorType() || RHSTy->isVectorType())
1022710228
return CheckVectorOperands(LHS, RHS, Loc, IsCompAssign,
1022810229
/*AllowBothBool*/getLangOpts().AltiVec,
1022910230
/*AllowBoolConversions*/false);
10230-
if (!IsDiv && (LHS.get()->getType()->isConstantMatrixType() ||
10231-
RHS.get()->getType()->isConstantMatrixType()))
10231+
if (!IsDiv &&
10232+
(LHSTy->isConstantMatrixType() || RHSTy->isConstantMatrixType()))
1023210233
return CheckMatrixMultiplyOperands(LHS, RHS, Loc, IsCompAssign);
10234+
// For division, only matrix-by-scalar is supported. Other combinations with
10235+
// matrix types are invalid.
10236+
if (IsDiv && LHSTy->isConstantMatrixType() && RHSTy->isArithmeticType())
10237+
return CheckMatrixElementwiseOperands(LHS, RHS, Loc, IsCompAssign);
1023310238

1023410239
QualType compType = UsualArithmeticConversions(
1023510240
LHS, RHS, Loc, IsCompAssign ? ACK_CompAssign : ACK_Arithmetic);

clang/test/CodeGen/matrix-type-operators.c

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,102 @@ void multiply_compound_int_matrix_constant(ix9x3_t a) {
729729
a *= 5;
730730
}
731731

732+
// CHECK-LABEL: @divide_double_matrix_scalar_float(
733+
// CHECK: [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8
734+
// CHECK-NEXT: [[S:%.*]] = load float, float* %s.addr, align 4
735+
// CHECK-NEXT: [[S_EXT:%.*]] = fpext float [[S]] to double
736+
// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <25 x double> poison, double [[S_EXT]], i32 0
737+
// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <25 x double> [[VECINSERT]], <25 x double> poison, <25 x i32> zeroinitializer
738+
// CHECK-NEXT: [[RES:%.*]] = fdiv <25 x double> [[A]], [[VECSPLAT]]
739+
// CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8
740+
// CHECK-NEXT: ret void
741+
//
742+
void divide_double_matrix_scalar_float(dx5x5_t a, float s) {
743+
a = a / s;
744+
}
745+
746+
// CHECK-LABEL: @divide_double_matrix_scalar_double(
747+
// CHECK: [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8
748+
// CHECK-NEXT: [[S:%.*]] = load double, double* %s.addr, align 8
749+
// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <25 x double> poison, double [[S]], i32 0
750+
// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <25 x double> [[VECINSERT]], <25 x double> poison, <25 x i32> zeroinitializer
751+
// CHECK-NEXT: [[RES:%.*]] = fdiv <25 x double> [[A]], [[VECSPLAT]]
752+
// CHECK-NEXT: store <25 x double> [[RES]], <25 x double>* {{.*}}, align 8
753+
// CHECK-NEXT: ret void
754+
//
755+
void divide_double_matrix_scalar_double(dx5x5_t a, double s) {
756+
a = a / s;
757+
}
758+
759+
// CHECK-LABEL: @divide_float_matrix_scalar_double(
760+
// CHECK: [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR:%.*]], align 4
761+
// CHECK-NEXT: [[S:%.*]] = load double, double* %s.addr, align 8
762+
// CHECK-NEXT: [[S_TRUNC:%.*]] = fptrunc double [[S]] to float
763+
// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <6 x float> poison, float [[S_TRUNC]], i32 0
764+
// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <6 x float> [[VECINSERT]], <6 x float> poison, <6 x i32> zeroinitializer
765+
// CHECK-NEXT: [[RES:%.*]] = fdiv <6 x float> [[MAT]], [[VECSPLAT]]
766+
// CHECK-NEXT: store <6 x float> [[RES]], <6 x float>* [[MAT_ADDR]], align 4
767+
// CHECK-NEXT: ret void
768+
//
769+
void divide_float_matrix_scalar_double(fx2x3_t b, double s) {
770+
b = b / s;
771+
}
772+
773+
// CHECK-LABEL: @divide_int_matrix_scalar_short(
774+
// CHECK: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4
775+
// CHECK-NEXT: [[S:%.*]] = load i16, i16* %s.addr, align 2
776+
// CHECK-NEXT: [[S_EXT:%.*]] = sext i16 [[S]] to i32
777+
// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <27 x i32> poison, i32 [[S_EXT]], i32 0
778+
// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <27 x i32> [[VECINSERT]], <27 x i32> poison, <27 x i32> zeroinitializer
779+
// CHECK-NEXT: [[RES:%.*]] = sdiv <27 x i32> [[MAT]], [[VECSPLAT]]
780+
// CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4
781+
// CHECK-NEXT: ret void
782+
//
783+
void divide_int_matrix_scalar_short(ix9x3_t b, short s) {
784+
b = b / s;
785+
}
786+
787+
// CHECK-LABEL: @divide_int_matrix_scalar_ull(
788+
// CHECK: [[MAT:%.*]] = load <27 x i32>, <27 x i32>* [[MAT_ADDR:%.*]], align 4
789+
// CHECK-NEXT: [[S:%.*]] = load i64, i64* %s.addr, align 8
790+
// CHECK-NEXT: [[S_TRUNC:%.*]] = trunc i64 [[S]] to i32
791+
// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <27 x i32> poison, i32 [[S_TRUNC]], i32 0
792+
// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <27 x i32> [[VECINSERT]], <27 x i32> poison, <27 x i32> zeroinitializer
793+
// CHECK-NEXT: [[RES:%.*]] = sdiv <27 x i32> [[MAT]], [[VECSPLAT]]
794+
// CHECK-NEXT: store <27 x i32> [[RES]], <27 x i32>* [[MAT_ADDR]], align 4
795+
// CHECK-NEXT: ret void
796+
//
797+
void divide_int_matrix_scalar_ull(ix9x3_t b, unsigned long long s) {
798+
b = b / s;
799+
}
800+
801+
// CHECK-LABEL: @divide_ull_matrix_scalar_ull(
802+
// CHECK: [[MAT:%.*]] = load <8 x i64>, <8 x i64>* [[MAT_ADDR:%.*]], align 8
803+
// CHECK-NEXT: [[S:%.*]] = load i64, i64* %s.addr, align 8
804+
// CHECK-NEXT: [[VECINSERT:%.*]] = insertelement <8 x i64> poison, i64 [[S]], i32 0
805+
// CHECK-NEXT: [[VECSPLAT:%.*]] = shufflevector <8 x i64> [[VECINSERT]], <8 x i64> poison, <8 x i32> zeroinitializer
806+
// CHECK-NEXT: [[RES:%.*]] = udiv <8 x i64> [[MAT]], [[VECSPLAT]]
807+
// CHECK-NEXT: store <8 x i64> [[RES]], <8 x i64>* [[MAT_ADDR]], align 8
808+
// CHECK-NEXT: ret void
809+
//
810+
void divide_ull_matrix_scalar_ull(ullx4x2_t b, unsigned long long s) {
811+
b = b / s;
812+
}
813+
814+
// CHECK-LABEL: @divide_float_matrix_constant(
815+
// CHECK-NEXT: entry:
816+
// CHECK-NEXT: [[A_ADDR:%.*]] = alloca [6 x float], align 4
817+
// CHECK-NEXT: [[MAT_ADDR:%.*]] = bitcast [6 x float]* [[A_ADDR]] to <6 x float>*
818+
// CHECK-NEXT: store <6 x float> [[A:%.*]], <6 x float>* [[MAT_ADDR]], align 4
819+
// CHECK-NEXT: [[MAT:%.*]] = load <6 x float>, <6 x float>* [[MAT_ADDR]], align 4
820+
// CHECK-NEXT: [[RES:%.*]] = fdiv <6 x float> [[MAT]], <float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00, float 2.500000e+00>
821+
// CHECK-NEXT: store <6 x float> [[RES]], <6 x float>* [[MAT_ADDR]], align 4
822+
// CHECK-NEXT: ret void
823+
//
824+
void divide_float_matrix_constant(fx2x3_t a) {
825+
a = a / 2.5;
826+
}
827+
732828
// Tests for the matrix type operators.
733829

734830
typedef double dx5x5_t __attribute__((matrix_type(5, 5)));

clang/test/CodeGen/matrix-type.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
#error Expected extension 'matrix_types' to be enabled
55
#endif
66

7+
#if !__has_extension(matrix_types_scalar_division)
8+
#error Expected extension 'matrix_types_scalar_division' to be enabled
9+
#endif
10+
711
typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
812

913
// CHECK: %struct.Matrix = type { i8, [12 x float], float }

clang/test/Sema/matrix-type-operators.c

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,40 @@ void mat_scalar_multiply(sx10x10_t a, sx5x10_t b, float sf, char *p) {
9494
// expected-error@-1 {{assigning to 'float' from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}}
9595
}
9696

97+
void mat_scalar_divide(sx10x10_t a, sx5x10_t b, float sf, char *p) {
98+
// Shape of multiplication result does not match the type of b.
99+
b = a / sf;
100+
// expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}}
101+
b = sf / a;
102+
// expected-error@-1 {{invalid operands to binary expression ('float' and 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))'))}}
103+
104+
a = a / p;
105+
// expected-error@-1 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'char *')}}
106+
a = p / a;
107+
// expected-error@-1 {{invalid operands to binary expression ('char *' and 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))'))}}
108+
109+
sf = a / sf;
110+
// expected-error@-1 {{assigning to 'float' from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}}
111+
}
112+
113+
void matrix_matrix_divide(sx10x10_t a, sx5x10_t b, ix10x5_t c, ix10x10_t d, float sf, char *p) {
114+
// Matrix by matrix division is not supported.
115+
a = a / a;
116+
// expected-error@-1 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'sx10x10_t')}}
117+
118+
b = a / a;
119+
// expected-error@-1 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'sx10x10_t')}}
120+
121+
// Check element type mismatches.
122+
a = b / c;
123+
// expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'ix10x5_t' (aka 'int __attribute__((matrix_type(10, 5)))'))}}
124+
d = a / a;
125+
// expected-error@-1 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'sx10x10_t')}}
126+
127+
p = a / a;
128+
// expected-error@-1 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'sx10x10_t')}}
129+
}
130+
97131
sx5x10_t get_matrix();
98132

99133
void insert(sx5x10_t a, float f) {

llvm/include/llvm/IR/MatrixBuilder.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,22 @@ template <class IRBuilderTy> class MatrixBuilder {
215215
return B.CreateMul(LHS, RHS);
216216
}
217217

218+
/// Divide matrix \p LHS by scalar \p RHS. If the operands are integers, \p
219+
/// IsUnsigned indicates whether UDiv or SDiv should be used.
220+
Value *CreateScalarDiv(Value *LHS, Value *RHS, bool IsUnsigned) {
221+
assert(LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy());
222+
assert(!isa<ScalableVectorType>(LHS->getType()) &&
223+
"LHS Assumed to be fixed width");
224+
RHS =
225+
B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getElementCount(),
226+
RHS, "scalar.splat");
227+
return cast<VectorType>(LHS->getType())
228+
->getElementType()
229+
->isFloatingPointTy()
230+
? B.CreateFDiv(LHS, RHS)
231+
: (IsUnsigned ? B.CreateUDiv(LHS, RHS) : B.CreateSDiv(LHS, RHS));
232+
}
233+
218234
/// Extracts the element at (\p RowIdx, \p ColumnIdx) from \p Matrix.
219235
Value *CreateExtractElement(Value *Matrix, Value *RowIdx, Value *ColumnIdx,
220236
unsigned NumRows, Twine const &Name = "") {

0 commit comments

Comments
 (0)