Skip to content

Commit

Permalink
Check weights corrupted while loading/saveing models
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Sep 26, 2023
1 parent af9fb59 commit bd152bf
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
13 changes: 13 additions & 0 deletions Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ public IComputeGraph CreateComputGraph(int deviceIdIdx, bool needBack = true)
protected abstract T LoadModel(string suffix = "");
protected bool SaveModelRoutine<ProtoBuf_T>(T model, Func<T, ProtoBuf_T> createModel4SerializeFunc, bool createBackupPrevious = false, string suffix = "")
{
Logger.WriteLine("Checking if all weights are normal.");
if (IsWeightsCorrupted())
{
throw new WeightsCorruptedException($"The weights has been corrupted. Abort training and please check checkpoint files.");
}

string modelFilePath = m_modelFilePath + suffix;
var fn = Path.GetFullPath(modelFilePath);
var dir = Path.GetDirectoryName(fn); if (!Directory.Exists(dir)) Directory.CreateDirectory(dir);
Expand Down Expand Up @@ -312,6 +318,13 @@ protected T LoadModelRoutine<ProtoBuf_T>(Func<T, bool> initializeParametersFunc,

model.ClearWeights();

Logger.WriteLine("Checking if all loaded weights are normal.");
if (IsWeightsCorrupted())
{
throw new WeightsCorruptedException($"The weights has been corrupted. Abort training and please check checkpoint files.");
}


return (model);
}

Expand Down
7 changes: 6 additions & 1 deletion Seq2SeqSharp/Tools/ComputeGraphTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2386,7 +2386,12 @@ void backward()
}



/// <summary>
/// </summary>
/// <param name="src">shape: [D1, D2, D3, ... Dn-1, Dn, Dn+1, ...Dm]</param>
/// <param name="dim"></param>
/// <param name="index"></param>
/// <returns>Remove "dim" from the output tensor. shape: [D1, D2, D3, ... Dn-1, Dn+1, ...Dm]</returns>
public IWeightTensor Select(IWeightTensor src, int dim, int index)
{
WeightTensor s = src as WeightTensor;
Expand Down

0 comments on commit bd152bf

Please sign in to comment.