@@ -1098,14 +1098,27 @@ def maybe_mark_profile(*args, **kwargs):
1098
1098
torch ._dynamo .config .repro_tolerance = tolerance
1099
1099
1100
1100
with maybe_profile (args .export_profiler_trace , ** args .profile_details ) as p :
1101
- if args .export_aot_inductor :
1102
- frozen_model_iter_fn = export_aot_inductor (
1103
- model , example_inputs , args .inductor_compile_mode
1104
- )
1105
- elif args .export_nativert :
1106
- frozen_model_iter_fn = export_nativert (model , example_inputs )
1101
+ use_generate_mode = kwargs .get ("use_generate_mode" , False )
1102
+ if use_generate_mode :
1103
+ assert not args .training
1104
+
1105
+ if args .export_aot_inductor :
1106
+ model .forward = export_aot_inductor_simple (
1107
+ model , example_inputs , args .inductor_compile_mode
1108
+ )
1109
+ elif args .export_nativert :
1110
+ frozen_model_iter_fn = export_nativert (model , example_inputs )
1111
+ else :
1112
+ model .forward = torch ._dynamo .run (model )
1113
+
1114
+ frozen_model_iter_fn = model_iter_fn
1107
1115
else :
1108
- frozen_model_iter_fn = torch ._dynamo .run (model_iter_fn )
1116
+ if args .export_aot_inductor :
1117
+ frozen_model_iter_fn = export_aot_inductor (
1118
+ model , example_inputs , args .inductor_compile_mode
1119
+ )
1120
+ else :
1121
+ frozen_model_iter_fn = torch ._dynamo .run (model_iter_fn )
1109
1122
1110
1123
for rep in trange (args .repeat , desc = "running benchmark" ):
1111
1124
inputs = (
@@ -1120,15 +1133,16 @@ def maybe_mark_profile(*args, **kwargs):
1120
1133
1121
1134
# interleave the runs to handle frequency scaling and load changes
1122
1135
with maybe_mark_profile (p = p , mark = "expected" ):
1123
- timings [rep , 0 ], expected_output = timed (
1124
- model ,
1125
- model_iter_fn ,
1126
- inputs ,
1127
- return_result = True ,
1128
- times = times ,
1129
- collect_outputs = args .collect_outputs ,
1130
- batch_size = kwargs .get ("batch_size" ),
1131
- )
1136
+ with torch .compiler .set_stance ("force_eager" ):
1137
+ timings [rep , 0 ], expected_output = timed (
1138
+ model ,
1139
+ model_iter_fn ,
1140
+ inputs ,
1141
+ return_result = True ,
1142
+ times = times ,
1143
+ collect_outputs = args .collect_outputs ,
1144
+ batch_size = kwargs .get ("batch_size" ),
1145
+ )
1132
1146
1133
1147
# call mark_step between the 2 calls to make the comparison fair.
1134
1148
maybe_mark_step (args )
@@ -1518,8 +1532,12 @@ def opt_nativert(_, example_inputs, collect_outputs=False):
1518
1532
return opt_nativert
1519
1533
1520
1534
1535
+ def export_aot_inductor_simple (model , example_inputs , mode ):
1536
+ return AOTInductorModelCache .load (model , example_inputs , mode )
1537
+
1538
+
1521
1539
def export_aot_inductor (model , example_inputs , mode ):
1522
- optimized = AOTInductorModelCache . load (model , example_inputs , mode )
1540
+ optimized = export_aot_inductor_simple (model , example_inputs , mode )
1523
1541
1524
1542
def opt_aot_inductor (_ , example_inputs , collect_outputs = False ):
1525
1543
example_args , example_kwargs = _normalize_bench_inputs (example_inputs )
@@ -2200,11 +2218,12 @@ def record_status(accuracy_status, dynamo_start_stats):
2200
2218
reset_rng_state ()
2201
2219
model_copy = None
2202
2220
try :
2203
- model_copy = self .deepcopy_and_maybe_parallelize (model )
2204
- self .init_optimizer (name , current_device , model_copy .parameters ())
2205
- correct_result = self .run_n_iterations (
2206
- model_copy , clone_inputs (example_inputs ), self .model_iter_fn
2207
- )
2221
+ with torch .compiler .set_stance ("force_eager" ):
2222
+ model_copy = self .deepcopy_and_maybe_parallelize (model )
2223
+ self .init_optimizer (name , current_device , model_copy .parameters ())
2224
+ correct_result = self .run_n_iterations (
2225
+ model_copy , clone_inputs (example_inputs ), self .model_iter_fn
2226
+ )
2208
2227
except Exception as e :
2209
2228
accuracy_status = (
2210
2229
"eager_1st_run_OOM"
@@ -2221,11 +2240,12 @@ def record_status(accuracy_status, dynamo_start_stats):
2221
2240
reset_rng_state ()
2222
2241
model_copy = None
2223
2242
try :
2224
- model_copy = self .deepcopy_and_maybe_parallelize (model )
2225
- self .init_optimizer (name , current_device , model_copy .parameters ())
2226
- correct_rerun_result = self .run_n_iterations (
2227
- model_copy , clone_inputs (example_inputs ), self .model_iter_fn
2228
- )
2243
+ with torch .compiler .set_stance ("force_eager" ):
2244
+ model_copy = self .deepcopy_and_maybe_parallelize (model )
2245
+ self .init_optimizer (name , current_device , model_copy .parameters ())
2246
+ correct_rerun_result = self .run_n_iterations (
2247
+ model_copy , clone_inputs (example_inputs ), self .model_iter_fn
2248
+ )
2229
2249
except Exception as e :
2230
2250
accuracy_status = (
2231
2251
"eager_2nd_run_OOM"
@@ -2274,6 +2294,11 @@ def record_status(accuracy_status, dynamo_start_stats):
2274
2294
try :
2275
2295
model_copy = self .deepcopy_and_maybe_parallelize (model )
2276
2296
self .init_optimizer (name , current_device , model_copy .parameters ())
2297
+
2298
+ use_generate_mode = getattr (self , "use_generate_mode" , False )
2299
+ if use_generate_mode :
2300
+ assert not self .args .training
2301
+
2277
2302
if (
2278
2303
self .args .export
2279
2304
or self .args .export_aot_inductor
@@ -2286,12 +2311,23 @@ def record_status(accuracy_status, dynamo_start_stats):
2286
2311
optimized_model_iter_fn = optimize_ctx (
2287
2312
model_copy , example_inputs
2288
2313
)
2289
- new_result = optimized_model_iter_fn (model_copy , example_inputs )
2314
+ if use_generate_mode :
2315
+ new_result = self .model_iter_fn (model_copy , example_inputs )
2316
+ else :
2317
+ new_result = optimized_model_iter_fn (
2318
+ model_copy , example_inputs
2319
+ )
2290
2320
else :
2291
- optimized_model_iter_fn = optimize_ctx (self .model_iter_fn )
2292
- new_result = self .run_n_iterations (
2293
- model_copy , example_inputs , optimized_model_iter_fn
2294
- )
2321
+ if use_generate_mode :
2322
+ optimized_model = optimize_ctx (model_copy )
2323
+ new_result = self .run_n_iterations (
2324
+ optimized_model , example_inputs , self .model_iter_fn
2325
+ )
2326
+ else :
2327
+ optimized_model_iter_fn = optimize_ctx (self .model_iter_fn )
2328
+ new_result = self .run_n_iterations (
2329
+ model_copy , example_inputs , optimized_model_iter_fn
2330
+ )
2295
2331
except Exception as e :
2296
2332
log .exception ("" )
2297
2333
print (
@@ -2507,14 +2543,22 @@ def warmup(fn, model, example_inputs, mode, niters=10):
2507
2543
self .model_iter_fn , model , example_inputs , "eager" , niters = 1
2508
2544
)
2509
2545
2510
- baseline_timings = experiment (
2511
- model , example_inputs , mark = "expected" , ** experiment_kwargs
2512
- )
2546
+ with torch .compiler .set_stance ("force_eager" ):
2547
+ baseline_timings = experiment (
2548
+ model , example_inputs , mark = "expected" , ** experiment_kwargs
2549
+ )
2550
+
2551
+ use_generate_mode = getattr (self , "use_generate_mode" , False )
2552
+ if use_generate_mode :
2553
+ assert not self .args .training
2554
+ optimized_model_iter_fn = self .model_iter_fn
2555
+ model = optimize_ctx (model )
2513
2556
2514
- if self .args .export_aot_inductor :
2515
- optimized_model_iter_fn = optimize_ctx
2516
2557
else :
2517
- optimized_model_iter_fn = optimize_ctx (self .model_iter_fn )
2558
+ if self .args .export_aot_inductor :
2559
+ optimized_model_iter_fn = optimize_ctx
2560
+ else :
2561
+ optimized_model_iter_fn = optimize_ctx (self .model_iter_fn )
2518
2562
2519
2563
with maybe_snapshot_memory (
2520
2564
self .args .snapshot_memory , f"compiled_{ self .args .only } "
@@ -2662,22 +2706,34 @@ def warmup(fn, model, example_inputs, mode, niters=5):
2662
2706
with maybe_snapshot_memory (
2663
2707
self .args .snapshot_memory , f"eager_{ self .args .only } "
2664
2708
):
2665
- eager_latency , eager_peak_mem , _ = warmup (
2666
- self .model_iter_fn , copy .deepcopy (model ), example_inputs , "eager"
2667
- )
2668
- if self .args .use_warm_peak_memory :
2669
- _ , eager_peak_mem , _ = warmup (
2709
+ with torch .compiler .set_stance ("force_eager" ):
2710
+ eager_latency , eager_peak_mem , _ = warmup (
2670
2711
self .model_iter_fn ,
2671
2712
copy .deepcopy (model ),
2672
2713
example_inputs ,
2673
2714
"eager" ,
2674
- niters = 1 ,
2675
2715
)
2716
+ if self .args .use_warm_peak_memory :
2717
+ _ , eager_peak_mem , _ = warmup (
2718
+ self .model_iter_fn ,
2719
+ copy .deepcopy (model ),
2720
+ example_inputs ,
2721
+ "eager" ,
2722
+ niters = 1 ,
2723
+ )
2724
+
2725
+ use_generate_mode = getattr (self , "use_generate_mode" , False )
2726
+ experiment_kwargs ["use_generate_mode" ] = use_generate_mode
2727
+ if use_generate_mode :
2728
+ assert not self .args .training
2729
+ optimized_model_iter_fn = self .model_iter_fn
2730
+ model = optimize_ctx (model )
2676
2731
2677
- if self .args .export_aot_inductor or self .args .export_nativert :
2678
- optimized_model_iter_fn = optimize_ctx
2679
2732
else :
2680
- optimized_model_iter_fn = optimize_ctx (self .model_iter_fn )
2733
+ if self .args .export_aot_inductor or self .args .export_nativert :
2734
+ optimized_model_iter_fn = optimize_ctx
2735
+ else :
2736
+ optimized_model_iter_fn = optimize_ctx (self .model_iter_fn )
2681
2737
2682
2738
with maybe_snapshot_memory (
2683
2739
self .args .snapshot_memory , f"compiled_{ self .args .only } "
0 commit comments