-
Notifications
You must be signed in to change notification settings - Fork 24.1k
Freezing Torchscript modules #32178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Freezing Torchscript modules #32178
Conversation
💊 CircleCI build failures summary and remediationsAs of commit 597d63d (more details on the Dr. CI page): Commit 597d63d was recently pushed. Waiting for builds... This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 85 times. |
4c584d3
to
beb2e40
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bzinodev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
35265b1
to
110a426
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bzinodev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
110a426
to
bcf1ec7
Compare
bcf1ec7
to
602a1b9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only took a brief look at alias analysis changes. As far as I understand, the semantics of freezing are:
"I will only run this frozen method after freezing" - which means that we must consider mutable attributes within the method, but any attributes which are not mutated we can inline.
If the pass correctly takes account of attributes which aren't mutated and then inlines them, why are there being changes made to alias analysis?
Edit: this still doesn't handle interface calls. I think we should just throw if we see one for now, since interfaces still aren't part of the public api.
602a1b9
to
49e2be1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking pretty close. I left some commentary inline. We should also do a quick style review after this; but I don't want to clutter up the discussion on the actual logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in order to inline attributes of a module the following needs to happen.
For a module value %ModVal
of type Mod
with a Tensor field weights
:
-
no value with
Mod
type has anyprim::SetAttr[field="weights"]
(*1) -
all
prim::GetAttr["weights"](%ModVal)
can be resolved statically
- For example, if%ModVal
is an output of a control flow node we cannot resolve all GetAttr -
weights
is not aliased by any other Tensor contained within the frozen module, including in Lists and Tuples of Tensors (*2)
- You can check this by checking the Tensor storage of each tensor -
at this point, you can replace all
prim::GetAttr[field="weights"]
with a single top-levelprim::FrozenGetAttr
. This preserves aliasing. If the vallue doesn't have any writers it can be inlined as a constant.
(*1) You can relax this condition for a specific value of type MyMod
if you can prove it doesn't alias any MyMod
values that reassign weights
, maybe better as a follow
(*2) You can relax this condition if you can prove the other aliases aren't mutated, also maybe better as a follow up.
Edit: the above correctness condition ensures that there is one single alias for an attribute Tensor, meaning that it will work even if the Tensor is mutated.
A separate correctness condition is to check that all aliases of an attribute Tensor are not mutated
Summary: Currently we compile constants directly into the graph, but when we fold ConvBn, we might need to change a constant bias to some Tensor, and this is not possible with bias being in the `__constants__` list, so we need to remove `bias` from the list. This might result in some minor perf regression but it should be unnoticeable and we can restore the perf after freeze feature is enabled for production models: #32178 Test Plan: python test/test_jit.py Reviewers: suo, mvz Subscribers: Tasks: Tags: [ghstack-poisoned]
3242e64
to
5841e2a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just did a quick skim. While you are right that when we script a module we construct a new list when we convert a python list to an ivalue. However that doesn't mean that each list is a unique alias.
class MyMod(nn.Module):
def __init__(self):
self.x = [1, 2, 3]
self.y = [1]
def make_alias(self):
self.x = self.y
...
mod = torch.jit.script(MyMod())
mod.make_alias()
freeze(mod)
Agreed this is a pretty unusual case. however the more important thing here is that aliasing of Tensors is preserved. So two lists can contain the same tensor.
I think we need to work on / finish up the hashing of potentially aliasing ivalues still.
5841e2a
to
d171576
Compare
d171576
to
27e0c35
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bzinodev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Currently we compile constants directly into the graph, but when we fold ConvBn, we might need to change a constant bias to some Tensor, and this is not possible with bias being in the `__constants__` list, so we need to remove `bias` from the list. This might result in some minor perf regression but it should be unnoticeable and we can restore the perf after freeze feature is enabled for production models: pytorch/pytorch#32178 Test Plan: python test/test_jit.py Reviewers: suo, mvz Subscribers: Tasks: Tags: ghstack-source-id: 8cdd38a Pull Request resolved: pytorch/pytorch#32543
f2194d6
to
6a86724
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bzinodev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bzinodev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
cc8848c
to
be52a4d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚢🚢🚢🚢🚢🚢🚢🚢🚢
Looks great, let's ship this! I left some comments that I would like addressed / responded to before landing, but let's merge this. It covers all of the reasonable edge cases, and as you have run into, requires some structural changes to Alias Analysis to handle the remaining degenerate cases that shouldn't block the PR, + we safely error out in those cases anyway.
Maybe you can help me here, but let's at least establish what some of the follows up:
- maybe clean up / move public apis on ivalue into freezing
- once alias analysis has fine-grained containment tracking, we can improve freezing here instead of erroring out
- remove methods, attributes
anything else ?
break; | ||
case Tag::Future: | ||
case Tag::Device: | ||
case Tag::Object: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do object like this:
case Tag::Object: {
auto obj_type = type()->expect<ClassType>();
auto obj_value = toObject();
auto attribute_names = obj_type->attributeNames();
for (const auto& name: attribute_names) {
auto attribute = obj_value->getAttr(name);
attribute.getSubValues(subValues);
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct! I wanted to be conservative. At this point freezing does not need this case. I will add in the future with a nice testcase.
case Tag::Object: | ||
case Tag::PyObject: | ||
case Tag::Uninitialized: | ||
case Tag::Capsule: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can remove Device and Uninitialized from this list. Uninitialized
only exists as an IR construct, should never be an attribute, Device is immutable.
false, "sub ivalue is nat enabled for: ", this->tagKind()); | ||
// Fall through | ||
default: | ||
// don't record scalars. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: /don't record scalars/ don't record immutable types
struct HashIValue { | ||
size_t operator()(const IValue& val) const { | ||
if (val.isTensor()) { | ||
return 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you look at the storage is_alias_of
implementation, all it does is check the storage pointer. https://github.com/pytorch/pytorch/blob/master/c10/core/Storage.h#L152 We can do that here too.
#include <torch/csrc/utils/hash.h>
if (val.isTensor()) {
return reinterpret_cast<size_t>(val.toTensor().storage().unsafeGetStorageImpl());
}
...
paging @smessmer that i did this right
return payload.as_intrusive_ptr; | ||
} | ||
|
||
TypePtr type() const; | ||
|
||
size_t hash() const { | ||
return payload.as_int; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is hashing for a very specific question - do these ivalues alias. You could imagine another user of this api to be asking, "are these two tensors the same", and they would be using this wrong. I would prefer if we just moved this all into the freezing files instead of adding APIs to ivalue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am renaming the HashIValue to HashAliasedIValue and CompIValue to CompAliasedIValue. Hope this clear enough :)
TORCH_INTERNAL_ASSERT(attrModule.hasattr(name)); | ||
Value* paramConst = nullptr; | ||
auto I = attrValues.find(attrModule._ivalue()); | ||
if (I != attrValues.end()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: the I
/ II
variable names are confusing to read.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, sorry in LLVM we capitalize variables. I am renaming I to iter and II to iter2
script::Module& module_; | ||
|
||
// Contains the attributes names (e.g. {"self", "subModule", "a"} | ||
std::deque<std::string> names_; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why this is a class attribute, probably shouldn't be
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used it to pretty print (e.g. %self.sub.conv = prim::constant(...). the time of computing the name and inserting the new node are far apart.
} | ||
} | ||
|
||
IValue overrideGradient(IValue attr) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we not just call getSubValues()
, and iterate through tensor subvalues, in-place removing gradient ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes cleaner and it handles missing cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Eliminate tensor detach?
self.c = (self.a, 10) | ||
|
||
def forward(self, x): | ||
self.b[1] += 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice
void run(std::shared_ptr<Graph>& graph) { | ||
Inline(*graph); | ||
propagateAttributes(graph); | ||
runOptimization(graph, /* unroll? */ false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know quantization and potentially other parties wanted to run optimizations immediately after freezing. It might be worth removing the runOptimization call here to make freezing more composable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The optimization is important for the clean up. Getting rid of all unused attributes. I think it is better to return a clean graph and user can call whatever followup otimization post freezing.
This patch enables folding GetAttr nodes with their corresponding values. _jit_pass_freeze_module API returns a new TorchScipt module where all function calls and get attributes are inlined. Usage: frozen_model = torch._C._freeze_module(scrited_model._c) frozen_model.forward(...) This API currently optimizes the forward method. We will follow up to to preserve and optimize methods and attributes on demand. Several future improvements to JIT optimizations are required to maximize clean up/de-sugar the graph and eliminate redundancies. this is an important step toward producing a graph that can easily be lowered to GLOW and other low-level backends. __
This patch adds two APIs to delete methods from modules and module's types. NOTE these APIs is only for internal only uses and should be used only by freezing where a new module and type are being created.
be52a4d
to
597d63d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@bzinodev has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
v.append(4) | ||
m_s.a = v | ||
m_s.eval() | ||
m_f = torch._C._freeze_module(m_s._c) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, before this pass is made public we need a workflow that does not involve accessing private members (torch._C
and m_s._c
). This pass should be in torch.jit.script
with proper documentation (e.g. https://github.com/pytorch/pytorch/blob/master/torch/jit/__init__.py#L1117).
// folded. | ||
// TODO: Determine if freezing in training mode is useful and further clarify | ||
// its semantics. | ||
TORCH_CHECK(!module.is_training()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we improve the error message here ?
Summary: This patch enables folding GetAttr nodes with their corresponding values. _jit_pass_freeze_module API returns a new TorchScipt module where all function calls and get attributes are inlined. Usage: frozen_model = torch._C._freeze_module(scrited_model._c) frozen_model.forward(...) This API currently optimizes the forward method. We will follow up to to preserve and optimize methods and attributes that are annotated as torch.jit.interface. Several future improvements to JIT optimizations are required to maximize clean up/de-sugar the graph and eliminate redundancies. Ideally, we want to produce a graph that can easily be lowered to GLOW and other low-level backends. __ Pull Request resolved: pytorch#32178 Differential Revision: D19419640 Pulled By: bzinodev fbshipit-source-id: 52baffaba9bca2cd60a8e747baa68d57711ad42b
This patch enables folding GetAttr nodes with their corresponding
values. _jit_pass_freeze_module API returns a new TorchScipt module
where all function calls and get attributes are inlined.
Usage:
frozen_model = torch._C._freeze_module(scrited_model._c)
frozen_model.forward(...)
This API currently optimizes the forward method. We will follow up to
to preserve and optimize methods and attributes that are annotated as
@torch.jit.interface.
Several future improvements to JIT optimizations are required to maximize
clean up/de-sugar the graph and eliminate redundancies.
Ideally, we want to produce a graph that can easily be lowered to
GLOW and other low-level backends.
__