Skip to content

Tidy torch/csrc/jit/passes/onnx code #160262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Tidy ONNX code
Signed-off-by: cyy <cyyever@outlook.com>
  • Loading branch information
cyyever committed Aug 11, 2025
commit 33da69394fcd15669daef189edfad77c20e8bab0
4 changes: 2 additions & 2 deletions torch/csrc/jit/passes/onnx/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ static std::optional<at::Tensor> runTorchSlice_opset9(
if (!(node->hasAttributeS("starts") && node->hasAttributeS("ends"))) {
return std::nullopt;
}
auto startsAttr = node->is(attr::starts);
auto endsAttr = node->is(attr::ends);
auto const& startsAttr = node->is(attr::starts);
auto const& endsAttr = node->is(attr::ends);
if (startsAttr.size() != endsAttr.size()) {
return std::nullopt;
}
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/passes/onnx/function_extraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ void FunctionExtractor::FunctionContext::SetAttrName(
TORCH_INTERNAL_ASSERT(
v_it != scope_ctxs_[scope_key_]->env_to_subgraph_.end());
auto* n_in_def = v_it->second->node();
auto n_attr_it = node_attr_to_name_[n_in_def][attr.toUnqualString()] = name;
node_attr_to_name_[n_in_def][attr.toUnqualString()] = name;
}

std::optional<std::string> FunctionExtractor::FunctionContext::FindAttrName(
Expand Down Expand Up @@ -405,7 +405,7 @@ std::optional<ScopePtr> FunctionExtractor::InferScope(Node* n) {
auto common_ancestor = FindCommonAncestor(scopes);
if (common_ancestor.has_value() &&
IsValidScope(common_ancestor.value())) {
return common_ancestor.value();
return common_ancestor;
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/passes/onnx/peephole.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ static bool isRNN(const Node* node) {
}

static bool isNopTranspose(const std::vector<int64_t>& perm) {
for (int64_t i = 0, perm_size = perm.size(); i < perm_size; i++) {
if (perm[i] != i) {
for (size_t i = 0, perm_size = perm.size(); i < perm_size; i++) {
if (perm[i] != static_cast<int64_t>(i)) {
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

#include <c10/util/irange.h>

#include <limits>

namespace torch::jit {

namespace {
Expand Down Expand Up @@ -344,7 +342,7 @@ static void PrepareForRemoveMutations(MutationRemover& mr, Block* b) {
auto it =
std::find(node->inputs().begin(), node->inputs().end(), input);
if (it != node->inputs().end()) {
int index = std::distance(node->inputs().begin(), it);
auto index = std::distance(node->inputs().begin(), it);
TORCH_WARN(
"ONNX Preprocess - Removing mutation from node ",
node->kind().toQualString(),
Expand Down
24 changes: 12 additions & 12 deletions torch/csrc/jit/passes/onnx/shape_type_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ void UpdateTorchValueByOnnxValueInfo(
}
}

bool IsValidONNXControlflowNode(const Node* n) {
static bool IsValidONNXControlflowNode(const Node* n) {
// Skip when block size is zero. This is when the node is being created,
// and doesn't have subblocks attached yet. Run shape inference for these
// nodes later, when the subgraph has already completed shape inferencing.
Expand All @@ -205,7 +205,7 @@ bool IsValidONNXControlflowNode(const Node* n) {
return true;
}

bool IsValidONNXNode(const Node* n) {
static bool IsValidONNXNode(const Node* n) {
auto node_kind = n->kind();

if (!node_kind.is_onnx()) {
Expand Down Expand Up @@ -282,7 +282,7 @@ Value* CloneValueFromListConstruct(
auto input = n_graph->addInput();
if (scalar_type) {
auto v_type = TensorType::create(
scalar_type.value(),
scalar_type,
at::kCPU,
c10::SymbolicShape(),
c10::VaryingShape<c10::Stride>{},
Expand Down Expand Up @@ -1244,7 +1244,7 @@ void ProcessUnsqueezeNode(Node* n) {
void ComputeConstant(Node* n, int opset_version) {
if (n->kind() == ::c10::onnx::Constant) {
if (n->kindOf(attr::value) == AttributeKind::t) {
at::Tensor const_val = n->t(attr::value);
const at::Tensor& const_val = n->t(attr::value);
at::Tensor const_val_copy =
at::empty(const_val.sizes(), const_val.options());
const_val_copy.copy_(const_val);
Expand Down Expand Up @@ -1381,7 +1381,7 @@ void ComputeConstant(Node* n, int opset_version) {
.value()
.sizes();
if (input0_shape_size.has_value()) {
auto input0_shape_value = input0_shape_size.value();
const auto& input0_shape_value = input0_shape_size.value();
if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
// When value of `shape` is statically known,
// output shape can be computed.
Expand Down Expand Up @@ -1474,7 +1474,7 @@ void ComputeConstant(Node* n, int opset_version) {
.value()
.sizes();
if (input0_shape_size.has_value()) {
auto input0_shape_value = input0_shape_size.value();
const auto& input0_shape_value = input0_shape_size.value();
int64_t total_size = 1;
auto is_full_static = true;
for (const auto i : c10::irange(input0_shape_value.size())) {
Expand Down Expand Up @@ -1510,7 +1510,7 @@ void ComputeConstant(Node* n, int opset_version) {
.value()
.sizes();
if (input0_shape_size.has_value()) {
auto input0_shape_value = input0_shape_size.value();
const auto& input0_shape_value = input0_shape_size.value();
if (ConstantValueMap::HasValue(n->input(1)->debugName())) {
auto shape_temp = ConstantValueMap::GetValueInto1DInt64Vector(
n->input(1)->debugName());
Expand Down Expand Up @@ -1659,10 +1659,10 @@ void SpecialPostProcess(Node* n) {
};

auto find_sequence_empty = [](Value* input,
TensorTypePtr t_type) -> Node* {
const TensorTypePtr& t_type) -> Node* {
auto find_sequence_empty_impl =
[](Value* input,
TensorTypePtr t_type,
const TensorTypePtr& t_type,
auto& find_sequence_empty_ref) -> Node* {
auto input_node = input->node();
TORCH_INTERNAL_ASSERT(input_node);
Expand Down Expand Up @@ -1708,7 +1708,7 @@ void SpecialPostProcess(Node* n) {
return nullptr;
};
return find_sequence_empty_impl(
input, std::move(t_type), find_sequence_empty_impl);
input, t_type, find_sequence_empty_impl);
};

if (seq_node && t_type && t_type->scalarType()) {
Expand Down Expand Up @@ -1837,7 +1837,7 @@ void FetchBlockInputMetadataFromParent(Block* b) {
}
}

void RemoveProcessedInputs(const Node* n) {
static void RemoveProcessedInputs(const Node* n) {
// After processing a node for shape inference, remove intermediate tensors
// that are stored in ConstantValueMap to reduce memory usage.
// This will only remove tensors that are no longer needed by any other node.
Expand Down Expand Up @@ -2213,7 +2213,7 @@ void ONNXSetDynamicInputShape(
GRAPH_UPDATE("dynamic axes tensor names:", [&]() {
std::vector<std::string> res(dynamic_axes.size());
std::transform(
dynamic_axes.begin(), dynamic_axes.end(), res.begin(), [](auto pair) {
dynamic_axes.begin(), dynamic_axes.end(), res.begin(), [](const auto& pair) {
return pair.first;
});
return res;
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ using namespace ::c10::onnx;
// 2. For ops with no output scale in op signature (like quantized::relu)
// we traverse up the graph to get the scale from its input until we hit a node
// where scale is explicitly specified.
double getScaleFromInput(Node* input_node) {
static double getScaleFromInput(Node* input_node) {
std::optional<IValue> scale;
std::string input_name = input_node->kind().toQualString();
std::unordered_set<std::string> noscale_ops = {
Expand Down Expand Up @@ -315,7 +315,7 @@ static void unpackQuantizedWeightsHelper(
auto config_vals = elements[1].to<std::vector<int64_t>>();
auto tensors = elements[2].to<std::vector<std::optional<at::Tensor>>>();

std::optional<at::Tensor> weight = tensors[1];
const std::optional<at::Tensor>& weight = tensors[1];
TORCH_INTERNAL_ASSERT(
weight, "Weight should always be present in serialized qconv.");
unpacked_weight = *weight;
Expand Down Expand Up @@ -373,7 +373,7 @@ static void unpackQuantizedWeightsHelper(
TORCH_INTERNAL_ASSERT(version == "2", "Unknown serialization version");
std::vector<at::Tensor> non_optional = elements[1].toTensorVector();

at::Tensor conv_params_packed = non_optional[0];
const at::Tensor& conv_params_packed = non_optional[0];
unpacked_weight = non_optional[1];

const int64_t kSpatialDim = conv_params_packed[0].item<int64_t>();
Expand Down