Skip to content

Commit be2277f

Browse files
effective-lightfhahn
authored andcommitted
[Matrix] Support #pragma clang fp
From https://bugs.llvm.org/show_bug.cgi?id=49739: Currently, `#pragma clang fp` are ignored for matrix types. For the code below, the `contract` fast-math flag should be added to the generated call to `llvm.matrix.multiply` and `fadd` ``` typedef float fx2x2_t __attribute__((matrix_type(2, 2))); void foo(fx2x2_t &A, fx2x2_t &C, fx2x2_t &B) { #pragma clang fp contract(fast) C = A*B + C; } ``` Reviewed By: fhahn, mibintc Differential Revision: https://reviews.llvm.org/D100834
1 parent 4393668 commit be2277f

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

clang/lib/CodeGen/CGExprScalar.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,6 +732,7 @@ class ScalarExprEmitter
732732
BO->getLHS()->getType().getCanonicalType());
733733
auto *RHSMatTy = dyn_cast<ConstantMatrixType>(
734734
BO->getRHS()->getType().getCanonicalType());
735+
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
735736
if (LHSMatTy && RHSMatTy)
736737
return MB.CreateMatrixMultiply(Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(),
737738
LHSMatTy->getNumColumns(),
@@ -3206,6 +3207,7 @@ Value *ScalarExprEmitter::EmitDiv(const BinOpInfo &Ops) {
32063207
"first operand must be a matrix");
32073208
assert(BO->getRHS()->getType().getCanonicalType()->isArithmeticType() &&
32083209
"second operand must be an arithmetic type");
3210+
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, Ops.FPFeatures);
32093211
return MB.CreateScalarDiv(Ops.LHS, Ops.RHS,
32103212
Ops.Ty->hasUnsignedIntegerRepresentation());
32113213
}
@@ -3585,6 +3587,7 @@ Value *ScalarExprEmitter::EmitAdd(const BinOpInfo &op) {
35853587

35863588
if (op.Ty->isConstantMatrixType()) {
35873589
llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
3590+
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, op.FPFeatures);
35883591
return MB.CreateAdd(op.LHS, op.RHS);
35893592
}
35903593

@@ -3734,6 +3737,7 @@ Value *ScalarExprEmitter::EmitSub(const BinOpInfo &op) {
37343737

37353738
if (op.Ty->isConstantMatrixType()) {
37363739
llvm::MatrixBuilder<CGBuilderTy> MB(Builder);
3740+
CodeGenFunction::CGFPOptionsRAII FPOptsRAII(CGF, op.FPFeatures);
37373741
return MB.CreateSub(op.LHS, op.RHS);
37383742
}
37393743

clang/test/CodeGen/fp-matrix-pragma.c

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// RUN: %clang -emit-llvm -S -fenable-matrix -mllvm -disable-llvm-optzns %s -o - | FileCheck %s
2+
3+
typedef float fx2x2_t __attribute__((matrix_type(2, 2)));
4+
typedef int ix2x2_t __attribute__((matrix_type(2, 2)));
5+
6+
fx2x2_t fp_matrix_contract(fx2x2_t a, fx2x2_t b, float c, float d) {
7+
// CHECK: call contract <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32
8+
// CHECK: fdiv contract <4 x float>
9+
// CHECK: fmul contract <4 x float>
10+
#pragma clang fp contract(fast)
11+
return (a * b / c) * d;
12+
}
13+
14+
fx2x2_t fp_matrix_reassoc(fx2x2_t a, fx2x2_t b, fx2x2_t c) {
15+
// CHECK: fadd reassoc <4 x float>
16+
// CHECK: fsub reassoc <4 x float>
17+
#pragma clang fp reassociate(on)
18+
return a + b - c;
19+
}
20+
21+
fx2x2_t fp_matrix_ops(fx2x2_t a, fx2x2_t b, fx2x2_t c) {
22+
// CHECK: call reassoc contract <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32
23+
// CHECK: fadd reassoc contract <4 x float>
24+
#pragma clang fp contract(fast) reassociate(on)
25+
return a * b + c;
26+
}
27+
28+
fx2x2_t fp_matrix_compound_ops(fx2x2_t a, fx2x2_t b, fx2x2_t c, fx2x2_t d,
29+
float e, float f) {
30+
// CHECK: call reassoc contract <4 x float> @llvm.matrix.multiply.v4f32.v4f32.v4f32
31+
// CHECK: fadd reassoc contract <4 x float>
32+
// CHECK: fsub reassoc contract <4 x float>
33+
// CHECK: fmul reassoc contract <4 x float>
34+
// CHECK: fdiv reassoc contract <4 x float>
35+
#pragma clang fp contract(fast) reassociate(on)
36+
a *= b;
37+
a += c;
38+
a -= d;
39+
a *= e;
40+
a /= f;
41+
42+
return a;
43+
}
44+
45+
ix2x2_t int_matrix_ops(ix2x2_t a, ix2x2_t b, ix2x2_t c) {
46+
// CHECK: call <4 x i32> @llvm.matrix.multiply.v4i32.v4i32.v4i32
47+
// CHECK: add <4 x i32>
48+
#pragma clang fp contract(fast) reassociate(on)
49+
return a * b + c;
50+
}

0 commit comments

Comments
 (0)