|
| 1 | +import time |
| 2 | +import urllib.parse as urlparse |
| 3 | + |
| 4 | +import urequests as requests |
| 5 | + |
| 6 | + |
| 7 | +class DeviceAuth: |
| 8 | + ''' |
| 9 | + Helps with authenticating devices with limited input capabilities |
| 10 | + per the OAuth2 device flow specification. |
| 11 | + ''' |
| 12 | + |
| 13 | + def __init__(self, client_id, client_secret, discovery_endpoint, scopes=list()): |
| 14 | + self.client_id = client_id |
| 15 | + self.client_secret = client_secret |
| 16 | + self.discovery_endpoint = discovery_endpoint |
| 17 | + self.scopes = scopes |
| 18 | + |
| 19 | + self.user_code = None |
| 20 | + self.verification_url = None |
| 21 | + |
| 22 | + self._discovered = False |
| 23 | + self._authorization_started = False |
| 24 | + self._authorization_completed = False |
| 25 | + |
| 26 | + self._device_auth_endpoint = None |
| 27 | + self._token_endpoint = None |
| 28 | + self._device_code = None |
| 29 | + self._interval = None |
| 30 | + self._code_expires_in = None |
| 31 | + |
| 32 | + self._access_token = None |
| 33 | + self._token_acquired_at = None |
| 34 | + self._token_expires_in = None |
| 35 | + self._token_scope = None |
| 36 | + self._token_type = None |
| 37 | + self._refresh_token = None |
| 38 | + |
| 39 | + def discover(self): |
| 40 | + ''' |
| 41 | + Performs OAuth2 device endpoint discovery. |
| 42 | + ''' |
| 43 | + |
| 44 | + if not self._discovered: |
| 45 | + r = requests.request('GET', self.discovery_endpoint) |
| 46 | + j = r.json() |
| 47 | + self._device_auth_endpoint = j['device_authorization_endpoint'] |
| 48 | + self._token_endpoint = j['token_endpoint'] |
| 49 | + self._discovered = True |
| 50 | + r.close() |
| 51 | + |
| 52 | + def authorize(self): |
| 53 | + ''' |
| 54 | + Makes an authorization request. |
| 55 | + ''' |
| 56 | + |
| 57 | + if not self._discovered: |
| 58 | + print('Need to discover authorization and token endpoints.') |
| 59 | + return |
| 60 | + |
| 61 | + headers = {'Content-Type': 'application/x-www-form-urlencoded'} |
| 62 | + payload = { |
| 63 | + 'client_id': self.client_id, |
| 64 | + 'scope': ' '.join(self.scopes) |
| 65 | + } |
| 66 | + encoded = urlparse.urlencode(payload) |
| 67 | + r = requests.request( |
| 68 | + 'POST', |
| 69 | + self._device_auth_endpoint, |
| 70 | + data=encoded, |
| 71 | + headers=headers |
| 72 | + ) |
| 73 | + j = r.json() |
| 74 | + r.close() |
| 75 | + |
| 76 | + if 'error' in j: |
| 77 | + raise RuntimeError(j['error']) |
| 78 | + |
| 79 | + self._device_code = j['device_code'] |
| 80 | + self.user_code = j['user_code'] |
| 81 | + self.verification_url = j['verification_url'] |
| 82 | + self._interval = j['interval'] |
| 83 | + self._code_expires_in = j['expires_in'] |
| 84 | + self._authorization_started = True |
| 85 | + message = 'Use code %s at %s to authorize the device.' % ( |
| 86 | + self.user_code, |
| 87 | + self.verification_url |
| 88 | + ) |
| 89 | + print(message) |
| 90 | + |
| 91 | + def check_authorization_complete(self, sleep_duration_seconds=5, max_attempts=10): |
| 92 | + ''' |
| 93 | + Polls until completion of an authorization request. |
| 94 | + ''' |
| 95 | + |
| 96 | + if not self._authorization_started: |
| 97 | + print('Start an authorization request.') |
| 98 | + return |
| 99 | + |
| 100 | + headers = {'Content-Type': 'application/x-www-form-urlencoded'} |
| 101 | + payload = { |
| 102 | + 'client_id': self.client_id, |
| 103 | + 'client_secret': self.client_secret, |
| 104 | + 'device_code': self._device_code, |
| 105 | + 'grant_type': 'urn:ietf:params:oauth:grant-type:device_code' |
| 106 | + } |
| 107 | + encoded = urlparse.urlencode(payload) |
| 108 | + |
| 109 | + current_attempt = 0 |
| 110 | + while not self.authorized and current_attempt < max_attempts: |
| 111 | + current_attempt = current_attempt + 1 |
| 112 | + r = requests.request( |
| 113 | + 'POST', |
| 114 | + self._token_endpoint, |
| 115 | + data=encoded, |
| 116 | + headers=headers |
| 117 | + ) |
| 118 | + j = r.json() |
| 119 | + r.close() |
| 120 | + if 'error' in j: |
| 121 | + if j['error'] == 'authorization_pending': |
| 122 | + print('Pending authorization. ') |
| 123 | + time.sleep(sleep_duration_seconds) |
| 124 | + elif j['error'] == 'access_denied': |
| 125 | + print('Access denied') |
| 126 | + raise RuntimeError(j['error']) |
| 127 | + else: |
| 128 | + self._access_token = j['access_token'] |
| 129 | + self._token_acquired_at = int(time.time()) |
| 130 | + self._token_expires_in = j['expires_in'] |
| 131 | + self._token_scope = j['scope'] |
| 132 | + self._token_type = j['token_type'] |
| 133 | + self._refresh_token = j['refresh_token'] |
| 134 | + print('Completed authorization') |
| 135 | + self._authorization_completed = True |
| 136 | + |
| 137 | + @property |
| 138 | + def authorized(self): |
| 139 | + return self._authorization_completed |
| 140 | + |
| 141 | + def token(self, force_refresh=False): |
| 142 | + ''' |
| 143 | + Fetches a valid access token. |
| 144 | + ''' |
| 145 | + |
| 146 | + if not self._authorization_completed: |
| 147 | + print('Complete an authorization request') |
| 148 | + return |
| 149 | + |
| 150 | + buffer = 10 * 60 * -1 # 10 min in seconds |
| 151 | + now = int(time.time()) |
| 152 | + is_valid = now < ( |
| 153 | + self._token_acquired_at + |
| 154 | + self._token_expires_in + |
| 155 | + buffer |
| 156 | + ) |
| 157 | + if not is_valid or force_refresh: |
| 158 | + print('Token expired. Refreshing access tokens.') |
| 159 | + headers = {'Content-Type': 'application/x-www-form-urlencoded'} |
| 160 | + payload = { |
| 161 | + 'client_id': self.client_id, |
| 162 | + 'client_secret': self.client_secret, |
| 163 | + 'refresh_token': self._refresh_token, |
| 164 | + 'grant_type': 'refresh_token' |
| 165 | + } |
| 166 | + encoded = urlparse.urlencode(payload) |
| 167 | + r = requests.request( |
| 168 | + 'POST', |
| 169 | + self._token_endpoint, |
| 170 | + data=encoded, |
| 171 | + headers=headers |
| 172 | + ) |
| 173 | + status_code = r.status_code |
| 174 | + j = r.json() |
| 175 | + r.close() |
| 176 | + |
| 177 | + if status_code == 400: |
| 178 | + print('Unable to refresh tokens.') |
| 179 | + raise(RuntimeError('Unable to refresh tokens.')) |
| 180 | + |
| 181 | + print('Updated access tokens.') |
| 182 | + self._access_token = j['access_token'] |
| 183 | + self._token_acquired_at = int(time.time()) |
| 184 | + self._token_expires_in = j['expires_in'] |
| 185 | + self._token_scope = j['scope'] |
| 186 | + self._token_type = j['token_type'] |
| 187 | + |
| 188 | + return self._access_token |
0 commit comments