# coding: utf-8
# Copyright (c) Pymatgen Development Team.
# Distributed under the terms of the MIT License.
"""
This module contains classes to wrap Python VTK to make nice molecular plots.
"""
import itertools
import math
import os
import subprocess
import time
import numpy as np
try:
import vtk
from vtk import vtkInteractorStyleTrackballCamera
except ImportError:
# VTK not present. The Camera is to set object to avoid errors in unittest.
vtk = None
vtkInteractorStyleTrackballCamera = object
from monty.dev import requires
from monty.serialization import loadfn
from pymatgen.core.periodic_table import Species
from pymatgen.core.sites import PeriodicSite
from pymatgen.core.structure import Structure
from pymatgen.util.coord import in_coord_list
module_dir = os.path.dirname(os.path.abspath(__file__))
EL_COLORS = loadfn(os.path.join(module_dir, "ElementColorSchemes.yaml"))
[docs]class StructureVis:
"""
Provides Structure object visualization using VTK.
"""
@requires(vtk, "Visualization requires the installation of VTK with " "Python bindings.")
def __init__(
self,
element_color_mapping=None,
show_unit_cell=True,
show_bonds=False,
show_polyhedron=True,
poly_radii_tol_factor=0.5,
excluded_bonding_elements=None,
):
"""
Constructs a Structure Visualization.
Args:
element_color_mapping: Optional color mapping for the elements,
as a dict of {symbol: rgb tuple}. For example, {"Fe": (255,
123,0), ....} If None is specified, a default based on
Jmol"s color scheme is used.
show_unit_cell: Set to False to not show the unit cell
boundaries. Defaults to True.
show_bonds: Set to True to show bonds. Defaults to True.
show_polyhedron: Set to True to show polyhedrons. Defaults to
False.
poly_radii_tol_factor: The polyhedron and bonding code uses the
ionic radii of the elements or species to determine if two
atoms are bonded. This specifies a tolerance scaling factor
such that atoms which are (1 + poly_radii_tol_factor) * sum
of ionic radii apart are still considered as bonded.
excluded_bonding_elements: List of atom types to exclude from
bonding determination. Defaults to an empty list. Useful
when trying to visualize a certain atom type in the
framework (e.g., Li in a Li-ion battery cathode material).
Useful keyboard shortcuts implemented.
h : Show help
A/a : Increase/decrease cell by one unit vector in a-direction
B/b : Increase/decrease cell by one unit vector in b-direction
C/c : Increase/decrease cell by one unit vector in c-direction
# : Toggle showing of polyhedrons
- : Toggle showing of bonds
[ : Decrease poly_radii_tol_factor by 0.05
] : Increase poly_radii_tol_factor by 0.05
r : Reset camera direction
o : Orthogonalize structure
Up/Down : Rotate view along Up direction by 90 clock/anticlockwise
Left/right : Rotate view along camera direction by 90
clock/anticlockwise
"""
# create a rendering window and renderer
self.ren = vtk.vtkRenderer()
self.ren_win = vtk.vtkRenderWindow()
self.ren_win.AddRenderer(self.ren)
self.ren.SetBackground(1, 1, 1)
self.title = "Structure Visualizer"
# create a renderwindowinteractor
self.iren = vtk.vtkRenderWindowInteractor()
self.iren.SetRenderWindow(self.ren_win)
self.mapper_map = {}
self.structure = None
if element_color_mapping:
self.el_color_mapping = element_color_mapping
else:
self.el_color_mapping = EL_COLORS["VESTA"]
self.show_unit_cell = show_unit_cell
self.show_bonds = show_bonds
self.show_polyhedron = show_polyhedron
self.poly_radii_tol_factor = poly_radii_tol_factor
self.excluded_bonding_elements = excluded_bonding_elements if excluded_bonding_elements else []
self.show_help = True
self.supercell = [[1, 0, 0], [0, 1, 0], [0, 0, 1]]
self.redraw()
style = StructureInteractorStyle(self)
self.iren.SetInteractorStyle(style)
self.ren.parent = self
[docs] def rotate_view(self, axis_ind=0, angle=0):
"""
Rotate the camera view.
Args:
axis_ind: Index of axis to rotate. Defaults to 0, i.e., a-axis.
angle: Angle to rotate by. Defaults to 0.
"""
camera = self.ren.GetActiveCamera()
if axis_ind == 0:
camera.Roll(angle)
elif axis_ind == 1:
camera.Azimuth(angle)
else:
camera.Pitch(angle)
self.ren_win.Render()
[docs] def write_image(self, filename="image.png", magnification=1, image_format="png"):
"""
Save render window to an image.
Arguments:
filename:
filename to save to. Defaults to image.png.
magnification:
magnification. Use it to render high res images.
image_format:
choose between jpeg, png. Png is the default.
"""
render_large = vtk.vtkRenderLargeImage()
render_large.SetInput(self.ren)
if image_format == "jpeg":
writer = vtk.vtkJPEGWriter()
writer.SetQuality(80)
else:
writer = vtk.vtkPNGWriter()
render_large.SetMagnification(magnification)
writer.SetFileName(filename)
writer.SetInputConnection(render_large.GetOutputPort())
self.ren_win.Render()
writer.Write()
del render_large
[docs] def redraw(self, reset_camera=False):
"""
Redraw the render window.
Args:
reset_camera: Set to True to reset the camera to a
pre-determined default for each structure. Defaults to False.
"""
self.ren.RemoveAllViewProps()
self.picker = None
self.add_picker_fixed()
self.helptxt_mapper = vtk.vtkTextMapper()
tprops = self.helptxt_mapper.GetTextProperty()
tprops.SetFontSize(14)
tprops.SetFontFamilyToTimes()
tprops.SetColor(0, 0, 0)
if self.structure is not None:
self.set_structure(self.structure, reset_camera)
self.ren_win.Render()
[docs] def orthongonalize_structure(self):
"""
Orthogonalize the structure.
"""
if self.structure is not None:
self.set_structure(self.structure.copy(sanitize=True))
self.ren_win.Render()
[docs] def display_help(self):
"""
Display the help for various keyboard shortcuts.
"""
helptxt = [
"h : Toggle help",
"A/a, B/b or C/c : Increase/decrease cell by one a," " b or c unit vector",
"# : Toggle showing of polyhedrons",
"-: Toggle showing of bonds",
"r : Reset camera direction",
"[/]: Decrease or increase poly_radii_tol_factor " "by 0.05. Value = " + str(self.poly_radii_tol_factor),
"Up/Down: Rotate view along Up direction by 90 " "clockwise/anticlockwise",
"Left/right: Rotate view along camera direction by " "90 clockwise/anticlockwise",
"s: Save view to image.png",
"o: Orthogonalize structure",
]
self.helptxt_mapper.SetInput("\n".join(helptxt))
self.helptxt_actor.SetPosition(10, 10)
self.helptxt_actor.VisibilityOn()
[docs] def set_structure(self, structure, reset_camera=True, to_unit_cell=True):
"""
Add a structure to the visualizer.
Args:
structure: structure to visualize
reset_camera: Set to True to reset the camera to a default
determined based on the structure.
to_unit_cell: Whether or not to fall back sites into the unit cell.
"""
self.ren.RemoveAllViewProps()
has_lattice = hasattr(structure, "lattice")
if has_lattice:
s = Structure.from_sites(structure, to_unit_cell=to_unit_cell)
s.make_supercell(self.supercell, to_unit_cell=to_unit_cell)
else:
s = structure
inc_coords = []
for site in s:
self.add_site(site)
inc_coords.append(site.coords)
count = 0
labels = ["a", "b", "c"]
colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
if has_lattice:
matrix = s.lattice.matrix
if self.show_unit_cell and has_lattice:
# matrix = s.lattice.matrix
self.add_text([0, 0, 0], "o")
for vec in matrix:
self.add_line((0, 0, 0), vec, colors[count])
self.add_text(vec, labels[count], colors[count])
count += 1
for (vec1, vec2) in itertools.permutations(matrix, 2):
self.add_line(vec1, vec1 + vec2)
for (vec1, vec2, vec3) in itertools.permutations(matrix, 3):
self.add_line(vec1 + vec2, vec1 + vec2 + vec3)
if self.show_bonds or self.show_polyhedron:
elements = sorted(s.composition.elements, key=lambda a: a.X)
anion = elements[-1]
def contains_anion(site):
for sp in site.species.keys():
if sp.symbol == anion.symbol:
return True
return False
anion_radius = anion.average_ionic_radius
for site in s:
exclude = False
max_radius = 0
color = np.array([0, 0, 0])
for sp, occu in site.species.items():
if sp.symbol in self.excluded_bonding_elements or sp == anion:
exclude = True
break
max_radius = max(max_radius, sp.average_ionic_radius)
color = color + occu * np.array(self.el_color_mapping.get(sp.symbol, [0, 0, 0]))
if not exclude:
max_radius = (1 + self.poly_radii_tol_factor) * (max_radius + anion_radius)
nn = structure.get_neighbors(site, float(max_radius))
nn_sites = []
for neighbor in nn:
if contains_anion(neighbor):
nn_sites.append(neighbor)
if not in_coord_list(inc_coords, neighbor.coords):
self.add_site(neighbor)
if self.show_bonds:
self.add_bonds(nn_sites, site)
if self.show_polyhedron:
color = [i / 255 for i in color]
self.add_polyhedron(nn_sites, site, color)
if self.show_help:
self.helptxt_actor = vtk.vtkActor2D()
self.helptxt_actor.VisibilityOn()
self.helptxt_actor.SetMapper(self.helptxt_mapper)
self.ren.AddActor(self.helptxt_actor)
self.display_help()
camera = self.ren.GetActiveCamera()
if reset_camera:
if has_lattice:
# Adjust the camera for best viewing
lengths = s.lattice.abc
pos = (matrix[1] + matrix[2]) * 0.5 + matrix[0] * max(lengths) / lengths[0] * 3.5
camera.SetPosition(pos)
camera.SetViewUp(matrix[2])
camera.SetFocalPoint((matrix[0] + matrix[1] + matrix[2]) * 0.5)
else:
origin = s.center_of_mass
max_site = max(s, key=lambda site: site.distance_from_point(origin))
camera.SetPosition(origin + 5 * (max_site.coords - origin))
camera.SetFocalPoint(s.center_of_mass)
self.structure = structure
self.title = s.composition.formula
[docs] def zoom(self, factor):
"""
Zoom the camera view by a factor.
"""
camera = self.ren.GetActiveCamera()
camera.Zoom(factor)
self.ren_win.Render()
[docs] def show(self):
"""
Display the visualizer.
"""
self.iren.Initialize()
self.ren_win.SetSize(800, 800)
self.ren_win.SetWindowName(self.title)
self.ren_win.Render()
self.iren.Start()
[docs] def add_site(self, site):
"""
Add a site to the render window. The site is displayed as a sphere, the
color of which is determined based on the element. Partially occupied
sites are displayed as a single element color, though the site info
still shows the partial occupancy.
Args:
site: Site to add.
"""
start_angle = 0
radius = 0
total_occu = 0
for specie, occu in site.species.items():
radius += occu * (
specie.ionic_radius
if isinstance(specie, Species) and specie.ionic_radius
else specie.average_ionic_radius
)
total_occu += occu
vis_radius = 0.2 + 0.002 * radius
for specie, occu in site.species.items():
if not specie:
color = (1, 1, 1)
elif specie.symbol in self.el_color_mapping:
color = [i / 255 for i in self.el_color_mapping[specie.symbol]]
mapper = self.add_partial_sphere(site.coords, vis_radius, color, start_angle, start_angle + 360 * occu)
self.mapper_map[mapper] = [site]
start_angle += 360 * occu
if total_occu < 1:
mapper = self.add_partial_sphere(
site.coords,
vis_radius,
(1, 1, 1),
start_angle,
start_angle + 360 * (1 - total_occu),
)
self.mapper_map[mapper] = [site]
[docs] def add_partial_sphere(self, coords, radius, color, start=0, end=360, opacity=1.0):
"""
Adding a partial sphere (to display partial occupancies.
Args:
coords (nd.array): Coordinates
radius (float): Radius of sphere
color (): Color of sphere.
start (float): Starting angle.
end (float): Ending angle.
opacity (float): Opacity.
"""
sphere = vtk.vtkSphereSource()
sphere.SetCenter(coords)
sphere.SetRadius(radius)
sphere.SetThetaResolution(18)
sphere.SetPhiResolution(18)
sphere.SetStartTheta(start)
sphere.SetEndTheta(end)
mapper = vtk.vtkPolyDataMapper()
mapper.SetInputConnection(sphere.GetOutputPort())
actor = vtk.vtkActor()
actor.SetMapper(mapper)
actor.GetProperty().SetColor(color)
actor.GetProperty().SetOpacity(opacity)
self.ren.AddActor(actor)
return mapper
[docs] def add_text(self, coords, text, color=(0, 0, 0)):
"""
Add text at a coordinate.
Args:
coords: Coordinates to add text at.
text: Text to place.
color: Color for text as RGB. Defaults to black.
"""
source = vtk.vtkVectorText()
source.SetText(text)
mapper = vtk.vtkPolyDataMapper()
mapper.SetInputConnection(source.GetOutputPort())
follower = vtk.vtkFollower()
follower.SetMapper(mapper)
follower.GetProperty().SetColor(color)
follower.SetPosition(coords)
follower.SetScale(0.5)
self.ren.AddActor(follower)
follower.SetCamera(self.ren.GetActiveCamera())
[docs] def add_line(self, start, end, color=(0.5, 0.5, 0.5), width=1):
"""
Adds a line.
Args:
start: Starting coordinates for line.
end: Ending coordinates for line.
color: Color for text as RGB. Defaults to grey.
width: Width of line. Defaults to 1.
"""
source = vtk.vtkLineSource()
source.SetPoint1(start)
source.SetPoint2(end)
vertexIDs = vtk.vtkStringArray()
vertexIDs.SetNumberOfComponents(1)
vertexIDs.SetName("VertexIDs")
# Set the vertex labels
vertexIDs.InsertNextValue("a")
vertexIDs.InsertNextValue("b")
source.GetOutput().GetPointData().AddArray(vertexIDs)
mapper = vtk.vtkPolyDataMapper()
mapper.SetInputConnection(source.GetOutputPort())
actor = vtk.vtkActor()
actor.SetMapper(mapper)
actor.GetProperty().SetColor(color)
actor.GetProperty().SetLineWidth(width)
self.ren.AddActor(actor)
[docs] def add_polyhedron(
self,
neighbors,
center,
color,
opacity=1.0,
draw_edges=False,
edges_color=[0.0, 0.0, 0.0],
edges_linewidth=2,
):
"""
Adds a polyhedron.
Args:
neighbors: Neighbors of the polyhedron (the vertices).
center: The atom in the center of the polyhedron.
color: Color for text as RGB.
opacity: Opacity of the polyhedron
draw_edges: If set to True, the a line will be drawn at each edge
edges_color: Color of the line for the edges
edges_linewidth: Width of the line drawn for the edges
"""
points = vtk.vtkPoints()
conv = vtk.vtkConvexPointSet()
for i, n in enumerate(neighbors):
x, y, z = n.coords
points.InsertPoint(i, x, y, z)
conv.GetPointIds().InsertId(i, i)
grid = vtk.vtkUnstructuredGrid()
grid.Allocate(1, 1)
grid.InsertNextCell(conv.GetCellType(), conv.GetPointIds())
grid.SetPoints(points)
dsm = vtk.vtkDataSetMapper()
polysites = [center]
polysites.extend(neighbors)
self.mapper_map[dsm] = polysites
if vtk.VTK_MAJOR_VERSION <= 5:
dsm.SetInputConnection(grid.GetProducerPort())
else:
dsm.SetInputData(grid)
ac = vtk.vtkActor()
# ac.SetMapper(mapHull)
ac.SetMapper(dsm)
ac.GetProperty().SetOpacity(opacity)
if color == "element":
# If partial occupations are involved, the color of the specie with
# the highest occupation is used
myoccu = 0.0
for specie, occu in center.species.items():
if occu > myoccu:
myspecie = specie
myoccu = occu
color = [i / 255 for i in self.el_color_mapping[myspecie.symbol]]
ac.GetProperty().SetColor(color)
else:
ac.GetProperty().SetColor(color)
if draw_edges:
ac.GetProperty().SetEdgeColor(edges_color)
ac.GetProperty().SetLineWidth(edges_linewidth)
ac.GetProperty().EdgeVisibilityOn()
self.ren.AddActor(ac)
[docs] def add_triangle(
self,
neighbors,
color,
center=None,
opacity=0.4,
draw_edges=False,
edges_color=[0.0, 0.0, 0.0],
edges_linewidth=2,
):
"""
Adds a triangular surface between three atoms.
Args:
atoms: Atoms between which a triangle will be drawn.
color: Color for triangle as RGB.
center: The "central atom" of the triangle
opacity: opacity of the triangle
draw_edges: If set to True, the a line will be drawn at each edge
edges_color: Color of the line for the edges
edges_linewidth: Width of the line drawn for the edges
"""
points = vtk.vtkPoints()
triangle = vtk.vtkTriangle()
for ii in range(3):
points.InsertNextPoint(neighbors[ii].x, neighbors[ii].y, neighbors[ii].z)
triangle.GetPointIds().SetId(ii, ii)
triangles = vtk.vtkCellArray()
triangles.InsertNextCell(triangle)
# polydata object
trianglePolyData = vtk.vtkPolyData()
trianglePolyData.SetPoints(points)
trianglePolyData.SetPolys(triangles)
# mapper
mapper = vtk.vtkPolyDataMapper()
mapper.SetInput(trianglePolyData)
ac = vtk.vtkActor()
ac.SetMapper(mapper)
ac.GetProperty().SetOpacity(opacity)
if color == "element":
if center is None:
raise ValueError(
"Color should be chosen according to the central atom, " "and central atom is not provided"
)
# If partial occupations are involved, the color of the specie with
# the highest occupation is used
myoccu = 0.0
for specie, occu in center.species.items():
if occu > myoccu:
myspecie = specie
myoccu = occu
color = [i / 255 for i in self.el_color_mapping[myspecie.symbol]]
ac.GetProperty().SetColor(color)
else:
ac.GetProperty().SetColor(color)
if draw_edges:
ac.GetProperty().SetEdgeColor(edges_color)
ac.GetProperty().SetLineWidth(edges_linewidth)
ac.GetProperty().EdgeVisibilityOn()
self.ren.AddActor(ac)
[docs] def add_faces(self, faces, color, opacity=0.35):
"""
Adding face of polygon.
Args:
faces (): Coordinates of the faces.
color (): Color.
opacity (float): Opacity
"""
for face in faces:
if len(face) == 3:
points = vtk.vtkPoints()
triangle = vtk.vtkTriangle()
for ii in range(3):
points.InsertNextPoint(face[ii][0], face[ii][1], face[ii][2])
triangle.GetPointIds().SetId(ii, ii)
triangles = vtk.vtkCellArray()
triangles.InsertNextCell(triangle)
trianglePolyData = vtk.vtkPolyData()
trianglePolyData.SetPoints(points)
trianglePolyData.SetPolys(triangles)
mapper = vtk.vtkPolyDataMapper()
if vtk.VTK_MAJOR_VERSION <= 5:
mapper.SetInputConnection(trianglePolyData.GetProducerPort())
else:
mapper.SetInputData(trianglePolyData)
# mapper.SetInput(trianglePolyData)
ac = vtk.vtkActor()
ac.SetMapper(mapper)
ac.GetProperty().SetOpacity(opacity)
ac.GetProperty().SetColor(color)
self.ren.AddActor(ac)
elif len(face) > 3:
center = np.zeros(3, np.float_)
for site in face:
center += site
center /= np.float_(len(face))
for ii, f in enumerate(face):
points = vtk.vtkPoints()
triangle = vtk.vtkTriangle()
points.InsertNextPoint(f[0], f[1], f[2])
ii2 = np.mod(ii + 1, len(face))
points.InsertNextPoint(face[ii2][0], face[ii2][1], face[ii2][2])
points.InsertNextPoint(center[0], center[1], center[2])
for ii in range(3):
triangle.GetPointIds().SetId(ii, ii)
triangles = vtk.vtkCellArray()
triangles.InsertNextCell(triangle)
trianglePolyData = vtk.vtkPolyData()
trianglePolyData.SetPoints(points)
trianglePolyData.SetPolys(triangles)
mapper = vtk.vtkPolyDataMapper()
if vtk.VTK_MAJOR_VERSION <= 5:
mapper.SetInputConnection(trianglePolyData.GetProducerPort())
else:
mapper.SetInputData(trianglePolyData)
# mapper.SetInput(trianglePolyData)
ac = vtk.vtkActor()
ac.SetMapper(mapper)
ac.GetProperty().SetOpacity(opacity)
ac.GetProperty().SetColor(color)
self.ren.AddActor(ac)
else:
raise ValueError("Number of points for a face should be >= 3")
[docs] def add_edges(self, edges, type="line", linewidth=2, color=[0.0, 0.0, 0.0]):
"""
Args:
edges (): List of edges
type ():
linewidth (): Width of line
color (nd.array/tuple): RGB color.
"""
points = vtk.vtkPoints()
lines = vtk.vtkCellArray()
for iedge, edge in enumerate(edges):
points.InsertPoint(2 * iedge, edge[0])
points.InsertPoint(2 * iedge + 1, edge[1])
lines.InsertNextCell(2)
lines.InsertCellPoint(2 * iedge)
lines.InsertCellPoint(2 * iedge + 1)
polydata = vtk.vtkPolyData()
polydata.SetPoints(points)
polydata.SetLines(lines)
mapper = vtk.vtkPolyDataMapper()
if vtk.VTK_MAJOR_VERSION <= 5:
mapper.SetInputConnection(polydata.GetProducerPort())
else:
mapper.SetInputData(polydata)
# mapper.SetInput(polydata)
ac = vtk.vtkActor()
ac.SetMapper(mapper)
ac.GetProperty().SetColor(color)
ac.GetProperty().SetLineWidth(linewidth)
self.ren.AddActor(ac)
[docs] def add_bonds(self, neighbors, center, color=None, opacity=None, radius=0.1):
"""
Adds bonds for a site.
Args:
neighbors: Neighbors of the site.
center: The site in the center for all bonds.
color: Color of the tubes representing the bonds
opacity: Opacity of the tubes representing the bonds
radius: Radius of tube s representing the bonds
"""
points = vtk.vtkPoints()
points.InsertPoint(0, center.x, center.y, center.z)
n = len(neighbors)
lines = vtk.vtkCellArray()
for i in range(n):
points.InsertPoint(i + 1, neighbors[i].coords)
lines.InsertNextCell(2)
lines.InsertCellPoint(0)
lines.InsertCellPoint(i + 1)
pd = vtk.vtkPolyData()
pd.SetPoints(points)
pd.SetLines(lines)
tube = vtk.vtkTubeFilter()
if vtk.VTK_MAJOR_VERSION <= 5:
tube.SetInputConnection(pd.GetProducerPort())
else:
tube.SetInputData(pd)
tube.SetRadius(radius)
mapper = vtk.vtkPolyDataMapper()
mapper.SetInputConnection(tube.GetOutputPort())
actor = vtk.vtkActor()
actor.SetMapper(mapper)
if opacity is not None:
actor.GetProperty().SetOpacity(opacity)
if color is not None:
actor.GetProperty().SetColor(color)
self.ren.AddActor(actor)
[docs] def add_picker_fixed(self):
"""
Create a cell picker.Returns:
"""
picker = vtk.vtkCellPicker()
# Create a Python function to create the text for the text mapper used
# to display the results of picking.
def annotate_pick(obj, event):
if picker.GetCellId() < 0 and not self.show_help:
self.helptxt_actor.VisibilityOff()
else:
mapper = picker.GetMapper()
if mapper in self.mapper_map:
output = []
for site in self.mapper_map[mapper]:
row = [
"{} - ".format(site.species_string),
", ".join(["{:.3f}".format(c) for c in site.frac_coords]),
"[" + ", ".join(["{:.3f}".format(c) for c in site.coords]) + "]",
]
output.append("".join(row))
self.helptxt_mapper.SetInput("\n".join(output))
self.helptxt_actor.SetPosition(10, 10)
self.helptxt_actor.VisibilityOn()
self.show_help = False
self.picker = picker
picker.AddObserver("EndPickEvent", annotate_pick)
self.iren.SetPicker(picker)
[docs] def add_picker(self):
"""
Create a cell picker.
"""
picker = vtk.vtkCellPicker()
# Create a Python function to create the text for the text mapper used
# to display the results of picking.
source = vtk.vtkVectorText()
mapper = vtk.vtkPolyDataMapper()
mapper.SetInputConnection(source.GetOutputPort())
follower = vtk.vtkFollower()
follower.SetMapper(mapper)
follower.GetProperty().SetColor((0, 0, 0))
follower.SetScale(0.2)
self.ren.AddActor(follower)
follower.SetCamera(self.ren.GetActiveCamera())
follower.VisibilityOff()
def annotate_pick(obj, event):
if picker.GetCellId() < 0:
follower.VisibilityOff()
else:
pick_pos = picker.GetPickPosition()
mapper = picker.GetMapper()
if mapper in self.mapper_map:
site = self.mapper_map[mapper]
output = [
site.species_string,
"Frac. coords: " + " ".join(["{:.4f}".format(c) for c in site.frac_coords]),
]
source.SetText("\n".join(output))
follower.SetPosition(pick_pos)
follower.VisibilityOn()
picker.AddObserver("EndPickEvent", annotate_pick)
self.picker = picker
self.iren.SetPicker(picker)
[docs]class StructureInteractorStyle(vtkInteractorStyleTrackballCamera):
"""
A custom interactor style for visualizing structures.
"""
def __init__(self, parent):
"""
Args:
parent ():
"""
self.parent = parent
self.AddObserver("LeftButtonPressEvent", self.leftButtonPressEvent)
self.AddObserver("MouseMoveEvent", self.mouseMoveEvent)
self.AddObserver("LeftButtonReleaseEvent", self.leftButtonReleaseEvent)
self.AddObserver("KeyPressEvent", self.keyPressEvent)
[docs] def mouseMoveEvent(self, obj, event):
"""
Args:
obj ():
event ():
"""
self.mouse_motion = 1
self.OnMouseMove()
[docs] def keyPressEvent(self, obj, event):
"""
Args:
obj ():
event ():
"""
parent = obj.GetCurrentRenderer().parent
sym = parent.iren.GetKeySym()
if sym in "ABCabc":
if sym == "A":
parent.supercell[0][0] += 1
elif sym == "B":
parent.supercell[1][1] += 1
elif sym == "C":
parent.supercell[2][2] += 1
elif sym == "a":
parent.supercell[0][0] = max(parent.supercell[0][0] - 1, 1)
elif sym == "b":
parent.supercell[1][1] = max(parent.supercell[1][1] - 1, 1)
elif sym == "c":
parent.supercell[2][2] = max(parent.supercell[2][2] - 1, 1)
parent.redraw()
elif sym == "numbersign":
parent.show_polyhedron = not parent.show_polyhedron
parent.redraw()
elif sym == "minus":
parent.show_bonds = not parent.show_bonds
parent.redraw()
elif sym == "bracketleft":
parent.poly_radii_tol_factor -= 0.05 if parent.poly_radii_tol_factor > 0 else 0
parent.redraw()
elif sym == "bracketright":
parent.poly_radii_tol_factor += 0.05
parent.redraw()
elif sym == "h":
parent.show_help = not parent.show_help
parent.redraw()
elif sym == "r":
parent.redraw(True)
elif sym == "s":
parent.write_image("image.png")
elif sym == "Up":
parent.rotate_view(1, 90)
elif sym == "Down":
parent.rotate_view(1, -90)
elif sym == "Left":
parent.rotate_view(0, -90)
elif sym == "Right":
parent.rotate_view(0, 90)
elif sym == "o":
parent.orthongonalize_structure()
parent.redraw()
self.OnKeyPress()
[docs]def make_movie(structures, output_filename="movie.mp4", zoom=1.0, fps=20, bitrate="10000k", quality=1, **kwargs):
r"""
Generate a movie from a sequence of structures using vtk and ffmpeg.
Args:
structures ([Structure]): sequence of structures
output_filename (str): filename for structure output. defaults to
movie.mp4
zoom (float): A zoom to be applied to the visualizer. Defaults to 1.0.
fps (int): Frames per second for the movie. Defaults to 20.
bitrate (str): Video bitate. Defaults to "10000k" (fairly high
quality).
quality (int): A quality scale. Defaults to 1.
\\*\\*kwargs: Any kwargs supported by StructureVis to modify the images
generated.
"""
vis = StructureVis(**kwargs)
vis.show_help = False
vis.redraw()
vis.zoom(zoom)
sigfig = int(math.floor(math.log10(len(structures))) + 1)
filename = "image{0:0" + str(sigfig) + "d}.png"
for i, s in enumerate(structures):
vis.set_structure(s)
vis.write_image(filename.format(i), 3)
filename = "image%0" + str(sigfig) + "d.png"
args = [
"ffmpeg",
"-y",
"-i",
filename,
"-q:v",
str(quality),
"-r",
str(fps),
"-b:v",
str(bitrate),
output_filename,
]
subprocess.Popen(args)
[docs]class MultiStructuresVis(StructureVis):
"""
Visualization for multiple structures.
"""
DEFAULT_ANIMATED_MOVIE_OPTIONS = {
"time_between_frames": 0.1,
"looping_type": "restart",
"number_of_loops": 1,
"time_between_loops": 1.0,
}
def __init__(
self,
element_color_mapping=None,
show_unit_cell=True,
show_bonds=False,
show_polyhedron=False,
poly_radii_tol_factor=0.5,
excluded_bonding_elements=None,
animated_movie_options=DEFAULT_ANIMATED_MOVIE_OPTIONS,
):
"""
Args:
element_color_mapping: Optional color mapping for the elements,
as a dict of {symbol: rgb tuple}. For example, {"Fe": (255,
123,0), ....} If None is specified, a default based on
Jmol"s color scheme is used.
show_unit_cell: Set to False to not show the unit cell
boundaries. Defaults to True.
show_bonds: Set to True to show bonds. Defaults to True.
show_polyhedron: Set to True to show polyhedrons. Defaults to
False.
poly_radii_tol_factor: The polyhedron and bonding code uses the
ionic radii of the elements or species to determine if two
atoms are bonded. This specifies a tolerance scaling factor
such that atoms which are (1 + poly_radii_tol_factor) * sum
of ionic radii apart are still considered as bonded.
excluded_bonding_elements: List of atom types to exclude from
bonding determination. Defaults to an empty list. Useful
when trying to visualize a certain atom type in the
framework (e.g., Li in a Li-ion battery cathode material).
animated_movie_options (): Used for moving.
"""
super().__init__(
element_color_mapping=element_color_mapping,
show_unit_cell=show_unit_cell,
show_bonds=show_bonds,
show_polyhedron=show_polyhedron,
poly_radii_tol_factor=poly_radii_tol_factor,
excluded_bonding_elements=excluded_bonding_elements,
)
self.warningtxt_actor = vtk.vtkActor2D()
self.infotxt_actor = vtk.vtkActor2D()
self.structures = None
style = MultiStructuresInteractorStyle(self)
self.iren.SetInteractorStyle(style)
self.istruct = 0
self.current_structure = None
self.set_animated_movie_options(animated_movie_options=animated_movie_options)
[docs] def set_structures(self, structures, tags=None):
"""
Add list of structures to the visualizer.
Args:
structures (List of Structures):
tags (): List of tags.
"""
self.structures = structures
self.istruct = 0
self.current_structure = self.structures[self.istruct]
self.tags = tags if tags is not None else []
self.all_radii = []
self.all_vis_radii = []
for struct in self.structures:
struct_radii = []
struct_vis_radii = []
for site in struct:
radius = 0
for specie, occu in site.species.items():
radius += occu * (
specie.ionic_radius
if isinstance(specie, Species) and specie.ionic_radius
else specie.average_ionic_radius
)
vis_radius = 0.2 + 0.002 * radius
struct_radii.append(radius)
struct_vis_radii.append(vis_radius)
self.all_radii.append(struct_radii)
self.all_vis_radii.append(struct_vis_radii)
self.set_structure(self.current_structure, reset_camera=True, to_unit_cell=False)
[docs] def set_structure(self, structure, reset_camera=True, to_unit_cell=False):
"""
Add a structure to the visualizer.
Args:
structure: structure to visualize
reset_camera: Set to True to reset the camera to a default
determined based on the structure.
to_unit_cell: Whether or not to fall back sites into the unit cell.
"""
super().set_structure(structure=structure, reset_camera=reset_camera, to_unit_cell=to_unit_cell)
self.apply_tags()
[docs] def set_animated_movie_options(self, animated_movie_options=None):
"""
Args:
animated_movie_options ():
"""
if animated_movie_options is None:
self.animated_movie_options = self.DEFAULT_ANIMATED_MOVIE_OPTIONS.copy()
else:
self.animated_movie_options = self.DEFAULT_ANIMATED_MOVIE_OPTIONS.copy()
for key in animated_movie_options:
if key not in self.DEFAULT_ANIMATED_MOVIE_OPTIONS.keys():
raise ValueError("Wrong option for animated movie")
self.animated_movie_options.update(animated_movie_options)
[docs] def display_help(self):
"""
Display the help for various keyboard shortcuts.
"""
helptxt = [
"h : Toggle help",
"A/a, B/b or C/c : Increase/decrease cell by one a," " b or c unit vector",
"# : Toggle showing of polyhedrons",
"-: Toggle showing of bonds",
"r : Reset camera direction",
"[/]: Decrease or increase poly_radii_tol_factor " "by 0.05. Value = " + str(self.poly_radii_tol_factor),
"Up/Down: Rotate view along Up direction by 90 " "clockwise/anticlockwise",
"Left/right: Rotate view along camera direction by " "90 clockwise/anticlockwise",
"s: Save view to image.png",
"o: Orthogonalize structure",
"n: Move to next structure",
"p: Move to previous structure",
"m: Animated movie of the structures",
]
self.helptxt_mapper.SetInput("\n".join(helptxt))
self.helptxt_actor.SetPosition(10, 10)
self.helptxt_actor.VisibilityOn()
[docs] def display_warning(self, warning):
"""
Args:
warning (str): Warning
"""
self.warningtxt_mapper = vtk.vtkTextMapper()
tprops = self.warningtxt_mapper.GetTextProperty()
tprops.SetFontSize(14)
tprops.SetFontFamilyToTimes()
tprops.SetColor(1, 0, 0)
tprops.BoldOn()
tprops.SetJustificationToRight()
self.warningtxt = "WARNING : {}".format(warning)
self.warningtxt_actor = vtk.vtkActor2D()
self.warningtxt_actor.VisibilityOn()
self.warningtxt_actor.SetMapper(self.warningtxt_mapper)
self.ren.AddActor(self.warningtxt_actor)
self.warningtxt_mapper.SetInput(self.warningtxt)
winsize = self.ren_win.GetSize()
self.warningtxt_actor.SetPosition(winsize[0] - 10, 10)
self.warningtxt_actor.VisibilityOn()
[docs] def erase_warning(self):
"""
Remove warnings.
"""
self.warningtxt_actor.VisibilityOff()
[docs] def display_info(self, info):
"""
Args:
info (str): Information.
"""
self.infotxt_mapper = vtk.vtkTextMapper()
tprops = self.infotxt_mapper.GetTextProperty()
tprops.SetFontSize(14)
tprops.SetFontFamilyToTimes()
tprops.SetColor(0, 0, 1)
tprops.BoldOn()
tprops.SetVerticalJustificationToTop()
self.infotxt = "INFO : {}".format(info)
self.infotxt_actor = vtk.vtkActor2D()
self.infotxt_actor.VisibilityOn()
self.infotxt_actor.SetMapper(self.infotxt_mapper)
self.ren.AddActor(self.infotxt_actor)
self.infotxt_mapper.SetInput(self.infotxt)
winsize = self.ren_win.GetSize()
self.infotxt_actor.SetPosition(10, winsize[1] - 10)
self.infotxt_actor.VisibilityOn()
[docs] def erase_info(self):
"""
Erase all info.
"""
self.infotxt_actor.VisibilityOff()
[docs]class MultiStructuresInteractorStyle(StructureInteractorStyle):
"""
Interactor for MultiStructureVis.
"""
def __init__(self, parent):
"""
Args:
parent ():
"""
StructureInteractorStyle.__init__(self, parent=parent)
[docs] def keyPressEvent(self, obj, event):
"""
Args:
obj ():
event ():
"""
parent = obj.GetCurrentRenderer().parent
sym = parent.iren.GetKeySym()
if sym == "n":
if parent.istruct == len(parent.structures) - 1:
parent.display_warning("LAST STRUCTURE")
parent.ren_win.Render()
else:
parent.istruct += 1
parent.current_structure = parent.structures[parent.istruct]
parent.set_structure(parent.current_structure, reset_camera=False, to_unit_cell=False)
parent.erase_warning()
parent.ren_win.Render()
elif sym == "p":
if parent.istruct == 0:
parent.display_warning("FIRST STRUCTURE")
parent.ren_win.Render()
else:
parent.istruct -= 1
parent.current_structure = parent.structures[parent.istruct]
parent.set_structure(parent.current_structure, reset_camera=False, to_unit_cell=False)
parent.erase_warning()
parent.ren_win.Render()
elif sym == "m":
parent.istruct = 0
parent.current_structure = parent.structures[parent.istruct]
parent.set_structure(parent.current_structure, reset_camera=False, to_unit_cell=False)
parent.erase_warning()
parent.ren_win.Render()
nloops = parent.animated_movie_options["number_of_loops"]
tstep = parent.animated_movie_options["time_between_frames"]
tloops = parent.animated_movie_options["time_between_loops"]
if parent.animated_movie_options["looping_type"] == "restart":
loop_istructs = range(len(parent.structures))
elif parent.animated_movie_options["looping_type"] == "palindrome":
loop_istructs = range(len(parent.structures)) + range(len(parent.structures) - 2, -1, -1)
else:
raise ValueError('"looping_type" should be "restart" or "palindrome"')
for iloop in range(nloops):
for istruct in loop_istructs:
time.sleep(tstep)
parent.istruct = istruct
parent.current_structure = parent.structures[parent.istruct]
parent.set_structure(parent.current_structure, reset_camera=False, to_unit_cell=False)
parent.display_info(
"Animated movie : structure {:d}/{:d} "
"(loop {:d}/{:d})".format(istruct + 1, len(parent.structures), iloop + 1, nloops)
)
parent.ren_win.Render()
time.sleep(tloops)
parent.erase_info()
parent.display_info("Ended animated movie ...")
parent.ren_win.Render()
StructureInteractorStyle.keyPressEvent(self, obj, event)