274 lines
10 KiB
Python
274 lines
10 KiB
Python
# app.py
|
|
# A lightweight OIDC Provider that uses Discord as an upstream identity provider.
|
|
# This application is designed to work with the PeerTube OpenID Connect plugin.
|
|
|
|
import os
|
|
import time
|
|
import json
|
|
import base64
|
|
from urllib.parse import urlencode
|
|
from collections import deque
|
|
from datetime import datetime
|
|
|
|
import requests
|
|
from flask import Flask, request, jsonify, redirect, render_template
|
|
# This is the corrected import section. We import from 'jose'.
|
|
from jose import jwt, jwk
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
|
|
app = Flask(__name__)
|
|
|
|
# --- In-memory store for recent requests ---
|
|
# A deque is a double-ended queue with a fixed size to prevent memory leaks.
|
|
recent_requests = deque(maxlen=20)
|
|
|
|
# --- Configuration from Environment Variables ---
|
|
# Load all settings from the .env file or system environment.
|
|
BASE_URL = os.environ.get('OIDC_PROVIDER_URL')
|
|
DISCORD_CLIENT_ID = os.environ.get('DISCORD_CLIENT_ID')
|
|
DISCORD_CLIENT_SECRET = os.environ.get('DISCORD_CLIENT_SECRET')
|
|
OIDC_CLIENT_ID = os.environ.get('OIDC_CLIENT_ID', 'peertube')
|
|
OIDC_CLIENT_SECRET = os.environ.get('OIDC_CLIENT_SECRET', 'peertube-secret')
|
|
PEERTUBE_CALLBACK_URL = os.environ.get('PEERTUBE_CALLBACK_URL')
|
|
|
|
# This URL must match what's in your Discord Developer Portal settings
|
|
DISCORD_REDIRECT_URI = f"{BASE_URL}/discord/callback" if BASE_URL else None
|
|
|
|
# Generate an RSA key pair for signing ID Tokens in memory.
|
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
|
public_key = private_key.public_key()
|
|
|
|
# --- Logging Helper ---
|
|
def log_request(endpoint, status, error_message=None):
|
|
"""Logs an incoming request to our in-memory store."""
|
|
recent_requests.appendleft({
|
|
"timestamp": datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC'),
|
|
"endpoint": endpoint,
|
|
"ip": request.remote_addr,
|
|
"status": status,
|
|
"error": error_message
|
|
})
|
|
|
|
# --- OIDC and JWT Helper Functions ---
|
|
def get_jwks():
|
|
"""Generates the JWKS (JSON Web Key Set) required for token validation."""
|
|
pem = public_key.public_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PublicFormat.SubjectPublicKeyInfo
|
|
)
|
|
# Use python-jose to construct the key dictionary
|
|
key = jwk.construct(pem, algorithm='RS256')
|
|
|
|
return {
|
|
"keys": [{
|
|
"kty": "RSA",
|
|
"use": "sig",
|
|
"kid": "1", # Key ID
|
|
**key.to_dict()
|
|
}]
|
|
}
|
|
|
|
def create_id_token(user_info, discord_guilds):
|
|
"""Creates a signed ID Token (JWT)."""
|
|
now = int(time.time())
|
|
|
|
claims = {
|
|
"iss": BASE_URL,
|
|
"sub": user_info['id'],
|
|
"aud": OIDC_CLIENT_ID,
|
|
"exp": now + 3600,
|
|
"iat": now,
|
|
"email": user_info['email'],
|
|
"email_verified": user_info.get('verified', False),
|
|
"preferred_username": f"{user_info['username']}_{user_info['discriminator']}",
|
|
"name": user_info.get('global_name') or user_info['username'],
|
|
"groups": [guild['id'] for guild in discord_guilds]
|
|
}
|
|
|
|
private_pem = private_key.private_bytes(
|
|
encoding=serialization.Encoding.PEM,
|
|
format=serialization.PrivateFormat.PKCS8,
|
|
encryption_algorithm=serialization.NoEncryption()
|
|
)
|
|
|
|
# Use jwt.encode from python-jose
|
|
return jwt.encode(claims, private_pem, algorithm="RS256", headers={"kid": "1"})
|
|
|
|
# --- Discord API Interaction ---
|
|
def exchange_discord_code(code):
|
|
"""Exchanges a Discord auth code for an access token."""
|
|
data = {
|
|
'client_id': DISCORD_CLIENT_ID,
|
|
'client_secret': DISCORD_CLIENT_SECRET,
|
|
'grant_type': 'authorization_code',
|
|
'code': code,
|
|
'redirect_uri': DISCORD_REDIRECT_URI
|
|
}
|
|
headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
|
r = requests.post('https://discord.com/api/oauth2/token', data=data, headers=headers)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
def get_discord_user_info(access_token):
|
|
"""Fetches user profile from Discord."""
|
|
headers = {'Authorization': f'Bearer {access_token}'}
|
|
r = requests.get('https://discord.com/api/users/@me', headers=headers)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
def get_discord_user_guilds(access_token):
|
|
"""Fetches the list of servers the user is in."""
|
|
headers = {'Authorization': f'Bearer {access_token}'}
|
|
r = requests.get('https://discord.com/api/users/@me/guilds', headers=headers)
|
|
r.raise_for_status()
|
|
return r.json()
|
|
|
|
# --- Root and Status Routes ---
|
|
@app.route('/')
|
|
def root_redirect():
|
|
"""Redirects the root URL to the status page for convenience."""
|
|
return redirect('/status')
|
|
|
|
@app.route('/status')
|
|
def status_endpoint():
|
|
"""Provides a status page for diagnostics."""
|
|
# Check environment variables
|
|
env_checks = {
|
|
'OIDC_PROVIDER_URL': bool(BASE_URL),
|
|
'DISCORD_CLIENT_ID': bool(DISCORD_CLIENT_ID),
|
|
'DISCORD_CLIENT_SECRET': 'Set' if DISCORD_CLIENT_SECRET else 'Not Set',
|
|
'PEERTUBE_CALLBACK_URL': bool(PEERTUBE_CALLBACK_URL)
|
|
}
|
|
|
|
# Check Discord API connectivity
|
|
discord_api_status = 'Unknown'
|
|
try:
|
|
r = requests.get('https://discord.com/api/v10/gateway', timeout=5)
|
|
if r.status_code == 200:
|
|
discord_api_status = 'OK'
|
|
else:
|
|
discord_api_status = f"Error - Status Code: {r.status_code}"
|
|
except requests.exceptions.RequestException as e:
|
|
discord_api_status = f"Failed to connect: {e}"
|
|
|
|
return render_template('status.html', env_checks=env_checks, discord_api_status=discord_api_status, requests=list(recent_requests))
|
|
|
|
|
|
# --- OIDC Provider Endpoints ---
|
|
@app.route('/.well-known/openid-configuration')
|
|
def discovery_endpoint():
|
|
"""Serves the OIDC discovery document."""
|
|
log_request('/.well-known/openid-configuration', 'OK')
|
|
return jsonify({
|
|
"issuer": BASE_URL,
|
|
"authorization_endpoint": f"{BASE_URL}/authorize",
|
|
"token_endpoint": f"{BASE_URL}/token",
|
|
"userinfo_endpoint": f"{BASE_URL}/userinfo",
|
|
"jwks_uri": f"{BASE_URL}/jwks.json",
|
|
"response_types_supported": ["code"],
|
|
"subject_types_supported": ["public"],
|
|
"id_token_signing_alg_values_supported": ["RS256"],
|
|
"scopes_supported": ["openid", "profile", "email", "groups"],
|
|
"token_endpoint_auth_methods_supported": ["client_secret_post", "client_secret_basic"],
|
|
"claims_supported": ["sub", "iss", "aud", "exp", "iat", "email", "email_verified", "preferred_username", "name", "groups"]
|
|
})
|
|
|
|
@app.route('/jwks.json')
|
|
def jwks_endpoint():
|
|
"""Serves the JSON Web Key Set."""
|
|
log_request('/jwks.json', 'OK')
|
|
return jsonify(get_jwks())
|
|
|
|
@app.route('/authorize')
|
|
def authorize_endpoint():
|
|
"""Starts the login flow, redirecting to Discord."""
|
|
log_request('/authorize', 'Redirecting to Discord')
|
|
discord_auth_params = {
|
|
'client_id': DISCORD_CLIENT_ID,
|
|
'redirect_uri': DISCORD_REDIRECT_URI,
|
|
'response_type': 'code',
|
|
'scope': 'identify email guilds',
|
|
'state': request.args.get('state')
|
|
}
|
|
return redirect(f"https://discord.com/api/oauth2/authorize?{urlencode(discord_auth_params)}")
|
|
|
|
@app.route('/discord/callback')
|
|
def discord_callback_endpoint():
|
|
"""Handles the callback from Discord and redirects back to PeerTube."""
|
|
code = request.args.get('code')
|
|
state = request.args.get('state')
|
|
if not code:
|
|
return "Error: Discord callback missing code.", 400
|
|
return redirect(f"{PEERTUBE_CALLBACK_URL}?code={code}&state={state}")
|
|
|
|
@app.route('/token', methods=['POST'])
|
|
def token_endpoint():
|
|
"""Exchanges the authorization code for tokens."""
|
|
if (request.form.get('client_id') != OIDC_CLIENT_ID or
|
|
request.form.get('client_secret') != OIDC_CLIENT_SECRET):
|
|
log_request('/token', 'FAIL', 'Invalid client_id or client_secret')
|
|
return jsonify({"error": "invalid_client"}), 401
|
|
|
|
try:
|
|
code = request.form.get('code')
|
|
discord_tokens = exchange_discord_code(code)
|
|
discord_access_token = discord_tokens['access_token']
|
|
|
|
user_info = get_discord_user_info(discord_access_token)
|
|
user_guilds = get_discord_user_guilds(discord_access_token)
|
|
|
|
id_token = create_id_token(user_info, user_guilds)
|
|
access_token = base64.urlsafe_b64encode(os.urandom(32)).decode('utf-8')
|
|
|
|
with open(f"/tmp/{access_token}.json", "w") as f:
|
|
json.dump({**user_info, "groups": [g['id'] for g in user_guilds]}, f)
|
|
|
|
log_request('/token', 'OK')
|
|
return jsonify({
|
|
'access_token': access_token,
|
|
'token_type': 'Bearer',
|
|
'expires_in': 3600,
|
|
'id_token': id_token,
|
|
})
|
|
except Exception as e:
|
|
error_msg = str(e)
|
|
log_request('/token', 'FAIL', error_msg)
|
|
print(f"Error in /token endpoint: {e}")
|
|
return jsonify({"error": "server_error"}), 500
|
|
|
|
@app.route('/userinfo', methods=['GET', 'POST'])
|
|
def userinfo_endpoint():
|
|
"""Provides user information to PeerTube."""
|
|
auth_header = request.headers.get('Authorization')
|
|
if not auth_header or not auth_header.startswith('Bearer '):
|
|
log_request('/userinfo', 'FAIL', 'Missing or malformed Authorization header')
|
|
return jsonify({"error": "invalid_token"}), 401
|
|
|
|
access_token = auth_header.split(' ')[1]
|
|
|
|
try:
|
|
with open(f"/tmp/{access_token}.json", "r") as f:
|
|
user_data = json.load(f)
|
|
|
|
log_request('/userinfo', 'OK')
|
|
return jsonify({
|
|
"sub": user_data['id'],
|
|
"email": user_data['email'],
|
|
"preferred_username": f"{user_data['username']}_{user_data['discriminator']}",
|
|
"name": user_data.get('global_name') or user_data['username'],
|
|
"groups": user_data.get('groups', [])
|
|
})
|
|
except FileNotFoundError:
|
|
log_request('/userinfo', 'FAIL', 'Invalid access token')
|
|
return jsonify({"error": "invalid_token"}), 401
|
|
|
|
if __name__ == '__main__':
|
|
if not all([BASE_URL, DISCORD_CLIENT_ID, DISCORD_CLIENT_SECRET, PEERTUBE_CALLBACK_URL]):
|
|
print("FATAL ERROR: One or more required environment variables are not set.")
|
|
print("Please check your .env file.")
|
|
else:
|
|
app.run(host='0.0.0.0', port=5000)
|