22
22
23
23
24
24
MODEL = 'census'
25
- VERSION = 'v1 '
26
- TF_RECORDS_VERSION = 'v1tfrecord '
25
+ JSON_VERSION = 'v1json '
26
+ EXAMPLES_VERSION = 'v1example '
27
27
PROJECT = 'python-docs-samples-tests'
28
28
JSON = {
29
29
'age' : 25 ,
41
41
'native_country' : ' United-States'
42
42
}
43
43
EXPECTED_OUTPUT = {
44
- u'probabilities' : [0.9942260384559631 , 0.005774002522230148 ],
45
- u'logits' : [- 5.148599147796631 ],
46
- u'classes' : 0 ,
47
- u'logistic' : [0.005774001590907574 ]
44
+ u'confidence' : 0.7760371565818787 ,
45
+ u'predictions' : u' <=50K'
48
46
}
49
47
50
48
51
49
def test_predict_json ():
52
50
result = predict .predict_json (
53
- PROJECT , MODEL , [JSON , JSON ], version = VERSION )
51
+ PROJECT , MODEL , [JSON , JSON ], version = JSON_VERSION )
54
52
assert [EXPECTED_OUTPUT , EXPECTED_OUTPUT ] == result
55
53
56
54
57
55
def test_predict_json_error ():
58
56
with pytest .raises (RuntimeError ):
59
- predict .predict_json (PROJECT , MODEL , [{"foo" : "bar" }], version = VERSION )
57
+ predict .predict_json (
58
+ PROJECT , MODEL , [{"foo" : "bar" }], version = JSON_VERSION )
60
59
61
60
62
61
@pytest .mark .slow
@@ -66,9 +65,8 @@ def test_census_example_to_bytes():
66
65
67
66
68
67
@pytest .mark .slow
69
- @pytest .mark .xfail ('Single placeholder inputs broken in service b/35778449' )
70
- def test_predict_tfrecords ():
68
+ def test_predict_examples ():
71
69
b = predict .census_to_example_bytes (JSON )
72
- result = predict .predict_tfrecords (
73
- PROJECT , MODEL , [b , b ], version = TF_RECORDS_VERSION )
70
+ result = predict .predict_examples (
71
+ PROJECT , MODEL , [b , b ], version = EXAMPLES_VERSION )
74
72
assert [EXPECTED_OUTPUT , EXPECTED_OUTPUT ] == result
0 commit comments