import cgi | |
import urllib | |
import time | |
import random | |
import urlparse | |
import hmac | |
import binascii | |
VERSION = '1.0' # Hi Blaine! | |
HTTP_METHOD = 'GET' | |
SIGNATURE_METHOD = 'PLAINTEXT' | |
# Generic exception class | |
class OAuthError(RuntimeError): | |
def __init__(self, message='OAuth error occured.'): | |
self.message = message | |
# optional WWW-Authenticate header (401 error) | |
def build_authenticate_header(realm=''): | |
return {'WWW-Authenticate': 'OAuth realm="%s"' % realm} | |
# url escape | |
def escape(s): | |
# escape '/' too | |
return urllib.quote(s, safe='~') | |
# util function: current timestamp | |
# seconds since epoch (UTC) | |
def generate_timestamp(): | |
return int(time.time()) | |
# util function: nonce | |
# pseudorandom number | |
def generate_nonce(length=8): | |
return ''.join([str(random.randint(0, 9)) for i in range(length)]) | |
# OAuthConsumer is a data type that represents the identity of the Consumer | |
# via its shared secret with the Service Provider. | |
class OAuthConsumer(object): | |
key = None | |
secret = None | |
def __init__(self, key, secret): | |
self.key = key | |
self.secret = secret | |
# OAuthToken is a data type that represents an End User via either an access | |
# or request token. | |
class OAuthToken(object): | |
# access tokens and request tokens | |
key = None | |
secret = None | |
''' | |
key = the token | |
secret = the token secret | |
''' | |
def __init__(self, key, secret): | |
self.key = key | |
self.secret = secret | |
def to_string(self): | |
return urllib.urlencode({'oauth_token': self.key, 'oauth_token_secret': self.secret}) | |
# return a token from something like: | |
# oauth_token_secret=digg&oauth_token=digg | |
def from_string(s): | |
params = cgi.parse_qs(s, keep_blank_values=False) | |
key = params['oauth_token'][0] | |
secret = params['oauth_token_secret'][0] | |
return OAuthToken(key, secret) | |
from_string = staticmethod(from_string) | |
def __str__(self): | |
return self.to_string() | |
# OAuthRequest represents the request and can be serialized | |
class OAuthRequest(object): | |
''' | |
OAuth parameters: | |
- oauth_consumer_key | |
- oauth_token | |
- oauth_signature_method | |
- oauth_signature | |
- oauth_timestamp | |
- oauth_nonce | |
- oauth_version | |
... any additional parameters, as defined by the Service Provider. | |
''' | |
parameters = None # oauth parameters | |
http_method = HTTP_METHOD | |
http_url = None | |
version = VERSION | |
def __init__(self, http_method=HTTP_METHOD, http_url=None, parameters=None): | |
self.http_method = http_method | |
self.http_url = http_url | |
self.parameters = parameters or {} | |
def set_parameter(self, parameter, value): | |
self.parameters[parameter] = value | |
def get_parameter(self, parameter): | |
try: | |
return self.parameters[parameter] | |
except: | |
raise OAuthError('Parameter not found: %s' % parameter) | |
def _get_timestamp_nonce(self): | |
return self.get_parameter('oauth_timestamp'), self.get_parameter('oauth_nonce') | |
# get any non-oauth parameters | |
def get_nonoauth_parameters(self): | |
parameters = {} | |
for k, v in self.parameters.iteritems(): | |
# ignore oauth parameters | |
if k.find('oauth_') < 0: | |
parameters[k] = v | |
return parameters | |
# serialize as a header for an HTTPAuth request | |
def to_header(self, realm=''): | |
auth_header = 'OAuth realm="%s"' % realm | |
# add the oauth parameters | |
if self.parameters: | |
for k, v in self.parameters.iteritems(): | |
if k[:6] == 'oauth_': | |
auth_header += ', %s="%s"' % (k, escape(str(v))) | |
return {'Authorization': auth_header} | |
# serialize as post data for a POST request | |
def to_postdata(self): | |
return '&'.join(['%s=%s' % (escape(str(k)), escape(str(v))) for k, v in self.parameters.iteritems()]) | |
# serialize as a url for a GET request | |
def to_url(self): | |
return '%s?%s' % (self.get_normalized_http_url(), self.to_postdata()) | |
# return a string that consists of all the parameters that need to be signed | |
def get_normalized_parameters(self): | |
params = self.parameters | |
try: | |
# exclude the signature if it exists | |
del params['oauth_signature'] | |
except: | |
pass | |
key_values = params.items() | |
# sort lexicographically, first after key, then after value | |
key_values.sort() | |
# combine key value pairs in string and escape | |
return '&'.join(['%s=%s' % (escape(str(k)), escape(str(v))) for k, v in key_values]) | |
# just uppercases the http method | |
def get_normalized_http_method(self): | |
return self.http_method.upper() | |
# parses the url and rebuilds it to be scheme://host/path | |
def get_normalized_http_url(self): | |
parts = urlparse.urlparse(self.http_url) | |
host = parts[1].lower() | |
if host.endswith(':80') or host.endswith(':443'): | |
host = host.split(':')[0] | |
url_string = '%s://%s%s' % (parts[0], host, parts[2]) # scheme, netloc, path | |
return url_string | |
# set the signature parameter to the result of build_signature | |
def sign_request(self, signature_method, consumer, token): | |
# set the signature method | |
self.set_parameter('oauth_signature_method', signature_method.get_name()) | |
# set the signature | |
self.set_parameter('oauth_signature', self.build_signature(signature_method, consumer, token)) | |
def build_signature(self, signature_method, consumer, token): | |
# call the build signature method within the signature method | |
return signature_method.build_signature(self, consumer, token) | |
def from_request(http_method, http_url, headers=None, parameters=None, query_string=None): | |
# combine multiple parameter sources | |
if parameters is None: | |
parameters = {} | |
# headers | |
if headers and 'Authorization' in headers: | |
auth_header = headers['Authorization'] | |
# check that the authorization header is OAuth | |
if auth_header.index('OAuth') > -1: | |
try: | |
# get the parameters from the header | |
header_params = OAuthRequest._split_header(auth_header) | |
parameters.update(header_params) | |
except: | |
raise OAuthError('Unable to parse OAuth parameters from Authorization header.') | |
# GET or POST query string | |
if query_string: | |
query_params = OAuthRequest._split_url_string(query_string) | |
parameters.update(query_params) | |
# URL parameters | |
param_str = urlparse.urlparse(http_url)[4] # query | |
url_params = OAuthRequest._split_url_string(param_str) | |
parameters.update(url_params) | |
if parameters: | |
return OAuthRequest(http_method, http_url, parameters) | |
return None | |
from_request = staticmethod(from_request) | |
def from_consumer_and_token(oauth_consumer, token=None, http_method=HTTP_METHOD, http_url=None, parameters=None): | |
if not parameters: | |
parameters = {} | |
defaults = { | |
'oauth_consumer_key': oauth_consumer.key, | |
'oauth_timestamp': generate_timestamp(), | |
'oauth_nonce': generate_nonce(), | |
'oauth_version': OAuthRequest.version, | |
} | |
defaults.update(parameters) | |
parameters = defaults | |
if token: | |
parameters['oauth_token'] = token.key | |
return OAuthRequest(http_method, http_url, parameters) | |
from_consumer_and_token = staticmethod(from_consumer_and_token) | |
def from_token_and_callback(token, callback=None, http_method=HTTP_METHOD, http_url=None, parameters=None): | |
if not parameters: | |
parameters = {} | |
parameters['oauth_token'] = token.key | |
if callback: | |
parameters['oauth_callback'] = callback | |
return OAuthRequest(http_method, http_url, parameters) | |
from_token_and_callback = staticmethod(from_token_and_callback) | |
# util function: turn Authorization: header into parameters, has to do some unescaping | |
def _split_header(header): | |
params = {} | |
parts = header[6:].split(',') | |
for param in parts: | |
# ignore realm parameter | |
if param.find('realm') > -1: | |
continue | |
# remove whitespace | |
param = param.strip() | |
# split key-value | |
param_parts = param.split('=', 1) | |
# remove quotes and unescape the value | |
params[param_parts[0]] = urllib.unquote(param_parts[1].strip('\"')) | |
return params | |
_split_header = staticmethod(_split_header) | |
# util function: turn url string into parameters, has to do some unescaping | |
# even empty values should be included | |
def _split_url_string(param_str): | |
parameters = cgi.parse_qs(param_str, keep_blank_values=True) | |
for k, v in parameters.iteritems(): | |
parameters[k] = urllib.unquote(v[0]) | |
return parameters | |
_split_url_string = staticmethod(_split_url_string) | |
# OAuthServer is a worker to check a requests validity against a data store | |
class OAuthServer(object): | |
timestamp_threshold = 300 # in seconds, five minutes | |
version = VERSION | |
signature_methods = None | |
data_store = None | |
def __init__(self, data_store=None, signature_methods=None): | |
self.data_store = data_store | |
self.signature_methods = signature_methods or {} | |
def set_data_store(self, oauth_data_store): | |
self.data_store = oauth_data_store | |
def get_data_store(self): | |
return self.data_store | |
def add_signature_method(self, signature_method): | |
self.signature_methods[signature_method.get_name()] = signature_method | |
return self.signature_methods | |
# process a request_token request | |
# returns the request token on success | |
def fetch_request_token(self, oauth_request): | |
try: | |
# get the request token for authorization | |
token = self._get_token(oauth_request, 'request') | |
except OAuthError: | |
# no token required for the initial token request | |
version = self._get_version(oauth_request) | |
consumer = self._get_consumer(oauth_request) | |
self._check_signature(oauth_request, consumer, None) | |
# fetch a new token | |
token = self.data_store.fetch_request_token(consumer) | |
return token | |
# process an access_token request | |
# returns the access token on success | |
def fetch_access_token(self, oauth_request): | |
version = self._get_version(oauth_request) | |
consumer = self._get_consumer(oauth_request) | |
# get the request token | |
token = self._get_token(oauth_request, 'request') | |
self._check_signature(oauth_request, consumer, token) | |
new_token = self.data_store.fetch_access_token(consumer, token) | |
return new_token | |
# verify an api call, checks all the parameters | |
def verify_request(self, oauth_request): | |
# -> consumer and token | |
version = self._get_version(oauth_request) | |
consumer = self._get_consumer(oauth_request) | |
# get the access token | |
token = self._get_token(oauth_request, 'access') | |
self._check_signature(oauth_request, consumer, token) | |
parameters = oauth_request.get_nonoauth_parameters() | |
return consumer, token, parameters | |
# authorize a request token | |
def authorize_token(self, token, user): | |
return self.data_store.authorize_request_token(token, user) | |
# get the callback url | |
def get_callback(self, oauth_request): | |
return oauth_request.get_parameter('oauth_callback') | |
# optional support for the authenticate header | |
def build_authenticate_header(self, realm=''): | |
return {'WWW-Authenticate': 'OAuth realm="%s"' % realm} | |
# verify the correct version request for this server | |
def _get_version(self, oauth_request): | |
try: | |
version = oauth_request.get_parameter('oauth_version') | |
except: | |
version = VERSION | |
if version and version != self.version: | |
raise OAuthError('OAuth version %s not supported.' % str(version)) | |
return version | |
# figure out the signature with some defaults | |
def _get_signature_method(self, oauth_request): | |
try: | |
signature_method = oauth_request.get_parameter('oauth_signature_method') | |
except: | |
signature_method = SIGNATURE_METHOD | |
try: | |
# get the signature method object | |
signature_method = self.signature_methods[signature_method] | |
except: | |
signature_method_names = ', '.join(self.signature_methods.keys()) | |
raise OAuthError('Signature method %s not supported try one of the following: %s' % (signature_method, signature_method_names)) | |
return signature_method | |
def _get_consumer(self, oauth_request): | |
consumer_key = oauth_request.get_parameter('oauth_consumer_key') | |
if not consumer_key: | |
raise OAuthError('Invalid consumer key.') | |
consumer = self.data_store.lookup_consumer(consumer_key) | |
if not consumer: | |
raise OAuthError('Invalid consumer.') | |
return consumer | |
# try to find the token for the provided request token key | |
def _get_token(self, oauth_request, token_type='access'): | |
token_field = oauth_request.get_parameter('oauth_token') | |
consumer = self._get_consumer(oauth_request) | |
token = self.data_store.lookup_token(consumer, token_type, token_field) | |
if not token: | |
raise OAuthError('Invalid %s token: %s' % (token_type, token_field)) | |
return token | |
def _check_signature(self, oauth_request, consumer, token): | |
timestamp, nonce = oauth_request._get_timestamp_nonce() | |
self._check_timestamp(timestamp) | |
self._check_nonce(consumer, token, nonce) | |
signature_method = self._get_signature_method(oauth_request) | |
try: | |
signature = oauth_request.get_parameter('oauth_signature') | |
except: | |
raise OAuthError('Missing signature.') | |
# validate the signature | |
valid_sig = signature_method.check_signature(oauth_request, consumer, token, signature) | |
if not valid_sig: | |
key, base = signature_method.build_signature_base_string(oauth_request, consumer, token) | |
raise OAuthError('Invalid signature. Expected signature base string: %s' % base) | |
built = signature_method.build_signature(oauth_request, consumer, token) | |
def _check_timestamp(self, timestamp): | |
# verify that timestamp is recentish | |
timestamp = int(timestamp) | |
now = int(time.time()) | |
lapsed = now - timestamp | |
if lapsed > self.timestamp_threshold: | |
raise OAuthError('Expired timestamp: given %d and now %s has a greater difference than threshold %d' % (timestamp, now, self.timestamp_threshold)) | |
def _check_nonce(self, consumer, token, nonce): | |
# verify that the nonce is uniqueish | |
nonce = self.data_store.lookup_nonce(consumer, token, nonce) | |
if nonce: | |
raise OAuthError('Nonce already used: %s' % str(nonce)) | |
# OAuthClient is a worker to attempt to execute a request | |
class OAuthClient(object): | |
consumer = None | |
token = None | |
def __init__(self, oauth_consumer, oauth_token): | |
self.consumer = oauth_consumer | |
self.token = oauth_token | |
def get_consumer(self): | |
return self.consumer | |
def get_token(self): | |
return self.token | |
def fetch_request_token(self, oauth_request): | |
# -> OAuthToken | |
raise NotImplementedError | |
def fetch_access_token(self, oauth_request): | |
# -> OAuthToken | |
raise NotImplementedError | |
def access_resource(self, oauth_request): | |
# -> some protected resource | |
raise NotImplementedError | |
# OAuthDataStore is a database abstraction used to lookup consumers and tokens | |
class OAuthDataStore(object): | |
def lookup_consumer(self, key): | |
# -> OAuthConsumer | |
raise NotImplementedError | |
def lookup_token(self, oauth_consumer, token_type, token_token): | |
# -> OAuthToken | |
raise NotImplementedError | |
def lookup_nonce(self, oauth_consumer, oauth_token, nonce, timestamp): | |
# -> OAuthToken | |
raise NotImplementedError | |
def fetch_request_token(self, oauth_consumer): | |
# -> OAuthToken | |
raise NotImplementedError | |
def fetch_access_token(self, oauth_consumer, oauth_token): | |
# -> OAuthToken | |
raise NotImplementedError | |
def authorize_request_token(self, oauth_token, user): | |
# -> OAuthToken | |
raise NotImplementedError | |
# OAuthSignatureMethod is a strategy class that implements a signature method | |
class OAuthSignatureMethod(object): | |
def get_name(self): | |
# -> str | |
raise NotImplementedError | |
def build_signature_base_string(self, oauth_request, oauth_consumer, oauth_token): | |
# -> str key, str raw | |
raise NotImplementedError | |
def build_signature(self, oauth_request, oauth_consumer, oauth_token): | |
# -> str | |
raise NotImplementedError | |
def check_signature(self, oauth_request, consumer, token, signature): | |
built = self.build_signature(oauth_request, consumer, token) | |
return built == signature | |
class OAuthSignatureMethod_HMAC_SHA1(OAuthSignatureMethod): | |
def get_name(self): | |
return 'HMAC-SHA1' | |
def build_signature_base_string(self, oauth_request, consumer, token): | |
sig = ( | |
escape(oauth_request.get_normalized_http_method()), | |
escape(oauth_request.get_normalized_http_url()), | |
escape(oauth_request.get_normalized_parameters()), | |
) | |
key = '%s&' % escape(consumer.secret) | |
if token: | |
key += escape(token.secret) | |
raw = '&'.join(sig) | |
return key, raw | |
def build_signature(self, oauth_request, consumer, token): | |
# build the base signature string | |
key, raw = self.build_signature_base_string(oauth_request, consumer, token) | |
# hmac object | |
try: | |
import hashlib # 2.5 | |
hashed = hmac.new(key, raw, hashlib.sha1) | |
except: | |
import sha # deprecated | |
hashed = hmac.new(key, raw, sha) | |
# calculate the digest base 64 | |
return binascii.b2a_base64(hashed.digest())[:-1] | |
class OAuthSignatureMethod_PLAINTEXT(OAuthSignatureMethod): | |
def get_name(self): | |
return 'PLAINTEXT' | |
def build_signature_base_string(self, oauth_request, consumer, token): | |
# concatenate the consumer key and secret | |
sig = escape(consumer.secret) + '&' | |
if token: | |
sig = sig + escape(token.secret) | |
return sig | |
def build_signature(self, oauth_request, consumer, token): | |
return self.build_signature_base_string(oauth_request, consumer, token) |