|
9 | 9 |
|
10 | 10 | from mljar.client.project import ProjectClient
|
11 | 11 | from project_based_test import ProjectBasedTest
|
| 12 | +from mljar.exceptions import BadValueException, IncorrectInputDataException |
12 | 13 | from mljar import Mljar
|
13 | 14 |
|
14 | 15 | class MljarTest(ProjectBasedTest):
|
@@ -50,6 +51,29 @@ def test_basic_usage(self):
|
50 | 51 | score = self.mse(pred, self.y)
|
51 | 52 | self.assertTrue(score < 0.1)
|
52 | 53 |
|
| 54 | + def test_empty_project_title(self): |
| 55 | + with self.assertRaises(BadValueException) as context: |
| 56 | + model = Mljar(project = '', experiment = '') |
| 57 | + |
| 58 | + def test_wrong_tuning_mode(self): |
| 59 | + with self.assertRaises(BadValueException) as context: |
| 60 | + model = Mljar(project = self.proj_title, experiment = self.expt_title, |
| 61 | + tuning_mode = 'Crazy') |
| 62 | + |
| 63 | + def test_default_tuning_mode(self): |
| 64 | + model = Mljar(project = self.proj_title, experiment = self.expt_title) |
| 65 | + self.assertEqual(model.tuning_mode, 'Sport') |
| 66 | + |
| 67 | + def test_wrong_input_dim(self): |
| 68 | + with self.assertRaises(IncorrectInputDataException) as context: |
| 69 | + model = Mljar(project = self.proj_title, experiment = self.expt_title) |
| 70 | + samples = 100 |
| 71 | + columns = 10 |
| 72 | + X = np.random.rand(samples, columns) |
| 73 | + y = np.random.choice([0,1], samples+1, replace = True) |
| 74 | + model.fit(X, y) |
| 75 | + |
| 76 | + |
53 | 77 | def test_non_wait_fit(self):
|
54 | 78 | '''
|
55 | 79 | Test the non wait fit.
|
|
0 commit comments