Skip to content

Commit

Permalink
Check Nan value when loading and saving weights
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Sep 25, 2023
1 parent 5aa6eb5 commit fffd57a
Showing 1 changed file with 34 additions and 0 deletions.
34 changes: 34 additions & 0 deletions Seq2SeqSharp/Models/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ public void AddWeights(string name, float[] weights)
{
Logger.WriteLine($"Adding weights '{name}' to the model.");

for (int i = 0; i < weights.Length; i++)
{
if (weights[i] == float.NaN)
{
throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value.");
}
}

if (VQType == VQTypeEnums.FLOAT16)
{
var weightsHalf = new ushort[weights.Length];
Expand Down Expand Up @@ -251,6 +259,14 @@ public float[] GetWeights(string name)

if (Name2Weights.ContainsKey(name))
{
for (int i = 0; i < Name2Weights[name].Length; i++)
{
if (Name2Weights[name][i] == float.NaN)
{
throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value.");
}
}

weight = Name2Weights[name];
}
else if (Name2WeightsHalf.ContainsKey(name))
Expand All @@ -270,6 +286,11 @@ public float[] GetWeights(string name)
weight = new float[Name2WeightsVQ[name].Length];
for (int i = 0; i < Name2WeightsVQ[name].Length; i++)
{
if (codeBook[Name2WeightsVQ[name][i]] == float.NaN)
{
throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value.");
}

weight[i] = (float)codeBook[Name2WeightsVQ[name][i]];
}
}
Expand Down Expand Up @@ -311,6 +332,10 @@ public half[] GetWeightsHalfType(string name)
weights = new half[values.Length];
for (int i = 0; i < values.Length; i++)
{
if (values[i] == float.NaN)
{
throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value.");
}
weights[i] = new half(values[i]);
}
}
Expand All @@ -320,6 +345,10 @@ public half[] GetWeightsHalfType(string name)
weights = new half[values.Length];
for (int i = 0; i < values.Length; i++)
{
if (values[i] == float.NaN)
{
throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value.");
}
weights[i] = new half(values[i]);
}
}
Expand All @@ -336,6 +365,11 @@ public half[] GetWeightsHalfType(string name)
weights = new half[Name2WeightsVQ[name].Length];
for (int i = 0; i < Name2WeightsVQ[name].Length; i++)
{
if (codeBook[Name2WeightsVQ[name][i]] == float.NaN)
{
throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value.");
}

weights[i] = new half(codeBook[Name2WeightsVQ[name][i]]);
}
}
Expand Down

0 comments on commit fffd57a

Please sign in to comment.