diff --git a/README.md b/README.md index b95d5b3..ce9df9e 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ -## VAD: Vectorized Scene Representation for Efficient Autonomous Driving +## VAD v1 & v2 + +[project page](https://hgao-cv.github.io/VADv2/) https://user-images.githubusercontent.com/45144254/229673708-648e8da5-4c70-4346-9da2-423447d1ecde.mp4 @@ -7,7 +9,7 @@ https://github.com/hustvl/VAD/assets/45144254/153b9bf0-5159-46b5-9fab-573baf5c61 > [**VAD: Vectorized Scene Representation for Efficient Autonomous Driving**](https://arxiv.org/abs/2303.12077) > -> [Bo Jiang](https://github.com/rb93dett)1\*, [Shaoyu Chen](https://scholar.google.com/citations?user=PIeNN2gAAAAJ&hl=en&oi=sra)1\*, Qing Xu2, [Bencheng Liao](https://github.com/LegendBC)1, Jiajie Chen2, [Helong Zhou](https://scholar.google.com/citations?user=wkhOMMwAAAAJ&hl=en&oi=ao)2, [Qian Zhang](https://scholar.google.com/citations?user=pCY-bikAAAAJ&hl=zh-CN)2, [Wenyu Liu](http://eic.hust.edu.cn/professor/liuwenyu/)1, [Chang Huang](https://scholar.google.com/citations?user=IyyEKyIAAAAJ&hl=zh-CN)2, [Xinggang Wang](https://xinggangw.info/)1,† +> [Bo Jiang](https://github.com/rb93dett)1\*, [Shaoyu Chen](https://scholar.google.com/citations?user=PIeNN2gAAAAJ&hl=en&oi=sra)1\*, Qing Xu2, [Bencheng Liao](https://github.com/LegendBC)1, Jiajie Chen2, [Helong Zhou](https://scholar.google.com/citations?user=wkhOMMwAAAAJ&hl=en&oi=ao)2, [Qian Zhang](https://scholar.google.com/citations?user=pCY-bikAAAAJ&hl=zh-CN)2, [Wenyu Liu](http://eic.hust.edu.cn/professor/liuwenyu/)1, [Chang Huang](https://scholar.google.com/citations?user=IyyEKyIAAAAJ&hl=zh-CN)2, [Xinggang Wang](https://xwcv.github.io/)1,† > > 1 Huazhong University of Science and Technology, 2 Horizon Robotics > @@ -16,6 +18,12 @@ https://github.com/hustvl/VAD/assets/45144254/153b9bf0-5159-46b5-9fab-573baf5c61 >[arXiv Paper](https://arxiv.org/abs/2303.12077), ICCV 2023 ## News +* **`27 Feb, 2025`:** Check out our latest work, [DiffusionDrive](https://github.com/hustvl/DiffusionDrive), accepted to CVPR 2025! This study explores multi-modal end-to-end driving using diffusion models for real-time and real-world applications. +* **`19 Feb, 2025`:** Checkout our new work [RAD](https://hgao-cv.github.io/RAD) 🥰, end-to-end autonomous driving with large-scale 3DGS-based Reinforcement Learning post-training. +* **`30 Oct, 2024`:** Checkout our new work [Senna](https://github.com/hustvl/Senna) 🥰, which combines VAD/VADv2 with large vision-language models to achieve more accurate, robust, and generalizable autonomous driving planning. +* **`20 Sep, 2024`:** Core code of VADv2 (config and model) is available in the `VADv2` folder. Easy to integrade it into the VADv1 framework for training and inference. +* **`17 June, 2024`:** CARLA implementation of VADv1 is available on [Bench2Drive](https://github.com/Thinklab-SJTU/Bench2Drive?tab=readme-ov-file). +* **`20 Feb, 2024`:** VADv2 is available on arXiv [paper](https://arxiv.org/pdf/2402.13243) [project page](https://hgao-cv.github.io/VADv2/). * **`1 Aug, 2023`:** Code & models are released! * **`14 July, 2023`:** VAD is accepted by ICCV 2023🎉! Code and models will be open source soon! * **`21 Mar, 2023`:** We release the VAD paper on [arXiv](https://arxiv.org/abs/2303.12077). Code/Models are coming soon. Please stay tuned! ☕️ @@ -83,6 +91,13 @@ If you find VAD useful in your research or applications, please consider giving journal={ICCV}, year={2023} } + +@article{chen2024vadv2, + title={Vadv2: End-to-end vectorized autonomous driving via probabilistic planning}, + author={Chen, Shaoyu and Jiang, Bo and Gao, Hao and Liao, Bencheng and Xu, Qing and Zhang, Qian and Huang, Chang and Liu, Wenyu and Wang, Xinggang}, + journal={arXiv preprint arXiv:2402.13243}, + year={2024} +} ``` ## License diff --git a/VADv2/VADv2_config_voca4096.py b/VADv2/VADv2_config_voca4096.py new file mode 100644 index 0000000..1a138cb --- /dev/null +++ b/VADv2/VADv2_config_voca4096.py @@ -0,0 +1,550 @@ +_base_ = [ + '../datasets/custom_nus-3d.py', + '../_base_/default_runtime.py' +] +# +plugin = True +plugin_dir = 'projects/mmdet3d_plugin/' + +# If point cloud range is changed, the models should also change their point +# cloud range accordingly +point_cloud_range = [-15.0, -30.0, -2.0, 15.0, 30.0, 2.0] +voxel_size = [0.15, 0.15, 4] + +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +# For nuScenes we usually do 10-class detection +class_names = ['car', 'bicycle', 'motorcycle', 'pedestrian'] +num_classes = len(class_names) + +# map has classes: lane_divider, road_edge, crosswalk +# map_classes = ['lane_divider','road_edge','crosswalk'] +# map_classes = ['road_edge', 'crosswalk'] +map_classes = ['lane_divider', 'road_edge', 'crosswalk', 'centerline'] + +# fixed_ptsnum_per_line = 20 +# map_classes = ['divider',] +map_num_vec = 100 +map_fixed_ptsnum_per_gt_line = 20 # now only support fixed_pts > 0 +map_fixed_ptsnum_per_pred_line = 20 +map_eval_use_same_gt_sample_num_flag = True +map_num_classes = len(map_classes) + +input_modality = dict( + use_lidar=False, + use_camera=True, + use_radar=False, + use_map=False, + use_external=True) + +_dim_ = 256 +_pos_dim_ = _dim_//2 +_ffn_dim_ = _dim_*2 +_num_levels_ = 4 +bev_h_ = 200 +bev_w_ = 200 +queue_length = 4 # each sequence contains `queue_length` frames. +# total_epochs = 12 + +# camera view order +view_names = ['rgb_front', 'rgb_front_right', 'rgb_front_left', + 'rgb_rear', 'rgb_rear_left', 'rgb_rear_right'] + +dataset_type = 'v3ADTRCustomCarlaDataset' +data_root = 'data/carladata/v2/pkl/' +file_client_args = dict(backend='disk') + +# open-loop test param +# route_id = 6 +test_data_root = 'data/carladata/v2/pkl/' +test_pkl = 'town05_long_new.pkl' #'carladata_v18_selected.pkl' #'town05_long_new.pkl' # 'carladata_town05long.pkl' #'carladata_town05long.pkl' #'carla_minival_v3.pkl' +gt_anno_file = 'test_record/carla/v2' +map_ann_file = gt_anno_file + '/test_map.json' +agent_ann_file = gt_anno_file + '/test_agent.json' +eval_detection_configs_path = 'projects/mmdet3d_plugin/datasets/carladata_eval_detection_configs.json' + +# open-loop train param +train_pkl = 'carladata_v14_part0.pkl' #'carladata_v11.pkl' # carladata_v6+v9.pkl' + + +model = dict( + type='v116ADTR', + use_grid_mask=True, + video_test_mode=True, + pretrained=dict(img='torchvision://resnet50'), + img_backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch'), + img_neck=dict( + type='FPN', + in_channels=[512, 1024, 2048], + out_channels=_dim_, + start_level=0, + add_extra_convs='on_output', + num_outs=_num_levels_, + relu_before_extra_convs=True), + pts_bbox_head=dict( + type='v116ADTRHead', + mot_map_thresh=0.5, + mot_dis_thresh=0.2, + pe_normalization=True, + plan_fut_mode=256, #1024 + plan_fut_mode_testing=4096, + tot_epoch=None, + ego_query_thresh=0.0, + query_use_fix_pad=False, + ego_lcf_feat_idx=[0,1,4], + valid_fut_ts=6, + plan_anchors_path='carla_plan_vocabulary_4096.npy', #'./gt_trajs.npy', + ego_pv_decoder=dict( + type='v0MotionTransformerDecoder', + num_layers=1, + return_intermediate=False, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=[ + dict( + type='MultiheadAttention', + embed_dims=_dim_, + num_heads=8, + dropout=0.1), + ], + feedforward_channels=_ffn_dim_, + ffn_dropout=0.1, + operation_order=('cross_attn', 'norm', 'ffn', 'norm'))), + ego_agent_decoder=dict( + type='v0MotionTransformerDecoder', + num_layers=1, + return_intermediate=False, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=[ + dict( + type='MultiheadAttention', + embed_dims=_dim_, + num_heads=8, + dropout=0.1), + ], + feedforward_channels=_ffn_dim_, + ffn_dropout=0.1, + operation_order=('cross_attn', 'norm', 'ffn', 'norm'))), + ego_map_decoder=dict( + type='v0MotionTransformerDecoder', + num_layers=1, + return_intermediate=False, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=[ + dict( + type='MultiheadAttention', + embed_dims=_dim_, + num_heads=8, + dropout=0.1), + ], + feedforward_channels=_ffn_dim_, + ffn_dropout=0.1, + operation_order=('cross_attn', 'norm', 'ffn', 'norm'))), + cf_backbone_ckpt='ckpts/resnet50-0676ba61.pth', + cf_backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(3, ), + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch'), + mot_decoder=dict( + type='v0MotionTransformerDecoder', + num_layers=1, + return_intermediate=False, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=[ + dict( + type='MultiheadAttention', + embed_dims=_dim_, + num_heads=8, + dropout=0.1), + ], + feedforward_channels=_ffn_dim_, + ffn_dropout=0.1, + operation_order=('cross_attn', 'norm', 'ffn', 'norm'))), + mot_map_decoder=dict( + type='v0MotionTransformerDecoder', + num_layers=1, + return_intermediate=False, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=[ + dict( + type='MultiheadAttention', + embed_dims=_dim_, + num_heads=8, + dropout=0.1), + ], + feedforward_channels=_ffn_dim_, + ffn_dropout=0.1, + operation_order=('cross_attn', 'norm', 'ffn', 'norm'))), + interaction_pe_type='sine_mlp', + bev_h=bev_h_, + bev_w=bev_w_, + num_query=300, + num_classes=num_classes, + in_channels=_dim_, + sync_cls_avg_factor=True, + with_box_refine=True, + as_two_stage=False, + map_num_vec=map_num_vec, + map_num_classes=map_num_classes, + map_num_pts_per_vec=map_fixed_ptsnum_per_pred_line, + map_num_pts_per_gt_vec=map_fixed_ptsnum_per_gt_line, + map_query_embed_type='instance_pts', + map_transform_method='minmax', + map_gt_shift_pts_pattern='v2', + map_dir_interval=1, + map_code_size=2, + map_code_weights=[1.0, 1.0, 1.0, 1.0], + transformer=dict( + type='v0PerceptionTransformer', + map_num_vec=map_num_vec, + map_num_pts_per_vec=map_fixed_ptsnum_per_pred_line, + rotate_prev_bev=True, + use_shift=True, + use_can_bus=True, + embed_dims=_dim_, + encoder=dict( + type='BEVFormerEncoder', + num_layers=6, + pc_range=point_cloud_range, + num_points_in_pillar=4, + return_intermediate=False, + transformerlayers=dict( + type='BEVFormerLayer', + attn_cfgs=[ + dict( + type='TemporalSelfAttention', + embed_dims=_dim_, + num_levels=1), + dict( + type='SpatialCrossAttention', + pc_range=point_cloud_range, + deformable_attention=dict( + type='MSDeformableAttention3D', + embed_dims=_dim_, + num_points=8, + num_levels=_num_levels_), + embed_dims=_dim_, + ) + ], + feedforward_channels=_ffn_dim_, + ffn_dropout=0.1, + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', + 'ffn', 'norm'))), + decoder=dict( + type='DetectionTransformerDecoder', + num_layers=6, + return_intermediate=True, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=[ + dict( + type='MultiheadAttention', + embed_dims=_dim_, + num_heads=8, + dropout=0.1), + dict( + type='CustomMSDeformableAttention', + embed_dims=_dim_, + num_levels=1), + ], + feedforward_channels=_ffn_dim_, + ffn_dropout=0.1, + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', + 'ffn', 'norm'))), + map_decoder=dict( + type='v0MapDetectionTransformerDecoder', + num_layers=6, + return_intermediate=True, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=[ + dict( + type='MultiheadAttention', + embed_dims=_dim_, + num_heads=8, + dropout=0.1), + dict( + type='CustomMSDeformableAttention', + embed_dims=_dim_, + num_levels=1), + ], + feedforward_channels=_ffn_dim_, + ffn_dropout=0.1, + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', + 'ffn', 'norm')))), + bbox_coder=dict( + type='v2CustomNMSFreeCoder', + post_center_range=[-20, -35, -10.0, 20, 35, 10.0], + pc_range=point_cloud_range, + max_num=100, + voxel_size=voxel_size, + num_classes=num_classes), + map_bbox_coder=dict( + type='v0MapNMSFreeCoder', + post_center_range=[-20, -35, -20, -35, 20, 35, 20, 35], + pc_range=point_cloud_range, + max_num=50, + voxel_size=voxel_size, + num_classes=map_num_classes), + positional_encoding=dict( + type='LearnedPositionalEncoding', + num_feats=_pos_dim_, + row_num_embed=bev_h_, + col_num_embed=bev_w_, + ), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.8), + loss_bbox=dict(type='L1Loss', loss_weight=0.1), + loss_mot_reg=dict(type='L1Loss', loss_weight=0.1), + loss_mot_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.2), + loss_iou=dict(type='GIoULoss', loss_weight=0.0), + loss_map_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.8), + loss_map_bbox=dict(type='L1Loss', loss_weight=0.0), + loss_map_iou=dict(type='GIoULoss', loss_weight=0.0), + loss_map_pts=dict(type='PtsL1Loss', loss_weight=0.4), + loss_map_dir=dict(type='PtsDirCosLoss', loss_weight=0.005), + loss_plan_cls_col=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.0), + loss_plan_cls_bd=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.0), + loss_plan_cls_cl=dict(type='L1Loss', loss_weight=0.0), + # loss_plan_cls_cl=dict( + # type='FocalLoss', + # use_sigmoid=True, + # gamma=2.0, + # alpha=0.25, + # loss_weight=2.0), + loss_plan_cls_expert=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=200.), + loss_plan_reg=dict(type='L1Loss', loss_weight=0.0), + loss_plan_bound=dict(type='PlanMapBoundLoss', loss_weight=0.0, lane_bound_cls_idx=0), + loss_plan_agent_dis=dict(type='PlanAgentDisLoss', loss_weight=0.0), + loss_plan_map_theta=dict(type='PlanMapThetaLoss', loss_weight=0.0, lane_div_cls_idx=0), # fake idx + loss_tl_status_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=0.8, + class_weight=None), + loss_tl_trigger_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=4., + class_weight=None), + loss_stopsign_trigger_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0, + class_weight=None)), + + # model training and testing settings + train_cfg=dict(pts=dict( + grid_size=[512, 512, 1], + voxel_size=voxel_size, + point_cloud_range=point_cloud_range, + out_size_factor=4, + assigner=dict( + type='HungarianAssigner3D', + cls_cost=dict(type='FocalLossCost', weight=0.8), + reg_cost=dict(type='BBox3DL1Cost', weight=0.1), + iou_cost=dict(type='IoUCost', weight=0.0), # Fake cost. This is just to make it compatible with DETR head. + pc_range=point_cloud_range), + map_assigner=dict( + type='v0MapHungarianAssigner3D', + cls_cost=dict(type='FocalLossCost', weight=0.8), + reg_cost=dict(type='BBoxL1Cost', weight=0.0, box_format='xywh'), + # reg_cost=dict(type='BBox3DL1Cost', weight=0.25), + # iou_cost=dict(type='IoUCost', weight=1.0), # Fake cost. This is just to make it compatible with DETR head. + iou_cost=dict(type='IoUCost', iou_mode='giou', weight=0.0), + pts_cost=dict(type='OrderedPtsL1Cost', weight=0.4), + pc_range=point_cloud_range)))) + +train_pipeline = [ + dict(type='LoadMultiViewImageFromCarla', to_float32=True), + dict(type='PhotoMetricDistortionMultiViewImage'), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True, with_attr_label=True), + dict(type='CustomObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='CustomObjectNameFilter', classes=class_names), + dict(type='NormalizeMultiviewImage', **img_norm_cfg), + dict(type='RandomScaleImageMultiViewImage', scales=[0.5]), + dict(type='PadMultiViewImage', size_divisor=32), + dict(type='CustomDefaultFormatBundle3DFromCarla', class_names=class_names, with_ego=True), + dict(type='CustomCollect3D',\ + keys=['gt_bboxes_3d', 'gt_labels_3d', 'img', 'ego_his_trajs', + 'ego_fut_trajs', 'ego_fut_masks', 'ego_fut_cmd', + 'ego_lcf_feat', 'gt_attr_labels', 'command_id', 'target_point', 'traffic_signal', 'stop_sign_signal']) + +] + +test_pipeline = [ + dict(type='LoadMultiViewImageFromCarla', to_float32=True), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True, with_attr_label=True), + dict(type='CustomObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='CustomObjectNameFilter', classes=class_names), + dict(type='NormalizeMultiviewImage', **img_norm_cfg), + # dict(type='PadMultiViewImage', size_divisor=32), + dict( + type='MultiScaleFlipAug3D', + img_scale=(416, 320), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict(type='RandomScaleImageMultiViewImage', scales=[0.5]), + dict(type='PadMultiViewImage', size_divisor=32), + dict(type='CustomDefaultFormatBundle3DFromCarla', + class_names=class_names, with_label=False, with_ego=True), + dict(type='CustomCollect3D',\ + keys=['gt_bboxes_3d', 'gt_labels_3d', 'img', 'fut_valid_flag', + 'ego_his_trajs', 'ego_fut_trajs', 'ego_fut_masks', 'ego_fut_cmd', + 'ego_lcf_feat', 'gt_attr_labels', 'command_id', 'target_point', 'traffic_signal', 'stop_sign_signal'])]) # 'traffic_signal', 'stop_sign_signal' + +] + +data = dict( + samples_per_gpu=1, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + ann_file=data_root + train_pkl, + pipeline=train_pipeline, + classes=class_names, + modality=input_modality, + test_mode=False, + use_valid_flag=True, + bev_size=(bev_h_, bev_w_), + pc_range=point_cloud_range, + queue_length=queue_length, + map_classes=map_classes, + map_fixed_ptsnum_per_line=map_fixed_ptsnum_per_gt_line, + map_eval_use_same_gt_sample_num_flag=map_eval_use_same_gt_sample_num_flag, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='LiDAR', + view_names = view_names, + eval_detection_configs_path=eval_detection_configs_path, + agent_gt_range=point_cloud_range, + map_gt_range=point_cloud_range, + base_path=data_root + ), + val=dict( + type=dataset_type, + data_root=test_data_root, + pc_range=point_cloud_range, + ann_file=test_data_root + test_pkl, + pipeline=test_pipeline, bev_size=(bev_h_, bev_w_), + classes=class_names, modality=input_modality, samples_per_gpu=1, + map_classes=map_classes, + map_ann_file=map_ann_file, + agent_ann_file=agent_ann_file, + map_fixed_ptsnum_per_line=map_fixed_ptsnum_per_gt_line, + map_eval_use_same_gt_sample_num_flag=map_eval_use_same_gt_sample_num_flag, + use_pkl_result=True, + view_names=view_names, + eval_detection_configs_path=eval_detection_configs_path, + agent_gt_range=point_cloud_range, + map_gt_range=point_cloud_range, + base_path=test_data_root + ), + test=dict( + type=dataset_type, + data_root=test_data_root, + pc_range=point_cloud_range, + ann_file=test_data_root + test_pkl, + pipeline=test_pipeline, bev_size=(bev_h_, bev_w_), + classes=class_names, modality=input_modality, samples_per_gpu=1, + map_classes=map_classes, + map_ann_file=map_ann_file, + agent_ann_file=agent_ann_file, + map_fixed_ptsnum_per_line=map_fixed_ptsnum_per_gt_line, + map_eval_use_same_gt_sample_num_flag=map_eval_use_same_gt_sample_num_flag, + use_pkl_result=True, + eval_detection_configs_path=eval_detection_configs_path, + agent_gt_range=point_cloud_range, + map_gt_range=point_cloud_range, + base_path=test_data_root + ), + shuffler_sampler=dict(type='DistributedGroupSampler'), + nonshuffler_sampler=dict(type='DistributedSampler') +) + +optimizer = dict( + type='AdamW', + lr=1e-4, + paramwise_cfg=dict( + custom_keys={ + 'img_backbone': dict(lr_mult=0.1), + }), + weight_decay=0.01) + +optimizer_config = dict(grad_clip=dict(max_norm=35, norm_type=2)) +# learning policy +lr_config = dict( + policy='CosineAnnealing', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + min_lr_ratio=1e-2) + +# evaluation = dict(interval=total_epochs, pipeline=test_pipeline, metric='bbox', map_metric='chamfer') + +# runner = dict(type='EpochBasedRunner', max_epochs=total_epochs) +runner = dict(type='IterBasedRunner', max_iters=120000) + +# load_from = 'ckpts/v0cPIPP_m_sr_20pts_pp_pretrain_e4.pth' +load_from = 'v116_datav18_q12_15000.pth' #'samplingnotlanefollow_iter_40000.pth' #'default_ckpt.pth' +# resume_from = 'ft_datav16_60000.pth' + +log_config = dict( + interval=100, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) +# fp16 = dict(loss_scale=512.) +# find_unused_parameters = True +checkpoint_config = dict(interval=1000, max_keep_ckpts=2) + + +custom_hooks = [dict(type='CustomSetEpochInfoHook')] \ No newline at end of file diff --git a/VADv2/VADv2_head.py b/VADv2/VADv2_head.py new file mode 100644 index 0000000..c39a1a1 --- /dev/null +++ b/VADv2/VADv2_head.py @@ -0,0 +1,2941 @@ +import copy +from math import pi, cos, sin +import cv2 +import numpy as np +import torch +import random +import torch.nn as nn +import torch.nn.functional as F +import matplotlib.pyplot as plt +from skimage.draw import polygon +from mmdet.models import HEADS, build_loss +from mmdet.models.dense_heads import DETRHead +from mmcv.runner import force_fp32, load_checkpoint +from mmcv.utils import TORCH_VERSION, digit_version +from mmdet.core import build_assigner, build_sampler +from mmdet3d.core.bbox.coders import build_bbox_coder +from mmdet.models.utils.transformer import inverse_sigmoid +from mmdet.core.bbox.transforms import bbox_xyxy_to_cxcywh +from mmcv.cnn import Linear, bias_init_with_prob, xavier_init +from mmdet.core import (multi_apply, multi_apply, reduce_mean) +from mmdet3d.models.builder import build_backbone +from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence + +from projects.mmdet3d_plugin.core.bbox.util import normalize_bbox +from projects.mmdet3d_plugin.PIPP.utils.map_utils import ( + normalize_2d_pts, normalize_2d_bbox, denormalize_2d_pts, denormalize_2d_bbox +) +from projects.mmdet3d_plugin.PIPP.utils.functional import pos2posemb2d +from projects.mmdet3d_plugin.PIPP.utils.plan_loss import segments_intersect +from shapely.geometry import LineString + + +# pos_idx_cnt = [0] * 256 +class MLP(nn.Module): + def __init__(self, in_channels, hidden_unit, verbose=False): + super(MLP, self).__init__() + self.mlp = nn.Sequential( + nn.Linear(in_channels, hidden_unit), + nn.LayerNorm(hidden_unit), + nn.ReLU() + ) + + def forward(self, x): + x = self.mlp(x) + return x + +class LaneNet(nn.Module): + def __init__(self, in_channels, hidden_unit, num_subgraph_layers): + super(LaneNet, self).__init__() + self.num_subgraph_layers = num_subgraph_layers + self.layer_seq = nn.Sequential() + for i in range(num_subgraph_layers): + self.layer_seq.add_module( + f'lmlp_{i}', MLP(in_channels, hidden_unit)) + in_channels = hidden_unit*2 + + def forward(self, pts_lane_feats): + ''' + Extract lane_feature from vectorized lane representation + + Args: + pts_lane_feats: [batch size, max_pnum, pts, D] + + Returns: + inst_lane_feats: [batch size, max_pnum, D] + ''' + x = pts_lane_feats + for name, layer in self.layer_seq.named_modules(): + if isinstance(layer, MLP): + # x [bs,max_lane_num,9,dim] + x = layer(x) + x_max = torch.max(x, -2)[0] + x_max = x_max.unsqueeze(2).repeat(1, 1, x.shape[2], 1) + x = torch.cat([x, x_max], dim=-1) + x_max = torch.max(x, -2)[0] + return x_max + + +@HEADS.register_module() +class v116ADTRHead(DETRHead): + """Head of Detr3D. + Args: + with_box_refine (bool): Whether to refine the reference points + in the decoder. Defaults to False. + as_two_stage (bool) : Whether to generate the proposal from + the outputs of encoder. + transformer (obj:`ConfigDict`): ConfigDict is used for building + the Encoder and Decoder. + bev_h, bev_w (int): spatial shape of BEV queries. + """ + # NOTE: already support map + def __init__(self, + *args, + with_box_refine=False, + as_two_stage=False, + transformer=None, + bbox_coder=None, + num_cls_fcs=2, + code_weights=None, + bev_h=30, + bev_w=30, + fut_ts=6, + mot_fut_mode=6, + loss_mot_reg=dict(type='L1Loss', loss_weight=0.25), + loss_mot_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=0.8), + map_bbox_coder=None, + map_num_query=900, + map_num_classes=3, + map_num_vec=20, + map_num_pts_per_vec=2, + map_num_pts_per_gt_vec=2, + map_query_embed_type='all_pts', + map_transform_method='minmax', + map_gt_shift_pts_pattern='v0', + map_dir_interval=1, + map_code_size=None, + map_code_weights=None, + loss_map_cls=dict( + type='CrossEntropyLoss', + bg_cls_weight=0.1, + use_sigmoid=False, + loss_weight=1.0, + class_weight=1.0), + loss_map_bbox=dict(type='L1Loss', loss_weight=5.0), + loss_map_iou=dict(type='GIoULoss', loss_weight=2.0), + loss_map_pts=dict( + type='ChamferDistance',loss_src_weight=1.0,loss_dst_weight=1.0 + ), + loss_map_dir=dict(type='PtsDirCosLoss', loss_weight=2.0), + tot_epoch=None, + mot_decoder=None, + mot_map_decoder=None, + interaction_pe_type='mlp', + mot_det_score=None, + mot_map_thresh=0.5, + mot_dis_thresh=0.2, + pe_normalization=True, + plan_fut_mode=256, + plan_fut_mode_testing=4096, + loss_plan_cls_col=None, + loss_plan_cls_bd=None, + loss_plan_cls_cl=None, + loss_plan_cls_expert=None, + loss_plan_reg=dict(type='L1Loss', loss_weight=0.), + loss_plan_bound=dict(type='PlanMapBoundLoss', loss_weight=0.), + loss_plan_agent_dis=dict(type='PlanAgentDisLoss', loss_weight=0.), + loss_plan_map_theta=dict(type='PlanMapThetaLoss', loss_weight=0.), + loss_tl_status_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0, + class_weight=None), + loss_tl_trigger_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0, + class_weight=None), + loss_stopsign_trigger_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0, + class_weight=None), + ego_pv_decoder=None, + ego_agent_decoder=None, + ego_map_decoder=None, + cf_backbone=None, + cf_backbone_ckpt=None, + ego_query_thresh=None, + query_use_fix_pad=None, + ego_lcf_feat_idx=None, + valid_fut_ts=6, + plan_anchors_path='./plan_anchors_endpoint_242.npy', + **kwargs): + + self.bev_h = bev_h + self.bev_w = bev_w + self.fp16_enabled = False + self.fut_ts = fut_ts + self.mot_fut_mode = mot_fut_mode + self.tot_epoch = tot_epoch + self.mot_decoder = mot_decoder + self.mot_map_decoder = mot_map_decoder + self.interaction_pe_type = interaction_pe_type + self.mot_det_score = mot_det_score + self.mot_map_thresh = mot_map_thresh + self.mot_dis_thresh = mot_dis_thresh + self.pe_normalization = pe_normalization + self.plan_fut_mode = plan_fut_mode + self.plan_fut_mode_testing = plan_fut_mode_testing + self.ego_pv_decoder = ego_pv_decoder + self.ego_agent_decoder = ego_agent_decoder + self.ego_map_decoder = ego_map_decoder + self.ego_query_thresh = ego_query_thresh + self.query_use_fix_pad = query_use_fix_pad + self.ego_lcf_feat_idx = ego_lcf_feat_idx + self.valid_fut_ts = valid_fut_ts + self.cf_backbone = cf_backbone + self.cf_backbone_ckpt = cf_backbone_ckpt + self.plan_anchors = np.load(plan_anchors_path) + self.plan_anchors = torch.from_numpy(self.plan_anchors).to(torch.float32).cuda() + + self.traj_selected_cnt = torch.zeros(self.plan_anchors.shape[0]).to(torch.float32).cuda() + + + if loss_mot_cls['use_sigmoid'] == True: + self.mot_num_cls = 1 # dont need to consider cls num here + else: + self.mot_num_cls = 2 + + self.tl_status_num_cls = 3 # Green, Red, Yellow + self.tl_trigger_num_cls = 1 + self.stopsign_trigger_num_cls = 5 + + self.with_box_refine = with_box_refine + self.as_two_stage = as_two_stage + if self.as_two_stage: + transformer['as_two_stage'] = self.as_two_stage + if 'code_size' in kwargs: + self.code_size = kwargs['code_size'] + else: + self.code_size = 10 + if code_weights is not None: + self.code_weights = code_weights + else: + self.code_weights = [1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2] + if map_code_size is not None: + self.map_code_size = map_code_size + else: + self.map_code_size = 10 + if map_code_weights is not None: + self.map_code_weights = map_code_weights + else: + self.map_code_weights = [1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2] + self.bbox_coder = build_bbox_coder(bbox_coder) + self.pc_range = self.bbox_coder.pc_range + self.real_w = self.pc_range[3] - self.pc_range[0] + self.real_h = self.pc_range[4] - self.pc_range[1] + self.num_cls_fcs = num_cls_fcs - 1 + + self.map_bbox_coder = build_bbox_coder(map_bbox_coder) + self.map_query_embed_type = map_query_embed_type + self.map_transform_method = map_transform_method + self.map_gt_shift_pts_pattern = map_gt_shift_pts_pattern + map_num_query = map_num_vec * map_num_pts_per_vec + self.map_num_query = map_num_query + self.map_num_classes = map_num_classes + self.map_num_vec = map_num_vec + self.map_num_pts_per_vec = map_num_pts_per_vec + self.map_num_pts_per_gt_vec = map_num_pts_per_gt_vec + self.map_dir_interval = map_dir_interval + + if loss_map_cls['use_sigmoid'] == True: + self.map_cls_out_channels = map_num_classes + else: + self.map_cls_out_channels = map_num_classes + 1 + + self.map_bg_cls_weight = 0 + map_class_weight = loss_map_cls.get('class_weight', None) + if map_class_weight is not None and (self.__class__ is v116ADTRHead): + assert isinstance(map_class_weight, float), 'Expected ' \ + 'class_weight to have type float. Found ' \ + f'{type(map_class_weight)}.' + # NOTE following the official DETR rep0, bg_cls_weight means + # relative classification weight of the no-object class. + map_bg_cls_weight = loss_map_cls.get('bg_cls_weight', map_class_weight) + assert isinstance(map_bg_cls_weight, float), 'Expected ' \ + 'bg_cls_weight to have type float. Found ' \ + f'{type(map_bg_cls_weight)}.' + map_class_weight = torch.ones(map_num_classes + 1) * map_class_weight + # set background class as the last indice + map_class_weight[map_num_classes] = map_bg_cls_weight + loss_map_cls.update({'class_weight': map_class_weight}) + if 'bg_cls_weight' in loss_map_cls: + loss_map_cls.pop('bg_cls_weight') + self.map_bg_cls_weight = map_bg_cls_weight + + self.mot_bg_cls_weight = 0 + + super(v116ADTRHead, self).__init__(*args, transformer=transformer, **kwargs) + self.code_weights = nn.Parameter(torch.tensor( + self.code_weights, requires_grad=False), requires_grad=False) + self.map_code_weights = nn.Parameter(torch.tensor( + self.map_code_weights, requires_grad=False), requires_grad=False) + + if kwargs['train_cfg'] is not None: + assert 'map_assigner' in kwargs['train_cfg'], 'map assigner should be provided '\ + 'when train_cfg is set.' + map_assigner = kwargs['train_cfg']['map_assigner'] + assert loss_map_cls['loss_weight'] == map_assigner['cls_cost']['weight'], \ + 'The classification weight for loss and matcher should be' \ + 'exactly the same.' + assert loss_map_bbox['loss_weight'] == map_assigner['reg_cost'][ + 'weight'], 'The regression L1 weight for loss and matcher ' \ + 'should be exactly the same.' + assert loss_map_iou['loss_weight'] == map_assigner['iou_cost']['weight'], \ + 'The regression iou weight for loss and matcher should be' \ + 'exactly the same.' + assert loss_map_pts['loss_weight'] == map_assigner['pts_cost']['weight'], \ + 'The regression l1 weight for map pts loss and matcher should be' \ + 'exactly the same.' + + self.map_assigner = build_assigner(map_assigner) + # DETR sampling=False, so use PseudoSampler + sampler_cfg = dict(type='PseudoSampler') + self.map_sampler = build_sampler(sampler_cfg, context=self) + + self.loss_mot_reg = build_loss(loss_mot_reg) + self.loss_mot_cls = build_loss(loss_mot_cls) + self.loss_map_bbox = build_loss(loss_map_bbox) + self.loss_map_cls = build_loss(loss_map_cls) + self.loss_map_iou = build_loss(loss_map_iou) + self.loss_map_pts = build_loss(loss_map_pts) + self.loss_map_dir = build_loss(loss_map_dir) + self.loss_plan_cls_col = build_loss(loss_plan_cls_col) + self.loss_plan_cls_bd = build_loss(loss_plan_cls_bd) + self.loss_plan_cls_cl = build_loss(loss_plan_cls_cl) + self.loss_plan_cls_expert = build_loss(loss_plan_cls_expert) + + self.loss_plan_reg = build_loss(loss_plan_reg) + self.loss_plan_bound = build_loss(loss_plan_bound) + self.loss_plan_agent_dis = build_loss(loss_plan_agent_dis) + self.loss_plan_map_theta = build_loss(loss_plan_map_theta) + self.loss_tl_status_cls = build_loss(loss_tl_status_cls) + self.loss_tl_trigger_cls = build_loss(loss_tl_trigger_cls) + self.loss_stopsign_trigger_cls = build_loss(loss_stopsign_trigger_cls) + + + # NOTE: already support map + def _init_layers(self): + """Initialize classification branch and regression branch of head.""" + cls_branch = [] + for _ in range(self.num_reg_fcs): + cls_branch.append(Linear(self.embed_dims, self.embed_dims)) + cls_branch.append(nn.LayerNorm(self.embed_dims)) + cls_branch.append(nn.ReLU(inplace=True)) + cls_branch.append(Linear(self.embed_dims, self.cls_out_channels)) + cls_branch = nn.Sequential(*cls_branch) + + reg_branch = [] + for _ in range(self.num_reg_fcs): + reg_branch.append(Linear(self.embed_dims, self.embed_dims)) + reg_branch.append(nn.ReLU()) + reg_branch.append(Linear(self.embed_dims, self.code_size)) + reg_branch = nn.Sequential(*reg_branch) + + mot_reg_branch = [] + for _ in range(self.num_reg_fcs): + mot_reg_branch.append(Linear(self.embed_dims*2, self.embed_dims*2)) + mot_reg_branch.append(nn.ReLU()) + mot_reg_branch.append(Linear(self.embed_dims*2, self.fut_ts*2)) + mot_reg_branch = nn.Sequential(*mot_reg_branch) + + mot_cls_branch = [] + for _ in range(self.num_reg_fcs): + mot_cls_branch.append(Linear(self.embed_dims*2, self.embed_dims*2)) + mot_cls_branch.append(nn.LayerNorm(self.embed_dims*2)) + mot_cls_branch.append(nn.ReLU(inplace=True)) + mot_cls_branch.append(Linear(self.embed_dims*2, self.mot_num_cls)) + mot_cls_branch = nn.Sequential(*mot_cls_branch) + + map_cls_branch = [] + for _ in range(self.num_reg_fcs): + map_cls_branch.append(Linear(self.embed_dims, self.embed_dims)) + map_cls_branch.append(nn.LayerNorm(self.embed_dims)) + map_cls_branch.append(nn.ReLU(inplace=True)) + map_cls_branch.append(Linear(self.embed_dims, self.map_cls_out_channels)) + map_cls_branch = nn.Sequential(*map_cls_branch) + + map_reg_branch = [] + for _ in range(self.num_reg_fcs): + map_reg_branch.append(Linear(self.embed_dims, self.embed_dims)) + map_reg_branch.append(nn.ReLU()) + map_reg_branch.append(Linear(self.embed_dims, self.map_code_size)) + map_reg_branch = nn.Sequential(*map_reg_branch) + + + ego_query_pre_branch = [] + ego_query_pre_branch.append(Linear(self.embed_dims * self.fut_ts, self.embed_dims)) + ego_query_pre_branch.append(nn.ReLU()) + ego_query_pre_branch.append(Linear(self.embed_dims, self.embed_dims)) + self.ego_query_pre_branch = nn.Sequential(*ego_query_pre_branch) + + + def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + # last reg_branch is used to generate proposal from + # encode feature map when as_two_stage is True. + num_decoder_layers = 1 + num_map_decoder_layers = 1 + if self.transformer.decoder is not None: + num_decoder_layers = self.transformer.decoder.num_layers + if self.transformer.map_decoder is not None: + num_map_decoder_layers = self.transformer.map_decoder.num_layers + num_mot_decoder_layers = 1 + num_pred = (num_decoder_layers + 1) if \ + self.as_two_stage else num_decoder_layers + mot_num_pred = (num_mot_decoder_layers + 1) if \ + self.as_two_stage else num_mot_decoder_layers + map_num_pred = (num_map_decoder_layers + 1) if \ + self.as_two_stage else num_map_decoder_layers + + if self.with_box_refine: + self.cls_branches = _get_clones(cls_branch, num_pred) + self.reg_branches = _get_clones(reg_branch, num_pred) + self.mot_reg_branches = _get_clones(mot_reg_branch, mot_num_pred) + self.mot_cls_branches = _get_clones(mot_cls_branch, mot_num_pred) + self.map_cls_branches = _get_clones(map_cls_branch, map_num_pred) + self.map_reg_branches = _get_clones(map_reg_branch, map_num_pred) + else: + self.cls_branches = nn.ModuleList( + [cls_branch for _ in range(num_pred)]) + self.reg_branches = nn.ModuleList( + [reg_branch for _ in range(num_pred)]) + self.mot_reg_branches = nn.ModuleList( + [mot_reg_branch for _ in range(mot_num_pred)]) + self.mot_cls_branches = nn.ModuleList( + [mot_cls_branch for _ in range(mot_num_pred)]) + self.map_cls_branches = nn.ModuleList( + [map_cls_branch for _ in range(map_num_pred)]) + self.map_reg_branches = nn.ModuleList( + [map_reg_branch for _ in range(map_num_pred)]) + + if not self.as_two_stage: + self.bev_embedding = nn.Embedding( + self.bev_h * self.bev_w, self.embed_dims) + self.query_embedding = nn.Embedding(self.num_query, + self.embed_dims * 2) + if self.map_query_embed_type == 'all_pts': + self.map_query_embedding = nn.Embedding(self.map_num_query, + self.embed_dims * 2) + elif self.map_query_embed_type == 'instance_pts': + self.map_query_embedding = None + self.map_instance_embedding = nn.Embedding(self.map_num_vec, self.embed_dims * 2) + self.map_pts_embedding = nn.Embedding(self.map_num_pts_per_vec, self.embed_dims * 2) + + if self.mot_decoder is not None: + self.mot_decoder = build_transformer_layer_sequence(self.mot_decoder) + self.mot_mode_query = nn.Embedding(self.mot_fut_mode, self.embed_dims) + self.mot_mode_query.weight.requires_grad = True + else: + raise NotImplementedError('Not implement yet') + + if self.mot_map_decoder is not None: + self.lane_encoder = LaneNet(self.embed_dims, self.embed_dims // 2, 3) + self.mot_map_decoder = build_transformer_layer_sequence(self.mot_map_decoder) + + # self.ego_query = nn.Embedding(self.plan_fut_mode, self.embed_dims) + # self.ego_query = pos2posemb2d(self.plan_anchors.reshape(1, self.plan_fut_mode * self.fut_ts, -1)) \ + # .reshape(self.plan_fut_mode, self.fut_ts, -1) + + if self.ego_pv_decoder is not None: + self.ego_pv_decoder = build_transformer_layer_sequence(self.ego_pv_decoder) + MAXNUM_PV_TOKEN = 800 + self.pv_pos_embedding = nn.Embedding( + MAXNUM_PV_TOKEN, self.embed_dims) + + if self.ego_agent_decoder is not None: + self.ego_agent_decoder = build_transformer_layer_sequence(self.ego_agent_decoder) + + if self.ego_map_decoder is not None: + self.ego_map_decoder = build_transformer_layer_sequence(self.ego_map_decoder) + + plan_reg_branch = [] + # plan_in_dim = self.embed_dims*4 + len(self.ego_lcf_feat_idx) \ + # if self.ego_lcf_feat_idx is not None else self.embed_dims*4 + for _ in range(self.num_reg_fcs): + plan_reg_branch.append(Linear(self.embed_dims, self.embed_dims)) + plan_reg_branch.append(nn.ReLU()) + plan_reg_branch.append(Linear(self.embed_dims, self.fut_ts*2)) + self.plan_reg_branch = nn.Sequential(*plan_reg_branch) + + self.fus_mlp = nn.Sequential( + nn.Linear(self.mot_fut_mode*2*self.embed_dims, self.embed_dims, bias=True), + nn.LayerNorm(self.embed_dims), + nn.ReLU(), + nn.Linear(self.embed_dims, self.embed_dims, bias=True)) + + if self.interaction_pe_type == 'sine_mlp': + pe_embed_mlps = nn.Sequential( + nn.Linear(self.embed_dims, self.embed_dims*2), + nn.ReLU(), + nn.Linear(self.embed_dims*2, self.embed_dims), + ) + elif self.interaction_pe_type == 'mlp': + pe_embed_mlps = nn.Linear(2, self.embed_dims) + else: + raise NotImplementedError('Not implement yet') + + self.pe_embed_mlps = _get_clones(pe_embed_mlps, 4) + + self.ego_feat_projs = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + ), # for agent + nn.Sequential( + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + ), # for map + nn.Sequential( + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + ), # for traffic light + nn.Sequential( + nn.Linear(2, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + ), # for target point + nn.Sequential( + nn.Linear(len(self.ego_lcf_feat_idx), self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + ), # for lcf feat + nn.Sequential( + nn.Linear(6, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + ), # for cmdid + nn.Sequential( + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + ), # for pv_feat + nn.Sequential( + nn.Linear(140 * 140, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + nn.ReLU(inplace=True), + nn.Linear(self.embed_dims, self.embed_dims, bias=True), + ), # for target point rasterized + ] + ) + + # NOTE: add front-view feature encoder + if self.cf_backbone is not None: + self.cf_backbone = build_backbone(self.cf_backbone) + + tl_feats_branch = [] + tl_feats_branch.append(Linear(2048 * 10 * 13 + 6 * self.embed_dims * 5 * 7, self.embed_dims)) #10 * 13 5 * 7 + tl_feats_branch.append(nn.LayerNorm(self.embed_dims)) + tl_feats_branch.append(nn.ReLU(inplace=True)) + tl_feats_branch.append(Linear(self.embed_dims, self.embed_dims)) + tl_feats_branch.append(nn.LayerNorm(self.embed_dims)) + tl_feats_branch.append(nn.ReLU(inplace=True)) + self.tl_feats_branch = nn.Sequential(*tl_feats_branch) + + tl_status_cls_branch = [] + tl_status_cls_branch.append(Linear(self.embed_dims, self.embed_dims)) + tl_status_cls_branch.append(nn.LayerNorm(self.embed_dims)) + tl_status_cls_branch.append(nn.ReLU(inplace=True)) + tl_status_cls_branch.append(Linear(self.embed_dims, self.tl_status_num_cls)) + self.tl_status_cls_branch = nn.Sequential(*tl_status_cls_branch) + + tl_trigger_cls_branch = [] + tl_trigger_cls_branch.append(Linear(self.embed_dims, self.embed_dims)) + tl_trigger_cls_branch.append(nn.LayerNorm(self.embed_dims)) + tl_trigger_cls_branch.append(nn.ReLU(inplace=True)) + tl_trigger_cls_branch.append(Linear(self.embed_dims, self.tl_trigger_num_cls)) + self.tl_trigger_cls_branch = nn.Sequential(*tl_trigger_cls_branch) + + stopsign_trigger_cls_branch = [] + stopsign_trigger_cls_branch.append(Linear(self.embed_dims, self.embed_dims)) + stopsign_trigger_cls_branch.append(nn.LayerNorm(self.embed_dims)) + stopsign_trigger_cls_branch.append(nn.ReLU(inplace=True)) + stopsign_trigger_cls_branch.append(Linear(self.embed_dims, self.stopsign_trigger_num_cls)) + self.stopsign_trigger_cls_branch = nn.Sequential(*stopsign_trigger_cls_branch) + + plan_cls_col_branch = [] + # plan_cls_col_branch.append(Linear(self.embed_dims, self.embed_dims)) + # plan_cls_col_branch.append(nn.LayerNorm(self.embed_dims)) + # plan_cls_col_branch.append(nn.ReLU(inplace=True)) + plan_cls_col_branch.append(Linear(self.embed_dims, 1)) + self.plan_cls_col_branch = nn.Sequential(*plan_cls_col_branch) + + plan_cls_bd_branch = [] + # plan_cls_bd_branch.append(Linear(self.embed_dims, self.embed_dims)) + # plan_cls_bd_branch.append(nn.LayerNorm(self.embed_dims)) + # plan_cls_bd_branch.append(nn.ReLU(inplace=True)) + plan_cls_bd_branch.append(Linear(self.embed_dims, 1)) + self.plan_cls_bd_branch = nn.Sequential(*plan_cls_bd_branch) + + plan_cls_cl_branch = [] + # plan_cls_cl_branch.append(Linear(self.embed_dims, self.embed_dims)) + # plan_cls_cl_branch.append(nn.LayerNorm(self.embed_dims)) + # plan_cls_cl_branch.append(nn.ReLU(inplace=True)) + plan_cls_cl_branch.append(Linear(self.embed_dims, 1)) + self.plan_cls_cl_branch = nn.Sequential(*plan_cls_cl_branch) + + plan_cls_expert_branch = [] + plan_cls_expert_branch.append(Linear(self.embed_dims, self.embed_dims)) + plan_cls_expert_branch.append(nn.LayerNorm(self.embed_dims)) + plan_cls_expert_branch.append(nn.ReLU(inplace=True)) + plan_cls_expert_branch.append(Linear(self.embed_dims, 1)) + self.plan_cls_expert_branch = nn.Sequential(*plan_cls_expert_branch) + + # NOTE: already support map + def init_weights(self): + """Initialize weights of the DeformDETR head.""" + self.transformer.init_weights() + if self.loss_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + for m in self.cls_branches: + nn.init.constant_(m[-1].bias, bias_init) + if self.loss_map_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + for m in self.map_cls_branches: + nn.init.constant_(m[-1].bias, bias_init) + if self.loss_mot_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + for m in self.mot_cls_branches: + nn.init.constant_(m[-1].bias, bias_init) + if self.loss_tl_status_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + nn.init.constant_(self.tl_status_cls_branch[-1].bias, bias_init) + if self.loss_tl_trigger_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + nn.init.constant_(self.tl_trigger_cls_branch[-1].bias, bias_init) + if self.loss_stopsign_trigger_cls.use_sigmoid: + bias_init = bias_init_with_prob(0.01) + nn.init.constant_(self.stopsign_trigger_cls_branch[-1].bias, bias_init) + # if self.plan_cls_col_branch.use_sigmoid: + # bias_init = bias_init_with_prob(0.01) + # nn.init.constant_(self.plan_cls_col_branch[-1].bias, bias_init) + # if self.plan_cls_bd_branch.use_sigmoid: + # bias_init = bias_init_with_prob(0.01) + # nn.init.constant_(self.plan_cls_bd_branch[-1].bias, bias_init) + # for m in self.map_reg_branches: + # constant_init(m[-1], 0, bias=0) + # nn.init.constant_(self.map_reg_branches[0][-1].bias.data[2:], 0.) + if self.mot_decoder is not None: + for p in self.mot_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + nn.init.orthogonal_(self.mot_mode_query.weight) + if self.mot_map_decoder is not None: + for p in self.mot_map_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for p in self.lane_encoder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + if self.ego_pv_decoder is not None: + for p in self.ego_pv_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + if self.ego_agent_decoder is not None: + for p in self.ego_agent_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + if self.ego_map_decoder is not None: + for p in self.ego_map_decoder.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + if self.interaction_pe_type is not None: + for emb_mlp in self.pe_embed_mlps: + xavier_init(emb_mlp, distribution='uniform', bias=0.) + # if self.cf_backbone is not None: + # load_checkpoint(self.cf_backbone, self.cf_backbone_ckpt, map_location='cpu') + + # NOTE: already support map + # @auto_fp16(apply_to=('mlvl_feats')) + @force_fp32(apply_to=('mlvl_feats', 'prev_bev')) + def forward(self, + mlvl_feats, + img_metas, + prev_bev=None, + only_bev=False, + ego_his_trajs=None, + ego_lcf_feat=None, + cf_img=None, + command_wp=None, + command_id=None, + target_point=None, + ego_fut_trajs=None + ): + """Forward function. + Args: + mlvl_feats (tuple[Tensor]): Features from the upstream + network, each is a 5D-tensor with shape + (B, N, C, H, W). + prev_bev: previous bev featues + only_bev: only compute BEV features with encoder. + Returns: + all_cls_scores (Tensor): Outputs from the classification head, \ + shape [nb_dec, bs, num_query, cls_out_channels]. Note \ + cls_out_channels should includes background. + all_bbox_preds (Tensor): Sigmoid outputs from the regression \ + head with normalized coordinate format (cx, cy, w, l, cz, h, theta, vx, vy). \ + Shape [nb_dec, bs, num_query, 9]. + """ + + if not only_bev and not self.training: + self.plan_fut_mode = self.plan_fut_mode_testing + + bs, num_cam, _, _, _ = mlvl_feats[0].shape + dtype = mlvl_feats[0].dtype + object_query_embeds = self.query_embedding.weight.to(dtype) + + # import pdb;pdb.set_trace() + if self.map_query_embed_type == 'all_pts': + map_query_embeds = self.map_query_embedding.weight.to(dtype) + elif self.map_query_embed_type == 'instance_pts': + map_pts_embeds = self.map_pts_embedding.weight.unsqueeze(0) + map_instance_embeds = self.map_instance_embedding.weight.unsqueeze(1) + map_query_embeds = (map_pts_embeds + map_instance_embeds).flatten(0, 1).to(dtype) + + bev_queries = self.bev_embedding.weight.to(dtype) + + bev_mask = torch.zeros((bs, self.bev_h, self.bev_w), + device=bev_queries.device).to(dtype) + bev_pos = self.positional_encoding(bev_mask).to(dtype) + + if only_bev: # only use encoder to obtain BEV features, TODO: refine the workaround + return self.transformer.get_bev_features( + mlvl_feats, + bev_queries, + self.bev_h, + self.bev_w, + grid_length=(self.real_h / self.bev_h, + self.real_w / self.bev_w), + bev_pos=bev_pos, + img_metas=img_metas, + prev_bev=prev_bev, + ) + else: + outputs = self.transformer( + mlvl_feats, + bev_queries, + object_query_embeds, + map_query_embeds, + self.bev_h, + self.bev_w, + grid_length=(self.real_h / self.bev_h, + self.real_w / self.bev_w), + bev_pos=bev_pos, + reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501 + cls_branches=self.cls_branches if self.as_two_stage else None, + map_reg_branches=self.map_reg_branches if self.with_box_refine else None, # noqa:E501 + map_cls_branches=self.map_cls_branches if self.as_two_stage else None, + img_metas=img_metas, + prev_bev=prev_bev + ) + + bev_embed, hs, init_reference, inter_references, \ + map_hs, map_init_reference, map_inter_references = outputs + + hs = hs.permute(0, 2, 1, 3) + outputs_classes = [] + outputs_coords = [] + outputs_coords_bev = [] + outputs_mot_trajs = [] + outputs_mot_trajs_classes = [] + + map_hs = map_hs.permute(0, 2, 1, 3) + map_outputs_classes = [] + map_outputs_coords = [] + map_outputs_pts_coords = [] + map_outputs_coords_bev = [] + + for lvl in range(hs.shape[0]): + if lvl == 0: + reference = init_reference + else: + reference = inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + outputs_class = self.cls_branches[lvl](hs[lvl]) + tmp = self.reg_branches[lvl](hs[lvl]) + + # TODO: check the shape of reference + assert reference.shape[-1] == 3 + tmp[..., 0:2] = tmp[..., 0:2] + reference[..., 0:2] + tmp[..., 0:2] = tmp[..., 0:2].sigmoid() + outputs_coords_bev.append(tmp[..., 0:2].clone().detach()) + tmp[..., 4:5] = tmp[..., 4:5] + reference[..., 2:3] + tmp[..., 4:5] = tmp[..., 4:5].sigmoid() + tmp[..., 0:1] = (tmp[..., 0:1] * (self.pc_range[3] - + self.pc_range[0]) + self.pc_range[0]) + tmp[..., 1:2] = (tmp[..., 1:2] * (self.pc_range[4] - + self.pc_range[1]) + self.pc_range[1]) + tmp[..., 4:5] = (tmp[..., 4:5] * (self.pc_range[5] - + self.pc_range[2]) + self.pc_range[2]) + + # TODO: check if using sigmoid + outputs_coord = tmp + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + + for lvl in range(map_hs.shape[0]): + if lvl == 0: + reference = map_init_reference + else: + reference = map_inter_references[lvl - 1] + reference = inverse_sigmoid(reference) + map_outputs_class = self.map_cls_branches[lvl]( + map_hs[lvl].view(bs,self.map_num_vec, self.map_num_pts_per_vec,-1).mean(2) + ) + tmp = self.map_reg_branches[lvl](map_hs[lvl]) + # TODO: check the shape of reference + assert reference.shape[-1] == 2 + tmp[..., 0:2] += reference[..., 0:2] + tmp = tmp.sigmoid() # cx,cy,w,h + map_outputs_coord, map_outputs_pts_coord = self.map_transform_box(tmp) + map_outputs_coords_bev.append(map_outputs_pts_coord.clone().detach()) + map_outputs_classes.append(map_outputs_class) + map_outputs_coords.append(map_outputs_coord) + map_outputs_pts_coords.append(map_outputs_pts_coord) + + if self.mot_decoder is not None: + batch_size, num_agent = outputs_coords_bev[-1].shape[:2] + # mot_query + mot_query = hs[-1].permute(1, 0, 2) # [A, B, D] + mode_query = self.mot_mode_query.weight # [mot_fut_mode, D] + # [M, B, D], M=A*mot_fut_mode + mot_query = (mot_query[:, None, :, :] + mode_query[None, :, None, :]).flatten(0, 1) + + if self.interaction_pe_type is not None: + mot_coords = outputs_coords_bev[-1] # [B, A, 2] + mot_coords = pos2posemb2d(mot_coords, num_pos_feats=self.embed_dims // 2) if self.interaction_pe_type == 'sine_mlp' else mot_coords + mot_pos = self.pe_embed_mlps[0](mot_coords) # [B, A, D] + mot_pos = mot_pos.unsqueeze(2).repeat(1, 1, self.mot_fut_mode, 1).flatten(1, 2) + mot_pos = mot_pos.permute(1, 0, 2) # [M, B, D] + else: + mot_pos = None + + if self.mot_det_score is not None: + mot_score = outputs_classes[-1] + max_mot_score = mot_score.max(dim=-1)[0] + invalid_mot_idx = max_mot_score < self.mot_det_score # [B, A] + invalid_mot_idx = invalid_mot_idx.unsqueeze(2).repeat(1, 1, self.mot_fut_mode).flatten(1, 2) + else: + invalid_mot_idx = None + + mot_hs = self.mot_decoder( + query=mot_query, + key=mot_query, + value=mot_query, + query_pos=mot_pos, + key_pos=mot_pos, + key_padding_mask=invalid_mot_idx) + + if self.mot_map_decoder is not None: + # map preprocess + mot_coords = outputs_coords_bev[-1] # [B, A, 2] + mot_coords = mot_coords.unsqueeze(2).repeat(1, 1, self.mot_fut_mode, 1).flatten(1, 2) + map_query = map_hs[-1].view(batch_size, self.map_num_vec, self.map_num_pts_per_vec, -1) + map_query = self.lane_encoder(map_query) # [B, P, pts, D] -> [B, P, D] + map_score = map_outputs_classes[-1] + map_pos = map_outputs_coords_bev[-1] + map_query, map_pos, key_padding_mask = self.select_and_pad_pred_map( + mot_coords, map_query, map_score, map_pos, + map_thresh=self.mot_map_thresh, dis_thresh=self.mot_dis_thresh, + pe_normalization=self.pe_normalization, use_fix_pad=True) + map_query = map_query.permute(1, 0, 2) # [P, B*M, D] + ca_mot_query = mot_hs.permute(1, 0, 2).flatten(0, 1).unsqueeze(0) + + # position encoding + if self.interaction_pe_type is not None: + (attn_num_query, attn_batch) = ca_mot_query.shape[:2] + mot_pos = torch.zeros((attn_num_query, attn_batch, 2), device=mot_hs.device) + mot_pos = pos2posemb2d(mot_pos, num_pos_feats=self.embed_dims // 2) if self.interaction_pe_type == 'sine_mlp' else mot_pos + mot_pos = self.pe_embed_mlps[1](mot_pos) + map_pos = map_pos.permute(1, 0, 2) + map_pos = pos2posemb2d(map_pos, num_pos_feats=self.embed_dims // 2) if self.interaction_pe_type == 'sine_mlp' else map_pos + map_pos = self.pe_embed_mlps[1](map_pos) + else: + mot_pos, map_pos = None, None + + ca_mot_query = self.mot_map_decoder( + query=ca_mot_query, + key=map_query, + value=map_query, + query_pos=mot_pos, + key_pos=map_pos, + key_padding_mask=key_padding_mask) + else: + ca_mot_query = mot_hs.permute(1, 0, 2).flatten(0, 1).unsqueeze(0) + + batch_size = outputs_coords_bev[-1].shape[0] + mot_hs = mot_hs.permute(1, 0, 2).unflatten( + dim=1, sizes=(num_agent, self.mot_fut_mode) + ) + ca_mot_query = ca_mot_query.squeeze(0).unflatten( + dim=0, sizes=(batch_size, num_agent, self.mot_fut_mode) + ) + mot_hs = torch.cat([mot_hs, ca_mot_query], dim=-1) # [B, A, mot_fut_mode, 2D] + else: + raise NotImplementedError('Not implement yet') + + outputs_traj = self.mot_reg_branches[0](mot_hs) + outputs_mot_trajs.append(outputs_traj) + outputs_mot_class = self.mot_cls_branches[0](mot_hs) + outputs_mot_trajs_classes.append(outputs_mot_class.squeeze(-1)) + + map_outputs_classes = torch.stack(map_outputs_classes) + map_outputs_coords = torch.stack(map_outputs_coords) + map_outputs_pts_coords = torch.stack(map_outputs_pts_coords) + outputs_classes = torch.stack(outputs_classes) + outputs_coords = torch.stack(outputs_coords) + outputs_mot_trajs = torch.stack(outputs_mot_trajs) + outputs_mot_trajs_classes = torch.stack(outputs_mot_trajs_classes) + + # planning + (batch, num_agent) = mot_hs.shape[:2] + + + # kinodynamic filtering + # dynamic voca ego_lcf_feat ego_his_trajs ego_fut_trajs + # from carla_simulation.team_code_autopilot.autopilot import EgoModel + # self.ego_model = EgoModel(dt=0.5) + # vx, vy = ego_lcf_feat[0,0,0,:2] + # spds = (vx**2 + vy**2).sqrt() + # self.ego_model.forward(locs=np.array([0,0]), yaws=0, spds=spds, acts=np.array([-1, 1, 0])) # steer, throt, brake + + Dt = 0.5 + vx, vy, ax, ay = ego_lcf_feat[0,0,0,:4] + v_xy = torch.sqrt(vx**2 + vy**2) + a_xy = torch.sqrt(ax**2 + ay**2) + pred_dis = v_xy * Dt + 1 / 2 * a_xy * Dt**2 + kinodynamic_mask = (torch.norm(self.plan_anchors[:,0,:], dim=-1) - pred_dis).abs() < 100000000000 + used_index = torch.multinomial(kinodynamic_mask.float(), self.plan_fut_mode, replacement=False) + # used_index = torch.LongTensor(random.sample(list(range(self.plan_anchors.shape[0]))[kinodynamic_mask], self.plan_fut_mode)).to(mot_hs.device) + if self.training: + best_match_idx = torch.linalg.norm(ego_fut_trajs[0].cumsum(dim=-2) - self.plan_anchors, dim=-1).sum(dim=-1).argmin() + # torch.linalg.norm(ego_traj_preds[:,:,:] - ego_fut_gt.cumsum(dim=-2)[:,:,:], dim=-1).sum(dim=-1) + if best_match_idx in used_index: + pass + else: + used_index[-1] = best_match_idx + + self.used_plan_anchors = torch.index_select(self.plan_anchors, 0, used_index) + + + # set stop traj to zero + self.used_plan_anchors[self.used_plan_anchors[:,0].norm(dim=-1) < 1e-2] = 0. + # fix one stop traj + self.used_plan_anchors[0] = 0. + + _tmp = pos2posemb2d(self.used_plan_anchors.reshape(1, self.plan_fut_mode * self.fut_ts, -1), num_pos_feats=self.embed_dims // 2) \ + .reshape(self.plan_fut_mode, self.fut_ts, -1) + ego_query = _tmp.unsqueeze(0).repeat(batch, 1, 1, 1).reshape(batch, self.plan_fut_mode, -1) + + + ego_query = self.ego_query_pre_branch(ego_query) + # ego_query = self.ego_query.weight.unsqueeze(0).repeat(batch, 1, 1) + # ego-environment Interaction + # ego<->agent query & pos + agent_conf = outputs_classes[-1] + agent_query = mot_hs.reshape(batch, num_agent, -1) + agent_query = self.fus_mlp(agent_query) # [B, A, mot_fut_mode*2*D] -> [B, A, D] + agent_pos = outputs_coords_bev[-1] + + agent_query, agent_pos, agent_mask = self.select_and_pad_query( + agent_query, agent_pos, agent_conf, + score_thresh=self.ego_query_thresh, + use_fix_pad=self.query_use_fix_pad) + + if self.interaction_pe_type is not None: + ego_agent_pos = torch.ones((batch, ego_query.shape[1], 2), device=ego_query.device)*0.5 # ego in the center + ego_agent_pos = pos2posemb2d(ego_agent_pos, num_pos_feats=self.embed_dims // 2) if self.interaction_pe_type == 'sine_mlp' else ego_agent_pos + ego_agent_pos = self.pe_embed_mlps[2](ego_agent_pos) + agent_pos = pos2posemb2d(agent_pos, num_pos_feats=self.embed_dims // 2) if self.interaction_pe_type == 'sine_mlp' else agent_pos + agent_pos = self.pe_embed_mlps[2](agent_pos) + ego_agent_pos = ego_agent_pos.permute(1, 0, 2) + agent_pos = agent_pos.permute(1, 0, 2) + else: + ego_agent_pos, agent_pos = None, None + + # ego <-> map query & pos + map_query = map_hs[-1].view(batch_size, self.map_num_vec, self.map_num_pts_per_vec, -1) + map_query = self.lane_encoder(map_query) # [B, P, pts, D] -> [B, P, D] + map_conf = map_outputs_classes[-1] + map_pos = map_outputs_coords_bev[-1] + # use the most close pts pos in each map inst as the inst's pos + batch, num_map = map_pos.shape[:2] + map_dis = torch.sqrt(map_pos[..., 0]**2 + map_pos[..., 1]**2) + min_map_pos_idx = map_dis.argmin(dim=-1).flatten() # [B*P] + min_map_pos = map_pos.flatten(0, 1) # [B*P, pts, 2] + min_map_pos = min_map_pos[range(min_map_pos.shape[0]), min_map_pos_idx] # [B*P, 2] + min_map_pos = min_map_pos.view(batch, num_map, 2) # [B, P, 2] + map_query, map_pos, map_mask = self.select_and_pad_query( + map_query, min_map_pos, map_conf, + score_thresh=self.ego_query_thresh, + use_fix_pad=self.query_use_fix_pad) + + if self.interaction_pe_type is not None: + ego_map_pos = torch.ones((batch, ego_query.shape[1], 2), device=agent_query.device)*0.5 # ego in the center + ego_map_pos = pos2posemb2d(ego_map_pos, num_pos_feats=self.embed_dims // 2) if self.interaction_pe_type == 'sine_mlp' else ego_map_pos + ego_map_pos = self.pe_embed_mlps[3](ego_map_pos) + map_pos = pos2posemb2d(map_pos, num_pos_feats=self.embed_dims // 2) if self.interaction_pe_type == 'sine_mlp' else map_pos + map_pos = self.pe_embed_mlps[3](map_pos) + ego_map_pos = ego_map_pos.permute(1, 0, 2) + map_pos = map_pos.permute(1, 0, 2) + else: + ego_map_pos, map_pos = None, None + + # ego_pv_query = ego_query + # ego <-> pv interaction + batch, _, c_dim, _, _ = mlvl_feats[-1].shape + attn_pv_feats = mlvl_feats[-1].permute(1, 3, 4, 0, 2).reshape(-1, batch, c_dim) + ego_pv_query = self.ego_pv_decoder( + query=ego_query.permute(1, 0, 2), + key=attn_pv_feats, + value=attn_pv_feats, + query_pos=ego_agent_pos, + key_pos=self.pv_pos_embedding.weight[:,None,:].repeat(1, batch, 1)[:attn_pv_feats.shape[0]] + ) + + # ego <-> agent interaction + ego_agent_query = self.ego_agent_decoder( + query=ego_pv_query, + key=agent_query.permute(1, 0, 2), + value=agent_query.permute(1, 0, 2), + query_pos=ego_agent_pos, + key_pos=agent_pos, + key_padding_mask=agent_mask) + + # ego <-> map interaction + ego_map_query = self.ego_map_decoder( + query=ego_agent_query, + key=map_query.permute(1, 0, 2), + value=map_query.permute(1, 0, 2), + query_pos=ego_map_pos, + key_pos=map_pos, + key_padding_mask=map_mask) + + # camera front feat -> embedding + assert cf_img is not None + # (B, 3, 320, 416) -> (B, 3, 160, 320) + cf_img_h, cf_img_w = cf_img.shape[2:] + crop_h = int(cf_img_h/2) + crop_w1, crop_w2 = int(cf_img_w/4), int(cf_img_w*3/4) + front_view_img = cf_img[:, :, :crop_h, crop_w1:crop_w2] + cf_img_feats = self.cf_backbone(cf_img) + if isinstance(cf_img_feats, dict): + cf_img_feats = list(cf_img_feats.values()) + cf_img_feats = torch.cat((cf_img_feats[-1].flatten(1, 3), mlvl_feats[-1].flatten(1, 4)), dim=-1) + cf_img_feats = self.tl_feats_branch(cf_img_feats) + cf_img_feats = cf_img_feats.unsqueeze(1) + + # Ego prediction + assert self.ego_lcf_feat_idx is not None + ego_pv_query = ego_pv_query.permute(1, 0, 2) + ego_agent_query = ego_agent_query.permute(1, 0, 2) + ego_map_query = ego_map_query.permute(1, 0, 2) + ego_pv_feat = self.ego_feat_projs[6](ego_pv_query) + ego_agent_feat = self.ego_feat_projs[0](ego_agent_query) + ego_map_feat = self.ego_feat_projs[1](ego_map_query) + ego_cf_feat = self.ego_feat_projs[2](cf_img_feats.clone().detach()) + # ego_wp_feat = self.ego_feat_projs[3](target_point.squeeze(1)) + if isinstance(target_point, torch.Tensor): + _tmp_target_point = target_point.unsqueeze(1) + else: + _tmp_target_point = torch.tensor(target_point[None,None], device=ego_cf_feat.device) + # range (-70m, +70m) grid_size 1m + _tmp_rasterized_feat = torch.zeros((batch, 140, 140), dtype=torch.float32, device=ego_cf_feat.device) + + # TODO no need / 2. + _idx = torch.floor((_tmp_target_point.clip(min=-69., max=69.) - (-70.)) / 2.).long() + for i in range(batch): + _tmp_rasterized_feat[i, _idx[i,0,0], _idx[i,0,1]] = 1. + _tmp_rasterized_feat = _tmp_rasterized_feat.reshape(batch, 1, 140 * 140) + ego_wp_feat = self.ego_feat_projs[3](_tmp_target_point) + ego_wp_feat += 1. * self.ego_feat_projs[7](_tmp_rasterized_feat) + + + # [VOID,LEFT,RIGHT,STRAIGHT,LANEFOLLOW,CHANGELANELEFT,CHANGELANERIGHT] + if isinstance(command_id, torch.Tensor): + cmdid_onehot = torch.zeros((batch, 1, 6), device=ego_cf_feat.device, dtype=torch.float32) + assert command_id.max() <= 6 and command_id.min() >= 1 + for i in range(batch): + cmdid_onehot[i, 0, command_id[i]-1] = 1. + ego_cmdid_feat = self.ego_feat_projs[5](cmdid_onehot) + else: + cmdid_onehot = torch.zeros((batch, 1, 6), device=ego_cf_feat.device, dtype=torch.float32) + assert command_id.max() <= 6 and command_id.min() >= 1 + assert batch == 1 + cmdid_onehot[0, 0, command_id - 1] = 1. + ego_cmdid_feat = self.ego_feat_projs[5](cmdid_onehot) + + + ego_status = ego_lcf_feat.squeeze(1)[..., self.ego_lcf_feat_idx] + ego_status_feat = self.ego_feat_projs[4](ego_status) + ego_feats = ego_agent_feat + ego_map_feat + \ + 1. * ego_wp_feat + 0. * ego_cmdid_feat + ego_status_feat + ego_cf_feat + 0. * ego_pv_feat + + outputs_ego_trajs = self.plan_reg_branch(ego_feats) + outputs_ego_trajs = outputs_ego_trajs.reshape(outputs_ego_trajs.shape[0], + self.plan_fut_mode, self.fut_ts, 2) + # if self.training: + # outputs_ego_trajs = outputs_ego_trajs * 0. + self.plan_anchors[None].to(outputs_ego_trajs.device) + # else: + # outputs_ego_trajs = outputs_ego_trajs * 0. + self.centerline_trajs[None] + outputs_ego_trajs = outputs_ego_trajs * 0. + self.used_plan_anchors[None].to(outputs_ego_trajs.device) + + # traffic light classification + tl_status_cls_scores = self.tl_status_cls_branch(cf_img_feats) + tl_trigger_cls_scores = self.tl_trigger_cls_branch(cf_img_feats) + stopsign_trigger_cls_scores = self.stopsign_trigger_cls_branch(cf_img_feats) + + outputs_ego_cls_col = self.plan_cls_col_branch(ego_feats) + outputs_ego_cls_bd = self.plan_cls_bd_branch(ego_feats) + outputs_ego_cls_cl = self.plan_cls_cl_branch(ego_feats) + outputs_ego_cls_expert = self.plan_cls_expert_branch(ego_feats) + + + outs = { + 'bev_embed': bev_embed, + 'all_cls_scores': outputs_classes, + 'all_bbox_preds': outputs_coords, + 'all_mot_preds': outputs_mot_trajs.repeat(outputs_coords.shape[0], 1, 1, 1, 1), + 'all_mot_cls_scores': outputs_mot_trajs_classes.repeat(outputs_coords.shape[0], 1, 1, 1), + 'map_all_cls_scores': map_outputs_classes, + 'map_all_bbox_preds': map_outputs_coords, + 'map_all_pts_preds': map_outputs_pts_coords, + 'enc_cls_scores': None, + 'enc_bbox_preds': None, + 'map_enc_cls_scores': None, + 'map_enc_bbox_preds': None, + 'map_enc_pts_preds': None, + 'ego_fut_preds': outputs_ego_trajs, + 'tl_status_cls_scores': tl_status_cls_scores, + 'tl_trigger_cls_scores': tl_trigger_cls_scores, + 'stopsign_trigger_cls_scores': stopsign_trigger_cls_scores, + 'ego_cls_col_preds': outputs_ego_cls_col, + 'ego_cls_bd_preds': outputs_ego_cls_bd, + 'ego_cls_cl_preds': outputs_ego_cls_cl, + 'ego_cls_expert_preds': outputs_ego_cls_expert, + } + + return outs + + def map_transform_box(self, pts, y_first=False): + """ + Converting the points set into bounding box. + + Args: + pts: the input points sets (fields), each points + set (fields) is represented as 2n scalar. + y_first: if y_fisrt=True, the point set is represented as + [y1, x1, y2, x2 ... yn, xn], otherwise the point set is + represented as [x1, y1, x2, y2 ... xn, yn]. + Returns: + The bbox [cx, cy, w, h] transformed from points. + """ + pts_reshape = pts.view(pts.shape[0], self.map_num_vec, + self.map_num_pts_per_vec,2) + pts_y = pts_reshape[:, :, :, 0] if y_first else pts_reshape[:, :, :, 1] + pts_x = pts_reshape[:, :, :, 1] if y_first else pts_reshape[:, :, :, 0] + if self.map_transform_method == 'minmax': + # import pdb;pdb.set_trace() + + xmin = pts_x.min(dim=2, keepdim=True)[0] + xmax = pts_x.max(dim=2, keepdim=True)[0] + ymin = pts_y.min(dim=2, keepdim=True)[0] + ymax = pts_y.max(dim=2, keepdim=True)[0] + bbox = torch.cat([xmin, ymin, xmax, ymax], dim=2) + bbox = bbox_xyxy_to_cxcywh(bbox) + else: + raise NotImplementedError + return bbox, pts_reshape + + def _get_target_single(self, + cls_score, + bbox_pred, + gt_labels, + gt_bboxes, + gt_attr_labels, + gt_bboxes_ignore=None): + """"Compute regression and classification targets for one image. + Outputs from a single decoder layer of a single feature level are used. + Args: + cls_score (Tensor): Box score logits from a single decoder layer + for one image. Shape [num_query, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from a single decoder layer + for one image, with normalized coordinate (cx, cy, w, h) and + shape [num_query, 10]. + gt_bboxes (Tensor): Ground truth bboxes for one image with + shape (num_gts, 9) in [x,y,z,w,l,h,yaw,vx,vy] format. + gt_labels (Tensor): Ground truth class indices for one image + with shape (num_gts, ). + gt_bboxes_ignore (Tensor, optional): Bounding boxes + which can be ignored. Default None. + Returns: + tuple[Tensor]: a tuple containing the following for one image. + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + + num_bboxes = bbox_pred.size(0) + # assigner and sampler + gt_mot_trajs = gt_attr_labels[:, :self.fut_ts*2] + gt_mot_masks = gt_attr_labels[:, self.fut_ts*2:self.fut_ts*3] + gt_bbox_c = gt_bboxes.shape[-1] + num_gt_bbox, gt_mot_c = gt_mot_trajs.shape + + if digit_version(TORCH_VERSION) >= digit_version('1.8'): + bbox_pred = torch.nan_to_num(bbox_pred) + + assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes, + gt_labels, gt_bboxes_ignore) + + sampling_result = self.sampler.sample(assign_result, bbox_pred, + gt_bboxes) + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label targets + labels = gt_bboxes.new_full((num_bboxes,), + self.num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_bboxes.new_ones(num_bboxes) + + # bbox targets + bbox_targets = torch.zeros_like(bbox_pred)[..., :gt_bbox_c] + bbox_weights = torch.zeros_like(bbox_pred) + bbox_weights[pos_inds] = 1.0 + + # trajs targets + mot_targets = torch.zeros((num_bboxes, gt_mot_c), dtype=torch.float32, device=bbox_pred.device) + mot_weights = torch.zeros_like(mot_targets) + mot_targets[pos_inds] = gt_mot_trajs[sampling_result.pos_assigned_gt_inds] + mot_weights[pos_inds] = 1.0 + + # Filter out invalid fut trajs + mot_masks = torch.zeros_like(mot_targets) # [num_bboxes, fut_ts*2] + gt_mot_masks = gt_mot_masks.unsqueeze(-1).repeat(1, 1, 2).view(num_gt_bbox, -1) # [num_gt_bbox, fut_ts*2] + mot_masks[pos_inds] = gt_mot_masks[sampling_result.pos_assigned_gt_inds] + mot_weights = mot_weights * mot_masks + + # Extra future timestamp mask for controlling pred horizon + fut_ts_mask = torch.zeros((num_bboxes, self.fut_ts, 2), + dtype=torch.float32, device=bbox_pred.device) + fut_ts_mask[:, :self.valid_fut_ts, :] = 1.0 + fut_ts_mask = fut_ts_mask.view(num_bboxes, -1) + mot_weights = mot_weights * fut_ts_mask + + # DETR + bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes + + return ( + labels, label_weights, bbox_targets, bbox_weights, mot_targets, + mot_weights, mot_masks.view(-1, self.fut_ts, 2)[..., 0], + pos_inds, neg_inds + ) + + def _map_get_target_single(self, + cls_score, + bbox_pred, + pts_pred, + gt_labels, + gt_bboxes, + gt_shifts_pts, + gt_bboxes_ignore=None): + """"Compute regression and classification targets for one image. + Outputs from a single decoder layer of a single feature level are used. + Args: + cls_score (Tensor): Box score logits from a single decoder layer + for one image. Shape [num_query, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from a single decoder layer + for one image, with normalized coordinate (cx, cy, w, h) and + shape [num_query, 4]. + gt_bboxes (Tensor): Ground truth bboxes for one image with + shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels (Tensor): Ground truth class indices for one image + with shape (num_gts, ). + gt_bboxes_ignore (Tensor, optional): Bounding boxes + which can be ignored. Default None. + Returns: + tuple[Tensor]: a tuple containing the following for one image. + - labels (Tensor): Labels of each image. + - label_weights (Tensor]): Label weights of each image. + - bbox_targets (Tensor): BBox targets of each image. + - bbox_weights (Tensor): BBox weights of each image. + - pos_inds (Tensor): Sampled positive indices for each image. + - neg_inds (Tensor): Sampled negative indices for each image. + """ + # import pdb;pdb.set_trace() + num_bboxes = bbox_pred.size(0) + # assigner and sampler + gt_c = gt_bboxes.shape[-1] + # import pdb;pdb.set_trace() + assign_result, order_index = self.map_assigner.assign(bbox_pred, cls_score, pts_pred, + gt_bboxes, gt_labels, gt_shifts_pts, + gt_bboxes_ignore) + + sampling_result = self.map_sampler.sample(assign_result, bbox_pred, + gt_bboxes) + # pts_sampling_result = self.sampler.sample(assign_result, pts_pred, + # gt_pts) + + + # import pdb;pdb.set_trace() + pos_inds = sampling_result.pos_inds + neg_inds = sampling_result.neg_inds + + # label targets + labels = gt_bboxes.new_full((num_bboxes,), + self.map_num_classes, + dtype=torch.long) + labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds] + label_weights = gt_bboxes.new_ones(num_bboxes) + + # bbox targets + bbox_targets = torch.zeros_like(bbox_pred)[..., :gt_c] + bbox_weights = torch.zeros_like(bbox_pred) + bbox_weights[pos_inds] = 1.0 + + # pts targets + # import pdb;pdb.set_trace() + # pts_targets = torch.zeros_like(pts_pred) + # num_query, num_order, num_points, num_coords + if order_index is None: + # import pdb;pdb.set_trace() + assigned_shift = gt_labels[sampling_result.pos_assigned_gt_inds] + else: + assigned_shift = order_index[sampling_result.pos_inds, sampling_result.pos_assigned_gt_inds] + pts_targets = pts_pred.new_zeros((pts_pred.size(0), + pts_pred.size(1), pts_pred.size(2))) + pts_weights = torch.zeros_like(pts_targets) + pts_weights[pos_inds] = 1.0 + + # DETR + bbox_targets[pos_inds] = sampling_result.pos_gt_bboxes + pts_targets[pos_inds] = gt_shifts_pts[sampling_result.pos_assigned_gt_inds,assigned_shift,:,:] + return (labels, label_weights, bbox_targets, bbox_weights, + pts_targets, pts_weights, + pos_inds, neg_inds) + + def get_targets(self, + cls_scores_list, + bbox_preds_list, + gt_bboxes_list, + gt_labels_list, + gt_attr_labels_list, + gt_bboxes_ignore_list=None): + """"Compute regression and classification targets for a batch image. + Outputs from a single decoder layer of a single feature level are used. + Args: + cls_scores_list (list[Tensor]): Box score logits from a single + decoder layer for each image with shape [num_query, + cls_out_channels]. + bbox_preds_list (list[Tensor]): Sigmoid outputs from a single + decoder layer for each image, with normalized coordinate + (cx, cy, w, h) and shape [num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + Returns: + tuple: a tuple containing the following targets. + - labels_list (list[Tensor]): Labels for all images. + - label_weights_list (list[Tensor]): Label weights for all \ + images. + - bbox_targets_list (list[Tensor]): BBox targets for all \ + images. + - bbox_weights_list (list[Tensor]): BBox weights for all \ + images. + - num_total_pos (int): Number of positive samples in all \ + images. + - num_total_neg (int): Number of negative samples in all \ + images. + """ + assert gt_bboxes_ignore_list is None, \ + 'Only supports for gt_bboxes_ignore setting to None.' + num_imgs = len(cls_scores_list) + gt_bboxes_ignore_list = [ + gt_bboxes_ignore_list for _ in range(num_imgs) + ] + + (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, mot_targets_list, mot_weights_list, + gt_fut_masks_list, pos_inds_list, neg_inds_list) = multi_apply( + self._get_target_single, cls_scores_list, bbox_preds_list, + gt_labels_list, gt_bboxes_list, gt_attr_labels_list, gt_bboxes_ignore_list + ) + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + mot_targets_list, mot_weights_list, gt_fut_masks_list, num_total_pos, num_total_neg) + + def map_get_targets(self, + cls_scores_list, + bbox_preds_list, + pts_preds_list, + gt_bboxes_list, + gt_labels_list, + gt_shifts_pts_list, + gt_bboxes_ignore_list=None): + """"Compute regression and classification targets for a batch image. + Outputs from a single decoder layer of a single feature level are used. + Args: + cls_scores_list (list[Tensor]): Box score logits from a single + decoder layer for each image with shape [num_query, + cls_out_channels]. + bbox_preds_list (list[Tensor]): Sigmoid outputs from a single + decoder layer for each image, with normalized coordinate + (cx, cy, w, h) and shape [num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + Returns: + tuple: a tuple containing the following targets. + - labels_list (list[Tensor]): Labels for all images. + - label_weights_list (list[Tensor]): Label weights for all \ + images. + - bbox_targets_list (list[Tensor]): BBox targets for all \ + images. + - bbox_weights_list (list[Tensor]): BBox weights for all \ + images. + - num_total_pos (int): Number of positive samples in all \ + images. + - num_total_neg (int): Number of negative samples in all \ + images. + """ + assert gt_bboxes_ignore_list is None, \ + 'Only supports for gt_bboxes_ignore setting to None.' + num_imgs = len(cls_scores_list) + gt_bboxes_ignore_list = [ + gt_bboxes_ignore_list for _ in range(num_imgs) + ] + + (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, pts_targets_list, pts_weights_list, + pos_inds_list, neg_inds_list) = multi_apply( + self._map_get_target_single, cls_scores_list, bbox_preds_list,pts_preds_list, + gt_labels_list, gt_bboxes_list, gt_shifts_pts_list, gt_bboxes_ignore_list) + num_total_pos = sum((inds.numel() for inds in pos_inds_list)) + num_total_neg = sum((inds.numel() for inds in neg_inds_list)) + return (labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, pts_targets_list, pts_weights_list, + num_total_pos, num_total_neg) + + def loss_planning(self, + ego_fut_preds, + ego_fut_gt, + ego_fut_masks, + ego_fut_cmd, + lane_preds, + lane_score_preds, + agent_preds, + agent_fut_preds, + agent_score_preds, + agent_fut_cls_preds, + ego_cls_col_preds, + ego_cls_bd_preds, + ego_cls_cl_preds, + ego_cls_expert_preds, + gt_agent_boxes, + gt_agent_feats, + gt_map_pts, + gt_map_labels, + ): + """"Loss function for ego vehicle planning. + Args: + ego_fut_preds (Tensor): [B, num_cmd, fut_ts, 2] + ego_fut_gt (Tensor): [B, fut_ts, 2] + ego_fut_masks (Tensor): [B, fut_ts] + ego_fut_cmd (Tensor): [B, num_cmd] + lane_preds (Tensor): [B, num_vec, num_pts, 2] + lane_score_preds (Tensor): [B, num_vec, 3] + agent_preds (Tensor): [B, num_agent, 2] + agent_fut_preds (Tensor): [B, num_agent, mot_fut_mode, fut_ts, 2] + agent_score_preds (Tensor): [B, num_agent, 10] + agent_fut_cls_scores (Tensor): [B, num_agent, mot_fut_mode] + ego_cls_col_preds (Tensor): [B, num_plan_mode, 1] + ego_cls_bd_preds (Tensor): [B, num_plan_mode, 1] + ego_cls_cl_preds (Tensor): [B, num_plan_mode, 1] + ego_cls_expert_preds (Tensor): [B, num_plan_mode, 1] + + Returns: + loss_plan_cls_col (Tensor): cls col loss. + loss_plan_cls_bd (Tensor): cls bd loss. + loss_plan_cls_cl (Tensor): cls cl loss. + loss_plan_cls_expert (Tensor): cls expert loss. + loss_plan_reg (Tensor): planning l1 loss. + loss_plan_bound (Tensor): planning map boundary loss. + loss_plan_agent_dis (Tensor): planning agent distance loss. + loss_plan_map_theta (Tensor): planning map theta loss. + """ + + batch = ego_fut_preds.shape[0] + ego_fut_gt = ego_fut_gt.unsqueeze(1).repeat(1, self.plan_fut_mode, 1, 1) + # loss_plan_l1_weight = ego_fut_cmd[..., None, None] * ego_fut_masks[:, None, :, None] + # loss_plan_l1_weight = loss_plan_l1_weight.repeat(1, 1, 1, 2) + + # get plan cls target + plan_col_labels, plan_bd_labels, plan_cl_labels = [], [], [] + plan_expert_labels, plan_expert_labels_weight = [], [] + for i in range(batch): + plan_col_label = self.get_plan_col_target( + ego_fut_preds[i].detach(), + ego_fut_gt[i], + gt_agent_boxes[i], + gt_agent_feats[i]) + plan_bd_label = self.get_plan_bd_target( + ego_fut_preds[i].detach(), + gt_map_pts[i], + gt_map_labels[i]) + plan_cl_label = self.get_plan_cl_target( + ego_fut_preds[i].detach(), + gt_map_pts[i], + gt_map_labels[i]) + plan_expert_label, plan_expert_label_weight = self.get_plan_expert_target( + ego_fut_preds[i].detach(), + ego_fut_gt[i], + ego_fut_masks[i], + ego_cls_expert_preds[i], + plan_col_label, + plan_bd_label) + + plan_col_labels.append(plan_col_label) + plan_bd_labels.append(plan_bd_label) + plan_cl_labels.append(plan_cl_label) + plan_expert_labels.append(plan_expert_label) + plan_expert_labels_weight.append(plan_expert_label_weight) + + plan_col_labels = torch.stack(plan_col_labels, dim=0).to(ego_fut_preds.device) + plan_bd_labels = torch.stack(plan_bd_labels, dim=0).to(ego_fut_preds.device) + plan_cl_labels = torch.stack(plan_cl_labels, dim=0).to(ego_fut_preds.device) + plan_expert_labels = torch.stack(plan_expert_labels, dim=0).to(ego_fut_preds.device) + plan_expert_labels_weight = torch.stack(plan_expert_labels_weight, + dim=0).to(ego_fut_preds.device) + + # plan collision classification loss + loss_plan_cls_col = self.loss_plan_cls_col( + ego_cls_col_preds.flatten(0, 1), plan_col_labels.flatten(), + plan_col_labels.new_ones(batch*self.plan_fut_mode), + avg_factor=batch*self.plan_fut_mode) + + # plan boundary overstepping classification loss + loss_plan_cls_bd = self.loss_plan_cls_bd( + ego_cls_bd_preds.flatten(0, 1), plan_bd_labels.flatten(), + plan_bd_labels.new_ones(batch*self.plan_fut_mode), + avg_factor=batch*self.plan_fut_mode) + + # plan centerline consistency classification loss + plan_cl_weight = plan_cl_labels.flatten() + plan_cl_labels_weight = (plan_cl_weight > -1.) + (plan_cl_weight == -1.) * 0.01 + loss_plan_cls_cl = self.loss_plan_cls_cl( + ego_cls_cl_preds.squeeze(-1).flatten(0, 1), plan_cl_labels.flatten(), + plan_cl_labels_weight, + # avg_factor=(plan_cl_labels > -1.).sum() + ) + + # plan expert driving behavior classification loss + loss_plan_cls_expert = self.loss_plan_cls_expert( + ego_cls_expert_preds.flatten(0, 1), plan_expert_labels.flatten(), + plan_expert_labels_weight.flatten(), + avg_factor=batch*self.plan_fut_mode) + + loss_plan_reg = (0. * ego_fut_preds).sum() + # loss_plan_reg = self.loss_plan_reg( + # ego_fut_preds, + # ego_fut_gt, + # loss_plan_l1_weight + # ) + + loss_plan_bound = (0. * ego_fut_preds).sum() + # loss_plan_bound = self.loss_plan_bound( + # ego_fut_preds[ego_fut_cmd==1], + # lane_preds, + # lane_score_preds, + # weight=ego_fut_masks + # ) + + loss_plan_agent_dis = (0. * ego_fut_preds).sum() + # loss_plan_agent_dis = self.loss_plan_agent_dis( + # ego_fut_preds[ego_fut_cmd==1], + # agent_preds, + # agent_fut_preds, + # agent_score_preds, + # agent_fut_cls_preds, + # weight=ego_fut_masks[:, :, None].repeat(1, 1, 2) + # ) + + loss_plan_map_theta = (0. * ego_fut_preds).sum() + # loss_plan_map_theta = self.loss_plan_map_theta( + # ego_fut_preds[ego_fut_cmd==1], + # lane_preds, + # lane_score_preds, + # weight=ego_fut_masks + # ) + + if digit_version(TORCH_VERSION) >= digit_version('1.8'): + loss_plan_cls_col = torch.nan_to_num(loss_plan_cls_col) + loss_plan_cls_bd = torch.nan_to_num(loss_plan_cls_bd) + loss_plan_cls_cl = torch.nan_to_num(loss_plan_cls_cl) + loss_plan_cls_expert = torch.nan_to_num(loss_plan_cls_expert) + loss_plan_reg = torch.nan_to_num(loss_plan_reg) + loss_plan_bound = torch.nan_to_num(loss_plan_bound) + loss_plan_agent_dis = torch.nan_to_num(loss_plan_agent_dis) + loss_plan_map_theta = torch.nan_to_num(loss_plan_map_theta) + + loss_plan_dict = dict() + loss_plan_dict['loss_plan_cls_col'] = loss_plan_cls_col + loss_plan_dict['loss_plan_cls_bd'] = loss_plan_cls_bd + loss_plan_dict['loss_plan_cls_cl'] = loss_plan_cls_cl + loss_plan_dict['loss_plan_cls_expert'] = loss_plan_cls_expert + + loss_plan_dict['loss_plan_reg'] = loss_plan_reg + loss_plan_dict['loss_plan_bound'] = loss_plan_bound + loss_plan_dict['loss_plan_agent_dis'] = loss_plan_agent_dis + loss_plan_dict['loss_plan_map_theta'] = loss_plan_map_theta + + return loss_plan_dict + + def loss_single(self, + cls_scores, + bbox_preds, + mot_preds, + mot_cls_preds, + gt_bboxes_list, + gt_labels_list, + gt_attr_labels_list, + gt_bboxes_ignore_list=None): + """"Loss function for outputs from a single decoder layer of a single + feature level. + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images. Shape [bs, num_query, cls_out_channels]. + bbox_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (cx, cy, w, h) and + shape [bs, num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + Returns: + dict[str, Tensor]: A dictionary of loss components for outputs from + a single decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] + cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list, + gt_bboxes_list, gt_labels_list, + gt_attr_labels_list, gt_bboxes_ignore_list) + + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + mot_targets_list, mot_weights_list, gt_fut_masks_list, + num_total_pos, num_total_neg) = cls_reg_targets + + labels = torch.cat(labels_list, 0) + label_weights = torch.cat(label_weights_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + bbox_weights = torch.cat(bbox_weights_list, 0) + mot_targets = torch.cat(mot_targets_list, 0) + mot_weights = torch.cat(mot_weights_list, 0) + gt_fut_masks = torch.cat(gt_fut_masks_list, 0) + + # classification loss + cls_scores = cls_scores.reshape(-1, self.cls_out_channels) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + \ + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + cls_scores.new_tensor([cls_avg_factor])) + + cls_avg_factor = max(cls_avg_factor, 1) + loss_cls = self.loss_cls(cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + + # Compute the average number of gt boxes accross all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + + # regression L1 loss + bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1)) + normalized_bbox_targets = normalize_bbox(bbox_targets, self.pc_range) + isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1) + bbox_weights = bbox_weights * self.code_weights + loss_bbox = self.loss_bbox( + bbox_preds[isnotnan, :10], + normalized_bbox_targets[isnotnan, :10], + bbox_weights[isnotnan, :10], + avg_factor=num_total_pos + ) + + # mot regression loss + best_mot_preds = self.get_best_fut_preds( + mot_preds.reshape(-1, self.mot_fut_mode, self.fut_ts, 2), + mot_targets.reshape(-1, self.fut_ts, 2), + gt_fut_masks + ) + + neg_inds = (bbox_weights[:, 0] == 0) + mot_labels = self.get_mot_cls_target( + mot_preds.reshape(-1, self.mot_fut_mode, self.fut_ts, 2), + mot_targets.reshape(-1, self.fut_ts, 2), + gt_fut_masks, + neg_inds + ) + + loss_mot_reg = self.loss_mot_reg( + best_mot_preds[isnotnan], + mot_targets[isnotnan], + mot_weights[isnotnan], + avg_factor=num_total_pos + ) + + # mot classification loss + mot_cls_scores = mot_cls_preds.reshape(-1, self.mot_fut_mode) + # construct weighted avg_factor to match with the official DETR repo + mot_cls_avg_factor = num_total_pos * 1.0 + \ + num_total_neg * self.mot_bg_cls_weight + if self.sync_cls_avg_factor: + mot_cls_avg_factor = reduce_mean( + mot_cls_scores.new_tensor([mot_cls_avg_factor])) + + mot_cls_avg_factor = max(mot_cls_avg_factor, 1) + loss_mot_cls = self.loss_mot_cls( + mot_cls_scores, mot_labels, label_weights, avg_factor=mot_cls_avg_factor + ) + + if digit_version(TORCH_VERSION) >= digit_version('1.8'): + loss_cls = torch.nan_to_num(loss_cls) + loss_bbox = torch.nan_to_num(loss_bbox) + loss_mot_reg = torch.nan_to_num(loss_mot_reg) + loss_mot_cls = torch.nan_to_num(loss_mot_cls) + + return loss_cls, loss_bbox, loss_mot_reg, loss_mot_cls + + def get_best_fut_preds(self, + traj_preds, + traj_targets, + gt_fut_masks): + """"Choose best preds among all modes. + Args: + traj_preds (Tensor): MultiModal traj preds with shape (num_box_preds, mot_fut_mode, fut_ts, 2). + traj_targets (Tensor): Ground truth traj for each pred box with shape (num_box_preds, fut_ts, 2). + gt_fut_masks (Tensor): Ground truth traj mask with shape (num_box_preds, fut_ts). + pred_box_centers (Tensor): Pred box centers with shape (num_box_preds, 2). + gt_box_centers (Tensor): Ground truth box centers with shape (num_box_preds, 2). + + Returns: + best_traj_preds (Tensor): best traj preds (min displacement error with gt) + with shape (num_box_preds, fut_ts*2). + """ + + cum_traj_preds = traj_preds.cumsum(dim=-2) + cum_traj_targets = traj_targets.cumsum(dim=-2) + + # Get min pred mode indices. + # (num_box_preds, mot_fut_mode, fut_ts) + dist = torch.linalg.norm(cum_traj_targets[:, None, :, :] - cum_traj_preds, dim=-1) + dist = dist * gt_fut_masks[:, None, :] + dist = dist[..., -1] + dist[torch.isnan(dist)] = dist[torch.isnan(dist)] * 0 + min_mode_idxs = torch.argmin(dist, dim=-1).tolist() + box_idxs = torch.arange(traj_preds.shape[0]).tolist() + best_traj_preds = traj_preds[box_idxs, min_mode_idxs, :, :].reshape(-1, self.fut_ts*2) + + return best_traj_preds + + def get_mot_cls_target(self, + mot_preds, + mot_targets, + gt_fut_masks, + neg_inds): + """"Get motion trajectory mode classification target. + Args: + mot_preds (Tensor): MultiModal traj preds with shape (num_box_preds, mot_fut_mode, fut_ts, 2). + mot_targets (Tensor): Ground truth traj for each pred box with shape (num_box_preds, fut_ts, 2). + gt_fut_masks (Tensor): Ground truth traj mask with shape (num_box_preds, fut_ts). + neg_inds (Tensor): Negtive indices with shape (num_box_preds,) + + Returns: + mot_labels (Tensor): traj cls labels (num_box_preds,). + """ + + cum_mot_preds = mot_preds.cumsum(dim=-2) + cum_mot_targets = mot_targets.cumsum(dim=-2) + + # Get min pred mode indices. + # (num_box_preds, mot_fut_mode, fut_ts) + dist = torch.linalg.norm(cum_mot_targets[:, None, :, :] - cum_mot_preds, dim=-1) + dist = dist * gt_fut_masks[:, None, :] + dist = dist[..., -1] + dist[torch.isnan(dist)] = dist[torch.isnan(dist)] * 0 + mot_labels = torch.argmin(dist, dim=-1) + mot_labels[neg_inds] = self.mot_fut_mode + + return mot_labels + + def map_loss_single(self, + cls_scores, + bbox_preds, + pts_preds, + gt_bboxes_list, + gt_labels_list, + gt_shifts_pts_list, + gt_bboxes_ignore_list=None): + """"Loss function for outputs from a single decoder layer of a single + feature level. + Args: + cls_scores (Tensor): Box score logits from a single decoder layer + for all images. Shape [bs, num_query, cls_out_channels]. + bbox_preds (Tensor): Sigmoid outputs from a single decoder layer + for all images, with normalized coordinate (cx, cy, w, h) and + shape [bs, num_query, 4]. + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + gt_pts_list (list[Tensor]): Ground truth pts for each image + with shape (num_gts, fixed_num, 2) in [x,y] format. + gt_bboxes_ignore_list (list[Tensor], optional): Bounding + boxes which can be ignored for each image. Default None. + Returns: + dict[str, Tensor]: A dictionary of loss components for outputs from + a single decoder layer. + """ + num_imgs = cls_scores.size(0) + cls_scores_list = [cls_scores[i] for i in range(num_imgs)] + bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)] + pts_preds_list = [pts_preds[i] for i in range(num_imgs)] + # import pdb;pdb.set_trace() + cls_reg_targets = self.map_get_targets(cls_scores_list, bbox_preds_list,pts_preds_list, + gt_bboxes_list, gt_labels_list,gt_shifts_pts_list, + gt_bboxes_ignore_list) + (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list, + pts_targets_list, pts_weights_list, + num_total_pos, num_total_neg) = cls_reg_targets + # import pdb;pdb.set_trace() + labels = torch.cat(labels_list, 0) + label_weights = torch.cat(label_weights_list, 0) + bbox_targets = torch.cat(bbox_targets_list, 0) + bbox_weights = torch.cat(bbox_weights_list, 0) + pts_targets = torch.cat(pts_targets_list, 0) + pts_weights = torch.cat(pts_weights_list, 0) + + # classification loss + cls_scores = cls_scores.reshape(-1, self.map_cls_out_channels) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + \ + num_total_neg * self.map_bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + cls_scores.new_tensor([cls_avg_factor])) + + cls_avg_factor = max(cls_avg_factor, 1) + loss_cls = self.loss_map_cls( + cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + + # Compute the average number of gt boxes accross all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + + # import pdb;pdb.set_trace() + # regression L1 loss + bbox_preds = bbox_preds.reshape(-1, bbox_preds.size(-1)) + normalized_bbox_targets = normalize_2d_bbox(bbox_targets, self.pc_range) + # normalized_bbox_targets = bbox_targets + isnotnan = torch.isfinite(normalized_bbox_targets).all(dim=-1) + bbox_weights = bbox_weights * self.map_code_weights + + loss_bbox = self.loss_map_bbox( + bbox_preds[isnotnan, :4], + normalized_bbox_targets[isnotnan,:4], + bbox_weights[isnotnan, :4], + avg_factor=num_total_pos) + + # regression pts CD loss + # pts_preds = pts_preds + # import pdb;pdb.set_trace() + + # num_samples, num_order, num_pts, num_coords + normalized_pts_targets = normalize_2d_pts(pts_targets, self.pc_range) + + # num_samples, num_pts, num_coords + pts_preds = pts_preds.reshape(-1, pts_preds.size(-2), pts_preds.size(-1)) + if self.map_num_pts_per_vec != self.map_num_pts_per_gt_vec: + pts_preds = pts_preds.permute(0,2,1) + pts_preds = F.interpolate(pts_preds, size=(self.map_num_pts_per_gt_vec), mode='linear', + align_corners=True) + pts_preds = pts_preds.permute(0,2,1).contiguous() + + # import pdb;pdb.set_trace() + loss_pts = self.loss_map_pts( + pts_preds[isnotnan,:,:], + normalized_pts_targets[isnotnan,:,:], + pts_weights[isnotnan,:,:], + avg_factor=num_total_pos) + + dir_weights = pts_weights[:, :-self.map_dir_interval,0] + denormed_pts_preds = denormalize_2d_pts(pts_preds, self.pc_range) + denormed_pts_preds_dir = denormed_pts_preds[:,self.map_dir_interval:,:] - \ + denormed_pts_preds[:,:-self.map_dir_interval,:] + pts_targets_dir = pts_targets[:, self.map_dir_interval:,:] - pts_targets[:,:-self.map_dir_interval,:] + # dir_weights = pts_weights[:, indice,:-1,0] + # import pdb;pdb.set_trace() + loss_dir = self.loss_map_dir( + denormed_pts_preds_dir[isnotnan,:,:], + pts_targets_dir[isnotnan,:,:], + dir_weights[isnotnan,:], + avg_factor=num_total_pos) + + bboxes = denormalize_2d_bbox(bbox_preds, self.pc_range) + # regression IoU loss, defaultly GIoU loss + loss_iou = self.loss_map_iou( + bboxes[isnotnan, :4], + bbox_targets[isnotnan, :4], + bbox_weights[isnotnan, :4], + avg_factor=num_total_pos) + + if digit_version(TORCH_VERSION) >= digit_version('1.8'): + loss_cls = torch.nan_to_num(loss_cls) + loss_bbox = torch.nan_to_num(loss_bbox) + loss_iou = torch.nan_to_num(loss_iou) + loss_pts = torch.nan_to_num(loss_pts) + loss_dir = torch.nan_to_num(loss_dir) + return loss_cls, loss_bbox, loss_iou, loss_pts, loss_dir + + # NOTE: already support map + @force_fp32(apply_to=('preds_dicts')) + def loss(self, + gt_bboxes_list, + gt_labels_list, + map_gt_bboxes_list, + map_gt_labels_list, + preds_dicts, + ego_fut_gt, + ego_fut_masks, + ego_fut_cmd, + gt_attr_labels, + traffic_signal, + stop_sign_signal, + gt_bboxes_ignore=None, + map_gt_bboxes_ignore=None, + img_metas=None): + """"Loss function. + Args: + + gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image + with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. + gt_labels_list (list[Tensor]): Ground truth class indices for each + image with shape (num_gts, ). + preds_dicts: + all_cls_scores (Tensor): Classification score of all + decoder layers, has shape + [nb_dec, bs, num_query, cls_out_channels]. + all_bbox_preds (Tensor): Sigmoid regression + outputs of all decode layers. Each is a 4D-tensor with + normalized coordinate format (cx, cy, w, h) and shape + [nb_dec, bs, num_query, 4]. + enc_cls_scores (Tensor): Classification scores of + points on encode feature map , has shape + (N, h*w, num_classes). Only be passed when as_two_stage is + True, otherwise is None. + enc_bbox_preds (Tensor): Regression results of each points + on the encode feature map, has shape (N, h*w, 4). Only be + passed when as_two_stage is True, otherwise is None. + gt_bboxes_ignore (list[Tensor], optional): Bounding boxes + which can be ignored for each image. Default None. + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + assert gt_bboxes_ignore is None, \ + f'{self.__class__.__name__} only supports ' \ + f'for gt_bboxes_ignore setting to None.' + + map_gt_vecs_list = copy.deepcopy(map_gt_bboxes_list) + + all_cls_scores = preds_dicts['all_cls_scores'] + all_bbox_preds = preds_dicts['all_bbox_preds'] + all_mot_preds = preds_dicts['all_mot_preds'] + all_mot_cls_scores = preds_dicts['all_mot_cls_scores'] + enc_cls_scores = preds_dicts['enc_cls_scores'] + enc_bbox_preds = preds_dicts['enc_bbox_preds'] + map_all_cls_scores = preds_dicts['map_all_cls_scores'] + map_all_bbox_preds = preds_dicts['map_all_bbox_preds'] + map_all_pts_preds = preds_dicts['map_all_pts_preds'] + map_enc_cls_scores = preds_dicts['map_enc_cls_scores'] + map_enc_bbox_preds = preds_dicts['map_enc_bbox_preds'] + map_enc_pts_preds = preds_dicts['map_enc_pts_preds'] + ego_fut_preds = preds_dicts['ego_fut_preds'] + tl_status_cls_scores = preds_dicts['tl_status_cls_scores'] + tl_trigger_cls_scores = preds_dicts['tl_trigger_cls_scores'] + stopsign_trigger_cls_scores = preds_dicts['stopsign_trigger_cls_scores'] + ego_cls_col_preds = preds_dicts['ego_cls_col_preds'] + ego_cls_bd_preds = preds_dicts['ego_cls_bd_preds'] + ego_cls_cl_preds = preds_dicts['ego_cls_cl_preds'] + ego_cls_expert_preds = preds_dicts['ego_cls_expert_preds'] + + + num_dec_layers = len(all_cls_scores) + device = gt_labels_list[0].device + + gt_bboxes_list = [torch.cat( + (gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]), + dim=1).to(device) for gt_bboxes in gt_bboxes_list] + + all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)] + all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)] + all_gt_attr_labels_list = [gt_attr_labels for _ in range(num_dec_layers)] + all_gt_bboxes_ignore_list = [ + gt_bboxes_ignore for _ in range(num_dec_layers) + ] + + losses_cls, losses_bbox, loss_mot_reg, loss_mot_cls = multi_apply( + self.loss_single, all_cls_scores, all_bbox_preds, all_mot_preds, + all_mot_cls_scores, all_gt_bboxes_list, + all_gt_labels_list, all_gt_attr_labels_list, all_gt_bboxes_ignore_list) + + + num_dec_layers = len(map_all_cls_scores) + device = map_gt_labels_list[0].device + # gt_bboxes_list = [torch.cat( + # (gt_bboxes.gravity_center, gt_bboxes.tensor[:, 3:]), + # dim=1).to(device) for gt_bboxes in gt_bboxes_list] + # import pdb;pdb.set_trace() + # gt_bboxes_list = [ + # gt_bboxes.to(device) for gt_bboxes in gt_bboxes_list] + map_gt_bboxes_list = [ + map_gt_bboxes.bbox.to(device) for map_gt_bboxes in map_gt_vecs_list] + map_gt_pts_list = [ + map_gt_bboxes.fixed_num_sampled_points.to(device) for map_gt_bboxes in map_gt_vecs_list] + if self.map_gt_shift_pts_pattern == 'v0': + map_gt_shifts_pts_list = [ + gt_bboxes.shift_fixed_num_sampled_points.to(device) for gt_bboxes in map_gt_vecs_list] + elif self.map_gt_shift_pts_pattern == 'v1': + map_gt_shifts_pts_list = [ + gt_bboxes.shift_fixed_num_sampled_points_v1.to(device) for gt_bboxes in map_gt_vecs_list] + elif self.map_gt_shift_pts_pattern == 'v2': + map_gt_shifts_pts_list = [ + gt_bboxes.shift_fixed_num_sampled_points_v2.to(device) for gt_bboxes in map_gt_vecs_list] + elif self.map_gt_shift_pts_pattern == 'v3': + map_gt_shifts_pts_list = [ + gt_bboxes.shift_fixed_num_sampled_points_v3.to(device) for gt_bboxes in map_gt_vecs_list] + elif self.map_gt_shift_pts_pattern == 'v4': + map_gt_shifts_pts_list = [ + gt_bboxes.shift_fixed_num_sampled_points_v4.to(device) for gt_bboxes in map_gt_vecs_list] + else: + raise NotImplementedError + map_all_gt_bboxes_list = [map_gt_bboxes_list for _ in range(num_dec_layers)] + map_all_gt_labels_list = [map_gt_labels_list for _ in range(num_dec_layers)] + map_all_gt_pts_list = [map_gt_pts_list for _ in range(num_dec_layers)] + map_all_gt_shifts_pts_list = [map_gt_shifts_pts_list for _ in range(num_dec_layers)] + map_all_gt_bboxes_ignore_list = [ + map_gt_bboxes_ignore for _ in range(num_dec_layers) + ] + # import pdb;pdb.set_trace() + map_losses_cls, map_losses_bbox, map_losses_iou, \ + map_losses_pts, map_losses_dir = multi_apply( + self.map_loss_single, map_all_cls_scores, map_all_bbox_preds, + map_all_pts_preds, map_all_gt_bboxes_list, map_all_gt_labels_list, + map_all_gt_shifts_pts_list, map_all_gt_bboxes_ignore_list) + + loss_dict = dict() + # loss from the last decoder layer + loss_dict['loss_cls'] = losses_cls[-1] + loss_dict['loss_bbox'] = losses_bbox[-1] + loss_dict['loss_mot_reg'] = loss_mot_reg[-1] + loss_dict['loss_mot_cls'] = loss_mot_cls[-1] + # loss from the last decoder layer + loss_dict['loss_map_cls'] = map_losses_cls[-1] + loss_dict['loss_map_bbox'] = map_losses_bbox[-1] + loss_dict['loss_map_iou'] = map_losses_iou[-1] + loss_dict['loss_map_pts'] = map_losses_pts[-1] + loss_dict['loss_map_dir'] = map_losses_dir[-1] + + # Planning Loss + ego_fut_gt = ego_fut_gt.squeeze(1) + ego_fut_masks = ego_fut_masks.squeeze(1).squeeze(1) + ego_fut_cmd = ego_fut_cmd.squeeze(1).squeeze(1) + + batch, num_agent = all_mot_preds[-1].shape[:2] + agent_fut_preds = all_mot_preds[-1].view(batch, num_agent, self.mot_fut_mode, self.fut_ts, 2) + agent_fut_cls_preds = all_mot_cls_scores[-1].view(batch, num_agent, self.mot_fut_mode) + loss_plan_input = [ego_fut_preds, ego_fut_gt, ego_fut_masks, ego_fut_cmd, + map_all_pts_preds[-1], map_all_cls_scores[-1].sigmoid(), + all_bbox_preds[-1][..., 0:2], agent_fut_preds, + all_cls_scores[-1].sigmoid(), agent_fut_cls_preds.sigmoid(), + ego_cls_col_preds, ego_cls_bd_preds, ego_cls_cl_preds, ego_cls_expert_preds, + gt_bboxes_list, gt_attr_labels, + map_gt_pts_list, map_gt_labels_list] + + loss_planning_dict = self.loss_planning(*loss_plan_input) + loss_dict['loss_plan_cls_col'] = loss_planning_dict['loss_plan_cls_col'] + loss_dict['loss_plan_cls_bd'] = loss_planning_dict['loss_plan_cls_bd'] + loss_dict['loss_plan_cls_cl'] = loss_planning_dict['loss_plan_cls_cl'] + loss_dict['loss_plan_cls_expert'] = loss_planning_dict['loss_plan_cls_expert'] + loss_dict['loss_plan_reg'] = loss_planning_dict['loss_plan_reg'] + loss_dict['loss_plan_bound'] = loss_planning_dict['loss_plan_bound'] + loss_dict['loss_plan_agent_dis'] = loss_planning_dict['loss_plan_agent_dis'] + loss_dict['loss_plan_map_theta'] = loss_planning_dict['loss_plan_map_theta'] + + + # traffic light trigger classification + tl_trigger_cls_scores = tl_trigger_cls_scores.reshape(-1, self.tl_trigger_num_cls) + tl_trigger_labels = traffic_signal[..., 1].reshape(-1) + tl_trigger_cls_avg_factor = tl_trigger_cls_scores.shape[0] * 1.0 + if self.sync_cls_avg_factor: + tl_trigger_cls_avg_factor = reduce_mean( + tl_trigger_cls_scores.new_tensor([tl_trigger_cls_avg_factor])) + tl_trigger_cls_avg_factor = max(tl_trigger_cls_avg_factor, 1) + loss_tl_trigger_cls = self.loss_tl_trigger_cls( + tl_trigger_cls_scores, tl_trigger_labels, + tl_trigger_cls_scores.new_ones(tl_trigger_labels.shape[0]), + avg_factor=tl_trigger_cls_avg_factor) + + # stop sign trigger classification + stopsign_trigger_cls_scores = stopsign_trigger_cls_scores.reshape(-1, self.stopsign_trigger_num_cls) + stopsign_trigger_labels = stop_sign_signal.reshape(-1) + stopsign_trigger_cls_avg_factor = stopsign_trigger_cls_scores.shape[0] * 1.0 + if self.sync_cls_avg_factor: + stopsign_trigger_cls_avg_factor = reduce_mean( + stopsign_trigger_cls_scores.new_tensor([stopsign_trigger_cls_avg_factor])) + stopsign_trigger_cls_avg_factor = max(stopsign_trigger_cls_avg_factor, 1) + loss_stopsign_trigger_cls = self.loss_stopsign_trigger_cls( + stopsign_trigger_cls_scores, stopsign_trigger_labels, + stopsign_trigger_cls_scores.new_ones(stopsign_trigger_labels.shape[0]), + avg_factor=stopsign_trigger_cls_avg_factor) + + # traffic light status classification + tl_status_weights = 1 - tl_trigger_labels + tl_status_cls_scores = tl_status_cls_scores.reshape(-1, self.tl_status_num_cls) + tl_status_labels = traffic_signal[..., 0].reshape(-1) + tl_status_cls_avg_factor = tl_status_cls_scores.shape[0] * 1.0 + if self.sync_cls_avg_factor: + tl_status_cls_avg_factor = reduce_mean( + tl_status_cls_scores.new_tensor([tl_status_cls_avg_factor])) + tl_status_cls_avg_factor = max(tl_status_cls_avg_factor, 1) + loss_tl_status_cls = self.loss_tl_status_cls( + tl_status_cls_scores, tl_status_labels, + tl_status_weights, + avg_factor=tl_status_cls_avg_factor) + + loss_dict['loss_tl_status_cls'] = loss_tl_status_cls + loss_dict['loss_tl_trigger_cls'] = loss_tl_trigger_cls + loss_dict['loss_stopsign_trigger_cls'] = loss_stopsign_trigger_cls + + # det loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_bbox_i in zip(losses_cls[:-1], losses_bbox[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i + num_dec_layer += 1 + + # map loss from other decoder layers + num_dec_layer = 0 + for map_loss_cls_i, map_loss_bbox_i, map_loss_iou_i, map_loss_pts_i, map_loss_dir_i in zip( + map_losses_cls[:-1], + map_losses_bbox[:-1], + map_losses_iou[:-1], + map_losses_pts[:-1], + map_losses_dir[:-1] + ): + loss_dict[f'd{num_dec_layer}.loss_map_cls'] = map_loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_map_bbox'] = map_loss_bbox_i + loss_dict[f'd{num_dec_layer}.loss_map_iou'] = map_loss_iou_i + loss_dict[f'd{num_dec_layer}.loss_map_pts'] = map_loss_pts_i + loss_dict[f'd{num_dec_layer}.loss_map_dir'] = map_loss_dir_i + num_dec_layer += 1 + + # loss of proposal generated from encode feature map. + if enc_cls_scores is not None: + binary_labels_list = [ + torch.zeros_like(gt_labels_list[i]) + for i in range(len(all_gt_labels_list)) + ] + enc_loss_cls, enc_losses_bbox = \ + self.loss_single(enc_cls_scores, enc_bbox_preds, + gt_bboxes_list, binary_labels_list, gt_bboxes_ignore) + loss_dict['enc_loss_cls'] = enc_loss_cls + loss_dict['enc_loss_bbox'] = enc_losses_bbox + + if map_enc_cls_scores is not None: + map_binary_labels_list = [ + torch.zeros_like(map_gt_labels_list[i]) + for i in range(len(map_all_gt_labels_list)) + ] + # TODO bug here, but we dont care enc_loss now + map_enc_loss_cls, map_enc_loss_bbox, map_enc_loss_iou, \ + map_enc_loss_pts, map_enc_loss_dir = \ + self.map_loss_single( + map_enc_cls_scores, map_enc_bbox_preds, + map_enc_pts_preds, map_gt_bboxes_list, + map_binary_labels_list, map_gt_pts_list, + map_gt_bboxes_ignore + ) + loss_dict['enc_loss_map_cls'] = map_enc_loss_cls + loss_dict['enc_loss_map_bbox'] = map_enc_loss_bbox + loss_dict['enc_loss_map_iou'] = map_enc_loss_iou + loss_dict['enc_loss_map_pts'] = map_enc_loss_pts + loss_dict['enc_loss_map_dir'] = map_enc_loss_dir + + return loss_dict + + # NOTE: already support map + @force_fp32(apply_to=('preds_dicts')) + def get_bboxes(self, preds_dicts, img_metas, rescale=False): + """Generate bboxes from bbox head predictions. + Args: + preds_dicts (tuple[list[dict]]): Prediction results. + img_metas (list[dict]): Point cloud and image's meta info. + Returns: + list[dict]: Decoded bbox, scores and labels after nms. + """ + + det_preds_dicts = self.bbox_coder.decode(preds_dicts) + # map_bboxes: xmin, ymin, xmax, ymax + map_preds_dicts = self.map_bbox_coder.decode(preds_dicts) + + num_samples = len(det_preds_dicts) + assert len(det_preds_dicts) == len(map_preds_dicts), \ + 'len(preds_dict) should be equal to len(map_preds_dicts)' + ret_list = [] + for i in range(num_samples): + preds = det_preds_dicts[i] + bboxes = preds['bboxes'] + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 + code_size = bboxes.shape[-1] + bboxes = img_metas[i]['box_type_3d'](bboxes, code_size) + scores = preds['scores'] + labels = preds['labels'] + trajs = preds['trajs'] + trajs_cls = preds['trajs_cls'] + + map_preds = map_preds_dicts[i] + map_bboxes = map_preds['map_bboxes'] + # bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 + # code_size = bboxes.shape[-1] + # bboxes = img_metas[i]['box_type_3d'](bboxes, code_size) + map_scores = map_preds['map_scores'] + map_labels = map_preds['map_labels'] + map_pts = map_preds['map_pts'] + + ret_list.append([bboxes, scores, labels, trajs, trajs_cls, map_bboxes, + map_scores, map_labels, map_pts]) + + return ret_list + + def select_and_pad_pred_map( + self, + mot_pos, + map_query, + map_score, + map_pos, + map_thresh=0.5, + dis_thresh=None, + pe_normalization=True, + use_fix_pad=False + ): + """select_and_pad_pred_map. + Args: + mot_pos: [B, A, 2] + map_query: [B, P, D]. + map_score: [B, P, 3]. + map_pos: [B, P, pts, 2]. + map_thresh: map confidence threshold for filtering low-confidence preds + dis_thresh: distance threshold for masking far maps for each agent in cross-attn + use_fix_pad: always pad one lane instance for each batch + Returns: + selected_map_query: [B*A, P1(+1), D], P1 is the max inst num after filter and pad. + selected_map_pos: [B*A, P1(+1), 2] + selected_padding_mask: [B*A, P1(+1)] + """ + + if dis_thresh is None: + raise NotImplementedError('Not implement yet') + + # use the most close pts pos in each map inst as the inst's pos + batch, num_map = map_pos.shape[:2] + map_dis = torch.sqrt(map_pos[..., 0]**2 + map_pos[..., 1]**2) + min_map_pos_idx = map_dis.argmin(dim=-1).flatten() # [B*P] + min_map_pos = map_pos.flatten(0, 1) # [B*P, pts, 2] + min_map_pos = min_map_pos[range(min_map_pos.shape[0]), min_map_pos_idx] # [B*P, 2] + min_map_pos = min_map_pos.view(batch, num_map, 2) # [B, P, 2] + + # select & pad map vectors for different batch using map_thresh + map_score = map_score.sigmoid() + map_max_score = map_score.max(dim=-1)[0] + map_idx = map_max_score > map_thresh + batch_max_pnum = 0 + for i in range(map_score.shape[0]): + pnum = map_idx[i].sum() + if pnum > batch_max_pnum: + batch_max_pnum = pnum + + selected_map_query, selected_map_pos, selected_padding_mask = [], [], [] + for i in range(map_score.shape[0]): + dim = map_query.shape[-1] + valid_pnum = map_idx[i].sum() + valid_map_query = map_query[i, map_idx[i]] + valid_map_pos = min_map_pos[i, map_idx[i]] + pad_pnum = batch_max_pnum - valid_pnum + padding_mask = torch.tensor([False], device=map_score.device).repeat(batch_max_pnum) + if pad_pnum != 0: + valid_map_query = torch.cat([valid_map_query, torch.zeros((pad_pnum, dim), device=map_score.device)], dim=0) + valid_map_pos = torch.cat([valid_map_pos, torch.zeros((pad_pnum, 2), device=map_score.device)], dim=0) + padding_mask[valid_pnum:] = True + selected_map_query.append(valid_map_query) + selected_map_pos.append(valid_map_pos) + selected_padding_mask.append(padding_mask) + + selected_map_query = torch.stack(selected_map_query, dim=0) + selected_map_pos = torch.stack(selected_map_pos, dim=0) + selected_padding_mask = torch.stack(selected_padding_mask, dim=0) + + # generate different pe for map vectors for each agent + num_agent = mot_pos.shape[1] + selected_map_query = selected_map_query.unsqueeze(1).repeat(1, num_agent, 1, 1) # [B, A, max_P, D] + selected_map_pos = selected_map_pos.unsqueeze(1).repeat(1, num_agent, 1, 1) # [B, A, max_P, 2] + selected_padding_mask = selected_padding_mask.unsqueeze(1).repeat(1, num_agent, 1) # [B, A, max_P] + # move lane to per-car coords system + selected_map_dist = selected_map_pos - mot_pos[:, :, None, :] # [B, A, max_P, 2] + if pe_normalization: + selected_map_pos = selected_map_pos - mot_pos[:, :, None, :] # [B, A, max_P, 2] + + # filter far map inst for each agent + map_dis = torch.sqrt(selected_map_dist[..., 0]**2 + selected_map_dist[..., 1]**2) + valid_map_inst = (map_dis <= dis_thresh) # [B, A, max_P] + invalid_map_inst = (valid_map_inst == False) + selected_padding_mask = selected_padding_mask + invalid_map_inst + + selected_map_query = selected_map_query.flatten(0, 1) + selected_map_pos = selected_map_pos.flatten(0, 1) + selected_padding_mask = selected_padding_mask.flatten(0, 1) + + num_batch = selected_padding_mask.shape[0] + feat_dim = selected_map_query.shape[-1] + if use_fix_pad: + pad_map_query = torch.zeros((num_batch, 1, feat_dim), device=selected_map_query.device) + pad_map_pos = torch.ones((num_batch, 1, 2), device=selected_map_pos.device) + pad_lane_mask = torch.tensor([False], device=selected_padding_mask.device).unsqueeze(0).repeat(num_batch, 1) + selected_map_query = torch.cat([selected_map_query, pad_map_query], dim=1) + selected_map_pos = torch.cat([selected_map_pos, pad_map_pos], dim=1) + selected_padding_mask = torch.cat([selected_padding_mask, pad_lane_mask], dim=1) + + return selected_map_query, selected_map_pos, selected_padding_mask + + + def select_and_pad_query( + self, + query, + query_pos, + query_score, + score_thresh=0.5, + use_fix_pad=True + ): + """select_and_pad_query. + Args: + query: [B, Q, D]. + query_pos: [B, Q, 2] + query_score: [B, Q, C]. + score_thresh: confidence threshold for filtering low-confidence query + use_fix_pad: always pad one query instance for each batch + Returns: + selected_query: [B, Q', D] + selected_query_pos: [B, Q', 2] + selected_padding_mask: [B, Q'] + """ + + # select & pad query for different batch using score_thresh + query_score = query_score.sigmoid() + query_score = query_score.max(dim=-1)[0] + query_idx = query_score > score_thresh + batch_max_qnum = 0 + for i in range(query_score.shape[0]): + qnum = query_idx[i].sum() + if qnum > batch_max_qnum: + batch_max_qnum = qnum + + selected_query, selected_query_pos, selected_padding_mask = [], [], [] + for i in range(query_score.shape[0]): + dim = query.shape[-1] + valid_qnum = query_idx[i].sum() + valid_query = query[i, query_idx[i]] + valid_query_pos = query_pos[i, query_idx[i]] + pad_qnum = batch_max_qnum - valid_qnum + padding_mask = torch.tensor([False], device=query_score.device).repeat(batch_max_qnum) + if pad_qnum != 0: + valid_query = torch.cat([valid_query, torch.zeros((pad_qnum, dim), device=query_score.device)], dim=0) + valid_query_pos = torch.cat([valid_query_pos, torch.zeros((pad_qnum, 2), device=query_score.device)], dim=0) + padding_mask[valid_qnum:] = True + selected_query.append(valid_query) + selected_query_pos.append(valid_query_pos) + selected_padding_mask.append(padding_mask) + + selected_query = torch.stack(selected_query, dim=0) + selected_query_pos = torch.stack(selected_query_pos, dim=0) + selected_padding_mask = torch.stack(selected_padding_mask, dim=0) + + num_batch = selected_padding_mask.shape[0] + feat_dim = selected_query.shape[-1] + if use_fix_pad: + pad_query = torch.zeros((num_batch, 1, feat_dim), device=selected_query.device) + pad_query_pos = torch.ones((num_batch, 1, 2), device=selected_query_pos.device) + pad_mask = torch.tensor([False], device=selected_padding_mask.device).unsqueeze(0).repeat(num_batch, 1) + selected_query = torch.cat([selected_query, pad_query], dim=1) + selected_query_pos = torch.cat([selected_query_pos, pad_query_pos], dim=1) + selected_padding_mask = torch.cat([selected_padding_mask, pad_mask], dim=1) + + return selected_query, selected_query_pos, selected_padding_mask + + + def get_plan_col_target(self, + ego_traj_preds, + ego_traj_gts, + agents_boxes_gts, + agents_feats_gts): + """"Get Trajectory mode classification target. + Args: + ego_traj_preds (Tensor): MultiModal traj preds with shape (B, plan_fut_mode, fut_ts, 2). + ego_traj_gts (Tensor): traj gts with shape (B, 1, fut_ts, 2). + agents_boxes_gts (List(Tensor)): Ground truth traj for each agent with shape (N_a, 9). + agents_feats_gts (List(Tensor)): Ground truth feats for each agent with shape (N_a, 34). + Returns: + traj_labels (Tensor): traj cls labels (1, plan_fut_mode). + """ + + planning_metric = PlanningMetric(fut_ts=self.fut_ts) + segmentation, pedestrian = planning_metric.get_label(agents_boxes_gts, agents_feats_gts) + occupancy = torch.logical_or(segmentation, pedestrian) + + label_list = [] + for i in range(self.plan_fut_mode): + label = planning_metric.evaluate_coll( + ego_traj_preds[None, i].detach(), + ego_traj_gts[None, i], + occupancy) + label_list.append(label) + + return torch.cat(label_list, dim=-1).to(agents_feats_gts.device) + + def get_plan_bd_target(self, + ego_traj_preds, + lane_preds, + lane_score_preds, + lane_bound_cls_idx=1, + map_thresh=0.5): + """"Get Trajectory mode classification target. + Args: + ego_traj_preds (Tensor): MultiModal traj preds with shape (mot_fut_mode, fut_ts, 2). + lane_preds (Tensor): map preds/GT with shape (num_vec, num_pts, 2). + lane_score_preds (Tensor): map scores with shape (num_vec, 3). + Returns: + traj_labels (Tensor): traj cls labels (1, mot_fut_mode). + """ + # filter lane element according to confidence score and class + # not_lane_bound_mask = lane_score_preds[..., lane_bound_cls_idx] < map_thresh + not_lane_bound_mask = (lane_score_preds != lane_bound_cls_idx) + + # denormalize map pts + lane_bound_preds = lane_preds.clone() + # lane_bound_preds[..., 0:1] = (lane_bound_preds[..., 0:1] * (self.pc_range[3] - + # self.pc_range[0]) + self.pc_range[0]) + # lane_bound_preds[..., 1:2] = (lane_bound_preds[..., 1:2] * (self.pc_range[4] - + # self.pc_range[1]) + self.pc_range[1]) + # pad not-lane-boundary cls and low confidence preds + lane_bound_preds[not_lane_bound_mask] = 1e6 + + ego_traj_starts = ego_traj_preds[:, :-1, :] + ego_traj_ends = ego_traj_preds + padding_zeros = torch.zeros((self.plan_fut_mode, 1, 2), dtype=ego_traj_preds.dtype, + device=ego_traj_preds.device) # initial position + ego_traj_starts = torch.cat((padding_zeros, ego_traj_starts), dim=1) + V, P, _ = lane_bound_preds.size() + ego_traj_expanded = ego_traj_ends.unsqueeze(2).unsqueeze(3) # [num_plan_mode, T, 1, 1, 2] + maps_expanded = lane_bound_preds.unsqueeze(0).unsqueeze(1) # [1, 1, M, P, 2] + + dist = torch.linalg.norm(ego_traj_expanded - maps_expanded, dim=-1) # [num_plan_mode, T, M, P] + dist = dist.min(dim=-1, keepdim=False)[0] + min_inst_idxs = torch.argmin(dist, dim=-1).tolist() + mode_idxs = [[i] for i in range(dist.shape[0])] + ts_idxs = [[i for i in range(dist.shape[1])] for j in range(dist.shape[0])] + bd_target = lane_bound_preds.unsqueeze(0).unsqueeze(1).repeat(self.plan_fut_mode, self.fut_ts, 1, 1, 1) + min_bd_insts = bd_target[mode_idxs, ts_idxs, min_inst_idxs] # [B, T, P, 2] + bd_inst_starts = min_bd_insts[:, :, :-1, :].flatten(0, 2) + bd_inst_ends = min_bd_insts[:, :, 1:, :].flatten(0, 2) + ego_traj_starts = ego_traj_starts.unsqueeze(2).repeat(1, 1, P-1, 1).flatten(0, 2) + ego_traj_ends = ego_traj_ends.unsqueeze(2).repeat(1, 1, P-1, 1).flatten(0, 2) + + intersect_mask = segments_intersect(ego_traj_starts, ego_traj_ends, + bd_inst_starts, bd_inst_ends) + left_deviation = ego_traj_starts.new_tensor([-0.9, 2.4]) + right_deviation = ego_traj_starts.new_tensor([+0.9, 2.4]) + forward_deviation = ego_traj_starts.new_tensor([0., 2.4]) + intersect_mask_left = segments_intersect(ego_traj_starts + left_deviation, ego_traj_ends + left_deviation, + bd_inst_starts, bd_inst_ends) + intersect_mask_right = segments_intersect(ego_traj_starts + right_deviation, ego_traj_ends + right_deviation, + bd_inst_starts, bd_inst_ends) + intersect_mask_forward = segments_intersect(ego_traj_starts + forward_deviation, ego_traj_ends + forward_deviation, + bd_inst_starts, bd_inst_ends) + intersect_mask = intersect_mask | intersect_mask_left | intersect_mask_right | intersect_mask_forward + # self.W = 1.85 + # self.H = 4.084 + intersect_mask = intersect_mask.reshape(self.plan_fut_mode, self.fut_ts, P-1) + intersect_mask = intersect_mask.any(dim=-1).any(dim=-1) + + bd_overstep_labels = torch.zeros((self.plan_fut_mode), dtype=torch.long, + device=ego_traj_preds.device) + bd_overstep_labels[intersect_mask] = 1 + bd_overstep_labels[~intersect_mask] = 0 + + return bd_overstep_labels + + + def get_plan_cl_target(self, + ego_traj_preds, + lane_preds, + lane_score_preds, + lane_bound_cls_idx=3, + map_thresh=0.5): + + + # filter lane element according to confidence score and class + # not_lane_bound_mask = lane_score_preds[..., lane_bound_cls_idx] < map_thresh + not_lane_bound_mask = (lane_score_preds != lane_bound_cls_idx) + + # denormalize map pts + lane_centerline_preds = lane_preds.clone() + # lane_centerline_preds[..., 0:1] = (lane_centerline_preds[..., 0:1] * (self.pc_range[3] - + # self.pc_range[0]) + self.pc_range[0]) + # lane_centerline_preds[..., 1:2] = (lane_centerline_preds[..., 1:2] * (self.pc_range[4] - + # self.pc_range[1]) + self.pc_range[1]) + # pad not-lane-boundary cls and low confidence preds + lane_centerline_preds[not_lane_bound_mask] = 1e6 + + + + ego_traj_expanded = ego_traj_preds.unsqueeze(2).unsqueeze(3) # [num_plan_mode, T, 1, 1, 2] + + maps_interpolated = F.interpolate(lane_centerline_preds.permute(0, 2, 1), \ + scale_factor=50, mode='linear', align_corners=True).permute(0, 2, 1) + + maps_expanded = maps_interpolated.unsqueeze(0).unsqueeze(1) # [1, 1, M, P, 2] + + dist = torch.linalg.norm(ego_traj_expanded - maps_expanded, dim=-1) # [num_plan_mode, T, M, P] + + dist = dist.min(dim=-1)[0] # map point dim + dist = dist.sum(dim=1) # dist = dist.max(dim=1)[0] plan T dim (max deviation) or dist = dist.sum(dim=1) (mean deviation) + dist, nearest_map_ins_idx = dist.min(dim=-1) # map ins dim + maps_matched = maps_interpolated.index_select(dim=0, index=nearest_map_ins_idx) # [num_plan_mode, P, 2] + + dist_2 = torch.linalg.norm(ego_traj_preds.unsqueeze(2) - maps_matched.unsqueeze(1), dim=-1) + + mode_idxs = [[i] for i in range(dist.shape[0])] + point_idx = dist_2.min(dim=-1)[1] + + point_idx[point_idx==0] = 1 + map_segment_starts = maps_matched[mode_idxs, point_idx - 1] + map_segment_ends = maps_matched[mode_idxs, point_idx] + centerline_vector = map_segment_ends - map_segment_starts + + ego_traj_starts = ego_traj_preds[:, :-1, :] + ego_traj_ends = ego_traj_preds + padding_zeros = torch.zeros((self.plan_fut_mode, 1, 2), dtype=ego_traj_preds.dtype, + device=ego_traj_preds.device) # initial position + ego_traj_starts = torch.cat((padding_zeros, ego_traj_starts), dim=1) + ego_vector = ego_traj_ends - ego_traj_starts + + cos_sim = F.cosine_similarity(ego_vector, centerline_vector, dim=-1) + + + cl_dir_labels = 1. - dist_2.min(dim=-1)[0].mean(dim=-1) + # cl_dir_labels = cos_sim.mean(dim=-1) \ + # - dist_2.min(dim=-1)[0].mean(dim=-1) + cl_dir_labels = torch.clamp(cl_dir_labels, min=-1.) + + # cl_dir_labels = (cl_dir_labels < 0.5).long() + + # TODO p2l cost + + return cl_dir_labels + + + def get_plan_expert_target(self, + ego_traj_preds, + ego_fut_gt, + ego_fut_masks, + ego_cls_expert_preds, # (N, 1) + plan_col_labels, + plan_bd_labels, + ): + + plan_expert_labels = torch.zeros((self.plan_fut_mode), dtype=torch.long, + device=ego_traj_preds.device) + + plan_expert_labels_weight = torch.zeros((self.plan_fut_mode), dtype=ego_traj_preds.dtype, + device=ego_traj_preds.device) + + if ego_fut_masks[0] == 1.: + + neg_idx = torch.ones((self.plan_fut_mode), dtype=torch.bool, + device=ego_traj_preds.device) + #### v1 + # traj_dis = torch.linalg.norm(ego_traj_preds[:,:,:] - ego_fut_gt.cumsum(dim=-2)[:,:,:], dim=-1).sum(dim=-1) \ + # + torch.linalg.norm(ego_traj_preds[:,0,:] - ego_fut_gt.cumsum(dim=-2)[:,0,:], dim=-1) * 5. + # plan_expert_labels[neg_idx] = 1 + # plan_expert_labels_weight[neg_idx] = torch.clip(traj_dis, min=0, max=2.) / 2. + #### v2 + # traj_dis = torch.linalg.norm(ego_traj_preds[:,:,:] - ego_fut_gt.cumsum(dim=-2)[:,:,:], dim=-1).sum(dim=-1) + # plan_expert_labels[neg_idx] = 1 + # plan_expert_labels_weight[neg_idx] = torch.clip(traj_dis, min=0, max=1.5) / 1.5 + #### v3 + # traj_dis = torch.linalg.norm(ego_traj_preds[:,:,:] - ego_fut_gt.cumsum(dim=-2)[:,:,:], dim=-1).sum(dim=-1) + # plan_expert_labels[neg_idx] = 1 + # plan_expert_labels_weight[neg_idx] = torch.clip(traj_dis, min=0, max=10) / 1.5 + # #### v4 + # traj_dis = torch.linalg.norm(ego_traj_preds[:,:,:] - ego_fut_gt.cumsum(dim=-2)[:,:,:], dim=-1).sum(dim=-1) + # plan_expert_labels[neg_idx] = 1 + # plan_expert_labels_weight[neg_idx] = torch.clip(traj_dis, min=0, max=15) / 1.5 + # #### v5 + # traj_dis = torch.linalg.norm(ego_traj_preds[:,:,:] - ego_fut_gt.cumsum(dim=-2)[:,:,:], dim=-1).sum(dim=-1) + # plan_expert_labels[neg_idx] = 1 + # plan_expert_labels_weight[neg_idx] = torch.clip(traj_dis, min=0, max=30) / 1.5 + #### v6 + # traj_dis = torch.linalg.norm(ego_traj_preds[:,:,:] - ego_fut_gt.cumsum(dim=-2)[:,:,:], dim=-1).sum(dim=-1) + # plan_expert_labels[neg_idx] = 1 + # plan_expert_labels_weight[neg_idx] = torch.clip(traj_dis, min=0, max=30) * 2. + #### v7 + # traj_dis = torch.linalg.norm(ego_traj_preds[:,:,:] - ego_fut_gt.cumsum(dim=-2)[:,:,:], dim=-1).sum(dim=-1) + # plan_expert_labels[neg_idx] = 1 + # plan_expert_labels_weight[neg_idx] = torch.clip(traj_dis, min=0, max=100.) * 2. + #### v8 + # traj_dis = torch.linalg.norm(ego_traj_preds[:,:,:] - ego_fut_gt.cumsum(dim=-2)[:,:,:], dim=-1).sum(dim=-1) + # plan_expert_labels[neg_idx] = 1 + # plan_expert_labels_weight[neg_idx] = torch.clip(traj_dis, min=0, max=100.) * 4. + #### v9 + # traj_dis = torch.linalg.norm(ego_traj_preds[:,:,:] - ego_fut_gt.cumsum(dim=-2)[:,:,:], dim=-1).sum(dim=-1) + # plan_expert_labels[neg_idx] = 1 + # plan_expert_labels_weight[neg_idx] = torch.clip(traj_dis, min=0, max=100.) * 10. + # #### v10 + # traj_dis = torch.linalg.norm(ego_traj_preds[:,:,:] - ego_fut_gt.cumsum(dim=-2)[:,:,:], dim=-1).sum(dim=-1) + # plan_expert_labels[neg_idx] = 1 + # plan_expert_labels_weight[neg_idx] = torch.clip(traj_dis, min=0, max=100.) * 20. + #### v11 + traj_dis = torch.linalg.norm(ego_traj_preds[:,:,:] - ego_fut_gt.cumsum(dim=-2)[:,:,:], dim=-1).sum(dim=-1) + plan_expert_labels[neg_idx] = 1 + plan_expert_labels_weight[neg_idx] = torch.clip(traj_dis, min=0, max=100.) * 100. + + plan_expert_labels[plan_col_labels == 1] = 1 + plan_expert_labels[plan_bd_labels == 1] = 1 + plan_expert_labels_weight[plan_col_labels == 1] = 100. + plan_expert_labels_weight[plan_bd_labels == 1] = 100. + + # pos_idx = torch.linalg.norm(ego_traj_preds[:,:1,:] - ego_fut_gt.cumsum(dim=-2)[:,:1,:], dim=-1).mean(dim=-1).argmin() + pos_idx = traj_dis.argmin() + plan_expert_labels[pos_idx] = 0 + + # add weights to balance trajs + self.traj_selected_cnt[pos_idx] += 1. + scaling_rate = self.traj_selected_cnt.sum() / self.traj_selected_cnt[pos_idx] / self.plan_fut_mode + scaling_rate = torch.clamp(scaling_rate, 0.5, 2.) + plan_expert_labels_weight[pos_idx] = 100. # * scaling_rate + + # global pos_idx_cnt + # pos_idx_cnt[pos_idx] += 1 + #------- + # pos_idx = torch.linalg.norm(ego_traj_preds[:,:,:] - ego_fut_gt.cumsum(dim=-2)[:,:,:], dim=-1).mean(dim=-1).argmin() + # rank = (ego_cls_expert_preds[pos_idx] < ego_cls_expert_preds).sum() + # plan_expert_labels[pos_idx] = 0 + # plan_expert_labels_weight[pos_idx] = 500. * min(rank, 10) + + + # neg_idx = torch.linalg.norm(ego_traj_preds[:,:1,:] - ego_fut_gt.cumsum(dim=-2)[:,:1,:], dim=-1).mean(dim=-1) > -10e6 # all + # plan_expert_labels[neg_idx] = 1 + # plan_expert_labels_weight[neg_idx] = min(rank, 10) / self.plan_fut_mode + + + # plan_expert_labels[plan_col_labels == 1] = 1 + # plan_expert_labels[plan_bd_labels == 1] = 1 + # plan_expert_labels_weight[plan_col_labels == 1] = 1. + # plan_expert_labels_weight[plan_bd_labels == 1] = 1. + + return plan_expert_labels, plan_expert_labels_weight + +class PlanningMetric(): + def __init__(self, fut_ts=6): + super().__init__() + self.X_BOUND = [-50.0, 50.0, 0.5] # Forward + self.Y_BOUND = [-50.0, 50.0, 0.5] # Sides + self.Z_BOUND = [-10.0, 10.0, 20.0] # Height + self.fut_ts = fut_ts + dx, bx, _ = self.gen_dx_bx(self.X_BOUND, self.Y_BOUND, self.Z_BOUND) + self.dx, self.bx = dx[:2], bx[:2] + + bev_resolution, bev_start_position, bev_dimension = self.calculate_birds_eye_view_parameters( + self.X_BOUND, self.Y_BOUND, self.Z_BOUND + ) + self.bev_resolution = bev_resolution.numpy() + self.bev_start_position = bev_start_position.numpy() + self.bev_dimension = bev_dimension.numpy() + + self.W = 1.85 + self.H = 4.084 + + self.category_index = { + 'human':[0,1,2,3], + 'vehicle':[0,1,2,3] + } + + def gen_dx_bx(self, xbound, ybound, zbound): + dx = torch.Tensor([row[2] for row in [xbound, ybound, zbound]]) + bx = torch.Tensor([row[0] + row[2]/2.0 for row in [xbound, ybound, zbound]]) + nx = torch.LongTensor([(row[1] - row[0]) / row[2] for row in [xbound, ybound, zbound]]) + + return dx, bx, nx + + def calculate_birds_eye_view_parameters(self, x_bounds, y_bounds, z_bounds): + """ + Parameters + ---------- + x_bounds: Forward direction in the ego-car. + y_bounds: Sides + z_bounds: Height + + Returns + ------- + bev_resolution: Bird's-eye view bev_resolution + bev_start_position Bird's-eye view first element + bev_dimension Bird's-eye view tensor spatial dimension + """ + bev_resolution = torch.tensor([row[2] for row in [x_bounds, y_bounds, z_bounds]]) + bev_start_position = torch.tensor([row[0] + row[2] / 2.0 for row in [x_bounds, y_bounds, z_bounds]]) + bev_dimension = torch.tensor([(row[1] - row[0]) / row[2] for row in [x_bounds, y_bounds, z_bounds]], + dtype=torch.long) + + return bev_resolution, bev_start_position, bev_dimension + + def get_label( + self, + gt_agent_boxes, + gt_agent_feats + ): + segmentation_np, pedestrian_np = self.get_birds_eye_view_label(gt_agent_boxes,gt_agent_feats) + segmentation = torch.from_numpy(segmentation_np).long().unsqueeze(0) + pedestrian = torch.from_numpy(pedestrian_np).long().unsqueeze(0) + + return segmentation, pedestrian + + def get_birds_eye_view_label( + self, + gt_agent_boxes, + gt_agent_feats + ): + ''' + gt_agent_boxes (LiDARInstance3DBoxes): list of GT Bboxs. + dim 9 = (x,y,z)+(w,l,h)+yaw+(vx,vy) + gt_agent_feats: (A, 4*T+10) + dim 4*T+10 = fut_traj(T*2) + fut_mask(T) + goal(1) + lcf_feat(9) + fut_yaw(T) + lcf_feat (x, y, yaw, vx, vy, width, length, height, type) + ''' + + segmentation = np.zeros((self.fut_ts, self.bev_dimension[0], self.bev_dimension[1])) + pedestrian = np.zeros((self.fut_ts, self.bev_dimension[0], self.bev_dimension[1])) + agent_num = gt_agent_feats.shape[0] + + gt_agent_boxes = gt_agent_boxes.cpu().numpy() #(N, 9) + gt_agent_feats = gt_agent_feats.cpu().numpy() + + gt_agent_fut_trajs = gt_agent_feats[..., :self.fut_ts*2].reshape(-1, self.fut_ts, 2) + gt_agent_fut_mask = gt_agent_feats[..., self.fut_ts*2:self.fut_ts*3].reshape(-1, self.fut_ts) + # gt_agent_lcf_feat = gt_agent_feats[..., T*3+1:T*3+10].reshape(-1, 9) + gt_agent_fut_yaw = gt_agent_feats[..., self.fut_ts*3+10:self.fut_ts*4+10].reshape(-1, self.fut_ts, 1) + gt_agent_fut_trajs = np.cumsum(gt_agent_fut_trajs, axis=1) + gt_agent_fut_yaw = np.cumsum(gt_agent_fut_yaw, axis=1) + + gt_agent_boxes[:,6:7] = -1*(gt_agent_boxes[:, 6:7] + np.pi/2) # NOTE: convert yaw to lidar frame + gt_agent_fut_trajs = gt_agent_fut_trajs + gt_agent_boxes[:, np.newaxis, 0:2] + gt_agent_fut_yaw = gt_agent_fut_yaw + gt_agent_boxes[:, np.newaxis, 6:7] + + for t in range(self.fut_ts): + for i in range(agent_num): + if gt_agent_fut_mask[i][t] == 1: + # Filter out all non vehicle instances + category_index = int(gt_agent_feats[i, 3*self.fut_ts+9]) + agent_length, agent_width = gt_agent_boxes[i][4], gt_agent_boxes[i][3] + x_a = gt_agent_fut_trajs[i, t, 0] + y_a = gt_agent_fut_trajs[i, t, 1] + yaw_a = gt_agent_fut_yaw[i, t, 0] + param = [x_a, y_a, yaw_a, agent_length, agent_width] + if (category_index in self.category_index['vehicle']): + poly_region = self._get_poly_region_in_image(param) + cv2.fillPoly(segmentation[t], [poly_region], 1.0) + if (category_index in self.category_index['human']): + poly_region = self._get_poly_region_in_image(param) + cv2.fillPoly(pedestrian[t], [poly_region], 1.0) + + return segmentation, pedestrian + + def _get_poly_region_in_image(self,param): + lidar2cv_rot = np.array([[1,0], [0,-1]]) + x_a, y_a, yaw_a, agent_length, agent_width = param + trans_a = np.array([[x_a, y_a]]).T + rot_mat_a = np.array([[np.cos(yaw_a), -np.sin(yaw_a)], + [np.sin(yaw_a), np.cos(yaw_a)]]) + agent_corner = np.array([ + [agent_length/2, -agent_length/2, -agent_length/2, agent_length/2], + [agent_width/2, agent_width/2, -agent_width/2, -agent_width/2]]) #(2,4) + agent_corner_lidar = np.matmul(rot_mat_a, agent_corner) + trans_a #(2,4) + # convert to cv frame + agent_corner_cv2 = (np.matmul(lidar2cv_rot, agent_corner_lidar) \ + - self.bev_start_position[:2, None] + self.bev_resolution[:2, None] / 2.0).T / self.bev_resolution[:2] #(4,2) + agent_corner_cv2 = np.round(agent_corner_cv2).astype(np.int32) + + return agent_corner_cv2 + + def evaluate_single_coll(self, traj, segmentation, input_gt): + ''' + traj: torch.Tensor (n_future, 2) + 自车lidar系为轨迹参考系 + ^ y + | + | + 0-------> + x + segmentation: torch.Tensor (n_future, 200, 200) + ''' + pts = np.array([ + [-self.H / 2. + 0.5, self.W / 2.], + [self.H / 2. + 0.5, self.W / 2.], + [self.H / 2. + 0.5, -self.W / 2.], + [-self.H / 2. + 0.5, -self.W / 2.], + ]) + pts = (pts - self.bx.cpu().numpy()) / (self.dx.cpu().numpy()) + pts[:, [0, 1]] = pts[:, [1, 0]] + rr, cc = polygon(pts[:,1], pts[:,0]) + rc = np.concatenate([rr[:,None], cc[:,None]], axis=-1) + + n_future, _ = traj.shape + trajs = traj.view(n_future, 1, 2) + # 轨迹坐标系转换为: + # ^ x + # | + # | + # 0-------> y + trajs_ = copy.deepcopy(trajs) + trajs_[:,:,[0,1]] = trajs_[:,:,[1,0]] # can also change original tensor + trajs_ = trajs_ / self.dx.to(trajs.device) + trajs_ = trajs_.cpu().numpy() + rc # (n_future, 32, 2) + + r = (self.bev_dimension[0] - trajs_[:,:,0]).astype(np.int32) + r = np.clip(r, 0, self.bev_dimension[0] - 1) + + c = trajs_[:,:,1].astype(np.int32) + c = np.clip(c, 0, self.bev_dimension[1] - 1) + + collision = np.full(n_future, False) + for t in range(n_future): + rr = r[t] + cc = c[t] + I = np.logical_and( + np.logical_and(rr >= 0, rr < self.bev_dimension[0]), + np.logical_and(cc >= 0, cc < self.bev_dimension[1]), + ) + collision[t] = np.any(segmentation[t, rr[I], cc[I]].cpu().numpy()) + + return torch.from_numpy(collision).to(device=traj.device) + + + def evaluate_coll( + self, + trajs, + gt_trajs, + segmentation + ): + ''' + trajs: torch.Tensor (B, n_future, 2) + 自车lidar系为轨迹参考系 + ^ y + | + | + 0-------> + x + gt_trajs: torch.Tensor (B, n_future, 2) + segmentation: torch.Tensor (B, n_future, 200, 200) + + ''' + B, n_future, _ = trajs.shape + # trajs = trajs * torch.tensor([-1, 1], device=trajs.device) + # gt_trajs = gt_trajs * torch.tensor([-1, 1], device=gt_trajs.device) + + obj_box_coll_sum = torch.zeros(n_future, device=segmentation.device) + + for i in range(B): + gt_box_coll = self.evaluate_single_coll(gt_trajs[i], segmentation[i], input_gt=True) + + xx, yy = trajs[i,:,0], trajs[i, :, 1] + # lidar系下的轨迹转换到图片坐标系下 + xi = ((-self.bx[0]/2 - yy) / self.dx[0]).long() + yi = ((-self.bx[1]/2 + xx) / self.dx[1]).long() + + m1 = torch.logical_and( + torch.logical_and(xi >= 0, xi < self.bev_dimension[0]), + torch.logical_and(yi >= 0, yi < self.bev_dimension[1]), + ).to(gt_box_coll.device) + m1 = torch.logical_and(m1, torch.logical_not(gt_box_coll)) + + ti = torch.arange(n_future) + # m2 = torch.logical_not(gt_box_coll) + m2 = torch.ones_like(gt_box_coll) + box_coll = self.evaluate_single_coll(trajs[i], segmentation[i], input_gt=False).to(ti.device) + obj_box_coll_sum[ti[m2]] += (box_coll[ti[m2]]).long() + + if obj_box_coll_sum.max() > 0: + return torch.ones((1), dtype=torch.long) + else: + return torch.zeros((1), dtype=torch.long)