From 5a5a403fd9ae515b5703496be5cd57431edd2c2a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 13 Nov 2023 11:24:05 -0800 Subject: [PATCH] Add inference function for darknet image classification model PiperOrigin-RevId: 582028410 --- .../yolo/dataloaders/classification_input.py | 18 ++++++++++++++++++ .../yolo/serving/export_module_factory.py | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/official/projects/yolo/dataloaders/classification_input.py b/official/projects/yolo/dataloaders/classification_input.py index 39d9f784eb7..88dca0647e0 100644 --- a/official/projects/yolo/dataloaders/classification_input.py +++ b/official/projects/yolo/dataloaders/classification_input.py @@ -13,6 +13,8 @@ # limitations under the License. """Classification decoder and parser.""" +from typing import List + import tensorflow as tf, tf_keras from official.vision.dataloaders import classification_input from official.vision.ops import preprocess_ops @@ -90,3 +92,19 @@ def _parse_eval_image(self, decoded_tensors): image = tf.image.convert_image_dtype(image, self._dtype) image = image / 255.0 return image + + @classmethod + def inference_fn( + cls, image: tf.Tensor, input_image_size: List[int], num_channels: int = 3 + ) -> tf.Tensor: + """Builds image model inputs for serving.""" + + image = tf.cast(image, dtype=tf.float32) + image = preprocess_ops.center_crop_image(image) + image = tf.image.resize( + image, input_image_size, method=tf.image.ResizeMethod.BILINEAR + ) + + image.set_shape(input_image_size + [num_channels]) + image = image / 255.0 + return image diff --git a/official/projects/yolo/serving/export_module_factory.py b/official/projects/yolo/serving/export_module_factory.py index dc97b13fac8..74a35202694 100644 --- a/official/projects/yolo/serving/export_module_factory.py +++ b/official/projects/yolo/serving/export_module_factory.py @@ -23,11 +23,11 @@ from official.projects.yolo.configs import darknet_classification from official.projects.yolo.configs import yolo from official.projects.yolo.configs import yolov7 +from official.projects.yolo.dataloaders import classification_input from official.projects.yolo.modeling import factory as yolo_factory from official.projects.yolo.modeling.backbones import darknet # pylint: disable=unused-import from official.projects.yolo.modeling.decoders import yolo_decoder # pylint: disable=unused-import from official.projects.yolo.serving import model_fn as yolo_model_fn -from official.vision.dataloaders import classification_input from official.vision.modeling import factory from official.vision.serving import export_utils