from __future__ import annotations
import logging
import gemmi
import numpy as np
from scipy.spatial.transform import Rotation
import iotbx.phil
from dxtbx.model import Crystal
from scitbx import matrix
from dials.algorithms.indexing import DialsIndexError
from dials.array_family import flex
from .strategy import Strategy
logger = logging.getLogger(__name__)
pink_indexer_phil_str = """
pink_indexer
.expert_level = 1
{
max_refls = 50
.type = int(value_min=10)
.help = "Maximum number of reflections to consider indexing"
wavelength = None
.type = float(value_min=0.)
.help = "The peak wavelength"
percent_bandwidth = 1.
.type = float(value_min=0.)
.help = "The percent bandwidth used to calculate the wavelength range for indexing. The wavelength range is defined (wavelength - wavelength*percent_bandwidth/200, wavelength + wavelength*percent_bandwidth/200). This parameter also reflects the uncertainty of the supplied cell constants with larger values appropriate for less certain unit cells."
rotogram_grid_points = 180
.type = int(value_min=10, value_max=1000)
.help = "Number of points at which to evaluate the angle search for each rlp-observation pair"
voxel_grid_points=150
.type = int(value_min=10, value_max=1000)
.help = "Controls the number of voxels onto which the rotograms are discretized"
min_lattices=1
.type = int(value_min=1, value_max=100)
.help = "The minimum number of candidate lattices to generate."
}
"""
def rotvec_to_quaternion(rotvec, deg=False, eps=1e-32):
"""
Convert rotation vector(s) to quaternion(s).
"""
alpha = norm2(rotvec)
ax = rotvec / np.where(alpha == 0.0, np.inf, alpha)[..., None]
if deg:
alpha = np.deg2rad(alpha)
a2 = 0.5 * alpha
sa2 = np.sin(a2)
w = np.cos(a2)
x = sa2 * ax[..., 0]
y = sa2 * ax[..., 1]
z = sa2 * ax[..., 2]
return np.stack((x, y, z, w), axis=-1)
def quaternion_multiply(a, b):
"""
Multiply two quaternions, return a*b
"""
# Expand a
ax = a[..., 0]
ay = a[..., 1]
az = a[..., 2]
aw = a[..., 3]
# Expand b
bx = b[..., 0]
by = b[..., 1]
bz = b[..., 2]
bw = b[..., 3]
# Calculate result
x = aw * bx + ax * bw + ay * bz - az * by
y = aw * by - ax * bz + ay * bw + az * bx
z = aw * bz + ax * by - ay * bx + az * bw
w = aw * bw - ax * bx - ay * by - az * bz
return np.dstack((x, y, z, w))
def norm2(array, axis=-1, keepdims=False):
"""Faster version of np.linalg.norm for the L2 norm in lower dimensions."""
a2 = np.square(array)
out = 0.0
for vec in np.split(a2, a2.shape[axis], axis=axis):
out += vec
out = np.sqrt(out)
if not keepdims:
out = np.squeeze(out, axis=axis)
return out
def normalize(array, axis=-1):
"""Normalize a numpy array along a particular axis by dividing by its L2 norm"""
out = array / norm2(array, axis=axis, keepdims=True)
return out
def angle_between(vec1, vec2, deg=True):
"""
This function computes the angle between vectors along the last dimension of the input arrays.
This version is a numerically stable one based on arctan2 as described in this post:
- https://scicomp.stackexchange.com/a/27769/39858
Parameters
----------
vec1 : array
An arbitrarily batched arry of vectors
vec2 : array
An arbitrarily batched arry of vectors
deg : bool (optional)
Whether angles are returned in degrees or radians. The default is degrees (deg=True).
Returns
-------
angles : array
A vector of angles with the same leading dimensions of vec1 and vec2.
"""
v1 = normalize(vec1, axis=-1)
v2 = normalize(vec2, axis=-1)
x1 = norm2(v1 - v2, axis=-1)
x2 = norm2(v1 + v2, axis=-1)
alpha = 2.0 * np.arctan2(x1, x2)
if deg:
return np.rad2deg(alpha)
return alpha
def generate_reciprocal_cell(cell, dmin, dtype=np.int32):
"""
Generate the miller indices of the full P1 reciprocal cell.
Parameters
----------
cell : tuple, list, np.ndarray of cell parameters, or gemmi.UnitCell
Unit cell parameters
dmin : float
Maximum resolution of the data in Å
dtype : np.dtype (optional)
The data type of the returned array. The default is np.int32.
Returns
-------
hkl : np.array(int32)
"""
hmax, kmax, lmax = cell.get_hkl_limits(dmin)
hkl = np.meshgrid(
np.linspace(-hmax, hmax + 1, 2 * hmax + 2, dtype=dtype),
np.linspace(-kmax, kmax + 1, 2 * kmax + 2, dtype=dtype),
np.linspace(-lmax, lmax + 1, 2 * lmax + 2, dtype=dtype),
)
hkl = np.stack(hkl).reshape((3, -1)).T
# Remove reflection 0,0,0
hkl = hkl[np.any(hkl != 0, axis=1)]
# Remove reflections outside of resolution range
dHKL = cell.calculate_d_array(hkl).astype("float32")
hkl = hkl[dHKL >= dmin]
return hkl
class Indexer:
"""
A class which implements the PinkIndexer algorithm.
[Gevorkov Y, et al. pinkIndexer – a universal indexer for pink-beam X-ray and electron diffraction snapshots. Acta Cryst A. 2020 Mar 1;76(2):121–31.](https://doi.org/10.1107/S2053273319015559)
"""
def __init__(self, cell, wavelength, bandwidth, float_dtype="float32"):
"""
Args:
cell (gemmi.UnitCell): The target cell.
wavelength (float): Peak wavelength in Å.
bandwidth (float): The percentage bandwidth used to calculate a wavelength range
"""
self.float_dtype = float_dtype
self.cell = cell
self.wav_peak = wavelength
self.wav_min = wavelength - bandwidth * wavelength / 200.0
self.wav_max = wavelength + bandwidth * wavelength / 200.0
self.B = np.asarray(cell.frac.mat, dtype=self.float_dtype).T
def index_pink(
self,
rlps,
max_refls=50,
rotogram_grid_points=200,
voxel_grid_points=200,
int_dtype="uint8",
float_dtype="float32",
min_lattices=1,
dilate_r=None,
):
"""
Args:
rlps (array): An n x 3 array floating point numbers corresponding to the rlp vectors. This will be normalized.
max_refls (int, optional): The maximum number of refls to consider.
rotogram_grid_points (int, optional): The number of points at which to evaluate the rotograms.
voxel_grid_points (int, optional): The fineness of the discretization used to quantitate rotograms.
int_dtype (str or dtype, optional): the dtype to use for the voxel grid
float_dtype (str or dtype, optional): the dtype to use for floating point values
min_lattices (int, optional): the minimum number of candidate lattices returned by this function
dilate_r (float, optional): optionally dilate the voxel grid by a kernel with this radius in pixels.
Returns:
iter_UB: an interable of crystal bases as numpy arrays
"""
q = rlps
q_len = norm2(q, keepdims=True) # length of q if lambda is 1A
# Possible resolution range for each observation considering wavelength range
res_min = self.wav_min / q_len.squeeze(-1)
res_max = self.wav_max / q_len.squeeze(-1)
in_range = None
# Truncate the resolution range if there are too many reflections
if len(res_min) > max_refls:
dmin = np.sort(res_min)[-max_refls]
in_range = res_min >= dmin
q = q[in_range]
q_len = q_len[in_range]
res_min = res_min[in_range]
res_max = res_max[in_range]
else:
dmin = res_min.min()
# Normalized q vector
qhat = normalize(q)
# Generate the feasible set of reflections from the current geometry
# These are in cartesian reciprocal space coordinates in the
# crystal-fixed system
Hall = generate_reciprocal_cell(self.cell, dmin, dtype=self.float_dtype)
h = (self.B @ Hall.T).T
hhat = normalize(h)
# Remove candidate rlps if incompatible with resolution range of observation
dall = self.cell.calculate_d_array(Hall)
mask = (dall <= res_max[:, None]) & (dall >= res_min[:, None])
i, j = np.where(mask)
hhat = hhat[j]
qhat = qhat[i]
# mhat bisects hhat and qhat
m = hhat + qhat
mhat = normalize(m)
# construct a rotation that maps hhat onto qhat
Rm = rotvec_to_quaternion(np.pi * mhat)
# construct rotations about qhat
phimin = -np.pi
phimax = np.pi
phi = np.linspace(
phimin, phimax, rotogram_grid_points + 1, dtype=self.float_dtype
)[:-1]
rotvec = qhat[:, None, :] * phi[None, :, None]
Rq = rotvec_to_quaternion(rotvec)
# combine mhat and qhat rotations into general rotation
quat = quaternion_multiply(Rq, Rm[:, None, :])
flat_quat = quat.reshape((-1, 4))
rotvec = (
Rotation.from_quat(flat_quat).as_rotvec().reshape(quat.shape[:-1] + (-1,))
)
theta = norm2(rotvec)
axis = rotvec / theta[..., None]
# Scaling the general rotvec norm by this factor makes the discretization more uniform
# without this, the rotations would be biased toward higher angles
scales = np.arctan(theta / 4.0)
scaled_rotvec = axis * scales[..., None]
# Discretize scaled rotvecs
scale_max = np.arctan(np.pi / 4.0)
bins = np.linspace(-scale_max, scale_max, voxel_grid_points)
idx = np.digitize(scaled_rotvec, bins)
# This is how to calculate the bin centers to extract the rotation matrix from the voxel grid
bin_centers = np.concatenate(
(bins[[0]], 0.5 * (bins[1:] + bins[:-1]), bins[[-1]])
)
# Map discretized rotvecs into voxel grid
n = voxel_grid_points + 1
voxels = np.zeros((n, n, n), dtype=int_dtype)
np.add.at(voxels, tuple(idx.transpose((2, 0, 1))), 1)
# Optionally dilate the voxel grid
if dilate_r is not None:
# PinkIndexer does a little dilation to help avoid overfitting
# I'm implementing this using a radially symmetric convolution
# In the original paper, they use a cubic kernel ones((3, 3, 3))
voxels = voxels.astype(float_dtype)
x = np.arange(voxels.shape[0], dtype=float_dtype)
x = np.square(x - x.mean())
kernel = np.exp(
-(x[:, None, None] + x[None, :, None] + x[None, None, :])
/ np.square(dilate_r)
)
from scipy.signal import fftconvolve
voxels = fftconvolve(voxels, kernel, mode="same")
# Possible solutions are voxels with the highest density
cutoff = np.sort(voxels.flatten())[-min_lattices]
peaks = np.column_stack(np.where(voxels >= cutoff))
for peak in peaks:
v = bin_centers[peak]
l = norm2(v)
theta = np.tan(l) * 4.0
rotvec = theta * v / l
U = Rotation.from_rotvec(rotvec).as_matrix()
UB = U @ self.B
yield UB
[docs]
class PinkIndexer(Strategy):
"""
A lattice search strategy using the pinkIndexer algorithm.
For more info, see:
[Gevorkov Y, et al. pinkIndexer – a universal indexer for pink-beam X-ray and electron diffraction snapshots. Acta Cryst A. 2020 Mar 1;76(2):121–31.](https://doi.org/10.1107/S2053273319015559)
"""
phil_help = (
"A lattice search strategy that matches low resolution spots to candidate "
"indices based on a known unit cell and space group. It supports mono and "
"polychromatic beams. "
)
phil_scope = iotbx.phil.parse(pink_indexer_phil_str)
[docs]
def __init__(
self, target_symmetry_primitive, max_lattices, params=None, *args, **kwargs
):
"""Construct PinkIndexer object.
Args:
target_symmetry_primitive (cctbx.crystal.symmetry): The target
crystal symmetry and unit cell
max_lattices (int): The maximum number of lattice models to find
params (phil,optional): Phil params
"""
super().__init__(params=None, *args, **kwargs)
self._target_symmetry_primitive = target_symmetry_primitive
self._max_lattices = max_lattices
if target_symmetry_primitive is None:
raise DialsIndexError(
"Target unit cell and space group must be provided for pink_indexer"
)
target_cell = target_symmetry_primitive.unit_cell()
if target_cell is None:
raise ValueError("Please specify known_symmetry.unit_cell")
self.cell = target_cell
self.tarsym = target_symmetry_primitive
self.spacegroup = target_symmetry_primitive.space_group()
self.wavelength = params.wavelength
self.percent_bandwidth = params.percent_bandwidth
self.max_refls = params.max_refls
self.rotogram_grid_points = params.rotogram_grid_points
self.voxel_grid_points = params.voxel_grid_points
self.min_lattices = params.min_lattices
[docs]
def find_crystal_models(self, reflections, experiments):
"""Find a list of candidate crystal models.
Args:
reflections (dials.array_family.flex.reflection_table):
The found spots centroids and associated data
experiments (dxtbx.model.experiment_list.ExperimentList):
The experimental geometry models
"""
# This is a workaround for https://github.com/dials/dials/issues/2485
reflections["id"] *= 0
reflections["id"] -= 1
expt = experiments[0]
refls = reflections
cell = gemmi.UnitCell(*self.cell.parameters())
wav, bw = self.wavelength, self.percent_bandwidth
beam = expt.beam
if wav is None:
if experiments.all_laue() or experiments.all_tof():
assert "wavelength" in reflections
wav = reflections["wavelength"]
else:
wav = beam.get_wavelength()
if bw is None:
bw = 1.0
pidxr = Indexer(cell, wav, bw)
rlps = flex.vec3_double(len(refls))
s1 = refls["s1"]
s1_hat = s1 / s1.norms()
s0_hat = beam.get_unit_s0()
# Rotate reflections to common coordinate frame
for i, expt in enumerate(experiments):
sel_expt = refls["imageset_id"] == i
s1_hat_expt = s1_hat.select(sel_expt)
q = s1_hat_expt - s0_hat
if expt.goniometer is not None:
setting_rotation = matrix.sqr(expt.goniometer.get_setting_rotation())
rotation_axis = expt.goniometer.get_rotation_axis_datum()
sample_rotation = matrix.sqr(expt.goniometer.get_fixed_rotation())
q = tuple(setting_rotation.inverse()) * q
if expt.scan is not None and expt.scan.has_property("oscillation"):
_, _, z = refls["xyzobs.mm.value"].select(sel_expt).parts()
q.rotate_around_origin(rotation_axis, -z)
q = tuple(sample_rotation.inverse()) * q
rlps.set_selected(sel_expt, q)
rlps = np.array(rlps, dtype="float32")
self.candidate_crystal_models = []
for UB in pidxr.index_pink(
rlps,
self.max_refls,
rotogram_grid_points=self.rotogram_grid_points,
voxel_grid_points=self.voxel_grid_points,
min_lattices=self.min_lattices,
):
real_a, real_b, real_c = np.linalg.inv(UB.astype("double"))
crystal = Crystal(real_a, real_b, real_c, self.spacegroup)
self.candidate_crystal_models.append(crystal)
return self.candidate_crystal_models