12
12
LessThan ,
13
13
LessThanOrEqual ,
14
14
)
15
+ from django .db .models .sql import Query
15
16
from django .db .models .sql .where import AND , OR , WhereNode
16
17
17
18
@@ -211,9 +212,14 @@ def as_oracle(self, compiler, connection):
211
212
212
213
class TupleIn (TupleLookupMixin , In ):
213
214
def get_prep_lookup (self ):
214
- self .check_rhs_is_tuple_or_list ()
215
- self .check_rhs_is_collection_of_tuples_or_lists ()
216
- self .check_rhs_elements_length_equals_lhs_length ()
215
+ if self .rhs_is_direct_value ():
216
+ self .check_rhs_is_tuple_or_list ()
217
+ self .check_rhs_is_collection_of_tuples_or_lists ()
218
+ self .check_rhs_elements_length_equals_lhs_length ()
219
+ else :
220
+ self .check_rhs_is_query ()
221
+ self .check_rhs_select_length_equals_lhs_length ()
222
+
217
223
return self .rhs # skip checks from mixin
218
224
219
225
def check_rhs_is_collection_of_tuples_or_lists (self ):
@@ -233,6 +239,25 @@ def check_rhs_elements_length_equals_lhs_length(self):
233
239
f"must have { len_lhs } elements each"
234
240
)
235
241
242
+ def check_rhs_is_query (self ):
243
+ if not isinstance (self .rhs , Query ):
244
+ lhs_str = self .get_lhs_str ()
245
+ rhs_cls = self .rhs .__class__ .__name__
246
+ raise ValueError (
247
+ f"{ self .lookup_name !r} subquery lookup of { lhs_str } "
248
+ f"must be a Query object (received { rhs_cls !r} )"
249
+ )
250
+
251
+ def check_rhs_select_length_equals_lhs_length (self ):
252
+ len_rhs = len (self .rhs .select )
253
+ len_lhs = len (self .lhs )
254
+ if len_rhs != len_lhs :
255
+ lhs_str = self .get_lhs_str ()
256
+ raise ValueError (
257
+ f"{ self .lookup_name !r} subquery lookup of { lhs_str } "
258
+ f"must have { len_lhs } fields (received { len_rhs } )"
259
+ )
260
+
236
261
def process_rhs (self , compiler , connection ):
237
262
rhs = self .rhs
238
263
if not rhs :
@@ -255,10 +280,17 @@ def process_rhs(self, compiler, connection):
255
280
256
281
return Tuple (* result ).as_sql (compiler , connection )
257
282
283
+ def as_sql (self , compiler , connection ):
284
+ if not self .rhs_is_direct_value ():
285
+ return self .as_subquery (compiler , connection )
286
+ return super ().as_sql (compiler , connection )
287
+
258
288
def as_sqlite (self , compiler , connection ):
259
289
rhs = self .rhs
260
290
if not rhs :
261
291
raise EmptyResultSet
292
+ if not self .rhs_is_direct_value ():
293
+ return self .as_subquery (compiler , connection )
262
294
263
295
# e.g.: (a, b, c) in [(x1, y1, z1), (x2, y2, z2)] as SQL:
264
296
# WHERE (a = x1 AND b = y1 AND c = z1) OR (a = x2 AND b = y2 AND c = z2)
@@ -271,6 +303,9 @@ def as_sqlite(self, compiler, connection):
271
303
272
304
return root .as_sql (compiler , connection )
273
305
306
+ def as_subquery (self , compiler , connection ):
307
+ return compiler .compile (In (self .lhs , self .rhs ))
308
+
274
309
275
310
tuple_lookups = {
276
311
"exact" : TupleExact ,
0 commit comments