Skip to content

Commit

Permalink
Add option for start batch id
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Sep 4, 2023
1 parent 352fa69 commit 4db69a5
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 6 deletions.
5 changes: 4 additions & 1 deletion Seq2SeqSharp/Applications/Options.cs
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,12 @@ public class Options
[Arg("Training corpus folder path", nameof(TrainCorpusPath))]
public string TrainCorpusPath = null;

[Arg("Indexed data set file paht", nameof(IndexedCorpusPath))]
[Arg("Indexed data set file paht. The default value is empty.", nameof(IndexedCorpusPath))]
public string IndexedCorpusPath = null;

[Arg("The batch id that the tool will start to process. The default value is 0", nameof(StartBatchId))]
public int StartBatchId = 0;

[Arg("The max degress of parallelism in task. Default is 1", nameof(TaskParallelism))]
public int TaskParallelism = 1;

Expand Down
11 changes: 9 additions & 2 deletions Seq2SeqSharp/Corpus/MonoCorpus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ namespace Seq2SeqSharp.Tools

private string m_indexedDataSetFilePath = "";
private int m_batchNumInTotal = 0;
private int m_startBatchId = 0;

public List<Dictionary<string, int>> CountTokenFreqs()
{
Expand Down Expand Up @@ -340,6 +341,11 @@ public IEnumerator<T> GetEnumerator()
string[] tgtLines = br.ReadString().Split("\n");
batchIdx++;

if (batchIdx < m_startBatchId)
{
continue;
}

T batch;
int currentTokenCountsInBatch = 0;
for (int i = 0; i < sizeInBatch; i++)
Expand All @@ -353,7 +359,7 @@ public IEnumerator<T> GetEnumerator()
Logger.WriteLine($"Processing batch '{batchIdx}/{m_batchNumInTotal}'."); // The '{i}th' record in this batch is: Target = '{tgtLine}'");
currentBatchPercent++;
}
}
}

SntPair sntPair = new SntPair(tgtLine, tgtLine);
currentTokenCountsInBatch += sntPair.GetTgtTokenCount();
Expand Down Expand Up @@ -393,7 +399,7 @@ public MonoCorpus()

}

public MonoCorpus(string corpusFilePath, string tgtLangName, int maxTokenSizePerBatch, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = "")
public MonoCorpus(string corpusFilePath, string tgtLangName, int maxTokenSizePerBatch, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = "", int startBatchId = 0)
{
Logger.WriteLine($"Loading mono corpus from '{corpusFilePath}' Files search pattern '*.{tgtLangName}.snt' MaxTgtSentLength = '{maxTgtSentLength}', aggregateLengthForShuffle = '{shuffleEnums}', TooLongSequence = '{tooLongSequence}'");
m_maxTokenSizePerBatch = maxTokenSizePerBatch;
Expand All @@ -402,6 +408,7 @@ public MonoCorpus(string corpusFilePath, string tgtLangName, int maxTokenSizePer
m_shuffleEnums = shuffleEnums;
CorpusName = corpusFilePath;
m_indexedDataSetFilePath = indexedFilePath;
m_startBatchId = startBatchId;

m_tgtFileList = new List<string>();
string[] files = Directory.GetFiles(corpusFilePath, $"*.{tgtLangName}.snt", SearchOption.TopDirectoryOnly);
Expand Down
4 changes: 2 additions & 2 deletions Seq2SeqSharp/Corpus/SeqCorpus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ namespace Seq2SeqSharp.Corpus
public class SeqCorpus : MonoCorpus<SeqCorpusBatch>
{

public SeqCorpus(string corpusFilePath, string tgtLangName, int maxTokenSizePerBatch, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = "")
: base(corpusFilePath, tgtLangName, maxTokenSizePerBatch, maxTgtSentLength, shuffleEnums: shuffleEnums, tooLongSequence: tooLongSequence, indexedFilePath: indexedFilePath)
public SeqCorpus(string corpusFilePath, string tgtLangName, int maxTokenSizePerBatch, int maxTgtSentLength = 32, ShuffleEnums shuffleEnums = ShuffleEnums.Random, TooLongSequence tooLongSequence = TooLongSequence.Ignore, string indexedFilePath = "", int startBatchId = 0)
: base(corpusFilePath, tgtLangName, maxTokenSizePerBatch, maxTgtSentLength, shuffleEnums: shuffleEnums, tooLongSequence: tooLongSequence, indexedFilePath: indexedFilePath, startBatchId: startBatchId)
{

}
Expand Down
2 changes: 1 addition & 1 deletion Tools/GPTConsole/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ private static void Main(string[] args)
{
// Load train corpus
var trainCorpus = new SeqCorpus(corpusFilePath: opts.TrainCorpusPath, tgtLangName: opts.TgtLang, maxTokenSizePerBatch: opts.MaxTokenSizePerBatch,
maxTgtSentLength: opts.MaxTgtSentLength, shuffleEnums: opts.ShuffleType, tooLongSequence: opts.TooLongSequence, indexedFilePath: opts.IndexedCorpusPath);
maxTgtSentLength: opts.MaxTgtSentLength, shuffleEnums: opts.ShuffleType, tooLongSequence: opts.TooLongSequence, indexedFilePath: opts.IndexedCorpusPath, startBatchId: opts.StartBatchId);

// Create learning rate
ILearningRate learningRate = new DecayLearningRate(opts.StartLearningRate, opts.WarmUpSteps, opts.WeightsUpdateCount, opts.LearningRateStepDownFactor, opts.UpdateNumToStepDownLearningRate);
Expand Down

0 comments on commit 4db69a5

Please sign in to comment.