Testing our Code
This is the most important section in the guide. Unit tests are short tests that assert the output of some function or method equals the expected value for some given input. They allow us to ensure our code does what we think it does, reducing bugs and making it dramatically easier to refactor our code down the line.
There are multiple unit testing frameworks in Python, but we will use a very popular one, pytest
.
Typically, tests are grouped under a tests
directory (Poetry created this for us automatically). Let's start by creating a file, test_main.py
under tests
.
touch tests/test_main.py
As is common, we will group our tests for a class under another class, like so
from se_best_practices_ml_perspective.main import LitClassifier
class TestLitClassifier:
model = LitClassifier()
def test_forward(self):
assert False
def test_training_step(self):
assert False
def test_validation_step(self):
assert False
def test_configure_optimizers(self):
assert False
We start with one unit test per method of LitClassifier
. A passing test must contain an assert
statement that evaluates to True
. To run our unit tests, we simply call
poetry run pytest
Notice that all four tests fail, as expected.
================================================== short test summary info ==================================================
FAILED tests/test_main.py::TestLitClassifier::test_forward - assert False
FAILED tests/test_main.py::TestLitClassifier::test_training_step - assert False
FAILED tests/test_main.py::TestLitClassifier::test_validation_step - assert False
FAILED tests/test_main.py::TestLitClassifier::test_configure_optimizers - assert False
Let's get each test to pass, one-by-one, starting with test_forward
.
from se_best_practices_ml_perspective.main import LitClassifier
import torch
class TestLitClassifier:
model = LitClassifier()
def test_forward(self):
"""Assert that the output shape of `LitClassifier.forward` is as expected."""
inputs = torch.randn(1, 28, 28)
outputs = self.model.forward(inputs)
expected_size = (1, 10)
actual_size = outputs.size()
assert actual_size == expected_size
def test_training_step(self):
assert False
def test_validation_step(self):
assert False
def test_configure_optimizers(self):
assert False
Here, we add a simple test that asserts for some random input, the output is of the expected shape. Running the tests again, we will notice that test_forward
is now passing.
================================================== short test summary info ==================================================
FAILED tests/test_main.py::TestLitClassifier::test_training_step - assert False
FAILED tests/test_main.py::TestLitClassifier::test_validation_step - assert False
FAILED tests/test_main.py::TestLitClassifier::test_configure_optimizers - assert False
Continuing with the remaining tests
from se_best_practices_ml_perspective.main import LitClassifier
import torch
class TestLitClassifier:
model = LitClassifier()
def test_forward(self):
"""Assert that the output shape of `LitClassifier.forward` is as expected."""
inputs = torch.randn(1, 28, 28)
outputs = self.model.forward(inputs)
expected_size = (1, 10)
actual_size = outputs.size()
assert actual_size == expected_size
def test_training_step(self):
"""Assert that `LitClassifier.training_step` returns a non-empty dictionary."""
inputs = (torch.randn(1, 28, 28), torch.randint(10, (1,)))
results = self.model.training_step(batch=inputs, batch_idx=0)
assert isinstance(results, dict)
assert results
def test_validation_step(self):
"""Assert that `LitClassifier.validation_step` returns a non-empty dictionary."""
inputs = (torch.randn(1, 28, 28), torch.randint(10, (1,)))
results = self.model.validation_step(batch=inputs, batch_idx=0)
assert isinstance(results, dict)
assert results
def test_configure_optimizers(self):
"""Assert that `LitClassifier.configure_optimizers` returns an Adam optimizer with the
expected learning rate.
"""
optimizer = self.model.configure_optimizers()
assert isinstance(optimizer, torch.optim.Adam)
assert optimizer.param_groups[0]["lr"] == 0.02
In this quick overview of unit testing, we wrote simple tests for each method of our neural network classifier. We then checked that all tests are passing, increasing our confidence that our code works as expected. In the future, if we were to refactor our code, we could re-run our tests to ensure we didn't break anything.
In the next and final section, we will see how to tie everything we have learned together and automate the process of linting, formatting and testing.