Skip to content

Commit

Permalink
Quantzation AWQ GEMM + GEMV (OpenNMT#1727)
Browse files Browse the repository at this point in the history
* quantzation awq gemm + gemv

* fix pipeline

* fix pipeline

* fix pipeline

* fix dequantize awq

* remove duplicated code
  • Loading branch information
minhthuc2502 authored Jul 4, 2024
1 parent 451c27b commit 39f48f2
Show file tree
Hide file tree
Showing 32 changed files with 1,964 additions and 77 deletions.
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,13 @@ set(SOURCES
src/ops/transpose.cc
src/ops/nccl_ops.cc
src/ops/nccl_ops_cpu.cc
src/ops/awq/dequantize.cc
src/ops/awq/dequantize_cpu.cc
src/ops/awq/gemm.cc
src/ops/awq/gemm_cpu.cc
src/ops/awq/gemv.cc
src/ops/awq/gemv_cpu.cc
src/ops/sum.cc
src/padder.cc
src/profiler.cc
src/random.cc
Expand Down Expand Up @@ -595,6 +602,9 @@ if (WITH_CUDA)
src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu
src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu
src/ops/awq/gemm_gpu.cu
src/ops/awq/gemv_gpu.cu
src/ops/awq/dequantize_gpu.cu
)

set_source_files_properties(
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ The project is production-oriented and comes with [backward compatibility guaran
## Key features

* **Fast and efficient execution on CPU and GPU**<br/>The execution [is significantly faster and requires less resources](#benchmarks) than general-purpose deep learning frameworks on supported models and tasks thanks to many advanced optimizations: layer fusion, padding removal, batch reordering, in-place operations, caching mechanism, etc.
* **Quantization and reduced precision**<br/>The model serialization and computation support weights with [reduced precision](https://opennmt.net/CTranslate2/quantization.html): 16-bit floating points (FP16), 16-bit brain floating points (BF16), 16-bit integers (INT16), and 8-bit integers (INT8).
* **Quantization and reduced precision**<br/>The model serialization and computation support weights with [reduced precision](https://opennmt.net/CTranslate2/quantization.html): 16-bit floating points (FP16), 16-bit brain floating points (BF16), 16-bit integers (INT16), 8-bit integers (INT8) and AWQ quantization (INT4).
* **Multiple CPU architectures support**<br/>The project supports x86-64 and AArch64/ARM64 processors and integrates multiple backends that are optimized for these platforms: [Intel MKL](https://software.intel.com/content/www/us/en/develop/tools/oneapi/components/onemkl.html), [oneDNN](https://github.com/oneapi-src/oneDNN), [OpenBLAS](https://www.openblas.net/), [Ruy](https://github.com/google/ruy), and [Apple Accelerate](https://developer.apple.com/documentation/accelerate).
* **Automatic CPU detection and code dispatch**<br/>One binary can include multiple backends (e.g. Intel MKL and oneDNN) and instruction set architectures (e.g. AVX, AVX2) that are automatically selected at runtime based on the CPU information.
* **Parallel and asynchronous execution**<br/>Multiple batches can be processed in parallel and asynchronously using multiple GPUs or CPU cores.
Expand Down
19 changes: 19 additions & 0 deletions docs/quantization.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Quantization is a technique that can reduce the model size and accelerate its ex
* 16-bit integers (INT16)
* 16-bit floating points (FP16)
* 16-bit brain floating points (BF16)
* 4-bit AWQ Quantization

```{tip}
See the benchmark results in the main [README](https://github.com/OpenNMT/CTranslate2#benchmarks) to compare the performance and memory usage with and without quantization.
Expand Down Expand Up @@ -161,3 +162,21 @@ In this mode, all model weights are stored in half precision and all layers are
* NVIDIA GPU with Compute Capability >= 8.0

In this mode, all model weights are stored in BF16 and all layers are run with this type.

### 4-bit AWQ

The compute type would be `int32_float16`

**Supported on:**

* NVIDIA GPU with Compute Capability >= 7.5

In this mode, all model weights are stored in half precision and all layers are run in half precision. Other parameters like scale and zero are stored in ``int32``.

For example,

```bash
ct2-transformers-converter --model TheBloke/Llama-2-7B-AWQ --copy_files tokenizer.model --output_dir ct2_model
```

We have to quantize the model with AWQ first, then convert it to CT2 format.
3 changes: 3 additions & 0 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,16 +138,19 @@ namespace ctranslate2 {
const StorageView& _weight;
const StorageView* _bias;
const StorageView* _qscale;
const StorageView* _qzero;
const StorageView* _u8_shift_compensation;
StorageView _partial_weight;
StorageView _partial_bias;
StorageView _partial_qscale;
StorageView _partial_u8_shift_compensation;
const DataType _output_type;
const models::QUANTIZATION_TYPE _quant_method;
const bool _quantized_gemm;
const ops::Gemm _gemm_op;
const ops::Quantize _quantize_op;
const ops::Dequantize _dequantize_op;
const ops::ActivationType* _activation_type;
const bool _is_layer_out;
};

Expand Down
17 changes: 16 additions & 1 deletion include/ctranslate2/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@
namespace ctranslate2 {
namespace models {

enum class QUANTIZATION_TYPE {
CT2,
AWQ_GEMM,
AWQ_GEMV
};

static const size_t current_binary_version = 6;

// Checks whether the provided path could contain a CTranslate2 model.
Expand Down Expand Up @@ -90,6 +96,14 @@ namespace ctranslate2 {
return _use_flash_attention;
}

QUANTIZATION_TYPE quant_method() const {
return _quant_method;
}

void set_quant_method(QUANTIZATION_TYPE type) {
_quant_method = type;
}

virtual bool use_global_int16_scale() const {
return true;
}
Expand Down Expand Up @@ -160,7 +174,7 @@ namespace ctranslate2 {

private:
void process_linear_weights();
void set_compute_type(ComputeType type, Device device, int device_index);
void set_compute_type(ComputeType type, Device device, int device_index, bool update_weight=true);
void ensure_dtype(const std::string& name,
StorageView& variable,
const DataType target_dtype);
Expand All @@ -177,6 +191,7 @@ namespace ctranslate2 {
std::unordered_map<std::string, std::shared_ptr<StorageView>> _variable_index;
bool _use_flash_attention = false;
bool _tensor_parallel = false;
QUANTIZATION_TYPE _quant_method = QUANTIZATION_TYPE::CT2;
};

template<>
Expand Down
26 changes: 26 additions & 0 deletions include/ctranslate2/ops/awq/dequantize_awq.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#include "../op.h"

namespace ctranslate2 {
namespace ops {

class DequantizeAwq : public Op {
public:
DequantizeAwq();

void operator()(const StorageView& input,
const StorageView& scale,
const StorageView& zeros,
StorageView& output) const;

private:
template <Device D, typename InT, typename OutT>
void dequantize(const StorageView& input,
const StorageView& scale,
const StorageView& zeros,
StorageView& output) const;
};

}
}
27 changes: 27 additions & 0 deletions include/ctranslate2/ops/awq/gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#pragma once

#include "../activation.h"
#include "../gemm.h"

namespace ctranslate2 {
namespace ops {
class GemmAwq : public Gemm {
public:
using Gemm::Gemm;
void operator()(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c,
const StorageView* bias = nullptr) const;

private:
template <Device D, typename In, typename Out>
void compute(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c) const;
};
}
}
33 changes: 33 additions & 0 deletions include/ctranslate2/ops/awq/gemv.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#pragma once

#include "../activation.h"
#include "../gemm.h"

namespace ctranslate2 {
namespace ops {
class GemvAwq : public Gemm {
public:
using Gemm::Gemm;
void operator()(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c,
const StorageView* bias = nullptr) const;

private:
template <Device D, typename In, typename Out>
void compute_gemv(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c) const;
template <Device D, typename In, typename Out>
void compute_gemv2(const StorageView& a,
const StorageView& b,
const StorageView& scale,
const StorageView& zero,
StorageView& c) const;
};
}
}
3 changes: 2 additions & 1 deletion include/ctranslate2/ops/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ namespace ctranslate2 {
const dim_t k,
const dim_t n,
const float alpha);
protected:
const ActivationType* _activation_type;

private:
float _alpha;
Expand All @@ -47,7 +49,6 @@ namespace ctranslate2 {
bool _trans_b;
bool _a_is_packed;
bool _b_is_packed;
const ActivationType* _activation_type;

template <Device D, typename In, typename Out>
void compute(const StorageView& a,
Expand Down
3 changes: 2 additions & 1 deletion include/ctranslate2/ops/mean.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@ namespace ctranslate2 {

void operator()(const StorageView& input, StorageView& output) const override;

private:
protected:
template <Device D, typename T>
void compute(const StorageView& input,
const dim_t outer_size,
const dim_t axis_size,
const dim_t inner_size,
const bool get_sum,
StorageView& output) const;

const dim_t _axis;
Expand Down
4 changes: 4 additions & 0 deletions include/ctranslate2/ops/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,7 @@
#include "slide.h"
#include "nccl_ops.h"
#include "flash_attention.h"
#include "awq/gemm.h"
#include "awq/gemv.h"
#include "awq/dequantize_awq.h"
#include "sum.h"
17 changes: 17 additions & 0 deletions include/ctranslate2/ops/sum.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include "op.h"
#include "mean.h"

namespace ctranslate2 {
namespace ops {

class Sum : public Mean {
public:
Sum(const dim_t axis);

void operator()(const StorageView& input, StorageView& output) const override;
};

}
}
Loading

0 comments on commit 39f48f2

Please sign in to comment.