Program Listing for File sharded_map.hpp

Return to documentation for file (sharded_map/sharded_map.hpp)

#pragma once

#include <algorithm>
#include <atomic>
#include <barrier>
#include <concepts>
#include <cstddef>
#include <iterator>
#include <omp.h>
#include <ranges>
#include <span>
#include <type_traits>
#include <unordered_map>
#include <vector>

namespace sharded_map {

namespace util {

constexpr uint64_t mix_select(uint64_t key) {
  key ^= (key >> 31);
  key *= 0x7fb5d329728ea185;
  key ^= (key >> 27);
  key *= 0x81dadef4bc2dd44d;
  key ^= (key >> 33);
  return key;
}

constexpr size_t ceil_div(std::integral auto x, std::integral auto y) {
  return 1 + (static_cast<size_t>(x) - 1) / static_cast<size_t>(y);
}

template<std::integral L, std::integral R>
struct unify_ints {
private:
  using temp_type = std::conditional_t<(sizeof(L) > sizeof(R)), L, R>;

public:
  using type = std::conditional_t<(std::signed_integral<L> || std::signed_integral<R>),
                                  std::make_signed_t<temp_type>,
                                  std::make_unsigned_t<temp_type>>;
};

template<std::integral L, std::integral R>
using unify_ints_t = unify_ints<L, R>::type;

} // namespace util

template<typename Fn, typename K, typename V>
concept UpdateFunction = requires(const K &k, V &v_lv, typename Fn::InputValue in_v_rv) {
  typename Fn::InputValue;
  // Updates a pre-existing value in the map.
  // Arguments are the key, the value in the map,
  // and the input value used to update the value in
  // the map
  { Fn::update(k, v_lv, std::move(in_v_rv)) } -> std::same_as<void>;
  // Initialize a value from an input value
  // Arguments are the key, and the value used to
  // initialize the value in the map. This returns the
  // value to be inserted into the map
  { Fn::init(k, std::move(in_v_rv)) } -> std::convertible_to<V>;
};

template<typename Fn>
concept StateCreatorFunction =
    std::invocable<Fn, size_t> && !std::is_void_v<std::result_of_t<Fn(size_t)>>;

template<typename Fn, typename K, typename V, typename IterVal, typename State>
concept StatefulGeneratorFunction =
    !std::is_void_v<State> &&
    std::invocable<Fn, std::add_rvalue_reference_t<IterVal>, std::add_lvalue_reference_t<State>> &&
    std::same_as<std::result_of_t<Fn(std::add_lvalue_reference_t<IterVal>,
                                     std::add_lvalue_reference_t<State>)>,
                 std::pair<K, V>>;

template<typename Fn, typename K, typename V, typename IterVal>
concept StatelessGeneratorFunction =
    std::invocable<Fn, std::add_lvalue_reference_t<IterVal>> &&
    std::same_as<std::result_of_t<Fn(std::add_lvalue_reference_t<IterVal>)>, std::pair<K, V>>;

namespace update_functions {

template<std::copy_constructible K, std::move_constructible V>
struct Overwrite {
  // The values we insert are the same as the ones saved in the map
  using InputValue = V;

  inline static void update(const K &, V &value, V &&input_value) { value = input_value; }

  inline static V init(const K &, V &&input_value) { return input_value; }
};

template<std::copy_constructible K, std::move_constructible V>
struct Keep {
  // The values we insert are the same as the ones saved in the map
  using InputValue = V;

  inline static void update(const K &, V &, V &&) {}

  inline static V init(const K &, V &&input_value) { return input_value; }
};

} // namespace update_functions

template<std::copy_constructible K,
         std::move_constructible V,
         template<typename, typename, typename...> typename SeqHashMapType = std::unordered_map,
         UpdateFunction<K, V> UpdateFn = update_functions::Overwrite<K, V>>
  requires std::movable<typename UpdateFn::InputValue>
class ShardedMap {
  using SeqHashMap = SeqHashMapType<K, V>;

  using Hasher = typename SeqHashMap::hasher;

  using InputValue = typename UpdateFn::InputValue;

  using QueueStoredValue = std::pair<K, InputValue>;

  using Queue = std::span<QueueStoredValue>;

  using mem = std::memory_order;

  const size_t thread_count_;

  std::vector<SeqHashMap> map_;

  std::vector<Queue> task_queue_;

  std::span<std::atomic_size_t> task_count_;

  std::atomic_size_t threads_handling_queue_;

  const size_t queue_capacity_;

  constexpr static std::invocable auto FN = []() noexcept {};

  std::unique_ptr<std::barrier<decltype(FN)>> barrier_;

public:
  ShardedMap(size_t thread_count, size_t queue_capacity) :
      thread_count_(thread_count),
      map_(),
      task_queue_(),
      task_count_(),
      threads_handling_queue_(0),
      queue_capacity_(queue_capacity),
      barrier_(
          std::make_unique<std::barrier<decltype(FN)>>(static_cast<ptrdiff_t>(thread_count), FN)) {
    map_.reserve(thread_count);
    task_queue_.reserve(thread_count);
    task_count_ = std::span<std::atomic_size_t>(new std::atomic_size_t[thread_count], thread_count);
    for (size_t i = 0; i < thread_count; i++) {
      map_.emplace_back();
      task_queue_.emplace_back(new QueueStoredValue[queue_capacity], queue_capacity);
      task_count_[i] = 0;
    }
  }

  ~ShardedMap() {
    delete[] task_count_.data();
    for (auto &queue : task_queue_) {
      delete[] queue.data();
    }
  }

  void reset_barrier() {
    barrier_.reset(new std::barrier<decltype(FN)>{static_cast<ptrdiff_t>(thread_count_), FN});
  }

  class Shard {
    // @brief The sharded map this shard belongs to.
    ShardedMap &sharded_map_;
    // @brief This thread's id.
    const size_t thread_id_;
    SeqHashMap &map_;
    Queue &task_queue_;
    std::atomic_size_t &task_count_;

  public:
    Shard(ShardedMap &sharded_map, size_t thread_id) :
        sharded_map_(sharded_map),
        thread_id_(thread_id),
        map_(sharded_map_.map_[thread_id]),
        task_queue_(sharded_map_.task_queue_[thread_id]),
        task_count_(sharded_map.task_count_[thread_id]) {}

    inline void insert_or_update_direct(const K &k, InputValue &&in_value) {
      auto res = map_.find(k);
      if (res == map_.end()) {
        // If the value does not exist, insert it
        K key     = k;
        V initial = UpdateFn::init(key, std::move(in_value));
        map_.emplace(key, std::move(initial));
      } else {
        // Otherwise, update it.
        V &val = res->second;
        UpdateFn::update(k, val, std::move(in_value));
      }
    }

    void handle_queue_sync(bool cause_wait = true) {
      // If we want to cause other threads to wait, we increment the number of threads handling the
      // queue This will cause other threads to wait when they call insert if >0
      if (cause_wait) {
        sharded_map_.threads_handling_queue_.fetch_add(1, mem::acq_rel);
      }
      sharded_map_.barrier_->arrive_and_wait();

      handle_queue_async();

      if (cause_wait) {
        sharded_map_.threads_handling_queue_.fetch_sub(1, mem::acq_rel);
      }
      sharded_map_.barrier_->arrive_and_wait();
    }

    void handle_queue_async() {
      const size_t num_tasks_uncapped = task_count_.exchange(0, mem::acq_rel);
      const size_t num_tasks          = std::min(num_tasks_uncapped, sharded_map_.queue_capacity_);
      if (num_tasks == 0) {
        return;
      }

      //  Handle all tasks in the queue
      for (size_t i = 0; i < num_tasks; ++i) {
        auto &entry = task_queue_[i];
        insert_or_update_direct(entry.first, std::move(entry.second));
      }
    }

    void insert(QueueStoredValue &&pair) {
      if (sharded_map_.threads_handling_queue_.load(mem::acquire) > 0) {
        handle_queue_sync();
      }
      const size_t hash             = Hasher{}(pair.first);
      const size_t target_thread_id = util::mix_select(hash) % sharded_map_.thread_count_;

      // Otherwise enqueue the new value in the target thread
      std::atomic_size_t &target_task_count = sharded_map_.task_count_[target_thread_id];

      size_t task_idx = target_task_count.fetch_add(1, mem::acq_rel);
      // If the target queue is full, signal to the other threads, that they
      // need to handle their queue and handle this thread's queue
      if (task_idx >= sharded_map_.queue_capacity_) {
        //  Since we incremented that thread's task count, but didn't insert
        //  anything, we need to decrement it again so that it has the correct
        //  value
        target_task_count.fetch_sub(1, mem::acq_rel);
        handle_queue_sync();
        // Since the queue was handled, the task count is now 0
        insert(std::move(pair));
        return;
      }
      // Insert the value into the queue
      sharded_map_.task_queue_[target_thread_id][task_idx] = std::move(pair);
    }

    inline void insert(const K &key, InputValue value) { insert(QueueStoredValue(key, value)); }
  };

  Shard get_shard(const size_t thread_id) { return Shard(*this, thread_id); }

  [[nodiscard]] size_t size() const {
    size_t size = 0;
    for (const SeqHashMap &map : map_) {
      size += map.size();
    }
    return size;
  }

  template<std::ranges::random_access_range Range,
           StateCreatorFunction             CreateStateFn,
           std::copyable                    State = std::result_of_t<CreateStateFn(size_t)>,
           StatefulGeneratorFunction<K, InputValue, std::ranges::range_reference_t<Range>, State>
               GeneratorFn>
  std::vector<State>
  batch_insert(Range &range, CreateStateFn gen_state, GeneratorFn generate_next) {
    using Iter = std::ranges::iterator_t<Range>;
    // Create thread-local state
    std::vector<State> state;
    state.reserve(thread_count_);
    for (size_t i = 0; i < thread_count_; i++) {
      state.push_back(gen_state(i));
    }
    // If the range is empty, there is nothing to do
    if (std::ranges::empty(range)) {
      return state;
    }

    // The size of the entire range and the size of the segment each thread handles
    const size_t range_size   = std::ranges::size(range);
    const size_t segment_size = util::ceil_div(range_size, thread_count_);

    std::atomic_size_t threads_done;

#pragma omp parallel num_threads(thread_count_)
    {
      const size_t thread_id = omp_get_thread_num();
      // Determine start and end iterator for this thread
      const Iter thread_start = std::ranges::begin(range) + (thread_id * segment_size);
      const Iter thread_end   = (thread_id + 1) * segment_size < range_size
                                    ? std::ranges::begin(range) + (thread_id + 1) * segment_size
                                    : std::ranges::end(range);
      // Get a shard and the state for this thread
      Shard  shard       = get_shard(thread_id);
      State &local_state = state[thread_id];

      for (Iter it = thread_start; it != thread_end; it++) {
        // Insert the generate element
        std::pair<K, InputValue> res = generate_next(*it, local_state);
        shard.insert(res.first, res.second);
      }

      // This thread is done inserting its elements
      threads_done++;

      // Threads might still be stuck in the insert loop. This thread must stay ready to handle its
      // queue, but not *cause* other threads to handle their queues.
      while (threads_done.load() < thread_count_) {
        shard.handle_queue_sync(false);
      }
      // Empty this thread's queue. All threads are done and no more insertions are coming
      shard.handle_queue_async();
      // This thread does not participate in the insert step anymore. The barrier won't wait for it
      // anymore.
      barrier_->arrive_and_drop();
    }

    // We're done with the insertion. We reset the barrier such that it can be used later.
    reset_barrier();

    // Return the final state of each thread.
    return state;
  }

  template<
      std::ranges::random_access_range                                                 Range,
      StatelessGeneratorFunction<K, InputValue, std::ranges::range_reference_t<Range>> GeneratorFn>
  void batch_insert(Range &range, GeneratorFn generate_next) {
    batch_insert(
        range,
        [](size_t) -> int { return 0; },
        [&](std::ranges::range_reference_t<Range> val, int &) { return generate_next(val); });
  }

  template<StateCreatorFunction        CreateStateFn,
           std::copyable               State = std::result_of_t<CreateStateFn(size_t)>,
           std::random_access_iterator Iter,
           StatefulGeneratorFunction<K, InputValue, std::iter_reference_t<Iter>, State> GeneratorFn>
  std::vector<State>
  batch_insert(Iter begin, Iter end, CreateStateFn gen_state, GeneratorFn generate_next) {
    return batch_insert(std::ranges::subrange(begin, end), gen_state, generate_next);
  }

  template<std::random_access_iterator                                            Iter,
           StatelessGeneratorFunction<K, InputValue, std::iter_reference_t<Iter>> GeneratorFn>
  void batch_insert(Iter begin, Iter end, GeneratorFn generate_next) {
    batch_insert(
        begin,
        end,
        [](size_t) -> int { return 0; },
        [&](std::iter_reference_t<Iter> val, int &) { return generate_next(val); });
  }

  template<
      StateCreatorFunction CreateStateFn,
      std::copyable        State = std::result_of_t<CreateStateFn(size_t)>,
      std::integral        StartIntType,
      std::integral        EndIntType,
      StatefulGeneratorFunction<K, InputValue, util::unify_ints_t<StartIntType, EndIntType>, State>
          GeneratorFn>
  std::vector<State> batch_insert(StartIntType  begin,
                                  EndIntType    end,
                                  CreateStateFn gen_state,
                                  GeneratorFn   generate_next) {
    using IntType = util::unify_ints_t<StartIntType, EndIntType>;

    const auto range =
        std::ranges::iota_view(static_cast<IntType>(begin), static_cast<IntType>(end));
    return batch_insert(range, gen_state, generate_next);
  }

  template<StatelessGeneratorFunction<K, InputValue, size_t> GeneratorFn>
  void batch_insert(std::integral auto begin, std::integral auto end, GeneratorFn generate_next) {
    batch_insert(
        begin,
        end,
        [](size_t) -> int { return 0; },
        [&](size_t i, int &) { return generate_next(i); });
  }

  enum class Whereabouts { NOWHERE, IN_MAP, IN_QUEUE };

  [[maybe_unused]] Whereabouts where(const K &k) {
    const size_t                  hash             = Hasher{}(k);
    const size_t                  target_thread_id = util::mix_select(hash) % thread_count_;
    SeqHashMap                   &map              = map_[target_thread_id];
    typename SeqHashMap::iterator it               = map.find(k);
    if (it != map.end()) {
      return Whereabouts::IN_MAP;
    }
    Queue &queue = task_queue_[target_thread_id];
    for (size_t i = 0; i < task_count_[target_thread_id]; ++i) {
      if (queue[i].first == k) {
        return Whereabouts::IN_QUEUE;
      }
    }
    return Whereabouts::NOWHERE;
  }

  void for_each(std::invocable<const K &, const V &> auto f) const {
    for (const SeqHashMap &map : map_) {
      for (const auto &[k, v] : map) {
        f(k, v);
      }
    }
  }

  typename SeqHashMap::iterator end() { return map_.back().end(); }

  typename SeqHashMap::iterator find(const K &key) {
    const size_t                  hash             = Hasher{}(key);
    const size_t                  target_thread_id = util::mix_select(hash) % thread_count_;
    SeqHashMap                   &map              = map_[target_thread_id];
    typename SeqHashMap::iterator it               = map.find(key);
    if (it == map.end()) {
      return end();
    }
    return it;
  }

  [[nodiscard]] std::vector<size_t> queue_loads() const {
    std::vector<size_t> loads;
    loads.reserve(thread_count_);
    for (size_t i = 0; i < thread_count_; ++i) {
      loads.push_back(task_count_[i].load());
    }
    return loads;
  }

  [[nodiscard]] std::vector<size_t> map_loads() const {
    std::vector<size_t> loads;
    loads.reserve(thread_count_);
    for (size_t i = 0; i < thread_count_; ++i) {
      loads.push_back(map_[i].size());
    }
    return loads;
  }

  std::barrier<decltype(FN)> &barrier() { return *barrier_; }

}; // namespace pasta

} // namespace sharded_map