Skip to content

Commit

Permalink
Bug fix for learning rate factor in Image2Seq
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Oct 17, 2023
1 parent d69d48c commit a828018
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions Seq2SeqSharp/Applications/Image2Seq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,16 +108,16 @@ private bool CreateTrainableParameters(IModel model)
isTrainable: true, learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType: elementType), DeviceIds);

m_posEmbedding = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { 1024, model.HiddenDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, name: "PositionalEmbedding",
isTrainable: true, dtype: elementType), raDeviceIds.ToArray());
learningRateFactor: m_options.EncoderStartLearningRateFactor, isTrainable: true, dtype: elementType), raDeviceIds.ToArray());

m_cls = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { 1, model.HiddenDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, name: "CLS",
m_cls = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { 1, model.HiddenDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, name: "CLS", learningRateFactor: m_options.EncoderStartLearningRateFactor,
isTrainable: true, dtype: elementType), raDeviceIds.ToArray());


m_tgtEmbedding = CreateTgtEmbeddings(model, raDeviceIds, m_options.IsTgtEmbeddingTrainable, m_options.DecoderStartLearningRateFactor, elementType: elementType);

m_srcEmbedding = new MultiProcessorNetworkWrapper<IFeedForwardLayer>(new FeedForwardLayer("SrcEmbedding_Decoder_0", model.HiddenDim, model.HiddenDim, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(),
isTrainable: true, learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType: elementType), DeviceIds);
isTrainable: true, learningRateFactor: m_options.EncoderStartLearningRateFactor, elementType: elementType), DeviceIds);

if (model.PointerGenerator)
{
Expand Down Expand Up @@ -215,7 +215,7 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph compu
{
(var c, _) = Decoder.DecodeTransformer(tgtTokensList, computeGraph, encOutput, decoder as TransformerDecoder, decoderFFLayer, tgtEmbedding, originalSrcLengths, m_modelMetaData.TgtVocab, m_shuffleType,
m_options.DropoutRatio, decodingOptions, isTraining, pointerGenerator: pointerGenerator, srcSeqs: null, lossType: m_options.LossType, focalLossGamma: m_options.FocalLossGamma,
segmentEmbeddings: segmentEmbedding, amp: m_options.AMP, posEmbeddings: posEmbeddings);
segmentEmbeddings: segmentEmbedding, amp: m_options.AMP, posEmbeddings: null);
nr.Cost = c;
nr.Output = null;
}
Expand All @@ -233,7 +233,7 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph compu
(var cost2, var bssSeqList) = Decoder.DecodeTransformer(tgtTokensList, g, encOutput, decoder as TransformerDecoder, decoderFFLayer, tgtEmbedding,
originalSrcLengths, m_modelMetaData.TgtVocab, m_shuffleType, 0.0f, decodingOptions, isTraining,
outputSentScore: decodingOptions.BeamSearchSize > 1, pointerGenerator: pointerGenerator,
srcSeqs: null, teacherForcedAlignment: true, lossType: m_options.LossType, segmentEmbeddings: segmentEmbedding, amp: m_options.AMP, posEmbeddings: posEmbeddings);
srcSeqs: null, teacherForcedAlignment: true, lossType: m_options.LossType, segmentEmbeddings: segmentEmbedding, amp: m_options.AMP, posEmbeddings: null);
nr.Cost = 0.0f;
nr.Output = m_modelMetaData.TgtVocab.CovertToWords(bssSeqList);
if (decodingOptions.OutputAligmentsToSrc)
Expand Down Expand Up @@ -271,7 +271,7 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph compu
outputSentScore: decodingOptions.BeamSearchSize > 1, previousBeamSearchResults: batchStatus,
pointerGenerator: pointerGenerator, srcSeqs: null,
cachedTensors: cachedTensors, alignmentsToSrc: alignmentsToSrc, alignmentScoresToSrc: alignmentScores,
blockedTokens: decodingOptions.BlockedTokens, segmentEmbeddings: segmentEmbedding, amp: m_options.AMP, posEmbeddings: posEmbeddings);
blockedTokens: decodingOptions.BlockedTokens, segmentEmbeddings: segmentEmbedding, amp: m_options.AMP, posEmbeddings: null);

bssSeqList = Decoder.SwapBeamAndBatch(bssSeqList); // Swap shape: (beam_search_size, batch_size) -> (batch_size, beam_search_size)
batch2beam2seq = Decoder.CombineBeamSearchResults(batch2beam2seq, bssSeqList);
Expand Down

0 comments on commit a828018

Please sign in to comment.