diff --git a/src/pipemode_op/RecordReader/TFRecordReader.cpp b/src/pipemode_op/RecordReader/TFRecordReader.cpp index 9c83c2c..3d72212 100644 --- a/src/pipemode_op/RecordReader/TFRecordReader.cpp +++ b/src/pipemode_op/RecordReader/TFRecordReader.cpp @@ -12,6 +12,7 @@ // language governing permissions and limitations under the License. #include #include +#include #include "tensorflow/core/lib/hash/crc32c.h" #include "TFRecordReader.hpp" @@ -20,7 +21,16 @@ using sagemaker::tensorflow::TFRecordReader; inline void ValidateLength(const std::uint64_t& length, const std::uint32_t masked_crc32_of_length) { if (tensorflow::crc32c::Unmask(masked_crc32_of_length) != tensorflow::crc32c::Value(reinterpret_cast(&(length)), sizeof(length))) { - throw std::runtime_error("Invalid header crc"); + throw std::runtime_error("CRC check on header failed."); + } +} + +inline void ValidateData(const std::string* storage, const std::uint64_t& length, + const std::uint32_t masked_crc32_of_data) { + auto unmasked_crc = tensorflow::crc32c::Unmask(masked_crc32_of_data); + auto data_crc = tensorflow::crc32c::Value(storage->data(), length); + if (unmasked_crc != data_crc) { + throw std::runtime_error("CRC check on data failed."); } } @@ -36,5 +46,6 @@ bool TFRecordReader::ReadRecord(std::string* storage) { Read(&(storage->at(0)), length); std::uint32_t footer; Read(&footer, sizeof(footer)); + ValidateData(storage, length, footer); return true; } diff --git a/src/pipemode_op/test/testRecordReader/TestTFRecordReader.cpp b/src/pipemode_op/test/testRecordReader/TestTFRecordReader.cpp index cedd8c6..864fb58 100644 --- a/src/pipemode_op/test/testRecordReader/TestTFRecordReader.cpp +++ b/src/pipemode_op/test/testRecordReader/TestTFRecordReader.cpp @@ -57,12 +57,15 @@ std::string ToTFRecord(const std::string& data) { result.push_back(header[i]); } result += data; + auto data_crc = tensorflow::crc32c::Mask(tensorflow::crc32c::Value(data.c_str(), length)); + masked_crc_ptr = reinterpret_cast(&data_crc); for (int i = 0; i < 4; i++) { - result.push_back('f'); + result.push_back(masked_crc_ptr[i]); } return result; } + TEST_F(TFRecordReaderTest, ReadRecord) { std::string encoded = ToTFRecord("hello"); std::unique_ptr reader = MakeTFRecordReader(