-
Notifications
You must be signed in to change notification settings - Fork 211
/
Copy pathgenerate.proto
474 lines (410 loc) · 11.8 KB
/
generate.proto
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
syntax = "proto3";
package generate.v1;
service LoraxService {
/// Model Info
rpc Info (InfoRequest) returns (InfoResponse) {}
/// Service discovery
rpc ServiceDiscovery (ServiceDiscoveryRequest) returns (ServiceDiscoveryResponse) {}
/// Empties batch cache
rpc ClearCache (ClearCacheRequest) returns (ClearCacheResponse);
/// Remove requests from a cached batch
rpc FilterBatch (FilterBatchRequest) returns (FilterBatchResponse);
/// Warmup the model and compute max cache size
rpc Warmup (WarmupRequest) returns (WarmupResponse);
/// Prefill batch and decode first token
rpc Prefill (PrefillRequest) returns (PrefillResponse);
/// Embed
rpc Embed (EmbedRequest) returns (EmbedResponse);
/// Classify
rpc Classify (ClassifyRequest) returns (ClassifyResponse);
/// Decode token for a list of prefilled batches
rpc Decode (DecodeRequest) returns (DecodeResponse);
/// Health check
rpc Health (HealthRequest) returns (HealthResponse);
/// Download adapter
rpc DownloadAdapter (DownloadAdapterRequest) returns (DownloadAdapterResponse);
/// Load adapter
rpc LoadAdapter (LoadAdapterRequest) returns (LoadAdapterResponse);
/// Offload adapter
rpc OffloadAdapter (OffloadAdapterRequest) returns (OffloadAdapterResponse);
}
message HealthRequest {}
message HealthResponse {}
message PreloadedAdapter {
/// Adapter params
AdapterParameters adapter_parameters = 1;
/// Adapter source
AdapterSource adapter_source = 2;
/// Adapter index
uint32 adapter_index = 3;
}
/// Empty request
message InfoRequest {}
message InfoResponse {
bool requires_padding = 1;
string dtype = 2;
string device_type = 3;
optional uint32 window_size = 4;
uint32 block_size = 5;
uint32 speculate = 6;
repeated PreloadedAdapter preloaded_adapters = 7;
bool supports_generation = 8;
bool supports_embeddings = 9;
bool supports_classification = 10;
bool chunked_prefill = 11;
bool requires_block_allocator = 12;
}
/// Empty request
message ServiceDiscoveryRequest {}
message ServiceDiscoveryResponse {
/// Other shards urls
repeated string urls = 1;
}
message ClearCacheRequest {
/// Optional batch id
optional uint64 id = 1;
}
/// Empty response
message ClearCacheResponse {}
message NextTokenChooserParameters {
/// exponential scaling output probability distribution
float temperature = 1;
/// restricting to the k highest probability elements
uint32 top_k = 2;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float top_p = 3;
/// restricting to top tokens summing to prob_cut_off <= prob_cut_off
float typical_p = 4;
/// apply sampling on the logits
bool do_sample = 5;
/// random seed for sampling
uint64 seed = 6;
/// repetition penalty
float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 8;
/// presence penalty
float presence_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 10;
/// adapter to use with lora exchange
string adapter_id = 11;
/// JSON schema used for constrained decoding (Outlines)
optional string schema = 12;
/// returning the k highest probability alternatives
uint32 return_k_alternatives = 13;
}
message StoppingCriteriaParameters {
/// Maximum number of generated tokens
uint32 max_new_tokens = 1;
/// Optional stopping sequences
repeated string stop_sequences = 2;
/// Ignore end of sequence token
/// used for benchmarking
bool ignore_eos_token = 3;
}
message Image {
/// Binary image data.
bytes data = 1;
/// Image MIME type.
string mimetype = 2;
}
message InputChunk {
oneof chunk {
/// Plain text data
string text = 1;
/// Image data
Image image = 2;
}
}
message TokenizedInputs {
/// Token IDs
repeated uint32 ids = 1;
/// Chunks
repeated InputChunk input_chunks = 2;
}
message Request {
/// Request ID
uint64 id = 1;
/// The generation context
string inputs = 2;
/// Tokenized inputs
TokenizedInputs tokenized_inputs = 3;
/// Context truncation
uint32 truncate = 4;
/// Next Token Chooser Parameters
NextTokenChooserParameters parameters = 5;
/// Stopping Criteria Parameters
StoppingCriteriaParameters stopping_parameters = 6;
/// Return prefill logprobs
bool prefill_logprobs = 7;
/// Adapter index
uint32 adapter_index = 8;
/// Paged attention blocks
repeated uint32 blocks = 9;
/// Paged attention slots
repeated uint32 slots = 10;
/// Tokens that can be retrieved from the KV cache.
/// This value is set for the first prefill and never reset
uint32 cache_len = 11;
/// Chunk of tokens that must be computed for the first prefill
/// This value is set for the first prefill and never reset
optional uint32 chunk_len = 12;
}
message Batch {
/// Batch ID
uint64 id = 1;
/// Individual requests
repeated Request requests = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Maximum number of Paged Attention blocks
uint32 max_blocks = 5;
}
message CachedBatch {
/// Batch ID
uint64 id = 1;
/// Individual requests ids
repeated uint64 request_ids = 2;
/// Batch size (==len(requests))
uint32 size = 3;
/// Maximum number of tokens this batch will grow to
uint32 max_tokens = 4;
/// Number of tokens in the next forward
uint32 current_tokens = 5;
}
enum FinishReason {
FINISH_REASON_LENGTH = 0;
FINISH_REASON_EOS_TOKEN = 1;
FINISH_REASON_STOP_SEQUENCE = 2;
}
message GeneratedText {
/// Output
string text = 1;
/// Number of generated tokens
uint32 generated_tokens = 2;
/// Number of skipped tokens due to speculative decoding hits
uint32 skipped_tokens = 3;
/// Finish reason
FinishReason finish_reason = 4;
/// Seed
optional uint64 seed = 5;
}
message AlternativeTokens {
/// Alternative Token IDs
repeated uint32 ids = 1;
/// Alternative Logprobs
repeated float logprobs = 2;
/// Alternative tokens
repeated string texts = 3;
}
message NextTokens {
/// Token IDs
repeated uint32 ids = 1;
/// Logprobs
repeated float logprobs = 2;
/// decoded text for each token
repeated string texts = 3;
/// is special for each token
repeated bool is_special = 4;
/// Alternative tokens (optional)
repeated AlternativeTokens alternative_tokens = 5;
}
message Generation {
/// Request ID
uint64 request_id = 1;
/// Prefill tokens (optional)
NextTokens prefill_tokens = 2;
/// Next tokens
NextTokens next_tokens = 3;
/// Complete generated text
optional GeneratedText generated_text = 4;
/// Prefill tokens length
uint32 prefill_tokens_length = 5;
}
message FilterBatchRequest {
/// Batch ID
uint64 batch_id = 1;
/// Requests to keep
repeated uint64 request_ids = 2;
}
message FilterBatchResponse {
/// Filtered Batch (cached)
CachedBatch batch = 1;
}
message PrefillRequest {
/// Batch
Batch batch = 1;
/// Optional cached batch
CachedBatch cached_batch = 2;
}
message PrefillResponse {
/// Generation
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
// TODO(travis): add timings
// /// Forward elapsed time in nanoseconds
// uint64 forward_ns = 3;
// /// Decode elapsed time in nanoseconds
// uint64 decode_ns = 4;
// /// Total elapsed time in nanoseconds
// uint64 total_ns = 5;
// /// Concatenate elapsed time in nanoseconds
// optional uint64 concat_ns = 6;
}
message DecodeRequest {
/// Cached batches
repeated CachedBatch batches = 1;
}
message DecodeResponse {
/// Decodes
repeated Generation generations = 1;
/// Next batch (cached)
optional CachedBatch batch = 2;
}
message Embedding {
/// Request ID
uint64 request_id = 1;
/// Embedding values
repeated float values = 2;
}
message EmbedRequest {
/// Batch
Batch batch = 1;
}
message EmbedResponse {
/// Embeddings
repeated Embedding embeddings = 1;
/// Error message on failure
string errorMsg = 2;
}
message Entity {
string entity = 1;
float score = 2;
uint32 index = 3;
string word = 4;
uint32 start = 5;
uint32 end = 6;
}
message EntityList {
/// Request ID
uint64 request_id = 1;
/// Entities
repeated Entity entities = 2;
/// XXX
repeated uint32 input_ids = 4;
}
message ClassifyPredictionList {
/// Request ID
uint64 request_id = 1;
/// Classifications
repeated string predictions = 2;
repeated float scores = 3;
}
message ClassifyRequest {
/// Batch
Batch batch = 1;
}
message ClassifyResponse {
/// Classifications
repeated ClassifyPredictionList classify_prediction_lists = 1;
/// Error message on failure
string errorMsg = 2;
}
message WarmupRequest {
/// Batch to warmup on
Batch batch = 1;
uint32 max_input_length = 2;
uint32 max_prefill_tokens = 3;
uint32 max_new_tokens = 4;
}
/// Empty response
message WarmupResponse {
/// Maximum number of tokens supported by the model
optional uint32 max_supported_total_tokens = 1;
}
enum AdapterSource {
/// Adapters loaded using the HuggingFace Hub
HUB = 0;
/// Adapters loaded via remote filesystem path
S3 = 1;
/// Adapters loaded via local filesystem path
LOCAL = 2;
/// Adapters loaded via predibase
PBASE = 3;
}
enum MergeStrategy {
/// Linear combination of adapters
LINEAR = 0;
/// TIES method for combining adapters
TIES = 1;
/// DARE method for combining adapters
DARE_LINEAR = 2;
/// DARE + TIES method for combining adapters
DARE_TIES = 3;
}
enum MajoritySignMethod {
/// Total method
TOTAL = 0;
/// Frequency method
FREQUENCY = 1;
}
message AdapterParameters {
/// Adapter IDs
repeated string adapter_ids = 1;
/// Adapter weights for merging
repeated float weights = 2;
/// Merge strategy (default: linear)
MergeStrategy merge_strategy = 3;
/// [0, 1], 0: full pruning, 1: no pruning
float density = 4;
/// Majority sign method (default: total)
MajoritySignMethod majority_sign_method = 5;
}
message DownloadAdapterRequest {
/// Adapter Parameters
AdapterParameters adapter_parameters = 1;
/// Adapter source
AdapterSource adapter_source = 2;
/// Token for external API (predibase / HuggingFace)
optional string api_token = 3;
}
message DownloadAdapterResponse {
/// True if download occurred, false if skipped
bool downloaded = 1;
/// Fraction of the adapter memory limit consumed by the adapter.
/// If no limit is set, will return 0.
/// When the total across all loaded adapters exceeds
/// the adapter_memory_fraction limit, no more adapters
/// will be loaded to GPU and LoRAX will begin swapping.
float memory_fraction = 2;
}
message LoadAdapterRequest {
/// Adapter Parameters
AdapterParameters adapter_parameters = 1;
/// Adapter source
AdapterSource adapter_source = 2;
/// Adapter index
uint32 adapter_index = 3;
/// Token for external API (predibase / HuggingFace)
optional string api_token = 4;
}
message LoadAdapterResponse {
/// True if load occurred, false if skipped
bool loaded = 1;
}
message OffloadAdapterRequest {
/// Adapter Parameters
AdapterParameters adapter_parameters = 1;
/// Adapter source
AdapterSource adapter_source = 2;
/// Adapter index
uint32 adapter_index = 3;
}
message OffloadAdapterResponse {
/// True if offload occurred, false if skipped
bool offloaded = 1;
}