StringDistances.jl/src/find.jl

42 lines
1.9 KiB
Julia
Raw Normal View History

2019-08-20 19:21:31 +02:00
"""
2019-12-12 22:49:20 +01:00
findmax(s::AbstractString, itr, dist::StringDistance; min_score = 0.0)
2019-08-20 19:21:31 +02:00
2019-12-12 22:49:20 +01:00
`findmax` returns the value and index of the element of `itr` that has the highest similarity score with `s` according to the distance `dist`.
2019-12-12 20:48:52 +01:00
It returns `(nothing, nothing)` if none of the elements has a similarity score higher or equal to `min_score` (default to 0.0)
2019-08-20 19:21:31 +02:00
The function is optimized for `Levenshtein` and `DamerauLevenshtein` distances (potentially modified by `Partial`, `TokenSort`, `TokenSet`, or `TokenMax`)
"""
2019-12-12 22:49:20 +01:00
function Base.findmax(s::AbstractString, itr, dist::StringDistance; min_score = 0.0)
vmin = Threads.Atomic{typeof(min_score)}(min_score)
vs = [0.0 for _ in 1:Threads.nthreads()]
xs = eltype(itr)["" for _ in 1:Threads.nthreads()]
2019-12-12 20:26:25 +01:00
is = [0 for _ in 1:Threads.nthreads()]
2019-12-12 23:02:46 +01:00
Threads.@threads for i in collect(keys(itr))
2019-12-12 22:49:20 +01:00
v = compare(s, itr[i], dist; min_score = vmin[])
v_old = Threads.atomic_max!(vmin, v)
if v >= v_old
vs[Threads.threadid()] = v
xs[Threads.threadid()] = itr[i]
2019-12-12 20:26:25 +01:00
is[Threads.threadid()] = i
2019-08-20 19:21:31 +02:00
end
end
2019-12-12 22:49:20 +01:00
i = argmax(vs)
2019-12-12 20:48:52 +01:00
is[i] == 0 ? (nothing, nothing) : (xs[i], is[i])
2019-08-20 19:21:31 +02:00
end
2019-08-20 21:38:14 +02:00
2019-08-20 19:21:31 +02:00
"""
2019-12-12 22:49:20 +01:00
findall(s::AbstractString, itr, dist::StringDistance; min_score = 0.8)
`findall` returns the vector of indices for elements of `itr` that have a similarity score higher or equal than `min_score` according to the distance `dist`.
2019-08-20 19:21:31 +02:00
The function is optimized for `Levenshtein` and `DamerauLevenshtein` distances (potentially modified by `Partial`, `TokenSort`, `TokenSet`, or `TokenMax`)
"""
2019-12-12 22:49:20 +01:00
function Base.findall(s::AbstractString, itr, dist::StringDistance; min_score = 0.8)
2019-12-12 20:26:25 +01:00
out = [Int[] for _ in 1:Threads.nthreads()]
2019-12-12 23:02:46 +01:00
Threads.@threads for i in collect(keys(itr))
2019-12-12 22:49:20 +01:00
score = compare(s, itr[i], dist; min_score = min_score)
2019-12-12 19:21:36 +01:00
if score >= min_score
2019-12-12 20:26:25 +01:00
push!(out[Threads.threadid()], i)
2019-12-12 19:21:36 +01:00
end
end
2019-12-12 20:26:25 +01:00
vcat(out...)
2019-08-20 19:21:31 +02:00
end