"""
Tools for constructing and visualising convex hulls (phase diagrams).
Provides a Plotly-based phase diagram plotter extending pymatgen's
PDPlotter, and helper functions for ternary plot configuration.
"""
import numpy as np
import plotly.graph_objects as go
from pymatgen.analysis.phase_diagram import PDPlotter
[docs]
def make_axis(title: str, tickangle: float) -> dict:
"""Create axis configuration for a ternary plot."""
return {
"title": title,
"titlefont": {"size": 20},
"tickangle": tickangle,
"showticklabels": False,
"tickfont": {"size": 15},
"tickcolor": "rgba(0,0,0,0)",
"ticklen": 5,
"showline": False,
"showgrid": True,
}
[docs]
def entry_type(entry, default: str = "Default") -> str:
"""Return the type label of a phase diagram entry."""
if entry.attribute is None:
return default
return entry.attribute.get("entry_type", default)
[docs]
def entry_name(entry, default: str = "None") -> str:
"""Return the name of a phase diagram entry."""
if entry.attribute is None:
return default
return entry.attribute.get("struct_name", default)
[docs]
class PlotlyPDPlotter(PDPlotter):
"""
An extension of PDPlotter for plotting phase diagrams with Plotly.
This may only work with old pymatgen versions that do not have
native Plotly backend support.
"""
@property
def pd_plot_data_ternary(self):
"""
Get the plot data for native ternary plot.
Returns:
(lines, stable_entries, unstable_entries):
Same as pd_plot_data, but in proper ternary coordinates.
"""
assert self._dim == 3, "Cannot get data for ternary plot - the system is not ternary"
pd = self._pd
entries = pd.qhull_entries
def recover_a(old_data):
"""Recover the a coordinates"""
old_shape = old_data.shape
data = np.zeros((old_shape[0], old_shape[1] + 1))
data[:, 1:] = old_data
data[:, 0] = 1 - (old_data[:, 0] + old_data[:, 1])
return data
data = recover_a(pd.qhull_data)
lines = []
stable_entries = {}
for line in self.lines:
entry1 = entries[line[0]]
entry2 = entries[line[1]]
coord = data[line, :3]
lines.append(coord)
stable_entries[tuple(float(x) for x in coord[0])] = entry1
stable_entries[tuple(float(x) for x in coord[1])] = entry2
all_entries = pd.all_entries
all_data = recover_a(pd.all_entries_hulldata)
unstable_entries = {}
stable = pd.stable_entries
for i, entry in enumerate(all_entries):
entry = all_entries[i]
if entry not in stable:
coord = all_data[i, :3]
coord = tuple(float(x) for x in coord)
unstable_entries[entry] = coord
return lines, stable_entries, unstable_entries
def _get_3d_ternary_plot(self): # pylint: disable=too-many-statements
"""Obtain a 3D ternary plot with all phases, stable and unstable."""
fig = go.Figure()
lines, labels, unstable = self.pd_plot_data
pd = self._pd
x_list, y_list, z_list = [], [], []
for x, y in lines:
x_list.extend(tuple(x) + (None,))
y_list.extend(tuple(y) + (None,))
entry1 = labels[(x[0], y[0])]
entry2 = labels[(x[1], y[1])]
form_eng1 = pd.get_form_energy_per_atom(entry1)
form_eng2 = pd.get_form_energy_per_atom(entry2)
z_list.extend((form_eng1, form_eng2, None))
fig.add_scatter3d(x=x_list, y=y_list, z=z_list, mode="lines", hoverinfo="none")
elems = pd.elements
all_types = [entry_type(entry, "Default") for entry in pd.all_entries]
all_types = sorted(set(all_types))
all_stable_names = set()
for etype in all_types:
stable_coords = []
stable_text = []
form_engs = []
for coords in sorted(labels.keys(), key=lambda x: -x[1]):
entry = labels[coords]
if entry_type(entry, "Default") != etype:
continue
label = entry.name
stable_coords.append(coords)
stable_text.append(label)
all_stable_names.add(label)
form_engs.append(pd.get_form_energy_per_atom(entry))
if not stable_coords:
continue
x, y = list(zip(*stable_coords))
fig.add_scatter3d(
x=x, y=y, z=form_engs, text=stable_text, name=f"Stable ({etype})",
mode="markers+text", marker_size=5,
)
if self.show_unstable:
for etype in all_types:
unstable_coords = []
unstable_text = []
unstable_name = []
form_engs = []
for entry, coords in unstable.items():
if entry_type(entry, "Unstable") != etype:
continue
ehull = pd.get_e_above_hull(entry)
if ehull > self.show_unstable:
continue
unstable_coords.append(coords)
unstable_text.append(entry.name)
unstable_name.append(entry_name(entry))
form_engs.append(pd.get_form_energy_per_atom(entry))
if not unstable_coords:
continue
x, y = list(zip(*unstable_coords))
fig.add_scatter3d(
x=x, y=y, z=form_engs, text=unstable_text,
customdata=np.array([unstable_name]).T,
name=f"Unstable ({etype})", mode="markers",
hovertemplate="%{text} - %{customdata[0]}", marker_size=2,
)
pname = "-".join(x.name for x in elems)
fig.update_layout({
"title": f"Ternary plot for {pname}",
"autosize": False, "height": 600, "width": 800,
})
return fig
def _get_2d_ternary_plot(self): # pylint: disable=too-many-statements,too-many-locals,too-many-branches
"""Obtain the 2D ternary plot."""
pd = self._pd
lines, labels, unstable = self.pd_plot_data_ternary
fig = go.Figure()
a_list, b_list, c_list = [], [], []
for startfinish in lines:
a, b, c = list(zip(*startfinish))
a_list.extend(a + (None,))
b_list.extend(b + (None,))
c_list.extend(c + (None,))
fig.add_scatterternary(a=a_list, b=b_list, c=c_list, mode="lines", hoverinfo="none")
elems = pd.elements
all_types = [entry_type(entry, "Default") for entry in pd.all_entries]
all_types = sorted(set(all_types))
all_stable_names = set()
for etype in all_types:
stable_coords = []
stable_text = []
stable_name = []
for coords in sorted(labels.keys(), key=lambda x: -x[1]):
entry = labels[coords]
if entry_type(entry, "Default") != etype:
continue
label = entry.name
stable_coords.append(coords)
stable_text.append(label)
all_stable_names.add(label)
stable_name.append(entry_name(entry))
a, b, c = list(zip(*stable_coords))
aname, bname, cname = (x.name for x in elems)
fig.add_scatterternary(
a=a, b=b, c=c, text=stable_text, marker_symbol="circle",
name=f"Stable ({etype})", mode="markers+text",
customdata=np.array([stable_name]).T,
hovertemplate=(
f"{aname}: %{{a:.2f}} {bname}: %{{b:.2f}} {cname}: %{{c:.2f}}"
"<br>name: %{customdata[0]}"
),
cliponaxis=False,
)
if self.show_unstable:
comps = {}
for entry, coords in unstable.items():
ehull = pd.get_e_above_hull(entry)
reduced = entry.composition.reduced_formula
if reduced not in comps:
comps[reduced] = [entry, coords, ehull]
else:
if ehull < comps[reduced][-1]:
comps[reduced] = [entry, coords, ehull]
for etype in all_types:
unstable_coords = []
unstable_text = []
unstable_name = []
dist2hull = []
nsamples = len(unstable)
scatter_mode = "markers" if nsamples > 10 else "markers+text"
for entry, coords, ehull in comps.values():
if entry_type(entry, "Unstable") != etype or entry.name in all_stable_names:
continue
if ehull > self.show_unstable:
continue
dist2hull.append(ehull)
unstable_coords.append(coords)
unstable_text.append(entry.name)
unstable_name.append(entry_name(entry))
if not unstable_coords:
continue
a, b, c = list(zip(*unstable_coords))
aname, bname, cname = (x.name for x in elems)
fig.add_scatterternary(
a=a, b=b, c=c, marker_symbol="triangle-up",
marker_color=dist2hull, text=unstable_text,
name=f"Unstable ({etype})", mode=scatter_mode,
customdata=np.array([unstable_name, dist2hull]).T,
hovertemplate=(
f"{aname}: %{{a:.2f}} {bname}: %{{b:.2f}} {cname}: %{{c:.2f}}"
"<br>name: %{customdata[0]}"
"<br>above_hull %{customdata[1]:.4f} eV"
),
cliponaxis=False,
)
aname, bname, cname = (x.name for x in elems)
fig.update_layout({
"title": f"Ternary plot for {aname}-{bname}-{cname}",
"ternary": {
"sum": 1,
"aaxis": make_axis(elems[0].name, 0),
"baxis": make_axis(elems[1].name, 45),
"caxis": make_axis(elems[2].name, -45),
},
"autosize": False, "height": 600, "width": 800,
})
return fig