Source code for decaylanguage.decay.viewer

# Copyright (c) 2018-2024, Eduardo Rodrigues and Henry Schreiner.
#
# Distributed under the 3-clause BSD license, see accompanying file LICENSE
# or https://github.com/scikit-hep/decaylanguage for details.

"""
Submodule with classes and utilities to visualize decay chains.
Decay chains are typically provided by the parser of .dec decay files,
see the ``DecFileParser`` class.
"""

from __future__ import annotations

import itertools
from typing import Any

import graphviz
from particle import latex_to_html_name
from particle.converters.bimap import DirectionalMaps

counter = iter(itertools.count())


_EvtGen2LatexNameMap, _Latex2EvtGenNameMap = DirectionalMaps("EvtGenName", "LaTexName")


class GraphNotBuiltError(RuntimeError):
    pass


[docs] class DecayChainViewer: """ The class to visualize a decay chain. Examples -------- >>> dfp = DecFileParser('my-Dst-decay-file.dec') # doctest: +SKIP >>> dfp.parse() # doctest: +SKIP >>> chain = dfp.build_decay_chains('D*+') # doctest: +SKIP >>> dcv = DecayChainViewer(chain) # doctest: +SKIP >>> # display the SVG figure in a notebook >>> dcv # doctest: +SKIP When not in notebooks the graph can easily be visualized with the ``graphviz.Digraph.render`` or ``graphviz.Digraph.view`` functions, e.g.: >>> dcv.graph.render(filename="test", format="pdf", view=True, cleanup=True) # doctest: +SKIP """ __slots__ = ("_chain", "_graph", "_graph_attributes") def __init__( self, decaychain: dict[str, list[dict[str, float | str | list[Any]]]], **attrs: dict[str, bool | int | float | str], ) -> None: """ Default constructor. Parameters ---------- decaychain: dict Input decay chain in dict format, typically created from ``decaylanguage.DecFileParser.build_decay_chains`` after parsing a .dec decay file, or from building a decay chain representation with ``decaylanguage.DecayChain.to_dict``. attrs: optional User input ``graphviz.Digraph`` class attributes. See also -------- decaylanguage.DecFileParser.build_decay_chains for creating a decay chain dict from parsing a .dec file. decaylanguage.DecFileParser: class for creating an input decay chain. """ # Store the input decay chain self._chain = decaychain # Instantiate the digraph with defaults possibly overridden by user attributes self._graph = self._instantiate_graph(**attrs) # Build the actual graph from the input decay chain structure self._build_decay_graph() def _build_decay_graph(self) -> None: """ Recursively navigate the decay chain tree and produce a Digraph in the DOT language. """ def safe_html_name(name: str) -> str: """ Get a safe HTML name from the EvtGen name. Note ---- The match is done using a conversion map rather than via ``Particle.from_evtgen_name(name).html_name`` for 2 reasons: - Some decay-file-specific "particle" names (e.g. cs_0) are not in the PDG table. - No need to load all particle information if all that's needed is a match EvtGen - HTML name. """ try: return latex_to_html_name(_EvtGen2LatexNameMap[name]) except Exception: return name def html_table_label( names: list[str], add_tags: bool = False, bgcolor: str = "#9abad6", ) -> str: if add_tags: label = f'<<TABLE BORDER="0" CELLSPACING="0" BGCOLOR="{bgcolor}">' else: label = f'<<TABLE BORDER="0" CELLSPACING="0" CELLPADDING="0" BGCOLOR="{bgcolor}"><TR>' for i, n in enumerate(names): if add_tags: label += '<TR><TD BORDER="1" CELLPADDING="5" PORT="p{tag}">{name}</TD></TR>'.format( tag=i, name=safe_html_name(n) ) else: label += f'<TD BORDER="0" CELLPADDING="2">{safe_html_name(n)}</TD>' label += "{tr}</TABLE>>".format(tr="" if add_tags else "</TR>") return label def new_node_no_subchain(list_parts: list[str]) -> str: label = html_table_label(list_parts, bgcolor="#eef3f8") r = f"dec{next(counter)}" self.graph.node(r, label=label, style="filled", fillcolor="#eef3f8") return r def new_node_with_subchain(list_parts: list[Any]) -> str: _list_parts = [ next(iter(p.keys())) if isinstance(p, dict) else p for p in list_parts ] label = html_table_label(_list_parts, add_tags=True) r = f"dec{next(counter)}" self.graph.node(r, shape="none", label=label) return r def iterate_chain( subchain: list[dict[str, float | str | list[Any]]], top_node: str | None = None, link_pos: int | None = None, ) -> None: if not top_node: top_node = "mother" self.graph.node("mother", shape="none", label=label) n_decaymodes = len(subchain) for idm in range(n_decaymodes): _list_parts = subchain[idm]["fs"] if not has_subdecay(_list_parts): # type: ignore[arg-type] _ref = new_node_no_subchain(_list_parts) # type: ignore[arg-type] _bf = subchain[idm]["bf"] if link_pos is None: self.graph.edge(top_node, _ref, label=str(_bf)) else: self.graph.edge(f"{top_node}:p{link_pos}", _ref, label=str(_bf)) else: _ref_1 = new_node_with_subchain(_list_parts) # type: ignore[arg-type] _bf_1 = subchain[idm]["bf"] if link_pos is None: self.graph.edge(top_node, _ref_1, label=str(_bf_1)) else: self.graph.edge( f"{top_node}:p{link_pos}", _ref_1, label=str(_bf_1), ) for i, _p in enumerate(_list_parts): # type: ignore[arg-type] if not isinstance(_p, str): _k = next(iter(_p.keys())) iterate_chain(_p[_k], top_node=_ref_1, link_pos=i) def has_subdecay(ds: list[Any]) -> bool: return not all(isinstance(p, str) for p in ds) k = next(iter(self._chain.keys())) label = html_table_label([k], add_tags=True, bgcolor="#568dba") sc = self._chain[k] # Actually build the whole decay chain, iteratively iterate_chain(sc) @property def graph(self) -> graphviz.Digraph: """ Get the actual ``graphviz.Digraph`` object. The user now has full control ... """ return self._graph
[docs] def to_string(self) -> str: """ Return a string representation of the built graph in the DOT language. The function is a trivial shortcut for ``graphviz.Digraph.source``. """ return self.graph.source # type: ignore[no-any-return]
def _instantiate_graph( self, **attrs: dict[str, bool | int | float | str] ) -> graphviz.Digraph: """ Return a ``graphviz.Digraph`` class instance using the default attributes specified in this class: - Default graph attributes are overridden by input by the user. - Class and node and edge defaults. """ graph_attr = self._get_graph_defaults() node_attr = self._get_node_defaults() edge_attr = self._get_edge_defaults() if "graph_attr" in attrs: graph_attr.update(**attrs["graph_attr"]) attrs.pop("graph_attr") if "node_attr" in attrs: node_attr.update(**attrs["node_attr"]) attrs.pop("node_attr") if "edge_attr" in attrs: edge_attr.update(**attrs["edge_attr"]) attrs.pop("edge_attr") arguments = self._get_default_arguments() arguments.update(**attrs) # type: ignore[call-overload] return graphviz.Digraph( graph_attr=graph_attr, node_attr=node_attr, edge_attr=edge_attr, **arguments ) def _get_default_arguments(self) -> dict[str, bool | int | float | str]: """ ``graphviz.Digraph`` default arguments. """ return { "name": "DecayChainGraph", "comment": "Created by https://github.com/scikit-hep/decaylanguage", "engine": "dot", "format": "png", } def _get_graph_defaults(self) -> dict[str, bool | int | float | str]: d = self._get_default_arguments() d.update(rankdir="LR") return d def _get_node_defaults(self) -> dict[str, bool | int | float | str]: return {"fontname": "Helvetica", "fontsize": "11", "shape": "oval"} def _get_edge_defaults(self) -> dict[str, bool | int | float | str]: return {"fontcolor": "#4c4c4c", "fontsize": "11"} def _repr_mimebundle_( self, include: bool | None = None, exclude: bool | None = None, **kwargs: Any, ) -> Any: # pragma: no cover """ IPython display helper. """ try: return self._graph._repr_mimebundle_( include=include, exclude=exclude, **kwargs ) except AttributeError: return {"image/svg+xml": self._graph._repr_svg_()} # for graphviz < 0.19