From fffd57ab3ec130c0ae935e93bf1088fcd430292a Mon Sep 17 00:00:00 2001 From: Zhongkai Fu Date: Sun, 24 Sep 2023 19:57:24 -0700 Subject: [PATCH] Check Nan value when loading and saving weights --- Seq2SeqSharp/Models/Model.cs | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/Seq2SeqSharp/Models/Model.cs b/Seq2SeqSharp/Models/Model.cs index e35bf42a..030c6f15 100644 --- a/Seq2SeqSharp/Models/Model.cs +++ b/Seq2SeqSharp/Models/Model.cs @@ -168,6 +168,14 @@ public void AddWeights(string name, float[] weights) { Logger.WriteLine($"Adding weights '{name}' to the model."); + for (int i = 0; i < weights.Length; i++) + { + if (weights[i] == float.NaN) + { + throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value."); + } + } + if (VQType == VQTypeEnums.FLOAT16) { var weightsHalf = new ushort[weights.Length]; @@ -251,6 +259,14 @@ public float[] GetWeights(string name) if (Name2Weights.ContainsKey(name)) { + for (int i = 0; i < Name2Weights[name].Length; i++) + { + if (Name2Weights[name][i] == float.NaN) + { + throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value."); + } + } + weight = Name2Weights[name]; } else if (Name2WeightsHalf.ContainsKey(name)) @@ -270,6 +286,11 @@ public float[] GetWeights(string name) weight = new float[Name2WeightsVQ[name].Length]; for (int i = 0; i < Name2WeightsVQ[name].Length; i++) { + if (codeBook[Name2WeightsVQ[name][i]] == float.NaN) + { + throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value."); + } + weight[i] = (float)codeBook[Name2WeightsVQ[name][i]]; } } @@ -311,6 +332,10 @@ public half[] GetWeightsHalfType(string name) weights = new half[values.Length]; for (int i = 0; i < values.Length; i++) { + if (values[i] == float.NaN) + { + throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value."); + } weights[i] = new half(values[i]); } } @@ -320,6 +345,10 @@ public half[] GetWeightsHalfType(string name) weights = new half[values.Length]; for (int i = 0; i < values.Length; i++) { + if (values[i] == float.NaN) + { + throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value."); + } weights[i] = new half(values[i]); } } @@ -336,6 +365,11 @@ public half[] GetWeightsHalfType(string name) weights = new half[Name2WeightsVQ[name].Length]; for (int i = 0; i < Name2WeightsVQ[name].Length; i++) { + if (codeBook[Name2WeightsVQ[name][i]] == float.NaN) + { + throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value."); + } + weights[i] = new half(codeBook[Name2WeightsVQ[name][i]]); } }