From caff84b57f6163a28c847a5b06eabcfd46c016f2 Mon Sep 17 00:00:00 2001 From: Zhongkai Fu Date: Sat, 16 Sep 2023 13:49:20 -0700 Subject: [PATCH] Add RMSNorm --- README.md | 5 +- Seq2SeqSharp/Applications/Decoder.cs | 4 +- Seq2SeqSharp/Applications/Encoder.cs | 2 +- Seq2SeqSharp/Applications/Options.cs | 3 + .../Applications/SeqClassification.cs | 2 +- Seq2SeqSharp/Applications/SeqLabel.cs | 2 +- Seq2SeqSharp/Applications/SeqSimilarity.cs | 2 +- Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs | 4 +- Seq2SeqSharp/Layers/AttentionUnit.cs | 8 +- Seq2SeqSharp/Layers/FeedForwardLayer.cs | 2 +- Seq2SeqSharp/Layers/INormalization.cs | 9 + .../Layers/LSTMAttentionDecoderCell.cs | 2 +- Seq2SeqSharp/Layers/LSTMCell.cs | 2 +- Seq2SeqSharp/Layers/LayerNormalization.cs | 13 +- Seq2SeqSharp/Layers/MoEFeedForward.cs | 6 +- Seq2SeqSharp/Layers/MultiHeadAttention.cs | 28 +- .../Layers/PositionwiseFeedForward.cs | 17 +- Seq2SeqSharp/Layers/RMSNormalization.cs | 73 ++++ Seq2SeqSharp/Models/IModel.cs | 1 + Seq2SeqSharp/Models/Model.cs | 3 + .../Models/Model_4_ProtoBufSerializer.cs | 2 + Seq2SeqSharp/Networks/GPTDecoder.cs | 28 +- Seq2SeqSharp/Networks/TransformerDecoder.cs | 32 +- Seq2SeqSharp/Networks/TransformerEncoder.cs | 26 +- Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs | 8 +- Seq2SeqSharp/Tools/ComputeGraphTensor.cs | 34 ++ Seq2SeqSharp/Tools/IComputeGraph.cs | 1 + Seq2SeqSharp/Tools/IWeightFactory.cs | 2 +- Seq2SeqSharp/Tools/WeightTensor.cs | 14 +- Seq2SeqSharp/Tools/WeightTensorFactory.cs | 8 +- Seq2SeqSharp/Utils/Misc.cs | 2 +- Seq2SeqSharp/Utils/ModeEnums.cs | 6 + TensorSharp.CUDA/CudaBasicOps.cs | 9 + TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs | 398 ++++++++++++++++++ TensorSharp/Cpu/CpuBasicOps.cs | 25 +- TensorSharp/Ops.cs | 5 + TensorSharp/TensorApplyCPU.cs | 180 +++++++- Tools/Seq2SeqConsole/Program.cs | 2 +- 38 files changed, 887 insertions(+), 83 deletions(-) create mode 100644 Seq2SeqSharp/Layers/INormalization.cs create mode 100644 Seq2SeqSharp/Layers/RMSNormalization.cs diff --git a/README.md b/README.md index 827aa8b0..488fa1ed 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ Mixture of Experts network that could easily train huge model with less computin Support Automatic Mixed Precesion (FP16) Built-in SentencePiece supported Rotary Positional Embeddings +Layer Norm and RMS Norm Python package supported Tags embeddings mechanism Prompted Decoders @@ -193,7 +194,9 @@ You can also keep all parameters into a json file and run Seq2SeqConsole.exe -Co "ShuffleType": "NoPadding", "Task": "Train", "TooLongSequence": "Ignore", - "ActivateFunc": "ReLU", + "ActivateFunc": "LeakyReLU", + "PEType": "RoPE", + "NormType": "LayerNorm", "LogVerbose": "Normal", "TgtLang": "TGT", "TrainCorpusPath": ".\\data\\train", diff --git a/Seq2SeqSharp/Applications/Decoder.cs b/Seq2SeqSharp/Applications/Decoder.cs index dd378976..806a29e6 100644 --- a/Seq2SeqSharp/Applications/Decoder.cs +++ b/Seq2SeqSharp/Applications/Decoder.cs @@ -37,14 +37,14 @@ public static MultiProcessorNetworkWrapper CreateDecoders(IModel model decoder = new MultiProcessorNetworkWrapper( new GPTDecoder("GPTDecoder", model.MultiHeadNum, model.HiddenDim, model.IntermediateDim, model.DecoderEmbeddingDim, model.DecoderLayerDepth, options.DropoutRatio, raDeviceIds.GetNextItem(), isTrainable: options.IsDecoderTrainable && (options.Task == ModeEnums.Train), learningRateFactor: options.DecoderStartLearningRateFactor, activateFunc: model.ActivateFunc, expertNum: model.ExpertNum, - expertsPerTokenFactor: model.ExpertsPerTokenFactor, elementType: elementType, peType:model.PEType), raDeviceIds.ToArray()); + expertsPerTokenFactor: model.ExpertsPerTokenFactor, elementType: elementType, peType:model.PEType, normType: model.NormType), raDeviceIds.ToArray()); } else { decoder = new MultiProcessorNetworkWrapper( new TransformerDecoder("TransformerDecoder", model.MultiHeadNum, model.HiddenDim, model.IntermediateDim, model.DecoderEmbeddingDim, model.DecoderLayerDepth, options.DropoutRatio, raDeviceIds.GetNextItem(), isTrainable: options.IsDecoderTrainable && (options.Task == ModeEnums.Train), learningRateFactor: options.DecoderStartLearningRateFactor, activateFunc: model.ActivateFunc, expertNum: model.ExpertNum, - expertsPerTokenFactor: model.ExpertsPerTokenFactor, elementType: elementType, peType:model.PEType), raDeviceIds.ToArray()); + expertsPerTokenFactor: model.ExpertsPerTokenFactor, elementType: elementType, peType:model.PEType, normType: model.NormType), raDeviceIds.ToArray()); } return decoder; diff --git a/Seq2SeqSharp/Applications/Encoder.cs b/Seq2SeqSharp/Applications/Encoder.cs index 348c4dfb..1a3b7640 100644 --- a/Seq2SeqSharp/Applications/Encoder.cs +++ b/Seq2SeqSharp/Applications/Encoder.cs @@ -52,7 +52,7 @@ public static MultiProcessorNetworkWrapper CreateEncoders(IModel model encoder = new MultiProcessorNetworkWrapper( new TransformerEncoder("TransformerEncoder", model.MultiHeadNum, model.HiddenDim, model.IntermediateDim, model.EncoderEmbeddingDim, model.EncoderLayerDepth, options.DropoutRatio, raDeviceIds.GetNextItem(), isTrainable: options.IsEncoderTrainable, learningRateFactor: options.EncoderStartLearningRateFactor, activateFunc: model.ActivateFunc, expertNum: model.ExpertNum, expertsPerTokenFactor: model.ExpertsPerTokenFactor, - elementType, peType: model.PEType), raDeviceIds.ToArray()); + elementType, peType: model.PEType, normType: model.NormType), raDeviceIds.ToArray()); } return encoder; diff --git a/Seq2SeqSharp/Applications/Options.cs b/Seq2SeqSharp/Applications/Options.cs index bf191424..6cba174e 100644 --- a/Seq2SeqSharp/Applications/Options.cs +++ b/Seq2SeqSharp/Applications/Options.cs @@ -246,6 +246,9 @@ public class Options [Arg("The Positional Embeddings Type. It supports APE, NoPE and RoPE", nameof(PEType))] public PositionEmbeddingEnums PEType = PositionEmbeddingEnums.APE; + [Arg("The type of normalization. It supports LayerNorm and RMSNorm", nameof(NormType))] + public NormEnums NormType = NormEnums.LayerNorm; + public void ValidateOptions() { if (AMP == true && ProcessorType != ProcessorTypeEnums.GPU) diff --git a/Seq2SeqSharp/Applications/SeqClassification.cs b/Seq2SeqSharp/Applications/SeqClassification.cs index 67f5a73c..2f7d5e23 100644 --- a/Seq2SeqSharp/Applications/SeqClassification.cs +++ b/Seq2SeqSharp/Applications/SeqClassification.cs @@ -83,7 +83,7 @@ private bool CreateTrainableParameters(IModel model) (m_positionalEmbeddings, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, m_options.MaxSentLength, model, createAPE: (model.PEType == PositionEmbeddingEnums.APE)); Logger.WriteLine($"Creating embeddings. Shape = '({model.SrcVocab.Count} ,{model.EncoderEmbeddingDim})'"); - m_srcEmbedding = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { model.SrcVocab.Count, model.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, fanOut: true, name: "SrcEmbeddings", + m_srcEmbedding = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { model.SrcVocab.Count, model.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "SrcEmbeddings", isTrainable: m_options.IsEmbeddingTrainable), DeviceIds); return true; diff --git a/Seq2SeqSharp/Applications/SeqLabel.cs b/Seq2SeqSharp/Applications/SeqLabel.cs index 418c98a7..c733b766 100644 --- a/Seq2SeqSharp/Applications/SeqLabel.cs +++ b/Seq2SeqSharp/Applications/SeqLabel.cs @@ -93,7 +93,7 @@ private bool CreateTrainableParameters(IModel model) m_encoder = Encoder.CreateEncoders(model, m_options, raDeviceIds); m_ffLayer = new MultiProcessorNetworkWrapper(new FeedForwardLayer("FeedForward", model.HiddenDim, model.ClsVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: true), DeviceIds); - m_srcEmbedding = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { model.SrcVocab.Count, model.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, name: "SrcEmbeddings", + m_srcEmbedding = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { model.SrcVocab.Count, model.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, name: "SrcEmbeddings", isTrainable: true), DeviceIds); (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, m_options.MaxSentLength, model, createAPE: (model.PEType == PositionEmbeddingEnums.APE)); diff --git a/Seq2SeqSharp/Applications/SeqSimilarity.cs b/Seq2SeqSharp/Applications/SeqSimilarity.cs index 1c3e174a..d16d5512 100644 --- a/Seq2SeqSharp/Applications/SeqSimilarity.cs +++ b/Seq2SeqSharp/Applications/SeqSimilarity.cs @@ -86,7 +86,7 @@ private bool CreateTrainableParameters(IModel model) (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTrainSentLength, m_options.MaxTestSentLength), model, createAPE: (model.PEType == PositionEmbeddingEnums.APE)); Logger.WriteLine($"Creating embeddings. Shape = '({model.SrcVocab.Count} ,{model.EncoderEmbeddingDim})'"); - m_srcEmbedding = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { model.SrcVocab.Count, model.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, fanOut: true, name: "SrcEmbeddings", + m_srcEmbedding = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { model.SrcVocab.Count, model.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "SrcEmbeddings", isTrainable: m_options.IsEmbeddingTrainable), DeviceIds); return true; diff --git a/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs b/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs index 9b68d12d..8a0f3838 100644 --- a/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs +++ b/Seq2SeqSharp/Corpus/Seq2SeqCorpus.cs @@ -18,8 +18,8 @@ namespace Seq2SeqSharp.Corpus public class Seq2SeqCorpus : ParallelCorpus { - public Seq2SeqCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore) - :base (corpusFilePath, srcLangName, tgtLangName, maxTokenSizePerBatch, maxSrcSentLength, maxTgtSentLength, shuffleEnums: shuffleEnums, tooLongSequence: tooLongSequence) + public Seq2SeqCorpus(string corpusFilePath, string srcLangName, string tgtLangName, int maxTokenSizePerBatch, int maxSrcSentLength = 32, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = null) + :base (corpusFilePath, srcLangName, tgtLangName, maxTokenSizePerBatch, maxSrcSentLength, maxTgtSentLength, shuffleEnums: shuffleEnums, tooLongSequence: tooLongSequence, indexedFilePath: indexedFilePath) { } diff --git a/Seq2SeqSharp/Layers/AttentionUnit.cs b/Seq2SeqSharp/Layers/AttentionUnit.cs index c61df40e..b0193bb7 100644 --- a/Seq2SeqSharp/Layers/AttentionUnit.cs +++ b/Seq2SeqSharp/Layers/AttentionUnit.cs @@ -56,15 +56,15 @@ public AttentionUnit(string name, int hiddenDim, int contextDim, int deviceId, b Logger.WriteLine($"Creating attention unit '{name}' HiddenDim = '{hiddenDim}', ContextDim = '{contextDim}', DeviceId = '{deviceId}', EnableCoverageModel = '{enableCoverageModel}'"); - m_Ua = new WeightTensor(new long[2] { contextDim, hiddenDim }, deviceId, normType: NormType.Uniform, name: $"{name}.{nameof(m_Ua)}", isTrainable: isTrainable, dtype: elementType); - m_Wa = new WeightTensor(new long[2] { hiddenDim, hiddenDim }, deviceId, normType: NormType.Uniform, name: $"{name}.{nameof(m_Wa)}", isTrainable: isTrainable, dtype: elementType); + m_Ua = new WeightTensor(new long[2] { contextDim, hiddenDim }, deviceId, initType: RandomInitType.Uniform, name: $"{name}.{nameof(m_Ua)}", isTrainable: isTrainable, dtype: elementType); + m_Wa = new WeightTensor(new long[2] { hiddenDim, hiddenDim }, deviceId, initType: RandomInitType.Uniform, name: $"{name}.{nameof(m_Wa)}", isTrainable: isTrainable, dtype: elementType); m_bUa = new WeightTensor(new long[2] { 1, hiddenDim }, 0, deviceId, name: $"{name}.{nameof(m_bUa)}", isTrainable: isTrainable, dtype: elementType); m_bWa = new WeightTensor(new long[2] { 1, hiddenDim }, 0, deviceId, name: $"{name}.{nameof(m_bWa)}", isTrainable: isTrainable, dtype: elementType); - m_V = new WeightTensor(new long[2] { hiddenDim, 1 }, deviceId, normType: NormType.Uniform, name: $"{name}.{nameof(m_V)}", isTrainable: isTrainable, dtype: elementType); + m_V = new WeightTensor(new long[2] { hiddenDim, 1 }, deviceId, initType: RandomInitType.Uniform, name: $"{name}.{nameof(m_V)}", isTrainable: isTrainable, dtype: elementType); if (m_enableCoverageModel) { - m_Wc = new WeightTensor(new long[2] { k_coverageModelDim, hiddenDim }, deviceId, normType: NormType.Uniform, name: $"{name}.{nameof(m_Wc)}", isTrainable: isTrainable, dtype: elementType); + m_Wc = new WeightTensor(new long[2] { k_coverageModelDim, hiddenDim }, deviceId, initType: RandomInitType.Uniform, name: $"{name}.{nameof(m_Wc)}", isTrainable: isTrainable, dtype: elementType); m_bWc = new WeightTensor(new long[2] { 1, hiddenDim }, 0, deviceId, name: $"{name}.{nameof(m_bWc)}", isTrainable: isTrainable, dtype: elementType); m_coverage = new LSTMCell(name: $"{name}.{nameof(m_coverage)}", hdim: k_coverageModelDim, inputDim: 1 + contextDim + hiddenDim, deviceId: deviceId, isTrainable: isTrainable, elementType: elementType); } diff --git a/Seq2SeqSharp/Layers/FeedForwardLayer.cs b/Seq2SeqSharp/Layers/FeedForwardLayer.cs index a30844d0..0e352eae 100644 --- a/Seq2SeqSharp/Layers/FeedForwardLayer.cs +++ b/Seq2SeqSharp/Layers/FeedForwardLayer.cs @@ -40,7 +40,7 @@ public FeedForwardLayer(string name, int inputDim, int outputDim, float dropoutR m_isTrainable = isTrainable; m_elementType = elementType; - m_Whd = new WeightTensor(new long[2] { inputDim, outputDim }, deviceId, name: $"{name}.{nameof(m_Whd)}", normType: NormType.Uniform, isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); + m_Whd = new WeightTensor(new long[2] { inputDim, outputDim }, deviceId, name: $"{name}.{nameof(m_Whd)}", initType: RandomInitType.Uniform, isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); m_Bd = new WeightTensor(new long[2] { 1, outputDim }, 0, deviceId, name: $"{name}.{nameof(m_Bd)}", isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); } diff --git a/Seq2SeqSharp/Layers/INormalization.cs b/Seq2SeqSharp/Layers/INormalization.cs new file mode 100644 index 00000000..69881b46 --- /dev/null +++ b/Seq2SeqSharp/Layers/INormalization.cs @@ -0,0 +1,9 @@ +using Seq2SeqSharp.Tools; + +namespace Seq2SeqSharp.Layers +{ + internal interface INormalization : INeuralUnit + { + IWeightTensor Norm(IWeightTensor input, IComputeGraph g); + } +} diff --git a/Seq2SeqSharp/Layers/LSTMAttentionDecoderCell.cs b/Seq2SeqSharp/Layers/LSTMAttentionDecoderCell.cs index 40b42bad..c7a207ef 100644 --- a/Seq2SeqSharp/Layers/LSTMAttentionDecoderCell.cs +++ b/Seq2SeqSharp/Layers/LSTMAttentionDecoderCell.cs @@ -40,7 +40,7 @@ public LSTMAttentionDecoderCell(string name, int hiddenDim, int inputDim, int co Logger.WriteLine($"Create LSTM attention decoder cell '{name}' HiddemDim = '{hiddenDim}', InputDim = '{inputDim}', ContextDim = '{contextDim}', DeviceId = '{deviceId}'"); - m_Wxhc = new WeightTensor(new long[2] { inputDim + hiddenDim + contextDim, hiddenDim * 4 }, deviceId, normType: NormType.Uniform, name: $"{name}.{nameof(m_Wxhc)}", isTrainable: isTrainable, dtype: elementType); + m_Wxhc = new WeightTensor(new long[2] { inputDim + hiddenDim + contextDim, hiddenDim * 4 }, deviceId, initType: RandomInitType.Uniform, name: $"{name}.{nameof(m_Wxhc)}", isTrainable: isTrainable, dtype: elementType); m_b = new WeightTensor(new long[2] { 1, hiddenDim * 4 }, 0, deviceId, name: $"{name}.{nameof(m_b)}", isTrainable: isTrainable, dtype: elementType); m_layerNorm1 = new LayerNormalization($"{name}.{nameof(m_layerNorm1)}", hiddenDim * 4, deviceId, isTrainable, elementType: elementType); diff --git a/Seq2SeqSharp/Layers/LSTMCell.cs b/Seq2SeqSharp/Layers/LSTMCell.cs index b0a6155a..cc33a046 100644 --- a/Seq2SeqSharp/Layers/LSTMCell.cs +++ b/Seq2SeqSharp/Layers/LSTMCell.cs @@ -34,7 +34,7 @@ public LSTMCell(string name, int hdim, int inputDim, int deviceId, bool isTraina { m_name = name; - m_Wxh = new WeightTensor(new long[2] { inputDim + hdim, hdim * 4 }, deviceId, normType: NormType.Uniform, name: $"{name}.{nameof(m_Wxh)}", isTrainable: isTrainable, dtype: elementType); + m_Wxh = new WeightTensor(new long[2] { inputDim + hdim, hdim * 4 }, deviceId, initType: RandomInitType.Uniform, name: $"{name}.{nameof(m_Wxh)}", isTrainable: isTrainable, dtype: elementType); m_b = new WeightTensor(new long[2] { 1, hdim * 4 }, 0, deviceId, name: $"{name}.{nameof(m_b)}", isTrainable: isTrainable, dtype: elementType); m_hdim = hdim; diff --git a/Seq2SeqSharp/Layers/LayerNormalization.cs b/Seq2SeqSharp/Layers/LayerNormalization.cs index 2e90c57d..415c8024 100644 --- a/Seq2SeqSharp/Layers/LayerNormalization.cs +++ b/Seq2SeqSharp/Layers/LayerNormalization.cs @@ -8,6 +8,7 @@ // Seq2SeqSharp is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details. +using Seq2SeqSharp.Layers; using Seq2SeqSharp.Tools; using System; using System.Collections.Generic; @@ -16,7 +17,7 @@ namespace Seq2SeqSharp { [Serializable] - internal class LayerNormalization + internal class LayerNormalization : INormalization { private readonly IWeightTensor m_alpha; private readonly IWeightTensor m_beta; @@ -70,5 +71,15 @@ public void Load(IModel stream) m_alpha.Load(stream); m_beta.Load(stream); } + + public INeuralUnit CloneToDeviceAt(int deviceId) + { + throw new NotImplementedException(); + } + + public int GetDeviceId() + { + throw new NotImplementedException(); + } } } diff --git a/Seq2SeqSharp/Layers/MoEFeedForward.cs b/Seq2SeqSharp/Layers/MoEFeedForward.cs index cb7503f2..f2821f61 100644 --- a/Seq2SeqSharp/Layers/MoEFeedForward.cs +++ b/Seq2SeqSharp/Layers/MoEFeedForward.cs @@ -47,10 +47,10 @@ public MoEFeedForward(string name, int expertNum, int hiddenDim, float dropoutRa layerNorm = new LayerNormalization($"{name}.{nameof(layerNorm)}", hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); - m_Whd1 = new WeightTensor(new long[3] { expertNum, hiddenDim, hiddenDim * 4 }, deviceId, name: $"{name}.{nameof(m_Whd1)}", normType: NormType.Uniform, isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); - m_Whd2 = new WeightTensor(new long[3] { expertNum, hiddenDim * 4, hiddenDim }, deviceId, name: $"{name}.{nameof(m_Whd2)}", normType: NormType.Uniform, isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); + m_Whd1 = new WeightTensor(new long[3] { expertNum, hiddenDim, hiddenDim * 4 }, deviceId, name: $"{name}.{nameof(m_Whd1)}", initType: RandomInitType.Uniform, isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); + m_Whd2 = new WeightTensor(new long[3] { expertNum, hiddenDim * 4, hiddenDim }, deviceId, name: $"{name}.{nameof(m_Whd2)}", initType: RandomInitType.Uniform, isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); - m_Router = new WeightTensor(new long[2] { hiddenDim, expertNum }, deviceId, name: $"{name}.{nameof(m_Router)}", normType: NormType.Uniform, isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); + m_Router = new WeightTensor(new long[2] { hiddenDim, expertNum }, deviceId, name: $"{name}.{nameof(m_Router)}", initType: RandomInitType.Uniform, isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); m_RouterBias = new WeightTensor(new long[2] { 1, expertNum }, 0, deviceId, name: $"{name}.{nameof(m_RouterBias)}", isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); } diff --git a/Seq2SeqSharp/Layers/MultiHeadAttention.cs b/Seq2SeqSharp/Layers/MultiHeadAttention.cs index c3d96b7c..358616cb 100644 --- a/Seq2SeqSharp/Layers/MultiHeadAttention.cs +++ b/Seq2SeqSharp/Layers/MultiHeadAttention.cs @@ -8,6 +8,9 @@ // Seq2SeqSharp is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details. +using AdvUtils; +using Seq2SeqSharp.Enums; +using Seq2SeqSharp.Layers; using Seq2SeqSharp.Tools; using Seq2SeqSharp.Utils; using System; @@ -33,7 +36,7 @@ internal class MultiHeadAttention private readonly IWeightTensor QKVb; - private readonly LayerNormalization layerNormQ; + private readonly INormalization layerNormQ; private readonly int m_hiddenDim; private readonly int m_d; @@ -46,7 +49,7 @@ internal class MultiHeadAttention private readonly PositionEmbeddingEnums m_PEType; public MultiHeadAttention(string name, int multiHeadNum, int hiddenDim, int inputDim, float dropoutRatio, int deviceId, bool isTrainable, - bool sharedQKV = false, float learningRateFactor = 1.0f, DType elementType = DType.Float32, PositionEmbeddingEnums peType = PositionEmbeddingEnums.APE) + bool sharedQKV = false, float learningRateFactor = 1.0f, DType elementType = DType.Float32, PositionEmbeddingEnums peType = PositionEmbeddingEnums.APE, NormEnums normType = NormEnums.LayerNorm) { m_name = name; m_hiddenDim = hiddenDim; @@ -56,27 +59,36 @@ public MultiHeadAttention(string name, int multiHeadNum, int hiddenDim, int inpu m_sharedQKV = sharedQKV; m_PEType = peType; - W0 = new WeightTensor(new long[2] { hiddenDim, hiddenDim }, deviceId, name: $"{name}.{nameof(W0)}", isTrainable: isTrainable, normType: NormType.Uniform, learningRateFactor: learningRateFactor, dtype: elementType); + Logger.WriteLine($"Creating multi-head attention layer. Name = '{name}', HiddenDim = '{hiddenDim}', multi-head dim = '{multiHeadNum}', DeviceId = '{deviceId}', Dropout ratio = '{dropoutRatio}', IsTrainable = '{isTrainable}', Learning rate factor = '{learningRateFactor}', PE = '{peType}', Norm = '{normType}'"); + + W0 = new WeightTensor(new long[2] { hiddenDim, hiddenDim }, deviceId, name: $"{name}.{nameof(W0)}", isTrainable: isTrainable, initType: RandomInitType.Uniform, learningRateFactor: learningRateFactor, dtype: elementType); b0 = new WeightTensor(new long[2] { 1, hiddenDim }, 0, deviceId, name: $"{name}.{nameof(b0)}", isTrainable: isTrainable, dtype: elementType); if (m_sharedQKV == false) { - Q = new WeightTensor(new long[2] { inputDim, hiddenDim }, deviceId, name: $"{name}.{nameof(Q)}", isTrainable: isTrainable, normType: NormType.Uniform, learningRateFactor: learningRateFactor, dtype: elementType); + Q = new WeightTensor(new long[2] { inputDim, hiddenDim }, deviceId, name: $"{name}.{nameof(Q)}", isTrainable: isTrainable, initType: RandomInitType.Uniform, learningRateFactor: learningRateFactor, dtype: elementType); Qb = new WeightTensor(new long[2] { 1, hiddenDim }, 0, deviceId, name: $"{name}.{nameof(Qb)}", isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); - K = new WeightTensor(new long[2] { inputDim, hiddenDim }, deviceId, name: $"{name}.{nameof(K)}", isTrainable: isTrainable, normType: NormType.Uniform, learningRateFactor: learningRateFactor, dtype: elementType); + K = new WeightTensor(new long[2] { inputDim, hiddenDim }, deviceId, name: $"{name}.{nameof(K)}", isTrainable: isTrainable, initType: RandomInitType.Uniform, learningRateFactor: learningRateFactor, dtype: elementType); Kb = new WeightTensor(new long[2] { 1, hiddenDim }, 0, deviceId, name: $"{name}.{nameof(Kb)}", isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); - V = new WeightTensor(new long[2] { inputDim, hiddenDim }, deviceId, name: $"{name}.{nameof(V)}", isTrainable: isTrainable, normType: NormType.Uniform, learningRateFactor: learningRateFactor, dtype: elementType); + V = new WeightTensor(new long[2] { inputDim, hiddenDim }, deviceId, name: $"{name}.{nameof(V)}", isTrainable: isTrainable, initType: RandomInitType.Uniform, learningRateFactor: learningRateFactor, dtype: elementType); Vb = new WeightTensor(new long[2] { 1, hiddenDim }, 0, deviceId, name: $"{name}.{nameof(Vb)}", isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); } else { - QKV = new WeightTensor(new long[2] { inputDim, hiddenDim * 3 }, deviceId, name: $"{name}.{nameof(Q)}", isTrainable: isTrainable, normType: NormType.Uniform, learningRateFactor: learningRateFactor, dtype: elementType); + QKV = new WeightTensor(new long[2] { inputDim, hiddenDim * 3 }, deviceId, name: $"{name}.{nameof(Q)}", isTrainable: isTrainable, initType: RandomInitType.Uniform, learningRateFactor: learningRateFactor, dtype: elementType); QKVb = new WeightTensor(new long[2] { 1, hiddenDim * 3 }, 0, deviceId, name: $"{name}.{nameof(Qb)}", isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); } - layerNormQ = new LayerNormalization($"{name}.{nameof(layerNormQ)}", m_hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); + if (normType == NormEnums.LayerNorm) + { + layerNormQ = new LayerNormalization($"{name}.{nameof(layerNormQ)}", m_hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); + } + else + { + layerNormQ = new RMSNormalization($"{name}.{nameof(layerNormQ)}", m_hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); + } } /// diff --git a/Seq2SeqSharp/Layers/PositionwiseFeedForward.cs b/Seq2SeqSharp/Layers/PositionwiseFeedForward.cs index bd39ae35..ab7d0134 100644 --- a/Seq2SeqSharp/Layers/PositionwiseFeedForward.cs +++ b/Seq2SeqSharp/Layers/PositionwiseFeedForward.cs @@ -9,6 +9,7 @@ // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details. using AdvUtils; +using Seq2SeqSharp.Enums; using Seq2SeqSharp.Layers; using Seq2SeqSharp.Tools; using Seq2SeqSharp.Utils; @@ -19,7 +20,7 @@ namespace Seq2SeqSharp { internal class PositionwiseFeedForward : IFeedForwardLayer { - private readonly LayerNormalization layerNorm2; + private readonly INormalization layerNorm2; private readonly FeedForwardLayer feedForwardLayer1; private readonly FeedForwardLayer feedForwardLayer2; @@ -29,16 +30,24 @@ internal class PositionwiseFeedForward : IFeedForwardLayer private ActivateFuncEnums m_activateFunc; - public PositionwiseFeedForward(string name, int hiddenDim, int intermediateDim, float dropoutRatio, int deviceId, bool isTrainable, float learningRateFactor = 1.0f, ActivateFuncEnums activateFunc = ActivateFuncEnums.ReLU, DType elementType = DType.Float32) + public PositionwiseFeedForward(string name, int hiddenDim, int intermediateDim, float dropoutRatio, int deviceId, bool isTrainable, float learningRateFactor = 1.0f, ActivateFuncEnums activateFunc = ActivateFuncEnums.ReLU, DType elementType = DType.Float32, NormEnums normType = NormEnums.LayerNorm) { m_name = name; m_dropoutRatio = dropoutRatio; m_activateFunc = activateFunc; m_elementType= elementType; - Logger.WriteLine($"Creating positionwise feed forward layer. Name = '{name}', HiddenDim = '{hiddenDim}', IntermediateDim = '{intermediateDim}', DeviceId = '{deviceId}', Dropout ratio = '{dropoutRatio}', IsTrainable = '{isTrainable}', Learning rate factor = '{learningRateFactor}', Activate Function = '{activateFunc}'"); + Logger.WriteLine($"Creating positionwise feed forward layer. Name = '{name}', HiddenDim = '{hiddenDim}', IntermediateDim = '{intermediateDim}', DeviceId = '{deviceId}', Dropout ratio = '{dropoutRatio}', IsTrainable = '{isTrainable}', Learning rate factor = '{learningRateFactor}', Activate Function = '{activateFunc}', Norm = '{normType}'"); + + if (normType == NormEnums.LayerNorm) + { + layerNorm2 = new LayerNormalization($"{name}.{nameof(layerNorm2)}", hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); + } + else + { + layerNorm2 = new RMSNormalization($"{name}.{nameof(layerNorm2)}", hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); + } - layerNorm2 = new LayerNormalization($"{name}.{nameof(layerNorm2)}", hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); feedForwardLayer1 = new FeedForwardLayer($"{name}.{nameof(feedForwardLayer1)}", hiddenDim, intermediateDim, m_dropoutRatio, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); feedForwardLayer2 = new FeedForwardLayer($"{name}.{nameof(feedForwardLayer2)}", intermediateDim, hiddenDim, m_dropoutRatio, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); } diff --git a/Seq2SeqSharp/Layers/RMSNormalization.cs b/Seq2SeqSharp/Layers/RMSNormalization.cs new file mode 100644 index 00000000..2dd2605b --- /dev/null +++ b/Seq2SeqSharp/Layers/RMSNormalization.cs @@ -0,0 +1,73 @@ +// Copyright (c) Zhongkai Fu. All rights reserved. +// https://github.com/zhongkaifu/Seq2SeqSharp +// +// This file is part of Seq2SeqSharp. +// +// Seq2SeqSharp is licensed under the BSD-3-Clause license found in the LICENSE file in the root directory of this source tree. +// +// Seq2SeqSharp is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details. + +using Seq2SeqSharp.Layers; +using Seq2SeqSharp.Tools; +using System; +using System.Collections.Generic; +using TensorSharp; + +namespace Seq2SeqSharp +{ + [Serializable] + internal class RMSNormalization : INormalization + { + private readonly IWeightTensor m_alpha; + private readonly IWeightTensor m_beta; + private readonly float m_epsilon; + + public RMSNormalization(string name, int dim, int deviceId, bool isTrainable, float learningRateFactor = 1.0f, float epsilon = 1e-06f, DType elementType = DType.Float32) + { + m_alpha = new WeightTensor(new long[2] { 1, dim }, 1.0f, deviceId, name: $"{name}.{nameof(m_alpha)}", isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); + m_beta = new WeightTensor(new long[2] { 1, dim }, 0, deviceId, name: $"{name}.{nameof(m_beta)}", isTrainable: isTrainable, learningRateFactor: learningRateFactor, dtype: elementType); + m_epsilon = epsilon; + } + + public IWeightTensor Norm(IWeightTensor input, IComputeGraph g) + { + var result = g.RMSNorm(input, m_alpha, m_beta, m_epsilon); + return result; + } + + public virtual List GetParams() + { + List response = new List + { + m_alpha, + m_beta + }; + + return response; + } + + public void Save(IModel stream) + { + m_alpha.Save(stream); + m_beta.Save(stream); + } + + + public void Load(IModel stream) + { + m_alpha.Load(stream); + m_beta.Load(stream); + } + + public INeuralUnit CloneToDeviceAt(int deviceId) + { + throw new NotImplementedException(); + } + + public int GetDeviceId() + { + throw new NotImplementedException(); + } + } +} diff --git a/Seq2SeqSharp/Models/IModel.cs b/Seq2SeqSharp/Models/IModel.cs index f6a5c2cd..a898433b 100644 --- a/Seq2SeqSharp/Models/IModel.cs +++ b/Seq2SeqSharp/Models/IModel.cs @@ -23,6 +23,7 @@ public interface IModel public int EncoderLayerDepth { get; set; } public VQTypeEnums VQType { get; set; } public PositionEmbeddingEnums PEType { get; set; } + public NormEnums NormType { get; set; } public ActivateFuncEnums ActivateFunc { get; set; } public int ExpertNum { get; set; } public int ExpertsPerTokenFactor { get; set; } diff --git a/Seq2SeqSharp/Models/Model.cs b/Seq2SeqSharp/Models/Model.cs index c4989f36..e35bf42a 100644 --- a/Seq2SeqSharp/Models/Model.cs +++ b/Seq2SeqSharp/Models/Model.cs @@ -90,6 +90,7 @@ public Vocab ClsVocab public Dictionary Name2CodeBook { get; set; } public PositionEmbeddingEnums PEType { get; set; } + public NormEnums NormType { get; set; } public Model() { } public Model(Options opts,Vocab srcVocab) @@ -109,6 +110,7 @@ public Model(Options opts,Vocab srcVocab) ActivateFunc = opts.ActivateFunc; VQType = opts.VQType; PEType = opts.PEType; + NormType = opts.NormType; Name2Weights = new Dictionary(); Name2WeightsHalf= new Dictionary(); @@ -139,6 +141,7 @@ public Model(Model_4_ProtoBufSerializer m) Name2WeightsVQ = m.Name2WeightsVQ; Name2CodeBook = m.Name2CodeBook; PEType = m.PEType; + NormType = m.NormType; if (Name2Weights == null) { diff --git a/Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs b/Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs index 1d1c31f2..11a227c6 100644 --- a/Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs +++ b/Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs @@ -283,6 +283,7 @@ public Model_4_ProtoBufSerializer(Model m) ExpertNum = m.ExpertNum; ExpertsPerTokenFactor = m.ExpertsPerTokenFactor; PEType= m.PEType; + NormType = m.NormType; } public static Model_4_ProtoBufSerializer Create(Model m) => new Model_4_ProtoBufSerializer(m); @@ -315,5 +316,6 @@ public Model_4_ProtoBufSerializer(Model m) [ProtoMember(27)] public Dictionary Name2WeightsHalf { get; set; } [ProtoMember(28)] public PositionEmbeddingEnums PEType { get; set; } + [ProtoMember(29)] public NormEnums NormType { get; set; } } } diff --git a/Seq2SeqSharp/Networks/GPTDecoder.cs b/Seq2SeqSharp/Networks/GPTDecoder.cs index 87948ee3..19e7155c 100644 --- a/Seq2SeqSharp/Networks/GPTDecoder.cs +++ b/Seq2SeqSharp/Networks/GPTDecoder.cs @@ -9,6 +9,7 @@ // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details. using AdvUtils; +using Seq2SeqSharp.Enums; using Seq2SeqSharp.Layers; using Seq2SeqSharp.Tools; using Seq2SeqSharp.Utils; @@ -33,18 +34,19 @@ public class GPTDecoder : IDecoder private readonly int m_deviceId; private readonly bool m_isTrainable; private readonly float m_learningRateFactor; - private readonly LayerNormalization layerNorm; + private readonly INormalization layerNorm; private readonly ActivateFuncEnums m_activateFunc; private readonly int m_expertNum; private readonly int m_expertsPerTokenFactor; private readonly DType m_elementType; private readonly PositionEmbeddingEnums m_peType; + private readonly NormEnums m_normType; public GPTDecoder(string name, int multiHeadNum, int hiddenDim, int intermediateDim, int inputDim, int depth, float dropoutRatio, int deviceId, bool isTrainable, float learningRateFactor = 1.0f, ActivateFuncEnums activateFunc = ActivateFuncEnums.ReLU, int expertNum = 1, - int expertsPerTokenFactor = 1, DType elementType = DType.Float32, PositionEmbeddingEnums peType = PositionEmbeddingEnums.APE) + int expertsPerTokenFactor = 1, DType elementType = DType.Float32, PositionEmbeddingEnums peType = PositionEmbeddingEnums.APE, NormEnums normType = NormEnums.LayerNorm) { - Logger.WriteLine($"Creating transformer decoder at device '{deviceId}'. HiddenDim = '{hiddenDim}', IntermediateDim = '{intermediateDim}', InputDim = '{inputDim}', Depth = '{depth}', MultiHeadNum = '{multiHeadNum}', ElementType = '{elementType}', Positional Embedding = '{peType}'"); + Logger.WriteLine($"Creating transformer decoder at device '{deviceId}'. HiddenDim = '{hiddenDim}', IntermediateDim = '{intermediateDim}', InputDim = '{inputDim}', Depth = '{depth}', MultiHeadNum = '{multiHeadNum}', ElementType = '{elementType}', Positional Embedding = '{peType}', Norm = '{normType}'"); m_name = name; m_multiHeadNum = multiHeadNum; @@ -61,6 +63,7 @@ public GPTDecoder(string name, int multiHeadNum, int hiddenDim, int intermediate m_expertsPerTokenFactor = expertsPerTokenFactor; m_elementType= elementType; m_peType = peType; + m_normType = normType; if (hiddenDim != inputDim) { @@ -68,11 +71,11 @@ public GPTDecoder(string name, int multiHeadNum, int hiddenDim, int intermediate } m_selfAttns.Add(new MultiHeadAttention($"{name}.SelfAttn_0", multiHeadNum, hiddenDim, inputDim, m_dropoutRatio, deviceId, - isTrainable: isTrainable, sharedQKV: true, learningRateFactor: learningRateFactor, elementType: elementType, peType: peType)); + isTrainable: isTrainable, sharedQKV: true, learningRateFactor: learningRateFactor, elementType: elementType, peType: peType, normType: normType)); for (int i = 1; i < depth; i++) { m_selfAttns.Add(new MultiHeadAttention($"{name}.SelfAttn_{i}", multiHeadNum, hiddenDim, hiddenDim, m_dropoutRatio, deviceId, - isTrainable: isTrainable, sharedQKV: true, learningRateFactor: learningRateFactor, elementType: elementType, peType: peType)); + isTrainable: isTrainable, sharedQKV: true, learningRateFactor: learningRateFactor, elementType: elementType, peType: peType, normType: normType)); } for (int i = 0; i < depth; i++) @@ -83,13 +86,18 @@ public GPTDecoder(string name, int multiHeadNum, int hiddenDim, int intermediate } else { - m_feedForwards.Add(new PositionwiseFeedForward($"{name}.PosFFN_{i}", hiddenDim, intermediateDim, m_dropoutRatio, deviceId, isTrainable, learningRateFactor: learningRateFactor, activateFunc: activateFunc, elementType: elementType)); + m_feedForwards.Add(new PositionwiseFeedForward($"{name}.PosFFN_{i}", hiddenDim, intermediateDim, m_dropoutRatio, deviceId, isTrainable, learningRateFactor: learningRateFactor, activateFunc: activateFunc, elementType: elementType, normType: normType)); } } - - layerNorm = new LayerNormalization($"{name}.{nameof(layerNorm)}", hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); - + if (normType == NormEnums.LayerNorm) + { + layerNorm = new LayerNormalization($"{name}.{nameof(layerNorm)}", hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); + } + else + { + layerNorm = new RMSNormalization($"{name}.{nameof(layerNorm)}", hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); + } } public int GetDeviceId() @@ -152,7 +160,7 @@ public void Reset(IWeightFactory weightFactory, int batchSize) public INeuralUnit CloneToDeviceAt(int deviceId) { return new GPTDecoder(m_name, m_multiHeadNum, m_hiddenDim, m_intermediateDim, m_inputDim, m_depth, m_dropoutRatio, deviceId, m_isTrainable, learningRateFactor: m_learningRateFactor, activateFunc: m_activateFunc, expertNum: m_expertNum, - expertsPerTokenFactor: m_expertsPerTokenFactor, elementType: m_elementType, peType: m_peType); + expertsPerTokenFactor: m_expertsPerTokenFactor, elementType: m_elementType, peType: m_peType, normType: m_normType); } public List GetParams() diff --git a/Seq2SeqSharp/Networks/TransformerDecoder.cs b/Seq2SeqSharp/Networks/TransformerDecoder.cs index 35a169de..9b3c6140 100644 --- a/Seq2SeqSharp/Networks/TransformerDecoder.cs +++ b/Seq2SeqSharp/Networks/TransformerDecoder.cs @@ -9,6 +9,7 @@ // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details. using AdvUtils; +using Seq2SeqSharp.Enums; using Seq2SeqSharp.Layers; using Seq2SeqSharp.Tools; using Seq2SeqSharp.Utils; @@ -34,18 +35,19 @@ public class TransformerDecoder : IDecoder private readonly int m_deviceId; private readonly bool m_isTrainable; private readonly float m_learningRateFactor; - private readonly LayerNormalization layerNorm; + private readonly INormalization layerNorm; private readonly ActivateFuncEnums m_activateFunc; private readonly int m_expertNum; private readonly int m_expertsPerTokenFactor; private readonly DType m_elementType; private readonly PositionEmbeddingEnums m_peType; + private readonly NormEnums m_normType; public TransformerDecoder(string name, int multiHeadNum, int hiddenDim, int intermediateDim, int inputDim, int depth, float dropoutRatio, int deviceId, bool isTrainable, float learningRateFactor = 1.0f, ActivateFuncEnums activateFunc = ActivateFuncEnums.ReLU, - int expertNum = 1, int expertsPerTokenFactor = 1, DType elementType = DType.Float32, PositionEmbeddingEnums peType = PositionEmbeddingEnums.APE) + int expertNum = 1, int expertsPerTokenFactor = 1, DType elementType = DType.Float32, PositionEmbeddingEnums peType = PositionEmbeddingEnums.APE, NormEnums normType = NormEnums.LayerNorm) { - Logger.WriteLine($"Creating transformer decoder at device '{deviceId}'. HiddenDim = '{hiddenDim}', IntermediateDim = '{intermediateDim}', InputDim = '{inputDim}', Depth = '{depth}', MultiHeadNum = '{multiHeadNum}', ElementType = '{elementType}', Positional Embedding = '{peType}'"); + Logger.WriteLine($"Creating transformer decoder at device '{deviceId}'. HiddenDim = '{hiddenDim}', IntermediateDim = '{intermediateDim}', InputDim = '{inputDim}', Depth = '{depth}', MultiHeadNum = '{multiHeadNum}', ElementType = '{elementType}', Positional Embedding = '{peType}' normType = '{normType}'"); m_name = name; m_multiHeadNum = multiHeadNum; @@ -62,6 +64,7 @@ public TransformerDecoder(string name, int multiHeadNum, int hiddenDim, int inte m_expertsPerTokenFactor = expertsPerTokenFactor; m_elementType = elementType; m_peType = peType; + m_normType = normType; if (hiddenDim != inputDim) { @@ -69,19 +72,19 @@ public TransformerDecoder(string name, int multiHeadNum, int hiddenDim, int inte } m_selfAttns.Add(new MultiHeadAttention($"{name}.SelfAttn_0", multiHeadNum, hiddenDim, inputDim, m_dropoutRatio, deviceId, - isTrainable: isTrainable, sharedQKV: true, learningRateFactor: learningRateFactor, elementType: elementType, peType: peType)); + isTrainable: isTrainable, sharedQKV: true, learningRateFactor: learningRateFactor, elementType: elementType, peType: peType, normType: normType)); for (int i = 1; i < depth; i++) { m_selfAttns.Add(new MultiHeadAttention($"{name}.SelfAttn_{i}", multiHeadNum, hiddenDim, hiddenDim, m_dropoutRatio, deviceId, - isTrainable: isTrainable, sharedQKV: true, learningRateFactor: learningRateFactor, elementType: elementType, peType: peType)); + isTrainable: isTrainable, sharedQKV: true, learningRateFactor: learningRateFactor, elementType: elementType, peType: peType, normType: normType)); } m_encAttns.Add(new MultiHeadAttention($"{name}.EncAttn_0", multiHeadNum, hiddenDim, inputDim, m_dropoutRatio, deviceId, - isTrainable: isTrainable, learningRateFactor: learningRateFactor, elementType: elementType, peType: peType)); + isTrainable: isTrainable, learningRateFactor: learningRateFactor, elementType: elementType, peType: peType, normType: normType)); for (int i = 1; i < depth; i++) { m_encAttns.Add(new MultiHeadAttention($"{name}.EncAttn_{i}", multiHeadNum, hiddenDim, hiddenDim, m_dropoutRatio, deviceId, - isTrainable: isTrainable, learningRateFactor: learningRateFactor, elementType: elementType, peType: peType)); + isTrainable: isTrainable, learningRateFactor: learningRateFactor, elementType: elementType, peType: peType, normType: normType)); } for (int i = 0; i < depth; i++) @@ -92,13 +95,18 @@ public TransformerDecoder(string name, int multiHeadNum, int hiddenDim, int inte } else { - m_feedForwards.Add(new PositionwiseFeedForward($"{name}.PosFFN_{i}", hiddenDim, intermediateDim, m_dropoutRatio, deviceId, isTrainable, learningRateFactor: learningRateFactor, activateFunc: activateFunc, elementType: elementType)); + m_feedForwards.Add(new PositionwiseFeedForward($"{name}.PosFFN_{i}", hiddenDim, intermediateDim, m_dropoutRatio, deviceId, isTrainable, learningRateFactor: learningRateFactor, activateFunc: activateFunc, elementType: elementType, normType: normType)); } } - - layerNorm = new LayerNormalization($"{name}.{nameof(layerNorm)}", hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); - + if (normType == NormEnums.LayerNorm) + { + layerNorm = new LayerNormalization($"{name}.{nameof(layerNorm)}", hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); + } + else + { + layerNorm = new RMSNormalization($"{name}.{nameof(layerNorm)}", hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); + } } public int GetDeviceId() @@ -178,7 +186,7 @@ public void Reset(IWeightFactory weightFactory, int batchSize) public INeuralUnit CloneToDeviceAt(int deviceId) { return new TransformerDecoder(m_name, m_multiHeadNum, m_hiddenDim, m_intermediateDim, m_inputDim, m_depth, m_dropoutRatio, deviceId, m_isTrainable, learningRateFactor: m_learningRateFactor, activateFunc: m_activateFunc, - expertNum: m_expertNum, expertsPerTokenFactor: m_expertsPerTokenFactor, elementType: m_elementType, peType: m_peType); + expertNum: m_expertNum, expertsPerTokenFactor: m_expertsPerTokenFactor, elementType: m_elementType, peType: m_peType, normType: m_normType); } public List GetParams() diff --git a/Seq2SeqSharp/Networks/TransformerEncoder.cs b/Seq2SeqSharp/Networks/TransformerEncoder.cs index 2476f004..9e38636b 100644 --- a/Seq2SeqSharp/Networks/TransformerEncoder.cs +++ b/Seq2SeqSharp/Networks/TransformerEncoder.cs @@ -9,6 +9,7 @@ // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details. using AdvUtils; +using Seq2SeqSharp.Enums; using Seq2SeqSharp.Layers; using Seq2SeqSharp.Tools; using Seq2SeqSharp.Utils; @@ -33,18 +34,19 @@ internal class TransformerEncoder : IEncoder private readonly int m_deviceId; private readonly bool m_isTrainable; private readonly float m_learningRateFactor; - private readonly LayerNormalization layerNorm; + private readonly INormalization layerNorm; private readonly ActivateFuncEnums m_activateFunc; private readonly int m_expertNum; private readonly int m_expertsPerTokenFactor; private readonly DType m_elementType; private readonly PositionEmbeddingEnums m_peType; + private readonly NormEnums m_normType; public TransformerEncoder(string name, int multiHeadNum, int hiddenDim, int intermediateDim, int inputDim, int depth, float dropoutRatio, int deviceId, bool isTrainable, float learningRateFactor = 1.0f, ActivateFuncEnums activateFunc = ActivateFuncEnums.ReLU, - int expertNum = 1, int expertsPerTokenFactor = 1, DType elementType = DType.Float32, PositionEmbeddingEnums peType = PositionEmbeddingEnums.APE) + int expertNum = 1, int expertsPerTokenFactor = 1, DType elementType = DType.Float32, PositionEmbeddingEnums peType = PositionEmbeddingEnums.APE, NormEnums normType = NormEnums.LayerNorm) { - Logger.WriteLine($"Creating transformer encoder at device '{deviceId}'. HiddenDim = '{hiddenDim}', IntermediateDim = '{intermediateDim},' InputDim = '{inputDim}', Depth = '{depth}', MultiHeadNum = '{multiHeadNum}', ElementType = '{elementType}', Positional Embedding = '{peType}'"); + Logger.WriteLine($"Creating transformer encoder at device '{deviceId}'. HiddenDim = '{hiddenDim}', IntermediateDim = '{intermediateDim},' InputDim = '{inputDim}', Depth = '{depth}', MultiHeadNum = '{multiHeadNum}', ElementType = '{elementType}', Positional Embedding = '{peType}' NormType = '{normType}'"); m_name = name; m_multiHeadNum = multiHeadNum; @@ -61,6 +63,7 @@ public TransformerEncoder(string name, int multiHeadNum, int hiddenDim, int inte m_expertsPerTokenFactor = expertsPerTokenFactor; m_elementType = elementType; m_peType= peType; + m_normType = normType; if (hiddenDim != inputDim) { @@ -68,11 +71,11 @@ public TransformerEncoder(string name, int multiHeadNum, int hiddenDim, int inte } m_encoders.Add(new MultiHeadAttention($"{name}.SelfAttn_0", multiHeadNum, hiddenDim, inputDim, m_dropoutRatio, deviceId, - isTrainable: isTrainable, sharedQKV: true, learningRateFactor: learningRateFactor, elementType: elementType, peType:peType)); + isTrainable: isTrainable, sharedQKV: true, learningRateFactor: learningRateFactor, elementType: elementType, peType:peType, normType: normType)); for (int i = 1; i < depth; i++) { m_encoders.Add(new MultiHeadAttention($"{name}.SelfAttn_{i}", multiHeadNum, hiddenDim, hiddenDim, m_dropoutRatio, deviceId, - isTrainable: isTrainable, sharedQKV: true, learningRateFactor: learningRateFactor, elementType: elementType, peType: peType)); + isTrainable: isTrainable, sharedQKV: true, learningRateFactor: learningRateFactor, elementType: elementType, peType: peType, normType: normType)); } for (int i = 0; i < depth; i++) @@ -83,11 +86,18 @@ public TransformerEncoder(string name, int multiHeadNum, int hiddenDim, int inte } else { - m_feedForwards.Add(new PositionwiseFeedForward($"{name}.PosFFN_{i}", hiddenDim, intermediateDim, m_dropoutRatio, deviceId, isTrainable, learningRateFactor: learningRateFactor, activateFunc: activateFunc, elementType: elementType)); + m_feedForwards.Add(new PositionwiseFeedForward($"{name}.PosFFN_{i}", hiddenDim, intermediateDim, m_dropoutRatio, deviceId, isTrainable, learningRateFactor: learningRateFactor, activateFunc: activateFunc, elementType: elementType, normType: normType)); } } - layerNorm = new LayerNormalization($"{name}.{nameof(layerNorm)}", hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); + if (normType == NormEnums.LayerNorm) + { + layerNorm = new LayerNormalization($"{name}.{nameof(layerNorm)}", hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); + } + else + { + layerNorm = new RMSNormalization($"{name}.{nameof(layerNorm)}", hiddenDim, deviceId, isTrainable, learningRateFactor: learningRateFactor, elementType: elementType); + } } @@ -148,7 +158,7 @@ public IWeightTensor Encode(IWeightTensor inputs, int batchSize, IComputeGraph g public INeuralUnit CloneToDeviceAt(int deviceId) { return new TransformerEncoder(m_name, m_multiHeadNum, m_hiddenDim, m_intermediateDim, m_inputDim, m_depth, m_dropoutRatio, deviceId, m_isTrainable, learningRateFactor: m_learningRateFactor, activateFunc: m_activateFunc, - expertNum: m_expertNum, expertsPerTokenFactor: m_expertsPerTokenFactor, elementType: m_elementType, peType: m_peType); + expertNum: m_expertNum, expertsPerTokenFactor: m_expertsPerTokenFactor, elementType: m_elementType, peType: m_peType, normType: m_normType); } public List GetParams() diff --git a/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs b/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs index 3be7aa70..d311081e 100644 --- a/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs +++ b/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs @@ -325,7 +325,7 @@ protected T LoadModelRoutine(Func initializeParametersFunc, { Logger.WriteLine($"Creating shared embeddings for both source side and target side. Shape = '({modelMetaData.SrcVocab.Count} ,{modelMetaData.EncoderEmbeddingDim})'"); srcEmbeddings = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { modelMetaData.SrcVocab.Count, modelMetaData.EncoderEmbeddingDim }, - raDeviceIds.GetNextItem(), normType: NormType.Uniform, fanOut: true, name: "SharedEmbeddings", isTrainable: isSrcEmbeddingTrainable, learningRateFactor: encoderStartLearningRateFactor, dtype: elementType), DeviceIds); + raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "SharedEmbeddings", isTrainable: isSrcEmbeddingTrainable, learningRateFactor: encoderStartLearningRateFactor, dtype: elementType), DeviceIds); tgtEmbeddings = null; } @@ -333,11 +333,11 @@ protected T LoadModelRoutine(Func initializeParametersFunc, { Logger.WriteLine($"Creating embeddings for source side. Shape = '({modelMetaData.SrcVocab.Count} ,{modelMetaData.EncoderEmbeddingDim})'"); srcEmbeddings = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { modelMetaData.SrcVocab.Count, modelMetaData.EncoderEmbeddingDim }, - raDeviceIds.GetNextItem(), normType: NormType.Uniform, fanOut: true, name: "SrcEmbeddings", isTrainable: isSrcEmbeddingTrainable, learningRateFactor: encoderStartLearningRateFactor, dtype: elementType), DeviceIds); + raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "SrcEmbeddings", isTrainable: isSrcEmbeddingTrainable, learningRateFactor: encoderStartLearningRateFactor, dtype: elementType), DeviceIds); Logger.WriteLine($"Creating embeddings for target side. Shape = '({modelMetaData.TgtVocab.Count} ,{modelMetaData.DecoderEmbeddingDim})'"); tgtEmbeddings = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { modelMetaData.TgtVocab.Count, modelMetaData.DecoderEmbeddingDim }, - raDeviceIds.GetNextItem(), normType: NormType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: isTgtEmbeddingTrainable, learningRateFactor: decoderStartLearningRateFactor, dtype: elementType), DeviceIds); + raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: isTgtEmbeddingTrainable, learningRateFactor: decoderStartLearningRateFactor, dtype: elementType), DeviceIds); } return (srcEmbeddings, tgtEmbeddings); @@ -347,7 +347,7 @@ internal MultiProcessorNetworkWrapper CreateTgtEmbeddings(IModel { Logger.WriteLine($"Creating embeddings for target side. Shape = '({modelMetaData.TgtVocab.Count} ,{modelMetaData.DecoderEmbeddingDim})'"); var tgtEmbeddings = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { modelMetaData.TgtVocab.Count, modelMetaData.DecoderEmbeddingDim }, - raDeviceIds.GetNextItem(), normType: NormType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: isTgtEmbeddingTrainable, learningRateFactor: decoderStartLearningRateFactor, dtype: elementType), DeviceIds); + raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "TgtEmbeddings", isTrainable: isTgtEmbeddingTrainable, learningRateFactor: decoderStartLearningRateFactor, dtype: elementType), DeviceIds); return tgtEmbeddings; } diff --git a/Seq2SeqSharp/Tools/ComputeGraphTensor.cs b/Seq2SeqSharp/Tools/ComputeGraphTensor.cs index c38f0c3d..7d453369 100644 --- a/Seq2SeqSharp/Tools/ComputeGraphTensor.cs +++ b/Seq2SeqSharp/Tools/ComputeGraphTensor.cs @@ -2184,6 +2184,40 @@ void backward() } + public IWeightTensor RMSNorm(IWeightTensor src, IWeightTensor alpha, IWeightTensor beta, float eps = 1e-9f) + { + WeightTensor srcT = src as WeightTensor; + WeightTensor alphaT = alpha as WeightTensor; + WeightTensor betaT = beta as WeightTensor; + + WeightTensor res = m_weightTensorFactory.CreateWeightTensor(srcT.Sizes, m_deviceId, name: $"{GetHashString(src.Name, alpha.Name, beta.Name)}.RMSNorm", graphToBind: this, needGradient: srcT.NeedGradient, dtype: src.ElementType); + VisualizeNodes(new IWeightTensor[] { src, alpha, beta }, res); + + Ops.RMSNorm(res.TWeight, srcT.TWeight, alphaT.TWeight, betaT.TWeight, eps); + if (m_needsBackprop) + { + var srcTWeight = srcT.TWeight.CopyRef(); + var resTWeight = res.TWeight.CopyRef(); + void backward() + { + if (srcT.NeedGradient) + { + Ops.RMSNormGrad(srcT.TGradient, alphaT.TGradient, betaT.TGradient, res.TGradient, resTWeight, srcTWeight, alphaT.TWeight, betaT.TWeight, eps); + } + srcTWeight.Dispose(); + resTWeight.Dispose(); + + res.Dispose(); + } + m_backprop.Add(backward); + + alphaT.UnbindFromComputeGraph(); + betaT.UnbindFromComputeGraph(); + } + + return res; + } + ///// ///// LayerNorm (src1 + src2) diff --git a/Seq2SeqSharp/Tools/IComputeGraph.cs b/Seq2SeqSharp/Tools/IComputeGraph.cs index 0272b5a5..3a21aba2 100644 --- a/Seq2SeqSharp/Tools/IComputeGraph.cs +++ b/Seq2SeqSharp/Tools/IComputeGraph.cs @@ -53,6 +53,7 @@ public interface IComputeGraph : IDisposable IWeightTensor Transpose(IWeightTensor w); IWeightTensor Mul(IWeightTensor w, float v, bool inPlace = false); IWeightTensor LayerNorm(IWeightTensor src, IWeightTensor alpha, IWeightTensor beta, float eps = 1e-9f); + IWeightTensor RMSNorm(IWeightTensor src, IWeightTensor alpha, IWeightTensor beta, float eps = 1e-9f); IWeightTensor Select(IWeightTensor src, int dim, int index); void Backward(); diff --git a/Seq2SeqSharp/Tools/IWeightFactory.cs b/Seq2SeqSharp/Tools/IWeightFactory.cs index a84313e6..cd7c0cee 100644 --- a/Seq2SeqSharp/Tools/IWeightFactory.cs +++ b/Seq2SeqSharp/Tools/IWeightFactory.cs @@ -15,6 +15,6 @@ namespace Seq2SeqSharp.Tools { public interface IWeightFactory : IDisposable { - WeightTensor CreateWeightTensor(int row, int column, int deviceId, bool cleanWeights = false, string name = "", bool isTrainable = false, IComputeGraph graphToBind = null, NormType normType = NormType.None, bool needGradient = true, DType dtype = DType.Float32); + WeightTensor CreateWeightTensor(int row, int column, int deviceId, bool cleanWeights = false, string name = "", bool isTrainable = false, IComputeGraph graphToBind = null, RandomInitType normType = RandomInitType.None, bool needGradient = true, DType dtype = DType.Float32); } } diff --git a/Seq2SeqSharp/Tools/WeightTensor.cs b/Seq2SeqSharp/Tools/WeightTensor.cs index f0c27278..8086ad4f 100644 --- a/Seq2SeqSharp/Tools/WeightTensor.cs +++ b/Seq2SeqSharp/Tools/WeightTensor.cs @@ -20,7 +20,7 @@ namespace Seq2SeqSharp.Tools { - public enum NormType + public enum RandomInitType { None, Uniform, @@ -66,7 +66,7 @@ public int Columns private readonly bool m_fanIn = false; private readonly bool m_fanOut = false; - private readonly NormType m_normType = NormType.None; + private readonly RandomInitType m_normType = RandomInitType.None; private readonly DType m_elementType = DType.Float32; @@ -166,7 +166,7 @@ public Tensor TGradient public DType ElementType => m_elementType; - public WeightTensor(long[] sizes, int deviceId, string name = "", bool isTrainable = false, NormType normType = NormType.None, bool fanIn = false, bool fanOut = false, float learningRateFactor = 1.0f, IComputeGraph graphToBind = null, bool needGradient = true, DType dtype = DType.Float32) + public WeightTensor(long[] sizes, int deviceId, string name = "", bool isTrainable = false, RandomInitType initType = RandomInitType.None, bool fanIn = false, bool fanOut = false, float learningRateFactor = 1.0f, IComputeGraph graphToBind = null, bool needGradient = true, DType dtype = DType.Float32) { Name = name; DeviceId = deviceId; @@ -177,7 +177,7 @@ public WeightTensor(long[] sizes, int deviceId, string name = "", bool isTrainab Sizes = sizes; m_fanIn = fanIn; m_fanOut = fanOut; - m_normType = normType; + m_normType = initType; m_elementType= dtype; if (graphToBind != null) @@ -188,7 +188,7 @@ public WeightTensor(long[] sizes, int deviceId, string name = "", bool isTrainab if (isTrainable) { - if (normType == NormType.Uniform) + if (initType == RandomInitType.Uniform) { var scale = (float)Math.Sqrt(6.0 / (double)(Rows + Columns)); @@ -204,7 +204,7 @@ public WeightTensor(long[] sizes, int deviceId, string name = "", bool isTrainab float[] w = TensorSharp.RandomGenerator.BuildRandomUniformWeight(Sizes, -scale, scale); SetWeightArray(w); } - else if (normType == NormType.Normal) + else if (initType == RandomInitType.Normal) { float[] w = TensorSharp.RandomGenerator.BuildRandomUniformWeight(Sizes, -1.0f, 1.0f); SetWeightArray(w); @@ -245,7 +245,7 @@ public int GetDeviceId() public INeuralUnit CloneToDeviceAt(int deviceId) { - return new WeightTensor(Sizes, deviceId, Name, IsTrainable, normType: m_normType, fanIn: m_fanIn, fanOut: m_fanOut, needGradient: NeedGradient, dtype: m_elementType); + return new WeightTensor(Sizes, deviceId, Name, IsTrainable, initType: m_normType, fanIn: m_fanIn, fanOut: m_fanOut, needGradient: NeedGradient, dtype: m_elementType); } public void ZeroGradient() diff --git a/Seq2SeqSharp/Tools/WeightTensorFactory.cs b/Seq2SeqSharp/Tools/WeightTensorFactory.cs index 63659648..7229c0f7 100644 --- a/Seq2SeqSharp/Tools/WeightTensorFactory.cs +++ b/Seq2SeqSharp/Tools/WeightTensorFactory.cs @@ -17,9 +17,9 @@ public class WeightTensorFactory : IWeightFactory { private readonly List weights = new List(); - public WeightTensor CreateWeightTensor(int row, int column, int deviceId, bool cleanWeights = false, string name = "", bool isTrainable = false, IComputeGraph graphToBind = null, NormType normType = NormType.None, bool needGradient = true, DType dtype = DType.Float32) + public WeightTensor CreateWeightTensor(int row, int column, int deviceId, bool cleanWeights = false, string name = "", bool isTrainable = false, IComputeGraph graphToBind = null, RandomInitType normType = RandomInitType.None, bool needGradient = true, DType dtype = DType.Float32) { - WeightTensor r = new WeightTensor(new long[2] { row, column }, deviceId, name: name, isTrainable: isTrainable, normType: normType, graphToBind: graphToBind, needGradient: needGradient, dtype: dtype); + WeightTensor r = new WeightTensor(new long[2] { row, column }, deviceId, name: name, isTrainable: isTrainable, initType: normType, graphToBind: graphToBind, needGradient: needGradient, dtype: dtype); if (cleanWeights) { @@ -31,9 +31,9 @@ public WeightTensor CreateWeightTensor(int row, int column, int deviceId, bool c return r; } - public WeightTensor CreateWeightTensor(long[] sizes, int deviceId, bool cleanWeights = false, string name = "", IComputeGraph graphToBind = null, NormType normType = NormType.None, bool needGradient = true, DType dtype = DType.Float32) + public WeightTensor CreateWeightTensor(long[] sizes, int deviceId, bool cleanWeights = false, string name = "", IComputeGraph graphToBind = null, RandomInitType normType = RandomInitType.None, bool needGradient = true, DType dtype = DType.Float32) { - WeightTensor r = new WeightTensor(sizes, deviceId, name, normType: normType, graphToBind: graphToBind, needGradient: needGradient, dtype: dtype); + WeightTensor r = new WeightTensor(sizes, deviceId, name, initType: normType, graphToBind: graphToBind, needGradient: needGradient, dtype: dtype); if (cleanWeights) { diff --git a/Seq2SeqSharp/Utils/Misc.cs b/Seq2SeqSharp/Utils/Misc.cs index fed74519..cad834e7 100644 --- a/Seq2SeqSharp/Utils/Misc.cs +++ b/Seq2SeqSharp/Utils/Misc.cs @@ -132,7 +132,7 @@ public static (MultiProcessorNetworkWrapper, MultiProcessorNetwor if (modelMetaData.EnableSegmentEmbeddings) { - segmentEmbeddings = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { modelMetaData.MaxSegmentNum, modelMetaData.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, name: "SegmentEmbedding", + segmentEmbeddings = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { modelMetaData.MaxSegmentNum, modelMetaData.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, name: "SegmentEmbedding", isTrainable: isTrainable, dtype: elementType), raDeviceIds.ToArray()); } } diff --git a/Seq2SeqSharp/Utils/ModeEnums.cs b/Seq2SeqSharp/Utils/ModeEnums.cs index 2e804ea2..c726289a 100644 --- a/Seq2SeqSharp/Utils/ModeEnums.cs +++ b/Seq2SeqSharp/Utils/ModeEnums.cs @@ -44,4 +44,10 @@ public enum VQTypeEnums INT8 = 256, INT4 = 16 } + + public enum NormEnums + { + LayerNorm = 0, + RMSNorm = 1 + } } diff --git a/TensorSharp.CUDA/CudaBasicOps.cs b/TensorSharp.CUDA/CudaBasicOps.cs index aa11a4d8..7043a10c 100644 --- a/TensorSharp.CUDA/CudaBasicOps.cs +++ b/TensorSharp.CUDA/CudaBasicOps.cs @@ -661,6 +661,15 @@ public Tensor BuildTriMask(Tensor result, float value, float maskedValue) public Tensor LayerNormGrad(Tensor outGrad, Tensor alphaGrad, Tensor betaGrad, Tensor inGrad, Tensor y, Tensor x, Tensor alpha, Tensor beta, float eps = 1e-09f) { return advFuncKernels.LayerNormGrad(outGrad, alphaGrad, betaGrad, inGrad, y, x, alpha, beta, eps); } + + [RegisterOpStorageType("rmsnorm", typeof(CudaStorage))] + public Tensor RMSNorm(Tensor result, Tensor src, Tensor alpha, Tensor beta, float eps = 1e-09f) { return advFuncKernels.RMSNorm(result, src, alpha, beta, eps); } + [RegisterOpStorageType("rmsnormgrad", typeof(CudaStorage))] + public Tensor RMSNormGrad(Tensor outGrad, Tensor alphaGrad, Tensor betaGrad, Tensor inGrad, Tensor y, Tensor x, Tensor alpha, Tensor beta, float eps = 1e-09f) { return advFuncKernels.RMSNormGrad(outGrad, alphaGrad, betaGrad, inGrad, y, x, alpha, beta, eps); } + + + + [RegisterOpStorageType("addlayernorm", typeof(CudaStorage))] public Tensor AddLayerNorm(Tensor result, Tensor src1, Tensor src2, Tensor alpha, Tensor beta, float eps = 1e-09f) { return advFuncKernels.AddLayerNorm(result, src1, src2, alpha, beta, eps); } [RegisterOpStorageType("addlayernormgrad", typeof(CudaStorage))] diff --git a/TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs b/TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs index 8ab5f496..9dd198c9 100644 --- a/TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs +++ b/TensorSharp.CUDA/DeviceCode/AdvFuncKernels.cs @@ -202,6 +202,156 @@ __global__ void gLayerNormalizationGrad(float* gradX, } } + +__global__ void RMSNorm(float* out, + const float* in, + const float* alpha, + const float* beta, + int rows, + int cols, + float eps = 1e-9) { + extern __shared__ float _share[]; + + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + float* so = out + j * cols; + const float* sp = in + j * cols; + + float* _sqSum = _share; + _sqSum[threadIdx.x] = 0.0; + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + float ex = sp[id]; + _sqSum[threadIdx.x] += ex * ex; + } + } + __syncthreads(); + int len = blockDim.x; + while(len != 1) { + __syncthreads(); + int skip = (len + 1) >> 1; + if(threadIdx.x < (len >> 1)) + _sqSum[threadIdx.x] += _sqSum[threadIdx.x + skip]; + len = (len + 1) >> 1; + } + __syncthreads(); + float sigma = sqrtf(eps + (_sqSum[0] / cols)); + __syncthreads(); + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + float t = alpha[id] * sp[id] / sigma; + if(beta) + t += beta[id]; + so[id] = t; + } + } + } + __syncthreads(); + } +} + + +__global__ void RMSNormGrad(float* gradX, + float* gradGamma, + float* gradBeta, + float* adj, + float* y, + float* x, + float* gamma, + float* beta, + int rows, + int cols, + float eps = 1e-9) { + extern __shared__ float shared[]; + + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + float* sum_adj = shared; + float* sum_adj_x = shared + blockDim.x; + float* sum_sqr = shared + 2 * blockDim.x; + + const float* xRow = x + j * cols; + const float* yRow = y + j * cols; + const float* adjRow = adj + j * cols; + float* gradXRow = gradX + j * cols; + + sum_adj[threadIdx.x] = 0.0f; + sum_adj_x[threadIdx.x] = 0.0f; + sum_sqr[threadIdx.x] = 0.0f; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + sum_adj_x[threadIdx.x] + += adjRow[id] * (yRow[id] - ((beta) ? beta[id] : 0)) / gamma[id]; + sum_adj[threadIdx.x] += adjRow[id]; + } + } + __syncthreads(); + int len = blockDim.x; + while(len != 1) { + __syncthreads(); + int skip = (len + 1) >> 1; + if(threadIdx.x < (len >> 1)) { + sum_adj[threadIdx.x] += sum_adj[threadIdx.x + skip]; + sum_adj_x[threadIdx.x] += sum_adj_x[threadIdx.x + skip]; + } + len = (len + 1) >> 1; + } + __syncthreads(); + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + float ex = xRow[id]; + sum_sqr[threadIdx.x] += ex * ex; + } + } + + __syncthreads(); + len = blockDim.x; + while(len != 1) { + __syncthreads(); + int skip = (len + 1) >> 1; + if(threadIdx.x < (len >> 1)) + sum_sqr[threadIdx.x] += sum_sqr[threadIdx.x + skip]; + len = (len + 1) >> 1; + } + __syncthreads(); + float sigma = sqrtf(eps + (sum_sqr[0] / cols)); + __syncthreads(); + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + float grad_x = 0.0f; + float x_hat = (yRow[id] - ((beta) ? beta[id] : 0)) / gamma[id]; + grad_x += cols * adjRow[id]; + grad_x -= sum_adj[0]; + grad_x -= sum_adj_x[0] * x_hat; + grad_x /= (cols * sigma); + + float valX = gamma[id] * grad_x; + float sign = (0.f < valX) - (valX < 0.f); + valX = fabs(valX) > 1000.0f ? sign * 1000.0f : valX; + + gradXRow[id] += valX; + atomicAdd(gradGamma + id, adjRow[id] * x_hat); + if(beta) { + atomicAdd(gradBeta + id, adjRow[id]); + } + } + } + } + __syncthreads(); + } +} + __global__ void gAddLNormalization(float* out, const float* in1, const float* in2, @@ -1109,6 +1259,160 @@ __global__ void gLayerNormalizationGradHalf(__half* gradX, } } + + +__global__ void RMSNormHalf(__half* out, + const __half* in, + const __half* alpha, + const __half* beta, + int rows, + int cols, + float eps = 1e-9) { + extern __shared__ float _share[]; + + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + __half* so = out + j * cols; + const __half* sp = in + j * cols; + + float* _sqSum = _share; + _sqSum[threadIdx.x] = 0.0; + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + float ex = __half2float(sp[id]); + _sqSum[threadIdx.x] += ex * ex; + } + } + __syncthreads(); + int len = blockDim.x; + while(len != 1) { + __syncthreads(); + int skip = (len + 1) >> 1; + if(threadIdx.x < (len >> 1)) + _sqSum[threadIdx.x] += _sqSum[threadIdx.x + skip]; + len = (len + 1) >> 1; + } + __syncthreads(); + float sigma = sqrtf(eps + (_sqSum[0] / cols)); + __syncthreads(); + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + float t = __half2float(alpha[id]) * __half2float(sp[id]) / sigma; + if(beta) + t += __half2float(beta[id]); + so[id] = __float2half(t); + } + } + } + __syncthreads(); + } +} + + + +__global__ void RMSNormGradHalf(__half* gradX, + __half* gradGamma, + __half* gradBeta, + __half* adj, + __half* y, + __half* x, + __half* gamma, + __half* beta, + int rows, + int cols, + float eps = 1e-9) { + extern __shared__ float shared[]; + + for(int bid = 0; bid < rows; bid += gridDim.x) { + int j = bid + blockIdx.x; + if(j < rows) { + float* sum_adj = shared; + float* sum_adj_x = shared + blockDim.x; + float* sum_sqr = shared + 2 * blockDim.x; + + const __half* xRow = x + j * cols; + const __half* yRow = y + j * cols; + const __half* adjRow = adj + j * cols; + __half* gradXRow = gradX + j * cols; + + sum_adj[threadIdx.x] = 0.0f; + sum_adj_x[threadIdx.x] = 0.0f; + sum_sqr[threadIdx.x] = 0.0f; + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + sum_adj_x[threadIdx.x] + += __half2float(adjRow[id]) * (__half2float(yRow[id]) - ((beta) ? __half2float(beta[id]) : 0)) / __half2float(gamma[id]); + sum_adj[threadIdx.x] += __half2float(adjRow[id]); + } + } + __syncthreads(); + int len = blockDim.x; + while(len != 1) { + __syncthreads(); + int skip = (len + 1) >> 1; + if(threadIdx.x < (len >> 1)) { + sum_adj[threadIdx.x] += sum_adj[threadIdx.x + skip]; + sum_adj_x[threadIdx.x] += sum_adj_x[threadIdx.x + skip]; + } + len = (len + 1) >> 1; + } + __syncthreads(); + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + float ex = __half2float(xRow[id]); + sum_sqr[threadIdx.x] += ex * ex; + } + } + + __syncthreads(); + len = blockDim.x; + while(len != 1) { + __syncthreads(); + int skip = (len + 1) >> 1; + if(threadIdx.x < (len >> 1)) + sum_sqr[threadIdx.x] += sum_sqr[threadIdx.x + skip]; + len = (len + 1) >> 1; + } + __syncthreads(); + float sigma = sqrtf(eps + (sum_sqr[0] / cols)); + __syncthreads(); + + for(int tid = 0; tid < cols; tid += blockDim.x) { + int id = tid + threadIdx.x; + if(id < cols) { + float grad_x = 0.0f; + float x_hat = (__half2float(yRow[id]) - ((beta) ? __half2float(beta[id]) : 0)) / __half2float(gamma[id]); + grad_x += cols * __half2float(adjRow[id]); + grad_x -= sum_adj[0]; + grad_x -= sum_adj_x[0] * x_hat; + grad_x /= (cols * sigma); + + float valX = __half2float(gamma[id]) * grad_x; + float sign = (0.f < valX) - (valX < 0.f); + valX = fabs(valX) > 1000.0f ? sign * 1000.0f : valX; + + gradXRow[id] = __hadd(gradXRow[id], __float2half(valX)); + atomicAdd(gradGamma + id, __float2half(__half2float(adjRow[id]) * x_hat)); + if(beta) { + atomicAdd(gradBeta + id, adjRow[id]); + } + } + } + } + __syncthreads(); + } +} + + + __global__ void RoPEGradHalf(__half* __restrict__ grad, __half* __restrict__ adj, int rows, int cols, int seqLen) { for(int bid = 0; bid < rows; bid += gridDim.x) @@ -1632,6 +1936,53 @@ private void LayerNormGrad(TSCudaContext context, Tensor outGrad, Tensor alphaGr } + public Tensor RMSNormGrad(Tensor outGrad, Tensor alphaGrad, Tensor betaGrad, Tensor inGrad, Tensor y, Tensor x, Tensor alpha, Tensor beta, float eps = 1e-9f) + { + TSCudaContext context = CudaHelpers.TSContextForTensor(inGrad); + Tensor writeTarget = TensorResultBuilder.GetWriteTarget(outGrad, inGrad, false, inGrad.Sizes); + RMSNormGrad(context, writeTarget, alphaGrad, betaGrad, inGrad, y, x, alpha, beta, eps); + + return writeTarget; + } + + + private void RMSNormGrad(TSCudaContext context, Tensor outGrad, Tensor alphaGrad, Tensor betaGrad, Tensor inGrad, Tensor y, Tensor x, Tensor alpha, Tensor beta, float eps = 1e-9f) + { + CudaContext cudaContext = context.CudaContextForTensor(inGrad); + + cudaContext.SetCurrent(); + + int ndim = inGrad.DimensionCount; + long storageSize = TensorDimensionHelpers.GetStorageSize(inGrad.Sizes, inGrad.Strides); + long cols = inGrad.Sizes[ndim - 1]; + + if (storageSize % cols != 0) + { + throw new Exception($"Invalid tensor storage size = '{storageSize}', and cols = '{cols}'"); + } + + long rows = storageSize / cols; + + dim3 block = new dim3((uint)Math.Min(512, cols)); + dim3 grid = new dim3((uint)Math.Min(1024, ApplyUtils.CeilDiv(rows, block.y))); + + CUdeviceptr outGradPtr = CudaHelpers.GetBufferStart(outGrad); + CUdeviceptr alphaGradPtr = CudaHelpers.GetBufferStart(alphaGrad); + CUdeviceptr betaGradPtr = CudaHelpers.GetBufferStart(betaGrad); + CUdeviceptr inGradPtr = CudaHelpers.GetBufferStart(inGrad); + CUdeviceptr yPtr = CudaHelpers.GetBufferStart(y); + CUdeviceptr xPtr = CudaHelpers.GetBufferStart(x); + CUdeviceptr alphaPtr = CudaHelpers.GetBufferStart(alpha); + CUdeviceptr betaPtr = CudaHelpers.GetBufferStart(beta); + + string kernelName = "RMSNormGrad"; + if (outGrad.ElementType == DType.Float16) + { + kernelName = "RMSNormGradHalf"; + } + Invoke(context, cudaContext, kernelName, grid, block, block.x * sizeof(float) * 4, CUstream.NullStream, outGradPtr, alphaGradPtr, betaGradPtr, inGradPtr, yPtr, xPtr, alphaPtr, betaPtr, rows, cols, eps); + } + public void AddLayerNormGrad(Tensor out1Grad, Tensor out2Grad, Tensor alphaGrad, Tensor betaGrad, Tensor inGrad, Tensor y, Tensor x1, Tensor x2, Tensor alpha, Tensor beta, float eps = 1e-9f) { TSCudaContext context = CudaHelpers.TSContextForTensor(inGrad); @@ -1722,6 +2073,53 @@ private void LayerNorm(TSCudaContext context, Tensor result, Tensor src, Tensor } + + public Tensor RMSNorm(Tensor result, Tensor src, Tensor alpha, Tensor beta, float eps = 1e-9f) + { + TSCudaContext context = CudaHelpers.TSContextForTensor(src); + Tensor writeTarget = TensorResultBuilder.GetWriteTarget(result, src, false, src.Sizes); + RMSNorm(context, writeTarget, src, alpha, beta, eps); + + return writeTarget; + } + + + private void RMSNorm(TSCudaContext context, Tensor result, Tensor src, Tensor alpha, Tensor beta, float eps = 1e-9f) + { + CudaContext cudaContext = context.CudaContextForTensor(src); + + cudaContext.SetCurrent(); + + int ndim = src.DimensionCount; + long storageSize = TensorDimensionHelpers.GetStorageSize(src.Sizes, src.Strides); + long cols = src.Sizes[ndim - 1]; + + if (storageSize % cols != 0) + { + throw new Exception($"Invalid tensor storage size = '{storageSize}', and cols = '{cols}'"); + } + + long rows = storageSize / cols; + + + dim3 block = new dim3((uint)Math.Min(512, cols)); + dim3 grid = new dim3((uint)Math.Min(1024, ApplyUtils.CeilDiv(rows, block.y))); + + CUdeviceptr resultPtr = CudaHelpers.GetBufferStart(result); + CUdeviceptr srcPtr = CudaHelpers.GetBufferStart(src); + CUdeviceptr alphaPtr = CudaHelpers.GetBufferStart(alpha); + CUdeviceptr betaPtr = CudaHelpers.GetBufferStart(beta); + + string kernelName = "RMSNorm"; + if (src.ElementType == DType.Float16) + { + kernelName = "RMSNormHalf"; + } + + Invoke(context, cudaContext, kernelName, grid, block, block.x * sizeof(float), CUstream.NullStream, resultPtr, srcPtr, alphaPtr, betaPtr, rows, cols, eps); + + } + public Tensor AddLayerNorm(Tensor result, Tensor src1, Tensor src2, Tensor alpha, Tensor beta, float eps = 1e-9f) { TSCudaContext context = CudaHelpers.TSContextForTensor(src1); diff --git a/TensorSharp/Cpu/CpuBasicOps.cs b/TensorSharp/Cpu/CpuBasicOps.cs index c956cee0..dcd28ab8 100644 --- a/TensorSharp/Cpu/CpuBasicOps.cs +++ b/TensorSharp/Cpu/CpuBasicOps.cs @@ -915,12 +915,35 @@ public Tensor LayerNormGrad(Tensor result, Tensor gradGamma_, Tensor gradBeta_, } catch (Exception err) { - Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"LayerNormGrad exception: '{err.Message}', CallStack:'{err.StackTrace}'"); + Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"{nameof(LayerNormGrad)} exception: '{err.Message}', CallStack:'{err.StackTrace}'"); throw; } } + [RegisterOpStorageType("rmsnorm", typeof(CpuStorage))] + public Tensor RMSNorm(Tensor result, Tensor src, Tensor gamma_, Tensor beta_, float eps) + { + Tensor writeTarget = TensorResultBuilder.GetWriteTarget(result, src, true, src.Sizes); + TensorApplyCPU.RMSNorm(writeTarget, src, gamma_, beta_, eps, (int)src.Sizes[0], (int)src.Sizes[1]); + return writeTarget; + } + + [RegisterOpStorageType("rmsnormgrad", typeof(CpuStorage))] + public Tensor RMSNormGrad(Tensor result, Tensor gradGamma_, Tensor gradBeta_, Tensor adj_, Tensor y_, Tensor x_, Tensor gamma_, Tensor beta_, float eps) + { + try + { + Tensor writeTarget = TensorResultBuilder.GetWriteTarget(result, adj_, false, adj_.Sizes); + TensorApplyCPU.RMSNormGrad(writeTarget, gradGamma_, gradBeta_, adj_, y_, x_, gamma_, beta_, (int)adj_.Sizes[0], (int)adj_.Sizes[1], eps); + return writeTarget; + } + catch (Exception err) + { + Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"{nameof(RMSNormGrad)} exception: '{err.Message}', CallStack:'{err.StackTrace}'"); + throw; + } + } private readonly MethodInfo addlayerNorm_func = NativeWrapper.GetMethod("TS_AddLayerNorm"); [RegisterOpStorageType("addlayernorm", typeof(CpuStorage))] diff --git a/TensorSharp/Ops.cs b/TensorSharp/Ops.cs index 4ac5485a..b76fc016 100644 --- a/TensorSharp/Ops.cs +++ b/TensorSharp/Ops.cs @@ -226,6 +226,11 @@ public static Tensor LayerNormGrad(Tensor outGrad, Tensor alphaGrad, Tensor beta return (Tensor)OpRegistry.Invoke("layernormgrad", outGrad, alphaGrad, betaGrad, inGrad, y, x, alpha, beta, eps); } + public static Tensor RMSNorm(Tensor result, Tensor src, Tensor alpha, Tensor beta, float eps = 1e-09f) { return (Tensor)OpRegistry.Invoke("rmsnorm", result, src, alpha, beta, eps); } + public static Tensor RMSNormGrad(Tensor outGrad, Tensor alphaGrad, Tensor betaGrad, Tensor inGrad, Tensor y, Tensor x, Tensor alpha, Tensor beta, float eps = 1e-09f) + { + return (Tensor)OpRegistry.Invoke("rmsnormgrad", outGrad, alphaGrad, betaGrad, inGrad, y, x, alpha, beta, eps); + } public static Tensor AddLayerNorm(Tensor result, Tensor src1, Tensor src2, Tensor alpha, Tensor beta, float eps = 1e-09f) { return (Tensor)OpRegistry.Invoke("addlayernorm", result, src1, src2, alpha, beta, eps); } public static Tensor AddLayerNormGrad(Tensor out1Grad, Tensor out2Grad, Tensor alphaGrad, Tensor betaGrad, Tensor inGrad, Tensor y, Tensor x1, Tensor x2, Tensor alpha, Tensor beta, float eps = 1e-09f) { return (Tensor)OpRegistry.Invoke("addlayernormgrad", out1Grad, out2Grad, alphaGrad, betaGrad, inGrad, y, x1, x2, alpha, beta, eps); } diff --git a/TensorSharp/TensorApplyCPU.cs b/TensorSharp/TensorApplyCPU.cs index fb03a323..a673490d 100644 --- a/TensorSharp/TensorApplyCPU.cs +++ b/TensorSharp/TensorApplyCPU.cs @@ -1405,7 +1405,77 @@ unsafe static public void LayerNorm(Tensor out_, } - unsafe static public void LayerNormGrad(Tensor gradX_, + + unsafe static public void RMSNorm(Tensor out_, + Tensor in_, + Tensor gamma_, + Tensor beta_, + float eps, + int rows, + int cols) + { + float* outPtr = (float*)CpuNativeHelpers.GetBufferStart(out_); + float* inPtr = (float*)CpuNativeHelpers.GetBufferStart(in_); + float* alpha = (float*)CpuNativeHelpers.GetBufferStart(gamma_); + float* beta = (beta_ != null) ? (float*)CpuNativeHelpers.GetBufferStart(beta_) : null; + + for (int j = 0; j < rows; ++j) + { + float* so = outPtr + j * cols; + float* sp = inPtr + j * cols; + + Span spanSP = new Span(sp, cols); + int vectorSize = Vector.Count; + int i = 0; + float sqSum = 0.0f; + + for (i = 0; i < cols - vectorSize; i += vectorSize) + { + Vector vecSp = new Vector(spanSP.Slice(i)); + sqSum += Vector.Dot(vecSp, vecSp); + } + for (; i < cols; ++i) + { + float ex = sp[i]; + sqSum += ex * ex; + } + + float sigma = (float)Math.Sqrt(eps + sqSum / cols); + + Span spanSO = new Span(so, cols); + Span spanAlpha = new Span(alpha, cols); + Span spanBeta = (beta != null) ? new Span(beta, cols) : null; + Vector vecSigma = new Vector(sigma); + + for (i = 0; i < cols - vectorSize; i += vectorSize) + { + Vector vecSp = new Vector(spanSP.Slice(i)); + Vector vecAlpha = new Vector(spanAlpha.Slice(i)); + + Vector vecT = vecAlpha * (vecSp / vecSigma); + + if (spanBeta != null) + { + Vector vecBeta = new Vector(spanBeta.Slice(i)); + vecT += vecBeta; + } + + vecT.CopyTo(spanSO.Slice(i)); + } + for (; i < cols; ++i) + { + float t = alpha[i] * (sp[i] / sigma); + if (beta != null) + { + t += beta[i]; + } + + so[i] = t; + } + } + } + + unsafe static public void LayerNormGrad(Tensor gradX_, Tensor gradGamma_, Tensor gradBeta_, Tensor adj_, @@ -1518,7 +1588,113 @@ unsafe static public void LayerNormGrad(Tensor gradX_, } - unsafe static public void Adam(Tensor tw, Tensor tg, Tensor tv, Tensor tm, int rows, int cols, int batchSize, float step_size, float clipval, float regc, float decay_rate_v, float decay_rate_m, int iter, float eps) + unsafe static public void RMSNormGrad(Tensor gradX_, + Tensor gradGamma_, + Tensor gradBeta_, + Tensor adj_, + Tensor y_, + Tensor x_, + Tensor gamma_, + Tensor beta_, + int rows, + int cols, + float eps) + { + float* gradX = (float*)CpuNativeHelpers.GetBufferStart(gradX_); + float* gradGamma = (float*)CpuNativeHelpers.GetBufferStart(gradGamma_); + float* gradBeta = gradBeta_ != null ? (float*)CpuNativeHelpers.GetBufferStart(gradBeta_) : null; + float* adj = (float*)CpuNativeHelpers.GetBufferStart(adj_); + float* y = (float*)CpuNativeHelpers.GetBufferStart(y_); + float* x = (float*)CpuNativeHelpers.GetBufferStart(x_); + float* gamma = (float*)CpuNativeHelpers.GetBufferStart(gamma_); + float* beta = beta_ != null ? (float*)CpuNativeHelpers.GetBufferStart(beta_) : null; + + if (beta != null) + { + for (int j = 0; j < rows; ++j) + { + float* xRow = x + j * cols; + float* yRow = y + j * cols; + float* adjRow = adj + j * cols; + float* gradXRow = gradX + j * cols; + + float sum_adj = 0.0f; + float sum_adj_x = 0.0f; + float sum_sqr = 0.0f; + + for (int i = 0; i < cols; ++i) + { + sum_adj_x += adjRow[i] * (yRow[i] - (beta != null ? beta[i] : 0.0f)) / gamma[i]; + sum_adj += adjRow[i]; + } + + for (int i = 0; i < cols; ++i) + { + float ex = xRow[i]; + sum_sqr += ex * ex; + } + + float sigma = (float)Math.Sqrt(eps + sum_sqr / cols); + for (int i = 0; i < cols; ++i) + { + float grad_x = 0.0f; + float x_hat = (yRow[i] - beta[i]) / gamma[i]; + grad_x += cols * adjRow[i]; + grad_x -= sum_adj; + grad_x -= sum_adj_x * x_hat; + grad_x /= cols * sigma; + + gradXRow[i] += gamma[i] * grad_x; + gradGamma[i] += adjRow[i] * x_hat; + gradBeta[i] += adjRow[i]; + } + } + } + else + { + for (int j = 0; j < rows; ++j) + { + float* xRow = x + j * cols; + float* yRow = y + j * cols; + float* adjRow = adj + j * cols; + float* gradXRow = gradX + j * cols; + + float sum_adj = 0.0f; + float sum_adj_x = 0.0f; + float sum_sqr = 0.0f; + + for (int i = 0; i < cols; ++i) + { + sum_adj_x += adjRow[i] * (yRow[i] - (beta != null ? beta[i] : 0.0f)) / gamma[i]; + sum_adj += adjRow[i]; + } + + for (int i = 0; i < cols; ++i) + { + float ex = xRow[i]; + sum_sqr += ex * ex; + } + + float sigma = (float)Math.Sqrt(eps + sum_sqr / cols); + + for (int i = 0; i < cols; ++i) + { + float grad_x = 0.0f; + float x_hat = yRow[i] / gamma[i]; + grad_x += cols * adjRow[i]; + grad_x -= sum_adj; + grad_x -= sum_adj_x * x_hat; + grad_x /= cols * sigma; + + gradXRow[i] += gamma[i] * grad_x; + gradGamma[i] += adjRow[i] * x_hat; + } + } + } + } + + + unsafe static public void Adam(Tensor tw, Tensor tg, Tensor tv, Tensor tm, int rows, int cols, int batchSize, float step_size, float clipval, float regc, float decay_rate_v, float decay_rate_m, int iter, float eps) { float* w = (float*)CpuNativeHelpers.GetBufferStart(tw); float* g = (float*)CpuNativeHelpers.GetBufferStart(tg); diff --git a/Tools/Seq2SeqConsole/Program.cs b/Tools/Seq2SeqConsole/Program.cs index a109be3b..345777d3 100644 --- a/Tools/Seq2SeqConsole/Program.cs +++ b/Tools/Seq2SeqConsole/Program.cs @@ -65,7 +65,7 @@ private static void Main(string[] args) { // Load train corpus var trainCorpus = new Seq2SeqCorpus(corpusFilePath: opts.TrainCorpusPath, srcLangName: opts.SrcLang, tgtLangName: opts.TgtLang, maxTokenSizePerBatch: opts.MaxTokenSizePerBatch, - maxSrcSentLength: opts.MaxSrcSentLength, maxTgtSentLength: opts.MaxTgtSentLength, shuffleEnums: opts.ShuffleType, tooLongSequence: opts.TooLongSequence); + maxSrcSentLength: opts.MaxSrcSentLength, maxTgtSentLength: opts.MaxTgtSentLength, shuffleEnums: opts.ShuffleType, tooLongSequence: opts.TooLongSequence, indexedFilePath: opts.IndexedCorpusPath); // Load valid corpus var validCorpusList = new List();