Shortcuts

Predictions (inference)

You can use Flash to get predictions on pretrained or finetuned models.

Predict on a single sample of data

You can pass in a sample of data (image file path, a string of text, etc) to the predict() method.

from flash.core.data.utils import download_data
from flash.image import ImageClassifier


# 1. Download the data set
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt"
)

# 3. Predict whether the image contains an ant or a bee
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)

Predict on a csv file

from flash.core.data.utils import download_data
from flash.tabular import TabularClassifier

# 1. Download the data
download_data("https://pl-flash-data.s3.amazonaws.com/titanic.zip", "data/")

# 2. Load the model from a checkpoint
model = TabularClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.6.0/tabular_classification_model.pt"
)

# 3. Generate predictions from a csv file! Who would survive?
predictions = model.predict("data/titanic/titanic.csv")
print(predictions)

Serializing predictions

To change the output format of predictions you can attach an Output to your Task. For example, you can choose to output probabilities (for more options see the API reference below).

from flash.core.classification import ProbabilitiesOutput
from flash.core.data.utils import download_data
from flash.image import ImageClassifier


# 1. Download the data set
download_data("https://pl-flash-data.s3.amazonaws.com/hymenoptera_data.zip", "data/")

# 2. Load the model from a checkpoint
model = ImageClassifier.load_from_checkpoint(
    "https://flash-weights.s3.amazonaws.com/0.6.0/image_classification_model.pt"
)

# 3. Attach the Output
model.output = ProbabilitiesOutput()

# 4. Predict whether the image contains an ant or a bee
predictions = model.predict("data/hymenoptera_data/val/bees/65038344_52a45d090d.jpg")
print(predictions)
# out: [[0.5926494598388672, 0.40735048055648804]]
Read the Docs v: latest
Versions
latest
stable
0.5.2
0.5.1
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.