1
- #define TM 8
2
- #define TK 8
3
-
4
- template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
5
- public:
6
- T *mat;
7
-
8
- public:
9
- T *get_data () { return mat; }
10
- void set_data (T *data) { mat = data; }
11
- big_matrix (T *data) : mat(data) {}
12
- };
1
+ constexpr size_t TM = 8 ;
2
+ constexpr size_t TK = 8 ;
13
3
14
4
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
15
5
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
@@ -60,7 +50,6 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
60
50
accC.template get_multi_ptr <access::decorated::no>() +
61
51
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
62
52
N, layout::row_major);
63
- joint_matrix_fill (sg, sub_a, 42 );
64
53
for (int k = 0 ; k < K; k += TK) {
65
54
joint_matrix_load (
66
55
sg, sub_a,
@@ -75,13 +64,12 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
75
64
// If no rounding to tf32 function is called, joint_matrix_mad
76
65
// function will work on truncated floats.
77
66
joint_matrix_apply (sg, sub_a,
78
- [=](float x) { x = round_to_tf32 (x); });
67
+ [=](float & x) { x = round_to_tf32 (x); });
79
68
joint_matrix_apply (sg, sub_b,
80
- [=](float x) { x = round_to_tf32 (x); });
69
+ [=](float & x) { x = round_to_tf32 (x); });
81
70
joint_matrix_mad (sg, sub_c, sub_a, sub_b, sub_c);
82
71
}
83
72
84
- joint_matrix_apply (sg, sub_a, [=](float x) { x *= 2 ; });
85
73
joint_matrix_store (
86
74
sg, sub_c,
87
75
accC.template get_multi_ptr <access::decorated::no>() +
@@ -91,43 +79,21 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
91
79
}).wait ();
92
80
}
93
81
94
- static constexpr size_t MATRIX_M = TM * 2 ;
95
- static constexpr size_t MATRIX_N = TN * 2 ;
96
- static constexpr size_t MATRIX_K = TK * 2 ;
97
- float A[MATRIX_M][MATRIX_K];
98
- float B[MATRIX_K][MATRIX_N];
99
- float C[MATRIX_M][MATRIX_N];
100
- float D[MATRIX_M][MATRIX_N];
101
-
102
- void matrix_multiply_ref (float *A_mem, float *B_mem, float *C_mem, int M, int N,
103
- int K) {
104
- for (int m = 0 ; m < M; m++)
105
- for (int n = 0 ; n < N; n++) {
106
- for (int k = 0 ; k < K; k++) {
107
- float va = A_mem[m * K + k];
108
- float vb = B_mem[k * N + n];
109
- C_mem[m * N + n] += va * vb;
110
- }
111
- }
112
- }
113
-
114
82
int main () {
115
- for (int i = 0 ; i < MATRIX_M; i++) {
116
- for (int j = 0 ; j < MATRIX_K; j++) {
117
- A[i][j] = 1 .0f * (i + j);
118
- }
119
- }
120
- for (int i = 0 ; i < MATRIX_K; i++) {
121
- for (int j = 0 ; j < MATRIX_N; j++) {
122
- B[i][j] = 2 .0f * i + 3 .0f * j;
123
- }
124
- }
125
- for (int i = 0 ; i < MATRIX_M; i++) {
126
- for (int j = 0 ; j < MATRIX_N; j++) {
127
- C[i][j] = 1.0 ;
128
- D[i][j] = 1.0 ;
129
- }
130
- }
83
+ static constexpr size_t MATRIX_M = TM * 2 ;
84
+ static constexpr size_t MATRIX_N = TN * 2 ;
85
+ static constexpr size_t MATRIX_K = TK * 2 ;
86
+ float A[MATRIX_M][MATRIX_K];
87
+ float B[MATRIX_K][MATRIX_N];
88
+ float C[MATRIX_M][MATRIX_N];
89
+ float D[MATRIX_M][MATRIX_N];
90
+
91
+ matrix_fill (MATRIX_M, MATRIX_K, (float *)A,
92
+ [](int i, int j) { return 1 .0f * (i + j); });
93
+ matrix_fill (MATRIX_K, MATRIX_N, (float *)B,
94
+ [](int i, int j) { return 2 .0f * i + 3 .0f * j; });
95
+ matrix_fill (MATRIX_M, MATRIX_N, (float *)C, 1 .0f );
96
+ matrix_fill (MATRIX_M, MATRIX_N, (float *)D, 1 .0f );
131
97
132
98
big_matrix<float , MATRIX_M, MATRIX_N> MC ((float *)&C);
133
99
big_matrix<float , MATRIX_M, MATRIX_N> MD ((float *)&D);
@@ -137,13 +103,7 @@ int main() {
137
103
matrix_multiply_ref ((float *)A, (float *)B, (float *)D, MATRIX_M, MATRIX_N,
138
104
MATRIX_K);
139
105
140
- bool res = true ;
141
- for (int i = 0 ; i < MATRIX_M; i++) {
142
- for (int j = 0 ; j < MATRIX_N; j++) {
143
- if (C[i][j] != D[i][j])
144
- res = false ;
145
- }
146
- }
106
+ bool res = matrix_compare (MATRIX_M, MATRIX_N, (float *)C, (float *)D);
147
107
std::cout << (res ? " passed" : " failed" ) << std::endl;
148
108
return !res;
149
109
}
0 commit comments