Skip to content

Commit

Permalink
keep disable_fp16_compression rt_info (openvinotoolkit#20625)
Browse files Browse the repository at this point in the history
* keep disable_fp16_compression rt_info

* style fix

* style fix 2

* cleanup init_node_info.cpp; redefining a class for rt_info in Serialize

* move rt_info refreshing inside serialize.cpp

* rename rt_info name in IR

* add rt_info serialize test

* add ticket number

* updated comments

* code style fix

---------

Co-authored-by: Andrei Kochin <[email protected]>
  • Loading branch information
pavel-esir and andrei-kochin authored Oct 27, 2023
1 parent 1d4520e commit 539b5a8
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@ TRANSFORMATIONS_API void do_not_postpone_fp16_compression(RTMap& rt_info);
/**
* @ingroup ie_runtime_attr_api
* @brief DisableFP16Compression class represents runtime info attribute that marks operation
* as prohibitted to convert to FP16 as part of Compressed Only format.
* as prohibited to convert to lower precision (e.g. to FP16) and they should be inferred precisely in the original
* precision.
*/
class TRANSFORMATIONS_API DisableFP16Compression : public RuntimeAttribute {
public:
OPENVINO_RTTI("disable_fp16_compression", "0");
OPENVINO_RTTI("precise", "0");

DisableFP16Compression() = default;

bool visit_attributes(AttributeVisitor& visitor) override {
return true;
}

bool is_copyable() const override {
return false;
}
Expand Down
8 changes: 8 additions & 0 deletions src/core/src/pass/serialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,14 @@ void serializeFunc(std::ostream& xml_file,
namespace ov {
bool pass::Serialize::run_on_model(const std::shared_ptr<ov::Model>& model) {
RUN_ON_FUNCTION_SCOPE(Serialize);

// TODO xxx-105807: if rt_info is set in python api as a string ['precise_0'] = '',
// we need to convert value to a class in order to have rt_info in the IR. The code below will convert
// ['precise_0'] = '' into => rt_info['precise_0'] = DisableFP16Compression{}
for (auto& node : model->get_ops())
if (fp16_compression_is_disabled(node))
disable_fp16_compression(node);

if (m_xmlFile && m_binFile) {
serializeFunc(*m_xmlFile, *m_binFile, model, m_version, m_custom_opsets);
} else {
Expand Down
28 changes: 28 additions & 0 deletions src/core/tests/pass/serialization/rt_info_serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,34 @@ TEST_F(RTInfoSerializationTest, all_attributes_latest) {
check_info(add->output(0).get_rt_info());
}

TEST_F(RTInfoSerializationTest, rt_info_precise_test) {
auto init_info = [](ov::RTMap& info) {
info[ov::DisableFP16Compression::get_type_info_static()] = ov::DisableFP16Compression{};
};
auto check_info = [](const ov::RTMap& info) {
const std::string& key = ov::DisableFP16Compression::get_type_info_static();
ASSERT_TRUE(info.count(key));
};

std::shared_ptr<ov::Model> function;
{
auto data_1 = std::make_shared<ov::opset8::Parameter>(ov::element::Type_t::f32, ov::Shape{1, 10});
auto data_2 = std::make_shared<ov::opset8::Parameter>(ov::element::Type_t::f32, ov::Shape{10, 1});
auto matmul_1 = std::make_shared<ov::opset8::MatMul>(data_1, data_2);
init_info(matmul_1->get_rt_info());
auto result = std::make_shared<ov::opset8::Result>(matmul_1);
function = std::make_shared<ov::Model>(ov::ResultVector{result}, ov::ParameterVector{data_1, data_2});
}
ov::pass::Manager m;
m.register_pass<ov::pass::Serialize>(m_out_xml_path, m_out_bin_path);
m.run_passes(function);
auto f = getWithIRFrontend(m_out_xml_path, m_out_bin_path);
ASSERT_NE(nullptr, f);

auto matmul = f->get_results()[0]->get_input_node_ptr(0);
check_info(matmul->get_rt_info());
}

TEST_F(RTInfoSerializationTest, all_attributes_v10) {
auto init_info = [](ov::RTMap& info) {
info[ov::FusedNames::get_type_info_static()] = ov::FusedNames("add");
Expand Down

0 comments on commit 539b5a8

Please sign in to comment.