From 467ee13aee3e7e7ddf2701dcaeb224c308ba4fcd Mon Sep 17 00:00:00 2001 From: Ian Johnson Date: Tue, 15 Oct 2024 19:41:29 -0400 Subject: [PATCH] feat: support alternate encodings Closes #37 --- src/Encoding.zig | 72 ++++++++++++++++++++++++++++ src/Reader.zig | 11 +++-- src/xml.zig | 121 ++++++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 200 insertions(+), 4 deletions(-) create mode 100644 src/Encoding.zig diff --git a/src/Encoding.zig b/src/Encoding.zig new file mode 100644 index 0000000..d381cdd --- /dev/null +++ b/src/Encoding.zig @@ -0,0 +1,72 @@ +const std = @import("std"); + +context: *const anyopaque, +guessFn: *const fn (context: *const anyopaque, text: []const u8) bool, +checkEncodingFn: *const fn (context: *const anyopaque, xml_encoding: []const u8) bool, +transcodeFn: *const fn (context: *const anyopaque, noalias dest: []u8, noalias src: []const u8) TranscodeResult, + +const Encoding = @This(); + +pub const TranscodeResult = struct { + dest_written: usize, + src_read: usize, +}; + +pub fn guess(encoding: Encoding, text: []const u8) bool { + return encoding.guessFn(encoding.context, text); +} + +pub fn checkEncoding(encoding: Encoding, xml_encoding: []const u8) bool { + return encoding.checkEncodingFn(encoding.context, xml_encoding); +} + +pub fn transcode(encoding: Encoding, noalias dest: []u8, noalias src: []const u8) TranscodeResult { + return encoding.transcodeFn(encoding.context, dest, src); +} + +pub const utf8: Encoding = .{ + .context = undefined, + .guessFn = &utf8Guess, + .checkEncodingFn = &utf8CheckEncoding, + .transcodeFn = &utf8Transcode, +}; + +fn utf8Guess(context: *const anyopaque, text: []const u8) bool { + _ = context; + _ = text; + return true; +} + +fn utf8CheckEncoding(context: *const anyopaque, encoding: []const u8) bool { + _ = context; + return std.ascii.eqlIgnoreCase(encoding, "UTF-8"); +} + +fn utf8Transcode(context: *const anyopaque, noalias dest: []u8, noalias src: []const u8) TranscodeResult { + _ = context; + var dest_written: usize = 0; + var src_read: usize = 0; + while (src_read < src.len) { + const cp_len = std.unicode.utf8ByteSequenceLength(src[src_read]) catch break; + if (src_read + cp_len > src.len or dest_written + cp_len > dest.len) break; + switch (cp_len) { + 1 => { + dest[dest_written] = src[src_read]; + dest_written += 1; + src_read += 1; + }, + 2, 3, 4 => { + const slice = src[src_read..][0..cp_len]; + if (!std.unicode.utf8ValidateSlice(slice)) break; + @memcpy(dest[dest_written..][0..cp_len], slice); + dest_written += cp_len; + src_read += cp_len; + }, + else => unreachable, + } + } + return .{ + .dest_written = dest_written, + .src_read = src_read, + }; +} diff --git a/src/Reader.zig b/src/Reader.zig index fea9106..fa05f73 100644 --- a/src/Reader.zig +++ b/src/Reader.zig @@ -155,17 +155,22 @@ pub const ErrorCode = enum { expected_equals, expected_quote, missing_end_quote, - invalid_utf8, + invalid_encoding, illegal_character, }; pub const Source = struct { context: *const anyopaque, moveFn: *const fn (context: *const anyopaque, advance: usize, len: usize) anyerror![]const u8, + checkEncodingFn: *const fn (context: *const anyopaque, encoding: []const u8) bool, pub fn move(source: Source, advance: usize, len: usize) anyerror![]const u8 { return source.moveFn(source.context, advance, len); } + + pub fn checkEncoding(source: Source, encoding: []const u8) bool { + return source.checkEncodingFn(source.context, encoding); + } }; const State = enum { @@ -1565,7 +1570,7 @@ fn checkXmlVersion(reader: *Reader, version: []const u8, n_attr: usize) !void { } fn checkXmlEncoding(reader: *Reader, encoding: []const u8, n_attr: usize) !void { - if (!std.ascii.eqlIgnoreCase(encoding, "utf-8")) { + if (!reader.source.checkEncoding(encoding)) { return reader.fatal(.xml_declaration_encoding_unsupported, reader.attributeValuePos(n_attr)); } } @@ -2149,7 +2154,7 @@ fn fatalInvalidUtf8(reader: *Reader, s: []const u8, pos: usize) error{MalformedX if (!std.unicode.utf8ValidateSlice(s[invalid_pos..][0..cp_len])) break; invalid_pos += cp_len; } - return reader.fatal(.invalid_utf8, pos + invalid_pos); + return reader.fatal(.invalid_encoding, pos + invalid_pos); } const base_read_size = 4096; diff --git a/src/xml.zig b/src/xml.zig index 504c7ac..661cb7f 100644 --- a/src/xml.zig +++ b/src/xml.zig @@ -77,6 +77,8 @@ pub const predefined_namespace_uris = std.StaticStringMap([]const u8).initCompti .{ "xmlns", ns_xmlns }, }); +pub const Encoding = @import("Encoding.zig"); + pub const Reader = @import("Reader.zig"); pub fn GenericReader(comptime SourceError: type) type { @@ -324,6 +326,7 @@ pub const StaticDocument = struct { return .{ .context = doc, .moveFn = &move, + .checkEncodingFn = Encoding.utf8.checkEncodingFn, }; } @@ -368,6 +371,7 @@ pub fn StreamingDocument(comptime ReaderType: type) type { return .{ .context = doc, .moveFn = &move, + .checkEncodingFn = Encoding.utf8.checkEncodingFn, }; } @@ -393,7 +397,7 @@ pub fn StreamingDocument(comptime ReaderType: type) type { const new_buf_len = @min(min_buf_len, std.math.ceilPowerOfTwoAssert(usize, target_len)); doc.buf = try doc.gpa.realloc(doc.buf, new_buf_len); } - doc.avail += try doc.stream.read(doc.buf[doc.avail..]); + doc.avail += try doc.stream.readAll(doc.buf[doc.avail..]); } }; } @@ -428,6 +432,121 @@ test streamingDocument { try expectEqual(.eof, try reader.read()); } +pub fn EncodedDocument(comptime ReaderType: type) type { + return struct { + stream: ReaderType, + encoding: Encoding, + buf: []u8, + pos: usize, + avail: usize, + raw_buf: [4096]u8, + raw_buf_len: usize, + gpa: Allocator, + + pub const Error = ReaderType.Error || Allocator.Error || error{InvalidEncoding}; + + pub fn init(gpa: Allocator, stream: ReaderType, encoding: Encoding) @This() { + return .{ + .stream = stream, + .encoding = encoding, + .buf = &.{}, + .pos = 0, + .avail = 0, + .raw_buf = undefined, + .raw_buf_len = 0, + .gpa = gpa, + }; + } + + pub fn deinit(doc: *@This()) void { + doc.gpa.free(doc.buf); + doc.* = undefined; + } + + pub fn reader(doc: *@This(), gpa: Allocator, options: Reader.Options) GenericReader(Error) { + var modified_options = options; + modified_options.assume_valid_utf8 = true; + return .{ .reader = Reader.init(gpa, doc.source(), modified_options) }; + } + + pub fn source(doc: *@This()) Reader.Source { + return .{ + .context = doc, + .moveFn = &move, + .checkEncodingFn = Encoding.utf8.checkEncodingFn, + }; + } + + fn move(context: *const anyopaque, advance: usize, len: usize) anyerror![]const u8 { + const doc: *@This() = @alignCast(@constCast(@ptrCast(context))); + doc.pos += advance; + if (len <= doc.avail - doc.pos) return doc.buf[doc.pos..][0..len]; + doc.discardRead(); + try doc.fillBuffer(len); + return doc.buf[0..@min(len, doc.avail)]; + } + + fn discardRead(doc: *@This()) void { + doc.avail -= doc.pos; + std.mem.copyForwards(u8, doc.buf[0..doc.avail], doc.buf[doc.pos..][0..doc.avail]); + doc.pos = 0; + } + + const min_buf_len = 4096; + + fn fillBuffer(doc: *@This(), target_len: usize) !void { + if (target_len > doc.buf.len) { + const new_buf_len = @min(min_buf_len, std.math.ceilPowerOfTwoAssert(usize, target_len)); + doc.buf = try doc.gpa.realloc(doc.buf, new_buf_len); + } + while (doc.avail < target_len) { + doc.raw_buf_len += try doc.stream.readAll(doc.raw_buf[doc.raw_buf_len..]); + if (doc.raw_buf_len == 0) return; + const res = doc.encoding.transcode(doc.buf[doc.avail..], doc.raw_buf[0..doc.raw_buf_len]); + if (res.dest_written == 0) return error.InvalidEncoding; + doc.avail += res.dest_written; + std.mem.copyForwards(u8, doc.raw_buf[0..doc.raw_buf_len], doc.raw_buf[res.src_read..doc.raw_buf_len]); + doc.raw_buf_len -= res.src_read; + } + } + + fn checkEncoding(context: *const anyopaque, encoding: []const u8) bool { + const doc: *const @This() = @alignCast(@ptrCast(context)); + return doc.encoding.checkEncoding(encoding); + } + }; +} + +pub fn encodedDocument(gpa: Allocator, reader: anytype, encoding: Encoding) EncodedDocument(@TypeOf(reader)) { + return EncodedDocument(@TypeOf(reader)).init(gpa, reader, encoding); +} + +test encodedDocument { + var fbs = std.io.fixedBufferStream( + \\ + \\Hello, world! + \\ + ); + var doc = encodedDocument(std.testing.allocator, fbs.reader(), Encoding.utf8); + defer doc.deinit(); + var reader = doc.reader(std.testing.allocator, .{}); + defer reader.deinit(); + + try expectEqual(.xml_declaration, try reader.read()); + try expectEqualStrings("1.0", reader.xmlDeclarationVersion()); + + try expectEqual(.element_start, try reader.read()); + try expectEqualStrings("root", reader.elementName()); + + try expectEqual(.text, try reader.read()); + try expectEqualStrings("Hello, world!", reader.textRaw()); + + try expectEqual(.element_end, try reader.read()); + try expectEqualStrings("root", reader.elementName()); + + try expectEqual(.eof, try reader.read()); +} + pub const Writer = @import("Writer.zig"); pub fn GenericWriter(comptime SinkError: type) type {