教程 | 如何在Python中快速进行语料库搜索:近似最近邻算法

选自Medium

作者:Kevin Yang

机器之心编译

参与:路雪



最近,我一直在研究在 GloVe 词嵌入中做加减法。例如,我们可以把「king」的词嵌入向量减去「man」的词嵌入向量,随后加入「woman」的词嵌入得到一个结果向量。随后,如果我们有这些词嵌入对应的语料库,那么我们可以通过搜索找到最相似的嵌入并检索相应的词。如果我们做了这样的查询,我们会得到:

  • King + (Woman - Man) = Queen

  • 我们有很多方法来搜索语料库中词嵌入对作为最近邻查询方式。绝对可以确保找到最优向量的方式是遍历你的语料库,比较每个对与查询需求的相似程度——这当然是耗费时间且不推荐的。一个更好的技术是使用向量化余弦距离方式,如下所示:

  • vectors = np.array(embeddingmodel.embeddings)

  • ranks = np.dot(query,vectors.T)/np.sqrt(np.sum(vectors**2,1))

  • mostSimilar = []

  • [mostSimilar.append(idx) for idx in ranks.argsort()[::-1]]





  • 想要了解余弦距离,可以看看这篇文章:http://masongallo.github.io/machine/learning,/python/2016/07/29/cosine-similarity.html

    矢量化的余弦距离比迭代法快得多,但速度可能太慢。是近似最近邻搜索算法该出现时候了:它可以快速返回近似结果。很多时候你并不需要准确的最佳结果,例如:「Queen」这个单词的同义词是什么?在这种情况下,你只需要快速得到足够好的结果,你需要使用近似最近邻搜索算法。

    在本文中,我们将会介绍一个简单的 Python 脚本来快速找到近似最近邻。我们会使用的 Python 库是 Annoy 和 Imdb。对于我的语料库,我会使用词嵌入对,但该说明实际上适用于任何类型的嵌入:如音乐推荐引擎需要用到的歌曲嵌入,甚至以图搜图中的图片嵌入。

    制作一个索引

    让我们创建一个名为:「make_annoy_index」的 Python 脚本。首先我们需要加入用得到的依赖项:

  • """

  • Usage: python2 make_annoy_index.py \

  •    --embeddings=<embedding path> \

  •    --num_trees=<int> \

  •    --verbose

  • Generate an Annoy index and lmdb map given an embedding file

  • Embedding file can be

  •  1. A .bin file that is compatible with word2vec binary formats.

  •     There are pre-trained vectors to download at http://code.google.com/p/word2vec/

  •  2. A .gz file with the GloVe format (item then a list of floats in plaintext)

  •  3. A plain text file with the same format as above

  • """

  • import annoy

  • import lmdb

  • import os

  • import sys

  • import argparse

  • from vector_utils import get_vectors

  • 最后一行里非常重要的是「vector_utils」。稍后我们会写「vector_utils」,所以不必担心。

    接下来,让我们丰富这个脚本:加入「creat_index」函数。这里我们将生成 lmdb 图和 Annoy 索引。

    1. 首先需要找到嵌入的长度,它会被用来做实例化 Annoy 的索引。

    2. 接下来实例化一个 Imdb 图,使用:「env = lmdb.open(fn_lmdb, map_size=int(1e9))」。

    3. 确保我们在当前路径中没有 Annoy 索引或 lmdb 图。

    4. 将嵌入文件中的每一个 key 和向量添加至 lmdb 图和 Annoy 索引。

    5. 构建和保存 Annoy 索引。

  • """

  • function create_index(fn, num_trees=30, verbose=False)

  • -------------------------------

  • Creates an Annoy index and lmdb map given an embedding file fn

  • Input:

  •    fn              - filename of the embedding file

  •    num_trees       - number of trees to build Annoy index with

  •    verbose         - log status

  • Return:

  •    Void

  • """

  • def create_index(fn, num_trees=30, verbose=False):

  •    fn_annoy = fn + ".annoy"

  •    fn_lmdb = fn + ".lmdb" # stores word <-> id mapping

  •    word, vec = get_vectors(fn).next()

  •    size = len(vec)

  •    if verbose:

  •        print("Vector size: {}".format(size))

  •    env = lmdb.open(fn_lmdb, map_size=int(1e9))

  •    if not os.path.exists(fn_annoy) or not os.path.exists(fn_lmdb):

  •        i = 0

  •        a = annoy.AnnoyIndex(size)

  •        with env.begin(write=True) as txn:

  •            for word, vec in get_vectors(fn):

  •                a.add_item(i, vec)

  •                id = "i%d" % i

  •                word = "w" + word

  •                txn.put(id, word)

  •                txn.put(word, id)

  •                i += 1

  •                if verbose:

  •                    if i % 1000 == 0:

  •                        print(i, "...")

  •        if verbose:

  •            print("Starting to build")

  •        a.build(num_trees)

  •        if verbose:

  •            print("Finished building")

  •        a.save(fn_annoy)

  •        if verbose:

  •            print("Annoy index saved to: {}".format(fn_annoy))

  •            print("lmdb map saved to: {}".format(fn_lmdb))

  •    else:

  •        print("Annoy index and lmdb map already in path")

  • 我已经推断出 argparse,因此,我们可以利用命令行启用我们的脚本:

  • """

  • private function _create_args()

  • -------------------------------

  • Creates an argeparse object for CLI for create_index() function

  • Input:

  •    Void

  • Return:

  •    args object with required arguments for threshold_image() function

  • """

  • def _create_args():

  •    parser = argparse.ArgumentParser()

  •    parser.add_argument("--embeddings", help="filename of the embeddings", type=str)

  •    parser.add_argument("--num_trees", help="number of trees to build index with", type=int)

  •    parser.add_argument("--verbose", help="print logging", action="store_true")

  •    args = parser.parse_args()

  •    return args

  • 添加主函数以启用脚本,得到 make_annoy_index.py:

  • if __name__ == "__main__":

  •    args = _create_args()

  •    create_index(args.embeddings, num_trees=args.num_trees, verbose=args.verbose)



  • 现在我们可以仅利用命令行启用新脚本,以生成 Annoy 索引和对应的 lmdb 图!

  • python2 make_annoy_index.py \

  •    --embeddings=<embedding path> \

  •    --num_trees=<int> \

  •    --verbose

  • 写向 量Utils

    我们在 make_annoy_index.py 中推导出 Python 脚本 vector_utils。现在要写该脚本,Vector_utils 用于帮助读取.txt, .bin 和 .pkl 文件中的向量。

    写该脚本与我们现在在做的不那么相关,因此我已经推导出整个脚本,如下:

  • """

  • Vector Utils

  • Utils to read in vectors from txt, .bin, or .pkl.

  • Taken from Erik Bernhardsson

  • Source: http://github.com/erikbern/ann-presentation/blob/master/util.py

  • """

  • import gzip

  • import struct

  • import cPickle

  • def _get_vectors(fn):

  •    if fn.endswith(".gz"):

  •        f = gzip.open(fn)

  •        fn = fn[:-3]

  •    else:

  •        f = open(fn)

  •    if fn.endswith(".bin"): # word2vec format

  •        words, size = (int(x) for x in f.readline().strip().split())

  •        t = "f" * size

  •        while True:

  •            pos = f.tell()

  •            buf = f.read(1024)

  •            if buf == "" or buf == "\n": return

  •            i = buf.index(" ")

  •            word = buf[:i]

  •            f.seek(pos + i + 1)

  •            vec = struct.unpack(t, f.read(4 * size))

  •            yield word.lower(), vec

  •    elif fn.endswith(".txt"): # Assume simple text format

  •        for line in f:

  •            items = line.strip().split()

  •            yield items[0], [float(x) for x in items[1:]]

  •    elif fn.endswith(".pkl"): # Assume pickle (MNIST)

  •        i = 0

  •        for pics, labels in cPickle.load(f):

  •            for pic in pics:

  •                yield i, pic

  •                i += 1

  • def get_vectors(fn, n=float("inf")):

  •    i = 0

  •    for line in _get_vectors(fn):

  •        yield line

  •        i += 1

  •        if i >= n:

  •            break



  • 测试 Annoy 索引和 lmdb 图

    我们已经生成了 Annoy 索引和 lmdb 图,现在我们来写一个脚本使用它们进行推断。

    将我们的文件命名为 annoy_inference.py,得到下列依赖项:

  • """

  • Usage: python2 annoy_inference.py \

  •    --token="hello" \

  •    --num_results=<int> \

  •    --verbose

  • Query an Annoy index to find approximate nearest neighbors

  • """

  • import annoy

  • import lmdb

  • import argparse



  • 现在我们需要在 Annoy 索引和 lmdb 图中加载依赖项,我们将进行全局加载,以方便访问。注意,这里设置的 VEC_LENGTH 为 50。确保你的 VEC_LENGTH 与嵌入长度匹配,否则 Annoy 会不开心的哦~

  • VEC_LENGTH = 50

  • FN_ANNOY = "glove.6B.50d.txt.annoy"

  • FN_LMDB = "glove.6B.50d.txt.lmdb"

  • a = annoy.AnnoyIndex(VEC_LENGTH)

  • a.load(FN_ANNOY)

  • env = lmdb.open(FN_LMDB, map_size=int(1e9))

  • 有趣的部分在于「calculate」函数。

    1. 从 lmdb 图中获取查询索引;

    2. 用 get_item_vector(id) 获取 Annoy 对应的向量;

    3. 用 a.get_nns_by_vector(v, num_results) 获取 Annoy 的最近邻。

  • """

  • private function calculate(query, num_results)

  • -------------------------------

  • Queries a given Annoy index and lmdb map for num_results nearest neighbors

  • Input:

  •    query           - query to be searched

  •    num_results     - the number of results

  • Return:

  •    ret_keys        - list of num_results nearest neighbors keys

  • """

  • def calculate(query, num_results, verbose=False):

  •    ret_keys = []

  •    with env.begin() as txn:

  •        id = int(txn.get("w" + query)[1:])

  •        if verbose:

  •            print("Query: {}, with id: {}".format(query, id))

  •        v = a.get_item_vector(id)

  •        for id in a.get_nns_by_vector(v, num_results):

  •            key = txn.get("i%d" % id)[1:]

  •            ret_keys.append(key)

  •    if verbose:

  •        print("Found: {} results".format(len(ret_keys)))

  •    return ret_keys



  • 再次,这里使用 argparse 来使读取命令行参数更加简单。

  • """

  • private function _create_args()

  • -------------------------------

  • Creates an argeparse object for CLI for calculate() function

  • Input:

  •    Void

  • Return:

  •    args object with required arguments for threshold_image() function

  • """

  • def _create_args():

  •    parser = argparse.ArgumentParser()

  •    parser.add_argument("--token", help="query word", type=str)

  •    parser.add_argument("--num_results", help="number of results to return", type=int)

  •    parser.add_argument("--verbose", help="print logging", action="store_true")

  •    args = parser.parse_args()

  •    return args

  • 主函数从命令行中启用 annoy_inference.py。

  • if __name__ == "__main__":

  •    args = _create_args()

  •    print(calculate(args.token, args.num_results, args.verbose))



  • 现在我们可以使用 Annoy 索引和 lmdb 图,获取查询的最近邻!

  • python2 annoy_inference.py --token="test" --num_results=30

  • ["test", "tests", "determine", "for", "crucial", "only", "preparation", "needed", "positive", "guided", "time", "performance", "one", "fitness", "replacement", "stages", "made", "both", "accuracy", "deliver", "put", "standardized", "best", "discovery", ".", "a", "diagnostic", "delayed", "while", "side"]

  • 代码

    本教程所有代码的 GitHub 地址:http://github.com/kyang6/annoy_tutorial

    原文地址:http://medium.com/@kevin_yang/simple-approximate-nearest-neighbors-in-python-with-annoy-and-lmdb-e8a701baf905

    本文为机器之心编译,转载请联系本公众号获得授权

    ?------------------------------------------------

    加入机器之心(全职记者/实习生):hr@jiqizhixin.com

    投稿或寻求报道:editor@jiqizhixin.com

    广告&商务合作:bd@jiqizhixin.com