Skip to content

Commit

Permalink
Code refactoring for image caption
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Oct 11, 2023
1 parent 69e644f commit d69d48c
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 119 deletions.
20 changes: 9 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

[![.NET](https://github.com/zhongkaifu/Seq2SeqSharp/actions/workflows/dotnet.yml/badge.svg)](https://github.com/zhongkaifu/Seq2SeqSharp/actions/workflows/dotnet.yml)
# Seq2SeqSharp
Seq2SeqSharp is a tensor based fast & flexible encoder-decoder deep neural network framework written by .NET (C#). It can be used for sequence-to-sequence task, sequence-labeling task and sequence-classification task and other NLP tasks. Seq2SeqSharp supports both CPUs and GPUs and is able to run cross-platforms, such as Windows and Linux (x86, x64 and ARM) without any modification and recompilation.
Seq2SeqSharp is a tensor based fast & flexible deep neural network framework written by .NET (C#). It can be used for sequence-to-sequence task, sequence-labeling task, sequence-classification task and others for text and images. Seq2SeqSharp supports both CPUs and GPUs and is able to run cross-platforms, such as Windows and Linux (x86, x64 and ARM) without any modification and recompilation.

# Features
Pure C# framework
Transformer encoder and decoder with pointer generator
Vision Transformer encoder for images
GPTDecoder
Attention based LSTM decoder with coverage model
Bi-directional LSTM encoder
Expand Down Expand Up @@ -247,6 +248,9 @@ For example: given the input sentence "▁i ▁would ▁like ▁to ▁drink ▁w
GPTConsole is a command line tool for GPT style model training and testing. Given text in input file per line, the model will continue generating the rest of text.
This tool is pretty similiar to Seq2SeqConsole and most of parameters are reusable. The main difference is that GPTConsole does not have settings for source side and encoders. Its all settings are for target side and decoder only.

## ImgSeqConsole is for image caption task
ImgSeqConsole is a command line tool for image caption task. Given a list of image file path, the model will generate descriptions of these images.

## SeqClassification for sequence-classification task
SeqClassification is used to classify input sequence to a certain category. Given an input sequence, the tool will add a [CLS] tag at the beginning of sequence, and then send it to the encoder. At top layer of the encoder, it will run softmax against [CLS] and decide which category the sequence belongs to.
This tool can be used to train a model for sequence-classification task, and test the model.
Expand Down Expand Up @@ -299,7 +303,6 @@ Here is the configuration file for model training.

### Data format for SeqCliassificationConsole tool
It also uses two files for each pair of data and follows the same naming convention as Seq2SeqConsole tool in above. The source file includes tokens as input to the model, and the target file includes the corresponding tags that model will predict. Each line contains one record.
The model supports multi-classifiers, so tags in the target file are split by tab character, such as [Tag1] \t [Tag2] \t ... \t [TagN]. Each classifiers predicts one tag.

Here is an example:
| Tag | Tokens in Sequence |
Expand All @@ -309,11 +312,6 @@ Here is an example:

"Otorhinolaryngology" and "Orthopedics" are tags for classification and the rest of the tokens in each line are tokens for input sequence. This is an example that given title and description in medical domain, asking model to predict which specialty it should be classified. [SEP] is used to split title and description in the sequence, but it's not required in other tasks.

## Seq2SeqClassificationConsole for sequence-to-sequence and classification multi-tasks
Here is the graph that what the model looks like:
![](https://raw.githubusercontent.com/zhongkaifu/Seq2SeqSharp/master/Images/Seq2SeqClassificationModel.jpeg)


## SeqLabelConsole for sequence-labeling task
The usage of **SeqLabelConsole.exe** is similar as **Seq2SeqConsole.exe** in above, you can just type it in the console and it will show you usage.

Expand Down Expand Up @@ -389,9 +387,6 @@ Here is the configuration file for model training.
}
```

## SeqSimilarityConsole for sequences similarity calculation
Each line in data set contains two sequences and the tool can calculate their similairy. These two sequences are split by tab character.

# Demos and released models
From 2.7.0 version, Seq2SeqSharp models are deployed on Hugging Face and you can also play demos there.
| Demo | Hugging Face Space Url | Hugging Face Model Url | Model Parameters |
Expand Down Expand Up @@ -838,5 +833,8 @@ In Seq2SeqConsole project, it shows you how to initialize and train your network

# Todo List
If you are interested in below items, please let me know. Becuase African proverb says "If you want to go fast, go alone. If you want to go far, go together" :)
Multimodal models
Support Mac Devices
And More...

## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=zhongkaifu/Seq2SeqSharp&type=Date)](https://star-history.com/#zhongkaifu/Seq2SeqSharp)
3 changes: 1 addition & 2 deletions Seq2SeqSharp.sln
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "Solution Items", "Solution
NetworkViz.png = NetworkViz.png
Overview.jpg = Overview.jpg
README.md = README.md
Seq2SeqClassificationModel.jpeg = Seq2SeqClassificationModel.jpeg
Seq2SeqModel.jpeg = Seq2SeqModel.jpeg
SeqClassificationModel.jpeg = SeqClassificationModel.jpeg
TagEmbeddings.jpeg = TagEmbeddings.jpeg
Expand Down Expand Up @@ -70,7 +69,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "PythonPackage", "PythonPack
PyPackage\Seq2SeqSharp\__init__.py = PyPackage\Seq2SeqSharp\__init__.py
EndProjectSection
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "ImgSeqConsole", "Tools\ImgSeqConsole\ImgSeqConsole.csproj", "{D5B59E92-8BFF-4B30-844B-E95E67D5A68B}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "ImgSeqConsole", "Tools\ImgSeqConsole\ImgSeqConsole.csproj", "{D5B59E92-8BFF-4B30-844B-E95E67D5A68B}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Expand Down
58 changes: 5 additions & 53 deletions Seq2SeqSharp/Applications/Image2Seq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,9 @@ public class Image2Seq : BaseSeq2SeqFramework<Seq2SeqModel>

private MultiProcessorNetworkWrapper<IFeedForwardLayer> m_pointerGenerator;

private MultiProcessorNetworkWrapper<INormalization> m_layerNorm1;
private MultiProcessorNetworkWrapper<INormalization> m_layerNorm2;

private readonly ShuffleEnums m_shuffleType = ShuffleEnums.Random;
readonly Seq2SeqOptions m_options = null;

private MemoryCache m_memoryCache;

public Image2Seq(Seq2SeqOptions options, Vocab srcVocab = null, Vocab tgtVocab = null)
: base(deviceIds: options.DeviceIds, processorType: options.ProcessorType, modelFilePath: options.ModelFilePath, memoryUsageRatio: options.MemoryUsageRatio,
compilerOptions: options.CompilerOptions, runValidEveryUpdates: options.RunValidEveryUpdates, updateFreq: options.UpdateFreq,
Expand All @@ -61,11 +56,6 @@ public Image2Seq(Seq2SeqOptions options, Vocab srcVocab = null, Vocab tgtVocab =
// Check if options are valided.
m_options.ValidateOptions();

m_memoryCache = new MemoryCache(new MemoryCacheOptions
{
SizeLimit = 1024
});

if (File.Exists(m_options.ModelFilePath))
{
if (srcVocab != null || tgtVocab != null)
Expand Down Expand Up @@ -116,11 +106,8 @@ private bool CreateTrainableParameters(IModel model)
m_decoder = Decoder.CreateDecoders(model, m_options, raDeviceIds, elementType: elementType);
m_decoderFFLayer = new MultiProcessorNetworkWrapper<IFeedForwardLayer>(new FeedForwardLayer("FeedForward_Decoder_0", model.HiddenDim, model.TgtVocab.Count, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(),
isTrainable: true, learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType: elementType), DeviceIds);
// (m_posEmbedding, m_segmentEmbedding) = Misc.CreateAuxEmbeddings(raDeviceIds, model.HiddenDim, Math.Max(Math.Max(m_options.MaxSrcSentLength, m_options.MaxValidSrcSentLength), Math.Max(m_options.MaxTgtSentLength, m_options.MaxValidTgtSentLength)), model,
// elementType: elementType, createAPE: (model.PEType == PositionEmbeddingEnums.APE));


m_posEmbedding = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { m_options.MaxSrcSentLength, model.HiddenDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, name: "PositionalEmbedding",
m_posEmbedding = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { 1024, model.HiddenDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, name: "PositionalEmbedding",
isTrainable: true, dtype: elementType), raDeviceIds.ToArray());

m_cls = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { 1, model.HiddenDim }, raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, name: "CLS",
Expand All @@ -132,13 +119,6 @@ private bool CreateTrainableParameters(IModel model)
m_srcEmbedding = new MultiProcessorNetworkWrapper<IFeedForwardLayer>(new FeedForwardLayer("SrcEmbedding_Decoder_0", model.HiddenDim, model.HiddenDim, dropoutRatio: 0.0f, deviceId: raDeviceIds.GetNextItem(),
isTrainable: true, learningRateFactor: m_options.DecoderStartLearningRateFactor, elementType: elementType), DeviceIds);


// m_srcEmbedding = new MultiProcessorNetworkWrapper<IWeightTensor>(new WeightTensor(new long[2] { 768, 256 },
//raDeviceIds.GetNextItem(), initType: RandomInitType.Uniform, fanOut: true, name: "SrcEmbeddings", isTrainable: true, learningRateFactor: 1.0f, dtype: elementType), DeviceIds);

m_layerNorm1 = new MultiProcessorNetworkWrapper<INormalization>(new LayerNormalization("LayerNorm1", model.HiddenDim, deviceId: raDeviceIds.GetNextItem(), isTrainable: true), DeviceIds);
m_layerNorm2 = new MultiProcessorNetworkWrapper<INormalization>(new LayerNormalization("LayerNorm2", model.HiddenDim, deviceId: raDeviceIds.GetNextItem(), isTrainable: true), DeviceIds);

if (model.PointerGenerator)
{
if (model.SharedEmbeddings == false)
Expand Down Expand Up @@ -167,7 +147,7 @@ public void VQModel()
/// <summary>
/// Get networks on specific devices
/// </summary>
private (IEncoder, IDecoder, IFeedForwardLayer, IFeedForwardLayer, IWeightTensor, IWeightTensor, IFeedForwardLayer, IWeightTensor, INormalization, INormalization, IWeightTensor) GetNetworksOnDeviceAt(int deviceId)
private (IEncoder, IDecoder, IFeedForwardLayer, IFeedForwardLayer, IWeightTensor, IWeightTensor, IFeedForwardLayer, IWeightTensor, IWeightTensor) GetNetworksOnDeviceAt(int deviceId)
{
var deviceIdIdx = TensorAllocator.GetDeviceIdIndex(deviceId);
return (m_encoder.GetNetworkOnDevice(deviceIdIdx),
Expand All @@ -176,8 +156,6 @@ public void VQModel()
m_srcEmbedding.GetNetworkOnDevice(deviceIdIdx),
m_tgtEmbedding.GetNetworkOnDevice(deviceIdIdx),
m_segmentEmbedding?.GetNetworkOnDevice(deviceIdIdx), m_pointerGenerator?.GetNetworkOnDevice(deviceIdIdx), m_posEmbedding?.GetNetworkOnDevice(deviceIdIdx),
m_layerNorm1.GetNetworkOnDevice(deviceIdIdx),
m_layerNorm2.GetNetworkOnDevice(deviceIdIdx),
m_cls.GetNetworkOnDevice(deviceIdIdx));
}

Expand All @@ -204,45 +182,19 @@ private string GenerateCacheKey(List<List<string>> strs)
/// <returns>The cost of forward part</returns>
public override List<NetworkResult> RunForwardOnSingleDevice(IComputeGraph computeGraph, IPairBatch sntPairBatch, DecodingOptions decodingOptions, bool isTraining)
{
(var encoder, var decoder, var decoderFFLayer, var srcEmbedding, var tgtEmbedding, var segmentEmbedding, var pointerGenerator, var posEmbeddings, var layerNorm1, var layerNorm2, var cls) = GetNetworksOnDeviceAt(computeGraph.DeviceId);
(var encoder, var decoder, var decoderFFLayer, var srcEmbedding, var tgtEmbedding, var segmentEmbedding, var pointerGenerator, var posEmbeddings, var cls) = GetNetworksOnDeviceAt(computeGraph.DeviceId);

var srcSnts = sntPairBatch.GetSrcTokens();
IWeightTensor encOutput = ImgEncoder.Run(computeGraph, srcSnts[0], encoder, srcEmbedding, posEmbeddings, layerNorm1, layerNorm2, cls, false);

// encOutput.PrintWeights();
IWeightTensor encOutput = ImgEncoder.Run(computeGraph, srcSnts[0], encoder, srcEmbedding, posEmbeddings, cls);

List<NetworkResult> nrs = new List<NetworkResult>();

// Generate output decoder sentences
int batchSize = srcSnts[0].Count;

//if (isTraining)
//{
// batchSize = batchSize * 6;
//}

float[] originalSrcLengths = new float[batchSize];
Array.Fill(originalSrcLengths, 257);


Array.Fill(originalSrcLengths, 257); // 256 tokens from the image + 1 cls token
var tgtSnts = sntPairBatch.GetTgtTokens();

//if (isTraining)
//{
// List<List<string>> newTgtSnts = new List<List<string>>();

// foreach (var item in tgtSnts)
// {
// for (int i = 0; i < 6; i++)
// {
// newTgtSnts.Add(item);
// }
// }

// tgtSnts = newTgtSnts;
//}


var tgtTokensList = m_modelMetaData.TgtVocab.GetWordIndex(tgtSnts);
NetworkResult nr = new NetworkResult();
nr.Status = NetworkResultStatus.SUCCEED;
Expand Down
54 changes: 4 additions & 50 deletions Seq2SeqSharp/Applications/ImgEncoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,41 +22,14 @@ public static class ImgEncoder

static int TOKEN_W = 16;
static int TOKEN_H = 16;

static int TOTAL_TOKEN_NUM_PER_IMG = (IMAGE_W / TOKEN_W) * (IMAGE_H / TOKEN_H);

static private IWeightTensor LoadImageToTokens(IComputeGraph g, string filePath, bool rotation = false, bool horizFlip = false, bool vertFlip = false, bool brigtness = false, bool contrast = false)
static private IWeightTensor LoadImageToTokens(IComputeGraph g, string filePath)
{

List<float[]> tokens = new List<float[]>();

using (Image<Rgb24> image = Image.Load<Rgb24>(filePath))
{
if (rotation == true)
{
image.Mutate(x => x.Rotate(90.0f));
}

if (horizFlip == true)
{
image.Mutate(x => x.Flip(FlipMode.Horizontal));
}

if (vertFlip == true)
{
image.Mutate(x => x.Flip(FlipMode.Vertical));
}

if (brigtness == true)
{
image.Mutate(x => x.Brightness(0.5f));
}

if (contrast == true)
{
image.Mutate(x => x.Contrast(0.5f));
}

int newWidth = 0;
int newHeight = 0;
if (image.Width < image.Height)
Expand Down Expand Up @@ -117,44 +90,25 @@ static private IWeightTensor LoadImageToTokens(IComputeGraph g, string filePath,
//Size(token) = TOTAL_TOKEN_NUM_PER_IMG
//Size(embedding_dim) = 768
//Shape: [batchsize, TOTAL_TOKEN_NUM_PER_IMG, 768]
static private IWeightTensor InnerEncode(IComputeGraph g, List<string> imgPaths, bool trainMode = false)
static private IWeightTensor InnerEncode(IComputeGraph g, List<string> imgPaths)
{
int batchSize = imgPaths.Count;
List<IWeightTensor> batchTokens = new List<IWeightTensor>();

foreach (var picPath in imgPaths)
{
batchTokens.Add(LoadImageToTokens(g, picPath)); //shape: [TOTAL_TOKEN_NUM_PER_IMG, 768]

if (trainMode)
{
batchTokens.Add(LoadImageToTokens(g, picPath, rotation: true)); //shape: [TOTAL_TOKEN_NUM_PER_IMG, 768]
batchTokens.Add(LoadImageToTokens(g, picPath, horizFlip: true)); //shape: [TOTAL_TOKEN_NUM_PER_IMG, 768]
batchTokens.Add(LoadImageToTokens(g, picPath, vertFlip: true)); //shape: [TOTAL_TOKEN_NUM_PER_IMG, 768]
batchTokens.Add(LoadImageToTokens(g, picPath, brigtness: true)); //shape: [TOTAL_TOKEN_NUM_PER_IMG, 768]
batchTokens.Add(LoadImageToTokens(g, picPath, contrast: true)); //shape: [TOTAL_TOKEN_NUM_PER_IMG, 768]
}
}

var res = g.Concate(batchTokens, 0);
return res;
}

static public IWeightTensor Run(IComputeGraph g, List<string> imgPaths, IEncoder encoder, IFeedForwardLayer srcEmbeddings, IWeightTensor posEmbeddings, INormalization layerNorm1, INormalization layerNorm2, IWeightTensor cls, bool trainMode = false)
static public IWeightTensor Run(IComputeGraph g, List<string> imgPaths, IEncoder encoder, IFeedForwardLayer srcEmbeddings, IWeightTensor posEmbeddings, IWeightTensor cls)
{
int batchSize = imgPaths.Count;
var inputEmbs = InnerEncode(g, imgPaths, trainMode);

if (trainMode)
{
batchSize = batchSize * 6;
}


// inputEmbs = layerNorm1.Norm(inputEmbs, g);
var inputEmbs = InnerEncode(g, imgPaths);
inputEmbs = srcEmbeddings.Process(inputEmbs, batchSize, g);
// inputEmbs = g.ReLU (inputEmbs);
// inputEmbs = layerNorm2.Norm(inputEmbs, g);

inputEmbs = g.View(inputEmbs, dims: new long[] { batchSize, -1, 768 });

Expand Down
4 changes: 2 additions & 2 deletions Seq2SeqSharp/Applications/Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ public class Options
[Arg("The decay steps of learning rate", nameof(LearningRateDecaySteps))]
public int LearningRateDecaySteps = 200000; // 200K

[Arg("The type of learning rate", nameof(LearnRateType))]
public LearningRateTypeEnums LearnRateType = LearningRateTypeEnums.Decay;
[Arg("The type of learning rate", nameof(LearningRateType))]
public LearningRateTypeEnums LearningRateType = LearningRateTypeEnums.Decay;

[Arg("Shuffle Type. It could be NoPaddingInSrc, NoPaddingInTgt and Random", nameof(ShuffleType))]
public ShuffleEnums ShuffleType = ShuffleEnums.Random;
Expand Down
2 changes: 1 addition & 1 deletion Tools/GPTConsole/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ private static void Main(string[] args)
// Create learning rate
ILearningRate learningRate = null;

if (opts.LearnRateType == LearningRateTypeEnums.CosineDecay)
if (opts.LearningRateType == LearningRateTypeEnums.CosineDecay)
{
learningRate = new CosineDecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.LearningRateDecaySteps, opts.WeightsUpdateCount);
}
Expand Down

0 comments on commit d69d48c

Please sign in to comment.