diff --git a/app/blueprints/github/__init__.py b/app/blueprints/github/__init__.py index f5962b6..8b68a43 100644 --- a/app/blueprints/github/__init__.py +++ b/app/blueprints/github/__init__.py @@ -69,19 +69,17 @@ def callback(oauth_token): if userByGithub is None: flash("Unable to find an account for that Github user", "danger") return redirect(url_for("users.claim_forums")) - elif login_user_set_active(userByGithub, remember=True): - addAuditLog(AuditSeverity.USER, userByGithub, "Logged in using GitHub OAuth", - url_for("users.profile", username=userByGithub.username)) - db.session.commit() - if not current_user.password: - return redirect(next_url or url_for("users.set_password", optional=True)) - else: - return redirect(next_url or url_for("homepage.home")) - else: + ret = login_user_set_active(userByGithub, remember=True) + if ret is None: flash("Authorization failed [err=gh-login-failed]", "danger") return redirect(url_for("users.login")) + addAuditLog(AuditSeverity.USER, userByGithub, "Logged in using GitHub OAuth", + url_for("users.profile", username=userByGithub.username)) + db.session.commit() + return ret + @bp.route("/github/webhook/", methods=["POST"]) @csrf.exempt diff --git a/app/blueprints/users/account.py b/app/blueprints/users/account.py index 2841bdf..6c4de4d 100644 --- a/app/blueprints/users/account.py +++ b/app/blueprints/users/account.py @@ -15,6 +15,7 @@ # along with this program. If not, see . + from flask import * from flask_login import current_user, login_required, logout_user, login_user from flask_wtf import FlaskForm @@ -24,7 +25,7 @@ from wtforms.validators import * from app.models import * from app.tasks.emails import send_verify_email, send_anon_email, send_unsubscribe_verify, send_user_email -from app.utils import randomString, make_flask_login_password, is_safe_url, check_password_hash, addAuditLog, nonEmptyOrNone +from app.utils import randomString, make_flask_login_password, is_safe_url, check_password_hash, addAuditLog, nonEmptyOrNone, post_login from passlib.pwd import genphrase from . import bp @@ -61,14 +62,11 @@ def handle_login(form): url_for("users.profile", username=user.username)) db.session.commit() - login_user(user, remember=form.remember_me.data) - flash("Logged in successfully.", "success") + if not login_user(user, remember=form.remember_me.data): + flash("Login failed", "danger") + return - next = request.args.get("next") - if next and not is_safe_url(next): - abort(400) - - return redirect(next or url_for("homepage.home")) + return post_login(user, request.args.get("next")) @bp.route("/user/login/", methods=["GET", "POST"]) diff --git a/app/blueprints/users/claim.py b/app/blueprints/users/claim.py index f067435..5ba3a08 100644 --- a/app/blueprints/users/claim.py +++ b/app/blueprints/users/claim.py @@ -105,12 +105,13 @@ def claim_forums(): db.session.add(user) db.session.commit() - if login_user_set_active(user, remember=True): - return redirect(url_for("users.set_password")) - else: + ret = login_user_set_active(user, remember=True) + if ret is None: flash("Unable to login as user", "danger") return redirect(url_for("users.claim_forums", username=username)) + return ret + else: flash("Could not find the key in your signature!", "danger") return redirect(url_for("users.claim_forums", username=username)) diff --git a/app/utils/user.py b/app/utils/user.py index 8a8ea69..c3ff511 100644 --- a/app/utils/user.py +++ b/app/utils/user.py @@ -19,9 +19,10 @@ from functools import wraps from flask_login import login_user, current_user from passlib.handlers.bcrypt import bcrypt -from flask import redirect, url_for, abort +from flask import redirect, url_for, abort, flash from app.models import User, UserRank, UserNotificationPreferences, db +from app.utils import is_safe_url def check_password_hash(stored, given): @@ -35,14 +36,33 @@ def make_flask_login_password(plaintext): return bcrypt.hash(plaintext.encode("UTF-8")) -def login_user_set_active(user: User, *args, **kwargs): +def post_login(user: User, next_url): + if next_url and is_safe_url(next_url): + return redirect(next_url) + + if not current_user.password: + return redirect(url_for("users.set_password", optional=True)) + + notif_count = len(user.notifications) + if notif_count > 0: + if notif_count >= 10: + flash("You have a lot of notifications, you should either read or clear them", "info") + return redirect(url_for("notifications.list_all")) + + return redirect(url_for("homepage.home")) + + +def login_user_set_active(user: User, next_url: str = None, *args, **kwargs): if user.rank == UserRank.NOT_JOINED and user.email is None: user.rank = UserRank.MEMBER user.notification_preferences = UserNotificationPreferences(user) user.is_active = True db.session.commit() - return login_user(user, *args, **kwargs) + if login_user(user, *args, **kwargs): + return post_login(user, next_url) + + return None def rank_required(rank):