diff --git a/lib_vector_search/include/bucket_finder.h b/lib_vector_search/include/bucket_finder.h index b97ecbc..4b2dddf 100644 --- a/lib_vector_search/include/bucket_finder.h +++ b/lib_vector_search/include/bucket_finder.h @@ -10,34 +10,17 @@ private: std::unordered_map directory_; public: - /// Inserts references to all words from word_list between first_index and - /// last_index, including first_index, excluding last_index. void insert(const WordList &word_list, size_t first_index, size_t last_index); - /// Find all words that start with search_term - /// @return A list with references to the results. WordRefList find_prefix(std::string_view search_term) const; }; -/** This class provides efficient, parallel search over a list of strings. - * - * References to all input strings are stored in a tree-like structure that - * provides fast and lock-free parallel insertion and parallel (with minimal - * synchronization) search for all words that start with a given term. - */ class BucketFinder : public Finder { private: std::vector buckets_; public: - /// Creates a BucketFinder over all words in word_list. BucketFinder(const WordList &word_list); - /// Find all words that start with search_term - /// @return A list with references to the results. WordRefList find_prefix(std::string_view search_term) const override; - -private: - /// Inserts references to all words from word_list. - void insert(const WordList &word_list); }; diff --git a/lib_vector_search/include/word_list.h b/lib_vector_search/include/word_list.h index 3894a15..f197170 100644 --- a/lib_vector_search/include/word_list.h +++ b/lib_vector_search/include/word_list.h @@ -1,9 +1,12 @@ #pragma once #include +#include #include #include +class WordRefList; + class WordList : public std::vector { public: WordList &multiply(size_t factor); @@ -12,10 +15,22 @@ public: static WordList oneCap(); static WordList fourCaps(); static WordList fromFile(const std::filesystem::path &path); + + static void find_prefix_in_range(const WordList &word_list, + const std::string_view &search_prefix, + size_t start_index, size_t end_index, + WordRefList &result, + std::mutex &result_mutex); }; class WordRefList : public std::vector { public: WordRefList() = default; WordRefList(const WordList &source); + + static void find_prefix_in_range(const WordRefList &word_list, + const std::string_view &search_prefix, + size_t start_index, size_t end_index, + WordRefList &result, + std::mutex &result_mutex); }; diff --git a/lib_vector_search/src/bucket_finder.cpp b/lib_vector_search/src/bucket_finder.cpp index 03949e1..8fd7500 100644 --- a/lib_vector_search/src/bucket_finder.cpp +++ b/lib_vector_search/src/bucket_finder.cpp @@ -1,11 +1,8 @@ #include "bucket_finder.h" -#include -#include #include -#include +#include #include -#include void Bucket::insert(const WordList &word_list, size_t first_index, size_t last_index) { @@ -31,35 +28,34 @@ WordRefList Bucket::find_prefix(std::string_view search_term) const { return result; } -BucketFinder::BucketFinder(const WordList &word_list) { insert(word_list); } - -void BucketFinder::insert(const WordList &word_list) { +BucketFinder::BucketFinder(const WordList &word_list) { if (word_list.empty()) { return; } - const size_t max_threads = std::thread::hardware_concurrency(); const size_t word_list_size = word_list.size(); - const size_t bucket_count = std::min(max_threads, word_list_size); + const size_t bucket_count = + std::min(std::thread::hardware_concurrency(), word_list_size); const size_t bucket_size = word_list_size / bucket_count; buckets_.resize(bucket_count); - std::vector insert_threads; + std::vector threads; for (auto bucket_index = 0; bucket_index < bucket_count; ++bucket_index) { auto &bucket = buckets_[bucket_index]; - const bool is_last_bucket = bucket_index == bucket_count - 1; - const size_t first_word_index = bucket_index * bucket_size; - const size_t last_word_index = - is_last_bucket ? word_list_size : first_word_index + bucket_size; + bool is_last_bucket = bucket_index == bucket_count - 1; - insert_threads.emplace_back([&, first_word_index, last_word_index] { - bucket.insert(word_list, first_word_index, last_word_index); + const size_t first_index = bucket_index * bucket_size; + const size_t last_index = + is_last_bucket ? word_list_size : first_index + bucket_size; + + threads.emplace_back([&, first_index, last_index] { + bucket.insert(word_list, first_index, last_index); }); } - for (auto &thread : insert_threads) { + for (auto &thread : threads) { thread.join(); } } @@ -73,7 +69,7 @@ WordRefList BucketFinder::find_prefix(std::string_view search_term) const { threads.emplace_back([&] { auto thread_search_results = bucket.find_prefix(search_term); if (!thread_search_results.empty()) { - const std::lock_guard result_lock(search_results_mutex); + std::lock_guard result_lock(search_results_mutex); std::move(thread_search_results.begin(), thread_search_results.end(), std::back_inserter(search_results)); } diff --git a/lib_vector_search/src/grouped_finder.cpp b/lib_vector_search/src/grouped_finder.cpp index d901dda..990b95c 100644 --- a/lib_vector_search/src/grouped_finder.cpp +++ b/lib_vector_search/src/grouped_finder.cpp @@ -19,43 +19,31 @@ WordRefList GroupedFinder::find_prefix(string_view search_prefix) const { } const auto word_list = group->second; - const size_t word_list_size = word_list.size(); - const size_t thread_count = - std::min(std::thread::hardware_concurrency(), word_list_size); - const size_t words_per_thread = word_list_size / thread_count; + const auto word_list_size = word_list.size(); - WordRefList search_results; - mutex search_results_mutex; + const auto thread_count = + std::min(std::thread::hardware_concurrency(), word_list_size); + + WordRefList result; + mutex result_mutex; vector search_threads; for (size_t thread_index = 0; thread_index < thread_count; ++thread_index) { - const bool is_last_thread = thread_index == thread_count - 1; - const size_t start_index = thread_index * words_per_thread; - const size_t end_index = - is_last_thread ? word_list_size : start_index + words_per_thread; + const size_t first_index = thread_index * (word_list_size / thread_count); - search_threads.emplace_back([&, start_index, end_index] { - WordRefList local_results; + const size_t last_index = + (thread_index == thread_count - 1) + ? word_list_size + : (thread_index + 1) * (word_list_size / thread_count); - for (size_t word_index = start_index; word_index < end_index; - ++word_index) { - const auto *current_word = word_list[word_index]; - if (current_word->starts_with(search_prefix)) { - local_results.push_back(current_word); - } - } - - if (!local_results.empty()) { - const std::lock_guard lock(search_results_mutex); - std::move(local_results.begin(), local_results.end(), - std::back_inserter(search_results)); - } - }); + search_threads.emplace_back( + WordRefList::find_prefix_in_range, cref(word_list), cref(search_prefix), + first_index, last_index, ref(result), ref(result_mutex)); } for (auto &thread : search_threads) { thread.join(); } - return search_results; + return result; } diff --git a/lib_vector_search/src/parallel_finder.cpp b/lib_vector_search/src/parallel_finder.cpp index 8d3367a..b398b2e 100644 --- a/lib_vector_search/src/parallel_finder.cpp +++ b/lib_vector_search/src/parallel_finder.cpp @@ -9,43 +9,30 @@ ParallelFinder::ParallelFinder(const WordList &word_list) : word_list_(word_list) {} WordRefList ParallelFinder::find_prefix(string_view search_prefix) const { - WordRefList search_results; - mutex search_results_mutex; + WordRefList result; + mutex result_mutex; const size_t word_list_size = word_list_.size(); const size_t thread_count = std::min(thread::hardware_concurrency(), word_list_size); - const size_t words_per_thread = word_list_size / thread_count; vector search_threads; for (size_t thread_index = 0; thread_index < thread_count; ++thread_index) { - const bool is_last_thread = thread_index == thread_count - 1; - const size_t start_index = thread_index * words_per_thread; - const size_t end_index = - is_last_thread ? word_list_size : start_index + words_per_thread; + const size_t first_index = thread_index * (word_list_size / thread_count); - search_threads.emplace_back([&, start_index, end_index] { - WordRefList local_results; + const size_t last_index = + (thread_index == thread_count - 1) + ? word_list_size + : (thread_index + 1) * (word_list_size / thread_count); - for (size_t word_index = start_index; word_index < end_index; - ++word_index) { - const auto ¤t_word = word_list_[word_index]; - if (current_word.starts_with(search_prefix)) { - local_results.push_back(¤t_word); - } - } - - if (!local_results.empty()) { - const std::lock_guard lock(search_results_mutex); - std::move(local_results.begin(), local_results.end(), - std::back_inserter(search_results)); - } - }); + search_threads.emplace_back( + WordList::find_prefix_in_range, cref(word_list_), cref(search_prefix), + first_index, last_index, ref(result), ref(result_mutex)); } for (auto &thread : search_threads) { thread.join(); } - return search_results; + return result; } diff --git a/lib_vector_search/src/word_list.cpp b/lib_vector_search/src/word_list.cpp index e1baf7c..039ed83 100644 --- a/lib_vector_search/src/word_list.cpp +++ b/lib_vector_search/src/word_list.cpp @@ -71,8 +71,50 @@ WordList WordList::fromFile(const std::filesystem::path &path) { return word_list; } +void WordList::find_prefix_in_range(const WordList &word_list, + const std::string_view &search_prefix, + size_t start_index, size_t end_index, + WordRefList &result, + std::mutex &result_mutex) { + WordRefList local_results; + + for (size_t index = start_index; index < end_index; ++index) { + const auto ¤t_word = word_list[index]; + if (current_word.starts_with(search_prefix)) { + local_results.push_back(¤t_word); + } + } + + if (!local_results.empty()) { + const std::lock_guard lock(result_mutex); + std::move(local_results.begin(), local_results.end(), + std::back_inserter(result)); + } +}; + WordRefList::WordRefList(const WordList &source) { for (const auto &word : source) { push_back(&word); } } + +void WordRefList::find_prefix_in_range(const WordRefList &word_list, + const std::string_view &search_prefix, + size_t start_index, size_t end_index, + WordRefList &result, + std::mutex &result_mutex) { + WordRefList local_results; + + for (size_t index = start_index; index < end_index; ++index) { + const auto *current_word = word_list[index]; + if (current_word->starts_with(search_prefix)) { + local_results.push_back(current_word); + } + } + + if (!local_results.empty()) { + const std::lock_guard lock(result_mutex); + std::move(local_results.begin(), local_results.end(), + std::back_inserter(result)); + } +};