Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Color and Molecule2 #17

Merged
merged 5 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .envrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
watch_file pixi.lock
eval "$(pixi shell-hook)"

1 change: 1 addition & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ jupyter = ">=1.1.1,<2"
opencv = "*"
numpy = "<2"
# pip = "*
polars = "*"
pillow = "*"
pytest = "*"
quarto = ">=1.6.40,<2"
Expand Down
91 changes: 91 additions & 0 deletions src/prettymol/color.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
from biotite.structure import AtomArray
# from molecularnodes.color import color_from_atomic_number, color_from_element, colors_from_elements, color_chains, color_chains_equidistant
from molecularnodes.color import (
color_from_atomic_number,
color_from_element,
colors_from_elements,
color_chains_equidistant,
color_chains,
iupac_colors_rgb
)

import numpy as np


class ColorArray(AtomArray):
def __init__(self, atom_array: AtomArray):
"""
Initialize a ColorArray from an existing AtomArray.

Parameters
----------
atom_array : AtomArray
The atom array to be converted into a ColorArray
"""
super().__init__(len(atom_array))

# Copy all annotations and coordinates from the input array
for annot in atom_array.get_annotation_categories():
self.set_annotation(annot, atom_array.get_annotation(annot))
self.coord = atom_array.coord.copy()

# Add color annotation
self.add_annotation("color", dtype=object)

def color_by_element(self):
"""
Assigns colors to atoms based on their chemical elements using IUPAC colors.
"""
atomic_numbers = np.array([atomic_number_map.get(elem, 0) for elem in self.element])
element_colors = colors_from_elements(atomic_numbers)
element_colors[:, 3] = 1.0
self.set_annotation("Color", element_colors)

def color_by_chain(self):
"""
Assigns colors to atoms based on their chain IDs using equidistant colors.
"""
chain_colors = color_chains_equidistant(self.chain_id)
chain_colors[:, 3] = 1.0
self.set_annotation("Color", chain_colors)

def color_by_chain_and_element(self):
"""
Assigns colors to atoms based on both chain IDs and elements.
Carbon atoms are colored by chain, other elements by their element color.
"""
atomic_numbers = np.array([atomic_number_map.get(elem, 0) for elem in self.element])
colors = color_chains(atomic_numbers, self.chain_id)
colors[:, 3] = 1.0
self.set_annotation("Color", colors)

def set_custom_colors(self, colors: np.ndarray):
"""
Sets custom colors for the atoms.

Parameters
----------
colors : np.ndarray
Array of RGBA colors with shape (n_atoms, 4)
"""
if len(colors) != len(self):
raise ValueError("Colors array must have the same length as the atom array")
if colors.shape[1] != 4:
raise ValueError("Colors must be RGBA values (shape: n_atoms x 4)")
self.set_annotation("Color", colors)

def get_colors(self) -> np.ndarray:
"""
Returns the current color array.

Returns
-------
np.ndarray
Array of RGBA colors for each atom
"""
return self.get_annotation("Color")

# Create atomic number mapping
atomic_number_map = {
element: i+1 for i, element in enumerate(iupac_colors_rgb.keys())
}
86 changes: 64 additions & 22 deletions src/prettymol/core.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,58 @@
import bpy
import databpy
import numpy as np
from typing import Union, Any
from biotite.structure import AtomArray, AtomArrayStack
from biotite.structure.io import pdbx
from biotite.structure import bonds
from molecularnodes.entities.molecule.base import _create_object
from molecularnodes.entities.molecule.base import _create_object, Molecule
from molecularnodes.download import download
from molecularnodes.blender import nodes as bl_nodes
from molecularnodes.blender.nodes import add_custom, get_input, get_output,get_mod, new_tree, styles_mapping

from .color import ColorArray
from .materials import Material, MaterialCreator
from .styles import BallStickStyle, CartoonStyle, RibbonStyle, SpheresStyle, SticksStyle, SurfaceStyle

from .molecule import Molecule2

# Modified form the original:
# https://github.com/BradyAJohnston/MolecularNodes/blob/main/molecularnodes/blender/nodes.py
# removes the color nodes to make a dead-simple node.
def create_starting_node_tree_minimal(
object: bpy.types.Object,
coll_frames: bpy.types.Collection | None = None,
style: str = "spheres",
name: str | None = None,
color: str = "common",
material: str = "MN Default",
is_modifier: bool = True,
) -> None:
mod = get_mod(object)
if not name:
name = f"MN_{object.name}"

try:
tree = bpy.data.node_groups[name]
mod.node_group = tree
return
except KeyError:
pass

tree = new_tree(name, input_name="Atoms")
tree.is_modifier = is_modifier
link = tree.links.new
mod.node_group = tree

# move the input and output nodes for the group
node_input = get_input(tree)
node_output = get_output(tree)
node_input.location = [0, 0]
node_output.location = [700, 0]
node_style = add_custom(tree, styles_mapping[style], [450, 0], material=material)
link(node_style.outputs[0], node_output.inputs[0])
link(node_input.outputs[0], node_style.inputs[0])
return None

# Connect bonds and center the structure
def load_pdb(code):
Expand All @@ -18,24 +61,24 @@ def load_pdb(code):
arr = next(iter(structures))
arr.bonds = bonds.connect_via_residue_names(arr)
arr.coord = arr.coord - np.mean(arr.coord, axis=0)
#arr = ColorArray(arr) # this will provide methods related to color
return arr


StyleType = Union[BallStickStyle, CartoonStyle, RibbonStyle, SpheresStyle, SticksStyle, SurfaceStyle]

def draw(arr: Any, style: StyleType, material: Material) -> None:

# ARR + Styles + Materials (+ Color Map....)
def draw(arr: AtomArray, style: StyleType, material: Material):

arr = ColorArray(arr)
arr.color_by_element()

# Create object and material
molname = f"mol_{id(arr)}"
matname = f"mol_{id(arr)}_mat"
obj, _ = _create_object(arr, name=molname)
bl_nodes.create_starting_node_tree(obj, style=style.style)

# Setup node tree
modifier = next(mod for mod in obj.modifiers if mod.type == "NODES")
node_tree = modifier.node_group
nodes = node_tree.nodes
style_node = next((node for node in nodes if "Style" in node.name), None)
mol = Molecule2.from_array(arr, name=molname)
obj = mol.create_object(style=style.style, color=None, name=f"{molname}")

# Create and setup material
mat = bpy.data.materials.new(matname)
Expand All @@ -46,22 +89,21 @@ def draw(arr: Any, style: StyleType, material: Material) -> None:
if value := material.get_by_key(input.name):
input.default_value = value

# Assign material to object
if obj.data.materials:
obj.data.materials[0] = mat
else:
obj.data.materials.append(mat)
mol.material = mat

# Apply style and materials
# Setup node tree
modifier = next(mod for mod in obj.modifiers if mod.type == "NODES")
# print(modifier)
node_tree = modifier.node_group
nodes = node_tree.nodes
style_node = next((node for node in nodes if "Style" in node.name), None)
print(style_node)

# Apply style overrides
if style_node:
for input in style_node.inputs:
if input.type != "GEOMETRY":
if value := style.get_by_key(input.name):
input.default_value = value

# Link material to style node
material_input = next((inp for inp in style_node.inputs if inp.name == "Material"), None)
if material_input:
material_input.default_value = mat

return obj
return mol.object
2 changes: 1 addition & 1 deletion src/prettymol/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class Material(StyleBase):
See the Blender documentation for full details:
https://docs.blender.org/manual/en/latest/render/shader_nodes/shader/principled.html
"""
base_color: Tuple[float, float, float, float] = field(default=(0.8, 0.8, 0.8, 1.0), metadata={"key": "Base Color"})
base_color: Tuple[float, float, float, float] = field(default=(0.8, 0.8, 0.8, 0.05), metadata={"key": "Base Color"})
metallic: float = field(default=0.0, metadata={"key": "Metallic"})
roughness: float = field(default=0.2, metadata={"key": "Roughness"})
ior: float = field(default=1.45, metadata={"key": "IOR"})
Expand Down
34 changes: 34 additions & 0 deletions src/prettymol/molecule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# shim for
import uuid
import bpy
from molecularnodes.session import MNSession
from molecularnodes.entities.molecule.base import Molecule
from molecularnodes.entities.molecule.pdb import _comp_secondary_structure
from molecularnodes.entities.base import EntityType
from biotite.structure import AtomArray

class Molecule2(Molecule):
def __init__(self, *args, **kwargs):
# super().__init__(file_path="UNUSED_TEST.pdb")
self._entity_type = EntityType.MOLECULE
self._uuid = str(uuid.uuid1())

@classmethod
def from_array(cls, array: AtomArray, name: str = "FromArray"):
instance = cls() # Create a new instance
instance.file = "ARRAY_LOADED_DIRECTLY"
instance._frames_collection = None
instance._entity_type = EntityType.MOLECULE
instance._assemblies = lambda : None #
instance._uuid = str(uuid.uuid1())
instance.array = instance._validate_structure_array(array)
instance.n_atoms = instance.array.array_length()
# from Molecular entity Init
# bpy.context.scene.MNSession.register_entity(instance)
#instance.create_object(name=name)
return instance # Return the new instance

def _validate_structure_array(self, array: AtomArray):
sec_struct = _comp_secondary_structure(array)
array.set_annotation("sec_struct", sec_struct)
return array
5 changes: 2 additions & 3 deletions src/prettymol/styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class BallStickStyle(StyleBase):
as_mesh: bool = field(default=True, metadata={"key": "As Mesh"})
sphere_radii: float = field(default=0.3, metadata={"key": "Sphere Radii"})
bond_split: bool = field(default=False, metadata={"key": "Bond Split"})
bond_find: bool = field(default=False, metadata={"key": "Bond Find"})
bond_find: bool = field(default=True, metadata={"key": "Bond Find"})
bond_radius: float = field(default=0.3, metadata={"key": "Bond Radius"})
color_blur: bool = field(default=False, metadata={"key": "Color Blur"})
shade_smooth: bool = field(default=True, metadata={"key": "Shade Smooth"})
Expand Down Expand Up @@ -71,7 +71,7 @@ class SpheresStyle(StyleBase):

@dataclass(frozen=True)
class SticksStyle(StyleBase):
style: str = field(default="stick", metadata={"key": "Style"})
style: str = field(default="sticks", metadata={"key": "Style"})
quality: int = field(default=2, metadata={"key": "Quality"})
radius: float = field(default=0.2, metadata={"key": "Radius"})
color_blur: bool = field(default=False, metadata={"key": "Color Blur"})
Expand All @@ -95,7 +95,6 @@ class SurfaceStyle(StyleBase):

StyleType = Union[BallStickStyle, CartoonStyle, RibbonStyle, SpheresStyle, SticksStyle, SurfaceStyle]


class StyleCreator():
def new() -> StyleType:
return CartoonStyle()
Expand Down