import warnings
from shutil import copy as copy_file
from os import path
from h5py import VirtualSource, VirtualLayout
from tomoscan.io import HDF5File
from ..resources.logger import Logger, LoggerOrPrint
from ..io.utils import get_first_hdf5_entry
from .cli_configs import ZSplitConfig
from .utils import parse_params_values
warnings.warn(
"This command-line utility is intended as a temporary solution. Please do not rely too much on it.", Warning
)
def _get_z_translations(fname, entry):
z_path = path.join(entry, "sample", "z_translation")
with HDF5File(fname, "r") as fid:
z_transl = fid[z_path][:]
return z_transl
[docs]
class NXZSplitter:
def __init__(self, fname, output_dir, n_stages=None, entry=None, logger=None, use_virtual_dataset=False):
self.fname = fname
self._ext = path.splitext(fname)[-1]
self.output_dir = output_dir
self.n_stages = n_stages
if entry is None:
entry = get_first_hdf5_entry(fname)
self.entry = entry
self.logger = LoggerOrPrint(logger)
self.use_virtual_dataset = use_virtual_dataset
def _patch_nx_file(self, fname, mask):
orig_fname = self.fname
detector_path = path.join(self.entry, "instrument", "detector")
sample_path = path.join(self.entry, "sample")
with HDF5File(fname, "a") as fid:
def patch_nx_entry(name):
newval = fid[name][mask]
del fid[name]
fid[name] = newval
detector_entries = [
path.join(detector_path, what) for what in ["count_time", "image_key", "image_key_control"]
]
sample_entries = [
path.join(sample_path, what)
for what in ["rotation_angle", "x_translation", "y_translation", "z_translation"]
]
for what in detector_entries + sample_entries:
self.logger.debug("Patching %s" % what)
patch_nx_entry(what)
# Patch "data" using a virtual dataset
self.logger.debug("Patching data")
data_path = path.join(detector_path, "data")
if self.use_virtual_dataset:
data_shape = fid[data_path].shape
data_dtype = fid[data_path].dtype
new_data_shape = (int(mask.sum()),) + data_shape[1:]
vlayout = VirtualLayout(shape=new_data_shape, dtype=data_dtype)
vsource = VirtualSource(orig_fname, name=data_path, shape=data_shape, dtype=data_dtype)
vlayout[:] = vsource[mask, :, :]
del fid[data_path]
fid[detector_path].create_virtual_dataset("data", vlayout)
if not (self.use_virtual_dataset):
data_path = path.join(self.entry, "instrument", "detector", "data")
with HDF5File(orig_fname, "r") as fid:
data_arr = fid[data_path][mask, :, :] # Actually load data. Heavy !
with HDF5File(fname, "a") as fid:
del fid[data_path]
fid[data_path] = data_arr
[docs]
def z_split(self):
"""
Split a HDF5-NX file according to different z_translation.
"""
z_transl = _get_z_translations(self.fname, self.entry)
different_z = set(z_transl)
n_z = len(different_z)
self.logger.info("Detected %d different z values: %s" % (n_z, str(different_z)))
if n_z <= 1:
raise ValueError("Detected only %d z-value. Stopping." % n_z)
if self.n_stages is not None and self.n_stages != n_z:
raise ValueError("Expected %d different stages, but I detected %d" % (self.n_stages, n_z))
masks = [(z_transl == z) for z in different_z]
for i_z, mask in enumerate(masks):
fname_curr_z = path.join(
self.output_dir, path.splitext(path.basename(self.fname))[0] + str("_%06d" % i_z) + self._ext
)
self.logger.info("Creating %s" % fname_curr_z)
copy_file(self.fname, fname_curr_z)
self._patch_nx_file(fname_curr_z, mask)
[docs]
def zsplit():
# Parse arguments
args = parse_params_values(
ZSplitConfig, parser_description="Split a HDF5-Nexus file according to z translation (z-series)"
)
# Sanitize arguments
fname = args["input_file"]
output_dir = args["output_directory"]
loglevel = args["loglevel"].upper()
entry = args["entry"]
if len(entry) == 0:
entry = None
n_stages = args["n_stages"]
if n_stages < 0:
n_stages = None
use_virtual_dataset = bool(args["use_virtual_dataset"])
# Instantiate and execute
logger = Logger("NX_z-splitter", level=loglevel, logfile="nxzsplit.log")
nx_splitter = NXZSplitter(
fname, output_dir, n_stages=n_stages, entry=entry, logger=logger, use_virtual_dataset=use_virtual_dataset
)
nx_splitter.z_split()
return 0
if __name__ == "__main__":
zsplit()