自定义Celery任务记录器

自定义Celery任务记录器

celery 有一个特殊的记录器celery.task,这个记录器由 celery worker 设立,目的是将与任务相关的信息添加到日志消息中。这个日志包含两个新的参数:

  • task_id
  • task_name

通过访问任务记录器 celery.utils.log,这两个参数可以帮助我们了解日志消息来自哪个任务。

1
2
3
4
5
6
7
8
9
10
11
12
# tasks.py
import os
from celery.utils.log import get_task_logger
from worker import app

logger = get_task_logger(__name__)

@app.task()
def add(x, y):
result = x + y
logger.info(f'Add: {x} + {y} = {result}')
return result

执行任务,get_task_logger将产生如下的日志

1
2
3
[2019-10-31 07:30:13,545: INFO/MainProcess] Received task: tasks.get_request[9c332222-d2fc-47d9-adc3-04cebbe145cb]
[2019-10-31 07:30:13,546: INFO/MainProcess] tasks.get_request[9c332222-d2fc-47d9-adc3-04cebbe145cb]: Add: 3 + 5 = 8
[2019-10-31 07:30:13,598: INFO/MainProcess] Task tasks.get_request[9c332222-d2fc-47d9-adc3-04cebbe145cb] succeeded in 0.052071799989789724s: None

如果celery 应用程序处理非常多的任务,那么 celey.task 日志记录器对于日志输出是必不可少的

执行任务,标准 logging.getlogger 将产生如下日志

1
2
3
[2019-10-31 07:33:16,140: INFO/MainProcess] Received task: tasks.get_request[7d2ec1a7-0af2-4e8c-8354-02cd0975c906]
[2019-10-31 07:33:16,140: INFO/MainProcess] Add: 3 + 5 = 8
[2019-10-31 07:33:16,193: INFO/MainProcess] Task tasks.get_request[7d2ec1a7-0af2-4e8c-8354-02cd0975c906] succeeded in 0.052330999984405935s: None

如何自定义 celery.task 日志格式

celery.task 记录器有 after_setup_task_logger 信号,一旦 celery worker 设置了 celery.task 记录仪,就会触发该信号,这是我们要连接一定制日志格式化程序的信号。

这其中有个问题需要注意,为了访问task_id和 task_name,必须使用 celery.app.log.TaskFormatter 代替logging.Formatter。

1
2
3
4
5
6
7
8
9
10
11
12
# worker.py
import os
from celery import Celery
from celery.signals import after_setup_task_logger
from celery.app.log import TaskFormatter

app = Celery()

@after_setup_task_logger.connect
def setup_task_logger(logger, *args, **kwargs):
for handler in logger.handlers:
handler.setFormatter(TaskFormatter('%(asctime)s - %(task_id)s - %(task_name)s - %(name)s - %(levelname)s - %(message)s'))

如何使用标准记录器获取 task_id

celery.task 可以很好的完成celery 异步任务,但是有的时候我们的模型是用于 celery 和flask 两个环境中的,这个时候怎么解决呢?

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# models.py
import logging

from passlib.hash import sha256_crypt
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import validates
from sqlalchemy import text
from . import db

logger = logging.getLogger(__name__)

class User(db.Model):
__tablename__ = 'users'
id = db.Column(UUID(as_uuid=True), primary_key=True, server_default=text("uuid_generate_v4()"))
name = db.Column(db.String(64), unique=False, nullable=True)
email = db.Column(db.String(256), unique=True, nullable=False)

@validates('email')
def validate_email(self, key, value):
logger.info(f'Validate email address: {value}')
if value is not None:
assert '@' in value
return value.lower()

底层代码不在乎模型是在哪个上下文中运行,我们可以通过在celery 异步任务中调用 validate_email 然后在日志消息中获取 celery.task_id,validate_email 从 flask 中调用的时候,没有任务ID

我们还可以通过 celery._state.get_current_task在celery.app.log.TaskFormatter添加 task_id和 task_name。如果 celery_state.get_current_task 在celery 任务之外执行,则会返回 None

celery.app.log.TaskFormatter 通过打印 ???处理 None,而不是 task_id 和 task_name。这意味着我们可以在 celery 异步任务之外安全地创建日志处理程序 celery.app.log.TaskFormatter

1
2
3
4
5
6
7
8
import logging
from celery.app.log import TaskFormatter

logger = logging.getLogger()
sh = logging.StreamHandler()
sh.setFormatter(TaskFormatter('%(asctime)s - %(task_id)s - %(task_name)s - %(name)s - %(levelname)s - %(message)s'))
logger.setLevel(logging.INFO)
logger.addHandler(sh)

如果不喜欢默认值???,也可以自己在 celery.app.log 中自定义自己的任务格式化程序

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import logging

class TaskFormatter(logging.Formatter):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
try:
from celery._state import get_current_task
self.get_current_task = get_current_task
except ImportError:
self.get_current_task = lambda: None


def format(self, record):
task = self.get_current_task()
if task and task.request:
record.__dict__.update(task_id=task.request.id,
task_name=task.name)
else:
record.__dict__.setdefault('task_name', '')
record.__dict__.setdefault('task_id', '')
return super().format(record)

logger = logging.getLogger()
sh = logging.StreamHandler()
sh.setFormatter(
TaskFormatter(
'%(asctime)s - %(task_id)s - %(task_name)s - %(name)s - %(levelname)s - %(message)s'))
logger.setLevel(logging.INFO)
logger.addHandler(sh)

这个定制的 TaskFormatter 同样适用于 logging.getlogger。