Click here to go to the corresponding page for the latest version of DIALS
Source code for dials.algorithms.spot_finding.factory
from __future__ import absolute_import, division, print_function
import collections
import logging
logger = logging.getLogger(__name__)
[docs]def generate_phil_scope():
    from iotbx.phil import parse
    import dials.extensions
    phil_scope = parse(
        """
  spotfinder
    .help = "Parameters used in the spot finding algorithm."
  {
    include scope dials.data.lookup.phil_scope
    write_hot_mask = False
      .type = bool
      .help = "Write the hot mask"
    hot_mask_prefix = 'hot_mask'
      .type = str
      .help = "Prefix for the hot mask pickle file"
    force_2d = False
      .type = bool
      .help = "Do spot finding in 2D"
    scan_range = None
      .help = "The range of images to use in finding spots. The ranges are"
              "inclusive (e.g. j0 <= j < j1)."
              "For sweeps the scan range is interpreted as the literal scan"
              "range. Whereas for imagesets the scan range is interpreted as"
              "the image number in the imageset. Multiple ranges can be"
              "specified by repeating the scan_range= parameter."
      .type = ints(size=2)
      .multiple = True
    region_of_interest = None
      .type = ints(size=4)
      .help = "A region of interest to look for spots."
              "Specified as: x0,x1,y0,y1"
              "The pixels x0 and y0 are included in the range but the pixels x1 and y1"
              "are not. To specify an ROI covering the whole image set"
              "region_of_interest=0,width,0,height."
    compute_mean_background = False
      .type = bool
      .help = "Compute the mean background for each image"
    filter
      .help = "Parameters used in the spot finding filter strategy."
    {
      min_spot_size = Auto
        .help = "The minimum number of contiguous pixels for a spot"
                "to be accepted by the filtering algorithm."
        .type = int(value_min=1)
      max_spot_size = 100
        .help = "The maximum number of contiguous pixels for a spot"
                "to be accepted by the filtering algorithm."
        .type = int(value_min=1, allow_none=False)
      max_separation = 2
        .help = "The maximum peak-to-centroid separation (in pixels)"
                "for a spot to be accepted by the filtering algorithm."
        .type = float(value_min=0)
        .expert_level = 1
      max_strong_pixel_fraction = 0.25
        .help = "If the fraction of pixels in an image marked as strong is"
                "greater than this value, throw an exception"
        .type = float(value_min=0, value_max=1)
      background_gradient
        .expert_level=2
      {
        filter = False
          .type = bool
        background_size = 2
          .type = int(value_min=1)
        gradient_cutoff = 4
          .type = float(value_min=0)
      }
      spot_density
        .expert_level=2
      {
        filter = False
          .type = bool
      }
      include scope dials.util.masking.phil_scope
    }
    mp {
      method = *none drmaa sge lsf pbs
        .type = choice
        .help = "The cluster method to use"
      njobs = 1
        .type = int(value_min=1)
        .help = "The number of cluster jobs to use"
      nproc = 1
        .type = int(value_min=1)
        .help = "The number of processes to use per cluster job"
      chunksize = auto
        .type = int(value_min=1)
        .help = "The number of jobs to process per process"
      min_chunksize = 20
        .type = int(value_min=1)
        .help = "When chunksize is auto, this is the minimum chunksize"
    }
  }
  """,
        process_includes=True,
    )
    main_scope = phil_scope.get_without_substitution("spotfinder")
    assert len(main_scope) == 1
    main_scope = main_scope[0]
    main_scope.adopt_scope(dials.extensions.SpotFinderThreshold.phil_scope())
    return phil_scope
phil_scope = generate_phil_scope()
[docs]class FilterRunner(object):
    """
    A class to run multiple filters in succession.
    """
    def __init__(self, filters=None):
        """
        Initialise with a list of filters.
        :param filters: The list of filters
        """
        if filters is None:
            self.filters = []
        else:
            self.filters = filters
    def __call__(self, flags, **kwargs):
        """
        Call the filters one by one.
        :param flags: The input flags
        :returns: The filtered flags
        """
        flags = self.check_flags(flags, **kwargs)
        for f in self.filters:
            flags = f(flags, **kwargs)
        return flags
[docs]    def check_flags(
        self, flags, predictions=None, observations=None, shoeboxes=None, **kwargs
    ):
        """
        Check the flags are set, if they're not then create a list
        of Trues equal to the number of items given.
        :param flags: The input flags
        :param predictions: The predictions
        :param observations: The observations
        :param shoeboxes: The shoeboxes
        :return: The filtered flags
        """
        from scitbx.array_family import flex
        # If flags are not set then create a list of Trues
        if flags is None:
            length = 0
            if predictions:
                length = len(predictions)
            if observations:
                if length > 0:
                    assert length == len(observations)
                else:
                    length = len(observations)
            if shoeboxes:
                if length > 0:
                    assert length == len(observations)
                else:
                    length = len(shoeboxes)
            # Create an array of flags
            flags = flex.bool(length, True)
        # Return the flags
        return flags
[docs]class PeakCentroidDistanceFilter(object):
    def __init__(self, maxd):
        """
        Initialise
        :param maxd: The maximum distance allowed
        """
        self.maxd = maxd
[docs]    def run(self, flags, observations=None, shoeboxes=None, **kwargs):
        """
        Run the filtering.
        """
        # Get the peak locations and the centroids and return the flags of
        # those closer than the min distance
        peak = shoeboxes.peak_coordinates()
        cent = observations.centroids().px_position()
        return flags.__and__((peak - cent).norms() <= self.maxd)
    def __call__(self, flags, **kwargs):
        """ Call the filter and print information. """
        num_before = flags.count(True)
        flags = self.run(flags, **kwargs)
        num_after = flags.count(True)
        logger.info(
            "Filtered {0} of {1} spots by peak-centroid distance".format(
                num_after, num_before
            )
        )
        return flags
[docs]class BackgroundGradientFilter(object):
    def __init__(self, background_size=2, gradient_cutoff=4):
        self.background_size = background_size
        self.gradient_cutoff = gradient_cutoff
[docs]    def run(self, flags, sweep=None, shoeboxes=None, **kwargs):
        from dials.array_family import flex
        from dials.algorithms.shoebox import MaskCode
        from dials.algorithms.background.simple import Linear2dModeller
        bg_code = MaskCode.Valid | MaskCode.BackgroundUsed
        fg_code = MaskCode.Valid | MaskCode.Foreground
        strong_code = MaskCode.Valid | MaskCode.Strong
        modeller = Linear2dModeller()
        expanded_shoeboxes = flex.shoebox()
        detector = sweep.get_detector()
        zoffset = 0
        if sweep.get_scan() is not None:
            zoffset = sweep.get_scan().get_array_range()[0]
        class image_data_cache(object):
            def __init__(self, imageset, size=10):
                self.imageset = imageset
                self.size = size
                self._image_data = collections.OrderedDict()
            def __getitem__(self, i):
                image_data = self._image_data.get(i)
                if image_data is None:
                    image_data = self.imageset.get_raw_data(i)
                    if len(self._image_data) >= self.size:
                        # remove the oldest entry in the cache
                        del self._image_data[self._image_data.keys()[0]]
                    self._image_data[i] = image_data
                return image_data
        cache = image_data_cache(sweep)
        # cache = sweep
        # sort shoeboxes by centroid z
        frame = shoeboxes.centroid_all().position_frame()
        perm = flex.sort_permutation(frame)
        shoeboxes = shoeboxes.select(perm)
        buffer_size = 1
        bg_plus_buffer = self.background_size + buffer_size
        import time
        t0 = time.time()
        for i, shoebox in enumerate(shoeboxes):
            if not flags[perm[i]]:
                continue
            panel = detector[shoebox.panel]
            trusted_range = panel.get_trusted_range()
            max_x, max_y = panel.get_image_size()
            bbox = shoebox.bbox
            x1, x2, y1, y2, z1, z2 = bbox
            # expand the bbox with a background region around the spotfinder shoebox
            # perhaps also should use a buffer zone between the shoebox and the
            # background region
            expanded_bbox = (
                max(0, x1 - bg_plus_buffer),
                min(max_x, x2 + bg_plus_buffer),
                max(0, y1 - bg_plus_buffer),
                min(max_y, y2 + bg_plus_buffer),
                z1,
                z2,
            )
            shoebox.bbox = expanded_bbox
        t1 = time.time()
        logger.info("Time expand_shoebox: %s" % (t1 - t0))
        rlist = flex.reflection_table()
        rlist["shoebox"] = shoeboxes
        rlist["shoebox"].allocate()
        rlist["panel"] = shoeboxes.panels()
        rlist["bbox"] = shoeboxes.bounding_boxes()
        t0 = time.time()
        rlist.extract_shoeboxes(sweep)
        t1 = time.time()
        shoeboxes = rlist["shoebox"]
        shoeboxes.flatten()
        t0 = time.time()
        for i, shoebox in enumerate(shoeboxes):
            if not flags[perm[i]]:
                continue
            panel = detector[shoebox.panel]
            trusted_range = panel.get_trusted_range()
            max_x, max_y = panel.get_image_size()
            ex1, ex2, ey1, ey2, ez1, ez2 = shoebox.bbox
            data = shoebox.data
            mask = flex.bool(data.accessor(), False)
            for i_y, y in enumerate(range(ey1, ey2)):
                for i_x, x in enumerate(range(ex1, ex2)):
                    value = data[0, i_y, i_x]
                    if (
                        y >= (ey1 + buffer_size)
                        and y < (ey2 - buffer_size)
                        and x >= (ex1 + buffer_size)
                        and x < (ex2 - buffer_size)
                    ):
                        mask[0, i_y, i_x] = False  # foreground
                    elif value > trusted_range[0] and value < trusted_range[1]:
                        mask[0, i_y, i_x] = True  # background
            model = modeller.create(data.as_double(), mask)
            d, a, b = model.params()[:3]
            c = -1
            if abs(a) > self.gradient_cutoff or abs(b) > self.gradient_cutoff:
                flags[perm[i]] = False
        t1 = time.time()
        return flags
    def __call__(self, flags, **kwargs):
        """ Call the filter and print information. """
        num_before = flags.count(True)
        flags = self.run(flags, **kwargs)
        num_after = flags.count(True)
        logger.info(
            "Filtered {0} or {1} spots by background gradient".format(
                num_after, num_before
            )
        )
        return flags
[docs]class SpotDensityFilter(object):
    def __init__(self, nbins=50, gradient_cutoff=0.002):
        self.nbins = nbins
        self.gradient_cutoff = gradient_cutoff
[docs]    def run(self, flags, sweep=None, observations=None, **kwargs):
        obs_x, obs_y = observations.centroids().px_position_xy().parts()
        import numpy as np
        H, xedges, yedges = np.histogram2d(
            obs_x.as_numpy_array(), obs_y.as_numpy_array(), bins=self.nbins
        )
        from scitbx.array_family import flex
        H_flex = flex.double(H.flatten().astype(np.float64))
        n_slots = min(int(flex.max(H_flex)), 30)
        hist = flex.histogram(H_flex, n_slots=n_slots)
        slots = hist.slots()
        cumulative_hist = flex.long(len(slots))
        for i in range(len(slots)):
            cumulative_hist[i] = slots[i]
            if i > 0:
                cumulative_hist[i] += cumulative_hist[i - 1]
        cumulative_hist = cumulative_hist.as_double() / flex.max(
            cumulative_hist.as_double()
        )
        cutoff = None
        gradients = flex.double()
        for i in range(len(slots) - 1):
            x1 = cumulative_hist[i]
            x2 = cumulative_hist[i + 1]
            g = (x2 - x1) / hist.slot_width()
            gradients.append(g)
            if (
                cutoff is None
                and i > 0
                and g < self.gradient_cutoff
                and gradients[i - 1] < self.gradient_cutoff
            ):
                cutoff = hist.slot_centers()[i - 1] - 0.5 * hist.slot_width()
        H_flex = flex.double(np.ascontiguousarray(H))
        isel = (H_flex > cutoff).iselection()
        sel = np.column_stack(np.where(H > cutoff))
        for (ix, iy) in sel:
            flags.set_selected(
                (
                    (obs_x > xedges[ix])
                    & (obs_x < xedges[ix + 1])
                    & (obs_y > yedges[iy])
                    & (obs_y < yedges[iy + 1])
                ),
                False,
            )
        if 0:
            from matplotlib import pyplot
            fig, ax1 = pyplot.subplots()
            extent = [yedges[0], yedges[-1], xedges[0], xedges[-1]]
            plot1 = ax1.imshow(H, extent=extent, interpolation="nearest")
            pyplot.xlim((0, pyplot.xlim()[1]))
            pyplot.ylim((0, pyplot.ylim()[1]))
            pyplot.gca().invert_yaxis()
            cbar1 = pyplot.colorbar(plot1)
            pyplot.axes().set_aspect("equal")
            pyplot.show()
            fig, ax1 = pyplot.subplots()
            ax2 = ax1.twinx()
            ax1.scatter(hist.slot_centers() - 0.5 * hist.slot_width(), cumulative_hist)
            ax1.set_ylim(0, 1)
            ax2.plot(hist.slot_centers()[:-1] - 0.5 * hist.slot_width(), gradients)
            ymin, ymax = pyplot.ylim()
            pyplot.vlines(cutoff, ymin, ymax, color="r")
            pyplot.show()
            H2 = H.copy()
            if cutoff is not None:
                H2[np.where(H2 >= cutoff)] = 0
            fig, ax1 = pyplot.subplots()
            plot1 = ax1.pcolormesh(xedges, yedges, H2)
            pyplot.xlim((0, pyplot.xlim()[1]))
            pyplot.ylim((0, pyplot.ylim()[1]))
            pyplot.gca().invert_yaxis()
            cbar1 = pyplot.colorbar(plot1)
            pyplot.axes().set_aspect("equal")
            pyplot.show()
        return flags
    def __call__(self, flags, **kwargs):
        """ Call the filter and print information. """
        num_before = flags.count(True)
        flags = self.run(flags, **kwargs)
        num_after = flags.count(True)
        logger.info(
            "Filtered {0} of {1} spots by spot density".format(num_after, num_before)
        )
        return flags
[docs]class SpotFinderFactory(object):
    """
    Factory class to create spot finders
    """
[docs]    @staticmethod
    def from_parameters(params=None, datablock=None):
        """
        Given a set of parameters, construct the spot finder
        :param params: The input parameters
        :returns: The spot finder instance
        """
        from dials.util.masking import MaskGenerator
        from dials.algorithms.spot_finding.finder import SpotFinder
        from libtbx.phil import parse
        from dxtbx.imageset import ImageSweep
        if params is None:
            params = phil_scope.fetch(source=parse("")).extract()
        if params.spotfinder.force_2d and params.output.shoeboxes is False:
            no_shoeboxes_2d = True
        elif datablock is not None and params.output.shoeboxes is False:
            no_shoeboxes_2d = False
            all_stills = True
            for imageset in datablock.extract_imagesets():
                if isinstance(imageset, ImageSweep):
                    all_stills = False
                    break
            if all_stills:
                no_shoeboxes_2d = True
        else:
            no_shoeboxes_2d = False
        # Read in the lookup files
        mask = SpotFinderFactory.load_image(params.spotfinder.lookup.mask)
        params.spotfinder.lookup.mask = mask
        # Configure the filter options
        filter_spots = SpotFinderFactory.configure_filter(params)
        # Create the threshold strategy
        threshold_function = SpotFinderFactory.configure_threshold(params, datablock)
        # Configure the mask generator
        mask_generator = MaskGenerator(params.spotfinder.filter)
        # Make sure 'none' is interpreted as None
        if params.spotfinder.mp.method == "none":
            params.spotfinder.mp.method = None
        # Setup the spot finder
        return SpotFinder(
            threshold_function=threshold_function,
            mask=params.spotfinder.lookup.mask,
            filter_spots=filter_spots,
            scan_range=params.spotfinder.scan_range,
            write_hot_mask=params.spotfinder.write_hot_mask,
            hot_mask_prefix=params.spotfinder.hot_mask_prefix,
            mp_method=params.spotfinder.mp.method,
            mp_nproc=params.spotfinder.mp.nproc,
            mp_njobs=params.spotfinder.mp.njobs,
            mp_chunksize=params.spotfinder.mp.chunksize,
            max_strong_pixel_fraction=params.spotfinder.filter.max_strong_pixel_fraction,
            compute_mean_background=params.spotfinder.compute_mean_background,
            region_of_interest=params.spotfinder.region_of_interest,
            mask_generator=mask_generator,
            min_spot_size=params.spotfinder.filter.min_spot_size,
            max_spot_size=params.spotfinder.filter.max_spot_size,
            no_shoeboxes_2d=no_shoeboxes_2d,
            min_chunksize=params.spotfinder.mp.min_chunksize,
        )
[docs]    @staticmethod
    def configure_threshold(params, datablock):
        """
        Get the threshold strategy
        :param params: The input parameters
        :return: The threshold algorithm
        """
        import dials.extensions
        # Configure the algotihm
        Algorithm = dials.extensions.SpotFinderThreshold.load(
            params.spotfinder.threshold.algorithm
        )
        return Algorithm(params)
[docs]    @staticmethod
    def configure_filter(params):
        """
        Get the filter strategy.
        :param params: The input parameters
        :return: The filter algorithm
        """
        from dials.algorithms import shoebox
        # Initialise an empty list of filters
        filters = []
        # Add a peak-centroid distance filter
        if params.spotfinder.filter.max_separation is not None:
            filters.append(
                PeakCentroidDistanceFilter(params.spotfinder.filter.max_separation)
            )
        if params.spotfinder.filter.background_gradient.filter:
            bg_filter_params = params.spotfinder.filter.background_gradient
            filters.append(
                BackgroundGradientFilter(
                    background_size=bg_filter_params.background_size,
                    gradient_cutoff=bg_filter_params.gradient_cutoff,
                )
            )
        if params.spotfinder.filter.spot_density.filter:
            filters.append(SpotDensityFilter())
        # Return the filter runner with the list of filters
        return FilterRunner(filters)
[docs]    @staticmethod
    def load_image(filename_or_data):
        """
        Given a filename, load an image. If the data is already loaded, return it.
        :param filename_or_data: The input filename (or data)
        :return: The image or None
        """
        import six.moves.cPickle as pickle
        # If no filename is set then return None
        if not filename_or_data:
            return None
        # If it's already loaded, return early
        if isinstance(filename_or_data, tuple):
            return filename_or_data
        # Read the image and return the image data
        with open(filename_or_data, "rb") as fh:
            image = pickle.load(fh)
        if not isinstance(image, tuple):
            image = (image,)
        return image
  


