Skip to content

Commit

Permalink
Remove some vector allocations / copies (milvus-io#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephentoub authored Jul 5, 2023
1 parent e6f4b4d commit 40c724b
Show file tree
Hide file tree
Showing 8 changed files with 191 additions and 114 deletions.
46 changes: 32 additions & 14 deletions src/IO.Milvus/BinaryVectorField.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using Google.Protobuf;
using IO.Milvus.Diagnostics;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Diagnostics;

namespace IO.Milvus;

Expand All @@ -23,32 +25,48 @@ public BinaryVectorField(string fieldName, IList<byte[]> bytes)
/// <inheritdoc />
public override Grpc.FieldData ToGrpcFieldData()
{
Grpc.FloatArray floatArray = new();

int dim = Data.First().Length;
if (!Data.All(p => p.Length == dim))
int dataCount = Data.Count;
if (dataCount <= 0)
{
throw new Diagnostics.MilvusException("Row count of fields must be equal");
throw new MilvusException("Number of rows must be positive.");
}

using MemoryStream stream = new();
using BinaryWriter writer = new(stream);
foreach (byte[] value in Data)
int dim = Data[0].Length;
int lengthSum = 0;
for (int i = 1; i < dataCount; i++)
{
writer.Write(value);
int rowLength = Data[i].Length;
if (rowLength != dim)
{
throw new MilvusException("Row count of fields must be equal.");
}

checked { lengthSum += rowLength; }
}

ByteString byteString = ByteString.CopyFrom(stream.ToArray());
byte[] bytes = ArrayPool<byte>.Shared.Rent(lengthSum);
int pos = 0;
for (int i = 0; i < dataCount; i++)
{
byte[] row = Data[i];
Array.Copy(row, 0, bytes, pos, row.Length);
pos += row.Length;
}
Debug.Assert(pos == lengthSum);

return new Grpc.FieldData()
var result = new Grpc.FieldData()
{
FieldName = FieldName,
Type = (Grpc.DataType)DataType,
Vectors = new Grpc.VectorField()
{
BinaryVector = byteString,
BinaryVector = ByteString.CopyFrom(bytes.AsSpan(0, lengthSum)),
Dim = dim,
},
};

ArrayPool<byte>.Shared.Return(bytes);

return result;
}
}
44 changes: 15 additions & 29 deletions src/IO.Milvus/Field.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,7 @@ public static Field FromGrpcFieldData(Grpc.FieldData fieldData)
}
else if (fieldData.Vectors.DataCase == Grpc.VectorField.DataOneofCase.BinaryVector)
{
byte[] bytes = fieldData.Vectors.BinaryVector.ToByteArray();

List<byte[]> byteArray = new();

using MemoryStream stream = new(bytes);
using BinaryReader reader = new(stream);

Byte[] subBytes = reader.ReadBytes(dim);
while (subBytes.Length > 0)
{
byteArray.Add(subBytes);
subBytes = reader.ReadBytes(dim);
}

return Field.CreateBinaryVectors(fieldData.FieldName, byteArray);
return CreateFromBytes(fieldData.FieldName, fieldData.Vectors.BinaryVector.Span, dim);
}
else
{
Expand Down Expand Up @@ -224,7 +210,7 @@ internal static MilvusDataType EnsureDataType<TDataType>()
/// <list type="bullet">
/// <item><see cref="bool"/> : bool <see cref="MilvusDataType.Bool"/></item>
/// <item><see cref="sbyte"/> : int8 <see cref="MilvusDataType.Int8"/></item>
/// <item><see cref="Int16"/> : int16 <see cref="MilvusDataType.Int16"/></item>
/// <item><see cref="short"/> : int16 <see cref="MilvusDataType.Int16"/></item>
/// <item><see cref="int"/> : int32 <see cref="MilvusDataType.Int32"/></item>
/// <item><see cref="long"/> : int64 <see cref="MilvusDataType.Int64"/></item>
/// <item><see cref="float"/> : float <see cref="MilvusDataType.Float"/></item>
Expand Down Expand Up @@ -265,24 +251,25 @@ public static Field<string> CreateVarChar(
/// <param name="bytes">Byte array data.</param>
/// <param name="dimension">Dimension of data.</param>
/// <returns></returns>
public static BinaryVectorField CreateFromBytes(string fieldName, byte[] bytes, long dimension)
public static BinaryVectorField CreateFromBytes(string fieldName, ReadOnlySpan<byte> bytes, long dimension)
{
Verify.NotNullOrWhiteSpace(fieldName);
Verify.GreaterThan(dimension, 0);

List<byte[]> byteArray = new();
List<byte[]> byteArray = new((int)Math.Ceiling((double)bytes.Length / dimension));

using MemoryStream stream = new(bytes);
using BinaryReader reader = new(stream);
while (bytes.Length > dimension)
{
byteArray.Add(bytes.Slice(0, (int)dimension).ToArray());
bytes = bytes.Slice((int)dimension);
}

Byte[] subBytes = reader.ReadBytes((int)dimension);
while (subBytes.Length > 0)
if (!bytes.IsEmpty)
{
byteArray.Add(subBytes);
subBytes = reader.ReadBytes((int)dimension);
byteArray.Add(bytes.ToArray());
}

BinaryVectorField field = new(fieldName, byteArray);
return field;
return new BinaryVectorField(fieldName, byteArray);
}

/// <summary>
Expand Down Expand Up @@ -322,8 +309,7 @@ internal static FloatVectorField CreateFloatVector(string fieldName, List<float>

for (int i = 0; i < floatVector.Count; i += (int)dimension)
{
List<float> subVector = floatVector.GetRange(i, (int)dimension);
floatVectors.Add(subVector);
floatVectors.Add(floatVector.GetRange(i, (int)dimension));
}

return new FloatVectorField(fieldName, floatVectors);
Expand Down Expand Up @@ -469,7 +455,7 @@ public override Grpc.FieldData ToGrpcFieldData()
case MilvusDataType.Int16:
{
Grpc.IntArray intData = new();
intData.Data.AddRange((Data as IEnumerable<Int16>).Select(static p => (int)p));
intData.Data.AddRange((Data as IEnumerable<short>).Select(static p => (int)p));

fieldData.Scalars = new Grpc.ScalarField()
{
Expand Down
2 changes: 1 addition & 1 deletion src/IO.Milvus/FieldType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public static FieldType Create(
/// <list type="bullet">
/// <item><see cref="bool"/> : bool <see cref="MilvusDataType.Bool"/></item>
/// <item><see cref="sbyte"/> : int8 <see cref="MilvusDataType.Int8"/></item>
/// <item><see cref="Int16"/> : int16 <see cref="MilvusDataType.Int16"/></item>
/// <item><see cref="short"/> : int16 <see cref="MilvusDataType.Int16"/></item>
/// <item><see cref="int"/> : int32 <see cref="MilvusDataType.Int32"/></item>
/// <item><see cref="long"/> : int64 <see cref="MilvusDataType.Int64"/></item>
/// <item><see cref="float"/> : float <see cref="MilvusDataType.Float"/></item>
Expand Down
1 change: 1 addition & 0 deletions src/IO.Milvus/IO.Milvus.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
<Version>2.2.1-alpha.5</Version>
<GenerateDocumentationFile>True</GenerateDocumentationFile>
<PackageReadmeFile>readme.md</PackageReadmeFile>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<ItemGroup>
Expand Down
107 changes: 57 additions & 50 deletions src/IO.Milvus/MilvusFieldConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text.Json;
using System.Text.Json.Serialization;

Expand All @@ -16,7 +17,7 @@ public sealed class MilvusFieldConverter : JsonConverter<IList<Field>>
public override IList<Field> Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options)
{
List<Field> list = new();
DeserializePropertyList<Field>(ref reader, list);
DeserializePropertyList(ref reader, list);
return list;
}

Expand All @@ -36,9 +37,9 @@ private static void DeserializePropertyList<T>(ref Utf8JsonReader reader, IList<

while (reader.Read() && reader.TokenType != JsonTokenType.EndArray)
{
if (TryCastValue(ref reader, typeof(T), out Object item))
if (TryCastValue(ref reader, out T item))
{
list.Add((T)item);
list.Add(item);
}

System.Diagnostics.Debug.Assert(reader.TokenType != JsonTokenType.StartArray);
Expand All @@ -51,41 +52,40 @@ private static void DeserializePropertyList<T>(ref Utf8JsonReader reader, IList<
System.Diagnostics.Debug.Assert(reader.TokenType == JsonTokenType.EndArray);
}

private static bool TryCastValue(ref Utf8JsonReader reader, Type vtype, out Object value)
private static bool TryCastValue<T>(ref Utf8JsonReader reader, out T value)
{
value = null;

if (reader.TokenType == JsonTokenType.EndArray) return false;
if (reader.TokenType == JsonTokenType.EndObject) return false;
// if (reader.TokenType == JsonToken.EndConstructor) return false;
if (reader.TokenType is JsonTokenType.EndArray or JsonTokenType.EndObject)
{
value = default;
return false;
}

if (reader.TokenType == JsonTokenType.PropertyName) reader.Read();
if (reader.TokenType == JsonTokenType.PropertyName)
{
reader.Read();
}

// untangle nullable
Type ntype = Nullable.GetUnderlyingType(vtype);
if (ntype != null) vtype = ntype;

if (vtype == typeof(String)) { value = reader.GetString(); return true; }
if (vtype == typeof(Boolean)) { value = reader.GetBoolean(); return true; }
if (vtype == typeof(SByte)) { value = reader.GetInt16(); return true; }
if (vtype == typeof(Int16)) { value = reader.GetInt16(); return true; }
if (vtype == typeof(Int32)) { value = reader.GetInt32(); return true; }
if (vtype == typeof(Int64)) { value = reader.GetInt64(); return true; }
if (vtype == typeof(UInt16)) { value = reader.GetUInt16(); return true; }
if (vtype == typeof(UInt32)) { value = reader.GetUInt32(); return true; }
if (vtype == typeof(UInt64)) { value = reader.GetUInt64(); return true; }
if (vtype == typeof(Single)) { value = reader.GetSingle(); return true; }
if (vtype == typeof(Double)) { value = reader.GetDouble(); return true; }
if (vtype == typeof(Decimal)) { value = reader.GetDecimal(); return true; }
if (vtype == typeof(byte)) { value = reader.GetByte(); return true; }

if (vtype == typeof(Field))
if (typeof(T) == typeof(string)) { value = (T)(object)reader.GetString(); return true; }
if (typeof(T) == typeof(bool)) { value = (T)(object)reader.GetBoolean(); return true; }
if (typeof(T) == typeof(sbyte)) { value = (T)(object)reader.GetInt16(); return true; }
if (typeof(T) == typeof(short)) { value = (T)(object)reader.GetInt16(); return true; }
if (typeof(T) == typeof(int)) { value = (T)(object)reader.GetInt32(); return true; }
if (typeof(T) == typeof(long)) { value = (T)(object)reader.GetInt64(); return true; }
if (typeof(T) == typeof(ushort)) { value = (T)(object)reader.GetUInt16(); return true; }
if (typeof(T) == typeof(uint)) { value = (T)(object)reader.GetUInt32(); return true; }
if (typeof(T) == typeof(ulong)) { value = (T)(object)reader.GetUInt64(); return true; }
if (typeof(T) == typeof(float)) { value = (T)(object)reader.GetSingle(); return true; }
if (typeof(T) == typeof(double)) { value = (T)(object)reader.GetDouble(); return true; }
if (typeof(T) == typeof(decimal)) { value = (T)(object)reader.GetDecimal(); return true; }
if (typeof(T) == typeof(byte)) { value = (T)(object)reader.GetByte(); return true; }

if (typeof(T) == typeof(Field))
{
value = DeserializeField(ref reader);
value = (T)(object)DeserializeField(ref reader);
return true;
}

throw new NotImplementedException($"Can't deserialize {vtype}");
throw new NotImplementedException($"Can't deserialize {typeof(T)}");
}

private static Field DeserializeField(ref Utf8JsonReader reader)
Expand Down Expand Up @@ -148,25 +148,25 @@ private static Field DeserializeField(ref Utf8JsonReader reader)
switch (scalarTypeName)
{
case "BoolData":
DeserializePropertyList<bool>(ref reader, boolData);
DeserializePropertyList(ref reader, boolData);
break;
case "BytesData":
DeserializePropertyList<byte>(ref reader, bytesData);
DeserializePropertyList(ref reader, bytesData);
break;
case "IntData":
DeserializePropertyList<int>(ref reader, intData);
DeserializePropertyList(ref reader, intData);
break;
case "FloatData":
DeserializePropertyList<float>(ref reader, floatData);
DeserializePropertyList(ref reader, floatData);
break;
case "DoubleData":
DeserializePropertyList<double>(ref reader, doubleData);
DeserializePropertyList(ref reader, doubleData);
break;
case "StringData":
DeserializePropertyList<string>(ref reader, stringData);
DeserializePropertyList(ref reader, stringData);
break;
case "LongData":
DeserializePropertyList<long>(ref reader, longData);
DeserializePropertyList(ref reader, longData);
break;
default:
throw new JsonException($"Unexpected property {scalarTypeName}");
Expand Down Expand Up @@ -197,31 +197,38 @@ private static Field DeserializeField(ref Utf8JsonReader reader)
switch (dataType)
{
case MilvusDataType.Bool:
field = Field.Create<bool>(fieldName, boolData);
field = Field.Create(fieldName, boolData);
break;
case MilvusDataType.Int8:
field = Field.Create<sbyte>(fieldName, intData.Select(static s => (sbyte)s).ToList());
field = Field.Create(fieldName, intData.Select(static s => (sbyte)s).ToList());
break;
case MilvusDataType.Int16:
field = Field.Create<Int16>(fieldName, intData.Select(static s => (short)s).ToList()); ;
field = Field.Create(fieldName, intData.Select(static s => (short)s).ToList()); ;
break;
case MilvusDataType.Int32:
field = Field.Create<int>(fieldName, intData);
field = Field.Create(fieldName, intData);
break;
case MilvusDataType.Int64:
field = Field.Create<long>(fieldName, longData);
field = Field.Create(fieldName, longData);
break;
case MilvusDataType.Float:
field = Field.Create<float>(fieldName, floatData);
field = Field.Create(fieldName, floatData);
break;
case MilvusDataType.Double:
field = Field.Create<double>(fieldName, doubleData);
field = Field.Create(fieldName, doubleData);
break;
case MilvusDataType.VarChar:
field = Field.CreateVarChar(fieldName, stringData);
break;
case MilvusDataType.BinaryVector:
field = Field.CreateFromBytes(fieldName, binaryVector.ToArray(), dim);
field = Field.CreateFromBytes(
fieldName,
#if NET6_0_OR_GREATER
CollectionsMarshal.AsSpan(binaryVector),
#else
binaryVector.ToArray(),
#endif
dim);
break;
case MilvusDataType.FloatVector:
field = Field.CreateFloatVector(fieldName, floatVector, dim);
Expand Down Expand Up @@ -264,12 +271,12 @@ private static long DeserializeVector(ref Utf8JsonReader reader, List<float> flo
{
case "FloatVector":
{
DeserializePropertyList<float>(ref reader, floatVector);
DeserializePropertyList(ref reader, floatVector);
}
break;
case "BinaryVector":
{
DeserializePropertyList<byte>(ref reader, binaryVector);
DeserializePropertyList(ref reader, binaryVector);
}
break;
default:
Expand All @@ -289,7 +296,7 @@ private static long DeserializeVector(ref Utf8JsonReader reader, List<float> flo
return dim;
}

private static Object DeserializeUnknownObject(ref Utf8JsonReader reader)
private static object DeserializeUnknownObject(ref Utf8JsonReader reader)
{
if (reader.TokenType == JsonTokenType.PropertyName) reader.Read();

Expand Down Expand Up @@ -336,7 +343,7 @@ private static Object DeserializeUnknownObject(ref Utf8JsonReader reader)

internal static class JsonConverterExtension
{
public static Object GetAnyValue(this in Utf8JsonReader reader)
public static object GetAnyValue(this in Utf8JsonReader reader)
{
return reader.TokenType switch
{
Expand Down
2 changes: 1 addition & 1 deletion src/IO.Milvus/MilvusMutationResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ internal static MilvusIds From(Grpc.IDs ids)

if (ids.StrId?.Data?.Count > 0)
{
idField.StrId = new MilvusId<String>
idField.StrId = new MilvusId<string>
{
Data = ids.StrId.Data.ToList(),
};
Expand Down
Loading

0 comments on commit 40c724b

Please sign in to comment.