Skip to content

Commit c410fe6

Browse files
committed
Tidy ONNX code
Signed-off-by: cyy <cyyever@outlook.com>
1 parent df55ec7 commit c410fe6

File tree

6 files changed

+22
-24
lines changed

6 files changed

+22
-24
lines changed

torch/csrc/jit/passes/onnx/constant_fold.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ static std::optional<at::Tensor> runTorchSlice_opset9(
7676
if (!(node->hasAttributeS("starts") && node->hasAttributeS("ends"))) {
7777
return std::nullopt;
7878
}
79-
auto startsAttr = node->is(attr::starts);
80-
auto endsAttr = node->is(attr::ends);
79+
auto const& startsAttr = node->is(attr::starts);
80+
auto const& endsAttr = node->is(attr::ends);
8181
if (startsAttr.size() != endsAttr.size()) {
8282
return std::nullopt;
8383
}

torch/csrc/jit/passes/onnx/function_extraction.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ void FunctionExtractor::FunctionContext::SetAttrName(
216216
TORCH_INTERNAL_ASSERT(
217217
v_it != scope_ctxs_[scope_key_]->env_to_subgraph_.end());
218218
auto* n_in_def = v_it->second->node();
219-
auto n_attr_it = node_attr_to_name_[n_in_def][attr.toUnqualString()] = name;
219+
node_attr_to_name_[n_in_def][attr.toUnqualString()] = name;
220220
}
221221

222222
std::optional<std::string> FunctionExtractor::FunctionContext::FindAttrName(
@@ -405,7 +405,7 @@ std::optional<ScopePtr> FunctionExtractor::InferScope(Node* n) {
405405
auto common_ancestor = FindCommonAncestor(scopes);
406406
if (common_ancestor.has_value() &&
407407
IsValidScope(common_ancestor.value())) {
408-
return common_ancestor.value();
408+
return common_ancestor;
409409
}
410410
}
411411
}

torch/csrc/jit/passes/onnx/peephole.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ static bool isRNN(const Node* node) {
3535
}
3636

3737
static bool isNopTranspose(const std::vector<int64_t>& perm) {
38-
for (int64_t i = 0, perm_size = perm.size(); i < perm_size; i++) {
39-
if (perm[i] != i) {
38+
for (size_t i = 0, perm_size = perm.size(); i < perm_size; i++) {
39+
if (perm[i] != static_cast<int64_t>(i)) {
4040
return false;
4141
}
4242
}

torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
#include <c10/util/irange.h>
1212

13-
#include <limits>
14-
1513
namespace torch::jit {
1614

1715
namespace {
@@ -344,7 +342,7 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) {
344342
auto it =
345343
std::find(node->inputs().begin(), node->inputs().end(), input);
346344
if (it != node->inputs().end()) {
347-
int index = std::distance(node->inputs().begin(), it);
345+
auto index = std::distance(node->inputs().begin(), it);
348346
TORCH_WARN(
349347
"ONNX Preprocess - Removing mutation from node ",
350348
node->kind().toQualString(),

torch/csrc/jit/passes/onnx/shape_type_inference.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ void UpdateTorchValueByOnnxValueInfo(
191191
}
192192
}
193193

194-
bool IsValidONNXControlflowNode(const Node* n) {
194+
static bool IsValidONNXControlflowNode(const Node* n) {
195195
// Skip when block size is zero. This is when the node is being created,
196196
// and doesn't have subblocks attached yet. Run shape inference for these
197197
// nodes later, when the subgraph has already completed shape inferencing.
@@ -205,7 +205,7 @@ bool IsValidONNXControlflowNode(const Node* n) {
205205
return true;
206206
}
207207

208-
bool IsValidONNXNode(const Node* n) {
208+
static bool IsValidONNXNode(const Node* n) {
209209
auto node_kind = n->kind();
210210

211211
if (!node_kind.is_onnx()) {
@@ -282,7 +282,7 @@ Value* CloneValueFromListConstruct(
282282
auto input = n_graph->addInput();
283283
if (scalar_type) {
284284
auto v_type = TensorType::create(
285-
scalar_type.value(),
285+
scalar_type,
286286
at::kCPU,
287287
c10::SymbolicShape(),
288288
c10::VaryingShape<c10::Stride>{},
@@ -1244,7 +1244,7 @@ void ProcessUnsqueezeNode(Node* n) {
12441244
void ComputeConstant(Node* n, int opset_version) {
12451245
if (n->kind() == ::c10::onnx::Constant) {
12461246
if (n->kindOf(attr::value) == AttributeKind::t) {
1247-
at::Tensor const_val = n->t(attr::value);
1247+
const at::Tensor& const_val = n->t(attr::value);
12481248
at::Tensor const_val_copy =
12491249
at::empty(const_val.sizes(), const_val.options());
12501250
const_val_copy.copy_(const_val);
@@ -1381,7 +1381,7 @@ void ComputeConstant(Node* n, int opset_version) {
13811381
.value()
13821382
.sizes();
13831383
if (input0_shape_size.has_value()) {
1384-
auto input0_shape_value = input0_shape_size.value();
1384+
const auto& input0_shape_value = input0_shape_size.value();
13851385
if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
13861386
// When value of `shape` is statically known,
13871387
// output shape can be computed.
@@ -1474,7 +1474,7 @@ void ComputeConstant(Node* n, int opset_version) {
14741474
.value()
14751475
.sizes();
14761476
if (input0_shape_size.has_value()) {
1477-
auto input0_shape_value = input0_shape_size.value();
1477+
const auto& input0_shape_value = input0_shape_size.value();
14781478
int64_t total_size = 1;
14791479
auto is_full_static = true;
14801480
for (const auto i : c10::irange(input0_shape_value.size())) {
@@ -1510,7 +1510,7 @@ void ComputeConstant(Node* n, int opset_version) {
15101510
.value()
15111511
.sizes();
15121512
if (input0_shape_size.has_value()) {
1513-
auto input0_shape_value = input0_shape_size.value();
1513+
const auto& input0_shape_value = input0_shape_size.value();
15141514
if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
15151515
auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
15161516
n->input(1)->debugName());
@@ -1659,10 +1659,10 @@ void SpecialPostProcess(Node* n) {
16591659
};
16601660

16611661
auto find_sequence_empty = [](Value* input,
1662-
TensorTypePtr t_type) -> Node* {
1662+
const TensorTypePtr& t_type) -> Node* {
16631663
auto find_sequence_empty_impl =
16641664
[](Value* input,
1665-
TensorTypePtr t_type,
1665+
const TensorTypePtr& t_type,
16661666
auto& find_sequence_empty_ref) -> Node* {
16671667
auto input_node = input->node();
16681668
TORCH_INTERNAL_ASSERT(input_node);
@@ -1708,7 +1708,7 @@ void SpecialPostProcess(Node* n) {
17081708
return nullptr;
17091709
};
17101710
return find_sequence_empty_impl(
1711-
input, std::move(t_type), find_sequence_empty_impl);
1711+
input, t_type, find_sequence_empty_impl);
17121712
};
17131713

17141714
if (seq_node && t_type && t_type->scalarType()) {
@@ -1837,7 +1837,7 @@ void FetchBlockInputMetadataFromParent(Block* b) {
18371837
}
18381838
}
18391839

1840-
void RemoveProcessedInputs(const Node* n) {
1840+
static void RemoveProcessedInputs(const Node* n) {
18411841
// After processing a node for shape inference, remove intermediate tensors
18421842
// that are stored in ConstantValueMap to reduce memory usage.
18431843
// This will only remove tensors that are no longer needed by any other node.
@@ -2213,7 +2213,7 @@ void ONNXSetDynamicInputShape(
22132213
GRAPH_UPDATE("dynamic axes tensor names:", [&]() {
22142214
std::vector<std::string> res(dynamic_axes.size());
22152215
std::transform(
2216-
dynamic_axes.begin(), dynamic_axes.end(), res.begin(), [](auto pair) {
2216+
dynamic_axes.begin(), dynamic_axes.end(), res.begin(), [](const auto& pair) {
22172217
return pair.first;
22182218
});
22192219
return res;

torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ using namespace ::c10::onnx;
2727
// 2. For ops with no output scale in op signature (like quantized::relu)
2828
// we traverse up the graph to get the scale from its input until we hit a node
2929
// where scale is explicitly specified.
30-
double getScaleFromInput(Node* input_node) {
30+
static double getScaleFromInput(Node* input_node) {
3131
std::optional<IValue> scale;
3232
std::string input_name = input_node->kind().toQualString();
3333
std::unordered_set<std::string> noscale_ops = {
@@ -315,7 +315,7 @@ static void unpackQuantizedWeightsHelper(
315315
auto config_vals = elements[1].to<std::vector<int64_t>>();
316316
auto tensors = elements[2].to<std::vector<std::optional<at::Tensor>>>();
317317

318-
std::optional<at::Tensor> weight = tensors[1];
318+
const std::optional<at::Tensor>& weight = tensors[1];
319319
TORCH_INTERNAL_ASSERT(
320320
weight, "Weight should always be present in serialized qconv.");
321321
unpacked_weight = *weight;
@@ -373,7 +373,7 @@ static void unpackQuantizedWeightsHelper(
373373
TORCH_INTERNAL_ASSERT(version == "2", "Unknown serialization version");
374374
std::vector<at::Tensor> non_optional = elements[1].toTensorVector();
375375

376-
at::Tensor conv_params_packed = non_optional[0];
376+
const at::Tensor& conv_params_packed = non_optional[0];
377377
unpacked_weight = non_optional[1];
378378

379379
const int64_t kSpatialDim = conv_params_packed[0].item<int64_t>();

0 commit comments

Comments
 (0)