quantize: make output filename optional, default to ggml-model-<ftype>.bin (#1301)

This commit is contained in:
slaren 2023-05-05 00:58:56 +02:00 committed by GitHub
parent 34d9f22f44
commit 94c5652fc0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -6,7 +6,7 @@
#include <map> #include <map>
#include <string> #include <string>
static const std::map<std::string, enum llama_ftype> LLAMA_FTYPE_MAP = { static const std::map<std::string, llama_ftype> LLAMA_FTYPE_MAP = {
{"q4_0", LLAMA_FTYPE_MOSTLY_Q4_0}, {"q4_0", LLAMA_FTYPE_MOSTLY_Q4_0},
{"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1}, {"q4_1", LLAMA_FTYPE_MOSTLY_Q4_1},
{"q4_2", LLAMA_FTYPE_MOSTLY_Q4_2}, {"q4_2", LLAMA_FTYPE_MOSTLY_Q4_2},
@ -15,14 +15,38 @@ static const std::map<std::string, enum llama_ftype> LLAMA_FTYPE_MAP = {
{"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0}, {"q8_0", LLAMA_FTYPE_MOSTLY_Q8_0},
}; };
bool try_parse_ftype(const std::string & ftype_str, llama_ftype & ftype, std::string & ftype_str_out) {
auto it = LLAMA_FTYPE_MAP.find(ftype_str);
if (it != LLAMA_FTYPE_MAP.end()) {
ftype = it->second;
ftype_str_out = it->first;
return true;
}
// try to parse as an integer
try {
int ftype_int = std::stoi(ftype_str);
for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
if (it->second == ftype_int) {
ftype = it->second;
ftype_str_out = it->first;
return true;
}
}
}
catch (...) {
// stoi failed
}
return false;
}
// usage: // usage:
// ./quantize models/llama/ggml-model.bin models/llama/ggml-model-quant.bin type // ./quantize models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads]
// //
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_time_init(); ggml_time_init();
if (argc < 4) { if (argc < 3) {
fprintf(stderr, "usage: %s model-f32.bin model-quant.bin type [nthread]\n", argv[0]); fprintf(stderr, "usage: %s model-f32.bin [model-quant.bin] type [nthreads]\n", argv[0]);
for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) { for (auto it = LLAMA_FTYPE_MAP.begin(); it != LLAMA_FTYPE_MAP.end(); it++) {
fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second); fprintf(stderr, " type = \"%s\" or %d\n", it->first.c_str(), it->second);
} }
@ -36,24 +60,62 @@ int main(int argc, char ** argv) {
ggml_free(ctx); ggml_free(ctx);
} }
// parse command line arguments
const std::string fname_inp = argv[1]; const std::string fname_inp = argv[1];
const std::string fname_out = argv[2]; std::string fname_out;
int nthread;
llama_ftype ftype;
enum llama_ftype ftype; int arg_idx = 2;
if (argv[3][0] == 'q') { std::string ftype_str;
auto it = LLAMA_FTYPE_MAP.find(argv[3]); if (try_parse_ftype(argv[arg_idx], ftype, ftype_str)) {
if (it == LLAMA_FTYPE_MAP.end()) { // argv[2] is the ftype
fprintf(stderr, "%s: unknown ftype '%s'\n", __func__, argv[3]); std::string fpath;
const size_t pos = fname_inp.find_last_of('/');
if (pos != std::string::npos) {
fpath = fname_inp.substr(0, pos + 1);
}
// export as [inp path]/ggml-model-[ftype].bin
fname_out = fpath + "ggml-model-" + ftype_str + ".bin";
arg_idx++;
}
else {
// argv[2] is the output path
fname_out = argv[arg_idx];
arg_idx++;
if (argc <= arg_idx) {
fprintf(stderr, "%s: missing ftype\n", __func__);
return 1;
}
// argv[3] is the ftype
if (!try_parse_ftype(argv[arg_idx], ftype, ftype_str)) {
fprintf(stderr, "%s: invalid ftype '%s'\n", __func__, argv[3]);
return 1;
}
arg_idx++;
}
// parse nthreads
if (argc > arg_idx) {
try {
nthread = std::stoi(argv[arg_idx]);
}
catch (const std::exception & e) {
fprintf(stderr, "%s: invalid nthread '%s' (%s)\n", __func__, argv[arg_idx], e.what());
return 1; return 1;
} }
ftype = it->second;
} else { } else {
ftype = (enum llama_ftype)atoi(argv[3]); nthread = 0;
} }
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
int nthread = argc > 4 ? atoi(argv[4]) : 0; fprintf(stderr, "%s: quantizing '%s' to '%s' as %s", __func__, fname_inp.c_str(), fname_out.c_str(), ftype_str.c_str());
if (nthread > 0) {
fprintf(stderr, " using %d threads", nthread);
}
fprintf(stderr, "\n");
const int64_t t_main_start_us = ggml_time_us(); const int64_t t_main_start_us = ggml_time_us();