import json
import os
import warnings
from collections import namedtuple
from typing import Any, Dict, List, Mapping, Optional
import boto3
from botocore.exceptions import ClientError
from dagster import (
Array,
Field,
Noneable,
Permissive,
ScalarUnion,
StringSource,
_check as check,
)
from dagster._core.events import EngineEventData, MetadataEntry
from dagster._core.launcher.base import (
CheckRunHealthResult,
LaunchRunContext,
RunLauncher,
WorkerStatus,
)
from dagster._core.storage.pipeline_run import DagsterRun
from dagster._grpc.types import ExecuteRunArgs
from dagster._serdes import ConfigurableClass
from dagster._utils.backoff import backoff
from ..secretsmanager import get_secrets_from_arns
from .container_context import SHARED_ECS_SCHEMA, EcsContainerContext
from .tasks import (
DagsterEcsTaskDefinitionConfig,
get_current_ecs_task,
get_current_ecs_task_metadata,
get_task_definition_dict_from_current_task,
get_task_kwargs_from_current_task,
)
from .utils import sanitize_family, task_definitions_match
Tags = namedtuple("Tags", ["arn", "cluster", "cpu", "memory"])
RUNNING_STATUSES = [
"PROVISIONING",
"PENDING",
"ACTIVATING",
"RUNNING",
"DEACTIVATING",
"STOPPING",
"DEPROVISIONING",
]
STOPPED_STATUSES = ["STOPPED"]
[docs]class EcsRunLauncher(RunLauncher, ConfigurableClass):
"""RunLauncher that starts a task in ECS for each Dagster job run."""
def __init__(
self,
inst_data=None,
task_definition=None,
container_name="run",
secrets=None,
secrets_tag="dagster",
env_vars=None,
include_sidecars=False,
use_current_ecs_task_config: bool = True,
run_task_kwargs: Optional[Mapping[str, Any]] = None,
run_resources: Optional[Dict[str, str]] = None,
):
self._inst_data = inst_data
self.ecs = boto3.client("ecs")
self.ec2 = boto3.resource("ec2")
self.secrets_manager = boto3.client("secretsmanager")
self.logs = boto3.client("logs")
self.task_definition = None
self.task_definition_dict = None
if isinstance(task_definition, str):
self.task_definition = task_definition
elif task_definition and "env" in task_definition:
check.invariant(
len(task_definition) == 1,
(
"If `task_definition` is set to a dictionary with `env`, `env` must be the only"
" key."
),
)
env_var = task_definition["env"]
self.task_definition = os.getenv(env_var)
if not self.task_definition:
raise Exception(
f"You have attempted to fetch the environment variable {env_var} which is not"
" set."
)
else:
self.task_definition_dict = task_definition
self.container_name = container_name
self.secrets = check.opt_list_param(secrets, "secrets")
self.env_vars = check.opt_list_param(env_vars, "env_vars")
if self.secrets and all(isinstance(secret, str) for secret in self.secrets):
warnings.warn(
(
"Setting secrets as a list of ARNs is deprecated. "
"Secrets should instead follow the same structure as the ECS API: "
"https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_Secret.html"
),
DeprecationWarning,
)
self.secrets = [
{"name": name, "valueFrom": value_from}
for name, value_from in get_secrets_from_arns(
self.secrets_manager, self.secrets
).items()
]
self.secrets_tags = [secrets_tag] if secrets_tag else []
self.include_sidecars = include_sidecars
if self.task_definition:
task_definition = self.ecs.describe_task_definition(taskDefinition=self.task_definition)
container_names = [
container.get("name")
for container in task_definition["taskDefinition"]["containerDefinitions"]
]
check.invariant(
container_name in container_names,
(
f"Cannot override container '{container_name}' in task definition "
f"'{self.task_definition}' because the container is not defined."
),
)
self.task_definition = task_definition["taskDefinition"]["taskDefinitionArn"]
self.use_current_ecs_task_config = check.opt_bool_param(
use_current_ecs_task_config, "use_current_ecs_task_config"
)
self.run_task_kwargs = check.opt_mapping_param(run_task_kwargs, "run_task_kwargs")
if run_task_kwargs:
check.invariant(
"taskDefinition" not in run_task_kwargs,
"Use the `taskDefinition` config field to pass in a task definition to run.",
)
check.invariant(
"overrides" not in run_task_kwargs,
"Task overrides are set by the run launcher and cannot be set in run_task_kwargs.",
)
expected_keys = [
key for key in self.ecs.meta.service_model.shape_for("RunTaskRequest").members
]
for key in run_task_kwargs:
check.invariant(
key in expected_keys, f"Found an unexpected key {key} in run_task_kwargs"
)
self.run_resources = check.opt_mapping_param(run_resources, "run_resources")
self._current_task_metadata = None
self._current_task = None
@property
def inst_data(self):
return self._inst_data
@classmethod
def config_type(cls):
return {
"task_definition": Field(
ScalarUnion(
scalar_type=str,
non_scalar_schema={
"log_group": Field(StringSource, is_required=False),
"sidecar_containers": Field(Array(Permissive({})), is_required=False),
"execution_role_arn": Field(StringSource, is_required=False),
"task_role_arn": Field(StringSource, is_required=False),
"requires_compatibilities": Field(Array(str), is_required=False),
"env": Field(
str,
is_required=False,
description=(
"Backwards-compatibility for when task_definition was a"
" StringSource.Can be used to source the task_definition scalar"
" from an environment variable."
),
),
},
),
is_required=False,
description=(
"Either the short name of an existing task definition to use when launching new"
" tasks, or a dictionary configuration to use when creating a task definition"
" for the run.If neither is provided, the task definition will be created based"
" on the current task's task definition."
),
),
"container_name": Field(
StringSource,
is_required=False,
default_value="run",
description=(
"The container name to use when launching new tasks. Defaults to 'run'."
),
),
"secrets": Field(
Array(
ScalarUnion(
scalar_type=str,
non_scalar_schema={"name": StringSource, "valueFrom": StringSource},
)
),
is_required=False,
description=(
"An array of AWS Secrets Manager secrets. These secrets will "
"be mounted as environment variables in the container. See "
"https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_Secret.html."
),
),
"secrets_tag": Field(
Noneable(StringSource),
is_required=False,
default_value="dagster",
description=(
"AWS Secrets Manager secrets with this tag will be mounted as "
"environment variables in the container. Defaults to 'dagster'."
),
),
"include_sidecars": Field(
bool,
is_required=False,
default_value=False,
description=(
"Whether each run should use the same sidecars as the task that launches it. "
"Defaults to False."
),
),
"use_current_ecs_task_config": Field(
bool,
is_required=False,
default_value=True,
description=(
"Whether to use the run launcher's current ECS task in order to determine "
"the cluster and networking configuration for the launched task. Defaults to "
"True. Should only be called if the run launcher is running within an ECS "
"task."
),
),
"run_task_kwargs": Field(
Permissive(
{
"cluster": Field(
StringSource,
is_required=False,
description="Name of the ECS cluster to launch ECS tasks in.",
),
}
),
is_required=False,
description=(
"Additional arguments to include while running the task. See"
" https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task"
" for the available parameters. The overrides and taskDefinition arguments will"
" always be set by the run launcher."
),
),
**SHARED_ECS_SCHEMA,
}
@staticmethod
def from_config_value(inst_data, config_value):
return EcsRunLauncher(inst_data=inst_data, **config_value)
def _set_run_tags(self, run_id: str, cluster: str, task_arn: str):
tags = {
"ecs/task_arn": task_arn,
"ecs/cluster": cluster,
}
self._instance.add_run_tags(run_id, tags)
def build_ecs_tags_for_run_task(self, run):
return [{"key": "dagster/run_id", "value": run.run_id}]
def _get_run_tags(self, run_id):
run = self._instance.get_run_by_id(run_id)
tags = run.tags if run else {}
arn = tags.get("ecs/task_arn")
cluster = tags.get("ecs/cluster")
cpu = tags.get("ecs/cpu")
memory = tags.get("ecs/memory")
return Tags(arn, cluster, cpu, memory)
def launch_run(self, context: LaunchRunContext) -> None:
"""
Launch a run in an ECS task.
"""
run = context.pipeline_run
container_context = EcsContainerContext.create_for_run(run, self)
pipeline_origin = check.not_none(context.pipeline_code_origin)
image = pipeline_origin.repository_origin.container_image
# ECS limits overrides to 8192 characters including json formatting
# https://docs.aws.amazon.com/AmazonECS/latest/APIReference/API_RunTask.html
# When container_context is serialized as part of the ExecuteRunArgs, we risk
# going over this limit (for example, if many secrets have been set). This strips
# the container context off of our pipeline origin because we don't actually need
# it to launch the run; we only needed it to create the task definition.
repository_origin = pipeline_origin.repository_origin
# pylint: disable=protected-access
stripped_repository_origin = repository_origin._replace(container_context={})
stripped_pipeline_origin = pipeline_origin._replace(
repository_origin=stripped_repository_origin
)
# pylint: enable=protected-access
args = ExecuteRunArgs(
pipeline_origin=stripped_pipeline_origin,
pipeline_run_id=run.run_id,
instance_ref=self._instance.get_ref(),
)
command = args.get_command_args()
run_task_kwargs = self._run_task_kwargs(run, image, container_context)
# Set cpu or memory overrides
# https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-cpu-memory-error.html
cpu_and_memory_overrides = self.get_cpu_and_memory_overrides(container_context, run)
task_overrides = self._get_task_overrides(run)
container_overrides: List[Dict[str, Any]] = [
{
"name": self._get_container_name(container_context),
"command": command,
# containerOverrides expects cpu/memory as integers
**{k: int(v) for k, v in cpu_and_memory_overrides.items()},
}
]
run_task_kwargs["overrides"] = {
"containerOverrides": container_overrides,
# taskOverrides expects cpu/memory as strings
**cpu_and_memory_overrides,
**task_overrides,
}
run_task_kwargs["tags"] = [
*run_task_kwargs.get("tags", []),
*self.build_ecs_tags_for_run_task(run),
]
# Run a task using the same network configuration as this processes's
# task.
response = self.ecs.run_task(**run_task_kwargs)
tasks = response["tasks"]
if not tasks:
failures = response["failures"]
exceptions = []
for failure in failures:
arn = failure.get("arn")
reason = failure.get("reason")
detail = failure.get("detail")
exceptions.append(Exception(f"Task {arn} failed because {reason}: {detail}"))
raise Exception(exceptions)
arn = tasks[0]["taskArn"]
cluster_arn = tasks[0]["clusterArn"]
self._set_run_tags(run.run_id, cluster=cluster_arn, task_arn=arn)
self.report_launch_events(run, arn, cluster_arn)
def report_launch_events(
self, run: DagsterRun, arn: Optional[str] = None, cluster: Optional[str] = None
):
# Extracted method to allow for subclasses to customize the launch reporting behavior
metadata_entries = []
if arn:
metadata_entries.append(MetadataEntry("ECS Task ARN", value=arn))
if cluster:
metadata_entries.append(MetadataEntry("ECS Cluster", value=cluster))
metadata_entries.append(MetadataEntry("Run ID", value=run.run_id))
self._instance.report_engine_event(
message="Launching run in ECS task",
pipeline_run=run,
engine_event_data=EngineEventData(metadata_entries),
cls=self.__class__,
)
def get_cpu_and_memory_overrides(
self, container_context: EcsContainerContext, run: DagsterRun
) -> Mapping[str, str]:
overrides = {}
cpu = run.tags.get("ecs/cpu", container_context.run_resources.get("cpu"))
memory = run.tags.get("ecs/memory", container_context.run_resources.get("memory"))
if cpu:
overrides["cpu"] = cpu
if memory:
overrides["memory"] = memory
return overrides
def _get_task_overrides(self, run: DagsterRun) -> Mapping[str, Any]:
overrides = run.tags.get("ecs/task_overrides")
if overrides:
return json.loads(overrides)
return {}
def terminate(self, run_id):
tags = self._get_run_tags(run_id)
if not (tags.arn and tags.cluster):
return False
tasks = self.ecs.describe_tasks(tasks=[tags.arn], cluster=tags.cluster).get("tasks")
if not tasks:
return False
status = tasks[0].get("lastStatus")
if status == "STOPPED":
return False
self.ecs.stop_task(task=tags.arn, cluster=tags.cluster)
return True
def _get_current_task_metadata(self):
if self._current_task_metadata is None:
self._current_task_metadata = get_current_ecs_task_metadata()
return self._current_task_metadata
def _get_current_task(self):
if self._current_task is None:
current_task_metadata = self._get_current_task_metadata()
self._current_task = get_current_ecs_task(
self.ecs, current_task_metadata.task_arn, current_task_metadata.cluster
)
return self._current_task
def _get_run_task_definition_family(self, run) -> str:
return sanitize_family(
run.external_pipeline_origin.external_repository_origin.repository_location_origin.location_name # type: ignore
)
def _get_container_name(self, container_context) -> str:
return container_context.container_name or self.container_name
def _run_task_kwargs(self, run, image, container_context) -> Dict[str, Any]:
"""
Return a dictionary of args to launch the ECS task, registering a new task
definition if needed.
"""
environment = self._environment(container_context)
environment.append({"name": "DAGSTER_RUN_JOB_NAME", "value": run.job_name})
secrets = self._secrets(container_context)
if container_context.task_definition_arn:
task_definition = container_context.task_definition_arn
else:
family = self._get_run_task_definition_family(run)
if self.task_definition_dict:
task_definition_config = DagsterEcsTaskDefinitionConfig(
family,
image,
self._get_container_name(container_context),
command=None,
log_configuration=(
{
"logDriver": "awslogs",
"options": {
"awslogs-group": self.task_definition_dict["log_group"],
"awslogs-region": self.ecs.meta.region_name,
"awslogs-stream-prefix": family,
},
}
if self.task_definition_dict.get("log_group")
else None
),
secrets=secrets if secrets else [],
environment=environment,
execution_role_arn=self.task_definition_dict.get("execution_role_arn"),
task_role_arn=self.task_definition_dict.get("task_role_arn"),
sidecars=self.task_definition_dict.get("sidecar_containers"),
requires_compatibilities=self.task_definition_dict.get(
"requires_compatibilities", []
),
)
task_definition_dict = task_definition_config.task_definition_dict()
else:
task_definition_dict = get_task_definition_dict_from_current_task(
self.ecs,
family,
self._get_current_task(),
image,
self._get_container_name(container_context),
environment=environment,
secrets=secrets if secrets else {},
include_sidecars=self.include_sidecars,
)
task_definition_config = DagsterEcsTaskDefinitionConfig.from_task_definition_dict(
task_definition_dict,
self._get_container_name(container_context),
)
container_name = self._get_container_name(container_context)
backoff(
self._reuse_or_register_task_definition,
retry_on=(Exception,),
kwargs={
"desired_task_definition_config": task_definition_config,
"container_name": container_name,
"task_definition_dict": task_definition_dict,
},
max_retries=5,
)
task_definition = family
if self.use_current_ecs_task_config:
current_task_metadata = get_current_ecs_task_metadata()
current_task = get_current_ecs_task(
self.ecs, current_task_metadata.task_arn, current_task_metadata.cluster
)
task_kwargs = get_task_kwargs_from_current_task(
self.ec2,
current_task_metadata.cluster,
current_task,
)
else:
task_kwargs = {}
return {**task_kwargs, **self.run_task_kwargs, "taskDefinition": task_definition}
def _reuse_task_definition(
self, desired_task_definition_config: DagsterEcsTaskDefinitionConfig, container_name: str
):
family = desired_task_definition_config.family
try:
existing_task_definition = self.ecs.describe_task_definition(taskDefinition=family)[
"taskDefinition"
]
except ClientError:
# task definition does not exist, do not reuse
return False
return task_definitions_match(
desired_task_definition_config,
existing_task_definition,
container_name=container_name,
)
def _reuse_or_register_task_definition(
self,
desired_task_definition_config: DagsterEcsTaskDefinitionConfig,
container_name: str,
task_definition_dict: dict,
):
if not self._reuse_task_definition(desired_task_definition_config, container_name):
self.ecs.register_task_definition(**task_definition_dict)
def _environment(self, container_context):
return [
{"name": key, "value": value}
for key, value in container_context.get_environment_dict().items()
]
def _secrets(self, container_context):
secrets = container_context.get_secrets_dict(self.secrets_manager)
return (
[{"name": key, "valueFrom": value} for key, value in secrets.items()] if secrets else []
)
@property
def supports_check_run_worker_health(self):
return True
def check_run_worker_health(self, run: DagsterRun):
tags = self._get_run_tags(run.run_id)
if not (tags.arn and tags.cluster):
return CheckRunHealthResult(WorkerStatus.UNKNOWN, "")
tasks = self.ecs.describe_tasks(tasks=[tags.arn], cluster=tags.cluster).get("tasks")
if not tasks:
return CheckRunHealthResult(WorkerStatus.UNKNOWN, "")
t = tasks[0]
if t.get("lastStatus") in RUNNING_STATUSES:
return CheckRunHealthResult(WorkerStatus.RUNNING)
elif t.get("lastStatus") in STOPPED_STATUSES:
failed_containers = []
for c in t.get("containers"):
if c.get("exitCode") != 0:
failed_containers.append(c)
if len(failed_containers) > 0:
if len(failed_containers) > 1:
container_str = "Containers"
else:
container_str = "Container"
return CheckRunHealthResult(
WorkerStatus.FAILED,
(
f"ECS task failed. Stop code: {t.get('stopCode')}. Stop reason:"
f" {t.get('stoppedReason')}."
f" {container_str} {[c.get('name') for c in failed_containers]} failed."
f" Check the logs for task {t.get('taskArn')} for details."
),
)
return CheckRunHealthResult(WorkerStatus.SUCCESS)
return CheckRunHealthResult(WorkerStatus.UNKNOWN, "ECS task health status is unknown.")