diff options
Diffstat (limited to 'oidc.py')
| -rw-r--r-- | oidc.py | 206 |
1 files changed, 206 insertions, 0 deletions
@@ -0,0 +1,206 @@ +import time +import secrets +import hashlib +import base64 +import jwt +import json +import urllib.parse +from flask import Blueprint, request, redirect, jsonify +from CTFd.utils.user import authed, get_current_user +from CTFd.models import Users +from CTFd.plugins import bypass_csrf_protection +from cryptography.hazmat.primitives import serialization +from .models import db, OIDCClient, OIDCAuthCode, OIDCRefreshToken, OIDCAccessToken, OIDCKey +from .crypto import sign_jwt +from .config import get_config + +oidc_blueprint = Blueprint( + "oidc", + __name__, + url_prefix="/oidc" +) + + +@oidc_blueprint.route("/.well-known/openid-configuration") +def discovery(): + issuer = get_config("base_url", "") + return jsonify({ + "issuer": f"{issuer}/oidc/", + "authorization_endpoint": f"{issuer}/oidc/authorize", + "token_endpoint": f"{issuer}/oidc/token", + "userinfo_endpoint": f"{issuer}/oidc/userinfo", + "jwks_uri": f"{issuer}/oidc/jwks", + "response_types_supported": ["code"], + "grant_types_supported": ["authorization_code", "refresh_token"], + "code_challenge_methods_supported": ["S256"], + "token_endpoint_auth_methods_supported": ["client_secret_basic"], + "id_token_signing_alg_values_supported": ["RS256"], + "scopes_supported": ["openid", "profile", "email"], #TODO maybe support these for real + }) + + +@oidc_blueprint.route("/jwks") +def jwks(): + keys = [] + + for key in OIDCKey.query.all(): + public_key = serialization.load_pem_public_key(key.public_pem.encode("utf-8")) + jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key)) + jwk["kid"] = key.kid + keys.append(jwk) + + return jsonify({"keys": keys}) + + +@oidc_blueprint.route("/authorize") +def authorize(): + if not authed(): + return redirect(f"/login?next={urllib.parse.quote(request.full_path)}") + + client_id = request.args.get("client_id") + redirect_uri = request.args.get("redirect_uri") + state = request.args.get("state") + code_challenge = request.args.get("code_challenge") + + client = OIDCClient.query.filter_by(client_id=client_id).first() + if not client or redirect_uri not in str.splitlines(client.redirect_uris): + return "Invalid client", 400 + + if not code_challenge and client.pkce: + return "PKCE required", 400 + + code = secrets.token_urlsafe(32) + + db.session.add(OIDCAuthCode( + code=code, + user_id=get_current_user().id, + client_id=client.client_id, + redirect_uri=redirect_uri, + code_challenge=code_challenge, + exp=int(time.time()) + 300, + )) + db.session.commit() + + return redirect(f"{redirect_uri}?code={code}&state={state}") + + +@oidc_blueprint.route("/token", methods=["POST"]) +@bypass_csrf_protection +def token(): + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Basic "): + return jsonify({"error": "invalid_client"}), 401 + + try: + b64 = auth_header.split(" ", 1)[1] + decoded = base64.b64decode(b64).decode("utf-8") + client_id, client_secret = decoded.split(":", 1) + except Exception: + return jsonify({"error": "invalid_client"}), 401 + + client = OIDCClient.query.filter_by(client_id=client_id).first() + if not client or client.client_secret != client_secret: + return jsonify({"error": "invalid_client"}), 401 + + issuer = get_config("base_url", "") + "/oidc/" + + def create_tokens(thing): + user = thing.user + db.session.delete(thing) + + now = int(time.time()) + + refresh_token = secrets.token_urlsafe(48) + db.session.add(OIDCRefreshToken( + refresh_token=refresh_token, + user_id=user.id, + client_id=client.client_id, + exp=now + 86400, + )) + + access_token = secrets.token_urlsafe(48) + db.session.add(OIDCAccessToken( + access_token=access_token, + user_id=user.id, + client_id=client.client_id, + exp=now + 3600, + )) + + db.session.commit() + + id_token = sign_jwt( + {"sub": str(user.id), "email": user.email, "name": user.name}, + issuer, + client.client_id, + ) + + return jsonify({ + "access_token": access_token, + "id_token": id_token, + "refresh_token": refresh_token, + "token_type": "Bearer", + "expires_in": 3600, + }) + + if request.form["grant_type"] == "authorization_code": + code = request.form.get("code") + redirect_uri = request.form.get("redirect_uri") + + authcode = OIDCAuthCode.query.filter_by(code=code, client_id=client.client_id, redirect_uri=redirect_uri).first() + if not authcode or authcode.exp < time.time(): + return jsonify({"error": "invalid_grant"}), 400 + + if client.pkce: + verifier = request.form.get("code_verifier") + digest = hashlib.sha256(verifier.encode()).digest() + challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + + if challenge != authcode.code_challenge: + return jsonify({"error": "invalid_grant"}), 400 + + return create_tokens(authcode) + + if request.form["grant_type"] == "refresh_token": + old_refresh = request.form["refresh_token"] + + token = OIDCRefreshToken.query.filter_by(refresh_token=old_refresh, client_id=client.client_id).first() + + if (not token or token.exp < time.time()): + return jsonify({"error": "invalid_grant"}), 400 + + return create_tokens(token) + + return jsonify({"error": "unsupported_grant_type"}), 400 + + +@oidc_blueprint.route("/userinfo") +def userinfo(): + auth = request.headers.get("Authorization", "") + if not auth.startswith("Bearer "): + return jsonify({"error": "invalid_token"}), 401 + + access_token = auth.split(" ", 1)[1] + + token = OIDCAccessToken.query.filter_by(access_token=access_token).first() + + if not token or token.exp < time.time(): + return jsonify({"error": "invalid_token"}), 401 + + user = Users.query.get(token.user_id) + if not user: + return jsonify({"error": "invalid_token"}), 401 + + team_id = None + team_name = None + + if hasattr(user, "team") and user.team: + team_id = str(user.team.id) + team_name = user.team.name + + return jsonify({ + "sub": str(user.id), + "email": user.email, + "name": user.name, + "team_id": team_id, + "team_name": team_name, + }) |
