Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit a757bad

Browse files
Support new VAMANA vector type (#3702)
* Support new vector type * Skip VAMANA tests is redis versin is not 8.2 * Add async tests * Fix resp 3 errors
1 parent ce56d1c commit a757bad

File tree

3 files changed

+1020
-6
lines changed

3 files changed

+1020
-6
lines changed

‎redis/commands/search/field.py‎

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs):
181181
182182
``name`` is the name of the field.
183183
184-
``algorithm`` can be "FLAT"or "HNSW".
184+
``algorithm`` can be "FLAT", "HNSW", or "SVS-VAMANA".
185185
186186
``attributes`` each algorithm can have specific attributes. Some of them
187187
are mandatory and some of them are optional. See
@@ -194,10 +194,10 @@ def __init__(self, name: str, algorithm: str, attributes: dict, **kwargs):
194194
if sort or noindex:
195195
raise DataError("Cannot set 'sortable' or 'no_index' in Vector fields.")
196196

197-
if algorithm.upper() not in ["FLAT", "HNSW"]:
197+
if algorithm.upper() not in ["FLAT", "HNSW", "SVS-VAMANA"]:
198198
raise DataError(
199-
"Realtime vector indexing supporting 2 Indexing Methods:"
200-
"'FLAT'and 'HNSW'."
199+
"Realtime vector indexing supporting 3 Indexing Methods:"
200+
"'FLAT', 'HNSW', and 'SVS-VAMANA'."
201201
)
202202

203203
attr_li = []

‎tests/test_asyncio/test_search.py‎

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1815,3 +1815,181 @@ async def test_binary_and_text_fields(decoded_r: redis.Redis):
18151815
assert docs[0]["first_name"] == mixed_data["first_name"], (
18161816
"The text field is not decoded correctly"
18171817
)
1818+
1819+
1820+
# SVS-VAMANA Async Tests
1821+
@pytest.mark.redismod
1822+
@skip_if_server_version_lt("8.1.224")
1823+
async def test_async_svs_vamana_basic_functionality(decoded_r: redis.Redis):
1824+
await decoded_r.ft().create_index(
1825+
(
1826+
VectorField(
1827+
"v",
1828+
"SVS-VAMANA",
1829+
{"TYPE": "FLOAT32", "DIM": 4, "DISTANCE_METRIC": "L2"},
1830+
),
1831+
)
1832+
)
1833+
1834+
vectors = [
1835+
[1.0, 2.0, 3.0, 4.0],
1836+
[2.0, 3.0, 4.0, 5.0],
1837+
[3.0, 4.0, 5.0, 6.0],
1838+
[10.0, 11.0, 12.0, 13.0],
1839+
]
1840+
1841+
for i, vec in enumerate(vectors):
1842+
await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes())
1843+
1844+
query = "*=>[KNN 3 @v $vec]"
1845+
q = Query(query).return_field("__v_score").sort_by("__v_score", True)
1846+
res = await decoded_r.ft().search(
1847+
q, query_params={"vec": np.array(vectors[0], dtype=np.float32).tobytes()}
1848+
)
1849+
1850+
if is_resp2_connection(decoded_r):
1851+
assert res.total == 3
1852+
assert "doc0" == res.docs[0].id
1853+
else:
1854+
assert res["total_results"] == 3
1855+
assert "doc0" == res["results"][0]["id"]
1856+
1857+
1858+
@pytest.mark.redismod
1859+
@skip_if_server_version_lt("8.1.224")
1860+
async def test_async_svs_vamana_distance_metrics(decoded_r: redis.Redis):
1861+
# Test COSINE distance
1862+
await decoded_r.ft().create_index(
1863+
(
1864+
VectorField(
1865+
"v",
1866+
"SVS-VAMANA",
1867+
{"TYPE": "FLOAT32", "DIM": 3, "DISTANCE_METRIC": "COSINE"},
1868+
),
1869+
)
1870+
)
1871+
1872+
vectors = [[1.0, 0.0, 0.0], [0.707, 0.707, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]]
1873+
1874+
for i, vec in enumerate(vectors):
1875+
await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes())
1876+
1877+
query = Query("*=>[KNN 2 @v $vec as score]").sort_by("score").no_content()
1878+
query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()}
1879+
1880+
res = await decoded_r.ft().search(query, query_params=query_params)
1881+
if is_resp2_connection(decoded_r):
1882+
assert res.total == 2
1883+
assert "doc0" == res.docs[0].id
1884+
else:
1885+
assert res["total_results"] == 2
1886+
assert "doc0" == res["results"][0]["id"]
1887+
1888+
1889+
@pytest.mark.redismod
1890+
@skip_if_server_version_lt("8.1.224")
1891+
async def test_async_svs_vamana_vector_types(decoded_r: redis.Redis):
1892+
# Test FLOAT16
1893+
await decoded_r.ft("idx16").create_index(
1894+
(
1895+
VectorField(
1896+
"v16",
1897+
"SVS-VAMANA",
1898+
{"TYPE": "FLOAT16", "DIM": 4, "DISTANCE_METRIC": "L2"},
1899+
),
1900+
)
1901+
)
1902+
1903+
vectors = [[1.5, 2.5, 3.5, 4.5], [2.5, 3.5, 4.5, 5.5], [3.5, 4.5, 5.5, 6.5]]
1904+
1905+
for i, vec in enumerate(vectors):
1906+
await decoded_r.hset(
1907+
f"doc16_{i}", "v16", np.array(vec, dtype=np.float16).tobytes()
1908+
)
1909+
1910+
query = Query("*=>[KNN 2 @v16 $vec as score]").no_content()
1911+
query_params = {"vec": np.array(vectors[0], dtype=np.float16).tobytes()}
1912+
1913+
res = await decoded_r.ft("idx16").search(query, query_params=query_params)
1914+
if is_resp2_connection(decoded_r):
1915+
assert res.total == 2
1916+
assert "doc16_0" == res.docs[0].id
1917+
else:
1918+
assert res["total_results"] == 2
1919+
assert "doc16_0" == res["results"][0]["id"]
1920+
1921+
1922+
@pytest.mark.redismod
1923+
@skip_if_server_version_lt("8.1.224")
1924+
async def test_async_svs_vamana_compression(decoded_r: redis.Redis):
1925+
await decoded_r.ft().create_index(
1926+
(
1927+
VectorField(
1928+
"v",
1929+
"SVS-VAMANA",
1930+
{
1931+
"TYPE": "FLOAT32",
1932+
"DIM": 8,
1933+
"DISTANCE_METRIC": "L2",
1934+
"COMPRESSION": "LVQ8",
1935+
"TRAINING_THRESHOLD": 1024,
1936+
},
1937+
),
1938+
)
1939+
)
1940+
1941+
vectors = []
1942+
for i in range(20):
1943+
vec = [float(i + j) for j in range(8)]
1944+
vectors.append(vec)
1945+
await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes())
1946+
1947+
query = Query("*=>[KNN 5 @v $vec as score]").no_content()
1948+
query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()}
1949+
1950+
res = await decoded_r.ft().search(query, query_params=query_params)
1951+
if is_resp2_connection(decoded_r):
1952+
assert res.total == 5
1953+
assert "doc0" == res.docs[0].id
1954+
else:
1955+
assert res["total_results"] == 5
1956+
assert "doc0" == res["results"][0]["id"]
1957+
1958+
1959+
@pytest.mark.redismod
1960+
@skip_if_server_version_lt("8.1.224")
1961+
async def test_async_svs_vamana_build_parameters(decoded_r: redis.Redis):
1962+
await decoded_r.ft().create_index(
1963+
(
1964+
VectorField(
1965+
"v",
1966+
"SVS-VAMANA",
1967+
{
1968+
"TYPE": "FLOAT32",
1969+
"DIM": 6,
1970+
"DISTANCE_METRIC": "COSINE",
1971+
"CONSTRUCTION_WINDOW_SIZE": 300,
1972+
"GRAPH_MAX_DEGREE": 64,
1973+
"SEARCH_WINDOW_SIZE": 20,
1974+
"EPSILON": 0.05,
1975+
},
1976+
),
1977+
)
1978+
)
1979+
1980+
vectors = []
1981+
for i in range(15):
1982+
vec = [float(i + j) for j in range(6)]
1983+
vectors.append(vec)
1984+
await decoded_r.hset(f"doc{i}", "v", np.array(vec, dtype=np.float32).tobytes())
1985+
1986+
query = Query("*=>[KNN 3 @v $vec as score]").no_content()
1987+
query_params = {"vec": np.array(vectors[0], dtype=np.float32).tobytes()}
1988+
1989+
res = await decoded_r.ft().search(query, query_params=query_params)
1990+
if is_resp2_connection(decoded_r):
1991+
assert res.total == 3
1992+
assert "doc0" == res.docs[0].id
1993+
else:
1994+
assert res["total_results"] == 3
1995+
assert "doc0" == res["results"][0]["id"]

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /