Skip to content

Commit

Permalink
Add layernorm for source embeddings in image caption
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Oct 18, 2023
1 parent 6022102 commit 53d517b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
16 changes: 12 additions & 4 deletions Seq2SeqSharp/Applications/Image2Seq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
using Seq2SeqSharp.Tools;
using Seq2SeqSharp.Utils;
using TensorSharp;
using System.Xml.Linq;

namespace Seq2SeqSharp.Applications
{
Expand All @@ -37,6 +38,8 @@ public class Image2Seq : BaseSeq2SeqFramework<Seq2SeqModel>
private MultiProcessorNetworkWrapper<IWeightTensor> m_posEmbedding = null;
private MultiProcessorNetworkWrapper<IWeightTensor> m_segmentEmbedding;

private MultiProcessorNetworkWrapper<INormalization> m_layerNorm = null;

private MultiProcessorNetworkWrapper<IWeightTensor> m_cls = null;

private MultiProcessorNetworkWrapper<IFeedForwardLayer> m_pointerGenerator;
Expand Down Expand Up @@ -114,11 +117,16 @@ private bool CreateTrainableParameters(IModel model)
isTrainable: true, dtype: elementType), raDeviceIds.ToArray());


m_layerNorm = new MultiProcessorNetworkWrapper<INormalization>(new LayerNormalization($"Src_LayerNorm", 768, raDeviceIds.GetNextItem(), isTrainable: true, learningRateFactor: m_options.EncoderStartLearningRateFactor, elementType: elementType), raDeviceIds.ToArray());


m_tgtEmbedding = CreateTgtEmbeddings(model, raDeviceIds, m_options.IsTgtEmbeddingTrainable, m_options.DecoderStartLearningRateFactor, elementType: elementType);

m_srcEmbedding = new MultiProcessorNetworkWrapper<IFeedForwardLayer>(new FeedForwardLayer("SrcEmbedding_Decoder_0", 768, model.HiddenDim, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(),
isTrainable: true, learningRateFactor: m_options.EncoderStartLearningRateFactor, elementType: elementType), DeviceIds);



if (model.PointerGenerator)
{
if (model.SharedEmbeddings == false)
Expand Down Expand Up @@ -147,7 +155,7 @@ public void VQModel()
/// <summary>
/// Get networks on specific devices
/// </summary>
private (IEncoder, IDecoder, IFeedForwardLayer, IFeedForwardLayer, IWeightTensor, IWeightTensor, IFeedForwardLayer, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceId)
private (IEncoder, IDecoder, IFeedForwardLayer, IFeedForwardLayer, IWeightTensor, IWeightTensor, IFeedForwardLayer, IWeightTensor, IWeightTensor, INormalization) GetNetworksOnDeviceAt(int deviceId)
{
var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId);
return (m_encoder.GetNetworkOnDevice(deviceIdIdx),
Expand All @@ -156,7 +164,7 @@ public void VQModel()
m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx),
m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx),
m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_pointerGenerator?.GetNetworkOnDevice(deviceIdIdx), m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx),
m_cls.GetNetworkOnDevice(deviceIdIdx));
m_cls.GetNetworkOnDevice(deviceIdIdx), m_layerNorm.GetNetworkOnDevice(deviceIdIdx));
}

private string GenerateCacheKey(List<List<string>> strs)
Expand All @@ -182,10 +190,10 @@ private string GenerateCacheKey(List<List<string>> strs)
/// <returns>The cost of forward part</returns>
public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, IPairBatch sntPairBatch, DecodingOptions decodingOptions, bool isTraining)
{
(var encoder, var decoder, var decoderFFLayer, var srcEmbedding, var tgtEmbedding, var segmentEmbedding, var pointerGenerator, var posEmbeddings, var cls) = GetNetworksOnDeviceAt(computeGraph.DeviceId);
(var encoder, var decoder, var decoderFFLayer, var srcEmbedding, var tgtEmbedding, var segmentEmbedding, var pointerGenerator, var posEmbeddings, var cls, var layerNorm) = GetNetworksOnDeviceAt(computeGraph.DeviceId);

var srcSnts = sntPairBatch.GetSrcTokens();
IWeightTensor encOutput = ImgEncoder.Run(computeGraph, srcSnts[0], encoder, srcEmbedding, posEmbeddings, cls, m_modelMetaData.HiddenDim);
IWeightTensor encOutput = ImgEncoder.Run(computeGraph, srcSnts[0], encoder, srcEmbedding, posEmbeddings, cls, m_modelMetaData.HiddenDim, layerNorm);

List<NetworkResult> nrs = new List<NetworkResult>();

Expand Down
5 changes: 4 additions & 1 deletion Seq2SeqSharp/Applications/ImgEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,14 @@ static private IWeightTensor InnerEncode(IComputeGraph g, List<string> imgPaths)
return res;
}

static public IWeightTensor Run(IComputeGraph g, List<string> imgPaths, IEncoder encoder, IFeedForwardLayer srcEmbeddings, IWeightTensor posEmbeddings, IWeightTensor cls, int dim)
static public IWeightTensor Run(IComputeGraph g, List<string> imgPaths, IEncoder encoder, IFeedForwardLayer srcEmbeddings, IWeightTensor posEmbeddings, IWeightTensor cls, int dim, INormalization layernorm)
{
int batchSize = imgPaths.Count;
var inputEmbs = InnerEncode(g, imgPaths);

inputEmbs = layernorm.Norm(inputEmbs, g);
inputEmbs = srcEmbeddings.Process(inputEmbs, batchSize, g);
inputEmbs = g.SiLU(inputEmbs);

inputEmbs = g.View(inputEmbs, dims: new long[] { batchSize, -1, dim });

Expand Down

0 comments on commit 53d517b

Please sign in to comment.