discord: update /tagme command to use new autotagger service.

This commit is contained in:
evazion
2022-06-27 01:39:43 -05:00
parent ee57ada33b
commit 04359d67f4
3 changed files with 28 additions and 42 deletions

View File

@@ -43,7 +43,6 @@ gem 'google-cloud-bigquery', require: "google/cloud/bigquery"
gem 'google-cloud-storage', require: "google/cloud/storage"
gem 'ed25519'
gem 'bcrypt_pbkdf' # https://github.com/net-ssh/net-ssh/issues/565
gem 'terminal-table'
gem 'clockwork'
gem 'puma-metrics'
gem 'puma_worker_killer'

View File

@@ -478,8 +478,6 @@ GEM
multi_json (~> 1.0)
stripe (> 5, < 6)
strscan (3.0.3)
terminal-table (3.0.2)
unicode-display_width (>= 1.1.1, < 3)
thor (1.2.1)
tilt (2.0.10)
timeout (0.3.0)
@@ -591,7 +589,6 @@ DEPENDENCIES
stackprof
stripe
stripe-ruby-mock
terminal-table
tzinfo-data
view_component
webpacker (= 6.0.0.rc.6)

View File

@@ -10,34 +10,45 @@ class DiscordSlashCommand
required: false,
type: ApplicationCommandOptionType::String
}, {
name: "table",
description: "Format the output as a table",
name: "confidence",
description: "The minimum tag confidence level (default: 1%)",
required: false,
type: ApplicationCommandOptionType::Boolean
type: ApplicationCommandOptionType::Integer
}]
def call
table = params.fetch(:table, false)
confidence = params.fetch(:confidence, 1).to_i / 100.0
# Use the given URL, if present, or the last message with an attachment, if not.
if params[:url].present?
respond_later { tagme(params[:url], table: table) }
respond_later { tagme(params[:url], confidence) }
elsif result = get_last_message_with_url
message, url = result
respond_later { tagme(url, table: table) }
respond_later { tagme(url, confidence) }
else
respond_with("No image found. Post an image or provide a URL.")
end
end
def tagme(url, table: false)
tags = get_tags(url)
def tagme(url, confidence, limit: 50, size: 500)
response, file = http.download_media(url)
preview = file.preview(size, size)
tags = autotagger.evaluate(preview, limit: limit, confidence: confidence).to_a
tags = tags.sort_by { |tag, confidence| [TagCategory.split_header_list.index(tag.category_name.downcase), -confidence] }.to_h
if table
build_tag_table(tags)
else
build_tag_list(tags)
end
return {
embeds: [{
description: build_tag_list(tags),
author: {
name: "#{Danbooru.config.app_name} Autotagger",
url: "https://github.com/danbooru/autotagger",
icon_url: "https://danbooru.donmai.us/images/danbooru-logo-96x96.png",
},
image: {
url: url,
},
}]
}
end
def get_last_message_with_url(limit: 10)
@@ -56,32 +67,11 @@ class DiscordSlashCommand
nil
end
def get_tags(url, size: 500, minimum_confidence: 0.5)
response, file = http.download_media(url)
preview = file.preview(size, size)
tags = deep_danbooru.tags!(preview).to_a
tags = tags.reject { |tag, confidence| confidence < minimum_confidence }
tags = tags.sort_by { |tag, confidence| [tag.general? ? 1 : 0, tag.name] }.to_h
tags
end
def build_tag_table(tags)
table = Terminal::Table.new
table.headings = ["Tag", "Count", "Confidence"]
tags.each do |tag, confidence|
table << [tag.name, tag.post_count, "%.f%%" % (100 * confidence)]
break if table.to_s.size >= DiscordApiClient::MAX_MESSAGE_LENGTH
end
"```\n#{table}\n```"
end
def build_tag_list(tags)
msg = ""
tags.keys.each do |tag|
msg << "[#{tag.name}](#{Routes.posts_url(tags: tag.name)}) "
tags.each do |tag, confidence|
msg += "#{(100*confidence).to_i}% [#{tag.name}](#{Routes.posts_url(tags: tag.name)})\n"
break if msg.size >= DiscordApiClient::MAX_MESSAGE_LENGTH
end
@@ -92,8 +82,8 @@ class DiscordSlashCommand
@http ||= Danbooru::Http.timeout(15)
end
def deep_danbooru
@deep_danbooru ||= DeepDanbooruClient.new(http: http)
def autotagger
@autotagger ||= AutotaggerClient.new(http: http)
end
end
end