fix thread unsafety (#63)

main
adienes 2024-04-07 15:38:56 -04:00 committed by GitHub
parent b384824741
commit b7dd39f5e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 22 deletions

View File

@ -8,8 +8,8 @@ StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
[compat]
Distances = "0.8.1, 0.9, 0.10"
julia = "1.3"
StatsAPI = "1"
julia = "1.3"
[extras]
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

View File

@ -2,6 +2,7 @@ module StringDistances
using Distances: Distances, SemiMetric, Metric, evaluate, result_type
using StatsAPI: StatsAPI, pairwise, pairwise!
# Distances API
abstract type StringSemiMetric <: SemiMetric end
abstract type StringMetric <: Metric end

View File

@ -12,7 +12,7 @@ julia> compare("martha", "marhta", Levenshtein())
"""
function compare(s1, s2, dist::Union{StringSemiMetric, StringMetric}; min_score = 0.0)
1 - Normalized(dist)(s1, s2; max_dist = 1 - min_score)
end
end
"""
findnearest(s, itr, dist::Union{StringMetric, StringSemiMetric}) -> (x, index)
@ -35,22 +35,34 @@ julia> findnearest(s, iter, Levenshtein(); min_score = 0.9)
```
"""
function findnearest(s, itr, dist::Union{StringSemiMetric, StringMetric}; min_score = 0.0)
_citr = collect(itr)
isempty(_citr) && return (nothing, nothing)
_preprocessed_s = _preprocess(dist, s)
min_score_atomic = Threads.Atomic{Float64}(min_score)
scores = [0.0 for _ in 1:Threads.nthreads()]
is = [0 for _ in 1:Threads.nthreads()]
s = _preprocess(dist, s)
# need collect since @threads requires a length method
Threads.@threads for i in collect(eachindex(itr))
score = compare(s, _preprocess(dist, itr[i]), dist; min_score = min_score_atomic[])
score_old = Threads.atomic_max!(min_score_atomic, score)
if score >= score_old
scores[Threads.threadid()] = score
is[Threads.threadid()] = i
chunk_size = max(1, length(_citr) ÷ (2 * Threads.nthreads()))
data_chunks = Iterators.partition(_citr, chunk_size)
chunk_score_tasks = map(data_chunks) do chunk
Threads.@spawn begin
map(chunk) do x
score = compare(_preprocessed_s, _preprocess(dist, x), dist; min_score = min_score)
Threads.atomic_max!(min_score_atomic, score)
score
end
end
end
imax = is[argmax(scores)]
imax == 0 ? (nothing, nothing) : (itr[imax], imax)
# retrieve return type of `compare` for type stability in task
_self_cmp = compare(_preprocessed_s, _preprocessed_s, dist; min_score = min_score)
chunk_scores = fetch.(chunk_score_tasks)::Vector{Vector{typeof(_self_cmp)}}
scores = reduce(vcat, fetch.(chunk_scores))
imax = argmax(scores)
iszero(scores) ? (nothing, nothing) : (_citr[imax], imax)
end
_preprocess(dist::AbstractQGramDistance, ::Missing) = missing
_preprocess(dist::AbstractQGramDistance, s) = QGramSortedVector(s, dist.q)
_preprocess(dist::Union{StringSemiMetric, StringMetric}, s) = s
@ -83,14 +95,25 @@ julia> findall(s, iter, Levenshtein(); min_score = 0.9)
```
"""
function Base.findall(s, itr, dist::Union{StringSemiMetric, StringMetric}; min_score = 0.8)
out = [Int[] for _ in 1:Threads.nthreads()]
s = _preprocess(dist, s)
# need collect since @threads requires a length method
Threads.@threads for i in collect(eachindex(itr))
score = compare(s, _preprocess(dist, itr[i]), dist; min_score = min_score)
if score >= min_score
push!(out[Threads.threadid()], i)
_citr = collect(itr)
_preprocessed_s = _preprocess(dist, s)
chunk_size = max(1, length(_citr) ÷ (2 * Threads.nthreads()))
data_chunks = Iterators.partition(itr, chunk_size)
isempty(data_chunks) && return empty(eachindex(_citr))
chunk_score_tasks = map(data_chunks) do chunk
Threads.@spawn begin
map(chunk) do x
compare(_preprocessed_s, _preprocess(dist, x), dist; min_score = min_score)
end
end
end
vcat(out...)
# retrieve return type of `compare` for type stability in task
_self_cmp = compare(_preprocessed_s, _preprocessed_s, dist; min_score = min_score)
chunk_scores::Vector{Vector{typeof(_self_cmp)}} = fetch.(chunk_score_tasks)
scores = reduce(vcat, fetch.(chunk_scores))
return findall(>=(min_score), scores)
end