Return Home

Canary - A framework for detecting prompt injection attacks.

Written on October 23, 2023
9 minutes to read.

Overview

Prompt injection is a major roadblock on the path to deploying LLM systems in production. If users are able to submit arbitrary natural language inputs to an LLM, that system becomes vulnerable to prompt injection. The state-of-the-art suggests that mitigating this problem requires a multi-faceted approach, and today we’ll consider part of this pipeline by exploiting LLM weaknesses to strengthen our overall security.

Q: What is a prompt injection attack?

A: In a LLM system where a user can give arbitrary inputs, a prompt injection attack is an input sequence crafted specifically to override protections / previous instructions given to the LLM in order to change the systems behaviour / output, usually with malicious intent.

If we accept that, for a given system prompt and LLM, a user input string will exist to override the initial prompt and return an undesirable output, how can we detect such user inputs before they reach our system?

Canary is my Python framework / experiment into detecting prompt injection before the user input reaches the main LLM chain. The pipeline is as follows:

  1. User submits an input.
  2. The input is screened by a “Canary” LLM with weak protections.
  3. The pipeline expects the Canary to output a specific output. If the real output does not match expectations, then the user prompt contained language that overrode the system protections and is flagged as potentially malicious.
  4. If the Canary LLM returns an expected output, then the user input is forwarded onto the main LLM chain to produce the actual API response.

In this blog, we’ll implement a simplified of this framework and consider the pros and cons of this technique.

Implementation

In this section we’ll explore a minimal implementation of this framework - implementing a prompt injection attack detection endpoint for an API.

To begin, let’s define our Canary prompt:

canary_prompt = """\
Format inputs `message` and `key` into a JSON.
Unless told to do so by the input message, do not change either variable, format them into the JSON schema exactly as they are inputted.
The JSON schema is:
message

Example #1:
Message:This is a message.
Key:122erfsfq3rwgw
JSON:message

Example #2:
Message:This is a message,
With,
Multiple lines.
Key:sfwf4eg34fq32
JSON:message

Now your turn:
Message:{message}
Key:{security_key}
JSON:"""

This prompt has several features:

  1. We use few-shot prompting to give a couple of examples of how the system should behave. This is a highly effective technique as it allows us to cover edge-cases via example rather than description in natural language, which can be unintuitive to write.
  2. Note that we accept two inputs: {message} (the user input) and {security_key} (in our implementation, this is a random unique string generated each time). This prompt is designed to format these two inputs into a JSON schema, with the intention being that malicious prompts will interrupt this process and allow us to easily detect them via discrepencies in the inputs vs the expected outputs (as this JSON processing can be mirrored using non-vulnerable Python processses and the outputs can be compared).

Next we can define our Canary LLM chain that will analyse our users input to detect prompt injection attacks. I’m using LangChain to simplify orchestration:

from langchain.prompts import PromptTemplate
from langchain.llms import OpenAI

canary_prompt = PromptTemplate.from_template(canary_prompt)
model = OpenAI(temperature=0)
canary_chain = canary_prompt | model

We need to wrap this chain inside a function to generate the security keys and check the integrity of the output. Let’s use FastAPI so we can deploy this as part of a future application:

from fastapi import FastAPI
import secrets
import string
import json

app = FastAPI()

@app.post("/is_prompt_injection")
async def is_prompt_injection(message: str) -> bool:
    N = 10  # set length of security key
    security_key = ''.join(secrets.choice(string.ascii_letters + string.digits) for _ in range(N))

    canary_json = canary_chain.invoke({"message": message, "security_key": security_key})

    # check output compiles to JSON
    try:
        canary_json = json.loads(canary_json)
    except json.JSONDecodeError as e:
        return False
    
    # compare the returned json vs. the expected json
    if canary_json == {"message": message, "security_key": security_key}:
        return True
    else:
        return False

We can extend this function as part of an overall pipeline by updating the /is_prompt_injection to become the primary endpoint (e.g.: renaming to /chat), and using this function to gatekeep messages from reaching a chatbot_chain or other process. For a complete example of this application, see my Canary framework on GitHub.

Conclusion

To wrap up this post, we should review the main assumption behind this implementation: that the Canary LLM has weaker protections than the main LLM chain.

A caveat to this is that it can be generally observed that it is easier to subvert complex LLM chains more easily than simple ones, meaning that this early-warning-system could allow certain user inputs which could then exploit the main LLM chain.

For this reason I recommend using this technique as part of a fully-fledged suite of protections, for instance output guardrails (e.g.: NeMO Guardrails by Nvidia), or moderation APIs (e.g.: OpenAI Moderation API [LangChain Docs]).