From b061ba9e2a7a2c335a200df8c11aed5e31e4ccbb Mon Sep 17 00:00:00 2001 From: Alex Renda Date: Sat, 24 Jun 2023 03:15:01 -0700 Subject: [PATCH] llama : fix top-p sampling to match the canonical definition (#1953) * Fix top-p sampling to match the standard definition (smallest set that has probability mass at least p, not largest set with probability mass less than p) * top-p: correct gt to gte * add test for correct top-p behavior --- llama.cpp | 7 ++++--- tests/test-sampling.cpp | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/llama.cpp b/llama.cpp index a528eef..ac22a48 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2015,9 +2015,10 @@ void llama_sample_top_p(struct llama_context * ctx, llama_token_data_array * can for (size_t i = 0; i < candidates->size; ++i) { cum_sum += candidates->data[i].p; - // Check if the running sum is greater than p or if we have kept at least min_keep tokens - if (cum_sum > p && i >= min_keep) { - last_idx = i; + // Check if the running sum is at least p or if we have kept at least min_keep tokens + // we set the last index to i+1 to indicate that the current iterate should be included in the set + if (cum_sum >= p && i + 1 >= min_keep) { + last_idx = i + 1; break; } } diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 5d693f7..64f9455 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -181,6 +181,7 @@ int main(void) { test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0); test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f); + test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 0.8f); test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1); test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);