|
| 1 | +""" |
| 2 | +.. _weight_streaming_example: |
| 3 | +
|
| 4 | +Weight Streaming |
| 5 | +======================= |
| 6 | +
|
| 7 | +Weight streaming in TensorRT is a powerful feature designed to overcome GPU memory limitations |
| 8 | +when working with large models. It enables running models larger than available GPU memory |
| 9 | +by streaming weight data from host (CPU) memory to GPU memory during inference. |
| 10 | +
|
| 11 | +Streaming larger amounts of memory will likely result in lower performance. But if |
| 12 | +streaming weights allows the user to run larger batch sizes and it can lead to higher throughput. |
| 13 | +This increased throughput can sometimes outweigh the slowdown caused by streaming weights. |
| 14 | +The optimal amount of memory to stream varies depending on the specific model and hardware. |
| 15 | +Experimenting with different memory limits can help find the best balance between streaming |
| 16 | +overhead and batch size benefits. |
| 17 | +
|
| 18 | +This example uses a pre-trained Llama-2 model and show how to use weight streaming feature with |
| 19 | +Torch-TensorRT. |
| 20 | + 1. compile option - build trt engine with weight streaming feature |
| 21 | + 2. runtime api - weight streaming budget control by context manager |
| 22 | +""" |
| 23 | + |
| 24 | +# %% |
| 25 | +# Imports and Model Definition |
| 26 | +# ---------------------------------- |
| 27 | + |
| 28 | +import copy |
| 29 | +import timeit |
| 30 | + |
| 31 | +import numpy as np |
| 32 | +import torch |
| 33 | +import torch_tensorrt |
| 34 | +from transformers import AutoModelForCausalLM |
| 35 | +from utils import export_llm |
| 36 | + |
| 37 | + |
| 38 | +def time_generate(model, inputs, output_seq_length, iterations=10): |
| 39 | + """ |
| 40 | + Measure the time for generating a sentence over certain number of iterations |
| 41 | + """ |
| 42 | + # We only support single input (B x seq_len) for LLMs now |
| 43 | + input_seq = inputs[0] |
| 44 | + with torch.no_grad(): |
| 45 | + timings = [] |
| 46 | + for _ in range(iterations): |
| 47 | + start_time = timeit.default_timer() |
| 48 | + inputs_copy = copy.copy(input_seq) |
| 49 | + # Greedy decoding of the model. This generates up to max_tokens. |
| 50 | + while inputs_copy.shape[1] <= output_seq_length: |
| 51 | + outputs = model(inputs_copy) |
| 52 | + logits = outputs.logits |
| 53 | + next_token_logits = logits[:, -1, :] |
| 54 | + next_tokens = torch.argmax(next_token_logits, dim=-1) |
| 55 | + inputs_copy = torch.cat([inputs_copy, next_tokens[:, None]], dim=-1) |
| 56 | + torch.cuda.synchronize() |
| 57 | + end_time = timeit.default_timer() |
| 58 | + timings.append(end_time - start_time) |
| 59 | + |
| 60 | + times = np.array(timings) |
| 61 | + time_mean_ms = np.mean(times) * 1000 |
| 62 | + |
| 63 | + return time_mean_ms |
| 64 | + |
| 65 | + |
| 66 | +# Load the LLaMA-2 model |
| 67 | +DEVICE = torch.device("cuda:0") |
| 68 | +llama_path = "meta-llama/Llama-2-7b-chat-hf" |
| 69 | +with torch.no_grad(): |
| 70 | + model = AutoModelForCausalLM.from_pretrained( |
| 71 | + llama_path, use_cache=False, attn_implementation="eager" |
| 72 | + ).eval() |
| 73 | + |
| 74 | +# Set input and output sequence lengths |
| 75 | +isl = 128 |
| 76 | +osl = 256 |
| 77 | + |
| 78 | +# Create random input tensors |
| 79 | +input_tensors = [torch.randint(0, 5, (1, isl), dtype=torch.int64).cuda()] |
| 80 | +# Convert the model to half precision (FP16) |
| 81 | +model = model.half() |
| 82 | +# Exports the LLM model into an ExportedProgram with dynamic shapes. |
| 83 | +llama2_ep = export_llm(model, input_tensors[0], max_seq_len=osl) |
| 84 | + |
| 85 | +# %% |
| 86 | +# Compiler option |
| 87 | +# ---------------------------------- |
| 88 | +# |
| 89 | +# enable_weight_streaming=True option and use_explicit_typing=True are required to build |
| 90 | +# the engine with weight streaming feature. use_explicit_typing=True option creates a |
| 91 | +# `strongly typed network <https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#strongly-typed-networks>`_ and only float32 precision is allowed in enabled_precisions option |
| 92 | +# |
| 93 | + |
| 94 | +# Create a TensorRT-compiled model |
| 95 | +trt_model = torch_tensorrt.dynamo.compile( |
| 96 | + llama2_ep, |
| 97 | + inputs=input_tensors, |
| 98 | + enabled_precisions={torch.float32}, |
| 99 | + truncate_double=True, |
| 100 | + device=DEVICE, |
| 101 | + use_explicit_typing=True, |
| 102 | + enable_weight_streaming=True, |
| 103 | +) |
| 104 | + |
| 105 | +# Warm up for 3 iterations |
| 106 | +_ = time_generate(trt_model, input_tensors, osl, 3) |
| 107 | + |
| 108 | +# %% |
| 109 | +# Running with automatic budget size |
| 110 | +# ---------------------------------- |
| 111 | +# |
| 112 | +# Once you specify the enable_weight_streaming compile option, automatic budget size is configured. |
| 113 | +# This automatic size may not always provide the optimal solution because the automatically determined |
| 114 | +# budget lacks insight into the user's specific memory constraints and usage patterns |
| 115 | + |
| 116 | +# Weight streaming context to get current weight budget information |
| 117 | +weight_streaming_ctx = torch_tensorrt.runtime.weight_streaming(trt_model) |
| 118 | +# Measure the mean latency of the model with weight streaming |
| 119 | +mean_latency = time_generate(trt_model, input_tensors, osl, 1) |
| 120 | +# Calculate the percentage of current weight budget used |
| 121 | +weight_budget_pct = ( |
| 122 | + weight_streaming_ctx.device_budget / weight_streaming_ctx.total_device_budget * 100 |
| 123 | +) |
| 124 | +print( |
| 125 | + f"Set weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms" |
| 126 | +) |
| 127 | + |
| 128 | +# %% |
| 129 | +# Running with weight streaming context manager |
| 130 | +# ---------------------------------- |
| 131 | +# |
| 132 | +# Weight streaming budget can be limited by using weight streaming context manager. |
| 133 | +# The permissible range for the budget size is from 0 to ctx.total_device_budget. |
| 134 | +# 0 means maximum memory savings occur by using minimum amounts of memory. Value |
| 135 | +# equal to ctx.total_device_budget will disable weight streaming. |
| 136 | +# If multiple trt engines are created, budgets are distributed proportionally |
| 137 | + |
| 138 | +# Use a context manager for weight streaming |
| 139 | +with torch_tensorrt.runtime.weight_streaming(trt_model) as weight_streaming_ctx: |
| 140 | + # Get the total size of streamable weights in the engine |
| 141 | + streamable_budget = weight_streaming_ctx.total_device_budget |
| 142 | + |
| 143 | + # Scenario 1: Automatic weight streaming budget |
| 144 | + # Get the automatically determined weight streaming budget |
| 145 | + requested_budget = weight_streaming_ctx.get_automatic_weight_streaming_budget() |
| 146 | + # Set the device budget to the automatically determined value |
| 147 | + weight_streaming_ctx.device_budget = requested_budget |
| 148 | + # Measure the mean latency with automatic budget |
| 149 | + mean_latency = time_generate(trt_model, input_tensors, osl, 1) |
| 150 | + # Calculate the percentage of the weight budget used |
| 151 | + weight_budget_pct = ( |
| 152 | + weight_streaming_ctx.device_budget |
| 153 | + / weight_streaming_ctx.total_device_budget |
| 154 | + * 100 |
| 155 | + ) |
| 156 | + print( |
| 157 | + f"Set auto weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms" |
| 158 | + ) |
| 159 | + |
| 160 | + # Scenario 2: Manual 10% weight streaming budget |
| 161 | + # Set the budget to 10% of the total streamable weights |
| 162 | + requested_budget = int(streamable_budget * 0.1) |
| 163 | + weight_streaming_ctx.device_budget = requested_budget |
| 164 | + # Measure the mean latency with 10% budget |
| 165 | + mean_latency = time_generate(trt_model, input_tensors, osl, 1) |
| 166 | + # Calculate the percentage of the weight budget used |
| 167 | + weight_budget_pct = ( |
| 168 | + weight_streaming_ctx.device_budget |
| 169 | + / weight_streaming_ctx.total_device_budget |
| 170 | + * 100 |
| 171 | + ) |
| 172 | + print( |
| 173 | + f"Set weight streaming budget as {weight_budget_pct}%. {weight_streaming_ctx.device_budget} bytes out of {weight_streaming_ctx.total_device_budget}. mean latency = {mean_latency} ms" |
| 174 | + ) |
0 commit comments