export pairwise!

pull/39/head
matthieugomez 2020-11-08 17:05:14 -08:00
parent a0e5347d8c
commit 3040be7ced
3 changed files with 48 additions and 41 deletions

View File

@ -9,7 +9,9 @@ include("normalize.jl")
const StringDistance = Union{Jaro, Levenshtein, DamerauLevenshtein, RatcliffObershelp, QGramDistance, Winkler, Partial, TokenSort, TokenSet, TokenMax, Normalize}
# Distances API
Distances.result_type(dist::StringDistance, s1, s2) = typeof(dist("", ""))
Distances.result_type(dist::StringDistance, s1::Type, s2::Type) = typeof(dist("", ""))
Distances.result_type(dist::StringDistance, s1, s2) = result_type(dist, eltype(s1), eltype(s2))
include("find.jl")
include("pairwise.jl")
@ -43,6 +45,7 @@ result_type,
qgrams,
normalize,
findnearest,
pairwise
pairwise,
pairwise!
end

View File

@ -1,13 +1,9 @@
_allocmatrix(X, Y, T) = Matrix{T}(undef, length(X), length(Y))
_allocmatrix(X, T) = Matrix{T}(undef, length(X), length(X))
@doc """
pairwise(dist::StringDistance, itr; eltype = Float64, preprocess = nothing)
pairwise(dist::StringDistance, itr1, itr2; eltype = Float64, preprocess = nothing)
pairwise(dist::StringDistance, itr; preprocess = nothing)
pairwise(dist::StringDistance, itr1, itr2; preprocess = nothing)
Compute distances between all pairs of elements in `itr` according to the
`StringDistance` `dist`. The element type of the returned distance matrix
can be set via `eltype`.
`StringDistance` `dist`.
For QGramDistances preprocessing will be used either if `preprocess` is set
to true or if there are more than 5 elements in `itr`. Set `preprocess` to
@ -32,39 +28,52 @@ julia> pairwise(Levenshtein(), iter, iter2) # asymmetric
"""
Distances.pairwise
Distances.pairwise(dist::StringDistance, X, Y; eltype = Float64, preprocess = nothing) =
pairwise!(_allocmatrix(X, Y, eltype), dist, X, Y; preprocess = preprocess)
function Distances.pairwise(dist::StringDistance, X, Y; preprocess = nothing)
T = result_type(dist, eltype(X), eltype(Y))
R = Matrix{T}(undef, length(X), length(Y))
pairwise!(R, dist, X, Y; preprocess = preprocess)
end
Distances.pairwise(dist::StringDistance, X; eltype = Float64, preprocess = nothing) =
pairwise!(_allocmatrix(X, eltype), dist, X; preprocess = preprocess)
function Distances.pairwise(dist::StringDistance, X; preprocess = nothing)
T = result_type(dist, eltype(X), eltype(X))
R = Matrix{T}(undef, length(X), length(X))
pairwise!(R, dist, X; preprocess = preprocess)
end
pairwise!(R::AbstractMatrix{N}, dist::StringDistance, X; preprocess = nothing) where {N<:Number} =
@doc """
pairwise!(r::AbstractMatrix, dist::StringDistance, itr; preprocess = nothing)
pairwise!(r::AbstractMatrix, dist::StringDistance, itr1, itr2; preprocess = nothing)
Compute distances between all pairs of elements in `itr` according to the
`StringDistance` `dist` and write the result in `r`.
For QGramDistances preprocessing will be used either if `preprocess` is set
to true or if there are more than 5 elements in `itr`. Set `preprocess` to
false if no preprocessing should be used, regardless of length.
"""
Distances.pairwise!
function Distances.pairwise!(R::AbstractMatrix{<:Number}, dist::StringDistance, X, Y; preprocess = nothing)
_asymmetric_pairwise!(R, dist, X, Y; preprocess = preprocess)
end
function Distances.pairwise!(R::AbstractMatrix{<:Number}, dist::StringDistance, X; preprocess = nothing)
(dist isa SemiMetric) ?
_symmetric_pairwise!(R, dist, X; preprocess = preprocess) :
_asymmetric_pairwise!(R, dist, X, X; preprocess = preprocess)
pairwise!(R::AbstractMatrix{N}, dist::StringDistance, X, Y; preprocess = nothing) where {N<:Number} =
_asymmetric_pairwise!(R, dist, X, Y; preprocess = preprocess)
_preprocess(X, PT, q) = PT[PT(X[i], q) for i in 1:length(X)]
const PrecalcMinLength = 5 # Only precalc if length >= 5
preprocess_if_needed(X, dist::StringDistance, preprocess, preprocessType) = X
function preprocess_if_needed(X, dist::QGramDistance, preprocess, preprocessType)
# preprocess only if a QGramDistance and
# if precalc set to true or if isnothing and length is at least min length
cond = (preprocess === true) ||
(isnothing(preprocess) && length(X) >= PrecalcMinLength)
cond ? _preprocess(X, preprocessType, dist.q) : X
end
function _symmetric_pairwise!(R, dist::StringDistance, X;
preprocess = nothing, preprocessType = QGramSortedVector)
function _preprocess(X, dist::QGramDistance, preprocess)
# preprocess only precalc set to true or if isnothing and length is at least min length
cond = (preprocess === true) ||
(isnothing(preprocess) && length(X) >= 5)
cond ? [QGramSortedVector(X[i], dist.q) for i in 1:length(X)] : X
end
_preprocess(X, dist::StringDistance, preprocess) = X
objs = preprocess_if_needed(X, dist, preprocess, preprocessType)
function _symmetric_pairwise!(R, dist::StringDistance, X; preprocess = nothing)
objs = _preprocess(X, dist, preprocess)
for i in 1:length(objs)
R[i, i] = 0
Threads.@threads for j in (i+1):length(objs)
@ -76,14 +85,12 @@ end
function _asymmetric_pairwise!(R, dist::StringDistance, X, Y;
preprocess = nothing, preprocessType = QGramSortedVector)
objsX = preprocess_if_needed(X, dist, preprocess, preprocessType)
objsY = preprocess_if_needed(Y, dist, preprocess, preprocessType)
objsX = _preprocess(X, dist, preprocess)
objsY = _preprocess(Y, dist, preprocess)
for i in 1:length(objsX)
Threads.@threads for j in 1:length(objsY)
R[i, j] = evaluate(dist, objsX[i], objsY[j])
end
end
return R
end
end

View File

@ -13,7 +13,6 @@ TestStrings2 = ["mew", "ab"]
d = (DT <: QGramDistance) ? DT(2) : DT()
R = pairwise(d, TestStrings1)
@test R isa Matrix{Float64}
@test size(R) == (4, 4)
# No distance on the diagonal, since comparing strings to themselves
@ -46,7 +45,6 @@ TestStrings2 = ["mew", "ab"]
# Test also the assymetric version
R2 = pairwise(d, TestStrings1, TestStrings2)
@test R2 isa Matrix{Float64}
@test size(R2) == (4, 2)
@test equalorNaN(R2[1, 1], evaluate(d, "", "mew"))
@ -62,7 +60,6 @@ TestStrings2 = ["mew", "ab"]
@test equalorNaN(R2[4, 2], evaluate(d, "kitten", "ab"))
R3 = pairwise(d, TestStrings2, TestStrings1)
@test R3 isa Matrix{Float64}
@test size(R3) == (2, 4)
for i in 1:length(TestStrings1)