|  | 
|  | 1 | +from __future__ import annotations | 
|  | 2 | + | 
|  | 3 | +import asyncio | 
|  | 4 | +import contextvars | 
|  | 5 | +import functools | 
|  | 6 | +import json | 
|  | 7 | +import urllib.parse | 
|  | 8 | +from io import BytesIO | 
|  | 9 | +from typing import Any, Optional, Union | 
|  | 10 | +from typing_extensions import Literal | 
|  | 11 | + | 
|  | 12 | +from graphql import ExecutionResult | 
|  | 13 | +from webob import Request, Response | 
|  | 14 | + | 
|  | 15 | +from graphql_server.http import GraphQLHTTPResponse | 
|  | 16 | +from graphql_server.http.ides import GraphQL_IDE | 
|  | 17 | +from graphql_server.webob import GraphQLView as BaseGraphQLView | 
|  | 18 | +from tests.http.context import get_context | 
|  | 19 | +from tests.views.schema import Query, schema | 
|  | 20 | + | 
|  | 21 | +from .base import JSON, HttpClient, Response as ClientResponse, ResultOverrideFunction | 
|  | 22 | + | 
|  | 23 | + | 
|  | 24 | +class GraphQLView(BaseGraphQLView[dict[str, object], object]): | 
|  | 25 | + result_override: ResultOverrideFunction = None | 
|  | 26 | + | 
|  | 27 | + def get_root_value(self, request: Request) -> Query: | 
|  | 28 | + super().get_root_value(request) # for coverage | 
|  | 29 | + return Query() | 
|  | 30 | + | 
|  | 31 | + def get_context(self, request: Request, response: Response) -> dict[str, object]: | 
|  | 32 | + context = super().get_context(request, response) | 
|  | 33 | + return get_context(context) | 
|  | 34 | + | 
|  | 35 | + def process_result( | 
|  | 36 | + self, request: Request, result: ExecutionResult, strict: bool = False | 
|  | 37 | + ) -> GraphQLHTTPResponse: | 
|  | 38 | + if self.result_override: | 
|  | 39 | + return self.result_override(result) | 
|  | 40 | + return super().process_result(request, result, strict) | 
|  | 41 | + | 
|  | 42 | + | 
|  | 43 | +class WebobHttpClient(HttpClient): | 
|  | 44 | + def __init__( | 
|  | 45 | + self, | 
|  | 46 | + graphiql: Optional[bool] = None, | 
|  | 47 | + graphql_ide: Optional[GraphQL_IDE] = "graphiql", | 
|  | 48 | + allow_queries_via_get: bool = True, | 
|  | 49 | + result_override: ResultOverrideFunction = None, | 
|  | 50 | + multipart_uploads_enabled: bool = False, | 
|  | 51 | + ) -> None: | 
|  | 52 | + self.view = GraphQLView( | 
|  | 53 | + schema=schema, | 
|  | 54 | + graphiql=graphiql, | 
|  | 55 | + graphql_ide=graphql_ide, | 
|  | 56 | + allow_queries_via_get=allow_queries_via_get, | 
|  | 57 | + multipart_uploads_enabled=multipart_uploads_enabled, | 
|  | 58 | + ) | 
|  | 59 | + self.view.result_override = result_override | 
|  | 60 | + | 
|  | 61 | + async def _graphql_request( | 
|  | 62 | + self, | 
|  | 63 | + method: Literal["get", "post"], | 
|  | 64 | + query: Optional[str] = None, | 
|  | 65 | + operation_name: Optional[str] = None, | 
|  | 66 | + variables: Optional[dict[str, object]] = None, | 
|  | 67 | + files: Optional[dict[str, BytesIO]] = None, | 
|  | 68 | + headers: Optional[dict[str, str]] = None, | 
|  | 69 | + extensions: Optional[dict[str, Any]] = None, | 
|  | 70 | + **kwargs: Any, | 
|  | 71 | + ) -> ClientResponse: | 
|  | 72 | + body = self._build_body( | 
|  | 73 | + query=query, | 
|  | 74 | + operation_name=operation_name, | 
|  | 75 | + variables=variables, | 
|  | 76 | + files=files, | 
|  | 77 | + method=method, | 
|  | 78 | + extensions=extensions, | 
|  | 79 | + ) | 
|  | 80 | + | 
|  | 81 | + data: Union[dict[str, object], str, None] = None | 
|  | 82 | + | 
|  | 83 | + url = "/graphql" | 
|  | 84 | + | 
|  | 85 | + if body and files: | 
|  | 86 | + body.update({name: (file, name) for name, file in files.items()}) | 
|  | 87 | + | 
|  | 88 | + if method == "get": | 
|  | 89 | + body_encoded = urllib.parse.urlencode(body or {}) | 
|  | 90 | + url = f"{url}?{body_encoded}" | 
|  | 91 | + else: | 
|  | 92 | + if body: | 
|  | 93 | + data = body if files else json.dumps(body) | 
|  | 94 | + kwargs["body"] = data | 
|  | 95 | + | 
|  | 96 | + headers = self._get_headers(method=method, headers=headers, files=files) | 
|  | 97 | + | 
|  | 98 | + return await self.request(url, method, headers=headers, **kwargs) | 
|  | 99 | + | 
|  | 100 | + def _do_request( | 
|  | 101 | + self, | 
|  | 102 | + url: str, | 
|  | 103 | + method: Literal["get", "post", "patch", "put", "delete"], | 
|  | 104 | + headers: Optional[dict[str, str]] = None, | 
|  | 105 | + **kwargs: Any, | 
|  | 106 | + ) -> ClientResponse: | 
|  | 107 | + body = kwargs.get("body", None) | 
|  | 108 | + req = Request.blank( | 
|  | 109 | + url, method=method.upper(), headers=headers or {}, body=body | 
|  | 110 | + ) | 
|  | 111 | + resp = self.view.dispatch_request(req) | 
|  | 112 | + return ClientResponse( | 
|  | 113 | + status_code=resp.status_code, data=resp.body, headers=resp.headers | 
|  | 114 | + ) | 
|  | 115 | + | 
|  | 116 | + async def request( | 
|  | 117 | + self, | 
|  | 118 | + url: str, | 
|  | 119 | + method: Literal["head", "get", "post", "patch", "put", "delete"], | 
|  | 120 | + headers: Optional[dict[str, str]] = None, | 
|  | 121 | + **kwargs: Any, | 
|  | 122 | + ) -> ClientResponse: | 
|  | 123 | + loop = asyncio.get_running_loop() | 
|  | 124 | + ctx = contextvars.copy_context() | 
|  | 125 | + func_call = functools.partial( | 
|  | 126 | + ctx.run, self._do_request, url=url, method=method, headers=headers, **kwargs | 
|  | 127 | + ) | 
|  | 128 | + return await loop.run_in_executor(None, func_call) # type: ignore | 
|  | 129 | + | 
|  | 130 | + async def get( | 
|  | 131 | + self, url: str, headers: Optional[dict[str, str]] = None | 
|  | 132 | + ) -> ClientResponse: | 
|  | 133 | + return await self.request(url, "get", headers=headers) | 
|  | 134 | + | 
|  | 135 | + async def post( | 
|  | 136 | + self, | 
|  | 137 | + url: str, | 
|  | 138 | + data: Optional[bytes] = None, | 
|  | 139 | + json: Optional[JSON] = None, | 
|  | 140 | + headers: Optional[dict[str, str]] = None, | 
|  | 141 | + ) -> ClientResponse: | 
|  | 142 | + body = json if json is not None else data | 
|  | 143 | + return await self.request(url, "post", headers=headers, body=body) | 
0 commit comments