# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections import OrderedDict
# To avoid circular imports
import airflow.utils.dag_processing
from airflow import configuration
from airflow.stats import Stats
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.state import State
[docs]PARALLELISM = configuration.conf.getint('core', 'PARALLELISM')
[docs]class BaseExecutor(LoggingMixin):
def __init__(self, parallelism=PARALLELISM):
"""
Class to derive in order to interface with executor-type systems
like Celery, Yarn and the likes.
:param parallelism: how many jobs should run at one time. Set to
``0`` for infinity
:type parallelism: int
"""
self.parallelism = parallelism
self.queued_tasks = OrderedDict()
self.running = {}
self.event_buffer = {}
[docs] def start(self): # pragma: no cover
"""
Executors may need to get things started. For example LocalExecutor
starts N workers.
"""
[docs] def queue_command(self, simple_task_instance, command, priority=1, queue=None):
key = simple_task_instance.key
if key not in self.queued_tasks and key not in self.running:
self.log.info("Adding to queue: %s", command)
self.queued_tasks[key] = (command, priority, queue, simple_task_instance)
else:
self.log.info("could not queue task %s", key)
[docs] def queue_task_instance(
self,
task_instance,
mark_success=False,
pickle_id=None,
ignore_all_deps=False,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
pool=None,
cfg_path=None):
pool = pool or task_instance.pool
# TODO (edgarRd): AIRFLOW-1985:
# cfg_path is needed to propagate the config values if using impersonation
# (run_as_user), given that there are different code paths running tasks.
# For a long term solution we need to address AIRFLOW-1986
command = task_instance.command_as_list(
local=True,
mark_success=mark_success,
ignore_all_deps=ignore_all_deps,
ignore_depends_on_past=ignore_depends_on_past,
ignore_task_deps=ignore_task_deps,
ignore_ti_state=ignore_ti_state,
pool=pool,
pickle_id=pickle_id,
cfg_path=cfg_path)
self.queue_command(
airflow.utils.dag_processing.SimpleTaskInstance(task_instance),
command,
priority=task_instance.task.priority_weight_total,
queue=task_instance.task.queue)
[docs] def has_task(self, task_instance):
"""
Checks if a task is either queued or running in this executor
:param task_instance: TaskInstance
:return: True if the task is known to this executor
"""
if task_instance.key in self.queued_tasks or task_instance.key in self.running:
return True
"""
Sync will get called periodically by the heartbeat method.
Executors should override this to perform gather statuses.
"""
[docs] def heartbeat(self):
# Triggering new jobs
if not self.parallelism:
open_slots = len(self.queued_tasks)
else:
open_slots = self.parallelism - len(self.running)
num_running_tasks = len(self.running)
num_queued_tasks = len(self.queued_tasks)
self.log.debug("%s running task instances", num_running_tasks)
self.log.debug("%s in queue", num_queued_tasks)
self.log.debug("%s open slots", open_slots)
Stats.gauge('executor.open_slots', open_slots)
Stats.gauge('executor.queued_tasks', num_queued_tasks)
Stats.gauge('executor.running_tasks', num_running_tasks)
self.trigger_tasks(open_slots)
# Calling child class sync method
self.log.debug("Calling the %s sync method", self.__class__)
self.sync()
[docs] def trigger_tasks(self, open_slots):
"""
Trigger tasks
:param open_slots: Number of open slots
:return:
"""
sorted_queue = sorted(
[(k, v) for k, v in self.queued_tasks.items()],
key=lambda x: x[1][1],
reverse=True)
for _ in range(min((open_slots, len(self.queued_tasks)))):
key, (command, _, queue, simple_ti) = sorted_queue.pop(0)
self.queued_tasks.pop(key)
self.running[key] = command
self.execute_async(key=key,
command=command,
queue=queue,
executor_config=simple_ti.executor_config)
[docs] def change_state(self, key, state):
self.log.debug("Changing state: %s", key)
self.running.pop(key, None)
self.event_buffer[key] = state
[docs] def fail(self, key):
self.change_state(key, State.FAILED)
[docs] def success(self, key):
self.change_state(key, State.SUCCESS)
[docs] def get_event_buffer(self, dag_ids=None):
"""
Returns and flush the event buffer. In case dag_ids is specified
it will only return and flush events for the given dag_ids. Otherwise
it returns and flushes all
:param dag_ids: to dag_ids to return events for, if None returns all
:return: a dict of events
"""
cleared_events = dict()
if dag_ids is None:
cleared_events = self.event_buffer
self.event_buffer = dict()
else:
for key in list(self.event_buffer.keys()):
dag_id, _, _, _ = key
if dag_id in dag_ids:
cleared_events[key] = self.event_buffer.pop(key)
return cleared_events
[docs] def execute_async(self,
key,
command,
queue=None,
executor_config=None): # pragma: no cover
"""
This method will execute the command asynchronously.
"""
raise NotImplementedError()
[docs] def end(self): # pragma: no cover
"""
This method is called when the caller is done submitting job and
wants to wait synchronously for the job submitted previously to be
all done.
"""
raise NotImplementedError()
[docs] def terminate(self):
"""
This method is called when the daemon receives a SIGTERM
"""
raise NotImplementedError()