Skip to content

feat: Implement custom OpenAI-compatible LLM support #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"name": "playwright-cdp-test",
"version": "1.0.0",
"description": "A script to test Playwright CDP connection.",
"main": "test_cdp.js",
"scripts": {
"start": "node test_cdp.js"
},
"dependencies": {
"playwright": "^1.40.0"
},
"author": "",
"license": "ISC"
}
21 changes: 17 additions & 4 deletions stagehand/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions stagehand/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
"dist"
],
"scripts": {
"build": "tsc && shx chmod +x dist/*.js",
"prepare": "npm run build",
"watch": "tsc --watch"
"build": "npx tsc && shx chmod +x dist/*.js",
"watch": "npx tsc --watch"
},
"dependencies": {
"@browserbasehq/sdk": "^2.0.0",
"@browserbasehq/stagehand": "^2.0.0",
"@modelcontextprotocol/sdk": "^1.0.3",
"@modelcontextprotocol/server-stagehand": "file:",
"@playwright/test": "^1.49.0"
"@playwright/test": "^1.49.0",
"openai": "^4.0.0"
},
"devDependencies": {
"shx": "^0.3.4",
Expand Down
192 changes: 192 additions & 0 deletions stagehand/src/customOpenAIClient.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import {
AvailableModel,
CreateChatCompletionOptions,
LLMClient,
Logger,
LogLine,
} from "@browserbasehq/stagehand";
import OpenAI from "openai";
import type {
ChatCompletion,
ChatCompletionContentPartImage,
ChatCompletionContentPartText,
ChatCompletionCreateParamsNonStreaming,
ChatCompletionMessageParam,
} from "openai/resources/chat/completions";
import { z } from "zod";

class CreateChatCompletionResponseError extends Error {
constructor(message: string) {
super(message);
this.name = "CreateChatCompletionResponseError";
}
}

function validateZodSchema(schema: z.ZodTypeAny, data: unknown) {
try {
schema.parse(data);
return true;
} catch {
return false;
}
}

export class CustomOpenAIClientWrapper extends LLMClient {
public type = "openai" as const;
private client: OpenAI;

constructor({ modelName, client }: { modelName: string; client: OpenAI }) {
super(modelName as AvailableModel);
this.client = client;
this.modelName = modelName as AvailableModel;
}

public hasVision: boolean = false;
public clientOptions: Record<string, any> = {};

async createChatCompletion<T = ChatCompletion>({
options,
retries = 3,
logger,
}: CreateChatCompletionOptions): Promise<T> {
const { image, requestId, ...optionsWithoutImageAndRequestId } = options;

const effectiveLogger: Logger = logger || ((logLine: LogLine) => console.error(JSON.stringify(logLine)));

const errorLevel = 0;
const infoLevel = 1;

if (image) {
effectiveLogger({
message: "Image provided. Vision is not currently supported by this custom client.",
level: errorLevel,
auxiliary: {
requestId: { value: String(requestId), type: "string" },
component: { value: "CustomOpenAIClientWrapper", type: "string" },
imageProvided: { value: String(true), type: "boolean" },
},
});
}

effectiveLogger({
message: "Creating chat completion with CustomOpenAIClientWrapper",
level: infoLevel,
auxiliary: {
options: { value: JSON.stringify({ ...optionsWithoutImageAndRequestId, requestId }), type: "object" },
modelName: { value: this.modelName, type: "string" },
component: { value: "CustomOpenAIClientWrapper", type: "string" },
},
});

let responseFormatPayload: { type: "json_object" } | undefined = undefined;
if (options.response_model && options.response_model.schema) {
responseFormatPayload = { type: "json_object" };
}

const { response_model, ...openaiOptions } = {
...optionsWithoutImageAndRequestId,
model: this.modelName,
};

const formattedMessages: ChatCompletionMessageParam[] = options.messages.map(
(message): ChatCompletionMessageParam => {
if (typeof message.content === 'string') {
return message as ChatCompletionMessageParam;
}
if (Array.isArray(message.content)) {
const contentParts = message.content.map((part) => {
if (part.type === 'image_url') {
return part as ChatCompletionContentPartImage;
}
return part as ChatCompletionContentPartText;
});
return { ...message, content: contentParts } as ChatCompletionMessageParam;
}
return message as ChatCompletionMessageParam;
}
);

const body: ChatCompletionCreateParamsNonStreaming = {
...openaiOptions,
messages: formattedMessages,
model: this.modelName,
response_format: responseFormatPayload,
stream: false,
tools: options.tools?.map((tool) => ({
function: {
name: tool.name,
description: tool.description,
parameters: tool.parameters,
},
type: "function",
})),
};

const response = await this.client.chat.completions.create(body);

effectiveLogger({
message: "Response received from OpenAI compatible API",
level: infoLevel,
auxiliary: {
choiceCount: { value: String(response.choices.length), type: "integer" },
firstChoiceFinishReason: { value: String(response.choices[0]?.finish_reason), type: "string" },
usage: { value: JSON.stringify(response.usage), type: "object" },
requestId: { value: String(requestId), type: "string" },
component: { value: "CustomOpenAIClientWrapper", type: "string" },
},
});

if (options.response_model && options.response_model.schema) {
const extractedData = response.choices[0]?.message?.content;
if (extractedData == null) {
effectiveLogger({
message: "No content in response message for structured response.",
level: errorLevel,
auxiliary: {
component: { value: "CustomOpenAIClientWrapper", type: "string" },
requestId: { value: String(requestId), type: "string" }
}
});
throw new CreateChatCompletionResponseError("No content in response message for structured response.");
}

let parsedData;
try {
parsedData = JSON.parse(extractedData);
} catch (e: any) {
effectiveLogger({
message: `Failed to parse JSON response: ${e.message}`,
level: errorLevel,
auxiliary: {
component: { value: "CustomOpenAIClientWrapper", type: "string" },
originalResponse: { value: extractedData, type: "string" },
requestId: { value: String(requestId), type: "string" }
}
});
if (retries > 0) {
return this.createChatCompletion({ options, logger, retries: retries - 1 });
}
throw new CreateChatCompletionResponseError(`Failed to parse JSON response: ${extractedData}`);
}

if (!validateZodSchema(options.response_model.schema, parsedData)) {
effectiveLogger({
message: "Invalid response schema after parsing.",
level: errorLevel,
auxiliary: {
component: { value: "CustomOpenAIClientWrapper", type: "string" },
parsedDataJSON: { value: JSON.stringify(parsedData), type: "object" },
requestId: { value: String(requestId), type: "string" }
}
});
if (retries > 0) {
return this.createChatCompletion({ options, logger, retries: retries - 1 });
}
throw new CreateChatCompletionResponseError("Invalid response schema");
}
return { data: parsedData, usage: response.usage } as T;
}

return { data: response.choices[0]?.message?.content, usage: response.usage } as T;
}
}
Loading