👨‍💻 Looking for a developer? I'm available for work. Contact me at contact@ademhodzic.com
Get Django Saas Blog

How to Set a Callback to a Celery Task (With Examples)

Published on: March 14, 2021

Intro

Heads up, I'm not a Celery expert.

But the other day I've had to implement the following feature for one of my clients.

When we send a POST request to create model A, it should create model B and send the user associated with model B an email. And all of this needs to be done in the background.

I figured out I should probably create two Celery tasks for this: one for creating model A and one for sending an email to the user.

And that the latter task should probably be a callback to the former task.

I've never encountered this before so I did some research and here's what I found.

Celery Task Signatures

The first thing I've learned when reading about Celery workflows is that there is this thing called Signatures and Partials.

I think it's best to explain these two concepts through an example.

Signatures

Let's say for the sake of the argument we have a task that sums up two numbers.

from base.celery import app @app.task def sum(a,b): return a + b

We can call this task by calling the apply_async function.

from tasks import sum sum.apply_async((2,2)) >> 4

But we can also get the task function itself by calling signature (or just s) function.

from tasks import sum sum.s((2,2)) >> tasks.sum(2,2)

Partials

If you're familiar with functional programming and currying, partials are curried signatures.

If that doesn't ring a bell, partials are signatures that haven't been given all arguments yet.

from tasks import sum partial = sum.s(6) partial.apply_async((4,)) # 6 + 4

This is the basics of what should we know about signatures and partials to implement our feature.

Now we can focus on callbacks.

Callbacks

In Celery, you can link tasks together by passing callback task partial to link argument in apply_async.

The result of the task will be given as the first argument of the callback task.

from tasks import sum sum.apply_async((8,8), link=sum.s(10)) # Result of the first task is 16 so the second task will be called with 16 and 10

If you don't want your callback to accept the result of the previous task as an argument, you can make your linked partial immutable by using si() instead of s().

from tasks import sum, increase_counter sum.apply_async((8,8), link=increase_counter.si()) # Calculate 8 + 8 and increase counter of sum task calls

Of course, I'm simplifying this for the sake of the article. Celery provides a lot of advanced features for parallel processing.

If you want to read more, visit Celery workflow docs.

Example

These concepts can get a little bit confusing so let's create an example scenario to make Celery callbacks crystal clear.

Imagine we have a Django app with the below models.

from django.db import models class Employee(models.Model): name = models.CharField(max_length=50) email = models.CharField(max_length=50) class EmployeeTraining(models.Model): employee = models.ForeignKey(Employee, related_name="employee")

We'll create a REST endpoint that will create Employee and kick off tasks to create EmployeeTraining and send that Employee an email to start their training.

Let's go one step at a time and create an APIView first.

# views.py from rest_framework.generics import CreateAPIView from rest_framework.response import Response from .serializers import EmployeeSerializer # <---- pretend we have this class EmployeeCreateView(CreateAPIView): serializer_class = EmployeeSerializer def post(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) if serializer.is_valid(): return Response(data=serializer.data, status=201) return Response(data=serializer.errors, status=500)

Next up, let's write the task for creating EmployeeTraining in the background and add it in our EmployeeCreateView.

# tasks.py from base.celery import app # <---- replace the base with a module where your celery.py is from .models import Employee, EmployeeTraining @app.task def create_employee_training(employee_id): try: employee = Employee.object.get(pk=employee_id) employee_training = EmployeeTraining.objects.create(employee=employee) return employee_id # <---- we return employee_id because we'll use it in callback to send email notification except Exception: return None # views.py from rest_framework.generics import CreateAPIView from rest_framework.response import Response from .serializers import EmployeeSerializer from .tasks import create_employee_training class EmployeeCreateView(CreateAPIView): serializer_class = EmployeeSerializer def post(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) if serializer.is_valid(): create_employee_training.apply_async((serializer.id,)) return Response(data=serializer.data, status=201) return Response(data=serializer.errors, status=500)

Now that we create EmployeeTraining when we create a new Employee, let's send that employee an email notification to start training.

# tasks.py from base.celery import app from .models import Employee, EmployeeTraining from .utils import send_email_notification # <---- pretend we've already written a function that sends email notifications @app.task def create_employee_training(employee_id): try: employee = Employee.object.get(pk=employee_id) employee_training = EmployeeTraining.objects.create(employee=employee) return employee_id except Exception: return None @app.task def notify_employee_of_training(employee_id): try: employee = Employee.object.get(pk=employee_id) message = f"Hey {employee.name}, you have 1 new training assigned to you." send_email_notification(email=employee.email, message=message) except Exception: pass # views.py from rest_framework.generics import CreateAPIView from rest_framework.response import Response from .serializers import EmployeeSerializer from .tasks import create_employee_training, notify_employee_of_training class EmployeeCreateView(CreateAPIView): serializer_class = EmployeeSerializer def post(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data) if serializer.is_valid(): create_employee_training.apply_async( (serializer.id,), link=notify_employee_of_training.s() ) return Response(data=serializer.data, status=201) return Response(data=serializer.errors, status=500)

We'be added link=notify_employee_of_training.s() so if create_employee_training it will pass employee_id to notify_employee_of_training.

Conclusion

Celery is a very powerful tool and we've only scratched the surface of what you can do with it.

I've tried to keep this article as short as possible while teaching you a thing or two about callbacks in Celery.

If you have any suggestions, feel more than free to DM me on Twitter @triforceop.