ML Inference Workloads on the Triton Inference Server

cancel
Showing results for 
Show  only  | Search instead for 
Did you mean: 
Announcements
Please sign in to see details of an important advisory in our Customer Advisories area.
Community Team Member

ML inference workloads on the Triton Inference Server .jpg

 

This blog written by Ashwin Kannan

 

As we continue to scale, both in terms of traffic and in terms of the number of APIs (models/services) that we host, being cross compatible between cloud platforms (AWS/GCP/Azure) starts to become more of a priority. When we only had a few models and one model that needed to be run on an Accelerator, Inferentia (from AWS) was the best choice in terms of cost as well as latency. Upon adding on a few more APIs under our arsenal and into our overall offerings, the scalability aspect of inferentia became less important as it meant that we were tied to AWS, as it is the only cloud provider to offer Inferentia instances, and having the flexibility to move across cloud providers became more of a pressing priority. This left us with one viable option which is moving over to using GPU instances as our accelerated instance.

 

Why moving over to the GPU without any optimization is not an option

 

  1. Cost

When it comes to the cost of each instance, Inferentia instances start at 22 cents/hour on AWS and a t4 GPU instance, the cheapest of the GPU instances starts at 75 cents/hour. Given that the GPU instance is 3x more expensive than the inferentia instance, we know that mindlessly moving our workload from the inferentia instances to the GPU instances would not be an option.

 

  1. Latency

The latency on the GPU for the same input was 8x as much as the inferentia instances on the same input size. When run with an optimized runtime such as ONNX runtime, we were able to get it down to be 7x slower. Given that each instance could have multiple instances of the model loaded into memory, the real metric that accounts for throughput of an instance is RPS (requests per second) which represents the number of requests the instance can handle per second. On the inferentia instances, we can handle 1RPS on the maximum file size that we allow and on the GPU with the ONNX runtime, we can handle 0.2 RPS. This would mean to process the same number of requests, we would 5x the number of GPU instances as inferentia instances.

 

How can we optimize the model

 

  1. Model optimizations such as quantization/pruning

We tried quite a few optimization techniques such as quantization and the pruning of our model but the performance was not there once quantized or pruned. We noticed that the accuracy numbers dropped significantly once quantized or pruned on a sample set of data that we had and the logits, which are the outputs from the model, were not even close to what the model would return.

 

  1. Tensor Run Time (TRT)

In order to convert a model to the Tensor Run Time, there needs to be an ONNX version of the model. The NVIDIA documentation on how to convert a pytorch or tensorflow model to a TRT model can be found here: (LINK)

Given that we have a pytorch model, here are the steps that we followed.

 

IMPORTANT: although we initially tried to create a docker container that contained all of the requirements that are needed for the Tensor Run Time, it is much simpler to use a prebuilt container built by NVIDIA that already has the requirements pre-installed. You can find the containers here: https://catalog.ngc.nvidia.com/orgs/nvidia/containers/tritonserver/tags

 

The tag that we are using is: 

 nvcr.io/nvidia/tritonserver:23.01-py3

 

To get the same tag, you can run: 

docker pull nvcr.io/nvidia/tritonserver:23.01-py3

 

  1. Convert the model to ONNX (the opset matters, for us opset 11 worked)
  2. Convert the ONNX model to a TRT model

 

In order to convert the pytorch model to ONNX, the following command was run:

torch.onnx.export(model, tokens, "pytorch_phi_fully_loaded_with_input_names_with_dynamic_batch_size_opset_11.onnx", 
verbose=True, opset_version=11, input_names=['input_ids', 'attention_mask'], output_names=['logits'],
dynamic_axes={'input_ids' : {0: 'batch', 1: 'sequence'}, 'attention_mask' : {0: 'batch', 1: 'sequence'},
'logits' : {0: 'batch', 1: 'sequence'}})

 

In this case, the model is the loaded in pytorch model, and as can be seen in the command, we specify the opset_version as well as the input and output names. We also specify that the input_ids , attention_mask and logits are dynamic as that makes a difference when converting the ONNX model to a TRT model. In our case, the size of tokens going into the model is not fixed as it can be up to size 512 but the size is unknown and based upon how we partition the document that we receive. Given this is the case, our input and output sizes are dynamic.

 

To convert the ONNX model to a TRT model, the following command was run:

!trtexec — onnx=pytorch_phi_fully_loaded_with_input_names_with_dynamic_batch_size_opset_11.onnx — 
saveEngine=/workspace/new_pytorch_classifier_head_defined_variable_size.plan — verbose — minShapes=input_ids:1x1,
attention_mask:1x1 — optShapes=input_ids:1x512,attention_mask:1x512 — maxShapes=input_ids:1x512,attention_mask:1x512 — workspace=14000

 

In the command above, we pass in the path to the onnx file and where we want to save the TRT model. Additionally, since our input sizes are unknown, we have to specify that the minsize is 1 and the max size is 512. The optimal size is 512 because we expect that if we get text which is longer than 512 tokens, we will split it up at 512 tokens exactly every time.

 

Running Inference of the TRT Model

 

The TRT model optimizes a chunk of the computations by saving memory in CUDA for the inputs and outputs, sizes that have to be specified when instantiating the session. If there is a size mismatch, there is an error which is thrown, so it is important to be careful with this step. As can be seen in the code below, I specify the input and output size after the request comes in and before instantiating the session.

 

import numpy as np
input_ids = np.array(data['input_ids'][0], dtype="int32").reshape(1, -1)
attn_mask = np.array(data['attention_mask'][0], dtype="int32").reshape(1, -1)
data = {'input_ids': input_ids, 'attention_mask': attn_mask}
with open(path_to_engine, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime, \
runtime.deserialize_cuda_engine(f.read()) as engine, engine.create_execution_context() as context:
input_shape = input_ids.shape
input_nbytes = trt.volume(input_shape) * trt.int32.itemsize
# Allocate device memory for inputs.
d_inputs = [cuda.mem_alloc(input_nbytes) for binding in range(2)]
# Create a stream in which to copy inputs/outputs and run inference.
stream = cuda.Stream()
# Specify input shapes. These must be within the min/max bounds of the active profile (0th profile in this case)
# Note that input shapes can be specified on a per-inference basis, but in this case, we only have a single shape.
for binding in range(2):
context.set_binding_shape(binding, input_shape)
assert context.all_binding_shapes_specified
# Allocate output buffer by querying the size from the context. This may be different for different input shapes.
# print(tuple(context.get_binding_shape(2)))
h_output = cuda.pagelocked_empty(tuple(context.get_binding_shape(2)), dtype=np.float32)
d_output = cuda.mem_alloc(h_output.nbytes)
input_ids = cuda.register_host_memory(data['input_ids'])
attention_mask = cuda.register_host_memory(data['attention_mask'])
cuda.memcpy_htod_async(d_inputs[0], input_ids, stream)
cuda.memcpy_htod_async(d_inputs[1], attention_mask, stream)
# Run inference
return_val = context.execute_async_v2(bindings=[int(d_inp) for d_inp in d_inputs] + [int(d_output)], stream_handle=stream.handle)
# Synchronize the stream
stream.synchronize()
# Transfer predictions back from GPU
cuda.memcpy_dtoh_async(h_output, d_output, stream)
stream.synchronize()
trt_output = h_output

 

Most of the code above is taken from multiple samples and there are a few key things to note:

 

  1. Ensure that the precision (int32/int64/etc) that you are passing in to the model is what the model is expecting
  2. You can check what the model by running

!polygraphy inspect model new_pytorch_classifier_head_defined_variable_size.plan

 

Since the input shape is dynamic, I am reading in my input_ids first and then computing the shape on the fly. This is critical because when running a TRT model, memory is allocated to the inputs and outputs, if the memory allocated is not exact, an error will be thrown.

 

Results with TRT

 

When running inference on the TRT model, we were able to get a response time which is 2x faster than the response time of the ONNX runtime model on the same instance and with the same sized input. When we compared the output of the TRT model to the ONNX model, the logits output were almost exactly identical. The loss is minimal during the conversion.

 

This is very promising given that we can run multiple instances of the TRT model on one instance meaning that we could get much closer to competing with the RPS per instance that we were getting from Inferentia.

 

Running the TRT model on the Triton Inference Server

 

Before we hop into running the TRT model on Triton Inference Server, I would like to give a little background on the Triton Inference Server.

 

In simple terms, the Triton Inference Server is just a docker container which has the ability to host various kinds of models such as TRT models, PyTorch models, ONNX models and Tensorflow models. Along with hosting the models, it also hosts a metrics endpoint where you can get some information like CPU usage, number of requests served per model that is running on the server, latency per request, etc. There are some other nice features that come with the server such as the ability to host multiple versions of the same model for A/B testing, the ability to spin up and spin down model versions, rate limiting, running ensemble workflows seamlessly by hosting both models on the same server, etc.

 

In order to run the TRT model or any model on the Triton Inference Server, there are a few things that are always going to be needed:

 

  1. The model
  2. A corresponding config.pbtxt file

 

The model folder structure is going to be different from one kind of model to another and I will give an example for both the TRT model as well as a Tensorflow model. The config.pbtxt file tells the server what the input size to the model should be, what the output size of the model should be, which version of the model to use, what the dynamic batching should be (if any) and whether the model should run on the GPU or the CPU. I will also provide an example below of both a GPU config.pbtxt and a CPU config.pbtxt.

 

The pytorch folder structure looks like this:

└── pyotrch_model
└── pytorch
├── 1
│ └── model.plan
└── config.pbtxt

 

In this setting, 1 is the version number and the config.pbtxt is expected to stay the same no matter which version you are using.

 

The tensorflow folder structure looks like this:

├── tf_model
│ └── tf
│ ├── 1
│ │ └── model.savedmodel
│ │ ├── saved_model.pb
│ │ └── variables
│ │ ├── variables.data-00000-of-00001
│ │ └── variables.index
│ └── config.pbtxt

 

The config.pbtxt for our pytroch model which runs on the GPU looks like below:

name: "pytorch_model"
platform: "tensorrt_plan"
dynamic_batching { }
instance_group [
 {
 count: <how many instances of the model you want spun up>
 kind: KIND_GPU
 }
]
input [
       {
       name: "<input_1>"
       data_type: TYPE_INT32
       dims: [-1]
       },
       {
       name: "<input_2>"
       data_type: TYPE_INT32
       dims: [-1]
       }
     ]
output [
       {
       name: "<output_1>"
       data_type: TYPE_FP32
       dims: [-1,25]
       }
      ]

 

As can be seen, the input_ids and attention_mask are dynamic hence marked with shape -1. Now, you may be wondering, was the shape not [1xY] where Y is in between 1 and 512 for the inputs? The answer is yes, it was but, when specifying the dimensions in the config.pbtxt, you can ignore the first dimension.

 

As can also be seen above, we do not have any dynamic batching enabled as our batch size is 1, which is the setting that should be used for the fastest response time from the Triton Inference Server.

 

The config.pbtxt for the tensorflow model which runs on the CPU is like below:

name: "tensorflow_model_name"
platform: "tensorflow_savedmodel"
dynamic_batching { }
instance_group [
 {
 count: <how many instances of the model you want spun up>
 kind: KIND_CPU
 }
]
input [
 {
 name: "<input_1>"
 data_type:<data_type>
 dims: [256,256,3]
 }
]
output [
 {
 name: "<output_1>"
 data_type: <data_type>
 dims: [6]
 }
]

 

You can find out more about the model config.pbtxt file here — https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md

 

Leveraging Auto Configs Built By Triton Inference Server

 

One of the nice parts about the Triton Inference Server is that you can pass it a model and it will read the model’s input and output information to create a config.pbtxt for you. This takes out a lot of the configuration that may be required to be done before spinning up a model on the Triton Inference Server.

 

You can find the auto generated configuration here , once you have the model hosted on the Triton Inference Server— curl localhost:8000/v2/models/<model name>/config

 

Running the models on the Triton Inference Server

 

Now that we have our models hosted, we ran a load test on the instance and saw that the RPS on the instance was 0.8 for the same input sized input. In comparison to running on the ONNX model on the GPU, this is a 4x improvement in RPS and puts us much closer to inferentia which is at 1RPS per instance.

 

Benefits of the Triton Inference Server

 

The benefits of the Triton Inference Server are that it takes out a lot of the complication of running a TRT model while giving a fairly simple user interface. Additionally, you can spin up multiple instances of the same model by specifying it in the config.pbtxt file and although this does not necessarily mean that all of the model instances will be invoked on the GPU at the same time, it does allow for one to utilize the GPU memory fully. The biggest advantage of the triton inference server is the CPU usage on a GPU workload is very minimal.

 

We also noticed that the RPS on CPU only models is 20% more when running on Triton Inference Server compared to running on the CPU with the same amount of resources.

 

Drawbacks of the Triton Inference Server

 

The only drawback of the Triton Inference Server is that it does not allow dynamic scaling from one instance to another. Assume you have 2 services, A and B. A receives 1 request per second and B receives 10 requests per second. You would assume that B would have 10x the number of instances as A does. Let us say that when we spin up the first server, both have 10 instances of the model each. When it does come time to scale, the second server will also have 10 instances of the model each instead of it having more instances of model B since it gets more traffic compared to model A. This means that in a production environment, each model would have to have its own Triton Inference Server setup which is the direction that we have decided to go in.

 

Current Infrastructure

 

Most ML services look something like this:

 

  1. Receive request
  2. Do pre-processing
  3. Call the model
  4. Do post-processing
  5. Return response

Since we load up multiple instances of the model into the memory on the accelerated instance, we notice that each ML service takes up the whole instance when it comes to the CPU, accelerator and Memory. A t4 GPU instance has 1 GPU, 8 CPUs and 32 GB of RAM and our model service takes up the GPU, 6 CPUs and 28 GB of RAM. The other 2 CPUs are taken up by daemonsets and miscellaneous processes making it that we occupy the whole instance with one service.

 

Services example.jpg

 

Just to keep things simple, we have 2 services, one ML service which runs on an accelerated instance (GPU/inferentia) and another ML service which runs on a CPU. For the purpose of keeping it simple, both services get traffic from a service X which also runs on a CPU.

 

Our current auto-scaling strategy looks like something like this:

Auto-scaling strategy.jpg

 

If there is a new GPU pod that needs to be spun up, we create a new instance because the pod will take up all of the resources on the instance. If there is a new CPU pod that needs to be spun up, we check if there are already enough resources on the CPU instances to host the pod and if so, we spin up on the CPU instance. If not, we spin up a new CPU instance

 

Future Infrastructure with Triton

 

The need for a logic server with each Triton Inference Server

Because the Triton Inference Server can only handle inference and no pre/post processing, for all the services that run as one container where we do the 5 steps listed above in a single container, we need to create a separate service which we will call the Logic Server moving forward. The logic server is going to be a minimal CPU compute server that takes care of the following:

 

  1. Handle the incoming request
  2. Do the pre-processing
  3. Call and wait for the response from the model endpoint
  4. Do the post processing once the response is received
  5. Return the response

 

Running CPU compute containers on the excess CPUs on GPU instances

 

As we noted above in the advantages of utilizing the triton inference server section, there are excess CPUs when running a GPU workload on the Triton Inference Server on a GPU instance.

 

We will utilize those CPUs to run other Triton Inference Servers which only have a CPU workload or other services which are CPU based such as the logic servers.

 

Autoscaling with Triton Inference Server

 

We are using the keda autoscaler that comes with our current kubernetes setup in order to handle our autoscaling needs. The GPU Triton Inference Server services will spin up new GPU instances when the service needs to scale. The CPU Triton Inference Server services, logic services and other CPU based services will try to be allocated on the excess CPUs on the GPU instances which are already spun up and if there are not enough resources for the service to be placed on those instances, they will get placed on additional CPU instances. We are doing this with node affinity which is a concept within Kubernetes that can be read more about here —  https://kubernetes.io/docs/concepts/scheduling-eviction/assign-pod-node/

 

Our overall auto-scaling strategy changes to be something like:

Overall auto-scaling strategy.jpg

 

A sample request that is coming to a “router” that calls all of the models that we have may look like:

Sample outline of instances and allocation.jpg

 

A sample outline of Instances and allocation of CPU and GPU pods to the instances. In the image above, we have 2 instances of the ML Service 1 which runs on the GPU as depicted in green and we have 2 instances of ML Service 2 running on the CPU as depicted in blue. The ELB decides which instance of the ML service to hit based on traffic that is being sent to each service. Just for demonstration purposes, I have shown that the logic server pods and the ML service pods do not need to be on the same instance as the ELB handles all of the calling no matter which instance the pods are located on.

 

Impact on our team with Triton Inference Server

  1. Cost

Our current operating costs with inferentia and all the pods that we have running for our services is about $37 an hour. This consists of 1 inferentia based service and 5 CPU based services. If we were to blindly move the inferentia based service over to a GPU, based on the current instance count that we have on inferentia, the cost per hour would go up to $55 an hour. This is an increase of approximately 50%.

 

With the Triton Inference Server and having our accelerated model running on the GPU there, using the excess CPUs for our other services, our cost is $33 an hour. This is significant considering that Inferentia instances are 3x cheaper than GPU instances. An additional benefit of this is that as we scale up the number of services that require a GPU, the gap between converting the model to an inferentia model and running it on inferentia instances and converting the model to a TRT model and running it on the Triton Inference Server with the GPU is going to keep increasing as we will have excess CPUs which can then be used by other services for essentially no additional cost.

 

2. Latency

The latency on the Triton Inference Server for CPU based models is faster than running the model on a regular CPU instance, with an improvement of about 20% in latency. The latency on the GPU with Triton Inference Server and converting the model to a TRT model gives a 50+% performance boost as compared to just running it on the GPU with PyTorch or Tensorflow. The latency compared to inferentia is similar.

 

Impact on PAN with Triton Inference Server

 

The impact on PAN can be widespread and the Triton Inference Server solution is widely applicable to many individuals and teams as PAN goes through the AI transformation. The Triton Inference Server allows for various kinds of models to be hosted, including LLMs which has been the latest craze since ChatGPT blew up a few months ago and with PANW moving down that path, there will be many such models that would need to be hosted on the GPU and CPU while being cost efficient with no impact on latency. 

 

Additionally, with the ability for the models to be CPU efficient while running on the Triton Inference Server, this allows for us to run other CPU workloads on the running instances which is going to result in potentially large cost savings.

 

Conclusion

 

Moving to Triton Inference Server is a logical next step for us as we plan to move our workloads potentially across multiple service providers. The move to use the Triton Inference Server makes sense as it is cheaper than running the accelerated workloads on Inferentia for us as we have many models which can run on the CPU and make use of the excess CPUs on the GPU instances. We note a cost benefit with moving over to the Triton Inference Servers as well which is marginal for now but is expected to grow as we spin up new GPU and CPU based services.

  • 3928 Views
  • 0 comments
  • 1 Likes
Register or Sign-in
Labels
Top Liked Authors