The Tests¶
Our next step is to create some tests for our Task
.
For the TemplateSKLearnClassifier
, we will just create some basic tests.
You should expand on these to include tests for any specific functionality you have in your Task
.
Smoke tests¶
We use smoke tests, usually called test_smoke
, throughout.
These just instantiate the class we are testing, to see that they can be created without raising any errors.
tests/examples/test_scripts.py¶
Before we write our custom tests, we should add out examples to the CI.
To do this, add a line for each example (finetuning
and predict
) to the annotation of test_example
in tests/examples/test_scripts.py.
Here’s how those lines look for our template.py
examples:
pytest.param(
"finetuning", "template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
),
...
pytest.param(
"predict", "template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
),
test_data.py¶
The most important tests in test_data.py check that the from_*
methods work correctly.
In the class TestTemplateData
, we have two of these: test_from_numpy
and test_from_sklearn
.
In general, there should be one test_from_*
method for each input
you have configured.
Here’s the code for test_from_numpy
:
def test_from_numpy(self):
"""Tests that ``TemplateData`` is properly created when using the ``from_numpy`` method."""
data = np.random.rand(10, self.num_features)
targets = np.random.randint(0, self.num_classes, (10,))
# instantiate the data module
dm = TemplateData.from_numpy(
train_data=data,
train_targets=targets,
val_data=data,
val_targets=targets,
test_data=data,
test_targets=targets,
batch_size=2,
num_workers=0,
)
assert dm is not None
assert dm.train_dataloader() is not None
assert dm.val_dataloader() is not None
assert dm.test_dataloader() is not None
# check training data
data = next(iter(dm.train_dataloader()))
rows, targets = data[DataKeys.INPUT], data[DataKeys.TARGET]
assert rows.shape == (2, self.num_features)
assert targets.shape == (2,)
# check val data
data = next(iter(dm.val_dataloader()))
rows, targets = data[DataKeys.INPUT], data[DataKeys.TARGET]
assert rows.shape == (2, self.num_features)
assert targets.shape == (2,)
# check test data
data = next(iter(dm.test_dataloader()))
rows, targets = data[DataKeys.INPUT], data[DataKeys.TARGET]
assert rows.shape == (2, self.num_features)
assert targets.shape == (2,)
test_model.py¶
In test_model.py, we first have test_forward
and test_train
.
These test that tensors can be passed to the forward and that the Task
can be trained.
Here’s the code for test_forward
and test_train
:
@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
@pytest.mark.parametrize("num_classes", [4, 256])
@pytest.mark.parametrize("shape", [(1, 3), (2, 128)])
def test_forward(num_classes, shape):
"""Tests that a tensor can be given to the model forward and gives the correct output size."""
model = TemplateSKLearnClassifier(
num_features=shape[1],
num_classes=num_classes,
)
model.eval()
row = torch.rand(*shape)
out = model(row)
assert out.shape == (shape[0], num_classes)
@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
def test_train(tmpdir):
"""Tests that the model can be trained on our ``DummyDataset``."""
model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model, train_dl)
We also include tests for validating and testing: test_val
, and test_test
.
These tests are very similar to test_train
, but here they are for completeness:
@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
def test_val(tmpdir):
"""Tests that the model can be validated on our ``DummyDataset``."""
model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
val_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.validate(model, val_dl)
@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
def test_test(tmpdir):
"""Tests that the model can be tested on our ``DummyDataset``."""
model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
test_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
trainer.test(model, test_dl)
We also include tests for prediction named test_predict_*
for each of our data sources.
In our case, we have test_predict_numpy
and test_predict_sklearn
.
These tests should load the data with a DataModule
and generate predictions with Trainer.predict
.
Here’s test_predict_sklearn
as an example:
@pytest.mark.skipif(not _TOPIC_CORE_AVAILABLE, reason="Not testing core.")
def test_predict_sklearn():
"""Tests that we can generate predictions from a scikit-learn ``Bunch``."""
bunch = datasets.load_iris()
model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
datamodule = TemplateData.from_sklearn(predict_bunch=bunch, batch_size=1)
trainer = Trainer()
out = trainer.predict(model, datamodule=datamodule, output="classes")
assert isinstance(out[0][0], int)
Now that you’ve written the tests, it’s time to add some docs!