Heads up, I'm not a Celery expert.
But the other day I've had to implement the following feature for one of my clients.
When we send a POST request to create model A, it should create model B and send the user associated with model B an email. And all of this needs to be done in the background.
I figured out I should probably create two Celery tasks for this: one for creating model A and one for sending an email to the user.
And that the latter task should probably be a callback to the former task.
I've never encountered this before so I did some research and here's what I found.
The first thing I've learned when reading about Celery workflows is that there is this thing called Signatures and Partials.
I think it's best to explain these two concepts through an example.
Let's say for the sake of the argument we have a task that sums up two numbers.
from base.celery import app
@app.task
def sum(a,b):
return a + b
We can call this task by calling the apply_async
function.
from tasks import sum
sum.apply_async((2,2))
>> 4
But we can also get the task function itself by calling signature
(or just s
) function.
from tasks import sum
sum.s((2,2))
>> tasks.sum(2,2)
If you're familiar with functional programming and currying, partials are curried signatures.
If that doesn't ring a bell, partials are signatures that haven't been given all arguments yet.
from tasks import sum
partial = sum.s(6)
partial.apply_async((4,)) # 6 + 4
This is the basics of what should we know about signatures and partials to implement our feature.
Now we can focus on callbacks.
In Celery, you can link tasks together by passing callback task partial to link
argument in apply_async
.
The result of the task will be given as the first argument of the callback task.
from tasks import sum
sum.apply_async((8,8), link=sum.s(10)) # Result of the first task is 16 so the second task will be called with 16 and 10
If you don't want your callback to accept the result of the previous task as an argument, you can make your linked partial immutable by using si()
instead of s()
.
from tasks import sum, increase_counter
sum.apply_async((8,8), link=increase_counter.si()) # Calculate 8 + 8 and increase counter of sum task calls
Of course, I'm simplifying this for the sake of the article. Celery provides a lot of advanced features for parallel processing.
If you want to read more, visit Celery workflow docs.
These concepts can get a little bit confusing so let's create an example scenario to make Celery callbacks crystal clear.
Imagine we have a Django app with the below models.
from django.db import models
class Employee(models.Model):
name = models.CharField(max_length=50)
email = models.CharField(max_length=50)
class EmployeeTraining(models.Model):
employee = models.ForeignKey(Employee, related_name="employee")
We'll create a REST endpoint that will create Employee
and kick off tasks to create EmployeeTraining
and send that Employee
an email to start their training.
Let's go one step at a time and create an APIView first.
# views.py
from rest_framework.generics import CreateAPIView
from rest_framework.response import Response
from .serializers import EmployeeSerializer # <---- pretend we have this
class EmployeeCreateView(CreateAPIView):
serializer_class = EmployeeSerializer
def post(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
return Response(data=serializer.data, status=201)
return Response(data=serializer.errors, status=500)
Next up, let's write the task for creating EmployeeTraining
in the background and add it in our EmployeeCreateView
.
# tasks.py
from base.celery import app # <---- replace the base with a module where your celery.py is
from .models import Employee, EmployeeTraining
@app.task
def create_employee_training(employee_id):
try:
employee = Employee.object.get(pk=employee_id)
employee_training = EmployeeTraining.objects.create(employee=employee)
return employee_id # <---- we return employee_id because we'll use it in callback to send email notification
except Exception:
return None
# views.py
from rest_framework.generics import CreateAPIView
from rest_framework.response import Response
from .serializers import EmployeeSerializer
from .tasks import create_employee_training
class EmployeeCreateView(CreateAPIView):
serializer_class = EmployeeSerializer
def post(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
create_employee_training.apply_async((serializer.id,))
return Response(data=serializer.data, status=201)
return Response(data=serializer.errors, status=500)
Now that we create EmployeeTraining
when we create a new Employee
, let's send that employee an email notification to start training.
# tasks.py
from base.celery import app
from .models import Employee, EmployeeTraining
from .utils import send_email_notification # <---- pretend we've already written a function that sends email notifications
@app.task
def create_employee_training(employee_id):
try:
employee = Employee.object.get(pk=employee_id)
employee_training = EmployeeTraining.objects.create(employee=employee)
return employee_id
except Exception:
return None
@app.task
def notify_employee_of_training(employee_id):
try:
employee = Employee.object.get(pk=employee_id)
message = f"Hey {employee.name}, you have 1 new training assigned to you."
send_email_notification(email=employee.email, message=message)
except Exception:
pass
# views.py
from rest_framework.generics import CreateAPIView
from rest_framework.response import Response
from .serializers import EmployeeSerializer
from .tasks import create_employee_training, notify_employee_of_training
class EmployeeCreateView(CreateAPIView):
serializer_class = EmployeeSerializer
def post(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
if serializer.is_valid():
create_employee_training.apply_async(
(serializer.id,),
link=notify_employee_of_training.s()
)
return Response(data=serializer.data, status=201)
return Response(data=serializer.errors, status=500)
We'be added link=notify_employee_of_training.s()
so if create_employee_training
it will pass employee_id
to notify_employee_of_training
.
Celery is a very powerful tool and we've only scratched the surface of what you can do with it.
I've tried to keep this article as short as possible while teaching you a thing or two about callbacks in Celery.
If you have any suggestions, feel more than free to DM me on Twitter @triforceop.