Skip to content

Commit

Permalink
Add VQType for Float16
Browse files Browse the repository at this point in the history
  • Loading branch information
zhongkaifu committed Aug 22, 2023
1 parent af2f4f4 commit 777cb1d
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Seq2SeqSharp/Applications/GPT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public void UpdateVocabs(Vocab tgtVocab)
public void VQModel()
{
m_modelMetaData.VQType = m_options.VQType;
SaveModel(createBackupPrevious: true, suffix: ".vq");
SaveModel(createBackupPrevious: true, suffix: $".{m_modelMetaData.VQType.ToString()}");

// SaveModel_As_BinaryFormatter(suffix: ".vq.bin");
}
Expand Down
54 changes: 43 additions & 11 deletions Seq2SeqSharp/Models/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ public Vocab ClsVocab

public Dictionary<string, float[]> Name2Weights { get; set; }

public Dictionary<string, half[]> Name2WeightsHalf { get; set; }

public VQTypeEnums VQType { get; set; }
public Dictionary<string, byte[]> Name2WeightsVQ { get; set; }
public Dictionary<string, double[]> Name2CodeBook { get; set; }
Expand All @@ -106,6 +108,7 @@ public Model(Options opts,Vocab srcVocab)
VQType = opts.VQType;

Name2Weights = new Dictionary<string, float[]>();
Name2WeightsHalf= new Dictionary<string, half[]>();
Name2WeightsVQ = new Dictionary<string, byte[]>();
Name2CodeBook = new Dictionary<string, double[]>();
}
Expand All @@ -129,6 +132,7 @@ public Model(Model_4_ProtoBufSerializer m)
VQType = m.VQType;

Name2Weights = m.Name2Weights;
Name2WeightsHalf = m.Name2WeightsHalf;
Name2WeightsVQ = m.Name2WeightsVQ;
Name2CodeBook = m.Name2CodeBook;

Expand All @@ -137,6 +141,11 @@ public Model(Model_4_ProtoBufSerializer m)
Name2Weights = new Dictionary<string, float[]>();
}

if (Name2WeightsHalf == null)
{
Name2WeightsHalf = new Dictionary<string, half[]>();
}

if (Name2WeightsVQ == null)
{
Name2WeightsVQ = new Dictionary<string, byte[]>();
Expand All @@ -152,7 +161,16 @@ public void AddWeights(string name, float[] weights)
{
Logger.WriteLine($"Adding weights '{name}' to the model.");

if (VQType == VQTypeEnums.INT8)
if (VQType == VQTypeEnums.FLOAT16)
{
var weightsHalf = new half[weights.Length];
for (int i = 0; i < weights.Length; i++)
{
weightsHalf[i] = new half(weights[i]);
}
Name2WeightsHalf.Add(name, weightsHalf);
}
else if (VQType == VQTypeEnums.INT8)
{
int vqSize = 256;
VectorQuantization vq = new VectorQuantization();
Expand Down Expand Up @@ -226,7 +244,11 @@ public float[] GetWeights(string name)

if (Name2Weights.ContainsKey(name))
{
weight = Name2Weights[name];
weight = Name2Weights[name];
}
else if (Name2WeightsHalf.ContainsKey(name))
{
throw new InvalidCastException($"The model is saved as Float16 type, so please enable AMP for model loading.");
}
else if (VQType == VQTypeEnums.INT8)
{
Expand Down Expand Up @@ -275,16 +297,20 @@ public float[] GetWeights(string name)

public half[] GetWeightsHalfType(string name)
{
half[] weight = null;
half[] weights = null;
if (Name2Weights.ContainsKey(name))
{
var values = Name2Weights[name];
weight = new half[values.Length];
weights = new half[values.Length];
for (int i = 0; i < values.Length; i++)
{
weight[i] = new half(values[i]);
weights[i] = new half(values[i]);
}
}
else if (Name2WeightsHalf.ContainsKey(name))
{
weights = Name2WeightsHalf[name];
}
else if (VQType == VQTypeEnums.INT8)
{
if (Name2WeightsVQ.ContainsKey(name) == false)
Expand All @@ -295,10 +321,10 @@ public half[] GetWeightsHalfType(string name)

var codeBook = Name2CodeBook[name];

weight = new half[Name2WeightsVQ[name].Length];
weights = new half[Name2WeightsVQ[name].Length];
for (int i = 0; i < Name2WeightsVQ[name].Length; i++)
{
weight[i] = new half(codeBook[Name2WeightsVQ[name][i]]);
weights[i] = new half(codeBook[Name2WeightsVQ[name][i]]);
}
}
else if (VQType == VQTypeEnums.INT4)
Expand All @@ -311,22 +337,22 @@ public half[] GetWeightsHalfType(string name)

var codeBook = Name2CodeBook[name];

weight = new half[Name2WeightsVQ[name].Length * 2];
weights = new half[Name2WeightsVQ[name].Length * 2];
for (int i = 0; i < Name2WeightsVQ[name].Length; i++)
{
double highWeight = codeBook[Name2WeightsVQ[name][i] / 16];
double lowWeight = codeBook[Name2WeightsVQ[name][i] & 0x0F];

weight[i * 2] = new half(lowWeight);
weight[i * 2 + 1] = new half(highWeight);
weights[i * 2] = new half(lowWeight);
weights[i * 2 + 1] = new half(highWeight);
}
}
else
{
Logger.WriteLine(Logger.Level.warn, ConsoleColor.Yellow, $"Weight '{name}' doesn't exist in the model.");
}

return weight;
return weights;
}

public void DeleteWeights(string name)
Expand All @@ -345,13 +371,19 @@ public void DeleteWeights(string name)
{
Name2Weights.Remove(name);
}

if (Name2WeightsHalf != null && Name2WeightsHalf.ContainsKey(name))
{
Name2WeightsHalf.Remove(name);
}
}

public void ClearWeights()
{
Name2WeightsVQ.Clear();
Name2CodeBook.Clear();
Name2Weights.Clear();
Name2WeightsHalf.Clear();
}

public void ShowModelInfo()
Expand Down
8 changes: 8 additions & 0 deletions Seq2SeqSharp/Models/Model_4_ProtoBufSerializer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using Seq2SeqSharp.Corpus;
using Seq2SeqSharp.Utils;
using Seq2SeqSharp.Enums;
using ManagedCuda.BasicTypes;

namespace Seq2SeqSharp.Models
{
Expand Down Expand Up @@ -240,6 +241,12 @@ public Model_4_ProtoBufSerializer(Model m)
Name2Weights = new Dictionary<string, float[]>();
}

Name2WeightsHalf = m.Name2WeightsHalf;
if (Name2WeightsHalf == null)
{
Name2WeightsHalf = new Dictionary<string, half[]>();
}

VQType = m.VQType;
Name2WeightsVQ = m.Name2WeightsVQ;
if (Name2WeightsVQ == null)
Expand Down Expand Up @@ -304,5 +311,6 @@ public Model_4_ProtoBufSerializer(Model m)
[ProtoMember(24)] public VQTypeEnums VQType { get; set; }
[ProtoMember(25)] public Dictionary<string, byte[]> Name2WeightsVQ { get; set; }
[ProtoMember(26)] public Dictionary<string, double[]> Name2CodeBook { get; set; }
[ProtoMember(27)] public Dictionary<string, half[]> Name2WeightsHalf { get; set; }
}
}
1 change: 1 addition & 0 deletions Seq2SeqSharp/Utils/ModeEnums.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public enum DecoderTypeEnums
public enum VQTypeEnums
{
None = 0,
FLOAT16 = 65536,
INT8 = 256,
INT4 = 16
}
Expand Down

0 comments on commit 777cb1d

Please sign in to comment.