fix thread unsafety (#63)
parent
b384824741
commit
b7dd39f5e7
|
@ -8,8 +8,8 @@ StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0"
|
|||
|
||||
[compat]
|
||||
Distances = "0.8.1, 0.9, 0.10"
|
||||
julia = "1.3"
|
||||
StatsAPI = "1"
|
||||
julia = "1.3"
|
||||
|
||||
[extras]
|
||||
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
|
||||
|
|
|
@ -2,6 +2,7 @@ module StringDistances
|
|||
|
||||
using Distances: Distances, SemiMetric, Metric, evaluate, result_type
|
||||
using StatsAPI: StatsAPI, pairwise, pairwise!
|
||||
|
||||
# Distances API
|
||||
abstract type StringSemiMetric <: SemiMetric end
|
||||
abstract type StringMetric <: Metric end
|
||||
|
|
65
src/find.jl
65
src/find.jl
|
@ -12,7 +12,7 @@ julia> compare("martha", "marhta", Levenshtein())
|
|||
"""
|
||||
function compare(s1, s2, dist::Union{StringSemiMetric, StringMetric}; min_score = 0.0)
|
||||
1 - Normalized(dist)(s1, s2; max_dist = 1 - min_score)
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
findnearest(s, itr, dist::Union{StringMetric, StringSemiMetric}) -> (x, index)
|
||||
|
@ -35,22 +35,34 @@ julia> findnearest(s, iter, Levenshtein(); min_score = 0.9)
|
|||
```
|
||||
"""
|
||||
function findnearest(s, itr, dist::Union{StringSemiMetric, StringMetric}; min_score = 0.0)
|
||||
_citr = collect(itr)
|
||||
isempty(_citr) && return (nothing, nothing)
|
||||
|
||||
_preprocessed_s = _preprocess(dist, s)
|
||||
min_score_atomic = Threads.Atomic{Float64}(min_score)
|
||||
scores = [0.0 for _ in 1:Threads.nthreads()]
|
||||
is = [0 for _ in 1:Threads.nthreads()]
|
||||
s = _preprocess(dist, s)
|
||||
# need collect since @threads requires a length method
|
||||
Threads.@threads for i in collect(eachindex(itr))
|
||||
score = compare(s, _preprocess(dist, itr[i]), dist; min_score = min_score_atomic[])
|
||||
score_old = Threads.atomic_max!(min_score_atomic, score)
|
||||
if score >= score_old
|
||||
scores[Threads.threadid()] = score
|
||||
is[Threads.threadid()] = i
|
||||
|
||||
chunk_size = max(1, length(_citr) ÷ (2 * Threads.nthreads()))
|
||||
data_chunks = Iterators.partition(_citr, chunk_size)
|
||||
|
||||
chunk_score_tasks = map(data_chunks) do chunk
|
||||
Threads.@spawn begin
|
||||
map(chunk) do x
|
||||
score = compare(_preprocessed_s, _preprocess(dist, x), dist; min_score = min_score)
|
||||
Threads.atomic_max!(min_score_atomic, score)
|
||||
score
|
||||
end
|
||||
end
|
||||
end
|
||||
imax = is[argmax(scores)]
|
||||
imax == 0 ? (nothing, nothing) : (itr[imax], imax)
|
||||
|
||||
# retrieve return type of `compare` for type stability in task
|
||||
_self_cmp = compare(_preprocessed_s, _preprocessed_s, dist; min_score = min_score)
|
||||
chunk_scores = fetch.(chunk_score_tasks)::Vector{Vector{typeof(_self_cmp)}}
|
||||
scores = reduce(vcat, fetch.(chunk_scores))
|
||||
|
||||
imax = argmax(scores)
|
||||
iszero(scores) ? (nothing, nothing) : (_citr[imax], imax)
|
||||
end
|
||||
|
||||
_preprocess(dist::AbstractQGramDistance, ::Missing) = missing
|
||||
_preprocess(dist::AbstractQGramDistance, s) = QGramSortedVector(s, dist.q)
|
||||
_preprocess(dist::Union{StringSemiMetric, StringMetric}, s) = s
|
||||
|
@ -83,14 +95,25 @@ julia> findall(s, iter, Levenshtein(); min_score = 0.9)
|
|||
```
|
||||
"""
|
||||
function Base.findall(s, itr, dist::Union{StringSemiMetric, StringMetric}; min_score = 0.8)
|
||||
out = [Int[] for _ in 1:Threads.nthreads()]
|
||||
s = _preprocess(dist, s)
|
||||
# need collect since @threads requires a length method
|
||||
Threads.@threads for i in collect(eachindex(itr))
|
||||
score = compare(s, _preprocess(dist, itr[i]), dist; min_score = min_score)
|
||||
if score >= min_score
|
||||
push!(out[Threads.threadid()], i)
|
||||
_citr = collect(itr)
|
||||
_preprocessed_s = _preprocess(dist, s)
|
||||
|
||||
chunk_size = max(1, length(_citr) ÷ (2 * Threads.nthreads()))
|
||||
data_chunks = Iterators.partition(itr, chunk_size)
|
||||
isempty(data_chunks) && return empty(eachindex(_citr))
|
||||
|
||||
chunk_score_tasks = map(data_chunks) do chunk
|
||||
Threads.@spawn begin
|
||||
map(chunk) do x
|
||||
compare(_preprocessed_s, _preprocess(dist, x), dist; min_score = min_score)
|
||||
end
|
||||
end
|
||||
end
|
||||
vcat(out...)
|
||||
|
||||
# retrieve return type of `compare` for type stability in task
|
||||
_self_cmp = compare(_preprocessed_s, _preprocessed_s, dist; min_score = min_score)
|
||||
chunk_scores::Vector{Vector{typeof(_self_cmp)}} = fetch.(chunk_score_tasks)
|
||||
|
||||
scores = reduce(vcat, fetch.(chunk_scores))
|
||||
return findall(>=(min_score), scores)
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue