Skip to content

Commit be5ef79

Browse files
authored
HIP: Supress transformation warning in softmax.cu
loops with bounds not known at compile time can not be unrolled. when ncols_template == 0, the bounds of the loop are not constexpr, thus llvm cant unroll the loops here.
1 parent cae9fb4 commit be5ef79

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

ggml/src/ggml-cuda/softmax.cu

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@ __device__ float __forceinline__ t2f32<half>(half val) {
1313
return __half2float(val);
1414
}
1515

16+
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
17+
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
18+
#ifdef __clang__
19+
#pragma clang diagnostic push
20+
#pragma clang diagnostic ignored "-Wpass-failed"
21+
#endif
1622
template <bool use_shared, int ncols_template, int block_size_template, typename T>
1723
static __global__ void soft_max_f32(
1824
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
@@ -118,6 +124,9 @@ static __global__ void soft_max_f32(
118124
dst[col] = vals[col] * inv_sum;
119125
}
120126
}
127+
#ifdef __clang__
128+
#pragma clang diagnostic pop
129+
#endif
121130

122131
static __global__ void soft_max_back_f32(
123132
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {

0 commit comments

Comments
 (0)