diff --git a/app/logical/related_tag_calculator.rb b/app/logical/related_tag_calculator.rb index d99d287c3..9905d21aa 100644 --- a/app/logical/related_tag_calculator.rb +++ b/app/logical/related_tag_calculator.rb @@ -1,11 +1,5 @@ class RelatedTagCalculator - def self.find_tags(tag, limit) - CurrentUser.without_safe_mode do - Post.with_timeout(5_000, [], {:tags => tag}) do - Post.tag_match(tag).limit(limit).reorder("posts.md5").pluck(:tag_string) - end - end - end + MAX_RESULTS = 25 def self.calculate_from_sample_to_array(tags, category_constraint = nil) convert_hash_to_array(calculate_from_sample(tags, Danbooru.config.post_sample_size, category_constraint)) @@ -60,29 +54,21 @@ class RelatedTagCalculator convert_hash_to_array(similar_counts) end - def self.calculate_from_sample(tags, limit, category_constraint = nil) - counts = Hash.new {|h, k| h[k] = 0} + def self.calculate_from_sample(tags, sample_size, category_constraint = nil, max_results = MAX_RESULTS) + Post.with_timeout(5_000, [], {:tags => tags}) do + sample = Post.sample(tags, sample_size) + posts_with_tags = Post.from(sample).with_unflattened_tags - find_tags(tags, limit).each do |tags| - tag_array = Tag.scan_tags(tags) if category_constraint - tag_array.each do |tag| - category = Tag.category_for(tag) - if category == category_constraint - counts[tag] += 1 - end - end - else - tag_array.each do |tag| - counts[tag] += 1 - end + posts_with_tags = posts_with_tags.joins("JOIN tags ON tags.name = tag").where("tags.category" => category_constraint) end - end - counts + counts = posts_with_tags.order("count(*) DESC").limit(max_results).group("tag").count + counts + end end - def self.convert_hash_to_array(hash, limit = 25) + def self.convert_hash_to_array(hash, limit = MAX_RESULTS) hash.to_a.sort_by {|x| -x[1]}.slice(0, limit) end diff --git a/app/models/post.rb b/app/models/post.rb index 10c3a863b..f36c7514e 100644 --- a/app/models/post.rb +++ b/app/models/post.rb @@ -1593,6 +1593,17 @@ class Post < ActiveRecord::Base where("md5 >= ?", key).reorder("md5 asc").first end + def sample(query, sample_size) + CurrentUser.without_safe_mode do + tag_match(query).reorder(:md5).limit(sample_size) + end + end + + # unflattens the tag_string into one tag per row. + def with_unflattened_tags + joins("CROSS JOIN unnest(string_to_array(tag_string, ' ')) AS tag") + end + def pending where("is_pending = ?", true) end