Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit d24c9c9

Browse files
Adding WITHATTRIBS option to vector set's vsim command. (#3746)
1 parent 78fb85e commit d24c9c9

File tree

5 files changed

+293
-17
lines changed

5 files changed

+293
-17
lines changed

‎redis/commands/vectorset/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ def __init__(self, client, **kwargs):
2424
# Set the module commands' callbacks
2525
self._MODULE_CALLBACKS = {
2626
VEMB_CMD: parse_vemb_result,
27+
VSIM_CMD: parse_vsim_result,
2728
VGETATTR_CMD: lambda r: r and json.loads(r) or None,
2829
}
2930

3031
self._RESP2_MODULE_CALLBACKS = {
3132
VINFO_CMD: lambda r: r and pairs_to_dict(r) or None,
32-
VSIM_CMD: parse_vsim_result,
3333
VLINKS_CMD: parse_vlinks_result,
3434
}
3535
self._RESP3_MODULE_CALLBACKS = {}

‎redis/commands/vectorset/commands.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import json
22
from enum import Enum
3-
from typing import Awaitable, Dict, List, Optional, Union
3+
from typing import Any, Awaitable, Dict, List, Optional, Union
44

55
from redis.client import NEVER_DECODE
66
from redis.commands.helpers import get_protocol_version
@@ -19,6 +19,15 @@
1919
VGETATTR_CMD = "VGETATTR"
2020
VRANDMEMBER_CMD = "VRANDMEMBER"
2121

22+
# Return type for vsim command
23+
VSimResult = Optional[
24+
List[
25+
Union[
26+
List[EncodableT], Dict[EncodableT, Number], Dict[EncodableT, Dict[str, Any]]
27+
]
28+
]
29+
]
30+
2231

2332
class QuantizationOptions(Enum):
2433
"""Quantization options for the VADD command."""
@@ -33,6 +42,7 @@ class CallbacksOptions(Enum):
3342

3443
RAW = "RAW"
3544
WITHSCORES = "WITHSCORES"
45+
WITHATTRIBS = "WITHATTRIBS"
3646
ALLOW_DECODING = "ALLOW_DECODING"
3747
RESP3 = "RESP3"
3848

@@ -123,22 +133,22 @@ def vsim(
123133
key: KeyT,
124134
input: Union[List[float], bytes, str],
125135
with_scores: Optional[bool] = False,
136+
with_attribs: Optional[bool] = False,
126137
count: Optional[int] = None,
127138
ef: Optional[Number] = None,
128139
filter: Optional[str] = None,
129140
filter_ef: Optional[str] = None,
130141
truth: Optional[bool] = False,
131142
no_thread: Optional[bool] = False,
132143
epsilon: Optional[Number] = None,
133-
) -> Union[
134-
Awaitable[Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]]],
135-
Optional[List[Union[List[EncodableT], Dict[EncodableT, Number]]]],
136-
]:
144+
) -> Union[Awaitable[VSimResult], VSimResult]:
137145
"""
138146
Compare a vector or element ``input`` with the other vectors in a vector set ``key``.
139147
140-
``with_scores`` sets if the results should be returned with the
141-
similarity scores of the elements in the result.
148+
``with_scores`` sets if similarity scores should be returned for each element in the result.
149+
150+
``with_attribs`` ``with_attribs`` sets if the results should be returned with the
151+
attributes of the elements in the result, or None when no attributes are present.
142152
143153
``count`` sets the number of results to return.
144154
@@ -173,9 +183,17 @@ def vsim(
173183
else:
174184
pieces.extend(["ELE", input])
175185

176-
if with_scores:
177-
pieces.append("WITHSCORES")
178-
options[CallbacksOptions.WITHSCORES.value] = True
186+
if with_scores or with_attribs:
187+
if get_protocol_version(self.client) in ["3", 3]:
188+
options[CallbacksOptions.RESP3.value] = True
189+
190+
if with_scores:
191+
pieces.append("WITHSCORES")
192+
options[CallbacksOptions.WITHSCORES.value] = True
193+
194+
if with_attribs:
195+
pieces.append("WITHATTRIBS")
196+
options[CallbacksOptions.WITHATTRIBS.value] = True
179197

180198
if count:
181199
pieces.extend(["COUNT", count])

‎redis/commands/vectorset/utils.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
13
from redis._parsers.helpers import pairs_to_dict
24
from redis.commands.vectorset.commands import CallbacksOptions
35

@@ -75,19 +77,53 @@ def parse_vsim_result(response, **options):
7577
structures depending on input options.
7678
Parsing VSIM result into:
7779
- List[List[str]]
78-
- List[Dict[str, Number]]
80+
- List[Dict[str, Number]] - when with_scores is used (without attributes)
81+
- List[Dict[str, Mapping[str, Any]]] - when with_attribs is used (without scores)
82+
- List[Dict[str, Union[Number, Mapping[str, Any]]]] - when with_scores and with_attribs are used
83+
7984
"""
8085
if response is None:
8186
return response
8287

83-
if options.get(CallbacksOptions.WITHSCORES.value):
88+
withscores = bool(options.get(CallbacksOptions.WITHSCORES.value))
89+
withattribs = bool(options.get(CallbacksOptions.WITHATTRIBS.value))
90+
91+
# Exactly one of withscores or withattribs is True
92+
if (withscores and not withattribs) or (not withscores and withattribs):
8493
# Redis will return a list of list of pairs.
8594
# This list have to be transformed to dict
8695
result_dict = {}
87-
for key, value in pairs_to_dict(response).items():
88-
value = float(value)
96+
if options.get(CallbacksOptions.RESP3.value):
97+
resp_dict = response
98+
else:
99+
resp_dict = pairs_to_dict(response)
100+
for key, value in resp_dict.items():
101+
if withscores:
102+
value = float(value)
103+
else:
104+
value = json.loads(value) if value else None
105+
89106
result_dict[key] = value
90107
return result_dict
108+
elif withscores and withattribs:
109+
it = iter(response)
110+
result_dict = {}
111+
if options.get(CallbacksOptions.RESP3.value):
112+
for elem, data in response.items():
113+
if data[1] is not None:
114+
attribs_dict = json.loads(data[1])
115+
else:
116+
attribs_dict = None
117+
result_dict[elem] = {"score": data[0], "attributes": attribs_dict}
118+
else:
119+
for elem, score, attribs in zip(it, it, it):
120+
if attribs is not None:
121+
attribs_dict = json.loads(attribs)
122+
else:
123+
attribs_dict = None
124+
125+
result_dict[elem] = {"score": float(score), "attributes": attribs_dict}
126+
return result_dict
91127
else:
92128
# return the list of elements for each level
93129
# list of lists

‎tests/test_asyncio/test_vsets.py

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,80 @@ async def test_vsim_with_scores(d_client):
262262
assert 0 <= vsim["elem1"] <= 1
263263

264264

265+
@skip_if_server_version_lt("8.2.0")
266+
async def test_vsim_with_attribs_attribs_set(d_client):
267+
elements_count = 5
268+
vector_dim = 10
269+
attrs_dict = {"key1": "value1", "key2": "value2"}
270+
for i in range(elements_count):
271+
float_array = [random.uniform(0, 5) for x in range(vector_dim)]
272+
await d_client.vset().vadd(
273+
"myset",
274+
float_array,
275+
f"elem{i}",
276+
numlinks=64,
277+
attributes=attrs_dict if i % 2 == 0 else None,
278+
)
279+
280+
vsim = await d_client.vset().vsim("myset", input="elem1", with_attribs=True)
281+
assert len(vsim) == 5
282+
assert isinstance(vsim, dict)
283+
assert vsim["elem1"] is None
284+
assert vsim["elem2"] == attrs_dict
285+
286+
287+
@skip_if_server_version_lt("8.2.0")
288+
async def test_vsim_with_scores_and_attribs_attribs_set(d_client):
289+
elements_count = 5
290+
vector_dim = 10
291+
attrs_dict = {"key1": "value1", "key2": "value2"}
292+
for i in range(elements_count):
293+
float_array = [random.uniform(0, 5) for x in range(vector_dim)]
294+
await d_client.vset().vadd(
295+
"myset",
296+
float_array,
297+
f"elem{i}",
298+
numlinks=64,
299+
attributes=attrs_dict if i % 2 == 0 else None,
300+
)
301+
302+
vsim = await d_client.vset().vsim(
303+
"myset", input="elem1", with_scores=True, with_attribs=True
304+
)
305+
assert len(vsim) == 5
306+
assert isinstance(vsim, dict)
307+
assert isinstance(vsim["elem1"], dict)
308+
assert "score" in vsim["elem1"]
309+
assert "attributes" in vsim["elem1"]
310+
assert isinstance(vsim["elem1"]["score"], float)
311+
assert vsim["elem1"]["attributes"] is None
312+
313+
assert isinstance(vsim["elem2"], dict)
314+
assert "score" in vsim["elem2"]
315+
assert "attributes" in vsim["elem2"]
316+
assert isinstance(vsim["elem2"]["score"], float)
317+
assert vsim["elem2"]["attributes"] == attrs_dict
318+
319+
320+
@skip_if_server_version_lt("8.2.0")
321+
async def test_vsim_with_attribs_attribs_not_set(d_client):
322+
elements_count = 20
323+
vector_dim = 50
324+
for i in range(elements_count):
325+
float_array = [random.uniform(0, 10) for x in range(vector_dim)]
326+
await d_client.vset().vadd(
327+
"myset",
328+
float_array,
329+
f"elem{i}",
330+
numlinks=64,
331+
)
332+
333+
vsim = await d_client.vset().vsim("myset", input="elem1", with_attribs=True)
334+
assert len(vsim) == 10
335+
assert isinstance(vsim, dict)
336+
assert vsim["elem1"] is None
337+
338+
265339
@skip_if_server_version_lt("7.9.0")
266340
async def test_vsim_with_different_vector_input_types(d_client):
267341
elements_count = 10
@@ -785,13 +859,51 @@ async def test_vrandmember(d_client):
785859
assert members_list == []
786860

787861

862+
@skip_if_server_version_lt("8.2.0")
863+
async def test_8_2_new_vset_features_without_decoding_responces(client):
864+
# test vadd
865+
elements = ["elem1", "elem2", "elem3"]
866+
attrs_dict = {"key1": "value1", "key2": "value2"}
867+
for elem in elements:
868+
float_array = [random.uniform(0.5, 10) for x in range(0, 8)]
869+
resp = await client.vset().vadd(
870+
"myset", float_array, element=elem, attributes=attrs_dict
871+
)
872+
assert resp == 1
873+
874+
# test vsim with attributes
875+
vsim_with_attribs = await client.vset().vsim(
876+
"myset", input="elem1", with_attribs=True
877+
)
878+
assert len(vsim_with_attribs) == 3
879+
assert isinstance(vsim_with_attribs, dict)
880+
assert isinstance(vsim_with_attribs[b"elem1"], dict)
881+
assert vsim_with_attribs[b"elem1"] == attrs_dict
882+
883+
# test vsim with score and attributes
884+
vsim_with_scores_and_attribs = await client.vset().vsim(
885+
"myset", input="elem1", with_scores=True, with_attribs=True
886+
)
887+
assert len(vsim_with_scores_and_attribs) == 3
888+
assert isinstance(vsim_with_scores_and_attribs, dict)
889+
assert isinstance(vsim_with_scores_and_attribs[b"elem1"], dict)
890+
assert "score" in vsim_with_scores_and_attribs[b"elem1"]
891+
assert "attributes" in vsim_with_scores_and_attribs[b"elem1"]
892+
assert isinstance(vsim_with_scores_and_attribs[b"elem1"]["score"], float)
893+
assert isinstance(vsim_with_scores_and_attribs[b"elem1"]["attributes"], dict)
894+
assert vsim_with_scores_and_attribs[b"elem1"]["attributes"] == attrs_dict
895+
896+
788897
@skip_if_server_version_lt("7.9.0")
789898
async def test_vset_commands_without_decoding_responces(client):
790899
# test vadd
791900
elements = ["elem1", "elem2", "elem3"]
901+
attrs_dict = {"key1": "value1", "key2": "value2"}
792902
for elem in elements:
793903
float_array = [random.uniform(0.5, 10) for x in range(0, 8)]
794-
resp = await client.vset().vadd("myset", float_array, element=elem)
904+
resp = await client.vset().vadd(
905+
"myset", float_array, element=elem, attributes=attrs_dict
906+
)
795907
assert resp == 1
796908

797909
# test vemb

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /