Logo

@ivanleomk

I write code sometimes

back?

Why use Chat GPT when you can build your own?

I've been toying with this idea for a while and I think I have some of the pieces in place to make it happen. Disclaimer - this has significantly reduced capabilities.


Introduction

Code is avaliable here

If you're new to the blog, I'm a full stack software engineer looking to break into the machine learning and data space. I don't really have much experience so I'm documenting my journey along the way to make it easier for others to follow along and learn from my mistakes. I figure it'll be nice to have a record of my progress as well.

Recently, someone linked me to a fantastic video by David Shapiro where he built his own version of Chat GPT locally. I thought that it was a pretty cool project and had some initial ideas on how to improve upon his design so I thought I would work on implementing it in parts.

Today's article is going to be the first part - we're just going to get a chatbot up and running that can respond to user input. It'll be able to query a list of relevant chat messages using embeddings and the most recent messages to generate a good answer for you.

Here's a quick diagram of what we're going to be building at some point

Setting up the environment

We first need to create a new python3 virtual environment. We can do so with the command

python3 -m venv .virtual

This will create a new virtual environment in the current directory. We can then activate it with the command

source .virtual/bin/activate

Building a chatbot

Here are the list of requirements that our chatbot will need to run

We can install these with the command

pip install openai pydantic numpy

Getting User Input

We can start by first writing a simple loop that grabs some input from our user in the terminal. We'll use the input function to do so.

main.py

if __name__ == "__main__":
    while True:
        user_message = input('User: ')
        print(f"....User responded with: {user_message}")

We can execute this with the command

python3 main.py

which in turn gives the following output

(.virtual) ➜  notes git:(main) ✗ python3 main.py
User: This is a response
....User responded with: This is a response
User: What's up lol
....User responded with: What's up lol

Saving User Messages

Now that we've succesfully managed to grab user input, let's work on saving it to our local disc space.

We'll save this to a simple json that we can query later for now. We can work on improving the efficiency of this later. When it comes to user responses, it makes sense that we should probably save a few things

  1. The message
  2. The time that we recieved the message
  3. A uuid to identify this message by
  4. The speaker - is it a message from a user or a agent

We can use pydantic to create a simple model for this.

model.py

from enum import Enum
from pydantic import BaseModel, Field
import uuid
 
class SpeakerEnum(str, Enum):
    USER = "USER"
    AGENT = "AGENT"
 
def generate_uuid():
    return str(uuid.uuid4())
 
class Message(BaseModel):
    timestamp: float
    message: str
    uuid: str = Field(default_factory=generate_uuid)
    speaker: SpeakerEnum

A few things here to point out

  • We ensure that we are assigning values in a typesafe way by using Pydantic's inbuilt type validation
  • We use specific Enums so that we don't have any unsafe access to our data
  • We use a factory function to be able to generate UUIDs automatically when we create a new message. This helps make our code nicer.

Let's now update our main.py to use this model

main.py

if __name__ == "__main__":
    while True:
        user_message = input('User: ')
        userMessage = Message(
            timestamp=time.time(),
            message=user_message,
            uuid=uuid.uuid4(),
            speaker=SpeakerEnum.USER
        )
        print(userMessage)

This in turn gives the output when we run it

(.virtual) ➜  notes git:(main) ✗ python3 main.py                              
User: This is a test
timestamp=1686325696.7861178 message='This is a test' uuid=UUID('7bfb6e76-8657-4f1a-9523-b7ef2c171d38') speaker=<SpeakerEnum.USER: 'USER'>
User: What\'s up amigo
timestamp=1686325703.707817 message="What's up amigo" uuid=UUID('88eab7a9-8821-42a1-804f-d1b2c9ecb477') speaker=<SpeakerEnum.USER: 'USER'>

Now that we've succesfully created our messages, let's work on serializing it and saving it to a location in memory. We can do so with a simple code snippet as seen below in our updated main.py

Let's now update our main.py to use this function

main.py

import os
import time
from model import Message, SpeakerEnum
import json
from pathlib import Path
 
SAVE_LOCATION = Path.cwd() / "data"
 
def save_message(message: Message):
    filename = f"{message.uuid}.json"
    with open(SAVE_LOCATION / filename, "w") as f:
        json.dump(message.dict(), f)
 
if __name__ == "__main__":
    if not os.path.exists(SAVE_LOCATION):
        os.mkdir(SAVE_LOCATION)
 
    while True:
        user_message = input('User: ')
        userMessage = Message(
            timestamp=time.time(),
            message=user_message,
            speaker=SpeakerEnum.USER
        )
        save_message(userMessage)

We can now run this and see that it works as expected

(.virtual) ➜  notes git:(main) ✗ python3 main.py
User: This is a test to see if the file succesfully serializes

which corresponds to the following JSON that we observed

{
  "timestamp": 1686326379.212329,
  "message": "This is a test to see if the file succesfully serializes",
  "uuid": "1f60676a-a950-4eae-b07f-807163e5c8b8",
  "speaker": "USER"
}

Getting A Response

Now let's work on getting a response from Open AI. We can do so by using the OpenAI API. We'll need to install the openai python package to do so. We can do so with the command

pip3 install openai

Don't forget to create a .env file to save your OpenAI API Key.

.env

OPENAI_API_KEY=YOUR_API_KEY

The way Open AI's competion model works is that we need to provide a prompt and then it will generate a response. there are a variety of different prompts that you might be able to use but in my case, I decided to just use this

config.py

RESPONSE_PROMPT = """You are now TutorGPT. You will be given a user prompt to respond too in some time. Make sure to answer in a polite, friendly and helpful manner.
 
Here are some examples of prior conversations for reference
AGENT: A language model is a statistical model that describes the probability of a word given the previous words.
USER: What is a statistical model?
 
User Question {prompt}
"""

In this case we're using a python templated string. We can use this to insert the user prompt into the string later on by using the .format() method.

We'll be using the completion api in order to do so. Let's update our main.py file so that it now uses this new completion API.

main.py

import os
import time
 
import openai
from config import RESPONSE_PROMPT
from model import Message, SpeakerEnum
import json
from pathlib import Path
 
SAVE_LOCATION = Path.cwd() / "data"
 
from dotenv import load_dotenv
 
# Load variables from .env file
load_dotenv()
 
openai.api_key = os.getenv("OPENAI_API_KEY")
 
if not openai.api_key:
    raise Exception("OPENAI_API_KEY not found in .env file")
 
def get_response_from_openai(user_prompt: str):
    response = openai.Completion.create(
        model="text-davinci-003",
        prompt=RESPONSE_PROMPT.format(prompt=user_prompt),
        temperature=0.2,
        max_tokens=300,
        top_p=1,
        frequency_penalty=0.0,
        presence_penalty=0.0,
        stop=["USER:", "AGENT:"],
    )
    return f'AGENT: {response["choices"][0].text}'
 
def save_message(message: Message):
    filename = f"{message.uuid}.json"
    with open(SAVE_LOCATION / filename, "w") as f:
        json.dump(message.dict(), f)
 
if __name__ == "__main__":
    if not os.path.exists(SAVE_LOCATION):
        os.mkdir(SAVE_LOCATION)
 
    while True:
        user_message = input("User: ")
        userMessage = Message(
            timestamp=time.time(), message=user_message, speaker=SpeakerEnum.USER
        )
        save_message(userMessage)
        response = get_response_from_openai(user_message)
        print(response)
        modelResponse = Message(
            timestamp=time.time(), message=response, speaker=SpeakerEnum.AGENT
        )
        save_message(modelResponse)
 

With this, we can now converse with our agent. Bear in mind the agent does not have any memory at this point in time.

(.virtual) ➜  notes git:(main) ✗ python main.py
User: What is the fastest way to travel between Malaysia and Brunei?
AGENT: The fastest way to travel between Malaysia and Brunei is by air. There are several airlines that offer direct flights between the two countries, so you can find a flight that fits your schedule and budget.
User: Between Singapore and Malaysia, which has the better food?
AGENT: It\'s hard to say which country has the better food, as both Singapore and Malaysia have a variety of delicious dishes. Singapore is known for its unique fusion of Chinese, Malay, and Indian flavors, while Malaysia is known for its spicy and flavorful dishes. Ultimately, it comes down to personal preference!

Adding memory

Now that we've got a basic agent that can respond to us, let's work on adding memory to our agent. We already implemented a function to help us save messages to disc. Let's now work on loading messages from disc.

Here's a high level overview of what this will look linked

  1. We first load all the messages from disc
  2. We convert them to a list of Message objects
  3. We sort the messages using the timestamp property. These messages are now in chronological order
  4. We then take a slice of the n most recent messages

In order to do so, we'll have to modify our prompt. This was what I ended up working with.

The new chat history is now implemented using a {chat_log} variable. This variable will be replaced with the chat history at some point using the format function.

You are now TutorGPT. You are an expert in multiple domains. You will be given a user prompt to respond too in some time. Make sure to answer in a polite, friendly and helpful manner.
 
Here are an examples of a prior conversation for reference
 
USER: What is a statistical model?
AGENT: A language model is a statistical model that describes the probability of a word given the previous words.
 
Here are the most recent chat messages between you and the user for reference
{chat_log}
 
Question: {prompt}
 
Please respond in the following format
 
AGENT: <response>

We can then implement a function to read in the messages from disc as seen below.

Read Prior Messages From Disc

def get_latest_messages(number: int = 4) -> list[Message]]:
    messages = []
    for file in os.listdir(SAVE_LOCATION):
        with open(SAVE_LOCATION / file, "r") as f:
            message = json.load(f)
            messages.append(message)
    messages = sorted(messages, key=lambda x: x["timestamp"])
    messages = [Message(**message) for message in messages]
    if number > len(messages):
        return messages
    return messages[-number:]

This function retrieves the most recent messages from a collection of JSON files stored in a specified location. It reads the files, loads the messages as JSON objects, sorts them based on their "timestamp" field, and converts them into Message objects. The function then returns either the latest number of messages or all messages if there are fewer than number messages available.

Note that as the number of messages scales, we'll end up incurring a large time penalty. This is because we're reading all the messages from disc and then sorting them. This is highly inefficient because we only need the n most recent messages.

Therefore, what we can do is simple

  1. We generate a list of the n most recent messages when our bot boots up - in the event that our bot has no prior messages or less than n messages, then we simply use all of the messages that we've managed to generate
  2. We then pop off the first two elements in our list ( since each user input generates two messages - one from the user and anothe response from the bot )
  3. We then append the two new messages to the end of the list
  4. We then feed this list into our prompt and repeat the process

This means that at each step, we perform a simple pop and append operation. This is much more efficient than reading all the messages from disc and sorting them.

I implemented step 2 using a function I called remove_enqued_messages. This function simply removes the first two elements in our list.

Remove Enqued Messages

def remove_enqueued_message(curr_messages: list[Message], curr_limit: int = 4):
    if len(curr_messages) > curr_limit:
        curr_messages.pop(0)
        curr_messages.pop(0)
    return curr_messages
  • Append is a O(1) operation in python but since we are popping from the front of the list, this is a O(n) operation. But since n is fixed, our entire step 2 becomes O(1).

With this, we now have a new main loop which looks like this

Main Chatbot Loop

if __name__ == "__main__":
    if not os.path.exists(SAVE_LOCATION):
        os.mkdir(SAVE_LOCATION)
 
    convo_history = get_latest_messages(4)
 
    while True:
        user_message = input("User: ")
        userMessage = Message(
            timestamp=time.time(), message=user_message, speaker=SpeakerEnum.USER
        )
        save_message(userMessage)
        response = get_response_from_openai(user_message, convo_history)
        print(response)
 
        modelResponse = Message(
            timestamp=time.time(), message=response, speaker=SpeakerEnum.AGENT
        )
        save_message(modelResponse)
 
        convo_history.append(userMessage)
        convo_history.append(modelResponse)
        remove_enqueued_message(convo_history)

We can see evidence that this new model has memory as seen below

User: Explain the concept of an embedding
AGENT: An embedding is a mathematical representation of a word or phrase in a vector space. It is used to capture semantic and syntactic relationships between words in a corpus. Embeddings are typically used in natural language processing \(NLP\) tasks such as sentiment analysis, text classification, and machine translation.
User: What did I just ask you to explain?
AGENT: You asked me to explain the concept of an embedding.
  • Our model was asked the concept of an embedding and when asked what it was previously asked about, it was able to correctly identify that it was about an embedding.

Adding some relevant context

Right now we've got a bot with the attention span of a goldfish - it'll only remember what you tell it in the last 4 messages. Let's work on improving this by adding some relevant context to our bot.

We'll do this by adding a new function called get_relevant_context. This function will take in a list of messages and return a string containing the most relevant context. To do this, we'll be using something called embeddings - these are vector representations of words. We'll be using the openai library to generate these embeddings.

In our case, Open AI's embeddings are simply a list of 1500 numbers.

Each number represents a different dimension. The embeddings are generated using a neural network. The neural network is trained on a large corpus of text. The neural network learns to generate embeddings such that words that are similar to each other are close to each other in the embedding space.

Let's update our Message model so that we can store the embeddings for each message. We can do by adding an additional field to our Message model.

Message

class Message(BaseModel):
    timestamp: float
    message: str
    uuid: str = Field(default_factory=generate_uuid)
    speaker: SpeakerEnum
    embeddings: list[float]

We then need to update our save_message function to generate the embeddings for each message. We can do this by using the openai library. We'll be using the openai.Embedding function to generate the embeddings. We'll then store the embeddings in our Message model.

Generate Embedding

def generate_embedding(message: str):
    response = openai.Embedding.create(input=message, model="text-embedding-ada-002")
    return response["data"][0]["embedding"]

The next step is to use embeddings to identify related messages that we've sent before. Fortunately David Shapiro provided an easy function to do so with

def similarity(v1: list[int], v2: list[int]) -> int:
    return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))

Note that the maximum value of the similarity function is 1. This means that the maximum similarity between two embeddings is 1. The minimum value is -1.

We can then use it in a small helper function to help to identify the most relevant messages that exist in our database.

def get_similar_messages(
    messages: list[Message], newMessage: Message, desiredLength: int
):
    scores = [
        (similarity(newMessage.embeddings, message.embeddings), message)
        for message in messages
    ]
    ordered = filter(lambda x: x[0] > 0.9, scores)
    ordered = sorted(scores, key=lambda x: x[0], reverse=True)
    ordered = [message[1] for message in ordered]
 
    if len(ordered) < desiredLength:
        return ordered
    return ordered[:desiredLength]

Here's a quick breakdown of what's happening in the above function

  1. We calculate the similarity between the new message and all the messages in our database
  2. We then filter out all messages that don't have a minimum similarity score. In my case, i've specified an arbitrary minimum score of 0.9
  3. We then sort the messages by their similarity score such that the item with the greatest similarity score is at the front.

We can see that we get the following output with this new approach

(.virtual) ➜  notes git:(main) ✗ python3 main.py
User: What can I do to make my tulip flower grow faster?
Relevant responses found 2:
---USER: What are some of the best flowers to have in your garden?
---AGENT: Some of the best flowers to have in your garden are roses, daisies, tulips, lilies, and sunflowers. These flowers are all beautiful and easy to maintain. You can also add some herbs and vegetables to your garden for a more diverse and interesting look.
AGENT: To make your tulip flower grow faster, you can make sure it is planted in a sunny spot and that it is getting enough water. You can also add some fertilizer to the soil to help it grow faster. Additionally, you can trim off any dead or wilted leaves to help the plant focus its energy on blooming.
User: What are some useful programming languages to learn for data analysis?
Relevant responses found 2:
---USER: What can I do with a language like Python?
---AGENT: You can do a lot of things with Python! It is a powerful programming language that can be used for web development, data analysis, artificial intelligence, and more. Python is also a great choice for scripting and automation.
AGENT: Some of the most useful programming languages to learn for data analysis are Python, R, and SQL. Python is a great language for data analysis as it has a wide range of libraries and packages that make data analysis easier. R is also a great language for data analysis as it has a wide range of statistical and graphical capabilities. SQL is also a great language for data analysis as it allows you to easily query and manipulate data.

We can see that our bot is now able to identify relevant context and respond accordingly. Let's now try to give the bot some context and see how it responds.

User: I would like to know the number of mountains that Moribund has
---Considering tidbit of USER: How many mountains does Moribund have?
---Considering tidbit of USER: How many mountains does Moribund have?
---Considering tidbit of AGENT: I'm not sure how many mountains Moribund has, but I'm sure it's a beautiful place with plenty of natural wonders to explore.
---Considering tidbit of USER: There exists a kingdom in the highland called moribund. It has three large mountains and a great king to protect it
AGENT: From what you've told me, it sounds like Moribund has three large mountains. If you're looking for more information, you could try doing some research online or contacting local experts who may know more about the area.

We can see that the bot is identifying queries and responses that are close in intent to our current query. It then injects these into the prompt and uses it to generate a valid response to us. This is a great way to add some context to our bot. For now since we're feeding it some nonsensical information, it's not able to generate a particularly stunning response. However, if we were to feed it some valid information, it would be able to generate a valid response.

Conclusion

Today's build was a quick one and we implemented some of the basic ideas behind chat gpt and other gpt-powered chatbots that you see today

  1. We used the openai library to generate embeddings for our messages - this is a useful way to help us to identify related bits of information in our database. In our case, these corresponded to prior messages but for other applications, these could correspond to other bits of information or proprietary data lakes.

  2. We used role-based prompts - Andrej Karpathy mentions this in his State Of GPT talk and it's a great way to help to add our bot to adhere to certain guidelines. Note the use of the term You are an expert in multiple domains. This is a useful way to help force the LLM to prefer more expert-like responses, which will in turn imrprove the quality of the responses

  3. We were able to refer to a list of documents and data that was constantly evolving using embeddings and prompt time injection of data.

  4. We used pydantic as a way to ensure some level of type safety when working with serialization and deserialization.

Overall, i'd say for a weekend build it's not too bad. Some potential areas for improvement I'd say we need to consider would be

  • Using a actual database built for vector similarity search - this could be pgvector, weaviate, pinecone or even chroma.
  • Using a database to store information instead of reading and deserializing to disc.