from __future__ import annotations
import logging
import gemmi
import numpy as np
from scipy.spatial.transform import Rotation
import iotbx.phil
from dxtbx.model import Crystal
from dials.algorithms.indexing import DialsIndexError
from .strategy import Strategy
logger = logging.getLogger(__name__)
pink_indexer_phil_str = """
pink_indexer
.expert_level = 1
{
max_refls = 50
.type = int(value_min=10)
.help = "Maximum number of reflections to consider indexing"
wavelength = None
.type = float(value_min=0.)
.help = "The peak wavelength"
percent_bandwidth = 1.
.type = float(value_min=0.)
.help = "The percent bandwidth used to calculate the wavelength range for indexing. The wavelength range is defined (wavelength - wavelength*percent_bandwidth/200, wavelength + wavelength*percent_bandwidth/200). This parameter also reflects the uncertainty of the supplied cell constants with larger values appropriate for less certain unit cells."
rotogram_grid_points = 180
.type = int(value_min=10, value_max=1000)
.help = "Number of points at which to evaluate the angle search for each rlp-observation pair"
voxel_grid_points=150
.type = int(value_min=10, value_max=1000)
.help = "Controls the number of voxels onto which the rotograms are discretized"
min_lattices=1
.type = int(value_min=1, value_max=100)
.help = "The minimum number of candidate lattices to generate."
}
"""
def rotvec_to_quaternion(rotvec, deg=False, eps=1e-32):
"""
Convert rotation vector(s) to quaternion(s).
"""
alpha = norm2(rotvec)
ax = rotvec / np.where(alpha == 0.0, np.inf, alpha)[..., None]
if deg:
alpha = np.deg2rad(alpha)
a2 = 0.5 * alpha
sa2 = np.sin(a2)
w = np.cos(a2)
x = sa2 * ax[..., 0]
y = sa2 * ax[..., 1]
z = sa2 * ax[..., 2]
return np.stack((x, y, z, w), axis=-1)
def quaternion_multiply(a, b):
"""
Multiply two quaternions, return a*b
"""
# Expand a
ax = a[..., 0]
ay = a[..., 1]
az = a[..., 2]
aw = a[..., 3]
# Expand b
bx = b[..., 0]
by = b[..., 1]
bz = b[..., 2]
bw = b[..., 3]
# Calculate result
x = aw * bx + ax * bw + ay * bz - az * by
y = aw * by - ax * bz + ay * bw + az * bx
z = aw * bz + ax * by - ay * bx + az * bw
w = aw * bw - ax * bx - ay * by - az * bz
return np.dstack((x, y, z, w))
def norm2(array, axis=-1, keepdims=False):
"""Faster version of np.linalg.norm for the L2 norm in lower dimensions."""
a2 = np.square(array)
out = 0.0
for vec in np.split(a2, a2.shape[axis], axis=axis):
out += vec
out = np.sqrt(out)
if not keepdims:
out = np.squeeze(out, axis=axis)
return out
def normalize(array, axis=-1):
"""Normalize a numpy array along a particular axis by dividing by its L2 norm"""
out = array / norm2(array, axis=axis, keepdims=True)
return out
def angle_between(vec1, vec2, deg=True):
"""
This function computes the angle between vectors along the last dimension of the input arrays.
This version is a numerically stable one based on arctan2 as described in this post:
- https://scicomp.stackexchange.com/a/27769/39858
Parameters
----------
vec1 : array
An arbitrarily batched arry of vectors
vec2 : array
An arbitrarily batched arry of vectors
deg : bool (optional)
Whether angles are returned in degrees or radians. The default is degrees (deg=True).
Returns
-------
angles : array
A vector of angles with the same leading dimensions of vec1 and vec2.
"""
v1 = normalize(vec1, axis=-1)
v2 = normalize(vec2, axis=-1)
x1 = norm2(v1 - v2, axis=-1)
x2 = norm2(v1 + v2, axis=-1)
alpha = 2.0 * np.arctan2(x1, x2)
if deg:
return np.rad2deg(alpha)
return alpha
def generate_reciprocal_cell(cell, dmin, dtype=np.int32):
"""
Generate the miller indices of the full P1 reciprocal cell.
Parameters
----------
cell : tuple, list, np.ndarray of cell parameters, or gemmi.UnitCell
Unit cell parameters
dmin : float
Maximum resolution of the data in Å
dtype : np.dtype (optional)
The data type of the returned array. The default is np.int32.
Returns
-------
hkl : np.array(int32)
"""
hmax, kmax, lmax = cell.get_hkl_limits(dmin)
hkl = np.meshgrid(
np.linspace(-hmax, hmax + 1, 2 * hmax + 2, dtype=dtype),
np.linspace(-kmax, kmax + 1, 2 * kmax + 2, dtype=dtype),
np.linspace(-lmax, lmax + 1, 2 * lmax + 2, dtype=dtype),
)
hkl = np.stack(hkl).reshape((3, -1)).T
# Remove reflection 0,0,0
hkl = hkl[np.any(hkl != 0, axis=1)]
# Remove reflections outside of resolution range
dHKL = cell.calculate_d_array(hkl).astype("float32")
hkl = hkl[dHKL >= dmin]
return hkl
class Indexer:
"""
A class which implements the PinkIndexer algorithm.
[Gevorkov Y, et al. pinkIndexer – a universal indexer for pink-beam X-ray and electron diffraction snapshots. Acta Cryst A. 2020 Mar 1;76(2):121–31.](https://doi.org/10.1107/S2053273319015559)
"""
def __init__(self, cell, wavelength, bandwidth, float_dtype="float32"):
"""
Args:
cell (gemmi.UnitCell): The target cell.
wavelength (float): Peak wavelength in Å.
bandwidth (float): The percentage bandwidth used to calculate a wavelength range
"""
self.float_dtype = float_dtype
self.cell = cell
self.wav_peak = wavelength
self.wav_min = wavelength - bandwidth * wavelength / 200.0
self.wav_max = wavelength + bandwidth * wavelength / 200.0
self.B = np.asarray(cell.frac.mat, dtype=self.float_dtype).T
def index_pink(
self,
s0,
s1,
max_refls=50,
rotogram_grid_points=200,
voxel_grid_points=200,
int_dtype="uint8",
float_dtype="float32",
min_lattices=1,
dilate_r=None,
):
"""
Args:
s0 (array): An array of 3 floating point numbers corresponding to the s0 vector. This will be normalized.
s1 (array): An n x 3 array floating point numbers corresponding to the s1 vectors. This will be normalized.
max_refls (int, optional): The maximum number of refls to consider.
rotogram_grid_points (int, optional): The number of points at which to evaluate the rotograms.
voxel_grid_points (int, optional): The fineness of the discretization used to quantitate rotograms.
int_dtype (str or dtype, optional): the dtype to use for the voxel grid
float_dtype (str or dtype, optional): the dtype to use for floating point values
min_lattices (int, optional): the minimum number of candidate lattices returned by this function
dilate_r (float, optional): optionally dilate the voxel grid by a kernel with this radius in pixels.
Returns:
iter_UB: an interable of crystal bases as numpy arrays
"""
s0_hat = normalize(np.asarray(s0, dtype=self.float_dtype))
s1_hat = normalize(np.asarray(s1, dtype=self.float_dtype))
q = s1_hat - s0_hat[None, :]
q_len = norm2(q, keepdims=True) # length of q if lambda is 1A
# Possible resolution range for each observation considering wavelength range
res_min = self.wav_min / q_len.squeeze(-1)
res_max = self.wav_max / q_len.squeeze(-1)
in_range = None
# Truncate the resolution range if there are too many reflections
if len(res_min) > max_refls:
dmin = np.sort(res_min)[-max_refls]
in_range = res_min >= dmin
s1_hat = s1_hat[in_range]
s1 = s1[in_range]
q = q[in_range]
q_len = q_len[in_range]
res_min = res_min[in_range]
res_max = res_max[in_range]
else:
dmin = res_min.min()
# Normalized q vector
qhat = normalize(q)
# Generate the feasible set of reflections from the current geometry
# These are in cartesian reciprocal space coordinates in the
# crystal-fixed system
Hall = generate_reciprocal_cell(self.cell, dmin, dtype=self.float_dtype)
h = (self.B @ Hall.T).T
hhat = normalize(h)
# Remove candidate rlps if incompatible with resolution range of observation
dall = self.cell.calculate_d_array(Hall)
mask = (dall <= res_max[:, None]) & (dall >= res_min[:, None])
i, j = np.where(mask)
hhat = hhat[j]
qhat = qhat[i]
# mhat bisects hhat and qhat
m = hhat + qhat
mhat = normalize(m)
# construct a rotation that maps hhat onto qhat
Rm = rotvec_to_quaternion(np.pi * mhat)
# construct rotations about qhat
phimin = -np.pi
phimax = np.pi
phi = np.linspace(
phimin, phimax, rotogram_grid_points + 1, dtype=self.float_dtype
)[:-1]
rotvec = qhat[:, None, :] * phi[None, :, None]
Rq = rotvec_to_quaternion(rotvec)
# combine mhat and qhat rotations into general rotation
quat = quaternion_multiply(Rq, Rm[:, None, :])
flat_quat = quat.reshape((-1, 4))
rotvec = (
Rotation.from_quat(flat_quat).as_rotvec().reshape(quat.shape[:-1] + (-1,))
)
theta = norm2(rotvec)
axis = rotvec / theta[..., None]
# Scaling the general rotvec norm by this factor makes the discretization more uniform
# without this, the rotations would be biased toward higher angles
scales = np.arctan(theta / 4.0)
scaled_rotvec = axis * scales[..., None]
# Discretize scaled rotvecs
scale_max = np.arctan(np.pi / 4.0)
bins = np.linspace(-scale_max, scale_max, voxel_grid_points)
idx = np.digitize(scaled_rotvec, bins)
# This is how to calculate the bin centers to extract the rotation matrix from the voxel grid
bin_centers = np.concatenate(
(bins[[0]], 0.5 * (bins[1:] + bins[:-1]), bins[[-1]])
)
# Map discretized rotvecs into voxel grid
n = voxel_grid_points + 1
voxels = np.zeros((n, n, n), dtype=int_dtype)
np.add.at(voxels, tuple(idx.transpose((2, 0, 1))), 1)
# Optionally dilate the voxel grid
if dilate_r is not None:
# PinkIndexer does a little dilation to help avoid overfitting
# I'm implementing this using a radially symmetric convolution
# In the original paper, they use a cubic kernel ones((3, 3, 3))
voxels = voxels.astype(float_dtype)
x = np.arange(voxels.shape[0], dtype=float_dtype)
x = np.square(x - x.mean())
kernel = np.exp(
-(x[:, None, None] + x[None, :, None] + x[None, None, :])
/ np.square(dilate_r)
)
from scipy.signal import fftconvolve
voxels = fftconvolve(voxels, kernel, mode="same")
# Possible solutions are voxels with the highest density
cutoff = np.sort(voxels.flatten())[-min_lattices]
peaks = np.column_stack(np.where(voxels >= cutoff))
for peak in peaks:
v = bin_centers[peak]
l = norm2(v)
theta = np.tan(l) * 4.0
rotvec = theta * v / l
U = Rotation.from_rotvec(rotvec).as_matrix()
UB = U @ self.B
yield UB
[docs]
class PinkIndexer(Strategy):
"""
A lattice search strategy using the pinkIndexer algorithm.
For more info, see:
[Gevorkov Y, et al. pinkIndexer – a universal indexer for pink-beam X-ray and electron diffraction snapshots. Acta Cryst A. 2020 Mar 1;76(2):121–31.](https://doi.org/10.1107/S2053273319015559)
"""
phil_help = (
"A lattice search strategy that matches low resolution spots to candidate "
"indices based on a known unit cell and space group. It supports mono and "
"polychromatic beams. "
)
phil_scope = iotbx.phil.parse(pink_indexer_phil_str)
[docs]
def __init__(
self, target_symmetry_primitive, max_lattices, params=None, *args, **kwargs
):
"""Construct PinkIndexer object.
Args:
target_symmetry_primitive (cctbx.crystal.symmetry): The target
crystal symmetry and unit cell
max_lattices (int): The maximum number of lattice models to find
params (phil,optional): Phil params
"""
super().__init__(params=None, *args, **kwargs)
self._target_symmetry_primitive = target_symmetry_primitive
self._max_lattices = max_lattices
if target_symmetry_primitive is None:
raise DialsIndexError(
"Target unit cell and space group must be provided for pink_indexer"
)
target_cell = target_symmetry_primitive.unit_cell()
if target_cell is None:
raise ValueError("Please specify known_symmetry.unit_cell")
self.cell = target_cell
self.tarsym = target_symmetry_primitive
self.spacegroup = target_symmetry_primitive.space_group()
self.wavelength = params.wavelength
self.percent_bandwidth = params.percent_bandwidth
self.max_refls = params.max_refls
self.rotogram_grid_points = params.rotogram_grid_points
self.voxel_grid_points = params.voxel_grid_points
self.min_lattices = params.min_lattices
[docs]
def find_crystal_models(self, reflections, experiments):
"""Find a list of candidate crystal models.
Args:
reflections (dials.array_family.flex.reflection_table):
The found spots centroids and associated data
experiments (dxtbx.model.experiment_list.ExperimentList):
The experimental geometry models
"""
# This is a workaround for https://github.com/dials/dials/issues/2485
reflections["id"] *= 0
reflections["id"] -= 1
if len(experiments) < 1:
msg = "pink_indexer received an experimentlist with length > 1. To use the pink_indexer method, you must set joint_indexing=False when you call dials.index"
raise ValueError(msg)
expt = experiments[0]
refls = reflections
cell = gemmi.UnitCell(*self.cell.parameters())
beam = expt.beam
s0 = beam.get_s0()
wav, bw = self.wavelength, self.percent_bandwidth
if wav is None:
wav = beam.get_wavelength()
if bw is None:
bw = 1.0
pidxr = Indexer(cell, wav, bw)
s1 = np.array(refls["s1"], dtype="float32")
self.candidate_crystal_models = []
for UB in pidxr.index_pink(
s0,
s1,
self.max_refls,
rotogram_grid_points=self.rotogram_grid_points,
voxel_grid_points=self.voxel_grid_points,
min_lattices=self.min_lattices,
):
real_a, real_b, real_c = np.linalg.inv(UB.astype("double"))
crystal = Crystal(real_a, real_b, real_c, self.spacegroup)
self.candidate_crystal_models.append(crystal)
return self.candidate_crystal_models