Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ dev = [
"pandas",
"mapbox-vector-tile",
"jinja2",
"nltk",
"sentence_transformers",
"tqdm",
"mypy",
"pyright",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,27 @@
# specific language governing permissions and limitations
# under the License.

import sys
from hashlib import md5
from typing import Any, List, Tuple
from unittest import SkipTest
from unittest.mock import Mock, patch

import pytest

from elasticsearch import AsyncElasticsearch

from ..async_examples import vectors


@pytest.mark.asyncio
async def test_vector_search(
async_write_client: AsyncElasticsearch, es_version: Tuple[int, ...], mocker: Any
async_write_client: AsyncElasticsearch, es_version: Tuple[int, ...]
) -> None:
# this test only runs on Elasticsearch >= 8.11 because the example uses
# a dense vector without specifying an explicit size
if es_version < (8, 11):
raise SkipTest("This test requires Elasticsearch 8.11 or newer")

class MockModel:
class MockSentenceTransformer:
def __init__(self, model: Any):
pass

Expand All @@ -44,9 +44,22 @@ def encode(self, text: str) -> List[float]:
total = sum(vector)
return [float(v) / total for v in vector]

mocker.patch.object(vectors, "SentenceTransformer", new=MockModel)
def mock_nltk_tokenize(content: str):
return content.split("\n")

# mock sentence_transformers and nltk, because they are quite big and
# irrelevant for testing the example logic
with patch.dict(
sys.modules,
{
"sentence_transformers": Mock(SentenceTransformer=MockSentenceTransformer),
"nltk": Mock(sent_tokenize=mock_nltk_tokenize),
},
):
# import the example after the dependencies are mocked
from ..async_examples import vectors

await vectors.create()
await vectors.WorkplaceDoc._index.refresh()
results = await (await vectors.search("Welcome to our team!")).execute()
assert results[0].name == "New Employee Onboarding Guide"
await vectors.create()
await vectors.WorkplaceDoc._index.refresh()
results = await (await vectors.search("Welcome to our team!")).execute()
assert results[0].name == "Intellectual Property Policy"
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,27 @@
# specific language governing permissions and limitations
# under the License.

import sys
from hashlib import md5
from typing import Any, List, Tuple
from unittest import SkipTest
from unittest.mock import Mock, patch

import pytest

from elasticsearch import Elasticsearch

from ..examples import vectors


@pytest.mark.sync
def test_vector_search(
write_client: Elasticsearch, es_version: Tuple[int, ...], mocker: Any
write_client: Elasticsearch, es_version: Tuple[int, ...]
) -> None:
# this test only runs on Elasticsearch >= 8.11 because the example uses
# a dense vector without specifying an explicit size
if es_version < (8, 11):
raise SkipTest("This test requires Elasticsearch 8.11 or newer")

class MockModel:
class MockSentenceTransformer:
def __init__(self, model: Any):
pass

Expand All @@ -44,9 +44,22 @@ def encode(self, text: str) -> List[float]:
total = sum(vector)
return [float(v) / total for v in vector]

mocker.patch.object(vectors, "SentenceTransformer", new=MockModel)
def mock_nltk_tokenize(content: str):
return content.split("\n")

# mock sentence_transformers and nltk, because they are quite big and
# irrelevant for testing the example logic
with patch.dict(
sys.modules,
{
"sentence_transformers": Mock(SentenceTransformer=MockSentenceTransformer),
"nltk": Mock(sent_tokenize=mock_nltk_tokenize),
},
):
# import the example after the dependencies are mocked
from ..examples import vectors

vectors.create()
vectors.WorkplaceDoc._index.refresh()
results = (vectors.search("Welcome to our team!")).execute()
assert results[0].name == "New Employee Onboarding Guide"
vectors.create()
vectors.WorkplaceDoc._index.refresh()
results = (vectors.search("Welcome to our team!")).execute()
assert results[0].name == "Intellectual Property Policy"