Skip to content

Commit

Permalink
Add inPlace for SiLU activation when it runs forward only.
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Sep 7, 2023
1 parent 89b9b25 commit f1d5a6d
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Layers/MoEFeedForward.cs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ private IWeightTensor RunActivateFunction(IComputeGraph gExp, IWeightTensor toke
{
if (m_activateFunc == ActivateFuncEnums.SiLU)
{
tokenEmbs = gExp.SiLU(tokenEmbs);
tokenEmbs = gExp.SiLU(tokenEmbs, inPlace: true);
}
else if (m_activateFunc == ActivateFuncEnums.ReLU)
{
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Layers/PositionwiseFeedForward.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ private IWeightTensor RunActivateFunction(IComputeGraph g, IWeightTensor tokenEm
{
if (m_activateFunc == ActivateFuncEnums.SiLU)
{
tokenEmbs = g.SiLU(tokenEmbs);
tokenEmbs = g.SiLU(tokenEmbs, inPlace: true);
}
else if (m_activateFunc == ActivateFuncEnums.ReLU)
{
Expand Down
14 changes: 12 additions & 2 deletions Seq2SeqSharp/Tools/ComputeGraphTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,20 @@ void backward()
}


public IWeightTensor SiLU(IWeightTensor w)
public IWeightTensor SiLU(IWeightTensor w, bool inPlace = false)
{
WeightTensor m = w as WeightTensor;
WeightTensor res = m_weightTensorFactory.CreateWeightTensor(m.Sizes, m_deviceId, name: $"{GetHashString(w.Name)}.SiLU", graphToBind: this, needGradient: m.NeedGradient, dtype: m.ElementType);

WeightTensor res = null;
// We only enable in-place when we don't need to run back-prop
if (inPlace && !m_needsBackprop)
{
res = m.CopyWeightsRef($"{GetHashString(w.Name)}.SiLU", needGradient: m.NeedGradient, graphToBind: this);
}
else
{
res = m_weightTensorFactory.CreateWeightTensor(m.Sizes, m_deviceId, name: $"{GetHashString(w.Name)}.SiLU", graphToBind: this, needGradient: m.NeedGradient, dtype: m.ElementType);
}

VisualizeNodes(w, res);
Ops.SiLU(res.TWeight, m.TWeight);
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Tools/IComputeGraph.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public interface IComputeGraph : IDisposable
IWeightTensor Tanh(IWeightTensor w);
IWeightTensor Sigmoid(IWeightTensor w);
IWeightTensor ReLU(IWeightTensor w, bool inPlace = false);
IWeightTensor SiLU(IWeightTensor w);
IWeightTensor SiLU(IWeightTensor w, bool inPlace = false);
IWeightTensor LeakyReLU(IWeightTensor w, bool inPlace = false);

IWeightTensor Affine(IWeightTensor m1, IWeightTensor m2, IWeightTensor mbias, float alpha = 1.0f);
Expand Down

0 comments on commit f1d5a6d

Please sign in to comment.