diff --git a/ecnuopenapi/auth.py b/ecnuopenapi/auth.py new file mode 100644 index 0000000..cb4328a --- /dev/null +++ b/ecnuopenapi/auth.py @@ -0,0 +1,37 @@ +import secrets +import requests +from requests_oauthlib import OAuth2Session +from ecnuopenapi.oauth_init import GetOpenAPIClient + + +def getAuthorizationEndpoint(state): + c = GetOpenAPIClient() + oauth2_session = OAuth2Session(c.client_id, redirect_uri=c.redirect_url, scope=c.scopes) + authorization_url, state = oauth2_session.authorization_url(c.auth_url, state=state) + return authorization_url + + +def generateState(): + return secrets.token_hex(16) + + +def getToken(code, state): + c = GetOpenAPIClient() + oauth2_session = OAuth2Session(c.client_id, redirect_uri=c.redirect_url, state=state) + try: + token = oauth2_session.fetch_token(c.token_url, client_secret=c.client_secret, code=code) + return token + except Exception as e: + print("获取token失败:", e) + return None + + +def getUserInfo(token): + c = GetOpenAPIClient() + headers = {'Authorization': f'Bearer {token["access_token"]}'} + response = requests.get(c.user_info_url, headers=headers) + if response.status_code == 200: + return response.json() + else: + print("获取用户信息失败") + return None diff --git a/ecnuopenapi/oauth_init.py b/ecnuopenapi/oauth_init.py index f510ee9..6b15b3d 100644 --- a/ecnuopenapi/oauth_init.py +++ b/ecnuopenapi/oauth_init.py @@ -6,10 +6,17 @@ DEFAULT_SCOPE = "ECNU-Basic" DEFAULT_BASE_URL = "https://api.ecnu.edu.cn" DEFAULT_TIMEOUT = 10 +DEFAULT_USER_INFO_URL = "https://api.ecnu.edu.cn/oauth2/userinfo" +DEFAULT_AUTH_URL = "https://api.ecnu.edu.cn/oauth2/authorize" +DEFAULT_TOKEN_URL = "https://api.ecnu.edu.cn/oauth2/token" + open_api_client = None + class OAuth2Config: - def __init__(self, client_id, client_secret, base_url=DEFAULT_BASE_URL, scopes=[DEFAULT_SCOPE], timeout=DEFAULT_TIMEOUT, debug=False): + def __init__(self, client_id, client_secret, redirect_url=None, base_url=DEFAULT_BASE_URL, scopes=[DEFAULT_SCOPE], + timeout=DEFAULT_TIMEOUT, auth_url=DEFAULT_AUTH_URL, token_url=DEFAULT_TOKEN_URL, + user_info_url=DEFAULT_USER_INFO_URL, debug=False): self.client_id = client_id self.client_secret = client_secret self.base_url = base_url @@ -17,33 +24,48 @@ def __init__(self, client_id, client_secret, base_url=DEFAULT_BASE_URL, scopes=[ self.timeout = timeout self.debug = debug + self.redirect_url = redirect_url + self.auth_url = auth_url + self.token_url = token_url + self.user_info_url = user_info_url + + class OAuth2Client: def __init__(self, config): self.client_id = config.client_id self.client_secret = config.client_secret - self.token_url = config.base_url + '/oauth2/token' + self.token_url = config.token_url if config.base_url is None else config.base_url + "/oauth2/token" self.base_url = config.base_url self.token_expiration = None self.oauth2_session = self.createOauth2Session() self.debug = config.debug self.RetryCount = 0 + self.auth_url = config.auth_url + self.auth_url = config.auth_url + self.user_info_url = config.user_info_url + self.redirect_url = config.redirect_url + self.scopes = config.scopes + def createOauth2Session(self): client = BackendApplicationClient(client_id=self.client_id) oauth2_session = OAuth2Session(client=client) return oauth2_session def getAccessToken(self): - if self.token_expiration is None : - token = self.oauth2_session.fetch_token(self.token_url, client_id=self.client_id, client_secret=self.client_secret) + if self.token_expiration is None: + token = self.oauth2_session.fetch_token(self.token_url, client_id=self.client_id, + client_secret=self.client_secret) self.token_expiration = datetime.datetime.now() + datetime.timedelta(seconds=token.get("expires_in", 3600)) # 小于600秒刷新--- if (self.token_expiration - datetime.datetime.now()).total_seconds() <= 600: - token = self.oauth2_session.fetch_token(self.token_url, client_id=self.client_id, client_secret=self.client_secret) + token = self.oauth2_session.fetch_token(self.token_url, client_id=self.client_id, + client_secret=self.client_secret) return self.oauth2_session.access_token def refreshAccessToken(self): - token = self.oauth2_session.fetch_token(self.token_url, client_id=self.client_id, client_secret=self.client_secret) + token = self.oauth2_session.fetch_token(self.token_url, client_id=self.client_id, + client_secret=self.client_secret) self.token_expiration = datetime.datetime.now() + datetime.timedelta(seconds=token.get("expires_in", 3600)) def retryAdd(self): @@ -52,10 +74,16 @@ def retryAdd(self): def retryReset(self): self.RetryCount = 0 + def initOauth2ClientCredentials(config): global open_api_client open_api_client = OAuth2Client(config) + +def initOAuth2AuthorizationCode(config): + global open_api_client + open_api_client = OAuth2Client(config) + + def GetOpenAPIClient(): return open_api_client - diff --git a/example/example_auth_code.py b/example/example_auth_code.py new file mode 100644 index 0000000..873cc30 --- /dev/null +++ b/example/example_auth_code.py @@ -0,0 +1,53 @@ +from functools import wraps + +from flask import Flask, redirect, request, jsonify +import urllib.parse as urlparse +import webbrowser +from ecnuopenapi import auth +from ecnuopenapi import oauth_init + +app = Flask(__name__) +# 初始化OAuth客户端配置 +config = oauth_init.OAuth2Config( + client_id='client_id', + client_secret='client_secret', + redirect_url='http://localhost:8080/user', +) +oauth_init.initOAuth2AuthorizationCode(config) + + +def oauth_required(f): + @wraps(f) + def decorated_function(*args, **kwargs): + code = request.args.get('code') + state = request.args.get('state') + + token = auth.getToken(code, state) + if not token: + return jsonify({'error': 'Failed to fetch token.'}), 500 + + user_info = auth.getUserInfo(token) + if not user_info: + return jsonify({'error': 'Failed to fetch user info.'}), 500 + + kwargs['user_info'] = user_info + return f(*args, **kwargs) + + return decorated_function + + +@app.route('/login') +def login(): + state = auth.generateState() + auth_url = auth.getAuthorizationEndpoint(state) + return redirect(auth_url) + + +@app.route('/user') +@oauth_required +def user_info(user_info): + return jsonify(user_info) + + +if __name__ == '__main__': + app.run(port=8080) diff --git a/example/example_auth_code_simple.py b/example/example_auth_code_simple.py new file mode 100644 index 0000000..4b0cd07 --- /dev/null +++ b/example/example_auth_code_simple.py @@ -0,0 +1,59 @@ +from http.server import BaseHTTPRequestHandler, HTTPServer +import urllib.parse as urlparse +import webbrowser +from ecnuopenapi import auth +from ecnuopenapi import oauth_init + +# 初始化OAuth客户端配置 +config = oauth_init.OAuth2Config( + client_id='client_id', + client_secret='client_secret', + redirect_url='http://localhost:8080/user', +) +oauth_init.initOAuth2AuthorizationCode(config) + + +class SimpleServer(BaseHTTPRequestHandler): + def do_GET(self): + if self.path.startswith('/login'): + state = auth.generateState() + authorization_url = auth.getAuthorizationEndpoint(state) + self.send_response(302) + self.send_header('Location', authorization_url) + self.end_headers() + elif self.path.startswith('/user'): + parsed_path = urlparse.urlparse(self.path) + query = urlparse.parse_qs(parsed_path.query) + code = query.get('code', [None])[0] + state = query.get('state', [None])[0] + + token = auth.getToken(code, state) + if not token: + self._send_text('Failed to fetch token', 400) + return + + user_info = auth.getUserInfo(token) + if not user_info: + self._send_text('Failed to fetch user info', 400) + return + + self._send_text(f'User Info: {user_info}', 200) + else: + self._send_text('Not Found', 404) + + def _send_text(self, text, status_code=200): + self.send_response(status_code) + self.send_header('Content-type', 'text/plain; charset=utf-8') + self.end_headers() + self.wfile.write(text.encode()) + + +def run(server_class=HTTPServer, handler_class=SimpleServer, port=8080): + server_address = ('', port) + httpd = server_class(server_address, handler_class) + print(f'Starting httpd on port {port}...') + httpd.serve_forever() + + +if __name__ == "__main__": + run()