Shortcuts

The Tests

Our next step is to create some tests for our Task. For the TemplateSKLearnClassifier, we will just create some basic tests. You should expand on these to include tests for any specific functionality you have in your Task.

Smoke tests

We use smoke tests, usually called test_smoke, throughout. These just instantiate the class we are testing, to see that they can be created without raising any errors.

tests/examples/test_scripts.py

Before we write our custom tests, we should add out examples to the CI. To do this, add a line for each example (finetuning and predict) to the annotation of test_example in tests/examples/test_scripts.py. Here’s how those lines look for our template.py examples:

pytest.param(
    "finetuning", "template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
),
...
pytest.param(
    "predict", "template.py", marks=pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
),

test_data.py

The most important tests in test_data.py check that the from_* methods work correctly. In the class TestTemplateData, we have two of these: test_from_numpy and test_from_sklearn. In general, there should be one test_from_* method for each data_source you have configured.

Here’s the code for test_from_numpy:

    def test_from_numpy(self):
        """Tests that ``TemplateData`` is properly created when using the ``from_numpy`` method."""
        data = np.random.rand(10, self.num_features)
        targets = np.random.randint(0, self.num_classes, (10,))

        # instantiate the data module
        dm = TemplateData.from_numpy(
            train_data=data,
            train_targets=targets,
            val_data=data,
            val_targets=targets,
            test_data=data,
            test_targets=targets,
            batch_size=2,
            num_workers=0,
        )
        assert dm is not None
        assert dm.train_dataloader() is not None
        assert dm.val_dataloader() is not None
        assert dm.test_dataloader() is not None

        # check training data
        data = next(iter(dm.train_dataloader()))
        rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
        assert rows.shape == (2, self.num_features)
        assert targets.shape == (2,)

        # check val data
        data = next(iter(dm.val_dataloader()))
        rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
        assert rows.shape == (2, self.num_features)
        assert targets.shape == (2,)

        # check test data
        data = next(iter(dm.test_dataloader()))
        rows, targets = data[DefaultDataKeys.INPUT], data[DefaultDataKeys.TARGET]
        assert rows.shape == (2, self.num_features)
        assert targets.shape == (2,)

test_model.py

In test_model.py, we first have test_forward and test_train. These test that tensors can be passed to the forward and that the Task can be trained. Here’s the code for test_forward and test_train:

@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
@pytest.mark.parametrize("num_classes", [4, 256])
@pytest.mark.parametrize("shape", [(1, 3), (2, 128)])
def test_forward(num_classes, shape):
    """Tests that a tensor can be given to the model forward and gives the correct output size."""
    model = TemplateSKLearnClassifier(
        num_features=shape[1],
        num_classes=num_classes,
    )
    model.eval()

    row = torch.rand(*shape)

    out = model(row)
    assert out.shape == (shape[0], num_classes)
@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_train(tmpdir):
    """Tests that the model can be trained on our ``DummyDataset``."""
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    train_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.fit(model, train_dl)

We also include tests for validating and testing: test_val, and test_test. These tests are very similar to test_train, but here they are for completeness:

@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_val(tmpdir):
    """Tests that the model can be validated on our ``DummyDataset``."""
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    val_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.validate(model, val_dl)
@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_test(tmpdir):
    """Tests that the model can be tested on our ``DummyDataset``."""
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    test_dl = torch.utils.data.DataLoader(DummyDataset(), batch_size=4)
    trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
    trainer.test(model, test_dl)

We also include tests for prediction named test_predict_* for each of our data sources. In our case, we have test_predict_numpy and test_predict_sklearn. These tests should use the data_source argument to predict() to select the required DataSource. Here’s test_predict_sklearn as an example:

@pytest.mark.skipif(not _SKLEARN_AVAILABLE, reason="sklearn isn't installed")
def test_predict_sklearn():
    """Tests that we can generate predictions from a scikit-learn ``Bunch``."""
    bunch = datasets.load_iris()
    model = TemplateSKLearnClassifier(num_features=DummyDataset.num_features, num_classes=DummyDataset.num_classes)
    data_pipe = DataPipeline(preprocess=TemplatePreprocess())
    out = model.predict(bunch, data_source="sklearn", data_pipeline=data_pipe)
    assert isinstance(out[0], int)

Now that you’ve written the tests, it’s time to add some docs!

Read the Docs v: latest
Versions
latest
stable
0.5.0
0.4.0
0.3.2
0.3.1
0.3.0
0.2.3
0.2.2
0.2.1
0.2.0
0.1.0post1
Downloads
pdf
html
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.