FCOS computing loss source code interpretation

Source code interpretation of FCOS loss calculation

Recently, I feel that the FCOS paper is not specific enough, so I debug the source code interpretation source code for my future review. There are many techniques that are hard to come up with without reading the source code of the author.
It includes the following:

  • How to generate the box style needed by the loss function according to the box coordinates of the original data
  • How to allocate the characteristic graphs of different level s according to different box sizes
"""
This file contains specific functions for computing losses of FCOS
file
"""

import torch
from torch.nn import functional as F
from torch import nn
import os
from ..utils import concat_box_prediction_layers
from fcos_core.layers import IOULoss
from fcos_core.layers import SigmoidFocalLoss
from fcos_core.modeling.matcher import Matcher
from fcos_core.modeling.utils import cat
from fcos_core.structures.boxlist_ops import boxlist_iou
from fcos_core.structures.boxlist_ops import cat_boxlist


INF = 100000000


def get_num_gpus():
    return int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1


def reduce_sum(tensor):
    if get_num_gpus() <= 1:
        return tensor
    import torch.distributed as dist
    tensor = tensor.clone()
    dist.all_reduce(tensor, op=dist.reduce_op.SUM)
    return tensor


class FCOSLossComputation(object):
    """
    This class computes the FCOS losses.
    """

    def __init__(self, cfg):
        self.cls_loss_func = SigmoidFocalLoss(
            cfg.MODEL.FCOS.LOSS_GAMMA,
            cfg.MODEL.FCOS.LOSS_ALPHA
        )
        self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES
        self.center_sampling_radius = cfg.MODEL.FCOS.CENTER_SAMPLING_RADIUS
        self.iou_loss_type = cfg.MODEL.FCOS.IOU_LOSS_TYPE
        self.norm_reg_targets = cfg.MODEL.FCOS.NORM_REG_TARGETS

        # we make use of IOU Loss for bounding boxes regression,
        # but we found that L1 in log scale can yield a similar performance
        self.box_reg_loss_func = IOULoss(self.iou_loss_type)
        self.centerness_loss_func = nn.BCEWithLogitsLoss(reduction="sum")

    def get_sample_region(self, gt, strides, num_points_per, gt_xs, gt_ys, radius=1.0):
        '''
        This code is from
        https://github.com/yqyao/FCOS_PLUS/blob/0d20ba34ccc316650d8c30febb2eb40cb6eaae37/
        maskrcnn_benchmark/modeling/rpn/fcos/loss.py#L42
        '''
        num_gts = gt.shape[0]
        K = len(gt_xs)
        gt = gt[None].expand(K, num_gts, 4)
        center_x = (gt[..., 0] + gt[..., 2]) / 2
        center_y = (gt[..., 1] + gt[..., 3]) / 2
        center_gt = gt.new_zeros(gt.shape)
        # no gt
        if center_x[..., 0].sum() == 0:
            return gt_xs.new_zeros(gt_xs.shape, dtype=torch.uint8)
        beg = 0
        for level, n_p in enumerate(num_points_per):
            end = beg + n_p
            stride = strides[level] * radius
            xmin = center_x[beg:end] - stride
            ymin = center_y[beg:end] - stride
            xmax = center_x[beg:end] + stride
            ymax = center_y[beg:end] + stride
            # limit sample region in gt
            center_gt[beg:end, :, 0] = torch.where(
                xmin > gt[beg:end, :, 0], xmin, gt[beg:end, :, 0]
            )
            center_gt[beg:end, :, 1] = torch.where(
                ymin > gt[beg:end, :, 1], ymin, gt[beg:end, :, 1]
            )
            center_gt[beg:end, :, 2] = torch.where(
                xmax > gt[beg:end, :, 2],
                gt[beg:end, :, 2], xmax
            )
            center_gt[beg:end, :, 3] = torch.where(
                ymax > gt[beg:end, :, 3],
                gt[beg:end, :, 3], ymax
            )
            beg = end
        left = gt_xs[:, None] - center_gt[..., 0]
        right = center_gt[..., 2] - gt_xs[:, None]
        top = gt_ys[:, None] - center_gt[..., 1]
        bottom = center_gt[..., 3] - gt_ys[:, None]
        center_bbox = torch.stack((left, top, right, bottom), -1)
        inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
        return inside_gt_bbox_mask

    def prepare_targets(self, points, targets):
        # The scale size of the detected box corresponding to each layer's feature map
        object_sizes_of_interest = [
            [-1, 64],
            [64, 128],
            [128, 256],
            [256, 512],
            [512, INF],
        ]
        # Create 5 sensors corresponding to 5 feature maps
        # The number of rows in each sensor is the number of points in the corresponding characteristic graph, and the number of columns is 2
        # For example, the size of the first sensor is (len(points_per_level), 2), and the element is a duplicate [- 1, 64]
        expanded_object_sizes_of_interest = []
        for l, points_per_level in enumerate(points):
            object_sizes_of_interest_per_level = \
                points_per_level.new_tensor(object_sizes_of_interest[l])
            expanded_object_sizes_of_interest.append(
                object_sizes_of_interest_per_level[None].expand(
                    len(points_per_level), -1)
            )
        # Expanded object sizes of interest dimension
        # Expanded object sizes of interest [0] is a duplicate [- 1, 64]
        # Expanded object sizes of interest [1] is a duplicate [64, 128]
        # [e.shape for e in expanded_object_sizes_of_interest]
        # [torch.Size([16128, 2]), torch.Size([4032, 2]), torch.Size([1008, 2]), torch.Size([252, 2]), torch.Size([66, 2])]

        # concat on line dimension
        expanded_object_sizes_of_interest = torch.cat(
            expanded_object_sizes_of_interest, dim=0)
        # Number of points in each layer of feature map
        num_points_per_level = [len(points_per_level)
                                for points_per_level in points]
        self.num_points_per_level = num_points_per_level
        points_all_level = torch.cat(points, dim=0)

        # box allocation strategy in papers
        # According to the preset scale, different box sizes are allocated to different level feature maps
        # labels: the corresponding label size of each point in the 5-layer feature map: [torch.Size([21486])]
        # reg_targets: four coordinates of the box corresponding to each point in the 5-layer feature map size: [torch.Size([21486, 4])]
        labels, reg_targets = self.compute_targets_for_locations(
            points_all_level, targets, expanded_object_sizes_of_interest
        )

        # batch_size is len(labels)
        for i in range(len(labels)):
            labels[i] = torch.split(labels[i], num_points_per_level, dim=0)
            # Divide the label[0] into [torch.Size([16128]), torch.Size([4032]), torch.Size([1008]), torch.Size([252]), torch.Size([66])]
            # Corresponding to these five characteristic graphs
            # Reg? Targets the same
            reg_targets[i] = torch.split(
                reg_targets[i], num_points_per_level, dim=0)

        labels_level_first = []
        reg_targets_level_first = []
        # Combine label and reg? Targets of the same level of multiple pictures into one sensor
        for level in range(len(points)):
            # concat all point s of all pictures in a batch on the dimension of the row
            # Append to labels? Level? First
            labels_level_first.append(
                torch.cat([labels_per_im[level]
                           for labels_per_im in labels], dim=0)
            )

            # concat all point s of all pictures in a batch on the dimension of the row
            reg_targets_per_level = torch.cat([
                reg_targets_per_im[level]
                for reg_targets_per_im in reg_targets
            ], dim=0)

            if self.norm_reg_targets:
                # self.fpn_strides: [8, 16, 32, 64, 128]
                reg_targets_per_level = reg_targets_per_level / \
                    self.fpn_strides[level]
            # Add to reg targets per level
            reg_targets_level_first.append(reg_targets_per_level)
        # Labels? Level? First [0]: the label of the same level of all pictures in a batch
        # Reg? Targets? Level? First: box coordinates of the same level for all pictures in a batch
        return labels_level_first, reg_targets_level_first

    def compute_targets_for_locations(self, locations, targets, object_sizes_of_interest):
        labels = []
        reg_targets = []
        xs, ys = locations[:, 0], locations[:, 1]

        for im_i in range(len(targets)):
            targets_per_im = targets[im_i]
            assert targets_per_im.mode == "xyxy"
            bboxes = targets_per_im.bbox
            labels_per_im = targets_per_im.get_field("labels")
            area = targets_per_im.area()

            l = xs[:, None] - bboxes[:, 0][None]
            t = ys[:, None] - bboxes[:, 1][None]
            r = bboxes[:, 2][None] - xs[:, None]
            b = bboxes[:, 3][None] - ys[:, None]
            reg_targets_per_im = torch.stack([l, t, r, b], dim=2)

            if self.center_sampling_radius > 0:
                is_in_boxes = self.get_sample_region(
                    bboxes,
                    self.fpn_strides,
                    self.num_points_per_level,
                    xs, ys,
                    radius=self.center_sampling_radius
                )
            else:
                # no center sampling, it will use all the locations within a ground-truth box
                is_in_boxes = reg_targets_per_im.min(dim=2)[0] > 0

            max_reg_targets_per_im = reg_targets_per_im.max(dim=2)[0]
            # limit the regression range for each location
            is_cared_in_the_level = \
                (max_reg_targets_per_im >= object_sizes_of_interest[:, [0]]) & \
                (max_reg_targets_per_im <= object_sizes_of_interest[:, [1]])

            locations_to_gt_area = area[None].repeat(len(locations), 1)
            locations_to_gt_area[is_in_boxes == 0] = INF
            locations_to_gt_area[is_cared_in_the_level == 0] = INF

            # if there are still more than one objects for a location,
            # we choose the one with minimal area
            locations_to_min_area, locations_to_gt_inds = locations_to_gt_area.min(
                dim=1)

            reg_targets_per_im = reg_targets_per_im[range(
                len(locations)), locations_to_gt_inds]
            labels_per_im = labels_per_im[locations_to_gt_inds]
            labels_per_im[locations_to_min_area == INF] = 0

            labels.append(labels_per_im)
            reg_targets.append(reg_targets_per_im)

        return labels, reg_targets

    def compute_centerness_targets(self, reg_targets):
        left_right = reg_targets[:, [0, 2]]
        top_bottom = reg_targets[:, [1, 3]]
        centerness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \
            (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
        return torch.sqrt(centerness)

    def __call__(self, locations, box_cls, box_regression, centerness, targets):
        """
        Arguments:
            locations (list[BoxList])
            box_cls (list[Tensor])
            box_regression (list[Tensor])
            centerness (list[Tensor])
            targets (list[BoxList])

        Returns:
            cls_loss (Tensor)
            reg_loss (Tensor)
            centerness_loss (Tensor)
        """
        N = box_cls[0].size(0)
        num_classes = box_cls[0].size(1)

        # The function is used to construct the ground truth of category and box coordinates according to location and target
        # Dimension of element in location
        # [torch.Size([15200, 2]), torch.Size([3800, 2]), torch.Size([950, 2]), torch.Size([247, 2]), torch.Size([70, 2])]
        # torch.Size([15200, 2]): there are 15200 points in the 0-level feature map, and the value is the original picture coordinate position corresponding to each point
        # targets
        # [BoxList(num_boxes=3, image_width=1201, image_height=800, mode=xyxy)]
        # labels
        # [torch.Size([15200]), torch.Size([3800]), torch.Size([950]), torch.Size([247]), torch.Size([70])]
        # reg_targets
        # [torch.Size([15200, 4]), torch.Size([3800, 4]), torch.Size([950, 4]), torch.Size([247, 4]), torch.Size([70, 4])]
        labels, reg_targets = self.prepare_targets(locations, targets)

        box_cls_flatten = []
        box_regression_flatten = []
        centerness_flatten = []
        labels_flatten = []
        reg_targets_flatten = []

        # reshape the predicted box, cls and GT's box and cls of each layer's feature map
        for l in range(len(labels)):
            box_cls_flatten.append(box_cls[l].permute(
                0, 2, 3, 1).reshape(-1, num_classes))
            box_regression_flatten.append(
                box_regression[l].permute(0, 2, 3, 1).reshape(-1, 4))
            labels_flatten.append(labels[l].reshape(-1))
            reg_targets_flatten.append(reg_targets[l].reshape(-1, 4))
            centerness_flatten.append(centerness[l].reshape(-1))

        # Add on line dimension
        box_cls_flatten = torch.cat(box_cls_flatten, dim=0)
        box_regression_flatten = torch.cat(box_regression_flatten, dim=0)
        centerness_flatten = torch.cat(centerness_flatten, dim=0)
        labels_flatten = torch.cat(labels_flatten, dim=0)
        reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0)
        # dimension
        # box_cls_flatten: torch.Size([20267, 80])
        # box_regression_flatten: torch.Size([20267, 4])
        # labels_flatten: torch.Size([20267])
        # reg_targets_flatten: torch.Size([20267, 4])
        # centerness_flatten : torch.Size([20267])

        # Extracting points in classified feature map
        pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1)
        # pos_inds : tensor([ 7971,  7972,  7973,  8123,  8124,  8125,  8275,  8276,  8277, 17133,
        # 17134, 17135, 20057, 20058, 20059, 20068, 20069, 20070, 20076, 20077,
        # 20078, 20087, 20088, 20089, 20095, 20096, 20097, 20106, 20107, 20108,
        # 20243, 20244], device='cuda:0')

        # Extracting positive samples according to pos_inds
        box_regression_flatten = box_regression_flatten[pos_inds]
        reg_targets_flatten = reg_targets_flatten[pos_inds]
        centerness_flatten = centerness_flatten[pos_inds]

        num_gpus = get_num_gpus()
        # sync num_pos from all gpus
        total_num_pos = reduce_sum(
            pos_inds.new_tensor([pos_inds.numel()])).item()
        num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0)

        # Category loss: SigmoidFocalLoss
        cls_loss = self.cls_loss_func(
            box_cls_flatten,
            labels_flatten.int()
        ) / num_pos_avg_per_gpu

        if pos_inds.numel() > 0:
            # Calculate the centerness of target
            centerness_targets = self.compute_centerness_targets(
                reg_targets_flatten)

            # average sum_centerness_targets from all gpus,
            # which is used to normalize centerness-weighed reg loss
            sum_centerness_targets_avg_per_gpu = \
                reduce_sum(centerness_targets.sum()).item() / float(num_gpus)

            # Calculating the loss of box coordinates
            reg_loss = self.box_reg_loss_func(
                box_regression_flatten,
                reg_targets_flatten,
                centerness_targets
            ) / sum_centerness_targets_avg_per_gpu
            # Calculate centers loss
            centerness_loss = self.centerness_loss_func(
                centerness_flatten,
                centerness_targets
            ) / num_pos_avg_per_gpu
        else:
            # If the picture doesn't have a box
            reg_loss = box_regression_flatten.sum()
            reduce_sum(centerness_flatten.new_tensor([0.0]))
            centerness_loss = centerness_flatten.sum()

        return cls_loss, reg_loss, centerness_loss


def make_fcos_loss_evaluator(cfg):
    loss_evaluator = FCOSLossComputation(cfg)
    return loss_evaluator

502 original articles published, 96 praised, 250000 visitors+
His message board follow

Tags: github

Posted on Mon, 13 Jan 2020 08:05:44 -0800 by nhanlee