Skip to content

Commit 40213c9

Browse files
drbhNarsil
andauthored
Pali gemma modeling (#1895)
This PR adds paligemma modeling code Blog post: https://huggingface.co/blog/paligemma Transformers PR: huggingface/transformers#30814 install the latest changes and run with ```bash # get the weights # text-generation-server download-weights gv-hf/PaliGemma-base-224px-hf # run TGI text-generation-launcher --model-id gv-hf/PaliGemma-base-224px-hf ``` basic example sending various requests ```python from huggingface_hub import InferenceClient client = InferenceClient("http://127.0.0.1:3000") images = [ "https://huggingface.co/datasets/hf-internal-testing/fixtures-captioning/resolve/main/cow_beach_1.png", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/rabbit.png", ] prompts = [ "What animal is in this image?", "Name three colors in this image.", "What are 10 colors in this image?", "Where is the cow standing?", "answer en Where is the cow standing?", "Is there a bird in the image?", "Is ther a cow in the image?", "Is there a rabbit in the image?", "how many birds are in the image?", "how many rabbits are in the image?", ] for img in images: print(f"\nImage: {img.split('/')[-1]}") for prompt in prompts: inputs = f"![]({img}){prompt}\n" json_data = { "inputs": inputs, "parameters": { "max_new_tokens": 30, "do_sample": False, }, } generated_output = client.text_generation(prompt, max_new_tokens=30, stream=False) print([f"{prompt}\n{generated_output}"]) ``` --------- Co-authored-by: Nicolas Patry <[email protected]>
1 parent 6c715f8 commit 40213c9

File tree

23 files changed

+1148
-157
lines changed

23 files changed

+1148
-157
lines changed

.github/workflows/build.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
runs-on: ubuntu-latest
2828
env:
2929
AWS_REGION: us-east-1
30-
EC2_AMI_ID: ami-03cfed9ea28f4b002
30+
EC2_AMI_ID: ami-0789b6925c11b1fb2
3131
EC2_INSTANCE_TYPE: g5.12xlarge
3232
EC2_SUBNET_ID: subnet-931b34f5,subnet-ecb993cd,subnet-943dc2d8,subnet-45371f1a,subnet-ee93e0df,subnet-fddc3dfc
3333
EC2_SECURITY_GROUP: sg-030175c435ac141d6

Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ ARG PYTORCH_VERSION=2.3.0
4343
ARG PYTHON_VERSION=3.10
4444
# Keep in sync with `server/pyproject.toml
4545
ARG CUDA_VERSION=12.1
46-
ARG MAMBA_VERSION=23.3.1-1
46+
ARG MAMBA_VERSION=24.3.0-0
4747
ARG CUDA_CHANNEL=nvidia
4848
ARG INSTALL_CHANNEL=pytorch
4949
# Automatically set by buildx
@@ -181,6 +181,7 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-ins
181181
ca-certificates \
182182
make \
183183
curl \
184+
git \
184185
&& rm -rf /var/lib/apt/lists/*
185186

186187
# Copy conda with PyTorch installed
65.7 KB
Loading
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
{
2+
"details": {
3+
"best_of_sequences": null,
4+
"finish_reason": "eos_token",
5+
"generated_tokens": 2,
6+
"prefill": [],
7+
"seed": null,
8+
"tokens": [
9+
{
10+
"id": 54901,
11+
"logprob": -0.72753906,
12+
"special": false,
13+
"text": "beach"
14+
},
15+
{
16+
"id": 1,
17+
"logprob": -0.011009216,
18+
"special": true,
19+
"text": "<eos>"
20+
}
21+
],
22+
"top_tokens": null
23+
},
24+
"generated_text": "beach"
25+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pytest
2+
import requests
3+
import io
4+
import base64
5+
6+
7+
@pytest.fixture(scope="module")
8+
def flash_pali_gemma_handle(launcher):
9+
with launcher(
10+
"google/paligemma-3b-pt-224",
11+
num_shard=1,
12+
revision="float16",
13+
max_input_length=4000,
14+
max_total_tokens=4096,
15+
) as handle:
16+
yield handle
17+
18+
19+
@pytest.fixture(scope="module")
20+
async def flash_pali_gemma(flash_pali_gemma_handle):
21+
await flash_pali_gemma_handle.health(300)
22+
return flash_pali_gemma_handle.client
23+
24+
25+
def get_cow_beach():
26+
with open("integration-tests/images/cow_beach.png", "rb") as image_file:
27+
encoded_string = base64.b64encode(image_file.read())
28+
return f"data:image/png;base64,{encoded_string.decode('utf-8')}"
29+
30+
31+
@pytest.mark.asyncio
32+
@pytest.mark.private
33+
async def test_flash_pali_gemma(flash_pali_gemma, response_snapshot):
34+
cow = get_cow_beach()
35+
inputs = f"![]({cow})Where is the cow standing?\n"
36+
response = await flash_pali_gemma.generate(inputs, max_new_tokens=20)
37+
38+
assert response.generated_text == "beach"
39+
assert response == response_snapshot

router/src/config.rs

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,13 @@ impl LlavaNext {
100100
}
101101

102102
#[derive(Clone, Debug, Serialize, Deserialize)]
103-
#[serde(tag = "model_type")]
104103
#[serde(rename_all = "snake_case")]
105104
pub struct ClipVisionModel {
106105
image_size: usize,
107106
patch_size: usize,
108107
}
109108

110109
#[derive(Clone, Debug, Serialize, Deserialize)]
111-
#[serde(tag = "model_type")]
112110
#[serde(rename_all = "snake_case")]
113111
pub struct Idefics2 {}
114112

@@ -118,6 +116,24 @@ impl Idefics2 {
118116
}
119117
}
120118

119+
#[derive(Clone, Debug, Serialize, Deserialize)]
120+
#[serde(rename_all = "snake_case")]
121+
pub struct PaliTextConfig {
122+
num_image_tokens: usize,
123+
}
124+
125+
#[derive(Clone, Debug, Serialize, Deserialize)]
126+
#[serde(rename_all = "snake_case")]
127+
pub struct Paligemma {
128+
text_config: PaliTextConfig,
129+
}
130+
131+
impl Paligemma {
132+
pub fn get_number_of_features(&self, _height: usize, _width: usize) -> usize {
133+
self.text_config.num_image_tokens
134+
}
135+
}
136+
121137
#[derive(Clone, Debug, Serialize, Deserialize)]
122138
#[serde(tag = "model_type")]
123139
#[serde(rename_all = "snake_case")]
@@ -140,6 +156,7 @@ pub enum Config {
140156
Phi3,
141157
Llama,
142158
Baichuan,
159+
Paligemma(Paligemma),
143160
Gemma,
144161
Cohere,
145162
Drbx,

router/src/validation.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,30 @@ fn prepare_input(
544544
inputs = modified_inputs;
545545
tokenizer_query
546546
}
547+
Some(Config::Paligemma(config)) => {
548+
let mut modified_inputs = String::with_capacity(inputs.len());
549+
let mut tokenizer_query = String::with_capacity(inputs.len());
550+
let mut start = 0;
551+
for chunk in RE.find_iter(&inputs) {
552+
let chunk_start = chunk.start();
553+
let chunk_end = chunk.end();
554+
if chunk_start != start {
555+
modified_inputs.push_str(&inputs[start..chunk_start]);
556+
tokenizer_query.push_str(&inputs[start..chunk_start]);
557+
}
558+
let (image_uri, height, width) = fetch_image(&inputs[chunk_start..chunk_end])?;
559+
let slots = config.get_number_of_features(height, width);
560+
tokenizer_query.push_str(&"<image>".repeat(slots));
561+
modified_inputs.push_str(&image_uri);
562+
start = chunk_end;
563+
}
564+
if start != inputs.len() - 1 {
565+
modified_inputs.push_str(&inputs[start..]);
566+
tokenizer_query.push_str(&inputs[start..]);
567+
}
568+
inputs = modified_inputs;
569+
tokenizer_query
570+
}
547571
Some(Config::Idefics2(config)) => {
548572
let mut modified_inputs = String::with_capacity(inputs.len());
549573
let mut tokenizer_query = String::with_capacity(inputs.len());

0 commit comments

Comments
 (0)