Skip to content

Commit

Permalink
Merge pull request MPolaris#12 from MPolaris/dev
Browse files Browse the repository at this point in the history
dev merge request
  • Loading branch information
MPolaris authored Oct 12, 2022
2 parents e6df9cb + d3c0c88 commit 1150d9e
Show file tree
Hide file tree
Showing 9 changed files with 67 additions and 54 deletions.
2 changes: 1 addition & 1 deletion converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def onnx_converter(onnx_model_path:str, output_path:str=None,
input_node_names:list=None, output_node_names:list=None,
need_simplify:bool=True, target_formats:list = ['keras', 'tflite'],
weight_quant:bool=False, int8_model:bool=False, image_root:str=None,
int8_mean:list or float = [0.485, 0.456, 0.406], int8_std:list or float = [0.229, 0.224, 0.225]):
int8_mean:list or float = [123.675, 116.28, 103.53], int8_std:list or float = [58.395, 57.12, 57.375]):
if not isinstance(target_formats, list) and 'keras' not in target_formats and 'tflite' not in target_formats:
raise KeyError("'keras' or 'tflite' should in list")

Expand Down
2 changes: 1 addition & 1 deletion layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from utils import OPERATOR
from .conv_layers import *
from .shape_axis_utils import *
from .dimension_utils import *
from .common_layers import *
from .activations_layers import *
from .calculations_layers import *
Expand Down
4 changes: 2 additions & 2 deletions layers/activations_layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tensorflow as tf
from tensorflow import keras

from .shape_axis_utils import Torch2TFAxis
from .dimension_utils import channel_to_last_dimension
from . import OPERATOR

@OPERATOR.register_operator("Relu")
Expand Down Expand Up @@ -122,7 +122,7 @@ def __call__(self, inputs):
class TFSoftmax():
def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
super().__init__()
self.axis = Torch2TFAxis(node_attribute.get('axis', -1))
self.axis = channel_to_last_dimension(node_attribute.get('axis', -1))

def __call__(self, inputs):
return keras.activations.softmax(inputs, axis=self.axis)
Expand Down
16 changes: 8 additions & 8 deletions layers/calculations_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import tensorflow as tf

from . import OPERATOR
from . import shape_axis_utils
from . import dimension_utils

LOG = logging.getLogger("calculations_layers :")

Expand Down Expand Up @@ -31,7 +31,7 @@ def get_number(tensor_grap, node_weights, node_inputs):
for _ in range(len(first_operand.shape) - 2):
second_operand = second_operand[..., np.newaxis]
else:
second_operand = shape_axis_utils.TorchWeights2TF(second_operand)
second_operand = dimension_utils.tensor_NCD_to_NDC_format(second_operand)
elif (not first_operand_flg) and second_operand_flg:
# 当second_operand为计算得出的,first_operand来自weight时
if len(first_operand.shape) == 1:
Expand All @@ -40,7 +40,7 @@ def get_number(tensor_grap, node_weights, node_inputs):
for _ in range(len(second_operand.shape) - 2):
first_operand = first_operand[..., np.newaxis]
else:
first_operand = shape_axis_utils.TorchWeights2TF(first_operand)
first_operand = dimension_utils.tensor_NCD_to_NDC_format(first_operand)

return first_operand, second_operand

Expand Down Expand Up @@ -148,7 +148,7 @@ def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args
super().__init__()
self.keep_dims = node_attribute.get("keepdims", 1) == 1
input_shape_len = len(tensor_grap[node_inputs[0]].shape)
self.axes = [shape_axis_utils.Torch2TFAxis(i) if i >=0 else shape_axis_utils.Torch2TFAxis(input_shape_len + i) for i in node_attribute.get("axes", [-1])]
self.axes = [dimension_utils.channel_to_last_dimension(i) if i >=0 else dimension_utils.channel_to_last_dimension(input_shape_len + i) for i in node_attribute.get("axes", [-1])]

def __call__(self, inputs, *args, **kwargs):
return tf.math.reduce_mean(inputs, axis=self.axes, keepdims=self.keep_dims)
Expand All @@ -159,7 +159,7 @@ def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args
super().__init__()
self.keep_dims = node_attribute.get("keepdims", 1) == 1
input_shape_len = len(tensor_grap[node_inputs[0]].shape)
self.axes = [shape_axis_utils.Torch2TFAxis(i) if i >=0 else shape_axis_utils.Torch2TFAxis(input_shape_len + i) for i in node_attribute.get("axes", [-1])]
self.axes = [dimension_utils.channel_to_last_dimension(i) if i >=0 else dimension_utils.channel_to_last_dimension(input_shape_len + i) for i in node_attribute.get("axes", [-1])]

def __call__(self, inputs, *args, **kwargs):
return tf.math.reduce_max(inputs, axis=self.axes, keepdims=self.keep_dims)
Expand All @@ -170,7 +170,7 @@ def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args
super().__init__()
self.keep_dims = node_attribute.get("keepdims", 1) == 1
input_shape_len = len(tensor_grap[node_inputs[0]].shape)
self.axes = [shape_axis_utils.Torch2TFAxis(i) if i >=0 else input_shape_len + i for i in node_attribute.get("axes", [-1])]
self.axes = [dimension_utils.channel_to_last_dimension(i) if i >=0 else input_shape_len + i for i in node_attribute.get("axes", [-1])]

def __call__(self, inputs, *args, **kwargs):
return tf.math.reduce_min(inputs, axis=self.axes, keepdims=self.keep_dims)
Expand All @@ -179,7 +179,7 @@ def __call__(self, inputs, *args, **kwargs):
class TFArgMax():
def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
super().__init__()
self.axis = shape_axis_utils.Torch2TFAxis(node_attribute.get('axis', 0))
self.axis = dimension_utils.channel_to_last_dimension(node_attribute.get('axis', 0))
self.keepdims = node_attribute.get("keepdims", 1) == 1

def __call__(self, inputs, *args, **kwargs):
Expand All @@ -192,7 +192,7 @@ def __call__(self, inputs, *args, **kwargs):
class TFArgMin():
def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
super().__init__()
self.axis = shape_axis_utils.Torch2TFAxis(node_attribute.get('axis', 0))
self.axis = dimension_utils.channel_to_last_dimension(node_attribute.get('axis', 0))
self.keepdims = node_attribute.get("keepdims", 1) == 1

def __call__(self, inputs, *args, **kwargs):
Expand Down
22 changes: 11 additions & 11 deletions layers/deformation_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tensorflow as tf

from . import OPERATOR
from . import shape_axis_utils
from . import dimension_utils

LOG = logging.getLogger("deformation_layers :")

Expand All @@ -16,11 +16,11 @@ def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args
elif len(node_attribute['perm']) > 4:
self.perm_list = []
for axis in node_attribute['perm']:
new_axis = shape_axis_utils.Torch2TFAxis(axis)
new_axis = dimension_utils.channel_to_last_dimension(axis)
if new_axis == -1:
new_axis = max(node_attribute['perm'])
self.perm_list.append(new_axis)
self.perm_list = shape_axis_utils.TorchShape2TF(self.perm_list)
self.perm_list = dimension_utils.shape_NCD_to_NDC_format(self.perm_list)
else:
self.perm_list = [i for i in node_attribute['perm']]
LOG.info("Transpose will process tensor after change back to NCHW format.")
Expand All @@ -44,12 +44,12 @@ def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args
if len(node_inputs) == 1:
self.starts = node_attribute['starts'][0]
self.ends = node_attribute['ends'][0]
self.axis = shape_axis_utils.Torch2TFAxis(node_attribute['axes'][0])
self.axis = dimension_utils.channel_to_last_dimension(node_attribute['axes'][0])
self.steps = 1
else:
self.starts = node_weights[node_inputs[1]][0] if node_inputs[1] in node_weights else tensor_grap[node_inputs[1]][0]
self.axis = node_weights[node_inputs[3]][0] if node_inputs[3] in node_weights else tensor_grap[node_inputs[3]][0]
self.axis = shape_axis_utils.Torch2TFAxis(self.axis)
self.axis = dimension_utils.channel_to_last_dimension(self.axis)
self.ends = node_weights[node_inputs[2]][0] if node_inputs[2] in node_weights else tensor_grap[node_inputs[2]][0]
self.ends = min(self.ends, tensor_grap[node_inputs[0]].shape[self.axis])
if len(node_inputs) < 5:
Expand All @@ -65,7 +65,7 @@ def __call__(self, inputs):
class TFGather():
def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs) -> None:
super().__init__()
self.axis = shape_axis_utils.Torch2TFAxis(node_attribute.get('axis', 0))
self.axis = dimension_utils.channel_to_last_dimension(node_attribute.get('axis', 0))
self.indices = tensor_grap[node_inputs[1]] if node_inputs[1] in tensor_grap else node_weights[node_inputs[1]]

def __call__(self, inputs):
Expand All @@ -75,7 +75,7 @@ def __call__(self, inputs):
class TFConcat():
def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs):
super().__init__()
_axis = shape_axis_utils.Torch2TFAxis(node_attribute['axis'])
_axis = dimension_utils.channel_to_last_dimension(node_attribute['axis'])
_gather = [tensor_grap[x] for x in node_inputs]
self.out = tf.concat(_gather, axis=_axis)

Expand Down Expand Up @@ -129,7 +129,7 @@ def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args
start += int(node_attribute['split'][i])
end = start + node_attribute['split'][index]
self.indices = tf.keras.backend.arange(start, end, 1)
self.axis = shape_axis_utils.Torch2TFAxis(node_attribute.get("axis", 0))
self.axis = dimension_utils.channel_to_last_dimension(node_attribute.get("axis", 0))

def __call__(self, inputs):
return tf.gather(inputs, indices=self.indices, axis=self.axis)
Expand All @@ -138,7 +138,7 @@ def __call__(self, inputs):
class TFExpand():
def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs)->None:
super().__init__()
self.shape = shape_axis_utils.TorchShape2TF(node_weights[node_inputs[1]])
self.shape = dimension_utils.shape_NCD_to_NDC_format(node_weights[node_inputs[1]])

def __call__(self, inputs):
for i in range(len(self.shape)):
Expand All @@ -152,7 +152,7 @@ def __call__(self, inputs):
class TFUnsqueeze():
def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs)->None:
super().__init__()
self.axis = shape_axis_utils.Torch2TFAxis(node_attribute['axes'][0])
self.axis = dimension_utils.channel_to_last_dimension(node_attribute['axes'][0])

def __call__(self, inputs):
return tf.expand_dims(inputs, self.axis)
Expand All @@ -161,7 +161,7 @@ def __call__(self, inputs):
class TFSqueeze():
def __init__(self, tensor_grap, node_weights, node_inputs, node_attribute, *args, **kwargs)->None:
super().__init__()
self.axis = shape_axis_utils.Torch2TFAxis(node_attribute['axes'][0])
self.axis = dimension_utils.channel_to_last_dimension(node_attribute['axes'][0])

def __call__(self, inputs):
return tf.squeeze(inputs, self.axis)
33 changes: 33 additions & 0 deletions layers/dimension_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
'''
shape and axis transform utils func.
'''
def channel_to_last_dimension(axis):
'''
make channel first to channel last
'''
if axis == 0:
axis = 0
elif axis == 1:
axis = -1
else:
axis -= 1
return axis

def shape_NCD_to_NDC_format(shape:list or tuple):
'''
make shape format from channel first to channel last
'''
if len(shape) <= 2:
return tuple(shape)
new_shape = [shape[0], *shape[2:], shape[1]]
return tuple(new_shape)

def tensor_NCD_to_NDC_format(tensor):
'''
make tensor format from channel first to channel last
'''
if(len(tensor.shape) > 2):
shape = [i for i in range(len(tensor.shape))]
shape = shape_NCD_to_NDC_format(shape)
tensor = tensor.transpose(*shape)
return tensor
22 changes: 0 additions & 22 deletions layers/shape_axis_utils.py

This file was deleted.

13 changes: 8 additions & 5 deletions readme.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# ONNX->Keras and ONNX->TFLite tools

## How to use
```cmd
pip install -r requirements.txt
```
```python
# base
python converter.py --weights "./your_model.onnx"
Expand All @@ -17,12 +20,12 @@ python converter.py --weights "./your_model.onnx" --outpath "./save_path" --form
# cutoff model, redefine inputs and outputs, support middle layers
python converter.py --weights "./your_model.onnx" --outpath "./save_path" --formats "tflite" --input-node-names "layer_name" --output-node-names "layer_name1" "layer_name2"

# quantitative model weight, only weight
# quantify model weight, only weight
python converter.py --weights "./your_model.onnx" --formats "tflite" --weigthquant

# quantitative model weight, include input and output
# quantify model weight, include input and output
## recommend
python converter.py --weights "./your_model.onnx" --formats "tflite" --int8 --imgroot "./dataset_path" --int8mean 0 0 0 --int8std 1 1 1
python converter.py --weights "./your_model.onnx" --formats "tflite" --int8 --imgroot "./dataset_path" --int8mean 0 0 0 --int8std 255 255 255
## generate random data, instead of read from image file
python converter.py --weights "./your_model.onnx" --formats "tflite" --int8
```
Expand Down Expand Up @@ -95,8 +98,8 @@ onnx_converter(
target_formats = ['tflite'], #or ['keras'], ['keras', 'tflite']
weight_quant = False,
int8_model = True, # do quantification
int8_mean = [0.485, 0.456, 0.406], # give mean of image preprocessing
int8_std = [0.229, 0.224, 0.225], # give std of image preprocessing
int8_mean = [123.675, 116.28, 103.53], # give mean of image preprocessing
int8_std = [58.395, 57.12, 57.375], # give std of image preprocessing
image_root = "./dataset/train" # give image folder of train
)
```
Expand Down
7 changes: 3 additions & 4 deletions utils/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from onnx import numpy_helper
from .op_registry import OPERATOR

def representative_dataset_gen(img_root, img_size, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
def representative_dataset_gen(img_root, img_size, mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]):
if isinstance(mean, list):
mean = np.array(mean, dtype=np.float32)
if isinstance(std, list):
Expand All @@ -24,11 +24,10 @@ def representative_dataset_gen(img_root, img_size, mean=[0.485, 0.456, 0.406], s
else:
VALID_FORMAT = ['jpg', 'png', 'jpeg']
for i, fn in enumerate(os.listdir(img_root)):
if fn.split(".")[-1] not in VALID_FORMAT:
if fn.split(".")[-1].lower() not in VALID_FORMAT:
continue
_input = cv2.imread(os.path.join(img_root, fn))
_input = cv2.resize(_input, (img_size[1], img_size[0]))[:, :, ::-1]
_input = _input/255
if mean is not None:
_input = (_input - mean)
if std is not None:
Expand Down Expand Up @@ -115,7 +114,7 @@ def keras_builder(onnx_model, new_input_nodes:list=None, new_output_nodes:list=N
return keras_model

def tflite_builder(keras_model, weight_quant:bool=False, int8_model:bool=False, image_root:str=None,
int8_mean:list or float = [0.485, 0.456, 0.406], int8_std:list or float = [0.229, 0.224, 0.225]):
int8_mean:list or float = [123.675, 116.28, 103.53], int8_std:list or float = [58.395, 57.12, 57.375]):
converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
if weight_quant or int8_model:
Expand Down

0 comments on commit 1150d9e

Please sign in to comment.