Skip to content

Commit 3a41cf4

Browse files
committed
examples/finetune -opt SGD (stochastic gradient descent) memory opt
add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating m, v tensors. support finetune.cpp arg -opt SGD (or sgd). (default adamw as before) llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch) when using SGD instead of 19gb (55 sec/epoch) using adamw. using larger batch/context - 1728 (possible due to memory savings), finetune (SGD) on 500 lines of wikipedia: train: data=0000039/0000039 loss=0.01601±0.00086 acc=99.73±0.02% t=00:00:40 val: data=0000003/0000003 loss=1.99405±1.09012 acc=72.18±0.66% t=00:00:01 using the same GPU memory, adamw can only do 512 batch/context, reaching: (100 wikipedia lines quickly exactly memorized: train: ... loss=0.00231±0.00032 acc=99.99±0.01% t=00:00:05 val: ... loss=3.91926±nan acc=58.40±2.18% ) note: when finetuning long enough (or w/ enough -lr), validation accuracy eventually drops ('catastrophic forgetting') -lr-half (halflife) option useful for SGD to avoid oscillation or super slow underdamped learning (makes setting -lr more forgiving) note: objective loss not directly comparable between adamw, sgd? - check perplexity or accuracy or consider relative improvements for convergence also, note that logical batch size > physical batch (gradient accumulation) seems unsupported for optimization (limited to physical , unlike in ppx - also limited to ctx-size). training quality/convergence could be improved by implementing (at cost of some memory, but you can make that up by using a much smaller physical batch for a net memory savings). presumably it's physical batch that should be limited to ctx-size? see llama_context::opt_epoch; (opt_period > 1 may already be implemented and would give you multiples of physical batch - added an option for this and oddly didn't see increased gpu memory usage using nvidia-smi when >1) new finetune args -wd 1e-9 to enable weight decay in sgd or adamw, and max -epochs N (default 2 as before) cache (1 - wd*alpha) in 'adamw' opt struct - no noticeable perf benefit, disabled (still done for new SGD though) since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params would probably be able to change between SGD and AdamW with each epoch but would need to use adamw for the first (unconfirmed - no cmdline arg to set such a policy yet) test-opt checks adamw as before and now sgd (except for a few check that just need values collected+added); tolerance on the 'regression' test is lower (weight decay enabled but generally 1st order method converges slower than 2nd)
1 parent 7f4fbe5 commit 3a41cf4

22 files changed

+608
-179
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
1212
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
1313
endif()
1414

15+
message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
16+
1517
# Add path to modules
1618
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
1719

common/arg.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3376,5 +3376,52 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
33763376
}
33773377
).set_examples({LLAMA_EXAMPLE_SERVER}));
33783378

3379+
add_opt(common_arg({ "-save", "--opt-save-model-to" }, "PATH",
3380+
string_format(
3381+
"write optimized model to this filename (default: %s)",
3382+
params.opt_save_model_to.c_str()),
3383+
[](common_params & params, const std::string & value) { params.opt_save_model_to = value; })
3384+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3385+
add_opt(
3386+
common_arg({ "-lr", "--learning-rate" }, "ALPHA",
3387+
string_format(
3388+
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
3389+
(double) params.lr.lr),
3390+
[](common_params & params, const std::string & value) { params.lr.lr = std::stof(value); })
3391+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3392+
add_opt(common_arg(
3393+
{ "-lr-half", "--learning-rate-halflife-epochs" }, "N",
3394+
string_format("reduce lr in half every N epochs (default: %.3g)", (double) params.lr.halflife_epochs),
3395+
[](common_params & params, const std::string & value) { params.lr.halflife_epochs = std::stof(value); })
3396+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3397+
add_opt(common_arg({ "-lr-halvings", "--learning-rate-halvings" }, "N",
3398+
string_format("max N lr halvings (default: %.3g)", (double) params.lr.halvings),
3399+
[](common_params & params, const std::string & value) { params.lr.halvings = std::stof(value); })
3400+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3401+
add_opt(common_arg(
3402+
{ "-wd", "--weight-decay" }, "WD",
3403+
string_format(
3404+
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
3405+
(double) params.lr.wd),
3406+
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
3407+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3408+
add_opt(common_arg({ "-val", "--val-split" }, "FRACTION",
3409+
string_format("portion of data to use as validation when optimizing (default: %.2g).",
3410+
(double) params.val_split),
3411+
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
3412+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3413+
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
3414+
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
3415+
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
3416+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3417+
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
3418+
[](common_params & params, const std::string & name) {
3419+
params.optimizer = ggml_opt_get_optimizer(name.c_str());
3420+
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
3421+
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
3422+
}
3423+
})
3424+
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
3425+
33793426
return ctx_arg;
33803427
}

common/common.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1535,3 +1535,32 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std
15351535

15361536
return result;
15371537
}
1538+
1539+
ggml_opt_optimizer_params common_lr_opt_pars(void * userdata) {
1540+
ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);
1541+
const lr_opt & d = *(lr_opt *) userdata;
1542+
result.adamw.alpha = result.sgd.alpha = d.decayed(d.epoch);
1543+
result.sgd.wd = result.adamw.wd = d.wd;
1544+
return result;
1545+
}
1546+
1547+
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type o) {
1548+
switch (o) {
1549+
case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
1550+
return "adamw";
1551+
case GGML_OPT_OPTIMIZER_TYPE_SGD:
1552+
return "sgd";
1553+
default:
1554+
return "undefined";
1555+
};
1556+
}
1557+
1558+
GGML_API enum ggml_opt_optimizer_type ggml_opt_get_optimizer(const char * n) {
1559+
if (!strcasecmp("adamw", n)) {
1560+
return GGML_OPT_OPTIMIZER_TYPE_ADAMW;
1561+
} else if (!strcasecmp("sgd", n)) {
1562+
return GGML_OPT_OPTIMIZER_TYPE_SGD;
1563+
} else {
1564+
return GGML_OPT_OPTIMIZER_TYPE_COUNT;
1565+
}
1566+
}

common/common.h

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
#pragma once
44

5-
#include "llama-cpp.h"
6-
75
#include <set>
6+
#include <sstream>
87
#include <string>
98
#include <string_view>
109
#include <vector>
11-
#include <sstream>
10+
#include <cmath>
11+
12+
#include "ggml-opt.h"
13+
#include "llama-cpp.h"
1214

1315
#ifdef _WIN32
1416
#define DIRECTORY_SEPARATOR '\\'
@@ -80,6 +82,7 @@ enum llama_example {
8082
LLAMA_EXAMPLE_LOOKUP,
8183
LLAMA_EXAMPLE_PARALLEL,
8284
LLAMA_EXAMPLE_TTS,
85+
LLAMA_EXAMPLE_FINETUNE,
8386

8487
LLAMA_EXAMPLE_COUNT,
8588
};
@@ -219,6 +222,25 @@ enum common_reasoning_format {
219222
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
220223
};
221224

225+
struct lr_decay {
226+
float lr = 1e-5;
227+
float halflife_epochs = 100;
228+
float halvings = 10;
229+
230+
float decayed(float epoch) const {
231+
float maxepoch = halvings * halflife_epochs;
232+
return lr * std::pow(.5, (epoch > maxepoch ? maxepoch : epoch) / halflife_epochs);
233+
}
234+
};
235+
236+
struct lr_opt : lr_decay {
237+
float epoch = 0;
238+
float wd = 0;
239+
unsigned epochs = 2;
240+
};
241+
242+
struct ggml_opt_optimizer_params common_lr_opt_pars(void * userdata);
243+
222244
struct common_params {
223245
int32_t n_predict = -1; // new tokens to predict
224246
int32_t n_ctx = 4096; // context size
@@ -350,6 +372,12 @@ struct common_params {
350372
bool no_mmproj = false; // explicitly disable multimodal model
351373
std::vector<std::string> image; // path to image file(s)
352374

375+
// finetune
376+
struct lr_opt lr;
377+
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
378+
float val_split = 0.05f; // fraction of data used for validation when optimizing
379+
std::string opt_save_model_to = "finetuned-model.gguf";
380+
353381
// embedding
354382
bool embedding = false; // get only sentence embedding
355383
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
@@ -671,3 +699,7 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
671699
//
672700

673701
ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);
702+
703+
// "adamw" or "sgd" (case insensitive)
704+
const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);
705+
enum ggml_opt_optimizer_type ggml_opt_get_optimizer(const char *);

examples/training/finetune.cpp

Lines changed: 36 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,31 @@
1-
#include "arg.h"
2-
#include "common.h"
3-
#include "log.h"
4-
#include "llama.h"
5-
61
#include <cmath>
72
#include <cstdio>
83
#include <cstring>
94
#include <ctime>
105
#include <vector>
116

7+
#include "arg.h"
8+
#include "common.h"
9+
#include "llama.h"
10+
#include "log.h"
11+
1212
#if defined(_MSC_VER)
13-
#pragma warning(disable: 4244 4267) // possible loss of data
13+
# pragma warning(disable : 4244 4267) // possible loss of data
1414
#endif
1515

16+
17+
1618
int main(int argc, char ** argv) {
1719
common_params params;
18-
1920
params.escape = false;
2021

21-
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
22+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
2223
return 1;
2324
}
2425

2526
if (params.use_mmap) {
26-
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
27+
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n",
28+
__func__);
2729
params.use_mmap = false;
2830
}
2931
if (params.cache_type_k != GGML_TYPE_F32) {
@@ -38,11 +40,11 @@ int main(int argc, char ** argv) {
3840
common_init();
3941
llama_backend_init();
4042
llama_numa_init(params.numa);
41-
4243
// load the model and apply lora adapter, if any
43-
common_init_result llama_init = common_init_from_params(params);
44-
llama_model_ptr & model = llama_init.model;
45-
llama_context_ptr & ctx = llama_init.context;
44+
common_init_result llama_init = common_init_from_params(params);
45+
llama_model_ptr & model = llama_init.model;
46+
llama_context_ptr & ctx = llama_init.context;
47+
auto pctx = ctx.get();
4648

4749
if (model == NULL) {
4850
LOG_ERR("%s: unable to load model\n", __func__);
@@ -55,31 +57,32 @@ int main(int argc, char ** argv) {
5557
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
5658
}
5759

58-
constexpr float val_split = 0.05f;
59-
60-
std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
61-
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);
62-
63-
struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
64-
optimizer_params.adamw.alpha = 1e-7f; // learning rate
65-
66-
struct llama_opt_params lopt_params {
67-
/*n_ctx_train =*/ 0,
68-
/*param_filter =*/ llama_opt_param_filter_all,
69-
/*param_filter_ud =*/ nullptr,
70-
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
71-
/*get_opt_pars_ud =*/ &optimizer_params,
60+
std::vector<llama_token> tokens = common_tokenize(pctx, params.prompt, true);
61+
ggml_opt_dataset_t dataset = common_opt_dataset_init(pctx, tokens, llama_n_ctx(pctx) / 2);
62+
63+
auto & lr = params.lr;
64+
LOG_INF("-optimizer %s -lr %.2g -wd %.2g -lr-half %.2g -epochs %d -period %.2g -val %.2g\n",
65+
ggml_opt_optimizer_name(params.optimizer), (double) lr.lr, (double) lr.wd, (double) lr.halflife_epochs,
66+
(unsigned) lr.epochs, (double) params.n_batch / params.n_ubatch, (double) params.val_split);
67+
68+
struct llama_opt_params lopt_params{
69+
/*n_ctx_train =*/0,
70+
/*param_filter =*/llama_opt_param_filter_all,
71+
/*param_filter_ud =*/nullptr,
72+
/*get_opt_pars =*/common_lr_opt_pars,
73+
/*get_opt_pars_ud =*/&params.lr,
74+
/*optimizer_type =*/params.optimizer,
7275
};
73-
llama_opt_init(ctx.get(), model.get(), lopt_params);
76+
llama_opt_init(pctx, model.get(), lopt_params);
7477

75-
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
78+
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
7679

7780
ggml_opt_result_t result_train = ggml_opt_result_init();
7881
ggml_opt_result_t result_eval = ggml_opt_result_init();
7982

80-
for (int epoch = 0; epoch < 2; ++epoch) {
81-
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
82-
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
83+
for (unsigned epoch = 0; epoch < lr.epochs; ++epoch) {
84+
llama_opt_epoch(pctx, dataset, result_train, result_eval, idata_split,
85+
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
8386
fprintf(stderr, "\n");
8487

8588
ggml_opt_result_reset(result_train);
@@ -88,7 +91,7 @@ int main(int argc, char ** argv) {
8891
ggml_opt_result_free(result_train);
8992
ggml_opt_result_free(result_eval);
9093

91-
llama_model_save_to_file(model.get(), "finetuned-model.gguf");
94+
llama_model_save_to_file(model.get(), params.opt_save_model_to.c_str());
9295

9396
llama_backend_free();
9497

ggml/include/ggml-opt.h

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,26 @@ extern "C" {
7474
GGML_OPT_BUILD_TYPE_OPT = 30,
7575
};
7676

77+
enum ggml_opt_optimizer_type {
78+
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
79+
GGML_OPT_OPTIMIZER_TYPE_SGD,
80+
81+
GGML_OPT_OPTIMIZER_TYPE_COUNT
82+
};
83+
7784
// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
7885
struct ggml_opt_optimizer_params {
79-
// AdamW optimizer parameters
8086
struct {
81-
float alpha; // learning rate
82-
float beta1;
83-
float beta2;
84-
float eps; // epsilon for numerical stability
85-
float wd; // weight decay for AdamW, use 0.0f to disable
87+
float alpha; // learning rate
88+
float beta1; // first AdamW momentum
89+
float beta2; // second AdamW momentum
90+
float eps; // epsilon for numerical stability
91+
float wd; // weight decay - 0.0f to disable
8692
} adamw;
93+
struct {
94+
float alpha; // learning rate
95+
float wd; // weight decay
96+
} sgd;
8797
};
8898

8999
// callback to calculate optimizer parameters prior to a backward pass
@@ -113,7 +123,10 @@ extern "C" {
113123
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done
114124

115125
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
116-
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
126+
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
127+
128+
// only GGML_OPT_OPTIMIZER_TYPE_ADAMW allocates m, v per parameter
129+
enum ggml_opt_optimizer_type optimizer;
117130
};
118131

119132
// get parameters for an optimization context with defaults set where possible
@@ -186,7 +199,7 @@ extern "C" {
186199
// The second context should contain all other tensors and will be (re)allocated automatically.
187200
// Due to this automated allocation the data of the second context is not defined when accessed in user code.
188201
// Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors.
189-
// 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead.
202+
// 4. Call ggml_opt_fit. If you need more control (e.g. optimizer sgd) you can use ggml_opt_epoch instead.
190203

191204
// signature for a callback while evaluating opt_ctx on dataset, called after an evaluation
192205
typedef void (*ggml_opt_epoch_callback)(
@@ -226,12 +239,14 @@ extern "C" {
226239
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
227240
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
228241
enum ggml_opt_loss_type loss_type, // loss to minimize
242+
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
229243
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
230244
int64_t nepoch, // how many times the dataset should be iterated over
231245
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
232246
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
233247
bool silent); // whether or not info prints to stderr should be suppressed
234248

249+
GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t);
235250
#ifdef __cplusplus
236251
}
237252
#endif

ggml/include/ggml.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ extern "C" {
450450
GGML_OP_REPEAT_BACK,
451451
GGML_OP_CONCAT,
452452
GGML_OP_SILU_BACK,
453-
GGML_OP_NORM, // normalize
453+
GGML_OP_NORM, // normalize
454454
GGML_OP_RMS_NORM,
455455
GGML_OP_RMS_NORM_BACK,
456456
GGML_OP_GROUP_NORM,
@@ -486,7 +486,7 @@ extern "C" {
486486
GGML_OP_POOL_1D,
487487
GGML_OP_POOL_2D,
488488
GGML_OP_POOL_2D_BACK,
489-
GGML_OP_UPSCALE, // nearest interpolate
489+
GGML_OP_UPSCALE, // nearest interpolate
490490
GGML_OP_PAD,
491491
GGML_OP_PAD_REFLECT_1D,
492492
GGML_OP_ARANGE,
@@ -517,6 +517,7 @@ extern "C" {
517517
GGML_OP_CROSS_ENTROPY_LOSS,
518518
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
519519
GGML_OP_OPT_STEP_ADAMW,
520+
GGML_OP_OPT_STEP_SGD,
520521

521522
GGML_OP_COUNT,
522523
};
@@ -2063,6 +2064,14 @@ extern "C" {
20632064
struct ggml_tensor * v,
20642065
struct ggml_tensor * adamw_params); // parameters such a the learning rate
20652066

2067+
// SGD (with weight decay) step
2068+
GGML_API struct ggml_tensor * ggml_opt_step_sgd(
2069+
// params: alpha (learning rate), wd (weight decay)
2070+
struct ggml_context * ctx,
2071+
struct ggml_tensor * a,
2072+
struct ggml_tensor * grad,
2073+
struct ggml_tensor * adamw_params);
2074+
20662075
//
20672076
// automatic differentiation
20682077
//

0 commit comments

Comments
 (0)