From f1d5a6d67f88f0fd8ff8c02bf8c103d548adb200 Mon Sep 17 00:00:00 2001 From: Zhongkai Fu Date: Wed, 6 Sep 2023 20:27:14 -0700 Subject: [PATCH] Add inPlace for SiLU activation when it runs forward only. --- Seq2SeqSharp/Layers/MoEFeedForward.cs | 2 +- Seq2SeqSharp/Layers/PositionwiseFeedForward.cs | 2 +- Seq2SeqSharp/Tools/ComputeGraphTensor.cs | 14 ++++++++++++-- Seq2SeqSharp/Tools/IComputeGraph.cs | 2 +- 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/Seq2SeqSharp/Layers/MoEFeedForward.cs b/Seq2SeqSharp/Layers/MoEFeedForward.cs index 014d8b8a..cb7503f2 100644 --- a/Seq2SeqSharp/Layers/MoEFeedForward.cs +++ b/Seq2SeqSharp/Layers/MoEFeedForward.cs @@ -163,7 +163,7 @@ private IWeightTensor RunActivateFunction(IComputeGraph gExp, IWeightTensor toke { if (m_activateFunc == ActivateFuncEnums.SiLU) { - tokenEmbs = gExp.SiLU(tokenEmbs); + tokenEmbs = gExp.SiLU(tokenEmbs, inPlace: true); } else if (m_activateFunc == ActivateFuncEnums.ReLU) { diff --git a/Seq2SeqSharp/Layers/PositionwiseFeedForward.cs b/Seq2SeqSharp/Layers/PositionwiseFeedForward.cs index affb7938..bd39ae35 100644 --- a/Seq2SeqSharp/Layers/PositionwiseFeedForward.cs +++ b/Seq2SeqSharp/Layers/PositionwiseFeedForward.cs @@ -124,7 +124,7 @@ private IWeightTensor RunActivateFunction(IComputeGraph g, IWeightTensor tokenEm { if (m_activateFunc == ActivateFuncEnums.SiLU) { - tokenEmbs = g.SiLU(tokenEmbs); + tokenEmbs = g.SiLU(tokenEmbs, inPlace: true); } else if (m_activateFunc == ActivateFuncEnums.ReLU) { diff --git a/Seq2SeqSharp/Tools/ComputeGraphTensor.cs b/Seq2SeqSharp/Tools/ComputeGraphTensor.cs index 47140921..c38f0c3d 100644 --- a/Seq2SeqSharp/Tools/ComputeGraphTensor.cs +++ b/Seq2SeqSharp/Tools/ComputeGraphTensor.cs @@ -136,10 +136,20 @@ void backward() } - public IWeightTensor SiLU(IWeightTensor w) + public IWeightTensor SiLU(IWeightTensor w, bool inPlace = false) { WeightTensor m = w as WeightTensor; - WeightTensor res = m_weightTensorFactory.CreateWeightTensor(m.Sizes, m_deviceId, name: $"{GetHashString(w.Name)}.SiLU", graphToBind: this, needGradient: m.NeedGradient, dtype: m.ElementType); + + WeightTensor res = null; + // We only enable in-place when we don't need to run back-prop + if (inPlace && !m_needsBackprop) + { + res = m.CopyWeightsRef($"{GetHashString(w.Name)}.SiLU", needGradient: m.NeedGradient, graphToBind: this); + } + else + { + res = m_weightTensorFactory.CreateWeightTensor(m.Sizes, m_deviceId, name: $"{GetHashString(w.Name)}.SiLU", graphToBind: this, needGradient: m.NeedGradient, dtype: m.ElementType); + } VisualizeNodes(w, res); Ops.SiLU(res.TWeight, m.TWeight); diff --git a/Seq2SeqSharp/Tools/IComputeGraph.cs b/Seq2SeqSharp/Tools/IComputeGraph.cs index 9a39dccf..0272b5a5 100644 --- a/Seq2SeqSharp/Tools/IComputeGraph.cs +++ b/Seq2SeqSharp/Tools/IComputeGraph.cs @@ -29,7 +29,7 @@ public interface IComputeGraph : IDisposable IWeightTensor Tanh(IWeightTensor w); IWeightTensor Sigmoid(IWeightTensor w); IWeightTensor ReLU(IWeightTensor w, bool inPlace = false); - IWeightTensor SiLU(IWeightTensor w); + IWeightTensor SiLU(IWeightTensor w, bool inPlace = false); IWeightTensor LeakyReLU(IWeightTensor w, bool inPlace = false); IWeightTensor Affine(IWeightTensor m1, IWeightTensor m2, IWeightTensor mbias, float alpha = 1.0f);