|  | 
|  | 1 | +import json | 
|  | 2 | +import time | 
|  | 3 | +import urllib.parse as urlparse | 
|  | 4 | + | 
|  | 5 | +import urequests as requests | 
|  | 6 | + | 
|  | 7 | + | 
|  | 8 | +class DeviceAuth: | 
|  | 9 | + ''' | 
|  | 10 | + Helps with authenticating devices with limited input capabilities | 
|  | 11 | + per the OAuth2 device flow specification. | 
|  | 12 | + ''' | 
|  | 13 | + | 
|  | 14 | + def __init__(self, client_id, client_secret, discovery_endpoint, scopes=list()): | 
|  | 15 | + self.client_id = client_id | 
|  | 16 | + self.client_secret = client_secret | 
|  | 17 | + self.discovery_endpoint = discovery_endpoint | 
|  | 18 | + self.scopes = scopes | 
|  | 19 | + | 
|  | 20 | + self.user_code = None | 
|  | 21 | + self.verification_url = None | 
|  | 22 | + | 
|  | 23 | + self._discovered = False | 
|  | 24 | + self._authorization_started = False | 
|  | 25 | + self._authorization_completed = False | 
|  | 26 | + | 
|  | 27 | + self._device_auth_endpoint = None | 
|  | 28 | + self._token_endpoint = None | 
|  | 29 | + self._device_code = None | 
|  | 30 | + self._interval = None | 
|  | 31 | + self._code_expires_in = None | 
|  | 32 | + | 
|  | 33 | + self._access_token = None | 
|  | 34 | + self._token_acquired_at = None | 
|  | 35 | + self._token_expires_in = None | 
|  | 36 | + self._token_scope = None | 
|  | 37 | + self._token_type = None | 
|  | 38 | + self._refresh_token = None | 
|  | 39 | + | 
|  | 40 | + | 
|  | 41 | + def discover(self): | 
|  | 42 | + ''' | 
|  | 43 | + Performs OAuth2 device endpoint discovery. | 
|  | 44 | + ''' | 
|  | 45 | + | 
|  | 46 | + if not self._discovered: | 
|  | 47 | + r = requests.request('GET', self.discovery_endpoint) | 
|  | 48 | + j = r.json() | 
|  | 49 | + self._device_auth_endpoint = j['device_authorization_endpoint'] | 
|  | 50 | + self._token_endpoint = j['token_endpoint'] | 
|  | 51 | + self._discovered = True | 
|  | 52 | + r.close() | 
|  | 53 | + | 
|  | 54 | + | 
|  | 55 | + def authorize(self): | 
|  | 56 | + ''' | 
|  | 57 | + Makes an authorization request. | 
|  | 58 | + ''' | 
|  | 59 | + | 
|  | 60 | + if not self._discovered: | 
|  | 61 | + print('Need to discover authorization and token endpoints.') | 
|  | 62 | + return | 
|  | 63 | + | 
|  | 64 | + headers = {'Content-Type': 'application/x-www-form-urlencoded'} | 
|  | 65 | + payload = { | 
|  | 66 | + 'client_id': self.client_id, | 
|  | 67 | + 'scope': ' '.join(self.scopes) | 
|  | 68 | + } | 
|  | 69 | + encoded = urlparse.urlencode(payload) | 
|  | 70 | + r = requests.request('POST', self._device_auth_endpoint, data=encoded, headers=headers) | 
|  | 71 | + j = r.json() | 
|  | 72 | + r.close() | 
|  | 73 | + | 
|  | 74 | + if 'error' in j: | 
|  | 75 | + raise RuntimeError(j['error']) | 
|  | 76 | + | 
|  | 77 | + self._device_code = j['device_code'] | 
|  | 78 | + self.user_code = j['user_code'] | 
|  | 79 | + self.verification_url = j['verification_url'] | 
|  | 80 | + self._interval = j['interval'] | 
|  | 81 | + self._code_expires_in = j['expires_in'] | 
|  | 82 | + self._authorization_started = True | 
|  | 83 | + message = 'Use code %s at %s to authorize the device.' % (self.user_code, self.verification_url) | 
|  | 84 | + print(message) | 
|  | 85 | + | 
|  | 86 | + | 
|  | 87 | + def check_authorization_complete(self, sleep_duration_seconds=5, max_attempts=10): | 
|  | 88 | + ''' | 
|  | 89 | + Polls until completion of an authorization request. | 
|  | 90 | + ''' | 
|  | 91 | + | 
|  | 92 | + if not self._authorization_started: | 
|  | 93 | + print('Start an authorization request.') | 
|  | 94 | + return | 
|  | 95 | + | 
|  | 96 | + headers = {'Content-Type': 'application/x-www-form-urlencoded'} | 
|  | 97 | + payload = { | 
|  | 98 | + 'client_id': self.client_id, | 
|  | 99 | + 'client_secret': self.client_secret, | 
|  | 100 | + 'device_code': self._device_code, | 
|  | 101 | + 'grant_type': 'urn:ietf:params:oauth:grant-type:device_code' | 
|  | 102 | + } | 
|  | 103 | + encoded = urlparse.urlencode(payload) | 
|  | 104 | + | 
|  | 105 | + current_attempt = 0 | 
|  | 106 | + while not self._authorization_completed and current_attempt < max_attempts: | 
|  | 107 | + current_attempt = current_attempt + 1 | 
|  | 108 | + r = requests.request('POST', self._token_endpoint, data=encoded, headers=headers) | 
|  | 109 | + j = r.json() | 
|  | 110 | + r.close() | 
|  | 111 | + if 'error' in j: | 
|  | 112 | + if j['error'] == 'authorization_pending': | 
|  | 113 | + print('Pending authorization. ') | 
|  | 114 | + time.sleep(sleep_duration_seconds) | 
|  | 115 | + elif j['error'] == 'access_denied': | 
|  | 116 | + print('Access denied') | 
|  | 117 | + raise RuntimeError(j['error']) | 
|  | 118 | + else: | 
|  | 119 | + self._access_token = j['access_token'] | 
|  | 120 | + self._token_acquired_at = int(time.time()) | 
|  | 121 | + self._token_expires_in = j['expires_in'] | 
|  | 122 | + self._token_scope = j['scope'] | 
|  | 123 | + self._token_type = j['token_type'] | 
|  | 124 | + self._refresh_token = j['refresh_token'] | 
|  | 125 | + print('Completed authorization') | 
|  | 126 | + self._authorization_completed = True | 
|  | 127 | + | 
|  | 128 | + | 
|  | 129 | + def token(self, force_refresh=False): | 
|  | 130 | + ''' | 
|  | 131 | + Fetches a valid access token. | 
|  | 132 | + ''' | 
|  | 133 | + | 
|  | 134 | + if not self._authorization_completed: | 
|  | 135 | + print('Complete an authorization request') | 
|  | 136 | + return | 
|  | 137 | + | 
|  | 138 | + buffer = 10 * 60 * -1 # 10 min in seconds | 
|  | 139 | + now = int(time.time()) | 
|  | 140 | + is_valid = now < (self._token_acquired_at + self._token_expires_in + buffer) | 
|  | 141 | + | 
|  | 142 | + if not is_valid or force_refresh: | 
|  | 143 | + print('Token expired. Refreshing access tokens.') | 
|  | 144 | + headers = {'Content-Type': 'application/x-www-form-urlencoded'} | 
|  | 145 | + payload = { | 
|  | 146 | + 'client_id': self.client_id, | 
|  | 147 | + 'client_secret': self.client_secret, | 
|  | 148 | + 'refresh_token': self._refresh_token, | 
|  | 149 | + 'grant_type': 'refresh_token' | 
|  | 150 | + } | 
|  | 151 | + encoded = urlparse.urlencode(payload) | 
|  | 152 | + r = requests.request('POST', self._token_endpoint, data=encoded, headers=headers) | 
|  | 153 | + status_code = r.status_code | 
|  | 154 | + j = r.json() | 
|  | 155 | + r.close() | 
|  | 156 | + | 
|  | 157 | + if status_code == 400: | 
|  | 158 | + print('Unable to refresh tokens.') | 
|  | 159 | + raise(RuntimeError('Unable to refresh tokens.')) | 
|  | 160 | + | 
|  | 161 | + print('Updated access tokens.') | 
|  | 162 | + self._access_token = j['access_token'] | 
|  | 163 | + self._token_acquired_at = int(time.time()) | 
|  | 164 | + self._token_expires_in = j['expires_in'] | 
|  | 165 | + self._token_scope = j['scope'] | 
|  | 166 | + self._token_type = j['token_type'] | 
|  | 167 | + | 
|  | 168 | + return self._access_token | 
0 commit comments