Skip to content

Commit d443069

Browse files
committed
Pass provider inputs in raw payloads
1 parent 7f574bd commit d443069

21 files changed

+43
-27
lines changed

packages/inference/src/snippets/python.ts

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ interface TemplateParams {
1717
baseUrl?: string;
1818
fullUrl?: string;
1919
inputs?: object;
20+
providerInputs?: object;
2021
model?: ModelDataMinimal;
2122
provider?: InferenceProvider;
2223
providerModelId?: string;
@@ -115,6 +116,18 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
115116
{ chatCompletion: templateName.includes("conversational"), task: model.pipeline_tag as InferenceTask }
116117
);
117118

119+
/// Parse request.info.body if not a binary.
120+
/// This is the body sent to the provider. Important for snippets with raw payload (e.g curl, requests, etc.)
121+
let providerInputs = inputs;
122+
const bodyAsObj = request.info.body;
123+
if (typeof bodyAsObj === "string") {
124+
try {
125+
providerInputs = JSON.parse(bodyAsObj);
126+
} catch (e) {
127+
console.error("Failed to parse body as JSON", e);
128+
}
129+
}
130+
118131
/// Prepare template injection data
119132
const params: TemplateParams = {
120133
accessToken,
@@ -126,6 +139,11 @@ const snippetGenerator = (templateName: string, inputPreparationFn?: InputPrepar
126139
asJsonString: formatBody(inputs, "json"),
127140
asPythonString: indentString(formatBody(inputs, "python"), 4),
128141
},
142+
providerInputs: {
143+
asObj: providerInputs,
144+
asJsonString: formatBody(providerInputs, "json"),
145+
asPythonString: indentString(formatBody(providerInputs, "python"), 4),
146+
},
129147
model,
130148
provider,
131149
providerModelId: providerModelId ?? model.id,

packages/inference/src/snippets/templates/python/requests/automaticSpeechRecognition.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ def query(filename):
44
response = requests.post(API_URL, headers=headers, data=data)
55
return response.json()
66

7-
output = query({{ inputs.asObj.inputs }})
7+
output = query({{ providerInputs.asObj.inputs }})

packages/inference/src/snippets/templates/python/requests/basic.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ def query(payload):
33
return response.json()
44

55
output = query({
6-
"inputs": {{ inputs.asObj.inputs }},
6+
"inputs": {{ providerInputs.asObj.inputs }},
77
})

packages/inference/src/snippets/templates/python/requests/basicFile.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ def query(filename):
44
response = requests.post(API_URL, headers=headers, data=data)
55
return response.json()
66

7-
output = query({{ inputs.asObj.inputs }})
7+
output = query({{ providerInputs.asObj.inputs }})

packages/inference/src/snippets/templates/python/requests/conversational.jinja

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ def query(payload):
33
return response.json()
44

55
response = query({
6-
"model": "{{ providerModelId }}",
7-
{{ inputs.asJsonString }}
6+
{{ providerInputs.asJsonString }}
87
})
98

109
print(response["choices"][0]["message"])

packages/inference/src/snippets/templates/python/requests/conversationalStream.jinja

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ def query(payload):
88
yield json.loads(line.decode("utf-8").lstrip("data:").rstrip("/n"))
99

1010
chunks = query({
11-
"model": "{{ providerModelId }}",
12-
{{ inputs.asJsonString }},
11+
{{ providerInputs.asJsonString }},
1312
"stream": True,
1413
})
1514

packages/inference/src/snippets/templates/python/requests/documentQuestionAnswering.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,5 @@ def query(payload):
66
return response.json()
77

88
output = query({
9-
{{ inputs.asJsonString }},
9+
{{ providerInputs.asJsonString }},
1010
})

packages/inference/src/snippets/templates/python/requests/imageToImage.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ def query(payload):
66
return response.content
77

88
image_bytes = query({
9-
{{ inputs.asJsonString }}
9+
{{ providerInputs.asJsonString }}
1010
})
1111

1212
# You can access the image with PIL.Image for example

packages/inference/src/snippets/templates/python/requests/tabular.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ def query(payload):
44

55
response = query({
66
"inputs": {
7-
"data": {{ inputs.asObj.inputs }}
7+
"data": {{ providerInputs.asObj.inputs }}
88
},
99
})

packages/inference/src/snippets/templates/python/requests/textToAudio.jinja

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ def query(payload):
44
return response.content
55

66
audio_bytes = query({
7-
"inputs": {{ inputs.asObj.inputs }},
7+
"inputs": {{ providerInputs.asObj.inputs }},
88
})
99
# You can access the audio with IPython.display for example
1010
from IPython.display import Audio
@@ -15,7 +15,7 @@ def query(payload):
1515
return response.json()
1616

1717
audio, sampling_rate = query({
18-
"inputs": {{ inputs.asObj.inputs }},
18+
"inputs": {{ providerInputs.asObj.inputs }},
1919
})
2020
# You can access the audio with IPython.display for example
2121
from IPython.display import Audio

packages/inference/src/snippets/templates/python/requests/textToImage.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ def query(payload):
44
return response.content
55

66
image_bytes = query({
7-
"inputs": {{ inputs.asObj.inputs }},
7+
"inputs": {{ providerInputs.asObj.inputs }},
88
})
99

1010
# You can access the image with PIL.Image for example

packages/inference/src/snippets/templates/python/requests/zeroShotClassification.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ def query(payload):
33
return response.json()
44

55
output = query({
6-
"inputs": {{ inputs.asObj.inputs }},
6+
"inputs": {{ providerInputs.asObj.inputs }},
77
"parameters": {"candidate_labels": ["refund", "legal", "faq"]},
88
})

packages/inference/src/snippets/templates/python/requests/zeroShotImageClassification.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@ def query(data):
99
return response.json()
1010

1111
output = query({
12-
"image_path": {{ inputs.asObj.inputs }},
12+
"image_path": {{ providerInputs.asObj.inputs }},
1313
"parameters": {"candidate_labels": ["cat", "dog", "llama"]},
1414
})

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.requests.hf-inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ def query(payload):
88
return response.json()
99

1010
response = query({
11-
"model": "meta-llama/Llama-3.1-8B-Instruct",
1211
"messages": [
1312
{
1413
"role": "user",
1514
"content": "What is the capital of France?"
1615
}
1716
],
18-
"max_tokens": 500
17+
"max_tokens": 500,
18+
"model": "meta-llama/Llama-3.1-8B-Instruct"
1919
})
2020

2121
print(response["choices"][0]["message"])

packages/tasks-gen/snippets-fixtures/conversational-llm-non-stream/1.requests.together.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@ def query(payload):
88
return response.json()
99

1010
response = query({
11-
"model": "<together alias for meta-llama/Llama-3.1-8B-Instruct>",
1211
"messages": [
1312
{
1413
"role": "user",
1514
"content": "What is the capital of France?"
1615
}
1716
],
18-
"max_tokens": 500
17+
"max_tokens": 500,
18+
"model": "<together alias for meta-llama/Llama-3.1-8B-Instruct>"
1919
})
2020

2121
print(response["choices"][0]["message"])

packages/tasks-gen/snippets-fixtures/conversational-llm-stream/1.requests.hf-inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ def query(payload):
1414
yield json.loads(line.decode("utf-8").lstrip("data:").rstrip("/n"))
1515

1616
chunks = query({
17-
"model": "meta-llama/Llama-3.1-8B-Instruct",
1817
"messages": [
1918
{
2019
"role": "user",
2120
"content": "What is the capital of France?"
2221
}
2322
],
2423
"max_tokens": 500,
24+
"model": "meta-llama/Llama-3.1-8B-Instruct",
2525
"stream": True,
2626
})
2727

packages/tasks-gen/snippets-fixtures/conversational-llm-stream/1.requests.together.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ def query(payload):
1414
yield json.loads(line.decode("utf-8").lstrip("data:").rstrip("/n"))
1515

1616
chunks = query({
17-
"model": "<together alias for meta-llama/Llama-3.1-8B-Instruct>",
1817
"messages": [
1918
{
2019
"role": "user",
2120
"content": "What is the capital of France?"
2221
}
2322
],
2423
"max_tokens": 500,
24+
"model": "<together alias for meta-llama/Llama-3.1-8B-Instruct>",
2525
"stream": True,
2626
})
2727

packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/1.requests.fireworks-ai.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ def query(payload):
88
return response.json()
99

1010
response = query({
11-
"model": "<fireworks-ai alias for meta-llama/Llama-3.2-11B-Vision-Instruct>",
1211
"messages": [
1312
{
1413
"role": "user",
@@ -26,7 +25,8 @@ def query(payload):
2625
]
2726
}
2827
],
29-
"max_tokens": 500
28+
"max_tokens": 500,
29+
"model": "<fireworks-ai alias for meta-llama/Llama-3.2-11B-Vision-Instruct>"
3030
})
3131

3232
print(response["choices"][0]["message"])

packages/tasks-gen/snippets-fixtures/conversational-vlm-non-stream/1.requests.hf-inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ def query(payload):
88
return response.json()
99

1010
response = query({
11-
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
1211
"messages": [
1312
{
1413
"role": "user",
@@ -26,7 +25,8 @@ def query(payload):
2625
]
2726
}
2827
],
29-
"max_tokens": 500
28+
"max_tokens": 500,
29+
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct"
3030
})
3131

3232
print(response["choices"][0]["message"])

packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/1.requests.fireworks-ai.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ def query(payload):
1414
yield json.loads(line.decode("utf-8").lstrip("data:").rstrip("/n"))
1515

1616
chunks = query({
17-
"model": "<fireworks-ai alias for meta-llama/Llama-3.2-11B-Vision-Instruct>",
1817
"messages": [
1918
{
2019
"role": "user",
@@ -33,6 +32,7 @@ def query(payload):
3332
}
3433
],
3534
"max_tokens": 500,
35+
"model": "<fireworks-ai alias for meta-llama/Llama-3.2-11B-Vision-Instruct>",
3636
"stream": True,
3737
})
3838

packages/tasks-gen/snippets-fixtures/conversational-vlm-stream/1.requests.hf-inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@ def query(payload):
1414
yield json.loads(line.decode("utf-8").lstrip("data:").rstrip("/n"))
1515

1616
chunks = query({
17-
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
1817
"messages": [
1918
{
2019
"role": "user",
@@ -33,6 +32,7 @@ def query(payload):
3332
}
3433
],
3534
"max_tokens": 500,
35+
"model": "meta-llama/Llama-3.2-11B-Vision-Instruct",
3636
"stream": True,
3737
})
3838

0 commit comments

Comments
 (0)