diff --git a/HugeCTR/include/hps/hash_map_backend.hpp b/HugeCTR/include/hps/hash_map_backend.hpp index 881afad2ed..9ec274827f 100644 --- a/HugeCTR/include/hps/hash_map_backend.hpp +++ b/HugeCTR/include/hps/hash_map_backend.hpp @@ -27,6 +27,7 @@ #include #include #include +#include namespace HugeCTR { @@ -115,6 +116,35 @@ class HashMapBackend final : public VolatileBackend { uint64_t access_count; }; ValuePtr value; + std::atomic lck; + + Payload() { lck = -1; } + + explicit Payload(const Payload& p) { + value = p.value; + lck.store(p.lck.load()); + } + + Payload(Payload&& p) { + value = std::exchange(p.value, nullptr); + lck.store(p.lck.load()); + } + + Payload& operator=(const Payload& p) { + if (this != &p) { + value = p.value; + lck.store(p.lck.load()); + } + return *this; + } + + Payload& operator=(Payload&& p) { + if (this != &p) { + value = std::exchange(p.value, nullptr); + lck.store(p.lck.load()); + } + return *this; + } }; using Entry = std::pair; @@ -126,13 +156,45 @@ class HashMapBackend final : public VolatileBackend { std::vector value_pages; std::vector value_slots; + mutable std::shared_mutex read_write_lck; + // Key -> Payload map. - phmap::flat_hash_map entries; + phmap::parallel_flat_hash_map, + phmap::priv::hash_default_eq, + phmap::priv::Allocator>, + 4, std::shared_mutex> entries; Partition() = delete; Partition(const uint32_t value_size, const HashMapBackendParams& params) : value_size{value_size}, allocation_rate{params.allocation_rate} {} + + explicit Partition(const Partition& p) { + value_size = p.value_size; + allocation_rate = p.allocation_rate; + + value_pages = p.value_pages; + value_slots = p.value_slots; + + entries = entries; + } + + Partition& operator=(Partition&& p) { + if (this != &p) { + // TODO(robertzhu) + // std::scoped_lock lock(read_write_lck, p.read_write_lck); + + value_size = p.value_size; + allocation_rate = p.allocation_rate; + + value_pages = std::move(p.value_pages); + value_slots = std::move(p.value_slots); + + entries = std::move(entries); + } + + return *this; + } }; // Actual data. @@ -140,7 +202,7 @@ class HashMapBackend final : public VolatileBackend { std::unordered_map> tables_; // Access control. - mutable std::shared_mutex read_write_guard_; + // mutable std::shared_mutex read_write_guard_; // Overflow resolution. size_t resolve_overflow_(const std::string& table_name, size_t part_index, Partition& part); diff --git a/HugeCTR/include/hps/hash_map_backend_detail.hpp b/HugeCTR/include/hps/hash_map_backend_detail.hpp index b32c034c28..da3e7d3246 100644 --- a/HugeCTR/include/hps/hash_map_backend_detail.hpp +++ b/HugeCTR/include/hps/hash_map_backend_detail.hpp @@ -19,6 +19,7 @@ #include #include #include +#include namespace HugeCTR { @@ -55,7 +56,9 @@ namespace HugeCTR { const Payload& payload{it->second}; \ \ /* Stash pointer and reference in map. */ \ + std::unique_lock lock(part.read_write_lck); \ part.value_slots.emplace_back(payload.value); \ + lock.unlock(); \ part.entries.erase(it); \ ++num_deletions; \ } \ @@ -91,6 +94,7 @@ namespace HugeCTR { \ /* Race-conditions here are deliberately ignored because insignificant in practice. */ \ __VA_ARGS__; \ + while (payload.lck.load(std::memory_order_relaxed) != 0); \ std::copy_n(payload.value, part.value_size, &values[(k - keys) * value_stride]); \ } else { \ on_miss(k - keys); \ @@ -135,6 +139,7 @@ namespace HugeCTR { std::is_same_v); \ static_assert(std::is_same_v); \ \ + /* TODO(robertzhu): use thread safe api */ \ const auto& res{part.entries.try_emplace(*k)}; \ Payload& payload{res.first->second}; \ \ @@ -142,6 +147,7 @@ namespace HugeCTR { \ /* If new insertion. */ \ if (res.second) { \ + std::unique_lock lock(part.read_write_lck); \ /* If no free space, allocate another buffer, and fill pointer queue. */ \ if (part.value_slots.empty()) { \ const size_t stride{(value_size + value_page_alignment - 1) / value_page_alignment * \ @@ -152,6 +158,7 @@ namespace HugeCTR { /* Get more memory. */ \ part.value_pages.emplace_back(num_values* stride, char_allocator_); \ ValuePage& value_page{part.value_pages.back()}; \ + /*HCTR_LOG_C(DEBUG, WORLD, "insert value_page: num_values ", num_values, "; stride ", stride, "; value_page ", value_page.capacity(), ".\n"); \*/ \ /* Stock up slot references. */ \ part.value_slots.reserve(part.value_slots.size() + num_values); \ @@ -165,9 +172,16 @@ namespace HugeCTR { payload.value = part.value_slots.back(); \ part.value_slots.pop_back(); \ ++num_inserts; \ + lock.unlock(); \ } \ \ + if (payload.lck.load(std::memory_order_relaxed) != -1) { \ + int8_t expected = 0; \ + while (!payload.lck.compare_exchange_weak(expected, 1, std::memory_order_release, \ + std::memory_order_relaxed)); \ + } \ std::copy_n(&values[(k - keys) * value_stride], value_size, payload.value); \ + payload.lck.store(0, std::memory_order_relaxed); \ } while (0) /** @@ -198,4 +212,4 @@ namespace HugeCTR { // TODO: Remove me! #pragma GCC diagnostic pop -} // namespace HugeCTR \ No newline at end of file +} // namespace HugeCTR diff --git a/HugeCTR/src/hps/hash_map_backend.cpp b/HugeCTR/src/hps/hash_map_backend.cpp index cba9d7cb1d..4a3f71ed9c 100644 --- a/HugeCTR/src/hps/hash_map_backend.cpp +++ b/HugeCTR/src/hps/hash_map_backend.cpp @@ -36,7 +36,7 @@ HashMapBackend::HashMapBackend(const HashMapBackendParams& params) : Base(p template size_t HashMapBackend::size(const std::string& table_name) const { - const std::shared_lock lock(read_write_guard_); + // const std::shared_lock lock(read_write_guard_); // Locate the partitions. const auto& tables_it{tables_.find(table_name)}; @@ -54,7 +54,7 @@ size_t HashMapBackend::contains(const std::string& table_name, const size_t const Key* const keys, const std::chrono::nanoseconds& time_budget) const { const auto begin{std::chrono::high_resolution_clock::now()}; - const std::shared_lock lock(read_write_guard_); + // const std::shared_lock lock(read_write_guard_); // Locate partitions. const auto& tables_it{tables_.find(table_name)}; @@ -132,7 +132,7 @@ size_t HashMapBackend::insert(const std::string& table_name, const size_t n const uint32_t value_size, const size_t value_stride) { HCTR_CHECK(value_size <= value_stride); - const std::unique_lock lock(read_write_guard_); + // const std::unique_lock lock(read_write_guard_); // Locate the partitions, or create them, if they do not exist yet. const auto& tables_it{tables_.try_emplace(table_name).first}; @@ -222,7 +222,7 @@ size_t HashMapBackend::fetch(const std::string& table_name, const size_t nu const size_t value_stride, const DatabaseMissCallback& on_miss, const std::chrono::nanoseconds& time_budget) { const auto begin{std::chrono::high_resolution_clock::now()}; - const std::shared_lock lock(read_write_guard_); + // const std::shared_lock lock(read_write_guard_); // Locate the partitions. const auto& tables_it{tables_.find(table_name)}; @@ -306,7 +306,7 @@ size_t HashMapBackend::fetch(const std::string& table_name, const size_t nu const DatabaseMissCallback& on_miss, const std::chrono::nanoseconds& time_budget) { const auto begin{std::chrono::high_resolution_clock::now()}; - const std::shared_lock lock(read_write_guard_); + // const std::shared_lock lock(read_write_guard_); // Locate the partitions. const auto& tables_it{tables_.find(table_name)}; @@ -386,7 +386,7 @@ size_t HashMapBackend::fetch(const std::string& table_name, const size_t nu template size_t HashMapBackend::evict(const std::string& table_name) { - const std::unique_lock lock(read_write_guard_); + // const std::unique_lock lock(read_write_guard_); // Locate the partitions. const auto& tables_it{tables_.find(table_name)}; @@ -410,7 +410,7 @@ size_t HashMapBackend::evict(const std::string& table_name) { template size_t HashMapBackend::evict(const std::string& table_name, const size_t num_keys, const Key* const keys) { - const std::unique_lock lock(read_write_guard_); + // const std::unique_lock lock(read_write_guard_); // Locate the partitions. const auto& tables_it{tables_.find(table_name)}; @@ -476,7 +476,7 @@ template std::vector HashMapBackend::find_tables(const std::string& model_name) { const std::string& tag_prefix{HierParameterServerBase::make_tag_name(model_name, "", false)}; - const std::shared_lock lock(read_write_guard_); + // const std::shared_lock lock(read_write_guard_); std::vector matches; for (const auto& pair : tables_) { @@ -489,7 +489,7 @@ std::vector HashMapBackend::find_tables(const std::string& mod template size_t HashMapBackend::dump_bin(const std::string& table_name, std::ofstream& file) { - const std::shared_lock lock(read_write_guard_); + // const std::shared_lock lock(read_write_guard_); // Locate the partitions. const auto& tables_it{tables_.find(table_name)}; @@ -519,7 +519,7 @@ size_t HashMapBackend::dump_bin(const std::string& table_name, std::ofstrea #ifdef HCTR_USE_ROCKS_DB template size_t HashMapBackend::dump_sst(const std::string& table_name, rocksdb::SstFileWriter& file) { - const std::shared_lock lock(read_write_guard_); + // const std::shared_lock lock(read_write_guard_); // Locate the partitions. const auto& tables_it{tables_.find(table_name)}; diff --git a/HugeCTR/src/hps/hier_parameter_server.cpp b/HugeCTR/src/hps/hier_parameter_server.cpp index 631a0a4c4c..bf28369bf0 100644 --- a/HugeCTR/src/hps/hier_parameter_server.cpp +++ b/HugeCTR/src/hps/hier_parameter_server.cpp @@ -371,7 +371,7 @@ void HierParameterServer::update_database_per_model( }; char host_name[HOST_NAME_MAX + 1]; - HCTR_CHECK_HINT(!gethostname(host_name, sizeof(host_name)), "Unable to determine hostname.\n"); + HCTR_CHECK_HINT(!::gethostname(host_name, sizeof(host_name)), "Unable to determine hostname.\n"); switch (inference_params.update_source.type) { case UpdateSourceType_t::Null: