Skip to content

Commit

Permalink
Merge pull request #976 from martindevans/Fixed_UnmanagedMemoryStream
Browse files Browse the repository at this point in the history
Fixed `ChatSession.LoadSession`
  • Loading branch information
martindevans authored Nov 14, 2024
2 parents 4304dce + 8f8a9b2 commit f340f31
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 8 deletions.
30 changes: 30 additions & 0 deletions LLama.Unittest/LLamaContextTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,5 +88,35 @@ public void TokenizeEmpty()

Assert.Equal(Array.Empty<LLamaToken>(), tokens);
}

[Fact]
public void SaveLoadState()
{
using var state1 = _context.GetState();

var stream = new MemoryStream();
state1.Save(stream);

stream.Position = 0;

using var state2 = LLamaContext.State.Load(stream);

Assert.Equal(state1.Size, state2.Size);
}

[Fact]
public async Task SaveLoadStateAsync()
{
using var state1 = _context.GetState();

var stream = new MemoryStream();
await state1.SaveAsync(stream);

stream.Position = 0;

using var state2 = await LLamaContext.State.LoadAsync(stream);

Assert.Equal(state1.Size, state2.Size);
}
}
}
20 changes: 12 additions & 8 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -466,16 +466,15 @@ public void Dispose()
public class State
: SafeLLamaHandleBase
{
private readonly nuint _size;
/// <summary>
/// Get the size in bytes of this state object
/// </summary>
public nuint Size => _size;
public nuint Size { get; }

internal State(IntPtr memory, nuint size)
: base(memory, true)
{
_size = size;
Size = size;
}

/// <inheritdoc />
Expand All @@ -494,7 +493,8 @@ public async Task SaveAsync(Stream stream)
UnmanagedMemoryStream from;
unsafe
{
from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), checked((long)Size));
var length = (long)Size;
from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), length, length, FileAccess.Read);
}
await from.CopyToAsync(stream);
}
Expand All @@ -508,7 +508,8 @@ public void Save(Stream stream)
UnmanagedMemoryStream from;
unsafe
{
from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), checked((long)Size));
var length = (long)Size;
from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), length, length, FileAccess.Read);
}
from.CopyTo(stream);
}
Expand All @@ -526,7 +527,8 @@ public static async Task<State> LoadAsync(Stream stream)
UnmanagedMemoryStream dest;
unsafe
{
dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), stream.Length);
var length = stream.Length;
dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), length, length, FileAccess.Write);
}
await stream.CopyToAsync(dest);

Expand All @@ -543,11 +545,13 @@ public static State Load(Stream stream)
var memory = Marshal.AllocHGlobal((nint)stream.Length);
var state = new State(memory, (nuint)stream.Length);

UnmanagedMemoryStream dest;
unsafe
{
var dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), stream.Length);
stream.CopyTo(dest);
var length = stream.Length;
dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), length, length, FileAccess.Write);
}
stream.CopyTo(dest);

return state;
}
Expand Down

0 comments on commit f340f31

Please sign in to comment.