Skip to content

Commit e4048df

Browse files
committed
add OAI clinet as well
1 parent 9cb93be commit e4048df

File tree

1 file changed

+58
-8
lines changed

1 file changed

+58
-8
lines changed

packages/tasks/src/snippets/python.ts

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ export const snippetConversational = (
2828
max_tokens?: GenerationParameters["max_tokens"];
2929
top_p?: GenerationParameters["top_p"];
3030
}
31-
): InferenceSnippet => {
31+
): InferenceSnippet[] => {
3232
const streaming = opts?.streaming ?? true;
3333
const messages: ChatCompletionInputMessage[] = opts?.messages ?? [
3434
{ role: "user", content: "What is the capital of France?" },
@@ -41,8 +41,10 @@ export const snippetConversational = (
4141
};
4242

4343
if (streaming) {
44-
return {
45-
content: `from huggingface_hub import InferenceClient
44+
return [
45+
{
46+
client: "huggingface_hub",
47+
content: `from huggingface_hub import InferenceClient
4648
4749
client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
4850
@@ -57,10 +59,34 @@ stream = client.chat.completions.create(
5759
5860
for chunk in stream:
5961
print(chunk.choices[0].delta.content)`,
60-
};
62+
},
63+
{
64+
client: "openai",
65+
content: `from openai import OpenAI
66+
67+
client = OpenAI(
68+
base_url="https://api-inference.huggingface.co/v1/",
69+
api_key="${accessToken || "{API_TOKEN}"}"
70+
)
71+
72+
messages = ${formatGenerationMessages({ messages, sep: ",\n\t", start: `[\n\t`, end: `\n]` })}
73+
74+
stream = client.chat.completions.create(
75+
model="${model.id}",
76+
messages=messages,
77+
${formatGenerationConfig({ config, sep: ",\n\t", start: "", end: "", connector: "=" })},
78+
stream=True
79+
)
80+
81+
for chunk in stream:
82+
print(chunk.choices[0].delta.content)`,
83+
},
84+
];
6185
} else {
62-
return {
63-
content: `from huggingface_hub import InferenceClient
86+
return [
87+
{
88+
client: "huggingface_hub",
89+
content: `from huggingface_hub import InferenceClient
6490
6591
client = InferenceClient(api_key="${accessToken || "{API_TOKEN}"}")
6692
@@ -73,7 +99,27 @@ completion = client.chat.completions.create(
7399
)
74100
75101
print(completion.choices[0].message)`,
76-
};
102+
},
103+
{
104+
client: "openai",
105+
content: `from openai import OpenAI
106+
107+
client = OpenAI(
108+
base_url="https://api-inference.huggingface.co/v1/",
109+
api_key="${accessToken || "{API_TOKEN}"}"
110+
)
111+
112+
messages = ${formatGenerationMessages({ messages, sep: ",\n\t", start: `[\n\t`, end: `\n]` })}
113+
114+
completion = client.chat.completions.create(
115+
model="${model.id}",
116+
messages=messages,
117+
${formatGenerationConfig({ config, sep: ",\n\t", start: "", end: "", connector: "=" })}
118+
)
119+
120+
print(completion.choices[0].message)`,
121+
},
122+
];
77123
}
78124
};
79125

@@ -220,7 +266,11 @@ output = query({
220266
export const pythonSnippets: Partial<
221267
Record<
222268
PipelineType,
223-
(model: ModelDataMinimal, accessToken: string, opts?: Record<string, unknown>) => InferenceSnippet
269+
(
270+
model: ModelDataMinimal,
271+
accessToken: string,
272+
opts?: Record<string, unknown>
273+
) => InferenceSnippet | InferenceSnippet[]
224274
>
225275
> = {
226276
// Same order as in tasks/src/pipelines.ts

0 commit comments

Comments
 (0)