diff --git a/README.md b/README.md
index a9c76bf..0db952f 100644
--- a/README.md
+++ b/README.md
@@ -27,6 +27,8 @@ models.
- spaCy **pipeline component** and **extension attributes**.
- Fully **serializable** so you can easily ship your sense2vec vectors with your
spaCy model packages.
+- Use [`annoy`](https://github.com/spotify/annoy) to build an index for **super
+ fast approximate calculations** of most similar vectors.
- **Train your own vectors** using a pretrained spaCy model, raw text and
[GloVe](https://github.com/stanfordnlp/GloVe) or Word2Vec via
[fastText](https://github.com/facebookresearch/fastText)
@@ -417,7 +419,11 @@ assert s2v.similarity("machine_learning|NOUN", "machine_learning|NOUN") == 1.0
#### method `Sense2Vec.most_similar`
Get the most similar entries in the table. If more than one key is provided, the
-average of the vectors is used.
+average of the vectors is used. To make this faster, you can run
+`Sense2Vec.build_index`, which uses the
+[annoy](https://github.com/spotify/annoy) library to build an index of the
+vectors. This will make the initial load time slower, but will speed up the most
+similar calculations significantly.
| Argument | Type | Description |
| ------------ | ------------------------- | ------------------------------------------------------- |
@@ -466,6 +472,17 @@ assert s2v.get_best_sense("duck") == "duck|NOUN"
assert s2v.get_best_sense("duck", ["VERB", "ADJ"]) == "duck|VERB"
```
+#### method `Sense2Vec.build_index`
+
+Build an `AnnoyIndex` from the vectors. Used for faster calculation of the
+approximate nearest neighbors in `Sense2Vec.most_similar`. See the
+[`annoy` docs](https://github.com/spotify/annoy) for more details.
+
+| Argument | Type | Description |
+| --------- | ------- | ------------------------------------------------------------------------------------------------- |
+| `metric` | unicode | The [metric](https://github.com/spotify/annoy#full-python-api) to use. Defaults to `"euclidean"`. |
+| `n_trees` | int | The number of trees to build. Defaults to `100`. |
+
#### method `Sense2Vec.to_bytes`
Serialize a `Sense2Vec` object to a bytestring.
diff --git a/requirements.txt b/requirements.txt
index a7827ff..5a7ccbf 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,6 +3,7 @@ spacy>=2.2.2,<3.0.0
srsly>=0.2.0
catalogue>=0.0.4
# Third-party dependencies
+annoy==1.15.2
numpy>=1.15.0
importlib_metadata>=0.20; python_version < "3.8"
# Development requirements
diff --git a/sense2vec/sense2vec.py b/sense2vec/sense2vec.py
index 93127ec..d9d7c93 100644
--- a/sense2vec/sense2vec.py
+++ b/sense2vec/sense2vec.py
@@ -2,10 +2,11 @@
from pathlib import Path
from spacy.vectors import Vectors
from spacy.strings import StringStore
+from annoy import AnnoyIndex
import numpy
import srsly
-from .util import registry, SimpleFrozenDict
+from .util import registry, SimpleFrozenDict, get_similarity
class Sense2Vec(object):
@@ -31,8 +32,14 @@ def __init__(
"""
self.vectors = Vectors(shape=shape, name=vectors_name)
self.strings = StringStore() if strings is None else strings
+ self.index = None
self.freqs: Dict[int, int] = {}
- self.cfg = {"senses": senses, "make_key": "default", "split_key": "default"}
+ self.cfg = {
+ "senses": senses,
+ "annoy_metric": "euclidean",
+ "make_key": "default",
+ "split_key": "default",
+ }
self.cfg.update(overrides)
@property
@@ -171,13 +178,7 @@ def similarity(
keys_b = [keys_b]
average_a = numpy.vstack([self[key] for key in keys_a]).mean(axis=0)
average_b = numpy.vstack([self[key] for key in keys_b]).mean(axis=0)
- if average_a.all() == 0 or average_b.all() == 0:
- return 0.0
- norm_a = numpy.linalg.norm(average_a)
- norm_b = numpy.linalg.norm(average_b)
- if norm_a == norm_b:
- return 1.0
- return numpy.dot(average_a, average_b) / (norm_a * norm_b)
+ return get_similarity(average_a, average_b)
def most_similar(
self,
@@ -186,7 +187,8 @@ def most_similar(
batch_size: int = 16,
) -> List[Tuple[str, float]]:
"""Get the most similar entries in the table. If more than one key is
- provided, the average of the vectors is used.
+ provided, the average of the vectors is used. To make this faster,
+ you can run Sense2Vec.build_index, which uses the annoy library.
keys (unicode / int / iterable): The string or integer key(s) to compare to.
n (int): The number of similar keys to return.
@@ -203,14 +205,24 @@ def most_similar(
if len(self.vectors) < n_similar:
n_similar = len(self.vectors)
vecs = numpy.vstack([self[key] for key in keys])
- average = vecs.mean(axis=0, keepdims=True)
- result_keys, _, scores = self.vectors.most_similar(
- average, n=n_similar, batch_size=batch_size
- )
- result = list(zip(result_keys.flatten(), scores.flatten()))
- result = [(self.strings[key], score) for key, score in result if key]
- result = [(key, score) for key, score in result if key not in keys]
- return result
+ if self.index is None: # use the less efficient default way
+ avg = vecs.mean(axis=0, keepdims=True)
+ result_keys, _, scores = self.vectors.most_similar(
+ avg, n=n_similar, batch_size=batch_size
+ )
+ result = list(zip(result_keys.flatten(), scores.flatten()))
+ result = [(self.strings[key], score) for key, score in result if key]
+ return [(key, score) for key, score in result if key not in keys]
+ else: # index is built, use annoy
+ avg = vecs.mean(axis=0, keepdims=False)
+ nns = self.index.get_nns_by_vector(avg, n_similar, include_distances=True)
+ result = []
+ for row, dist in zip(*nns):
+ key = self.strings[self.vectors.find(row=row)[0]]
+ if key not in keys:
+ score = 1.0 if dist == 0.0 else get_similarity(avg, self[key])
+ result.append((key, score))
+ return result
def get_other_senses(
self, key: Union[str, int], ignore_case: bool = True
@@ -258,6 +270,22 @@ def get_best_sense(
freqs.append((freq, key))
return max(freqs)[1] if freqs else None
+ def build_index(self, metric: str = "euclidean", n_trees: int = 100):
+ """Build an AnnoyIndex from the vectors. Used for faster calculation of
+ the approximate nearest neighbors in Sense2Vec.most_similar. See the
+ annoy docs for more details: https://github.com/spotify/annoy
+
+ metric (unicode): The metric to use.
+ n_trees (int): The number of trees to build.
+ """
+ self.cfg["annoy_metric"] = metric
+ self.index = AnnoyIndex(self.vectors.shape[1], metric)
+ for key, vector in self.vectors.items():
+ # The key ints are too big so use the row for annoy
+ row = self.vectors.find(key=key)
+ self.index.add_item(row, vector)
+ self.index.build(n_trees)
+
def to_bytes(self, exclude: Sequence[str] = tuple()) -> bytes:
"""Serialize a Sense2Vec object to a bytestring.
@@ -298,6 +326,8 @@ def to_disk(self, path: Union[Path, str], exclude: Sequence[str] = tuple()):
srsly.write_json(path / "freqs.json", list(self.freqs.items()))
if "strings" not in exclude:
self.strings.to_disk(path / "strings.json")
+ if "index" not in exclude and self.index is not None:
+ self.index.save(str(path / "index.ann"))
def from_disk(self, path: Union[Path, str], exclude: Sequence[str] = tuple()):
"""Load a Sense2Vec object from a directory.
@@ -308,6 +338,7 @@ def from_disk(self, path: Union[Path, str], exclude: Sequence[str] = tuple()):
"""
path = Path(path)
strings_path = path / "strings.json"
+ index_path = path / "index.ann"
freqs_path = path / "freqs.json"
self.vectors = Vectors().from_disk(path)
self.cfg.update(srsly.read_json(path / "cfg"))
@@ -315,4 +346,7 @@ def from_disk(self, path: Union[Path, str], exclude: Sequence[str] = tuple()):
self.freqs = dict(srsly.read_json(freqs_path))
if "strings" not in exclude and strings_path.exists():
self.strings = StringStore().from_disk(strings_path)
+ if "index" not in exclude and index_path.exists():
+ self.index = AnnoyIndex(self.vectors.shape[1], self.cfg["annoy_metric"])
+ self.index.load(str(index_path))
return self
diff --git a/sense2vec/util.py b/sense2vec/util.py
index 496a7b4..4a653ed 100644
--- a/sense2vec/util.py
+++ b/sense2vec/util.py
@@ -3,6 +3,7 @@
from spacy.tokens import Doc, Token, Span
from spacy.util import filter_spans
import catalogue
+import numpy
try:
import importlib.metadata as importlib_metadata # Python 3.8
@@ -167,6 +168,22 @@ def merge_phrases(doc: Doc) -> Doc:
return doc
+def get_similarity(vec_a: numpy.ndarray, vec_b: numpy.ndarray) -> float:
+ """Calculate the similarity of two vectors.
+
+ vec_a (numpy.ndarray): The vector.
+ vec_b (numpy.ndarray): The other vector.
+ RETURNS (float): The similarity score.
+ """
+ if vec_a.all() == 0 or vec_b.all() == 0:
+ return 0.0
+ norm_a = numpy.linalg.norm(vec_a)
+ norm_b = numpy.linalg.norm(vec_b)
+ if norm_a == norm_b:
+ return 1.0
+ return numpy.dot(vec_a, vec_b) / (norm_a * norm_b)
+
+
class SimpleFrozenDict(dict):
"""Simplified implementation of a frozen dict, mainly used as default
function or method argument (for arguments that should default to empty
diff --git a/setup.cfg b/setup.cfg
index 2a11bf0..4fc73cb 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -31,6 +31,7 @@ install_requires =
srsly>=0.2.0
catalogue>=0.0.4
wasabi>=0.4.0,<1.1.0
+ annoy==1.15.2
numpy>=1.15.0
importlib_metadata>=0.20; python_version < "3.8"
diff --git a/tests/test_sense2vec.py b/tests/test_sense2vec.py
index fbf4657..68ae168 100644
--- a/tests/test_sense2vec.py
+++ b/tests/test_sense2vec.py
@@ -81,7 +81,8 @@ def test_sense2vec_similarity():
assert s2v.similarity("a", "e") == 0.0
-def test_sense2vec_most_similar():
+@pytest.mark.parametrize("build_index", [True, False])
+def test_sense2vec_most_similar(build_index):
s2v = Sense2Vec(shape=(6, 4))
s2v.add("a", numpy.asarray([4, 2, 2, 2], dtype=numpy.float32))
s2v.add("b", numpy.asarray([4, 4, 2, 2], dtype=numpy.float32))
@@ -89,15 +90,17 @@ def test_sense2vec_most_similar():
s2v.add("d", numpy.asarray([4, 4, 4, 4], dtype=numpy.float32))
s2v.add("x", numpy.asarray([4, 2, 2, 2], dtype=numpy.float32))
s2v.add("y", numpy.asarray([0.1, 1, 1, 1], dtype=numpy.float32))
+ if build_index:
+ s2v.build_index()
result1 = s2v.most_similar(["x"], n=2)
assert len(result1) == 2
assert result1[0][0] == "a"
assert result1[0][1] == 1.0
assert result1[0][1] == pytest.approx(1.0)
assert result1[1][0] == "b"
- result2 = s2v.most_similar(["a", "x"], n=2)
- assert len(result2) == 2
- assert sorted([key for key, _ in result2]) == ["b", "d"]
+ result2 = s2v.most_similar(["a", "x"], n=3)
+ assert len(result2) == 3
+ assert sorted([key for key, _ in result2]) == ["b", "c", "d"]
result3 = s2v.most_similar(["a", "b"], n=3)
assert len(result3) == 3
assert "y" not in [key for key, _ in result3]