From aa6d38e5fee94c1870bedfb8ddf3adb2923bef4f Mon Sep 17 00:00:00 2001 From: Anna Mironova Date: Wed, 21 Oct 2020 10:32:10 +0300 Subject: [PATCH] AC: Added custom evaluator for two-stream I3D model (#1674) * Updated converter * Added custom evaluator for two-stream I3D model * Existing model has been updated in according with changes * Fix pylint * Fix comments in converter, revert changes in dataset_definitions.yml * Fix comments in evaluator --- models/public/i3d-rgb-tf/accuracy-check.yml | 2 +- .../action_recognition.py | 191 +++++--- .../evaluators/custom_evaluators/README.md | 12 +- .../custom_evaluators/i3d_evaluator.py | 413 ++++++++++++++++++ .../metrics/classification.py | 5 +- 5 files changed, 556 insertions(+), 67 deletions(-) create mode 100644 tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/i3d_evaluator.py diff --git a/models/public/i3d-rgb-tf/accuracy-check.yml b/models/public/i3d-rgb-tf/accuracy-check.yml index 23e25f3a3e9..d042ddaa715 100644 --- a/models/public/i3d-rgb-tf/accuracy-check.yml +++ b/models/public/i3d-rgb-tf/accuracy-check.yml @@ -1,5 +1,5 @@ models: - - name: i3d-rgb-tf + - name: i3d-rgb launchers: - framework: dlsdk adapter: classification diff --git a/tools/accuracy_checker/accuracy_checker/annotation_converters/action_recognition.py b/tools/accuracy_checker/accuracy_checker/annotation_converters/action_recognition.py index a1ad0154f7c..2bccb08ae7a 100644 --- a/tools/accuracy_checker/accuracy_checker/annotation_converters/action_recognition.py +++ b/tools/accuracy_checker/accuracy_checker/annotation_converters/action_recognition.py @@ -15,11 +15,12 @@ """ from collections import OrderedDict +import warnings from ..utils import read_json, read_txt, check_file_existence from ..representation import ClassificationAnnotation from ..data_readers import ClipIdentifier -from ..config import PathField, NumberField, StringField, BoolField +from ..config import PathField, NumberField, StringField, BoolField, ConfigError from .format_converter import BaseFormatConverter, ConverterReturn, verify_label_map @@ -50,7 +51,11 @@ def parameters(cls): 'dataset_meta_file': PathField( description='path to json file with dataset meta (e.g. label_map)', optional=True ), - 'numpy_input': BoolField(description='use numpy arrays instead of images', optional=True, default=False), + 'numpy_input': BoolField(description='use numpy arrays as input', optional=True, default=False), + 'two_stream_input': BoolField(description='use two streams: images and numpy arrays as input', + optional=True, default=False), + 'image_subpath': StringField(description="sub-directory for images", optional=True), + 'numpy_subpath': StringField(description="sub-directory for numpy arrays", optional=True), 'num_samples': NumberField( description='number of samples used for annotation', optional=True, value_type=int, min_value=1 ) @@ -67,11 +72,38 @@ def configure(self): self.subset = self.get_value_from_config('subset') self.dataset_meta = self.get_value_from_config('dataset_meta_file') self.numpy_input = self.get_value_from_config('numpy_input') + self.two_stream_input = self.get_value_from_config('two_stream_input') + self.numpy_subdir = self.get_value_from_config('numpy_subpath') + self.image_subdir = self.get_value_from_config('image_subpath') self.num_samples = self.get_value_from_config('num_samples') + if self.numpy_subdir and (self.numpy_input or + self.two_stream_input) and not (self.data_dir / self.numpy_subdir).exists(): + raise ConfigError('Please check numpy_subpath or data_dir. ' + 'Path {} does not exist'.format(self.data_dir / self.numpy_subdir)) + + if self.image_subdir and (not self.numpy_input or + self.two_stream_input) and not (self.data_dir / self.image_subdir).exists(): + raise ConfigError('Please check image_subpath or data_dir. ' + 'Path {} does not exist'.format(self.data_dir / self.image_subdir)) + + if self.two_stream_input: + if not self.numpy_subdir: + raise ConfigError('numpy_subpath should be provided in case of using two streams') + + if not self.image_subdir: + raise ConfigError('image_subpath should be provided in case of using two streams') + else: + if self.numpy_input and self.numpy_subdir: + warnings.warn("numpy_subpath is provided. " + "Make sure that data_source is {}".format(self.data_dir / self.numpy_subdir)) + if not self.numpy_input and self.image_subdir: + warnings.warn("image_subpath is provided. " + "Make sure that data_source is {}".format(self.data_dir / self.image_subdir)) + def convert(self, check_content=False, progress_callback=None, progress_interval=100, **kwargs): full_annotation = read_json(self.annotation_file, object_pairs_hook=OrderedDict) - data_ext = 'jpg' if not self.numpy_input else 'npy' + data_ext, data_dir = self.get_ext_and_dir() label_map = dict(enumerate(full_annotation['labels'])) if self.dataset_meta: dataset_meta = read_json(self.dataset_meta) @@ -83,41 +115,13 @@ def convert(self, check_content=False, progress_callback=None, progress_interval video_names, annotations = self.get_video_names_and_annotations(full_annotation['database'], self.subset) class_to_idx = {v: k for k, v in label_map.items()} - videos = [] - for video_name, annotation in zip(video_names, annotations): - video_path = self.data_dir / video_name - if not video_path.exists(): - continue - - n_frames_file = video_path / 'n_frames' - n_frames = ( - int(read_txt(n_frames_file)[0].rstrip('\n\r')) if n_frames_file.exists() - else len(list(video_path.glob('*.{}'.format(data_ext)))) - ) - if n_frames <= 0: - continue - - begin_t = 1 - end_t = n_frames - sample = { - 'video': video_path, - 'video_name': video_name, - 'segment': [begin_t, end_t], - 'n_frames': n_frames, - 'video_id': video_name, - 'label': class_to_idx[annotation['label']] - } - - videos.append(sample) - if self.num_samples and len(videos) == self.num_samples: - break - + videos = self.get_videos(video_names, annotations, class_to_idx, data_dir, data_ext) videos = sorted(videos, key=lambda v: v['video_id'].split('/')[-1]) clips = [] for video in videos: - for clip in self.get_clips(video, self.clips_per_video, self.clip_duration, self.temporal_stride, data_ext): - clips.append(clip) + clips.extend(self.get_clips(video, self.clips_per_video, + self.clip_duration, self.temporal_stride, data_ext)) annotations = [] num_iterations = len(clips) @@ -125,43 +129,108 @@ def convert(self, check_content=False, progress_callback=None, progress_interval for clip_idx, clip in enumerate(clips): if progress_callback is not None and clip_idx % progress_interval: progress_callback(clip_idx * 100 / num_iterations) - identifier = ClipIdentifier(clip['video_name'], clip_idx, clip['frames']) + identifier = [] + for ext in data_ext: + identifier.append(ClipIdentifier(clip['video_name'], clip_idx, clip['frames_{}'.format(ext)])) if check_content: - content_errors.extend([ - '{}: does not exist'.format(self.data_dir / frame) - for frame in clip['frames'] if not check_file_existence(self.data_dir / frame) - ]) + for ext, dir_ in zip(data_ext, data_dir): + content_errors.extend([ + '{}: does not exist'.format(dir_ / frame) + for frame in clip['frames_{}'.format(ext)] if not check_file_existence(dir_ / frame) + ]) + if len(identifier) == 1: + identifier = identifier[0] + annotations.append(ClassificationAnnotation(identifier, clip['label'])) return ConverterReturn(annotations, {'label_map': label_map}, content_errors) - @staticmethod - def get_clips(video, clips_per_video, clip_duration, temporal_stride=1, file_ext='jpg'): - shift = int(file_ext == 'npy') - num_frames = video['n_frames'] - shift - clip_duration *= temporal_stride - - if clips_per_video == 0: - step = clip_duration - else: - step = max(1, (num_frames - clip_duration) // (clips_per_video - 1)) + def get_ext_and_dir(self): + if self.two_stream_input: + return ['jpg', 'npy'], [self.data_dir / self.image_subdir, self.data_dir / self.numpy_subdir] - for clip_start in range(1, 1 + clips_per_video * step, step): - clip_end = min(clip_start + clip_duration, num_frames + 1) + if self.numpy_input: + return ['npy'], [self.data_dir / self.numpy_subdir if self.numpy_subdir else self.data_dir] - clip_idxs = list(range(clip_start, clip_end)) + return ['jpg'], [self.data_dir / self.image_subdir if self.image_subdir else self.data_dir] - if not clip_idxs: - return + def get_videos(self, video_names, annotations, class_to_idx, data_dir, data_ext): + videos = [] + for video_name, annotation in zip(video_names, annotations): + video_info = { + 'video_name': video_name, + 'video_id': video_name, + 'label': class_to_idx[annotation['label']] + } + for dir_, ext in zip(data_dir, data_ext): + video_path = dir_ / video_name + if not video_path.exists(): + video_info.clear() + continue + + n_frames_file = video_path / 'n_frames' + n_frames = ( + int(read_txt(n_frames_file)[0].rstrip('\n\r')) if n_frames_file.exists() + else len(list(video_path.glob('*.{}'.format(ext)))) + ) + if n_frames <= 0: + video_info.clear() + continue + + begin_t = 1 + end_t = n_frames + sample = { + 'video_{}'.format(ext): video_path, + 'segment_{}'.format(ext): [begin_t, end_t], + 'n_frames_{}'.format(ext): n_frames, + } + video_info.update(sample) + + if video_info: + videos.append(video_info) + if self.num_samples and len(videos) == self.num_samples: + break + return videos - # loop clip if it is shorter than clip_duration - while len(clip_idxs) < clip_duration: - clip_idxs = (clip_idxs * 2)[:clip_duration] + @staticmethod + def get_clips(video, clips_per_video, clip_duration, temporal_stride=1, file_ext='jpg'): + clip_duration *= temporal_stride + frames_ext = {} + for ext in file_ext: + frames = [] + shift = int(ext == 'npy') + num_frames = video['n_frames_{}'.format(ext)] - shift + + if clips_per_video == 0: + step = clip_duration + else: + step = max(1, (num_frames - clip_duration) // (clips_per_video - 1)) + for clip_start in range(1, 1 + clips_per_video * step, step): + clip_end = min(clip_start + clip_duration, num_frames + 1) + + clip_idxs = list(range(clip_start, clip_end)) + + if not clip_idxs: + return [] + + # loop clip if it is shorter than clip_duration + while len(clip_idxs) < clip_duration: + clip_idxs = (clip_idxs * 2)[:clip_duration] + + frames_idx = clip_idxs[::temporal_stride] + frames.append(['image_{:05d}.{}'.format(frame_idx, ext) for frame_idx in frames_idx]) + frames_ext.update({ + ext: frames + }) - clip = dict(video) - frames_idx = clip_idxs[::temporal_stride] - clip['frames'] = ['image_{:05d}.{}'.format(frame_idx, file_ext) for frame_idx in frames_idx] - yield clip + clips = [] + for key, value in frames_ext.items(): + if not clips: + for _ in range(len(value)): + clips.append(dict(video)) + for val, clip in zip(value, clips): + clip['frames_{}'.format(key)] = val + return clips @staticmethod def get_video_names_and_annotations(data, subset): diff --git a/tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/README.md b/tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/README.md index 50e576a6bc6..1cd8358a901 100644 --- a/tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/README.md +++ b/tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/README.md @@ -35,7 +35,11 @@ Optionally you can provide `module_config` section which contains config for cus Configuration file example: text-spotting-0002. * **Automatic Speech Recognition Evaluator** shows how to evaluate speech recognition pipeline (encoder + decoder). -Evaluator code. -* **Im2latex formula recognition** demonstrates how to run encoder-decoder model for extractring latex formula from image - Evaluator code - Configuration file example: im2latex-medium-0002 + Evaluator code. + +* **Im2latex formula recognition** demonstrates how to run encoder-decoder model for extractring latex formula from image. + Evaluator code. + Configuration file example: im2latex-medium-0002. + +* **I3D Evaluator** demonstrates how to evaluate two-stream I3D model (RGB + Flow). + Evaluator code. diff --git a/tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/i3d_evaluator.py b/tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/i3d_evaluator.py new file mode 100644 index 00000000000..a6ad046c2f3 --- /dev/null +++ b/tools/accuracy_checker/accuracy_checker/evaluators/custom_evaluators/i3d_evaluator.py @@ -0,0 +1,413 @@ +""" +Copyright (c) 2018-2020 Intel Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from pathlib import Path +from collections import OrderedDict +import warnings +import numpy as np + +from ..base_evaluator import BaseEvaluator +from ..quantization_model_evaluator import create_dataset_attributes +from ...adapters import create_adapter +from ...config import ConfigError +from ...launcher import create_launcher +from ...data_readers import create_reader +from ...utils import extract_image_representations, contains_all, get_path +from ...progress_reporters import ProgressReporter +from ...logging import print_info +from ...preprocessor import Crop, Resize + + +class I3DEvaluator(BaseEvaluator): + def __init__(self, dataset_config, launcher, adapter, rgb_model, flow_model): + self.dataset_config = dataset_config + self.preprocessor = None + self.dataset = None + self.postprocessor = None + self.metric_executor = None + self.launcher = launcher + self.adapter = adapter + self.rgb_model = rgb_model + self.flow_model = flow_model + self._metrics_results = [] + + @classmethod + def from_configs(cls, config, delayed_model_loading=False): + dataset_config = config['datasets'] + launcher_settings = config['launchers'][0] + supported_frameworks = ['dlsdk'] + if not launcher_settings['framework'] in supported_frameworks: + raise ConfigError('{} framework not supported'.format(launcher_settings['framework'])) + if 'device' not in launcher_settings: + launcher_settings['device'] = 'CPU' + launcher = create_launcher(launcher_settings, delayed_model_loading=True) + adapter = create_adapter(launcher_settings['adapter']) + network_info = config.get('network_info', {}) + data_source = dataset_config[0].get('data_source', None) + if not delayed_model_loading: + flow_network = network_info.get('flow', {}) + rgb_network = network_info.get('rgb', {}) + model_args = config.get('_models', []) + models_is_blob = config.get('_model_is_blob') + if 'model' not in flow_network and model_args: + flow_network['model'] = model_args[0] + flow_network['_model_is_blob'] = models_is_blob + if 'model' not in rgb_network and model_args: + rgb_network['model'] = model_args[1 if len(model_args) > 1 else 0] + rgb_network['_model_is_blob'] = models_is_blob + network_info.update({ + 'flow': flow_network, + 'rgb': rgb_network + }) + if not contains_all(network_info, ['flow', 'rgb']): + raise ConfigError('configuration for flow/rgb does not exist') + + flow_model = I3DFlowModel( + network_info.get('flow', {}), launcher, data_source, delayed_model_loading + ) + rgb_model = I3DRGBModel( + network_info.get('rgb', {}), launcher, data_source, delayed_model_loading + ) + if rgb_model.output_blob != flow_model.output_blob: + warnings.warn("Outputs for rgb and flow models have different names. " + "rgb model's output name: {}. flow model's output name: {}. Output name of rgb model " + "will be used in combined output".format(rgb_model.output_blob, flow_model.output_blob)) + adapter.output_blob = rgb_model.output_blob + return cls(dataset_config, launcher, adapter, rgb_model, flow_model) + + @staticmethod + def get_dataset_info(dataset): + annotation = dataset.annotation_reader.annotation + identifiers = dataset.annotation_reader.identifiers + + return annotation, identifiers + + @staticmethod + def combine_predictions(output_rgb, output_flow): + output = {} + for key_rgb, key_flow in zip(output_rgb.keys(), output_flow.keys()): + data_rgb = np.asarray(output_rgb[key_rgb]) + data_flow = np.asarray(output_flow[key_flow]) + + if data_rgb.shape != data_flow.shape: + raise ValueError("Сalculation of combined output is not possible. Outputs for rgb and flow models have " + "different shapes. rgb model's output shape: {}. " + "flow model's output shape: {}.".format(data_rgb.shape, data_flow.shape)) + + result_data = (data_rgb + data_flow) / 2 + output[key_rgb] = result_data + + return output + + def process_dataset( + self, subset=None, + num_images=None, + check_progress=False, + dataset_tag='', + allow_pairwise_subset=False, + **kwargs): + + if self.dataset is None or (dataset_tag and self.dataset.tag != dataset_tag): + self.select_dataset(dataset_tag) + self._create_subset(subset, num_images, allow_pairwise_subset) + + self._annotations, self._predictions = [], [] + + if 'progress_reporter' in kwargs: + _progress_reporter = kwargs['progress_reporter'] + _progress_reporter.reset(self.dataset.size) + else: + _progress_reporter = None if not check_progress else self._create_progress_reporter( + check_progress, self.dataset.size + ) + + compute_intermediate_metric_res = kwargs.get('intermediate_metrics_results', False) + if compute_intermediate_metric_res: + metric_interval = kwargs.get('metrics_interval', 1000) + ignore_results_formatting = kwargs.get('ignore_results_formatting', False) + + annotation, identifiers = self.get_dataset_info(self.dataset) + for batch_id, (batch_annotation, batch_identifiers) in enumerate(zip(annotation, identifiers)): + batch_inputs_images = self.rgb_model.prepare_data(batch_identifiers) + batch_inputs_flow = self.flow_model.prepare_data(batch_identifiers) + + extr_batch_inputs_images, _ = extract_image_representations([batch_inputs_images]) + extr_batch_inputs_flow, _ = extract_image_representations([batch_inputs_flow]) + + batch_raw_prediction_rgb = self.rgb_model.predict(extr_batch_inputs_images) + batch_raw_prediction_flow = self.flow_model.predict(extr_batch_inputs_flow) + batch_raw_out = self.combine_predictions(batch_raw_prediction_rgb, batch_raw_prediction_flow) + + batch_prediction = self.adapter.process([batch_raw_out], identifiers, [{}]) + + if self.metric_executor.need_store_predictions: + self._annotations.extend([batch_annotation]) + self._predictions.extend(batch_prediction) + + if _progress_reporter: + _progress_reporter.update(batch_id, len(batch_prediction)) + if compute_intermediate_metric_res and _progress_reporter.current % metric_interval == 0: + self.compute_metrics( + print_results=True, ignore_results_formatting=ignore_results_formatting + ) + + if _progress_reporter: + _progress_reporter.finish() + + def compute_metrics(self, print_results=True, ignore_results_formatting=False): + if self._metrics_results: + del self._metrics_results + self._metrics_results = [] + + for result_presenter, evaluated_metric in self.metric_executor.iterate_metrics( + self._annotations, self._predictions + ): + self._metrics_results.append(evaluated_metric) + if print_results: + result_presenter.write_result(evaluated_metric, ignore_results_formatting) + + return self._metrics_results + + def print_metrics_results(self, ignore_results_formatting=False): + if not self._metrics_results: + self.compute_metrics(True, ignore_results_formatting) + return + result_presenters = self.metric_executor.get_metric_presenters() + for presenter, metric_result in zip(result_presenters, self._metrics_results): + presenter.write_results(metric_result, ignore_results_formatting) + + def extract_metrics_results(self, print_results=True, ignore_results_formatting=False): + if not self._metrics_results: + self.compute_metrics(False, ignore_results_formatting) + + result_presenters = self.metric_executor.get_metric_presenters() + extracted_results, extracted_meta = [], [] + for presenter, metric_result in zip(result_presenters, self._metrics_results): + result, metadata = presenter.extract_result(metric_result) + if isinstance(result, list): + extracted_results.extend(result) + extracted_meta.extend(metadata) + else: + extracted_results.append(result) + extracted_meta.append(metadata) + if print_results: + presenter.write_result(metric_result, ignore_results_formatting) + + return extracted_results, extracted_meta + + def release(self): + self.rgb_model.release() + self.flow_model.release() + self.launcher.release() + + def reset(self): + if self.metric_executor: + self.metric_executor.reset() + if hasattr(self, '_annotations'): + del self._annotations + del self._predictions + del self._metrics_results + self._annotations = [] + self._predictions = [] + self._metrics_results = [] + if self.dataset: + self.dataset.reset(self.postprocessor.has_processors) + + @staticmethod + def get_processing_info(config): + module_specific_params = config.get('module_config') + model_name = config['name'] + dataset_config = module_specific_params['datasets'][0] + launcher_config = module_specific_params['launchers'][0] + return ( + model_name, launcher_config['framework'], launcher_config['device'], launcher_config.get('tags'), + dataset_config['name'] + ) + + def select_dataset(self, dataset_tag): + if self.dataset is not None and isinstance(self.dataset_config, list): + return + dataset_attributes = create_dataset_attributes(self.dataset_config, dataset_tag) + self.dataset, self.metric_executor, self.preprocessor, self.postprocessor = dataset_attributes + + @staticmethod + def _create_progress_reporter(check_progress, dataset_size): + pr_kwargs = {} + if isinstance(check_progress, int) and not isinstance(check_progress, bool): + pr_kwargs = {"print_interval": check_progress} + + return ProgressReporter.provide('print', dataset_size, **pr_kwargs) + + def _create_subset(self, subset=None, num_images=None, allow_pairwise=False): + if self.dataset.batch is None: + self.dataset.batch = 1 + if subset is not None: + self.dataset.make_subset(ids=subset, accept_pairs=allow_pairwise) + elif num_images is not None: + self.dataset.make_subset(end=num_images, accept_pairs=allow_pairwise) + + +class BaseModel: + def __init__(self, network_info, launcher, data_source, delayed_model_loading=False): + self.input_blob = None + self.output_blob = None + self.with_prefix = False + reader_config = network_info.get('reader', {}) + source_prefix = reader_config.get('source_prefix', '') + reader_config.update({ + 'data_source': data_source / source_prefix + }) + self.reader = create_reader(reader_config) + if not delayed_model_loading: + self.load_model(network_info, launcher, log=True) + + @staticmethod + def auto_model_search(network_info, net_type): + model = Path(network_info['model']) + is_blob = network_info.get('_model_is_blob') + if model.is_dir(): + if is_blob: + model_list = list(model.glob('*.blob')) + else: + model_list = list(model.glob('*.xml')) + if not model_list and is_blob is None: + model_list = list(model.glob('*.blob')) + if not model_list: + raise ConfigError('Suitable model not found') + if len(model_list) > 1: + raise ConfigError('Several suitable models found') + model = model_list[0] + print_info('{} - Found model: {}'.format(net_type, model)) + if model.suffix == '.blob': + return model, None + weights = get_path(network_info.get('weights', model.parent / model.name.replace('xml', 'bin'))) + print_info('{} - Found weights: {}'.format(net_type, weights)) + + return model, weights + + def predict(self, input_data): + return self.exec_network.infer(inputs=input_data[0]) + + def release(self): + del self.network + del self.exec_network + + def load_model(self, network_info, launcher, log=False): + model, weights = self.auto_model_search(network_info, self.net_type) + if weights: + self.network = launcher.read_network(str(model), str(weights)) + self.network.batch_size = 1 + self.exec_network = launcher.ie_core.load_network(self.network, launcher.device) + else: + self.network = None + launcher.ie_core.import_network(str(model)) + self.set_input_and_output() + if log: + self.print_input_output_info() + + def set_input_and_output(self): + has_info = hasattr(self.exec_network, 'input_info') + input_info = self.exec_network.input_info if has_info else self.exec_network.inputs + input_blob = next(iter(input_info)) + with_prefix = input_blob.startswith('{}_'.format(self.net_type)) + if self.input_blob is None or with_prefix != self.with_prefix: + if self.input_blob is None: + output_blob = next(iter(self.exec_network.outputs)) + else: + output_blob = ( + '_'.join([self.net_type, self.output_blob]) + if with_prefix else self.output_blob.split('{}_'.format(self.net_type))[-1] + ) + self.input_blob = input_blob + self.output_blob = output_blob + self.with_prefix = with_prefix + + def print_input_output_info(self): + print_info('{} - Input info:'.format(self.net_type)) + has_info = hasattr(self.network if self.network is not None else self.exec_network, 'input_info') + if self.network: + if has_info: + network_inputs = OrderedDict( + [(name, data.input_data) for name, data in self.network.input_info.items()] + ) + else: + network_inputs = self.network.inputs + network_outputs = self.network.outputs + else: + if has_info: + network_inputs = OrderedDict([ + (name, data.input_data) for name, data in self.exec_network.input_info.items() + ]) + else: + network_inputs = self.exec_network.inputs + network_outputs = self.exec_network.outputs + for name, input_info in network_inputs.items(): + print_info('\tLayer name: {}'.format(name)) + print_info('\tprecision: {}'.format(input_info.precision)) + print_info('\tshape {}\n'.format(input_info.shape)) + print_info('{} - Output info'.format(self.net_type)) + for name, output_info in network_outputs.items(): + print_info('\tLayer name: {}'.format(name)) + print_info('\tprecision: {}'.format(output_info.precision)) + print_info('\tshape: {}\n'.format(output_info.shape)) + + def fit_to_input(self, input_data): + has_info = hasattr(self.exec_network, 'input_info') + input_info = ( + self.exec_network.input_info[self.input_blob].input_data + if has_info else self.exec_network.inputs[self.input_blob] + ) + input_data = np.array(input_data) + input_data = np.transpose(input_data, (3, 0, 1, 2)) + input_data = np.reshape(input_data, input_info.shape) + return {self.input_blob: input_data} + + def prepare_data(self, data): + pass + + +class I3DRGBModel(BaseModel): + def __init__(self, network_info, launcher, data_source, delayed_model_loading=False): + self.net_type = 'rgb' + super().__init__(network_info, launcher, data_source, delayed_model_loading) + + def prepare_data(self, data): + image_data = data[0] + prepared_data = self.reader(image_data) + prepared_data = self.preprocessing(prepared_data) + prepared_data.data = self.fit_to_input(prepared_data.data) + return prepared_data + + @staticmethod + def preprocessing(image): + resizer_config = {'type': 'resize', 'size': 256, 'aspect_ratio_scale': 'fit_to_window'} + resizer = Resize(resizer_config) + image = resizer.process(image) + for i, frame in enumerate(image.data): + image.data[i] = Crop.process_data(frame, 224, 224, None, False, True, {}) + return image + + +class I3DFlowModel(BaseModel): + def __init__(self, network_info, launcher, data_source, delayed_model_loading=False): + self.net_type = 'flow' + super().__init__(network_info, launcher, data_source, delayed_model_loading) + + def prepare_data(self, data): + numpy_data = data[1] + prepared_data = self.reader(numpy_data) + prepared_data.data = self.fit_to_input(prepared_data.data) + return prepared_data diff --git a/tools/accuracy_checker/accuracy_checker/metrics/classification.py b/tools/accuracy_checker/accuracy_checker/metrics/classification.py index 9641c89a9b6..cc074755e40 100644 --- a/tools/accuracy_checker/accuracy_checker/metrics/classification.py +++ b/tools/accuracy_checker/accuracy_checker/metrics/classification.py @@ -201,7 +201,10 @@ def __init__(self, *args, **kwargs): self.previous_video_label = None def update(self, annotation, prediction): - video_id = annotation.identifier.video + if isinstance(annotation.identifier, list): + video_id = annotation.identifier[0].video + else: + video_id = annotation.identifier.video if self.previous_video_id is not None and video_id != self.previous_video_id: video_top_label = np.argmax(self.video_avg_prob.evaluate())