Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 7f60d30

Browse files
authored
feat(pipelines): get API token from cog's current_scope, if available (#48)
* feat(pipelines): get API token from cog's current_scope, if available * lint more
1 parent 22f4fe8 commit 7f60d30

File tree

4 files changed

+265
-3
lines changed

4 files changed

+265
-3
lines changed

‎src/replicate/_client.py‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import httpx
2222

23+
from replicate.lib.cog import _get_api_token_from_environment
2324
from replicate.lib._files import FileEncodingStrategy
2425
from replicate.lib._predictions_run import Model, Version, ModelVersionIdentifier
2526
from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion
@@ -108,7 +109,7 @@ def __init__(
108109
This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided.
109110
"""
110111
if bearer_token is None:
111-
bearer_token = os.environ.get("REPLICATE_API_TOKEN")
112+
bearer_token = _get_api_token_from_environment()
112113
if bearer_token is None:
113114
raise ReplicateError(
114115
"The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable"
@@ -419,7 +420,7 @@ def __init__(
419420
This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided.
420421
"""
421422
if bearer_token is None:
422-
bearer_token = os.environ.get("REPLICATE_API_TOKEN")
423+
bearer_token = _get_api_token_from_environment()
423424
if bearer_token is None:
424425
raise ReplicateError(
425426
"The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable"

‎src/replicate/lib/cog.py‎

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Cog integration utilities for Replicate."""
2+
3+
import os
4+
from typing import Any, Union, Iterator, cast
5+
6+
from replicate._utils._logs import logger
7+
8+
9+
def _get_api_token_from_environment() -> Union[str, None]:
10+
"""Get API token from cog current scope if available, otherwise from environment."""
11+
try:
12+
import cog # type: ignore[import-untyped, import-not-found]
13+
14+
# Get the current scope - this might return None or raise an exception
15+
scope = getattr(cog, "current_scope", lambda: None)()
16+
if scope is None:
17+
return os.environ.get("REPLICATE_API_TOKEN")
18+
19+
# Get the context from the scope
20+
context = getattr(scope, "context", None)
21+
if context is None:
22+
return os.environ.get("REPLICATE_API_TOKEN")
23+
24+
# Get the items method and call it
25+
items_method = getattr(context, "items", None)
26+
if not callable(items_method):
27+
return os.environ.get("REPLICATE_API_TOKEN")
28+
29+
# Iterate through context items looking for the API token
30+
items = cast(Iterator["tuple[Any, Any]"], items_method())
31+
for key, value in items:
32+
if str(key).upper() == "REPLICATE_API_TOKEN":
33+
return str(value) if value is not None else value
34+
35+
except Exception as e: # Catch all exceptions to ensure robust fallback
36+
logger.debug("Failed to retrieve API token from cog.current_scope(): %s", e)
37+
38+
return os.environ.get("REPLICATE_API_TOKEN")
39+
40+
41+
__all__ = ["_get_api_token_from_environment"]

‎src/replicate/types/prediction_create_params.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class PredictionCreateParamsWithoutVersion(TypedDict, total=False):
3535
- you don't want to upload and host the file somewhere
3636
- you don't need to use the file again (Replicate will not store it)
3737
"""
38-
38+
3939
stream: bool
4040
"""**This field is deprecated.**
4141

‎tests/test_current_scope.py‎

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
"""Tests for current_scope token functionality."""
2+
3+
import os
4+
import sys
5+
from unittest import mock
6+
7+
import pytest
8+
9+
from replicate import Replicate, AsyncReplicate
10+
from replicate.lib.cog import _get_api_token_from_environment
11+
from replicate._exceptions import ReplicateError
12+
13+
14+
class TestGetApiTokenFromEnvironment:
15+
"""Test the _get_api_token_from_environment function."""
16+
17+
def test_cog_no_current_scope_method_falls_back_to_env(self):
18+
"""Test fallback when cog exists but has no current_scope method."""
19+
mock_cog = mock.MagicMock()
20+
del mock_cog.current_scope # Remove the method
21+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
22+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
23+
token = _get_api_token_from_environment()
24+
assert token == "env-token"
25+
26+
def test_cog_current_scope_returns_none_falls_back_to_env(self):
27+
"""Test fallback when current_scope() returns None."""
28+
mock_cog = mock.MagicMock()
29+
mock_cog.current_scope.return_value = None
30+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
31+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
32+
token = _get_api_token_from_environment()
33+
assert token == "env-token"
34+
35+
def test_cog_scope_no_context_attr_falls_back_to_env(self):
36+
"""Test fallback when scope has no context attribute."""
37+
mock_scope = mock.MagicMock()
38+
del mock_scope.context # Remove the context attribute
39+
mock_cog = mock.MagicMock()
40+
mock_cog.current_scope.return_value = mock_scope
41+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
42+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
43+
token = _get_api_token_from_environment()
44+
assert token == "env-token"
45+
46+
def test_cog_scope_context_not_dict_falls_back_to_env(self):
47+
"""Test fallback when scope.context is not a dictionary."""
48+
mock_scope = mock.MagicMock()
49+
mock_scope.context = "not a dict"
50+
mock_cog = mock.MagicMock()
51+
mock_cog.current_scope.return_value = mock_scope
52+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
53+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
54+
token = _get_api_token_from_environment()
55+
assert token == "env-token"
56+
57+
def test_cog_scope_no_replicate_api_token_key_falls_back_to_env(self):
58+
"""Test fallback when replicate_api_token key is missing from context."""
59+
mock_scope = mock.MagicMock()
60+
mock_scope.context = {"other_key": "other_value"} # Missing replicate_api_token
61+
mock_cog = mock.MagicMock()
62+
mock_cog.current_scope.return_value = mock_scope
63+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
64+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
65+
token = _get_api_token_from_environment()
66+
assert token == "env-token"
67+
68+
def test_cog_scope_replicate_api_token_valid_string(self):
69+
"""Test successful retrieval of non-empty token from cog."""
70+
mock_scope = mock.MagicMock()
71+
mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"}
72+
mock_cog = mock.MagicMock()
73+
mock_cog.current_scope.return_value = mock_scope
74+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
75+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
76+
token = _get_api_token_from_environment()
77+
assert token == "cog-token"
78+
79+
def test_cog_scope_replicate_api_token_case_insensitive(self):
80+
"""Test successful retrieval of non-empty token from cog ignoring case."""
81+
mock_scope = mock.MagicMock()
82+
mock_scope.context = {"replicate_api_token": "cog-token"}
83+
mock_cog = mock.MagicMock()
84+
mock_cog.current_scope.return_value = mock_scope
85+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
86+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
87+
token = _get_api_token_from_environment()
88+
assert token == "cog-token"
89+
90+
def test_cog_scope_replicate_api_token_empty_string(self):
91+
"""Test that empty string from cog is returned (not falling back to env)."""
92+
mock_scope = mock.MagicMock()
93+
mock_scope.context = {"replicate_api_token": ""} # Empty string
94+
mock_cog = mock.MagicMock()
95+
mock_cog.current_scope.return_value = mock_scope
96+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
97+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
98+
token = _get_api_token_from_environment()
99+
assert token == "" # Should return empty string, not env token
100+
101+
def test_cog_scope_replicate_api_token_none(self):
102+
"""Test that None from cog is returned (not falling back to env)."""
103+
mock_scope = mock.MagicMock()
104+
mock_scope.context = {"replicate_api_token": None}
105+
mock_cog = mock.MagicMock()
106+
mock_cog.current_scope.return_value = mock_scope
107+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
108+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
109+
token = _get_api_token_from_environment()
110+
assert token is None # Should return None, not env token
111+
112+
def test_cog_current_scope_raises_exception_falls_back_to_env(self):
113+
"""Test fallback when current_scope() raises an exception."""
114+
mock_cog = mock.MagicMock()
115+
mock_cog.current_scope.side_effect = RuntimeError("Scope error")
116+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
117+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
118+
token = _get_api_token_from_environment()
119+
assert token == "env-token"
120+
121+
def test_no_env_token_returns_none(self):
122+
"""Test that None is returned when no environment token is set and cog unavailable."""
123+
with mock.patch.dict(os.environ, {}, clear=True): # Clear all env vars
124+
with mock.patch.dict(sys.modules, {"cog": None}):
125+
token = _get_api_token_from_environment()
126+
assert token is None
127+
128+
def test_env_token_empty_string(self):
129+
"""Test that empty string from environment is returned."""
130+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": ""}):
131+
with mock.patch.dict(sys.modules, {"cog": None}):
132+
token = _get_api_token_from_environment()
133+
assert token == ""
134+
135+
def test_env_token_valid_string(self):
136+
"""Test that valid token from environment is returned."""
137+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
138+
with mock.patch.dict(sys.modules, {"cog": None}):
139+
token = _get_api_token_from_environment()
140+
assert token == "env-token"
141+
142+
143+
class TestClientCurrentScopeIntegration:
144+
"""Test that the client uses current_scope functionality."""
145+
146+
def test_sync_client_uses_current_scope_token(self):
147+
"""Test that sync client retrieves token from current_scope."""
148+
mock_scope = mock.MagicMock()
149+
mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"}
150+
mock_cog = mock.MagicMock()
151+
mock_cog.current_scope.return_value = mock_scope
152+
153+
# Clear environment variable to ensure we're using cog
154+
with mock.patch.dict(os.environ, {}, clear=True):
155+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
156+
client = Replicate(base_url="http://test.example.com")
157+
assert client.bearer_token == "cog-token"
158+
159+
def test_async_client_uses_current_scope_token(self):
160+
"""Test that async client retrieves token from current_scope."""
161+
mock_scope = mock.MagicMock()
162+
mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"}
163+
mock_cog = mock.MagicMock()
164+
mock_cog.current_scope.return_value = mock_scope
165+
166+
# Clear environment variable to ensure we're using cog
167+
with mock.patch.dict(os.environ, {}, clear=True):
168+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
169+
client = AsyncReplicate(base_url="http://test.example.com")
170+
assert client.bearer_token == "cog-token"
171+
172+
def test_sync_client_falls_back_to_env_when_cog_unavailable(self):
173+
"""Test that sync client falls back to env when cog is unavailable."""
174+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
175+
with mock.patch.dict(sys.modules, {"cog": None}):
176+
client = Replicate(base_url="http://test.example.com")
177+
assert client.bearer_token == "env-token"
178+
179+
def test_async_client_falls_back_to_env_when_cog_unavailable(self):
180+
"""Test that async client falls back to env when cog is unavailable."""
181+
with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}):
182+
with mock.patch.dict(sys.modules, {"cog": None}):
183+
client = AsyncReplicate(base_url="http://test.example.com")
184+
assert client.bearer_token == "env-token"
185+
186+
def test_sync_client_raises_error_when_no_token_available(self):
187+
"""Test that sync client raises error when no token is available."""
188+
with mock.patch.dict(os.environ, {}, clear=True):
189+
with mock.patch.dict(sys.modules, {"cog": None}):
190+
with pytest.raises(ReplicateError, match="bearer_token client option must be set"):
191+
Replicate(base_url="http://test.example.com")
192+
193+
def test_async_client_raises_error_when_no_token_available(self):
194+
"""Test that async client raises error when no token is available."""
195+
with mock.patch.dict(os.environ, {}, clear=True):
196+
with mock.patch.dict(sys.modules, {"cog": None}):
197+
with pytest.raises(ReplicateError, match="bearer_token client option must be set"):
198+
AsyncReplicate(base_url="http://test.example.com")
199+
200+
def test_explicit_token_overrides_current_scope(self):
201+
"""Test that explicitly provided token overrides current_scope."""
202+
mock_scope = mock.MagicMock()
203+
mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"}
204+
mock_cog = mock.MagicMock()
205+
mock_cog.current_scope.return_value = mock_scope
206+
207+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
208+
client = Replicate(bearer_token="explicit-token", base_url="http://test.example.com")
209+
assert client.bearer_token == "explicit-token"
210+
211+
def test_explicit_async_token_overrides_current_scope(self):
212+
"""Test that explicitly provided token overrides current_scope for async client."""
213+
mock_scope = mock.MagicMock()
214+
mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"}
215+
mock_cog = mock.MagicMock()
216+
mock_cog.current_scope.return_value = mock_scope
217+
218+
with mock.patch.dict(sys.modules, {"cog": mock_cog}):
219+
client = AsyncReplicate(bearer_token="explicit-token", base_url="http://test.example.com")
220+
assert client.bearer_token == "explicit-token"

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /