Skip to content

Commit ff09bd0

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Update permute_copy to support all scalar types
Summary: Context: D47863112 The new permute_copy broke tests since the previous one supported all types, whereas it was changed to only support real types and bool. This diff rectifies this by updating it to support all possible scalar types. Reviewed By: digantdesai Differential Revision: D47876145 fbshipit-source-id: 43673299da2b2f89d43899e00833f9200c1c9119
1 parent c2418e9 commit ff09bd0

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

kernels/portable/cpu/op_permute_copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ Tensor& permute_copy_out(
5353

5454
const auto in_type = out.scalar_type();
5555
// in and out must be the same dtype
56-
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "permute_copy", CTYPE, [&] {
56+
ET_SWITCH_ALL_TYPES(in_type, ctx, "permute_copy", CTYPE, [&] {
5757
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
5858
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
5959

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,54 @@ inline size_t sizeof_scalar_type(exec_aten::ScalarType type) {
574574
} \
575575
}()
576576

577+
#define ET_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, ...) \
578+
ET_INTERNAL_SWITCH_CASE( \
579+
exec_aten::ScalarType::Byte, CTYPE_ALIAS, __VA_ARGS__) \
580+
ET_INTERNAL_SWITCH_CASE( \
581+
exec_aten::ScalarType::Char, CTYPE_ALIAS, __VA_ARGS__) \
582+
ET_INTERNAL_SWITCH_CASE( \
583+
exec_aten::ScalarType::Short, CTYPE_ALIAS, __VA_ARGS__) \
584+
ET_INTERNAL_SWITCH_CASE( \
585+
exec_aten::ScalarType::Int, CTYPE_ALIAS, __VA_ARGS__) \
586+
ET_INTERNAL_SWITCH_CASE( \
587+
exec_aten::ScalarType::Long, CTYPE_ALIAS, __VA_ARGS__) \
588+
ET_INTERNAL_SWITCH_CASE( \
589+
exec_aten::ScalarType::Half, CTYPE_ALIAS, __VA_ARGS__) \
590+
ET_INTERNAL_SWITCH_CASE( \
591+
exec_aten::ScalarType::Float, CTYPE_ALIAS, __VA_ARGS__) \
592+
ET_INTERNAL_SWITCH_CASE( \
593+
exec_aten::ScalarType::Double, CTYPE_ALIAS, __VA_ARGS__) \
594+
ET_INTERNAL_SWITCH_CASE( \
595+
exec_aten::ScalarType::ComplexHalf, CTYPE_ALIAS, __VA_ARGS__) \
596+
ET_INTERNAL_SWITCH_CASE( \
597+
exec_aten::ScalarType::ComplexFloat, CTYPE_ALIAS, __VA_ARGS__) \
598+
ET_INTERNAL_SWITCH_CASE( \
599+
exec_aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__) \
600+
ET_INTERNAL_SWITCH_CASE( \
601+
exec_aten::ScalarType::Bool, CTYPE_ALIAS, __VA_ARGS__) \
602+
ET_INTERNAL_SWITCH_CASE( \
603+
exec_aten::ScalarType::QInt8, CTYPE_ALIAS, __VA_ARGS__) \
604+
ET_INTERNAL_SWITCH_CASE( \
605+
exec_aten::ScalarType::QUInt8, CTYPE_ALIAS, __VA_ARGS__) \
606+
ET_INTERNAL_SWITCH_CASE( \
607+
exec_aten::ScalarType::QInt32, CTYPE_ALIAS, __VA_ARGS__) \
608+
ET_INTERNAL_SWITCH_CASE( \
609+
exec_aten::ScalarType::BFloat16, CTYPE_ALIAS, __VA_ARGS__) \
610+
ET_INTERNAL_SWITCH_CASE( \
611+
exec_aten::ScalarType::QUInt4x2, CTYPE_ALIAS, __VA_ARGS__) \
612+
ET_INTERNAL_SWITCH_CASE( \
613+
exec_aten::ScalarType::QUInt2x4, CTYPE_ALIAS, __VA_ARGS__) \
614+
ET_INTERNAL_SWITCH_CASE( \
615+
exec_aten::ScalarType::Bits1x8, CTYPE_ALIAS, __VA_ARGS__) \
616+
ET_INTERNAL_SWITCH_CASE( \
617+
exec_aten::ScalarType::Bits2x4, CTYPE_ALIAS, __VA_ARGS__) \
618+
ET_INTERNAL_SWITCH_CASE( \
619+
exec_aten::ScalarType::Bits4x2, CTYPE_ALIAS, __VA_ARGS__) \
620+
ET_INTERNAL_SWITCH_CASE( \
621+
exec_aten::ScalarType::Bits8, CTYPE_ALIAS, __VA_ARGS__) \
622+
ET_INTERNAL_SWITCH_CASE( \
623+
exec_aten::ScalarType::Bits16, CTYPE_ALIAS, __VA_ARGS__)
624+
577625
#define ET_INTERNAL_SWITCH_CASE_REAL_TYPES(CTYPE_ALIAS, ...) \
578626
ET_INTERNAL_SWITCH_CASE( \
579627
exec_aten::ScalarType::Byte, CTYPE_ALIAS, __VA_ARGS__) \
@@ -700,6 +748,13 @@ inline size_t sizeof_scalar_type(exec_aten::ScalarType type) {
700748
// used to alias the ctype associated with the ScalarType that is being handled.
701749
//
702750

751+
#define ET_SWITCH_ALL_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
752+
ET_INTERNAL_SWITCH( \
753+
TYPE, \
754+
CONTEXT, \
755+
NAME, \
756+
ET_INTERNAL_SWITCH_CASE_ALL_TYPES(CTYPE_ALIAS, __VA_ARGS__))
757+
703758
#define ET_SWITCH_REAL_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
704759
ET_INTERNAL_SWITCH( \
705760
TYPE, \

0 commit comments

Comments
 (0)