From f270e18310bac2cfa72487d1113bae22408660d2 Mon Sep 17 00:00:00 2001 From: matthieugomez Date: Wed, 30 Dec 2020 14:39:18 +0100 Subject: [PATCH] simplify pairwise --- src/pairwise.jl | 58 +++++++++++++++++-------------------------------- 1 file changed, 20 insertions(+), 38 deletions(-) diff --git a/src/pairwise.jl b/src/pairwise.jl index 3fd6aa2..ad677b1 100644 --- a/src/pairwise.jl +++ b/src/pairwise.jl @@ -1,9 +1,8 @@ @doc """ - pairwise(dist::StringDistance, xs::AbstractVector; preprocess = nothing) - pairwise(dist::StringDistance, xs::AbstractVector, ys::AbstractVector; preprocess = nothing) + pairwise(dist::StringDistance, xs::AbstractVector, ys::AbstractVector = xs; preprocess = nothing) Compute distances between all pairs of elements in `xs` and `ys` according to the -`StringDistance` `dist`. +`StringDistance` `dist`. Returns a matrix R such that `R[i, j]` corrresponds to the distance between `xs[i]` and `ys[j]`. For AbstractQGramDistances preprocessing will be used either if `preprocess` is set to true or if there are more than 5 elements in `xs`. Set `preprocess` to @@ -15,12 +14,12 @@ Both symmetric and asymmetric versions are available. ```julia-repl julia> using StringDistances julia> iter = ["New York", "Princeton"] -julia> pairwise(Levenshtein(), iter) # symmetric +julia> pairwise(Levenshtein(), iter) 2×2 Array{Float64,2}: 0.0 9.0 9.0 0.0 julia> iter2 = ["San Francisco"] -julia> pairwise(Levenshtein(), iter, iter2) # asymmetric +julia> pairwise(Levenshtein(), iter, iter2) 2×1 Array{Float64,2}: 12.0 10.0 @@ -28,7 +27,7 @@ julia> pairwise(Levenshtein(), iter, iter2) # asymmetric """ Distances.pairwise -function Distances.pairwise(dist::StringDistance, xs::AbstractVector, ys::AbstractVector; preprocess = nothing) +function Distances.pairwise(dist::StringDistance, xs::AbstractVector, ys::AbstractVector = xs; preprocess = nothing) T = result_type(dist, eltype(xs), eltype(ys)) if Missing <: Union{eltype(xs), eltype(ys)} T = Union{T, Missing} @@ -37,21 +36,11 @@ function Distances.pairwise(dist::StringDistance, xs::AbstractVector, ys::Abstra pairwise!(R, dist, xs, ys; preprocess = preprocess) end -function Distances.pairwise(dist::StringDistance, xs::AbstractVector; preprocess = nothing) - T = result_type(dist, eltype(xs), eltype(xs)) - if Missing <: eltype(xs) - T = Union{T, Missing} - end - R = Matrix{T}(undef, length(xs), length(xs)) - pairwise!(R, dist, xs; preprocess = preprocess) -end - @doc """ - pairwise!(R::AbstractMatrix, dist::StringDistance, xs::AbstractVector; preprocess = nothing) - pairwise!(R::AbstractMatrix, dist::StringDistance, xs::AbstractVector, ys::AbstractVector; preprocess = nothing) + pairwise!(R::AbstractMatrix, dist::StringDistance, xs::AbstractVector, ys::AbstractVector = xs; preprocess = nothing) Compute distances between all pairs of elements in `xs` and `ys` according to the -`StringDistance` `dist` and write the result in `R`. +`StringDistance` `dist` and write the result in `R`. `R[i, j]` corrresponds to the distance between `xs[i]` and `ys[j]`. For AbstractQGramDistances preprocessing will be used either if `preprocess` is set to true or if there are more than 5 elements in `xs`. Set `preprocess` to @@ -59,30 +48,14 @@ false if no preprocessing should be used, regardless of length. """ Distances.pairwise! -function Distances.pairwise!(R::AbstractMatrix, dist::StringDistance, xs::AbstractVector, ys::AbstractVector; preprocess = nothing) +function Distances.pairwise!(R::AbstractMatrix, dist::StringDistance, xs::AbstractVector, ys::AbstractVector = xs; preprocess = nothing) length(xs) == size(R, 1) || throw(DimensionMismatch("inconsistent length")) length(ys) == size(R, 2) || throw(DimensionMismatch("inconsistent length")) - _asymmetric_pairwise!(R, dist, xs, ys; preprocess = preprocess) -end - -function Distances.pairwise!(R::AbstractMatrix, dist::StringDistance, xs::AbstractVector; preprocess = nothing) - length(xs) == size(R, 1) || throw(DimensionMismatch("inconsistent length")) - length(xs) == size(R, 2) || throw(DimensionMismatch("inconsistent length")) - (dist isa SemiMetric) ? + ((xs === ys) & (dist isa SemiMetric)) ? _symmetric_pairwise!(R, dist, xs; preprocess = preprocess) : - _asymmetric_pairwise!(R, dist, xs, xs; preprocess = preprocess) + _asymmetric_pairwise!(R, dist, xs, ys; preprocess = preprocess) end -function _preprocess(xs, dist::AbstractQGramDistance, preprocess) - if preprocess === nothing ? length(xs) >= 5 : preprocess - return map(x -> x === missing ? x : QGramSortedVector(x, dist.q), xs) - else - return xs - end -end -_preprocess(xs, dist::StringDistance, preprocess) = xs - - function _symmetric_pairwise!(R::AbstractMatrix, dist::StringDistance, xs::AbstractVector; preprocess = nothing) objs = _preprocess(xs, dist, preprocess) for i in 1:length(objs) @@ -97,7 +70,7 @@ end function _asymmetric_pairwise!(R::AbstractMatrix, dist::StringDistance, xs::AbstractVector, ys::AbstractVector; preprocess = nothing) objsxs = _preprocess(xs, dist, preprocess) - objsys = _preprocess(ys, dist, preprocess) + objsys = xs === ys ? objsxs : _preprocess(ys, dist, preprocess) for i in 1:length(objsxs) Threads.@threads for j in 1:length(objsys) R[i, j] = evaluate(dist, objsxs[i], objsys[j]) @@ -105,3 +78,12 @@ function _asymmetric_pairwise!(R::AbstractMatrix, dist::StringDistance, xs::Abst end return R end + +function _preprocess(xs, dist::AbstractQGramDistance, preprocess) + if preprocess === nothing ? length(xs) >= 5 : preprocess + return map(x -> x === missing ? x : QGramSortedVector(x, dist.q), xs) + else + return xs + end +end +_preprocess(xs, dist::StringDistance, preprocess) = xs