A.N.T

Serving an ONNX Model using FastAPI

Learn how to serve an ONNX model with FastAPI.

August 26, 2024

06m Read

By: Abhilaksh Singh Reen

Table of Contents

Server Setup

Requirements

Configuration

FastAPI App and ORT Inference

Testing

Configuration

API Tests

Conclusion

In a previous Post, we saw how to train a CNN in PyTorch and export it to ONNX. Today, we'll create a FastAPI server that wraps around our ONNX Runtime to provide an API endpoint to which we can send images and get back predictions.

This GitHub Repository contains the code that we'll write today.

Server Setup

Our codebase will have a directory structure similar to the following:

│   requirements.txt
│
├───models
├───src
│       app.py
│       config.py
│       __init__.py
│
└───tests
        config.py
        requirements.txt
        test_apis.py

Requirements

Inside a virtual environment, install the required packages.

pip install fastapi numpy onnxruntime opencv-python pydantic uvicorn

Configuration

Let's work on src/config.py first.

from os.path import dirname, join as path_join


models_dir = path_join(dirname((dirname(__file__))), "models")

This file only contains a variable to store the path to our models directory. That's about it for the configuration.

FastAPI App and ORT Inference

Inside the src folder, create a file called app.py. Inside this file, we first import the required packages and create our InferenceSession.

from os.path import join as path_join

import cv2
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
import numpy as np
import onnxruntime

from .config import models_dir


app = FastAPI()

ONNX_MODEL_FILE_PATH = path_join(models_dir, "training_id", "model.onnx")
ONNX_MODEL_FILE_PATH = path_join(models_dir, "torch---2024-04-23-08-25-15", "model.onnx")

ort_session = onnxruntime.InferenceSession(ONNX_MODEL_FILE_PATH)
ort_session_input_name = ort_session.get_inputs()[0].name
ort_session_output_name = ort_session.get_outputs()[0].name

We get the input and output names as we'll be needing these later when we run the inference.

Next, we define a couple of helper functions. One to normalize Uint8 values to floats between 0 and 1 and another to preprocess the image: change it into the model's input shape and normalize it.

def normalize(x, axis=-1, order=2):
    "Taken from: https://github.com/keras-team/keras/blob/v3.2.1/keras/utils/numerical_utils.py#L7-L34"

    norm = np.atleast_1d(np.linalg.norm(x, order, axis))
    norm[norm == 0] = 1

    axis = axis or -1
    return x / np.expand_dims(norm, axis)


def preprocess_image(image):
    image = cv2.resize(image, (28, 28))
    image = image.astype(np.float32)
    image = image[np.newaxis, np.newaxis, :, :]
    image = normalize(image, axis=1)
    return image

The normalize function here is a custom implementation of a similar function that would be offered by Torch or TensorFlow. The reason for doing this DIY is to eliminate the dependency on Torch or TF when running the ONNX inference.

With that done, we can move on to creating our inference route.

@app.post("/api/run-inference")
async def run_inference(file: UploadFile = File(...)):
    file_contents = await file.read()
    np_arr = np.frombuffer(file_contents, np.uint8)
    image = cv2.imdecode(np_arr, cv2.IMREAD_GRAYSCALE)

    image = cv2.resize(image, (28, 28))

    preprocessed_image = preprocess_image(image)

    prediction = ort_session.run(
        [ort_session_output_name],
        {
            ort_session_input_name: preprocessed_image,
        }
    )

    predicted_label = int(np.argmax(prediction))

    return JSONResponse(
        content={
            "predicted_label": predicted_label,
        },
        status_code=200
    )

The route accepts a file that contains the image, we read its contents into a numpy array and then create a cv2 image from the array, and then pass it to the preprocess_image function to get the input we can send to the model. We use ort_session.run to get the probabilities for all the 10 classes. Finally, we can use numpy.argmax to get the index of the class with the highest probability. Since the classes are from 0 to 9, this index is the predicted label. We return it in a JSONResponse.

That's it for the code, start the server using:

uvicorn src.app:app --port=8000 --host=0.0.0.0

You should see a message that says uvicorn is running on 0.0.0.0:8000.

Testing

Inside the project directory, at the same level as the src folder, create a new folder called tests.

Configuration

In the tests folder, create a new file called config.py.

from os.path import dirname, join as path_join


data_dir = path_join(dirname(dirname(__file__)), "data")

Here, we've just defined a data directory from where we'll be loading our test images.

API Tests

Great, now let's create another file in the same directory called test_apis.py. We'll be using Python's unittest module along with requests to test our APIs. FastAPI provides a TestClient that can be used for this purpose. However, in later Posts, we'll be deploying the same ONNX model on an Express server in Node.js. Hence, I'll write API tests that are independent of the web server.

First things first, we have to install the required packages.

pip install requests

Now, inside test_apis.py, we import the required packages and set up our unittest TestCase.

from json import load as json_load
from os import listdir
from os.path import join as path_join
import unittest

import requests

from config import data_dir


class TestAPIInference(unittest.TestCase):
    def setUp(self):
        self.url = "http://localhost:8000/api/run-inference"
        self.images_dir = path_join(data_dir, "test_images")
        self.labels_json_file_path = path_join(data_dir, "test_images_labels.json")

        with open(self.labels_json_file_path, 'r') as labels_json_file:
            self.labels = json_load(labels_json_file)

For this to work, inside the specified data directory, I have created a folder called test_images that contains some digits drawn in MS Paint (white brush on a black background). Also, in the data directory, I have a file called test_images_labels.json that contains the labels corresponding to the test images.

Here's the labels file:

{
    "1.png": 1,
    "2.png": 7,
    "3.png": 2,
    "4.png": 9,
    "5.png": 8,
    "6.png": 5,
    "7.png": 1,
    "8.png": 7,
    "9.png": 1,
    "10.png": 7,
    "11.png": 7,
    "12.png": 0,
    "13.png": 5,
    "14.png": 3,
    "15.png": 2,
    "16.png": 1,
    "17.png": 0,
    "18.png": 8,
    "19.png": 7,
    "20.png": 4
}

Now, inside test_apis.py, in the TestAPIInference class, we define a function to test our APIs.

def test_inference(self):
    num_correct_predictions = 0
    num_bad_responses = 0
    all_image_names = listdir(self.images_dir)

    for image_name in all_image_names:
        image_file_path = path_join(self.images_dir, image_name)

        with open(image_file_path, "rb") as file:
            response = requests.post(self.url, files={"file": file})

        if response.status_code != 200:
            num_bad_responses += 1

        self.assertEqual(response.status_code, 200)

        response_data = response.json()

        if response_data['predicted_label'] == self.labels[image_name]:
            num_correct_predictions += 1

        print(
            f"Image: {image_name}, "
            f"Predicted Label: {response_data['predicted_label']}, "
            f"Correct: {response_data['predicted_label'] == self.labels[image_name]}"
        )

    num_inferred = len(all_image_names) - num_bad_responses
    accuracy = num_correct_predictions / num_inferred

    print(f"Failed to infer: {num_bad_responses} / {len(all_image_names)}")
    print(f"Correct Predictions: {num_correct_predictions} / {num_inferred}, Accuracy: {accuracy}")

Note that the job of the test is to just check if the API is functioning correctly, not the accuracy of the model. Although we do compute the accuracy, a bad prediction does not cause the test to fail.

Finally, in the main block, we can call unittest.main.

if __name__ == '__main__':
    unittest.main()

Here's the entire test_apis.py file:

from json import load as json_load
from os import listdir
from os.path import join as path_join
import unittest

import requests

from config import data_dir


class TestAPIInference(unittest.TestCase):
    def setUp(self):
        self.url = "http://localhost:8000/api/run-inference"
        self.images_dir = path_join(data_dir, "test_images")
        self.labels_json_file_path = path_join(data_dir, "test_images_labels.json")

        with open(self.labels_json_file_path, 'r') as labels_json_file:
            self.labels = json_load(labels_json_file)

    def test_inference(self):
        num_correct_predictions = 0
        num_bad_responses = 0
        all_image_names = listdir(self.images_dir)

        for image_name in all_image_names:
            image_file_path = path_join(self.images_dir, image_name)

            with open(image_file_path, "rb") as file:
                response = requests.post(self.url, files={"file": file})

            if response.status_code != 200:
                num_bad_responses += 1

            self.assertEqual(response.status_code, 200)

            response_data = response.json()

            if response_data['predicted_label'] == self.labels[image_name]:
                num_correct_predictions += 1

            print(
                f"Image: {image_name}, "
                f"Predicted Label: {response_data['predicted_label']}, "
                f"Correct: {response_data['predicted_label'] == self.labels[image_name]}"
            )

        num_inferred = len(all_image_names) - num_bad_responses
        accuracy = num_correct_predictions / num_inferred

        print(f"Failed to infer: {num_bad_responses} / {len(all_image_names)}")
        print(f"Correct Predictions: {num_correct_predictions} / {num_inferred}, Accuracy: {accuracy}")


if __name__ == '__main__':
    unittest.main()

With the server running, let's run this file in another terminal:

python test_apis.py

You should see an output similar to the following:

Image: 1.png, Predicted Label: 6, Correct: False
Image: 10.png, Predicted Label: 7, Correct: True
Image: 11.png, Predicted Label: 7, Correct: True
Image: 12.png, Predicted Label: 0, Correct: True
Image: 13.png, Predicted Label: 5, Correct: True
Image: 14.png, Predicted Label: 3, Correct: True
Image: 15.png, Predicted Label: 2, Correct: True
Image: 16.png, Predicted Label: 1, Correct: True
Image: 17.png, Predicted Label: 0, Correct: True
Image: 18.png, Predicted Label: 6, Correct: False
Image: 19.png, Predicted Label: 7, Correct: True
Image: 2.png, Predicted Label: 7, Correct: True
Image: 20.png, Predicted Label: 4, Correct: True
Image: 3.png, Predicted Label: 2, Correct: True
Image: 4.png, Predicted Label: 9, Correct: True
Image: 5.png, Predicted Label: 8, Correct: True
Image: 6.png, Predicted Label: 5, Correct: True
Image: 7.png, Predicted Label: 1, Correct: True
Image: 8.png, Predicted Label: 7, Correct: True
Image: 9.png, Predicted Label: 1, Correct: True
Failed to infer: 0 / 20
Correct Predictions: 18 / 20, Accuracy: 0.9
.
----------------------------------------------------------------------
Ran 1 test in 40.873s

OK

Great, our APIs are functioning as intended.

Conclusion

Today, we've learned a simple and easy way of creating an API for inference using our own ML model. Most models that have been exported to ONNX can be deployed using a similar API, with the corresponding changes in pre and post processing. Today's tutorial was not even a hundred lines of code!

Deployment of ONNX models is not restricted to Python, it can be done in any language that has the ONNX Runtime available for it. In this Article, we deploy this same model in an Express.js server.

But, what if you don't want to go the ONNX route? If you wish to run inference on Torch or TensorFlow, you can check out this Article (PyTorch) or this one (TensorFlow) where we use FastAPI to build the exact same inference API but the underlying ML backend is powered by the mentioned ML frameworks.

See you next time :)

Serving a TensorFlow Model using FastAPI

Serving a TensorFlow Model using FastAPI

Learn how to serve a TensorFlow model with FastAPI.

05m Read