Skip to content

Commit

Permalink
Only keep RoPE for self-attention layer
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Sep 4, 2023
1 parent 4db69a5 commit 54945f1
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 54 deletions.
20 changes: 13 additions & 7 deletions Seq2SeqSharp/Applications/Encoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,22 @@ public static MultiProcessorNetworkWrapper<IEncoder> CreateEncoders(IModel model
}

static public IWeightTensor Run(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, IEncoder encoder, IModel modelMetaData, ShuffleEnums shuffleType,
IWeightTensor srcEmbedding, IWeightTensor segmentEmbedding, List<List<int>> srcSntsIds, float[] originalSrcLengths, bool amp = false)
IWeightTensor srcEmbedding, IWeightTensor posEmbeddings, IWeightTensor segmentEmbedding, List<List<int>> srcSntsIds, float[] originalSrcLengths, bool amp = false)
{
// Reset networks
encoder.Reset(computeGraph.GetWeightFactory(), srcSntsIds.Count);

IWeightTensor encOutput = InnerRunner(computeGraph, srcSntsIds, originalSrcLengths, shuffleType, encoder, modelMetaData, srcEmbedding, segmentEmbedding, amp);
IWeightTensor encOutput = InnerRunner(computeGraph, srcSntsIds, originalSrcLengths, shuffleType, encoder, modelMetaData, srcEmbedding, posEmbeddings, segmentEmbedding, amp);
return encOutput;
}

public static IWeightTensor BuildTensorForSourceTokenGroupAt(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, ShuffleEnums shuffleType, IEncoder encoder, IModel modelMetaData, IWeightTensor srcEmbedding, IWeightTensor segmentEmbedding, int groupId)
public static IWeightTensor BuildTensorForSourceTokenGroupAt(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, ShuffleEnums shuffleType, IEncoder encoder, IModel modelMetaData, IWeightTensor srcEmbedding, IWeightTensor posEmbeddings, IWeightTensor segmentEmbedding, int groupId)
{
var contextTokens = InsertCLSToken(sntPairBatch.GetSrcTokens(groupId));
var originalSrcContextLength = BuildInTokens.PadSentences(contextTokens);
var contextTokenIds = modelMetaData.SrcVocab.GetWordIndex(contextTokens);

IWeightTensor encContextOutput = InnerRunner(computeGraph, contextTokenIds, originalSrcContextLength, shuffleType, encoder, modelMetaData, srcEmbedding, segmentEmbedding);
IWeightTensor encContextOutput = InnerRunner(computeGraph, contextTokenIds, originalSrcContextLength, shuffleType, encoder, modelMetaData, srcEmbedding, posEmbeddings, segmentEmbedding);

int contextPaddedLen = contextTokens[0].Count;
float[] contextCLSIdxs = new float[sntPairBatch.BatchSize];
Expand All @@ -88,14 +88,14 @@ public static IWeightTensor BuildTensorForSourceTokenGroupAt(IComputeGraph compu
}

static private IWeightTensor InnerRunner(IComputeGraph computeGraph, List<List<int>> srcTokensList, float[] originalSrcLengths, ShuffleEnums shuffleType, IEncoder encoder, IModel modelMetaData,
IWeightTensor srcEmbedding, IWeightTensor segmentEmbedding, bool amp = false)
IWeightTensor srcEmbedding, IWeightTensor posEmbedding, IWeightTensor segmentEmbedding, bool amp = false)
{
int batchSize = srcTokensList.Count;
int srcSeqPaddedLen = srcTokensList[0].Count;
IWeightTensor srcSelfMask = (shuffleType == ShuffleEnums.NoPaddingInSrc || shuffleType == ShuffleEnums.NoPadding || batchSize == 1) ? null : computeGraph.BuildPadSelfMask(srcSeqPaddedLen, originalSrcLengths, elementType: amp ? DType.Float16 : DType.Float32); // The length of source sentences are same in a single mini-batch, so we don't have source mask.

// Encoding input source sentences
var encOutput = RunEncoder(computeGraph, srcTokensList, encoder, modelMetaData, srcEmbedding, srcSelfMask, segmentEmbedding, amp: amp);
var encOutput = RunEncoder(computeGraph, srcTokensList, encoder, modelMetaData, srcEmbedding, srcSelfMask, posEmbedding, segmentEmbedding, amp: amp);
if (srcSelfMask != null)
{
srcSelfMask.Dispose();
Expand All @@ -114,11 +114,17 @@ static private IWeightTensor InnerRunner(IComputeGraph computeGraph, List<List<i
/// <param name="reversEncoder"></param>
/// <param name="embeddings"></param>
/// <returns></returns>
static private IWeightTensor RunEncoder(IComputeGraph g, List<List<int>> seqs, IEncoder encoder, IModel modelMetaData, IWeightTensor embeddings, IWeightTensor selfMask,
static private IWeightTensor RunEncoder(IComputeGraph g, List<List<int>> seqs, IEncoder encoder, IModel modelMetaData, IWeightTensor embeddings, IWeightTensor selfMask, IWeightTensor posEmbeddings,
IWeightTensor segmentEmbeddings, bool amp = false)
{
int batchSize = seqs.Count;
var inputEmbs = TensorUtils.CreateTokensEmbeddings(seqs, g, embeddings, segmentEmbeddings, modelMetaData.SrcVocab, (float)Math.Sqrt(embeddings.Columns), enableTagEmbedding: modelMetaData.EnableTagEmbeddings, amp: amp);

if (modelMetaData.EncoderType == EncoderTypeEnums.Transformer && posEmbeddings != null)
{
inputEmbs = PositionEmbedding.AddPositionEmbedding(g, posEmbeddings, batchSize, inputEmbs, 0.0f);
}

return encoder.Encode(inputEmbs, batchSize, g, selfMask);
}
}
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Applications/GPT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ private bool CreateTrainableParameters(IModel model)
m_decoderFFLayer = new MultiProcessorNetworkWrapper<IFeedForwardLayer>(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(),
isTrainable: (m_options.Task == ModeEnums.Train), learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType), DeviceIds);

m_segmentEmbedding = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength), model, elementType, isTrainable: (m_options.Task == ModeEnums.Train));
(_, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength), model, elementType, isTrainable: (m_options.Task == ModeEnums.Train));
m_tgtEmbedding = CreateTgtEmbeddings(model, raDeviceIds, m_options.IsTgtEmbeddingTrainable && (m_options.Task == ModeEnums.Train), m_options.DecoderStartLearningRateFactor, elementType);

return (true);
Expand Down
13 changes: 7 additions & 6 deletions Seq2SeqSharp/Applications/Seq2Seq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class Seq2Seq : BaseSeq2SeqFramework<Seq2SeqModel>
private MultiProcessorNetworkWrapper<IDecoder> m_decoder; //The decoders over devices
private MultiProcessorNetworkWrapper<IFeedForwardLayer> m_decoderFFLayer; //The feed forward layers over devices after all layers in decoder

private MultiProcessorNetworkWrapper<IWeightTensor> m_posEmbedding = null;
private MultiProcessorNetworkWrapper<IWeightTensor> m_segmentEmbedding;

private MultiProcessorNetworkWrapper<IFeedForwardLayer> m_pointerGenerator;
Expand Down Expand Up @@ -109,7 +110,7 @@ private bool CreateTrainableParameters(IModel model)
m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, elementType: elementType);
m_decoderFFLayer = new MultiProcessorNetworkWrapper<IFeedForwardLayer>(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(),
isTrainable: true, learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType: elementType), DeviceIds);
m_segmentEmbedding = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(Math.Max(m_options.MaxSrcSentLength, m_options.MaxValidSrcSentLength), Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength)), model, elementType: elementType);
(m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(Math.Max(m_options.MaxSrcSentLength, m_options.MaxValidSrcSentLength), Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength)), model, elementType: elementType, createAPE: false);
(m_srcEmbedding, m_tgtEmbedding) = CreateSrcTgtEmbeddings(model, raDeviceIds, m_options.IsSrcEmbeddingTrainable, m_options.IsTgtEmbeddingTrainable, m_options.EncoderStartLearningRateFactor, m_options.DecoderStartLearningRateFactor, elementType: elementType);


Expand Down Expand Up @@ -141,15 +142,15 @@ public void VQModel()
/// <summary>
/// Get networks on specific devices
/// </summary>
private (IEncoder, IDecoder, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor, IFeedForwardLayer) GetNetworksOnDeviceAt(int deviceId)
private (IEncoder, IDecoder, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor, IFeedForwardLayer, IWeightTensor) GetNetworksOnDeviceAt(int deviceId)
{
var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId);
return (m_encoder.GetNetworkOnDevice(deviceIdIdx),
m_decoder.GetNetworkOnDevice(deviceIdIdx),
m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx),
m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx),
m_modelMetaData.SharedEmbeddings ? m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx) : m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx),
m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_pointerGenerator?.GetNetworkOnDevice(deviceIdIdx));
m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_pointerGenerator?.GetNetworkOnDevice(deviceIdIdx), m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx));
}

private string GenerateCacheKey(List<List<string>> strs)
Expand All @@ -175,7 +176,7 @@ private string GenerateCacheKey(List<List<string>> strs)
/// <returns>The cost of forward part</returns>
public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, DecodingOptions decodingOptions, bool isTraining)
{
(var encoder, var decoder, var decoderFFLayer, var srcEmbedding, var tgtEmbedding, var segmentEmbedding, var pointerGenerator) = GetNetworksOnDeviceAt(computeGraph.DeviceId);
(var encoder, var decoder, var decoderFFLayer, var srcEmbedding, var tgtEmbedding, var segmentEmbedding, var pointerGenerator, var posEmbeddings) = GetNetworksOnDeviceAt(computeGraph.DeviceId);

var srcSnts = sntPairBatch.GetSrcTokens(0);
var originalSrcLengths = BuildInTokens.PadSentences(srcSnts);
Expand All @@ -193,7 +194,7 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph compu
string cacheKey = GenerateCacheKey(srcSnts);
if (!m_memoryCache.TryGetValue(cacheKey, out encOutput))
{
encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths); // Shape: [batchsize * seqLen, embedding_dim]
encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbeddings, segmentEmbedding, srcTokensList, originalSrcLengths); // Shape: [batchsize * seqLen, embedding_dim]

var cacheEntryOptions = new MemoryCacheEntryOptions().SetSize(1);
m_memoryCache.Set(cacheKey, encOutput.CopyWeightsRef($"cache_{encOutput.Name}", false, graphToBind: null), cacheEntryOptions);
Expand All @@ -202,7 +203,7 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph compu
else
{
// Compute src tensor
encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths, amp:m_options.AMP);
encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbeddings, segmentEmbedding, srcTokensList, originalSrcLengths, amp:m_options.AMP);
}

List<NetworkResult> nrs = new List<NetworkResult>();
Expand Down
11 changes: 6 additions & 5 deletions Seq2SeqSharp/Applications/Seq2SeqClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ public class Seq2SeqClassification : BaseSeq2SeqFramework<Seq2SeqClassificationM
private MultiProcessorNetworkWrapper<IFeedForwardLayer> m_encoderFFLayer; //The feed forward layers over devices after all layers in encoder
private MultiProcessorNetworkWrapper<IFeedForwardLayer> m_decoderFFLayer; //The feed forward layers over devices after all layers in decoder

private MultiProcessorNetworkWrapper<IWeightTensor> m_posEmbedding = null;
private MultiProcessorNetworkWrapper<IWeightTensor> m_segmentEmbedding;

private readonly ShuffleEnums m_shuffleType = ShuffleEnums.Random;
Expand Down Expand Up @@ -91,7 +92,7 @@ private bool CreateTrainableParameters(IModel model)
m_decoderFFLayer = new MultiProcessorNetworkWrapper<IFeedForwardLayer>(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(),
isTrainable: true), DeviceIds);

m_segmentEmbedding = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(Math.Max(m_options.MaxSrcSentLength, m_options.MaxValidSrcSentLength), Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength)), model);
(m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(Math.Max(m_options.MaxSrcSentLength, m_options.MaxValidSrcSentLength), Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength)), model, createAPE: false);
(m_srcEmbedding, m_tgtEmbedding) = CreateSrcTgtEmbeddings(model, raDeviceIds, m_options.IsSrcEmbeddingTrainable, m_options.IsTgtEmbeddingTrainable, m_options.EncoderStartLearningRateFactor, m_options.DecoderStartLearningRateFactor);
return true;
}
Expand All @@ -101,7 +102,7 @@ private bool CreateTrainableParameters(IModel model)
/// </summary>
/// <param name="deviceIdIdx"></param>
/// <returns></returns>
private (IEncoder, IDecoder, IFeedForwardLayer, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceId)
private (IEncoder, IDecoder, IFeedForwardLayer, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceId)
{
var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId);
return (m_encoder.GetNetworkOnDevice(deviceIdIdx),
Expand All @@ -110,7 +111,7 @@ private bool CreateTrainableParameters(IModel model)
m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx),
m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx),
m_modelMetaData.SharedEmbeddings ? m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx) : m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx),
m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx));
m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx));
}

/// <summary>
Expand All @@ -123,13 +124,13 @@ private bool CreateTrainableParameters(IModel model)
/// <returns>The cost of forward part</returns>
public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, DecodingOptions decodingOptions, bool isTraining)
{
(IEncoder encoder, IDecoder decoder, IFeedForwardLayer encoderFFLayer, IFeedForwardLayer decoderFFLayer, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding, IWeightTensor segmentEmbedding) = GetNetworksOnDeviceAt(computeGraph.DeviceId);
(IEncoder encoder, IDecoder decoder, IFeedForwardLayer encoderFFLayer, IFeedForwardLayer decoderFFLayer, IWeightTensor srcEmbedding, IWeightTensor tgtEmbedding, IWeightTensor segmentEmbedding, IWeightTensor posEmbeddings) = GetNetworksOnDeviceAt(computeGraph.DeviceId);

var srcSnts = sntPairBatch.GetSrcTokens(0);
var originalSrcLengths = BuildInTokens.PadSentences(srcSnts);
var srcTokensList = m_modelMetaData.SrcVocab.GetWordIndex(srcSnts);

IWeightTensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, segmentEmbedding, srcTokensList, originalSrcLengths);
IWeightTensor encOutput = Encoder.Run(computeGraph, sntPairBatch, encoder, m_modelMetaData, m_shuffleType, srcEmbedding, posEmbeddings, segmentEmbedding, srcTokensList, originalSrcLengths);

List<NetworkResult> nrs = new List<NetworkResult>();
int srcSeqPaddedLen = srcSnts[0].Count;
Expand Down
Loading

0 comments on commit 54945f1

Please sign in to comment.