summaryrefslogtreecommitdiffstats
path: root/oidc.py
diff options
context:
space:
mode:
Diffstat (limited to 'oidc.py')
-rw-r--r--oidc.py206
1 files changed, 206 insertions, 0 deletions
diff --git a/oidc.py b/oidc.py
new file mode 100644
index 0000000..71a187c
--- /dev/null
+++ b/oidc.py
@@ -0,0 +1,206 @@
+import time
+import secrets
+import hashlib
+import base64
+import jwt
+import json
+import urllib.parse
+from flask import Blueprint, request, redirect, jsonify
+from CTFd.utils.user import authed, get_current_user
+from CTFd.models import Users
+from CTFd.plugins import bypass_csrf_protection
+from cryptography.hazmat.primitives import serialization
+from .models import db, OIDCClient, OIDCAuthCode, OIDCRefreshToken, OIDCAccessToken, OIDCKey
+from .crypto import sign_jwt
+from .config import get_config
+
+oidc_blueprint = Blueprint(
+ "oidc",
+ __name__,
+ url_prefix="/oidc"
+)
+
+
+@oidc_blueprint.route("/.well-known/openid-configuration")
+def discovery():
+ issuer = get_config("base_url", "")
+ return jsonify({
+ "issuer": f"{issuer}/oidc/",
+ "authorization_endpoint": f"{issuer}/oidc/authorize",
+ "token_endpoint": f"{issuer}/oidc/token",
+ "userinfo_endpoint": f"{issuer}/oidc/userinfo",
+ "jwks_uri": f"{issuer}/oidc/jwks",
+ "response_types_supported": ["code"],
+ "grant_types_supported": ["authorization_code", "refresh_token"],
+ "code_challenge_methods_supported": ["S256"],
+ "token_endpoint_auth_methods_supported": ["client_secret_basic"],
+ "id_token_signing_alg_values_supported": ["RS256"],
+ "scopes_supported": ["openid", "profile", "email"], #TODO maybe support these for real
+ })
+
+
+@oidc_blueprint.route("/jwks")
+def jwks():
+ keys = []
+
+ for key in OIDCKey.query.all():
+ public_key = serialization.load_pem_public_key(key.public_pem.encode("utf-8"))
+ jwk = json.loads(jwt.algorithms.RSAAlgorithm.to_jwk(public_key))
+ jwk["kid"] = key.kid
+ keys.append(jwk)
+
+ return jsonify({"keys": keys})
+
+
+@oidc_blueprint.route("/authorize")
+def authorize():
+ if not authed():
+ return redirect(f"/login?next={urllib.parse.quote(request.full_path)}")
+
+ client_id = request.args.get("client_id")
+ redirect_uri = request.args.get("redirect_uri")
+ state = request.args.get("state")
+ code_challenge = request.args.get("code_challenge")
+
+ client = OIDCClient.query.filter_by(client_id=client_id).first()
+ if not client or redirect_uri not in str.splitlines(client.redirect_uris):
+ return "Invalid client", 400
+
+ if not code_challenge and client.pkce:
+ return "PKCE required", 400
+
+ code = secrets.token_urlsafe(32)
+
+ db.session.add(OIDCAuthCode(
+ code=code,
+ user_id=get_current_user().id,
+ client_id=client.client_id,
+ redirect_uri=redirect_uri,
+ code_challenge=code_challenge,
+ exp=int(time.time()) + 300,
+ ))
+ db.session.commit()
+
+ return redirect(f"{redirect_uri}?code={code}&state={state}")
+
+
+@oidc_blueprint.route("/token", methods=["POST"])
+@bypass_csrf_protection
+def token():
+ auth_header = request.headers.get("Authorization")
+ if not auth_header or not auth_header.startswith("Basic "):
+ return jsonify({"error": "invalid_client"}), 401
+
+ try:
+ b64 = auth_header.split(" ", 1)[1]
+ decoded = base64.b64decode(b64).decode("utf-8")
+ client_id, client_secret = decoded.split(":", 1)
+ except Exception:
+ return jsonify({"error": "invalid_client"}), 401
+
+ client = OIDCClient.query.filter_by(client_id=client_id).first()
+ if not client or client.client_secret != client_secret:
+ return jsonify({"error": "invalid_client"}), 401
+
+ issuer = get_config("base_url", "") + "/oidc/"
+
+ def create_tokens(thing):
+ user = thing.user
+ db.session.delete(thing)
+
+ now = int(time.time())
+
+ refresh_token = secrets.token_urlsafe(48)
+ db.session.add(OIDCRefreshToken(
+ refresh_token=refresh_token,
+ user_id=user.id,
+ client_id=client.client_id,
+ exp=now + 86400,
+ ))
+
+ access_token = secrets.token_urlsafe(48)
+ db.session.add(OIDCAccessToken(
+ access_token=access_token,
+ user_id=user.id,
+ client_id=client.client_id,
+ exp=now + 3600,
+ ))
+
+ db.session.commit()
+
+ id_token = sign_jwt(
+ {"sub": str(user.id), "email": user.email, "name": user.name},
+ issuer,
+ client.client_id,
+ )
+
+ return jsonify({
+ "access_token": access_token,
+ "id_token": id_token,
+ "refresh_token": refresh_token,
+ "token_type": "Bearer",
+ "expires_in": 3600,
+ })
+
+ if request.form["grant_type"] == "authorization_code":
+ code = request.form.get("code")
+ redirect_uri = request.form.get("redirect_uri")
+
+ authcode = OIDCAuthCode.query.filter_by(code=code, client_id=client.client_id, redirect_uri=redirect_uri).first()
+ if not authcode or authcode.exp < time.time():
+ return jsonify({"error": "invalid_grant"}), 400
+
+ if client.pkce:
+ verifier = request.form.get("code_verifier")
+ digest = hashlib.sha256(verifier.encode()).digest()
+ challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
+
+ if challenge != authcode.code_challenge:
+ return jsonify({"error": "invalid_grant"}), 400
+
+ return create_tokens(authcode)
+
+ if request.form["grant_type"] == "refresh_token":
+ old_refresh = request.form["refresh_token"]
+
+ token = OIDCRefreshToken.query.filter_by(refresh_token=old_refresh, client_id=client.client_id).first()
+
+ if (not token or token.exp < time.time()):
+ return jsonify({"error": "invalid_grant"}), 400
+
+ return create_tokens(token)
+
+ return jsonify({"error": "unsupported_grant_type"}), 400
+
+
+@oidc_blueprint.route("/userinfo")
+def userinfo():
+ auth = request.headers.get("Authorization", "")
+ if not auth.startswith("Bearer "):
+ return jsonify({"error": "invalid_token"}), 401
+
+ access_token = auth.split(" ", 1)[1]
+
+ token = OIDCAccessToken.query.filter_by(access_token=access_token).first()
+
+ if not token or token.exp < time.time():
+ return jsonify({"error": "invalid_token"}), 401
+
+ user = Users.query.get(token.user_id)
+ if not user:
+ return jsonify({"error": "invalid_token"}), 401
+
+ team_id = None
+ team_name = None
+
+ if hasattr(user, "team") and user.team:
+ team_id = str(user.team.id)
+ team_name = user.team.name
+
+ return jsonify({
+ "sub": str(user.id),
+ "email": user.email,
+ "name": user.name,
+ "team_id": team_id,
+ "team_name": team_name,
+ })