rmv NormalizedStringMetric

pull/57/head
matthieugomez 2021-09-13 11:43:23 -04:00
parent 87e9817ddf
commit 04a3fc587e
2 changed files with 8 additions and 9 deletions

View File

@ -3,9 +3,7 @@ module StringDistances
using Distances
import StatsAPI: pairwise, pairwise!
abstract type StringSemiMetric <: SemiMetric end
abstract type NormalizedStringSemiMetric <: StringSemiMetric end
abstract type StringMetric <: Metric end
(dist::NormalizedStringSemiMetric)(s1, s2; max_dist = 1.0) = dist(s1, s2)
(dist::Union{StringSemiMetric, StringMetric})(s1, s2; max_dist = nothing) = dist(s1, s2)
function Distances.result_type(dist::Union{StringSemiMetric, StringMetric}, s1::Type, s2::Type)

View File

@ -15,24 +15,25 @@ julia> StringDistances.Normalized(Levenshtein())(s1, s2)
0.8064
```
"""
struct Normalized{T <: Union{StringSemiMetric, StringMetric}} <: NormalizedStringSemiMetric
struct Normalized{T <: Union{StringSemiMetric, StringMetric}} <: StringSemiMetric
dist::T
end
Normalized(dist::Union{StringSemiMetric, StringMetric}) = Normalized{typeof(dist)}(dist)
Normalized(dist::Normalized) = dist
# this basically says that all distances are considered to be normalized by default
function (dist::Normalized)(s1, s2; max_dist = 1.0)
dist.dist(s1, s2; max_dist = max_dist)
out = dist.dist(s1, s2; max_dist = max_dist)
max_dist !== nothing && out > max_dist && return 1.0
return out
end
function (dist::Normalized{<:Union{Hamming, DamerauLevenshtein}})(s1, s2; max_dist = 1.0)
(s1 === missing) | (s2 === missing) && return missing
isempty(s1) && isempty(s2) && return 0.0
out = dist.dist(s1, s2) / length(s2)
s1, s2 = reorder(s1, s2)
len1, len2 = length(s1), length(s2)
out = dist.dist(s1, s2) / len2
max_dist !== nothing && out > max_dist && return 1.0
return out
end
@ -42,7 +43,7 @@ function (dist::Normalized{<:Union{Levenshtein, OptimalStringAlignement}})(s1, s
isempty(s1) && isempty(s2) && return 0.0
s1, s2 = reorder(s1, s2)
len1, len2 = length(s1), length(s2)
if max_dist === nothing || max_dist == 1.0
if max_dist == 1.0
d = dist.dist(s1, s2)
else
d = dist.dist(s1, s2; max_dist = ceil(Int, len2 * max_dist))