A Single Inference Wrapper for OpenAI, Together AI, Hugging Face Inference TGI, Ollama, etc.
Until recently I thought that the openai
library was only for connecting to OpenAI endpoints. It was not until I was testing out LLM inference with together.ai that I came across a section in their documentation on OpenAI API compatibility. The idea of using the openai
client to do inference with open source models was completely new to me. In the together.ai documentation example they use the openai
library to connect to an open source model.
import os
import openai
system_content = "You are a travel agent. Be descriptive and helpful."
user_content = "Tell me about San Francisco"
client = openai.OpenAI(
api_key=os.environ.get("TOGETHER_API_KEY"),
base_url="https://api.together.xyz/v1",
)
chat_completion = client.chat.completions.create(
model="mistralai/Mixtral-8x7B-Instruct-v0.1",
messages=[
{"role": "system", "content": system_content},
{"role": "user", "content": user_content},
],
temperature=0.7,
max_tokens=1024,
)
response = chat_completion.choices[0].message.content
print("Together response:\n", response)
Then a week later I saw that Hugging Face had also released support for OpenAI compatibility with Text Generation Inference (TGI) and Inference Endpoints. Again, you simply modify the base_url
, api_key
, and model
as seen is this example from their blog post announcement.
from openai import OpenAI
# initialize the client but point it to TGI
client = OpenAI(
base_url="<ENDPOINT_URL>" + "/v1/", # replace with your endpoint url
api_key="<HF_API_TOKEN>", # replace with your token
)
chat_completion = client.chat.completions.create(
model="tgi",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Why is open-source software important?"},
],
stream=True,
max_tokens=500
)
# iterate and print stream
for message in chat_completion:
print(message.choices[0].delta.content, end="")
What about working with LLMs locally? Two such options are Ollama and LM Studio. Ollama recently added support for the openai
client
and LM Studio supports it too. For example, here is how one can use mistral-7b
locally with Ollama to run inference with the
openai
client:
ollama pull mistral
from openai import OpenAI
client = OpenAI(
base_url = 'http://localhost:11434/v1',
api_key='ollama', # required, but unused
)
response = client.chat.completions.create(
model="mistral",
messages=[
{"role": "system", "content": "You are a helpful assistant and always talk like a pirate."},
{"role": "user", "content": "Write a haiku."},
])
print(response.choices[0].message.content)
There are other services and libraries for running LLM inference that are compatible with the openai
library too. I find it all very exciting because it is less code I have to write and maintain for running inference with LLMs. All I need to change is a base_url
, an api_key
, and the name of the model
.
At the same time that I was learning about openai
client compatibility, I was also looking into the instructor library. Since it patches in some additional functionality into the openai
client, I thought it would be fun to discuss here too.
Start by creating a virtual environment:
python3 -m venv env
source env/bin/activate
Then install:
pip install openai
pip install instructor # only if you want to try out instructor library
pip install python-dotenv # or define your environment variables differently
I also have:
ollama pull gemma:2b-instruct
and ollama pull llama2
In my .env
file I have the following:
OPENAI_API_KEY=your_key
HUGGING_FACE_ACCESS_TOKEN=your_key
TOGETHER_API_KEY=your_key
import os
from dotenv import load_dotenv
load_dotenv()
You could go ahead and just start using client.chat.completions.create
directly as in the examples from the introduction.
However, I do like wrapping third party services into classes for reusability, maintainability, etc.
The class below, OpenAIChatCompletion
, does several things:
clients
dictclient.chat.completions.create
in the __call__
methodAsyncOpenAI
client, but sometimes I prefer simply using
futures.ThreadPoolExecutor
as seen in the function create_chat_completions_async
.OpenAI
client with the instructor library. If you don't want to play around with instructor library then simply remove the instructor.patch
code.I also added some logging functionality which keeps track of every outgoing LLM request. This was inspired by the awesome blog post by Hamel Husain, Fuck You, Show Me The Prompt.. In that post, Hamel writes about how various LLM tools can often hide the prompts, making it tricky to see what requests are actually sent to the LLM behind the scenes. I created a simple logger class OpenAIMessagesLogger
which keeps track of all the requests sent to the openai
client. Later when we try out the instructor library for getting structured output, we will utilize this debugging logger to see some additional messages that were sent to the client.
import ast
import logging
import re
from concurrent import futures
from typing import Any, Dict, List, Optional, Union
import instructor
from openai import APITimeoutError, OpenAI
from openai._streaming import Stream
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
class OpenAIChatCompletion:
clients: Dict = dict()
@classmethod
def _load_client(cls, base_url: Optional[str] = None, api_key: Optional[str] = None) -> OpenAI:
client_key = (base_url, api_key)
if OpenAIChatCompletion.clients.get(client_key) is None:
OpenAIChatCompletion.clients[client_key] = instructor.patch(OpenAI(base_url=base_url, api_key=api_key))
return OpenAIChatCompletion.clients[client_key]
def __call__(
self,
model: str,
messages: list,
base_url: Optional[str] = None,
api_key: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatCompletion, Stream[ChatCompletionChunk]]:
# https://platform.openai.com/docs/api-reference/chat/create
# https://github.com/openai/openai-python
client = self._load_client(base_url, api_key)
return client.chat.completions.create(model=model, messages=messages, **kwargs)
@classmethod
def create_chat_completions_async(
cls, task_args_list: List[Dict], concurrency: int = 10
) -> List[Union[ChatCompletion, Stream[ChatCompletionChunk]]]:
"""
Make a series of calls to chat.completions.create endpoint in parallel and collect back
the results.
:param task_args_list: A list of dictionaries where each dictionary contains the keyword
arguments required for __call__ method.
:param concurrency: the max number of workers
"""
def create_chat_task(
task_args: Dict,
) -> Union[None, ChatCompletion, Stream[ChatCompletionChunk]]:
try:
return cls().__call__(**task_args)
except APITimeoutError:
return None
with futures.ThreadPoolExecutor(max_workers=concurrency) as executor:
results = list(executor.map(create_chat_task, task_args_list))
return results
class OpenAIMessagesLogger(logging.Handler):
def __init__(self):
super().__init__()
self.log_messages = []
def emit(self, record):
# Append the log message to the list
log_record_str = self.format(record)
match = re.search(r"Request options: (.+)", log_record_str, re.DOTALL)
if match:
text = match[1].replace("\n", "")
log_obj = ast.literal_eval(text)
self.log_messages.append(log_obj)
def debug_messages():
msg = OpenAIMessagesLogger()
openai_logger = logging.getLogger("openai")
openai_logger.setLevel(logging.DEBUG)
openai_logger.addHandler(msg)
return msg
Here is how you use the inference class to call the LLM. If you have ever used the openai
client you will be familiar with the input and output format.
llm = OpenAIChatCompletion()
message_logger = debug_messages() # optional for keeping track of all outgoing requests
print(llm(model="gpt-3.5-turbo-0125", messages=[dict(role="user", content="Hello!")]))
And our logger is keeping track of all the outgoing requests:
message_logger.log_messages
Now we can define some different models that can all be accessed through the same inference class.
class Models:
# OpenAI GPT Models
GPT4 = dict(model="gpt-4-0125-preview", base_url=None, api_key=None)
GPT3 = dict(model="gpt-3.5-turbo-0125", base_url=None, api_key=None)
# Hugging Face Inference Endpoints
OPENHERMES2_5_MISTRAL_7B = dict(
model="tgi",
base_url="https://xofunqxk66baupmf.us-east-1.aws.endpoints.huggingface.cloud" + "/v1/",
api_key=os.environ["HUGGING_FACE_ACCESS_TOKEN"],
)
# Ollama Models
LLAMA2 = dict(
model="llama2",
base_url="http://localhost:11434/v1",
api_key="ollama",
)
GEMMA2B = dict(
model="gemma:2b-instruct",
base_url="http://localhost:11434/v1",
api_key="ollama",
)
# together AI endpoints
GEMMA7B = dict(model="google/gemma-7b-it", base_url="https://api.together.xyz/v1", api_key=os.environ.get("TOGETHER_API_KEY"))
MISTRAL7B = dict(model="mistralai/Mistral-7B-Instruct-v0.1", base_url="https://api.together.xyz/v1", api_key=os.environ.get("TOGETHER_API_KEY"))
all_models = [(model_name, model_config) for model_name, model_config in Models.__dict__.items() if not model_name.startswith("__")]
messages = [
{"role": "system", "content": "You are a helpful assistant. Your replies are short, brief and to the point."},
{"role": "user", "content": "Who was the first person to walk on the Moon, and in what year did it happen?"},
]
for model_name, model_config in all_models:
resp = llm(messages=messages, **model_config)
print(f"Model: {model_name}")
print(f"Response: {resp.choices[0].message.content}")
We can also send the same requests in parallel like this:
task_args_list = []
for model_name, model_config in all_models:
task_args_list.append(dict(messages=messages, **model_config))
# execute the same calls in parallel
model_names = [m[0] for m in all_models]
resps = llm.create_chat_completions_async(task_args_list)
for model_name, resp in zip(model_names, resps):
print(f"Model: {model_name}")
print(f"Response: {resp.choices[0].message.content}")
assert len(message_logger.log_messages) == 15
message_logger.log_messages[-1]
There are various approaches to getting structured output from LLMs. For example see JSON mode and Function calling. Some open source models and inference providers are also starting to offer these capabilities. For example see the together.ai docs. The instructor blog also has lots of examples and tips for getting structured output from LLMs. See this recent blog post for getting structured output from open source and Local LLMs.
One thing that is neat about the instructor library is you can define a Pydantic schema and then pass it to the patched openai
client.
It also adds in schema validation and retry logic.
First we will clear out our debugging log messages.
message_logger.log_messages = []
from typing import List
from pydantic import BaseModel, field_validator
class Character(BaseModel):
name: str
race: str
fun_fact: str
favorite_food: str
skills: List[str]
weapons: List[str]
class Characters(BaseModel):
characters: List[Character]
@field_validator("characters")
@classmethod
def validate_characters(cls, v):
if len(v) < 20:
raise ValueError(f"The number of characters must be at least 20, but it is {len(v)}")
return v
res = llm(
messages=[dict(role="user", content="Who are the main characters from Lord of the Rings?.")],
response_model=Characters,
max_retries=4,
**Models.GPT4,
)
for character in res.characters:
for k, v in character.model_dump().items():
print(f"{k}: {v}")
print()
It is probably likely that GPT would not return 20 characters in the first request.
If max_retries=0
then it would likely raise a Pydantic validation error.
But since we have max_retries=4
then the instructor
library sends back
the validation error as a message and asks again. How exactly does it do that?
We can look at the messages that we have logged for debugging.
assert len(message_logger.log_messages) > 1
len(message_logger.log_messages)
message_logger.log_messages
If you look through the above messages carefully you can see the retry asking logic.
Recall the function correctly, fix the errors and exceptions found\n1 validation error for Characters\ncharacters\n Value error, The number of characters must be at least 20, ...
You can even use the structured output with some of the open source models. I would refer to the instructor blog or documentation for further information on that. I have not fully looked into the different patching modes yet. But here is a simple example of using MISTRAL7B
through together.ai.
res = llm(
messages=[dict(role="user", content="Give me a character from a movie or book.")],
response_model=Character,
max_retries=2,
**Models.MISTRAL7B,
)
print(res.model_dump())
Again, I really like the idea of using a single interface for interacting with multiple LLMs. I hope the space continues to mature so that more open source models and services support JSON mode and function calling. I think instructor is a cool library and the corresponding blog is interesting too. I also like the idea of logging all the outgoing prompts/messages just to make sure I fully understand what is happening under the hood.