diff --git a/README.md b/README.md
index 2f3738c3..b05022cc 100644
--- a/README.md
+++ b/README.md
@@ -118,6 +118,7 @@ python run_net.py --config-file=configs/base.py --task=test
| RSDet-R50-FPN | DOTA1.0|1024/200|Flip|-| SGD | 1x | 68.41 | [arxiv](https://arxiv.org/abs/1911.08299) | [config](configs/rotated_retinanet/rsdet_obb_r50_fpn_1x_dota_lmr5p.py) | [model](https://cloud.tsinghua.edu.cn/f/642e200f5a8a420eb726/?dl=1) |
| ATSS-R50-FPN|DOTA1.0|1024/200| flip|-| SGD | 1x | 72.44 | [arxiv](https://arxiv.org/abs/1912.02424) | [config](configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_atss.py) | [model](https://cloud.tsinghua.edu.cn/f/5168189dcd364eaebce5/?dl=1) |
| Reppoints-R50-FPN|DOTA1.0|1024/200| flip|-| SGD | 1x | 56.34 | [arxiv](https://arxiv.org/abs/1904.11490) | [config](configs/rotated_retinanet/rotated_retinanet_obb_r50_fpn_1x_dota_atss.py) | [model](https://cloud.tsinghua.edu.cn/f/be359ac932c84f9c839e/?dl=1) |
+| R3Det-R50-FPN | DOTA1.0|1024/200| flip|-| SGD | 1x | 64.41 | [arxiv](https://arxiv.org/pdf/1908.05612.pdf)| [config](configs/projects/r3det/r3det_r50_fpn_1x_dota.py) | [model]() |
**Notice**:
@@ -153,7 +154,7 @@ python run_net.py --config-file=configs/base.py --task=test
- :heavy_check_mark: Reppoints
- :heavy_check_mark: RSDet
- :heavy_check_mark: ATSS
-- :clock3: R3Det
+- :heavy_check_mark: R3Det
- :clock3: Cascade R-CNN
- :clock3: Oriented Reppoints
- :heavy_plus_sign: DCL
diff --git a/configs/r3det_r50_fpn_1x_dota.py b/configs/r3det_r50_fpn_1x_dota.py
index 876e7270..5fdb7445 100644
--- a/configs/r3det_r50_fpn_1x_dota.py
+++ b/configs/r3det_r50_fpn_1x_dota.py
@@ -2,165 +2,177 @@
model = dict(
type='R3Det',
backbone=dict(
- type='ResNet50',
- num_stages=4,
- out_indices=(0, 1, 2, 3),
+ type='Resnet50',
frozen_stages=1,
- norm_cfg=dict(type='BN', requires_grad=True),
- norm_eval=True,
- pretrained=True),
+ return_stages=["layer1","layer2","layer3","layer4"],
+ pretrained= True),
neck=dict(
type='FPN',
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
- add_extra_convs='on_input',
+ add_extra_convs="on_input",
num_outs=5),
bbox_head=dict(
- type='RRetinaHead',
- num_classes=15,
+ type='R3Head',
+ num_classes=16,
in_channels=256,
- stacked_convs=4,
- use_h_gt=True,
feat_channels=256,
- anchor_generator=dict(
- type='RAnchorGenerator',
- octave_base_scale=4,
- scales_per_octave=3,
- ratios=[1.0, 0.5, 2.0, 1.0 / 3.0, 3.0, 0.2, 5.0],
- angles=None,
- strides=[8, 16, 32, 64, 128]),
- bbox_coder=dict(
- type='DeltaXYWHABBoxCoder',
- target_means=(.0, .0, .0, .0, .0),
- target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)),
- loss_cls=dict(
+ stacked_convs=4,
+ octave_base_scale=4,
+ scales_per_octave=3,
+ anchor_ratios=[1.0, 0.5, 2.0],
+ anchor_strides=[8, 16, 32, 64, 128],
+ target_means=[.0, .0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0, 1.0],
+ loss_init_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
- loss_bbox=dict(
- type='SmoothL1Loss',
- beta=0.11,
- loss_weight=1.0)),
- frm_cfgs=[
- dict(
- in_channels=256,
- featmap_strides=[8, 16, 32, 64, 128]),
- dict(
- in_channels=256,
- featmap_strides=[8, 16, 32, 64, 128])
- ],
- num_refine_stages=2,
- refine_heads=[
- dict(
- type='RRetinaRefineHead',
- num_classes=15,
- in_channels=256,
- stacked_convs=4,
- feat_channels=256,
- anchor_generator=dict(
- type='PseudoAnchorGenerator',
- strides=[8, 16, 32, 64, 128]),
- bbox_coder=dict(
- type='DeltaXYWHABBoxCoder',
- target_means=(.0, .0, .0, .0, .0),
- target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)),
- loss_cls=dict(
- type='FocalLoss',
- use_sigmoid=True,
- gamma=2.0,
- alpha=0.25,
- loss_weight=1.0),
- loss_bbox=dict(
- type='SmoothL1Loss',
- beta=0.11,
- loss_weight=1.0)),
- dict(
- type='RRetinaRefineHead',
- num_classes=15,
- in_channels=256,
- stacked_convs=4,
- feat_channels=256,
- anchor_generator=dict(
- type='PseudoAnchorGenerator',
- strides=[8, 16, 32, 64, 128]),
- bbox_coder=dict(
- type='DeltaXYWHABBoxCoder',
- target_means=(.0, .0, .0, .0, .0),
- target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)),
- loss_cls=dict(
- type='FocalLoss',
- use_sigmoid=True,
- gamma=2.0,
- alpha=0.25,
- loss_weight=1.0),
- loss_bbox=dict(
- type='SmoothL1Loss',
- beta=0.11,
- loss_weight=1.0)),
- ]
-)
-# training and testing settings
-train_cfg = dict(
- s0=dict(
- assigner=dict(
- type='MaxIoUAssigner',
- pos_iou_thr=0.5,
- neg_iou_thr=0.4,
- min_pos_iou=0,
- ignore_iof_thr=-1,
- iou_calculator=dict(type='RBboxOverlaps2D')),
- allowed_border=-1,
- pos_weight=-1,
- debug=False),
- sr=[
- dict(
- assigner=dict(
- type='MaxIoUAssigner',
- pos_iou_thr=0.6,
- neg_iou_thr=0.5,
- min_pos_iou=0,
- ignore_iof_thr=-1,
- iou_calculator=dict(type='RBboxOverlaps2D')),
- allowed_border=-1,
- pos_weight=-1,
- debug=False),
- dict(
- assigner=dict(
- type='MaxIoUAssigner',
- pos_iou_thr=0.7,
- neg_iou_thr=0.6,
- min_pos_iou=0,
- ignore_iof_thr=-1,
- iou_calculator=dict(type='RBboxOverlaps2D')),
- allowed_border=-1,
- pos_weight=-1,
- debug=False
+ loss_init_bbox=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ loss_refine_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_refine_bbox=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ test_cfg=dict(
+ nms_pre=2000,
+ min_bbox_size=0,
+ score_thr=0.05,
+ nms=dict(type='nms_rotated', iou_thr=0.1),
+ max_per_img=2000),
+ train_cfg=dict(
+ init_cfg=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.4,
+ min_pos_iou=0,
+ ignore_iof_thr=-1,
+ iou_calculator=dict(type='BboxOverlaps2D_rotated')),
+ bbox_coder=dict(type='DeltaXYWHABBoxCoder',
+ target_means=(0., 0., 0., 0., 0.),
+ target_stds=(1., 1., 1., 1., 1.),
+ clip_border=True),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ refine_cfg=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.4,
+ min_pos_iou=0,
+ ignore_iof_thr=-1,
+ iou_calculator=dict(type='BboxOverlaps2D_rotated')),
+ bbox_coder=dict(type='DeltaXYWHABBoxCoder',
+ target_means=(0., 0., 0., 0., 0.),
+ target_stds=(1., 1., 1., 1., 1.),
+ clip_border=True),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False))
)
- ],
- stage_loss_weights=[1.0, 1.0]
+ )
+dataset = dict(
+ train=dict(
+ type="DOTADataset",
+ dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(type='RotatedRandomFlip', prob=0.5),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False,)
+
+ ],
+ batch_size=2,
+ num_workers=4,
+ shuffle=True,
+ filter_empty_gt=False
+ ),
+ val=dict(
+ type="DOTADataset",
+ dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False),
+ ],
+ batch_size=2,
+ num_workers=4,
+ shuffle=False
+ ),
+ test=dict(
+ type="ImageDataset",
+ images_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/test_1024_200_1.0/images',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False,),
+ ],
+ num_workers=4,
+ batch_size=1,
+ )
)
-merge_nms_iou_thr_dict = {
- 'roundabout': 0.1, 'tennis-court': 0.3, 'swimming-pool': 0.1, 'storage-tank': 0.1,
- 'soccer-ball-field': 0.3, 'small-vehicle': 0.05, 'ship': 0.05, 'plane': 0.3,
- 'large-vehicle': 0.05, 'helicopter': 0.2, 'harbor': 0.0001, 'ground-track-field': 0.3,
- 'bridge': 0.0001, 'basketball-court': 0.3, 'baseball-diamond': 0.3
-}
+optimizer = dict(
+ type='SGD',
+ lr=0.01/4., #0.0,#0.01*(1/8.),
+ momentum=0.9,
+ weight_decay=0.0001,
+ grad_clip=dict(
+ max_norm=35,
+ norm_type=2))
-merge_cfg = dict(
- nms_pre=2000,
- score_thr=0.1,
- nms=dict(type='rnms', iou_thr=merge_nms_iou_thr_dict),
- max_per_img=1000,
-)
+scheduler = dict(
+ type='StepLR',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ milestones=[10])
+
+
+logger = dict(
+ type="RunLogger")
-test_cfg = dict(
- nms_pre=1000,
- score_thr=0.1,
- nms=dict(type='rnms', iou_thr=0.05),
- max_per_img=100,
- merge_cfg=merge_cfg
-)
\ No newline at end of file
+# when we the trained model from cshuan, image is rgb
+max_epoch = 12
+eval_interval = 1
+checkpoint_interval = 1
+log_interval = 50
\ No newline at end of file
diff --git a/projects/r3det/README.md b/projects/r3det/README.md
index c208e890..4ba058ad 100644
--- a/projects/r3det/README.md
+++ b/projects/r3det/README.md
@@ -1 +1,39 @@
-# TODO: this model is not finished.
\ No newline at end of file
+## R3Det
+> [R3Det: Refined Single-Stage Detector with Feature Refinement for Rotating Object](https://arxiv.org/pdf/1908.05612.pdf)
+
+
+### Abstract
+
+
+

+
+
+Rotation detection is a challenging task due to the difficulties of locating the multi-angle objects and separating them effectively from the background. Though considerable progress has been made, for practical settings, there still exist challenges for rotating objects with large aspect ratio, dense distribution and category extremely imbalance. In this paper, we propose an end-to-end refined single-stage rotation detector for fast and accurate object detection by using a progressive regression approach from coarse to fine granularity. Considering the shortcoming of feature misalignment in existing refined single stage detector, we design a feature refinement module to improve detection performance by getting more accurate features. The key idea of feature refinement module is to re-encode the position information of the current refined bounding box to the corresponding feature points through pixel-wise feature interpolation to realize feature reconstruction and alignment. For more accurate rotation estimation, an approximate SkewIoU loss is proposed to solve the problem that the calculation of SkewIoU is not derivable. Experiments on three popular remote sensing public datasets DOTA, HRSC2016, UCAS-AOD as well as one scene text dataset ICDAR2015 show the effectiveness of our approach.
+
+### Training
+```sh
+python run_net.py --config-file=configs/projects/r3det/r3det_r50_fpn_1x_dota.py --task=train
+```
+
+### Testing
+```sh
+python run_net.py --config-file=configs/projects/r3det/r3det_r50_fpn_1x_dota.py --task=test
+```
+
+### Performance
+| Models | Dataset| Sub_Image_Size/Overlap |Train Aug | Test Aug | Optim | Lr schd | mAP | Paper | Config | Download |
+|:-----------:| :-----: |:-----:|:-----:| :-----: | :-----:| :-----:| :----: |:--------:|:--------------------------------------------------------------:| :--------: |
+| R3Det-R50-FPN | DOTA1.0|1024/200| flip|-| SGD | 1x | 64.41 | [arxiv](https://arxiv.org/pdf/1908.05612.pdf)| [config](configs/projects/r3det/r3det_r50_fpn_1x_dota.py) | [model]() |
+
+### Citation
+```
+@inproceedings{yang2021r3det,
+ title={R3Det: Refined Single-Stage Detector with Feature Refinement for Rotating Object},
+ author={Yang, Xue and Yan, Junchi and Feng, Ziming and He, Tao},
+ booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
+ volume={35},
+ number={4},
+ pages={3163--3171},
+ year={2021}
+}
+```
\ No newline at end of file
diff --git a/projects/r3det/configs/r3det_r50_fpn_1x_dota.py b/projects/r3det/configs/r3det_r50_fpn_1x_dota.py
index 7887a229..38ba9cbd 100644
--- a/projects/r3det/configs/r3det_r50_fpn_1x_dota.py
+++ b/projects/r3det/configs/r3det_r50_fpn_1x_dota.py
@@ -11,153 +11,168 @@
in_channels=[256, 512, 1024, 2048],
out_channels=256,
start_level=1,
- add_extra_convs='on_input',
+ add_extra_convs="on_input",
num_outs=5),
bbox_head=dict(
- type='RRetinaHead',
- num_classes=15,
+ type='R3Head',
+ num_classes=16,
in_channels=256,
- stacked_convs=4,
- use_h_gt=True,
feat_channels=256,
- anchor_generator=dict(
- type='RAnchorGenerator',
- octave_base_scale=4,
- scales_per_octave=3,
- ratios=[1.0, 0.5, 2.0, 1.0 / 3.0, 3.0, 0.2, 5.0],
- angles=None,
- strides=[8, 16, 32, 64, 128]),
- bbox_coder=dict(
- type='DeltaXYWHABBoxCoder',
- target_means=(.0, .0, .0, .0, .0),
- target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)),
- loss_cls=dict(
+ stacked_convs=4,
+ octave_base_scale=4,
+ scales_per_octave=3,
+ anchor_ratios=[1.0, 0.5, 2.0],
+ anchor_strides=[8, 16, 32, 64, 128],
+ target_means=[.0, .0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0, 1.0],
+ loss_init_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
- loss_bbox=dict(
- type='SmoothL1Loss',
- beta=0.11,
- loss_weight=1.0)),
- frm_cfgs=[
- dict(
- in_channels=256,
- featmap_strides=[8, 16, 32, 64, 128]),
- dict(
- in_channels=256,
- featmap_strides=[8, 16, 32, 64, 128])
- ],
- num_refine_stages=2,
- refine_heads=[
- dict(
- type='RRetinaRefineHead',
- num_classes=15,
- in_channels=256,
- stacked_convs=4,
- feat_channels=256,
- anchor_generator=dict(
- type='PseudoAnchorGenerator',
- strides=[8, 16, 32, 64, 128]),
- bbox_coder=dict(
- type='DeltaXYWHABBoxCoder',
- target_means=(.0, .0, .0, .0, .0),
- target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)),
- loss_cls=dict(
- type='FocalLoss',
- use_sigmoid=True,
- gamma=2.0,
- alpha=0.25,
- loss_weight=1.0),
- loss_bbox=dict(
- type='SmoothL1Loss',
- beta=0.11,
- loss_weight=1.0)),
- dict(
- type='RRetinaRefineHead',
- num_classes=15,
- in_channels=256,
- stacked_convs=4,
- feat_channels=256,
- anchor_generator=dict(
- type='PseudoAnchorGenerator',
- strides=[8, 16, 32, 64, 128]),
- bbox_coder=dict(
- type='DeltaXYWHABBoxCoder',
- target_means=(.0, .0, .0, .0, .0),
- target_stds=(1.0, 1.0, 1.0, 1.0, 1.0)),
- loss_cls=dict(
- type='FocalLoss',
- use_sigmoid=True,
- gamma=2.0,
- alpha=0.25,
- loss_weight=1.0),
- loss_bbox=dict(
- type='SmoothL1Loss',
- beta=0.11,
- loss_weight=1.0)),
- ]
-)
-# training and testing settings
-train_cfg = dict(
- s0=dict(
- assigner=dict(
- type='MaxIoUAssigner',
- pos_iou_thr=0.5,
- neg_iou_thr=0.4,
- min_pos_iou=0,
- ignore_iof_thr=-1,
- iou_calculator=dict(type='RBboxOverlaps2D')),
- allowed_border=-1,
- pos_weight=-1,
- debug=False),
- sr=[
- dict(
- assigner=dict(
- type='MaxIoUAssigner',
- pos_iou_thr=0.6,
- neg_iou_thr=0.5,
- min_pos_iou=0,
- ignore_iof_thr=-1,
- iou_calculator=dict(type='RBboxOverlaps2D')),
- allowed_border=-1,
- pos_weight=-1,
- debug=False),
- dict(
- assigner=dict(
- type='MaxIoUAssigner',
- pos_iou_thr=0.7,
- neg_iou_thr=0.6,
- min_pos_iou=0,
- ignore_iof_thr=-1,
- iou_calculator=dict(type='RBboxOverlaps2D')),
- allowed_border=-1,
- pos_weight=-1,
- debug=False
+ loss_init_bbox=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ loss_refine_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_refine_bbox=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ test_cfg=dict(
+ nms_pre=2000,
+ min_bbox_size=0,
+ score_thr=0.05,
+ nms=dict(type='nms_rotated', iou_thr=0.1),
+ max_per_img=2000),
+ train_cfg=dict(
+ init_cfg=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.4,
+ min_pos_iou=0,
+ ignore_iof_thr=-1,
+ iou_calculator=dict(type='BboxOverlaps2D_rotated')),
+ bbox_coder=dict(type='DeltaXYWHABBoxCoder',
+ target_means=(0., 0., 0., 0., 0.),
+ target_stds=(1., 1., 1., 1., 1.),
+ clip_border=True),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ refine_cfg=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.4,
+ min_pos_iou=0,
+ ignore_iof_thr=-1,
+ iou_calculator=dict(type='BboxOverlaps2D_rotated')),
+ bbox_coder=dict(type='DeltaXYWHABBoxCoder',
+ target_means=(0., 0., 0., 0., 0.),
+ target_stds=(1., 1., 1., 1., 1.),
+ clip_border=True),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False))
)
- ],
- stage_loss_weights=[1.0, 1.0]
+ )
+dataset = dict(
+ train=dict(
+ type="DOTADataset",
+ dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(type='RotatedRandomFlip', prob=0.5),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False,)
+
+ ],
+ batch_size=2,
+ num_workers=4,
+ shuffle=True,
+ filter_empty_gt=False
+ ),
+ val=dict(
+ type="DOTADataset",
+ dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False),
+ ],
+ batch_size=2,
+ num_workers=4,
+ shuffle=False
+ ),
+ test=dict(
+ type="ImageDataset",
+ images_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/test_1024_200_1.0/images',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False,),
+ ],
+ num_workers=4,
+ batch_size=1,
+ )
)
-merge_nms_iou_thr_dict = {
- 'roundabout': 0.1, 'tennis-court': 0.3, 'swimming-pool': 0.1, 'storage-tank': 0.1,
- 'soccer-ball-field': 0.3, 'small-vehicle': 0.05, 'ship': 0.05, 'plane': 0.3,
- 'large-vehicle': 0.05, 'helicopter': 0.2, 'harbor': 0.0001, 'ground-track-field': 0.3,
- 'bridge': 0.0001, 'basketball-court': 0.3, 'baseball-diamond': 0.3
-}
+optimizer = dict(
+ type='SGD',
+ lr=0.01/4., #0.0,#0.01*(1/8.),
+ momentum=0.9,
+ weight_decay=0.0001,
+ grad_clip=dict(
+ max_norm=35,
+ norm_type=2))
-merge_cfg = dict(
- nms_pre=2000,
- score_thr=0.1,
- nms=dict(type='rnms', iou_thr=merge_nms_iou_thr_dict),
- max_per_img=1000,
-)
+scheduler = dict(
+ type='StepLR',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ milestones=[10])
+
+
+logger = dict(
+ type="RunLogger")
-test_cfg = dict(
- nms_pre=1000,
- score_thr=0.1,
- nms=dict(type='rnms', iou_thr=0.05),
- max_per_img=100,
- merge_cfg=merge_cfg
-)
\ No newline at end of file
+# when we the trained model from cshuan, image is rgb
+max_epoch = 12
+eval_interval = 1
+checkpoint_interval = 1
+log_interval = 50
diff --git a/projects/r3det/configs/r3det_r50_fpn_1x_test.py b/projects/r3det/configs/r3det_r50_fpn_1x_test.py
new file mode 100644
index 00000000..dd6404e2
--- /dev/null
+++ b/projects/r3det/configs/r3det_r50_fpn_1x_test.py
@@ -0,0 +1,178 @@
+# model settings
+model = dict(
+ type='R3Det',
+ backbone=dict(
+ type='Resnet50',
+ frozen_stages=1,
+ return_stages=["layer1","layer2","layer3","layer4"],
+ pretrained= True),
+ neck=dict(
+ type='FPN',
+ in_channels=[256, 512, 1024, 2048],
+ out_channels=256,
+ start_level=1,
+ add_extra_convs="on_input",
+ num_outs=5),
+ bbox_head=dict(
+ type='R3Head',
+ num_classes=16,
+ in_channels=256,
+ feat_channels=256,
+ stacked_convs=4,
+ octave_base_scale=4,
+ scales_per_octave=3,
+ anchor_ratios=[1.0, 0.5, 2.0],
+ anchor_strides=[8, 16, 32, 64, 128],
+ target_means=[.0, .0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0, 1.0],
+ loss_init_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_init_bbox=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ loss_refine_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_refine_bbox=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ test_cfg=dict(
+ nms_pre=2000,
+ min_bbox_size=0,
+ score_thr=0.05,
+ nms=dict(type='nms_rotated', iou_thr=0.1),
+ max_per_img=2000),
+ train_cfg=dict(
+ init_cfg=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.4,
+ min_pos_iou=0,
+ ignore_iof_thr=-1,
+ iou_calculator=dict(type='BboxOverlaps2D_rotated')),
+ bbox_coder=dict(type='DeltaXYWHABBoxCoder',
+ target_means=(0., 0., 0., 0., 0.),
+ target_stds=(1., 1., 1., 1., 1.),
+ clip_border=True),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ refine_cfg=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.4,
+ min_pos_iou=0,
+ ignore_iof_thr=-1,
+ iou_calculator=dict(type='BboxOverlaps2D_rotated')),
+ bbox_coder=dict(type='DeltaXYWHABBoxCoder',
+ target_means=(0., 0., 0., 0., 0.),
+ target_stds=(1., 1., 1., 1., 1.),
+ clip_border=True),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False))
+ )
+ )
+dataset = dict(
+ train=dict(
+ type="DOTADataset",
+ dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ # dict(type='RotatedRandomFlip', prob=0.5),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False,)
+
+ ],
+ batch_size=2,
+ num_workers=4,
+ shuffle=True,
+ filter_empty_gt=False
+ ),
+ val=dict(
+ type="DOTADataset",
+ dataset_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/trainval_1024_200_1.0',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False),
+ ],
+ batch_size=2,
+ num_workers=4,
+ shuffle=False
+ ),
+ test=dict(
+ type="ImageDataset",
+ images_dir='/home/cxjyxx_me/workspace/JAD/datasets/processed_DOTA/test_1024_200_1.0/images',
+ transforms=[
+ dict(
+ type="RotatedResize",
+ min_size=1024,
+ max_size=1024
+ ),
+ dict(
+ type = "Pad",
+ size_divisor=32),
+ dict(
+ type = "Normalize",
+ mean = [123.675, 116.28, 103.53],
+ std = [58.395, 57.12, 57.375],
+ to_bgr=False,),
+ ],
+ num_workers=4,
+ batch_size=1,
+ )
+)
+
+optimizer = dict(
+ type='SGD',
+ lr=0.01/4., #0.0,#0.01*(1/8.),
+ momentum=0.9,
+ weight_decay=0.0001,
+ grad_clip=dict(
+ max_norm=35,
+ norm_type=2))
+
+scheduler = dict(
+ type='StepLR',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=1.0 / 3,
+ milestones=[10])
+
+
+logger = dict(
+ type="RunLogger")
+
+# when we the trained model from cshuan, image is rgb
+max_epoch = 12
+eval_interval = 1
+checkpoint_interval = 1
+log_interval = 50
diff --git a/projects/r3det/test_r3det.py b/projects/r3det/test_r3det.py
new file mode 100644
index 00000000..a85ab230
--- /dev/null
+++ b/projects/r3det/test_r3det.py
@@ -0,0 +1,72 @@
+import jittor as jt
+from jdet.config import init_cfg, get_cfg
+from jdet.utils.general import parse_losses
+from jdet.utils.registry import build_from_cfg,MODELS,DATASETS,OPTIMS
+import argparse
+import os
+import pickle as pk
+import jdet
+
+def main():
+ parser = argparse.ArgumentParser(description="Jittor Object Detection Training")
+ parser.add_argument(
+ "--set_data",
+ action='store_true'
+ )
+ args = parser.parse_args()
+
+ jt.flags.use_cuda=1
+ jt.set_global_seed(666)
+ init_cfg("projects/r3det/configs/r3det_r50_fpn_1x_test.py")
+ cfg = get_cfg()
+
+ model = build_from_cfg(cfg.model,MODELS)
+ optimizer = build_from_cfg(cfg.optimizer,OPTIMS,params= model.parameters())
+
+ model.train()
+ if (args.set_data):
+ imagess = []
+ targetss = []
+ correct_loss = []
+ train_dataset = build_from_cfg(cfg.dataset.train,DATASETS)
+ for batch_idx,(images,targets) in enumerate(train_dataset):
+ if (batch_idx > 10):
+ break
+ print(batch_idx)
+ imagess.append(jdet.utils.general.sync(images))
+ targetss.append(jdet.utils.general.sync(targets))
+ losses = model(images,targets)
+ all_loss,losses = parse_losses(losses)
+ optimizer.step(all_loss)
+ correct_loss.append(all_loss.item())
+ data = {
+ "imagess": imagess,
+ "targetss": targetss,
+ "correct_loss": correct_loss,
+ }
+ if (not os.path.exists("test_datas_r3det")):
+ os.makedirs("test_datas_r3det")
+ pk.dump(data, open("test_datas_r3det/test_data.pk", "wb"))
+ print(correct_loss)
+ else:
+ data = pk.load(open("test_datas_r3det/test_data.pk", "rb"))
+ imagess = jdet.utils.general.to_jt_var(data["imagess"])
+ targetss = jdet.utils.general.to_jt_var(data["targetss"])
+ correct_loss = data["correct_loss"]
+
+ for batch_idx in range(len(imagess)):
+ images = imagess[batch_idx]
+ targets = targetss[batch_idx]
+ losses = model(images,targets)
+ all_loss,losses = parse_losses(losses)
+ optimizer.step(all_loss)
+ l = all_loss.item()
+ c_l = correct_loss[batch_idx]
+ err_rate = abs(c_l-l)/min(c_l,l)
+ print(f"correct loss is {c_l:.4f}, runtime loss is {l:.4f}, err rate is {err_rate*100:.2f}%")
+ assert err_rate<1e-3,"LOSS is not correct, please check it"
+ print(f"Loss is correct with err_rate<{1e-3}")
+ print("success!")
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/python/jdet/models/boxes/anchor_generator.py b/python/jdet/models/boxes/anchor_generator.py
index 5a047cf6..3fd035dd 100644
--- a/python/jdet/models/boxes/anchor_generator.py
+++ b/python/jdet/models/boxes/anchor_generator.py
@@ -1,8 +1,53 @@
+import collections.abc
from calendar import c
import jittor as jt
import numpy as np
from jittor.misc import _pair
from jdet.utils.registry import BOXES
+from itertools import repeat
+
+
+def to_2tuple(x):
+ if isinstance(x, collections.abc.Iterable):
+ return x
+ return tuple(repeat(x, 2))
+
+@BOXES.register_module()
+class PseudoAnchorGenerator:
+ def __init__(self, base_size):
+ self.strides = to_2tuple(base_size)
+
+ @property
+ def num_base_anchors(self):
+ """list[int]: total number of base anchors in a feature grid"""
+ return 1
+
+ def _meshgrid(self, x, y, row_major=True):
+ xx = x.repeat(len(y))
+ yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
+ if row_major:
+ return xx, yy
+ else:
+ return yy, xx
+
+ def valid_flags(self, featmap_size, valid_size):
+ feat_h, feat_w = featmap_size
+ valid_h, valid_w = valid_size
+ assert valid_h <= feat_h and valid_w <= feat_w
+ valid_x = jt.zeros((feat_w,)).bool()
+ valid_y = jt.zeros((feat_h,)).bool()
+ valid_x[:valid_w] = 1
+ valid_y[:valid_h] = 1
+ valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
+ valid = valid_xx & valid_yy
+ valid = valid[:, None].expand((valid.size(0), self.num_base_anchors)).view(-1)
+ return valid
+
+ def __repr__(self):
+ indent_str = ' '
+ repr_str = self.__class__.__name__ + '(\n'
+ repr_str += f'{indent_str}strides={self.strides})'
+ return repr_str
@BOXES.register_module()
class AnchorGeneratorRotatedRetinaNet:
diff --git a/python/jdet/models/networks/__init__.py b/python/jdet/models/networks/__init__.py
index d7bc69ad..003cf0a6 100644
--- a/python/jdet/models/networks/__init__.py
+++ b/python/jdet/models/networks/__init__.py
@@ -1,6 +1,7 @@
from .rcnn import RCNN
from .retinanet import RetinaNet
from .rotated_retinanet import RotatedRetinaNet
+from .r3det import R3Det
from .s2anet import S2ANet
from .gliding_vertex import GlidingVertex
from .oriented_rcnn import OrientedRCNN
diff --git a/python/jdet/models/networks/r3det.py b/python/jdet/models/networks/r3det.py
index 79115f27..95732499 100644
--- a/python/jdet/models/networks/r3det.py
+++ b/python/jdet/models/networks/r3det.py
@@ -1,160 +1,37 @@
-from jdet.ops.fr import FeatureRefineModule
-from jittor import nn,init
import jittor as jt
-from jdet.utils.registry import build_from_cfg,BACKBONES,HEADS,NECKS
+from jittor import nn
+from jdet.utils.registry import MODELS,build_from_cfg,BACKBONES,HEADS,NECKS
+
+
+@MODELS.register_module()
class R3Det(nn.Module):
"""
- Rotated Refinement RetinaNet
"""
- def __init__(self,
- num_refine_stages,
- backbone,
- neck=None,
- bbox_head=None,
- refine_heads=None):
- super(R3Det, self).__init__()
- self.num_refine_stages = num_refine_stages
+ def __init__(self,backbone,neck=None,bbox_head=None):
+ super().__init__()
self.backbone = build_from_cfg(backbone,BACKBONES)
self.neck = build_from_cfg(neck,NECKS)
self.bbox_head = build_from_cfg(bbox_head,HEADS)
- self.feat_refine_module = nn.ModuleList()
- self.refine_head = nn.ModuleList()
- for i, (frm_cfg, refine_head) in enumerate(zip(frm_cfgs, refine_heads)):
- self.feat_refine_module.append(FeatureRefineModule(**frm_cfg))
- self.refine_head.append(build_from_cfg(refine_head,HEADS))
+
+ def train(self):
+ super().train()
+ self.backbone.train()
def execute(self,images,targets):
- pass
- x = self.backbone(images)
+ '''
+ Args:
+ images (jt.Var): image tensors, shape is [N,C,H,W]
+ targets (list[dict]): targets for each image
+ Rets:
+ outputs: train mode will be losses val mode will be results
+ '''
+ features = self.backbone(images)
+
if self.neck:
- x = self.neck(x)
-
- outs = self.bbox_head(x)
- rois = self.bbox_head.filter_bboxes(*outs)
- # rois: list(indexed by images) of list(indexed by levels)
- for i in range(self.num_refine_stages):
- x_refine = self.feat_refine_module[i](x, rois)
- outs = self.refine_head[i](x_refine)
- if i + 1 in range(self.num_refine_stages):
- rois = self.refine_head[i].refine_bboxes(*outs, rois=rois)
- return outs
-
- def forward_train(self,
- img,
- img_metas,
- gt_bboxes,
- gt_labels,
- gt_bboxes_ignore=None):
- losses = dict()
- x = self.extract_feat(img)
-
- outs = self.bbox_head(x)
-
- train_cfg = self.train_cfg['s0']
- loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
- loss_base = self.bbox_head.loss(
- *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
- for name, value in loss_base.items():
- losses['s0.{}'.format(name)] = value
-
- rois = self.bbox_head.filter_bboxes(*outs)
- # rois: list(indexed by images) of list(indexed by levels)
- for i in range(self.num_refine_stages):
- lw = self.train_cfg.stage_loss_weights[i]
- train_cfg = self.train_cfg['sr'][i]
-
- x_refine = self.feat_refine_module[i](x, rois)
- outs = self.refine_head[i](x_refine)
- loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
- loss_refine = self.refine_head[i].loss(
- *loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore, rois=rois)
- for name, value in loss_refine.items():
- losses['sr{}.{}'.format(i, name)] = (
- [v * lw for v in value] if 'loss' in name else value)
-
- if i + 1 in range(self.num_refine_stages):
- rois = self.refine_head[i].refine_bboxes(*outs, rois=rois)
-
- return losses
-
- def simple_test(self,
- img,
- img_meta,
- rescale=False):
- if 'tile_offset' in img_meta[0]:
- # using tile-cropped TTA. force using aug_test instead of simple_test
- return self.aug_test(imgs=[img], img_metas=[img_meta], rescale=True)
-
- x = self.extract_feat(img)
- outs = self.bbox_head(x)
- rois = self.bbox_head.filter_bboxes(*outs)
- # rois: list(indexed by images) of list(indexed by levels)
- for i in range(self.num_refine_stages):
- x_refine = self.feat_refine_module[i](x, rois)
- outs = self.refine_head[i](x_refine)
- if i + 1 in range(self.num_refine_stages):
- rois = self.refine_head[i].refine_bboxes(*outs, rois=rois)
-
- bbox_inputs = outs + (img_meta, self.test_cfg, rescale)
- bbox_list = self.refine_head[-1].get_bboxes(*bbox_inputs, rois=rois)
- bbox_results = [
- rbbox2result(det_bboxes, det_labels, self.refine_head[-1].num_classes)
- for det_bboxes, det_labels in bbox_list
- ]
- return bbox_results[0]
-
- def aug_test(self, imgs, img_metas, rescale=True):
- AUG_BS = 8
- assert rescale, '''while r3det uses overlapped cropping augmentation by default,
- the result should be rescaled to input images sizes to simplify the test pipeline'''
- if 'tile_offset' in img_metas[0][0]:
- assert imgs[0].size(0) == 1, '''when using cropped tiles augmentation,
- image batch size must be set to 1'''
- aug_det_bboxes, aug_det_labels = [], []
- num_augs = len(imgs)
- for idx in range(0, num_augs, AUG_BS):
- img = imgs[idx:idx + AUG_BS]
- img_meta = img_metas[idx:idx + AUG_BS]
- act_num_augs = len(img_meta)
- img = torch.cat(img, dim=0)
- img_meta = sum(img_meta, [])
- # for img, img_meta in zip(imgs, img_metas):
- x = self.extract_feat(img)
- outs = self.bbox_head(x)
- rois = self.bbox_head.filter_bboxes(*outs)
- # rois: list(indexed by images) of list(indexed by levels)
- det_bbox_bs = [[] for _ in range(act_num_augs)]
- det_label_bs = [[] for _ in range(act_num_augs)]
- for i in range(self.num_refine_stages):
- x_refine = self.feat_refine_module[i](x, rois)
- outs = self.refine_head[i](x_refine)
- if i + 1 in range(self.num_refine_stages):
- rois = self.refine_head[i].refine_bboxes(*outs, rois=rois)
-
- bbox_inputs = outs + (img_meta, self.test_cfg, False)
- bbox_bs = self.refine_head[i].get_bboxes(*bbox_inputs, rois=rois)
- # [(rbbox_aug0, class_aug0), (rbbox_aug1, class_aug1), (rbbox_aug2, class_aug2), ...]
- for j in range(act_num_augs):
- det_bbox_bs[j].append(bbox_bs[j][0])
- det_label_bs[j].append(bbox_bs[j][1])
-
- for j in range(act_num_augs):
- det_bbox_bs[j] = torch.cat(det_bbox_bs[j])
- det_label_bs[j] = torch.cat(det_label_bs[j])
-
- aug_det_bboxes += det_bbox_bs
- aug_det_labels += det_label_bs
-
- aug_det_bboxes, aug_det_labels = merge_tiles_aug_rbboxes(
- aug_det_bboxes,
- aug_det_labels,
- img_metas,
- self.test_cfg.merge_cfg,
- self.CLASSES)
-
- return rbbox2result(aug_det_bboxes, aug_det_labels, self.refine_head[-1].num_classes)
-
- else:
- raise NotImplementedError
\ No newline at end of file
+ features = self.neck(features)
+
+ outputs = self.bbox_head(features, targets)
+
+ return outputs
\ No newline at end of file
diff --git a/python/jdet/models/roi_heads/__init__.py b/python/jdet/models/roi_heads/__init__.py
index c82b747c..01d37e39 100644
--- a/python/jdet/models/roi_heads/__init__.py
+++ b/python/jdet/models/roi_heads/__init__.py
@@ -5,6 +5,7 @@
from . import rotated_retina_head
from . import rotated_retina_distribution_head
from . import ld_rotated_retina_head
+from . import r3_head
from . import s2anet_head
from . import rpn_head
from . import oriented_rpn_head
diff --git a/python/jdet/models/roi_heads/r3_head.py b/python/jdet/models/roi_heads/r3_head.py
new file mode 100644
index 00000000..223faa02
--- /dev/null
+++ b/python/jdet/models/roi_heads/r3_head.py
@@ -0,0 +1,660 @@
+import numpy as np
+import jittor as jt
+from jittor import nn
+
+from jdet.ops.fr import FeatureRefineModule
+from jdet.models.utils.weight_init import normal_init,bias_init_with_prob
+from jdet.models.utils.modules import ConvModule
+from jdet.utils.general import multi_apply
+from jdet.utils.registry import HEADS,LOSSES,BOXES,build_from_cfg
+
+
+from jdet.ops.nms_rotated import multiclass_nms_rotated
+from jdet.models.boxes.box_ops import delta2bbox_rotated, rotated_box_to_poly
+from jdet.models.boxes.anchor_target import images_to_levels,anchor_target
+from jdet.models.boxes.anchor_generator import PseudoAnchorGenerator
+from jdet.models.boxes.anchor_generator import AnchorGeneratorRotatedRetinaNet
+
+
+@HEADS.register_module()
+class R3Head(nn.Module):
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ feat_channels=256,
+ stacked_convs=2,
+ octave_base_scale=4,
+ scales_per_octave=3,
+ anchor_ratios=[1.0],
+ anchor_strides=[8, 16, 32, 64, 128],
+ anchor_base_sizes=None,
+ target_means=(.0, .0, .0, .0, .0),
+ target_stds=(1.0, 1.0, 1.0, 1.0, 1.0),
+ loss_init_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_init_bbox=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ loss_refine_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_refine_bbox=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ test_cfg=dict(
+ nms_pre=2000,
+ min_bbox_size=0,
+ score_thr=0.05,
+ nms=dict(type='nms_rotated', iou_thr=0.1),
+ max_per_img=2000),
+ train_cfg=dict(
+ init_cfg=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.4,
+ min_pos_iou=0,
+ ignore_iof_thr=-1,
+ iou_calculator=dict(type='BboxOverlaps2D_rotated')),
+ bbox_coder=dict(type='DeltaXYWHABBoxCoder',
+ target_means=(0., 0., 0., 0., 0.),
+ target_stds=(1., 1., 1., 1., 1.),
+ clip_border=True),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ refine_cfg=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.6,
+ neg_iou_thr=0.5,
+ min_pos_iou=0,
+ ignore_iof_thr=-1,
+ iou_calculator=dict(type='BboxOverlaps2D_rotated')),
+ bbox_coder=dict(type='DeltaXYWHABBoxCoder',
+ target_means=(0., 0., 0., 0., 0.),
+ target_stds=(1., 1., 1., 1., 1.),
+ clip_border=True),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False))):
+ super().__init__()
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.stacked_convs = stacked_convs
+ self.anchor_ratios = anchor_ratios
+ self.anchor_strides = anchor_strides
+ self.anchor_base_sizes = list(
+ anchor_strides) if anchor_base_sizes is None else anchor_base_sizes
+ self.target_means = target_means
+ self.target_stds = target_stds
+
+ self.use_sigmoid_cls = loss_refine_cls.get('use_sigmoid', False)
+ self.sampling = loss_refine_cls['type'] not in ['FocalLoss', 'GHMC']
+ if self.use_sigmoid_cls:
+ self.cls_out_channels = num_classes - 1
+ else:
+ self.cls_out_channels = num_classes
+
+ if self.cls_out_channels <= 0:
+ raise ValueError('num_classes={} is too small'.format(num_classes))
+ self.loss_init_cls = build_from_cfg(loss_init_cls,LOSSES)
+ self.loss_init_bbox = build_from_cfg(loss_init_bbox,LOSSES)
+ self.loss_refine_cls = build_from_cfg(loss_refine_cls,LOSSES)
+ self.loss_refine_bbox = build_from_cfg(loss_refine_bbox,LOSSES)
+ self.feat_refine_module = FeatureRefineModule(in_channels=in_channels, featmap_strides=anchor_strides)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ self.anchor_generators = []
+ self.refine_anchor_generators = []
+ for anchor_base in self.anchor_base_sizes:
+ self.anchor_generators.append(AnchorGeneratorRotatedRetinaNet(anchor_base, None, anchor_ratios,
+ octave_base_scale=octave_base_scale, scales_per_octave=scales_per_octave))
+ self.refine_anchor_generators.append(PseudoAnchorGenerator(anchor_base))
+ self.num_anchors = self.anchor_generators[0].num_base_anchors
+ # anchor cache
+ self.base_anchors = dict()
+ self._init_layers()
+
+ def _init_layers(self):
+ self.relu = nn.ReLU()
+ self.init_reg_convs = nn.ModuleList()
+ self.init_cls_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.init_reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1))
+ self.init_cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1))
+
+ self.init_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 5, 3, padding=1)
+ self.init_cls = nn.Conv2d(self.feat_channels, self.num_anchors * self.cls_out_channels, 3, padding=1)
+ # self.init_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 5, 1)
+ # self.init_cls = nn.Conv2d(self.feat_channels, self.num_anchors * self.cls_out_channels, 1)
+
+ self.refine_reg_convs = nn.ModuleList()
+ self.refine_cls_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.refine_reg_convs.append(
+ ConvModule(
+ self.feat_channels,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1))
+ self.refine_cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1))
+
+ self.refine_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+ self.refine_reg = nn.Conv2d(self.feat_channels, 5, 3, padding=1)
+
+ self.init_weights()
+
+ def init_weights(self):
+ for m in self.init_reg_convs:
+ normal_init(m.conv, std=0.01)
+ for m in self.init_cls_convs:
+ normal_init(m.conv, std=0.01)
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.init_reg, std=0.01)
+ normal_init(self.init_cls, std=0.01, bias=bias_cls)
+
+ for m in self.refine_reg_convs:
+ normal_init(m.conv, std=0.01)
+ for m in self.refine_cls_convs:
+ normal_init(m.conv, std=0.01)
+ normal_init(self.refine_cls, std=0.01, bias=bias_cls)
+ normal_init(self.refine_reg, std=0.01)
+
+ def init_forward_single(self, x):
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.init_cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.init_reg_convs:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.init_cls(cls_feat)
+ bbox_pred = self.init_reg(reg_feat)
+ return cls_score, bbox_pred
+
+ def refine_forward_single(self, x):
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.refine_cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.refine_reg_convs:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.refine_cls(cls_feat)
+ bbox_pred = self.refine_reg(reg_feat)
+ return cls_score, bbox_pred
+
+ def filter_bboxes(self, cls_scores, bbox_preds):
+ """Filter predicted bounding boxes at each position of the feature
+ maps. Only one bounding boxes with highest score will be left at each
+ position. This filter will be used in R3Det prior to the first feature
+ refinement stage.
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 5, H, W)
+ Returns:
+ list[list[Tensor]]: best or refined rbboxes of each level \
+ of each image.
+ """
+ num_levels = len(cls_scores)
+ assert num_levels == len(bbox_preds)
+
+ num_imgs = cls_scores[0].size(0)
+
+ for i in range(num_levels):
+ assert num_imgs == cls_scores[i].size(0) == bbox_preds[i].size(0)
+
+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+
+ bboxes_list = [[] for _ in range(num_imgs)]
+
+ for lvl in range(num_levels):
+ cls_score = cls_scores[lvl]
+ bbox_pred = bbox_preds[lvl]
+ anchors = self.anchor_generators[lvl].grid_anchors(featmap_sizes[lvl], self.anchor_strides[lvl])
+
+ cls_score = cls_score.permute(0, 2, 3, 1)
+ cls_score = cls_score.reshape(num_imgs, -1, self.num_anchors,
+ self.cls_out_channels)
+
+ cls_score = cls_score.max(dim=3, keepdims=True)
+ best_ind, _ = cls_score.argmax(dim=2, keepdims=True)
+ best_ind = best_ind.expand(-1, -1, -1, 5)
+
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1)
+ bbox_pred = bbox_pred.reshape(num_imgs, -1, self.num_anchors, 5)
+ best_pred = bbox_pred.gather(
+ dim=-2, index=best_ind).squeeze(dim=-2)
+
+ anchors = anchors.reshape(-1, self.num_anchors, 5)
+
+ for img_id in range(num_imgs):
+ best_ind_i = best_ind[img_id]
+ best_pred_i = best_pred[img_id]
+ best_anchor_i = anchors.gather(
+ dim=-2, index=best_ind_i).squeeze(dim=-2)
+
+ best_bbox_i = delta2bbox_rotated(best_anchor_i, best_pred_i, self.target_means,
+ self.target_stds, wh_ratio_clip=1e-6)
+ bboxes_list[img_id].append(best_bbox_i.detach())
+
+ return bboxes_list
+
+ def get_init_anchors(self,
+ featmap_sizes,
+ img_metas):
+ """Get anchors according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ img_metas (list[dict]): Image meta info.
+
+ Returns:
+ tuple: anchors of each image, valid flags of each image
+ """
+ num_imgs = len(img_metas)
+ num_levels = len(featmap_sizes)
+
+ # since feature map sizes of all images are the same, we only compute
+ # anchors for one time
+ multi_level_anchors = []
+ for i in range(num_levels):
+ anchors = self.anchor_generators[i].grid_anchors(featmap_sizes[i], self.anchor_strides[i])
+ multi_level_anchors.append(anchors)
+ anchor_list = [multi_level_anchors for _ in range(num_imgs)]
+
+ # for each image, we compute valid flags of multi level anchors
+ valid_flag_list = []
+ for img_id, img_meta in enumerate(img_metas):
+ multi_level_flags = []
+ for i in range(num_levels):
+ anchor_stride = self.anchor_strides[i]
+ feat_h, feat_w = featmap_sizes[i]
+ w,h = img_meta['pad_shape'][:2]
+ valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
+ valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
+ flags = self.anchor_generators[i].valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w))
+ multi_level_flags.append(flags)
+ valid_flag_list.append(multi_level_flags)
+ return anchor_list, valid_flag_list
+
+ def get_refine_anchors(self,
+ featmap_sizes,
+ refine_anchors,
+ img_metas,
+ is_train=True):
+ num_levels = len(featmap_sizes)
+
+ # refine_anchors_list = []
+ # for img_id, img_meta in enumerate(img_metas):
+ # mlvl_refine_anchors = []
+ # for i in range(num_levels):
+ # refine_anchor = refine_anchors[i][img_id].reshape(-1, 5)
+ # mlvl_refine_anchors.append(refine_anchor)
+ # refine_anchors_list.append(mlvl_refine_anchors)
+ refine_anchors_list = [[
+ bboxes_img_lvl.clone().detach() for bboxes_img_lvl in bboxes_img
+ ] for bboxes_img in refine_anchors]
+
+ valid_flag_list = []
+ if is_train:
+ for img_id, img_meta in enumerate(img_metas):
+ multi_level_flags = []
+ for i in range(num_levels):
+ anchor_stride = self.anchor_strides[i]
+ feat_h, feat_w = featmap_sizes[i]
+ w,h = img_meta['pad_shape'][:2]
+ valid_feat_h = min(int(np.ceil(h / anchor_stride)), feat_h)
+ valid_feat_w = min(int(np.ceil(w / anchor_stride)), feat_w)
+ flags = self.refine_anchor_generators[i].valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w))
+ multi_level_flags.append(flags)
+ valid_flag_list.append(multi_level_flags)
+ return refine_anchors_list, valid_flag_list
+
+ def loss(self,
+ init_cls_scores,
+ init_bbox_preds,
+ refine_anchors,
+ refine_cls_scores,
+ refine_bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+
+ cfg = self.train_cfg.copy()
+ featmap_sizes = [featmap.size()[-2:] for featmap in refine_cls_scores]
+ assert len(featmap_sizes) == len(self.anchor_generators)
+
+ anchor_list, valid_flag_list = self.get_init_anchors(featmap_sizes, img_metas)
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ # concat all level anchors and flags to a single tensor
+ concat_anchor_list = []
+ for i in range(len(anchor_list)):
+ concat_anchor_list.append(jt.contrib.concat(anchor_list[i]))
+ all_anchor_list = images_to_levels(concat_anchor_list,num_level_anchors)
+
+ # Feature Alignment Module
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = anchor_target(
+ anchor_list,
+ valid_flag_list.copy(),
+ gt_bboxes,
+ img_metas,
+ self.target_means,
+ self.target_stds,
+ cfg.init_cfg,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels,
+ sampling=self.sampling)
+ if cls_reg_targets is None:
+ return None
+ labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,num_total_pos, num_total_neg = cls_reg_targets
+
+ num_total_samples = num_total_pos + num_total_neg if self.sampling else num_total_pos
+
+ losses_init_cls, losses_init_bbox = multi_apply(
+ self.loss_init_single,
+ init_cls_scores,
+ init_bbox_preds,
+ all_anchor_list,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_samples=num_total_samples,
+ cfg=cfg.init_cfg)
+
+
+ # refine stage
+ refine_anchors_list, valid_flag_list = self.get_refine_anchors(
+ featmap_sizes, refine_anchors, img_metas)
+ # refine_anchors_list = [[
+ # bboxes_img_lvl.clone().detach() for bboxes_img_lvl in bboxes_img
+ # ] for bboxes_img in refine_anchors]
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0)
+ for anchors in refine_anchors_list[0]]
+ # concat all level anchors and flags to a single tensor
+ concat_anchor_list = []
+ for i in range(len(refine_anchors_list)):
+ concat_anchor_list.append(jt.contrib.concat(refine_anchors_list[i]))
+ all_anchor_list = images_to_levels(concat_anchor_list,
+ num_level_anchors)
+
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = anchor_target(
+ refine_anchors_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ self.target_means,
+ self.target_stds,
+ cfg.refine_cfg,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels,
+ sampling=self.sampling)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+ num_total_samples = (
+ num_total_pos + num_total_neg if self.sampling else num_total_pos)
+
+ losses_refine_cls, losses_refine_bbox = multi_apply(
+ self.loss_refine_single,
+ refine_cls_scores,
+ refine_bbox_preds,
+ all_anchor_list,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_samples=num_total_samples,
+ cfg=cfg.refine_cfg)
+
+ return dict(loss_init_cls=losses_init_cls,
+ loss_init_bbox=losses_init_bbox,
+ loss_refine_cls=losses_refine_cls,
+ loss_refine_bbox=losses_refine_bbox)
+
+ def loss_init_single(self,
+ init_cls_score,
+ init_bbox_pred,
+ anchors,
+ labels,
+ label_weights,
+ bbox_targets,
+ bbox_weights,
+ num_total_samples,
+ cfg):
+ # classification loss
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ init_cls_score = init_cls_score.permute(
+ 0, 2, 3, 1).reshape(-1, self.cls_out_channels)
+ loss_init_cls = self.loss_init_cls(
+ init_cls_score, labels, label_weights, avg_factor=num_total_samples)
+ # regression loss
+ bbox_targets = bbox_targets.reshape(-1, 5)
+ bbox_weights = bbox_weights.reshape(-1, 5)
+ init_bbox_pred = init_bbox_pred.permute(0, 2, 3, 1).reshape(-1, 5)
+
+ reg_decoded_bbox = cfg.get('reg_decoded_bbox', False)
+ if reg_decoded_bbox:
+ # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+ # is applied directly on the decoded bounding boxes, it
+ # decodes the already encoded coordinates to absolute format.
+ bbox_coder_cfg = cfg.get('bbox_coder', '')
+ if bbox_coder_cfg == '':
+ bbox_coder_cfg = dict(type='DeltaXYWHBBoxCoder')
+ bbox_coder = build_from_cfg(bbox_coder_cfg,BOXES)
+ anchors = anchors.reshape(-1, 5)
+ init_bbox_pred = bbox_coder.decode(anchors, init_bbox_pred)
+ loss_init_bbox = self.loss_init_bbox(
+ init_bbox_pred,
+ bbox_targets,
+ bbox_weights,
+ avg_factor=num_total_samples)
+ return loss_init_cls, loss_init_bbox
+
+ def loss_refine_single(self,
+ refine_cls_score,
+ refine_bbox_pred,
+ anchors,
+ labels,
+ label_weights,
+ bbox_targets,
+ bbox_weights,
+ num_total_samples,
+ cfg):
+ # classification loss
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ refine_cls_score = refine_cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ loss_refine_cls = self.loss_refine_cls(
+ refine_cls_score, labels, label_weights, avg_factor=num_total_samples)
+ # regression loss
+ bbox_targets = bbox_targets.reshape(-1, 5)
+ bbox_weights = bbox_weights.reshape(-1, 5)
+ refine_bbox_pred = refine_bbox_pred.permute(0, 2, 3, 1).reshape(-1, 5)
+
+ reg_decoded_bbox = cfg.get('reg_decoded_bbox', False)
+ if reg_decoded_bbox:
+ # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+ # is applied directly on the decoded bounding boxes, it
+ # decodes the already encoded coordinates to absolute format.
+ bbox_coder_cfg = cfg.get('bbox_coder', '')
+ if bbox_coder_cfg == '':
+ bbox_coder_cfg = dict(type='DeltaXYWHBBoxCoder')
+ bbox_coder = build_from_cfg(bbox_coder_cfg,BOXES)
+ anchors = anchors.reshape(-1, 5)
+ refine_bbox_pred = bbox_coder.decode(anchors, refine_bbox_pred)
+ loss_refine_bbox = self.loss_refine_bbox(
+ refine_bbox_pred,
+ bbox_targets,
+ bbox_weights,
+ avg_factor=num_total_samples)
+ return loss_refine_cls, loss_refine_bbox
+
+ def get_bboxes(self,
+ refine_anchors,
+ refine_cls_scores,
+ refine_bbox_preds,
+ img_metas,
+ rescale=True):
+ assert len(refine_cls_scores) == len(refine_bbox_preds)
+ cfg = self.test_cfg.copy()
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in refine_cls_scores]
+ num_levels = len(refine_cls_scores)
+
+ refine_anchors = self.get_refine_anchors(
+ featmap_sizes, refine_anchors, img_metas, is_train=False)
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score_list = [
+ refine_cls_scores[i][img_id].detach() for i in range(num_levels)
+ ]
+ bbox_pred_list = [
+ refine_bbox_preds[i][img_id].detach() for i in range(num_levels)
+ ]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ proposals = self.get_bboxes_single(cls_score_list, bbox_pred_list,
+ refine_anchors[0][img_id], img_shape,
+ scale_factor, cfg, rescale)
+
+ result_list.append(proposals)
+ return result_list
+
+ def get_bboxes_single(self,
+ cls_score_list,
+ bbox_pred_list,
+ mlvl_anchors,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False):
+ """
+ Transform outputs for a single batch item into labeled boxes.
+ """
+
+ assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
+ mlvl_bboxes = []
+ mlvl_scores = []
+ for cls_score, bbox_pred, anchors in zip(cls_score_list,
+ bbox_pred_list, mlvl_anchors):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ cls_score = cls_score.permute(
+ 1, 2, 0).reshape(-1, self.cls_out_channels)
+
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ scores = cls_score.softmax(-1)
+
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 5)
+ # anchors = rect2rbox(anchors)
+ nms_pre = cfg.get('nms_pre', -1)
+ if nms_pre > 0 and scores.shape[0] > nms_pre:
+ # Get maximum scores for foreground classes.
+ if self.use_sigmoid_cls:
+ max_scores = scores.max(dim=1)
+ else:
+ max_scores = scores[:, 1:].max(dim=1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ anchors = anchors[topk_inds, :]
+ bbox_pred = bbox_pred[topk_inds, :]
+ scores = scores[topk_inds, :]
+ bboxes = delta2bbox_rotated(anchors, bbox_pred, self.target_means,
+ self.target_stds, img_shape)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_bboxes = jt.contrib.concat(mlvl_bboxes)
+ if rescale:
+ mlvl_bboxes[..., :4] /= scale_factor
+ mlvl_scores = jt.contrib.concat(mlvl_scores)
+ if self.use_sigmoid_cls:
+ # Add a dummy background class to the front when using sigmoid
+ padding = jt.zeros((mlvl_scores.shape[0], 1),dtype=mlvl_scores.dtype)
+ mlvl_scores = jt.contrib.concat([padding, mlvl_scores], dim=1)
+ det_bboxes, det_labels = multiclass_nms_rotated(mlvl_bboxes,
+ mlvl_scores,
+ cfg.score_thr, cfg.nms,
+ cfg.max_per_img)
+ boxes = det_bboxes[:, :5]
+ scores = det_bboxes[:, 5]
+ polys = rotated_box_to_poly(boxes)
+ return polys, scores, det_labels
+
+
+ def parse_targets(self,targets,is_train=True):
+ img_metas = []
+ gt_bboxes = []
+ gt_bboxes_ignore = []
+ gt_labels = []
+
+ for target in targets:
+ if is_train:
+ gt_bboxes.append(target["rboxes"])
+ gt_labels.append(target["labels"])
+ gt_bboxes_ignore.append(target["rboxes_ignore"])
+ img_metas.append(dict(
+ img_shape=target["img_size"][::-1],
+ scale_factor=target["scale_factor"],
+ pad_shape = target["pad_shape"]
+ ))
+ if not is_train:
+ return img_metas
+ return gt_bboxes,gt_labels,img_metas,gt_bboxes_ignore
+
+ def execute(self, feats, targets):
+
+ init_outs = multi_apply(self.init_forward_single, feats)
+ rois = self.filter_bboxes(*init_outs)
+ x_refine = self.feat_refine_module(feats, rois)
+ refine_outs = multi_apply(self.refine_forward_single, x_refine)
+
+ if self.is_training():
+ return self.loss(*init_outs, rois, *refine_outs, *self.parse_targets(targets))
+ else:
+ return self.get_bboxes(rois, *refine_outs, self.parse_targets(targets,is_train=False))
diff --git a/python/jdet/ops/fr.py b/python/jdet/ops/fr.py
index 2a8ccf8a..086d25fa 100644
--- a/python/jdet/ops/fr.py
+++ b/python/jdet/ops/fr.py
@@ -1,3 +1,4 @@
+import numpy as np
import jittor as jt
from jittor import nn
from jdet.models.utils.weight_init import normal_init
@@ -328,6 +329,15 @@ def init_weights(self):
normal_init(self.conv_1_5, std=0.01)
normal_init(self.conv_1_1, std=0.01)
+ def le135_to_oc(self, boxes):
+ x, y, w, h, t = boxes.unbind(dim=-1)
+ start_angle = -0.5 * np.pi
+ t = ((t - start_angle) % np.pi)
+ w_ = jt.where(t < np.pi / 2, w, h)
+ h_ = jt.where(t < np.pi / 2, h, w)
+ t = jt.where(t < np.pi / 2, t, t - np.pi / 2) + start_angle
+ return jt.stack([x, y, w_, h_, t], dim=-1)
+
def execute(self, x, best_rbboxes):
"""
Args:
@@ -342,6 +352,8 @@ def execute(self, x, best_rbboxes):
feat_scale_1 = self.conv_5_1(self.conv_1_5(x_scale))
feat_scale_2 = self.conv_1_1(x_scale)
feat_scale = feat_scale_1 + feat_scale_2
+ # convert 'le135' to 'oc', because feature_refine op only support 'oc'.
+ best_rbboxes_scale = self.le135_to_oc(best_rbboxes_scale)
feat_refined_scale = fr_scale(feat_scale, best_rbboxes_scale)
out.append(x_scale + feat_refined_scale)
return out