Skip to content

Commit 768942e

Browse files
committed
build: add some tensorflow install rules
1 parent 98c01d6 commit 768942e

File tree

1 file changed

+86
-11
lines changed
  • utils/swift_build_support/swift_build_support/products

1 file changed

+86
-11
lines changed

utils/swift_build_support/swift_build_support/products/tensorflow.py

Lines changed: 86 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# ----------------------------------------------------------------------------
1212

1313
import os
14+
import shutil
1415

1516
from . import product
1617
from .. import shell
@@ -68,6 +69,9 @@ def build(self, host_target):
6869
self.args.install_destdir),
6970
'-D', 'CMAKE_MAKE_PROGRAM={}'.format(self.toolchain.ninja),
7071
'-D', 'CMAKE_Swift_COMPILER={}'.format(swiftc),
72+
# SWIFT_ENABLE_TENSORFLOW
73+
'-D', 'USE_BUNDLED_CTENSORFLOW=YES',
74+
# SWIFT_ENABLE_TENSORFLOW END
7175
'-D', 'TensorFlow_INCLUDE_DIR={}'.format(tensorflow_source_dir),
7276
'-D', 'TensorFlow_LIBRARY={}'.format(
7377
os.path.join(tensorflow_source_dir, 'bazel-bin', 'tensorflow',
@@ -90,7 +94,7 @@ def test(self, host_target):
9094
pass
9195

9296
def should_install(self, host_target):
93-
return self.args.install_tensorflow_swift_apis
97+
return self.args.build_tensorflow_swift_apis
9498

9599
def install(self, host_target):
96100
shell.call([
@@ -116,6 +120,15 @@ def is_build_script_impl_product(cls):
116120
def should_build(self, host_target):
117121
return self.args.build_tensorflow_swift_apis
118122

123+
def _get_tensorflow_library(self, host):
124+
if host.startswith('macosx'):
125+
return ('libtensorflow.2.1.0.dylib', 'libtensorflow.dylib')
126+
127+
if host.startswith('linux'):
128+
return ('libtensorflow.so.2.1.0', 'libtensorflow.so')
129+
130+
raise RuntimeError('unknown host target {}'.format(host))
131+
119132
def build(self, host_target):
120133
with shell.pushd(self.source_dir):
121134
shell.call([
@@ -130,14 +143,8 @@ def build(self, host_target):
130143
"//tensorflow:tensorflow",
131144
])
132145

133-
if host_target.startswith('macosx'):
134-
suffixed_lib_name = 'libtensorflow.2.1.0.dylib'
135-
unsuffixed_lib_name = 'libtensorflow.dylib'
136-
elif host_target.startswith('linux'):
137-
suffixed_lib_name = 'libtensorflow.so.2.1.0'
138-
unsuffixed_lib_name = 'libtensorflow.so'
139-
else:
140-
raise RuntimeError('unknown host target {}'.format(host_target))
146+
(suffixed_lib_name, unsuffixed_lib_name) = \
147+
self._get_tensorflow_library(host_target)
141148

142149
# NOTE: ignore the race condition here ....
143150
try:
@@ -157,8 +164,76 @@ def test(self, host_target):
157164
pass
158165

159166
def should_install(self, host_target):
160-
return False
167+
return self.args.build_tensorflow_swift_apis
161168

162169
def install(self, host_target):
163-
pass
170+
(suffixed_lib_name, unsuffixed_lib_name) = \
171+
self._get_tensorflow_library(host_target)
172+
173+
subdir = None
174+
if host_target.startswith('macsox'):
175+
subdir = 'macosx'
176+
if host_target.startswith('linux'):
177+
subdir = 'linux'
178+
179+
if not subdir:
180+
raise RuntimeError('unknown host target {}'.format(host_target))
181+
182+
try:
183+
os.unlink(os.path.join(self.install_toolchain_path(),
184+
'usr', 'lib', 'swift',
185+
subdir, suffixed_lib_name))
186+
os.makedirs(os.path.join(self.install_toolchain_path(),
187+
'usr', 'lib', 'swift', subdir))
188+
except OSError:
189+
pass
190+
shutil.copy(os.path.join(self.source_dir, 'bazel-bin', 'tensorflow',
191+
suffixed_lib_name),
192+
os.path.join(self.install_toolchain_path(),
193+
'usr', 'lib', 'swift',
194+
subdir, suffixed_lib_name))
195+
196+
try:
197+
os.unlink(os.path.join(self.install_toolchain_path(),
198+
'usr', 'lib', 'swift',
199+
subdir, unsuffixed_lib_name))
200+
except OSError:
201+
pass
202+
os.symlink(suffixed_lib_name,
203+
os.path.join(self.install_toolchain_path(),
204+
'usr', 'lib', 'swift',
205+
subdir, unsuffixed_lib_name))
206+
207+
try:
208+
shutil.rmtree(os.path.join(self.install_toolchain_path(),
209+
'usr', 'lib', 'swift', 'tensorflow'))
210+
os.makedirs(os.path.join(self.install_toolchain_path(),
211+
'usr', 'lib', 'swift', 'tensorflow', 'c',
212+
'eager'))
213+
except OSError:
214+
pass
215+
for header in (
216+
'c_api.h',
217+
'c_api_experimental.h',
218+
'tf_attrtype.h',
219+
'tf_datatype.h',
220+
'tf_status.h',
221+
'tf_tensor.h',
222+
'eager/c_api.h',
223+
):
224+
shutil.copy(os.path.join(self.source_dir, 'tensorflow', 'c', header),
225+
os.path.join(self.install_toolchain_path(),
226+
'usr', 'lib', 'swift', 'tensorflow', 'c',
227+
header))
228+
229+
for name in (
230+
'CTensorFlow.h',
231+
'module.modulemap',
232+
):
233+
shutil.copy(os.path.join(self.source_dir, '..',
234+
'tensorflow-swift-apis', 'Sources',
235+
'CTensorFlow', name),
236+
os.path.join(self.install_toolchain_path(),
237+
'usr', 'lib', 'swift', 'tensorflow', name))
238+
164239
# SWIFT_ENABLE_TENSORFLOW END

0 commit comments

Comments
 (0)