diff --git a/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs b/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs index 9e8a9519..0537b674 100644 --- a/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs +++ b/Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs @@ -192,6 +192,12 @@ public IComputeGraph CreateComputGraph(int deviceIdIdx, bool needBack = true) protected abstract T LoadModel(string suffix = ""); protected bool SaveModelRoutine(T model, Func createModel4SerializeFunc, bool createBackupPrevious = false, string suffix = "") { + Logger.WriteLine("Checking if all weights are normal."); + if (IsWeightsCorrupted()) + { + throw new WeightsCorruptedException($"The weights has been corrupted. Abort training and please check checkpoint files."); + } + string modelFilePath = m_modelFilePath + suffix; var fn = Path.GetFullPath(modelFilePath); var dir = Path.GetDirectoryName(fn); if (!Directory.Exists(dir)) Directory.CreateDirectory(dir); @@ -312,6 +318,13 @@ protected T LoadModelRoutine(Func initializeParametersFunc, model.ClearWeights(); + Logger.WriteLine("Checking if all loaded weights are normal."); + if (IsWeightsCorrupted()) + { + throw new WeightsCorruptedException($"The weights has been corrupted. Abort training and please check checkpoint files."); + } + + return (model); } diff --git a/Seq2SeqSharp/Tools/ComputeGraphTensor.cs b/Seq2SeqSharp/Tools/ComputeGraphTensor.cs index e728df03..64535927 100644 --- a/Seq2SeqSharp/Tools/ComputeGraphTensor.cs +++ b/Seq2SeqSharp/Tools/ComputeGraphTensor.cs @@ -2386,7 +2386,12 @@ void backward() } - + /// + /// + /// shape: [D1, D2, D3, ... Dn-1, Dn, Dn+1, ...Dm] + /// + /// + /// Remove "dim" from the output tensor. shape: [D1, D2, D3, ... Dn-1, Dn+1, ...Dm] public IWeightTensor Select(IWeightTensor src, int dim, int index) { WeightTensor s = src as WeightTensor;