From 58785a3c17b6c814f075d60bd91a3d2337dca9a3 Mon Sep 17 00:00:00 2001 From: Sai Asish Y Date: Mon, 18 May 2026 23:50:36 -0700 Subject: [PATCH] Fix enum fields in NDArray of records on the Python numpy serialization path --- cpp/test/generated/binary/protocols.cc | 8 ++++ cpp/test/generated/binary/protocols.h | 2 + cpp/test/generated/hdf5/protocols.cc | 8 ++++ cpp/test/generated/hdf5/protocols.h | 4 ++ cpp/test/generated/mocks.cc | 24 ++++++++++++ cpp/test/generated/model.json | 8 ++++ cpp/test/generated/ndjson/protocols.cc | 8 ++++ cpp/test/generated/ndjson/protocols.h | 2 + cpp/test/generated/protocols.cc | 38 ++++++++++++++++--- cpp/test/generated/protocols.h | 8 ++++ cpp/test/roundtrip_test.cc | 5 +++ .../+test_model/+binary/EnumsReader.m | 6 +++ .../+test_model/+binary/EnumsWriter.m | 6 +++ .../+test_model/+testing/MockEnumsWriter.m | 13 +++++++ .../+test_model/+testing/TestEnumsWriter.m | 5 +++ .../generated/+test_model/EnumsReaderBase.m | 16 +++++++- .../generated/+test_model/EnumsWriterBase.m | 17 ++++++++- matlab/test/RoundTripTest.m | 6 +++ models/test/unittests.yml | 1 + python/test_model/binary.py | 6 +++ python/test_model/ndjson.py | 10 +++++ python/test_model/protocols.py | 38 +++++++++++++++++-- python/tests/test_generated_types.py | 12 ++++++ python/tests/test_protocol_roundtrip.py | 22 +++++++++++ .../internal/python/static_files/_binary.py | 15 +++++++- 25 files changed, 274 insertions(+), 14 deletions(-) diff --git a/cpp/test/generated/binary/protocols.cc b/cpp/test/generated/binary/protocols.cc index 5fc3efc9..613fab6f 100644 --- a/cpp/test/generated/binary/protocols.cc +++ b/cpp/test/generated/binary/protocols.cc @@ -4139,6 +4139,10 @@ void EnumsWriter::WriteRecImpl(test_model::RecordWithEnums const& value) { test_model::binary::WriteRecordWithEnums(stream_, value); } +void EnumsWriter::WriteRecArrayImpl(yardl::DynamicNDArray const& value) { + yardl::binary::WriteDynamicNDArray(stream_, value); +} + void EnumsWriter::Flush() { stream_.Flush(); } @@ -4163,6 +4167,10 @@ void EnumsReader::ReadRecImpl(test_model::RecordWithEnums& value) { test_model::binary::ReadRecordWithEnums(stream_, value); } +void EnumsReader::ReadRecArrayImpl(yardl::DynamicNDArray& value) { + yardl::binary::ReadDynamicNDArray(stream_, value); +} + void EnumsReader::CloseImpl() { if (!skip_completed_check_) { stream_.VerifyFinished(); diff --git a/cpp/test/generated/binary/protocols.h b/cpp/test/generated/binary/protocols.h index bac9ff8b..d036a3a5 100644 --- a/cpp/test/generated/binary/protocols.h +++ b/cpp/test/generated/binary/protocols.h @@ -1098,6 +1098,7 @@ class EnumsWriter : public test_model::EnumsWriterBase, yardl::binary::BinaryWri void WriteVecImpl(std::vector const& value) override; void WriteSizeImpl(test_model::SizeBasedEnum const& value) override; void WriteRecImpl(test_model::RecordWithEnums const& value) override; + void WriteRecArrayImpl(yardl::DynamicNDArray const& value) override; void CloseImpl() override; Version version_; @@ -1119,6 +1120,7 @@ class EnumsReader : public test_model::EnumsReaderBase, yardl::binary::BinaryRea void ReadVecImpl(std::vector& value) override; void ReadSizeImpl(test_model::SizeBasedEnum& value) override; void ReadRecImpl(test_model::RecordWithEnums& value) override; + void ReadRecArrayImpl(yardl::DynamicNDArray& value) override; void CloseImpl() override; Version version_; diff --git a/cpp/test/generated/hdf5/protocols.cc b/cpp/test/generated/hdf5/protocols.cc index 58d2006d..80825859 100644 --- a/cpp/test/generated/hdf5/protocols.cc +++ b/cpp/test/generated/hdf5/protocols.cc @@ -3304,6 +3304,10 @@ void EnumsWriter::WriteRecImpl(test_model::RecordWithEnums const& value) { yardl::hdf5::WriteScalarDataset(group_, "rec", test_model::hdf5::GetRecordWithEnumsHdf5Ddl(), value); } +void EnumsWriter::WriteRecArrayImpl(yardl::DynamicNDArray const& value) { + yardl::hdf5::WriteScalarDataset, yardl::DynamicNDArray>(group_, "recArray", yardl::hdf5::DynamicNDArrayDdl(test_model::hdf5::GetRecordWithEnumsHdf5Ddl()), value); +} + EnumsReader::EnumsReader(std::string path, bool skip_completed_check) : test_model::EnumsReaderBase(skip_completed_check), yardl::hdf5::Hdf5Reader::Hdf5Reader(path, "Enums", schema_) { } @@ -3324,6 +3328,10 @@ void EnumsReader::ReadRecImpl(test_model::RecordWithEnums& value) { yardl::hdf5::ReadScalarDataset(group_, "rec", test_model::hdf5::GetRecordWithEnumsHdf5Ddl(), value); } +void EnumsReader::ReadRecArrayImpl(yardl::DynamicNDArray& value) { + yardl::hdf5::ReadScalarDataset, yardl::DynamicNDArray>(group_, "recArray", yardl::hdf5::DynamicNDArrayDdl(test_model::hdf5::GetRecordWithEnumsHdf5Ddl()), value); +} + FlagsWriter::FlagsWriter(std::string path) : yardl::hdf5::Hdf5Writer::Hdf5Writer(path, "Flags", schema_) { } diff --git a/cpp/test/generated/hdf5/protocols.h b/cpp/test/generated/hdf5/protocols.h index 58658ea9..a8aae5c2 100644 --- a/cpp/test/generated/hdf5/protocols.h +++ b/cpp/test/generated/hdf5/protocols.h @@ -856,6 +856,8 @@ class EnumsWriter : public test_model::EnumsWriterBase, public yardl::hdf5::Hdf5 void WriteRecImpl(test_model::RecordWithEnums const& value) override; + void WriteRecArrayImpl(yardl::DynamicNDArray const& value) override; + private: }; @@ -872,6 +874,8 @@ class EnumsReader : public test_model::EnumsReaderBase, public yardl::hdf5::Hdf5 void ReadRecImpl(test_model::RecordWithEnums& value) override; + void ReadRecArrayImpl(yardl::DynamicNDArray& value) override; + private: }; diff --git a/cpp/test/generated/mocks.cc b/cpp/test/generated/mocks.cc index 4bef75d9..b86a8491 100644 --- a/cpp/test/generated/mocks.cc +++ b/cpp/test/generated/mocks.cc @@ -3085,6 +3085,22 @@ class MockEnumsWriter : public EnumsWriterBase { WriteRecImpl_expected_values_.push(value); } + void WriteRecArrayImpl (yardl::DynamicNDArray const& value) override { + if (WriteRecArrayImpl_expected_values_.empty()) { + throw std::runtime_error("Unexpected call to WriteRecArrayImpl"); + } + if (WriteRecArrayImpl_expected_values_.front() != value) { + throw std::runtime_error("Unexpected argument value for call to WriteRecArrayImpl"); + } + WriteRecArrayImpl_expected_values_.pop(); + } + + std::queue> WriteRecArrayImpl_expected_values_; + + void ExpectWriteRecArrayImpl (yardl::DynamicNDArray const& value) { + WriteRecArrayImpl_expected_values_.push(value); + } + void Verify() { if (!WriteSingleImpl_expected_values_.empty()) { throw std::runtime_error("Expected call to WriteSingleImpl was not received"); @@ -3098,6 +3114,9 @@ class MockEnumsWriter : public EnumsWriterBase { if (!WriteRecImpl_expected_values_.empty()) { throw std::runtime_error("Expected call to WriteRecImpl was not received"); } + if (!WriteRecArrayImpl_expected_values_.empty()) { + throw std::runtime_error("Expected call to WriteRecArrayImpl was not received"); + } } }; @@ -3133,6 +3152,11 @@ class TestEnumsWriterBase : public EnumsWriterBase { mock_writer_.ExpectWriteRecImpl(value); } + void WriteRecArrayImpl(yardl::DynamicNDArray const& value) override { + writer_->WriteRecArray(value); + mock_writer_.ExpectWriteRecArrayImpl(value); + } + void CloseImpl() override { close_called_ = true; writer_->Close(); diff --git a/cpp/test/generated/model.json b/cpp/test/generated/model.json index 53365d7a..e549f0d2 100644 --- a/cpp/test/generated/model.json +++ b/cpp/test/generated/model.json @@ -5229,6 +5229,14 @@ { "name": "rec", "type": "TestModel.RecordWithEnums" + }, + { + "name": "recArray", + "type": { + "array": { + "items": "TestModel.RecordWithEnums" + } + } } ] }, diff --git a/cpp/test/generated/ndjson/protocols.cc b/cpp/test/generated/ndjson/protocols.cc index d37346a5..2e788196 100644 --- a/cpp/test/generated/ndjson/protocols.cc +++ b/cpp/test/generated/ndjson/protocols.cc @@ -3745,6 +3745,10 @@ void EnumsWriter::WriteRecImpl(test_model::RecordWithEnums const& value) { ordered_json json_value = value; yardl::ndjson::WriteProtocolValue(stream_, "rec", json_value);} +void EnumsWriter::WriteRecArrayImpl(yardl::DynamicNDArray const& value) { + ordered_json json_value = value; + yardl::ndjson::WriteProtocolValue(stream_, "recArray", json_value);} + void EnumsWriter::Flush() { stream_.flush(); } @@ -3769,6 +3773,10 @@ void EnumsReader::ReadRecImpl(test_model::RecordWithEnums& value) { yardl::ndjson::ReadProtocolValue(stream_, line_, "rec", true, unused_step_, value); } +void EnumsReader::ReadRecArrayImpl(yardl::DynamicNDArray& value) { + yardl::ndjson::ReadProtocolValue(stream_, line_, "recArray", true, unused_step_, value); +} + void EnumsReader::CloseImpl() { if (!skip_completed_check_) { VerifyFinished(); diff --git a/cpp/test/generated/ndjson/protocols.h b/cpp/test/generated/ndjson/protocols.h index 0bc2ad4b..1ba7f8aa 100644 --- a/cpp/test/generated/ndjson/protocols.h +++ b/cpp/test/generated/ndjson/protocols.h @@ -993,6 +993,7 @@ class EnumsWriter : public test_model::EnumsWriterBase, yardl::ndjson::NDJsonWri void WriteVecImpl(std::vector const& value) override; void WriteSizeImpl(test_model::SizeBasedEnum const& value) override; void WriteRecImpl(test_model::RecordWithEnums const& value) override; + void WriteRecArrayImpl(yardl::DynamicNDArray const& value) override; void CloseImpl() override; }; @@ -1012,6 +1013,7 @@ class EnumsReader : public test_model::EnumsReaderBase, yardl::ndjson::NDJsonRea void ReadVecImpl(std::vector& value) override; void ReadSizeImpl(test_model::SizeBasedEnum& value) override; void ReadRecImpl(test_model::RecordWithEnums& value) override; + void ReadRecArrayImpl(yardl::DynamicNDArray& value) override; void CloseImpl() override; }; diff --git a/cpp/test/generated/protocols.cc b/cpp/test/generated/protocols.cc index 5c7f7622..d7fd1aeb 100644 --- a/cpp/test/generated/protocols.cc +++ b/cpp/test/generated/protocols.cc @@ -4780,6 +4780,7 @@ void EnumsWriterBaseInvalidState(uint8_t attempted, [[maybe_unused]] bool end, u case 1: expected_method = "WriteVec()"; break; case 2: expected_method = "WriteSize()"; break; case 3: expected_method = "WriteRec()"; break; + case 4: expected_method = "WriteRecArray()"; break; } std::string attempted_method; switch (attempted) { @@ -4787,7 +4788,8 @@ void EnumsWriterBaseInvalidState(uint8_t attempted, [[maybe_unused]] bool end, u case 1: attempted_method = "WriteVec()"; break; case 2: attempted_method = "WriteSize()"; break; case 3: attempted_method = "WriteRec()"; break; - case 4: attempted_method = "Close()"; break; + case 4: attempted_method = "WriteRecArray()"; break; + case 5: attempted_method = "Close()"; break; } throw std::runtime_error("Expected call to " + expected_method + " but received call to " + attempted_method + " instead."); } @@ -4799,7 +4801,8 @@ void EnumsReaderBaseInvalidState(uint8_t attempted, uint8_t current) { case 1: return "ReadVec()"; case 2: return "ReadSize()"; case 3: return "ReadRec()"; - case 4: return "Close()"; + case 4: return "ReadRecArray()"; + case 5: return "Close()"; default: return ""; } }; @@ -4808,7 +4811,7 @@ void EnumsReaderBaseInvalidState(uint8_t attempted, uint8_t current) { } // namespace -std::string EnumsWriterBase::schema_ = R"({"protocol":{"name":"Enums","sequence":[{"name":"single","type":"TestModel.Fruits"},{"name":"vec","type":{"vector":{"items":"TestModel.Fruits"}}},{"name":"size","type":"TestModel.SizeBasedEnum"},{"name":"rec","type":"TestModel.RecordWithEnums"}]},"types":[{"name":"DaysOfWeek","values":[{"symbol":"monday","value":1},{"symbol":"tuesday","value":2},{"symbol":"wednesday","value":4},{"symbol":"thursday","value":8},{"symbol":"friday","value":16},{"symbol":"saturday","value":32},{"symbol":"sunday","value":64}]},{"name":"Fruits","values":[{"symbol":"apple","value":1},{"symbol":"banana","value":2},{"symbol":"pear","value":3}]},{"name":"TextFormat","base":"uint64","values":[{"symbol":"regular","value":0},{"symbol":"bold","value":1},{"symbol":"italic","value":2},{"symbol":"underline","value":4},{"symbol":"strikethrough","value":8}]},{"name":"DaysOfWeek","type":"BasicTypes.DaysOfWeek"},{"name":"Fruits","type":"BasicTypes.Fruits"},{"name":"RecordWithEnums","fields":[{"name":"enum","type":"TestModel.Fruits"},{"name":"flags","type":"TestModel.DaysOfWeek"},{"name":"flags2","type":"TestModel.TextFormat"},{"name":"rec","type":"TestModel.RecordWithNoDefaultEnum"}]},{"name":"RecordWithNoDefaultEnum","fields":[{"name":"enum","type":"TestModel.Fruits"}]},{"name":"SizeBasedEnum","base":"size","values":[{"symbol":"a","value":0},{"symbol":"b","value":1},{"symbol":"c","value":2}]},{"name":"TextFormat","type":"BasicTypes.TextFormat"}]})"; +std::string EnumsWriterBase::schema_ = R"({"protocol":{"name":"Enums","sequence":[{"name":"single","type":"TestModel.Fruits"},{"name":"vec","type":{"vector":{"items":"TestModel.Fruits"}}},{"name":"size","type":"TestModel.SizeBasedEnum"},{"name":"rec","type":"TestModel.RecordWithEnums"},{"name":"recArray","type":{"array":{"items":"TestModel.RecordWithEnums"}}}]},"types":[{"name":"DaysOfWeek","values":[{"symbol":"monday","value":1},{"symbol":"tuesday","value":2},{"symbol":"wednesday","value":4},{"symbol":"thursday","value":8},{"symbol":"friday","value":16},{"symbol":"saturday","value":32},{"symbol":"sunday","value":64}]},{"name":"Fruits","values":[{"symbol":"apple","value":1},{"symbol":"banana","value":2},{"symbol":"pear","value":3}]},{"name":"TextFormat","base":"uint64","values":[{"symbol":"regular","value":0},{"symbol":"bold","value":1},{"symbol":"italic","value":2},{"symbol":"underline","value":4},{"symbol":"strikethrough","value":8}]},{"name":"DaysOfWeek","type":"BasicTypes.DaysOfWeek"},{"name":"Fruits","type":"BasicTypes.Fruits"},{"name":"RecordWithEnums","fields":[{"name":"enum","type":"TestModel.Fruits"},{"name":"flags","type":"TestModel.DaysOfWeek"},{"name":"flags2","type":"TestModel.TextFormat"},{"name":"rec","type":"TestModel.RecordWithNoDefaultEnum"}]},{"name":"RecordWithNoDefaultEnum","fields":[{"name":"enum","type":"TestModel.Fruits"}]},{"name":"SizeBasedEnum","base":"size","values":[{"symbol":"a","value":0},{"symbol":"b","value":1},{"symbol":"c","value":2}]},{"name":"TextFormat","type":"BasicTypes.TextFormat"}]})"; std::vector EnumsWriterBase::previous_schemas_ = { }; @@ -4856,11 +4859,20 @@ void EnumsWriterBase::WriteRec(test_model::RecordWithEnums const& value) { state_ = 4; } -void EnumsWriterBase::Close() { +void EnumsWriterBase::WriteRecArray(yardl::DynamicNDArray const& value) { if (unlikely(state_ != 4)) { EnumsWriterBaseInvalidState(4, false, state_); } + WriteRecArrayImpl(value); + state_ = 5; +} + +void EnumsWriterBase::Close() { + if (unlikely(state_ != 5)) { + EnumsWriterBaseInvalidState(5, false, state_); + } + CloseImpl(); } @@ -4910,11 +4922,20 @@ void EnumsReaderBase::ReadRec(test_model::RecordWithEnums& value) { state_ = 8; } -void EnumsReaderBase::Close() { - if (!skip_completed_check_ && unlikely(state_ != 8)) { +void EnumsReaderBase::ReadRecArray(yardl::DynamicNDArray& value) { + if (unlikely(state_ != 8)) { EnumsReaderBaseInvalidState(8, state_); } + ReadRecArrayImpl(value); + state_ = 10; +} + +void EnumsReaderBase::Close() { + if (!skip_completed_check_ && unlikely(state_ != 10)) { + EnumsReaderBaseInvalidState(10, state_); + } + CloseImpl(); } void EnumsReaderBase::CopyTo(EnumsWriterBase& writer) { @@ -4938,6 +4959,11 @@ void EnumsReaderBase::CopyTo(EnumsWriterBase& writer) { ReadRec(value); writer.WriteRec(value); } + { + yardl::DynamicNDArray value; + ReadRecArray(value); + writer.WriteRecArray(value); + } } namespace { diff --git a/cpp/test/generated/protocols.h b/cpp/test/generated/protocols.h index 7d1bc420..2ea44acd 100644 --- a/cpp/test/generated/protocols.h +++ b/cpp/test/generated/protocols.h @@ -2124,6 +2124,9 @@ class EnumsWriterBase { // Ordinal 3. void WriteRec(test_model::RecordWithEnums const& value); + // Ordinal 4. + void WriteRecArray(yardl::DynamicNDArray const& value); + // Optionaly close this writer before destructing. Validates that all steps were completed. void Close(); @@ -2137,6 +2140,7 @@ class EnumsWriterBase { virtual void WriteVecImpl(std::vector const& value) = 0; virtual void WriteSizeImpl(test_model::SizeBasedEnum const& value) = 0; virtual void WriteRecImpl(test_model::RecordWithEnums const& value) = 0; + virtual void WriteRecArrayImpl(yardl::DynamicNDArray const& value) = 0; virtual void CloseImpl() {} static std::string schema_; @@ -2168,6 +2172,9 @@ class EnumsReaderBase { // Ordinal 3. void ReadRec(test_model::RecordWithEnums& value); + // Ordinal 4. + void ReadRecArray(yardl::DynamicNDArray& value); + // Optionaly close this writer before destructing. Validates that all steps were completely read. void Close(); @@ -2180,6 +2187,7 @@ class EnumsReaderBase { virtual void ReadVecImpl(std::vector& value) = 0; virtual void ReadSizeImpl(test_model::SizeBasedEnum& value) = 0; virtual void ReadRecImpl(test_model::RecordWithEnums& value) = 0; + virtual void ReadRecArrayImpl(yardl::DynamicNDArray& value) = 0; virtual void CloseImpl() {} static std::string schema_; diff --git a/cpp/test/roundtrip_test.cc b/cpp/test/roundtrip_test.cc index ae1a4678..2956434f 100644 --- a/cpp/test/roundtrip_test.cc +++ b/cpp/test/roundtrip_test.cc @@ -410,6 +410,11 @@ TEST_P(RoundTripTests, Enums) { tw->WriteSize(SizeBasedEnum::kC); tw->WriteRec(RecordWithEnums{Fruits::kBanana, DaysOfWeek::kMonday, TextFormat::kBold, RecordWithNoDefaultEnum{Fruits::kPear}}); + tw->WriteRecArray({ + RecordWithEnums{Fruits::kPear, DaysOfWeek::kMonday, TextFormat::kBold, RecordWithNoDefaultEnum{Fruits::kBanana}}, + RecordWithEnums{Fruits::kApple, DaysOfWeek(), TextFormat::kRegular, RecordWithNoDefaultEnum{Fruits::kApple}}, + }); + tw->Close(); } diff --git a/matlab/generated/+test_model/+binary/EnumsReader.m b/matlab/generated/+test_model/+binary/EnumsReader.m index 6d54b243..df20d4da 100644 --- a/matlab/generated/+test_model/+binary/EnumsReader.m +++ b/matlab/generated/+test_model/+binary/EnumsReader.m @@ -7,6 +7,7 @@ vec_serializer size_serializer rec_serializer + rec_array_serializer end methods @@ -21,6 +22,7 @@ self.vec_serializer = yardl.binary.VectorSerializer(yardl.binary.EnumSerializer('basic_types.Fruits', @basic_types.Fruits, yardl.binary.Int32Serializer)); self.size_serializer = yardl.binary.EnumSerializer('test_model.SizeBasedEnum', @test_model.SizeBasedEnum, yardl.binary.SizeSerializer); self.rec_serializer = test_model.binary.RecordWithEnumsSerializer(); + self.rec_array_serializer = yardl.binary.DynamicNDArraySerializer(test_model.binary.RecordWithEnumsSerializer()); end end @@ -40,5 +42,9 @@ function value = read_rec_(self) value = self.rec_serializer.read(self.stream_); end + + function value = read_rec_array_(self) + value = self.rec_array_serializer.read(self.stream_); + end end end diff --git a/matlab/generated/+test_model/+binary/EnumsWriter.m b/matlab/generated/+test_model/+binary/EnumsWriter.m index 60b41ffb..f33b3752 100644 --- a/matlab/generated/+test_model/+binary/EnumsWriter.m +++ b/matlab/generated/+test_model/+binary/EnumsWriter.m @@ -7,6 +7,7 @@ vec_serializer size_serializer rec_serializer + rec_array_serializer end methods @@ -17,6 +18,7 @@ self.vec_serializer = yardl.binary.VectorSerializer(yardl.binary.EnumSerializer('basic_types.Fruits', @basic_types.Fruits, yardl.binary.Int32Serializer)); self.size_serializer = yardl.binary.EnumSerializer('test_model.SizeBasedEnum', @test_model.SizeBasedEnum, yardl.binary.SizeSerializer); self.rec_serializer = test_model.binary.RecordWithEnumsSerializer(); + self.rec_array_serializer = yardl.binary.DynamicNDArraySerializer(test_model.binary.RecordWithEnumsSerializer()); end end @@ -36,5 +38,9 @@ function write_size_(self, value) function write_rec_(self, value) self.rec_serializer.write(self.stream_, value); end + + function write_rec_array_(self, value) + self.rec_array_serializer.write(self.stream_, value); + end end end diff --git a/matlab/generated/+test_model/+testing/MockEnumsWriter.m b/matlab/generated/+test_model/+testing/MockEnumsWriter.m index 54f4ec00..9c40aa9d 100644 --- a/matlab/generated/+test_model/+testing/MockEnumsWriter.m +++ b/matlab/generated/+test_model/+testing/MockEnumsWriter.m @@ -7,6 +7,7 @@ expected_vec expected_size expected_rec + expected_rec_array end methods @@ -16,6 +17,7 @@ self.expected_vec = yardl.None; self.expected_size = yardl.None; self.expected_rec = yardl.None; + self.expected_rec_array = yardl.None; end function expect_write_single_(self, value) @@ -34,11 +36,16 @@ function expect_write_rec_(self, value) self.expected_rec = yardl.Optional(value); end + function expect_write_rec_array_(self, value) + self.expected_rec_array = yardl.Optional(value); + end + function verify(self) self.testCase_.verifyEqual(self.expected_single, yardl.None, "Expected call to write_single_ was not received"); self.testCase_.verifyEqual(self.expected_vec, yardl.None, "Expected call to write_vec_ was not received"); self.testCase_.verifyEqual(self.expected_size, yardl.None, "Expected call to write_size_ was not received"); self.testCase_.verifyEqual(self.expected_rec, yardl.None, "Expected call to write_rec_ was not received"); + self.testCase_.verifyEqual(self.expected_rec_array, yardl.None, "Expected call to write_rec_array_ was not received"); end end @@ -67,6 +74,12 @@ function write_rec_(self, value) self.expected_rec = yardl.None; end + function write_rec_array_(self, value) + self.testCase_.verifyTrue(self.expected_rec_array.has_value(), "Unexpected call to write_rec_array_"); + self.testCase_.verifyEqual(value, self.expected_rec_array.value, "Unexpected argument value for call to write_rec_array_"); + self.expected_rec_array = yardl.None; + end + function close_(self) end function end_stream_(self) diff --git a/matlab/generated/+test_model/+testing/TestEnumsWriter.m b/matlab/generated/+test_model/+testing/TestEnumsWriter.m index 999db3ee..38b5edb8 100644 --- a/matlab/generated/+test_model/+testing/TestEnumsWriter.m +++ b/matlab/generated/+test_model/+testing/TestEnumsWriter.m @@ -50,6 +50,11 @@ function write_rec_(self, value) self.mock_writer_.expect_write_rec_(value); end + function write_rec_array_(self, value) + self.writer_.write_rec_array(value); + self.mock_writer_.expect_write_rec_array_(value); + end + function close_(self) self.close_called_ = true; self.writer_.close(); diff --git a/matlab/generated/+test_model/EnumsReaderBase.m b/matlab/generated/+test_model/EnumsReaderBase.m index a830e9c5..7b2c5cbb 100644 --- a/matlab/generated/+test_model/EnumsReaderBase.m +++ b/matlab/generated/+test_model/EnumsReaderBase.m @@ -17,7 +17,7 @@ function close(self) self.close_(); - if ~self.skip_completed_check_ && self.state_ ~= 4 + if ~self.skip_completed_check_ && self.state_ ~= 5 expected_method = self.state_to_method_name_(self.state_); throw(yardl.ProtocolError("Protocol reader closed before all data was consumed. Expected call to '%s'.", expected_method)); end @@ -63,11 +63,22 @@ function close(self) self.state_ = 4; end + % Ordinal 4 + function value = read_rec_array(self) + if self.state_ ~= 4 + self.raise_unexpected_state_(4); + end + + value = self.read_rec_array_(); + self.state_ = 5; + end + function copy_to(self, writer) writer.write_single(self.read_single()); writer.write_vec(self.read_vec()); writer.write_size(self.read_size()); writer.write_rec(self.read_rec()); + writer.write_rec_array(self.read_rec_array()); end end @@ -82,6 +93,7 @@ function copy_to(self, writer) read_vec_(self) read_size_(self) read_rec_(self) + read_rec_array_(self) close_(self) end @@ -102,6 +114,8 @@ function raise_unexpected_state_(self, actual) name = "read_size"; elseif state == 3 name = "read_rec"; + elseif state == 4 + name = "read_rec_array"; else name = ""; end diff --git a/matlab/generated/+test_model/EnumsWriterBase.m b/matlab/generated/+test_model/EnumsWriterBase.m index e0d5288d..2384270e 100644 --- a/matlab/generated/+test_model/EnumsWriterBase.m +++ b/matlab/generated/+test_model/EnumsWriterBase.m @@ -13,7 +13,7 @@ function close(self) self.close_(); - if self.state_ ~= 4 + if self.state_ ~= 5 expected_method = self.state_to_method_name_(self.state_); throw(yardl.ProtocolError("Protocol writer closed before all steps were called. Expected call to '%s'.", expected_method)); end @@ -58,11 +58,21 @@ function write_rec(self, value) self.write_rec_(value); self.state_ = 4; end + + % Ordinal 4 + function write_rec_array(self, value) + if self.state_ ~= 4 + self.raise_unexpected_state_(4); + end + + self.write_rec_array_(value); + self.state_ = 5; + end end methods (Static) function res = schema() - res = string('{"protocol":{"name":"Enums","sequence":[{"name":"single","type":"TestModel.Fruits"},{"name":"vec","type":{"vector":{"items":"TestModel.Fruits"}}},{"name":"size","type":"TestModel.SizeBasedEnum"},{"name":"rec","type":"TestModel.RecordWithEnums"}]},"types":[{"name":"DaysOfWeek","values":[{"symbol":"monday","value":1},{"symbol":"tuesday","value":2},{"symbol":"wednesday","value":4},{"symbol":"thursday","value":8},{"symbol":"friday","value":16},{"symbol":"saturday","value":32},{"symbol":"sunday","value":64}]},{"name":"Fruits","values":[{"symbol":"apple","value":1},{"symbol":"banana","value":2},{"symbol":"pear","value":3}]},{"name":"TextFormat","base":"uint64","values":[{"symbol":"regular","value":0},{"symbol":"bold","value":1},{"symbol":"italic","value":2},{"symbol":"underline","value":4},{"symbol":"strikethrough","value":8}]},{"name":"DaysOfWeek","type":"BasicTypes.DaysOfWeek"},{"name":"Fruits","type":"BasicTypes.Fruits"},{"name":"RecordWithEnums","fields":[{"name":"enum","type":"TestModel.Fruits"},{"name":"flags","type":"TestModel.DaysOfWeek"},{"name":"flags2","type":"TestModel.TextFormat"},{"name":"rec","type":"TestModel.RecordWithNoDefaultEnum"}]},{"name":"RecordWithNoDefaultEnum","fields":[{"name":"enum","type":"TestModel.Fruits"}]},{"name":"SizeBasedEnum","base":"size","values":[{"symbol":"a","value":0},{"symbol":"b","value":1},{"symbol":"c","value":2}]},{"name":"TextFormat","type":"BasicTypes.TextFormat"}]}'); + res = string('{"protocol":{"name":"Enums","sequence":[{"name":"single","type":"TestModel.Fruits"},{"name":"vec","type":{"vector":{"items":"TestModel.Fruits"}}},{"name":"size","type":"TestModel.SizeBasedEnum"},{"name":"rec","type":"TestModel.RecordWithEnums"},{"name":"recArray","type":{"array":{"items":"TestModel.RecordWithEnums"}}}]},"types":[{"name":"DaysOfWeek","values":[{"symbol":"monday","value":1},{"symbol":"tuesday","value":2},{"symbol":"wednesday","value":4},{"symbol":"thursday","value":8},{"symbol":"friday","value":16},{"symbol":"saturday","value":32},{"symbol":"sunday","value":64}]},{"name":"Fruits","values":[{"symbol":"apple","value":1},{"symbol":"banana","value":2},{"symbol":"pear","value":3}]},{"name":"TextFormat","base":"uint64","values":[{"symbol":"regular","value":0},{"symbol":"bold","value":1},{"symbol":"italic","value":2},{"symbol":"underline","value":4},{"symbol":"strikethrough","value":8}]},{"name":"DaysOfWeek","type":"BasicTypes.DaysOfWeek"},{"name":"Fruits","type":"BasicTypes.Fruits"},{"name":"RecordWithEnums","fields":[{"name":"enum","type":"TestModel.Fruits"},{"name":"flags","type":"TestModel.DaysOfWeek"},{"name":"flags2","type":"TestModel.TextFormat"},{"name":"rec","type":"TestModel.RecordWithNoDefaultEnum"}]},{"name":"RecordWithNoDefaultEnum","fields":[{"name":"enum","type":"TestModel.Fruits"}]},{"name":"SizeBasedEnum","base":"size","values":[{"symbol":"a","value":0},{"symbol":"b","value":1},{"symbol":"c","value":2}]},{"name":"TextFormat","type":"BasicTypes.TextFormat"}]}'); end end @@ -71,6 +81,7 @@ function write_rec(self, value) write_vec_(self, value) write_size_(self, value) write_rec_(self, value) + write_rec_array_(self, value) end_stream_(self) close_(self) @@ -92,6 +103,8 @@ function raise_unexpected_state_(self, actual) name = "write_size"; elseif state == 3 name = "write_rec"; + elseif state == 4 + name = "write_rec_array"; else name = ''; end diff --git a/matlab/test/RoundTripTest.m b/matlab/test/RoundTripTest.m index f822b20f..6d59c383 100644 --- a/matlab/test/RoundTripTest.m +++ b/matlab/test/RoundTripTest.m @@ -364,6 +364,12 @@ function testEnums(testCase, format) w.write_rec(test_model.RecordWithEnums(... enum=test_model.Fruits.PEAR, ... rec=test_model.RecordWithNoDefaultEnum(enum=test_model.Fruits.BANANA))); + RE = @test_model.RecordWithEnums; + RN = @test_model.RecordWithNoDefaultEnum; + w.write_rec_array([... + RE(enum=test_model.Fruits.PEAR, rec=RN(enum=test_model.Fruits.BANANA)), ... + RE(enum=test_model.Fruits.APPLE, rec=RN(enum=test_model.Fruits.APPLE)) ... + ]); w.close(); end diff --git a/models/test/unittests.yml b/models/test/unittests.yml index 7cd2f4a6..629d0b83 100644 --- a/models/test/unittests.yml +++ b/models/test/unittests.yml @@ -367,6 +367,7 @@ Enums: !protocol vec: Fruits* size: SizeBasedEnum rec: RecordWithEnums + recArray: RecordWithEnums[] DaysOfWeek: BasicTypes.DaysOfWeek diff --git a/python/test_model/binary.py b/python/test_model/binary.py index 0c6afe9b..273d4b53 100644 --- a/python/test_model/binary.py +++ b/python/test_model/binary.py @@ -908,6 +908,9 @@ def _write_size(self, value: SizeBasedEnum) -> None: def _write_rec(self, value: RecordWithEnums) -> None: RecordWithEnumsSerializer().write(self._stream, value) + def _write_rec_array(self, value: npt.NDArray[np.void]) -> None: + _binary.DynamicNDArraySerializer(RecordWithEnumsSerializer()).write(self._stream, value) + class BinaryEnumsReader(_binary.BinaryProtocolReader, EnumsReaderBase): """Binary writer for the Enums protocol.""" @@ -929,6 +932,9 @@ def _read_size(self) -> SizeBasedEnum: def _read_rec(self) -> RecordWithEnums: return RecordWithEnumsSerializer().read(self._stream) + def _read_rec_array(self) -> npt.NDArray[np.void]: + return _binary.DynamicNDArraySerializer(RecordWithEnumsSerializer()).read(self._stream) + class BinaryFlagsWriter(_binary.BinaryProtocolWriter, FlagsWriterBase): """Binary writer for the Flags protocol.""" diff --git a/python/test_model/ndjson.py b/python/test_model/ndjson.py index a0a02387..7714d328 100644 --- a/python/test_model/ndjson.py +++ b/python/test_model/ndjson.py @@ -3818,6 +3818,11 @@ def _write_rec(self, value: RecordWithEnums) -> None: json_value = converter.to_json(value) self._write_json_line({"rec": json_value}) + def _write_rec_array(self, value: npt.NDArray[np.void]) -> None: + converter = _ndjson.DynamicNDArrayConverter(RecordWithEnumsConverter()) + json_value = converter.to_json(value) + self._write_json_line({"recArray": json_value}) + class NDJsonEnumsReader(_ndjson.NDJsonProtocolReader, EnumsReaderBase): """NDJson writer for the Enums protocol.""" @@ -3847,6 +3852,11 @@ def _read_rec(self) -> RecordWithEnums: converter = RecordWithEnumsConverter() return converter.from_json(json_object) + def _read_rec_array(self) -> npt.NDArray[np.void]: + json_object = self._read_json_line("recArray", True) + converter = _ndjson.DynamicNDArrayConverter(RecordWithEnumsConverter()) + return converter.from_json(json_object) + class NDJsonFlagsWriter(_ndjson.NDJsonProtocolWriter, FlagsWriterBase): """NDJson writer for the Flags protocol.""" diff --git a/python/test_model/protocols.py b/python/test_model/protocols.py index c1edc8b0..2fe3ec26 100644 --- a/python/test_model/protocols.py +++ b/python/test_model/protocols.py @@ -4836,11 +4836,11 @@ class EnumsWriterBase(abc.ABC): def __init__(self) -> None: self._state = 0 - schema = r"""{"protocol":{"name":"Enums","sequence":[{"name":"single","type":"TestModel.Fruits"},{"name":"vec","type":{"vector":{"items":"TestModel.Fruits"}}},{"name":"size","type":"TestModel.SizeBasedEnum"},{"name":"rec","type":"TestModel.RecordWithEnums"}]},"types":[{"name":"DaysOfWeek","values":[{"symbol":"monday","value":1},{"symbol":"tuesday","value":2},{"symbol":"wednesday","value":4},{"symbol":"thursday","value":8},{"symbol":"friday","value":16},{"symbol":"saturday","value":32},{"symbol":"sunday","value":64}]},{"name":"Fruits","values":[{"symbol":"apple","value":1},{"symbol":"banana","value":2},{"symbol":"pear","value":3}]},{"name":"TextFormat","base":"uint64","values":[{"symbol":"regular","value":0},{"symbol":"bold","value":1},{"symbol":"italic","value":2},{"symbol":"underline","value":4},{"symbol":"strikethrough","value":8}]},{"name":"DaysOfWeek","type":"BasicTypes.DaysOfWeek"},{"name":"Fruits","type":"BasicTypes.Fruits"},{"name":"RecordWithEnums","fields":[{"name":"enum","type":"TestModel.Fruits"},{"name":"flags","type":"TestModel.DaysOfWeek"},{"name":"flags2","type":"TestModel.TextFormat"},{"name":"rec","type":"TestModel.RecordWithNoDefaultEnum"}]},{"name":"RecordWithNoDefaultEnum","fields":[{"name":"enum","type":"TestModel.Fruits"}]},{"name":"SizeBasedEnum","base":"size","values":[{"symbol":"a","value":0},{"symbol":"b","value":1},{"symbol":"c","value":2}]},{"name":"TextFormat","type":"BasicTypes.TextFormat"}]}""" + schema = r"""{"protocol":{"name":"Enums","sequence":[{"name":"single","type":"TestModel.Fruits"},{"name":"vec","type":{"vector":{"items":"TestModel.Fruits"}}},{"name":"size","type":"TestModel.SizeBasedEnum"},{"name":"rec","type":"TestModel.RecordWithEnums"},{"name":"recArray","type":{"array":{"items":"TestModel.RecordWithEnums"}}}]},"types":[{"name":"DaysOfWeek","values":[{"symbol":"monday","value":1},{"symbol":"tuesday","value":2},{"symbol":"wednesday","value":4},{"symbol":"thursday","value":8},{"symbol":"friday","value":16},{"symbol":"saturday","value":32},{"symbol":"sunday","value":64}]},{"name":"Fruits","values":[{"symbol":"apple","value":1},{"symbol":"banana","value":2},{"symbol":"pear","value":3}]},{"name":"TextFormat","base":"uint64","values":[{"symbol":"regular","value":0},{"symbol":"bold","value":1},{"symbol":"italic","value":2},{"symbol":"underline","value":4},{"symbol":"strikethrough","value":8}]},{"name":"DaysOfWeek","type":"BasicTypes.DaysOfWeek"},{"name":"Fruits","type":"BasicTypes.Fruits"},{"name":"RecordWithEnums","fields":[{"name":"enum","type":"TestModel.Fruits"},{"name":"flags","type":"TestModel.DaysOfWeek"},{"name":"flags2","type":"TestModel.TextFormat"},{"name":"rec","type":"TestModel.RecordWithNoDefaultEnum"}]},{"name":"RecordWithNoDefaultEnum","fields":[{"name":"enum","type":"TestModel.Fruits"}]},{"name":"SizeBasedEnum","base":"size","values":[{"symbol":"a","value":0},{"symbol":"b","value":1},{"symbol":"c","value":2}]},{"name":"TextFormat","type":"BasicTypes.TextFormat"}]}""" def close(self) -> None: self._close() - if self._state != 8: + if self._state != 10: expected_method = self._state_to_method_name((self._state + 1) & ~1) raise ProtocolError(f"Protocol writer closed before all steps were called. Expected to call to '{expected_method}'.") @@ -4890,6 +4890,15 @@ def write_rec(self, value: RecordWithEnums) -> None: self._write_rec(value) self._state = 8 + def write_rec_array(self, value: npt.NDArray[np.void]) -> None: + """Ordinal 4""" + + if self._state != 8: + self._raise_unexpected_state(8) + + self._write_rec_array(value) + self._state = 10 + @abc.abstractmethod def _write_single(self, value: Fruits) -> None: raise NotImplementedError() @@ -4906,6 +4915,10 @@ def _write_size(self, value: SizeBasedEnum) -> None: def _write_rec(self, value: RecordWithEnums) -> None: raise NotImplementedError() + @abc.abstractmethod + def _write_rec_array(self, value: npt.NDArray[np.void]) -> None: + raise NotImplementedError() + @abc.abstractmethod def _close(self) -> None: pass @@ -4928,6 +4941,8 @@ def _state_to_method_name(self, state: int) -> str: return 'write_size' if state == 6: return 'write_rec' + if state == 8: + return 'write_rec_array' return "" class EnumsReaderBase(abc.ABC): @@ -4940,7 +4955,7 @@ def __init__(self, skip_completed_check: bool = False) -> None: def close(self) -> None: self._close() - if not self._skip_completed_check and self._state != 8: + if not self._skip_completed_check and self._state != 10: if self._state % 2 == 1: previous_method = self._state_to_method_name(self._state - 1) raise ProtocolError(f"Protocol reader closed before all data was consumed. The iterable returned by '{previous_method}' was not fully consumed.") @@ -5005,11 +5020,22 @@ def read_rec(self) -> RecordWithEnums: self._state = 8 return value + def read_rec_array(self) -> npt.NDArray[np.void]: + """Ordinal 4""" + + if self._state != 8: + self._raise_unexpected_state(8) + + value = self._read_rec_array() + self._state = 10 + return value + def copy_to(self, writer: EnumsWriterBase) -> None: writer.write_single(self.read_single()) writer.write_vec(self.read_vec()) writer.write_size(self.read_size()) writer.write_rec(self.read_rec()) + writer.write_rec_array(self.read_rec_array()) @abc.abstractmethod def _read_single(self) -> Fruits: @@ -5027,6 +5053,10 @@ def _read_size(self) -> SizeBasedEnum: def _read_rec(self) -> RecordWithEnums: raise NotImplementedError() + @abc.abstractmethod + def _read_rec_array(self) -> npt.NDArray[np.void]: + raise NotImplementedError() + T = typing.TypeVar('T') def _wrap_iterable(self, iterable: collections.abc.Iterable[T], final_state: int) -> collections.abc.Iterable[T]: yield from iterable @@ -5050,6 +5080,8 @@ def _state_to_method_name(self, state: int) -> str: return 'read_size' if state == 6: return 'read_rec' + if state == 8: + return 'read_rec_array' return "" class FlagsWriterBase(abc.ABC): diff --git a/python/tests/test_generated_types.py b/python/tests/test_generated_types.py index 486e456c..4173d2b4 100644 --- a/python/tests/test_generated_types.py +++ b/python/tests/test_generated_types.py @@ -161,6 +161,18 @@ def test_get_dtype(): align=True, ) + # Enum fields map to their integer base dtype, so iterating a record + # array yields bare numpy scalars for those fields. + assert tm.get_dtype(tm.RecordWithEnums) == np.dtype( + [ + ("enum", " None: - self._integer_serializer.write(stream, value.value) + int_value = value.value if isinstance(value, Enum) else value + self._integer_serializer.write(stream, int_value) def write_numpy(self, stream: CodedOutputStream, value: T_NP) -> None: return self._integer_serializer.write_numpy(stream, value) @@ -1325,7 +1326,17 @@ def _read(self, stream: CodedInputStream) -> tuple[Any, ...]: ) def read_numpy(self, stream: CodedInputStream) -> np.void: - return cast(np.void, self._read(stream)) + # Enum and nested-record fields must be read via read_numpy so they + # yield numpy-assignable values rather than Enum or record objects. + return cast( + np.void, + tuple( + serializer.read_numpy(stream) + if isinstance(serializer, (EnumSerializer, RecordSerializer)) + else serializer.read(stream) + for _, serializer in self._field_serializers + ), + ) # Only used in the header