Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit f7601ae

Browse files
csirmazbendeguzsarahboyce
authored andcommitted
Refs #373 -- Added TupleIn subqueries.
1 parent 611bf6c commit f7601ae

File tree

2 files changed

+79
-3
lines changed

2 files changed

+79
-3
lines changed

‎django/db/models/fields/tuple_lookups.py‎

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
LessThan,
1313
LessThanOrEqual,
1414
)
15+
from django.db.models.sql import Query
1516
from django.db.models.sql.where import AND, OR, WhereNode
1617

1718

@@ -211,9 +212,14 @@ def as_oracle(self, compiler, connection):
211212

212213
class TupleIn(TupleLookupMixin, In):
213214
def get_prep_lookup(self):
214-
self.check_rhs_is_tuple_or_list()
215-
self.check_rhs_is_collection_of_tuples_or_lists()
216-
self.check_rhs_elements_length_equals_lhs_length()
215+
if self.rhs_is_direct_value():
216+
self.check_rhs_is_tuple_or_list()
217+
self.check_rhs_is_collection_of_tuples_or_lists()
218+
self.check_rhs_elements_length_equals_lhs_length()
219+
else:
220+
self.check_rhs_is_query()
221+
self.check_rhs_select_length_equals_lhs_length()
222+
217223
return self.rhs # skip checks from mixin
218224

219225
def check_rhs_is_collection_of_tuples_or_lists(self):
@@ -233,6 +239,25 @@ def check_rhs_elements_length_equals_lhs_length(self):
233239
f"must have {len_lhs} elements each"
234240
)
235241

242+
def check_rhs_is_query(self):
243+
if not isinstance(self.rhs, Query):
244+
lhs_str = self.get_lhs_str()
245+
rhs_cls = self.rhs.__class__.__name__
246+
raise ValueError(
247+
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
248+
f"must be a Query object (received {rhs_cls!r})"
249+
)
250+
251+
def check_rhs_select_length_equals_lhs_length(self):
252+
len_rhs = len(self.rhs.select)
253+
len_lhs = len(self.lhs)
254+
if len_rhs != len_lhs:
255+
lhs_str = self.get_lhs_str()
256+
raise ValueError(
257+
f"{self.lookup_name!r} subquery lookup of {lhs_str} "
258+
f"must have {len_lhs} fields (received {len_rhs})"
259+
)
260+
236261
def process_rhs(self, compiler, connection):
237262
rhs = self.rhs
238263
if not rhs:
@@ -255,10 +280,17 @@ def process_rhs(self, compiler, connection):
255280

256281
return Tuple(*result).as_sql(compiler, connection)
257282

283+
def as_sql(self, compiler, connection):
284+
if not self.rhs_is_direct_value():
285+
return self.as_subquery(compiler, connection)
286+
return super().as_sql(compiler, connection)
287+
258288
def as_sqlite(self, compiler, connection):
259289
rhs = self.rhs
260290
if not rhs:
261291
raise EmptyResultSet
292+
if not self.rhs_is_direct_value():
293+
return self.as_subquery(compiler, connection)
262294

263295
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
264296
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
@@ -271,6 +303,9 @@ def as_sqlite(self, compiler, connection):
271303

272304
return root.as_sql(compiler, connection)
273305

306+
def as_subquery(self, compiler, connection):
307+
return compiler.compile(In(self.lhs, self.rhs))
308+
274309

275310
tuple_lookups = {
276311
"exact": TupleExact,

‎tests/foreign_object/test_tuple_lookups.py‎

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
TupleLessThan,
1212
TupleLessThanOrEqual,
1313
)
14+
from django.db.models.lookups import In
1415
from django.test import TestCase, skipUnlessDBFeature
1516

1617
from .models import Contact, Customer
@@ -126,6 +127,46 @@ def test_in_subquery(self):
126127
(self.contact_1, self.contact_2, self.contact_5),
127128
)
128129

130+
def test_tuple_in_subquery_must_be_query(self):
131+
lhs = (F("customer_code"), F("company_code"))
132+
# If rhs is any non-Query object with an as_sql() function.
133+
rhs = In(F("customer_code"), [1, 2, 3])
134+
with self.assertRaisesMessage(
135+
ValueError,
136+
"'in' subquery lookup of ('customer_code', 'company_code') "
137+
"must be a Query object (received 'In')",
138+
):
139+
TupleIn(lhs, rhs)
140+
141+
def test_tuple_in_subquery_must_have_2_fields(self):
142+
lhs = (F("customer_code"), F("company_code"))
143+
rhs = Customer.objects.values_list("customer_id").query
144+
with self.assertRaisesMessage(
145+
ValueError,
146+
"'in' subquery lookup of ('customer_code', 'company_code') "
147+
"must have 2 fields (received 1)",
148+
):
149+
TupleIn(lhs, rhs)
150+
151+
def test_tuple_in_subquery(self):
152+
customers = Customer.objects.values_list("customer_id", "company")
153+
test_cases = (
154+
(self.customer_1, (self.contact_1, self.contact_2, self.contact_5)),
155+
(self.customer_2, (self.contact_3,)),
156+
(self.customer_3, (self.contact_4,)),
157+
(self.customer_4, ()),
158+
(self.customer_5, (self.contact_6,)),
159+
)
160+
161+
for customer, contacts in test_cases:
162+
lhs = (F("customer_code"), F("company_code"))
163+
rhs = customers.filter(id=customer.id).query
164+
lookup = TupleIn(lhs, rhs)
165+
qs = Contact.objects.filter(lookup).order_by("id")
166+
167+
with self.subTest(customer=customer.id, query=str(qs.query)):
168+
self.assertSequenceEqual(qs, contacts)
169+
129170
def test_tuple_in_rhs_must_be_collection_of_tuples_or_lists(self):
130171
test_cases = (
131172
(1, 2, 3),

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /