discord: update /tagme command to use new autotagger service.

This commit is contained in:
evazion
2022-06-27 01:39:43 -05:00
parent ee57ada33b
commit 04359d67f4
3 changed files with 28 additions and 42 deletions

View File

@@ -43,7 +43,6 @@ gem 'google-cloud-bigquery', require: "google/cloud/bigquery"
gem 'google-cloud-storage', require: "google/cloud/storage" gem 'google-cloud-storage', require: "google/cloud/storage"
gem 'ed25519' gem 'ed25519'
gem 'bcrypt_pbkdf' # https://github.com/net-ssh/net-ssh/issues/565 gem 'bcrypt_pbkdf' # https://github.com/net-ssh/net-ssh/issues/565
gem 'terminal-table'
gem 'clockwork' gem 'clockwork'
gem 'puma-metrics' gem 'puma-metrics'
gem 'puma_worker_killer' gem 'puma_worker_killer'

View File

@@ -478,8 +478,6 @@ GEM
multi_json (~> 1.0) multi_json (~> 1.0)
stripe (> 5, < 6) stripe (> 5, < 6)
strscan (3.0.3) strscan (3.0.3)
terminal-table (3.0.2)
unicode-display_width (>= 1.1.1, < 3)
thor (1.2.1) thor (1.2.1)
tilt (2.0.10) tilt (2.0.10)
timeout (0.3.0) timeout (0.3.0)
@@ -591,7 +589,6 @@ DEPENDENCIES
stackprof stackprof
stripe stripe
stripe-ruby-mock stripe-ruby-mock
terminal-table
tzinfo-data tzinfo-data
view_component view_component
webpacker (= 6.0.0.rc.6) webpacker (= 6.0.0.rc.6)

View File

@@ -10,34 +10,45 @@ class DiscordSlashCommand
required: false, required: false,
type: ApplicationCommandOptionType::String type: ApplicationCommandOptionType::String
}, { }, {
name: "table", name: "confidence",
description: "Format the output as a table", description: "The minimum tag confidence level (default: 1%)",
required: false, required: false,
type: ApplicationCommandOptionType::Boolean type: ApplicationCommandOptionType::Integer
}] }]
def call def call
table = params.fetch(:table, false) confidence = params.fetch(:confidence, 1).to_i / 100.0
# Use the given URL, if present, or the last message with an attachment, if not. # Use the given URL, if present, or the last message with an attachment, if not.
if params[:url].present? if params[:url].present?
respond_later { tagme(params[:url], table: table) } respond_later { tagme(params[:url], confidence) }
elsif result = get_last_message_with_url elsif result = get_last_message_with_url
message, url = result message, url = result
respond_later { tagme(url, table: table) } respond_later { tagme(url, confidence) }
else else
respond_with("No image found. Post an image or provide a URL.") respond_with("No image found. Post an image or provide a URL.")
end end
end end
def tagme(url, table: false) def tagme(url, confidence, limit: 50, size: 500)
tags = get_tags(url) response, file = http.download_media(url)
preview = file.preview(size, size)
tags = autotagger.evaluate(preview, limit: limit, confidence: confidence).to_a
tags = tags.sort_by { |tag, confidence| [TagCategory.split_header_list.index(tag.category_name.downcase), -confidence] }.to_h
if table return {
build_tag_table(tags) embeds: [{
else description: build_tag_list(tags),
build_tag_list(tags) author: {
end name: "#{Danbooru.config.app_name} Autotagger",
url: "https://github.com/danbooru/autotagger",
icon_url: "https://danbooru.donmai.us/images/danbooru-logo-96x96.png",
},
image: {
url: url,
},
}]
}
end end
def get_last_message_with_url(limit: 10) def get_last_message_with_url(limit: 10)
@@ -56,32 +67,11 @@ class DiscordSlashCommand
nil nil
end end
def get_tags(url, size: 500, minimum_confidence: 0.5)
response, file = http.download_media(url)
preview = file.preview(size, size)
tags = deep_danbooru.tags!(preview).to_a
tags = tags.reject { |tag, confidence| confidence < minimum_confidence }
tags = tags.sort_by { |tag, confidence| [tag.general? ? 1 : 0, tag.name] }.to_h
tags
end
def build_tag_table(tags)
table = Terminal::Table.new
table.headings = ["Tag", "Count", "Confidence"]
tags.each do |tag, confidence|
table << [tag.name, tag.post_count, "%.f%%" % (100 * confidence)]
break if table.to_s.size >= DiscordApiClient::MAX_MESSAGE_LENGTH
end
"```\n#{table}\n```"
end
def build_tag_list(tags) def build_tag_list(tags)
msg = "" msg = ""
tags.keys.each do |tag| tags.each do |tag, confidence|
msg << "[#{tag.name}](#{Routes.posts_url(tags: tag.name)}) " msg += "#{(100*confidence).to_i}% [#{tag.name}](#{Routes.posts_url(tags: tag.name)})\n"
break if msg.size >= DiscordApiClient::MAX_MESSAGE_LENGTH break if msg.size >= DiscordApiClient::MAX_MESSAGE_LENGTH
end end
@@ -92,8 +82,8 @@ class DiscordSlashCommand
@http ||= Danbooru::Http.timeout(15) @http ||= Danbooru::Http.timeout(15)
end end
def deep_danbooru def autotagger
@deep_danbooru ||= DeepDanbooruClient.new(http: http) @autotagger ||= AutotaggerClient.new(http: http)
end end
end end
end end