|
| 1 | +#include "torch/csrc/jit/ir/constants.h" |
| 2 | +#include "torch/csrc/jit/passes/subgraph_rewrite.h" |
| 3 | + |
| 4 | +#include "core/util/prelude.h" |
| 5 | + |
| 6 | +#include <vector> |
| 7 | + |
| 8 | +namespace torch_tensorrt { |
| 9 | +namespace core { |
| 10 | +namespace lowering { |
| 11 | +namespace passes { |
| 12 | + |
| 13 | +// Presumably this is safe since torch::jit::EraseNumberTypesOnBlock exists which just |
| 14 | +// removes prim::TensorToNum, aten::Float, aten::Int and prim::NumToTensor nodes outright |
| 15 | +void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph) { |
| 16 | + std::string int_cast_pattern = R"IR( |
| 17 | + graph(%1: int): |
| 18 | + %2: Tensor = aten::NumToTensor(%1) |
| 19 | + %3: int = aten::Int(%2) |
| 20 | + return (%3))IR"; |
| 21 | + std::string int_clean_pattern = R"IR( |
| 22 | + graph(%1: int): |
| 23 | + return (%1))IR"; |
| 24 | + |
| 25 | + std::string float_cast_pattern = R"IR( |
| 26 | + graph(%1: float): |
| 27 | + %2: Tensor = aten::NumToTensor(%1) |
| 28 | + %3: float = aten::Float(%2) |
| 29 | + return (%3))IR"; |
| 30 | + std::string float_clean_pattern = R"IR( |
| 31 | + graph(%1: float): |
| 32 | + return (%1))IR"; |
| 33 | + |
| 34 | + std::string bool_cast_pattern = R"IR( |
| 35 | + graph(%1: bool): |
| 36 | + %2: Tensor = aten::NumToTensor(%1) |
| 37 | + %3: bool = aten::Bool(%2) |
| 38 | + return (%3))IR"; |
| 39 | + std::string bool_clean_pattern = R"IR( |
| 40 | + graph(%1: bool): |
| 41 | + return (%1))IR"; |
| 42 | + |
| 43 | + torch::jit::SubgraphRewriter int_cast_rewriter; |
| 44 | + int_cast_rewriter.RegisterRewritePattern(int_cast_pattern, int_clean_pattern); |
| 45 | + int_cast_rewriter.runOnGraph(graph); |
| 46 | + |
| 47 | + torch::jit::SubgraphRewriter float_cast_rewriter; |
| 48 | + float_cast_rewriter.RegisterRewritePattern(float_cast_pattern, float_clean_pattern); |
| 49 | + float_cast_rewriter.runOnGraph(graph); |
| 50 | + |
| 51 | + torch::jit::SubgraphRewriter bool_cast_rewriter; |
| 52 | + bool_cast_rewriter.RegisterRewritePattern(bool_cast_pattern, bool_clean_pattern); |
| 53 | + bool_cast_rewriter.runOnGraph(graph); |
| 54 | + |
| 55 | + LOG_GRAPH("After RemoveUnnecessaryCasts: " << *graph); |
| 56 | +} |
| 57 | + |
| 58 | +void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g) { |
| 59 | + for (auto it = g->block()->nodes().begin(), end = g->block()->nodes().end(); it != end; ++it) { |
| 60 | + if (it->kind() == torch::jit::prim::Constant) { |
| 61 | + // Going from a constant and is single use means we can fuse |
| 62 | + if (it->output()->type()->isSubtypeOf(c10::TensorType::get())) { |
| 63 | + // Get the tensor stored in constant |
| 64 | + at::Tensor t = *torch::jit::constant_as<at::Tensor>(it->output()); |
| 65 | + // If shape is 0D |
| 66 | + if (t.sizes() == std::vector<int64_t>({})) { |
| 67 | + LOG_GRAPH("Found a 0D Tensor: " << it->output()->debugName()); |
| 68 | + LOG_GRAPH("Number of uses: " << it->output()->uses().size()); |
| 69 | + // If the tensor is only used once |
| 70 | + if (it->output()->uses().size() == 1) { |
| 71 | + auto use = it->output()->uses()[0]; |
| 72 | + auto user = use.user; |
| 73 | + |
| 74 | + // Is a NumToTensor / aten::[Int/Float] case |
| 75 | + if (user->outputs().size() == 1 && user->outputs()[0]->type()->isSubtypeOf(c10::TensorType::get())) { |
| 76 | + if (user->output()->uses().size() == 1) { |
| 77 | + auto potential_cast = user->output()->uses()[0].user; |
| 78 | + // The downstream user is aten::Int |
| 79 | + if (potential_cast->kind() == c10::Symbol::fromQualString("aten::Int") || |
| 80 | + potential_cast->kind() == c10::Symbol::fromQualString("aten::Float")) { |
| 81 | + LOG_GRAPH("Downstream user is aten::Int/aten::Float"); |
| 82 | + auto arg = use.offset; |
| 83 | + |
| 84 | + for (size_t k = 0; k < user->inputs().size(); ++k) { |
| 85 | + if (k != arg) { |
| 86 | + if (user->inputs()[k]->type()->isSubtypeOf(c10::TensorType::get())) { |
| 87 | + LOG_GRAPH("Input " << k << " is a Tensor"); |
| 88 | + if (user->inputs()[k]->node()->kind() == c10::Symbol::fromQualString("prim::NumToTensor")) { |
| 89 | + auto num_to_tensor = user->inputs()[k]->node(); |
| 90 | + |
| 91 | + LOG_GRAPH( |
| 92 | + "Found a prim::NumToTensor / aten::[Int/Float] pair with an intermediate operation:\n " |
| 93 | + << *(*it) << *num_to_tensor << *user << *potential_cast); |
| 94 | + |
| 95 | + // Replace the Tensor Constant with a scalar constant |
| 96 | + LOG_GRAPH("Deleting 0-dim Tensor: " << **it); |
| 97 | + torch::jit::WithInsertPoint gaurd(*it); |
| 98 | + |
| 99 | + auto new_const_val = g->insertConstant(t.item(), c10::nullopt, it->scope()); |
| 100 | + new_const_val->copyMetadata(it->output()); |
| 101 | + // How to determine the internal scalar type instead of assuming? |
| 102 | + if (potential_cast->kind() == c10::aten::Int) { |
| 103 | + new_const_val->setType(c10::IntType::get()); |
| 104 | + } else if (potential_cast->kind() == c10::aten::Float) { |
| 105 | + new_const_val->setType(c10::FloatType::get()); |
| 106 | + } |
| 107 | + it->output()->replaceAllUsesWith(new_const_val); |
| 108 | + it.destroyCurrent(); |
| 109 | + |
| 110 | + LOG_GRAPH("New constant: " << *new_const_val->node()); |
| 111 | + |
| 112 | + // Delete NumToTensor |
| 113 | + LOG_GRAPH("Deleting NumToTensor: " << *num_to_tensor); |
| 114 | + num_to_tensor->output()->replaceAllUsesWith(num_to_tensor->inputs()[0]); |
| 115 | + num_to_tensor->destroy(); |
| 116 | + |
| 117 | + // Change intermediate op output type |
| 118 | + LOG_GRAPH(user->schema()); |
| 119 | + |
| 120 | + torch::jit::Node* new_node; |
| 121 | + switch (user->kind()) { |
| 122 | + // Use this to handle special cases where the scalar version of the intermediate operator |
| 123 | + // has a different schema than the original |
| 124 | + case c10::aten::add: |
| 125 | + new_node = g->create( |
| 126 | + user->kind(), |
| 127 | + torch::jit::ArrayRef<torch::jit::Value*>({user->inputs()[0], user->inputs()[1]}), |
| 128 | + 1); |
| 129 | + new_node->insertAfter(user); |
| 130 | + new_node->outputs()[0]->setType(c10::IntType::get()); |
| 131 | + user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); |
| 132 | + user->destroy(); |
| 133 | + break; |
| 134 | + default: |
| 135 | + new_node = g->create(user->kind(), user->inputs(), 1); |
| 136 | + new_node->insertAfter(user); |
| 137 | + new_node->outputs()[0]->setType(c10::IntType::get()); |
| 138 | + user->outputs()[0]->replaceAllUsesWith(new_node->outputs()[0]); |
| 139 | + user->destroy(); |
| 140 | + break; |
| 141 | + } |
| 142 | + |
| 143 | + LOG_GRAPH("New intermediate operation: " << *new_node); |
| 144 | + LOG_GRAPH(new_node->schema()); |
| 145 | + |
| 146 | + // Delete aten::Int |
| 147 | + LOG_GRAPH("Deleting aten::[Int/Float]: " << *potential_cast); |
| 148 | + potential_cast->output()->replaceAllUsesWith(potential_cast->inputs()[0]); |
| 149 | + potential_cast->destroy(); |
| 150 | + } |
| 151 | + } |
| 152 | + } |
| 153 | + } |
| 154 | + } |
| 155 | + } |
| 156 | + } |
| 157 | + } |
| 158 | + } |
| 159 | + } |
| 160 | + } |
| 161 | + } |
| 162 | + LOG_GRAPH("Post removing single use 0-dim Tensor operations: " << *g); |
| 163 | +} |
| 164 | + |
| 165 | +} // namespace passes |
| 166 | +} // namespace lowering |
| 167 | +} // namespace core |
| 168 | +} // namespace torch_tensorrt |
0 commit comments