Skip to content

Commit 0f978be

Browse files
author
Max Keller
committed
Add tests for dropout query strategies
1 parent 0a2d24a commit 0f978be

File tree

2 files changed

+166
-20
lines changed

2 files changed

+166
-20
lines changed

modAL/dropout.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ def mc_dropout_max_variationRatios(classifier: BaseEstimator, X: modALinput, n_i
251251
return shuffled_argmax(variationRatios, n_instances=n_instances)
252252

253253

254-
def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_indexes: list,
254+
def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_indexes: list = [],
255255
num_predictions: int = 50, sample_per_forward_pass: int = 1000,
256256
logits_adaptor: Callable[[torch.tensor, modALinput], torch.tensor] = default_logits_adaptor):
257257
"""
@@ -273,6 +273,9 @@ def get_predictions(classifier: BaseEstimator, X: modALinput, dropout_layer_inde
273273
prediction: list with all predictions
274274
"""
275275

276+
assert num_predictions > 0, 'num_predictions must be larger than zero'
277+
assert sample_per_forward_pass > 0, 'sample_per_forward_pass must be larger than zero'
278+
276279
predictions = []
277280
# set dropout layers to train mode
278281
set_dropout_mode(classifier.estimator.module_,

tests/core_tests.py

Lines changed: 162 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import modAL.expected_error
1919
import modAL.multilabel
2020
import modAL.uncertainty
21+
import modAL.dropout
2122

2223
from copy import deepcopy
2324
from itertools import chain, product
@@ -39,6 +40,8 @@
3940
import torch
4041
from torch import nn
4142

43+
from skorch import NeuralNetClassifier
44+
4245
Test = namedtuple('Test', ['input', 'output'])
4346

4447

@@ -718,18 +721,170 @@ def test_entropy_sampling(self):
718721
shuffled_query_idx, true_query_idx)
719722

720723

724+
# PyTorch model for test cases --> Do not change the layers
725+
class Torch_Model(nn.Module):
726+
def __init__(self,):
727+
super(Torch_Model, self).__init__()
728+
self.convs = nn.Sequential(
729+
nn.Conv2d(1, 32, 3),
730+
nn.ReLU(),
731+
nn.Conv2d(32, 64, 3),
732+
nn.ReLU(),
733+
nn.MaxPool2d(2),
734+
nn.Dropout(0.25)
735+
)
736+
self.fcs = nn.Sequential(
737+
nn.Linear(12*12*64, 128),
738+
nn.ReLU(),
739+
nn.Dropout(0.5),
740+
nn.Linear(128, 10),
741+
)
742+
743+
def forward(self, x):
744+
return x
745+
746+
721747
class TestDropout(unittest.TestCase):
722-
def test_mc_dropout_bald(self): pass
723-
def test_mc_dropout_mean_st(self): pass
724-
def test_mc_dropout_max_entropy(self): pass
725-
def test_mc_dropout_max_variationRatios(self): pass
726-
def test_get_predictions(self): pass
727-
def test_set_dropout_mode(self): pass
748+
def setUp(self):
749+
self.skorch_classifier = NeuralNetClassifier(Torch_Model,
750+
criterion=torch.nn.CrossEntropyLoss,
751+
optimizer=torch.optim.Adam,
752+
train_split=None,
753+
verbose=1)
754+
755+
def test_mc_dropout_bald(self):
756+
learner = modAL.models.learners.DeepActiveLearner(
757+
estimator=self.skorch_classifier,
758+
query_strategy=modAL.dropout.mc_dropout_bald,
759+
)
760+
for random_tie_break in [True, False]:
761+
for num_cycles, sample_per_forward_pass in product(range(1, 5), range(1, 5)):
762+
for n_samples, n_classes in product(range(1, 5), range(1, 5)):
763+
for n_instances in range(1, n_samples):
764+
X_pool = torch.randn(n_samples, n_classes)
765+
modAL.dropout.mc_dropout_bald(learner, X_pool, n_instances, random_tie_break, [],
766+
num_cycles, sample_per_forward_pass)
767+
768+
def test_mc_dropout_mean_st(self):
769+
learner = modAL.models.learners.DeepActiveLearner(
770+
estimator=self.skorch_classifier,
771+
query_strategy=modAL.dropout.mc_dropout_mean_st,
772+
)
773+
for random_tie_break in [True, False]:
774+
for num_cycles, sample_per_forward_pass in product(range(1, 5), range(1, 5)):
775+
for n_samples, n_classes in product(range(1, 5), range(1, 5)):
776+
for n_instances in range(1, n_samples):
777+
X_pool = torch.randn(n_samples, n_classes)
778+
modAL.dropout.mc_dropout_mean_st(learner, X_pool, n_instances, random_tie_break, [],
779+
num_cycles, sample_per_forward_pass)
780+
781+
def test_mc_dropout_max_entropy(self):
782+
learner = modAL.models.learners.DeepActiveLearner(
783+
estimator=self.skorch_classifier,
784+
query_strategy=modAL.dropout.mc_dropout_max_entropy,
785+
)
786+
for random_tie_break in [True, False]:
787+
for num_cycles, sample_per_forward_pass in product(range(1, 5), range(1, 5)):
788+
for n_samples, n_classes in product(range(1, 5), range(1, 5)):
789+
for n_instances in range(1, n_samples):
790+
X_pool = torch.randn(n_samples, n_classes)
791+
modAL.dropout.mc_dropout_max_entropy(learner, X_pool, n_instances, random_tie_break, [],
792+
num_cycles, sample_per_forward_pass)
793+
794+
def test_mc_dropout_max_variationRatios(self):
795+
learner = modAL.models.learners.DeepActiveLearner(
796+
estimator=self.skorch_classifier,
797+
query_strategy=modAL.dropout.mc_dropout_max_variationRatios,
798+
)
799+
for random_tie_break in [True, False]:
800+
for num_cycles, sample_per_forward_pass in product(range(1, 5), range(1, 5)):
801+
for n_samples, n_classes in product(range(1, 5), range(1, 5)):
802+
for n_instances in range(1, n_samples):
803+
X_pool = torch.randn(n_samples, n_classes)
804+
modAL.dropout.mc_dropout_max_variationRatios(learner, X_pool, n_instances, random_tie_break, [],
805+
num_cycles, sample_per_forward_pass)
806+
807+
def test_get_predictions(self):
808+
X = torch.randn(100, 1)
809+
810+
learner = modAL.models.learners.DeepActiveLearner(
811+
estimator=self.skorch_classifier,
812+
query_strategy=mock.MockFunction(return_val=None),
813+
)
814+
815+
# num predictions tests
816+
for num_predictions in range(1, 20):
817+
for samples_per_forward_pass in range(1, 10):
818+
819+
predictions = modAL.dropout.get_predictions(
820+
learner, X, dropout_layer_indexes=[],
821+
num_predictions=num_predictions,
822+
sample_per_forward_pass=samples_per_forward_pass)
823+
824+
self.assertEqual(len(predictions), num_predictions)
825+
826+
self.assertRaises(AssertionError, modAL.dropout.get_predictions,
827+
learner, X, dropout_layer_indexes=[],
828+
num_predictions=-1,
829+
sample_per_forward_pass=0)
830+
831+
self.assertRaises(AssertionError, modAL.dropout.get_predictions,
832+
learner, X, dropout_layer_indexes=[],
833+
num_predictions=10,
834+
sample_per_forward_pass=-5)
835+
836+
# logits adapter function test
837+
for samples, classes, subclasses in product(range(1, 10), range(1, 10), range(1, 10)):
838+
input_shape = (samples, classes, subclasses)
839+
desired_shape = (input_shape[0], np.prod(input_shape[1:]))
840+
X_adaption_needed = torch.randn(input_shape)
841+
842+
def logits_adaptor(input_tensor, data): return torch.flatten(
843+
input_tensor, start_dim=1)
844+
845+
predictions = modAL.dropout.get_predictions(
846+
learner, X_adaption_needed, dropout_layer_indexes=[],
847+
num_predictions=num_predictions,
848+
sample_per_forward_pass=samples_per_forward_pass,
849+
logits_adaptor=logits_adaptor)
850+
851+
self.assertEqual(predictions[0].shape, desired_shape)
852+
853+
def test_set_dropout_mode(self):
854+
# set dropmout mode for all dropout layers
855+
for train_mode in [True, False]:
856+
model = Torch_Model()
857+
modules = list(model.modules())
858+
859+
for module in modules:
860+
self.assertEqual(module.training, True)
861+
862+
modAL.dropout.set_dropout_mode(model, [], train_mode)
863+
864+
self.assertEqual(modules[7].training, train_mode)
865+
self.assertEqual(modules[11].training, train_mode)
866+
867+
# set dropout mode only for special layers:
868+
for train_mode in [True, False]:
869+
model = Torch_Model()
870+
modules = list(model.modules())
871+
modAL.dropout.set_dropout_mode(model, [7], train_mode)
872+
self.assertEqual(modules[7].training, train_mode)
873+
self.assertEqual(modules[11].training, True)
874+
875+
modAL.dropout.set_dropout_mode(model, [], True)
876+
modAL.dropout.set_dropout_mode(model, [11], train_mode)
877+
self.assertEqual(modules[11].training, train_mode)
878+
self.assertEqual(modules[7].training, True)
879+
880+
# No Dropout Layer
881+
self.assertRaises(KeyError, modAL.dropout.set_dropout_mode,
882+
model, [5], train_mode)
728883

729884

730885
class TestDeepActiveLearner(unittest.TestCase):
731886
"""
732-
Tests for the base class methods of the BaseLearner (base.py) are provided in
887+
Tests for the base class methods of the BaseLearner (base.py) are provided in
733888
the TestActiveLearner.
734889
"""
735890

@@ -1535,17 +1690,5 @@ def test_examples(self):
15351690
import example_tests.ranked_batch_mode
15361691

15371692

1538-
# Empty PyTorch model for test cases
1539-
class Torch_Model(nn.Module):
1540-
def __init__(self,):
1541-
super(Torch_Model, self).__init__()
1542-
self.convs = nn.Sequential(
1543-
nn.Conv2d(1, 5, 3),
1544-
)
1545-
1546-
def forward(self, x):
1547-
return x
1548-
1549-
15501693
if __name__ == '__main__':
15511694
unittest.main(verbosity=2)

0 commit comments

Comments
 (0)