Skip to content

Commit

Permalink
AC: Add detection labels shift postprocessing for Efficientdet (openv…
Browse files Browse the repository at this point in the history
…inotoolkit#1746)

* support efficientdet

* update postprocessor

* remove default value
  • Loading branch information
Julia Kamelina authored Nov 5, 2020
1 parent 72b32b1 commit 0476ba9
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,5 @@ Accuracy Checker supports following set of postprocessors:
* `size` - size of model input for recovering YCrCb image.
* `dst_width` and `dst_height` - width and height of model input respectively for recovering YCrCb image.
* `argmax_segmentation_mask` - translates categorical annotation segmentation mask to numerical. Supported representations: `SegmentationAnnotation`, `SegmentationPrediction`.
* `shift_labels` - shifts predicted detection labels. Supported representation: `DetectionPrediction`.
* `offset` - value for shift.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from .correct_yolo_v2_boxes import CorrectYoloV2Boxes
from .resize_segmentation_mask import ResizeSegmentationMask
from .encode_segmentation_mask import EncodeSegMask
from .shift import Shift
from .shift import Shift, ShiftLabels
from .normalize_landmarks_points import NormalizeLandmarksPoints
from .clip_points import ClipPoints
from .extend_segmentation_mask import ExtendSegmentationMask
Expand Down Expand Up @@ -84,6 +84,7 @@
'ResizeSegmentationMask',
'EncodeSegMask',
'Shift',
'ShiftLabels',
'ExtendSegmentationMask',
'ZoomSegMask',
'CropSegmentationMask',
Expand Down
32 changes: 30 additions & 2 deletions tools/accuracy_checker/accuracy_checker/postprocessor/shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

import numpy as np
from ..config import NumberField
from .postprocessor import PostprocessorWithSpecificTargets
from ..representation import SegmentationAnnotation, SegmentationPrediction
from .postprocessor import PostprocessorWithSpecificTargets, Postprocessor
from ..representation import SegmentationAnnotation, SegmentationPrediction, DetectionPrediction, DetectionAnnotation


class Shift(PostprocessorWithSpecificTargets):
Expand Down Expand Up @@ -55,3 +55,31 @@ def process_image(self, annotation, prediction):
prediction_.mask = update_mask.astype(np.int16)

return annotation, prediction


class ShiftLabels(Postprocessor):
"""
Shift predicted detection labels.
"""

__provider__ = 'shift_labels'

prediction_types = (DetectionPrediction, )
annotation_types = (DetectionAnnotation, )

@classmethod
def parameters(cls):
parameters = super().parameters()
parameters.update({
'offset': NumberField(value_type=int, optional=False, description="Value for shift.")
})
return parameters

def configure(self):
self.offset = self.get_value_from_config('offset')

def process_image(self, annotation, prediction):
for prediction_ in prediction:
prediction_.labels += self.offset

return annotation, prediction

0 comments on commit 0476ba9

Please sign in to comment.