Skip to content

Commit cd6b1b9

Browse files
committed
docs: Clean up testing and documentation
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent a9f33e4 commit cd6b1b9

File tree

16 files changed

+52
-24
lines changed

16 files changed

+52
-24
lines changed

BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pkg_tar(
1111
"//core/conversion/evaluators:include",
1212
"//core/execution:include",
1313
"//core/lowering:include",
14-
"//core/lowering/irfusers:include",
14+
"//core/lowering/passes:include",
1515
"//core/util:include",
1616
"//core/util/logging:include"
1717
],

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ Thanks for wanting to contribute! There are two main ways to handle supporting a
149149
150150
You can register a converter for your op using the `NodeConverterRegistry` inside your application.
151151

152+
## Known Limitations
153+
154+
- You cannot use both Adaptive Pooling in PyTorch and also use TRTorch Dynamic input shape
155+
152156
## Structure of the repo
153157

154158
| Component | Description |

core/compiler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
102102
ExtraInfo cfg) {
103103
// TODO: Should be doing a functional transform but need PR #31978
104104
// [jit] More robust mangling
105-
// torch::jit::script::Module new_mod = mod.clone();
105+
//torch::jit::script::Module new_mod = mod.clone();
106106
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
107107
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
108108
for (const torch::jit::script::Method& method : mod.get_methods()) {

core/lowering/lowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
2626
passes::RemoveDropout(g);
2727
passes::FuseFlattenLinear(g);
2828
passes::UnpackAddMM(g);
29-
passes::ExpandLogSoftmax(g);
29+
passes::UnpackLogSoftmax(g);
3030
//passes::RemoveDimExeception(g);
3131
//irfusers::UnpackBatchNorm(g);
3232
torch::jit::EliminateDeadCode(g);

core/lowering/passes/BUILD

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ cc_library(
66
"passes.h",
77
],
88
srcs = [
9+
"exception_elimination.cpp",
910
"fuse_flatten_linear.cpp",
10-
"expand_log_softmax.cpp",
1111
"remove_dropout.cpp",
12+
"unpack_addmm.cpp",
1213
"unpack_batch_norm.cpp",
13-
"exception_elimination.cpp",
14-
"unpack_addmm.cpp"
14+
"unpack_log_softmax.cpp",
1515
],
1616
deps = [
1717
"//core/util:prelude",
@@ -23,7 +23,7 @@ load("@rules_pkg//:pkg.bzl", "pkg_tar")
2323

2424
pkg_tar(
2525
name = "include",
26-
package_dir = "core/lowering/irfusers/",
27-
srcs = ["irfusers.h"],
26+
package_dir = "core/lowering/passes/",
27+
srcs = ["passes.h"],
2828
)
2929

core/lowering/passes/exception_elimination.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ struct ExceptionOrPassPatternElimination {
2121
: graph_(std::move(graph)) {}
2222

2323
void run() {
24-
LOG_GRAPH("Pre exeception or pass elimination: " << *graph_);
2524
findExceptionOrPassNodes(graph_->block());
2625
torch::jit::EliminateDeadCode(graph_);
2726
LOG_GRAPH("Post exeception or pass elimination: " << *graph_);

core/lowering/passes/fuse_flatten_linear.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "torch/csrc/jit/passes/fuse_linear.h"
22
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
33

4+
#include "core/util/prelude.h"
5+
46
namespace trtorch {
57
namespace core {
68
namespace lowering {
@@ -38,6 +40,7 @@ void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph) {
3840
flatten_linear_bias_none_to_linear.RegisterRewritePattern(
3941
flatten_linear_bias_none_pattern, fused_linear_bias_none);
4042
flatten_linear_bias_none_to_linear.runOnGraph(graph);
43+
LOG_GRAPH("Post flatten linear: " << *graph);
4144
}
4245

4346
} // namespace passes

core/lowering/passes/passes.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ namespace lowering {
88
namespace passes {
99

1010
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
11-
void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
1211
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
13-
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
1412
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
13+
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
14+
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
1515
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
1616

1717
} // namespace irfusers

core/lowering/passes/remove_dropout.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include <torch/csrc/jit/passes/fuse_linear.h>
22
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
33

4+
#include "core/util/prelude.h"
5+
46
namespace trtorch {
57
namespace core {
68
namespace lowering {
@@ -20,6 +22,7 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
2022
remove_dropout.RegisterRewritePattern(
2123
dropout_pattern, no_dropout_pattern);
2224
remove_dropout.runOnGraph(graph);
25+
LOG_GRAPH("Post remove dropout: " << *graph);
2326
}
2427

2528
} // namespace passes

core/lowering/passes/unpack_addmm.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "torch/csrc/jit/passes/fuse_linear.h"
22
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
33

4+
#include "core/util/prelude.h"
5+
46
namespace trtorch {
57
namespace core {
68
namespace lowering {
@@ -23,6 +25,7 @@ void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
2325
torch::jit::SubgraphRewriter unpack_addmm;
2426
unpack_addmm.RegisterRewritePattern(addmm_pattern, mm_add_pattern);
2527
unpack_addmm.runOnGraph(graph);
28+
LOG_GRAPH("Post unpack addmm: " << *graph);
2629
}
2730

2831

core/lowering/passes/unpack_batch_norm.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
22

3+
#include "core/util/prelude.h"
4+
35
namespace trtorch {
46
namespace core {
57
namespace lowering {
@@ -39,6 +41,7 @@ void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph) {
3941
torch::jit::SubgraphRewriter unpack_batch_norm;
4042
unpack_batch_norm.RegisterRewritePattern(batch_norm_pattern, expanded_batch_norm_pattern);
4143
unpack_batch_norm.runOnGraph(graph);
44+
LOG_GRAPH("Post unpack batchnorm: " << *graph);
4245
}
4346
} // Namespace passes
4447
} // namespace lowering

core/lowering/passes/expand_log_softmax.cpp renamed to core/lowering/passes/unpack_log_softmax.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#include "torch/csrc/jit/passes/fuse_linear.h"
22
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
33

4+
#include "core/util/prelude.h"
5+
46
namespace trtorch {
57
namespace core {
68
namespace lowering {
79
namespace passes {
810

9-
void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph) {
11+
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph) {
1012
// Its easier for TensorRT if we seperate softmax and log
1113
// There might need to be a reshape inserted see:
1214
// https://github.com/onnx/onnx-tensorrt/blob/5dca8737851118f6ab8a33ea1f7bcb7c9f06caf5/builtin_op_importers.cpp#L1593
@@ -43,6 +45,7 @@ void ExpandLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph) {
4345
logsoftmax_none_to_softmax_log_none.RegisterRewritePattern(
4446
logsoftmax_none_pattern, softmax_log_none_pattern);
4547
logsoftmax_none_to_softmax_log_none.runOnGraph(graph);
48+
LOG_GRAPH("Post unpack logsoftmax: " << *graph);
4649
}
4750

4851
} // namespace passes

core/util/logging/TRTorchLogger.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ namespace {
101101
TRTorchLogger& get_global_logger() {
102102
#ifndef NDEBUG
103103
static TRTorchLogger global_logger("[TRTorch - Debug Build] - ",
104-
LogLevel::kGRAPH,
104+
LogLevel::kDEBUG,
105105
true);
106106
#else
107107
static TRTorchLogger global_logger("[TRTorch] - ",

cpp/api/include/trtorch/logging.h

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ namespace logging {
99
* Emum for setting message severity
1010
*/
1111
enum Level {
12-
kINTERNAL_ERROR,
13-
kERROR,
14-
kWARNING,
15-
kINFO,
16-
kDEBUG,
12+
kINTERNAL_ERROR, // Only print messages for internal errors
13+
kERROR, // Print all internal errors and errors (default)
14+
kWARNING, // Print warnings and errors
15+
kINFO, // Print all info, warnings and errors
16+
kDEBUG, // Print all debug info, info, warnings and errors
17+
kGRAPH, // Print everything including the intermediate graphs of the lowering phase
1718
};
1819

1920
// Are these ones necessary for the user?
@@ -35,7 +36,7 @@ TRTORCH_API void set_reportable_log_level(Level lvl);
3536
TRTORCH_API void set_is_colored_output_on(bool colored_output_on);
3637

3738
/**
38-
* @brief Get the current reportable log level
39+
* @brief Get the current reportable log level
3940
*/
4041
TRTORCH_API Level get_reportable_log_level();
4142

@@ -45,10 +46,10 @@ TRTORCH_API Level get_reportable_log_level();
4546
TRTORCH_API bool get_is_colored_output_on();
4647

4748
/**
48-
* @brief Adds a message to the global log
49+
* @brief Adds a message to the global log
4950
*
50-
* @param lvl: trtorch::logging::Level - Severity of the message
51-
* @param msg: std::string - Message to be logged
51+
* @param lvl: trtorch::logging::Level - Severity of the message
52+
* @param msg: std::string - Message to be logged
5253
*/
5354
// Dont know if we want this?
5455
TRTORCH_API void log(Level lvl, std::string msg);

cpp/api/src/logging.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace logging {
77
std::string get_logging_prefix() {
88
return core::util::logging::get_logger().get_logging_prefix();
99
}
10-
10+
1111
void set_logging_prefix(std::string prefix) {
1212
core::util::logging::get_logger().set_logging_prefix(prefix);
1313
}
@@ -27,6 +27,9 @@ void set_reportable_log_level(Level lvl) {
2727
case Level::kINFO:
2828
log_lvl = core::util::logging::LogLevel::kINFO;
2929
break;
30+
case Level::kGRAPH:
31+
log_lvl = core::util::logging::LogLevel::kGRAPH;
32+
break;
3033
case Level::kDEBUG:
3134
default:
3235
log_lvl = core::util::logging::LogLevel::kDEBUG;
@@ -50,12 +53,14 @@ Level get_reportable_log_level() {
5053
return Level::kWARNING;
5154
case core::util::logging::LogLevel::kINFO:
5255
return Level::kINFO;
56+
case core::util::logging::LogLevel::kGRAPH:
57+
return Level::kGRAPH;
5358
case core::util::logging::LogLevel::kDEBUG:
5459
default:
5560
return Level::kDEBUG;
5661
}
5762
}
58-
63+
5964
bool get_is_colored_output_on() {
6065
return core::util::logging::get_logger().get_is_colored_output_on();
6166
}

cpp/trtorchexec/main.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,11 +55,13 @@ int main(int argc, const char* argv[]) {
5555
dims.push_back(v);
5656
}
5757

58+
std::cout << "Checking operator support" << std::endl;
5859
if (!trtorch::CheckMethodOperatorSupport(mod, "forward")) {
5960
std::cerr << "Method is not currently supported by TRTorch" << std::endl;
6061
return -1;
6162
}
6263

64+
std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl;
6365
auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", dims);
6466
std::ofstream out("/tmp/engine_converted_from_jit.trt");
6567
out << engine;
@@ -75,7 +77,9 @@ int main(int argc, const char* argv[]) {
7577
std::vector<at::Tensor> jit_results;
7678
jit_results.push_back(jit_results_ivalues.toTensor());
7779

80+
std::cout << "Compiling graph as module" << std::endl;
7881
auto trt_mod = trtorch::CompileGraph(mod, dims);
82+
std::cout << "Running TRT module" << std::endl;
7983
torch::jit::IValue trt_results_ivalues = trt_mod.forward(trt_inputs_ivalues);
8084
std::vector<at::Tensor> trt_results;
8185
trt_results.push_back(trt_results_ivalues.toTensor());

0 commit comments

Comments
 (0)