Skip to content

Commit 0357fc8

Browse files
committed
implemented shell_task and basic unittests. Generated tasks do not work as need to determine best way to map onto input_spec
1 parent 83c06e3 commit 0357fc8

File tree

2 files changed

+227
-76
lines changed

2 files changed

+227
-76
lines changed

pydra/mark/shell_commands.py

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,45 @@ def shell_task(
5757
"input_field arguments"
5858
)
5959
name = klass_or_name
60+
6061
if output_fields is None:
6162
output_fields = {}
62-
if bases is None:
63-
bases = [pydra.engine.task.ShellCommandTask]
64-
if input_bases is None:
65-
input_bases = [pydra.engine.specs.ShellSpec]
66-
if output_bases is None:
67-
output_bases = [pydra.engine.specs.ShellOutSpec]
68-
Inputs = type("Inputs", tuple(input_bases), input_fields)
69-
Outputs = type("Outputs", tuple(output_bases), output_fields)
63+
64+
# Ensure bases are lists and can be modified
65+
bases = list(bases) if bases is not None else []
66+
input_bases = list(input_bases) if input_bases is not None else []
67+
output_bases = list(output_bases) if output_bases is not None else []
68+
69+
# Ensure base classes included somewhere in MRO
70+
def ensure_base_of(base_class: type, bases_list: list[type]):
71+
if not any(issubclass(b, base_class) for b in bases_list):
72+
bases_list.append(base_class)
73+
74+
ensure_base_of(pydra.engine.task.ShellCommandTask, bases)
75+
ensure_base_of(pydra.engine.specs.ShellSpec, input_bases)
76+
ensure_base_of(pydra.engine.specs.ShellOutSpec, output_bases)
77+
78+
def convert_to_attrs(fields: dict[str, dict[str, ty.Any]], attrs_func):
79+
annotations = {}
80+
attrs_dict = {"__annotations__": annotations}
81+
for name, dct in fields.items():
82+
kwargs = dict(dct) # copy to avoid modifying input to outer function
83+
annotations[name] = kwargs.pop("type")
84+
attrs_dict[name] = attrs_func(**kwargs)
85+
return attrs_dict
86+
87+
Inputs = attrs.define(kw_only=True, slots=False)(
88+
type(
89+
"Inputs", tuple(input_bases), convert_to_attrs(input_fields, shell_arg)
90+
)
91+
)
92+
Outputs = attrs.define(kw_only=True, slots=False)(
93+
type(
94+
"Outputs",
95+
tuple(output_bases),
96+
convert_to_attrs(output_fields, shell_out),
97+
)
98+
)
7099
else:
71100
if (
72101
executable,
@@ -96,39 +125,43 @@ def shell_task(
96125
"Classes decorated by `shell_task` should contain an `Inputs` class attribute "
97126
"specifying the inputs to the shell tool"
98127
)
99-
if not issubclass(Inputs, pydra.engine.specs.ShellSpec):
100-
Inputs = type("Inputs", (Inputs, pydra.engine.specs.ShellSpec), {})
128+
101129
try:
102130
Outputs = klass.Outputs
103131
except KeyError:
104132
Outputs = type("Outputs", (pydra.engine.specs.ShellOutSpec,))
133+
134+
Inputs = attrs.define(kw_only=True, slots=False)(Inputs)
135+
Outputs = attrs.define(kw_only=True, slots=False)(Outputs)
136+
137+
if not issubclass(Inputs, pydra.engine.specs.ShellSpec):
138+
Inputs = attrs.define(kw_only=True, slots=False)(
139+
type("Inputs", (Inputs, pydra.engine.specs.ShellSpec), {})
140+
)
141+
142+
if not issubclass(Outputs, pydra.engine.specs.ShellOutSpec):
143+
Outputs = attrs.define(kw_only=True, slots=False)(
144+
type("Outputs", (Outputs, pydra.engine.specs.ShellOutSpec), {})
145+
)
146+
105147
bases = [klass]
106148
if not issubclass(klass, pydra.engine.task.ShellCommandTask):
107149
bases.append(pydra.engine.task.ShellCommandTask)
108150

109-
Inputs = attrs.define(kw_only=True, slots=False)(Inputs)
110-
Outputs = attrs.define(kw_only=True, slots=False)(Outputs)
111-
112151
dct = {
113152
"executable": executable,
114-
"Inputs": Outputs,
115-
"Outputs": Inputs,
116-
"inputs": attrs.field(factory=Inputs),
117-
"outputs": attrs.field(factory=Outputs),
153+
"Inputs": Inputs,
154+
"Outputs": Outputs,
118155
"__annotations__": {
119156
"executable": str,
120157
"inputs": Inputs,
121158
"outputs": Outputs,
159+
"Inputs": type,
160+
"Outputs": type,
122161
},
123162
}
124163

125-
return attrs.define(kw_only=True, slots=False)(
126-
type(
127-
name,
128-
tuple(bases),
129-
dct,
130-
)
131-
)
164+
return type(name, tuple(bases), dct)
132165

133166

134167
def shell_arg(

pydra/mark/tests/test_shell_commands.py

Lines changed: 170 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,69 +1,187 @@
11
import os
22
import tempfile
3-
from pathlib import Path
43
import attrs
5-
import pydra.engine
4+
from pathlib import Path
5+
import pytest
6+
import cloudpickle as cp
67
from pydra.mark import shell_task, shell_arg, shell_out
78

89

9-
def test_shell_task_full():
10-
@attrs.define(kw_only=True, slots=False)
11-
class LsInputSpec(pydra.specs.ShellSpec):
12-
directory: os.PathLike = shell_arg(
13-
help_string="the directory to list the contents of",
14-
argstr="",
15-
mandatory=True,
16-
)
17-
hidden: bool = shell_arg(help_string=("display hidden FS objects"), argstr="-a")
18-
long_format: bool = shell_arg(
19-
help_string=(
20-
"display properties of FS object, such as permissions, size and timestamps "
21-
),
22-
argstr="-l",
23-
)
24-
human_readable: bool = shell_arg(
25-
help_string="display file sizes in human readable form",
26-
argstr="-h",
27-
requires=["long_format"],
28-
)
29-
complete_date: bool = shell_arg(
30-
help_string="Show complete date in long format",
31-
argstr="-T",
32-
requires=["long_format"],
33-
xor=["date_format_str"],
34-
)
35-
date_format_str: str = shell_arg(
36-
help_string="format string for ",
37-
argstr="-D",
38-
requires=["long_format"],
39-
xor=["complete_date"],
40-
)
10+
def list_entries(stdout):
11+
return stdout.split("\n")[:-1]
4112

42-
def list_outputs(stdout):
43-
return stdout.split("\n")[:-1]
4413

45-
@attrs.define(kw_only=True, slots=False)
46-
class LsOutputSpec(pydra.specs.ShellOutSpec):
47-
entries: list = shell_out(
48-
help_string="list of entries returned by ls command", callable=list_outputs
49-
)
14+
@pytest.fixture
15+
def tmpdir():
16+
return Path(tempfile.mkdtemp())
5017

51-
class Ls(pydra.engine.ShellCommandTask):
52-
"""Task definition for the `ls` command line tool"""
5318

54-
executable = "ls"
19+
@pytest.fixture(params=["static", "dynamic"])
20+
def Ls(request):
21+
if request.param == "static":
5522

56-
input_spec = pydra.specs.SpecInfo(
57-
name="LsInput",
58-
bases=(LsInputSpec,),
59-
)
23+
@shell_task
24+
class Ls:
25+
executable = "ls"
6026

61-
output_spec = pydra.specs.SpecInfo(
62-
name="LsOutput",
63-
bases=(LsOutputSpec,),
27+
class Inputs:
28+
directory: os.PathLike = shell_arg(
29+
help_string="the directory to list the contents of",
30+
argstr="",
31+
mandatory=True,
32+
)
33+
hidden: bool = shell_arg(
34+
help_string=("display hidden FS objects"),
35+
argstr="-a",
36+
default=False,
37+
)
38+
long_format: bool = shell_arg(
39+
help_string=(
40+
"display properties of FS object, such as permissions, size and "
41+
"timestamps "
42+
),
43+
default=False,
44+
argstr="-l",
45+
)
46+
human_readable: bool = shell_arg(
47+
help_string="display file sizes in human readable form",
48+
argstr="-h",
49+
default=False,
50+
requires=["long_format"],
51+
)
52+
complete_date: bool = shell_arg(
53+
help_string="Show complete date in long format",
54+
argstr="-T",
55+
default=False,
56+
requires=["long_format"],
57+
xor=["date_format_str"],
58+
)
59+
date_format_str: str = shell_arg(
60+
help_string="format string for ",
61+
argstr="-D",
62+
default=None,
63+
requires=["long_format"],
64+
xor=["complete_date"],
65+
)
66+
67+
class Outputs:
68+
entries: list = shell_out(
69+
help_string="list of entries returned by ls command",
70+
callable=list_entries,
71+
)
72+
73+
elif request.param == "dynamic":
74+
Ls = shell_task(
75+
"Ls",
76+
executable="ls",
77+
input_fields={
78+
"directory": {
79+
"type": os.PathLike,
80+
"help_string": "the directory to list the contents of",
81+
"argstr": "",
82+
"mandatory": True,
83+
},
84+
"hidden": {
85+
"type": bool,
86+
"help_string": "display hidden FS objects",
87+
"argstr": "-a",
88+
},
89+
"long_format": {
90+
"type": bool,
91+
"help_string": (
92+
"display properties of FS object, such as permissions, size and "
93+
"timestamps "
94+
),
95+
"argstr": "-l",
96+
},
97+
"human_readable": {
98+
"type": bool,
99+
"help_string": "display file sizes in human readable form",
100+
"argstr": "-h",
101+
"requires": ["long_format"],
102+
},
103+
"complete_date": {
104+
"type": bool,
105+
"help_string": "Show complete date in long format",
106+
"argstr": "-T",
107+
"requires": ["long_format"],
108+
"xor": ["date_format_str"],
109+
},
110+
"date_format_str": {
111+
"type": str,
112+
"help_string": "format string for ",
113+
"argstr": "-D",
114+
"requires": ["long_format"],
115+
"xor": ["complete_date"],
116+
},
117+
},
118+
output_fields={
119+
"entries": {
120+
"type": list,
121+
"help_string": "list of entries returned by ls command",
122+
"callable": list_entries,
123+
}
124+
},
64125
)
65126

66-
tmpdir = Path(tempfile.mkdtemp())
127+
else:
128+
assert False
129+
130+
return Ls
131+
132+
133+
def test_shell_task_fields(Ls):
134+
assert [a.name for a in attrs.fields(Ls.Inputs)] == [
135+
"executable",
136+
"args",
137+
"directory",
138+
"hidden",
139+
"long_format",
140+
"human_readable",
141+
"complete_date",
142+
"date_format_str",
143+
]
144+
145+
assert [a.name for a in attrs.fields(Ls.Outputs)] == [
146+
"return_code",
147+
"stdout",
148+
"stderr",
149+
"entries",
150+
]
151+
152+
153+
def test_shell_task_pickle_roundtrip(Ls, tmpdir):
154+
pkl_file = tmpdir / "ls.pkl"
155+
with open(pkl_file, "wb") as f:
156+
cp.dump(Ls, f)
157+
158+
with open(pkl_file, "rb") as f:
159+
RereadLs = cp.load(f)
160+
161+
assert RereadLs is Ls
162+
163+
164+
@pytest.mark.xfail(
165+
reason=(
166+
"Need to change relationship between Inputs/Outputs and input_spec/output_spec "
167+
"for the task to run"
168+
)
169+
)
170+
def test_shell_task_init(Ls, tmpdir):
171+
inputs = Ls.Inputs(directory=tmpdir)
172+
assert inputs.directory == tmpdir
173+
assert not inputs.hidden
174+
outputs = Ls.Outputs(entries=[])
175+
assert outputs.entries == []
176+
177+
178+
@pytest.mark.xfail(
179+
reason=(
180+
"Need to change relationship between Inputs/Outputs and input_spec/output_spec "
181+
"for the task to run"
182+
)
183+
)
184+
def test_shell_task_run(Ls, tmpdir):
67185
Path.touch(tmpdir / "a")
68186
Path.touch(tmpdir / "b")
69187
Path.touch(tmpdir / "c")

0 commit comments

Comments
 (0)