Skip to content

Commit e48e2c4

Browse files
committed
wip
1 parent 58b4b23 commit e48e2c4

File tree

6 files changed

+111
-65
lines changed

6 files changed

+111
-65
lines changed

nitransforms/io/itk.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,8 @@ def from_image(cls, imgobj):
347347
warnings.warn("Incorrect intent identified.")
348348
hdr.set_intent("vector")
349349

350-
field = np.squeeze(np.asanyarray(imgobj.dataobj))
350+
field = np.squeeze(np.asanyarray(imgobj.dataobj)).transpose(2, 1, 0, 3)
351+
351352
return imgobj.__class__(field, LPS @ imgobj.affine, hdr)
352353

353354
@classmethod
@@ -357,7 +358,7 @@ def to_image(cls, imgobj):
357358
hdr = imgobj.header.copy()
358359
hdr.set_intent("vector")
359360

360-
warp_data = imgobj.get_fdata().reshape(imgobj.shape[:3] + (1, imgobj.shape[-1]))
361+
warp_data = imgobj.get_fdata().transpose(2, 1, 0, 3)[..., None, :]
361362
return imgobj.__class__(warp_data, LPS @ imgobj.affine, hdr)
362363

363364

nitransforms/nonlinear.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def to_x5(self, metadata=None):
285285
)
286286

287287
@classmethod
288-
def from_filename(cls, filename, fmt="X5", x5_position=0):
288+
def from_filename(cls, filename, is_deltas=True, fmt="X5", x5_position=0):
289289
_factory = {
290290
"afni": io.afni.AFNIDisplacementsField,
291291
"itk": io.itk.ITKDisplacementsField,
@@ -299,7 +299,7 @@ def from_filename(cls, filename, fmt="X5", x5_position=0):
299299
if fmt == "X5":
300300
return from_x5(load_x5(filename), x5_position=x5_position)
301301

302-
return cls(_factory[fmt.lower()].from_filename(filename))
302+
return cls(_factory[fmt.lower()].from_filename(filename), is_deltas=is_deltas)
303303

304304

305305
load = DenseFieldTransform.from_filename

nitransforms/resampling.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -254,22 +254,22 @@ def apply(
254254

255255
targets = None
256256
ref_ndcoords = _ref.ndcoords.T
257-
if hasattr(transform, "to_field") and callable(transform.to_field):
258-
targets = ImageGrid(spatialimage).index(
259-
_as_homogeneous(
260-
transform.to_field(reference=reference).map(ref_ndcoords),
261-
dim=_ref.ndim,
262-
)
263-
)
264-
else:
265-
# Targets' shape is (Nt, 3, Nv) with Nv = Num. voxels, Nt = Num. timepoints.
266-
targets = (
267-
ImageGrid(spatialimage).index(
268-
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
269-
)
270-
if targets is None
271-
else targets
257+
# if hasattr(transform, "to_field") and callable(transform.to_field):
258+
# targets = ImageGrid(spatialimage).index(
259+
# _as_homogeneous(
260+
# transform.to_field(reference=reference).map(ref_ndcoords),
261+
# dim=_ref.ndim,
262+
# )
263+
# )
264+
# else:
265+
# Targets' shape is (Nt, 3, Nv) with Nv = Num. voxels, Nt = Num. timepoints.
266+
targets = (
267+
ImageGrid(spatialimage).index(
268+
_as_homogeneous(transform.map(ref_ndcoords), dim=_ref.ndim)
272269
)
270+
if targets is None
271+
else targets
272+
)
273273

274274
if targets.ndim == 3:
275275
targets = np.rollaxis(targets, targets.ndim - 1, 0)

nitransforms/tests/test_io.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import pytest
1212
from h5py import File as H5File
1313

14+
import SimpleITK as sitk
1415
import nibabel as nb
1516
from nibabel.eulerangles import euler2mat
1617
from nibabel.affines import from_matvec
@@ -694,6 +695,47 @@ def test_itk_linear_h5(tmpdir, data_path, testdata_path):
694695
with pytest.raises(TransformIOError):
695696
itk.ITKLinearTransform.from_filename("test.h5")
696697

698+
# Added tests for displacements fields orientations (ANTs/ITK)
699+
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
700+
def test_itk_displacements(tmp_path, get_testdata, image_orientation):
701+
"""Exercise I/O of ITK displacements fields."""
702+
703+
nii = get_testdata[image_orientation]
704+
705+
# Create a reference centered at the origin with various axis orders/flips
706+
shape = nii.shape
707+
ref_affine = nii.affine.copy()
708+
709+
field = np.hstack((
710+
np.linspace(-50, 50, num=np.prod(shape)),
711+
np.linspace(-80, 80, num=np.prod(shape)),
712+
np.zeros(np.prod(shape))
713+
)).reshape(shape + (3, ))
714+
715+
nit_nii = itk.ITKDisplacementsField.to_image(
716+
nb.Nifti1Image(field, ref_affine, None)
717+
)
718+
719+
itk_file = tmp_path / "itk_displacements.nii.gz"
720+
itk_img = sitk.GetImageFromArray(field, isVector=True)
721+
itk_img.SetOrigin(tuple(ref_affine[:3, 3]))
722+
zooms = np.sqrt((ref_affine[:3, :3] ** 2).sum(0))
723+
itk_img.SetSpacing(tuple(zooms))
724+
direction = (ref_affine[:3, :3] / zooms).ravel()
725+
itk_img.SetDirection(tuple(direction))
726+
sitk.WriteImage(itk_img, str(itk_file))
727+
728+
itk_nit_nii = itk.ITKDisplacementsField.from_filename(itk_file)
729+
730+
assert itk_nit_nii.shape == field.shape
731+
np.testing.assert_allclose(itk_nit_nii.affine, ref_affine)
732+
np.testing.assert_allclose(itk_nit_nii.dataobj, field)
733+
734+
itk_nii = nb.load(itk_file)
735+
assert nit_nii.shape == itk_nii.shape
736+
np.testing.assert_allclose(itk_nii.dataobj, nit_nii.dataobj)
737+
np.testing.assert_allclose(itk_nii.affine, nit_nii.affine)
738+
697739

698740
# Added tests for h5 orientation bug (#167)
699741
def _load_composite_testdata(data_path):

nitransforms/tests/test_nonlinear.py

Lines changed: 41 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -288,36 +288,30 @@ def test_densefield_map_against_ants(testdata_path, tmp_path):
288288
assert np.allclose(mapped, ants_pts, atol=1e-6)
289289

290290

291-
@pytest.mark.parametrize(
292-
"mat",
293-
[
294-
np.eye(3),
295-
np.diag([-1.0, 1.0, 1.0]),
296-
np.diag([1.0, -1.0, 1.0]),
297-
np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]]),
298-
],
299-
)
300-
def test_constant_field_vs_ants(tmp_path, mat):
291+
@pytest.mark.parametrize("image_orientation", ["RAS", "LAS", "LPS", "oblique"])
292+
@pytest.mark.parametrize("gridpoints", [True, False])
293+
def test_constant_field_vs_ants(tmp_path, get_testdata, image_orientation, gridpoints):
301294
"""Create a constant displacement field and compare mappings."""
302295

296+
nii = get_testdata[image_orientation]
297+
303298
# Create a reference centered at the origin with various axis orders/flips
304-
shape = (25, 25, 25)
305-
center = (np.array(shape) - 1) / 2
306-
ref_affine = from_matvec(mat, -mat @ center)
299+
shape = nii.shape
300+
ref_affine = nii.affine.copy()
307301

308-
field = np.zeros(shape + (3,), dtype="float32")
309-
field[..., 0] = -5
310-
field[..., 1] = 2
311-
field[..., 2] = 0 # No displacement in the third axis
302+
field = np.hstack((
303+
np.linspace(-50, 50, num=np.prod(shape)),
304+
np.linspace(-80, 80, num=np.prod(shape)),
305+
np.zeros(np.prod(shape))
306+
)).reshape(shape + (3, ))
307+
fieldnii = nb.Nifti1Image(field, ref_affine, None)
312308

313309
warpfile = tmp_path / "const_disp.nii.gz"
314-
itk_img = sitk.GetImageFromArray(field, isVector=True)
315-
itk_img.SetOrigin(tuple(ref_affine[:3, 3]))
316-
zooms = np.sqrt((ref_affine[:3, :3] ** 2).sum(0))
317-
itk_img.SetSpacing(tuple(zooms))
318-
direction = (ref_affine[:3, :3] / zooms).ravel()
319-
itk_img.SetDirection(tuple(direction))
320-
sitk.WriteImage(itk_img, str(warpfile))
310+
ITKDisplacementsField.to_filename(fieldnii, warpfile)
311+
xfm = DenseFieldTransform(fieldnii)
312+
xfm2 = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
313+
314+
np.testing.assert_allclose(xfm.reference.affine, xfm2.reference.affine)
321315

322316
points = np.array(
323317
[
@@ -328,6 +322,11 @@ def test_constant_field_vs_ants(tmp_path, mat):
328322
[12.0, 0.0, -11.0],
329323
]
330324
)
325+
326+
if gridpoints:
327+
coords = xfm.reference.ndcoords
328+
points = (ref_affine @ np.vstack((coords, np.ones((1, coords.shape[1]))))).T[:, :3]
329+
331330
csvin = tmp_path / "points.csv"
332331
np.savetxt(csvin, points, delimiter=",", header="x,y,z", comments="")
333332

@@ -341,23 +340,25 @@ def test_constant_field_vs_ants(tmp_path, mat):
341340
ants_res = np.genfromtxt(csvout, delimiter=",", names=True)
342341
ants_pts = np.vstack([ants_res[n] for n in ("x", "y", "z")]).T
343342

344-
xfm = DenseFieldTransform(ITKDisplacementsField.from_filename(warpfile))
345-
mapped = xfm.map(points)
343+
import pdb; pdb.set_trace()
346344

347-
assert np.allclose(mapped, ants_pts, atol=1e-6)
345+
ants_field = ants_pts.reshape(shape + (3, ))
346+
diff = xfm._field[..., 0] - ants_field[..., 0]
347+
mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
348+
assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
348349

349-
# Verify deformation field generated via NiTransforms matches SimpleITK
350-
csvout2 = tmp_path / "out.csv"
351-
warpfile2 = tmp_path / "const_disp.nii.gz"
352-
ITKDisplacementsField.to_image(nb.Nifti1Image(field, ref_affine, None)).to_filename(
353-
warpfile2
354-
)
350+
diff = xfm._field[..., 1] - ants_field[..., 1]
351+
mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
352+
assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
355353

356-
check_call(
357-
f"antsApplyTransformsToPoints -d 3 -i {csvin} -o {csvout2} -t {warpfile}",
358-
shell=True,
359-
)
360-
ants_res2 = np.genfromtxt(csvout2, delimiter=",", names=True)
361-
ants_pts2 = np.vstack([ants_res2[n] for n in ("x", "y", "z")]).T
354+
diff = xfm._field[..., 2] - ants_field[..., 2]
355+
mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
356+
assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"
357+
358+
mapped = xfm.map(points)
359+
np.testing.assert_array_equal(np.round(mapped, 3), ants_pts)
360+
361+
diff = mapped - ants_pts
362+
mask = np.argwhere(np.abs(diff) > 1e-2)[:, 0]
362363

363-
assert np.allclose(ants_pts, ants_pts2, atol=1e-6)
364+
assert len(mask) == 0, f"A total of {len(mask)}/{ants_pts.shape[0]} contained errors:\n{diff[mask]}"

nitransforms/tests/test_resampling.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,21 @@ def test_displacements_field1(
177177
fieldmap[..., axis] = -10.0
178178

179179
_hdr = nii.header.copy()
180+
affine = nii.affine.copy()
180181
if sw_tool in ("itk",):
181182
_hdr.set_intent("vector")
183+
affine = io.itk.LPS @ affine
182184
_hdr.set_data_dtype("float32")
185+
186+
field = nb.Nifti1Image(fieldmap, affine, _hdr)
183187

184188
xfm_fname = "warp.nii.gz"
185-
field = nb.Nifti1Image(fieldmap, nii.affine, _hdr)
186189
field.to_filename(xfm_fname)
187190

188191
xfm = nitnl.load(xfm_fname, fmt=sw_tool)
189192

193+
np.testing.assert_array_equal(xfm._deltas, np.squeeze(field.dataobj))
194+
190195
# Then apply the transform and cross-check with software
191196
cmd = APPLY_NONLINEAR_CMD[sw_tool](
192197
transform=os.path.abspath(xfm_fname),
@@ -226,12 +231,9 @@ def test_displacements_field1(
226231
sw_moved = nb.load("resampled.nii.gz")
227232

228233
nt_moved = apply(xfm, nii, order=0)
229-
nt_moved.set_data_dtype(nii.get_data_dtype())
230234
nt_moved.to_filename("nt_resampled.nii.gz")
231-
sw_moved.set_data_dtype(nt_moved.get_data_dtype())
232-
diff = np.asanyarray(
233-
sw_moved.dataobj, dtype=sw_moved.get_data_dtype()
234-
) - np.asanyarray(nt_moved.dataobj, dtype=nt_moved.get_data_dtype())
235+
diff = sw_moved.get_fdata() - nt_moved.get_fdata()
236+
235237
# A certain tolerance is necessary because of resampling at borders
236238
assert np.sqrt((diff[brainmask] ** 2).mean()) < RMSE_TOL_LINEAR
237239

0 commit comments

Comments
 (0)