This documentation page refers to a previous release of DIALS (2.2).
Click here to go to the corresponding page for the latest version of DIALS

Source code for dials.algorithms.indexing.lattice_search.low_res_spot_match

from __future__ import absolute_import, division, print_function

import copy
import logging
import math
import operator

import libtbx.phil
from cctbx import miller
from dials.algorithms.indexing import DialsIndexError
from dials.array_family import flex
from dxtbx.model import Crystal
from scitbx import matrix
from scitbx.math import least_squares_plane, superpose

from .strategy import Strategy

logger = logging.getLogger(__name__)

TWO_PI = 2.0 * math.pi
FIVE_DEG = TWO_PI * 5.0 / 360.0


class CompleteGraph(object):
    def __init__(self, seed_vertex):
        self.vertices = [seed_vertex]
        self.weight = [{0: 0.0}]
        self.total_weight = 0.0

    def factory_add_vertex(self, vertex, weights_to_other):
        # Return a new graph as a copy of this with an extra vertex added. This
        # is a factory rather than a change in-place because CompleteGraph ought
        # to be immutable to implement __hash__
        g = copy.deepcopy(self)

        current_len = len(g.vertices)
        assert len(weights_to_other) == current_len
        g.vertices.append(vertex)
        node = current_len

        # Update distances from other nodes to the new one
        for i, w in enumerate(weights_to_other):
            g.weight[i][node] = w

        # Add distances to other nodes from this one
        weights_to_other.append(0.0)
        to_other = {}
        for i, w in enumerate(weights_to_other):
            to_other[i] = w
        g.weight.append(to_other)

        # Update the total weight
        g.total_weight += sum(weights_to_other)

        # Sort the vertices and weights by spot_id
        l = zip(g.vertices, g.weight)
        l = sorted(l, key=lambda v_w: v_w[0]["spot_id"])
        v, w = zip(*l)
        g.vertices = list(v)
        g.weight = list(w)

        return g

    def __hash__(self):
        h = tuple((e["spot_id"], e["miller_index"]) for e in self.vertices)
        return hash(h)

    def __eq__(self, other):
        for a, b in zip(self.vertices, other.vertices):
            if a["spot_id"] != b["spot_id"]:
                return False
            if a["miller_index"] != b["miller_index"]:
                return False
        return True

    def __ne__(self, other):
        return not self == other


low_res_spot_match_phil_str = """\
candidate_spots
{
    limit_resolution_by = *n_spots d_min
    .type = choice

    d_min = 15.0
    .type = float(value_min=0)

    n_spots = 10
    .type = int

    d_star_tolerance = 4.0
    .help = "Number of sigmas from the centroid position for which to"
            "calculate d* bands"
    .type = float
}

use_P1_indices_as_seeds = False
    .type = bool

search_depth = *triplets quads
    .type = choice

bootstrap_crystal = False
    .type = bool

max_pairs = 200
    .type = int

max_triplets = 600
    .type = int

max_quads = 600
    .type = int
"""


[docs]class LowResSpotMatch(Strategy): """Lattice search by matching low resolution spots to candidate indices based on resolution and reciprocal space distance between observed spots. """ phil_scope = libtbx.phil.parse(low_res_spot_match_phil_str) def __init__( self, target_symmetry_primitive, max_lattices, params=None, *args, **kwargs ): """Construct a LowResSpotMatch object. Args: target_symmetry_primitive (cctbx.crystal.symmetry): The target crystal symmetry and unit cell max_lattices (int): The maximum number of lattice models to find """ super(LowResSpotMatch, self).__init__(params=params, *args, **kwargs) self._target_symmetry_primitive = target_symmetry_primitive self._max_lattices = max_lattices if target_symmetry_primitive is None: raise DialsIndexError( "Target unit cell and space group must be provided for low_res_spot_match" ) # Set reciprocal space orthogonalisation matrix uc = self._target_symmetry_primitive.unit_cell() self.Bmat = matrix.sqr(uc.fractionalization_matrix()).transpose()
[docs] def find_crystal_models(self, reflections, experiments): """Find a list of candidate crystal models. Args: reflections (dials.array_family.flex.reflection_table): The found spots centroids and associated data experiments (dxtbx.model.experiment_list.ExperimentList): The experimental geometry models """ # Take a subset of the observations at the same resolution and calculate # some values that will be needed for the search self._calc_obs_data(reflections, experiments) # Construct a library of candidate low res indices with their d* values self._calc_candidate_hkls() # First search: match each observation with candidate indices within the # acceptable resolution band self._calc_seeds_and_stems() if self._params.use_P1_indices_as_seeds: seeds = self.stems else: seeds = self.seeds logger.info("Using {0} seeds".format(len(seeds))) # Second search: match seed spots with another spot from a different # reciprocal lattice row, such that the observed reciprocal space distances # are within tolerances pairs = [] for seed in seeds: pairs.extend(self._pairs_with_seed(seed)) logger.info("Found {0} pairs".format(len(pairs))) pairs = list(set(pairs)) # filter duplicates if self._params.max_pairs: pairs.sort(key=operator.attrgetter("total_weight")) idx = self._params.max_pairs pairs = pairs[0:idx] logger.info("Using {0} highest-scoring pairs".format(len(pairs))) # Further search iterations: extend to more spots within tolerated distances triplets = [] for pair in pairs: triplets.extend(self._extend_by_candidates(pair)) logger.info("Found {0} triplets".format(len(triplets))) triplets = list(set(triplets)) # filter duplicates if self._params.max_triplets: triplets.sort(key=operator.attrgetter("total_weight")) idx = self._params.max_triplets triplets = triplets[0:idx] logger.info("Using {0} highest-scoring triplets".format(len(triplets))) branches = triplets if self._params.search_depth == "quads": quads = [] for triplet in triplets: quads.extend(self._extend_by_candidates(triplet)) logger.info("{0} quads".format(len(quads))) quads = list(set(quads)) # filter duplicates if self._params.max_quads: quads.sort(key=operator.attrgetter("total_weight")) idx = self._params.max_quads quads = quads[0:idx] logger.info("Using {0} highest-scoring quads".format(len(quads))) branches = quads # Sort branches by total deviation of observed distances from expected branches.sort(key=operator.attrgetter("total_weight")) candidate_crystal_models = [] for branch in branches: model = self._fit_crystal_model(branch) if model: candidate_crystal_models.append(model) if len(candidate_crystal_models) == self._max_lattices: break self.candidate_crystal_models = candidate_crystal_models return self.candidate_crystal_models
def _calc_candidate_hkls(self): # First a list of indices that fill 1 ASU hkl_list = miller.build_set( self._target_symmetry_primitive, anomalous_flag=False, d_min=self._params.candidate_spots.d_min, ) rt = flex.reflection_table() rt["miller_index"] = hkl_list.indices() rt["d_star"] = 1.0 / hkl_list.d_spacings().data() rt["rlp_datum"] = self.Bmat.elems * rt["miller_index"].as_vec3_double() self.candidate_hkls = rt # Now P1 indices with separate Friedel pairs hkl_list = miller.build_set( self._target_symmetry_primitive, anomalous_flag=True, d_min=self._params.candidate_spots.d_min, ) hkl_list_p1 = hkl_list.expand_to_p1() rt = flex.reflection_table() rt["miller_index"] = hkl_list_p1.indices() rt["d_star"] = 1.0 / hkl_list_p1.d_spacings().data() rt["rlp_datum"] = self.Bmat.elems * rt["miller_index"].as_vec3_double() self.candidate_hkls_p1 = rt return def _calc_obs_data(self, reflections, experiments): """Calculates a set of low resolution observations to try to match to indices. Each observation will record its d* value as well as tolerated d* bands and a 'clock angle'""" spot_d_star = reflections["rlp"].norms() if self._params.candidate_spots.limit_resolution_by == "n_spots": n_spots = self._params.candidate_spots.n_spots n_spots = min(n_spots, len(reflections) - 1) d_star_max = flex.sorted(spot_d_star)[n_spots - 1] self._params.candidate_spots.d_min = 1.0 / d_star_max # First select low resolution spots only spot_d_star = reflections["rlp"].norms() d_star_max = 1.0 / self._params.candidate_spots.d_min sel = spot_d_star <= d_star_max self.spots = reflections.select(sel) self.spots["d_star"] = spot_d_star.select(sel) # XXX In what circumstance might there be more than one experiment? detector = experiments.detectors()[0] beam = experiments.beams()[0] # Lab coordinate of the beam centre, using the first spot's panel panel = detector[self.spots[0]["panel"]] bc = panel.get_ray_intersection(beam.get_s0()) bc_lab = panel.get_lab_coord(bc) # Lab coordinate of each spot spot_lab = flex.vec3_double(len(self.spots)) pnl_ids = set(self.spots["panel"]) for pnl in pnl_ids: sel = self.spots["panel"] == pnl panel = detector[pnl] obs = self.spots["xyzobs.mm.value"].select(sel) x_mm, y_mm, _ = obs.parts() spot_lab.set_selected( sel, panel.get_lab_coord(flex.vec2_double(x_mm, y_mm)) ) # Radius vectors for each spot radius = spot_lab - bc_lab # Usually the radius vectors would all be in a single plane, but this might # not be the case if the spots are on different panels. To put them on the # same plane, project onto fast/slow of the panel used to get the beam # centre df = flex.vec3_double(len(self.spots), detector[0].get_fast_axis()) ds = flex.vec3_double(len(self.spots), detector[0].get_slow_axis()) clock_dirs = (radius.dot(df) * df + radius.dot(ds) * ds).each_normalize() # From this, find positive angles of each vector around a clock, using the # fast axis as 12 o'clock angs = clock_dirs.angle(detector[0].get_fast_axis()) dots = clock_dirs.dot(detector[0].get_slow_axis()) sel = dots < 0 # select directions in the second half of the clock face angs.set_selected(sel, (TWO_PI - angs.select(sel))) self.spots["clock_angle"] = angs # Project radius vectors onto fast/slow of the relevant panels df = flex.vec3_double(len(self.spots)) ds = flex.vec3_double(len(self.spots)) for pnl in pnl_ids: sel = self.spots["panel"] == pnl panel = detector[pnl] df.set_selected(sel, panel.get_fast_axis()) ds.set_selected(sel, panel.get_slow_axis()) panel_dirs = (radius.dot(df) * df + radius.dot(ds) * ds).each_normalize() # Calc error along each panel direction with simple error propagation # that assumes no covariance between x and y centroid errors. x = panel_dirs.dot(df) y = panel_dirs.dot(ds) x2, y2 = flex.pow2(x), flex.pow2(y) r2 = x2 + y2 sig_x2, sig_y2, _ = self.spots["xyzobs.mm.variance"].parts() var_r = (x2 / r2) * sig_x2 + (y2 / r2) * sig_y2 sig_r = flex.sqrt(var_r) # Pixel coordinates at limits of the band tol = self._params.candidate_spots.d_star_tolerance outer_spot_lab = spot_lab + panel_dirs * (tol * sig_r) inner_spot_lab = spot_lab - panel_dirs * (tol * sig_r) # Set d* at band limits inv_lambda = 1.0 / beam.get_wavelength() s1_outer = outer_spot_lab.each_normalize() * inv_lambda s1_inner = inner_spot_lab.each_normalize() * inv_lambda self.spots["d_star_outer"] = (s1_outer - beam.get_s0()).norms() self.spots["d_star_inner"] = (s1_inner - beam.get_s0()).norms() self.spots["d_star_band2"] = flex.pow2( self.spots["d_star_outer"] - self.spots["d_star_inner"] ) def _calc_seeds_and_stems(self): # As the first stage of search, determine a list of seed spots for further # stages. Order these by distance of observed d* from the candidate # reflection's canonical d* # First the 'seeds' (in 1 ASU) self.seeds = [] for i, spot in enumerate(self.spots.rows()): sel = (self.candidate_hkls["d_star"] <= spot["d_star_outer"]) & ( self.candidate_hkls["d_star"] >= spot["d_star_inner"] ) cands = self.candidate_hkls.select(sel) for c in cands.rows(): r_dst = abs(c["d_star"] - spot["d_star"]) self.seeds.append( { "spot_id": i, "miller_index": c["miller_index"], "rlp_datum": matrix.col(c["rlp_datum"]), "residual_d_star": r_dst, "clock_angle": spot["clock_angle"], } ) self.seeds.sort(key=operator.itemgetter("residual_d_star")) # Now the 'stems' to use in second search level, using all indices in P 1 self.stems = [] for i, spot in enumerate(self.spots.rows()): sel = (self.candidate_hkls_p1["d_star"] <= spot["d_star_outer"]) & ( self.candidate_hkls_p1["d_star"] >= spot["d_star_inner"] ) cands = self.candidate_hkls_p1.select(sel) for c in cands.rows(): r_dst = abs(c["d_star"] - spot["d_star"]) self.stems.append( { "spot_id": i, "miller_index": c["miller_index"], "rlp_datum": matrix.col(c["rlp_datum"]), "residual_d_star": r_dst, "clock_angle": spot["clock_angle"], } ) self.stems.sort(key=operator.itemgetter("residual_d_star")) def _pairs_with_seed(self, seed): seed_rlp = matrix.col(self.spots[seed["spot_id"]]["rlp"]) result = [] for cand in self.stems: # Don't check the seed spot itself if cand["spot_id"] == seed["spot_id"]: continue # Skip spots at a very similar clock angle, which probably belong to the # same line of indices from the origin angle_diff = cand["clock_angle"] - seed["clock_angle"] angle_diff = abs(((angle_diff + math.pi) % TWO_PI) - math.pi) if angle_diff < FIVE_DEG: continue # Calculate the plane normal for the plane containing the seed and stem. # Skip pairs of Miller indices that belong to the same line seed_vec = seed["rlp_datum"] cand_vec = cand["rlp_datum"] try: seed_vec.cross(cand_vec).normalize() except ZeroDivisionError: continue # Compare expected reciprocal space distance with observed distance cand_rlp = matrix.col(self.spots[cand["spot_id"]]["rlp"]) obs_dist = (cand_rlp - seed_rlp).length() exp_dist = (seed_vec - cand_vec).length() r_dist = abs(obs_dist - exp_dist) # If the distance difference is larger than the sum in quadrature of the # tolerated d* bands then reject the candidate sq_band1 = self.spots[seed["spot_id"]]["d_star_band2"] sq_band2 = self.spots[cand["spot_id"]]["d_star_band2"] if r_dist > math.sqrt(sq_band1 + sq_band2): continue # Store the seed-stem match as a 2-node graph g = CompleteGraph( { "spot_id": seed["spot_id"], "miller_index": seed["miller_index"], "rlp_datum": seed["rlp_datum"], } ) g = g.factory_add_vertex( { "spot_id": cand["spot_id"], "miller_index": cand["miller_index"], "rlp_datum": cand["rlp_datum"], }, weights_to_other=[r_dist], ) result.append(g) return result def _extend_by_candidates(self, graph): existing_ids = [e["spot_id"] for e in graph.vertices] obs_relps = [matrix.col(self.spots[e]["rlp"]) for e in existing_ids] exp_relps = [e["rlp_datum"] for e in graph.vertices] result = [] for cand in self.stems: # Don't check spots already matched if cand["spot_id"] in existing_ids: continue # Compare expected reciprocal space distances with observed distances cand_rlp = matrix.col(self.spots[cand["spot_id"]]["rlp"]) cand_vec = cand["rlp_datum"] obs_dists = [(cand_rlp - rlp).length() for rlp in obs_relps] exp_dists = [(vec - cand_vec).length() for vec in exp_relps] residual_dist = [abs(a - b) for (a, b) in zip(obs_dists, exp_dists)] # If any of the distance differences is larger than the sum in quadrature # of the tolerated d* bands then reject the candidate sq_candidate_band = self.spots[cand["spot_id"]]["d_star_band2"] bad_candidate = False for r_dist, spot_id in zip(residual_dist, existing_ids): sq_relp_band = self.spots[spot_id]["d_star_band2"] if r_dist > math.sqrt(sq_relp_band + sq_candidate_band): bad_candidate = True break if bad_candidate: continue # Calculate co-planarity of the relps, including the origin points = flex.vec3_double(exp_relps + [cand_vec, (0.0, 0.0, 0.0)]) plane = least_squares_plane(points) plane_score = flex.sum_sq( points.dot(plane.normal) - plane.distance_to_origin ) # Reject if the group of relps are too far from lying in a single plane. # This cut-off was determined by trial and error using simulated images. if plane_score > 6e-7: continue # Construct a graph including the accepted candidate node g = graph.factory_add_vertex( { "spot_id": cand["spot_id"], "miller_index": cand["miller_index"], "rlp_datum": cand["rlp_datum"], }, weights_to_other=residual_dist, ) result.append(g) return result @staticmethod def _fit_U_from_superposed_points(reference, other): # Add the origin to both sets of points reference.append((0, 0, 0)) other.append((0, 0, 0)) # Find U matrix that takes ideal relps to the reference fit = superpose.least_squares_fit(reference, other) return fit.r def _fit_crystal_model(self, graph): vertices = graph.vertices # Reciprocal lattice points of the observations sel = flex.size_t([e["spot_id"] for e in vertices]) reference = self.spots["rlp"].select(sel) # Ideal relps from the known cell other = flex.vec3_double([e["rlp_datum"] for e in vertices]) U = self._fit_U_from_superposed_points(reference, other) UB = U * self.Bmat if self._params.bootstrap_crystal: # Attempt to index the low resolution spots from dials_algorithms_indexing_ext import AssignIndices phi = self.spots["xyzobs.mm.value"].parts()[2] UB_matrices = flex.mat3_double([UB]) result = AssignIndices(self.spots["rlp"], phi, UB_matrices, tolerance=0.3) hkl = result.miller_indices() sel = hkl != (0, 0, 0) hkl_vec = hkl.as_vec3_double().select(sel) # Use the result to get a new UB matrix reference = self.spots["rlp"].select(sel) other = self.Bmat.elems * hkl_vec U = self._fit_U_from_superposed_points(reference, other) UB = U * self.Bmat # Calculate RMSD of the fit rms = reference.rms_difference(U.elems * other) # Construct a crystal model xl = Crystal(A=UB, space_group_symbol="P1") # Monkey-patch crystal to return rms of the fit (useful?) xl.rms = rms return xl