StringDistances.jl/src/find.jl

73 lines
2.7 KiB
Julia
Raw Normal View History

2019-08-20 19:21:31 +02:00
"""
2020-04-20 20:27:03 +02:00
findmax(s, itr, dist::StringDistance; min_score = 0.0) -> (x, index)
2019-08-20 19:21:31 +02:00
2019-12-13 00:55:41 +01:00
`findmax` returns the value and index of the element of `itr` that has the
highest similarity score with `s` according to the distance `dist`.
It returns `(nothing, nothing)` if none of the elements has a similarity score
2019-12-13 15:32:23 +01:00
higher or equal to `min_score` (default to 0.0).
2019-12-13 16:33:06 +01:00
It is particularly optimized for [`Levenshtein`](@ref) and [`DamerauLevenshtein`](@ref) distances
(as well as their modifications via [`Partial`](@ref), [`TokenSort`](@ref), [`TokenSet`](@ref), or [`TokenMax`](@ref)).
### Examples
```julia-repl
julia> using StringDistances
2020-02-08 17:54:40 +01:00
julia> s = "Newark"
2019-12-13 16:33:06 +01:00
julia> iter = ["New York", "Princeton", "San Francisco"]
julia> findmax(s, iter, Levenshtein())
("NewYork", 1)
julia> findmax(s, iter, Levenshtein(); min_score = 0.9)
(nothing, nothing)
```
2019-08-20 19:21:31 +02:00
"""
2020-04-20 20:27:03 +02:00
function Base.findmax(s, itr, dist::StringDistance; min_score = 0.0)
2019-12-13 16:33:06 +01:00
min_score_atomic = Threads.Atomic{typeof(min_score)}(min_score)
2019-12-13 15:32:23 +01:00
scores = [0.0 for _ in 1:Threads.nthreads()]
2019-12-12 20:26:25 +01:00
is = [0 for _ in 1:Threads.nthreads()]
2020-04-20 20:09:52 +02:00
# need collect since @threads requires a length method
2020-04-20 20:08:29 +02:00
Threads.@threads for i in collect(eachindex(itr))
2019-12-13 16:33:06 +01:00
score = compare(s, itr[i], dist; min_score = min_score_atomic[])
score_old = Threads.atomic_max!(min_score_atomic, score)
2019-12-13 15:32:23 +01:00
if score >= score_old
scores[Threads.threadid()] = score
2019-12-12 20:26:25 +01:00
is[Threads.threadid()] = i
2019-08-20 19:21:31 +02:00
end
end
2019-12-13 15:32:23 +01:00
imax = is[argmax(scores)]
2019-12-13 15:15:39 +01:00
imax == 0 ? (nothing, nothing) : (itr[imax], imax)
2019-08-20 19:21:31 +02:00
end
"""
2020-04-20 20:27:03 +02:00
findall(s, itr , dist::StringDistance; min_score = 0.8)
2019-12-13 00:55:41 +01:00
`findall` returns the vector of indices for elements of `itr` that have a
2019-12-13 15:32:23 +01:00
similarity score higher or equal than `min_score` according to the distance `dist`.
If there are no such elements, return an empty array.
2019-12-13 16:33:06 +01:00
It is particularly optimized for [`Levenshtein`](@ref) and [`DamerauLevenshtein`](@ref) distances
2019-12-13 15:32:23 +01:00
(as well as their modifications via `Partial`, `TokenSort`, `TokenSet`, or `TokenMax`).
2019-12-13 16:33:06 +01:00
### Examples
```julia-repl
julia> using StringDistances
julia> s = "Newark"
julia> iter = ["Newwark", "Princeton", "San Francisco"]
julia> findall(s, iter, Levenshtein())
1-element Array{Int64,1}:
1
julia> findall(s, iter, Levenshtein(); min_score = 0.9)
0-element Array{Int64,1}
```
2019-08-20 19:21:31 +02:00
"""
2020-04-20 20:27:03 +02:00
function Base.findall(s, itr, dist::StringDistance; min_score = 0.8)
2019-12-12 20:26:25 +01:00
out = [Int[] for _ in 1:Threads.nthreads()]
2020-04-20 20:09:52 +02:00
# need collect since @threads requires a length method
2020-04-20 20:08:29 +02:00
Threads.@threads for i in collect(eachindex(itr))
2019-12-12 22:49:20 +01:00
score = compare(s, itr[i], dist; min_score = min_score)
2019-12-12 19:21:36 +01:00
if score >= min_score
2019-12-12 20:26:25 +01:00
push!(out[Threads.threadid()], i)
2019-12-12 19:21:36 +01:00
end
end
2019-12-12 20:26:25 +01:00
vcat(out...)
2019-08-20 19:21:31 +02:00
end