There’s been a lot written about Large Language Models (LLMs) and ways to adjust them to your particular needs. Sampling is one of these ways that is fairly easy to use that allows the user to vary the output of an LLM such as Gemma. This post will describe some common techniques of sampling available on many different LLMs, using a simple, hypothetical training set. It will also show code for doing this with a much larger training set with Gemma.
Why sample?
LLMs work much like predictive text in editors: given a preceding group of words, an LLM will select the next word/phrase based on the sample it has been trained on. It takes the context and using the training data will select the most likely next word or phrase. For example, if the context is “have a happy” it is likely to be followed by “birthday” or “new year” and very unlikely to be followed by “xylophone” since “have a happy xylophone” is not a frequently occurring phrase.
To make LLMs seem more creative, the most common prediction may not always be the one selected. There are a number of ways to guide the LLM in making a selection. This process is known as sampling.
We’ll look at a number of ways to do sampling, both in general and on the LLM Gemma in the notebook available on GitHub. Gemma is a family of open lightweight generative AI models built on the same technology as Gemini, Google’s largest and most capable LLM. It is designed to be easy to customize.
Getting some data
To make the ideas here more concrete, let’s analyze a real text, The Little Red Hen. This is an English folk tale with just under 1400 words.
Yes, this is a massive simplification. But sometimes starting simple works better.
We can look at how often certain word pairs occur in the text. Not surprisingly, the word “red” appears 16 times and each time is followed by the word “hen”. The word “little” also appears 16 times, but is only followed by “red” 14 times. The other times it’s followed by “fluff-balls” and “body,” one time each.
The word “the” is more interesting. It appears 109 times and is followed by 58 different words. 45 of these words only appear after “the” once and 6 appear after “the” 2 times each. The other 7 words are:
| word | count | frequency |
| little | 10 | 9.2% |
| pig | 9 | 8.3% |
| cat | 9 | 8.3% |
| rat | 8 | 7.3% |
| wheat | 8 | 7.3% |
| barnyard | 4 | 3.7% |
| bread | 4 | 3.7% |
These 7 words appear after the word “the” almost 48% of the time “the” appears. Let’s look at how different sampling methods select from these choices.
Greedy sampling
Greedy sampling is quite simple–for the next token, pick the one with the highest frequency. In our example, the word selected to follow “the” will always be “little” since it has the highest frequency.
This is the default in Gemma. Using Keras, a deep learning API for Python, you can create a Gemma model and set its sampler to greedy using the code below:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma_2b_en")
gemma_lm.compile(sampler="greedy")
Greedy sampling is simple enough, but doesn’t lead to much variety in responses. If you’re working in the Gemma notebook, you can try it out using the code below. You should notice that you get the same response each time you run it.
print(gemma_lm.generate('Are cats or dogs better?', max_length=32))
Top K Sampling
Top K sampling can give more variety. Instead of selecting the response with the highest frequency, the user can specify a K value–the number of tokens with the highest frequencies to select from. So, for k = 5, in our example above, one of the tokens “little”, “pig”, “cat”, “rat”, or “wheat” would be selected. Since “little” is most frequent, it would be selected most often, but all 5 would be possible responses. The relative percentages of these 5 are shown below.
| word | count | frequency |
| little | 10 | 22.7% |
| pig | 9 | 20.4% |
| cat | 9 | 20.4% |
| rat | 8 | 18.1% |
| wheat | 8 | 18.1% |
To use the Top K sampler in Gemma, you need to first create the sampler since you need to specify a value for k. Once that’s done, you can recompile the model without having to recreate it, making it easier to experiment with different samplers:
sampler = keras_nlp.samplers.TopKSampler(k=5)
gemma_lm.compile(sampler=sampler)
If you tried that, you would notice some variety in the response you get. You can also try using different values of k to see how they affect the output.
Top P Sampling
Top P sampling is similar to Top K, but instead of specifying how many tokens to include in the pool, Top P specifies what percentage of tokens to include, based on how frequent the tokens are. So, for a Top P sample of the data with p = 25%, tokens would be taken from the most frequent (“little”) to the next most frequent (“pig” and “cat”) until a total of 25% frequency has been met. These three words have a combined frequency of 25.8% .
Using a Top P sampler in Gemma works much like a Top K sampler.
sampler = keras_nlp.samplers.TopPSampler(p=0.25)
gemma_lm.compile(sampler=sampler)
Try it again and see how the output changes as you rerun the code and change the value of p.
Random Sampling
Random sampling includes all possible values in determining the next token, using the probability of each token as the chance of selecting it. So there would be a 9.2% chance of selecting “little” after “the” and a 0.9% chance of selecting “big” since the combination “the big” appears once in the text.
Temperature
In addition to changing the type of sampling done, a function can be applied to the counts, making the differences between the values either more or less significant. The temperature is a value included in this function and can be sent as a parameter to almost all types of sampling. You can think of this function as raising a base (based on the temperature) to the count power.
Consider what happens with just the 5 most frequent values. When you use them as the exponent on 2, the values end up widely spread apart, with the largest (at 39.6%) more than 60 times the frequency of the smallest (at 0.6%). But when the base is 1.2, the largest frequency will only be about 3 times the frequency of the smallest.
| Word | Count | 2^count | distribution | 1.2^count | distribution |
| little | 10 | 1024 | 39.5% | 6.19 | 21.2% |
| pig | 9 | 512 | 19.8% | 5.16 | 17.6% |
| cat | 9 | 512 | 19.8% | 5.16 | 17.6% |
| rat | 8 | 256 | 9.9% | 4.30 | 14.7% |
| wheat | 8 | 256 | 9.9% | 4.30 | 14.7% |
| barnyard | 4 | 16 | 0.6% | 2.07 | 7.1% |
| bread | 4 | 16 | 0.6% | 2.07 | 7.1% |
| sum | 52 | 2592 | 29.26 |
Temperature can be added as a parameter to any of the samplers we’ve used thus far as in:
sampler = keras_nlp.samplers.TopPSampler(p=0.25, temperature=0.7)
gemma_lm.compile(sampler=sampler)
What’s Next
In this tutorial, you learned how to modify the output of Gemma by using different sampling techniques. Here are a few suggestions for what to learn next:
- Learn how to finetune a Gemma model.
- Learn how to perform distributed fine-tuning and inference on a Gemma model.
- Learn about Gemma integration with Vertex AI.
- Learn how to use Gemma models with Vertex AI.
- Learn the details of how temperature is really computed by the softmax function.