@@ -416,10 +416,11 @@ <h1>Source code for torch.distributed.fsdp.fully_sharded_data_parallel</h1><div
416
416
< span class ="n "> Any</ span > < span class ="p "> ,</ span >
417
417
< span class ="n "> Callable</ span > < span class ="p "> ,</ span >
418
418
< span class ="n "> Dict</ span > < span class ="p "> ,</ span >
419
- < span class ="n "> List</ span > < span class ="p "> ,</ span >
420
- < span class ="n "> Optional</ span > < span class ="p "> ,</ span >
421
419
< span class ="n "> Generator</ span > < span class ="p "> ,</ span >
420
+ < span class ="n "> Iterator</ span > < span class ="p "> ,</ span >
421
+ < span class ="n "> List</ span > < span class ="p "> ,</ span >
422
422
< span class ="n "> NamedTuple</ span > < span class ="p "> ,</ span >
423
+ < span class ="n "> Optional</ span > < span class ="p "> ,</ span >
423
424
< span class ="n "> Set</ span > < span class ="p "> ,</ span >
424
425
< span class ="n "> Tuple</ span > < span class ="p "> ,</ span >
425
426
< span class ="n "> Union</ span > < span class ="p "> ,</ span >
@@ -434,25 +435,28 @@ <h1>Source code for torch.distributed.fsdp.fully_sharded_data_parallel</h1><div
434
435
< span class ="kn "> from</ span > < span class ="nn "> torch.autograd</ span > < span class ="kn "> import</ span > < span class ="n "> Variable</ span >
435
436
< span class ="kn "> from</ span > < span class ="nn "> torch.distributed</ span > < span class ="kn "> import</ span > < span class ="n "> ProcessGroup</ span >
436
437
< span class ="kn "> from</ span > < span class ="nn "> torch.distributed._sharded_tensor</ span > < span class ="kn "> import</ span > < span class ="p "> (</ span >
437
- < span class ="n "> init_from_local_shards</ span > < span class ="p "> ,</ span >
438
438
< span class ="n "> Shard</ span > < span class ="p "> ,</ span >
439
439
< span class ="n "> ShardedTensor</ span > < span class ="p "> ,</ span >
440
+ < span class ="n "> init_from_local_shards</ span > < span class ="p "> ,</ span >
440
441
< span class ="p "> )</ span >
441
442
< span class ="kn "> from</ span > < span class ="nn "> torch.distributed.distributed_c10d</ span > < span class ="kn "> import</ span > < span class ="n "> _get_default_group</ span >
442
443
< span class ="kn "> from</ span > < span class ="nn "> torch.nn.parameter</ span > < span class ="kn "> import</ span > < span class ="n "> Parameter</ span >
443
444
444
- < span class ="kn "> from</ span > < span class ="nn "> .flatten_params_wrapper</ span > < span class ="kn "> import</ span > < span class ="n "> FlatParameter</ span > < span class ="p "> ,</ span > < span class ="n "> FlattenParamsWrapper</ span > < span class ="p "> ,</ span > < span class ="n "> FLAT_PARAM</ span >
445
- < span class ="kn "> from</ span > < span class ="nn "> .utils</ span > < span class ="kn "> import</ span > < span class ="p "> (</ span >
446
- < span class ="n "> _apply_to_tensors</ span > < span class ="p "> ,</ span >
447
- < span class ="n "> _replace_by_prefix</ span > < span class ="p "> ,</ span >
445
+ < span class ="kn "> from</ span > < span class ="nn "> .flatten_params_wrapper</ span > < span class ="kn "> import</ span > < span class ="p "> (</ span >
446
+ < span class ="n "> FLAT_PARAM</ span > < span class ="p "> ,</ span >
447
+ < span class ="n "> FPW_MODULE</ span > < span class ="p "> ,</ span >
448
+ < span class ="n "> FlatParameter</ span > < span class ="p "> ,</ span >
449
+ < span class ="n "> FlattenParamsWrapper</ span > < span class ="p "> ,</ span >
448
450
< span class ="p "> )</ span >
451
+ < span class ="kn "> from</ span > < span class ="nn "> .utils</ span > < span class ="kn "> import</ span > < span class ="n "> _apply_to_tensors</ span > < span class ="p "> ,</ span > < span class ="n "> _replace_by_prefix</ span >
449
452
< span class ="kn "> from</ span > < span class ="nn "> .wrap</ span > < span class ="kn "> import</ span > < span class ="n "> _recursive_wrap</ span >
450
453
451
454
< span class ="k "> if</ span > < span class ="n "> TYPE_CHECKING</ span > < span class ="p "> :</ span >
452
455
< span class ="kn "> from</ span > < span class ="nn "> collections</ span > < span class ="kn "> import</ span > < span class ="n "> OrderedDict</ span > < span class ="c1 "> # noqa: F401</ span >
453
456
454
457
455
458
< span class ="n "> FSDP_WRAPPED_MODULE</ span > < span class ="o "> =</ span > < span class ="s2 "> "_fsdp_wrapped_module"</ span >
459
+ < span class ="n "> FSDP_PREFIX</ span > < span class ="o "> =</ span > < span class ="n "> FSDP_WRAPPED_MODULE</ span > < span class ="o "> +</ span > < span class ="s2 "> "."</ span > < span class ="o "> +</ span > < span class ="n "> FPW_MODULE</ span > < span class ="o "> +</ span > < span class ="s2 "> "."</ span >
456
460
457
461
458
462
< span class ="k "> class</ span > < span class ="nc "> ShardingStrategy</ span > < span class ="p "> (</ span > < span class ="n "> Enum</ span > < span class ="p "> ):</ span >
@@ -1762,6 +1766,26 @@ <h1>Source code for torch.distributed.fsdp.fully_sharded_data_parallel</h1><div
1762
1766
< span class ="n "> _free_full_params_and_use_local_shard</ span > < span class ="p "> (</ span > < span class ="n "> currently_local_params</ span > < span class ="p "> )</ span >
1763
1767
< span class ="bp "> self</ span > < span class ="o "> .</ span > < span class ="n "> training_state</ span > < span class ="o "> =</ span > < span class ="n "> TrainingState_</ span > < span class ="o "> .</ span > < span class ="n "> IDLE</ span > </ div >
1764
1768
1769
+ < div class ="viewcode-block " id ="FullyShardedDataParallel.named_parameters "> < a class ="viewcode-back " href ="../../../../fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.named_parameters "> [docs]</ a > < span class ="k "> def</ span > < span class ="nf "> named_parameters</ span > < span class ="p "> (</ span >
1770
+ < span class ="bp "> self</ span > < span class ="p "> ,</ span >
1771
+ < span class ="o "> *</ span > < span class ="n "> args</ span > < span class ="p "> ,</ span >
1772
+ < span class ="o "> **</ span > < span class ="n "> kwargs</ span > < span class ="p "> ,</ span >
1773
+ < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> Iterator</ span > < span class ="p "> [</ span > < span class ="n "> Tuple</ span > < span class ="p "> [</ span > < span class ="nb "> str</ span > < span class ="p "> ,</ span > < span class ="n "> torch</ span > < span class ="o "> .</ span > < span class ="n "> nn</ span > < span class ="o "> .</ span > < span class ="n "> Parameter</ span > < span class ="p "> ]]:</ span >
1774
+ < span class ="sd "> """</ span >
1775
+ < span class ="sd "> Overrides :meth:`named_parameters()` to intercept parameter names and</ span >
1776
+ < span class ="sd "> remove all occurrences of the FSDP-specific flattened parameter prefix</ span >
1777
+ < span class ="sd "> when inside the :meth:`summon_full_params` context manager.</ span >
1778
+ < span class ="sd "> """</ span >
1779
+ < span class ="c1 "> # Determine which logic to use based on the context at call time</ span >
1780
+ < span class ="n "> in_summon_full_params</ span > < span class ="o "> =</ span > < span class ="nb "> getattr</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="s2 "> "training_state"</ span > < span class ="p "> ,</ span > < span class ="kc "> None</ span > < span class ="p "> )</ span > < span class ="o "> ==</ span > \
1781
+ < span class ="n "> TrainingState_</ span > < span class ="o "> .</ span > < span class ="n "> SUMMON_FULL_PARAMS</ span >
1782
+ < span class ="k "> for</ span > < span class ="n "> param_name</ span > < span class ="p "> ,</ span > < span class ="n "> param</ span > < span class ="ow "> in</ span > < span class ="nb "> super</ span > < span class ="p "> ()</ span > < span class ="o "> .</ span > < span class ="n "> named_parameters</ span > < span class ="p "> (</ span > < span class ="o "> *</ span > < span class ="n "> args</ span > < span class ="p "> ,</ span > < span class ="o "> **</ span > < span class ="n "> kwargs</ span > < span class ="p "> ):</ span >
1783
+ < span class ="k "> if</ span > < span class ="n "> in_summon_full_params</ span > < span class ="p "> :</ span >
1784
+ < span class ="c1 "> # Remove any instances of the FSDP-specific prefix; there can</ span >
1785
+ < span class ="c1 "> # be multiple in the case of nested FSDP modules</ span >
1786
+ < span class ="n "> param_name</ span > < span class ="o "> =</ span > < span class ="n "> param_name</ span > < span class ="o "> .</ span > < span class ="n "> replace</ span > < span class ="p "> (</ span > < span class ="n "> FSDP_PREFIX</ span > < span class ="p "> ,</ span > < span class ="s2 "> ""</ span > < span class ="p "> )</ span >
1787
+ < span class ="k "> yield</ span > < span class ="p "> (</ span > < span class ="n "> param_name</ span > < span class ="p "> ,</ span > < span class ="n "> param</ span > < span class ="p "> )</ span > </ div >
1788
+
1765
1789
< span class ="k "> def</ span > < span class ="nf "> _register_pre_backward_hooks</ span > < span class ="p "> (</ span > < span class ="bp "> self</ span > < span class ="p "> ,</ span > < span class ="n "> outputs</ span > < span class ="p "> :</ span > < span class ="n "> Any</ span > < span class ="p "> )</ span > < span class ="o "> -></ span > < span class ="n "> Any</ span > < span class ="p "> :</ span >
1766
1790
< span class ="sd "> """Register pre-backward hook to run before the wrapped module's</ span >
1767
1791
< span class ="sd "> backward. Hooks should be attached to all outputs from the forward.</ span >
0 commit comments