18
18
from contextlib import suppress
19
19
from functools import partial
20
20
21
+ try :
22
+ from deepspeed .profiling .flops_profiler import get_model_profile
23
+ except ImportError as e :
24
+ get_model_profile = None
25
+
21
26
from timm .models import create_model , is_model , list_models
22
27
from timm .optim import create_optimizer_v2
23
28
from timm .data import resolve_data_config
67
72
metavar = 'N' , help = 'Input image dimension, uses model default if empty' )
68
73
parser .add_argument ('--input-size' , default = None , nargs = 3 , type = int ,
69
74
metavar = 'N N N' , help = 'Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty' )
75
+ parser .add_argument ('--use-train-size' , action = 'store_true' , default = False ,
76
+ help = 'Run inference at train size, not test-input-size if it exists.' )
70
77
parser .add_argument ('--num-classes' , type = int , default = None ,
71
78
help = 'Number classes in dataset' )
72
79
parser .add_argument ('--gp' , default = None , type = str , metavar = 'POOL' ,
81
88
help = 'convert model torchscript for inference' )
82
89
83
90
91
+
84
92
# train optimizer parameters
85
93
parser .add_argument ('--opt' , default = 'sgd' , type = str , metavar = 'OPTIMIZER' ,
86
94
help = 'Optimizer (default: "sgd"' )
@@ -139,10 +147,25 @@ def resolve_precision(precision: str):
139
147
return use_amp , model_dtype , data_dtype
140
148
141
149
150
+ def profile (model , input_size = (3 , 224 , 224 )):
151
+ batch_size = 1
152
+ macs , params = get_model_profile (
153
+ model = model ,
154
+ input_res = (batch_size ,) + input_size , # input shape or input to the input_constructor
155
+ input_constructor = None , # if specified, a constructor taking input_res is used as input to the model
156
+ print_profile = False , # prints the model graph with the measured profile attached to each module
157
+ detailed = False , # print the detailed profile
158
+ warm_up = 10 , # the number of warm-ups before measuring the time of each module
159
+ as_string = False , # print raw numbers (e.g. 1000) or as human-readable strings (e.g. 1k)
160
+ output_file = None , # path to the output file. If None, the profiler prints to stdout.
161
+ ignore_modules = None ) # the list of modules to ignore in the profiling
162
+ return macs
163
+
164
+
142
165
class BenchmarkRunner :
143
166
def __init__ (
144
167
self , model_name , detail = False , device = 'cuda' , torchscript = False , precision = 'float32' ,
145
- num_warm_iter = 10 , num_bench_iter = 50 , ** kwargs ):
168
+ num_warm_iter = 10 , num_bench_iter = 50 , use_train_size = False , ** kwargs ):
146
169
self .model_name = model_name
147
170
self .detail = detail
148
171
self .device = device
@@ -166,7 +189,7 @@ def __init__(
166
189
if torchscript :
167
190
self .model = torch .jit .script (self .model )
168
191
169
- data_config = resolve_data_config (kwargs , model = self .model , use_test_size = True )
192
+ data_config = resolve_data_config (kwargs , model = self .model , use_test_size = not use_train_size )
170
193
self .input_size = data_config ['input_size' ]
171
194
self .batch_size = kwargs .pop ('batch_size' , 256 )
172
195
@@ -234,6 +257,10 @@ def _step():
234
257
param_count = round (self .param_count / 1e6 , 2 ),
235
258
)
236
259
260
+ if get_model_profile is not None :
261
+ macs = profile (self .model , self .input_size )
262
+ results ['GMACs' ] = round (macs / 1e9 , 2 )
263
+
237
264
_logger .info (
238
265
f"Inference benchmark of { self .model_name } done. "
239
266
f"{ results ['samples_per_sec' ]:.2f} samples/sec, { results ['step_time' ]:.2f} ms/step" )
0 commit comments