From 54945f14e2bb4084a7a773fa87f134aaad1f3107 Mon Sep 17 00:00:00 2001 From: Zhongkai Fu Date: Mon, 4 Sep 2023 10:54:12 -0700 Subject: [PATCH] Only keep RoPE for self-attention layer --- Seq2SeqSharp/Applications/Encoder.cs | 20 +++-- Seq2SeqSharp/Applications/GPT.cs | 2 +- Seq2SeqSharp/Applications/Seq2Seq.cs | 13 +-- .../Applications/Seq2SeqClassification.cs | 11 +-- .../Applications/SeqClassification.cs | 12 +-- Seq2SeqSharp/Applications/SeqLabel.cs | 12 +-- Seq2SeqSharp/Applications/SeqSimilarity.cs | 19 ++-- Seq2SeqSharp/Layers/MultiHeadAttention.cs | 16 +--- Seq2SeqSharp/Utils/Misc.cs | 14 ++- Seq2SeqSharp/Utils/PositionEmbedding.cs | 86 +++++++++++++++++++ 10 files changed, 151 insertions(+), 54 deletions(-) create mode 100644 Seq2SeqSharp/Utils/PositionEmbedding.cs diff --git a/Seq2SeqSharp/Applications/Encoder.cs b/Seq2SeqSharp/Applications/Encoder.cs index 7f176d0f..2d521570 100644 --- a/Seq2SeqSharp/Applications/Encoder.cs +++ b/Seq2SeqSharp/Applications/Encoder.cs @@ -58,22 +58,22 @@ public static MultiProcessorNetworkWrapper CreateEncoders(IModel model } static public IWeightTensor Run(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, IEncoder encoder, IModel modelMetaData, ShuffleEnums shuffleType, - IWeightTensor srcEmbedding, IWeightTensor segmentEmbedding, List> srcSntsIds, float[] originalSrcLengths, bool amp = false) + IWeightTensor srcEmbedding, IWeightTensor posEmbeddings, IWeightTensor segmentEmbedding, List> srcSntsIds, float[] originalSrcLengths, bool amp = false) { // Reset networks encoder.Reset(computeGraph.GetWeightFactory(), srcSntsIds.Count); - IWeightTensor encOutput = InnerRunner(computeGraph, srcSntsIds, originalSrcLengths, shuffleType, encoder, modelMetaData, srcEmbedding, segmentEmbedding, amp); + IWeightTensor encOutput = InnerRunner(computeGraph, srcSntsIds, originalSrcLengths, shuffleType, encoder, modelMetaData, srcEmbedding, posEmbeddings, segmentEmbedding, amp); return encOutput; } - public static IWeightTensor BuildTensorForSourceTokenGroupAt(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, ShuffleEnums shuffleType, IEncoder encoder, IModel modelMetaData, IWeightTensor srcEmbedding, IWeightTensor segmentEmbedding, int groupId) + public static IWeightTensor BuildTensorForSourceTokenGroupAt(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, ShuffleEnums shuffleType, IEncoder encoder, IModel modelMetaData, IWeightTensor srcEmbedding, IWeightTensor posEmbeddings, IWeightTensor segmentEmbedding, int groupId) { var contextTokens = InsertCLSToken(sntPairBatch.GetSrcTokens(groupId)); var originalSrcContextLength = BuildInTokens.PadSentences(contextTokens); var contextTokenIds = modelMetaData.SrcVocab.GetWordIndex(contextTokens); - IWeightTensor encContextOutput = InnerRunner(computeGraph, contextTokenIds, originalSrcContextLength, shuffleType, encoder, modelMetaData, srcEmbedding, segmentEmbedding); + IWeightTensor encContextOutput = InnerRunner(computeGraph, contextTokenIds, originalSrcContextLength, shuffleType, encoder, modelMetaData, srcEmbedding, posEmbeddings, segmentEmbedding); int contextPaddedLen = contextTokens[0].Count; float[] contextCLSIdxs = new float[sntPairBatch.BatchSize]; @@ -88,14 +88,14 @@ public static IWeightTensor BuildTensorForSourceTokenGroupAt(IComputeGraph compu } static private IWeightTensor InnerRunner(IComputeGraph computeGraph, List> srcTokensList, float[] originalSrcLengths, ShuffleEnums shuffleType, IEncoder encoder, IModel modelMetaData, - IWeightTensor srcEmbedding, IWeightTensor segmentEmbedding, bool amp = false) + IWeightTensor srcEmbedding, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding, bool amp = false) { int batchSize = srcTokensList.Count; int srcSeqPaddedLen = srcTokensList[0].Count; IWeightTensor srcSelfMask = (shuffleType == ShuffleEnums.NoPaddingInSrc || shuffleType == ShuffleEnums.NoPadding || batchSize == 1) ? null : computeGraph.BuildPadSelfMask(srcSeqPaddedLen, originalSrcLengths, elementType: amp ? DType.Float16 : DType.Float32); // The length of source sentences are same in a single mini-batch, so we don't have source mask. // Encoding input source sentences - var encOutput = RunEncoder(computeGraph, srcTokensList, encoder, modelMetaData, srcEmbedding, srcSelfMask, segmentEmbedding, amp: amp); + var encOutput = RunEncoder(computeGraph, srcTokensList, encoder, modelMetaData, srcEmbedding, srcSelfMask, posEmbedding, segmentEmbedding, amp: amp); if (srcSelfMask != null) { srcSelfMask.Dispose(); @@ -114,11 +114,17 @@ static private IWeightTensor InnerRunner(IComputeGraph computeGraph, List /// /// - static private IWeightTensor RunEncoder(IComputeGraph g, List> seqs, IEncoder encoder, IModel modelMetaData, IWeightTensor embeddings, IWeightTensor selfMask, + static private IWeightTensor RunEncoder(IComputeGraph g, List> seqs, IEncoder encoder, IModel modelMetaData, IWeightTensor embeddings, IWeightTensor selfMask, IWeightTensor posEmbeddings, IWeightTensor segmentEmbeddings, bool amp = false) { int batchSize = seqs.Count; var inputEmbs = TensorUtils.CreateTokensEmbeddings(seqs, g, embeddings, segmentEmbeddings, modelMetaData.SrcVocab, (float)Math.Sqrt(embeddings.Columns), enableTagEmbedding: modelMetaData.EnableTagEmbeddings, amp: amp); + + if (modelMetaData.EncoderType == EncoderTypeEnums.Transformer && posEmbeddings != null) + { + inputEmbs = PositionEmbedding.AddPositionEmbedding(g, posEmbeddings, batchSize, inputEmbs, 0.0f); + } + return encoder.Encode(inputEmbs, batchSize, g, selfMask); } } diff --git a/Seq2SeqSharp/Applications/GPT.cs b/Seq2SeqSharp/Applications/GPT.cs index 2aca8324..d7e71a52 100644 --- a/Seq2SeqSharp/Applications/GPT.cs +++ b/Seq2SeqSharp/Applications/GPT.cs @@ -125,7 +125,7 @@ private bool CreateTrainableParameters(IModel model) m_decoderFFLayer = new MultiProcessorNetworkWrapper(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: (m_options.Task == ModeEnums.Train), learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType), DeviceIds); - m_segmentEmbedding = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength), model, elementType, isTrainable: (m_options.Task == ModeEnums.Train)); + (_, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength), model, elementType, isTrainable: (m_options.Task == ModeEnums.Train)); m_tgtEmbedding = CreateTgtEmbeddings(model, raDeviceIds, m_options.IsTgtEmbeddingTrainable && (m_options.Task == ModeEnums.Train), m_options.DecoderStartLearningRateFactor, elementType); return (true); diff --git a/Seq2SeqSharp/Applications/Seq2Seq.cs b/Seq2SeqSharp/Applications/Seq2Seq.cs index bd05c073..6db2d42f 100644 --- a/Seq2SeqSharp/Applications/Seq2Seq.cs +++ b/Seq2SeqSharp/Applications/Seq2Seq.cs @@ -33,6 +33,7 @@ public class Seq2Seq : BaseSeq2SeqFramework private MultiProcessorNetworkWrapper m_decoder; //The decoders over devices private MultiProcessorNetworkWrapper m_decoderFFLayer; //The feed forward layers over devices after all layers in decoder + private MultiProcessorNetworkWrapper m_posEmbedding = null; private MultiProcessorNetworkWrapper m_segmentEmbedding; private MultiProcessorNetworkWrapper m_pointerGenerator; @@ -109,7 +110,7 @@ private bool CreateTrainableParameters(IModel model) m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, elementType: elementType); m_decoderFFLayer = new MultiProcessorNetworkWrapper(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: true, learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType: elementType), DeviceIds); - m_segmentEmbedding = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(Math.Max(m_options.MaxSrcSentLength, m_options.MaxValidSrcSentLength), Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength)), model, elementType: elementType); + (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(Math.Max(m_options.MaxSrcSentLength, m_options.MaxValidSrcSentLength), Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength)), model, elementType: elementType, createAPE: false); (m_srcEmbedding, m_tgtEmbedding) = CreateSrcTgtEmbeddings(model, raDeviceIds, m_options.IsSrcEmbeddingTrainable, m_options.IsTgtEmbeddingTrainable, m_options.EncoderStartLearningRateFactor, m_options.DecoderStartLearningRateFactor, elementType: elementType); @@ -141,7 +142,7 @@ public void VQModel() /// /// Get networks on specific devices /// - private (IEncoder, IDecoder, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor, IFeedForwardLayer) GetNetworksOnDeviceAt(int deviceId) + private (IEncoder, IDecoder, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor, IFeedForwardLayer, IWeightTensor) GetNetworksOnDeviceAt(int deviceId) { var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId); return (m_encoder.GetNetworkOnDevice(deviceIdIdx), @@ -149,7 +150,7 @@ public void VQModel() m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx), m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx), m_modelMetaData.SharedEmbeddings ? m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx) : m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx), - m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_pointerGenerator?.GetNetworkOnDevice(deviceIdIdx)); + m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_pointerGenerator?.GetNetworkOnDevice(deviceIdIdx), m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx)); } private string GenerateCacheKey(List> strs) @@ -175,7 +176,7 @@ private string GenerateCacheKey(List> strs) /// The cost of forward part public override List RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, DecodingOptions decodingOptions, bool isTraining) { - (var encoder, var decoder, var decoderFFLayer, var srcEmbedding, var tgtEmbedding, var segmentEmbedding, var pointerGenerator) = GetNetworksOnDeviceAt(computeGraph.DeviceId); + (var encoder, var decoder, var decoderFFLayer, var srcEmbedding, var tgtEmbedding, var segmentEmbedding, var pointerGenerator, var posEmbeddings) = GetNetworksOnDeviceAt(computeGraph.DeviceId); var srcSnts = sntPairBatch.GetSrcTokens(0); var originalSrcLengths = BuildInTokens.PadSentences(srcSnts); @@ -193,7 +194,7 @@ public override List RunForwardOnSingleDevice(IComputeGraph compu string cacheKey = GenerateCacheKey(srcSnts); if (!m_memoryCache.TryGetValue(cacheKey, out encOutput)) { - encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths); // Shape: [batchsize * seqLen, embedding_dim] + encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbeddings, segmentEmbedding, srcTokensList, originalSrcLengths); // Shape: [batchsize * seqLen, embedding_dim] var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1); m_memoryCache.Set(cacheKey, encOutput.CopyWeightsRef($"cache_{encOutput.Name}", false, graphToBind: null), cacheEntryOptions); @@ -202,7 +203,7 @@ public override List RunForwardOnSingleDevice(IComputeGraph compu else { // Compute src tensor - encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths, amp:m_options.AMP); + encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbeddings, segmentEmbedding, srcTokensList, originalSrcLengths, amp:m_options.AMP); } List nrs = new List(); diff --git a/Seq2SeqSharp/Applications/Seq2SeqClassification.cs b/Seq2SeqSharp/Applications/Seq2SeqClassification.cs index d951d794..f4eaa6bf 100644 --- a/Seq2SeqSharp/Applications/Seq2SeqClassification.cs +++ b/Seq2SeqSharp/Applications/Seq2SeqClassification.cs @@ -35,6 +35,7 @@ public class Seq2SeqClassification : BaseSeq2SeqFramework m_encoderFFLayer; //The feed forward layers over devices after all layers in encoder private MultiProcessorNetworkWrapper m_decoderFFLayer; //The feed forward layers over devices after all layers in decoder + private MultiProcessorNetworkWrapper m_posEmbedding = null; private MultiProcessorNetworkWrapper m_segmentEmbedding; private readonly ShuffleEnums m_shuffleType = ShuffleEnums.Random; @@ -91,7 +92,7 @@ private bool CreateTrainableParameters(IModel model) m_decoderFFLayer = new MultiProcessorNetworkWrapper(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: true), DeviceIds); - m_segmentEmbedding = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(Math.Max(m_options.MaxSrcSentLength, m_options.MaxValidSrcSentLength), Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength)), model); + (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(Math.Max(m_options.MaxSrcSentLength, m_options.MaxValidSrcSentLength), Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength)), model, createAPE: false); (m_srcEmbedding, m_tgtEmbedding) = CreateSrcTgtEmbeddings(model, raDeviceIds, m_options.IsSrcEmbeddingTrainable, m_options.IsTgtEmbeddingTrainable, m_options.EncoderStartLearningRateFactor, m_options.DecoderStartLearningRateFactor); return true; } @@ -101,7 +102,7 @@ private bool CreateTrainableParameters(IModel model) /// /// /// - private (IEncoder, IDecoder, IFeedForwardLayer, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceId) + private (IEncoder, IDecoder, IFeedForwardLayer, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceId) { var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId); return (m_encoder.GetNetworkOnDevice(deviceIdIdx), @@ -110,7 +111,7 @@ private bool CreateTrainableParameters(IModel model) m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx), m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx), m_modelMetaData.SharedEmbeddings ? m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx) : m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx), - m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx)); + m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx)); } /// @@ -123,13 +124,13 @@ private bool CreateTrainableParameters(IModel model) /// The cost of forward part public override List RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, DecodingOptions decodingOptions, bool isTraining) { - (IEncoder encoder, IDecoder decoder, IFeedForwardLayer encoderFFLayer, IFeedForwardLayer decoderFFLayer, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(computeGraph.DeviceId); + (IEncoder encoder, IDecoder decoder, IFeedForwardLayer encoderFFLayer, IFeedForwardLayer decoderFFLayer, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding, IWeightTensor segmentEmbedding, IWeightTensor posEmbeddings) = GetNetworksOnDeviceAt(computeGraph.DeviceId); var srcSnts = sntPairBatch.GetSrcTokens(0); var originalSrcLengths = BuildInTokens.PadSentences(srcSnts); var srcTokensList = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts); - IWeightTensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths); + IWeightTensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbeddings, segmentEmbedding, srcTokensList, originalSrcLengths); List nrs = new List(); int srcSeqPaddedLen = srcSnts[0].Count; diff --git a/Seq2SeqSharp/Applications/SeqClassification.cs b/Seq2SeqSharp/Applications/SeqClassification.cs index e53d4ca4..bb6a34d8 100644 --- a/Seq2SeqSharp/Applications/SeqClassification.cs +++ b/Seq2SeqSharp/Applications/SeqClassification.cs @@ -32,6 +32,8 @@ public class SeqClassification : BaseSeq2SeqFramework private MultiProcessorNetworkWrapper m_encoder; //The encoders over devices. private MultiProcessorNetworkWrapper m_segmentEmbedding; + private MultiProcessorNetworkWrapper m_positionalEmbeddings = null; + private readonly ShuffleEnums m_shuffleType = ShuffleEnums.Random; private readonly SeqClassificationOptions m_options; @@ -78,7 +80,7 @@ private bool CreateTrainableParameters(IModel model) m_encoderFFLayer[i] = new MultiProcessorNetworkWrapper(new FeedForwardLayer($"FeedForward_Encoder_{i}", model.HiddenDim, model.ClsVocabs[i].Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: true), DeviceIds); } - m_segmentEmbedding = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, m_options.MaxSentLength, model); + (m_positionalEmbeddings, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, m_options.MaxSentLength, model, createAPE: false); Logger.WriteLine($"Creating embeddings. Shape = '({model.SrcVocab.Count} ,{model.EncoderEmbeddingDim})'"); m_srcEmbedding = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { model.SrcVocab.Count, model.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, fanOut: true, name: "SrcEmbeddings", @@ -92,7 +94,7 @@ private bool CreateTrainableParameters(IModel model) /// /// /// - private (IEncoder, IWeightTensor, List, IWeightTensor) GetNetworksOnDeviceAt(int deviceId) + private (IEncoder, IWeightTensor, List, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceId) { var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId); @@ -105,7 +107,7 @@ private bool CreateTrainableParameters(IModel model) return (m_encoder.GetNetworkOnDevice(deviceIdIdx), m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx), feedForwardLayers, - m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx)); + m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_positionalEmbeddings?.GetNetworkOnDevice(deviceIdIdx)); } /// @@ -120,11 +122,11 @@ public override List RunForwardOnSingleDevice(IComputeGraph compu { List nrs = new List(); - (IEncoder encoder, IWeightTensor srcEmbedding, List encoderFFLayer, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(computeGraph.DeviceId); + (IEncoder encoder, IWeightTensor srcEmbedding, List encoderFFLayer, IWeightTensor segmentEmbedding, IWeightTensor posEmbeddings) = GetNetworksOnDeviceAt(computeGraph.DeviceId); var srcSnts = sntPairBatch.GetSrcTokens(0); var originalSrcLengths = BuildInTokens.PadSentences(srcSnts); var srcTokensList = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts); - IWeightTensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths); + IWeightTensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbeddings, segmentEmbedding, srcTokensList, originalSrcLengths); int srcSeqPaddedLen = srcSnts[0].Count; int batchSize = srcSnts.Count; diff --git a/Seq2SeqSharp/Applications/SeqLabel.cs b/Seq2SeqSharp/Applications/SeqLabel.cs index 97362034..b808a668 100644 --- a/Seq2SeqSharp/Applications/SeqLabel.cs +++ b/Seq2SeqSharp/Applications/SeqLabel.cs @@ -30,6 +30,8 @@ public class SeqLabel : BaseSeq2SeqFramework private MultiProcessorNetworkWrapper m_ffLayer; //The feed forward layers over over devices. private MultiProcessorNetworkWrapper m_segmentEmbedding; + private MultiProcessorNetworkWrapper m_posEmbedding = null; + private readonly ShuffleEnums m_shuffleType = ShuffleEnums.Random; private readonly SeqLabelOptions m_options; @@ -93,7 +95,7 @@ private bool CreateTrainableParameters(IModel model) m_srcEmbedding = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { model.SrcVocab.Count, model.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, name: "SrcEmbeddings", isTrainable: true), DeviceIds); - m_segmentEmbedding = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, m_options.MaxSentLength, model); + (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, m_options.MaxSentLength, model, createAPE: false); return true; } @@ -101,12 +103,12 @@ private bool CreateTrainableParameters(IModel model) /// /// Get networks on specific devices /// - private (IEncoder, IWeightTensor, IWeightTensor, FeedForwardLayer) GetNetworksOnDeviceAt(int deviceId) + private (IEncoder, IWeightTensor, IWeightTensor, FeedForwardLayer, IWeightTensor) GetNetworksOnDeviceAt(int deviceId) { var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId); return (m_encoder.GetNetworkOnDevice(deviceIdIdx), m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx), - m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_ffLayer.GetNetworkOnDevice(deviceIdIdx)); + m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_ffLayer.GetNetworkOnDevice(deviceIdIdx), m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx)); } /// @@ -124,7 +126,7 @@ public override List RunForwardOnSingleDevice(IComputeGraph g, IS var srcSnts = sntPairBatch.GetSrcTokens(0); var tgtSnts = sntPairBatch.GetTgtTokens(0); - (var encoder, var srcEmbedding, var segmentEmbedding, var decoderFFLayer) = GetNetworksOnDeviceAt(g.DeviceId); + (var encoder, var srcEmbedding, var segmentEmbedding, var decoderFFLayer, var posEmbeddings) = GetNetworksOnDeviceAt(g.DeviceId); // Reset networks encoder.Reset(g.GetWeightFactory(), srcSnts.Count); @@ -156,7 +158,7 @@ public override List RunForwardOnSingleDevice(IComputeGraph g, IS int batchSize = srcSnts.Count; // Encoding input source sentences - IWeightTensor encOutput = Encoder.Run(g, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths); + IWeightTensor encOutput = Encoder.Run(g, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbeddings, segmentEmbedding, srcTokensList, originalSrcLengths); IWeightTensor ffLayer = decoderFFLayer.Process(encOutput, batchSize, g); float cost = 0.0f; diff --git a/Seq2SeqSharp/Applications/SeqSimilarity.cs b/Seq2SeqSharp/Applications/SeqSimilarity.cs index 7c68eb50..9aaf3dd4 100644 --- a/Seq2SeqSharp/Applications/SeqSimilarity.cs +++ b/Seq2SeqSharp/Applications/SeqSimilarity.cs @@ -34,6 +34,8 @@ public class SeqSimilarity : BaseSeq2SeqFramework private MultiProcessorNetworkWrapper m_encoder; //The encoders over devices. private MultiProcessorNetworkWrapper m_segmentEmbedding; + private MultiProcessorNetworkWrapper m_posEmbedding = null; + private readonly ShuffleEnums m_shuffleType = ShuffleEnums.Random; private readonly SeqSimilarityOptions m_options; private MemoryCache m_memoryCache; @@ -81,7 +83,7 @@ private bool CreateTrainableParameters(IModel model) var raDeviceIds = new RoundArray(DeviceIds); m_encoder = Encoder.CreateEncoders(model, m_options, raDeviceIds); m_encoderFFLayer = new MultiProcessorNetworkWrapper(new FeedForwardLayer($"FeedForward_Encoder", model.HiddenDim, model.ClsVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(), isTrainable: true), DeviceIds); - m_segmentEmbedding = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTrainSentLength, m_options.MaxTestSentLength), model); + (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTrainSentLength, m_options.MaxTestSentLength), model, createAPE: false); Logger.WriteLine($"Creating embeddings. Shape = '({model.SrcVocab.Count} ,{model.EncoderEmbeddingDim})'"); m_srcEmbedding = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { model.SrcVocab.Count, model.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, fanOut: true, name: "SrcEmbeddings", @@ -93,13 +95,14 @@ private bool CreateTrainableParameters(IModel model) /// /// Get networks on specific devices /// - private (IEncoder, IWeightTensor, IFeedForwardLayer, IWeightTensor) GetNetworksOnDeviceAt(int deviceId) + private (IEncoder, IWeightTensor, IFeedForwardLayer, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceId) { var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId); return (m_encoder.GetNetworkOnDevice(deviceIdIdx), m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx), m_encoderFFLayer.GetNetworkOnDevice(deviceIdIdx), - m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx)); + m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), + m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx)); } private string GenerateCacheKey(List> strs) @@ -130,7 +133,7 @@ public override List RunForwardOnSingleDevice(IComputeGraph compu var nrs = new List(); var nr = new NetworkResult { Output = new List>>() }; - (IEncoder encoder, IWeightTensor srcEmbedding, IFeedForwardLayer encoderFFLayer, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(computeGraph.DeviceId); + (IEncoder encoder, IWeightTensor srcEmbedding, IFeedForwardLayer encoderFFLayer, IWeightTensor segmentEmbedding, IWeightTensor posEmbeddings) = GetNetworksOnDeviceAt(computeGraph.DeviceId); IWeightTensor encOutput1; IWeightTensor encOutput2; @@ -140,7 +143,7 @@ public override List RunForwardOnSingleDevice(IComputeGraph compu string cacheKey1 = GenerateCacheKey(sntPairBatch.GetSrcTokens(0)); if (!m_memoryCache.TryGetValue(cacheKey1, out encOutput1)) { - encOutput1 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, segmentEmbedding, 0); // output shape: [batch_size, dim] + encOutput1 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbeddings, segmentEmbedding, 0); // output shape: [batch_size, dim] var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1); m_memoryCache.Set(cacheKey1, encOutput1.CopyWeightsRef($"cache_{encOutput1.Name}", false, graphToBind: null), cacheEntryOptions); @@ -149,7 +152,7 @@ public override List RunForwardOnSingleDevice(IComputeGraph compu string cacheKey2 = GenerateCacheKey(sntPairBatch.GetSrcTokens(1)); if (!m_memoryCache.TryGetValue(cacheKey2, out encOutput2)) { - encOutput2 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, segmentEmbedding, 1); // output_shape: [batch_size, dim] + encOutput2 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbeddings, segmentEmbedding, 1); // output_shape: [batch_size, dim] var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1); m_memoryCache.Set(cacheKey2, encOutput2.CopyWeightsRef($"cache_{encOutput2.Name}", false, graphToBind: null), cacheEntryOptions); @@ -158,8 +161,8 @@ public override List RunForwardOnSingleDevice(IComputeGraph compu else { //We always run encoder network during training time or using GPUs - encOutput1 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, segmentEmbedding, 0); // output shape: [batch_size, dim] - encOutput2 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, segmentEmbedding, 1); // output_shape: [batch_size, dim] + encOutput1 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbeddings, segmentEmbedding, 0); // output shape: [batch_size, dim] + encOutput2 = Encoder.BuildTensorForSourceTokenGroupAt(computeGraph, sntPairBatch, m_shuffleType, encoder, m_modelMetaData, srcEmbedding, posEmbeddings, segmentEmbedding, 1); // output_shape: [batch_size, dim] } if (m_modelMetaData.SimilarityType.Equals("Continuous", StringComparison.InvariantCultureIgnoreCase)) diff --git a/Seq2SeqSharp/Layers/MultiHeadAttention.cs b/Seq2SeqSharp/Layers/MultiHeadAttention.cs index 8562d5c3..bf825189 100644 --- a/Seq2SeqSharp/Layers/MultiHeadAttention.cs +++ b/Seq2SeqSharp/Layers/MultiHeadAttention.cs @@ -229,7 +229,6 @@ public IWeightTensor Perform(IWeightTensor inputQ, IWeightTensor keyMask, int ba //Multi-head attentions IWeightTensor Qs = g.View(g.AsContiguous(g.Transpose(allQ, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, newTokensIdx, m_d }); - Qs = g.RoPE(Qs, newTokensIdx); IWeightTensor Ks = null; IWeightTensor Vs = null; @@ -238,13 +237,7 @@ public IWeightTensor Perform(IWeightTensor inputQ, IWeightTensor keyMask, int ba { IWeightTensor allK = g.View(g.Affine(inputK, K, Kb), dims: new long[] { batchSize, seqLenK, m_multiHeadNum, m_d }); IWeightTensor allV = g.View(g.Affine(inputV, V, Vb), dims: new long[] { batchSize, seqLenV, m_multiHeadNum, m_d }); - //Ks = g.View(g.AsContiguous(g.Transpose(g.Transpose(allK, 1, 2), 2, 3)), dims: new long[] { batchSize * m_multiHeadNum, m_d, seqLenK }); - - Ks = g.View(g.AsContiguous(g.Transpose(allK, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, seqLenK, m_d }); - Ks = g.RoPE(Ks, seqLenK); - Ks = g.View(g.AsContiguous(g.Transpose(Ks, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, m_d, seqLenK }); - - + Ks = g.View(g.AsContiguous(g.Transpose(g.Transpose(allK, 1, 2), 2, 3)), dims: new long[] { batchSize * m_multiHeadNum, m_d, seqLenK }); Vs = g.View(g.AsContiguous(g.Transpose(allV, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, seqLenV, m_d }); } else @@ -255,12 +248,7 @@ public IWeightTensor Perform(IWeightTensor inputQ, IWeightTensor keyMask, int ba if (cachedTensors.ContainsKey(KsCacheName) == false) { IWeightTensor allK = g.View(g.Affine(inputK, K, Kb), dims: new long[] { batchSize, seqLenK, m_multiHeadNum, m_d }); - //Ks = g.View(g.AsContiguous(g.Transpose(g.Transpose(allK, 1, 2), 2, 3)), dims: new long[] { batchSize * m_multiHeadNum, m_d, seqLenK }); - - Ks = g.View(g.AsContiguous(g.Transpose(allK, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, seqLenK, m_d }); - Ks = g.RoPE(Ks, seqLenK); - Ks = g.View(g.AsContiguous(g.Transpose(Ks, 1, 2)), dims: new long[] { batchSize * m_multiHeadNum, m_d, seqLenK }); - + Ks = g.View(g.AsContiguous(g.Transpose(g.Transpose(allK, 1, 2), 2, 3)), dims: new long[] { batchSize * m_multiHeadNum, m_d, seqLenK }); cachedTensors.Add(KsCacheName, Ks.CopyWeightsRef(KsCacheName, Ks.NeedGradient, graphToBind: null)); } else diff --git a/Seq2SeqSharp/Utils/Misc.cs b/Seq2SeqSharp/Utils/Misc.cs index 363b5cd6..fed74519 100644 --- a/Seq2SeqSharp/Utils/Misc.cs +++ b/Seq2SeqSharp/Utils/Misc.cs @@ -116,20 +116,28 @@ public static IOptimizer CreateOptimizer(Options opts) return optimizer; } - public static MultiProcessorNetworkWrapper CreateAuxEmbeddings(RoundArray raDeviceIds, int hiddenDim, int maxSentLength, IModel modelMetaData, DType elementType = DType.Float32, bool isTrainable = true) + public static (MultiProcessorNetworkWrapper, MultiProcessorNetworkWrapper) CreateAuxEmbeddings(RoundArray raDeviceIds, int hiddenDim, int maxSentLength, IModel modelMetaData, DType elementType = DType.Float32, bool isTrainable = true, bool createAPE = false) { + MultiProcessorNetworkWrapper posEmbeddings = null; MultiProcessorNetworkWrapper segmentEmbeddings = null; if (modelMetaData.EncoderType != EncoderTypeEnums.BiLSTM || modelMetaData.DecoderType != DecoderTypeEnums.AttentionLSTM) { + if (createAPE) + { + posEmbeddings = new MultiProcessorNetworkWrapper(PositionEmbedding.BuildPositionWeightTensor( + maxSentLength + 2, + hiddenDim, raDeviceIds.GetNextItem(), "PosEmbedding", false, elementType: elementType), raDeviceIds.ToArray(), true); + } + if (modelMetaData.EnableSegmentEmbeddings) { - segmentEmbeddings = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { modelMetaData.MaxSegmentNum, modelMetaData.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, name: "SegmentEmbedding", + segmentEmbeddings = new MultiProcessorNetworkWrapper(new WeightTensor(new long[2] { modelMetaData.MaxSegmentNum, modelMetaData.EncoderEmbeddingDim }, raDeviceIds.GetNextItem(), normType: NormType.Uniform, name: "SegmentEmbedding", isTrainable: isTrainable, dtype: elementType), raDeviceIds.ToArray()); } } - return segmentEmbeddings; + return (posEmbeddings, segmentEmbeddings); } diff --git a/Seq2SeqSharp/Utils/PositionEmbedding.cs b/Seq2SeqSharp/Utils/PositionEmbedding.cs new file mode 100644 index 00000000..1d3ec622 --- /dev/null +++ b/Seq2SeqSharp/Utils/PositionEmbedding.cs @@ -0,0 +1,86 @@ +// Copyright (c) Zhongkai Fu. All rights reserved. +// https://github.com/zhongkaifu/Seq2SeqSharp +// +// This file is part of Seq2SeqSharp. +// +// Seq2SeqSharp is licensed under the BSD-3-Clause license found in the LICENSE file in the root directory of this source tree. +// +// Seq2SeqSharp is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details. + +using AdvUtils; +using Seq2SeqSharp.Tools; +using System; +using TensorSharp; + +namespace Seq2SeqSharp.Utils +{ + public enum PositionEmbeddingEnums + { + APE, + RoPE, + } + + public class PositionEmbedding + { + + public static IWeightTensor AddPositionEmbedding(IComputeGraph g, IWeightTensor posEmbedding, int batchSize, IWeightTensor inputEmbs, float dropoutRatio) + { + var Column = posEmbedding.Columns; + int seqLen = inputEmbs.Rows / batchSize; + + IWeightTensor posEmbeddingPeek = g.Peek(posEmbedding, 0, 0, seqLen); + using (var posEmbeddingPeekView = g.View(posEmbeddingPeek, dims: new long[] { 1, seqLen, Column })) + { + using (var posEmbeddingPeekViewExp = g.Expand(posEmbeddingPeekView, dims: new long[] { batchSize, seqLen, Column })) + { + inputEmbs = g.View(inputEmbs, dims: new long[] { batchSize, seqLen, Column }); + inputEmbs = g.Add(inputEmbs, posEmbeddingPeekViewExp, inPlace: true); + inputEmbs = g.View(inputEmbs, dims: new long[] { batchSize * seqLen, Column }); + } + } + + posEmbeddingPeek.Dispose(); + + inputEmbs = g.Dropout(inputEmbs, batchSize, dropoutRatio, inPlace: true); + + return inputEmbs; + } + + public static WeightTensor BuildPositionWeightTensor(int row, int column, int deviceId, string name = "", bool isTrainable = false, DType elementType = DType.Float32) + { + Logger.WriteLine($"Building position weights tensor. Row = '{row}', Column = '{column}', DeviceId = '{deviceId}', Name = '{name}', Trainable = '{isTrainable}'"); + + WeightTensor t = new WeightTensor(new long[2] { row, column }, deviceId, name: name, isTrainable: isTrainable, needGradient: isTrainable, dtype: elementType); + float[] posWeights = new float[row * column]; + + float numTimescales = (float)column / 2; + float logTimescaleIncrement = (float)(Math.Log(10000.0f) / (numTimescales - 1.0f)); + + for (int p = 0; p < row; ++p) + { + for (int i = 0; i < numTimescales; i++) + { + float v = (float)(p * Math.Exp(i * -logTimescaleIncrement)); + + posWeights[p * column + i] = (float)Math.Sin(v); + posWeights[p * column + (int)numTimescales + i] = (float)Math.Cos(v); + } + } + + if (elementType == DType.Float16) + { + Tensor tmp = new Tensor(t.Allocator, DType.Float32, t.Sizes); + tmp.CopyFrom(posWeights); + Ops.Float2Half(t.TWeight, tmp); + tmp.Dispose(); + } + else + { + t.TWeight.CopyFrom(posWeights); + } + + return t; + } + } +} \ No newline at end of file