Source code for autofragment.partitioners.molecular

# Copyright (c) 2026 Bryce M. Westheimer
# SPDX-License-Identifier: BSD-3-Clause

"""Molecular partitioner for water clusters and similar systems.

This module provides the main partitioner class for creating fragments
from collections of molecules (typically water clusters). Supports both
flat partitioning and tiered hierarchical partitioning (2-tier and 3-tier).
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np

from autofragment.algorithms.clustering import partition_labels
from autofragment.algorithms.geometric import partition_by_planes, partition_by_planes_tiered
from autofragment.core.geometry import compute_centroids
from autofragment.core.types import (
    ChemicalSystem,
    Fragment,
    FragmentTree,
    Molecule,
    system_to_molecules,
)
from autofragment.io.output import format_partitioning_info, format_source_info
from autofragment.partitioners.base import BasePartitioner
from autofragment.partitioners.topology import (
    BondPolicy,
    SelectionMode,
    TopologyNeighborSelection,
)

# Type aliases for tiered labels
LabelTuple2 = Tuple[int, int]
LabelTuple3 = Tuple[int, int, int]


class PartitionError(ValueError):
    """Raised when partitioning fails."""

    pass


@dataclass(frozen=True)
class PartitionResult:
    """Result of flat partitioning with labels and metadata."""

    n_fragments: int
    labels: np.ndarray
    centroids: np.ndarray
    molecules: List[Molecule]


@dataclass(frozen=True)
class TieredPartitionResult:
    """Result of tiered partitioning with tuple labels and metadata."""

    tiers: int
    n_primary: int
    n_secondary: int
    n_tertiary: Optional[int]
    labels: List[Union[LabelTuple2, LabelTuple3]]
    centroids: np.ndarray
    molecules: List[Molecule]


[docs] class MolecularPartitioner(BasePartitioner): """ Partitioner for water clusters and similar molecular systems. Supports both flat partitioning (default) and tiered hierarchical partitioning (2-tier and 3-tier). Parameters ---------- n_fragments : int Number of fragments (flat mode). Default is 4. method : str, optional Clustering method. Default is "kmeans". random_state : int, optional Random seed for clustering. Default is 42. strict_balanced : bool, optional If True, validate equal cluster sizes. Default is True for kmeans_constrained, False otherwise. tiers : int, optional Number of hierarchy tiers (2 or 3). None = flat mode (default). n_primary : int, optional Number of primary fragments (tiered mode). n_secondary : int, optional Number of secondary fragments per primary (tiered mode). n_tertiary : int, optional Number of tertiary fragments per secondary (3-tier mode). init_strategy : str | ndarray | dict | None, optional Default seeding strategy for all tiers / flat mode. init_strategy_primary : str | ndarray | dict | None, optional Override seeding strategy for primary (tier-1) clustering. init_strategy_secondary : str | ndarray | dict | None, optional Override seeding strategy for secondary (tier-2) clustering. init_strategy_tertiary : str | ndarray | dict | None, optional Override seeding strategy for tertiary (tier-3) clustering. Examples -------- Flat mode: >>> partitioner = MolecularPartitioner(n_fragments=2, method="kmeans") >>> tree = partitioner.partition(system) >>> len(tree.fragments) 2 Tiered mode: >>> partitioner = MolecularPartitioner( ... tiers=2, n_primary=4, n_secondary=4 ... ) >>> tree = partitioner.partition(system) >>> tree.n_primary 4 See Also -------- autofragment.partitioners.batch.BatchPartitioner : For processing multiple files. autofragment.core.types.FragmentTree : The result object containing fragments. """
[docs] def __init__( self, n_fragments: int = 4, method: str = "kmeans", random_state: int = 42, strict_balanced: Optional[bool] = None, topology_refine: bool = False, topology_mode: SelectionMode = "graph", topology_hops: int = 1, topology_layers: int = 1, topology_k_per_layer: int = 1, topology_bond_policy: BondPolicy = "infer", # Tiered parameters tiers: Optional[int] = None, n_primary: Optional[int] = None, n_secondary: Optional[int] = None, n_tertiary: Optional[int] = None, # Seeding parameters init_strategy: Union[None, str, np.ndarray, Dict[str, Any]] = None, init_strategy_primary: Union[None, str, np.ndarray, Dict[str, Any]] = None, init_strategy_secondary: Union[None, str, np.ndarray, Dict[str, Any]] = None, init_strategy_tertiary: Union[None, str, np.ndarray, Dict[str, Any]] = None, ): """Initialize a new MolecularPartitioner instance.""" self.tiers = tiers self.method = method self.random_state = random_state self.topology_refine = topology_refine self.topology_mode = topology_mode self.topology_hops = topology_hops self.topology_layers = topology_layers self.topology_k_per_layer = topology_k_per_layer self.topology_bond_policy = topology_bond_policy self.n_primary: Optional[int] = None self.n_secondary: Optional[int] = None self.n_tertiary: Optional[int] = None if tiers is not None: # Tiered mode if tiers not in (2, 3): raise ValueError(f"tiers must be 2 or 3, got {tiers}") if n_primary is None or n_secondary is None: raise ValueError("n_primary and n_secondary are required for tiered mode") if tiers == 3 and n_tertiary is None: raise ValueError("n_tertiary is required for 3-tier partitioning") self.n_primary = n_primary self.n_secondary = n_secondary self.n_tertiary = n_tertiary # n_fragments is still set for compatibility (total leaf fragments) total = n_primary * n_secondary if tiers == 3 and n_tertiary is not None: total *= n_tertiary self.n_fragments = total else: # Flat mode if n_fragments <= 0: raise ValueError(f"n_fragments must be positive, got {n_fragments}") self.n_fragments = n_fragments # Per-tier init strategies: per-tier override > general > None self._init_strategy = init_strategy self._init_primary = ( init_strategy_primary if init_strategy_primary is not None else init_strategy ) self._init_secondary = ( init_strategy_secondary if init_strategy_secondary is not None else init_strategy ) self._init_tertiary = ( init_strategy_tertiary if init_strategy_tertiary is not None else init_strategy ) if strict_balanced is None: self.strict_balanced = method == "kmeans_constrained" else: self.strict_balanced = strict_balanced
def partition( self, system: ChemicalSystem, source_file: str | None = None, ) -> FragmentTree: """ Partition a chemical system into fragments. Parameters ---------- system : ChemicalSystem Chemical system to partition. source_file : str, optional Path to source file for metadata. Returns ------- FragmentTree Fragmentation result. Flat for non-tiered mode, hierarchical for tiered mode. """ molecules = system_to_molecules(system, require_metadata=True) if self.tiers is not None: return self._partition_tiered(list(molecules), source_file) # --- Flat mode (existing path) --- result = self._build_partition(list(molecules)) if self.topology_refine: refined_labels = self._refine_partition_topology(system, result) result = PartitionResult( n_fragments=result.n_fragments, labels=refined_labels, centroids=result.centroids, molecules=result.molecules, ) if self.strict_balanced: self._validate_partition(result) fragments = self._build_fragments(result) # Build metadata source = {} if source_file: source = format_source_info(source_file, "xyz") partitioning = format_partitioning_info( algorithm=self.method, n_fragments=self.n_fragments, ) return FragmentTree( fragments=fragments, source=source, partitioning=partitioning, ) def _build_partition(self, molecules: List[Molecule]) -> PartitionResult: """Build the partition with labels (flat mode).""" centroids = compute_centroids(molecules) if self.method == "geom_planes": labels = partition_by_planes(centroids, self.n_fragments) else: labels = partition_labels( centroids, self.n_fragments, self.method, self.random_state, init=self._init_strategy, ) return PartitionResult( n_fragments=self.n_fragments, labels=labels.astype(int), centroids=centroids, molecules=molecules, ) def _molecule_atom_indices( self, system: ChemicalSystem, molecules: List[Molecule], ) -> List[List[int]]: """Get atom indices for each molecule in order.""" metadata = system.metadata or {} if "molecule_atom_indices" in metadata: return [list(indices) for indices in metadata["molecule_atom_indices"]] indices: List[List[int]] = [] start = 0 for mol in molecules: end = start + len(mol) indices.append(list(range(start, end))) start = end return indices def _refine_partition_topology( self, system: ChemicalSystem, result: PartitionResult, ) -> np.ndarray: """Optionally refine molecule labels using topology neighborhoods around cluster centers.""" labels = result.labels.copy().astype(int) n_molecules = len(result.molecules) if n_molecules == 0: return labels molecule_indices = self._molecule_atom_indices(system, result.molecules) coords = np.array([atom.coords for atom in system.atoms], dtype=float) elements = [atom.symbol for atom in system.atoms] bonds = [ (int(bond["atom1"]), int(bond["atom2"])) for bond in system.bonds if "atom1" in bond and "atom2" in bond ] # Representative molecule per cluster = one closest to cluster centroid. representatives: Dict[int, int] = {} for cluster_idx in range(self.n_fragments): members = np.where(labels == cluster_idx)[0].tolist() if not members: continue cluster_center = np.mean(result.centroids[members], axis=0) rep = min( members, key=lambda idx: float(np.linalg.norm(result.centroids[idx] - cluster_center)), ) representatives[cluster_idx] = rep if not representatives: return labels overlap_scores = np.zeros((self.n_fragments, n_molecules), dtype=int) for cluster_idx, rep_idx in representatives.items(): selector = TopologyNeighborSelection( seed_atoms=set(molecule_indices[rep_idx]), mode=self.topology_mode, hops=self.topology_hops, layers=self.topology_layers, k_per_layer=self.topology_k_per_layer, expand_residues=False, bond_policy=self.topology_bond_policy, ) selected_atoms = selector.select(coords, elements, bonds=bonds).selected_atoms for mol_idx, atom_ids in enumerate(molecule_indices): overlap_scores[cluster_idx, mol_idx] = len(selected_atoms.intersection(atom_ids)) refined = labels.copy() for mol_idx in range(n_molecules): current = int(labels[mol_idx]) best_cluster = current best_overlap = int(overlap_scores[current, mol_idx]) for cluster_idx in range(self.n_fragments): overlap = int(overlap_scores[cluster_idx, mol_idx]) if overlap > best_overlap: best_overlap = overlap best_cluster = cluster_idx if best_overlap > 0: refined[mol_idx] = best_cluster # Preserve valid non-empty partitioning; otherwise keep original labels. counts = [int(np.sum(refined == k)) for k in range(self.n_fragments)] if any(count == 0 for count in counts): return labels return refined def _validate_partition(self, result: PartitionResult) -> None: """Validate that cluster sizes are equal (when requested).""" n_molecules = len(result.molecules) if n_molecules % self.n_fragments != 0: raise PartitionError( f"Non-integer cluster sizes: {n_molecules} molecules / {self.n_fragments} fragments" ) expected = n_molecules // self.n_fragments for k in range(self.n_fragments): count_k = int(np.sum(result.labels == k)) if count_k != expected: raise PartitionError( f"Fragment {k}: size {count_k} != expected {expected}" ) def _build_fragments(self, result: PartitionResult) -> List[Fragment]: """Build Fragment objects from partition result (flat mode).""" fragments: List[Fragment] = [] for k in range(self.n_fragments): chosen = [ result.molecules[i] for i, lbl in enumerate(result.labels.tolist()) if int(lbl) == k ] if not chosen: raise PartitionError(f"Fragment {k} is empty") f = Fragment.from_molecules(chosen, f"F{k + 1}") f.metadata["n_molecules"] = len(chosen) fragments.append(f) return fragments # ------------------------------------------------------------------ # Tiered partitioning # ------------------------------------------------------------------ def _partition_tiered( self, molecules: List[Molecule], source_file: str | None = None, ) -> FragmentTree: """Orchestrate hierarchical clustering and build a tiered FragmentTree.""" result = self._build_partition_tiered(molecules) if self.strict_balanced: self._validate_tiered_hierarchy(result) fragments = self._build_tiered_fragments(result) # Build metadata source = {} if source_file: source = format_source_info(source_file, "xyz") partitioning = format_partitioning_info( algorithm=self.method, n_fragments=self.n_fragments, tiers=self.tiers, n_primary=self.n_primary, n_secondary=self.n_secondary, n_tertiary=self.n_tertiary, ) return FragmentTree( fragments=fragments, source=source, partitioning=partitioning, ) def _build_partition_tiered(self, molecules: List[Molecule]) -> TieredPartitionResult: """Build the tiered partition with tuple labels.""" centroids = compute_centroids(molecules) if self.method == "geom_planes": return self._partition_geometric_tiered(molecules, centroids) return self._partition_clustering_tiered(molecules, centroids) def _partition_geometric_tiered( self, molecules: List[Molecule], centroids: np.ndarray, ) -> TieredPartitionResult: """Partition using geometric planes (tiered mode).""" if self.tiers is None: raise ValueError("tiers must be set for tiered partitioning") if self.n_primary is None: raise ValueError("n_primary must be set for tiered partitioning") if self.n_secondary is None: raise ValueError("n_secondary must be set for tiered partitioning") n_t = self.n_tertiary if self.tiers == 3 else 1 prim, sec, ter = partition_by_planes_tiered( centroids, self.n_primary, self.n_secondary, n_t or 1 ) if self.tiers == 2: labels: List[Union[LabelTuple2, LabelTuple3]] = [ (int(p), int(s)) for p, s in zip(prim, sec) ] else: labels = [ (int(p), int(s), int(t)) for p, s, t in zip(prim, sec, ter) ] return TieredPartitionResult( tiers=self.tiers, n_primary=self.n_primary, n_secondary=self.n_secondary, n_tertiary=self.n_tertiary if self.tiers == 3 else None, labels=labels, centroids=centroids, molecules=molecules, ) def _partition_clustering_tiered( self, molecules: List[Molecule], centroids: np.ndarray, ) -> TieredPartitionResult: """Partition using clustering algorithms (tiered mode).""" if self.tiers is None: raise ValueError("tiers must be set for tiered partitioning") if self.n_primary is None: raise ValueError("n_primary must be set for tiered partitioning") if self.n_secondary is None: raise ValueError("n_secondary must be set for tiered partitioning") primary_labels = partition_labels( centroids, self.n_primary, self.method, self.random_state, init=self._init_primary, ) if self.tiers == 2: labels: List[Union[LabelTuple2, LabelTuple3]] = [(0, 0)] * len(molecules) for p in range(self.n_primary): idx = np.where(primary_labels == p)[0] if len(idx) == 0: raise PartitionError(f"Primary cluster {p} is empty") sec = partition_labels( centroids[idx], self.n_secondary, self.method, self.random_state, init=self._init_secondary, ) for local_k, wi in enumerate(idx): labels[int(wi)] = (int(p), int(sec[local_k])) return TieredPartitionResult( tiers=self.tiers, n_primary=self.n_primary, n_secondary=self.n_secondary, n_tertiary=None, labels=labels, centroids=centroids, molecules=molecules, ) # 3-tier partitioning if self.n_tertiary is None: raise ValueError("n_tertiary must be set for 3-tier partitioning") labels3: List[Union[LabelTuple2, LabelTuple3]] = [(0, 0, 0)] * len(molecules) n_t = self.n_tertiary for p in range(self.n_primary): idx_p = np.where(primary_labels == p)[0] if len(idx_p) == 0: raise PartitionError(f"Primary cluster {p} is empty") sec_p = partition_labels( centroids[idx_p], self.n_secondary, self.method, self.random_state, init=self._init_secondary, ) for s in range(self.n_secondary): idx_ps = idx_p[np.where(sec_p == s)[0]] if len(idx_ps) == 0: raise PartitionError(f"Secondary cluster {p}:{s} is empty") ter_ps = partition_labels( centroids[idx_ps], n_t, self.method, self.random_state, init=self._init_tertiary, ) for local_k, wi in enumerate(idx_ps): labels3[int(wi)] = (int(p), int(s), int(ter_ps[local_k])) return TieredPartitionResult( tiers=self.tiers, n_primary=self.n_primary, n_secondary=self.n_secondary, n_tertiary=n_t, labels=labels3, centroids=centroids, molecules=molecules, ) def _validate_tiered_hierarchy(self, result: TieredPartitionResult) -> None: """Validate that cluster sizes are equal at every tier level.""" if self.n_primary is None: raise ValueError("n_primary must be set for tiered validation") if self.n_secondary is None: raise ValueError("n_secondary must be set for tiered validation") n_molecules = len(result.molecules) if self.tiers == 2: expected_primary = n_molecules / self.n_primary expected_secondary = n_molecules / (self.n_primary * self.n_secondary) if expected_primary % 1 or expected_secondary % 1: raise PartitionError( f"Non-integer cluster sizes for 2-tier hierarchy: " f"{n_molecules} molecules / {self.n_primary} primary " f"/ {self.n_secondary} secondary" ) expected_primary = int(expected_primary) expected_secondary = int(expected_secondary) for p in range(self.n_primary): count_p = sum(1 for lab in result.labels if lab[0] == p) if count_p != expected_primary: raise PartitionError( f"Primary cluster {p}: size {count_p} != expected {expected_primary}" ) for s in range(self.n_secondary): count_ps = sum( 1 for lab in result.labels if lab[0] == p and lab[1] == s ) if count_ps != expected_secondary: raise PartitionError( f"Secondary cluster {p}:{s}: size {count_ps} " f"!= expected {expected_secondary}" ) elif self.tiers == 3: if self.n_tertiary is None: raise ValueError("n_tertiary must be set for 3-tier validation") n_t = self.n_tertiary expected_primary = n_molecules / self.n_primary expected_secondary = n_molecules / (self.n_primary * self.n_secondary) expected_tertiary = n_molecules / (self.n_primary * self.n_secondary * n_t) if expected_primary % 1 or expected_secondary % 1 or expected_tertiary % 1: raise PartitionError( "Non-integer cluster sizes for 3-tier hierarchy" ) expected_primary = int(expected_primary) expected_secondary = int(expected_secondary) expected_tertiary = int(expected_tertiary) for p in range(self.n_primary): count_p = sum(1 for lab in result.labels if lab[0] == p) if count_p != expected_primary: raise PartitionError( f"Primary cluster {p}: size {count_p} != expected {expected_primary}" ) for s in range(self.n_secondary): count_ps = sum( 1 for lab in result.labels if lab[0] == p and lab[1] == s ) if count_ps != expected_secondary: raise PartitionError( f"Secondary cluster {p}:{s}: size {count_ps} " f"!= expected {expected_secondary}" ) for t in range(n_t): count_pst = sum( 1 for lab in result.labels if len(lab) > 2 and lab[0] == p and lab[1] == s and lab[2] == t ) if count_pst != expected_tertiary: raise PartitionError( f"Tertiary cluster {p}:{s}:{t}: size {count_pst} " f"!= expected {expected_tertiary}" ) def _build_tiered_fragments(self, result: TieredPartitionResult) -> List[Fragment]: """Build hierarchical Fragment objects from tiered partition result.""" if self.n_primary is None: raise ValueError("n_primary must be set for building tiered fragments") if self.n_secondary is None: raise ValueError("n_secondary must be set for building tiered fragments") fragments: List[Fragment] = [] if self.tiers == 2: for p in range(self.n_primary): pf_id = f"PF{p + 1}" pf = Fragment(id=pf_id) for s in range(self.n_secondary): sf_id = f"{pf_id}_SF{s + 1}" chosen = [ result.molecules[i] for i, lab in enumerate(result.labels) if lab[0] == p and lab[1] == s ] sf = Fragment.from_molecules(chosen, sf_id) sf.metadata["n_molecules"] = len(chosen) pf.fragments.append(sf) fragments.append(pf) elif self.tiers == 3: if self.n_tertiary is None: raise ValueError("n_tertiary must be set for building 3-tier fragments") n_t = self.n_tertiary for p in range(self.n_primary): pf_id = f"PF{p + 1}" pf = Fragment(id=pf_id) for s in range(self.n_secondary): sf_id = f"{pf_id}_SF{s + 1}" sf = Fragment(id=sf_id) for t in range(n_t): tf_id = f"{sf_id}_TF{t + 1}" chosen = [ result.molecules[i] for i, lab in enumerate(result.labels) if len(lab) > 2 and lab[0] == p and lab[1] == s and lab[2] == t ] tf = Fragment.from_molecules(chosen, tf_id) tf.metadata["n_molecules"] = len(chosen) sf.fragments.append(tf) pf.fragments.append(sf) fragments.append(pf) return fragments