11
11
# ----------------------------------------------------------------------------
12
12
13
13
import os
14
+ import shutil
14
15
15
16
from . import product
16
17
from .. import shell
@@ -68,6 +69,9 @@ def build(self, host_target):
68
69
self .args .install_destdir ),
69
70
'-D' , 'CMAKE_MAKE_PROGRAM={}' .format (self .toolchain .ninja ),
70
71
'-D' , 'CMAKE_Swift_COMPILER={}' .format (swiftc ),
72
+ # SWIFT_ENABLE_TENSORFLOW
73
+ '-D' , 'USE_BUNDLED_CTENSORFLOW=YES' ,
74
+ # SWIFT_ENABLE_TENSORFLOW END
71
75
'-D' , 'TensorFlow_INCLUDE_DIR={}' .format (tensorflow_source_dir ),
72
76
'-D' , 'TensorFlow_LIBRARY={}' .format (
73
77
os .path .join (tensorflow_source_dir , 'bazel-bin' , 'tensorflow' ,
@@ -90,7 +94,7 @@ def test(self, host_target):
90
94
pass
91
95
92
96
def should_install (self , host_target ):
93
- return self .args .install_tensorflow_swift_apis
97
+ return self .args .build_tensorflow_swift_apis
94
98
95
99
def install (self , host_target ):
96
100
shell .call ([
@@ -116,6 +120,15 @@ def is_build_script_impl_product(cls):
116
120
def should_build (self , host_target ):
117
121
return self .args .build_tensorflow_swift_apis
118
122
123
+ def _get_tensorflow_library (self , host ):
124
+ if host .startswith ('macosx' ):
125
+ return ('libtensorflow.2.1.0.dylib' , 'libtensorflow.dylib' )
126
+
127
+ if host .startswith ('linux' ):
128
+ return ('libtensorflow.so.2.1.0' , 'libtensorflow.so' )
129
+
130
+ raise RuntimeError ('unknown host target {}' .format (host ))
131
+
119
132
def build (self , host_target ):
120
133
with shell .pushd (self .source_dir ):
121
134
shell .call ([
@@ -130,14 +143,8 @@ def build(self, host_target):
130
143
"//tensorflow:tensorflow" ,
131
144
])
132
145
133
- if host_target .startswith ('macosx' ):
134
- suffixed_lib_name = 'libtensorflow.2.1.0.dylib'
135
- unsuffixed_lib_name = 'libtensorflow.dylib'
136
- elif host_target .startswith ('linux' ):
137
- suffixed_lib_name = 'libtensorflow.so.2.1.0'
138
- unsuffixed_lib_name = 'libtensorflow.so'
139
- else :
140
- raise RuntimeError ('unknown host target {}' .format (host_target ))
146
+ (suffixed_lib_name , unsuffixed_lib_name ) = \
147
+ self ._get_tensorflow_library (host_target )
141
148
142
149
# NOTE: ignore the race condition here ....
143
150
try :
@@ -157,8 +164,76 @@ def test(self, host_target):
157
164
pass
158
165
159
166
def should_install (self , host_target ):
160
- return False
167
+ return self . args . build_tensorflow_swift_apis
161
168
162
169
def install (self , host_target ):
163
- pass
170
+ (suffixed_lib_name , unsuffixed_lib_name ) = \
171
+ self ._get_tensorflow_library (host_target )
172
+
173
+ subdir = None
174
+ if host_target .startswith ('macsox' ):
175
+ subdir = 'macosx'
176
+ if host_target .startswith ('linux' ):
177
+ subdir = 'linux'
178
+
179
+ if not subdir :
180
+ raise RuntimeError ('unknown host target {}' .format (host_target ))
181
+
182
+ try :
183
+ os .unlink (os .path .join (self .install_toolchain_path (),
184
+ 'usr' , 'lib' , 'swift' ,
185
+ subdir , suffixed_lib_name ))
186
+ os .makedirs (os .path .join (self .install_toolchain_path (),
187
+ 'usr' , 'lib' , 'swift' , subdir ))
188
+ except OSError :
189
+ pass
190
+ shutil .copy (os .path .join (self .source_dir , 'bazel-bin' , 'tensorflow' ,
191
+ suffixed_lib_name ),
192
+ os .path .join (self .install_toolchain_path (),
193
+ 'usr' , 'lib' , 'swift' ,
194
+ subdir , suffixed_lib_name ))
195
+
196
+ try :
197
+ os .unlink (os .path .join (self .install_toolchain_path (),
198
+ 'usr' , 'lib' , 'swift' ,
199
+ subdir , unsuffixed_lib_name ))
200
+ except OSError :
201
+ pass
202
+ os .symlink (suffixed_lib_name ,
203
+ os .path .join (self .install_toolchain_path (),
204
+ 'usr' , 'lib' , 'swift' ,
205
+ subdir , unsuffixed_lib_name ))
206
+
207
+ try :
208
+ shutil .rmtree (os .path .join (self .install_toolchain_path (),
209
+ 'usr' , 'lib' , 'swift' , 'tensorflow' ))
210
+ os .makedirs (os .path .join (self .install_toolchain_path (),
211
+ 'usr' , 'lib' , 'swift' , 'tensorflow' , 'c' ,
212
+ 'eager' ))
213
+ except OSError :
214
+ pass
215
+ for header in (
216
+ 'c_api.h' ,
217
+ 'c_api_experimental.h' ,
218
+ 'tf_attrtype.h' ,
219
+ 'tf_datatype.h' ,
220
+ 'tf_status.h' ,
221
+ 'tf_tensor.h' ,
222
+ 'eager/c_api.h' ,
223
+ ):
224
+ shutil .copy (os .path .join (self .source_dir , 'tensorflow' , 'c' , header ),
225
+ os .path .join (self .install_toolchain_path (),
226
+ 'usr' , 'lib' , 'swift' , 'tensorflow' , 'c' ,
227
+ header ))
228
+
229
+ for name in (
230
+ 'CTensorFlow.h' ,
231
+ 'module.modulemap' ,
232
+ ):
233
+ shutil .copy (os .path .join (self .source_dir , '..' ,
234
+ 'tensorflow-swift-apis' , 'Sources' ,
235
+ 'CTensorFlow' , name ),
236
+ os .path .join (self .install_toolchain_path (),
237
+ 'usr' , 'lib' , 'swift' , 'tensorflow' , name ))
238
+
164
239
# SWIFT_ENABLE_TENSORFLOW END
0 commit comments