Skip to content

Commit e752031

Browse files
committed
finetune.cpp command-line arg
add to ggml-opt learning rate (adamw alpha) cmdline arg, and an optimizer enum defaulting to adamw, including string->id mapping, preparatory to work to support SGD these are in common args a set of optimizer options active only for the new FINETUNE example (but we drop all the previous finetune.cpp PERPLEXITY options which we're told are unused/accidental) perhaps breaking with precedent, the ggml_opt_optimizer_params struct is included directly as args - if desired, we can instead just add learning rate and optimizer type to a struct independent of ggml-opt.h as proposed in #13835
1 parent e0e3aa2 commit e752031

File tree

5 files changed

+63
-5
lines changed

5 files changed

+63
-5
lines changed

common/arg.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,7 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
10951095
"llama-embedding",
10961096
"llama-eval-callback",
10971097
"llama-export-lora",
1098+
"llama-finetune",
10981099
"llama-gen-docs",
10991100
"llama-gguf",
11001101
"llama-gguf-hash",
@@ -1239,6 +1240,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
12391240
sampler_type_names.pop_back();
12401241

12411242

1243+
params.optimize = ggml_opt_get_default_optimizer_params(NULL);
1244+
params.optimize.adamw.alpha = 1e-8; // default 1e-3 is much too high for LLAMA_EXAMPLE_FINETUNE
1245+
12421246
/**
12431247
* filter options by example
12441248
* rules:
@@ -1472,14 +1476,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
14721476
[](common_params & params) {
14731477
params.ctx_shift = false;
14741478
}
1475-
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
1479+
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE}).set_env("LLAMA_ARG_NO_CONTEXT_SHIFT"));
14761480
add_opt(common_arg(
14771481
{"--chunks"}, "N",
14781482
string_format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
14791483
[](common_params & params, int value) {
14801484
params.n_chunks = value;
14811485
}
1482-
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
1486+
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_FINETUNE, LLAMA_EXAMPLE_RETRIEVAL}));
14831487
add_opt(common_arg(
14841488
{"-fa", "--flash-attn"},
14851489
string_format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"),
@@ -2181,6 +2185,24 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
21812185
params.ppl_output_type = value;
21822186
}
21832187
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
2188+
add_opt(common_arg(
2189+
{"-lr", "-alpha", "--alpha", "--learning-rate"}, "ALPHA",
2190+
string_format("adamw optimizer alpha (default: %.1f)", (double)params.optimize.adamw.alpha),
2191+
[](common_params & params, const std::string & value) {
2192+
params.optimize.adamw.alpha = std::stof(value);
2193+
}
2194+
).set_examples({LLAMA_EXAMPLE_FINETUNE}));
2195+
add_opt(common_arg(
2196+
{"-opt", "--optimizer"}, "sgd|adamw",
2197+
"adamw or //TODO:sgd",
2198+
[](common_params & params, std::string const& name) {
2199+
params.optimize.optimizer = named_ggml_opt_optimizer(name.c_str());
2200+
if (params.optimize.optimizer == GGML_OPT_OPTIMIZER_COUNT)
2201+
throw std::invalid_argument("invalid --optimizer N (try 0)");
2202+
else if (params.optimize.optimizer == GGML_OPT_OPTIMIZER_SGD)
2203+
throw std::invalid_argument("TODO: implement SGD");
2204+
}
2205+
).set_examples({LLAMA_EXAMPLE_FINETUNE}));
21842206
add_opt(common_arg(
21852207
{"-dt", "--defrag-thold"}, "N",
21862208
string_format("KV cache defragmentation threshold (default: %.1f, < 0 - disabled)", (double)params.defrag_thold),

common/common.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#pragma once
44

55
#include "llama-cpp.h"
6+
#include "ggml-opt.h"
67

78
#include <set>
89
#include <string>
@@ -80,6 +81,7 @@ enum llama_example {
8081
LLAMA_EXAMPLE_LOOKUP,
8182
LLAMA_EXAMPLE_PARALLEL,
8283
LLAMA_EXAMPLE_TTS,
84+
LLAMA_EXAMPLE_FINETUNE,
8385

8486
LLAMA_EXAMPLE_COUNT,
8587
};
@@ -349,6 +351,8 @@ struct common_params {
349351
bool no_mmproj = false; // explicitly disable multimodal model
350352
std::vector<std::string> image; // path to image file(s)
351353

354+
// finetune
355+
struct ggml_opt_optimizer_params optimize;
352356
// embedding
353357
bool embedding = false; // get only sentence embedding
354358
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)

examples/training/finetune.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ int main(int argc, char ** argv) {
1818

1919
params.escape = false;
2020

21-
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
21+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
2222
return 1;
2323
}
2424

@@ -60,8 +60,8 @@ int main(int argc, char ** argv) {
6060
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
6161
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
6262

63-
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
64-
optimizer_params.adamw.alpha = 1e-7f; // learning rate
63+
struct ggml_opt_optimizer_params &optimizer_params = params.optimize;
64+
LOG_INF("-optimizer %d -lr: %.1f", optimizer_params.optimizer, (double)optimizer_params.adamw.alpha);
6565

6666
struct llama_opt_params lopt_params {
6767
/*n_ctx_train =*/ 0,

ggml/include/ggml-opt.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,17 @@ extern "C" {
7474
GGML_OPT_BUILD_TYPE_OPT = 30,
7575
};
7676

77+
enum ggml_opt_optimizer {
78+
GGML_OPT_OPTIMIZER_ADAMW,
79+
GGML_OPT_OPTIMIZER_SGD,
80+
81+
GGML_OPT_OPTIMIZER_COUNT
82+
};
83+
84+
// "adamw" or "sgd" (case insensitive)
85+
GGML_API char const* ggml_opt_optimizer_name (enum ggml_opt_optimizer);
86+
GGML_API enum ggml_opt_optimizer named_ggml_opt_optimizer(char const*);
87+
7788
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
7889
struct ggml_opt_optimizer_params {
7990
// AdamW optimizer parameters
@@ -84,6 +95,7 @@ extern "C" {
8495
float eps; // epsilon for numerical stability
8596
float wd; // weight decay for AdamW, use 0.0f to disable
8697
} adamw;
98+
enum ggml_opt_optimizer optimizer;
8799
};
88100

89101
// callback to calculate optimizer parameters prior to a backward pass

ggml/src/ggml-opt.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,10 +228,30 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
228228
result.adamw.beta2 = 0.999f;
229229
result.adamw.eps = 1e-8f;
230230
result.adamw.wd = 0.0f;
231+
result.optimizer = GGML_OPT_OPTIMIZER_ADAMW;
231232

232233
return result;
233234
}
234235

236+
GGML_API char const* ggml_opt_optimizer_name (enum ggml_opt_optimizer o) {
237+
switch(o) {
238+
case GGML_OPT_OPTIMIZER_ADAMW:
239+
return "adamw";
240+
case GGML_OPT_OPTIMIZER_SGD:
241+
return "sgd";
242+
default:
243+
return "undefined";
244+
};
245+
}
246+
247+
248+
GGML_API enum ggml_opt_optimizer named_ggml_opt_optimizer (char const* n) {
249+
if (!strcasecmp("adamw", n)) return GGML_OPT_OPTIMIZER_ADAMW;
250+
else if (!strcasecmp("sgd", n))return GGML_OPT_OPTIMIZER_SGD;
251+
else return GGML_OPT_OPTIMIZER_COUNT;
252+
}
253+
254+
235255
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
236256
return *((struct ggml_opt_optimizer_params *) userdata);
237257
}

0 commit comments

Comments
 (0)