Skip to content

Commit

Permalink
Allow using ONNX for testing (WIP)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdw771 committed Apr 26, 2024
1 parent de1a5e4 commit adf8753
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 5 deletions.
69 changes: 69 additions & 0 deletions generic_trainer/inference_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import logging

try:
import pycuda.driver as cuda
import tensorrt as trt
except ImportError:
print('Unable to import pycuda and tensorrt. If you do not intend to use the ONNX inferencer, ignore '
'this message. ')


def engine_build_from_onnx(onnx_mdl):
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
builder = trt.Builder(TRT_LOGGER)
config = builder.create_builder_config()
# config.set_flag(trt.BuilderFlag.FP16)
config.set_flag(trt.BuilderFlag.TF32)
# config.max_workspace_size = 1 * (1 << 30) # the maximum size that any layer in the network can use

network = builder.create_network(EXPLICIT_BATCH)
parser = trt.OnnxParser(network, TRT_LOGGER)
# Load the Onnx model and parse it in order to populate the TensorRT network.
success = parser.parse_from_file(onnx_mdl)
for idx in range(parser.num_errors):
print(parser.get_error(idx))

if not success:
return None

return builder.build_engine(network, config)


def mem_allocation(engine):
"""
Determine dimensions and create page-locked memory buffers (i.e. won't be swapped to disk) to hold host
inputs/outputs.
"""
logging.info('Expected input node shape is {}'.format(engine.get_binding_shape(0)))
in_sz = trt.volume(engine.get_binding_shape(0)) * engine.max_batch_size
logging.info('Input size: {}'.format(in_sz))
h_input = cuda.pagelocked_empty(in_sz, dtype='float32')

out_sz = trt.volume(engine.get_binding_shape(1)) * engine.max_batch_size
h_output = cuda.pagelocked_empty(out_sz, dtype='float32')

# Allocate device memory for inputs and outputs.
d_input = cuda.mem_alloc(h_input.nbytes)
d_output = cuda.mem_alloc(h_output.nbytes)

# Create a stream in which to copy inputs/outputs and run inference.
stream = cuda.Stream()

return h_input, h_output, d_input, d_output, stream


def inference(context, h_input, h_output, d_input, d_output, stream):
# Transfer input data to the GPU.
cuda.memcpy_htod_async(d_input, h_input, stream)

# Run inference.
context.execute_async_v2(bindings=[int(d_input), int(d_output)], stream_handle=stream.handle)

# Transfer predictions back from the GPU.
cuda.memcpy_dtoh_async(h_output, d_output, stream)

# Synchronize the stream
stream.synchronize()
# Return the host
return h_output
57 changes: 52 additions & 5 deletions generic_trainer/tester.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
import os

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

import generic_trainer.trainer as trainer
from generic_trainer.configs import *
from generic_trainer.inference_util import *


class Tester(trainer.Trainer):
Expand All @@ -18,6 +21,20 @@ def __init__(self, configs: InferenceConfig):
self.sampler = None
self.dataloader = None
self.parallelization_type = self.configs.parallelization_params.parallelization_type
self.mode = 'state_dict'

# Attributes below are used for ONNX
self.onnx_mdl = None

self.trt_hin = None
self.trt_din = None
self.trt_hout = None
self.trt_dout = None

self.trt_engine = None
self.trt_stream = None
self.trt_context = None
self.context = None

def build(self):
self.build_ranks()
Expand All @@ -28,6 +45,22 @@ def build(self):
self.build_model()
self.build_dir()

def build_model(self):
if self.configs.pretrained_model_path.endswith('onnx'):
logging.info('An ONNX model is given. This model will be loaded and run with TensorRT.')
self.build_onnx_model()
self.mode = 'onnx'
else:
super().build_model()

def build_onnx_model(self):
import pycuda.autoinit
self.context = pycuda.autoinit.context
self.onnx_mdl = self.configs.pretrained_model_path
self.trt_engine = engine_build_from_onnx(self.onnx_mdl)
self.trt_hin, self.trt_hout, self.trt_din, self.trt_dout, self.trt_stream = mem_allocation(self.trt_engine)
self.trt_context = self.trt_engine.create_execution_context()

def build_scalable_parameters(self):
self.all_proc_batch_size = self.configs.batch_size_per_process * self.num_processes

Expand Down Expand Up @@ -59,11 +92,25 @@ def build_dir(self):
self.barrier()

def run(self):
self.model.eval()
if self.mode == 'state_dict':
self.model.eval()
for j, data_and_labels in enumerate(self.dataloader):
data, _ = self.process_data_loader_yield(data_and_labels)
preds = self.model(*data)
self.save_predictions(preds)
data, labels = self.process_data_loader_yield(data_and_labels)
if self.mode == 'state_dict':
preds = self.model(*data)
else:
preds = self.run_onnx_inference(*data)
self.update_result_holders(preds, labels)

def run_onnx_inference(self, data):
data = data.cpu().numpy()
orig_shape = data.shape
np.copyto(self.trt_hin, data.astype(np.float32).ravel())
pred = np.array(inference(self.trt_context, self.trt_hin, self.trt_hout,
self.trt_din, self.trt_dout, self.trt_stream))

pred = pred.reshape(orig_shape)
return pred

def save_predictions(self, preds):
def update_result_holders(self, preds, *args, **kwargs):
pass

0 comments on commit adf8753

Please sign in to comment.