Skip to content

Commit

Permalink
Bug fix: Pass SaveModelEveryUpdates to the framework
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Oct 10, 2023
1 parent f61642f commit 69e644f
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Applications/GPT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public GPT(Seq2SeqOptions options, Vocab tgtVocab = null)
: base(deviceIds: options.DeviceIds, processorType: options.ProcessorType, modelFilePath: options.ModelFilePath, memoryUsageRatio: options.MemoryUsageRatio,
compilerOptions: options.CompilerOptions, runValidEveryUpdates: options.RunValidEveryUpdates, updateFreq: options.UpdateFreq,
startToRunValidAfterUpdates: options.StartValidAfterUpdates, maxDegressOfParallelism: options.TaskParallelism, mklInstructions: options.MKLInstructions, weightsUpdateCount: options.WeightsUpdateCount,
enableTensorCore: options.EnableTensorCore, cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32, randomSeed: options.RandomSeed)
enableTensorCore: options.EnableTensorCore, cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32, randomSeed: options.RandomSeed, saveModelEveryUpdats: options.SaveModelEveryUpdates)
{
m_shuffleType = options.ShuffleType;
m_options = options;
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Applications/Image2Seq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public Image2Seq(Seq2SeqOptions options, Vocab srcVocab = null, Vocab tgtVocab =
: base(deviceIds: options.DeviceIds, processorType: options.ProcessorType, modelFilePath: options.ModelFilePath, memoryUsageRatio: options.MemoryUsageRatio,
compilerOptions: options.CompilerOptions, runValidEveryUpdates: options.RunValidEveryUpdates, updateFreq: options.UpdateFreq,
startToRunValidAfterUpdates: options.StartValidAfterUpdates, maxDegressOfParallelism: options.TaskParallelism, mklInstructions: options.MKLInstructions,
weightsUpdateCount: options.WeightsUpdateCount, cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32)
weightsUpdateCount: options.WeightsUpdateCount, cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32, saveModelEveryUpdats: options.SaveModelEveryUpdates)
{
m_shuffleType = options.ShuffleType;
m_options = options;
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Applications/Seq2Seq.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public Seq2Seq(Seq2SeqOptions options, Vocab srcVocab = null, Vocab tgtVocab = n
: base(deviceIds: options.DeviceIds, processorType: options.ProcessorType, modelFilePath: options.ModelFilePath, memoryUsageRatio: options.MemoryUsageRatio,
compilerOptions: options.CompilerOptions, runValidEveryUpdates: options.RunValidEveryUpdates, updateFreq: options.UpdateFreq,
startToRunValidAfterUpdates: options.StartValidAfterUpdates, maxDegressOfParallelism: options.TaskParallelism, mklInstructions: options.MKLInstructions,
weightsUpdateCount: options.WeightsUpdateCount, cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32)
weightsUpdateCount: options.WeightsUpdateCount, cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32, saveModelEveryUpdats: options.SaveModelEveryUpdates)
{
m_shuffleType = options.ShuffleType;
m_options = options;
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Applications/SeqClassification.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public class SeqClassification : BaseSeq2SeqFramework<SeqClassificationModel>
public SeqClassification(SeqClassificationOptions options, Vocab srcVocab = null, Vocab tgtVocab = null)
: base(options.DeviceIds, options.ProcessorType, options.ModelFilePath, options.MemoryUsageRatio, options.CompilerOptions,
runValidEveryUpdates: options.RunValidEveryUpdates, updateFreq: options.UpdateFreq, maxDegressOfParallelism: options.TaskParallelism,
cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32)
cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32, saveModelEveryUpdats: options.SaveModelEveryUpdates)
{
m_shuffleType = options.ShuffleType;
m_options = options;
Expand Down
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Applications/SeqLabel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class SeqLabel : BaseSeq2SeqFramework<SeqLabelModel>
public SeqLabel(SeqLabelOptions options, Vocab srcVocab = null, Vocab clsVocab = null)
: base(options.DeviceIds, options.ProcessorType, options.ModelFilePath, options.MemoryUsageRatio, options.CompilerOptions, startToRunValidAfterUpdates: options.StartValidAfterUpdates,
runValidEveryUpdates: options.RunValidEveryUpdates, updateFreq: options.UpdateFreq, maxDegressOfParallelism: options.TaskParallelism,
cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32)
cudaMemoryAllocatorType: options.CudaMemoryAllocatorType, elementType: options.AMP ? DType.Float16 : DType.Float32, saveModelEveryUpdats: options.SaveModelEveryUpdates)
{
m_shuffleType = options.ShuffleType;
m_options = options;
Expand Down
4 changes: 2 additions & 2 deletions Tools/GPTConsole/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
using Seq2SeqSharp.Utils;
using Seq2SeqSharp.Applications;

namespace Seq2SeqConsole
namespace GPTConsole
{
internal static class Program
{
Expand Down Expand Up @@ -51,7 +51,7 @@ private static void Main(string[] args)
}

Logger.Verbose = opts.LogVerbose;
Logger.LogFile = $"{nameof(Seq2SeqConsole)}_{opts.Task}_{Utils.GetTimeStamp(DateTime.Now)}.log";
Logger.LogFile = $"{nameof(GPTConsole)}_{opts.Task}_{Utils.GetTimeStamp(DateTime.Now)}.log";

ShowOptions(args, opts);

Expand Down

0 comments on commit 69e644f

Please sign in to comment.