Source code for dials.algorithms.spot_finding.finder
"""
Contains implementation interface for finding spots on one or many images
"""
from __future__ import annotations
import logging
import math
import pickle
from collections.abc import Iterable
import libtbx
from dxtbx.format.image import ImageBool
from dxtbx.imageset import ImageSequence, ImageSet
from dxtbx.model import ExperimentList
from dxtbx.model.tof_helpers import wavelength_from_tof
from dials.array_family import flex
from dials.model.data import PixelList, PixelListLabeller
from dials.util import Sorry, log
from dials.util.log import rehandle_cached_records
from dials.util.mp import batch_multi_node_parallel_map
from dials.util.system import CPU_COUNT
logger = logging.getLogger(__name__)
[docs]
class ExtractPixelsFromImage:
"""
A class to extract pixels from a single image
"""
[docs]
def __init__(
self,
imageset,
threshold_function,
mask,
region_of_interest,
max_strong_pixel_fraction,
compute_mean_background,
):
"""
Initialise the class
:param imageset: The imageset to extract from
:param threshold_function: The function to threshold with
:param mask: The image mask
:param region_of_interest: A region of interest to process
:param max_strong_pixel_fraction: The maximum fraction of pixels allowed
"""
self.threshold_function = threshold_function
self.imageset = imageset
self.mask = mask
self.region_of_interest = region_of_interest
self.max_strong_pixel_fraction = max_strong_pixel_fraction
self.compute_mean_background = compute_mean_background
if self.mask is not None:
detector = self.imageset.get_detector()
assert len(self.mask) == len(detector)
def __call__(self, index):
"""
Extract strong pixels from an image
:param index: The index of the image
"""
# Get the frame number
if isinstance(self.imageset, ImageSequence):
frame = self.imageset.get_array_range()[0] + index
else:
ind = self.imageset.indices()
if len(ind) > 1:
assert all(i1 + 1 == i2 for i1, i2 in zip(ind[0:-1], ind[1:-1]))
frame = ind[index]
# Create the list of pixel lists
pixel_list = []
# Get the image and mask
image = self.imageset.get_corrected_data(index)
mask = self.imageset.get_mask(index)
# Set the mask
if self.mask is not None:
assert len(self.mask) == len(mask)
mask = tuple(m1 & m2 for m1, m2 in zip(mask, self.mask))
logger.debug(
f"Number of masked pixels for image {index}: {sum(m.count(False) for m in mask)}",
)
# Add the images to the pixel lists
num_strong = 0
average_background = 0
for i_panel, (im, mk) in enumerate(zip(image, mask)):
if self.imageset.is_marked_for_rejection(index):
threshold_mask = flex.bool(im.accessor(), False)
elif self.region_of_interest is not None:
x0, x1, y0, y1 = self.region_of_interest
height, width = im.all()
assert x0 < x1, "x0 < x1"
assert y0 < y1, "y0 < y1"
assert x0 >= 0, "x0 >= 0"
assert y0 >= 0, "y0 >= 0"
assert x1 <= width, "x1 <= width"
assert y1 <= height, "y1 <= height"
im_roi = im[y0:y1, x0:x1]
mk_roi = mk[y0:y1, x0:x1]
tm_roi = self.threshold_function.compute_threshold(
im_roi,
mk_roi,
imageset=self.imageset,
i_panel=i_panel,
region_of_interest=self.region_of_interest,
)
threshold_mask = flex.bool(im.accessor(), False)
threshold_mask[y0:y1, x0:x1] = tm_roi
else:
threshold_mask = self.threshold_function.compute_threshold(
im, mk, imageset=self.imageset, i_panel=i_panel
)
# Add the pixel list
plist = PixelList(frame, im, threshold_mask)
pixel_list.append(plist)
# Get average background
if self.compute_mean_background:
background = im.as_1d().select((mk & ~threshold_mask).as_1d())
average_background += flex.mean(background)
# Add to the spot count
num_strong += len(plist)
# Make average background
average_background /= len(image)
# Check total number of strong pixels
if self.max_strong_pixel_fraction < 1:
num_image = 0
for im in image:
num_image += len(im)
max_strong = int(math.ceil(self.max_strong_pixel_fraction * num_image))
if num_strong > max_strong:
raise RuntimeError(
f"""
The number of strong pixels found ({num_strong}) is greater than the
maximum allowed ({max_strong}). Try changing spot finding parameters
"""
)
# Print some info
if self.compute_mean_background:
logger.info(
f"Found {num_strong} strong pixels on image {frame + 1} with average background {average_background}",
)
else:
logger.info(f"Found {num_strong} strong pixels on image {frame + 1}")
# Return the result
return pixel_list
[docs]
class ExtractPixelsFromImage2DNoShoeboxes(ExtractPixelsFromImage):
"""
A class to extract pixels from a single image
"""
[docs]
def __init__(
self,
imageset,
threshold_function,
mask,
region_of_interest,
max_strong_pixel_fraction,
compute_mean_background,
min_spot_size,
max_spot_size,
filter_spots,
):
"""
Initialise the class
:param imageset: The imageset to extract from
:param threshold_function: The function to threshold with
:param mask: The image mask
:param region_of_interest: A region of interest to process
:param max_strong_pixel_fraction: The maximum fraction of pixels allowed
"""
super().__init__(
imageset,
threshold_function,
mask,
region_of_interest,
max_strong_pixel_fraction,
compute_mean_background,
)
# Save some stuff
self.min_spot_size = min_spot_size
self.max_spot_size = max_spot_size
self.filter_spots = filter_spots
def __call__(self, index):
"""
Extract strong pixels from an image
:param index: The index of the image
"""
# Initialise the pixel labeller
num_panels = len(self.imageset.get_detector())
pixel_labeller = [PixelListLabeller() for p in range(num_panels)]
# Call the super function
result = super().__call__(index)
# Add pixel lists to the labeller
assert len(pixel_labeller) == len(result), "Inconsistent size"
for plabeller, plist in zip(pixel_labeller, result):
plabeller.add(plist)
# Create shoeboxes from pixel list
reflections, _ = pixel_list_to_reflection_table(
self.imageset,
pixel_labeller,
filter_spots=self.filter_spots,
min_spot_size=self.min_spot_size,
max_spot_size=self.max_spot_size,
write_hot_pixel_mask=False,
)
# Delete the shoeboxes
del reflections["shoeboxes"]
# Return the reflections
return [reflections]
[docs]
class ExtractSpotsParallelTask:
"""
Execute the spot finder task in parallel
We need this external class so that we can pickle it for cluster jobs
"""
[docs]
def __init__(self, function):
"""
Initialise with the function to call
"""
self.function = function
def __call__(self, task):
"""
Call the function with th task and save the IO
"""
log.config_simple_cached()
result = self.function(task)
handlers = logging.getLogger("dials").handlers
assert len(handlers) == 1, "Invalid number of logging handlers"
return result, handlers[0].records
[docs]
def pixel_list_to_shoeboxes(
imageset: ImageSet,
pixel_labeller: Iterable[PixelListLabeller],
min_spot_size: int,
max_spot_size: int,
write_hot_pixel_mask: bool,
) -> tuple[flex.shoebox, tuple[flex.size_t, ...]]:
"""Convert a pixel list to shoeboxes"""
# Extract the pixel lists into a list of reflections
shoeboxes = flex.shoebox()
spotsizes = flex.size_t()
hotpixels = tuple(flex.size_t() for i in range(len(imageset.get_detector())))
if isinstance(imageset, ImageSequence):
twod = imageset.get_scan().is_still()
else:
twod = True
for i, (p, hp) in enumerate(zip(pixel_labeller, hotpixels)):
if p.num_pixels() > 0:
creator = flex.PixelListShoeboxCreator(
p,
i, # panel
0, # zrange
twod, # twod
min_spot_size, # min_pixels
max_spot_size, # max_pixels
write_hot_pixel_mask,
)
shoeboxes.extend(creator.result())
spotsizes.extend(creator.spot_size())
hp.extend(creator.hot_pixels())
logger.info(f"\nExtracted {len(shoeboxes)} spots")
# Get the unallocated spots and print some info
selection = shoeboxes.is_allocated()
shoeboxes = shoeboxes.select(selection)
ntoosmall = (spotsizes < min_spot_size).count(True)
ntoolarge = (spotsizes > max_spot_size).count(True)
assert ntoosmall + ntoolarge == selection.count(False)
logger.info(f"Removed {ntoosmall} spots with size < {min_spot_size} pixels")
logger.info(f"Removed {ntoolarge} spots with size > {max_spot_size} pixels")
# Return the shoeboxes
return shoeboxes, hotpixels
[docs]
def shoeboxes_to_reflection_table(
imageset: ImageSet, shoeboxes: flex.shoebox, filter_spots
) -> flex.reflection_table:
"""Filter shoeboxes and create reflection table"""
# Calculate the spot centroids
centroid = shoeboxes.centroid_valid()
logger.info(f"Calculated {len(shoeboxes)} spot centroids")
# Calculate the spot intensities
intensity = shoeboxes.summed_intensity()
logger.info(f"Calculated {len(shoeboxes)} spot intensities")
# Create the observations
observed = flex.observation(shoeboxes.panels(), centroid, intensity)
# Filter the reflections and select only the desired spots
flags = filter_spots(
None, sweep=imageset, observations=observed, shoeboxes=shoeboxes
)
observed = observed.select(flags)
shoeboxes = shoeboxes.select(flags)
# Return as a reflection list
return flex.reflection_table(observed, shoeboxes)
[docs]
def pixel_list_to_reflection_table(
imageset: ImageSet,
pixel_labeller: Iterable[PixelListLabeller],
filter_spots,
min_spot_size: int,
max_spot_size: int,
write_hot_pixel_mask: bool,
) -> tuple[flex.shoebox, tuple[flex.size_t, ...]]:
"""Convert pixel list to reflection table"""
shoeboxes, hot_pixels = pixel_list_to_shoeboxes(
imageset,
pixel_labeller,
min_spot_size=min_spot_size,
max_spot_size=max_spot_size,
write_hot_pixel_mask=write_hot_pixel_mask,
)
# Setup the reflection table converter
return (
shoeboxes_to_reflection_table(imageset, shoeboxes, filter_spots=filter_spots),
hot_pixels,
)
[docs]
class ExtractSpots:
"""
Class to find spots in an image and extract them into shoeboxes.
"""
[docs]
def __init__(
self,
threshold_function=None,
mask=None,
region_of_interest=None,
max_strong_pixel_fraction=0.1,
compute_mean_background=False,
mp_method=None,
mp_nproc=1,
mp_njobs=1,
mp_chunksize=1,
min_spot_size=1,
max_spot_size=20,
filter_spots=None,
no_shoeboxes_2d=False,
min_chunksize=50,
write_hot_pixel_mask=False,
):
"""
Initialise the class with the strategy
:param threshold_function: The image thresholding strategy
:param mask: The mask to use
:param mp_method: The multi processing method
:param nproc: The number of processors
:param max_strong_pixel_fraction: The maximum number of strong pixels
"""
# Set the required strategies
self.threshold_function = threshold_function
self.mask = mask
self.mp_method = mp_method
self.mp_chunksize = mp_chunksize
self.mp_nproc = mp_nproc
self.mp_njobs = mp_njobs
self.max_strong_pixel_fraction = max_strong_pixel_fraction
self.compute_mean_background = compute_mean_background
self.region_of_interest = region_of_interest
self.min_spot_size = min_spot_size
self.max_spot_size = max_spot_size
self.filter_spots = filter_spots
self.no_shoeboxes_2d = no_shoeboxes_2d
self.min_chunksize = min_chunksize
self.write_hot_pixel_mask = write_hot_pixel_mask
def __call__(self, imageset):
"""
Find the spots in the imageset
:param imageset: The imageset to process
:return: The list of spot shoeboxes
"""
if not self.no_shoeboxes_2d:
return self._find_spots(imageset)
else:
return self._find_spots_2d_no_shoeboxes(imageset)
def _compute_chunksize(self, nimg, nproc, min_chunksize):
"""
Compute the chunk size for a given number of images and processes
"""
chunksize = int(math.ceil(nimg / nproc))
remainder = nimg % (chunksize * nproc)
test_chunksize = chunksize - 1
while test_chunksize >= min_chunksize:
test_remainder = nimg % (test_chunksize * nproc)
if test_remainder <= remainder:
chunksize = test_chunksize
remainder = test_remainder
test_chunksize -= 1
return chunksize
def _find_spots(self, imageset):
"""
Find the spots in the imageset
:param imageset: The imageset to process
:return: The list of spot shoeboxes
"""
# Change the number of processors if necessary
mp_nproc = self.mp_nproc
mp_njobs = self.mp_njobs
if mp_nproc is libtbx.Auto:
mp_nproc = CPU_COUNT
logger.info(f"Setting nproc={mp_nproc}")
if mp_nproc * mp_njobs > len(imageset):
mp_nproc = min(mp_nproc, len(imageset))
mp_njobs = int(math.ceil(len(imageset) / mp_nproc))
mp_method = self.mp_method
mp_chunksize = self.mp_chunksize
if mp_chunksize is libtbx.Auto:
mp_chunksize = self._compute_chunksize(
len(imageset), mp_njobs * mp_nproc, self.min_chunksize
)
logger.info(f"Setting chunksize={mp_chunksize}")
len_by_nproc = int(math.floor(len(imageset) / (mp_njobs * mp_nproc)))
if mp_chunksize > len_by_nproc:
mp_chunksize = len_by_nproc
if mp_chunksize == 0:
mp_chunksize = 1
assert mp_nproc > 0, "Invalid number of processors"
assert mp_njobs > 0, "Invalid number of jobs"
assert mp_njobs == 1 or mp_method is not None, "Invalid cluster method"
assert mp_chunksize > 0, "Invalid chunk size"
# The extract pixels function
function = ExtractPixelsFromImage(
imageset=imageset,
threshold_function=self.threshold_function,
mask=self.mask,
max_strong_pixel_fraction=self.max_strong_pixel_fraction,
compute_mean_background=self.compute_mean_background,
region_of_interest=self.region_of_interest,
)
# The indices to iterate over
indices = list(range(len(imageset)))
# Initialise the pixel labeller
num_panels = len(imageset.get_detector())
pixel_labeller = [PixelListLabeller() for p in range(num_panels)]
# Do the processing
logger.info("Extracting strong pixels from images")
if mp_njobs > 1:
logger.info(
f" Using {mp_method} with {mp_njobs} parallel job(s) and {mp_nproc} processes per node\n"
)
else:
logger.info(f" Using multiprocessing with {mp_nproc} parallel job(s)\n")
if mp_nproc > 1 or mp_njobs > 1:
def process_output(result):
rehandle_cached_records(result[1])
assert len(pixel_labeller) == len(result[0]), "Inconsistent size"
for plabeller, plist in zip(pixel_labeller, result[0]):
plabeller.add(plist)
batch_multi_node_parallel_map(
func=ExtractSpotsParallelTask(function),
iterable=indices,
nproc=mp_nproc,
njobs=mp_njobs,
cluster_method=mp_method,
chunksize=mp_chunksize,
callback=process_output,
)
else:
for task in indices:
result = function(task)
assert len(pixel_labeller) == len(result), "Inconsistent size"
for plabeller, plist in zip(pixel_labeller, result):
plabeller.add(plist)
result.clear()
# Create shoeboxes from pixel list
return pixel_list_to_reflection_table(
imageset,
pixel_labeller,
filter_spots=self.filter_spots,
min_spot_size=self.min_spot_size,
max_spot_size=self.max_spot_size,
write_hot_pixel_mask=self.write_hot_pixel_mask,
)
def _find_spots_2d_no_shoeboxes(self, imageset):
"""
Find the spots in the imageset
:param imageset: The imageset to process
:return: The list of spot shoeboxes
"""
# Change the number of processors if necessary
mp_nproc = self.mp_nproc
mp_njobs = self.mp_njobs
if mp_nproc * mp_njobs > len(imageset):
mp_nproc = min(mp_nproc, len(imageset))
mp_njobs = int(math.ceil(len(imageset) / mp_nproc))
mp_method = self.mp_method
mp_chunksize = self.mp_chunksize
if mp_chunksize == libtbx.Auto:
mp_chunksize = self._compute_chunksize(
len(imageset), mp_njobs * mp_nproc, self.min_chunksize
)
logger.info(f"Setting chunksize={mp_chunksize}")
len_by_nproc = int(math.floor(len(imageset) / (mp_njobs * mp_nproc)))
if mp_chunksize > len_by_nproc:
mp_chunksize = len_by_nproc
assert mp_nproc > 0, "Invalid number of processors"
assert mp_njobs > 0, "Invalid number of jobs"
assert mp_njobs == 1 or mp_method is not None, "Invalid cluster method"
assert mp_chunksize > 0, "Invalid chunk size"
# The extract pixels function
function = ExtractPixelsFromImage2DNoShoeboxes(
imageset=imageset,
threshold_function=self.threshold_function,
mask=self.mask,
max_strong_pixel_fraction=self.max_strong_pixel_fraction,
compute_mean_background=self.compute_mean_background,
region_of_interest=self.region_of_interest,
min_spot_size=self.min_spot_size,
max_spot_size=self.max_spot_size,
filter_spots=self.filter_spots,
)
# The indices to iterate over
indices = list(range(len(imageset)))
# The resulting reflections
reflections = flex.reflection_table()
# Do the processing
logger.info("Extracting strong spots from images")
if mp_njobs > 1:
logger.info(
f" Using {mp_method} with {mp_njobs} parallel job(s) and {mp_nproc} processes per node\n"
)
else:
logger.info(f" Using multiprocessing with {mp_nproc} parallel job(s)\n")
if mp_nproc > 1 or mp_njobs > 1:
def process_output(result):
for message in result[1]:
logger.log(message.levelno, message.msg)
reflections.extend(result[0][0])
result[0][0] = None
batch_multi_node_parallel_map(
func=ExtractSpotsParallelTask(function),
iterable=indices,
nproc=mp_nproc,
njobs=mp_njobs,
cluster_method=mp_method,
chunksize=mp_chunksize,
callback=process_output,
)
else:
for task in indices:
reflections.extend(function(task)[0])
# Return the reflections
return reflections, None
[docs]
class SpotFinder:
"""
A class to do spot finding and filtering.
"""
[docs]
def __init__(
self,
threshold_function=None,
mask=None,
region_of_interest=None,
max_strong_pixel_fraction=0.1,
compute_mean_background=False,
mp_method=None,
mp_nproc=1,
mp_njobs=1,
mp_chunksize=1,
mask_generator=None,
filter_spots=None,
scan_range=None,
write_hot_mask=True,
hot_mask_prefix="hot_mask",
min_spot_size=1,
max_spot_size=20,
no_shoeboxes_2d=False,
min_chunksize=50,
is_stills=False,
):
"""
Initialise the class.
:param find_spots: The spot finding algorithm
:param filter_spots: The spot filtering algorithm
:param scan_range: The scan range to find spots over
:param is_stills: [ADVANCED] Force still-handling of experiment
ID remapping for dials.stills_process.
"""
# Set the filter and some other stuff
self.threshold_function = threshold_function
self.mask = mask
self.region_of_interest = region_of_interest
self.max_strong_pixel_fraction = max_strong_pixel_fraction
self.compute_mean_background = compute_mean_background
self.mask_generator = mask_generator
self.filter_spots = filter_spots
self.scan_range = scan_range
self.write_hot_mask = write_hot_mask
self.hot_mask_prefix = hot_mask_prefix
self.min_spot_size = min_spot_size
self.max_spot_size = max_spot_size
self.mp_method = mp_method
self.mp_chunksize = mp_chunksize
self.mp_nproc = mp_nproc
self.mp_njobs = mp_njobs
self.no_shoeboxes_2d = no_shoeboxes_2d
self.min_chunksize = min_chunksize
self.is_stills = is_stills
[docs]
def find_spots(self, experiments: ExperimentList) -> flex.reflection_table:
"""
Do spotfinding for a set of experiments.
Args:
experiments: The experiment list to process
Returns:
A new reflection table of found reflections
"""
# Loop through all the experiments and get the unique imagesets
imagesets = []
for experiment in experiments:
if experiment.imageset not in imagesets:
imagesets.append(experiment.imageset)
# Loop through all the imagesets and find the strong spots
reflections = flex.reflection_table()
for j, imageset in enumerate(imagesets):
# Find the strong spots in the sequence
logger.info(
"-" * 80 + f"\nFinding strong spots in imageset {j}\n" + "-" * 80
)
table, hot_mask = self._find_spots_in_imageset(imageset)
# Fix up the experiment ID's now
table["id"] = flex.int(table.nrows(), -1)
for i, experiment in enumerate(experiments):
if experiment.imageset is not imageset:
continue
if not self.is_stills and experiment.scan:
z0, z1 = experiment.scan.get_array_range()
z = table["xyzobs.px.value"].parts()[2]
table["id"].set_selected((z > z0) & (z < z1), i)
if experiment.identifier:
table.experiment_identifiers()[i] = experiment.identifier
else:
table["id"] = flex.int(table.nrows(), j)
if experiment.identifier:
table.experiment_identifiers()[j] = experiment.identifier
missed = table["id"] == -1
assert (
missed.count(True) == 0
), f"Failed to remap {missed.count(True)} experiment IDs"
reflections.extend(table)
# Write a hot pixel mask
if self.write_hot_mask:
if not imageset.external_lookup.mask.data.empty():
for m1, m2 in zip(hot_mask, imageset.external_lookup.mask.data):
m1 &= m2.data()
imageset.external_lookup.mask.data = ImageBool(hot_mask)
else:
imageset.external_lookup.mask.data = ImageBool(hot_mask)
imageset.external_lookup.mask.filename = "%s_%d.pickle" % (
self.hot_mask_prefix,
i,
)
# Write the hot mask
with open(imageset.external_lookup.mask.filename, "wb") as outfile:
pickle.dump(hot_mask, outfile, protocol=pickle.HIGHEST_PROTOCOL)
# Set the strong spot flag
reflections.set_flags(
flex.size_t_range(len(reflections)), reflections.flags.strong
)
reflections.is_overloaded(experiments)
reflections = self._post_process(reflections)
return reflections
def _find_spots_in_imageset(self, imageset):
"""
Do the spot finding.
:param imageset: The imageset to process
:return: The observed spots
"""
# The input mask
mask = self.mask_generator(imageset)
if self.mask is not None:
mask = tuple(m1 & m2 for m1, m2 in zip(mask, self.mask))
# Set the spot finding algorithm
extract_spots = ExtractSpots(
threshold_function=self.threshold_function,
mask=mask,
region_of_interest=self.region_of_interest,
max_strong_pixel_fraction=self.max_strong_pixel_fraction,
compute_mean_background=self.compute_mean_background,
mp_method=self.mp_method,
mp_nproc=self.mp_nproc,
mp_njobs=self.mp_njobs,
mp_chunksize=self.mp_chunksize,
min_spot_size=self.min_spot_size,
max_spot_size=self.max_spot_size,
filter_spots=self.filter_spots,
no_shoeboxes_2d=self.no_shoeboxes_2d,
min_chunksize=self.min_chunksize,
write_hot_pixel_mask=self.write_hot_mask,
)
# Get the max scan range
if isinstance(imageset, ImageSequence):
max_scan_range = imageset.get_array_range()
else:
max_scan_range = (0, len(imageset))
# Get list of scan ranges
if not self.scan_range or self.scan_range[0] is None:
scan_range = [(max_scan_range[0] + 1, max_scan_range[1])]
else:
scan_range = self.scan_range
# Get spots from bits of scan
hot_pixels = tuple(flex.size_t() for i in range(len(imageset.get_detector())))
reflections = flex.reflection_table()
for j0, j1 in scan_range:
# Make sure we were asked to do something sensible
if j1 < j0:
raise Sorry("Scan range must be in ascending order")
elif j0 < max_scan_range[0] or j1 > max_scan_range[1]:
raise Sorry(
f"Scan range must be within image range {max_scan_range[0] + 1}..{max_scan_range[1]}"
)
logger.info(f"\nFinding spots in image {j0} to {j1}...")
j0 -= 1
if isinstance(imageset, ImageSequence):
j0 -= imageset.get_array_range()[0]
j1 -= imageset.get_array_range()[0]
if len(imageset) == 1:
r, h = extract_spots(imageset)
else:
r, h = extract_spots(imageset[j0:j1])
reflections.extend(r)
if h is not None:
for h1, h2 in zip(hot_pixels, h):
h1.extend(h2)
# Find hot pixels
hot_mask = self._create_hot_mask(imageset, hot_pixels)
# Return as a reflection list
return reflections, hot_mask
def _create_hot_mask(self, imageset, hot_pixels):
"""
Find hot pixels in images
"""
# Write the hot mask
if self.write_hot_mask:
# Create the hot pixel mask
hot_mask = tuple(
flex.bool(flex.grid(p.get_image_size()[::-1]), True)
for p in imageset.get_detector()
)
num_hot = 0
if hot_pixels:
for hp, hm in zip(hot_pixels, hot_mask):
for i in range(len(hp)):
hm[hp[i]] = False
num_hot += len(hp)
logger.info(f"Found {num_hot} possible hot pixel(s)")
else:
hot_mask = None
# Return the hot mask
return hot_mask
def _post_process(self, reflections):
return reflections
[docs]
class TOFSpotFinder(SpotFinder):
"""
Class to do spot finding tailored to time of flight experiments
"""
[docs]
def __init__(
self,
experiments,
threshold_function=None,
mask=None,
region_of_interest=None,
max_strong_pixel_fraction=0.1,
compute_mean_background=False,
mp_method=None,
mp_nproc=1,
mp_njobs=1,
mp_chunksize=1,
mask_generator=None,
filter_spots=None,
scan_range=None,
write_hot_mask=True,
hot_mask_prefix="hot_mask",
min_spot_size=1,
max_spot_size=20,
min_chunksize=50,
):
super().__init__(
threshold_function=threshold_function,
mask=mask,
region_of_interest=region_of_interest,
max_strong_pixel_fraction=max_strong_pixel_fraction,
compute_mean_background=compute_mean_background,
mp_method=mp_method,
mp_nproc=mp_nproc,
mp_njobs=mp_njobs,
mp_chunksize=mp_chunksize,
mask_generator=mask_generator,
filter_spots=filter_spots,
scan_range=scan_range,
write_hot_mask=write_hot_mask,
hot_mask_prefix=hot_mask_prefix,
min_spot_size=min_spot_size,
max_spot_size=max_spot_size,
no_shoeboxes_2d=False,
min_chunksize=min_chunksize,
is_stills=False,
)
self.experiments = experiments
def _correct_centroid_tof(self, reflections):
"""
Sets the centroid of the spot to the peak position along the
time of flight, as this tends to more accurately represent the true
centroid for spallation sources.
"""
x, y, tof = reflections["xyzobs.px.value"].parts()
peak_x, peak_y, peak_tof = reflections["shoebox"].peak_coordinates().parts()
reflections["xyzobs.px.value"] = flex.vec3_double(x, y, peak_tof)
return reflections
def _post_process(self, reflections):
reflections = self._correct_centroid_tof(reflections)
# Filter any reflections outside of the tof range
for i, expt in enumerate(self.experiments):
_, _, frame = reflections["xyzobs.px.value"].parts()
tof_frame_range = (0, len(expt.scan.get_property("time_of_flight")) - 1)
if "imageset_id" in reflections:
sel_expt = reflections["imageset_id"] == i
else:
sel_expt = reflections["id"] == i
sel = sel_expt & (
(frame > tof_frame_range[1]) | (frame < tof_frame_range[0])
)
reflections = reflections.select(~sel)
n_rows = reflections.nrows()
panel_numbers = flex.size_t(reflections["panel"])
reflections["L1"] = flex.double(n_rows)
reflections["wavelength"] = flex.double(n_rows)
reflections["s0"] = flex.vec3_double(n_rows)
reflections.centroid_px_to_mm(self.experiments)
for i, expt in enumerate(self.experiments):
if "imageset_id" in reflections:
sel_expt = reflections["imageset_id"] == i
else:
sel_expt = reflections["id"] == i
L0 = expt.beam.get_sample_to_source_distance() * 10**-3 # (m)
unit_s0 = expt.beam.get_unit_s0()
for i_panel in range(len(expt.detector)):
sel = sel_expt & (panel_numbers == i_panel)
x, y, tof = reflections["xyzobs.mm.value"].select(sel).parts()
px, py, frame = reflections["xyzobs.px.value"].select(sel).parts()
s1 = expt.detector[i_panel].get_lab_coord(flex.vec2_double(x, y))
L1 = s1.norms()
wavelengths = wavelength_from_tof(L0 + L1 * 10**-3, tof * 10**-6)
s0s = flex.vec3_double(
unit_s0[0] / wavelengths,
unit_s0[1] / wavelengths,
unit_s0[2] / wavelengths,
)
reflections["wavelength"].set_selected(sel, wavelengths)
reflections["s0"].set_selected(sel, s0s)
reflections["L1"].set_selected(sel, L1)
return reflections