Skip to content

Commit fcc241e

Browse files
committed
examples/finetune -opt SGD (stochastic gradient descent) memory opt
add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating m, v tensors. support finetune.cpp arg -opt SGD (or sgd). (default adamw as before) llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch) when using SGD instead of 19gb (55 sec/epoch) using adamw. (wikipedia 100 lines finetune) ( using the same GPU memory, adamw can only do before OOM 512 batch/context, reaching: train: [███████▉] data=0000140/0000140 loss=0.02575±0.00099 acc=99.52±0.03% t=00:00:47 ETA=00:00:00 val: [███████▉] data=0000008/0000008 loss=4.76565±0.28810 acc=41.46±0.77% t=00:00:00 ETA=00:00:00 SGD is superior, though it converges slower, with max before OOM 1728 batch/context (esp see the better validation perf): train: [███████▉] data=0000039/0000039 loss=0.00371±0.00010 acc=99.96±0.01% t=00:00:41 ETA=00:00:00 val: [███████▉] data=0000003/0000003 loss=5.11406±0.76034 acc=48.01±0.69% t=00:00:01 ETA=00:00:00 ) note: when finetuning long enough (or w/ enough -lr), validation accuracy *eventually* drops ('catastrophic forgetting') -lr-half (halflife) option useful for SGD to avoid oscillation or super slow underdamped learning (makes setting -lr more forgiving). terminal -lr for now is set by lr-halvings i.e. if you want at most 1/8 the inital -lr you set -lr-halvings 3. note: objective loss not directly comparable between adamw, sgd? - check perplexity or accuracy or consider relative improvements for convergence new finetune args -wd 1e-9 to enable weight decay in sgd or adamw, and max -epochs N (default 2 as before) cache (1 - wd*alpha) in 'adamw' opt struct - no noticeable perf benefit, disabled (still done for new SGD though) since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params would probably be able to change between SGD and AdamW with each epoch but would need to use adamw for the first (unconfirmed - no cmdline arg to set such a policy yet) test-opt checks adamw as before and now sgd (except for a few disabled tests for sgd only; probably just needs logging values and adding alternate reference values); tolerance on the 'regression' test is broader for sgd (so we don't need many more epochs)
1 parent fb85a28 commit fcc241e

22 files changed

+631
-180
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
1212
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
1313
endif()
1414

15+
message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
16+
1517
# Add path to modules
1618
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
1719

common/arg.cpp

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,6 +1200,7 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
12001200
common_params_print_completion(ctx_arg);
12011201
exit(0);
12021202
}
1203+
params.lr.init();
12031204
} catch (const std::invalid_argument & ex) {
12041205
fprintf(stderr, "%s\n", ex.what());
12051206
ctx_arg.params = params_org;
@@ -2613,9 +2614,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
26132614
{"-o", "--output", "--output-file"}, "FNAME",
26142615
string_format("output file (default: '%s')", params.out_file.c_str()),
26152616
[](common_params & params, const std::string & value) {
2616-
params.out_file = value;
2617+
params.out_file = value;
26172618
}
2618-
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS}));
2619+
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE}));
26192620
add_opt(common_arg(
26202621
{"-ofreq", "--output-frequency"}, "N",
26212622
string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq),
@@ -3376,5 +3377,53 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
33763377
}
33773378
).set_examples({LLAMA_EXAMPLE_SERVER}));
33783379

3380+
add_opt(
3381+
common_arg({ "-lr", "--learning-rate-initial" }, "ALPHA",
3382+
string_format(
3383+
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
3384+
(double) params.lr.lr0),
3385+
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); })
3386+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3387+
add_opt(
3388+
common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
3389+
string_format(
3390+
"(if >0) final learning rate (default=%.2g)",
3391+
(double) params.lr.lr_min),
3392+
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); })
3393+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3394+
add_opt(common_arg(
3395+
{ "-lr-half", "--learning-rate-halflife" }, "N",
3396+
string_format("half-life (in epochs) for -lr (default: %.3g)", (double) params.lr.halflife_epochs),
3397+
[](common_params & params, const std::string & value) { params.lr.halflife_epochs = std::stof(value); })
3398+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3399+
add_opt(common_arg({ "-lr-halvings", "--learning-rate-halvings" }, "N",
3400+
string_format("max N lr halvings (default: %.3g)", (double) params.lr.halvings),
3401+
[](common_params & params, const std::string & value) { params.lr.halvings = std::stof(value); })
3402+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3403+
add_opt(common_arg(
3404+
{ "-wd", "--weight-decay" }, "WD",
3405+
string_format(
3406+
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
3407+
(double) params.lr.wd),
3408+
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
3409+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3410+
add_opt(common_arg({ "-val", "--val-split" }, "FRACTION",
3411+
string_format("fraction of data to use as validation set for training (default: %.2g).",
3412+
(double) params.val_split),
3413+
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
3414+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3415+
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
3416+
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
3417+
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
3418+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3419+
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
3420+
[](common_params & params, const std::string & name) {
3421+
params.optimizer = common_opt_get_optimizer(name.c_str());
3422+
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
3423+
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
3424+
}
3425+
})
3426+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3427+
33793428
return ctx_arg;
33803429
}

common/common.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,3 +1535,21 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
15351535

15361536
return result;
15371537
}
1538+
1539+
ggml_opt_optimizer_params common_opt_lr_pars(void * userdata) {
1540+
ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);
1541+
const lr_opt & d = *(lr_opt *) userdata;
1542+
result.adamw.alpha = result.sgd.alpha = d.get_lr(d.epoch);
1543+
result.sgd.wd = result.adamw.wd = d.wd;
1544+
return result;
1545+
}
1546+
1547+
GGML_API enum ggml_opt_optimizer_type common_opt_get_optimizer(const char * n) {
1548+
if (!strcasecmp("adamw", n)) {
1549+
return GGML_OPT_OPTIMIZER_TYPE_ADAMW;
1550+
} else if (!strcasecmp("sgd", n)) {
1551+
return GGML_OPT_OPTIMIZER_TYPE_SGD;
1552+
} else {
1553+
return GGML_OPT_OPTIMIZER_TYPE_COUNT;
1554+
}
1555+
}

common/common.h

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
#pragma once
44

5-
#include "llama-cpp.h"
6-
75
#include <set>
6+
#include <sstream>
87
#include <string>
98
#include <string_view>
109
#include <vector>
11-
#include <sstream>
10+
#include <cmath>
11+
12+
#include "ggml-opt.h"
13+
#include "llama-cpp.h"
1214

1315
#ifdef _WIN32
1416
#define DIRECTORY_SEPARATOR '\\'
@@ -80,6 +82,7 @@ enum llama_example {
8082
LLAMA_EXAMPLE_LOOKUP,
8183
LLAMA_EXAMPLE_PARALLEL,
8284
LLAMA_EXAMPLE_TTS,
85+
LLAMA_EXAMPLE_FINETUNE,
8386

8487
LLAMA_EXAMPLE_COUNT,
8588
};
@@ -219,6 +222,42 @@ enum common_reasoning_format {
219222
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
220223
};
221224

225+
226+
static float constexpr k_log_2 = std::log(2.f);
227+
228+
struct lr_opt {
229+
float lr0 = 1e-5; // learning rate at first epoch
230+
float lr_min = 0;
231+
float min_epochs = -1; // if >0, constant (lr_min) after this many epochs
232+
float halflife_epochs = 100;
233+
float halvings = 10;
234+
float epoch = 0;
235+
float wd = 0;
236+
unsigned epochs = 2;
237+
238+
float get_lr(float epoch) const {
239+
float maxepoch = halvings * halflife_epochs;
240+
return lr0 * std::pow(.5, (epoch > maxepoch ? maxepoch : epoch) / halflife_epochs);
241+
}
242+
243+
void init() {
244+
if (lr_min > 0 && lr_min < lr0) {
245+
float nhalf = std::log(lr0 / lr_min) / k_log_2;
246+
halvings = nhalf;
247+
float e = epoch;
248+
if (min_epochs > 0 && min_epochs < e)
249+
e = min_epochs;
250+
halflife_epochs = e / nhalf;
251+
} else if (min_epochs > 0) {
252+
float h = min_epochs / halflife_epochs;
253+
if (h < halvings)
254+
halvings = h;
255+
}
256+
}
257+
};
258+
259+
struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);
260+
222261
struct common_params {
223262
int32_t n_predict = -1; // new tokens to predict
224263
int32_t n_ctx = 4096; // context size
@@ -350,6 +389,12 @@ struct common_params {
350389
bool no_mmproj = false; // explicitly disable multimodal model
351390
std::vector<std::string> image; // path to image file(s)
352391

392+
// finetune
393+
struct lr_opt lr;
394+
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
395+
float val_split = 0.05f; // fraction of the data used for the validation set
396+
std::string opt_save_model_to = "finetuned-model.gguf";
397+
353398
// embedding
354399
bool embedding = false; // get only sentence embedding
355400
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
@@ -671,3 +716,6 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
671716
//
672717

673718
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
719+
720+
// "adamw" or "sgd" (case insensitive)
721+
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);

examples/training/finetune.cpp

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
1-
#include "arg.h"
2-
#include "common.h"
3-
#include "log.h"
4-
#include "llama.h"
5-
61
#include <cmath>
72
#include <cstdio>
83
#include <cstring>
94
#include <ctime>
105
#include <vector>
116

7+
#include "arg.h"
8+
#include "common.h"
9+
#include "llama.h"
10+
#include "log.h"
11+
1212
#if defined(_MSC_VER)
13-
#pragma warning(disable: 4244 4267) // possible loss of data
13+
# pragma warning(disable : 4244 4267) // possible loss of data
1414
#endif
1515

16+
17+
1618
int main(int argc, char ** argv) {
1719
common_params params;
18-
1920
params.escape = false;
2021

21-
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
22+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
2223
return 1;
2324
}
2425

2526
if (params.use_mmap) {
26-
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
27+
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n",
28+
__func__);
2729
params.use_mmap = false;
2830
}
2931
if (params.cache_type_k != GGML_TYPE_F32) {
@@ -38,11 +40,11 @@ int main(int argc, char ** argv) {
3840
common_init();
3941
llama_backend_init();
4042
llama_numa_init(params.numa);
41-
4243
// load the model and apply lora adapter, if any
43-
common_init_result llama_init = common_init_from_params(params);
44-
llama_model_ptr & model = llama_init.model;
45-
llama_context_ptr & ctx = llama_init.context;
44+
common_init_result llama_init = common_init_from_params(params);
45+
llama_model_ptr & model = llama_init.model;
46+
llama_context_ptr & ctx = llama_init.context;
47+
auto pctx = ctx.get();
4648

4749
if (model == NULL) {
4850
LOG_ERR("%s: unable to load model\n", __func__);
@@ -55,31 +57,32 @@ int main(int argc, char ** argv) {
5557
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
5658
}
5759

58-
constexpr float val_split = 0.05f;
59-
60-
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
61-
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
62-
63-
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
64-
optimizer_params.adamw.alpha = 1e-7f; // learning rate
65-
66-
struct llama_opt_params lopt_params {
67-
/*n_ctx_train =*/ 0,
68-
/*param_filter =*/ llama_opt_param_filter_all,
69-
/*param_filter_ud =*/ nullptr,
70-
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
71-
/*get_opt_pars_ud =*/ &optimizer_params,
60+
std::vector<llama_token> tokens = common_tokenize(pctx, params.prompt, true);
61+
ggml_opt_dataset_t dataset = common_opt_dataset_init(pctx, tokens, llama_n_ctx(pctx) / 2);
62+
63+
struct lr_opt & lr = params.lr;
64+
LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -lr-half %.2g -epochs %d -period %.2g -val %.2g\n",
65+
ggml_opt_optimizer_name(params.optimizer), (double) lr.lr0, (double) lr.wd, (double) lr.lr_min, (double) lr.halflife_epochs,
66+
(unsigned) lr.epochs, (double) params.n_batch / params.n_ubatch, (double) params.val_split);
67+
68+
struct llama_opt_params lopt_params{
69+
/*n_ctx_train =*/0,
70+
/*param_filter =*/llama_opt_param_filter_all,
71+
/*param_filter_ud =*/nullptr,
72+
/*get_opt_pars =*/common_opt_lr_pars,
73+
/*get_opt_pars_ud =*/&params.lr,
74+
/*optimizer_type =*/params.optimizer,
7275
};
73-
llama_opt_init(ctx.get(), model.get(), lopt_params);
76+
llama_opt_init(pctx, model.get(), lopt_params);
7477

75-
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
78+
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
7679

7780
ggml_opt_result_t result_train = ggml_opt_result_init();
7881
ggml_opt_result_t result_eval = ggml_opt_result_init();
7982

80-
for (int epoch = 0; epoch < 2; ++epoch) {
81-
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
82-
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
83+
for (unsigned epoch = 0; epoch < lr.epochs; ++epoch) {
84+
llama_opt_epoch(pctx, dataset, result_train, result_eval, idata_split,
85+
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
8386
fprintf(stderr, "\n");
8487

8588
ggml_opt_result_reset(result_train);
@@ -88,7 +91,7 @@ int main(int argc, char ** argv) {
8891
ggml_opt_result_free(result_train);
8992
ggml_opt_result_free(result_eval);
9093

91-
llama_model_save_to_file(model.get(), "finetuned-model.gguf");
94+
llama_model_save_to_file(model.get(), params.out_file.c_str());
9295

9396
llama_backend_free();
9497

ggml/include/ggml-opt.h

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,26 @@ extern "C" {
7474
GGML_OPT_BUILD_TYPE_OPT = 30,
7575
};
7676

77+
enum ggml_opt_optimizer_type {
78+
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
79+
GGML_OPT_OPTIMIZER_TYPE_SGD,
80+
81+
GGML_OPT_OPTIMIZER_TYPE_COUNT
82+
};
83+
7784
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
7885
struct ggml_opt_optimizer_params {
79-
// AdamW optimizer parameters
8086
struct {
81-
float alpha; // learning rate
82-
float beta1;
83-
float beta2;
84-
float eps; // epsilon for numerical stability
85-
float wd; // weight decay for AdamW, use 0.0f to disable
87+
float alpha; // learning rate
88+
float beta1; // first AdamW momentum
89+
float beta2; // second AdamW momentum
90+
float eps; // epsilon for numerical stability
91+
float wd; // weight decay - 0.0f to disable
8692
} adamw;
93+
struct {
94+
float alpha; // learning rate
95+
float wd; // weight decay
96+
} sgd;
8797
};
8898

8999
// callback to calculate optimizer parameters prior to a backward pass
@@ -113,7 +123,10 @@ extern "C" {
113123
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
114124

115125
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
116-
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
126+
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
127+
128+
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW allocates m, v per parameter
129+
enum ggml_opt_optimizer_type optimizer;
117130
};
118131

119132
// get parameters for an optimization context with defaults set where possible
@@ -142,6 +155,10 @@ extern "C" {
142155
// get the gradient accumulator for a node from the forward graph
143156
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);
144157

158+
GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t);
159+
160+
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
161+
145162
// ====== Optimization Result ======
146163

147164
GGML_API ggml_opt_result_t ggml_opt_result_init(void);
@@ -226,12 +243,14 @@ extern "C" {
226243
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
227244
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
228245
enum ggml_opt_loss_type loss_type, // loss to minimize
246+
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
229247
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
230248
int64_t nepoch, // how many times the dataset should be iterated over
231249
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
232250
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
233251
bool silent); // whether or not info prints to stderr should be suppressed
234252

253+
235254
#ifdef __cplusplus
236255
}
237256
#endif

0 commit comments

Comments
 (0)