Base LSH and Multi Probe LSH example

Download color histograms of the flick30k dataset here.

[1]:
import numpy as np
from scipy.spatial.distance import cdist
from floky import L2
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import time

Data preparation

First we load the data in numpy. Next we compute the real \(N\) nearest neighbors with scipy.spatial.distance.cdist.

From these \(N\) distance results we compute the mean and determine the top k results. Next we scale the data by \(R\). This makes it easier to verify if the LSH algorithm can find nearest neighbors. If we scale the data by \(\frac{1}{R}\) we expect the exact Nearest Neighbor to have a distance smaller than 1. If this isn’t the case, we need to choose another distance \(R\).

[2]:
with open("flickr30k_histograms.csv") as f:
    a = np.loadtxt(f, delimiter=",")
[3]:
# We will do N queries and compute recall and query times.
N = 100
[23]:
# Find the exact nearest neighbors. This is needed to compute recall.
t0 = time.time_ns()
dist = cdist(a[:N], a)
# ms
exact_duration = (time.time_ns() - t0) / 1e6
exact_duration
[23]:
1808.042179
[5]:
# non trivial top 1
# we skip the first as that is the query point itself
top_k = dist.argsort(1)[:, 1:2]
mean = dist.mean()
top_k_dist = dist[np.arange(N)[:, None], top_k]
[6]:
# Scale data by distance. So scaled R will be 1.
R = mean / 2.5
a /= R
dist /= R
top_k_dist /= R
R
[6]:
12717.77025887411
[7]:
# Check if real nearest neigbors are < R = 1
print("{}% < R".format((top_k_dist < 1).sum() / (top_k_dist.shape[0] * top_k_dist.shape[1]) * 100))
top_k_dist[:10]
83.0% < R
[7]:
array([[0.99372539],
       [0.45435497],
       [0.79676334],
       [1.14787659],
       [0.78890876],
       [0.63275089],
       [0.58949666],
       [0.99201873],
       [1.52371323],
       [1.61113221]])

Comparison Query / Preprocessing duration and Recall

Below we’ll examine the impact of the query duration on the recall.

We take a look at two k (# of values in the hash) values: * 15 * 30

For Base LSH we increase the numebr of hash tables to increase the recall. For Multi-probe LSH we increase the number of probes we execute. We will keep the number of hash tables constant to only 5.

[8]:
def cum_mov_avg(x, avg, n):
    return (x + n * avg) / (n + 1)

def recall(k, L):
    dim = len(a[0])
    lsh = L2(k, L, dim, in_mem=True)

    t0 = time.time()
    lsh.fit(a)
    fit_duration = time.time() - t0

    t0 = time.time_ns()
    p = lsh.predict(a[:N], only_index=True, top_k=6);
    predict_duration = time.time_ns() - t0

    c = 0
    avg_collisions = 0
    for i, pi in enumerate(p):
        if pi.n_collisions == 1:
            continue
        idx = set(pi.index[1:])
        if len(idx.intersection(top_k[i])) > 0:
            c += 1
        avg_collisions = cum_mov_avg(pi.n_collisions, avg_collisions, i)

    return c / N, avg_collisions, fit_duration, predict_duration

ks = []
Ls = []
recalls = []
avg_cs = []
duration_fit = []
duration_predict = []
for k in [15, 30]:
    for L in [5, 10, 15, 20, 50, 100]:
        ks.append(k)
        Ls.append(L)

        r, avg_collision, fit_duration, predict_duration = recall(k, L)
        duration_fit.append(fit_duration)
        duration_predict.append(predict_duration)
        recalls.append(r)
        avg_cs.append(avg_collision)
32000it [00:00, 60119.95it/s]
32000it [00:01, 30086.65it/s]
32000it [00:01, 21003.48it/s]
32000it [00:01, 16029.99it/s]
32000it [00:04, 6494.46it/s]
32000it [00:10, 2953.71it/s]
32000it [00:01, 17412.80it/s]
32000it [00:03, 9178.54it/s]
32000it [00:05, 6099.98it/s]
32000it [00:06, 5050.17it/s]
32000it [00:15, 2131.61it/s]
32000it [00:31, 1020.02it/s]
[9]:
df = pd.DataFrame({"recall": recalls,
             "avg_collisions": avg_cs,
             "L": Ls,
             "K": ks,
             "duration_fit": duration_fit,
             "duration_predict": duration_predict
             })
df
[9]:
recall avg_collisions L K duration_fit duration_predict
0 0.37 715.617084 5 15 0.559304 58695551
1 0.58 1187.012150 10 15 1.106947 79070180
2 0.76 2370.829281 15 15 1.546288 148368772
3 0.74 1914.947429 20 15 2.016144 135169832
4 0.91 4319.069349 50 15 4.948727 256704612
5 0.95 6013.273606 100 15 10.858258 407418596
6 0.14 30.292026 5 30 1.858710 12629968
7 0.26 188.690972 10 30 3.517005 31498599
8 0.30 77.511316 15 30 5.282044 21663692
9 0.37 176.957555 20 30 6.364206 26374739
10 0.44 226.344361 50 30 15.053674 55477450
11 0.63 409.212477 100 30 31.411575 85154231
[10]:
def recall_multi_probe(k, budget, lsh):
    lsh.multi_probe(budget)
    t0 = time.time_ns()
    p = lsh.predict(a[:N], only_index=True, top_k=6);
    predict_duration = time.time_ns() - t0

    c = 0
    avg_collisions = 0
    for i, pi in enumerate(p):
        if pi.n_collisions == 1:
            continue
        idx = set(pi.index[1:])
        if len(idx.intersection(top_k[i])) > 0:
            c += 1
        avg_collisions = cum_mov_avg(pi.n_collisions, avg_collisions, i)

    return c / N, avg_collisions, predict_duration

ks = []
recalls = []
avg_cs = []
probes = []
duration_fit = []
duration_predict = []
for k in [15, 30]:
    dim = len(a[0])

    t0 = time.time()
    lsh = L2(k, 5, dim, in_mem=True)
    fit_duration = time.time() - t0
    lsh.fit(a)

    for probe in [10, 20, 15, 20, 50, 100]:
        ks.append(k)
        probes.append(probe)

        r, avg_collision, predict_duration = recall_multi_probe(k, probe, lsh)
        duration_predict.append(predict_duration)
        duration_fit.append(fit_duration)
        recalls.append(r)
        avg_cs.append(avg_collision)
32000it [00:00, 37000.65it/s]
32000it [00:01, 19196.08it/s]
[11]:
df_mp = pd.DataFrame({"recall": recalls,
             "avg_collisions": avg_cs,
              "probes": probes,
             "K": ks,
                                   "duration_fit": duration_fit,
             "duration_predict": duration_predict
             })
df_mp
[11]:
recall avg_collisions probes K duration_fit duration_predict
0 0.84 3568.653046 10 15 0.000322 218376250
1 0.88 4608.842829 20 15 0.000322 295503006
2 0.85 4171.139471 15 15 0.000322 261069746
3 0.88 4608.842829 20 15 0.000322 297343177
4 0.91 6259.458000 50 15 0.000322 720279473
5 0.93 7812.191200 100 15 0.000322 1141625211
6 0.36 263.895373 10 30 0.006047 27857108
7 0.46 359.867660 20 30 0.006047 31135252
8 0.43 311.102990 15 30 0.006047 28399008
9 0.46 359.867660 20 30 0.006047 32754926
10 0.60 554.988570 50 30 0.006047 69406715
11 0.65 785.715304 100 30 0.006047 128893417
[38]:
fig, ax = plt.subplots(figsize=(20, 6), nrows=2, ncols=3)

for i, (k, df_) in enumerate(df.groupby("K")):
    color = f"C{i}"
    marker = "^"
    ax[0, 0].plot(df_.L, df_.recall, c=color, marker=marker, label=f"k = {k}")
    ax[0, 1].plot(df_.L, df_.duration_predict / 1e6, c=color, marker=marker)
    ax[0, 2].plot(df_.L, df_.duration_fit, c=color, marker=marker)

for i, (k, df_) in enumerate(df_mp.groupby("K")):
    color = f"C{i}"
    ax[1, 0].plot(df_.probes, df_.recall, c=color, marker=marker, label=f"k = {k}")
    ax[1, 1].plot(df_.probes, df_.duration_predict / 1e6, c=color, marker=marker)
    ax[1, 2].plot(df_.probes, df_.duration_fit, c=color, marker=marker)

ax[0, 0].legend()
ax[1, 0].legend()

ax[0, 1].axhline(exact_duration, c="black", label="exact search")
ax[1, 1].axhline(exact_duration, c="black", label="exact search")
ax[0, 1].legend()
ax[1, 1].legend()


plt.xlabel("L")
ax[0, 0].set_ylabel("recall")
ax[0, 0].set_ylim(0, 1)
ax[0, 0].set_xlabel("L hashtables")
ax[0, 1].set_xlabel("L hashtables")
ax[0, 1].set_ylabel("query duration [ms]")
ax[0, 1].set_ylim(0, exact_duration * 1.05)
ax[0, 2].set_ylabel("fit duration [s]")
ax[0, 2].set_xlabel("L hashtables")

ax[1, 0].set_ylabel("recall")
ax[1, 0].set_xlabel("# probes")
ax[1, 0].set_ylim(0, 1)
ax[1, 0].set_xscale("log")
ax[1, 0].xaxis.set_major_formatter(ScalarFormatter())
ax[1, 1].set_xlabel("# probes")
ax[1, 1].set_ylabel("query duration [ms]")
ax[1, 1].xaxis.set_major_formatter(ScalarFormatter())
ax[1, 1].yaxis.set_major_formatter(ScalarFormatter())
ax[1, 1].set_ylim(0, exact_duration * 1.05)

ax[1, 2].set_ylabel("fit duration [s]")
ax[1, 2].set_xlabel("# probes")

plt.show()
../_images/sections_LSH_recall_14_0.png