Using Docker to Generate Machine Learning Predictions in Real Time

Figure 1. A REST API serves as the communication layer between a machine learning model and incoming data.

Introduction

In Part III of our Docker for Machine Learning series, we learned how to use Docker to perform model training and batch inference. While batch inference is a simple and effective approach for generating predictions when latency requirements are on the order of hours or days, batch inference can’t be used when real time predictions are required. When we need to generate predictions on-the-fly, we need to use online inference. In this post, we’ll learn how to use Docker to perform online inference.

Online inference is more complex than batch inference due to tighter latency requirements. A production machine learning system that needs to respond to requests with predictions within 100 milliseconds is generally harder to implement than a system that has 24 hours to produce predictions. Within those 100 milliseconds the ML system needs to retrieve input data, perform inference, validate the model output, and then (typically) return the results over a network.

Although online inference is more complex, it’s not black magic, especially if you have any web development experience. In this post, we’re going to learn how to implement online inference using Docker and Flask-RESTful. Specifically, you’ll learn

  1. What is online inference and how does UberEats use it to estimate delivery times
  2. How online inference is implemented using a REST API.
  3. How to implement a REST API using Docker and Flask RESTful.

What is Online Inference?

Online Inference (a.k.a. real time or dynamic inference) is the process of generating machine learning predictions in real time upon request. Typically, these predictions are generated on-the-fly on a single instance of data. Since these predictions are generated on-the-fly, they are available at any time, as long as the service exposing the model is available.


An Industry Example: UberEats

As an example, consider the estimated-time-to-delivery feature in the UberEats app. Every time a user orders food through the app, a machine learning model predicts the delivery time of the order based on features including the time of day, the delivery location, and average meal prep time at the restaurant. It wouldn’t be feasible to generate these predictions in batch because users want to know approximately when their food will arrive as soon as they make an order. Therefore, Uber needs to deploy their models to return this information to users as they order food.

Let’s take a moment to appreciate the sheer complexity of this feature. From a modeling perspective, Uber data scientists have to build machine learning models to predict the duration of a complex multi-stage process. To accurately predict time of delivery, the models need to know about the complexity of the meal being prepared, the load the restaurant is currently facing, and the time required for a driver to be dispatched, stop to pick up the food, and then drive to the final delivery location. From an infrastructure perspective, the Uber team needs to be able to generate dozens (if not hundreds) of features including historical averages that very likely consume data from a variety of data stores. From a Product perspective, the predictions must be returned to millions of users, across devices, in a couple (I’m guessing < 1-2) seconds. All over the world : O

Figure 2. Image via Uber blog

What is a REST API?

If we want to generate predictions in real time, we need to expose our trained models. One way to expose our models is to implement a REST API. REST (representational state transfer) is a software architecture style that defines a way of creating web services. A REST API allows users (known as clients) to request resources from other hosts (known as servers). In a RESTful web service, the client requests a specific resource by referencing the resource’s Uniform Resource Indicator (URI). The server then responds to the request by issuing a response formatted in HTML, JSON, or some other format. These request-response pairs are typically transferred over HTTP.

Defining our Prediction API

We will define a very simple REST API to serve our model predictions. We define a Prediction resource, which accepts the features of an individual instance of data and responds with a model prediction. The API is as follows:

HTTP Method URI Action
POST http://[hostname]/predict Make a prediction

Implementing Online Inference with Docker and Flask-RESTful

For context, let’s examine the files that we’re going to work with in this post. Our directory structure looks like this:

code
├── api.py
├── Dockerfile
├── requirements.txt
└── train.py

Here’s a run down of each of the files

  • api.py – This script will be called to perform online inference. It will define and launch our REST API.
  • Dockerfile – The Dockerfile to build our Docker image.
  • requirements.txt – Python dependencies.
  • train.py – This file is the same as the train.py file found in Part III.

Here is the requirements.txt file.

flask
flask-restful
joblib

Flask-RESTful API Implementation

Now that we have discussed what a REST API is, let’s implement one using Flask-RESTful. Flask-RESTful provides a Resource base class that is used to define the routing for HTTP methods for a given URL. We then define methods for each HTTP method we would like to implement.

To define our Prediction resource, we will define a Prediction subclass, implement the post method that generates and returns a prediction, and registers the resource with the /prediction endpoint. Here is the implementation:

class Prediction(Resource):
    def __init__(self):
        self._required_features = ['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM',
                                   'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B',
                                   'LSTAT']
        self.reqparse = reqparse.RequestParser()
        for feature in self._required_features:
            self.reqparse.add_argument(
                feature, type = float, required = True, location = 'json',
                help = 'No {} provided'.format(feature))
        super(Prediction, self).__init__()

    def post(self):
        args = self.reqparse.parse_args()
        X = np.array([args[f] for f in self._required_features]).reshape(1, -1)
        y_pred = clf.predict(X)
        return {'prediction': y_pred.tolist()[0]}

We first define an instance attribute self._required_features, a list containing required input features. Next, we instantiate another attribute self.reqparse that is of type flask_restful.reqparse.RequestParser. This class is responsible for adding and parsing arguments during the web requests. We add each field from self._required_features to the list of arguments expected during the web request using the RequestParser.add_argument method. The additional named arguments passed to the method, type, required, location, and help are used to validate the incoming request data.

The post method defines what occurs when a client issues an HTTP POST request to the Prediction resource. First, we collect the incoming data by calling self.reqparse.parse_args(). This automatically validates the request data according to the parameters we passed to self.reqparse.add_argument in the init method. Next, we create a 2 dimensional numpy array from the request data because sklearn models expect 2-dimensional numpy arrays as input. We then generate a prediction. Finally, we return a dict containing a single prediction key and the numeric prediction. Flask RESTful automatically converts this dict to json for us.

Here is the rest of the file containing the Prediction class.

import os

from flask import Flask
from flask_restful import Resource, Api, reqparse
from joblib import load
import numpy as np

MODEL_DIR = os.environ["MODEL_DIR"]
MODEL_FILE = os.environ["MODEL_FILE"]
METADATA_FILE = os.environ["METADATA_FILE"]
MODEL_PATH = os.path.join(MODEL_DIR, MODEL_FILE)
METADATA_PATH = os.path.join(MODEL_DIR, METADATA_FILE)

print("Loading model from: {}".format(MODEL_PATH))
clf = load(MODEL_PATH)

app = Flask(__name__)
api = Api(app)

class Prediction(Resource):
    def __init__(self):
        self._required_features = ['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM',
                                   'AGE', 'DIS', 'RAD', 'TAX', 'PTRATIO', 'B',
                                   'LSTAT']
        self.reqparse = reqparse.RequestParser()
        for feature in self._required_features:
            self.reqparse.add_argument(
                feature, type = float, required = True, location = 'json',
                help = 'No {} provided'.format(feature))
        super(Prediction, self).__init__()

    def post(self):
        args = self.reqparse.parse_args()
        X = np.array([args[f] for f in self._required_features]).reshape(1, -1)
        y_pred = clf.predict(X)
        return {'prediction': y_pred.tolist()[0]}

api.add_resource(Prediction, '/predict')

if __name__ == '__main__':
    app.run(debug=True, host='0.0.0.0')

If you read my post on using Docker for batch inference, you’ll recognize parts of this file. We first retrieve environment variables containing the path to the trained model and then we deserialize the model using the joblib package. Next we create a flask app object, and use that object to instantiate a flask-RESTful Api object.

One thing I want to call out is api.add_resource(Prediction, '/predict'). This line of code is responsible for associating the Prediction class with the /predict endpoint. This way, when client code issues a POST request to /predict, the Prediction.post method is called to generate the prediction.

Building a Docker Image for Online Inference

With our Python implementation complete, let’s build our Docker image. As described in Part II of the series, we need a Dockerfile that defines how to build the image. Here it is.

FROM jupyter/scipy-notebook

COPY requirements.txt ./requirements.txt
RUN pip install -r requirements.txt

RUN mkdir model
ENV MODEL_DIR=/home/jovyan/model
ENV MODEL_FILE=clf.joblib
ENV METADATA_FILE=metadata.json

COPY train.py ./train.py
COPY api.py ./api.py

RUN python3 train.py

Similar to the Dockerfile used in the Part III, we start with the jupyter/scipy-notebook base image, install our Python dependencies, set a few environment variables, copy two files into the image, and then train our model. To build this image, run the following command

docker build -t docker-api -f Dockerfile .

Below is the output of running the command

$ docker build -t docker-api -f Dockerfile .
Sending build context to Docker daemon  8.704kB
Step 1/10 : FROM jupyter/scipy-notebook
 ---> 2fb85d5904cc
Step 2/10 : COPY requirements.txt ./requirements.txt
 ---> Using cache
 ---> f7d3df033bb4
Step 3/10 : RUN pip install -r requirements.txt
 ---> Using cache
 ---> 53c7c84910fd
Step 4/10 : RUN mkdir model
 ---> Using cache
 ---> c6f10f206379
Step 5/10 : ENV MODEL_DIR=/home/jovyan/model
 ---> Using cache
 ---> b42d1d794ca7
Step 6/10 : ENV MODEL_FILE=clf.joblib
 ---> Using cache
 ---> b53a23eddebf
Step 7/10 : ENV METADATA_FILE=metadata.json
 ---> Using cache
 ---> 63de64423761
Step 8/10 : COPY train.py ./train.py
 ---> Using cache
 ---> a1c900a04e0e
Step 9/10 : COPY api.py ./api.py
 ---> 43164e1e6ae9
Step 10/10 : RUN python3 train.py
 ---> Running in e0a2bca441c1
Loading data...
Splitting data...
Fitting model...
Serializing model to: /home/jovyan/model/clf.joblib
Serializing metadata to: /home/jovyan/model/metadata.json
Removing intermediate container e0a2bca441c1
 ---> ec6a8b813b78
Successfully built ec6a8b813b78
Successfully tagged docker-api:latest

Running Online Inference

Let’s summarize what we’ve accomplished. We implemented a basic REST API in Python using the flask-RESTful package. Each time a client issues a POST request to the /prediction endpoint with the appropriate request data, our Prediction class generates a prediction using a pre-trained model.

Now we will launch the web server. This lets us issue POST requests to the endpoint and retrieve predictions. To launch the web server, we need to run a Docker container and run the api.py script. The command to run the container is

docker run -it -p 5000:5000 docker-api python3 api.py

The -it flags allow us to see the logs from the container. The -p flag exposes port 5000 in the container to port 5000 on our host machine. In order to start the development server, we run command python3 api.py in the docker-api image.

Here is the output from the command.

code: docker run -it -p 5000:5000 docker-api python3 api.py
Loading model from: /home/jovyan/model/clf.joblib
 * Serving Flask app "api" (lazy loading)
 * Environment: production
   WARNING: Do not use the development server in a production environment.
   Use a production WSGI server instead.
 * Debug mode: on
 * Running on http://0.0.0.0:5000/ (Press CTRL+C to quit)
 * Restarting with stat
Loading model from: /home/jovyan/model/clf.joblib
 * Debugger is active!
 * Debugger PIN: 340-626-715

To actually make a request, we’ll use the curl command. Open up another terminal window and run the following command

curl -i -H "Content-Type: application/json" -X POST \
	-d '{"CRIM": 15.02, "ZN": 0.0, "INDUS": 18.1, "CHAS": 0.0, "NOX": 0.614, \
		 "RM": 5.3, "AGE": 97.3, "DIS": 2.1, "RAD": 24.0, "TAX": 666.0, \
		 "PTRATIO": 20.2, "B": 349.48, "LSTAT": 24.9}' \
	    127.0.0.1:5000/predict

Here we use the curl command to issue a POST request to IP address 127.0.0.1:5000/predict. We also pass in each field required for inference. And here is the moment of truth…

$ curl -i -H "Content-Type: application/json" -X POST \
	-d '{"CRIM": 15.02, "ZN": 0.0, "INDUS": 18.1, "CHAS": 0.0, "NOX": 0.614, \
		 "RM": 5.3, "AGE": 97.3, "DIS": 2.1, "RAD": 24.0, "TAX": 666.0, \
		 "PTRATIO": 20.2, "B": 349.48, "LSTAT": 24.9}' \
	    127.0.0.1:5000/predict
HTTP/1.0 200 OK
Content-Type: application/json
Content-Length: 41
Server: Werkzeug/0.15.0 Python/3.6.8
Date: Wed, 27 Mar 2019 22:19:35 GMT

{
    "prediction": 12.273424794987879
}

Tada! Our API server returned a response containing a prediction. If you scroll back to the terminal running the web server, you should see a line resembling the following:

172.17.0.1 - - [27/Mar/2019 22:19:35] "POST /predict HTTP/1.1" 200 -

which indicates that a POST request was made to the /predict endpoint and returned successfully.

As I mentioned earlier, flask-RESTful takes care of data validation for us. If we make the same POST request but remove one the required arguments, an error will be returned to the client.

$ curl -i -H "Content-Type: application/json" -X POST \
	-d '{"CRIM": 15.02, "ZN": 0.0, "INDUS": 18.1, "CHAS": 0.0, "NOX": 0.614, \
		 "RM": 5.3, "AGE": 97.3, "DIS": 2.1, "RAD": 24.0, "TAX": 666.0, \
		 "PTRATIO": 20.2, "B": 349.48}' \
	127.0.0.1:5000/predict
HTTP/1.0 400 BAD REQUEST
Content-Type: application/json
Content-Length: 64
Server: Werkzeug/0.15.0 Python/3.6.8
Date: Wed, 27 Mar 2019 23:37:00 GMT

{
    "message": {
        "LSTAT": "No LSTAT provided"
    }
}

This time the web server responds with HTTP code 400, indicating a client error:

172.17.0.1 - - [27/Mar/2019 23:32:55] "POST /predict HTTP/1.1" 400 -

Conclusion

If you’ve made it this far, great job! You now know how to use Docker and flask-RESTful to perform online inference! This architecture is widely used to build productionized machine learning products. We’ve laid the foundation for a full scale machine learning service, but our model is only available locally and we’ve left much to be desired in terms of logging and monitoring.

I’ve put together a guide demonstrating how to take this API to the next level so that it can be properly productionized. I cover how to implement proper logging so that you can track your ML model usage and improve the model over time and how to use a proper web server rather than Flask’s development server. If you want this content emailed to you, sign up using the form below!

Additional References

5 thoughts on “Using Docker to Generate Machine Learning Predictions in Real Time”

  1. Hello! I was trying to follow your tutorial and the code was exactly the same as yours, however, when building the Docker image, I get:

    KeyError: ‘METADATA_FILE’

    Is there any way to solve this?

    1. Hi there. Make sure you build the docker image. The environment variable for the METADATA_FILE is defined there.

  2. Hi Luigi, thank you for the detailing the nuts and bolts required to productionize the ML code. I just want to add a small snippet of code which can help to run the prediction method via API in Windows 10.

    cUrl (Command Prompt): curl -H “Content-Type:application/json” -X POST -d “{\”CRIM\”:15.02,\”ZN\”:0.0,\”INDUS\”:18.1,\”CHAS\”:0.0,\”NOX\”:0.614,\”RM\”:5.3,\”AGE\”:97.3,\”DIS\”:2.1,\”RAD\”:24.0,\”TAX\”:666.0,\”PTRATIO\”:20.2,\”B\”:349.48,\”LSTAT\”:24.9}” 127.0.0.1:5000/predict

    If any of you are trying the same using Postman to access the api running in Docker (docker-api)
    Step 1: select method “POST”
    Step 2: in the Url area, type/give “http://127.0.0.1:5000/predict”
    Step 3: Select “Body” –> select “raw” –> Select “json” from the dropdown.
    Step 4: Paste the string “{“CRIM”: 15.02, “ZN”: 0.0, “INDUS”: 18.1, “CHAS”: 0.0, “NOX”: 0.614,
    “RM”: 5.3, “AGE”: 97.3, “DIS”: 2.1, “RAD”: 24.0, “TAX”: 666.0,
    “PTRATIO”: 20.2, “B”: 349.48, “LSTAT”: 24.9}” in the text area.
    Step 5: Click on “Send” button.

    In the both the cases you should be getting the result::
    {
    “prediction”: 12.273424794987877
    }

  3. Hi Luigi,
    Thanks a lot for this post!

    Do you have any tips for non-numeric features? e.g. categorical features / boolean features – any features that should be transformed before moving to the model for prediction purposes.

    I think it’s related to a feature store, do you have a post / some info about how to implement it in a scale?

    Also, I have a concern where predictions need to be queried fast (e.g. 100 predictions / second).
    How would you tackle this kind of challenge?

    Thank you!
    Daniel

  4. Excellent guide, I read through all of your three part of Docker and other MLOps posts, I like to.
    I’ve come to this point, have a REST API that can make inference in local host using CURL. How to take it to the next level that external user can access, request, and make inference?

Leave a Reply

Your email address will not be published. Required fields are marked *