Skip to content

Commit eac0efa

Browse files
committed
More tidy
1 parent 4c264a0 commit eac0efa

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

torch/csrc/jit/passes/onnx/shape_type_inference.cpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ void UpdateTorchValueByOnnxValueInfo(
191191
}
192192
}
193193

194-
static bool IsValidONNXControlflowNode(const Node* n) {
194+
bool IsValidONNXControlflowNode(const Node* n) {
195195
// Skip when block size is zero. This is when the node is being created,
196196
// and doesn't have subblocks attached yet. Run shape inference for these
197197
// nodes later, when the subgraph has already completed shape inferencing.
@@ -205,7 +205,7 @@ static bool IsValidONNXControlflowNode(const Node* n) {
205205
return true;
206206
}
207207

208-
static bool IsValidONNXNode(const Node* n) {
208+
bool IsValidONNXNode(const Node* n) {
209209
auto node_kind = n->kind();
210210

211211
if (!node_kind.is_onnx()) {
@@ -411,7 +411,9 @@ void ConvertGraphToONNXProto(
411411
}
412412
}
413413

414-
std::optional<at::Tensor> ComputeConstantFolding(Node* n, int opset_version) {
414+
std::optional<at::Tensor> ComputeConstantFolding(
415+
const Node* n,
416+
int opset_version) {
415417
if (n->inputs().empty()) {
416418
return std::nullopt;
417419
}
@@ -463,7 +465,7 @@ std::optional<::c10::SymbolicShape> ComputeShapeFromReshape(
463465
auto it_0 = std::find_if(shape_vector.begin(), shape_vector.end(), is_zero);
464466
bool shape_has_zero = it_0 != shape_vector.end();
465467

466-
int minus_one_pos = -1;
468+
int64_t minus_one_pos = -1;
467469
for (auto i : c10::irange(shape_vector.size())) {
468470
if (shape_vector[i].value() == -1) {
469471
minus_one_pos = i;
@@ -773,7 +775,7 @@ void ProcessBroadcastNode(Node* n) {
773775
}
774776

775777
void ProcessShapeForConcatNode(Node* n) {
776-
int axis = n->i(attr::axis);
778+
auto axis = n->i(attr::axis);
777779
if (ConstantValueMap::HasRank(n->input(0)->debugName())) {
778780
auto rank = ConstantValueMap::GetRank(n->input(0)->debugName()).value();
779781
size_t axis_adjust = 0;
@@ -1837,7 +1839,7 @@ void FetchBlockInputMetadataFromParent(Block* b) {
18371839
}
18381840
}
18391841

1840-
static void RemoveProcessedInputs(const Node* n) {
1842+
void RemoveProcessedInputs(const Node* n) {
18411843
// After processing a node for shape inference, remove intermediate tensors
18421844
// that are stored in ConstantValueMap to reduce memory usage.
18431845
// This will only remove tensors that are no longer needed by any other node.
@@ -2213,10 +2215,9 @@ void ONNXSetDynamicInputShape(
22132215
GRAPH_UPDATE("dynamic axes tensor names:", [&]() {
22142216
std::vector<std::string> res(dynamic_axes.size());
22152217
std::transform(
2216-
dynamic_axes.begin(),
2217-
dynamic_axes.end(),
2218-
res.begin(),
2219-
[](const auto& pair) { return pair.first; });
2218+
dynamic_axes.begin(), dynamic_axes.end(), res.begin(), [](auto pair) {
2219+
return pair.first;
2220+
});
22202221
return res;
22212222
}());
22222223

@@ -2256,7 +2257,7 @@ void ONNXSetDynamicInputShape(
22562257
}
22572258
}
22582259

2259-
static bool HasSequenceTypeOutput(Node* node) {
2260+
static bool HasSequenceTypeOutput(const Node* node) {
22602261
if (node->kind() == ::c10::onnx::SplitToSequence ||
22612262
node->kind() == ::c10::onnx::SequenceInsert ||
22622263
node->kind() == ::c10::onnx::SequenceEmpty ||

0 commit comments

Comments
 (0)