Running multiple GPU model Inference in SageMaker
This post will help you to inference on multiple models on a GPU instance in SageMaker via Triton Inference Server. We will showcase it for two classification models; one for german and the other for english.
Model
First, compile your model to TensorRT format for faster inference. Once you have trained your model, you can convert via different ways. Pytorch to TensorRT direct conversion via Torch-TensorRT disn’t work out in this case. It had to be converted to ONNX and then to TensorRT.
Torch to ONNX
This conversion can be done inside a python script. e.g.
en_model_kind, en_model_onnx_config = FeaturesManager.check_supported_model_or_raise(en_model, feature=feature)
en_onnx_config = en_model_onnx_config(en_model.config)
# export
onnx_inputs, onnx_outputs = transformers.onnx.export(
preprocessor=en_tokenizer,
model=en_model,
config=en_onnx_config,
opset=13,
output=Path("model-en.onnx"),
)
ONNX to TensorRT
For this conversion, you need to run it on the same GPU as you plan for the inference. You can use the container nvcr.io/nvidia/pytorch:22.07-py3
for this as well.
We will use the trtexec
tool for that.
trtexec --onnx=model-en.onnx --saveEngine=model_en.plan --shapes=input_ids:1x512,attention_mask:1x512 --fp16 --verbose --workspace=14000 | tee conversion_en.txt
Code
At the time of writing this, Triton Inference Server containers can only be deployed via boto3 (the lower level API)
Place your model and config files in the following structure.
Generate two separate .tar archives via
tar -czvf english_model.tar.gz -C models/ english
tar -czvf german_model.tar.gz -C models/ german
Upload them to S3.
For the triton container, create a dictionary that tells that this is a multi-model endpoint.
container = {
'Image': triton_image_uri,
'ContainerHostname': 'MultiModel',
'Mode': 'MultiModel',
'ModelDataUrl': mme_data_path_s3,
}
and create the model!
sm_client.create_model(
ModelName=model_name,
PrimaryContainer=container,
ExecutionRoleArn=role_arn,
)
The inference script for triton is quite long and can be seen in the notebook. Make sure that the data format matches the ones specified during the conversion.
For the complete code, visit the repo.