summaryrefslogtreecommitdiffstats
path: root/crypto.py
blob: a278c25bf339b61991028ac8199da29c45f362c3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import serialization
from .models import db, OIDCKey, OIDCClient
import jwt
import datetime


def generate_rsa_key():
    key = rsa.generate_private_key(
        public_exponent=65537,
        key_size=2048,
    )

    private_pem = key.private_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PrivateFormat.PKCS8,
        encryption_algorithm=serialization.NoEncryption(),
    )

    public_pem = key.public_key().public_bytes(
        encoding=serialization.Encoding.PEM,
        format=serialization.PublicFormat.SubjectPublicKeyInfo,
    )

    return private_pem, public_pem


def sign_jwt(payload, issuer, client_id):
    key = OIDCKey.query.filter_by(client_id=client_id).order_by(OIDCKey.created.desc()).first()
    if not key:
        raise ValueError("No RSA key is active for this client")

    claims = {
        "sub": payload["sub"],
        "iss": issuer,
        "aud": [client_id],
        "iat": datetime.datetime.utcnow(),
        "exp": datetime.datetime.utcnow() + datetime.timedelta(hours=1),
        "email": payload.get("email", ""),
        "name": payload.get("name", ""),
    }

    token = jwt.encode(claims, key.private_pem.encode("utf-8"), algorithm="RS256")

    return token