Skip to content

Commit 3fdce73

Browse files
authored
Patching for generativeai Python package (#1329)
http://b/308644984
1 parent aab5d76 commit 3fdce73

File tree

4 files changed

+138
-0
lines changed

4 files changed

+138
-0
lines changed

Dockerfile.tmpl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,7 @@ RUN pip install mpld3 \
364364
eli5 \
365365
kaggle \
366366
kagglehub \
367+
google-generativeai \
367368
mock \
368369
pytest && \
369370
/tmp/clean-layer.sh

patches/sitecustomize.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import importlib
88
import importlib.machinery
99

10+
import wrapt
11+
1012
class GcpModuleFinder(importlib.abc.MetaPathFinder):
1113
_MODULES = [
1214
'google.cloud.bigquery',
@@ -73,3 +75,41 @@ def exec_module(self, module):
7375

7476
if not hasattr(sys, 'frozen'):
7577
sys.meta_path.insert(0, GcpModuleFinder())
78+
79+
@wrapt.when_imported('google.generativeai')
80+
def post_import_logic(module):
81+
if os.getenv('KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION') != None:
82+
return
83+
if (os.getenv('KAGGLE_DATA_PROXY_TOKEN') == None or
84+
os.getenv('KAGGLE_USER_SECRETS_TOKEN') == None or
85+
os.getenv('KAGGLE_DATA_PROXY_URL') == None):
86+
return
87+
88+
old_configure = module.configure
89+
90+
def new_configure(*args, **kwargs):
91+
if ('default_metadata' in kwargs):
92+
default_metadata = kwargs['default_metadata']
93+
else:
94+
default_metadata = []
95+
default_metadata.append(("x-kaggle-proxy-data", os.environ['KAGGLE_DATA_PROXY_TOKEN']))
96+
user_secrets_token = os.environ['KAGGLE_USER_SECRETS_TOKEN']
97+
default_metadata.append(('x-kaggle-authorization', f'Bearer {user_secrets_token}'))
98+
kwargs['default_metadata'] = default_metadata
99+
100+
if ('client_options' in kwargs):
101+
client_options = kwargs['client_options']
102+
else:
103+
client_options = {}
104+
client_options['api_endpoint'] = os.environ['KAGGLE_DATA_PROXY_URL']
105+
if os.getenv('KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY') != None:
106+
client_options['api_endpoint'] += '/palmapi'
107+
kwargs['transport'] = 'rest'
108+
elif 'transport' in kwargs and kwargs['transport'] == 'rest':
109+
client_options['api_endpoint'] += '/palmapi'
110+
kwargs['client_options'] = client_options
111+
112+
old_configure(*args, **kwargs)
113+
114+
module.configure = new_configure
115+
module.configure() # generativeai can use GOOGLE_API_KEY env variable, so make sure we have the other configs set
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import json
2+
import unittest
3+
import threading
4+
5+
from test.support.os_helper import EnvironmentVarGuard
6+
from urllib.parse import urlparse
7+
8+
from http.server import BaseHTTPRequestHandler, HTTPServer
9+
10+
class HTTPHandler(BaseHTTPRequestHandler):
11+
called = False
12+
path = None
13+
headers = {}
14+
15+
def do_HEAD(self):
16+
self.send_response(200)
17+
18+
def do_GET(self):
19+
HTTPHandler.path = self.path
20+
HTTPHandler.headers = self.headers
21+
HTTPHandler.called = True
22+
self.send_response(200)
23+
self.send_header("Content-type", "application/json")
24+
self.end_headers()
25+
26+
class TestGoogleGenerativeAiPatch(unittest.TestCase):
27+
endpoint = "http://127.0.0.1:80"
28+
29+
def test_proxy_enabled(self):
30+
env = EnvironmentVarGuard()
31+
secrets_token = "secrets_token"
32+
proxy_token = "proxy_token"
33+
env.set("KAGGLE_USER_SECRETS_TOKEN", secrets_token)
34+
env.set("KAGGLE_DATA_PROXY_TOKEN", proxy_token)
35+
env.set("KAGGLE_DATA_PROXY_URL", self.endpoint)
36+
env.set("KAGGLE_GOOGLE_GENERATIVE_AI_USE_REST_ONLY", "True")
37+
server_address = urlparse(self.endpoint)
38+
with env:
39+
with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd:
40+
threading.Thread(target=httpd.serve_forever).start()
41+
import google.generativeai as palm
42+
api_key = "NotARealAPIKey"
43+
palm.configure(api_key = api_key)
44+
try:
45+
for _ in palm.list_models():
46+
pass
47+
except:
48+
pass
49+
httpd.shutdown()
50+
self.assertTrue(HTTPHandler.called)
51+
self.assertIn("/palmapi", HTTPHandler.path)
52+
self.assertEqual(proxy_token, HTTPHandler.headers["x-kaggle-proxy-data"])
53+
self.assertEqual("Bearer {}".format(secrets_token), HTTPHandler.headers["x-kaggle-authorization"])
54+
self.assertEqual(api_key, HTTPHandler.headers["x-goog-api-key"])
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import json
2+
import unittest
3+
import threading
4+
5+
from test.support.os_helper import EnvironmentVarGuard
6+
from urllib.parse import urlparse
7+
8+
from http.server import BaseHTTPRequestHandler, HTTPServer
9+
10+
class HTTPHandler(BaseHTTPRequestHandler):
11+
called = False
12+
13+
def do_HEAD(self):
14+
self.send_response(200)
15+
16+
def do_GET(self):
17+
HTTPHandler.called = True
18+
self.send_response(200)
19+
self.send_header("Content-type", "application/json")
20+
self.end_headers()
21+
22+
class TestGoogleGenerativeAiPatchDisabled(unittest.TestCase):
23+
endpoint = "http://127.0.0.1:80"
24+
25+
def test_disabled(self):
26+
env = EnvironmentVarGuard()
27+
env.set("KAGGLE_USER_SECRETS_TOKEN", "foobar")
28+
env.set("KAGGLE_DATA_PROXY_TOKEN", "foobar")
29+
env.set("KAGGLE_DATA_PROXY_URL", self.endpoint)
30+
env.set("KAGGLE_DISABLE_GOOGLE_GENERATIVE_AI_INTEGRATION", "True")
31+
server_address = urlparse(self.endpoint)
32+
with env:
33+
with HTTPServer((server_address.hostname, server_address.port), HTTPHandler) as httpd:
34+
threading.Thread(target=httpd.serve_forever).start()
35+
import google.generativeai as palm
36+
palm.configure(api_key = "NotARealAPIKey")
37+
try:
38+
for _ in palm.list_models():
39+
pass
40+
except:
41+
pass
42+
httpd.shutdown()
43+
self.assertFalse(HTTPHandler.called)

0 commit comments

Comments
 (0)