Skip to content

Commit 33481c9

Browse files
committed
[mlgo] Fetch models from path / URL
Allow custom location for pre-trained models used when AOT-compiling policies. Differential Revision: https://reviews.llvm.org/D96796
1 parent a3c783d commit 33481c9

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

llvm/cmake/modules/TensorFlowCompile.cmake

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
1+
# Ensure the ${model} is available at ${final_path}.
2+
#
3+
function(tfgetmodel model final_path)
4+
if (IS_ABSOLUTE ${model})
5+
set(${final_path} ${model} PARENT_SCOPE)
6+
else()
7+
set(${final_path}
8+
${CMAKE_CURRENT_SOURCE_DIR}/${model} PARENT_SCOPE)
9+
endif()
10+
endfunction()
11+
112
# Run the tensorflow compiler (saved_model_cli) on the saved model in the
213
# ${model} directory, looking for the ${tag_set} tag set, and the SignatureDef
314
# ${signature_def_key}.
415
# Produce a pair of files called ${fname}.h and ${fname}.o in the
516
# ${CMAKE_CURRENT_BINARY_DIR}. The generated header will define a C++ class
617
# called ${cpp_class} - which may be a namespace-qualified class name.
718
function(tfcompile model tag_set signature_def_key fname cpp_class)
8-
if (IS_ABSOLUTE ${model})
9-
set(LLVM_ML_MODELS_ABSOLUTE ${model})
10-
else()
11-
set(LLVM_ML_MODELS_ABSOLUTE
12-
${CMAKE_CURRENT_SOURCE_DIR}/${model})
13-
endif()
14-
19+
tfgetmodel(${model} LLVM_ML_MODELS_ABSOLUTE)
20+
message("Using model at " ${LLVM_ML_MODELS_ABSOLUTE})
1521
set(prefix ${CMAKE_CURRENT_BINARY_DIR}/${fname})
1622
set(obj_file ${prefix}.o)
1723
set(hdr_file ${prefix}.h)

llvm/lib/Analysis/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
if (DEFINED LLVM_HAVE_TF_AOT OR DEFINED LLVM_HAVE_TF_API)
22
if (DEFINED LLVM_HAVE_TF_AOT)
3+
set(LLVM_INLINER_MODEL_PATH "models/inliner"
4+
CACHE STRING
5+
"ML-driven inliner policy location (path to saved model)")
36
include(TensorFlowCompile)
4-
tfcompile(models/inliner serve action InlinerSizeModel llvm::InlinerSizeModel)
7+
tfcompile(${LLVM_INLINER_MODEL_PATH} serve action InlinerSizeModel llvm::InlinerSizeModel)
58
list(APPEND GeneratedMLSources
69
$<TARGET_OBJECTS:tf_xla_runtime_objects>
710
${GENERATED_OBJS}

0 commit comments

Comments
 (0)