Skip to content

Commit

Permalink
minor bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Oct 18, 2023
1 parent 4588a7e commit 6022102
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Applications/Image2Seq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph compu
(var encoder, var decoder, var decoderFFLayer, var srcEmbedding, var tgtEmbedding, var segmentEmbedding, var pointerGenerator, var posEmbeddings, var cls) = GetNetworksOnDeviceAt(computeGraph.DeviceId);

var srcSnts = sntPairBatch.GetSrcTokens();
IWeightTensor encOutput = ImgEncoder.Run(computeGraph, srcSnts[0], encoder, srcEmbedding, posEmbeddings, cls);
IWeightTensor encOutput = ImgEncoder.Run(computeGraph, srcSnts[0], encoder, srcEmbedding, posEmbeddings, cls, m_modelMetaData.HiddenDim);

List<NetworkResult> nrs = new List<NetworkResult>();

Expand Down
12 changes: 5 additions & 7 deletions Seq2SeqSharp/Applications/ImgEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,22 +104,20 @@ static private IWeightTensor InnerEncode(IComputeGraph g, List<string> imgPaths)
return res;
}

static public IWeightTensor Run(IComputeGraph g, List<string> imgPaths, IEncoder encoder, IFeedForwardLayer srcEmbeddings, IWeightTensor posEmbeddings, IWeightTensor cls)
static public IWeightTensor Run(IComputeGraph g, List<string> imgPaths, IEncoder encoder, IFeedForwardLayer srcEmbeddings, IWeightTensor posEmbeddings, IWeightTensor cls, int dim)
{
int batchSize = imgPaths.Count;
var inputEmbs = InnerEncode(g, imgPaths);
inputEmbs = srcEmbeddings.Process(inputEmbs, batchSize, g);

inputEmbs = g.View(inputEmbs, dims: new long[] { batchSize, -1, 768 });
inputEmbs = g.View(inputEmbs, dims: new long[] { batchSize, -1, dim });

cls = g.View(cls, dims: new long[] { 1, 1, 768 });
cls = g.Expand(cls, dims: new long[] { batchSize, 1, 768 });
cls = g.View(cls, dims: new long[] { 1, 1, dim });
cls = g.Expand(cls, dims: new long[] { batchSize, 1, dim });

inputEmbs = g.Concate(1, cls, inputEmbs);

inputEmbs = g.View(inputEmbs, dims: new long[] { -1, 768 });


inputEmbs = g.View(inputEmbs, dims: new long[] { -1, dim });

inputEmbs = PositionEmbedding.AddPositionEmbedding(g, posEmbeddings, batchSize, inputEmbs, 0.0f);
return encoder.Encode(inputEmbs, batchSize, g, null);
Expand Down

0 comments on commit 6022102

Please sign in to comment.