2019-08-20 19:21:31 +02:00
|
|
|
"""
|
2019-12-12 20:48:52 +01:00
|
|
|
findmax(s::AbstractString, iter::AbstractVector, dist::StringDistance; min_score = 0.0)
|
2019-08-20 19:21:31 +02:00
|
|
|
|
2019-12-12 20:48:52 +01:00
|
|
|
`findmax` returns the value and index of the element of `iter` that has the highest similarity score with `s` according to the distance `dist`.
|
|
|
|
It returns `(nothing, nothing)` if none of the elements has a similarity score higher or equal to `min_score` (default to 0.0)
|
2019-08-20 19:21:31 +02:00
|
|
|
The function is optimized for `Levenshtein` and `DamerauLevenshtein` distances (potentially modified by `Partial`, `TokenSort`, `TokenSet`, or `TokenMax`)
|
|
|
|
"""
|
2019-12-12 20:48:52 +01:00
|
|
|
function Base.findmax(s::AbstractString, iter::AbstractVector, dist::StringDistance; min_score = 0.0)
|
2019-12-12 19:21:36 +01:00
|
|
|
min_score >= 0 || throw("min_score should be positive")
|
2019-12-12 20:26:25 +01:00
|
|
|
is = [0 for _ in 1:Threads.nthreads()]
|
2019-12-12 21:44:54 +01:00
|
|
|
xs = eltype(iter)["" for _ in 1:Threads.nthreads()]
|
2019-12-12 20:26:25 +01:00
|
|
|
scores = [-1.0 for _ in 1:Threads.nthreads()]
|
2019-12-12 19:21:36 +01:00
|
|
|
min_score_atomic = Threads.Atomic{typeof(min_score)}(min_score)
|
2019-12-12 20:26:25 +01:00
|
|
|
Threads.@threads for i in 1:length(iter)
|
|
|
|
score = compare(s, iter[i], dist; min_score = min_score_atomic[])
|
2019-12-12 19:21:36 +01:00
|
|
|
min_score_atomic_old = Threads.atomic_max!(min_score_atomic, score)
|
|
|
|
if score >= min_score_atomic_old
|
2019-12-12 20:26:25 +01:00
|
|
|
score == 1.0 && return i
|
|
|
|
is[Threads.threadid()] = i
|
2019-12-12 20:48:52 +01:00
|
|
|
xs[Threads.threadid()] = iter[i]
|
2019-12-12 20:26:25 +01:00
|
|
|
scores[Threads.threadid()] = score
|
2019-08-20 19:21:31 +02:00
|
|
|
end
|
|
|
|
end
|
2019-12-12 20:26:25 +01:00
|
|
|
i = argmax(scores)
|
2019-12-12 20:48:52 +01:00
|
|
|
is[i] == 0 ? (nothing, nothing) : (xs[i], is[i])
|
2019-08-20 19:21:31 +02:00
|
|
|
end
|
2019-08-20 21:38:14 +02:00
|
|
|
|
2019-08-20 19:21:31 +02:00
|
|
|
|
|
|
|
"""
|
2019-12-12 20:48:52 +01:00
|
|
|
findall(s::AbstractString, iter::AbstractVector, dist::StringDistance; min_score = 0.8)
|
|
|
|
`findall` returns the vector of indices for elements of `iter` that have a similarity score higher or equal than `min_score` according to the distance `dist`.
|
2019-08-20 19:21:31 +02:00
|
|
|
The function is optimized for `Levenshtein` and `DamerauLevenshtein` distances (potentially modified by `Partial`, `TokenSort`, `TokenSet`, or `TokenMax`)
|
|
|
|
"""
|
2019-12-12 20:48:52 +01:00
|
|
|
function Base.findall(s::AbstractString, iter::AbstractVector, dist::StringDistance; min_score = 0.8)
|
2019-12-12 20:26:25 +01:00
|
|
|
out = [Int[] for _ in 1:Threads.nthreads()]
|
|
|
|
Threads.@threads for i in 1:length(iter)
|
|
|
|
score = compare(s, iter[i], dist; min_score = min_score)
|
2019-12-12 19:21:36 +01:00
|
|
|
if score >= min_score
|
2019-12-12 20:26:25 +01:00
|
|
|
push!(out[Threads.threadid()], i)
|
2019-12-12 19:21:36 +01:00
|
|
|
end
|
|
|
|
end
|
2019-12-12 20:26:25 +01:00
|
|
|
vcat(out...)
|
2019-08-20 19:21:31 +02:00
|
|
|
end
|