parametric
matthieugomez 2021-09-06 10:24:16 -04:00
parent 2a8c0b97ef
commit 3751aa2be5
1 changed files with 56 additions and 56 deletions

View File

@ -18,26 +18,26 @@ evaluate(Overlap(2), qd1, qd2)
``` ```
""" """
struct QGramDict{S, K} struct QGramDict{S, K}
s::S s::S
q::Int q::Int
counts::Dict{K, Int} counts::Dict{K, Int}
end end
Base.length(s::QGramDict) = length(s.s) Base.length(s::QGramDict) = length(s.s)
Base.iterate(s::QGramDict) = iterate(s.s) Base.iterate(s::QGramDict) = iterate(s.s)
Base.iterate(s::QGramDict, state) = iterate(s.s, state) Base.iterate(s::QGramDict, state) = iterate(s.s, state)
function QGramDict(s, q::Integer = 2) function QGramDict(s, q::Integer = 2)
(s isa QGramDict) && (s.q == q) && return s (s isa QGramDict) && (s.q == q) && return s
qgs = qgrams(s, q) qgs = qgrams(s, q)
QGramDict{typeof(s), eltype(qgs)}(s, q, countdict(qgs)) QGramDict{typeof(s), eltype(qgs)}(s, q, countdict(qgs))
end end
# Turn a sequence of qgrams to a count dict for them, i.e. map each # Turn a sequence of qgrams to a count dict for them, i.e. map each
# qgram to the number of times it has been seen. # qgram to the number of times it has been seen.
function countdict(qgrams) function countdict(qgrams)
d = Dict{eltype(qgrams), Int}() d = Dict{eltype(qgrams), Int}()
for qg in qgrams for qg in qgrams
index = Base.ht_keyindex2!(d, qg) index = Base.ht_keyindex2!(d, qg)
if index > 0 if index > 0
d.age += 1 d.age += 1
@inbounds d.keys[index] = qg @inbounds d.keys[index] = qg
@ -45,29 +45,29 @@ function countdict(qgrams)
else else
@inbounds Base._setindex!(d, 1, qg, -index) @inbounds Base._setindex!(d, 1, qg, -index)
end end
end end
d d
end end
function (dist::AbstractQGramDistance)(qc1::QGramDict, qc2::QGramDict) function (dist::AbstractQGramDistance)(qc1::QGramDict, qc2::QGramDict)
dist.q == qc1.q == qc2.q || throw(ArgumentError("The distance and the QGramDict must have the same qgram length")) dist.q == qc1.q == qc2.q || throw(ArgumentError("The distance and the QGramDict must have the same qgram length"))
counter = eval_start(dist) counter = eval_start(dist)
d1, d2 = qc1.counts, qc2.counts d1, d2 = qc1.counts, qc2.counts
for (s1, n1) in d1 for (s1, n1) in d1
index = Base.ht_keyindex2!(d2, s1) index = Base.ht_keyindex2!(d2, s1)
if index > 0 if index > 0
counter = eval_op(dist, counter, n1, d2.vals[index]) counter = eval_op(dist, counter, n1, d2.vals[index])
else else
counter = eval_op(dist, counter, n1, 0) counter = eval_op(dist, counter, n1, 0)
end end
end end
for (s2, n2) in d2 for (s2, n2) in d2
index = Base.ht_keyindex2!(d1, s2) index = Base.ht_keyindex2!(d1, s2)
if index <= 0 if index <= 0
counter = eval_op(dist, counter, 0, n2) counter = eval_op(dist, counter, 0, n2)
end end
end end
eval_reduce(dist, counter) eval_reduce(dist, counter)
end end
""" """
@ -94,20 +94,20 @@ evaluate(Jaccard(2), qs1, qs2)
``` ```
""" """
struct QGramSortedVector{S, K} struct QGramSortedVector{S, K}
s::S s::S
q::Int q::Int
counts::Vector{Pair{K, Int}} counts::Vector{Pair{K, Int}}
end end
Base.length(s::QGramSortedVector) = length(s.s) Base.length(s::QGramSortedVector) = length(s.s)
Base.iterate(s::QGramSortedVector) = iterate(s.s) Base.iterate(s::QGramSortedVector) = iterate(s.s)
Base.iterate(s::QGramSortedVector, state) = iterate(s.s, state) Base.iterate(s::QGramSortedVector, state) = iterate(s.s, state)
function QGramSortedVector(s, q::Integer = 2) function QGramSortedVector(s, q::Integer = 2)
(s isa QGramSortedVector) && (s.q == q) && return s (s isa QGramSortedVector) && (s.q == q) && return s
qgs = qgrams(s, q) qgs = qgrams(s, q)
countpairs = collect(countdict(qgs)) countpairs = collect(countdict(qgs))
sort!(countpairs, by = first) sort!(countpairs, by = first)
QGramSortedVector{typeof(s), eltype(qgs)}(s, q, countpairs) QGramSortedVector{typeof(s), eltype(qgs)}(s, q, countpairs)
end end
@ -117,38 +117,38 @@ end
# The abstract type defines different fallback versions which can be # The abstract type defines different fallback versions which can be
# specialied by subtypes for best performance. # specialied by subtypes for best performance.
function (dist::AbstractQGramDistance)(qc1::QGramSortedVector, qc2::QGramSortedVector) function (dist::AbstractQGramDistance)(qc1::QGramSortedVector, qc2::QGramSortedVector)
dist.q == qc1.q == qc2.q || throw(ArgumentError("The distance and the QGramSortedVectors must have the same qgram length")) dist.q == qc1.q == qc2.q || throw(ArgumentError("The distance and the QGramSortedVectors must have the same qgram length"))
counter = eval_start(dist) counter = eval_start(dist)
d1, d2 = qc1.counts, qc2.counts d1, d2 = qc1.counts, qc2.counts
i1 = i2 = 1 i1 = i2 = 1
while true while true
# length can be zero # length can be zero
if i2 > length(d2) if i2 > length(d2)
for i in i1:length(d1) for i in i1:length(d1)
@inbounds counter = eval_op(dist, counter, d1[i][2], 0) @inbounds counter = eval_op(dist, counter, d1[i][2], 0)
end end
break break
elseif i1 > length(d1) elseif i1 > length(d1)
for i in i2:length(d2) for i in i2:length(d2)
@inbounds counter = eval_op(dist, counter, 0, d2[i][2]) @inbounds counter = eval_op(dist, counter, 0, d2[i][2])
end end
break break
end end
@inbounds s1, n1 = d1[i1] @inbounds s1, n1 = d1[i1]
@inbounds s2, n2 = d2[i2] @inbounds s2, n2 = d2[i2]
cmpval = Base.cmp(s1, s2) cmpval = Base.cmp(s1, s2)
if cmpval == -1 # k1 < k2 if cmpval == -1 # k1 < k2
counter = eval_op(dist, counter, n1, 0) counter = eval_op(dist, counter, n1, 0)
i1 += 1 i1 += 1
elseif cmpval == 1 # k2 < k1 elseif cmpval == 1 # k2 < k1
counter = eval_op(dist, counter, 0, n2) counter = eval_op(dist, counter, 0, n2)
i2 += 1 i2 += 1
else else
counter = eval_op(dist, counter, n1, n2) counter = eval_op(dist, counter, n1, n2)
i1 += 1 i1 += 1
i2 += 1 i2 += 1
end end
end end
eval_reduce(dist, counter) eval_reduce(dist, counter)
end end