Skip to content

dynamic shape for slice converter #2901

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 11, 2024
Merged

dynamic shape for slice converter #2901

merged 2 commits into from
Jul 11, 2024

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Jun 7, 2024

I have just added the test cases, and not refactored has_dynamic_shape and set_layer_name since they are already part of #2853

@apbose apbose requested a review from peri044 June 7, 2024 22:52
@github-actions github-actions bot added the component: tests Issues re: Tests label Jun 7, 2024
@github-actions github-actions bot requested a review from narendasan June 7, 2024 22:52
@@ -70,7 +147,7 @@ def forward(self, input):
Input(
shape=(1, 10, -1),
dtype=torch.float32,
shape_ranges=[((1, 10, 1), (1, 10, 10), (1, 10, 10))],
shape_ranges=[(min_shape, opt_shape, max_shape)],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use Input(min_shape= <>, opt_shape=<>, max_shape=<>)

@peri044
Copy link
Collaborator

peri044 commented Jun 12, 2024

  1. Can we perform slice on dynamic dimensions ?
  2. I'm facing slice layer conversion issues on llama2

Here's the reproducer :

  1. Please login via huggingface-cli (install it via pip install -U "huggingface_hub[cli]" )
    https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login. The user access token can be accessed in your settings.
    Script:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch_tensorrt

llama_path = "meta-llama/Llama-2-7b-hf"
model = AutoModelForCausalLM.from_pretrained(llama_path).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(llama_path)

base_prompt = "How many hours are in a day?"
base_inputs = tokenizer(base_prompt, return_tensors="pt").to("cuda:0")
input_ids = base_inputs.input_ids
pyt_out = model(input_ids)

seq_len = torch.export.Dim("seq_len", min=2, max=1024)

from torch.nn.attention import SDPBackend
with torch.nn.attention.sdpa_kernel([SDPBackend.MATH]), torch.no_grad():
    ep = torch.export.export(model, (input_ids,), dynamic_shapes=({1: seq_len},))

trt_model = torch_tensorrt.dynamo.compile(ep, inputs=[input_ids],
                                      enabled_precisions={torch.float16}, 
                                      min_block_size=1, 
                                      truncate_long_and_double=True, 
                                      debug=True)

The error is

DEBUG:torch_tensorrt.dynamo.conversion._TRTInterpreter:Converting node /model_layers_0_self_attn_rotary_emb/slice_3 (kind: aten.slice.Tensor, args
: ('[SHUFFLE]-[aten_ops.unsqueeze.default]-[/model_layers_0_self_attn_rotary_emb/unsqueeze_5]_output <tensorrt.ITensor [shape=(1, 1, -1), dtype=Da
taType.INT32]>', 2, 0, 9223372036854775807))                                                                                                      
Traceback (most recent call last):                                                                                                                
  File "/work/TensorRT/llama2_auto.py", line 21, in <module>                                                                                      
    enabled_precisions={torch.float16},                                                                                                           
  File "/work/TensorRT/py/torch_tensorrt/dynamo/_compiler.py", line 227, in compile                                                               
    trt_gm = compile_module(gm, inputs, settings)                                                                                                 
  File "/work/TensorRT/py/torch_tensorrt/dynamo/_compiler.py", line 412, in compile_module                                                        
    trt_module = convert_module(                                                                                                                  
  File "/work/TensorRT/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 106, in convert_module                                           
    interpreter_result = interpret_module_to_result(module, inputs, settings)                                                                     
  File "/work/TensorRT/py/torch_tensorrt/dynamo/conversion/_conversion.py", line 87, in interpret_module_to_result                                
    interpreter_result = interpreter.run()                                                                                                        
  File "/work/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 309, in run                                                  
    super().run()                                                                                                                                 
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/fx/interpreter.py", line 145, in run                                     
    self.env[node] = self.run_node(node)                                                                                                          
  File "/work/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 348, in run_node                                             
    trt_node: torch.fx.Node = super().run_node(n)                                                                                                 
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch/fx/interpreter.py", line 202, in run_node                                
    return getattr(self, n.op)(n.target, args, kwargs)                                                                                            
  File "/work/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py", line 456, in call_function                                        
    return converter(self.ctx, target, args, kwargs, self._cur_node_name)                                                                         
  File "/work/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py", line 400, in convert_with_type_enforcement                        
    return func(ctx, target, new_args, new_kwargs, name)                                                                                          
  File "/work/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 769, in aten_ops_slice                                   
    return impl.slice.slice_op(                                                                                                                   
  File "/work/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py", line 50, in slice_op                                               
    assert input.shape[dim] != -1, "Can't slice on dynamic shape dimension!"                                                                      
AssertionError: Can't slice on dynamic shape dimension!

@github-actions github-actions bot added component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jun 13, 2024
@apbose apbose force-pushed the dynamic_shapes_slice branch from aa9873a to a9467cd Compare June 13, 2024 21:16
@apbose
Copy link
Collaborator Author

apbose commented Jun 14, 2024

While testing the llama model I get this error -

torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (seq_len)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of seq_len = L['input_ids'].size()[1] in the specified range seq_len <= 1024 satisfy the generated guard Ne(Mod(L['input_ids'].size()[1], 64), 0).

Could you provide me the path of the model here- llama_path = "meta-llama/Llama-2-7b-hf"

@apbose
Copy link
Collaborator Author

apbose commented Jun 26, 2024

Hi @peri044 the code now goes past the above error points and also the python int overflow error. However I see this error now

Traceback (most recent call last):
  File "/code/torch_tensorrt/TensorRT/tests/py/dynamo/conversion/llama.py", line 20, in <module>
    trt_model = torch_tensorrt.dynamo.compile(ep, inputs=[input_ids],
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 227, in compile
    trt_gm = compile_module(gm, inputs, settings)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 365, in compile_module
    submodule_inputs = partitioning.construct_submodule_inputs(submodule)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch_tensorrt/dynamo/partitioning/common.py", line 116, in construct_submodule_inputs
    raise AssertionError(
AssertionError: Input scaled_dot_product_attention does not contain metadata. Please ensure you have exported the graph correctly

Oddly this comes after the engine creation debug prints here-
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:11.626404
INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 363528004 bytes of Memory

The above error comes in create_submodule_inputs which I believe should come before engine creation. Is it because of multi gpu setting?- WARNING: [Torch-TensorRT] - Detected this engine is being instantitated in a multi-GPU system with multi-device safe mode disabled.
Attached the logs for the same.[
llama_out.log
](url)

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py	2024-06-26 01:32:20.088317+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py	2024-06-26 01:34:14.458289+00:00
@@ -72,11 +72,11 @@
                start_slice = input.shape
                start_slice[dim] = -1 * start
            if stop < 0 or stop_dynamic_None:
                stop_slice = [0] * len(input.shape)
                stop_slice[dim] = -1 * stop
-            if(stop == sys.maxsize):
+            if stop == sys.maxsize:
                stop_slice = [0] * len(input.shape)
            start_slice_tensor = cat(
                ctx,
                target,
                source_ir,
@@ -115,11 +115,11 @@
                    name + "_sub_start",
                    trt.ElementWiseOperation.SUB,
                    shape,
                    start_slice_tensor,
                )
-            if ((stop < 0) or stop_dynamic_None or stop == sys.maxsize):
+            if (stop < 0) or stop_dynamic_None or stop == sys.maxsize:
                shape = get_shape_with_dynamic_shape(
                    ctx, target, source_ir, name, output_shape, input
                )
                stop_slice_tensor = convert_binary_elementwise(
                    ctx,

@apbose apbose force-pushed the dynamic_shapes_slice branch from 6e5c1d3 to ce63f4c Compare June 26, 2024 01:42
@peri044
Copy link
Collaborator

peri044 commented Jun 26, 2024

@apbose I'm seeing the following error

File "/home/dperi/Downloads/TensorRT/py/torch_tensorrt/dynamo/conversion/converter_utils.py", line 530, in convert_with_type_enforcement
    return func(ctx, target, new_args, new_kwargs, name)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/dperi/Downloads/TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py", line 884, in aten_ops_slice
    return impl.slice.slice_op(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/dperi/Downloads/TensorRT/py/torch_tensorrt/dynamo/conversion/impl/slice/ops.py", line 59, in slice_op
    stop_slice[dim] = stop
    ~~~~~~~~~~^^^^^
TypeError: __setitem__(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorrt_bindings.tensorrt.Dims, arg0: int, arg1: int) -> None
    2. (self: tensorrt_bindings.tensorrt.Dims, arg0: slice, arg1: tensorrt_bindings.tensorrt.Dims) -> None

Invoked with: (1, 1, 1024, 1024), 2, <tensorrt_bindings.tensorrt.ITensor object at 0x73495841e030>

While executing %slice_3 : [num_users=2] = call_function[target=torch.ops.aten.slice.Tensor](args = (%_frozen_param1, 2, 0, %sym_size_int_1), kwargs = {_itensor_to_tensor_meta: {<tensorrt_bindings.tensorrt.ITensor object at 0x7349584b9db0>: ((1, s0), torch.int64, False, (s0, 1), torch.contiguous_format, False, {}), <tensorrt_bindings.tensorrt.ITensor object at 0x734958ba4770>: None, <tensorrt_bindings.tensorrt.ITensor object at 0x7349585498f0>: None,

To reproduce:

  1. Please build torch_tensorrt using llm_examples_main branch.
  2. Run this example : https://github.com/pytorch/TensorRT/blob/llm_examples_main/examples/dynamo/torch_export_gpt2.py and also comment out torch_executed_ops line : https://github.com/pytorch/TensorRT/blob/llm_examples_main/examples/dynamo/torch_export_gpt2.py#L70
    Let me know if you're able to reproduce this error.

@apbose apbose force-pushed the dynamic_shapes_slice branch 2 times, most recently from d242b7e to 454cf5a Compare June 28, 2024 01:24
@apbose apbose requested a review from peri044 July 8, 2024 21:48
@apbose apbose force-pushed the dynamic_shapes_slice branch from 454cf5a to ae5cf44 Compare July 8, 2024 22:04
Adding cases for slicing on dynamic dimension

handling the case wjem stop is max int64 value in the dynamic dimension

Addressing GPT2 cases- when stop is an ITensor
Comment on lines 58 to 60
if stop is None:
stop_dynamic_None = True if input.shape[dim] == -1 else False
if stop is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you merge the two if stop is None statements ?

Comment on lines 72 to 73
for i in range(len(input.shape)):
start_slice.append(0) if i == dim else start_slice.append(start)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what is happening here ? Why are we appending start (an ITensor) to all the other dimensions except dim ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. start_slice should have start appended for i==dim.

@peri044 peri044 merged commit 0ef880d into main Jul 11, 2024
52 of 61 checks passed
cehongwang pushed a commit that referenced this pull request Jul 12, 2024
@Hukongtao
Copy link

Hi @peri044 the code now goes past the above error points and also the python int overflow error. However I see this error now

Traceback (most recent call last):
  File "/code/torch_tensorrt/TensorRT/tests/py/dynamo/conversion/llama.py", line 20, in <module>
    trt_model = torch_tensorrt.dynamo.compile(ep, inputs=[input_ids],
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 227, in compile
    trt_gm = compile_module(gm, inputs, settings)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 365, in compile_module
    submodule_inputs = partitioning.construct_submodule_inputs(submodule)
  File "/root/.pyenv/versions/3.10.14/lib/python3.10/site-packages/torch_tensorrt/dynamo/partitioning/common.py", line 116, in construct_submodule_inputs
    raise AssertionError(
AssertionError: Input scaled_dot_product_attention does not contain metadata. Please ensure you have exported the graph correctly

Oddly this comes after the engine creation debug prints here- INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:Build TRT engine elapsed time: 0:00:11.626404 INFO:torch_tensorrt.dynamo.conversion._TRTInterpreter:TRT Engine uses: 363528004 bytes of Memory

The above error comes in create_submodule_inputs which I believe should come before engine creation. Is it because of multi gpu setting?- WARNING: [Torch-TensorRT] - Detected this engine is being instantitated in a multi-GPU system with multi-device safe mode disabled. Attached the logs for the same.[ llama_out.log ](url)

Hello, I got the same error with you

AssertionError: Input scaled_dot_product_attention does not contain metadata. Please ensure you have exported the graph correctly

How did you solve it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants