-
-
Notifications
You must be signed in to change notification settings - Fork 623
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate from apns2 to aioapns #721
base: master
Are you sure you want to change the base?
Changes from 5 commits
2813e70
6b4d226
fef1bab
8d79718
445a216
e116a90
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,3 +12,5 @@ repos: | |
rev: v3.15.2 | ||
hooks: | ||
- id: pyupgrade | ||
args: | ||
- --keep-mock # for AsyncMock in 3.7 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,100 +4,165 @@ | |
https://developer.apple.com/library/content/documentation/NetworkingInternet/Conceptual/RemoteNotificationsPG/APNSOverview.html | ||
""" | ||
|
||
import asyncio | ||
import contextlib | ||
import tempfile | ||
import time | ||
|
||
from apns2 import client as apns2_client | ||
from apns2 import credentials as apns2_credentials | ||
from apns2 import errors as apns2_errors | ||
from apns2 import payload as apns2_payload | ||
import aioapns | ||
from aioapns.common import APNS_RESPONSE_CODE, PRIORITY_HIGH, PRIORITY_NORMAL | ||
from asgiref.sync import async_to_sync | ||
|
||
from . import models | ||
from .conf import get_manager | ||
from .exceptions import APNSError, APNSUnsupportedPriority, APNSServerError | ||
|
||
from .exceptions import APNSError, APNSServerError, APNSUnsupportedPriority | ||
|
||
|
||
SUCCESS_RESULT = "Success" | ||
UNREGISTERED_RESULT = "Unregistered" | ||
|
||
|
||
@contextlib.contextmanager | ||
def _apns_path_for_cert(cert): | ||
if cert is None: | ||
yield None | ||
with tempfile.NamedTemporaryFile("w") as cert_file: | ||
cert_file.write(cert) | ||
cert_file.flush() | ||
yield cert_file.name | ||
|
||
|
||
def _apns_create_client(application_id=None): | ||
cert = None | ||
key_path = None | ||
key_id = None | ||
team_id = None | ||
|
||
if not get_manager().has_auth_token_creds(application_id): | ||
cert = get_manager().get_apns_certificate(application_id) | ||
else: | ||
key_path, key_id, team_id = get_manager().get_apns_auth_creds(application_id) | ||
# No use getting a lifetime because this credential is | ||
# ephemeral, but if you're looking at this to see how to | ||
# create a credential, you could also pass the lifetime and | ||
# algorithm. Neither of those settings are exposed in the | ||
# settings API at the moment. | ||
|
||
with _apns_path_for_cert(cert) as cert_path: | ||
client = aioapns.APNs( | ||
client_cert=cert_path, | ||
key=key_path, | ||
key_id=key_id, | ||
team_id=team_id, | ||
use_sandbox=get_manager().get_apns_use_sandbox(application_id), | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I found two errors. The topic isn't being sent when creating the client and the context processor raises an exception if there's no credential path. I tested it afterwards and works great with Django 4.2 and Python 3.11
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here's a fork with the fix There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome. Thank you @aalbinati. Doing some testing with the service was on my to-do, I'll pull in the changes you're suggesting soon. |
||
|
||
def _apns_create_socket(creds=None, application_id=None): | ||
if creds is None: | ||
if not get_manager().has_auth_token_creds(application_id): | ||
cert = get_manager().get_apns_certificate(application_id) | ||
creds = apns2_credentials.CertificateCredentials(cert) | ||
else: | ||
keyPath, keyId, teamId = get_manager().get_apns_auth_creds(application_id) | ||
# No use getting a lifetime because this credential is | ||
# ephemeral, but if you're looking at this to see how to | ||
# create a credential, you could also pass the lifetime and | ||
# algorithm. Neither of those settings are exposed in the | ||
# settings API at the moment. | ||
creds = creds or apns2_credentials.TokenCredentials(keyPath, keyId, teamId) | ||
client = apns2_client.APNsClient( | ||
creds, | ||
use_sandbox=get_manager().get_apns_use_sandbox(application_id), | ||
use_alternative_port=get_manager().get_apns_use_alternative_port(application_id) | ||
) | ||
client.connect() | ||
return client | ||
|
||
|
||
def _apns_prepare( | ||
token, alert, application_id=None, badge=None, sound=None, category=None, | ||
content_available=False, action_loc_key=None, loc_key=None, loc_args=[], | ||
extra={}, mutable_content=False, thread_id=None, url_args=None): | ||
if action_loc_key or loc_key or loc_args: | ||
apns2_alert = apns2_payload.PayloadAlert( | ||
body=alert if alert else {}, body_localized_key=loc_key, | ||
body_localized_args=loc_args, action_localized_key=action_loc_key) | ||
else: | ||
apns2_alert = alert | ||
|
||
if callable(badge): | ||
badge = badge(token) | ||
|
||
return apns2_payload.Payload( | ||
alert=apns2_alert, badge=badge, sound=sound, category=category, | ||
url_args=url_args, custom=extra, thread_id=thread_id, | ||
content_available=content_available, mutable_content=mutable_content) | ||
|
||
|
||
def _apns_send( | ||
registration_id, alert, batch=False, application_id=None, creds=None, **kwargs | ||
token, | ||
alert, | ||
application_id=None, | ||
badge=None, | ||
sound=None, | ||
category=None, | ||
content_available=False, | ||
action_loc_key=None, | ||
loc_key=None, | ||
loc_args=[], | ||
extra={}, | ||
mutable_content=False, | ||
thread_id=None, | ||
url_args=None, | ||
): | ||
client = _apns_create_socket(creds=creds, application_id=application_id) | ||
|
||
notification_kwargs = {} | ||
|
||
# if expiration isn"t specified use 1 month from now | ||
notification_kwargs["expiration"] = kwargs.pop("expiration", None) | ||
if not notification_kwargs["expiration"]: | ||
notification_kwargs["expiration"] = int(time.time()) + 2592000 | ||
|
||
priority = kwargs.pop("priority", None) | ||
if priority: | ||
try: | ||
notification_kwargs["priority"] = apns2_client.NotificationPriority(str(priority)) | ||
except ValueError: | ||
raise APNSUnsupportedPriority("Unsupported priority %d" % (priority)) | ||
|
||
notification_kwargs["collapse_id"] = kwargs.pop("collapse_id", None) | ||
|
||
if batch: | ||
data = [apns2_client.Notification( | ||
token=rid, payload=_apns_prepare(rid, alert, **kwargs)) for rid in registration_id] | ||
# returns a dictionary mapping each token to its result. That | ||
# result is either "Success" or the reason for the failure. | ||
return client.send_notification_batch( | ||
data, get_manager().get_apns_topic(application_id=application_id), | ||
**notification_kwargs | ||
if action_loc_key or loc_key or loc_args: | ||
alert_payload = { | ||
"body": alert if alert else {}, | ||
"body_localized_key": loc_key, | ||
"body_localized_args": loc_args, | ||
"action_localized_key": action_loc_key, | ||
} | ||
else: | ||
alert_payload = alert | ||
|
||
if callable(badge): | ||
badge = badge(token) | ||
|
||
return { | ||
"alert": alert_payload, | ||
"badge": badge, | ||
"sound": sound, | ||
"category": category, | ||
"url_args": url_args, | ||
"custom": extra, | ||
"thread_id": thread_id, | ||
"content_available": content_available, | ||
"mutable_content": mutable_content, | ||
} | ||
|
||
|
||
@async_to_sync | ||
async def _apns_send( | ||
registration_ids, | ||
alert, | ||
application_id=None, | ||
*, | ||
priority=None, | ||
expiration=None, | ||
collapse_id=None, | ||
**kwargs, | ||
): | ||
"""Make calls to APNs for each device token (registration_id) provided. | ||
|
||
Since the underlying library (aioapns) is asynchronous, we are | ||
taking advantage of that here and making the requests in parallel. | ||
""" | ||
client = _apns_create_client(application_id=application_id) | ||
|
||
# if expiration isn't specified use 1 month from now | ||
# converting to ttl for underlying library | ||
if expiration: | ||
time_to_live = expiration - int(time.time()) | ||
else: | ||
time_to_live = 2592000 | ||
|
||
Comment on lines
+124
to
+129
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice catch here, Looks good! The code handles the case of an unspecified expiration clearly. I might have gone with the ternary approach as an improvement: time_to_live = expiration - int(time.timeit()) if expiration else 2592000 |
||
if priority is not None: | ||
if str(priority) not in [PRIORITY_HIGH, PRIORITY_NORMAL]: | ||
raise APNSUnsupportedPriority(f"Unsupported priority {priority}") | ||
|
||
# track which device token belongs to each coroutine. | ||
# this allows us to stitch the results back together later | ||
coro_registration_ids = {} | ||
for registration_id in set(registration_ids): | ||
coro = client.send_notification( | ||
aioapns.NotificationRequest( | ||
device_token=registration_id, | ||
message={"aps": _apns_prepare(registration_id, alert, **kwargs)}, | ||
time_to_live=time_to_live, | ||
priority=priority, | ||
collapse_key=collapse_id, | ||
) | ||
) | ||
coro_registration_ids[asyncio.create_task(coro)] = registration_id | ||
|
||
# run all of the tasks. this will resolve once all requests are complete | ||
done, _ = await asyncio.wait(coro_registration_ids.keys()) | ||
|
||
# recombine task results with their device tokens | ||
results = {} | ||
for coro in done: | ||
registration_id = coro_registration_ids[coro] | ||
result = await coro | ||
if result.is_successful: | ||
results[registration_id] = SUCCESS_RESULT | ||
else: | ||
results[registration_id] = result.description | ||
|
||
data = _apns_prepare(registration_id, alert, **kwargs) | ||
client.send_notification( | ||
registration_id, data, | ||
get_manager().get_apns_topic(application_id=application_id), | ||
**notification_kwargs | ||
) | ||
return results | ||
|
||
|
||
def apns_send_message(registration_id, alert, application_id=None, creds=None, **kwargs): | ||
def apns_send_message(registration_id, alert, application_id=None, **kwargs): | ||
""" | ||
Sends an APNS notification to a single registration_id. | ||
This will send the notification as form data. | ||
|
@@ -109,23 +174,21 @@ def apns_send_message(registration_id, alert, application_id=None, creds=None, * | |
to this for silent notifications. | ||
""" | ||
|
||
try: | ||
_apns_send( | ||
registration_id, alert, application_id=application_id, | ||
creds=creds, **kwargs | ||
) | ||
except apns2_errors.APNsException as apns2_exception: | ||
if isinstance(apns2_exception, apns2_errors.Unregistered): | ||
device = models.APNSDevice.objects.get(registration_id=registration_id) | ||
device.active = False | ||
device.save() | ||
results = _apns_send( | ||
[registration_id], alert, application_id=application_id, **kwargs | ||
) | ||
result = results[registration_id] | ||
|
||
raise APNSServerError(status=apns2_exception.__class__.__name__) | ||
if result == SUCCESS_RESULT: | ||
return | ||
if result == UNREGISTERED_RESULT: | ||
models.APNSDevice.objects.filter(registration_id=registration_id).update( | ||
active=False | ||
) | ||
raise APNSServerError(status=result) | ||
|
||
|
||
def apns_send_bulk_message( | ||
registration_ids, alert, application_id=None, creds=None, **kwargs | ||
): | ||
def apns_send_bulk_message(registration_ids, alert, application_id=None, **kwargs): | ||
""" | ||
Sends an APNS notification to one or more registration_ids. | ||
The registration_ids argument needs to be a list. | ||
|
@@ -136,9 +199,12 @@ def apns_send_bulk_message( | |
""" | ||
|
||
results = _apns_send( | ||
registration_ids, alert, batch=True, application_id=application_id, | ||
creds=creds, **kwargs | ||
registration_ids, alert, application_id=application_id, **kwargs | ||
) | ||
inactive_tokens = [ | ||
token for token, result in results.items() if result == UNREGISTERED_RESULT | ||
] | ||
models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update( | ||
active=False | ||
) | ||
inactive_tokens = [token for token, result in results.items() if result == "Unregistered"] | ||
models.APNSDevice.objects.filter(registration_id__in=inactive_tokens).update(active=False) | ||
return results |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍🏼