Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit c631a59

Browse files
committed
add compute_prediction
1 parent 78bf675 commit c631a59

File tree

3 files changed

+65
-5
lines changed

3 files changed

+65
-5
lines changed

mljar/mljar.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ class Mljar(object):
2424
This is a wrapper over MLJAR API - it does all the stuff.
2525
'''
2626

27-
def __init__(self, project, experiment,
27+
def __init__(self, project,
28+
experiment,
2829
metric = '',
2930
algorithms = [],
3031
validation_kfolds = MLJAR_DEFAULT_FOLDS,
@@ -316,3 +317,40 @@ def predict(self, X):
316317
logger.error('Sorry, there was some problem with computing prediction for your dataset. \
317318
Please login to mljar.com to your account and check details.')
318319
return None
320+
321+
322+
@staticmethod
323+
def compute_prediction(X, model_id, project_id):
324+
325+
326+
# chack if dataset exists in mljar if not upload dataset for prediction
327+
dataset = DatasetClient(project_id).add_dataset_if_not_exists(X, y = None)
328+
329+
# check if prediction is available
330+
total_checks = 100
331+
for i in xrange(total_checks):
332+
prediction = PredictionClient(project_id).\
333+
get_prediction(dataset.hid, model_id)
334+
335+
# prediction is not available, first check so submit job
336+
if i == 0 and prediction is None:
337+
# create prediction job
338+
submitted = PredictJobClient().submit(project_id, dataset.hid,
339+
model_id)
340+
if not submitted:
341+
logger.error('Problem with prediction for your dataset')
342+
return None
343+
344+
if prediction is not None:
345+
pred = PredictionDownloadClient().download(prediction.hid)
346+
#sys.stdout.write('\r\n')
347+
return pred
348+
349+
#sys.stdout.write('\rFetch predictions: {0}%'.format(round(i/(total_checks*0.01))))
350+
#sys.stdout.flush()
351+
time.sleep(5)
352+
353+
#sys.stdout.write('\r\n')
354+
logger.error('Sorry, there was some problem with computing prediction for your dataset. \
355+
Please login to mljar.com to your account and check details.')
356+
return None

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setup(
1212
name='mljar',
13-
version='0.0.6',
13+
version='0.0.7',
1414
description='Python wrapper over MLJAR API',
1515
long_description=long_description,
1616
url='https://github.com/mljar/mljar-api-python',

tests/mljar_test.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,38 @@ def setUp(self):
2626
self.X = df[cols]
2727
self.y = df[target]
2828

29-
def tearDown(self):
30-
# clean
31-
ProjectBasedTest.clean_projects()
29+
#def tearDown(self):
30+
# # clean
31+
# ProjectBasedTest.clean_projects()
3232

3333
def mse(self, predictions, targets):
3434
predictions = np.array(predictions)
3535
targets = np.array(targets)
3636
targets = targets.reshape((targets.shape[0],1))
3737
return ((predictions - targets) ** 2).mean()
3838

39+
40+
def test_compute_prediction(self):
41+
'''
42+
Test the most common usage.
43+
'''
44+
model = Mljar(project = self.proj_title, experiment = self.expt_title,
45+
algorithms = ['rfc'], metric='logloss',
46+
validation_kfolds=3, tuning_mode='Normal')
47+
self.assertTrue(model is not None)
48+
# fit models and wait till all models are trained
49+
model.fit(X = self.X, y = self.y)
50+
51+
# get project id
52+
project_id = model.project.hid
53+
# get model id
54+
model_id = model.selected_algorithm.hid
55+
# compute predictions
56+
pred = Mljar.compute_prediction(self.X, model_id, project_id)
57+
# compute score
58+
score = self.mse(pred, self.y)
59+
self.assertTrue(score < 0.1)
60+
3961
def test_basic_usage(self):
4062
'''
4163
Test the most common usage.

0 commit comments

Comments
 (0)