|
18 | 18 | import modAL.expected_error
|
19 | 19 | import modAL.multilabel
|
20 | 20 | import modAL.uncertainty
|
| 21 | +import modAL.dropout |
21 | 22 |
|
22 | 23 | from copy import deepcopy
|
23 | 24 | from itertools import chain, product
|
|
39 | 40 | import torch
|
40 | 41 | from torch import nn
|
41 | 42 |
|
| 43 | +from skorch import NeuralNetClassifier |
| 44 | + |
42 | 45 | Test = namedtuple('Test', ['input', 'output'])
|
43 | 46 |
|
44 | 47 |
|
@@ -718,18 +721,170 @@ def test_entropy_sampling(self):
|
718 | 721 | shuffled_query_idx, true_query_idx)
|
719 | 722 |
|
720 | 723 |
|
| 724 | +# PyTorch model for test cases --> Do not change the layers |
| 725 | +class Torch_Model(nn.Module): |
| 726 | + def __init__(self,): |
| 727 | + super(Torch_Model, self).__init__() |
| 728 | + self.convs = nn.Sequential( |
| 729 | + nn.Conv2d(1, 32, 3), |
| 730 | + nn.ReLU(), |
| 731 | + nn.Conv2d(32, 64, 3), |
| 732 | + nn.ReLU(), |
| 733 | + nn.MaxPool2d(2), |
| 734 | + nn.Dropout(0.25) |
| 735 | + ) |
| 736 | + self.fcs = nn.Sequential( |
| 737 | + nn.Linear(12*12*64, 128), |
| 738 | + nn.ReLU(), |
| 739 | + nn.Dropout(0.5), |
| 740 | + nn.Linear(128, 10), |
| 741 | + ) |
| 742 | + |
| 743 | + def forward(self, x): |
| 744 | + return x |
| 745 | + |
| 746 | + |
721 | 747 | class TestDropout(unittest.TestCase):
|
722 |
| - def test_mc_dropout_bald(self): pass |
723 |
| - def test_mc_dropout_mean_st(self): pass |
724 |
| - def test_mc_dropout_max_entropy(self): pass |
725 |
| - def test_mc_dropout_max_variationRatios(self): pass |
726 |
| - def test_get_predictions(self): pass |
727 |
| - def test_set_dropout_mode(self): pass |
| 748 | + def setUp(self): |
| 749 | + self.skorch_classifier = NeuralNetClassifier(Torch_Model, |
| 750 | + criterion=torch.nn.CrossEntropyLoss, |
| 751 | + optimizer=torch.optim.Adam, |
| 752 | + train_split=None, |
| 753 | + verbose=1) |
| 754 | + |
| 755 | + def test_mc_dropout_bald(self): |
| 756 | + learner = modAL.models.learners.DeepActiveLearner( |
| 757 | + estimator=self.skorch_classifier, |
| 758 | + query_strategy=modAL.dropout.mc_dropout_bald, |
| 759 | + ) |
| 760 | + for random_tie_break in [True, False]: |
| 761 | + for num_cycles, sample_per_forward_pass in product(range(1, 5), range(1, 5)): |
| 762 | + for n_samples, n_classes in product(range(1, 5), range(1, 5)): |
| 763 | + for n_instances in range(1, n_samples): |
| 764 | + X_pool = torch.randn(n_samples, n_classes) |
| 765 | + modAL.dropout.mc_dropout_bald(learner, X_pool, n_instances, random_tie_break, [], |
| 766 | + num_cycles, sample_per_forward_pass) |
| 767 | + |
| 768 | + def test_mc_dropout_mean_st(self): |
| 769 | + learner = modAL.models.learners.DeepActiveLearner( |
| 770 | + estimator=self.skorch_classifier, |
| 771 | + query_strategy=modAL.dropout.mc_dropout_mean_st, |
| 772 | + ) |
| 773 | + for random_tie_break in [True, False]: |
| 774 | + for num_cycles, sample_per_forward_pass in product(range(1, 5), range(1, 5)): |
| 775 | + for n_samples, n_classes in product(range(1, 5), range(1, 5)): |
| 776 | + for n_instances in range(1, n_samples): |
| 777 | + X_pool = torch.randn(n_samples, n_classes) |
| 778 | + modAL.dropout.mc_dropout_mean_st(learner, X_pool, n_instances, random_tie_break, [], |
| 779 | + num_cycles, sample_per_forward_pass) |
| 780 | + |
| 781 | + def test_mc_dropout_max_entropy(self): |
| 782 | + learner = modAL.models.learners.DeepActiveLearner( |
| 783 | + estimator=self.skorch_classifier, |
| 784 | + query_strategy=modAL.dropout.mc_dropout_max_entropy, |
| 785 | + ) |
| 786 | + for random_tie_break in [True, False]: |
| 787 | + for num_cycles, sample_per_forward_pass in product(range(1, 5), range(1, 5)): |
| 788 | + for n_samples, n_classes in product(range(1, 5), range(1, 5)): |
| 789 | + for n_instances in range(1, n_samples): |
| 790 | + X_pool = torch.randn(n_samples, n_classes) |
| 791 | + modAL.dropout.mc_dropout_max_entropy(learner, X_pool, n_instances, random_tie_break, [], |
| 792 | + num_cycles, sample_per_forward_pass) |
| 793 | + |
| 794 | + def test_mc_dropout_max_variationRatios(self): |
| 795 | + learner = modAL.models.learners.DeepActiveLearner( |
| 796 | + estimator=self.skorch_classifier, |
| 797 | + query_strategy=modAL.dropout.mc_dropout_max_variationRatios, |
| 798 | + ) |
| 799 | + for random_tie_break in [True, False]: |
| 800 | + for num_cycles, sample_per_forward_pass in product(range(1, 5), range(1, 5)): |
| 801 | + for n_samples, n_classes in product(range(1, 5), range(1, 5)): |
| 802 | + for n_instances in range(1, n_samples): |
| 803 | + X_pool = torch.randn(n_samples, n_classes) |
| 804 | + modAL.dropout.mc_dropout_max_variationRatios(learner, X_pool, n_instances, random_tie_break, [], |
| 805 | + num_cycles, sample_per_forward_pass) |
| 806 | + |
| 807 | + def test_get_predictions(self): |
| 808 | + X = torch.randn(100, 1) |
| 809 | + |
| 810 | + learner = modAL.models.learners.DeepActiveLearner( |
| 811 | + estimator=self.skorch_classifier, |
| 812 | + query_strategy=mock.MockFunction(return_val=None), |
| 813 | + ) |
| 814 | + |
| 815 | + # num predictions tests |
| 816 | + for num_predictions in range(1, 20): |
| 817 | + for samples_per_forward_pass in range(1, 10): |
| 818 | + |
| 819 | + predictions = modAL.dropout.get_predictions( |
| 820 | + learner, X, dropout_layer_indexes=[], |
| 821 | + num_predictions=num_predictions, |
| 822 | + sample_per_forward_pass=samples_per_forward_pass) |
| 823 | + |
| 824 | + self.assertEqual(len(predictions), num_predictions) |
| 825 | + |
| 826 | + self.assertRaises(AssertionError, modAL.dropout.get_predictions, |
| 827 | + learner, X, dropout_layer_indexes=[], |
| 828 | + num_predictions=-1, |
| 829 | + sample_per_forward_pass=0) |
| 830 | + |
| 831 | + self.assertRaises(AssertionError, modAL.dropout.get_predictions, |
| 832 | + learner, X, dropout_layer_indexes=[], |
| 833 | + num_predictions=10, |
| 834 | + sample_per_forward_pass=-5) |
| 835 | + |
| 836 | + # logits adapter function test |
| 837 | + for samples, classes, subclasses in product(range(1, 10), range(1, 10), range(1, 10)): |
| 838 | + input_shape = (samples, classes, subclasses) |
| 839 | + desired_shape = (input_shape[0], np.prod(input_shape[1:])) |
| 840 | + X_adaption_needed = torch.randn(input_shape) |
| 841 | + |
| 842 | + def logits_adaptor(input_tensor, data): return torch.flatten( |
| 843 | + input_tensor, start_dim=1) |
| 844 | + |
| 845 | + predictions = modAL.dropout.get_predictions( |
| 846 | + learner, X_adaption_needed, dropout_layer_indexes=[], |
| 847 | + num_predictions=num_predictions, |
| 848 | + sample_per_forward_pass=samples_per_forward_pass, |
| 849 | + logits_adaptor=logits_adaptor) |
| 850 | + |
| 851 | + self.assertEqual(predictions[0].shape, desired_shape) |
| 852 | + |
| 853 | + def test_set_dropout_mode(self): |
| 854 | + # set dropmout mode for all dropout layers |
| 855 | + for train_mode in [True, False]: |
| 856 | + model = Torch_Model() |
| 857 | + modules = list(model.modules()) |
| 858 | + |
| 859 | + for module in modules: |
| 860 | + self.assertEqual(module.training, True) |
| 861 | + |
| 862 | + modAL.dropout.set_dropout_mode(model, [], train_mode) |
| 863 | + |
| 864 | + self.assertEqual(modules[7].training, train_mode) |
| 865 | + self.assertEqual(modules[11].training, train_mode) |
| 866 | + |
| 867 | + # set dropout mode only for special layers: |
| 868 | + for train_mode in [True, False]: |
| 869 | + model = Torch_Model() |
| 870 | + modules = list(model.modules()) |
| 871 | + modAL.dropout.set_dropout_mode(model, [7], train_mode) |
| 872 | + self.assertEqual(modules[7].training, train_mode) |
| 873 | + self.assertEqual(modules[11].training, True) |
| 874 | + |
| 875 | + modAL.dropout.set_dropout_mode(model, [], True) |
| 876 | + modAL.dropout.set_dropout_mode(model, [11], train_mode) |
| 877 | + self.assertEqual(modules[11].training, train_mode) |
| 878 | + self.assertEqual(modules[7].training, True) |
| 879 | + |
| 880 | + # No Dropout Layer |
| 881 | + self.assertRaises(KeyError, modAL.dropout.set_dropout_mode, |
| 882 | + model, [5], train_mode) |
728 | 883 |
|
729 | 884 |
|
730 | 885 | class TestDeepActiveLearner(unittest.TestCase):
|
731 | 886 | """
|
732 |
| - Tests for the base class methods of the BaseLearner (base.py) are provided in |
| 887 | + Tests for the base class methods of the BaseLearner (base.py) are provided in |
733 | 888 | the TestActiveLearner.
|
734 | 889 | """
|
735 | 890 |
|
@@ -1535,17 +1690,5 @@ def test_examples(self):
|
1535 | 1690 | import example_tests.ranked_batch_mode
|
1536 | 1691 |
|
1537 | 1692 |
|
1538 |
| -# Empty PyTorch model for test cases |
1539 |
| -class Torch_Model(nn.Module): |
1540 |
| - def __init__(self,): |
1541 |
| - super(Torch_Model, self).__init__() |
1542 |
| - self.convs = nn.Sequential( |
1543 |
| - nn.Conv2d(1, 5, 3), |
1544 |
| - ) |
1545 |
| - |
1546 |
| - def forward(self, x): |
1547 |
| - return x |
1548 |
| - |
1549 |
| - |
1550 | 1693 | if __name__ == '__main__':
|
1551 | 1694 | unittest.main(verbosity=2)
|
0 commit comments