Shortcuts

Predictions (inference)

You can use Flash to get predictions on pretrained or finetuned models.

First create a DataModule with some predict data, then pass it to the Trainer.predict method.

from flash import Trainer
from flash.core.data.utils import download_data
from flash.image import ImageClassifier, ImageClassificationData

# 1. Download the data set
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt"
)

# 3. Predict whether the image contains an ant or a bee
trainer = Trainer()
datamodule = ImageClassificationData.from_files(
    predict_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"]
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# out: [["bees"]]

Serializing predictions

To change the output format of predictions you can attach an Output to your Task. For example, you can choose to output probabilities (for more options see the API reference below).

from flash.core.classification import ProbabilitiesOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassifier


# 1. Download the data set
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.7.0/image_classification_model.pt"
)

# 3. Attach the Output
model.output = ProbabilitiesOutput()

# 4. Predict whether the image contains an ant or a bee
trainer = Trainer()
datamodule = ImageClassificationData.from_files(
    predict_files=["data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg"]
)
predictions = trainer.predict(model, datamodule=datamodule)
print(predictions)
# out: [[[0.5926494598388672, 0.40735048055648804]]]

Note

PyTorch Lightning does not return predictions directly from predict when using a multi-GPU configuration (DDP). Instead you should use a pytorch_lightning.callbacks.BasePredictionWriter.

Read the Docs v: latest
Versions
latest
stable
0.8.2
0.8.1.post0
0.8.1
0.8.0
0.7.5
0.7.4
0.7.3
0.7.2
0.7.1
0.7.0
0.6.0
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.