Example

This page provides a complete working example of how to use dotjson to constrain an LLM to generate valid JSON.

The example below assumes you have followed the installation guide.

This example demonstrates:

  1. Creating a Vocabulary for gpt2
  2. Building an Index for a simple JSON schema
  3. Using the LogitsProcessor to mask out invalid tokens during generation
  4. Simulating an LLM’s token generation with random logits
  5. Performing weighted sampling from the valid tokens
  6. Continuing generation until either the EOS token is generated or a maximum token limit is reached

This example is a standalone implementation without an actual LLM backend. It uses random values to simulate token logits. In a real implementation, these would come from your language model’s forward pass.

Example code

Save the following code as example.py.

example.py
from dotjson import Vocabulary, Index, Guide, LogitsProcessor
import random
import torch


def main():
    # The vocabulary for gpt2 has 50257 tokens
    model = "gpt2"
    eos_token_id = 50256  # EOS token to check if we have completed sampling

    # Simple schema with one integer property
    schema = '{"type":"object","properties":{"x":{"type":"integer"}}}'

    # Create the vocabulary and index
    vocabulary = Vocabulary.from_pretrained(model)
    index = Index(schema, vocabulary)

    # The mask value is 0, batch size is 1
    mask_value = 0
    batch_size = 1
    vocab_size = vocabulary.max_token_id + 1

    # Instead of performing inference, we'll be using randomly
    # assigned logit scores between 1-100. Note that 0 is not in this range,
    # as our mask_value is 0.
    def random_logit() -> int:
        return random.randint(1, 100)

    # Create the logits initialized with random numbers
    logits_base = torch.randint(1, 101, (vocab_size,), dtype=torch.float32)
    logits = [logits_base]

    # Construct a guide
    guide = Guide(index, batch_size=batch_size)
    # Create the set of initial permitted tokens
    allowed_tokens = guide.get_start_tokensets()
    # Construct the processor, a function that masks out invalid tokens
    processor = LogitsProcessor(guide, 0, dtype="float32")
    # Initialize the sequence of prior generated tokens
    context = [torch.empty(0, dtype=torch.float32) for _ in range(batch_size)]

    # The following loop will:
    #
    # 1. Randomly assign logits to all tokens (to simulate an LLM forward pass)
    # 2. Mask out tokens inconsistent with the schema by setting their values
    #    to the mask_value (0 in this case)
    # 3. Select a token randomly, proportional to its logit
    # 4. Add that token to the context
    # 5. Check if the new token is the EOS token -- if so, stop sampling.

    # Set a maximum number of tokens to sample
    max_tokens_to_sample = 100

    # Begin sampling loop
    for _ in range(max_tokens_to_sample):

        # Reset logits_base with new random values before processing
        for i in range(vocab_size):
            logits_base[i] = random_logit()

        # Mask out any tokens that are inconsistent with the schema
        processor.update_logits(logits_base, allowed_tokens)

        # Sample proportionally from the logits tensor
        sampled_token_id = random.choices(range(vocab_size), weights=logits_base, k=1)[0]

        # Add the sampled token to the context for the current batch entry
        context[0] = torch.cat([context[0], torch.tensor([sampled_token_id], dtype=torch.float32)])

        # Update the set of allowed tokens for the next iteration
        allowed_tokens = guide.get_next_tokensets([sampled_token_id])

        # Check if the sampled token is the EOS token, if so, exit.
        if sampled_token_id == eos_token_id:
            print(f"EOS token ({eos_token_id}) sampled. Stopping generation.")
            break

    # Print the final tokens
    print("Final tokens:", " ".join(str(int(t)) for t in context[0].tolist()))

if __name__ == "__main__":
    main()

Run the example

Run the following command:

python example.py

You should see output similar to the following:

EOS token (50256) sampled. Stopping generation.
Final tokens: 90 197 628 198 92 628 197 198 201 197 201 628 628 197 198 197 198 50256