pairwise for calculating distance matrices (#38)
parent
c7728160bf
commit
a0e5347d8c
|
@ -11,7 +11,7 @@ const StringDistance = Union{Jaro, Levenshtein, DamerauLevenshtein, RatcliffOber
|
|||
# Distances API
|
||||
Distances.result_type(dist::StringDistance, s1, s2) = typeof(dist("", ""))
|
||||
include("find.jl")
|
||||
|
||||
include("pairwise.jl")
|
||||
|
||||
##############################################################################
|
||||
##
|
||||
|
@ -42,6 +42,7 @@ compare,
|
|||
result_type,
|
||||
qgrams,
|
||||
normalize,
|
||||
findnearest
|
||||
findnearest,
|
||||
pairwise
|
||||
end
|
||||
|
||||
|
|
|
@ -0,0 +1,89 @@
|
|||
_allocmatrix(X, Y, T) = Matrix{T}(undef, length(X), length(Y))
|
||||
_allocmatrix(X, T) = Matrix{T}(undef, length(X), length(X))
|
||||
|
||||
@doc """
|
||||
pairwise(dist::StringDistance, itr; eltype = Float64, preprocess = nothing)
|
||||
pairwise(dist::StringDistance, itr1, itr2; eltype = Float64, preprocess = nothing)
|
||||
|
||||
Compute distances between all pairs of elements in `itr` according to the
|
||||
`StringDistance` `dist`. The element type of the returned distance matrix
|
||||
can be set via `eltype`.
|
||||
|
||||
For QGramDistances preprocessing will be used either if `preprocess` is set
|
||||
to true or if there are more than 5 elements in `itr`. Set `preprocess` to
|
||||
false if no preprocessing should be used, regardless of length.
|
||||
|
||||
Both symmetric and asymmetric versions are available.
|
||||
|
||||
### Examples
|
||||
```julia-repl
|
||||
julia> using StringDistances
|
||||
julia> iter = ["New York", "Princeton"]
|
||||
julia> pairwise(Levenshtein(), iter) # symmetric
|
||||
2×2 Array{Float64,2}:
|
||||
0.0 9.0
|
||||
9.0 0.0
|
||||
julia> iter2 = ["San Francisco"]
|
||||
julia> pairwise(Levenshtein(), iter, iter2) # asymmetric
|
||||
2×1 Array{Float64,2}:
|
||||
12.0
|
||||
10.0
|
||||
```
|
||||
"""
|
||||
Distances.pairwise
|
||||
|
||||
Distances.pairwise(dist::StringDistance, X, Y; eltype = Float64, preprocess = nothing) =
|
||||
pairwise!(_allocmatrix(X, Y, eltype), dist, X, Y; preprocess = preprocess)
|
||||
|
||||
Distances.pairwise(dist::StringDistance, X; eltype = Float64, preprocess = nothing) =
|
||||
pairwise!(_allocmatrix(X, eltype), dist, X; preprocess = preprocess)
|
||||
|
||||
pairwise!(R::AbstractMatrix{N}, dist::StringDistance, X; preprocess = nothing) where {N<:Number} =
|
||||
(dist isa SemiMetric) ?
|
||||
_symmetric_pairwise!(R, dist, X; preprocess = preprocess) :
|
||||
_asymmetric_pairwise!(R, dist, X, X; preprocess = preprocess)
|
||||
|
||||
pairwise!(R::AbstractMatrix{N}, dist::StringDistance, X, Y; preprocess = nothing) where {N<:Number} =
|
||||
_asymmetric_pairwise!(R, dist, X, Y; preprocess = preprocess)
|
||||
|
||||
_preprocess(X, PT, q) = PT[PT(X[i], q) for i in 1:length(X)]
|
||||
|
||||
const PrecalcMinLength = 5 # Only precalc if length >= 5
|
||||
|
||||
preprocess_if_needed(X, dist::StringDistance, preprocess, preprocessType) = X
|
||||
|
||||
function preprocess_if_needed(X, dist::QGramDistance, preprocess, preprocessType)
|
||||
# preprocess only if a QGramDistance and
|
||||
# if precalc set to true or if isnothing and length is at least min length
|
||||
cond = (preprocess === true) ||
|
||||
(isnothing(preprocess) && length(X) >= PrecalcMinLength)
|
||||
cond ? _preprocess(X, preprocessType, dist.q) : X
|
||||
end
|
||||
|
||||
function _symmetric_pairwise!(R, dist::StringDistance, X;
|
||||
preprocess = nothing, preprocessType = QGramSortedVector)
|
||||
|
||||
objs = preprocess_if_needed(X, dist, preprocess, preprocessType)
|
||||
|
||||
for i in 1:length(objs)
|
||||
R[i, i] = 0
|
||||
Threads.@threads for j in (i+1):length(objs)
|
||||
R[i, j] = R[j, i] = evaluate(dist, objs[i], objs[j])
|
||||
end
|
||||
end
|
||||
return R
|
||||
end
|
||||
|
||||
function _asymmetric_pairwise!(R, dist::StringDistance, X, Y;
|
||||
preprocess = nothing, preprocessType = QGramSortedVector)
|
||||
|
||||
objsX = preprocess_if_needed(X, dist, preprocess, preprocessType)
|
||||
objsY = preprocess_if_needed(Y, dist, preprocess, preprocessType)
|
||||
|
||||
for i in 1:length(objsX)
|
||||
Threads.@threads for j in 1:length(objsY)
|
||||
R[i, j] = evaluate(dist, objsX[i], objsY[j])
|
||||
end
|
||||
end
|
||||
return R
|
||||
end
|
|
@ -0,0 +1,88 @@
|
|||
using StringDistances, Unicode, Test, Random
|
||||
using StringDistances: pairwise, pairwise!, QGramDistance
|
||||
|
||||
@testset "pairwise" begin
|
||||
|
||||
TestStrings1 = ["", "abc", "bc", "kitten"]
|
||||
TestStrings2 = ["mew", "ab"]
|
||||
|
||||
@testset "pairwise" begin
|
||||
for DT in [Jaro, Levenshtein, DamerauLevenshtein, RatcliffObershelp,
|
||||
QGram, Cosine, Jaccard, SorensenDice, Overlap]
|
||||
|
||||
d = (DT <: QGramDistance) ? DT(2) : DT()
|
||||
R = pairwise(d, TestStrings1)
|
||||
|
||||
@test R isa Matrix{Float64}
|
||||
@test size(R) == (4, 4)
|
||||
|
||||
# No distance on the diagonal, since comparing strings to themselves
|
||||
@test R[1, 1] == 0.0
|
||||
@test R[2, 2] == 0.0
|
||||
@test R[3, 3] == 0.0
|
||||
@test R[4, 4] == 0.0
|
||||
|
||||
# Since the distance might be NaN:
|
||||
equalorNaN(x, y) = (x == y) || (isnan(x) && isnan(y))
|
||||
|
||||
# First row is comparing "" to the other strings, so:
|
||||
@test equalorNaN(R[1, 2], evaluate(d, "", "abc"))
|
||||
@test equalorNaN(R[1, 3], evaluate(d, "", "bc"))
|
||||
@test equalorNaN(R[1, 4], evaluate(d, "", "kitten"))
|
||||
|
||||
# Second row is comparing "abc" to the other strings, so:
|
||||
@test equalorNaN(R[2, 3], evaluate(d, "abc", "bc"))
|
||||
@test equalorNaN(R[2, 4], evaluate(d, "abc", "kitten"))
|
||||
|
||||
# Third row row is comparing "bc" to the other strings, so:
|
||||
@test equalorNaN(R[3, 4], evaluate(d, "bc", "kitten"))
|
||||
|
||||
# Matrix is symmetric
|
||||
for i in 1:4
|
||||
for j in (i+1):4
|
||||
@test equalorNaN(R[i, j], R[j, i])
|
||||
end
|
||||
end
|
||||
|
||||
# Test also the assymetric version
|
||||
R2 = pairwise(d, TestStrings1, TestStrings2)
|
||||
@test R2 isa Matrix{Float64}
|
||||
@test size(R2) == (4, 2)
|
||||
|
||||
@test equalorNaN(R2[1, 1], evaluate(d, "", "mew"))
|
||||
@test equalorNaN(R2[1, 2], evaluate(d, "", "ab"))
|
||||
|
||||
@test equalorNaN(R2[2, 1], evaluate(d, "abc", "mew"))
|
||||
@test equalorNaN(R2[2, 2], evaluate(d, "abc", "ab"))
|
||||
|
||||
@test equalorNaN(R2[3, 1], evaluate(d, "bc", "mew"))
|
||||
@test equalorNaN(R2[3, 2], evaluate(d, "bc", "ab"))
|
||||
|
||||
@test equalorNaN(R2[4, 1], evaluate(d, "kitten", "mew"))
|
||||
@test equalorNaN(R2[4, 2], evaluate(d, "kitten", "ab"))
|
||||
|
||||
R3 = pairwise(d, TestStrings2, TestStrings1)
|
||||
@test R3 isa Matrix{Float64}
|
||||
@test size(R3) == (2, 4)
|
||||
|
||||
for i in 1:length(TestStrings1)
|
||||
for j in 1:length(TestStrings2)
|
||||
@test equalorNaN(R2[i, j], R3[j, i])
|
||||
end
|
||||
end
|
||||
|
||||
# Ensure same result if preprocessing for QGramDistances
|
||||
if DT <: QGramDistance
|
||||
R4 = pairwise(d, TestStrings1; preprocess = true)
|
||||
@test typeof(R4) == typeof(R)
|
||||
@test size(R4) == size(R)
|
||||
for i in 1:size(R4, 1)
|
||||
for j in 1:size(R4, 2)
|
||||
@test equalorNaN(R4[i, j], R[i, j])
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
end
|
|
@ -0,0 +1,65 @@
|
|||
using StringDistances, Random
|
||||
using BenchmarkTools
|
||||
|
||||
N = if length(ARGS) > 0
|
||||
try
|
||||
parse(Int, ARGS[1])
|
||||
catch _
|
||||
100
|
||||
end
|
||||
else
|
||||
100 # default value
|
||||
end
|
||||
|
||||
Maxlength = if length(ARGS) > 1
|
||||
try
|
||||
parse(Int, ARGS[2])
|
||||
catch _
|
||||
100
|
||||
end
|
||||
else
|
||||
100 # default value
|
||||
end
|
||||
|
||||
# If there are strings already cached to disk we start with them and only
|
||||
# add new ones if needed.
|
||||
using Serialization
|
||||
const CacheFile = joinpath(@__DIR__(), "perfteststrings_$(Maxlength).juliabin")
|
||||
SaveCache = false
|
||||
|
||||
S = if isfile(CacheFile)
|
||||
try
|
||||
res = deserialize(CacheFile)
|
||||
println("Read $(length(res)) strings from cache file: $CacheFile")
|
||||
res
|
||||
catch err
|
||||
String[]
|
||||
end
|
||||
else
|
||||
println("Creating $N random strings.")
|
||||
SaveCache = true
|
||||
String[randstring(rand(3:Maxlength)) for _ in 1:N]
|
||||
end
|
||||
|
||||
if length(S) < N
|
||||
for i in (length(S)+1):N
|
||||
push!(S, randstring(rand(3:Maxlength)))
|
||||
end
|
||||
SaveCache = true
|
||||
end
|
||||
|
||||
if SaveCache
|
||||
println("Saving cache file with $(length(S)) strings: $CacheFile")
|
||||
serialize(CacheFile, S)
|
||||
end
|
||||
|
||||
|
||||
println("For ", Threads.nthreads(), " threads and ", N, " strings of max length ", Maxlength, ":")
|
||||
|
||||
dist = Cosine(2)
|
||||
t1 = @belapsed dm1 = pairwise(dist, S; preprocess = false)
|
||||
t2 = @belapsed dm2 = pairwise(dist, S; preprocess = true)
|
||||
|
||||
println(" - time WITHOUT pre-calculation: ", round(t1, digits = 3))
|
||||
println(" - time WITH pre-calculation: ", round(t2, digits = 3))
|
||||
println(" - speedup with pre-calculation: ", round(t1/t2, digits = 1))
|
|
@ -3,3 +3,4 @@ using Test
|
|||
|
||||
include("distances.jl")
|
||||
include("modifiers.jl")
|
||||
include("pairwise.jl")
|
Loading…
Reference in New Issue