Running a Vitis AI XModel (Python)¶
This example walks you through the process to make an inference request to a custom XModel in Python.
If you haven’t already, make sure to go through the Hello World (Python) example first!
We gloss over some of the details covered in the previous one.
The complete script used here is available: examples/python/custom_processing.py
.
User variables¶
Making an inference to an actual XModel that accepts an image requires some additional data from the user. These variables are pulled out into a separate block to highlight them.
Batch size: The DPU your XModel targets may have a preferred batch size and so we can use this value to create the optimally-sized request.
XModel Path: The XModel you want to run should exist on a path where the server runs. Here, we use a ResNet50 model trained on the ImageNet dataset, which is an image classification model.
Image Path: To test this model, we need to use an image. Here, we use a sample image included for testing. The image path is given relative to the AMD Inference Server repository.
batch_size = 4
path_to_xmodel = "${PROTEUS_ROOT}/external/artifacts/u200_u250/resnet_v1_50_tf/resnet_v1_50_tf.xmodel"
path_to_image = root + "/tests/assets/dog-3619020_640.jpg"
Load a Worker¶
As in the previous example, we can start the server and set up our client object from RestClient
.
To make an inference request to a custom XModel, we use the Xmodel worker that we have to load first.
Some workers accept load-time parameters to configure different options.
The Xmodel worker is one such worker.
The parameter we add is to pass the path to the XModel that we want to use.
parameters = proteus.RequestParameters()
parameters.put("model", path_to_xmodel)
worker_name = client.workerLoad("Xmodel", parameters)
ready = False
while not ready:
ready = client.modelReady(worker_name)
Get Images¶
Now, we can prepare our request. In this example, we use one image and duplicate it batch_size times so we can create one whole batch.
images = []
for _ in range(batch_size):
image = cv2.imread(path_to_image)
images.append(image)
Inference¶
Using our images, we can construct a request to AMD Inference Server.
The ImageInferenceRequest
class is a helper class that simplifies creating a request in the right format.
It accepts an image or a list of images and an optional boolean parameter that indicates whether the request should store the images directly as a tensor of RGB values.
By default, images are saved as base64-encoded strings however the Xmodel worker requires that the data is a tensor so we add True.
request = proteus.ImageInferenceRequest(images, True)
response = client.modelInfer(worker_name, request)
assert not response.isError(), response.getError()
outputs = response.getOutputs()
for output in outputs:
assert output.datatype == proteus.DataType.INT8
recv_data = output.getInt8Data()
The data we receive from the response is a list of numbers.
Adding pre- and post-processing¶
Depending on what model is run, you may need to add pre- and post-processing to the request for a useful inference. For this model, we do need to apply pre- and post-processing to get accurate classifications. To double-check our inference, we can check it against what we expect to receive. For this test image and XModel, the top-5 classifications expected are listed. 259 is the most probable and this index corresponds to Pomeranian in the ImageNet labels.
gold_response_output = [259, 261, 260, 157, 154]
For pre-processing, you can add custom logic to perform the necessary actions to the images prior to constructing the request. Similarly, post-processing can be added after the data is received. The pre- and post-processing functions used in this example can be seen in the source file.
images = preprocess(images)
# Construct the request and send it
request = proteus.ImageInferenceRequest(images, True)
response = client.modelInfer(worker_name, request)
assert not response.isError(), response.getError()
outputs = response.getOutputs()
for output in outputs:
assert output.datatype == proteus.DataType.INT8
recv_data = output.getInt8Data()
# Can optionally post-process the result
k = postprocess(recv_data, 5)
assert k == gold_response_output