Skip to content

Commit 06aee6c

Browse files
authored
Merge branch 'main' into angelayi/aoti_api_update
2 parents 195aba8 + 3ce9c8e commit 06aee6c

File tree

11 files changed

+209
-34
lines changed

11 files changed

+209
-34
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
name: Run the README instructions - with stories - on Linux aarch64
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches:
7+
- main
8+
workflow_dispatch:
9+
10+
jobs:
11+
test-readme-cpu:
12+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
13+
permissions:
14+
id-token: write
15+
contents: read
16+
with:
17+
runner: linux.arm64.2xlarge
18+
docker-image: "pytorch/manylinux2_28_aarch64-builder:cpu-aarch64-main"
19+
gpu-arch-type: cpu-aarch64
20+
timeout: 60
21+
script: |
22+
echo "::group::Print machine info"
23+
uname -a
24+
echo "::endgroup::"
25+
26+
TORCHCHAT_DEVICE=cpu .ci/scripts/run-docs readme
27+
28+
echo "::group::Completion"
29+
echo "tests complete"
30+
echo "*******************************************"
31+
echo "::endgroup::"
32+
33+
test-quantization-cpu:
34+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
35+
permissions:
36+
id-token: write
37+
contents: read
38+
with:
39+
runner: linux.arm64.2xlarge
40+
docker-image: "pytorch/manylinux2_28_aarch64-builder:cpu-aarch64-main"
41+
gpu-arch-type: cpu-aarch64
42+
timeout: 60
43+
script: |
44+
echo "::group::Print machine info"
45+
uname -a
46+
echo "::endgroup::"
47+
48+
TORCHCHAT_DEVICE=cpu .ci/scripts/run-docs quantization
49+
50+
test-gguf-cpu:
51+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
52+
permissions:
53+
id-token: write
54+
contents: read
55+
with:
56+
runner: linux.arm64.2xlarge
57+
docker-image: "pytorch/manylinux2_28_aarch64-builder:cpu-aarch64-main"
58+
gpu-arch-type: cpu-aarch64
59+
timeout: 60
60+
script: |
61+
echo "::group::Print machine info"
62+
uname -a
63+
echo "::endgroup::"
64+
65+
TORCHCHAT_DEVICE=cpu .ci/scripts/run-docs gguf
66+
67+
echo "::group::Completion"
68+
echo "tests complete"
69+
echo "*******************************************"
70+
echo "::endgroup::"
71+
72+
test-advanced-cpu:
73+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
74+
permissions:
75+
id-token: write
76+
contents: read
77+
with:
78+
runner: linux.arm64.2xlarge
79+
docker-image: "pytorch/manylinux2_28_aarch64-builder:cpu-aarch64-main"
80+
gpu-arch-type: cpu-aarch64
81+
timeout: 60
82+
script: |
83+
echo "::group::Print machine info"
84+
uname -a
85+
echo "::endgroup::"
86+
87+
TORCHCHAT_DEVICE=cpu .ci/scripts/run-docs advanced
88+
89+
echo "::group::Completion"
90+
echo "tests complete"
91+
echo "*******************************************"
92+
echo "::endgroup::"
93+
94+
test-evaluation-cpu:
95+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
96+
permissions:
97+
id-token: write
98+
contents: read
99+
with:
100+
runner: linux.arm64.2xlarge
101+
docker-image: "pytorch/manylinux2_28_aarch64-builder:cpu-aarch64-main"
102+
gpu-arch-type: cpu-aarch64
103+
timeout: 60
104+
script: |
105+
echo "::group::Print machine info"
106+
uname -a
107+
echo "::endgroup::"
108+
109+
TORCHCHAT_DEVICE=cpu .ci/scripts/run-docs evaluation
110+
111+
echo "::group::Completion"
112+
echo "tests complete"
113+
echo "*******************************************"
114+
echo "::endgroup::"

.github/workflows/run-readme-pr-mps.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ jobs:
1010
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
1111
with:
1212
runner: macos-m1-14
13-
timeout: 50
13+
timeout: 60
1414
script: |
1515
conda create -y -n test-readme-mps-macos python=3.10.11 llvm-openmp
1616
conda activate test-readme-mps-macos
@@ -63,7 +63,7 @@ jobs:
6363
test-gguf-mps-macos:
6464
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
6565
with:
66-
runner: macos-m1-14 # neeps MPS, was macos-m1-stable
66+
runner: macos-m1-14 # needs MPS, was macos-m1-stable
6767
script: |
6868
set -x
6969
conda create -y -n test-quantization-mps-macos python=3.10.11
@@ -90,7 +90,7 @@ jobs:
9090
test-advanced-mps-macos:
9191
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
9292
with:
93-
runner: macos-m1-14 # neeps MPS, was macos-m1-stable
93+
runner: macos-m1-14 # needs MPS, was macos-m1-stable
9494
script: |
9595
set -x
9696
conda create -y -n test-quantization-mps-macos python=3.10.11

docs/ADVANCED-USERS.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,7 +479,7 @@ in a Python-free environment with AOT Inductor and ExecuTorch.
479479
| Hardware | OS | Eager | Eager + Compile | AOT Compile | ET Runtime |
480480
|-----|------|-----|-----|-----|-----|
481481
| x86 | Linux |||||
482-
| aarch64 | Linux | n/t | n/t | n/t | n/t |
482+
| aarch64 | Linux | | | | n/t |
483483
| aarch64 | macOS |||||
484484
| AMD GPU | Linux |||||
485485
| Nvidia GPU | Linux |||||
@@ -490,7 +490,7 @@ in a Python-free environment with AOT Inductor and ExecuTorch.
490490
| Mobile GPU (Vulkan) | Android |||||
491491
| CoreML | iOS |||||
492492
| Hexagon DSP | Android |||||
493-
| Raspberry Pi 4/5 | Raspbian | n/t | n/t | n/t ||
493+
| Raspberry Pi 4/5 | Raspbian | | | ||
494494
| Raspberry Pi 4/5 | Android |||| n/t |
495495
| ARM 32b (up to v7) | any |||||
496496

install/install_requirements.sh

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,13 @@ echo "Using pip executable: $PIP_EXECUTABLE"
5151
# NOTE: If a newly-fetched version of the executorch repo changes the value of
5252
# PYTORCH_NIGHTLY_VERSION, you should re-run this script to install the necessary
5353
# package versions.
54-
PYTORCH_NIGHTLY_VERSION=dev20241218
54+
PYTORCH_NIGHTLY_VERSION=dev20250119
5555

5656
# Nightly version for torchvision
57-
VISION_NIGHTLY_VERSION=dev20241218
57+
VISION_NIGHTLY_VERSION=dev20250119
5858

5959
# Nightly version for torchtune
60-
TUNE_NIGHTLY_VERSION=dev20241218
61-
62-
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
63-
(
64-
set -x
65-
$PIP_EXECUTABLE uninstall -y triton
66-
)
60+
TUNE_NIGHTLY_VERSION=dev20250119
6761

6862
# The pip repository that hosts nightly torch packages. cpu by default.
6963
# If cuda is available, based on presence of nvidia-smi, install the pytorch nightly
@@ -74,16 +68,28 @@ then
7468
elif [[ -x "$(command -v rocminfo)" ]];
7569
then
7670
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/rocm6.2"
71+
elif [[ -x "$(command -v xpu-smi)" ]];
72+
then
73+
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/xpu"
7774
else
7875
TORCH_NIGHTLY_URL="https://download.pytorch.org/whl/nightly/cpu"
7976
fi
8077

8178
# pip packages needed by exir.
82-
REQUIREMENTS_TO_INSTALL=(
83-
torch=="2.6.0.${PYTORCH_NIGHTLY_VERSION}"
84-
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
85-
torchtune=="0.5.0.${TUNE_NIGHTLY_VERSION}"
86-
)
79+
if [[ -x "$(command -v xpu-smi)" ]];
80+
then
81+
REQUIREMENTS_TO_INSTALL=(
82+
torch=="2.7.0.${PYTORCH_NIGHTLY_VERSION}"
83+
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
84+
torchtune=="0.6.0"
85+
)
86+
else
87+
REQUIREMENTS_TO_INSTALL=(
88+
torch=="2.7.0.${PYTORCH_NIGHTLY_VERSION}"
89+
torchvision=="0.22.0.${VISION_NIGHTLY_VERSION}"
90+
torchtune=="0.6.0.${TUNE_NIGHTLY_VERSION}"
91+
)
92+
fi
8793

8894
#
8995
# First install requirements in install/requirements.txt. Older torch may be
@@ -95,6 +101,12 @@ REQUIREMENTS_TO_INSTALL=(
95101
$PIP_EXECUTABLE install -r install/requirements.txt --extra-index-url "${TORCH_NIGHTLY_URL}"
96102
)
97103

104+
# Uninstall triton, as nightly will depend on pytorch-triton, which is one and the same
105+
(
106+
set -x
107+
$PIP_EXECUTABLE uninstall -y triton
108+
)
109+
98110
# Install the requirements. --extra-index-url tells pip to look for package
99111
# versions on the provided URL if they aren't available on the default URL.
100112
(
@@ -116,8 +128,6 @@ if [[ -x "$(command -v nvidia-smi)" ]]; then
116128
$PYTHON_EXECUTABLE torchchat/utils/scripts/patch_triton.py
117129
)
118130
fi
119-
120-
121131
(
122132
set -x
123133
$PIP_EXECUTABLE install evaluate=="0.4.3" lm-eval=="0.4.2" psutil=="6.0.0"

torchchat/cli/builder.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,10 +69,16 @@ class BuilderArgs:
6969
prefill_possible: bool = False
7070
dynamic_shapes: bool = False
7171
max_seq_length: Optional[int] = None
72+
attention_backend: str = "math"
7273

7374
def __post_init__(self):
7475
if self.device is None:
75-
self.device = "cuda" if torch.cuda.is_available() else "cpu"
76+
if torch.cuda.is_available():
77+
self.device = "cuda"
78+
elif torch.xpu.is_available():
79+
self.device = "xpu"
80+
else:
81+
self.device = "cpu"
7682

7783
if not (
7884
(self.checkpoint_path and self.checkpoint_path.is_file())
@@ -178,6 +184,17 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
178184
pp = getattr(args, "pp", 1)
179185
tp = getattr(args, "tp", 1)
180186
chpt_from = getattr(args, "chpt_from", "hf")
187+
sdp_backend_dict = {
188+
'math': torch.nn.attention.SDPBackend.MATH,
189+
'flash_attention': torch.nn.attention.SDPBackend.FLASH_ATTENTION,
190+
'efficient_attention': torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
191+
'cudnn_attention': torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
192+
}
193+
attention_backend = sdp_backend_dict[args.attention_backend]
194+
if args.device == "cpu" and (args.attention_backend == "efficient_attention"
195+
or args.attention_backend == "cudnn_attention"):
196+
print(f"Warning: {args.attention_backend} is not supported on CPU. Using math instead.")
197+
attention_backend = torch.nn.attention.SDPBackend.MATH
181198
return cls(
182199
checkpoint_dir=checkpoint_dir,
183200
checkpoint_path=checkpoint_path,
@@ -202,6 +219,7 @@ def from_args(cls, args: argparse.Namespace) -> "BuilderArgs":
202219
is_chat_model=is_chat_model,
203220
dynamic_shapes=getattr(args, "dynamic_shapes", False),
204221
max_seq_length=getattr(args, "max_seq_length", None),
222+
attention_backend=attention_backend,
205223
)
206224

207225
@classmethod

torchchat/cli/cli.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,15 @@ def _add_model_config_args(parser, verb: str) -> None:
176176
"--device",
177177
type=str,
178178
default=None,
179-
choices=["fast", "cpu", "cuda", "mps"],
180-
help="Hardware device to use. Options: fast, cpu, cuda, mps",
179+
choices=["fast", "cpu", "cuda", "mps", "xpu"],
180+
help="Hardware device to use. Options: fast, cpu, cuda, mps, xpu",
181+
)
182+
model_config_parser.add_argument(
183+
"--attention-backend",
184+
type=str,
185+
default="math",
186+
choices=["math", "flash_attention", "efficient_attention", "cudnn_attention"],
187+
help="SDPBackend to use. Options: MATH, FLASH_ATTENTION, EFFICIENT_ATTENTION, CUDNN_ATTENTION",
181188
)
182189

183190

torchchat/generate.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch.distributed as dist
2727
import torch.multiprocessing as mp
2828
from torch.distributed.pipelining import PipelineStage, ScheduleGPipe
29+
from torch._C import _SDPBackend as SDPBackend
2930

3031
from PIL import Image
3132

@@ -531,6 +532,7 @@ def decode_n_tokens(
531532
callback=lambda _: _,
532533
eos_token_id: int = 2,
533534
eot_id: Optional[int] = None,
535+
attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH,
534536
**sampling_kwargs,
535537
):
536538
new_tokens, new_probs = [], []
@@ -539,7 +541,7 @@ def decode_n_tokens(
539541
num_new_tokens - 1
540542
): # -1 to save space to run an EoS if dont generate it naturally
541543
# Actually better for Inductor to codegen attention here
542-
with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):
544+
with torch.nn.attention.sdpa_kernel([attention_backend]):
543545

544546
out_token = cur_token.clone()
545547
next_token, next_prob = self.decode_one_token(
@@ -683,6 +685,7 @@ def generate(
683685
sequential_prefill=True,
684686
callback=lambda x: x,
685687
max_seq_length: int,
688+
attention_backend: SDPBackend = torch.nn.attention.SDPBackend.MATH,
686689
seed: Optional[int] = None,
687690
**sampling_kwargs,
688691
) -> torch.Tensor:
@@ -799,6 +802,7 @@ def generate(
799802
if self.is_llama3_model
800803
else None
801804
),
805+
attention_backend=attention_backend,
802806
**sampling_kwargs,
803807
):
804808
generated_tokens.append(generated_token.view(-1))
@@ -1122,7 +1126,7 @@ def chat(
11221126
messages_to_encode.append(
11231127
{"role": "system", "content": self.system_prompt}
11241128
)
1125-
messages_to_encode.append({"role": "system", "content": prompt})
1129+
messages_to_encode.append({"role": "user", "content": prompt})
11261130
encoded = self.chat_formatter.encode_dialog_prompt(
11271131
messages_to_encode, add_generation_prompt=True,
11281132
)
@@ -1186,6 +1190,7 @@ def callback(x, *, done_generating=False):
11861190
start_pos=start_pos,
11871191
skip_cache_setup=not is_first_sample,
11881192
max_seq_length=max_seq_length,
1193+
attention_backend=self.builder_args.attention_backend,
11891194
)
11901195
for token_tensor, metrics in generator_func:
11911196
if token_tensor is not None:
@@ -1203,8 +1208,10 @@ def callback(x, *, done_generating=False):
12031208
if hasattr(prof, "export_chrome_trace"):
12041209
if self.builder_args.device == "cpu":
12051210
print(prof.key_averages().table(sort_by="self_cpu_time_total"))
1206-
else:
1211+
elif self.builder_args.device == "cuda":
12071212
print(prof.key_averages().table(sort_by="self_cuda_time_total"))
1213+
else:
1214+
print(prof.key_averages().table(sort_by="self_xpu_time_total"))
12081215
prof.export_chrome_trace(f"{self.profile}.json")
12091216

12101217
if start_pos >= max_seq_length:
@@ -1289,6 +1296,9 @@ def callback(x, *, done_generating=False):
12891296
)
12901297
if torch.cuda.is_available():
12911298
print(f"Memory used: {torch.cuda.max_memory_reserved() / 1e9:.02f} GB")
1299+
if torch.xpu.is_available():
1300+
print(f"Memory used: {torch.xpu.max_memory_reserved() / 1e9:.02f} GB")
1301+
12921302

12931303

12941304
class DistributedGenerator(LocalGenerator):
@@ -1615,6 +1625,8 @@ def run_generator(
16151625
)
16161626
if torch.cuda.is_available():
16171627
torch.cuda.reset_peak_memory_stats()
1628+
if torch.xpu.is_available():
1629+
torch.xpu.reset_peak_memory_stats()
16181630

16191631
for _ in gen.chat(generator_args):
16201632
pass

0 commit comments

Comments
 (0)