Skip to content

finetune.cpp command-line arg #13873

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif()

message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")

# Add path to modules
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")

Expand Down
51 changes: 49 additions & 2 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,7 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
common_params_print_completion(ctx_arg);
exit(0);
}
params.lr.init();
} catch (const std::invalid_argument & ex) {
fprintf(stderr, "%s\n", ex.what());
ctx_arg.params = params_org;
Expand Down Expand Up @@ -2609,9 +2610,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"-o", "--output", "--output-file"}, "FNAME",
string_format("output file (default: '%s')", params.out_file.c_str()),
[](common_params & params, const std::string & value) {
params.out_file = value;
params.out_file = value;
}
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS}));
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_CVECTOR_GENERATOR, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_TTS, LLAMA_EXAMPLE_FINETUNE}));
add_opt(common_arg(
{"-ofreq", "--output-frequency"}, "N",
string_format("output the imatrix every N iterations (default: %d)", params.n_out_freq),
Expand Down Expand Up @@ -3373,5 +3374,51 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
}
).set_examples({LLAMA_EXAMPLE_SERVER}));

add_opt(
common_arg({ "-lr", "--learning-rate-initial" }, "ALPHA",
string_format(
"adamw or sgd optimizer alpha (default: %.2g); note: sgd alpha recommended ~10x (no momentum)",
(double) params.lr.lr0),
[](common_params & params, const std::string & value) { params.lr.lr0 = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(
common_arg({ "-lr-min", "--learning-rate-min" }, "ALPHA",
string_format(
"(if >0) final learning rate (default=%.2g)",
(double) params.lr.lr_min),
[](common_params & params, const std::string & value) { params.lr.lr_min = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(
common_arg({ "-min-epochs", "--learning-rate-min-epochs" }, "ALPHA",
string_format(
"(if >0) reach -lr-min after this many epochs (instead of only at the last) (default=%.2g)",
(double) params.lr.min_epochs),
[](common_params & params, const std::string & value) { params.lr.min_epochs = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{ "-wd", "--weight-decay" }, "WD",
string_format(
"adamw or sgd optimizer weight decay (0 is off; recommend very small e.g. 1e-9) (default: %.2g).",
(double) params.lr.wd),
[](common_params & params, const std::string & value) { params.lr.wd = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-val", "--val-split" }, "FRACTION",
string_format("fraction of data to use as validation set for training (default: %.2g).",
(double) params.val_split),
[](common_params & params, const std::string & value) { params.val_split = std::stof(value); })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-epochs", "--epochs" }, "N",
string_format("optimizer max # of epochs (default: %d)", params.lr.epochs),
[](common_params & params, int epochs) { params.lr.epochs = epochs; })
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg({ "-opt", "--optimizer" }, "sgd|adamw", "adamw or sgd",
[](common_params & params, const std::string & name) {
params.optimizer = common_opt_get_optimizer(name.c_str());
if (params.optimizer == GGML_OPT_OPTIMIZER_TYPE_COUNT) {
throw std::invalid_argument("invalid --optimizer, valid options: adamw, sgd");
}
})
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));

return ctx_arg;
}
38 changes: 38 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1539,3 +1539,41 @@ ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std

return result;
}

ggml_opt_optimizer_params common_opt_lr_pars(void * userdata) {
ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(nullptr);
const lr_opt & d = *(lr_opt *) userdata;
result.adamw.alpha = result.sgd.alpha = d.get_lr(d.epoch);
result.sgd.wd = result.adamw.wd = d.wd;
return result;
}

GGML_API enum ggml_opt_optimizer_type common_opt_get_optimizer(const char * n) {
if (!strcasecmp("adamw", n)) {
return GGML_OPT_OPTIMIZER_TYPE_ADAMW;
} else if (!strcasecmp("sgd", n)) {
return GGML_OPT_OPTIMIZER_TYPE_SGD;
} else {
return GGML_OPT_OPTIMIZER_TYPE_COUNT;
}
}

void lr_opt::init() {
if (lr_min > 0 && lr_min < lr0) {
float nhalf = std::log(lr0 / lr_min) / k_log_2;
float e = epochs;
if (min_epochs > 0 && min_epochs < e)
e = min_epochs;
else
min_epochs = e;
scale_epoch = nhalf / e;
}
}

float lr_opt::get_lr(unsigned epoch) const {
float r = lr_min <= 0 ? lr0 :
epoch >= min_epochs ? lr_min :
lr0 * std::pow(.5, epoch * scale_epoch);
LOG_INF("epoch %.2g lr=%.2g\n", epoch, r);
return r;
}
39 changes: 36 additions & 3 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

#pragma once

#include "llama-cpp.h"

#include <set>
#include <sstream>
#include <string>
#include <string_view>
#include <vector>
#include <sstream>
#include <cmath>

#include "ggml-opt.h"
#include "llama-cpp.h"

#ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\'
Expand Down Expand Up @@ -80,6 +82,7 @@ enum llama_example {
LLAMA_EXAMPLE_LOOKUP,
LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_TTS,
LLAMA_EXAMPLE_FINETUNE,

LLAMA_EXAMPLE_COUNT,
};
Expand Down Expand Up @@ -219,6 +222,27 @@ enum common_reasoning_format {
COMMON_REASONING_FORMAT_DEEPSEEK, // Extract thinking tag contents and return as `message.reasoning_content`, including in streaming deltas.
};


static float constexpr k_log_2 = std::log(2.f);

struct lr_opt {
float lr0 = 1e-5; // learning rate at first epoch
float lr_min = -1;
float min_epochs = -1; // if >0, constant (lr_min) after this many epochs
float scale_epoch = 0;
float wd = 0;
unsigned epochs = 2;

unsigned epoch; // set by optimizer outer (epochs) loop
// learning rate decay - constant LR per epoch only for now
float get_lr(float e) const;
float get_lr() const { return get_lr(epoch); }
// must call after arg parse, before get_lr
void init();
};

struct ggml_opt_optimizer_params common_opt_lr_pars(void * userdata);

struct common_params {
int32_t n_predict = -1; // new tokens to predict
int32_t n_ctx = 4096; // context size
Expand Down Expand Up @@ -350,6 +374,12 @@ struct common_params {
bool no_mmproj = false; // explicitly disable multimodal model
std::vector<std::string> image; // path to image file(s)

// finetune
struct lr_opt lr;
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
float val_split = 0.05f; // fraction of the data used for the validation set
std::string opt_save_model_to = "finetuned-model.gguf";

// embedding
bool embedding = false; // get only sentence embedding
int32_t embd_normalize = 2; // normalisation for embeddings (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
Expand Down Expand Up @@ -670,3 +700,6 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
//

ggml_opt_dataset_t common_opt_dataset_init(struct llama_context * ctx, const std::vector<llama_token> & tokens, int64_t stride);

// "adamw" or "sgd" (case insensitive)
enum ggml_opt_optimizer_type common_opt_get_optimizer(const char *);
69 changes: 36 additions & 33 deletions examples/training/finetune.cpp
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
#include "arg.h"
#include "common.h"
#include "log.h"
#include "llama.h"

#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <vector>

#include "arg.h"
#include "common.h"
#include "llama.h"
#include "log.h"

#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
# pragma warning(disable : 4244 4267) // possible loss of data
#endif



int main(int argc, char ** argv) {
common_params params;

params.escape = false;

if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_PERPLEXITY)) {
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
return 1;
}

if (params.use_mmap) {
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n", __func__);
LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n",
__func__);
params.use_mmap = false;
}
if (params.cache_type_k != GGML_TYPE_F32) {
Expand All @@ -38,11 +40,11 @@ int main(int argc, char ** argv) {
common_init();
llama_backend_init();
llama_numa_init(params.numa);

// load the model and apply lora adapter, if any
common_init_result llama_init = common_init_from_params(params);
llama_model_ptr & model = llama_init.model;
llama_context_ptr & ctx = llama_init.context;
common_init_result llama_init = common_init_from_params(params);
llama_model_ptr & model = llama_init.model;
llama_context_ptr & ctx = llama_init.context;
auto pctx = ctx.get();

if (model == NULL) {
LOG_ERR("%s: unable to load model\n", __func__);
Expand All @@ -55,31 +57,32 @@ int main(int argc, char ** argv) {
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
}

constexpr float val_split = 0.05f;

std::vector<llama_token> tokens = common_tokenize(ctx.get(), params.prompt, true);
ggml_opt_dataset_t dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get())/2);

struct ggml_opt_optimizer_params optimizer_params = ggml_opt_get_default_optimizer_params(nullptr);
optimizer_params.adamw.alpha = 1e-7f; // learning rate

struct llama_opt_params lopt_params {
/*n_ctx_train =*/ 0,
/*param_filter =*/ llama_opt_param_filter_all,
/*param_filter_ud =*/ nullptr,
/*get_opt_pars =*/ ggml_opt_get_constant_optimizer_params,
/*get_opt_pars_ud =*/ &optimizer_params,
std::vector<llama_token> tokens = common_tokenize(pctx, params.prompt, true);
ggml_opt_dataset_t dataset = common_opt_dataset_init(pctx, tokens, llama_n_ctx(pctx) / 2);

struct lr_opt & lr = params.lr;
LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
ggml_opt_optimizer_name(params.optimizer), (double) lr.lr0, (double) lr.wd, (double) lr.lr_min, (double) lr.min_epochs,
(unsigned) lr.epochs, (double) params.n_batch / params.n_ubatch, (double) params.val_split);

struct llama_opt_params lopt_params{
/*n_ctx_train =*/0,
/*param_filter =*/llama_opt_param_filter_all,
/*param_filter_ud =*/nullptr,
/*get_opt_pars =*/common_opt_lr_pars,
/*get_opt_pars_ud =*/&params.lr,
/*optimizer_type =*/params.optimizer,
};
llama_opt_init(ctx.get(), model.get(), lopt_params);
llama_opt_init(pctx, model.get(), lopt_params);

const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - val_split);
const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);

ggml_opt_result_t result_train = ggml_opt_result_init();
ggml_opt_result_t result_eval = ggml_opt_result_init();

for (int epoch = 0; epoch < 2; ++epoch) {
llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) {
llama_opt_epoch(pctx, dataset, result_train, result_eval, idata_split,
ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
fprintf(stderr, "\n");

ggml_opt_result_reset(result_train);
Expand All @@ -88,7 +91,7 @@ int main(int argc, char ** argv) {
ggml_opt_result_free(result_train);
ggml_opt_result_free(result_eval);

llama_model_save_to_file(model.get(), "finetuned-model.gguf");
llama_model_save_to_file(model.get(), params.out_file.c_str());

llama_backend_free();

Expand Down
33 changes: 26 additions & 7 deletions ggml/include/ggml-opt.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,26 @@ extern "C" {
GGML_OPT_BUILD_TYPE_OPT = 30,
};

enum ggml_opt_optimizer_type {
GGML_OPT_OPTIMIZER_TYPE_ADAMW,
GGML_OPT_OPTIMIZER_TYPE_SGD,

GGML_OPT_OPTIMIZER_TYPE_COUNT
};

// parameters that control which optimizer is used and how said optimizer tries to find the minimal loss
struct ggml_opt_optimizer_params {
// AdamW optimizer parameters
struct {
float alpha; // learning rate
float beta1;
float beta2;
float eps; // epsilon for numerical stability
float wd; // weight decay for AdamW, use 0.0f to disable
float alpha; // learning rate
float beta1; // first AdamW momentum
float beta2; // second AdamW momentum
float eps; // epsilon for numerical stability
float wd; // weight decay - 0.0f to disable
} adamw;
struct {
float alpha; // learning rate
float wd; // weight decay
} sgd;
};

// callback to calculate optimizer parameters prior to a backward pass
Expand Down Expand Up @@ -113,7 +123,10 @@ extern "C" {
int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done

ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
void * get_opt_pars_ud; // userdata for calculating optimizer parameters

// only GGML_OPT_OPTIMIZER_TYPE_ADAMW allocates m, v per parameter
enum ggml_opt_optimizer_type optimizer;
};

// get parameters for an optimization context with defaults set where possible
Expand Down Expand Up @@ -142,6 +155,10 @@ extern "C" {
// get the gradient accumulator for a node from the forward graph
GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node);

GGML_API enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t);

GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type);

// ====== Optimization Result ======

GGML_API ggml_opt_result_t ggml_opt_result_init(void);
Expand Down Expand Up @@ -226,12 +243,14 @@ extern "C" {
struct ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used
ggml_opt_dataset_t dataset, // dataset with data and optionally also labels
enum ggml_opt_loss_type loss_type, // loss to minimize
enum ggml_opt_optimizer_type optimizer, // sgd or adamw
ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t)
int64_t nepoch, // how many times the dataset should be iterated over
int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs
float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f)
bool silent); // whether or not info prints to stderr should be suppressed


#ifdef __cplusplus
}
#endif
Loading
Loading