LoRAX Python client provides a convenient way of interfacing with a
lorax
instance running in your environment.
pip install lorax-client
from lorax import Client
endpoint_url = "http://127.0.0.1:8080"
client = Client(endpoint_url)
text = client.generate("Why is the sky blue?", adapter_id="some/adapter").generated_text
print(text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
or with the asynchronous client:
from lorax import AsyncClient
endpoint_url = "http://127.0.0.1:8080"
client = AsyncClient(endpoint_url)
response = await client.generate("Why is the sky blue?", adapter_id="some/adapter")
print(response.generated_text)
# ' Rayleigh scattering'
# Token Streaming
text = ""
async for response in client.generate_stream("Why is the sky blue?", adapter_id="some/adapter"):
if not response.token.special:
text += response.token.text
print(text)
# ' Rayleigh scattering'
See API reference for full details.
In some cases you may have a list of prompts that you wish to process in bulk ("batch processing").
Rather than process each prompt one at a time, you can take advantage of the AsyncClient
and LoRAX's native
parallelism to submit your prompts at once and await the results:
import asyncio
import time
from lorax import AsyncClient
# Batch of prompts to submit
prompts = [
"The quick brown fox",
"The rain in Spain",
"What comes up",
]
# Initialize the async client
endpoint_url = "http://127.0.0.1:8080"
async_client = AsyncClient(endpoint_url)
# Submit all prompts and do not block on the response
t0 = time.time()
futures = []
for prompt in prompts:
resp = async_client.generate(prompt, max_new_tokens=64)
futures.append(resp)
# Await the completion of all the prompt requests
responses = await asyncio.gather(*futures)
# Print responses
# Responses will always come back in the same order as the original list
for resp in responses:
print(resp.generated_text)
# Print duration to process all requests in batch
print("duration (s):", time.time() - t0)
Output:
duration (s): 2.9093329906463623
Compare this against the duration of submitting one at a time. You should find that for 3 prompts the duration of async is about 2.5 - 3x faster than serial processing:
from lorax import Client
client = Client(endpoint_url)
t0 = time.time()
responses = []
for prompt in prompts:
resp = client.generate(prompt, max_new_tokens=64)
responses.append(resp)
for resp in responses:
print(resp.generated_text)
print("duration (s):", time.time() - t0)
Output:
duration (s): 8.385080099105835
The LoRAX client can also be used to connect to Predibase managed LoRAX endpoints (including Predibase's serverless endpoints).
You need only make the following changes to the above examples:
- Change the
endpoint_url
to match the endpoint of your Predibase LLM of choice. - Provide your Predibase API token in the
headers
provided to the client.
Example:
from lorax import Client
# You can get your Predibase API token by going to Settings > My Profile > Generate API Token
# You can get your Predibase Tenant short code by going to Settings > My Profile > Overview > Tenant ID
endpoint_url = f"https://serving.app.predibase.com/{predibase_tenant_short_code}/deployments/v2/llms/{llm_deployment_name}"
headers = {
"Authorization": f"Bearer {api_token}"
}
client = Client(endpoint_url, headers=headers)
# same as above from here ...
response = client.generate("Why is the sky blue?", adapter_id=f"{model_repo}/{model_version}")
Note that by default Predibase will use its internal model repos as the default adapter_source
. To use an adapter from Huggingface:
response = client.generate("Why is the sky blue?", adapter_id="some/adapter", adapter_source="hub")