Skip to content

feat: support prim::If in automatic fallback #447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 28, 2021
Merged

feat: support prim::If in automatic fallback #447

merged 12 commits into from
Jul 28, 2021

Conversation

bowang007
Copy link
Collaborator

@bowang007 bowang007 commented May 4, 2021

Description

Support prim::If in automatic fallback.
Previously, all the prim::If operators will be fallback to PyTorch nodes when automatic fallback feature is enabled. This could be optimized since some operators in prim::If block could also be converted to TensorRT engines.

Type of change

  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes

@bowang007 bowang007 requested review from narendasan and peri044 May 4, 2021 23:03
@github-actions github-actions bot added the component: core Issues re: The core compiler label May 4, 2021
Signed-off-by: Bo Wang <[email protected]>
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/workspace/core/compiler.cpp b/tmp/changes.txt
index dfcd69d..2f2f202 100644
--- a/workspace/core/compiler.cpp
+++ b/tmp/changes.txt
@@ -203,8 +203,8 @@ void AddIfBlockToGraph(
    auto cur_block_mapping = graph_and_mapping.second;
    std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
    for (auto& i : cur_block_mapping) {
-      // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then it's
-      // mini graph's input
+      // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then
+      // it's mini graph's input
      if (old_to_new_g.count(i.first)) {
        block_graph_to_new_g[i.second] = old_to_new_g[i.first];
      }
@@ -317,8 +317,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
        input_ranges.insert({g->inputs()[i], cfg.convert_info.input_ranges[i]});
      }
      auto input_ivalues_map = partitioning::generateRandomInputs(input_ranges);
-      auto graph_and_mapping =
-          ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, named_params);
+      auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, named_params);
      new_g = graph_and_mapping.first;
      LOG_INFO(*new_g << "(FallbackGraph)\n");

diff --git a/workspace/core/partitioning/partitioning.cpp b/tmp/changes.txt
index 692878a..6590cd5 100644
--- a/workspace/core/partitioning/partitioning.cpp
+++ b/tmp/changes.txt
@@ -2,8 +2,8 @@

#include <queue>
#include "core/conversion/conversion.h"
-#include "torch/csrc/jit/passes/constant_pooling.h"
#include "core/partitioning/shape_analysis.h"
+#include "torch/csrc/jit/passes/constant_pooling.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"

namespace trtorch {
@@ -150,7 +150,8 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) {
      if (!updated_segments.count(first_torch_id)) {
        auto to_inject_blocks = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]);
        segmented_blocks.erase(segmented_blocks.begin() + first_torch_id);
-        segmented_blocks.insert(segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end());
+        segmented_blocks.insert(
+            segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end());
        updated_segments.insert(first_torch_id);
      }
    } else {
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/workspace/core/compiler.cpp b/tmp/changes.txt
index dfcd69d..2f2f202 100644
--- a/workspace/core/compiler.cpp
+++ b/tmp/changes.txt
@@ -203,8 +203,8 @@ void AddIfBlockToGraph(
    auto cur_block_mapping = graph_and_mapping.second;
    std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
    for (auto& i : cur_block_mapping) {
-      // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then it's
-      // mini graph's input
+      // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then
+      // it's mini graph's input
      if (old_to_new_g.count(i.first)) {
        block_graph_to_new_g[i.second] = old_to_new_g[i.first];
      }
@@ -317,8 +317,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
        input_ranges.insert({g->inputs()[i], cfg.convert_info.input_ranges[i]});
      }
      auto input_ivalues_map = partitioning::generateRandomInputs(input_ranges);
-      auto graph_and_mapping =
-          ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, named_params);
+      auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, named_params);
      new_g = graph_and_mapping.first;
      LOG_INFO(*new_g << "(FallbackGraph)\n");

diff --git a/workspace/core/partitioning/partitioning.cpp b/tmp/changes.txt
index 692878a..6590cd5 100644
--- a/workspace/core/partitioning/partitioning.cpp
+++ b/tmp/changes.txt
@@ -2,8 +2,8 @@

#include <queue>
#include "core/conversion/conversion.h"
-#include "torch/csrc/jit/passes/constant_pooling.h"
#include "core/partitioning/shape_analysis.h"
+#include "torch/csrc/jit/passes/constant_pooling.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"

namespace trtorch {
@@ -150,7 +150,8 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) {
      if (!updated_segments.count(first_torch_id)) {
        auto to_inject_blocks = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]);
        segmented_blocks.erase(segmented_blocks.begin() + first_torch_id);
-        segmented_blocks.insert(segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end());
+        segmented_blocks.insert(
+            segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end());
        updated_segments.insert(first_torch_id);
      }
    } else {
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/workspace/core/compiler.cpp b/tmp/changes.txt
index dfcd69d..2f2f202 100644
--- a/workspace/core/compiler.cpp
+++ b/tmp/changes.txt
@@ -203,8 +203,8 @@ void AddIfBlockToGraph(
    auto cur_block_mapping = graph_and_mapping.second;
    std::unordered_map<torch::jit::Value*, torch::jit::Value*> block_graph_to_new_g;
    for (auto& i : cur_block_mapping) {
-      // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then it's
-      // mini graph's input
+      // for every pair in then_mapping, old_value => mini graph value, if old_value also appears in old_to_new_g, then
+      // it's mini graph's input
      if (old_to_new_g.count(i.first)) {
        block_graph_to_new_g[i.second] = old_to_new_g[i.first];
      }
@@ -317,8 +317,7 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
        input_ranges.insert({g->inputs()[i], cfg.convert_info.input_ranges[i]});
      }
      auto input_ivalues_map = partitioning::generateRandomInputs(input_ranges);
-      auto graph_and_mapping =
-          ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, named_params);
+      auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, named_params);
      new_g = graph_and_mapping.first;
      LOG_INFO(*new_g << "(FallbackGraph)\n");

diff --git a/workspace/core/partitioning/partitioning.cpp b/tmp/changes.txt
index 692878a..6590cd5 100644
--- a/workspace/core/partitioning/partitioning.cpp
+++ b/tmp/changes.txt
@@ -2,8 +2,8 @@

#include <queue>
#include "core/conversion/conversion.h"
-#include "torch/csrc/jit/passes/constant_pooling.h"
#include "core/partitioning/shape_analysis.h"
+#include "torch/csrc/jit/passes/constant_pooling.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"

namespace trtorch {
@@ -150,7 +150,8 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) {
      if (!updated_segments.count(first_torch_id)) {
        auto to_inject_blocks = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]);
        segmented_blocks.erase(segmented_blocks.begin() + first_torch_id);
-        segmented_blocks.insert(segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end());
+        segmented_blocks.insert(
+            segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end());
        updated_segments.insert(first_torch_id);
      }
    } else {
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

@github-actions github-actions bot added the component: tests Issues re: Tests label May 19, 2021
Signed-off-by: Dheeraj Peri <[email protected]>
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/setup.py
--- /workspace/tests/modules/hub.py	(original)
+++ /workspace/tests/modules/hub.py	(reformatted)
@@ -99,6 +99,7 @@

# Sample Conditional Model (for testing partitioning and fallback in conditionals)
class FallbackIf(torch.nn.Module):
+
    def __init__(self):
        super(FallbackIf, self).__init__()
        self.relu1 = torch.nn.ReLU()
@@ -120,6 +121,7 @@
        x = self.conv1(x)
        return x

+
conditional_model = FallbackIf().eval().cuda()
conditional_script_model = torch.jit.script(conditional_model)
torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt")
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

@peri044
Copy link
Collaborator

peri044 commented Jul 21, 2021

@narendasan this is ready for your review

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
--- /workspace/tests/modules/hub.py	(original)
+++ /workspace/tests/modules/hub.py	(reformatted)
@@ -108,6 +108,7 @@

# Sample Conditional Model (for testing partitioning and fallback in conditionals)
class FallbackIf(torch.nn.Module):
+
    def __init__(self):
        super(FallbackIf, self).__init__()
        self.relu1 = torch.nn.ReLU()
@@ -129,6 +130,7 @@
        x = self.conv1(x)
        return x

+
conditional_model = FallbackIf().eval().cuda()
conditional_script_model = torch.jit.script(conditional_model)
torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt")
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Signed-off-by: Dheeraj Peri <[email protected]>
Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
--- /workspace/tests/modules/hub.py	(original)
+++ /workspace/tests/modules/hub.py	(reformatted)
@@ -108,6 +108,7 @@

# Sample Conditional Model (for testing partitioning and fallback in conditionals)
class FallbackIf(torch.nn.Module):
+
    def __init__(self):
        super(FallbackIf, self).__init__()
        self.relu1 = torch.nn.ReLU()
@@ -129,6 +130,7 @@
        x = self.conv1(x)
        return x

+
conditional_model = FallbackIf().eval().cuda()
conditional_script_model = torch.jit.script(conditional_model)
torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt")
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_api.py
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to C++ style guidelines:

diff --git a/workspace/core/compiler.cpp b/tmp/changes.txt
index 4dea4a7..9a63c9c 100644
--- a/workspace/core/compiler.cpp
+++ b/tmp/changes.txt
@@ -332,46 +332,46 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
        return mod;
      }

-// <<<<<<< HEAD
-// =======
-//       std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
-//       // add global graph's input to old_to_new_g mapping
-//       for (auto input : g->inputs()) {
-//         util::getOrAddInputForValue(input, new_g, old_to_new_g);
-//       }
-//       for (auto& seg_block : segmented_blocks) {
-//         std::string cur_block_target =
-//             seg_block.target() == partitioning::SegmentedBlock::kTensorRT ? "TensorRT" : "Torch";
-//         LOG_INFO(*seg_block.g() << "(Sub Graph" << cur_block_target << "Block)\n");
-//         std::ostringstream trt_engine_id;
-//         trt_engine_id << reinterpret_cast<const int*>(&seg_block);
-//         if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
-//           std::vector<ir::Input> inputs;
-//           for (auto& shape : seg_block.in_shape()) {
-//             inputs.push_back(ir::Input(shape));
-//           }
-//           // update the input ranges for each segments
-//           convert_cfg.inputs = inputs;
-//           auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
-//           auto temp_g = std::make_shared<torch::jit::Graph>();
-//           auto device_spec = convert_cfg.engine_settings.device;
-//           auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
-//           AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
-//
-//           seg_block.update_graph(temp_g);
-//           AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
-//         } else {
-//           AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
-//         }
-//       }
-//
-//       for (auto& output : g->outputs()) {
-//         new_g->registerOutput(old_to_new_g[output]);
-//       }
-//
-//       LOG_INFO(*new_g << "(FallbackGraph)\n");
-//
-// >>>>>>> master
+      // <<<<<<< HEAD
+      // =======
+      //       std::unordered_map<torch::jit::Value*, torch::jit::Value*> old_to_new_g;
+      //       // add global graph's input to old_to_new_g mapping
+      //       for (auto input : g->inputs()) {
+      //         util::getOrAddInputForValue(input, new_g, old_to_new_g);
+      //       }
+      //       for (auto& seg_block : segmented_blocks) {
+      //         std::string cur_block_target =
+      //             seg_block.target() == partitioning::SegmentedBlock::kTensorRT ? "TensorRT" : "Torch";
+      //         LOG_INFO(*seg_block.g() << "(Sub Graph" << cur_block_target << "Block)\n");
+      //         std::ostringstream trt_engine_id;
+      //         trt_engine_id << reinterpret_cast<const int*>(&seg_block);
+      //         if (seg_block.target() == partitioning::SegmentedBlock::kTensorRT) {
+      //           std::vector<ir::Input> inputs;
+      //           for (auto& shape : seg_block.in_shape()) {
+      //             inputs.push_back(ir::Input(shape));
+      //           }
+      //           // update the input ranges for each segments
+      //           convert_cfg.inputs = inputs;
+      //           auto engine = conversion::ConvertBlockToEngine(seg_block.block(), convert_cfg, named_params);
+      //           auto temp_g = std::make_shared<torch::jit::Graph>();
+      //           auto device_spec = convert_cfg.engine_settings.device;
+      //           auto cuda_device = runtime::CudaDevice(device_spec.gpu_id, device_spec.device_type);
+      //           AddEngineToGraph(new_mod, temp_g, engine, cuda_device, trt_engine_id.str(), true);
+      //
+      //           seg_block.update_graph(temp_g);
+      //           AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
+      //         } else {
+      //           AddSegmentedBlockToGraph(new_g, seg_block, old_to_new_g);
+      //         }
+      //       }
+      //
+      //       for (auto& output : g->outputs()) {
+      //         new_g->registerOutput(old_to_new_g[output]);
+      //       }
+      //
+      //       LOG_INFO(*new_g << "(FallbackGraph)\n");
+      //
+      // >>>>>>> master
      auto new_method = new_mod._ivalue()->compilation_unit()->create_function(method.name(), new_g);
      auto schema = util::GenerateGraphSchema(new_method->name(), new_g);
      new_mod.type()->addMethod(new_method);
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
--- /workspace/tests/modules/hub.py	(original)
+++ /workspace/tests/modules/hub.py	(reformatted)
@@ -100,6 +100,7 @@

# Sample Conditional Model (for testing partitioning and fallback in conditionals)
class FallbackIf(torch.nn.Module):
+
    def __init__(self):
        super(FallbackIf, self).__init__()
        self.relu1 = torch.nn.ReLU()
@@ -121,6 +122,7 @@
        x = self.conv1(x)
        return x

+
conditional_model = FallbackIf().eval().cuda()
conditional_script_model = torch.jit.script(conditional_model)
torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt")
Reformatting /workspace/tests/modules/hub.py
Reformatting /workspace/tests/py/test_to_backend_api.py
Reformatting /workspace/tests/py/test_api.py
Reformatting /workspace/tests/py/test_api_dla.py
Reformatting /workspace/tests/py/test_trt_intercompatability.py
Reformatting /workspace/tests/py/model_test_case.py
Reformatting /workspace/tests/py/test_multi_gpu.py
Reformatting /workspace/tests/py/test_ptq_to_backend.py
Reformatting /workspace/tests/py/test_ptq_trt_calibrator.py
Reformatting /workspace/tests/py/test_ptq_dataloader_calibrator.py
ERROR: Some files do not conform to style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
Copy link
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to Python style guidelines

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code conforms to C++ style guidelines

@narendasan narendasan merged commit 114969b into master Jul 28, 2021
@narendasan narendasan deleted the bowa_primif branch July 28, 2021 16:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
component: core Issues re: The core compiler component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants