diff --git a/Seq2SeqSharp/Applications/Image2Seq.cs b/Seq2SeqSharp/Applications/Image2Seq.cs index d5561c3b..2d0183e0 100644 --- a/Seq2SeqSharp/Applications/Image2Seq.cs +++ b/Seq2SeqSharp/Applications/Image2Seq.cs @@ -185,7 +185,7 @@ public override List RunForwardOnSingleDevice(IComputeGraph compu (var encoder, var decoder, var decoderFFLayer, var srcEmbedding, var tgtEmbedding, var segmentEmbedding, var pointerGenerator, var posEmbeddings, var cls) = GetNetworksOnDeviceAt(computeGraph.DeviceId); var srcSnts = sntPairBatch.GetSrcTokens(); - IWeightTensor encOutput = ImgEncoder.Run(computeGraph, srcSnts[0], encoder, srcEmbedding, posEmbeddings, cls); + IWeightTensor encOutput = ImgEncoder.Run(computeGraph, srcSnts[0], encoder, srcEmbedding, posEmbeddings, cls, m_modelMetaData.HiddenDim); List nrs = new List(); diff --git a/Seq2SeqSharp/Applications/ImgEncoder.cs b/Seq2SeqSharp/Applications/ImgEncoder.cs index b59e8c94..44c09361 100644 --- a/Seq2SeqSharp/Applications/ImgEncoder.cs +++ b/Seq2SeqSharp/Applications/ImgEncoder.cs @@ -104,22 +104,20 @@ static private IWeightTensor InnerEncode(IComputeGraph g, List imgPaths) return res; } - static public IWeightTensor Run(IComputeGraph g, List imgPaths, IEncoder encoder, IFeedForwardLayer srcEmbeddings, IWeightTensor posEmbeddings, IWeightTensor cls) + static public IWeightTensor Run(IComputeGraph g, List imgPaths, IEncoder encoder, IFeedForwardLayer srcEmbeddings, IWeightTensor posEmbeddings, IWeightTensor cls, int dim) { int batchSize = imgPaths.Count; var inputEmbs = InnerEncode(g, imgPaths); inputEmbs = srcEmbeddings.Process(inputEmbs, batchSize, g); - inputEmbs = g.View(inputEmbs, dims: new long[] { batchSize, -1, 768 }); + inputEmbs = g.View(inputEmbs, dims: new long[] { batchSize, -1, dim }); - cls = g.View(cls, dims: new long[] { 1, 1, 768 }); - cls = g.Expand(cls, dims: new long[] { batchSize, 1, 768 }); + cls = g.View(cls, dims: new long[] { 1, 1, dim }); + cls = g.Expand(cls, dims: new long[] { batchSize, 1, dim }); inputEmbs = g.Concate(1, cls, inputEmbs); - inputEmbs = g.View(inputEmbs, dims: new long[] { -1, 768 }); - - + inputEmbs = g.View(inputEmbs, dims: new long[] { -1, dim }); inputEmbs = PositionEmbedding.AddPositionEmbedding(g, posEmbeddings, batchSize, inputEmbs, 0.0f); return encoder.Encode(inputEmbs, batchSize, g, null);