diff --git a/src/Init/System/IO.lean b/src/Init/System/IO.lean index cbcc705e7952..94a3f74a86e4 100644 --- a/src/Init/System/IO.lean +++ b/src/Init/System/IO.lean @@ -464,31 +464,23 @@ def withFile (fn : FilePath) (mode : Mode) (f : Handle → IO α) : IO α := def Handle.putStrLn (h : Handle) (s : String) : IO Unit := h.putStr (s.push '\n') -partial def Handle.readBinToEnd (h : Handle) : IO ByteArray := do +partial def Handle.readBinToEndInto (h : Handle) (buf : ByteArray) : IO ByteArray := do let rec loop (acc : ByteArray) : IO ByteArray := do let buf ← h.read 1024 if buf.isEmpty then return acc else loop (acc ++ buf) - loop ByteArray.empty + loop buf -partial def Handle.readToEnd (h : Handle) : IO String := do - let rec loop (s : String) := do - let line ← h.getLine - if line.isEmpty then - return s - else - loop (s ++ line) - loop "" - -def readBinFile (fname : FilePath) : IO ByteArray := do - let h ← Handle.mk fname Mode.read - h.readBinToEnd +partial def Handle.readBinToEnd (h : Handle) : IO ByteArray := do + h.readBinToEndInto .empty -def readFile (fname : FilePath) : IO String := do - let h ← Handle.mk fname Mode.read - h.readToEnd +def Handle.readToEnd (h : Handle) : IO String := do + let data ← h.readBinToEnd + match String.fromUTF8? data with + | some s => return s + | none => throw <| .userError s!"Tried to read from handle containing non UTF-8 data." partial def lines (fname : FilePath) : IO (Array String) := do let h ← Handle.mk fname Mode.read @@ -594,6 +586,28 @@ end System.FilePath namespace IO +namespace FS + +def readBinFile (fname : FilePath) : IO ByteArray := do + -- Requires metadata so defined after metadata + let mdata ← fname.metadata + let size := mdata.byteSize.toUSize + let handle ← IO.FS.Handle.mk fname .read + let buf ← + if size > 0 then + handle.read mdata.byteSize.toUSize + else + pure <| ByteArray.mkEmpty 0 + handle.readBinToEndInto buf + +def readFile (fname : FilePath) : IO String := do + let data ← readBinFile fname + match String.fromUTF8? data with + | some s => return s + | none => throw <| .userError s!"Tried to read file '{fname}' containing non UTF-8 data." + +end FS + def withStdin [Monad m] [MonadFinally m] [MonadLiftT BaseIO m] (h : FS.Stream) (x : m α) : m α := do let prev ← setStdin h try x finally discard <| setStdin prev diff --git a/src/runtime/io.cpp b/src/runtime/io.cpp index 136612e2fee5..552e0ca062ff 100644 --- a/src/runtime/io.cpp +++ b/src/runtime/io.cpp @@ -485,43 +485,30 @@ extern "C" LEAN_EXPORT obj_res lean_io_prim_handle_write(b_obj_arg h, b_obj_arg } } -/* - Handle.getLine : (@& Handle) → IO Unit - The line returned by `lean_io_prim_handle_get_line` - is truncated at the first '\0' character and the - rest of the line is discarded. */ +/* Handle.getLine : (@& Handle) → IO Unit */ extern "C" LEAN_EXPORT obj_res lean_io_prim_handle_get_line(b_obj_arg h, obj_arg /* w */) { FILE * fp = io_get_handle(h); - const int buf_sz = 64; - char buf_str[buf_sz]; // NOLINT - std::string result; - bool first = true; - while (true) { - char * out = std::fgets(buf_str, buf_sz, fp); - if (out != nullptr) { - if (strlen(buf_str) < buf_sz-1 || buf_str[buf_sz-2] == '\n') { - if (first) { - return io_result_mk_ok(mk_string(out)); - } else { - result.append(out); - return io_result_mk_ok(mk_string(result)); - } - } - result.append(out); - } else if (std::feof(fp)) { - clearerr(fp); - return io_result_mk_ok(mk_string(result)); - } else { - return io_result_mk_error(decode_io_error(errno, nullptr)); - } - first = false; + char* buf = NULL; + size_t n = 0; + ssize_t read = getline(&buf, &n, fp); + if (read != -1) { + obj_res ret = io_result_mk_ok(mk_string_from_bytes(buf, read)); + free(buf); + return ret; + } else if (std::feof(fp)) { + clearerr(fp); + return io_result_mk_ok(mk_string("")); + } else { + return io_result_mk_error(decode_io_error(errno, nullptr)); } } /* Handle.putStr : (@& Handle) → (@& String) → IO Unit */ extern "C" LEAN_EXPORT obj_res lean_io_prim_handle_put_str(b_obj_arg h, b_obj_arg s, obj_arg /* w */) { FILE * fp = io_get_handle(h); - if (std::fputs(lean_string_cstr(s), fp) != EOF) { + usize n = lean_string_size(s) - 1; // - 1 to ignore the terminal NULL byte. + usize m = std::fwrite(lean_string_cstr(s), 1, n, fp); + if (m == n) { return io_result_mk_ok(box(0)); } else { return io_result_mk_error(decode_io_error(errno, nullptr)); diff --git a/src/runtime/object.h b/src/runtime/object.h index b5e5a6375817..e5d6181d7909 100644 --- a/src/runtime/object.h +++ b/src/runtime/object.h @@ -238,6 +238,7 @@ inline size_t string_capacity(object * o) { return lean_string_capacity(o); } inline uint32 char_default_value() { return lean_char_default_value(); } inline obj_res alloc_string(size_t size, size_t capacity, size_t len) { return lean_alloc_string(size, capacity, len); } inline obj_res mk_string(char const * s) { return lean_mk_string(s); } +inline obj_res mk_string_from_bytes(char const * s, size_t sz) { return lean_mk_string_from_bytes(s, sz); } LEAN_EXPORT obj_res mk_ascii_string_unchecked(std::string const & s); LEAN_EXPORT obj_res mk_string(std::string const & s); LEAN_EXPORT std::string string_to_std(b_obj_arg o); diff --git a/tests/lean/run/3546.lean b/tests/lean/run/3546.lean new file mode 100644 index 000000000000..4491b72cd67a --- /dev/null +++ b/tests/lean/run/3546.lean @@ -0,0 +1,14 @@ +def test : IO Unit := do + let tmpFile := "3546.tmp" + let firstLine := "foo\u0000bar\n" + let content := firstLine ++ "hello world\nbye" + IO.FS.writeFile tmpFile content + let handle ← IO.FS.Handle.mk tmpFile .read + let firstReadLine ← handle.getLine + let cond := firstLine == firstReadLine && firstReadLine.length == 8 -- paranoid + IO.println cond + IO.FS.removeFile tmpFile + +/-- info: true -/ +#guard_msgs in +#eval test