-
Notifications
You must be signed in to change notification settings - Fork 211
/
Copy pathtest_types.py
88 lines (72 loc) · 2.63 KB
/
test_types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import pytest
from lorax.types import Parameters, Request, MergedAdapters
from lorax.errors import ValidationError
def test_parameters_validation():
# Test best_of
Parameters(best_of=1)
with pytest.raises(ValidationError):
Parameters(best_of=0)
with pytest.raises(ValidationError):
Parameters(best_of=-1)
Parameters(best_of=2, do_sample=True)
with pytest.raises(ValidationError):
Parameters(best_of=2)
with pytest.raises(ValidationError):
Parameters(best_of=2, seed=1)
# Test repetition_penalty
Parameters(repetition_penalty=1)
with pytest.raises(ValidationError):
Parameters(repetition_penalty=0)
with pytest.raises(ValidationError):
Parameters(repetition_penalty=-1)
# Test seed
Parameters(seed=1)
with pytest.raises(ValidationError):
Parameters(seed=-1)
# Test temperature
Parameters(temperature=1)
Parameters(temperature=0)
with pytest.raises(ValidationError):
Parameters(temperature=-1)
# Test top_k
Parameters(top_k=1)
with pytest.raises(ValidationError):
Parameters(top_k=0)
with pytest.raises(ValidationError):
Parameters(top_k=-1)
# Test top_p
Parameters(top_p=0.5)
with pytest.raises(ValidationError):
Parameters(top_p=0)
with pytest.raises(ValidationError):
Parameters(top_p=-1)
with pytest.raises(ValidationError):
Parameters(top_p=1)
# Test truncate
Parameters(truncate=1)
with pytest.raises(ValidationError):
Parameters(truncate=0)
with pytest.raises(ValidationError):
Parameters(truncate=-1)
# Test typical_p
Parameters(typical_p=0.5)
with pytest.raises(ValidationError):
Parameters(typical_p=0)
with pytest.raises(ValidationError):
Parameters(typical_p=-1)
with pytest.raises(ValidationError):
Parameters(typical_p=1)
# Test adapter_id and merged_adapters
merged_adapters = MergedAdapters(ids=["test/adapter-id-1", "test/adapter-id-2"], weights=[0.5, 0.5], density=0.5)
Parameters(adapter_id="test/adapter-id")
Parameters(merged_adapters=merged_adapters)
with pytest.raises(ValidationError):
Parameters(adapter_id="test/adapter-id", merged_adapters=merged_adapters)
def test_request_validation():
Request(inputs="test")
with pytest.raises(ValidationError):
Request(inputs="")
Request(inputs="test", stream=True)
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True))
with pytest.raises(ValidationError):
Request(inputs="test", parameters=Parameters(best_of=2, do_sample=True), stream=True)