-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathauth.py
More file actions
123 lines (97 loc) · 3.86 KB
/
Copy pathauth.py
File metadata and controls
123 lines (97 loc) · 3.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import annotations
import uuid as uuid_mod
from datetime import UTC, datetime, timedelta
from typing import Any
import structlog
from argon2 import PasswordHasher
from argon2.exceptions import VerifyMismatchError
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt # type: ignore[import-untyped]
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from python_api.db.models import User
from python_api.utils.config import get_settings
from python_api.utils.errors import InvalidTokenTypeError, TokenDecodeError
logger = structlog.get_logger()
ph = PasswordHasher()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/signin")
def verify_password(plain_password: str, hashed_password: str) -> bool:
try:
return ph.verify(hashed_password, plain_password)
except VerifyMismatchError:
return False
def get_password_hash(password: str) -> str:
return str(ph.hash(password))
def create_access_token(
data: dict[str, Any],
expires_delta: timedelta | None = None,
) -> str:
settings = get_settings()
to_encode = data.copy()
expire = datetime.now(UTC) + (
expires_delta or timedelta(hours=settings.jwt_access_token_expire_hours)
)
to_encode.update({"exp": expire, "type": "access"})
return str(jwt.encode(to_encode, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
def create_refresh_token(data: dict[str, Any]) -> str:
settings = get_settings()
to_encode = data.copy()
expire = datetime.now(UTC) + timedelta(days=30)
to_encode.update({"exp": expire, "type": "refresh"})
return str(jwt.encode(to_encode, settings.jwt_secret_key, algorithm=settings.jwt_algorithm))
def decode_token(token: str, expected_type: str = "access") -> dict[str, Any]:
settings = get_settings()
try:
payload = jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm])
if payload.get("type") != expected_type:
raise InvalidTokenTypeError(
message="Invalid token type",
detail={"expected": expected_type, "received": payload.get("type")},
)
return dict(payload)
except JWTError:
raise TokenDecodeError(message="Could not validate credentials") from None
async def get_current_user(
request: Request,
token: str = Depends(oauth2_scheme),
) -> User:
payload = decode_token(token)
user_id: str | None = payload.get("sub")
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
from python_api.api.dependencies import get_redis
r = get_redis(request)
is_blacklisted = await r.exists(f"blacklist:{token}")
if is_blacklisted:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token has been revoked",
headers={"WWW-Authenticate": "Bearer"},
)
from python_api.db.session import get_session_factory
try:
user_uuid = uuid_mod.UUID(user_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid user ID",
headers={"WWW-Authenticate": "Bearer"},
) from None
session_factory = get_session_factory()
async with session_factory() as session:
result = await session.execute(
select(User).options(joinedload(User.role)).where(User.id == user_uuid)
)
user = result.scalar_one_or_none()
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User not found",
headers={"WWW-Authenticate": "Bearer"},
)
return user