Skip to content

Freezing Torchscript modules #32178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed

Freezing Torchscript modules #32178

wants to merge 3 commits into from

Conversation

bzinodev
Copy link
Contributor

@bzinodev bzinodev commented Jan 14, 2020

This patch enables folding GetAttr nodes with their corresponding
values. _jit_pass_freeze_module API returns a new TorchScipt module
where all function calls and get attributes are inlined.
Usage:

frozen_model = torch._C._freeze_module(scrited_model._c)
frozen_model.forward(...)

This API currently optimizes the forward method. We will follow up to
to preserve and optimize methods and attributes that are annotated as
@torch.jit.interface.

Several future improvements to JIT optimizations are required to maximize
clean up/de-sugar the graph and eliminate redundancies.
Ideally, we want to produce a graph that can easily be lowered to
GLOW and other low-level backends.
__

@bzinodev bzinodev requested a review from apaszke as a code owner January 14, 2020 19:16
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jan 14, 2020
@kostmo
Copy link
Member

kostmo commented Jan 14, 2020

💊 CircleCI build failures summary and remediations

As of commit 597d63d (more details on the Dr. CI page):


Commit 597d63d was recently pushed. Waiting for builds...


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 85 times.

@bzinodev bzinodev force-pushed the freeze_module branch 3 times, most recently from 4c584d3 to beb2e40 Compare January 15, 2020 22:43
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bzinodev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@bzinodev bzinodev force-pushed the freeze_module branch 2 times, most recently from 35265b1 to 110a426 Compare January 15, 2020 23:13
@bzinodev bzinodev requested a review from resistor January 15, 2020 23:18
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bzinodev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only took a brief look at alias analysis changes. As far as I understand, the semantics of freezing are:

"I will only run this frozen method after freezing" - which means that we must consider mutable attributes within the method, but any attributes which are not mutated we can inline.

If the pass correctly takes account of attributes which aren't mutated and then inlines them, why are there being changes made to alias analysis?

Edit: this still doesn't handle interface calls. I think we should just throw if we see one for now, since interfaces still aren't part of the public api.

Copy link
Member

@suo suo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking pretty close. I left some commentary inline. We should also do a quick style review after this; but I don't want to clutter up the discussion on the actual logic.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in order to inline attributes of a module the following needs to happen.

For a module value %ModVal of type Mod with a Tensor field weights:

  • no value with Mod type has any prim::SetAttr[field="weights"] (*1)

  • all prim::GetAttr["weights"](%ModVal) can be resolved statically
    - For example, if %ModVal is an output of a control flow node we cannot resolve all GetAttr

  • weights is not aliased by any other Tensor contained within the frozen module, including in Lists and Tuples of Tensors (*2)
    - You can check this by checking the Tensor storage of each tensor

  • at this point, you can replace all prim::GetAttr[field="weights"] with a single top-level prim::FrozenGetAttr. This preserves aliasing. If the vallue doesn't have any writers it can be inlined as a constant.

(*1) You can relax this condition for a specific value of type MyMod if you can prove it doesn't alias any MyMod values that reassign weights, maybe better as a follow

(*2) You can relax this condition if you can prove the other aliases aren't mutated, also maybe better as a follow up.

Edit: the above correctness condition ensures that there is one single alias for an attribute Tensor, meaning that it will work even if the Tensor is mutated.

A separate correctness condition is to check that all aliases of an attribute Tensor are not mutated

jerryzh168 added a commit that referenced this pull request Jan 23, 2020
Summary:
Currently we compile constants directly into the graph, but when we
fold ConvBn, we might need to change a constant bias to some Tensor, and
this is not possible with bias being in the `__constants__` list, so we
need to remove `bias` from the list. This might result in some minor perf
regression but it should be unnoticeable and we can restore the perf after
freeze feature is enabled for production models: #32178

Test Plan:
python test/test_jit.py

Reviewers:
suo, mvz
Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@bzinodev bzinodev force-pushed the freeze_module branch 2 times, most recently from 3242e64 to 5841e2a Compare February 12, 2020 22:07
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just did a quick skim. While you are right that when we script a module we construct a new list when we convert a python list to an ivalue. However that doesn't mean that each list is a unique alias.

class MyMod(nn.Module):
     def __init__(self):
          self.x = [1, 2, 3]
          self.y = [1]

     def make_alias(self):
          self.x = self.y
...
mod = torch.jit.script(MyMod())
mod.make_alias()
freeze(mod)

Agreed this is a pretty unusual case. however the more important thing here is that aliasing of Tensors is preserved. So two lists can contain the same tensor.

I think we need to work on / finish up the hashing of potentially aliasing ivalues still.

@bzinodev bzinodev linked an issue Feb 13, 2020 that may be closed by this pull request
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bzinodev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

jhjun37 pushed a commit to jhjun37/pytorch_copy that referenced this pull request Feb 18, 2020
Summary:
Currently we compile constants directly into the graph, but when we
fold ConvBn, we might need to change a constant bias to some Tensor, and
this is not possible with bias being in the `__constants__` list, so we
need to remove `bias` from the list. This might result in some minor perf
regression but it should be unnoticeable and we can restore the perf after
freeze feature is enabled for production models: pytorch/pytorch#32178

Test Plan:
python test/test_jit.py

Reviewers:
suo, mvz
Subscribers:

Tasks:

Tags:

ghstack-source-id: 8cdd38a
Pull Request resolved: pytorch/pytorch#32543
@bzinodev bzinodev force-pushed the freeze_module branch 2 times, most recently from f2194d6 to 6a86724 Compare February 26, 2020 07:48
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bzinodev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bzinodev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@bzinodev bzinodev force-pushed the freeze_module branch 2 times, most recently from cc8848c to be52a4d Compare February 28, 2020 19:48
Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚢🚢🚢🚢🚢🚢🚢🚢🚢

Looks great, let's ship this! I left some comments that I would like addressed / responded to before landing, but let's merge this. It covers all of the reasonable edge cases, and as you have run into, requires some structural changes to Alias Analysis to handle the remaining degenerate cases that shouldn't block the PR, + we safely error out in those cases anyway.

Maybe you can help me here, but let's at least establish what some of the follows up:

  • maybe clean up / move public apis on ivalue into freezing
  • once alias analysis has fine-grained containment tracking, we can improve freezing here instead of erroring out
  • remove methods, attributes

anything else ?

break;
case Tag::Future:
case Tag::Device:
case Tag::Object:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do object like this:

    case Tag::Object: {
      auto obj_type = type()->expect<ClassType>();
      auto obj_value = toObject();
      auto attribute_names = obj_type->attributeNames();
      for (const auto& name: attribute_names) {
        auto attribute = obj_value->getAttr(name);
        attribute.getSubValues(subValues);
      }
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correct! I wanted to be conservative. At this point freezing does not need this case. I will add in the future with a nice testcase.

case Tag::Object:
case Tag::PyObject:
case Tag::Uninitialized:
case Tag::Capsule:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove Device and Uninitialized from this list. Uninitialized only exists as an IR construct, should never be an attribute, Device is immutable.

false, "sub ivalue is nat enabled for: ", this->tagKind());
// Fall through
default:
// don't record scalars.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: /don't record scalars/ don't record immutable types

struct HashIValue {
size_t operator()(const IValue& val) const {
if (val.isTensor()) {
return 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you look at the storage is_alias_of implementation, all it does is check the storage pointer. https://github.com/pytorch/pytorch/blob/master/c10/core/Storage.h#L152 We can do that here too.

#include <torch/csrc/utils/hash.h>
if (val.isTensor()) {
  return reinterpret_cast<size_t>(val.toTensor().storage().unsafeGetStorageImpl());
}
...

paging @smessmer that i did this right

return payload.as_intrusive_ptr;
}

TypePtr type() const;

size_t hash() const {
return payload.as_int;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is hashing for a very specific question - do these ivalues alias. You could imagine another user of this api to be asking, "are these two tensors the same", and they would be using this wrong. I would prefer if we just moved this all into the freezing files instead of adding APIs to ivalue.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am renaming the HashIValue to HashAliasedIValue and CompIValue to CompAliasedIValue. Hope this clear enough :)

TORCH_INTERNAL_ASSERT(attrModule.hasattr(name));
Value* paramConst = nullptr;
auto I = attrValues.find(attrModule._ivalue());
if (I != attrValues.end()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: the I / II variable names are confusing to read.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, sorry in LLVM we capitalize variables. I am renaming I to iter and II to iter2

script::Module& module_;

// Contains the attributes names (e.g. {"self", "subModule", "a"}
std::deque<std::string> names_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why this is a class attribute, probably shouldn't be

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used it to pretty print (e.g. %self.sub.conv = prim::constant(...). the time of computing the name and inserting the new node are far apart.

}
}

IValue overrideGradient(IValue attr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not just call getSubValues(), and iterate through tensor subvalues, in-place removing gradient ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes cleaner and it handles missing cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Eliminate tensor detach?

self.c = (self.a, 10)

def forward(self, x):
self.b[1] += 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice

void run(std::shared_ptr<Graph>& graph) {
Inline(*graph);
propagateAttributes(graph);
runOptimization(graph, /* unroll? */ false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know quantization and potentially other parties wanted to run optimizations immediately after freezing. It might be worth removing the runOptimization call here to make freezing more composable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The optimization is important for the clean up. Getting rid of all unused attributes. I think it is better to return a clean graph and user can call whatever followup otimization post freezing.

Zino Benaissa added 3 commits February 29, 2020 10:00
This patch enables folding GetAttr nodes with their corresponding
values. _jit_pass_freeze_module API returns a new TorchScipt module
where all function calls and get attributes are inlined.
Usage:

frozen_model = torch._C._freeze_module(scrited_model._c)
frozen_model.forward(...)

This API currently optimizes the forward method. We will follow up to
to preserve and optimize methods and attributes on demand.

Several future improvements to JIT optimizations are required to maximize
clean up/de-sugar the graph and eliminate redundancies.
this is an important step toward producing a graph that can easily be lowered to
GLOW and other low-level backends.
__
This patch adds two APIs to delete methods from modules and module's types.
NOTE these APIs is only for internal only uses and should be used only by freezing
where a new module and type are being created.
Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bzinodev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

v.append(4)
m_s.a = v
m_s.eval()
m_f = torch._C._freeze_module(m_s._c)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, before this pass is made public we need a workflow that does not involve accessing private members (torch._C and m_s._c). This pass should be in torch.jit.script with proper documentation (e.g. https://github.com/pytorch/pytorch/blob/master/torch/jit/__init__.py#L1117).

// folded.
// TODO: Determine if freezing in training mode is useful and further clarify
// its semantics.
TORCH_CHECK(!module.is_training());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we improve the error message here ?

ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
Summary:
This patch enables folding GetAttr nodes with their corresponding
values. _jit_pass_freeze_module API returns a new TorchScipt module
where all function calls and get attributes are inlined.
Usage:

frozen_model = torch._C._freeze_module(scrited_model._c)
frozen_model.forward(...)

This API currently optimizes the forward method. We will follow up to
to preserve and optimize methods and attributes that are annotated as
 torch.jit.interface.

Several future improvements to JIT optimizations are required to maximize
clean up/de-sugar the graph and eliminate redundancies.
Ideally, we want to produce a graph that can easily be lowered to
GLOW and other low-level backends.
__
Pull Request resolved: pytorch#32178

Differential Revision: D19419640

Pulled By: bzinodev

fbshipit-source-id: 52baffaba9bca2cd60a8e747baa68d57711ad42b
@facebook-github-bot facebook-github-bot deleted the freeze_module branch July 13, 2020 17:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Freezing TorchScript Modules
8 participants