"""
Job scheduler interface for Slurm, SGE, and local execution.
Provides a unified interface for detecting the current scheduler
environment, querying job metadata (CPUs, remaining walltime),
and managing job arrays.
"""
import logging
import os
import re
import subprocess
import tempfile
from datetime import datetime, timedelta, timezone
from typing import ClassVar, Optional
logger = logging.getLogger(__name__)
[docs]
class Scheduler:
"""Base class for job scheduler interfaces."""
def __init__(self) -> None:
self._job_id: Optional[str] = None
self._ncpus: Optional[int] = None
[docs]
def get_n_cpus(self) -> Optional[int]:
"""Return the number of CPUs allocated to this job."""
raise NotImplementedError
@property
def user_name(self) -> str:
"""Return the name of the current user."""
return os.environ["USER"]
[docs]
def get_remaining_seconds(self) -> int:
"""Get the remaining walltime in seconds."""
raise NotImplementedError
@property
def is_in_job(self) -> bool:
"""Return whether we are inside a scheduler job."""
return self.job_id is not None
@property
def job_id(self) -> Optional[str]:
"""Return the job ID."""
raise NotImplementedError
[docs]
@classmethod
def get_scheduler(cls) -> Optional["Scheduler"]:
"""
Detect and return a scheduler instance for the current environment.
Tries Slurm, SGE, then Dummy. Returns None if not in any job
and Dummy is not appropriate.
"""
for trial in [Slurm, SGE, Dummy]:
obj = trial()
if obj.is_in_job:
return obj
return None
[docs]
class Dummy(Scheduler):
"""Dummy scheduler for local execution."""
DEFAULT_REMAINING_TIME = 3600 * 24 * 30 # 30 days
def __init__(self) -> None:
super().__init__()
self._job_id = "0"
[docs]
def get_n_cpus(self) -> int:
return 4
@property
def job_id(self) -> str:
return self._job_id
[docs]
def get_remaining_seconds(self) -> int:
"""Get the remaining time. Defaults to 30 days."""
return self.DEFAULT_REMAINING_TIME
@property
def is_in_job(self) -> bool:
return True
[docs]
class SGE(Scheduler):
"""Scheduler object for Sun Grid Engine (SGE)."""
def __init__(self) -> None:
super().__init__()
self._start_time: Optional[datetime] = None
if self.is_in_job:
self._read_task_info()
@property
def job_id(self) -> Optional[str]:
if self._job_id is None:
job_id = os.environ.get("JOB_ID")
task_id = os.environ.get("SGE_TASK_ID")
if job_id:
if task_id and task_id != "undefined":
job_id = job_id + "." + task_id
logger.warning("WARNING: REMAINING TIME IS NOT CORRECT FOR TASK ARRAY")
self._job_id = job_id
return self._job_id
@property
def is_in_job(self) -> bool:
return "JOB_ID" in os.environ
def _read_task_info(self) -> None:
"""Read detailed task information from qstat."""
raw_data = subprocess.check_output(
["qstat", "-j", str(self.job_id)], universal_newlines=True
)
raw_data = raw_data.split("\n")
task_info: dict[str, str] = {}
for line in raw_data[1:]:
try:
key, value = line.split(":", maxsplit=1)
except ValueError:
continue
task_info[key.strip()] = value.strip()
self._task_info = task_info
[docs]
def get_n_cpus(self) -> Optional[int]:
"""Get the number of CPUs from the NSLOTS environment variable."""
nslots = os.environ.get("NSLOTS")
if nslots:
return int(nslots)
return None
[docs]
def get_max_run_seconds(self) -> Optional[int]:
"""Return the maximum run time in seconds."""
rlist = self._task_info["hard resource_list"]
match = re.search(r"h_rt=(\d+)", rlist)
if match:
return int(match.group(1))
return None
[docs]
def get_end_time(self) -> Optional[datetime]:
"""Return the expected finish time of this job."""
start = self.get_start_time()
max_run = self.get_max_run_seconds()
if start and max_run:
return start + timedelta(seconds=max_run)
return None
[docs]
def get_start_time(self) -> Optional[datetime]:
"""Return the start time of this job."""
output = subprocess.check_output(
["qstat", "-j", str(self.job_id), "-xml"], universal_newlines=True
)
match = re.search(r"<JAT_start_time>(.+)</JAT_start_time>", output)
if match:
raw = match.group(1)
time_int = int(raw)
start_time = datetime.fromtimestamp(time_int, tz=timezone.utc)
self._start_time = start_time
return self._start_time
[docs]
def get_remaining_seconds(self) -> int:
"""Return the remaining walltime in seconds."""
end_time = self.get_end_time()
if end_time is None:
return 0
tdelta = end_time - datetime.now().astimezone()
return int(tdelta.total_seconds())
[docs]
class Slurm(Scheduler):
"""Slurm scheduler interface."""
_task_info: ClassVar[Optional[dict[str, str]]] = None
_warning: ClassVar[int] = 0
def __init__(self) -> None:
super().__init__()
self.task_info: dict[str, str] = {}
if Slurm._task_info is None:
self._read_task_info()
Slurm._task_info = self.task_info
else:
self.task_info = Slurm._task_info
@property
def is_in_job(self) -> bool:
"""Whether we are inside a Slurm job."""
return "SLURM_JOB_ID" in os.environ
@property
def job_id(self) -> Optional[str]:
if self._job_id is None:
self._job_id = os.environ.get("SLURM_JOB_ID")
return self._job_id
def _read_task_info(self) -> None:
"""Extract job information from scontrol."""
if not self.is_in_job:
if Slurm._warning == 0:
logger.debug("NOT STARTED FROM SLURM")
Slurm._warning += 1
self.task_info = {}
return
sinfo_dict: dict[str, str] = {}
with tempfile.TemporaryFile(mode="w+") as tmp_file:
subprocess.run(
["scontrol", "show", f"jobid={self.job_id}"],
check=True,
stdout=tmp_file,
)
tmp_file.seek(0)
for line in tmp_file:
for pair in line.split():
pair_s = pair.split("=", maxsplit=2)
if len(pair_s) == 2:
sinfo_dict[pair_s[0]] = pair_s[1]
elif len(pair_s) == 1:
sinfo_dict[pair_s[0]] = ""
Slurm._task_info = sinfo_dict
self.task_info = sinfo_dict
[docs]
def get_end_time(self) -> Optional[datetime]:
"""Return the end time of this job."""
if self.task_info:
return datetime.strptime(self.task_info["EndTime"], "%Y-%m-%dT%H:%M:%S")
return None
[docs]
def get_remaining_seconds(self) -> int:
"""Return the remaining walltime in seconds."""
end_time = self.get_end_time()
if end_time is None:
return 0
return int((end_time - datetime.now().astimezone()).total_seconds())
[docs]
def get_user_name(self) -> Optional[str]:
"""Parse the user name from task info."""
if self.task_info:
match = re.match(r"([A-Za-z]+[0-9]+)\([0-9]+\)", self.task_info["UserId"])
if match:
return match.group(1)
return None
[docs]
def get_n_cpus(self) -> Optional[str]:
"""Return number of CPUs allocated."""
return self.task_info.get("NumCPUs")
[docs]
def get_array_id(self) -> Optional[str]:
"""Return the array job ID, or None."""
return self.task_info.get("ArrayJobId")
[docs]
def get_array_task_id(self) -> Optional[str]:
"""Return the array task ID, or None."""
return self.task_info.get("ArrayTaskId")
[docs]
def get_array_job_id(self) -> Optional[str]:
"""Return the array job ID as ``{array_id}_{task_id}``."""
task = self.task_info
if not task:
return None
try:
return f"{task['ArrayJobId']}_{task['ArrayTaskId']}"
except KeyError:
return None
[docs]
def hold_array(self, array_num: Optional[str] = None) -> None:
"""Hold the array this job belongs to."""
array_id = self.get_array_job_id()
if array_num:
array = array_num
elif array_id:
array = array_id.split("_")[0]
else:
array = None
if array is not None:
proc = subprocess.run(["scontrol", "hold", str(array)], check=False)
if proc.returncode == 0:
logger.info("Successfully held array %s", array)
else:
logger.error("Cannot hold array %s", array)
else:
logger.error("Cannot find the array to hold")
[docs]
def hold_all_pd_arrays(self, user_name: Optional[str] = None) -> None:
"""Hold all pending arrays."""
arrays = self.get_pd_arrays(user_name)
if not arrays:
logger.warning("No array found to hold")
return
logger.info("Holding all pending arrays")
array_list = ",".join(arrays)
subprocess.run(["scontrol", "hold", array_list], check=True)
[docs]
def release_all_pd_arrays(self, user_name: Optional[str] = None) -> None:
"""Release all pending arrays."""
arrays = self.get_pd_arrays(user_name)
if not arrays:
logger.warning("No array found to release")
return
array_list = ",".join(arrays)
subprocess.run(["scontrol", "release", array_list], check=True)
[docs]
def get_running_jobs(self, user_name: Optional[str] = None) -> Optional[list[str]]:
"""Return a list of running job IDs for the current user."""
return self._get_id_of_state("R", "%A %t", user_name)
[docs]
def get_pd_arrays(self, user_name: Optional[str] = None) -> Optional[list[str]]:
"""Return a list of pending array job IDs."""
return self._get_id_of_state("PD", "%F %t", user_name)
def _get_id_of_state(
self, criteria: str, fmt_str: str, user_name: Optional[str] = None
) -> Optional[list[str]]:
"""Get IDs for jobs matching a given state."""
if not self.task_info and not user_name:
return None
if not user_name:
user_name = self.get_user_name()
if not user_name:
return None
res = tempfile.TemporaryFile(mode="w+")
subprocess.run(
["squeue", "-u", user_name, f"-o{fmt_str}"],
check=True,
stdout=res,
)
res.seek(0)
task_ids: list[str] = []
for line in res:
sline = line.split()
if len(sline) >= 2 and sline[1] == criteria:
task_ids.append(sline[0])
res.close()
return task_ids
[docs]
def __bool__(self) -> bool:
return bool(self.task_info)
[docs]
def __getitem__(self, key: str) -> Optional[str]:
if self.task_info:
return self.task_info[key]
return None
[docs]
def __contains__(self, key: str) -> bool:
return key in self.task_info