Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/).

## [Unreleased]

### Added

- VK ID OAuth2 backend.

### Removed

- Discontinued OAuth backends: AppsFuel, Beats Music, ChangeTip, Clef,
Expand Down
160 changes: 158 additions & 2 deletions social_core/backends/vk.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@

from __future__ import annotations

import base64
import json
from hashlib import md5
from hashlib import md5, sha256
from time import time
from typing import Any, cast

from social_core.exceptions import AuthException, AuthTokenRevoked
from social_core.exceptions import (
AuthException,
AuthFailed,
AuthMissingParameter,
AuthTokenRevoked,
)
from social_core.utils import parse_qs

from .base import BaseAuth
Expand Down Expand Up @@ -172,6 +178,156 @@
return None


class VKIDOAuth2(BaseOAuth2):
"""VK ID OAuth2 authentication backend"""

name = "vk-id"
ID_KEY = "id"
AUTHORIZATION_URL = "https://id.vk.ru/authorize"
ACCESS_TOKEN_URL = "https://id.vk.ru/oauth2/auth"
USER_INFO_URL = "https://id.vk.ru/oauth2/user_info"
REDIRECT_STATE = False
STATE_PARAMETER = True
SCOPE_SEPARATOR = " "
EXTRA_DATA = [
("id", "id"),
("user_id", "user_id"),
("expires_in", "expires_in"),
("refresh_token", "refresh_token"),
("id_token", "id_token"),
("scope", "scope"),
("device_id", "device_id"),
]

def code_verifier_session_key(self, state: str | None) -> str:
return f"{self.name}_code_verifier_{state or 'default'}"

def generate_code_verifier(self) -> str:
return self.strategy.random_string(128)

def generate_code_challenge(self, code_verifier: str) -> str:
digest = sha256(code_verifier.encode("ascii")).digest()
return base64.urlsafe_b64encode(digest).decode("ascii").rstrip("=")

def auth_params(self, state: str | None = None) -> dict[str, str]:
params = super().auth_params(state)
code_verifier = self.generate_code_verifier()
self.strategy.session_set(self.code_verifier_session_key(state), code_verifier)
params.update(
{
"code_challenge": self.generate_code_challenge(code_verifier),
"code_challenge_method": "S256",
}
)
return params

def callback_data(self) -> dict[str, Any]:
data = dict(self.data.items())
payload = data.get("payload")
if isinstance(payload, list):
payload = payload[0] if payload else None
if isinstance(payload, str) and payload:
try:
parsed_payload = json.loads(payload)
except json.JSONDecodeError as exc:
raise AuthFailed(self, "Invalid VK ID payload") from exc
if isinstance(parsed_payload, dict):
data.update(parsed_payload)
return data

def get_request_state(self):
request_state = self.callback_data().get("state")
if request_state and isinstance(request_state, list):
request_state = request_state[0]
return request_state

def auth_complete_params(self, state=None):
data = self.callback_data()
code = data.get("code")
device_id = data.get("device_id")
if not code:
raise AuthMissingParameter(self, "code")
if not device_id:
raise AuthMissingParameter(self, "device_id")
self._callback_device_id = device_id

Check warning on line 252 in social_core/backends/vk.py

View workflow job for this annotation

GitHub Actions / pylint

W0201

Attribute '_callback_device_id' defined outside __init__

code_verifier = self.strategy.session_pop(self.code_verifier_session_key(state))
if not code_verifier:
raise AuthMissingParameter(self, "code_verifier")

client_id, _client_secret = self.get_key_and_secret()
params = {
"grant_type": "authorization_code",
"code": code,
"code_verifier": code_verifier,
"client_id": client_id,
"device_id": device_id,
"redirect_uri": self.get_redirect_uri(state),
}
if state:
params["state"] = state
return params

def user_data(self, access_token: str, *args, **kwargs) -> dict[str, Any] | None:
response = kwargs.get("response") or {}
client_id, _client_secret = self.get_key_and_secret()
data = self.get_json(
self.USER_INFO_URL,
method="POST",
headers=self.auth_headers(),
data={
"access_token": access_token,
"client_id": client_id,
},
)
self.process_error(data)

user = data.get("user") if isinstance(data.get("user"), dict) else data
if not isinstance(user, dict):
return {}

user_id = (
user.get("user_id")
or user.get("id")
or response.get("user_id")
or response.get("id")
)
first_name = user.get("first_name") or user.get("firstName") or ""
last_name = user.get("last_name") or user.get("lastName") or ""
avatar = (
user.get("avatar")
or user.get("photo")
or user.get("photo_200")
or user.get("picture")
)

return {
**user,
"id": str(user_id) if user_id is not None else None,
"user_id": user_id,
"first_name": first_name,
"last_name": last_name,
"email": user.get("email") or response.get("email", ""),
"user_photo": avatar,
"photo": avatar,
"device_id": response.get("device_id")
or getattr(self, "_callback_device_id", None),
}

def get_user_details(self, response):
fullname, first_name, last_name = self.get_user_names(
first_name=response.get("first_name"),
last_name=response.get("last_name"),
)
return {
"username": "",
"email": response.get("email", ""),
"fullname": fullname,
"first_name": first_name,
"last_name": last_name,
}


class VKAppOAuth2(VKOAuth2):
"""VK.com Application Authentication support"""

Expand Down
118 changes: 118 additions & 0 deletions social_core/tests/backends/test_vk.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import json
from typing import TYPE_CHECKING, Any, cast

import responses

from social_core.utils import get_querystring, parse_qs

from .oauth import BaseAuthUrlTestMixin, OAuth2Test

if TYPE_CHECKING:
from social_core.tests.models import User


class VKOAuth2Test(OAuth2Test, BaseAuthUrlTestMixin):
backend_path = "social_core.backends.vk.VKOAuth2"
Expand All @@ -28,3 +36,113 @@ def test_login(self) -> None:

def test_partial_pipeline(self) -> None:
self.do_partial_pipeline()


class VKIDOAuth2Test(OAuth2Test, BaseAuthUrlTestMixin):
backend_path = "social_core.backends.vk.VKIDOAuth2"
raw_complete_url = "/complete/{0}/?code=foobar&device_id=device-id"
user_data_url = "https://id.vk.ru/oauth2/user_info"
user_data_url_post = True
expected_username = "pavel@example.com"
access_token_body = json.dumps(
{
"access_token": "foobar",
"token_type": "bearer",
"expires_in": 3600,
"refresh_token": "refresh",
"id_token": "id-token",
"user_id": 1,
"device_id": "token-device-id",
}
)
user_data_body = json.dumps(
{
"user": {
"user_id": 1,
"first_name": "Павел",
"last_name": "Дуров",
"email": "pavel@example.com",
"avatar": "https://example.com/avatar.jpg",
}
}
)

def extra_settings(self) -> dict[str, Any]:
settings: dict[str, Any] = super().extra_settings()
settings[f"SOCIAL_AUTH_{self.name}_USERNAME_IS_FULL_EMAIL"] = True
return settings

def test_login(self) -> None:
user = self.do_login()
social = user.social[0]

auth_request = next(
r.request
for r in responses.calls
if cast("str", r.request.url).startswith(self.backend.authorization_url())
)
auth_query = get_querystring(cast("str", auth_request.url))

token_request = next(
r.request
for r in responses.calls
if cast("str", r.request.url).startswith(self.backend.access_token_url())
)
token_data = parse_qs(token_request.body)
self.assertEqual(token_data["client_id"], "a-key")
self.assertNotIn("client_secret", token_data)
self.assertEqual(token_data["device_id"], "device-id")
self.assertEqual(token_data["state"], auth_query["state"])
self.assertEqual(
self.backend.generate_code_challenge(token_data["code_verifier"]),
auth_query["code_challenge"],
)

user_info_request = next(
r.request
for r in responses.calls
if cast("str", r.request.url).startswith(self.backend.USER_INFO_URL)
)
user_info_data = parse_qs(user_info_request.body)
self.assertEqual(user_info_data["access_token"], "foobar")
self.assertEqual(user_info_data["client_id"], "a-key")
self.assertEqual(social.extra_data["device_id"], "token-device-id")

def test_partial_pipeline(self) -> None:
self.do_partial_pipeline()

def test_login_with_payload_callback(self) -> None:
start_url = self.backend.start().url
state = get_querystring(start_url)["state"]
payload = json.dumps(
{
"code": "foobar",
"device_id": "payload-device-id",
"state": state,
}
)
self.strategy.set_request_data({"payload": payload}, self.backend)
responses.add(
self._method(self.backend.ACCESS_TOKEN_METHOD),
self.backend.access_token_url(),
status=200,
body=self.access_token_body,
content_type="application/json",
)
responses.add(
responses.POST,
self.user_data_url,
body=self.user_data_body,
content_type=self.user_data_content_type,
)

user = cast("User", self.backend.complete())

token_request = next(
r.request
for r in responses.calls
if cast("str", r.request.url).startswith(self.backend.access_token_url())
)
token_data = parse_qs(token_request.body)
self.assertEqual(token_data["device_id"], "payload-device-id")
self.assertEqual(user.username, self.expected_username)
Loading