Example
This page provides a complete working example of how to use dotjson to constrain an LLM to generate valid JSON.
The example below assumes you have followed the installation guide.
This example demonstrates:
- Creating a
Vocabularyfor gpt2 - Building an
Indexfor a simple JSON schema - Using the
LogitsProcessorto mask out invalid tokens during generation - Simulating an LLM’s token generation with random logits
- Performing weighted sampling from the valid tokens
- Continuing generation until either the EOS token is generated or a maximum token limit is reached
This example is a standalone implementation without an actual LLM backend. It uses random values to simulate token logits. In a real implementation, these would come from your language model’s forward pass.
Example code
Save the following code as example.py.
example.py
from dotjson import Vocabulary, Index, Guide, LogitsProcessor
import random
import torch
def main():
# The vocabulary for gpt2 has 50257 tokens
model = "gpt2"
eos_token_id = 50256 # EOS token to check if we have completed sampling
# Simple schema with one integer property
schema = '{"type":"object","properties":{"x":{"type":"integer"}}}'
# Create the vocabulary and index
vocabulary = Vocabulary.from_pretrained(model)
index = Index(schema, vocabulary)
# The mask value is 0, batch size is 1
mask_value = 0
batch_size = 1
vocab_size = vocabulary.max_token_id + 1
# Instead of performing inference, we'll be using randomly
# assigned logit scores between 1-100. Note that 0 is not in this range,
# as our mask_value is 0.
def random_logit() -> int:
return random.randint(1, 100)
# Create the logits initialized with random numbers
logits_base = torch.randint(1, 101, (vocab_size,), dtype=torch.float32)
logits = [logits_base]
# Construct a guide
guide = Guide(index, batch_size=batch_size)
# Create the set of initial permitted tokens
allowed_tokens = guide.get_start_tokensets()
# Construct the processor, a function that masks out invalid tokens
processor = LogitsProcessor(guide, 0, dtype="float32")
# Initialize the sequence of prior generated tokens
context = [torch.empty(0, dtype=torch.float32) for _ in range(batch_size)]
# The following loop will:
#
# 1. Randomly assign logits to all tokens (to simulate an LLM forward pass)
# 2. Mask out tokens inconsistent with the schema by setting their values
# to the mask_value (0 in this case)
# 3. Select a token randomly, proportional to its logit
# 4. Add that token to the context
# 5. Check if the new token is the EOS token -- if so, stop sampling.
# Set a maximum number of tokens to sample
max_tokens_to_sample = 100
# Begin sampling loop
for _ in range(max_tokens_to_sample):
# Reset logits_base with new random values before processing
for i in range(vocab_size):
logits_base[i] = random_logit()
# Mask out any tokens that are inconsistent with the schema
processor.update_logits(logits_base, allowed_tokens)
# Sample proportionally from the logits tensor
sampled_token_id = random.choices(range(vocab_size), weights=logits_base, k=1)[0]
# Add the sampled token to the context for the current batch entry
context[0] = torch.cat([context[0], torch.tensor([sampled_token_id], dtype=torch.float32)])
# Update the set of allowed tokens for the next iteration
allowed_tokens = guide.get_next_tokensets([sampled_token_id])
# Check if the sampled token is the EOS token, if so, exit.
if sampled_token_id == eos_token_id:
print(f"EOS token ({eos_token_id}) sampled. Stopping generation.")
break
# Print the final tokens
print("Final tokens:", " ".join(str(int(t)) for t in context[0].tolist()))
if __name__ == "__main__":
main()Run the example
Run the following command:
python example.pyYou should see output similar to the following:
EOS token (50256) sampled. Stopping generation.
Final tokens: 90 197 628 198 92 628 197 198 201 197 201 628 628 197 198 197 198 50256