Flash Serve¶
Flash Serve is a library to easily serve models in production.
Terminology¶
Here are common terms you need to be familiar with:
Term |
Definition |
---|---|
de-serialization |
Transform data encoded as text into tensors |
inference function |
A function taking the decoded tensors and forward them through the model to produce predictions. |
serialization |
Transform the predictions tensors back to a text encoding. |
|
The |
|
The |
|
The |
The |
Example¶
In this tutorial, we will serve a Resnet18 from the PyTorchVision library in 3 steps.
The entire tutorial can be found under examples/serve/generic
.
Introduction¶
Traditionally, an inference pipeline is made out of 3 steps:
de-serialization
: Transform data encoded as text into tensors.inference function
: A function taking the decoded tensors and forward them through the model to produce predictions.serialization
: Transform the predictions tensors back as text.
In this example, we will implement only the inference function as Flash Serve already provides some built-in de-serialization
and serialization
functions with Image
Step 1 - Create a ModelComponent¶
Inside inference_serve.py
,
we will implement a ClassificationInference
class, which overrides ModelComponent
.
First, we need make the following imports:
import torch
import torchvision
from flash.core.serve import Composition, Servable, ModelComponent, expose
from flash.core.serve.types import Image, Label
To implement ClassificationInference
, we need to implement a method responsible for inference function
and decorated with the expose()
function.
The name of the inference method isn’t constrained, but we will use classify
as appropriate in this example.
Our classify function will take a tensor image, apply some normalization on it, and forward it through the model.
def classify(img):
img = img.float() / 255
mean = torch.tensor([[[0.485, 0.456, 0.406]]]).float()
std = torch.tensor([[[0.229, 0.224, 0.225]]]).float()
img = (img - mean) / std
img = img.permute(0, 3, 2, 1)
out = self.model(img)
return out.argmax()
The expose()
is a python decorator extending the decorated function with the de-serialization
, serialization
steps.
Note
Flash Serve was designed this way to enable several models to be chained together by removing the decorator.
The expose()
function takes 2 arguments:
inputs
: Dictionary mapping the decorated function inputs toBaseType
objects.outputs
: Dictionary mapping the decorated function outputs toBaseType
objects.
A BaseType
is a python dataclass
which implements a serialize
and deserialize
function.
Note
Flash Serve has already several BaseType
built-in such as Image
or Text
.
class ClassificationInference(ModelComponent):
def __init__(self, model: Servable):
self.model = model
@expose(
inputs={"img": Image()},
outputs={"prediction": Label(path="imagenet_labels.txt")},
)
def classify(self, img):
img = img.float() / 255
mean = torch.tensor([[[0.485, 0.456, 0.406]]]).float()
std = torch.tensor([[[0.229, 0.224, 0.225]]]).float()
img = (img - mean) / std
img = img.permute(0, 3, 2, 1)
out = self.model(img)
return out.argmax()
Step 2 - Create a scripted Model¶
Using the PyTorchVision library, we create a resnet18
and use torch.jit.script to script the model.
Note
TorchScript is a way to create serializable and optimizable models from PyTorch code. Any TorchScript program can be saved from a Python process and loaded in a process where there is no Python dependency.
model = torchvision.models.resnet18(pretrained=True).eval()
torch.jit.script(model).save("resnet.pt")
Step 3 - Serve the model¶
The Servable
takes as argument the path to the TorchScripted model and then will be passed to our ClassificationInference
class.
The ClassificationInference
instance will be passed as argument to a Composition
class.
Once the Composition
class is instantiated, just call its serve()
method.
resnet = Servable("resnet.pt")
comp = ClassificationInference(resnet)
composition = Composition(classification=comp)
composition.serve()
Launching the server.¶
In Terminal 1¶
Just run:
python inference_server.py
And you should see this in your terminal
You should also see an Swagger UI already built for you at http://127.0.0.1:8000/docs
In Terminal 2¶
Run this script from another terminal:
import base64
from pathlib import Path
import requests
with Path("fish.jpg").open("rb") as f:
imgstr = base64.b64encode(f.read()).decode("UTF-8")
body = {"session": "UUID", "payload": {"img": {"data": imgstr}}}
resp = requests.post("http://127.0.0.1:8000/predict", json=body)
print(resp.json())
# {'session': 'UUID', 'result': {'prediction': 'goldfish, Carassius auratus'}}
Credits to @rlizzo, @hhsecond, @lantiga, @luiscape for building Flash Serve Engine.