Graph Visualization
GraphUtils.plot_graph() renders causal graphs with optional colored nodes. It supports
manual per-node coloring, automatic category-based coloring, custom labels, titles, and
saving to file.
Usage
Basic Plot
from causallearn.search.ScoreBased.GES import ges
from causallearn.utils.GraphUtils import GraphUtils
# Run GES to get a graph
Record = ges(X)
# Simple plot (no colors)
GraphUtils.plot_graph(Record['G'])
Plot with Custom Node Labels
GraphUtils.plot_graph(Record['G'],
labels=['Age', 'Income', 'Spending', 'Education'])
Plot with Manual Node Colors
# Assign a specific color to each node by name
colors = {
'Age': 'lightblue',
'Income': 'lightcoral',
'Spending': 'lightgreen',
'Education': 'lightsalmon',
}
GraphUtils.plot_graph(Record['G'],
labels=['Age', 'Income', 'Spending', 'Education'],
colors=colors,
title='Causal Graph with Manual Colors',
save_path='my_graph.png')
Plot with Category-based Colors (Auto-assigned)
# Group features into categories — each category gets a unique color
# Node labels are auto-derived from the category dict (flattened in order)
category_to_features = {
'demographics': ['Age', 'Education'],
'financial': ['Income', 'Spending'],
}
GraphUtils.plot_graph(Record['G'],
category_to_features=category_to_features,
title='Causal Graph by Category',
save_path='graph_by_category.png')
Full Example: GES + Colored Visualization
import numpy as np
from causallearn.search.ScoreBased.GES import ges
from causallearn.utils.GraphUtils import GraphUtils
# Generate sample data
np.random.seed(42)
n = 500
X = np.random.randn(n, 5)
X[:, 1] = X[:, 0] * 0.8 + np.random.randn(n) * 0.5
X[:, 2] = X[:, 1] * 0.6 + np.random.randn(n) * 0.4
# Run GES
Record = ges(X, node_names=['Drug_A', 'Drug_B', 'Recovery',
'Side_Effect', 'Dosage'])
# Visualize with category colors
categories = {
'treatment': ['Drug_A', 'Drug_B', 'Dosage'],
'outcome': ['Recovery', 'Side_Effect'],
}
GraphUtils.plot_graph(Record['G'],
category_to_features=categories,
title='Drug Study Causal Graph',
save_path='drug_study.png',
dpi=300)
Parameters
G: Graph. A causal-learn graph object (e.g., output of ges() or dges()).
labels: list of str or None. Node labels. Must have the same length as G.get_nodes(). Default: None (uses graph node names).
colors: dict or None. Mapping from node label (str) to fill color (str). Any valid CSS color name can be used (e.g., 'lightblue', '#FF6B6B'). If both colors and category_to_features are provided, colors takes precedence. Default: None.
category_to_features: dict or None. Mapping from category name (str) to list of feature names (list of str). Each category is automatically assigned a distinct color from a built-in palette. If labels is not provided, node labels are auto-derived by flattening this dict in order. Default: None.
save_path: str or None. File path to save the rendered image (e.g., 'output/graph.png'). Supports PNG, PDF, SVG formats. If None, the image is only displayed. Default: None.
title: str. Title displayed above the graph. Default: "".
dpi: float. Resolution in dots per inch. Default: 500.
figsize: tuple. Figure size in inches (width, height). Default: (20, 12).
get_category_colors()
A helper function that generates a color mapping from a category dictionary.
Used internally by plot_graph(), but can also be called directly if you need
to customize the color mapping before plotting.
from causallearn.utils.GraphUtils import GraphUtils
categories = {
'treatment': ['Drug_A', 'Drug_B'],
'outcome': ['Recovery', 'Side_Effect'],
}
colors = GraphUtils.get_category_colors(categories)
# Result: {'Drug_A': 'lightblue', 'Drug_B': 'lightblue',
# 'Recovery': 'lightcoral', 'Side_Effect': 'lightcoral'}
# Customize one color, then plot
colors['Drug_A'] = 'gold'
GraphUtils.plot_graph(G, colors=colors)
Parameters:
category_to_features: dict. Mapping from category name to list of feature names.
Returns:
feature_to_color: dict. Mapping from each feature name (str) to a CSS color string. All features in the same category share the same color.
Notes
Dependencies: Requires
matplotlibandpydot(which requires Graphviz to be installed).Edge notation: Directed edges are shown as arrows. Undirected edges are shown as lines.
Color priority: If both
colorsandcategory_to_featuresare passed,colorsis used andcategory_to_featuresis ignored.Nodes without colors: If a node label is not found in the
colorsdict, it will be rendered with the default white background.