Skip to content

Commit

Permalink
Update strategy for corrupted weights
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Sep 25, 2023
1 parent fffd57a commit af9fb59
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 84 deletions.
35 changes: 0 additions & 35 deletions Seq2SeqSharp/Models/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -167,15 +167,6 @@ public Model(Model_4_ProtoBufSerializer m)
public void AddWeights(string name, float[] weights)
{
Logger.WriteLine($"Adding weights '{name}' to the model.");

for (int i = 0; i < weights.Length; i++)
{
if (weights[i] == float.NaN)
{
throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value.");
}
}

if (VQType == VQTypeEnums.FLOAT16)
{
var weightsHalf = new ushort[weights.Length];
Expand Down Expand Up @@ -259,14 +250,6 @@ public float[] GetWeights(string name)

if (Name2Weights.ContainsKey(name))
{
for (int i = 0; i < Name2Weights[name].Length; i++)
{
if (Name2Weights[name][i] == float.NaN)
{
throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value.");
}
}

weight = Name2Weights[name];
}
else if (Name2WeightsHalf.ContainsKey(name))
Expand All @@ -286,11 +269,6 @@ public float[] GetWeights(string name)
weight = new float[Name2WeightsVQ[name].Length];
for (int i = 0; i < Name2WeightsVQ[name].Length; i++)
{
if (codeBook[Name2WeightsVQ[name][i]] == float.NaN)
{
throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value.");
}

weight[i] = (float)codeBook[Name2WeightsVQ[name][i]];
}
}
Expand Down Expand Up @@ -332,10 +310,6 @@ public half[] GetWeightsHalfType(string name)
weights = new half[values.Length];
for (int i = 0; i < values.Length; i++)
{
if (values[i] == float.NaN)
{
throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value.");
}
weights[i] = new half(values[i]);
}
}
Expand All @@ -345,10 +319,6 @@ public half[] GetWeightsHalfType(string name)
weights = new half[values.Length];
for (int i = 0; i < values.Length; i++)
{
if (values[i] == float.NaN)
{
throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value.");
}
weights[i] = new half(values[i]);
}
}
Expand All @@ -365,11 +335,6 @@ public half[] GetWeightsHalfType(string name)
weights = new half[Name2WeightsVQ[name].Length];
for (int i = 0; i < Name2WeightsVQ[name].Length; i++)
{
if (codeBook[Name2WeightsVQ[name][i]] == float.NaN)
{
throw new InvalidOperationException($"The weights '{name}' is corrupted due to Nan value.");
}

weights[i] = new half(codeBook[Name2WeightsVQ[name][i]]);
}
}
Expand Down
44 changes: 26 additions & 18 deletions Seq2SeqSharp/Tools/BaseSeq2SeqFramework.cs
Original file line number Diff line number Diff line change
Expand Up @@ -416,9 +416,13 @@ internal void TrainOneEpoch(int ep, ICorpus<ISntPairBatch> trainCorpus, ICorpus<

if (float.IsNaN(cost))
{
Logger.WriteLine(Logger.Level.warn, "The cost result is Nan, so it seems weights are corrupted. Let's roll back to the previous best checkpoint.");
Logger.WriteLine(Logger.Level.warn, "The cost result is Nan, so we won't update weights at this time.");

if (IsWeightsCorrupted())
{
throw new WeightsCorruptedException($"The weights has been corrupted. Abort training and please check checkpoint files.");
}

LoadModel();
break;
}

Expand Down Expand Up @@ -461,7 +465,6 @@ internal void TrainOneEpoch(int ep, ICorpus<ISntPairBatch> trainCorpus, ICorpus<
{
string oomMessage = string.Empty;
bool isOutOfMemException = false;
bool isArithmeticException = false;
foreach (var excep in err.InnerExceptions)
{
if (excep is OutOfMemoryException)
Expand All @@ -471,12 +474,6 @@ internal void TrainOneEpoch(int ep, ICorpus<ISntPairBatch> trainCorpus, ICorpus<
oomMessage = excep.Message;
break;
}
else if (excep is ArithmeticException)
{
isArithmeticException = true;
oomMessage = excep.Message;
break;
}
else
{
Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Inner Exception: {excep.Message}, Call stack: {excep.StackTrace}");
Expand All @@ -492,14 +489,9 @@ internal void TrainOneEpoch(int ep, ICorpus<ISntPairBatch> trainCorpus, ICorpus<
break;
}
}
else if (isArithmeticException)
{
Logger.WriteLine($"Arithmetic exception: '{err.Message}'");
break;
}
else
{
Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Exception: {err.Message}, Call stack: {err.StackTrace}");
Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Inner Exception: {err.Message}, Call stack: {err.StackTrace}");
throw err;
}
}
Expand All @@ -519,10 +511,10 @@ internal void TrainOneEpoch(int ep, ICorpus<ISntPairBatch> trainCorpus, ICorpus<
break;
}
}
catch (ArithmeticException err)
catch (WeightsCorruptedException err)
{
Logger.WriteLine($"Arithmetic exception: '{err.Message}'");
break;
Logger.WriteLine(Logger.Level.err, ConsoleColor.Red, $"Exception: {err.Message}, Call stack: {err.StackTrace}");
throw;
}
catch (Exception err)
{
Expand Down Expand Up @@ -1208,6 +1200,22 @@ protected virtual void LoadParameters(IModel model)
}
}


internal bool IsWeightsCorrupted()
{
var weights = GetParametersFromDefaultDevice();

foreach (var weight in weights)
{
if (weight.IsWeightsCorrupted())
{
return true;
}
}

return false;
}

/// <summary>
/// Copy weights from default device to all other devices
/// </summary>
Expand Down
24 changes: 2 additions & 22 deletions Seq2SeqSharp/Tools/IWeightTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,53 +20,33 @@ public interface IWeightTensor : INeuralUnit, IDisposable
int Rows { get; set; }
int Columns { get; set; }
string Name { get; set; }

bool IsTrainable { get; set; }

bool NeedGradient { get; set; }

int DeviceId { get; set; }

float LearningRateFactor { get; set; }

DType ElementType { get;}
IAllocator Allocator { get; }

float GetWeightAt(long[] indices);
float GetGradientAt(long[] indices);

void SetWeightAt(float val, long[] indices);

void CopyWeightsToGradients(IWeightTensor src);

List<int> GetTopNMaxWeightIdx(int topN);

void SetWeightArray(float[] v);

void ReleaseWeight();
void ReleaseGradient();

void ZeroGradient();
void CleanWeight();

WeightTensor CopyWeightsRef(string name, bool needGradient, IComputeGraph graphToBind);

void CopyWeightsFrom(IWeightTensor src);
void AddGradientFrom(IWeightTensor src);

float[] ToWeightArray();

void UnbindFromComputeGraph();

bool IsGradientNull();

IAllocator Allocator { get; }

void FillGradient(float val);

void Clamp(float min, float max);

long ElementCount { get; }

void PrintWeights();
bool IsWeightsCorrupted();
}
}
26 changes: 17 additions & 9 deletions Seq2SeqSharp/Tools/WeightTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ public int Columns
private Tensor m_TWeight = null;
private Tensor m_TGradient = null;
private static readonly object locker = new object();

// private bool releasedWeight = false;
private readonly IComputeGraph m_computeGraphToBind;

private string m_GradientSetName = "None";
Expand All @@ -83,11 +81,6 @@ public Tensor TWeight
{
get
{
//if (releasedWeight)
//{
// throw new Exception($"The weight '{Name}' has been released, you cannot access it.");
//}

if (m_TWeight == null)
{
m_TWeight = new Tensor(m_allocator, m_elementType, Sizes);
Expand Down Expand Up @@ -118,7 +111,6 @@ public Tensor TWeight
}
}
}
// releasedWeight = false;
}
}
}
Expand Down Expand Up @@ -248,6 +240,23 @@ public INeuralUnit CloneToDeviceAt(int deviceId)
return new WeightTensor(Sizes, deviceId, Name, IsTrainable, initType: m_normType, fanIn: m_fanIn, fanOut: m_fanOut, needGradient: NeedGradient, dtype: m_elementType);
}


public bool IsWeightsCorrupted()
{
float[] weights = ToWeightArray();

for (int i = 0; i < weights.Length; i++)
{
if (float.IsNaN(weights[i]) || float.IsInfinity(weights[i]))
{
return true;
}
}

return false;
}


public void ZeroGradient()
{
Ops.Fill(TGradient, 0.0f);
Expand Down Expand Up @@ -544,7 +553,6 @@ public void ReleaseWeight()
{
m_TWeight.Dispose();
m_TWeight = null;
// releasedWeight = true;
}
}

Expand Down
21 changes: 21 additions & 0 deletions Seq2SeqSharp/Utils/WeightsCorruptedException.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) Zhongkai Fu. All rights reserved.
// https://github.com/zhongkaifu/Seq2SeqSharp
//
// This file is part of Seq2SeqSharp.
//
// Seq2SeqSharp is licensed under the BSD-3-Clause license found in the LICENSE file in the root directory of this source tree.
//
// Seq2SeqSharp is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the BSD-3-Clause License for more details.

using System;

namespace Seq2SeqSharp.Utils
{
public class WeightsCorruptedException : Exception
{
public WeightsCorruptedException() { }

public WeightsCorruptedException(string message) : base(message) { }
}
}

0 comments on commit af9fb59

Please sign in to comment.