Skip to content

Commit 927f031

Browse files
committed
Major module / path restructure, timm.models.layers -> timm.layers, add _ prefix to all non model modules in timm.models
1 parent da6644b commit 927f031

File tree

149 files changed

+1387
-1269
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

149 files changed

+1387
-1269
lines changed

avg_checkpoints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import os
1717
import glob
1818
import hashlib
19-
from timm.models.helpers import load_state_dict
19+
from timm.models import load_state_dict
2020

2121
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
2222
parser.add_argument('--input', default='', type=str, metavar='PATH',

clean_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import hashlib
1414
import shutil
1515
from collections import OrderedDict
16-
from timm.models.helpers import load_state_dict
16+
from timm.models import load_state_dict
1717

1818
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
1919
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',

hubconf.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
dependencies = ['torch']
2-
from timm.models import registry
3-
4-
globals().update(registry._model_entrypoints)
2+
import timm
3+
globals().update(timm.models._registry._model_entrypoints)

inference.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,23 @@
55
66
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
77
"""
8-
import os
9-
import time
108
import argparse
119
import json
1210
import logging
11+
import os
12+
import time
1313
from contextlib import suppress
1414
from functools import partial
1515

1616
import numpy as np
1717
import pandas as pd
1818
import torch
1919

20-
from timm.models import create_model, apply_test_time_pool, load_checkpoint
2120
from timm.data import create_dataset, create_loader, resolve_data_config
21+
from timm.layers import apply_test_time_pool
22+
from timm.models import create_model
2223
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser
2324

24-
25-
2625
try:
2726
from apex import amp
2827
has_apex = True

tests/test_layers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
import pytest
21
import torch
32
import torch.nn as nn
4-
import platform
5-
import os
63

7-
from timm.models.layers import create_act_layer, get_act_layer, set_layer_config
4+
from timm.layers import create_act_layer, set_layer_config
85

96

107
class MLP(nn.Module):

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import timm
1616
from timm import list_models, create_model, set_scriptable, get_pretrained_cfg_value
17-
from timm.models.fx_features import _leaf_modules, _autowrap_functions
17+
from timm.models._features_fx import _leaf_modules, _autowrap_functions
1818

1919
if hasattr(torch._C, '_jit_set_profiling_executor'):
2020
# legacy executor is too slow to compile large models for unit tests

timm/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .version import __version__
2+
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable
23
from .models import create_model, list_models, list_pretrained, is_model, list_modules, model_entrypoint, \
3-
is_scriptable, is_exportable, set_scriptable, set_exportable, \
44
is_model_pretrained, get_pretrained_cfg, get_pretrained_cfg_value

timm/data/readers/class_map.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import pickle
33

4+
45
def load_class_map(map_or_filename, root=''):
56
if isinstance(map_or_filename, dict):
67
assert dict, 'class_map dict must be non-empty'
@@ -14,7 +15,7 @@ def load_class_map(map_or_filename, root=''):
1415
with open(class_map_path) as f:
1516
class_to_idx = {v.strip(): k for k, v in enumerate(f)}
1617
elif class_map_ext == '.pkl':
17-
with open(class_map_path,'rb') as f:
18+
with open(class_map_path, 'rb') as f:
1819
class_to_idx = pickle.load(f)
1920
else:
2021
assert False, f'Unsupported class map file extension ({class_map_ext}).'

timm/layers/__init__.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from .activations import *
2+
from .adaptive_avgmax_pool import \
3+
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
4+
from .blur_pool import BlurPool2d
5+
from .classifier import ClassifierHead, create_classifier
6+
from .cond_conv2d import CondConv2d, get_condconv_initializer
7+
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
8+
set_layer_config
9+
from .conv2d_same import Conv2dSame, conv2d_same
10+
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
11+
from .create_act import create_act_layer, get_act_layer, get_act_fn
12+
from .create_attn import get_attn, create_attn
13+
from .create_conv2d import create_conv2d
14+
from .create_norm import get_norm_layer, create_norm_layer
15+
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
16+
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
17+
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
18+
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
19+
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
20+
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
21+
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
22+
from .gather_excite import GatherExcite
23+
from .global_context import GlobalContext
24+
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple
25+
from .inplace_abn import InplaceAbn
26+
from .linear import Linear
27+
from .mixed_conv2d import MixedConv2d
28+
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
29+
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
30+
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
31+
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
32+
from .padding import get_padding, get_same_padding, pad_same
33+
from .patch_embed import PatchEmbed
34+
from .pool2d_same import AvgPool2dSame, create_pool2d
35+
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
36+
from .selective_kernel import SelectiveKernel
37+
from .separable_conv import SeparableConv2d, SeparableConvNormAct
38+
from .space_to_depth import SpaceToDepthModule
39+
from .split_attn import SplitAttn
40+
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
41+
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
42+
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
43+
from .trace_utils import _assert, _float_to_int
44+
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
File renamed without changes.

0 commit comments

Comments
 (0)