Skip to content

Commit

Permalink
Make Positional embedding is configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Sep 12, 2023
1 parent a61fd87 commit e972fa7
Show file tree
Hide file tree
Showing 19 changed files with 156 additions and 70 deletions.
24 changes: 20 additions & 4 deletions AdvUtils/Logger.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,16 @@ public static void WriteLine(Level level, string s, params object[] args)
else
Console.WriteLine(sLine);

if (s_sw != null)
s_sw.WriteLine(sLine);
try
{
if (s_sw != null)
s_sw.WriteLine(sLine);
}
catch (Exception err)
{
Console.Error.WriteLine($"Failed to output log to file '{LogFile}'. Error = '{err.Message}'");
s_sw = null;
}

}

Expand Down Expand Up @@ -76,8 +84,16 @@ public static void WriteLine(Level level, ConsoleColor color, string s, params o

Console.ResetColor();

if (s_sw != null)
s_sw.WriteLine(sLine);
try
{
if (s_sw != null)
s_sw.WriteLine(sLine);
}
catch (Exception err)
{
Console.Error.WriteLine($"Failed to output log to file '{LogFile}'. Error = '{err.Message}'");
s_sw = null;
}

}

Expand Down
36 changes: 25 additions & 11 deletions Seq2SeqSharp/Applications/Decoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,29 @@ namespace Seq2SeqSharp.Applications
{
public class Decoder
{
public static MultiProcessorNetworkWrapper<IDecoder> CreateDecoders(IModel modelMetaData, Seq2SeqOptions options, RoundArray<int> raDeviceIds, DType elementType = DType.Float32)
public static MultiProcessorNetworkWrapper<IDecoder> CreateDecoders(IModel model, Seq2SeqOptions options, RoundArray<int> raDeviceIds, DType elementType = DType.Float32)
{
MultiProcessorNetworkWrapper<IDecoder> decoder;
if (modelMetaData.DecoderType == DecoderTypeEnums.AttentionLSTM)
if (model.DecoderType == DecoderTypeEnums.AttentionLSTM)
{
decoder = new MultiProcessorNetworkWrapper<IDecoder>(
new AttentionDecoder("AttnLSTMDecoder", modelMetaData.HiddenDim, modelMetaData.DecoderEmbeddingDim, modelMetaData.HiddenDim,
options.DropoutRatio, modelMetaData.DecoderLayerDepth, raDeviceIds.GetNextItem(), modelMetaData.EnableCoverageModel,
new AttentionDecoder("AttnLSTMDecoder", model.HiddenDim, model.DecoderEmbeddingDim, model.HiddenDim,
options.DropoutRatio, model.DecoderLayerDepth, raDeviceIds.GetNextItem(), model.EnableCoverageModel,
isTrainable: options.IsDecoderTrainable && (options.Task == ModeEnums.Train), elementType: elementType), raDeviceIds.ToArray());
}
else if (modelMetaData.DecoderType == DecoderTypeEnums.GPTDecoder)
else if (model.DecoderType == DecoderTypeEnums.GPTDecoder)
{
decoder = new MultiProcessorNetworkWrapper<IDecoder>(
new GPTDecoder("GPTDecoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.IntermediateDim, modelMetaData.DecoderEmbeddingDim, modelMetaData.DecoderLayerDepth, options.DropoutRatio, raDeviceIds.GetNextItem(),
isTrainable: options.IsDecoderTrainable && (options.Task == ModeEnums.Train), learningRateFactor: options.DecoderStartLearningRateFactor, activateFunc: modelMetaData.ActivateFunc, expertNum: modelMetaData.ExpertNum, expertsPerTokenFactor: modelMetaData.ExpertsPerTokenFactor, elementType: elementType), raDeviceIds.ToArray());
new GPTDecoder("GPTDecoder", model.MultiHeadNum, model.HiddenDim, model.IntermediateDim, model.DecoderEmbeddingDim, model.DecoderLayerDepth, options.DropoutRatio, raDeviceIds.GetNextItem(),
isTrainable: options.IsDecoderTrainable && (options.Task == ModeEnums.Train), learningRateFactor: options.DecoderStartLearningRateFactor, activateFunc: model.ActivateFunc, expertNum: model.ExpertNum,
expertsPerTokenFactor: model.ExpertsPerTokenFactor, elementType: elementType, peType:model.PEType), raDeviceIds.ToArray());
}
else
{
decoder = new MultiProcessorNetworkWrapper<IDecoder>(
new TransformerDecoder("TransformerDecoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.IntermediateDim, modelMetaData.DecoderEmbeddingDim, modelMetaData.DecoderLayerDepth, options.DropoutRatio, raDeviceIds.GetNextItem(),
isTrainable: options.IsDecoderTrainable && (options.Task == ModeEnums.Train), learningRateFactor: options.DecoderStartLearningRateFactor, activateFunc: modelMetaData.ActivateFunc, expertNum: modelMetaData.ExpertNum, expertsPerTokenFactor: modelMetaData.ExpertsPerTokenFactor, elementType: elementType), raDeviceIds.ToArray());
new TransformerDecoder("TransformerDecoder", model.MultiHeadNum, model.HiddenDim, model.IntermediateDim, model.DecoderEmbeddingDim, model.DecoderLayerDepth, options.DropoutRatio, raDeviceIds.GetNextItem(),
isTrainable: options.IsDecoderTrainable && (options.Task == ModeEnums.Train), learningRateFactor: options.DecoderStartLearningRateFactor, activateFunc: model.ActivateFunc, expertNum: model.ExpertNum,
expertsPerTokenFactor: model.ExpertsPerTokenFactor, elementType: elementType, peType:model.PEType), raDeviceIds.ToArray());
}

return decoder;
Expand Down Expand Up @@ -265,7 +267,8 @@ public static List<List<BeamSearchStatus>> CombineBeamSearchResults(List<List<Be
public static (float, List<List<BeamSearchStatus>>) DecodeTransformer(List<List<int>> tgtSeqs, IComputeGraph g, IWeightTensor encOutputs, TransformerDecoder decoder, IFeedForwardLayer decoderFFLayer,
IWeightTensor tgtEmbedding, float[] srcOriginalLenghts, Vocab tgtVocab, ShuffleEnums shuffleType, float dropoutRatio, DecodingOptions decodingOptions, bool isTraining = true,
bool outputSentScore = true, List<BeamSearchStatus> previousBeamSearchResults = null, IFeedForwardLayer pointerGenerator = null, List<List<int>> srcSeqs = null, Dictionary<string, IWeightTensor> cachedTensors = null,
List<List<int>> alignmentsToSrc = null, List<List<float>> alignmentScoresToSrc = null, bool teacherForcedAlignment = false, LossEnums lossType = LossEnums.CrossEntropy, float focalLossGamma = 0.0f, float lossSmooth = 1e-9f, List<int> blockedTokens = null, IWeightTensor segmentEmbeddings = null, bool amp = false)
List<List<int>> alignmentsToSrc = null, List<List<float>> alignmentScoresToSrc = null, bool teacherForcedAlignment = false, LossEnums lossType = LossEnums.CrossEntropy, float focalLossGamma = 0.0f, float lossSmooth = 1e-9f,
List<int> blockedTokens = null, IWeightTensor segmentEmbeddings = null, bool amp = false, IWeightTensor posEmbeddings = null)
{
int eosTokenId = tgtVocab.GetWordIndex(BuildInTokens.EOS, logUnk: true);
int batchSize = tgtSeqs.Count;
Expand All @@ -291,6 +294,11 @@ public static (float, List<List<BeamSearchStatus>>) DecodeTransformer(List<List<
}

IWeightTensor inputEmbs = TensorUtils.CreateTokensEmbeddings(tgtSeqs, g, tgtEmbedding, segmentEmbeddings, tgtVocab, scaleFactor: (float)Math.Sqrt(tgtEmbedding.Columns), amp: amp);
if (posEmbeddings != null)
{
inputEmbs = PositionEmbedding.AddPositionEmbedding(g, posEmbeddings, batchSize, inputEmbs, dropoutRatio);
}

IWeightTensor decOutput;
IWeightTensor decEncAttnProbs;
(decOutput, decEncAttnProbs) = decoder.Decode(inputEmbs, encOutputs, tgtSelfTriMask, srcTgtMask, batchSize, g, outputAttnWeights: pointerGenerator != null, cachedTensors: cachedTensors);
Expand Down Expand Up @@ -470,7 +478,8 @@ public static (float, List<List<BeamSearchStatus>>) DecodeTransformer(List<List<
public static (float, List<List<BeamSearchStatus>>) GPTDecode(List<List<int>> tgtSeqs, IComputeGraph g, GPTDecoder decoder, IFeedForwardLayer decoderFFLayer,
IWeightTensor tgtEmbedding, Vocab tgtVocab, ShuffleEnums shuffleType, float dropoutRatio, DecodingOptions decodingOptions, bool isTraining = true,
bool outputSentScore = true, List<BeamSearchStatus> previousBeamSearchResults = null, Dictionary<string, IWeightTensor> cachedTensors = null,
LossEnums lossType = LossEnums.CrossEntropy, float focalLossGamma = 0.0f, float lossSmooth = 1e-9f, List<int> blockedTokens = null, IWeightTensor segmentEmbeddings = null, bool amp = true)
LossEnums lossType = LossEnums.CrossEntropy, float focalLossGamma = 0.0f, float lossSmooth = 1e-9f, List<int> blockedTokens = null, IWeightTensor segmentEmbeddings = null, bool amp = true,
IWeightTensor posEmbeddings = null)
{
int eosTokenId = tgtVocab.GetWordIndex(BuildInTokens.EOS, logUnk: true);
int batchSize = tgtSeqs.Count;
Expand All @@ -490,6 +499,11 @@ public static (float, List<List<BeamSearchStatus>>) GPTDecode(List<List<int>> tg
}

IWeightTensor inputEmbs = TensorUtils.CreateTokensEmbeddings(tgtSeqs, g, tgtEmbedding, segmentEmbeddings, tgtVocab, scaleFactor: (float)Math.Sqrt(tgtEmbedding.Columns), amp: amp);
if (posEmbeddings != null)
{
inputEmbs = PositionEmbedding.AddPositionEmbedding(g, posEmbeddings, batchSize, inputEmbs, dropoutRatio);
}

IWeightTensor decOutput;
(decOutput, _) = decoder.Decode(inputEmbs, tgtSelfTriMask, batchSize, g, cachedTensors: cachedTensors);

Expand Down
11 changes: 6 additions & 5 deletions Seq2SeqSharp/Applications/Encoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,20 @@ static List<List<string>> InsertCLSToken(List<List<string>> tokens)
return newTokens;
}

public static MultiProcessorNetworkWrapper<IEncoder> CreateEncoders(IModel modelMetaData, Options options, RoundArray<int> raDeviceIds, DType elementType = DType.Float32)
public static MultiProcessorNetworkWrapper<IEncoder> CreateEncoders(IModel model, Options options, RoundArray<int> raDeviceIds, DType elementType = DType.Float32)
{
MultiProcessorNetworkWrapper<IEncoder> encoder = null;
if (modelMetaData.EncoderType == EncoderTypeEnums.BiLSTM)
if (model.EncoderType == EncoderTypeEnums.BiLSTM)
{
encoder = new MultiProcessorNetworkWrapper<IEncoder>(
new BiEncoder("BiLSTMEncoder", modelMetaData.HiddenDim, modelMetaData.EncoderEmbeddingDim, modelMetaData.EncoderLayerDepth, raDeviceIds.GetNextItem(), isTrainable: options.IsEncoderTrainable), raDeviceIds.ToArray());
new BiEncoder("BiLSTMEncoder", model.HiddenDim, model.EncoderEmbeddingDim, model.EncoderLayerDepth, raDeviceIds.GetNextItem(), isTrainable: options.IsEncoderTrainable), raDeviceIds.ToArray());
}
else
{
encoder = new MultiProcessorNetworkWrapper<IEncoder>(
new TransformerEncoder("TransformerEncoder", modelMetaData.MultiHeadNum, modelMetaData.HiddenDim, modelMetaData.IntermediateDim, modelMetaData.EncoderEmbeddingDim, modelMetaData.EncoderLayerDepth, options.DropoutRatio, raDeviceIds.GetNextItem(),
isTrainable: options.IsEncoderTrainable, learningRateFactor: options.EncoderStartLearningRateFactor, activateFunc: modelMetaData.ActivateFunc, expertNum: modelMetaData.ExpertNum, expertsPerTokenFactor: modelMetaData.ExpertsPerTokenFactor, elementType), raDeviceIds.ToArray());
new TransformerEncoder("TransformerEncoder", model.MultiHeadNum, model.HiddenDim, model.IntermediateDim, model.EncoderEmbeddingDim, model.EncoderLayerDepth, options.DropoutRatio, raDeviceIds.GetNextItem(),
isTrainable: options.IsEncoderTrainable, learningRateFactor: options.EncoderStartLearningRateFactor, activateFunc: model.ActivateFunc, expertNum: model.ExpertNum, expertsPerTokenFactor: model.ExpertsPerTokenFactor,
elementType, peType: model.PEType), raDeviceIds.ToArray());
}

return encoder;
Expand Down
17 changes: 11 additions & 6 deletions Seq2SeqSharp/Applications/GPT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public class GPT : BaseSeq2SeqFramework<Seq2SeqModel>
private MultiProcessorNetworkWrapper<IDecoder> m_decoder = null; //The decoders over devices
private MultiProcessorNetworkWrapper<IFeedForwardLayer> m_decoderFFLayer = null ; //The feed forward layers over devices after all layers in decoder
private MultiProcessorNetworkWrapper<IWeightTensor> m_segmentEmbedding = null;
private MultiProcessorNetworkWrapper<IWeightTensor> m_posEmbedding = null;

private readonly ShuffleEnums m_shuffleType = ShuffleEnums.Random;
readonly Seq2SeqOptions m_options = null;
Expand Down Expand Up @@ -125,7 +126,8 @@ private bool CreateTrainableParameters(IModel model)
m_decoderFFLayer = new MultiProcessorNetworkWrapper<IFeedForwardLayer>(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(),
isTrainable: (m_options.Task == ModeEnums.Train), learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType), DeviceIds);

(_, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength), model, elementType, isTrainable: (m_options.Task == ModeEnums.Train));
(m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength), model, elementType,
isTrainable: (m_options.Task == ModeEnums.Train), createAPE: (model.PEType == PositionEmbeddingEnums.APE));
m_tgtEmbedding = CreateTgtEmbeddings(model, raDeviceIds, m_options.IsTgtEmbeddingTrainable && (m_options.Task == ModeEnums.Train), m_options.DecoderStartLearningRateFactor, elementType);

return (true);
Expand All @@ -134,13 +136,14 @@ private bool CreateTrainableParameters(IModel model)
/// <summary>
/// Get networks on specific devices
/// </summary>
private (IDecoder, IFeedForwardLayer, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceId)
private (IDecoder, IFeedForwardLayer, IWeightTensor, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceId)
{
var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId);
return (m_decoder.GetNetworkOnDevice(deviceIdIdx),
m_decoderFFLayer.GetNetworkOnDevice(deviceIdIdx),
m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx),
m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx));
m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx),
m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx));
}

private string GenerateCacheKey(List<List<string>> strs)
Expand Down Expand Up @@ -194,7 +197,7 @@ private List<List<List<string>>> CombineInputOutput(List<List<string>> input, Li
/// <returns>The cost of forward part</returns>
public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, ISntPairBatch sntPairBatch, DecodingOptions decodingOptions, bool isTraining)
{
(var decoder, var decoderFFLayer, var tgtEmbedding, var segmentEmbedding) = GetNetworksOnDeviceAt(computeGraph.DeviceId);
(var decoder, var decoderFFLayer, var tgtEmbedding, var segmentEmbedding, var posEmbeddings) = GetNetworksOnDeviceAt(computeGraph.DeviceId);
List<NetworkResult> nrs = new List<NetworkResult>();

// Generate output decoder sentences
Expand All @@ -209,7 +212,8 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph compu
if (isTraining)
{
(var c, _) = Decoder.GPTDecode(tgtTokensList, computeGraph, decoder as GPTDecoder, decoderFFLayer, tgtEmbedding, m_modelMetaData.TgtVocab, m_shuffleType,
m_options.DropoutRatio, decodingOptions, isTraining, lossType: m_options.LossType, focalLossGamma: m_options.FocalLossGamma, segmentEmbeddings: segmentEmbedding, amp: m_options.AMP);
m_options.DropoutRatio, decodingOptions, isTraining, lossType: m_options.LossType, focalLossGamma: m_options.FocalLossGamma,
segmentEmbeddings: segmentEmbedding, amp: m_options.AMP, posEmbeddings: posEmbeddings);
nr.Cost = c;
nr.Output = null;
}
Expand Down Expand Up @@ -248,7 +252,8 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph compu
(var cost2, var bssSeqList) = Decoder.GPTDecode(batch2tgtTokens, g, decoder as GPTDecoder, decoderFFLayer, tgtEmbedding,
m_modelMetaData.TgtVocab, m_shuffleType, 0.0f, decodingOptions, isTraining,
outputSentScore: decodingOptions.BeamSearchSize > 1, previousBeamSearchResults: batchStatus,
blockedTokens: decodingOptions.BlockedTokens, segmentEmbeddings: segmentEmbedding, cachedTensors: cachedTensors, amp: m_options.AMP);
blockedTokens: decodingOptions.BlockedTokens, segmentEmbeddings: segmentEmbedding,
cachedTensors: cachedTensors, amp: m_options.AMP, posEmbeddings: posEmbeddings);

bssSeqList = Decoder.SwapBeamAndBatch(bssSeqList); // Swap shape: (beam_search_size, batch_size) -> (batch_size, beam_search_size)
batch2beam2seq = Decoder.CombineBeamSearchResults(batch2beam2seq, bssSeqList);
Expand Down
3 changes: 3 additions & 0 deletions Seq2SeqSharp/Applications/Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ public class Options
[Arg("The seed value of random generator", nameof(RandomSeed))]
public int RandomSeed = -1;

[Arg("The Positional Embeddings Type. It supports APE, NoPE and RoPE", nameof(PEType))]
public PositionEmbeddingEnums PEType = PositionEmbeddingEnums.APE;

public void ValidateOptions()
{
if (AMP == true && ProcessorType != ProcessorTypeEnums.GPU)
Expand Down
Loading

0 comments on commit e972fa7

Please sign in to comment.