forked from OpenNMT/CTranslate2
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Quantzation AWQ GEMM + GEMV (OpenNMT#1727)
* quantzation awq gemm + gemv * fix pipeline * fix pipeline * fix pipeline * fix dequantize awq * remove duplicated code
- Loading branch information
1 parent
451c27b
commit 39f48f2
Showing
32 changed files
with
1,964 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
#pragma once | ||
|
||
#include "../op.h" | ||
|
||
namespace ctranslate2 { | ||
namespace ops { | ||
|
||
class DequantizeAwq : public Op { | ||
public: | ||
DequantizeAwq(); | ||
|
||
void operator()(const StorageView& input, | ||
const StorageView& scale, | ||
const StorageView& zeros, | ||
StorageView& output) const; | ||
|
||
private: | ||
template <Device D, typename InT, typename OutT> | ||
void dequantize(const StorageView& input, | ||
const StorageView& scale, | ||
const StorageView& zeros, | ||
StorageView& output) const; | ||
}; | ||
|
||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
#pragma once | ||
|
||
#include "../activation.h" | ||
#include "../gemm.h" | ||
|
||
namespace ctranslate2 { | ||
namespace ops { | ||
class GemmAwq : public Gemm { | ||
public: | ||
using Gemm::Gemm; | ||
void operator()(const StorageView& a, | ||
const StorageView& b, | ||
const StorageView& scale, | ||
const StorageView& zero, | ||
StorageView& c, | ||
const StorageView* bias = nullptr) const; | ||
|
||
private: | ||
template <Device D, typename In, typename Out> | ||
void compute(const StorageView& a, | ||
const StorageView& b, | ||
const StorageView& scale, | ||
const StorageView& zero, | ||
StorageView& c) const; | ||
}; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#pragma once | ||
|
||
#include "../activation.h" | ||
#include "../gemm.h" | ||
|
||
namespace ctranslate2 { | ||
namespace ops { | ||
class GemvAwq : public Gemm { | ||
public: | ||
using Gemm::Gemm; | ||
void operator()(const StorageView& a, | ||
const StorageView& b, | ||
const StorageView& scale, | ||
const StorageView& zero, | ||
StorageView& c, | ||
const StorageView* bias = nullptr) const; | ||
|
||
private: | ||
template <Device D, typename In, typename Out> | ||
void compute_gemv(const StorageView& a, | ||
const StorageView& b, | ||
const StorageView& scale, | ||
const StorageView& zero, | ||
StorageView& c) const; | ||
template <Device D, typename In, typename Out> | ||
void compute_gemv2(const StorageView& a, | ||
const StorageView& b, | ||
const StorageView& scale, | ||
const StorageView& zero, | ||
StorageView& c) const; | ||
}; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
#pragma once | ||
|
||
#include "op.h" | ||
#include "mean.h" | ||
|
||
namespace ctranslate2 { | ||
namespace ops { | ||
|
||
class Sum : public Mean { | ||
public: | ||
Sum(const dim_t axis); | ||
|
||
void operator()(const StorageView& input, StorageView& output) const override; | ||
}; | ||
|
||
} | ||
} |
Oops, something went wrong.