Fix convert script, warnings alpaca instructions, default params

This commit is contained in:
Georgi Gerganov 2023-03-21 17:59:16 +02:00
parent 715d292ee0
commit 3bfa3b43b7
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 23 additions and 17 deletions

View file

@ -193,15 +193,15 @@ First, download the `ggml` Alpaca model into the `./models` folder:
``` ```
# use one of these # use one of these
# TODO: add a script to simplify the download # TODO: add a script to simplify the download
curl -o ggml2-alpaca-7b-q4.bin -C - https://gateway.estuary.tech/gw/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1 curl -o ./models/ggml-alpaca-7b-q4.bin -C - https://gateway.estuary.tech/gw/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1
curl -o ggml2-alpaca-7b-q4.bin -C - https://ipfs.io/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1 curl -o ./models/ggml-alpaca-7b-q4.bin -C - https://ipfs.io/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1
curl -o ggml2-alpaca-7b-q4.bin -C - https://cloudflare-ipfs.com/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1 curl -o ./models/ggml-alpaca-7b-q4.bin -C - https://cloudflare-ipfs.com/ipfs/QmUp1UGeQFDqJKvtjbSYPBiZZKRjLp8shVP9hT8ZB9Ynv1
``` ```
Now run the `main` tool like this: Now run the `main` tool like this:
``` ```
./main -m ./models/ggml2-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins ./main -m ./models/ggml-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins
``` ```
Sample run: Sample run:
@ -218,7 +218,7 @@ Sample run:
There 26 letters in the English Alphabet There 26 letters in the English Alphabet
> What is the most common way of transportation in Amsterdam? > What is the most common way of transportation in Amsterdam?
The majority (54%) are using public transit. This includes buses, trams and metros with over 100 lines throughout the city which make it very accessible for tourists to navigate around town as well as locals who commute by tram or metro on a daily basis The majority (54%) are using public transit. This includes buses, trams and metros with over 100 lines throughout the city which make it very accessible for tourists to navigate around town as well as locals who commute by tram or metro on a daily basis
> List 5 words that start with "ca". > List 5 words that start with "ca".
cadaver, cauliflower, cabbage (vegetable), catalpa (tree) and Cailleach. cadaver, cauliflower, cabbage (vegetable), catalpa (tree) and Cailleach.
> >
``` ```

View file

@ -3,4 +3,4 @@
# Temporary script - will be removed in the future # Temporary script - will be removed in the future
# #
./main -m ./models/ggml-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins --top_k 10000 --temp 0.96 --repeat_penalty 1 -t 7 ./main -m ./models/ggml-alpaca-7b-q4.bin --color -f ./prompts/alpaca.txt -ins --top_k 10000 --temp 0.2 --repeat_penalty 1 -t 7

View file

@ -27,9 +27,9 @@ from sentencepiece import SentencePieceProcessor
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file') parser = argparse.ArgumentParser(description='Convert a LLaMA model checkpoint to a ggml compatible file')
parser.add_argument('dir_model', help='directory containing the model checkpoint') parser.add_argument('dir_model', help='directory containing the model checkpoint')
parser.add_argument('ftype', type=int, choices=[0, 1], default=1, help='file type (0: float32, 1: float16)') parser.add_argument('ftype', help='file type (0: float32, 1: float16)', type=int, choices=[0, 1], default=1)
parser.add_argument('vocab_only', type=bool, default=False, help='only write vocab to file') parser.add_argument('vocab_only', help='only write vocab to file', type=int, default=0, nargs='?')
return parser.parse_args() return parser.parse_args()
def get_n_parts(dim): def get_n_parts(dim):
@ -135,6 +135,8 @@ def main():
hparams, tokenizer = load_hparams_and_tokenizer(dir_model) hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
print(args)
# if only writing vocab to file # if only writing vocab to file
if args.vocab_only: if args.vocab_only:

View file

@ -165,12 +165,20 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca
// load vocab // load vocab
{ {
std::string word; std::string word;
std::vector<char> tmp(64);
for (int i = 0; i < model.hparams.n_vocab; i++) { for (int i = 0; i < model.hparams.n_vocab; i++) {
uint32_t len; uint32_t len;
fin.read((char *) &len, sizeof(len)); fin.read((char *) &len, sizeof(len));
word.resize(len); word.resize(len);
fin.read((char *) word.data(), len); if (len > 0) {
tmp.resize(len);
fin.read(tmp.data(), len);
word.assign(tmp.data(), len);
} else {
word.clear();
}
float score; float score;
fin.read((char *) &score, sizeof(score)); fin.read((char *) &score, sizeof(score));
@ -178,10 +186,6 @@ bool llama_model_load(const std::string & fname, llama_model & model, llama_voca
vocab.token_to_id[word] = i; vocab.token_to_id[word] = i;
vocab.id_to_token[i] = word; vocab.id_to_token[i] = word;
vocab.score[i] = score; vocab.score[i] = score;
//if (i < 30000) {
// fprintf(stderr, "%s: vocab[%d] = '%s'\n", __func__, i, word.c_str());
//}
} }
} }
@ -974,7 +978,7 @@ int main(int argc, char ** argv) {
n_past += embd.size(); n_past += embd.size();
embd.clear(); embd.clear();
if (embd_inp.size() <= input_consumed) { if ((int) embd_inp.size() <= input_consumed) {
// out of user input, sample next token // out of user input, sample next token
const float top_k = params.top_k; const float top_k = params.top_k;
const float top_p = params.top_p; const float top_p = params.top_p;
@ -1011,7 +1015,7 @@ int main(int argc, char ** argv) {
--remaining_tokens; --remaining_tokens;
} else { } else {
// some user input remains from prompt or interaction, forward it to processing // some user input remains from prompt or interaction, forward it to processing
while (embd_inp.size() > input_consumed) { while ((int) embd_inp.size() > input_consumed) {
embd.push_back(embd_inp[input_consumed]); embd.push_back(embd_inp[input_consumed]);
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(embd_inp[input_consumed]); last_n_tokens.push_back(embd_inp[input_consumed]);
@ -1036,7 +1040,7 @@ int main(int argc, char ** argv) {
// in interactive mode, and not currently processing queued inputs; // in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more // check if we should prompt the user for more
if (params.interactive && embd_inp.size() <= input_consumed) { if (params.interactive && (int) embd_inp.size() <= input_consumed) {
// check for reverse prompt // check for reverse prompt
for (auto antiprompt_inp : antipromptv_inp) { for (auto antiprompt_inp : antipromptv_inp) {
if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) { if (antiprompt_inp.size() && std::equal(antiprompt_inp.rbegin(), antiprompt_inp.rend(), last_n_tokens.rbegin())) {