Program Listing for File sharded_map.hpp¶
↰ Return to documentation for file (sharded_map/sharded_map.hpp)
#pragma once
#include <algorithm>
#include <atomic>
#include <barrier>
#include <concepts>
#include <cstddef>
#include <iterator>
#include <omp.h>
#include <ranges>
#include <span>
#include <type_traits>
#include <unordered_map>
#include <vector>
namespace sharded_map {
namespace util {
constexpr uint64_t mix_select(uint64_t key) {
key ^= (key >> 31);
key *= 0x7fb5d329728ea185;
key ^= (key >> 27);
key *= 0x81dadef4bc2dd44d;
key ^= (key >> 33);
return key;
}
constexpr size_t ceil_div(std::integral auto x, std::integral auto y) {
return 1 + (static_cast<size_t>(x) - 1) / static_cast<size_t>(y);
}
template<std::integral L, std::integral R>
struct unify_ints {
private:
using temp_type = std::conditional_t<(sizeof(L) > sizeof(R)), L, R>;
public:
using type = std::conditional_t<(std::signed_integral<L> || std::signed_integral<R>),
std::make_signed_t<temp_type>,
std::make_unsigned_t<temp_type>>;
};
template<std::integral L, std::integral R>
using unify_ints_t = unify_ints<L, R>::type;
} // namespace util
template<typename Fn, typename K, typename V>
concept UpdateFunction = requires(const K &k, V &v_lv, typename Fn::InputValue in_v_rv) {
typename Fn::InputValue;
// Updates a pre-existing value in the map.
// Arguments are the key, the value in the map,
// and the input value used to update the value in
// the map
{ Fn::update(k, v_lv, std::move(in_v_rv)) } -> std::same_as<void>;
// Initialize a value from an input value
// Arguments are the key, and the value used to
// initialize the value in the map. This returns the
// value to be inserted into the map
{ Fn::init(k, std::move(in_v_rv)) } -> std::convertible_to<V>;
};
template<typename Fn>
concept StateCreatorFunction =
std::invocable<Fn, size_t> && !std::is_void_v<std::result_of_t<Fn(size_t)>>;
template<typename Fn, typename K, typename V, typename IterVal, typename State>
concept StatefulGeneratorFunction =
!std::is_void_v<State> &&
std::invocable<Fn, std::add_rvalue_reference_t<IterVal>, std::add_lvalue_reference_t<State>> &&
std::same_as<std::result_of_t<Fn(std::add_lvalue_reference_t<IterVal>,
std::add_lvalue_reference_t<State>)>,
std::pair<K, V>>;
template<typename Fn, typename K, typename V, typename IterVal>
concept StatelessGeneratorFunction =
std::invocable<Fn, std::add_lvalue_reference_t<IterVal>> &&
std::same_as<std::result_of_t<Fn(std::add_lvalue_reference_t<IterVal>)>, std::pair<K, V>>;
namespace update_functions {
template<std::copy_constructible K, std::move_constructible V>
struct Overwrite {
// The values we insert are the same as the ones saved in the map
using InputValue = V;
inline static void update(const K &, V &value, V &&input_value) { value = input_value; }
inline static V init(const K &, V &&input_value) { return input_value; }
};
template<std::copy_constructible K, std::move_constructible V>
struct Keep {
// The values we insert are the same as the ones saved in the map
using InputValue = V;
inline static void update(const K &, V &, V &&) {}
inline static V init(const K &, V &&input_value) { return input_value; }
};
} // namespace update_functions
template<std::copy_constructible K,
std::move_constructible V,
template<typename, typename, typename...> typename SeqHashMapType = std::unordered_map,
UpdateFunction<K, V> UpdateFn = update_functions::Overwrite<K, V>>
requires std::movable<typename UpdateFn::InputValue>
class ShardedMap {
using SeqHashMap = SeqHashMapType<K, V>;
using Hasher = typename SeqHashMap::hasher;
using InputValue = typename UpdateFn::InputValue;
using QueueStoredValue = std::pair<K, InputValue>;
using Queue = std::span<QueueStoredValue>;
using mem = std::memory_order;
const size_t thread_count_;
std::vector<SeqHashMap> map_;
std::vector<Queue> task_queue_;
std::span<std::atomic_size_t> task_count_;
std::atomic_size_t threads_handling_queue_;
const size_t queue_capacity_;
constexpr static std::invocable auto FN = []() noexcept {};
std::unique_ptr<std::barrier<decltype(FN)>> barrier_;
public:
ShardedMap(size_t thread_count, size_t queue_capacity) :
thread_count_(thread_count),
map_(),
task_queue_(),
task_count_(),
threads_handling_queue_(0),
queue_capacity_(queue_capacity),
barrier_(
std::make_unique<std::barrier<decltype(FN)>>(static_cast<ptrdiff_t>(thread_count), FN)) {
map_.reserve(thread_count);
task_queue_.reserve(thread_count);
task_count_ = std::span<std::atomic_size_t>(new std::atomic_size_t[thread_count], thread_count);
for (size_t i = 0; i < thread_count; i++) {
map_.emplace_back();
task_queue_.emplace_back(new QueueStoredValue[queue_capacity], queue_capacity);
task_count_[i] = 0;
}
}
~ShardedMap() {
delete[] task_count_.data();
for (auto &queue : task_queue_) {
delete[] queue.data();
}
}
void reset_barrier() {
barrier_.reset(new std::barrier<decltype(FN)>{static_cast<ptrdiff_t>(thread_count_), FN});
}
class Shard {
// @brief The sharded map this shard belongs to.
ShardedMap &sharded_map_;
// @brief This thread's id.
const size_t thread_id_;
SeqHashMap &map_;
Queue &task_queue_;
std::atomic_size_t &task_count_;
public:
Shard(ShardedMap &sharded_map, size_t thread_id) :
sharded_map_(sharded_map),
thread_id_(thread_id),
map_(sharded_map_.map_[thread_id]),
task_queue_(sharded_map_.task_queue_[thread_id]),
task_count_(sharded_map.task_count_[thread_id]) {}
inline void insert_or_update_direct(const K &k, InputValue &&in_value) {
auto res = map_.find(k);
if (res == map_.end()) {
// If the value does not exist, insert it
K key = k;
V initial = UpdateFn::init(key, std::move(in_value));
map_.emplace(key, std::move(initial));
} else {
// Otherwise, update it.
V &val = res->second;
UpdateFn::update(k, val, std::move(in_value));
}
}
void handle_queue_sync(bool cause_wait = true) {
// If we want to cause other threads to wait, we increment the number of threads handling the
// queue This will cause other threads to wait when they call insert if >0
if (cause_wait) {
sharded_map_.threads_handling_queue_.fetch_add(1, mem::acq_rel);
}
sharded_map_.barrier_->arrive_and_wait();
handle_queue_async();
if (cause_wait) {
sharded_map_.threads_handling_queue_.fetch_sub(1, mem::acq_rel);
}
sharded_map_.barrier_->arrive_and_wait();
}
void handle_queue_async() {
const size_t num_tasks_uncapped = task_count_.exchange(0, mem::acq_rel);
const size_t num_tasks = std::min(num_tasks_uncapped, sharded_map_.queue_capacity_);
if (num_tasks == 0) {
return;
}
// Handle all tasks in the queue
for (size_t i = 0; i < num_tasks; ++i) {
auto &entry = task_queue_[i];
insert_or_update_direct(entry.first, std::move(entry.second));
}
}
void insert(QueueStoredValue &&pair) {
if (sharded_map_.threads_handling_queue_.load(mem::acquire) > 0) {
handle_queue_sync();
}
const size_t hash = Hasher{}(pair.first);
const size_t target_thread_id = util::mix_select(hash) % sharded_map_.thread_count_;
// Otherwise enqueue the new value in the target thread
std::atomic_size_t &target_task_count = sharded_map_.task_count_[target_thread_id];
size_t task_idx = target_task_count.fetch_add(1, mem::acq_rel);
// If the target queue is full, signal to the other threads, that they
// need to handle their queue and handle this thread's queue
if (task_idx >= sharded_map_.queue_capacity_) {
// Since we incremented that thread's task count, but didn't insert
// anything, we need to decrement it again so that it has the correct
// value
target_task_count.fetch_sub(1, mem::acq_rel);
handle_queue_sync();
// Since the queue was handled, the task count is now 0
insert(std::move(pair));
return;
}
// Insert the value into the queue
sharded_map_.task_queue_[target_thread_id][task_idx] = std::move(pair);
}
inline void insert(const K &key, InputValue value) { insert(QueueStoredValue(key, value)); }
};
Shard get_shard(const size_t thread_id) { return Shard(*this, thread_id); }
[[nodiscard]] size_t size() const {
size_t size = 0;
for (const SeqHashMap &map : map_) {
size += map.size();
}
return size;
}
template<std::ranges::random_access_range Range,
StateCreatorFunction CreateStateFn,
std::copyable State = std::result_of_t<CreateStateFn(size_t)>,
StatefulGeneratorFunction<K, InputValue, std::ranges::range_reference_t<Range>, State>
GeneratorFn>
std::vector<State>
batch_insert(Range &range, CreateStateFn gen_state, GeneratorFn generate_next) {
using Iter = std::ranges::iterator_t<Range>;
// Create thread-local state
std::vector<State> state;
state.reserve(thread_count_);
for (size_t i = 0; i < thread_count_; i++) {
state.push_back(gen_state(i));
}
// If the range is empty, there is nothing to do
if (std::ranges::empty(range)) {
return state;
}
// The size of the entire range and the size of the segment each thread handles
const size_t range_size = std::ranges::size(range);
const size_t segment_size = util::ceil_div(range_size, thread_count_);
std::atomic_size_t threads_done;
#pragma omp parallel num_threads(thread_count_)
{
const size_t thread_id = omp_get_thread_num();
// Determine start and end iterator for this thread
const Iter thread_start = std::ranges::begin(range) + (thread_id * segment_size);
const Iter thread_end = (thread_id + 1) * segment_size < range_size
? std::ranges::begin(range) + (thread_id + 1) * segment_size
: std::ranges::end(range);
// Get a shard and the state for this thread
Shard shard = get_shard(thread_id);
State &local_state = state[thread_id];
for (Iter it = thread_start; it != thread_end; it++) {
// Insert the generate element
std::pair<K, InputValue> res = generate_next(*it, local_state);
shard.insert(res.first, res.second);
}
// This thread is done inserting its elements
threads_done++;
// Threads might still be stuck in the insert loop. This thread must stay ready to handle its
// queue, but not *cause* other threads to handle their queues.
while (threads_done.load() < thread_count_) {
shard.handle_queue_sync(false);
}
// Empty this thread's queue. All threads are done and no more insertions are coming
shard.handle_queue_async();
// This thread does not participate in the insert step anymore. The barrier won't wait for it
// anymore.
barrier_->arrive_and_drop();
}
// We're done with the insertion. We reset the barrier such that it can be used later.
reset_barrier();
// Return the final state of each thread.
return state;
}
template<
std::ranges::random_access_range Range,
StatelessGeneratorFunction<K, InputValue, std::ranges::range_reference_t<Range>> GeneratorFn>
void batch_insert(Range &range, GeneratorFn generate_next) {
batch_insert(
range,
[](size_t) -> int { return 0; },
[&](std::ranges::range_reference_t<Range> val, int &) { return generate_next(val); });
}
template<StateCreatorFunction CreateStateFn,
std::copyable State = std::result_of_t<CreateStateFn(size_t)>,
std::random_access_iterator Iter,
StatefulGeneratorFunction<K, InputValue, std::iter_reference_t<Iter>, State> GeneratorFn>
std::vector<State>
batch_insert(Iter begin, Iter end, CreateStateFn gen_state, GeneratorFn generate_next) {
return batch_insert(std::ranges::subrange(begin, end), gen_state, generate_next);
}
template<std::random_access_iterator Iter,
StatelessGeneratorFunction<K, InputValue, std::iter_reference_t<Iter>> GeneratorFn>
void batch_insert(Iter begin, Iter end, GeneratorFn generate_next) {
batch_insert(
begin,
end,
[](size_t) -> int { return 0; },
[&](std::iter_reference_t<Iter> val, int &) { return generate_next(val); });
}
template<
StateCreatorFunction CreateStateFn,
std::copyable State = std::result_of_t<CreateStateFn(size_t)>,
std::integral StartIntType,
std::integral EndIntType,
StatefulGeneratorFunction<K, InputValue, util::unify_ints_t<StartIntType, EndIntType>, State>
GeneratorFn>
std::vector<State> batch_insert(StartIntType begin,
EndIntType end,
CreateStateFn gen_state,
GeneratorFn generate_next) {
using IntType = util::unify_ints_t<StartIntType, EndIntType>;
const auto range =
std::ranges::iota_view(static_cast<IntType>(begin), static_cast<IntType>(end));
return batch_insert(range, gen_state, generate_next);
}
template<StatelessGeneratorFunction<K, InputValue, size_t> GeneratorFn>
void batch_insert(std::integral auto begin, std::integral auto end, GeneratorFn generate_next) {
batch_insert(
begin,
end,
[](size_t) -> int { return 0; },
[&](size_t i, int &) { return generate_next(i); });
}
enum class Whereabouts { NOWHERE, IN_MAP, IN_QUEUE };
[[maybe_unused]] Whereabouts where(const K &k) {
const size_t hash = Hasher{}(k);
const size_t target_thread_id = util::mix_select(hash) % thread_count_;
SeqHashMap &map = map_[target_thread_id];
typename SeqHashMap::iterator it = map.find(k);
if (it != map.end()) {
return Whereabouts::IN_MAP;
}
Queue &queue = task_queue_[target_thread_id];
for (size_t i = 0; i < task_count_[target_thread_id]; ++i) {
if (queue[i].first == k) {
return Whereabouts::IN_QUEUE;
}
}
return Whereabouts::NOWHERE;
}
void for_each(std::invocable<const K &, const V &> auto f) const {
for (const SeqHashMap &map : map_) {
for (const auto &[k, v] : map) {
f(k, v);
}
}
}
typename SeqHashMap::iterator end() { return map_.back().end(); }
typename SeqHashMap::iterator find(const K &key) {
const size_t hash = Hasher{}(key);
const size_t target_thread_id = util::mix_select(hash) % thread_count_;
SeqHashMap &map = map_[target_thread_id];
typename SeqHashMap::iterator it = map.find(key);
if (it == map.end()) {
return end();
}
return it;
}
[[nodiscard]] std::vector<size_t> queue_loads() const {
std::vector<size_t> loads;
loads.reserve(thread_count_);
for (size_t i = 0; i < thread_count_; ++i) {
loads.push_back(task_count_[i].load());
}
return loads;
}
[[nodiscard]] std::vector<size_t> map_loads() const {
std::vector<size_t> loads;
loads.reserve(thread_count_);
for (size_t i = 0; i < thread_count_; ++i) {
loads.push_back(map_[i].size());
}
return loads;
}
std::barrier<decltype(FN)> &barrier() { return *barrier_; }
}; // namespace pasta
} // namespace sharded_map