Skip to content

Commit 1e32bbc

Browse files
xytintelmengfei25LuFinchratnampafengyuan14
authored
Sync main into release/2.6 branch (#1117)
Reset to bfdbaf4 --------- Co-authored-by: mengfei25 <[email protected]> Co-authored-by: LuFengqing <[email protected]> Co-authored-by: Ratnam Parikh <[email protected]> Co-authored-by: Feng Yuan <[email protected]>
1 parent 59672bb commit 1e32bbc

File tree

11 files changed

+130
-41
lines changed

11 files changed

+130
-41
lines changed

.github/scripts/apply_torch_pr.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"https://github.com/pytorch/pytorch/pull/126516",
1414
# Modify the tolerance level in TIMM benchmark
1515
"https://github.com/pytorch/pytorch/pull/129735",
16+
# [XPU] Update XPU C Shim Header
17+
"https://github.com/pytorch/pytorch/pull/141086",
1618
]
1719
)
1820
parser.add_argument('--extra-pr-list', '-e', nargs='+',default=[])

.github/scripts/env.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
#!/bin/bash
2-
source /opt/intel/oneapi/pytorch-gpu-dev-0.5/oneapi-vars.sh
2+
source /opt/intel/oneapi/compiler/latest/env/vars.sh
3+
source /opt/intel/oneapi/umf/latest/env/vars.sh
34
source /opt/intel/oneapi/pti/latest/env/vars.sh

.github/workflows/_linux_ut.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ jobs:
9797
run: |
9898
source activate xpu_op_${ZE_AFFINITY_MASK}
9999
source .github/scripts/env.sh
100-
pip install mkl-static mkl-include
100+
pip install mkl-static==2025.0.1 mkl-include==2025.0.1
101101
cd ../pytorch
102102
if [[ ${{ inputs.abi }} == '0' ]]; then
103103
export _GLIBCXX_USE_CXX11_ABI=0

.github/workflows/nightly_ondemand.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ jobs:
123123
conda remove --all -y -n e2e_ci || rm -rf $(dirname ${CONDA_EXE})/../envs/e2e_ci
124124
conda create -n e2e_ci python=${{ env.python }} cmake ninja -y
125125
source activate e2e_ci
126-
pip install mkl-static mkl-include
126+
pip install mkl-static==2025.0.1 mkl-include==2025.0.1
127127
pip install pandas scipy tqdm
128128
- name: Prepare Stock Pytorch
129129
run: |

.github/workflows/nightly_ondemand_rolling.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ jobs:
125125
conda remove --all -y -n e2e_ci || rm -rf $(dirname ${CONDA_EXE})/../envs/e2e_ci
126126
conda create -n e2e_ci python=${{ env.python }} cmake ninja -y
127127
source activate e2e_ci
128-
pip install mkl-static mkl-include
128+
pip install mkl-static==2025.0.1 mkl-include==2025.0.1
129129
pip install pandas scipy tqdm
130130
- name: Prepare Stock Pytorch
131131
run: |

.github/workflows/nightly_ondemand_whl.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ jobs:
9898
conda remove --all -y -n e2e_ci || rm -rf $(dirname ${CONDA_EXE})/../envs/e2e_ci
9999
conda create -n e2e_ci python=${{ env.python }} cmake ninja -y
100100
source activate e2e_ci
101-
pip install mkl-static mkl-include
101+
pip install mkl-static==2025.0.1 mkl-include==2025.0.1
102102
pip install pandas scipy tqdm
103103
- name: Prepare Stock Pytorch
104104
run: |

cmake/BuildFlags.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "MSVC"
122122
set(SYCL_OFFLINE_COMPILER_CG_OPTIONS "-options '${SYCL_OFFLINE_COMPILER_CG_OPTIONS}'")
123123

124124
if(WIN32)
125-
set(AOT_TARGETS "ats-m150,lnl-m,mtl-u,mtl-h")
125+
set(AOT_TARGETS "ats-m150,mtl-u,mtl-h,xe2-lpg,xe2-hpg")
126126
else()
127127
set(AOT_TARGETS "pvc,xe-lpg,ats-m150")
128128
endif()

src/ATen/native/xpu/sycl/MultiTensorApply.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ static inline int64_t multi_tensor_apply_fused_kernel_get_chunk_size() {
6868
}
6969

7070
template <typename T, typename Y, typename U, typename... ArgTypes>
71-
struct MultiTensorApplyKernelFunctor {
71+
struct MultiTensorApplyKernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
7272
void operator()(sycl::nd_item<1> item_id) const {
7373
// Expand the tuple elements manually and call the callable
7474
expandAndCall(item_id, std::index_sequence_for<ArgTypes...>());
@@ -85,6 +85,12 @@ struct MultiTensorApplyKernelFunctor {
8585
callable(callable_),
8686
args(std::make_tuple(args_...)) {}
8787

88+
void sycl_ker_config_convention(sycl::handler& cgh) {
89+
if constexpr (std::is_base_of_v<__SYCL_KER_CONFIG_CONVENTION__, U>) {
90+
callable.sycl_ker_config_convention(cgh);
91+
}
92+
}
93+
8894
private:
8995
template <std::size_t... Indices>
9096
void expandAndCall(sycl::nd_item<1> item_id, std::index_sequence<Indices...>)
@@ -117,7 +123,6 @@ void launch_multi_tensor_apply_kernel(
117123
U callable,
118124
int num_wg,
119125
ArgTypes... args) {
120-
121126
auto& q = getCurrentSYCLQueue();
122127
int64_t simd = syclMaxSubGroupSize();
123128
int64_t max_wg_size = multi_tensor_apply_kernel_get_wg_size(simd);
@@ -226,7 +231,6 @@ void multi_tensor_apply(
226231
std::vector<std::vector<at::Tensor>>& tensor_lists,
227232
T callable,
228233
ArgTypes... args) {
229-
230234
TORCH_CHECK(
231235
tensor_lists.size() == depth,
232236
"Number of tensor lists has to match he depth");

src/ATen/native/xpu/sycl/ScatterGatherKernels.cpp

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -143,41 +143,64 @@ struct alignas(N) OpaqueType {
143143
char data[N];
144144
};
145145

146-
template <int work_group_size, int thread_work_size, typename func_t>
146+
template <typename func_t>
147147
struct ScatterGatherElementwiseKernelFunctor {
148148
void operator()(sycl::nd_item<1> item) const {
149-
constexpr int nv = work_group_size * thread_work_size;
149+
int nv = work_group_size_ * thread_work_size_;
150150
auto wg_id = item.get_group_linear_id();
151151
auto local_id = item.get_local_linear_id();
152152
int idx = nv * wg_id + local_id;
153-
#pragma unroll
154-
for (int i = 0; i < thread_work_size; ++i) {
153+
for (int i = 0; i < thread_work_size_; ++i) {
155154
if (idx < N_) {
156155
f_(idx);
157-
idx += work_group_size;
156+
idx += work_group_size_;
158157
}
159158
}
160159
}
161-
ScatterGatherElementwiseKernelFunctor(int N, func_t f) : N_(N), f_(f) {}
160+
ScatterGatherElementwiseKernelFunctor(
161+
int N,
162+
func_t f,
163+
int work_group_size,
164+
int thread_work_size)
165+
: N_(N),
166+
f_(f),
167+
work_group_size_(work_group_size),
168+
thread_work_size_(thread_work_size) {}
162169

163170
private:
164171
int N_;
165172
func_t f_;
173+
int work_group_size_;
174+
int thread_work_size_;
166175
};
167176

168-
template <int nt, int vt, typename func_t>
177+
template <typename func_t>
169178
static void launch_scatter_gather_kernel(int64_t N, const func_t& f) {
170179
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
171180
if (N == 0) {
172181
return;
173182
}
174183

175-
sycl::range<1> local_range{(size_t)nt};
176-
int num_workgroups = (N + nt * vt - 1) / (nt * vt);
177-
sycl::range<1> global_range{(size_t)(num_workgroups * nt)};
178-
179-
auto caller =
180-
ScatterGatherElementwiseKernelFunctor<nt, vt, func_t>((int)N, f);
184+
using KernelFn = ScatterGatherElementwiseKernelFunctor<func_t>;
185+
int64_t max_wg_size = syclMaxWorkGroupSize<KernelFn>();
186+
int outputSize = N;
187+
int work_group_size = outputSize > max_wg_size ? max_wg_size : outputSize;
188+
const auto target_global_size = syclMaxWorkItemsPerTile();
189+
// Each work group size is work_group_size, one full device launch is
190+
// target_global_size, so we can calculate max work group num as below
191+
const int max_work_group_num = target_global_size / work_group_size;
192+
int work_group_num = outputSize / work_group_size < max_work_group_num
193+
? outputSize / work_group_size
194+
: max_work_group_num;
195+
int draft_work_group_num =
196+
(outputSize + work_group_size - 1) / work_group_size;
197+
198+
int thread_work_size = draft_work_group_num / work_group_num + 1;
199+
200+
sycl::range<1> local_range(work_group_size);
201+
sycl::range<1> global_range(work_group_num * work_group_size);
202+
203+
auto caller = KernelFn((int)N, f, work_group_size, thread_work_size);
181204
sycl_kernel_submit(
182205
global_range, local_range, at::xpu::getCurrentSYCLQueue(), caller);
183206
}
@@ -268,11 +291,7 @@ struct ScatterGatherInternalKernel {
268291
numel,
269292
f);
270293

271-
// TODO: optimize it
272-
constexpr int group_work_items = 256;
273-
constexpr int work_size_per_item = 4;
274-
launch_scatter_gather_kernel<group_work_items, work_size_per_item>(
275-
iter.numel(), loop);
294+
launch_scatter_gather_kernel(iter.numel(), loop);
276295
}
277296
};
278297

@@ -521,11 +540,7 @@ struct ScatterFillInternalKernel {
521540
decltype(offset_calc),
522541
func_t>(self_ptr, index_ptr, offset_calc, index_stride, f, src_val);
523542

524-
// TODO: optimize it
525-
constexpr int group_work_items = 256;
526-
constexpr int work_size_per_item = 4;
527-
launch_scatter_gather_kernel<group_work_items, work_size_per_item>(
528-
iter.numel(), loop);
543+
launch_scatter_gather_kernel(iter.numel(), loop);
529544
}
530545
};
531546

src/BuildOnWindows.cmake

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,6 @@
33
set(TORCH_XPU_OPS_LIBRARIES)
44
set(SYCL_LINK_LIBRARIES_KEYWORD PRIVATE)
55

6-
# Walk around cyclic dependence
7-
# libtorch_xpu.so links to libtorch_xpu_ops.a
8-
# Load libtorch_xpu_ops_aten.so explicitly by torch/__init__.py:_load_dll_libraries (Break cycle)
9-
# libtorch_xpu_ops_aten.so links to libtorch_xpu_ops_sycl_unary_binary_kernels.so and libtorch_xpu_ops_sycl_kernels.so
10-
# libtorch_xpu_ops_sycl_unary_binary_kernels.so and libtorch_xpu_ops_sycl_kernels.so links to libtorch_xpu.so
116
add_library(
127
torch_xpu_ops
138
STATIC
@@ -21,7 +16,6 @@ add_library(
2116
${ATen_XPU_NATIVE_CPP_SRCS}
2217
${ATen_XPU_GEN_SRCS})
2318
install(TARGETS torch_xpu_ops_aten DESTINATION "${TORCH_INSTALL_LIB_DIR}")
24-
# target_compile_definitions(torch_xpu_ops_aten PRIVATE CAFFE2_BUILD_MAIN_LIB)
2519
target_compile_definitions(torch_xpu_ops_aten PRIVATE TORCH_XPU_BUILD_MAIN_LIB)
2620
target_link_libraries(torch_xpu_ops_aten PUBLIC torch_xpu)
2721
target_link_libraries(torch_xpu_ops_aten PUBLIC torch_cpu)
@@ -48,8 +42,11 @@ else()
4842
set(ATen_XPU_SYCL_REDUCE_SRCS)
4943
set(ATen_XPU_SYCL_ACTIVATION_SRCS)
5044
set(ATen_XPU_SYCL_FOREACH_SRCS)
45+
set(ATen_XPU_SYCL_TENSOR_SRCS)
46+
set(ATen_XPU_SYCL_NORM_LOSS_SRCS)
47+
set(ATen_XPU_SYCL_POLY_SRCS)
48+
set(ATen_XPU_SYCL_DISTRIBUTION_SRCS)
5149
set(ATen_XPU_SYCL_OTHERS_SRCS)
52-
5350
foreach(sycl_src ${ATen_XPU_SYCL_SRCS})
5451
string(REGEX MATCH "Binary" IS_BINARY ${sycl_src})
5552
string(REGEX MATCH "Unary" IS_UNARY ${sycl_src})
@@ -63,6 +60,13 @@ else()
6360
string(REGEX MATCH "Activation" IS_ACTIVATION ${sycl_src})
6461
string(REGEX MATCH "Foreach" IS_FOREACH ${sycl_src})
6562
string(REGEX MATCH "Reduce" IS_REDUCE ${sycl_src})
63+
string(REGEX MATCH "Tensor" IS_TENSOR ${sycl_src})
64+
string(REGEX MATCH "Norm" IS_NORM ${sycl_src})
65+
string(REGEX MATCH "Loss" IS_LOSS ${sycl_src})
66+
string(REGEX MATCH "Polynomial" IS_POLY ${sycl_src})
67+
#Move resize kernel to Norm and Loss lib, to resolve symbol.
68+
string(REGEX MATCH "Resize" IS_RESIZE ${sycl_src})
69+
string(REGEX MATCH "Distribution" IS_DISTRIBUTION ${sycl_src})
6670
6771
if(NOT IS_FOREACH STREQUAL "")
6872
list(APPEND ATen_XPU_SYCL_FOREACH_SRCS ${sycl_src})
@@ -74,11 +78,18 @@ else()
7478
list(APPEND ATen_XPU_SYCL_REDUCE_SRCS ${sycl_src})
7579
elseif(NOT IS_ACTIVATION STREQUAL "")
7680
list(APPEND ATen_XPU_SYCL_ACTIVATION_SRCS ${sycl_src})
81+
elseif(NOT IS_TENSOR STREQUAL "")
82+
list(APPEND ATen_XPU_SYCL_TENSOR_SRCS ${sycl_src})
83+
elseif(NOT IS_DISTRIBUTION STREQUAL "")
84+
list(APPEND ATen_XPU_SYCL_DISTRIBUTION_SRCS ${sycl_src})
85+
elseif(NOT IS_NORM STREQUAL "" OR NOT IS_LOSS STREQUAL "" OR NOT IS_RESIZE STREQUAL "")
86+
list(APPEND ATen_XPU_SYCL_NORM_LOSS_SRCS ${sycl_src})
87+
elseif(NOT IS_POLY STREQUAL "")
88+
list(APPEND ATen_XPU_SYCL_POLY_SRCS ${sycl_src})
7789
else()
7890
list(APPEND ATen_XPU_SYCL_OTHERS_SRCS ${sycl_src})
7991
endif()
8092
endforeach()
81-
8293
# Binary kernel lib
8394
set(sycl_binary_lib torch_xpu_ops_sycl_binary_kernels)
8495
sycl_add_library(
@@ -148,7 +159,63 @@ else()
148159
149160
# Decouple with PyTorch cmake definition.
150161
install(TARGETS ${sycl_foreach_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}")
162+
163+
# Tensor kernel lib
164+
set(sycl_tensor_lib torch_xpu_ops_sycl_tensor_kernels)
165+
sycl_add_library(
166+
${sycl_tensor_lib}
167+
SHARED
168+
SYCL_SOURCES ${ATen_XPU_SYCL_TENSOR_SRCS})
169+
target_compile_definitions(${sycl_tensor_lib} PRIVATE TORCH_XPU_BUILD_MAIN_LIB)
170+
target_link_libraries(torch_xpu_ops_aten PUBLIC ${sycl_tensor_lib})
171+
target_link_libraries(${sycl_tensor_lib} PUBLIC torch_xpu)
172+
list(APPEND TORCH_XPU_OPS_LIBRARIES ${sycl_tensor_lib})
151173
174+
# Decouple with PyTorch cmake definition.
175+
install(TARGETS ${sycl_tensor_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}")
176+
177+
# Norm and Loss kernel lib
178+
set(sycl_norm_loss_lib torch_xpu_ops_sycl_norm_loss_kernels)
179+
sycl_add_library(
180+
${sycl_norm_loss_lib}
181+
SHARED
182+
SYCL_SOURCES ${ATen_XPU_SYCL_NORM_LOSS_SRCS})
183+
target_compile_definitions(${sycl_norm_loss_lib} PRIVATE TORCH_XPU_BUILD_MAIN_LIB)
184+
target_link_libraries(torch_xpu_ops_aten PUBLIC ${sycl_norm_loss_lib})
185+
target_link_libraries(${sycl_norm_loss_lib} PUBLIC torch_xpu)
186+
list(APPEND TORCH_XPU_OPS_LIBRARIES ${sycl_norm_loss_lib})
187+
188+
# Decouple with PyTorch cmake definition.
189+
install(TARGETS ${sycl_norm_loss_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}")
190+
191+
# Polynomial kernel lib
192+
set(sycl_poly_lib torch_xpu_ops_sycl_poly_kernels)
193+
sycl_add_library(
194+
${sycl_poly_lib}
195+
SHARED
196+
SYCL_SOURCES ${ATen_XPU_SYCL_POLY_SRCS})
197+
target_compile_definitions(${sycl_poly_lib} PRIVATE TORCH_XPU_BUILD_MAIN_LIB)
198+
target_link_libraries(torch_xpu_ops_aten PUBLIC ${sycl_poly_lib})
199+
target_link_libraries(${sycl_poly_lib} PUBLIC torch_xpu)
200+
list(APPEND TORCH_XPU_OPS_LIBRARIES ${sycl_poly_lib})
201+
202+
# Decouple with PyTorch cmake definition.
203+
install(TARGETS ${sycl_poly_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}")
204+
205+
# Distribution kernel lib
206+
set(sycl_dist_lib torch_xpu_ops_sycl_dist_kernels)
207+
sycl_add_library(
208+
${sycl_dist_lib}
209+
SHARED
210+
SYCL_SOURCES ${ATen_XPU_SYCL_DISTRIBUTION_SRCS})
211+
target_compile_definitions(${sycl_dist_lib} PRIVATE TORCH_XPU_BUILD_MAIN_LIB)
212+
target_link_libraries(torch_xpu_ops_aten PUBLIC ${sycl_dist_lib})
213+
target_link_libraries(${sycl_dist_lib} PUBLIC torch_xpu)
214+
list(APPEND TORCH_XPU_OPS_LIBRARIES ${sycl_dist_lib})
215+
216+
# Decouple with PyTorch cmake definition.
217+
install(TARGETS ${sycl_dist_lib} DESTINATION "${TORCH_INSTALL_LIB_DIR}")
218+
152219
# Other kernel lib
153220
set(sycl_lib torch_xpu_ops_sycl_kernels)
154221
sycl_add_library(

test/xpu/test_binary_ufuncs_xpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def to_np(value):
6565
else:
6666
self.assertRaisesRegex(
6767
RuntimeError,
68-
"Found dtype \\w+ but expected \\w+",
68+
r"result type \w+ can't be cast to the desired output type \w+",
6969
lambda: actual.pow_(exponent),
7070
)
7171

0 commit comments

Comments
 (0)