Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate from apns2 to aioapns #721

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.7', '3.8', '3.9']
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ repos:
rev: v3.15.2
hooks:
- id: pyupgrade
args:
- --keep-mock # for AsyncMock in 3.7
260 changes: 163 additions & 97 deletions push_notifications/apns.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏼

Original file line number Diff line number Diff line change
Expand Up @@ -4,100 +4,165 @@
https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/APNSOverview.html
"""

import asyncio
import contextlib
import tempfile
import time

from apns2 import client as apns2_client
from apns2 import credentials as apns2_credentials
from apns2 import errors as apns2_errors
from apns2 import payload as apns2_payload
import aioapns
from aioapns.common import APNS_RESPONSE_CODE, PRIORITY_HIGH, PRIORITY_NORMAL
from asgiref.sync import async_to_sync

from . import models
from .conf import get_manager
from .exceptions import APNSError, APNSUnsupportedPriority, APNSServerError

from .exceptions import APNSError, APNSServerError, APNSUnsupportedPriority


SUCCESS_RESULT = "Success"
UNREGISTERED_RESULT = "Unregistered"


@contextlib.contextmanager
def _apns_path_for_cert(cert):
if cert is None:
yield None
with tempfile.NamedTemporaryFile("w") as cert_file:
cert_file.write(cert)
cert_file.flush()
yield cert_file.name


def _apns_create_client(application_id=None):
cert = None
key_path = None
key_id = None
team_id = None

if not get_manager().has_auth_token_creds(application_id):
cert = get_manager().get_apns_certificate(application_id)
with _apns_path_for_cert(cert) as cert_path:
client = aioapns.APNs(
client_cert=cert_path,
team_id=team_id,
topic=get_manager().get_apns_topic(application_id),
use_sandbox=get_manager().get_apns_use_sandbox(application_id),
)
else:
key_path, key_id, team_id = get_manager().get_apns_auth_creds(application_id)
client = aioapns.APNs(
key=key_path,
key_id=key_id,
team_id=team_id,
topic=get_manager().get_apns_topic(application_id),
use_sandbox=get_manager().get_apns_use_sandbox(application_id),
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found two errors. The topic isn't being sent when creating the client and the context processor raises an exception if there's no credential path. I tested it afterwards and works great with Django 4.2 and Python 3.11

if not get_manager().has_auth_token_creds(application_id): cert = get_manager().get_apns_certificate(application_id) with _apns_path_for_cert(cert) as cert_path: client = aioapns.APNs( client_cert=cert_path, team_id=team_id, topic=get_manager().get_apns_topic(application_id), use_sandbox=get_manager().get_apns_use_sandbox(application_id), ) else: key_path, key_id, team_id = get_manager().get_apns_auth_creds(application_id) client = aioapns.APNs( key=key_path, key_id=key_id, team_id=team_id, topic=get_manager().get_apns_topic(application_id), use_sandbox=get_manager().get_apns_use_sandbox(application_id), )

Copy link

@aalbinati aalbinati May 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome. Thank you @aalbinati. Doing some testing with the service was on my to-do, I'll pull in the changes you're suggesting soon.


def _apns_create_socket(creds=None, application_id=None):
if creds is None:
if not get_manager().has_auth_token_creds(application_id):
cert = get_manager().get_apns_certificate(application_id)
creds = apns2_credentials.CertificateCredentials(cert)
else:
keyPath, keyId, teamId = get_manager().get_apns_auth_creds(application_id)
# No use getting a lifetime because this credential is
# ephemeral, but if you're looking at this to see how to
# create a credential, you could also pass the lifetime and
# algorithm. Neither of those settings are exposed in the
# settings API at the moment.
creds = creds or apns2_credentials.TokenCredentials(keyPath, keyId, teamId)
client = apns2_client.APNsClient(
creds,
use_sandbox=get_manager().get_apns_use_sandbox(application_id),
use_alternative_port=get_manager().get_apns_use_alternative_port(application_id)
)
client.connect()
return client


def _apns_prepare(
token, alert, application_id=None, badge=None, sound=None, category=None,
content_available=False, action_loc_key=None, loc_key=None, loc_args=[],
extra={}, mutable_content=False, thread_id=None, url_args=None):
if action_loc_key or loc_key or loc_args:
apns2_alert = apns2_payload.PayloadAlert(
body=alert if alert else {}, body_localized_key=loc_key,
body_localized_args=loc_args, action_localized_key=action_loc_key)
else:
apns2_alert = alert

if callable(badge):
badge = badge(token)

return apns2_payload.Payload(
alert=apns2_alert, badge=badge, sound=sound, category=category,
url_args=url_args, custom=extra, thread_id=thread_id,
content_available=content_available, mutable_content=mutable_content)


def _apns_send(
registration_id, alert, batch=False, application_id=None, creds=None, **kwargs
token,
alert,
application_id=None,
badge=None,
sound=None,
category=None,
content_available=False,
action_loc_key=None,
loc_key=None,
loc_args=[],
extra={},
mutable_content=False,
thread_id=None,
url_args=None,
):
client = _apns_create_socket(creds=creds, application_id=application_id)

notification_kwargs = {}

# if expiration isn"t specified use 1 month from now
notification_kwargs["expiration"] = kwargs.pop("expiration", None)
if not notification_kwargs["expiration"]:
notification_kwargs["expiration"] = int(time.time()) + 2592000

priority = kwargs.pop("priority", None)
if priority:
try:
notification_kwargs["priority"] = apns2_client.NotificationPriority(str(priority))
except ValueError:
raise APNSUnsupportedPriority("Unsupported priority %d" % (priority))

notification_kwargs["collapse_id"] = kwargs.pop("collapse_id", None)

if batch:
data = [apns2_client.Notification(
token=rid, payload=_apns_prepare(rid, alert, **kwargs)) for rid in registration_id]
# returns a dictionary mapping each token to its result. That
# result is either "Success" or the reason for the failure.
return client.send_notification_batch(
data, get_manager().get_apns_topic(application_id=application_id),
**notification_kwargs
if action_loc_key or loc_key or loc_args:
alert_payload = {
"body": alert if alert else {},
"body_localized_key": loc_key,
"body_localized_args": loc_args,
"action_localized_key": action_loc_key,
}
else:
alert_payload = alert

if callable(badge):
badge = badge(token)

return {
"alert": alert_payload,
"badge": badge,
"sound": sound,
"category": category,
"url_args": url_args,
"custom": extra,
"thread_id": thread_id,
"content_available": content_available,
"mutable_content": mutable_content,
}


@async_to_sync
async def _apns_send(
registration_ids,
alert,
application_id=None,
*,
priority=None,
expiration=None,
collapse_id=None,
**kwargs,
):
"""Make calls to APNs for each device token (registration_id) provided.

Since the underlying library (aioapns) is asynchronous, we are
taking advantage of that here and making the requests in parallel.
"""
client = _apns_create_client(application_id=application_id)

# if expiration isn't specified use 1 month from now
# converting to ttl for underlying library
if expiration:
time_to_live = expiration - int(time.time())
else:
time_to_live = 2592000

Comment on lines +124 to +129
Copy link
Member

@50-Course 50-Course May 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice catch here, Looks good! The code handles the case of an unspecified expiration clearly. I might have gone with the ternary approach as an improvement:

time_to_live = expiration - int(time.timeit()) if expiration else 2592000

if priority is not None:
if str(priority) not in [PRIORITY_HIGH, PRIORITY_NORMAL]:
raise APNSUnsupportedPriority(f"Unsupported priority {priority}")

# track which device token belongs to each coroutine.
# this allows us to stitch the results back together later
coro_registration_ids = {}
for registration_id in set(registration_ids):
coro = client.send_notification(
aioapns.NotificationRequest(
device_token=registration_id,
message={"aps": _apns_prepare(registration_id, alert, **kwargs)},
time_to_live=time_to_live,
priority=priority,
collapse_key=collapse_id,
)
)
coro_registration_ids[asyncio.create_task(coro)] = registration_id

# run all of the tasks. this will resolve once all requests are complete
done, _ = await asyncio.wait(coro_registration_ids.keys())

# recombine task results with their device tokens
results = {}
for coro in done:
registration_id = coro_registration_ids[coro]
result = await coro
if result.is_successful:
results[registration_id] = SUCCESS_RESULT
else:
results[registration_id] = result.description

data = _apns_prepare(registration_id, alert, **kwargs)
client.send_notification(
registration_id, data,
get_manager().get_apns_topic(application_id=application_id),
**notification_kwargs
)
return results


def apns_send_message(registration_id, alert, application_id=None, creds=None, **kwargs):
def apns_send_message(registration_id, alert, application_id=None, **kwargs):
"""
Sends an APNS notification to a single registration_id.
This will send the notification as form data.
Expand All @@ -109,23 +174,21 @@ def apns_send_message(registration_id, alert, application_id=None, creds=None, *
to this for silent notifications.
"""

try:
_apns_send(
registration_id, alert, application_id=application_id,
creds=creds, **kwargs
)
except apns2_errors.APNsException as apns2_exception:
if isinstance(apns2_exception, apns2_errors.Unregistered):
device = models.APNSDevice.objects.get(registration_id=registration_id)
device.active = False
device.save()
results = _apns_send(
[registration_id], alert, application_id=application_id, **kwargs
)
result = results[registration_id]

raise APNSServerError(status=apns2_exception.__class__.__name__)
if result == SUCCESS_RESULT:
return
if result == UNREGISTERED_RESULT:
models.APNSDevice.objects.filter(registration_id=registration_id).update(
active=False
)
raise APNSServerError(status=result)


def apns_send_bulk_message(
registration_ids, alert, application_id=None, creds=None, **kwargs
):
def apns_send_bulk_message(registration_ids, alert, application_id=None, **kwargs):
"""
Sends an APNS notification to one or more registration_ids.
The registration_ids argument needs to be a list.
Expand All @@ -136,9 +199,12 @@ def apns_send_bulk_message(
"""

results = _apns_send(
registration_ids, alert, batch=True, application_id=application_id,
creds=creds, **kwargs
registration_ids, alert, application_id=application_id, **kwargs
)
inactive_tokens = [
token for token, result in results.items() if result == UNREGISTERED_RESULT
]
models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update(
active=False
)
inactive_tokens = [token for token, result in results.items() if result == "Unregistered"]
models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update(active=False)
return results
9 changes: 5 additions & 4 deletions push_notifications/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from django.db import models
from django.utils.translation import gettext_lazy as _

from .apns import apns_send_bulk_message
from .fields import HexIntegerField
from .settings import PUSH_NOTIFICATIONS_SETTINGS as SETTINGS

Expand Down Expand Up @@ -133,7 +134,7 @@ def get_queryset(self):


class APNSDeviceQuerySet(models.query.QuerySet):
def send_message(self, message, creds=None, **kwargs):
def send_message(self, message, **kwargs):
if self.exists():
from .apns import apns_send_bulk_message

Expand All @@ -146,7 +147,7 @@ def send_message(self, message, creds=None, **kwargs):
)
r = apns_send_bulk_message(
registration_ids=reg_ids, alert=message, application_id=app_id,
creds=creds, **kwargs
**kwargs
)
if hasattr(r, "keys"):
res += [r]
Expand All @@ -169,13 +170,13 @@ class APNSDevice(Device):
class Meta:
verbose_name = _("APNS device")

def send_message(self, message, creds=None, **kwargs):
def send_message(self, message, **kwargs):
from .apns import apns_send_message

return apns_send_message(
registration_id=self.registration_id,
alert=message,
application_id=self.application_id, creds=creds,
application_id=self.application_id,
**kwargs
)

Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ setup_requires =

[options.extras_require]
APNS =
apns2>=0.3.0
aioapns
asgiref>=2.0
importlib-metadata;python_version < "3.8"
Django>=2.2

WP = pywebpush>=1.3.0

Expand Down
Loading
Loading