Skip to content

Commit 4fa46d3

Browse files
naromero77amdpytorchmergebot
authored andcommitted
TunableOp: Performance Improvement (pytorch#135371)
This PR reduces the overhead on the CPU side by eliminating the use of c10::str in creating signatures. Instead we use fmt library. TunableOp overhead on the CPU are reduced by around ~40%. The improvement is most noticeable on small GEMMs. This PR does not contain any bug fixes or new features. Pull Request resolved: pytorch#135371 Approved by: https://github.com/jeffdaily
1 parent da57849 commit 4fa46d3

File tree

6 files changed

+43
-33
lines changed

6 files changed

+43
-33
lines changed

aten/src/ATen/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,9 @@ endif()
266266

267267
if(USE_CUDA)
268268
list(APPEND ATen_CUDA_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/cuda)
269+
# Next two lines are needed because TunableOp uses third-party/fmt
270+
list(APPEND ATen_CUDA_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES>)
271+
list(APPEND ATen_CUDA_DEPENDENCY_LIBS fmt::fmt-header-only)
269272
list(APPEND ATen_CUDA_CU_SRCS
270273
${cuda_cu}
271274
${native_cuda_cu}
@@ -309,6 +312,9 @@ if(USE_ROCM)
309312
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
310313
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
311314
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
315+
# Next two lines are needed because TunableOp uses third-party/fmt
316+
list(APPEND ATen_HIP_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES>)
317+
list(APPEND ATen_HIP_DEPENDENCY_LIBS fmt::fmt-header-only)
312318
list(APPEND ATen_HIP_SRCS
313319
${ATen_HIP_SRCS}
314320
${hip_hip}

aten/src/ATen/cuda/tunable/GemmCommon.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <ATen/ops/allclose.h>
2323
#include <ATen/ops/from_blob.h>
2424
#endif
25+
#include <fmt/printf.h>
2526

2627
namespace at::cuda::tunable {
2728

@@ -30,15 +31,15 @@ enum class BlasOp {
3031
T = 1
3132
};
3233

33-
inline std::string BlasOpToString(BlasOp op) {
34+
inline char BlasOpToString(BlasOp op) {
3435
switch (op) {
3536
case BlasOp::N:
36-
return "N";
37+
return 'N';
3738
case BlasOp::T:
38-
return "T";
39+
return 'T';
3940
}
4041
TORCH_CHECK(false, "unrecognized BlasOp");
41-
return "N";
42+
return 'N';
4243
}
4344

4445
namespace detail {
@@ -81,7 +82,7 @@ struct GemmParams : OpParams {
8182
}
8283

8384
std::string Signature() const override {
84-
return c10::str(transa, transb, "_", m, "_", n, "_", k);
85+
return fmt::sprintf("%c%c_%ld_%ld_%ld", transa, transb, m, n, k);
8586
}
8687

8788
size_t GetSizeA() const {
@@ -158,7 +159,7 @@ struct GemmParams : OpParams {
158159
template <typename T>
159160
struct GemmAndBiasParams : OpParams {
160161
std::string Signature() const override {
161-
return c10::str(transa, transb, "_", m, "_", n, "_", k);
162+
return fmt::sprintf("%c%c_%ld_%ld_%ld", transa, transb, m, n, k);
162163
}
163164

164165
size_t GetSize(bool duplicate_inputs) const {
@@ -228,7 +229,7 @@ struct GemmStridedBatchedParams : OpParams {
228229
}
229230

230231
std::string Signature() const override {
231-
return c10::str(transa, transb, "_", m, "_", n, "_", k, "_B_", batch);
232+
return fmt::sprintf("%c%c_%ld_%ld_%ld_B_%ld", transa, transb, m, n, k, batch);
232233
}
233234

234235
size_t GetSizeA() const {
@@ -313,7 +314,7 @@ struct ScaledGemmParams : OpParams {
313314
}
314315

315316
std::string Signature() const override {
316-
return c10::str(transa, transb, "_", m, "_", n, "_", k);
317+
return fmt::sprintf("%c%c_%ld_%ld_%ld", transa, transb, m, n, k);
317318
}
318319

319320
size_t GetSizeA() const {

aten/src/ATen/cuda/tunable/GemmHipblaslt.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <ATen/cuda/tunable/GemmCommon.h>
1010
#include <c10/cuda/CUDACachingAllocator.h>
1111
#include <c10/util/StringUtil.h>
12+
#include <fmt/printf.h>
1213

1314
#include <hipblaslt/hipblaslt.h>
1415
#include <hipblaslt/hipblaslt-ext.hpp>
@@ -578,8 +579,7 @@ auto GetHipBlasLtTypeStringAndOps() {
578579
auto algo = heuristic_result[i].algo;
579580
int algo_index = hipblaslt_ext::getIndexFromAlgo(algo);
580581
auto callable = std::make_unique<HipblasltGemmOp<AT, BT, CT, ALayout, BLayout, ParamsT>>(algo);
581-
std::string type_string = c10::str(
582-
"Gemm_Hipblaslt_", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), "_", algo_index);
582+
std::string type_string = fmt::sprintf("Gemm_Hipblaslt_%c%c_%d", _charFromhipblasOp(transa_outer), _charFromhipblasOp(transb_outer), algo_index);
583583
ret.emplace_back(type_string, std::move(callable));
584584
}
585585

aten/src/ATen/cuda/tunable/GemmRocblas.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <ATen/cuda/tunable/TunableOp.h>
88
#include <ATen/cuda/tunable/GemmCommon.h>
99
#include <c10/util/StringUtil.h>
10+
#include <fmt/printf.h>
1011

1112
#define ROCBLAS_BETA_FEATURES_API
1213
#include <rocblas/rocblas.h>
@@ -197,7 +198,7 @@ auto GetRocBlasGemmTypeStringAndOps() {
197198
std::vector<std::pair<std::string, std::unique_ptr<Callable<GemmParams<T>>>>> ret;
198199
for (size_t i = 0; i < solutions.size(); ++i) {
199200
auto callable = std::make_unique<RocblasGemmOp<T>>(solutions[i]);
200-
ret.emplace_back(std::make_pair(c10::str("Gemm_Rocblas_", solutions[i]), std::move(callable)));
201+
ret.emplace_back(std::make_pair(fmt::sprintf("Gemm_Rocblas_%d", solutions[i]), std::move(callable)));
201202
}
202203
return ret;
203204
}

aten/src/ATen/cuda/tunable/Tunable.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,17 @@ static OstreamPtr get_stream(std::string filename) {
4848

4949
}
5050

51-
static void TunableLog(int level, const std::string& msg) {
51+
template<class... Types>
52+
static void TunableLog(int level, Types... args) {
5253
static const char *env_file = getenv("PYTORCH_TUNABLEOP_VERBOSE_FILENAME");
5354
static const char *env_verbose = getenv("PYTORCH_TUNABLEOP_VERBOSE");
5455
static int level_user = env_verbose ? atoi(env_verbose) : 0;
5556
static auto streamptr = detail::get_stream(env_file ? env_file : "err");
5657
if (level_user >= level) {
57-
(*streamptr) << msg <<std::endl;
58+
(*streamptr) << c10::str(args...) << std::endl;
5859
}
5960
}
60-
#define TUNABLE_LOGV(LEVEL, ...) TunableLog(LEVEL, c10::str(__VA_ARGS__))
61+
#define TUNABLE_LOGV(LEVEL, ...) TunableLog(LEVEL, __VA_ARGS__)
6162
#define TUNABLE_LOG1(...) TUNABLE_LOGV(1, __VA_ARGS__)
6263
#define TUNABLE_LOG2(...) TUNABLE_LOGV(2, __VA_ARGS__)
6364
#define TUNABLE_LOG3(...) TUNABLE_LOGV(3, __VA_ARGS__)

aten/src/ATen/cuda/tunable/TunableGemm.h

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <c10/util/Float8_e5m2.h>
2323
#include <c10/util/Float8_e5m2fnuz.h>
2424
#include <c10/util/StringUtil.h>
25+
#include <fmt/printf.h>
2526

2627
namespace at::cuda::tunable {
2728

@@ -135,57 +136,57 @@ inline bool IsZero(c10::complex<float> v) {
135136
}
136137

137138
template <typename T>
138-
inline std::string TypeName(T v) {
139+
inline const char* TypeName(T v) {
139140
return "unknown";
140141
}
141142

142143
template <>
143-
inline std::string TypeName(float v) {
144+
inline const char* TypeName(float v) {
144145
return "float";
145146
}
146147

147148
template <>
148-
inline std::string TypeName(double v) {
149+
inline const char* TypeName(double v) {
149150
return "double";
150151
}
151152

152153
template <>
153-
inline std::string TypeName(BFloat16 v) {
154+
inline const char* TypeName(BFloat16 v) {
154155
return "BFloat16";
155156
}
156157

157158
template <>
158-
inline std::string TypeName(Half v) {
159+
inline const char* TypeName(Half v) {
159160
return "Half";
160161
}
161162

162163
template <>
163-
inline std::string TypeName(Float8_e4m3fn v) {
164+
inline const char* TypeName(Float8_e4m3fn v) {
164165
return "Float8_e4m3fn";
165166
}
166167

167168
template <>
168-
inline std::string TypeName(Float8_e5m2 v) {
169+
inline const char* TypeName(Float8_e5m2 v) {
169170
return "Float8_e5m2";
170171
}
171172

172173
template <>
173-
inline std::string TypeName(Float8_e4m3fnuz v) {
174+
inline const char* TypeName(Float8_e4m3fnuz v) {
174175
return "Float8_e4m3fnuz";
175176
}
176177

177178
template <>
178-
inline std::string TypeName(Float8_e5m2fnuz v) {
179+
inline const char* TypeName(Float8_e5m2fnuz v) {
179180
return "Float8_e5m2fnuz";
180181
}
181182

182183
template <>
183-
inline std::string TypeName(c10::complex<double> v) {
184+
inline const char* TypeName(c10::complex<double> v) {
184185
return "c10::complex<double>";
185186
}
186187

187188
template <>
188-
inline std::string TypeName(c10::complex<float> v) {
189+
inline const char* TypeName(c10::complex<float> v) {
189190
return "c10::complex<float>";
190191
}
191192

@@ -218,7 +219,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
218219
}
219220

220221
std::string Signature() override {
221-
return c10::str("GemmTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
222+
return fmt::sprintf("GemmTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
222223
}
223224
};
224225

@@ -244,7 +245,7 @@ class GemmAndBiasTunableOp : public TunableOp<GemmAndBiasParams<T>, StreamTimer>
244245
}
245246

246247
std::string Signature() override {
247-
return c10::str("GemmAndBiasTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
248+
return fmt::sprintf("GemmAndBiasTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
248249
}
249250
};
250251

@@ -277,7 +278,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
277278
}
278279

279280
std::string Signature() override {
280-
return c10::str("GemmStridedBatchedTunableOp_", TypeName<T>(T{}), "_", BlasOpToString(ALayout), BlasOpToString(BLayout));
281+
return fmt::sprintf("GemmStridedBatchedTunableOp_%s_%c%c", TypeName<T>(T{}), BlasOpToString(ALayout), BlasOpToString(BLayout));
281282
}
282283
};
283284

@@ -295,11 +296,11 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
295296
}
296297

297298
std::string Signature() override {
298-
return c10::str("ScaledGemmTunableOp",
299-
"_", TypeName<AT>(AT{}),
300-
"_", TypeName<BT>(BT{}),
301-
"_", TypeName<CT>(CT{}),
302-
"_", BlasOpToString(ALayout), BlasOpToString(BLayout));
299+
return fmt::sprintf("ScaledGemmTunableOp_%s_%s_%s_%c%c",
300+
TypeName<AT>(AT{}),
301+
TypeName<BT>(BT{}),
302+
TypeName<CT>(CT{}),
303+
BlasOpToString(ALayout), BlasOpToString(BLayout));
303304
}
304305
};
305306

0 commit comments

Comments
 (0)