from __future__ import annotations
import logging
import numpy
import iotbx.phil
from cctbx.sgtbx import space_group
from dxtbx import flumpy
from dxtbx.model import Crystal
from dials.algorithms.indexing import DialsIndexError
from .strategy import Strategy
# Import fast feedback indexer package (CUDA implementation of the TORO algorithm)
# https://github.com/paulscherrerinstitute/fast-feedback-indexer/tree/main/python
try:
import ffbidx
except ModuleNotFoundError:
ffbidx = None
logger = logging.getLogger(__name__)
ffbidx_phil_str = """
ffbidx
.expert_level = 1
{
max_output_cells = 32
.type = int(value_min=1)
.help = "Maximum number of output cells"
max_spots = 300
.type = int(value_min=8)
.help = "Maximum number of reciprocal spots taken into account"
num_candidate_vectors = 32
.type = int(value_min=1)
.help = "Number of candidate cell vectors"
redundant_computations = True
.type = bool
.help = "Calculate candidates for all three cell vectors"
dist1 = 0.3
.type = float(value_min=0.001, value_max=0.5)
.help = "Reciprocal spots within this threshold contribute to the score for vector sampling"
dist3 = 0.15
.type = float(value_min=0.001, value_max=0.8)
.help = "Reciprocal spots within this threshold contribute to the score for cell sampling"
num_halfsphere_points = 32768
.type = int(value_min=8000)
.help = "Number of sampling points on the half sphere"
max_dist = 0.00075
.type = float(value_min=0.0)
.help = "Maximum final distance between measured and calculated reciprocal spot"
min_spots = 8
.type = int(value_min=6)
.help = "Minimum number of reciprocal spots within distance max_dist"
method = *ifssr ifss ifse raw
.type = choice
.help = "Refinement method (consult algorithm description)"
triml = 0.001
.type = float(value_min=0, value_max=0.5)
.help = "lower trimming value for intermediate score calculations"
trimh = 0.3
.type = float(value_min=0, value_max=0.5)
.help = "higher trimming value for intermediate score calculations"
delta = 0.1
.type = float(value_min=0.000001)
.help = "log2 curve position for intermediate score calculations, lower values will me more selective in choosing close spots"
simple_data_filename = None
.type = path
.help = "Optional filename for the output of a simple data file for debugging"
.expert_level = 2
}
"""
def write_simple_data_file(filename, rlp, cell):
"""Write a simple data file for debugging."""
with open(filename, "w") as f:
f.write(" ".join(map(str, cell.ravel())) + "\n")
for r in rlp:
f.write(" ".join(map(str, r.ravel())) + "\n")
[docs]
class FfbIndexer(Strategy):
"""
A lattice search strategy using a Cuda-accelerated implementation of the TORO algorithm.
For more info, see:
[Gasparotto P, et al. TORO Indexer: a PyTorch-based indexing algorithm for kilohertz serial crystallography. J. Appl. Cryst. 2024 57(4)](https://doi.org/10.1107/S1600576724003182)
"""
phil_help = (
"A lattice search strategy for very fast indexing using Cuda acceleration"
)
phil_scope = iotbx.phil.parse(ffbidx_phil_str)
[docs]
def __init__(
self, target_symmetry_primitive, max_lattices, params=None, *args, **kwargs
):
"""Construct FfbIndexer object.
Args:
target_symmetry_primitive (cctbx.crystal.symmetry): The target
crystal symmetry and unit cell
max_lattices (int): The maximum number of lattice models to find
params (phil,optional): Phil params
Returns:
None
"""
super().__init__(params=None, *args, **kwargs)
if ffbidx is None:
raise DialsIndexError(
"ffbidx requires the fast feedback indexer package. See (https://github.com/paulscherrerinstitute/fast-feedback-indexer)"
)
self._target_symmetry_primitive = target_symmetry_primitive
self._max_lattices = max_lattices
if target_symmetry_primitive is None:
raise DialsIndexError("Target unit cell must be provided for ffbidx")
target_cell = target_symmetry_primitive.unit_cell()
if target_cell is None:
raise ValueError("Please specify known_symmetry.unit_cell")
self.params = params
# Need the real space cell as numpy float32 array with all x vector coordinates, followed by y and z coordinates consecutively in memory
self.input_cell = numpy.reshape(
numpy.array(target_cell.orthogonalization_matrix(), dtype="float32"), (3, 3)
)
# Create fast feedback indexer object (on default CUDA device)
try:
self.indexer = ffbidx.Indexer(
max_output_cells=params.max_output_cells,
max_spots=params.max_spots,
num_candidate_vectors=params.num_candidate_vectors,
redundant_computations=params.redundant_computations,
)
except RuntimeError as e:
raise DialsIndexError(
"The ffbidx package is not correctly configured for this system. See (https://github.com/paulscherrerinstitute/fast-feedback-indexer). Error: "
+ str(e)
)
[docs]
def find_crystal_models(self, reflections, experiments):
"""Find a list of candidate crystal models.
Args:
reflections (dials.array_family.flex.reflection_table):
The found spots centroids and associated data
experiments (dxtbx.model.experiment_list.ExperimentList):
The experimental geometry models
Returns:
A list of candidate crystal models.
"""
rlp = numpy.array(flumpy.to_numpy(reflections["rlp"]), dtype="float32")
if self.params.simple_data_filename is not None:
write_simple_data_file(
self.params.simple_data_filename, rlp, self.input_cell
)
# Need the reciprocal lattice points as numpy float32 array with all x coordinates, followed by y and z coordinates consecutively in memory
rlp = rlp.transpose().copy()
output_cells, scores = self.indexer.run(
rlp,
self.input_cell,
dist1=self.params.dist1,
dist3=self.params.dist3,
num_halfsphere_points=self.params.num_halfsphere_points,
max_dist=self.params.max_dist,
min_spots=self.params.min_spots,
n_output_cells=self.params.max_output_cells,
method=self.params.method,
triml=self.params.triml,
trimh=self.params.trimh,
delta=self.params.delta,
)
cell_indices = self.indexer.crystals(
output_cells,
rlp,
scores,
threshold=self.params.max_dist,
min_spots=self.params.min_spots,
method=self.params.method,
)
candidate_crystal_models = []
if cell_indices is None:
return candidate_crystal_models
for index in cell_indices:
j = 3 * index
real_a = output_cells[:, j]
real_b = output_cells[:, j + 1]
real_c = output_cells[:, j + 2]
crystal = Crystal(
real_a.tolist(),
real_b.tolist(),
real_c.tolist(),
space_group=space_group("P1"),
)
candidate_crystal_models.append(crystal)
return candidate_crystal_models