Skip to content

Commit

Permalink
Add LeakyReLU activation function
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Sep 5, 2023
1 parent 54945f1 commit 89b9b25
Show file tree
Hide file tree
Showing 13 changed files with 300 additions and 51 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ You can also keep all parameters into a json file and run Seq2SeqConsole.exe -Co
"ShuffleType": "NoPadding",
"Task": "Train",
"TooLongSequence": "Ignore",
"ActivateFunc": "Relu",
"ActivateFunc": "ReLU",
"LogVerbose": "Normal",
"TgtLang": "TGT",
"TrainCorpusPath": ".\\data\\train",
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Applications/Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ public class Options
[Arg("How to deal with too long sequence. It can be Ignore or Truncation", nameof(TooLongSequence))]
public TooLongSequence TooLongSequence = TooLongSequence.Ignore;

[Arg("Activate function used in the model. It can be Relu or Swish", nameof(ActivateFunc))]
[Arg("Activate function used in the model. It can be ReLU, SiLU and LeakyReLU", nameof(ActivateFunc))]
public ActivateFuncEnums ActivateFunc = ActivateFuncEnums.ReLU;

[Arg("The level of log to output", nameof(LogVerbose))]
Expand Down
24 changes: 22 additions & 2 deletions Seq2SeqSharp/Layers/MoEFeedForward.cs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ public IWeightTensor Process(IWeightTensor input, int batchSize, IComputeGraph g


tokenEmbs = gExp.Mul(tokenEmbs, m_Whd1_i);
tokenEmbs = ((m_activateFunc == ActivateFuncEnums.SiLU) ? gExp.SiLU(tokenEmbs) : gExp.Relu(tokenEmbs, inPlace: true));

tokenEmbs = RunActivateFunction(gExp, tokenEmbs);

tokenEmbs = gExp.Mul(tokenEmbs, m_Whd2_i);
tokenEmbs = g.EltMul(tokenEmbs, topValue_eI);

Expand All @@ -156,7 +158,25 @@ public IWeightTensor Process(IWeightTensor input, int batchSize, IComputeGraph g

return input;
}


private IWeightTensor RunActivateFunction(IComputeGraph gExp, IWeightTensor tokenEmbs)
{
if (m_activateFunc == ActivateFuncEnums.SiLU)
{
tokenEmbs = gExp.SiLU(tokenEmbs);
}
else if (m_activateFunc == ActivateFuncEnums.ReLU)
{
tokenEmbs = gExp.ReLU(tokenEmbs, inPlace: true);
}
else if (m_activateFunc == ActivateFuncEnums.LeakyReLU)
{
tokenEmbs = gExp.LeakyReLU(tokenEmbs, inPlace: true);
}

return tokenEmbs;
}

public virtual List<IWeightTensor> GetParams()
{
List<IWeightTensor> response = new List<IWeightTensor>();
Expand Down
22 changes: 21 additions & 1 deletion Seq2SeqSharp/Layers/PositionwiseFeedForward.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ public IWeightTensor Process(IWeightTensor input, int batchSize, IComputeGraph g
//Feed forward
var ffnResult = feedForwardLayer1.Process(inputNorm, batchSize, g);
// Activate function
var actFFNResult = ((m_activateFunc == ActivateFuncEnums.SiLU) ? g.SiLU(ffnResult) : g.Relu(ffnResult, inPlace: true));
var actFFNResult = RunActivateFunction(g, ffnResult);

var ffn2Result = feedForwardLayer2.Process(actFFNResult, batchSize, g); // Shape: [batchSize * newTokenIdx, input_dim]

//Skip connection and layer normaliztion
Expand Down Expand Up @@ -119,6 +120,25 @@ public IWeightTensor Process(IWeightTensor input, int batchSize, IComputeGraph g

}

private IWeightTensor RunActivateFunction(IComputeGraph g, IWeightTensor tokenEmbs)
{
if (m_activateFunc == ActivateFuncEnums.SiLU)
{
tokenEmbs = g.SiLU(tokenEmbs);
}
else if (m_activateFunc == ActivateFuncEnums.ReLU)
{
tokenEmbs = g.ReLU(tokenEmbs, inPlace: true);
}
else if (m_activateFunc == ActivateFuncEnums.LeakyReLU)
{
tokenEmbs = g.LeakyReLU(tokenEmbs, inPlace: true);
}

return tokenEmbs;
}


public virtual List<IWeightTensor> GetParams()
{
List<IWeightTensor> response = new List<IWeightTensor>();
Expand Down
47 changes: 45 additions & 2 deletions Seq2SeqSharp/Tools/ComputeGraphTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ void backward()
if (m.NeedGradient)
{
res.ReleaseWeight();
Ops.AddSwishD(m.TGradient, m.TGradient, mTWeight, res.TGradient);
Ops.AddSiLUD(m.TGradient, m.TGradient, mTWeight, res.TGradient);

}
res.Dispose();
Expand Down Expand Up @@ -941,8 +941,51 @@ void backward()
return res;
}

public IWeightTensor LeakyReLU(IWeightTensor w, bool inPlace = false)
{
WeightTensor m = w as WeightTensor;
WeightTensor res = null;
if (inPlace)
{
res = m.CopyWeightsRef($"{GetHashString(w.Name)}.LeakyReLU", needGradient: m.NeedGradient, graphToBind: this);
}
else
{
res = m_weightTensorFactory.CreateWeightTensor(m.Sizes, m_deviceId, name: $"{GetHashString(w.Name)}.LeakyReLU", graphToBind: this, needGradient: m.NeedGradient, dtype: m.ElementType);
}
VisualizeNodes(w, res);


Ops.LeakyReLU(res.TWeight, m.TWeight);
if (m_needsBackprop)
{
Tensor mTWeight = m.TWeight.CopyRef();
void backward()
{
if (m.NeedGradient)
{
res.ReleaseWeight();

if (inPlace && m.IsGradientNull() && res.TGradient.IsOwnerExclusive())
{
m.TGradient = res.TGradient.CopyRef();
Ops.LeakyReLUD(m.TGradient, mTWeight, m.TGradient);
}
else
{
Ops.AddLeakyReLUD(m.TGradient, m.TGradient, mTWeight, res.TGradient);
}
}
mTWeight.Dispose();
res.Dispose();
}
m_backprop.Add(backward);
}

return res;
}

public IWeightTensor Relu(IWeightTensor w, bool inPlace = false)
public IWeightTensor ReLU(IWeightTensor w, bool inPlace = false)
{
WeightTensor m = w as WeightTensor;
WeightTensor res = null;
Expand Down
3 changes: 2 additions & 1 deletion Seq2SeqSharp/Tools/IComputeGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ public interface IComputeGraph : IDisposable
IWeightTensor Add(IWeightTensor w1, float v);
IWeightTensor Tanh(IWeightTensor w);
IWeightTensor Sigmoid(IWeightTensor w);
IWeightTensor Relu(IWeightTensor w, bool inPlace = false);
IWeightTensor ReLU(IWeightTensor w, bool inPlace = false);
IWeightTensor SiLU(IWeightTensor w);
IWeightTensor LeakyReLU(IWeightTensor w, bool inPlace = false);

IWeightTensor Affine(IWeightTensor m1, IWeightTensor m2, IWeightTensor mbias, float alpha = 1.0f);
IWeightTensor EltMulMulAdd(IWeightTensor w1, IWeightTensor w2, IWeightTensor w3, IWeightTensor w4);
Expand Down
3 changes: 2 additions & 1 deletion Seq2SeqSharp/Utils/ActivateFuncEnums.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ namespace Seq2SeqSharp.Utils
public enum ActivateFuncEnums
{
ReLU,
SiLU
SiLU,
LeakyReLU

}
}
27 changes: 21 additions & 6 deletions TensorSharp.CUDA/CudaBasicOps.cs
Original file line number Diff line number Diff line change
Expand Up @@ -421,14 +421,29 @@ public static Tensor AddmmBatch(Tensor result, float beta, Tensor src, float alp



[RegisterOpStorageType("Swish", typeof(CudaStorage))]
public Tensor Swish(Tensor result, Tensor src) { return ElementwiseTTOp.Invoke(elementwiseActKernels, "Swish", result, src); }

[RegisterOpStorageType("SwishD", typeof(CudaStorage))]
public Tensor SwishD(Tensor result, Tensor srcW, Tensor resG) { return ElementwiseTTTOp.Invoke(elementwiseActKernels, "SwishD", result, srcW, resG); }
[RegisterOpStorageType("LeakyReLU", typeof(CudaStorage))]
public Tensor LeakyReLU(Tensor result, Tensor src) { return ElementwiseTTOp.Invoke(elementwiseActKernels, "LeakyReLU", result, src); }

[RegisterOpStorageType("AddSwishD", typeof(CudaStorage))]
public Tensor AddSwishD(Tensor result, Tensor srcG, Tensor srcW, Tensor resG) { return ElementwiseTTTTOp.Invoke(elementwiseActKernels, "AddSwishD", result, srcG, srcW, resG); }
[RegisterOpStorageType("LeakyReLUD", typeof(CudaStorage))]
public Tensor LeakyReLUD(Tensor result, Tensor w, Tensor g) { return ElementwiseTTTOp.Invoke(elementwiseActKernels, "LeakyReLUD", result, w, g); }

[RegisterOpStorageType("AddLeakyReLUD", typeof(CudaStorage))]
public Tensor AddLeakyReLUD(Tensor result, Tensor t, Tensor w, Tensor g) { return ElementwiseTTTTOp.Invoke(elementwiseActKernels, "AddLeakyReLUD", result, t, w, g); }






[RegisterOpStorageType("SiLU", typeof(CudaStorage))]
public Tensor SiLU(Tensor result, Tensor src) { return ElementwiseTTOp.Invoke(elementwiseActKernels, "SiLU", result, src); }

[RegisterOpStorageType("SiLUD", typeof(CudaStorage))]
public Tensor SiLUD(Tensor result, Tensor srcW, Tensor resG) { return ElementwiseTTTOp.Invoke(elementwiseActKernels, "SiLUD", result, srcW, resG); }

[RegisterOpStorageType("AddSiLUD", typeof(CudaStorage))]
public Tensor AddSiLUD(Tensor result, Tensor srcG, Tensor srcW, Tensor resG) { return ElementwiseTTTTOp.Invoke(elementwiseActKernels, "AddSiLUD", result, srcG, srcW, resG); }



Expand Down
22 changes: 16 additions & 6 deletions TensorSharp.CUDA/DeviceCode/ElementwiseActKernels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,14 @@ private static string GetFullCode()
AppendTTTFunc(result, "relud", "relud");
AppendTTTTFunc(result, "addrelud", "addrelud");

AppendTTFunc(result, "Swish", "Swish");
AppendTTTFunc(result, "SwishD", "SwishD");
AppendTTTTFunc(result, "AddSwishD", "AddSwishD");
AppendTTFunc(result, "SiLU", "SiLU");
AppendTTTFunc(result, "SiLUD", "SiLUD");
AppendTTTTFunc(result, "AddSiLUD", "AddSiLUD");


AppendTTFunc(result, "LeakyReLU", "LeakyReLU");
AppendTTTFunc(result, "LeakyReLUD", "LeakyReLUD");
AppendTTTTFunc(result, "AddLeakyReLUD", "AddLeakyReLUD");

if (TSCudaContext.ElementType == DType.Float16)
{
Expand All @@ -44,9 +49,14 @@ private static string GetFullCode()
AppendTTTFunc(result, "relud", "relud", DType.Float16);
AppendTTTTFunc(result, "addrelud", "addreludhalf", DType.Float16);

AppendTTFunc(result, "Swish", "SwishHalf", DType.Float16);
AppendTTTFunc(result, "SwishD", "SwishDHalf", DType.Float16);
AppendTTTTFunc(result, "AddSwishD", "AddSwishDHalf", DType.Float16);
AppendTTFunc(result, "SiLU", "SiLUHalf", DType.Float16);
AppendTTTFunc(result, "SiLUD", "SiLUDHalf", DType.Float16);
AppendTTTTFunc(result, "AddSiLUD", "AddSiLUDHalf", DType.Float16);


AppendTTFunc(result, "LeakyReLU", "LeakyReLUHalf", DType.Float16);
AppendTTTFunc(result, "LeakyReLUD", "LeakyReLUDHalf", DType.Float16);
AppendTTTTFunc(result, "AddLeakyReLUD", "AddLeakyReLUDHalf", DType.Float16);
}

return result.ToString();
Expand Down
58 changes: 52 additions & 6 deletions TensorSharp.CUDA/DeviceCode/Headers/Math.cs
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,18 @@ template<typename T> INLINE_FUNC T Lerp(T a, T b, T weight) {
}
template<typename T> INLINE_FUNC T Swish(T w) {
template<typename T> INLINE_FUNC T SiLU(T w) {
return w / (T(1) + expf(-w));
}
template<typename T> INLINE_FUNC T SwishD(T w, T resG) {
template<typename T> INLINE_FUNC T SiLUD(T w, T resG) {
T sig = T(1) / (T(1) + expf(-w));
T grad = sig * (T(1) + w * (T(1) - sig));
return resG * grad;
}
template<typename T> INLINE_FUNC T AddSwishD(T t, T w, T resG) {
template<typename T> INLINE_FUNC T AddSiLUD(T t, T w, T resG) {
T sig = T(1) / (T(1) + expf(-w));
T grad = sig * (T(1) + w * (T(1) - sig));
Expand Down Expand Up @@ -151,6 +151,29 @@ template<typename T> INLINE_FUNC T AddTanh3(T x, T y, T z) {
return t;
}
template <typename T> INLINE_FUNC T LeakyReLU(T w) {
if (w < T(0))
return T(0.01) * w;
return w;
}
template <typename T> INLINE_FUNC T LeakyReLUD(T w, T g) {
if (w >= T(0))
return g;
return T(0.01) * g;
}
template <typename T> INLINE_FUNC T AddLeakyReLUD(T t, T w, T g) {
if (w >= T(0))
return t + g;
return t + T(0.01) * g;
}
template <typename T> INLINE_FUNC T Clamp(T val, T min, T max) {
if (val < min)
return min;
Expand All @@ -171,13 +194,13 @@ template<typename T> INLINE_FUNC T AddTanh3(T x, T y, T z) {
public const string Code16 = @"
#include <cuda_fp16.h>
template<typename T> INLINE_FUNC T SwishHalf(T wh) {
template<typename T> INLINE_FUNC T SiLUHalf(T wh) {
float w = __half2float(wh);
float res = w / (1.0 + expf(-w));
return __float2half(res);
}
template<typename T> INLINE_FUNC T SwishDHalf(T wh, T resGh) {
template<typename T> INLINE_FUNC T SiLUDHalf(T wh, T resGh) {
float w = __half2float(wh);
float resG = __half2float(resGh);
Expand All @@ -187,7 +210,7 @@ template<typename T> INLINE_FUNC T SwishDHalf(T wh, T resGh) {
return __float2half(resG * grad);
}
template<typename T> INLINE_FUNC T AddSwishDHalf(T th, T wh, T resGh) {
template<typename T> INLINE_FUNC T AddSiLUDHalf(T th, T wh, T resGh) {
float t = __half2float(th);
float w = __half2float(wh);
Expand All @@ -205,6 +228,29 @@ template<typename T> INLINE_FUNC T AddSwishDHalf(T th, T wh, T resGh) {
return t;
}
template <typename T> INLINE_FUNC T LeakyReLUHalf(T w) {
if (w < T(0))
return __hmul(T(0.01), w);
return w;
}
template <typename T> INLINE_FUNC T LeakyReLUDHalf(T w, T g) {
if (w >= T(0))
return g;
return __hmul(T(0.01), g);
}
template <typename T> INLINE_FUNC T AddLeakyReLUDHalf(T t, T w, T g) {
if (w >= T(0))
return __hadd(t, g);
return __hadd(t, __hmul(T(0.01), g));
}
";
}
}
Loading

0 comments on commit 89b9b25

Please sign in to comment.