diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index 4aba752..fef3906 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -14,3 +14,4 @@ jobs:
uses: janosh/workflows/.github/workflows/npm-test-release.yml@main
with:
install-cmd: npm install --force
+ test-cmd: npm run test:unit
diff --git a/src/lib/structure/StructureScene.svelte b/src/lib/structure/StructureScene.svelte
index 253ea1b..11fd31a 100644
--- a/src/lib/structure/StructureScene.svelte
+++ b/src/lib/structure/StructureScene.svelte
@@ -51,7 +51,7 @@
export let active_site: Site | null = null
export let precision: string = `.3~f`
export let auto_rotate: number | boolean = 0 // auto rotate speed. set to 0 to disable auto rotation.
- export let bond_radius: number | undefined = undefined
+ export let bond_radius: number | undefined = 0.05
export let bond_opacity: number = 0.5
export let bond_color: string = `#ffffff` // must be hex code for
export let bonding_strategy: keyof typeof bonding_strategies = `nearest_neighbor`
@@ -84,7 +84,7 @@
}
// make bond thickness reactive to atom_radius unless bond_radius is set
- $: bond_thickness = bond_radius ?? 0.1 * atom_radius
+ $: bond_thickness = bond_radius ?? 0.05 * atom_radius
const gizmo_defaults: Partial> = {
horizontalPlacement: `left`,
size: 100,
diff --git a/src/lib/structure/bonding.ts b/src/lib/structure/bonding.ts
index 0c63f93..e022616 100644
--- a/src/lib/structure/bonding.ts
+++ b/src/lib/structure/bonding.ts
@@ -1,27 +1,30 @@
import type { BondPair, PymatgenStructure } from '$lib'
-import { euclidean_dist } from '$lib'
-// TODO add unit tests for these functions
+export type BondingAlgo = typeof max_dist | typeof nearest_neighbor
+
export function max_dist(
structure: PymatgenStructure,
{ max_bond_dist = 3, min_bond_dist = 0.4 } = {}, // in Angstroms
): BondPair[] {
// finds all pairs of atoms within the max_bond_dist cutoff
const bonds: BondPair[] = []
- const bond_set: Set = new Set()
+ const bond_set = new Set()
+ const max_bond_dist_sq = max_bond_dist ** 2
+ const min_bond_dist_sq = min_bond_dist ** 2
- for (let idx = 0; idx < structure.sites.length; idx++) {
- const { xyz } = structure.sites[idx]
+ for (let i = 0; i < structure.sites.length; i++) {
+ const { xyz: xyz1 } = structure.sites[i]
- for (let idx_2 = idx + 1; idx_2 < structure.sites.length; idx_2++) {
- const { xyz: xyz_2 } = structure.sites[idx_2]
+ for (let j = i + 1; j < structure.sites.length; j++) {
+ const { xyz: xyz2 } = structure.sites[j]
- const dist = euclidean_dist(xyz, xyz_2)
- if (dist < max_bond_dist && dist > min_bond_dist) {
- const bond_key = [xyz, xyz_2].sort().toString()
+ const dist_sq = euclidean_dist_sq(xyz1, xyz2)
+ if (dist_sq <= max_bond_dist_sq && dist_sq >= min_bond_dist_sq) {
+ const dist = Math.sqrt(dist_sq)
+ const bond_key = `${i},${j}`
if (!bond_set.has(bond_key)) {
bond_set.add(bond_key)
- bonds.push([xyz, xyz_2, idx, idx_2, dist])
+ bonds.push([xyz1, xyz2, i, j, dist])
}
}
}
@@ -34,29 +37,41 @@ export function nearest_neighbor(
{ scaling_factor = 1.2, min_bond_dist = 0.1 } = {}, // in Angstroms
): BondPair[] {
// finds bonds to sites less than scaling_factor farther away than the nearest neighbor
+
const num_sites = structure.sites.length
const bonds: BondPair[] = []
- const bond_set: Set = new Set()
+ const bond_set = new Set()
+ const min_bond_dist_sq = min_bond_dist ** 2
+
+ const nearest_distances = new Array(num_sites).fill(Infinity)
+ // First pass: find nearest neighbor distances
for (let i = 0; i < num_sites; i++) {
const { xyz: xyz1 } = structure.sites[i]
- let min_dist = Infinity
for (let j = i + 1; j < num_sites; j++) {
const { xyz: xyz2 } = structure.sites[j]
- const dist = euclidean_dist(xyz1, xyz2)
+ const dist_sq = euclidean_dist_sq(xyz1, xyz2)
- if (dist > min_bond_dist && dist < min_dist) {
- min_dist = dist
+ if (dist_sq >= min_bond_dist_sq) {
+ if (dist_sq < nearest_distances[i]) nearest_distances[i] = dist_sq
+ if (dist_sq < nearest_distances[j]) nearest_distances[j] = dist_sq
}
}
+ }
+
+ // Second pass: add bonds within scaled distance
+ for (let i = 0; i < num_sites; i++) {
+ const { xyz: xyz1 } = structure.sites[i]
+ const max_dist_sq = nearest_distances[i] * scaling_factor ** 2
for (let j = i + 1; j < num_sites; j++) {
const { xyz: xyz2 } = structure.sites[j]
- const dist = euclidean_dist(xyz1, xyz2)
+ const dist_sq = euclidean_dist_sq(xyz1, xyz2)
- if (dist <= min_dist * scaling_factor) {
- const bond_key = [xyz1, xyz2].sort().toString()
+ if (dist_sq >= min_bond_dist_sq && dist_sq <= max_dist_sq) {
+ const dist = Math.sqrt(dist_sq)
+ const bond_key = `${i},${j}`
if (!bond_set.has(bond_key)) {
bond_set.add(bond_key)
bonds.push([xyz1, xyz2, i, j, dist])
@@ -67,3 +82,9 @@ export function nearest_neighbor(
return bonds
}
+
+// redundant functionality-wise with euclidean_dist from $lib/math.ts but needed for performance
+// makes bonding algos 2-3x faster
+function euclidean_dist_sq(vec_a: number[], vec_b: number[]): number {
+ return vec_a.reduce((sum, _, i) => sum + (vec_a[i] - vec_b[i]) ** 2, 0)
+}
diff --git a/tests/unit/bonding.test.ts b/tests/unit/bonding.test.ts
new file mode 100644
index 0000000..dc02de9
--- /dev/null
+++ b/tests/unit/bonding.test.ts
@@ -0,0 +1,194 @@
+import type { PymatgenStructure } from '$lib/structure'
+import type { BondingAlgo } from '$lib/structure/bonding'
+import { max_dist, nearest_neighbor } from '$lib/structure/bonding'
+import { performance } from 'perf_hooks'
+import { describe, expect, test } from 'vitest'
+
+const ci_max_time_multiplier = process.env.CI ? 5 : 1
+
+// Function to generate a random structure
+function make_rand_structure(numAtoms: number) {
+ return {
+ sites: Array.from({ length: numAtoms }, () => ({
+ xyz: [Math.random() * 10, Math.random() * 10, Math.random() * 10],
+ })),
+ } as PymatgenStructure
+}
+
+// Updated performance test function
+function perf_test(func: BondingAlgo, atom_count: number, max_time: number) {
+ const run = () => {
+ const structure = make_rand_structure(atom_count)
+ const start = performance.now()
+ func(structure)
+ const end = performance.now()
+ return end - start
+ }
+
+ const time1 = run()
+ const time2 = run()
+ const avg_time = (time1 + time2) / 2
+
+ expect(
+ avg_time,
+ `average run time: ${Math.ceil(avg_time)}, max expected: ${max_time * ci_max_time_multiplier}`, // Apply scaling factor
+ ).toBeLessThanOrEqual(max_time * ci_max_time_multiplier)
+}
+
+describe(`Bonding Functions Performance Tests`, () => {
+ const bonding_functions = [
+ {
+ func: max_dist,
+ max_times: [
+ [10, 0.1],
+ [100, 1],
+ [1000, 40],
+ [5000, 1000],
+ ],
+ },
+ {
+ func: nearest_neighbor,
+ max_times: [
+ [10, 0.2],
+ [100, 3],
+ [1000, 50],
+ [5000, 1000],
+ ],
+ },
+ ]
+
+ for (const { func, max_times } of bonding_functions) {
+ for (const [atom_count, max_time] of max_times) {
+ test(`${func.name} performance for ${atom_count} atoms`, () => {
+ perf_test(func, atom_count, max_time)
+ })
+ }
+ }
+})
+
+// Helper function to create a simple structure
+const make_struct = (sites: number[][]): PymatgenStructure => ({
+ sites: sites.map((xyz) => ({ xyz })),
+})
+
+describe(`max_dist function`, () => {
+ test(`should return correct bonds for a simple structure`, () => {
+ const structure = make_struct([
+ [0, 0, 0],
+ [1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1],
+ ])
+ const bonds = max_dist(structure, {
+ max_bond_dist: 1.5,
+ min_bond_dist: 0.5,
+ })
+ expect(bonds).toHaveLength(6)
+ expect(bonds).toContainEqual([[0, 0, 0], [1, 0, 0], 0, 1, 1])
+ expect(bonds).toContainEqual([[0, 0, 0], [0, 1, 0], 0, 2, 1])
+ expect(bonds).toContainEqual([[0, 0, 0], [0, 0, 1], 0, 3, 1])
+ })
+
+ test(`should not return bonds shorter than min_bond_dist`, () => {
+ const structure = make_struct([
+ [0, 0, 0],
+ [0.3, 0, 0],
+ ])
+ const bonds = max_dist(structure, { max_bond_dist: 1, min_bond_dist: 0.5 })
+ expect(bonds).toHaveLength(0)
+ })
+
+ test(`should not return bonds longer than max_bond_dist`, () => {
+ const structure = make_struct([
+ [0, 0, 0],
+ [2, 0, 0],
+ ])
+ const bonds = max_dist(structure, {
+ max_bond_dist: 1.5,
+ min_bond_dist: 0.5,
+ })
+ expect(bonds).toHaveLength(0)
+ })
+
+ test(`should handle empty structures`, () => {
+ const structure = make_struct([])
+ const bonds = max_dist(structure)
+ expect(bonds).toHaveLength(0)
+ })
+})
+
+describe(`nearest_neighbor function`, () => {
+ test(`should return correct bonds for a simple structure`, () => {
+ const structure = make_struct([
+ [0, 0, 0],
+ [1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1],
+ [2, 0, 0],
+ ])
+ const bonds = nearest_neighbor(structure, {
+ scaling_factor: 1.1,
+ min_bond_dist: 0.5,
+ })
+ expect(bonds).toHaveLength(4)
+ expect(bonds).toContainEqual([[0, 0, 0], [1, 0, 0], 0, 1, 1])
+ expect(bonds).toContainEqual([[0, 0, 0], [0, 1, 0], 0, 2, 1])
+ expect(bonds).toContainEqual([[0, 0, 0], [0, 0, 1], 0, 3, 1])
+ })
+
+ test(`should not return bonds shorter than min_bond_dist`, () => {
+ const structure = make_struct([
+ [0, 0, 0],
+ [0.05, 0, 0],
+ [1, 0, 0],
+ ])
+ const bonds = nearest_neighbor(structure, {
+ scaling_factor: 1.2,
+ min_bond_dist: 0.1,
+ })
+ expect(bonds).toHaveLength(2)
+ expect(bonds).toContainEqual([[0, 0, 0], [1, 0, 0], 0, 2, 1])
+ })
+
+ test(`should handle structures with multiple equidistant nearest neighbors`, () => {
+ const structure = make_struct([
+ [0, 0, 0],
+ [1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1],
+ ])
+ const bonds = nearest_neighbor(structure, {
+ scaling_factor: 1.1,
+ min_bond_dist: 0.5,
+ })
+ expect(bonds).toHaveLength(3)
+ expect(bonds).toContainEqual([[0, 0, 0], [1, 0, 0], 0, 1, 1])
+ expect(bonds).toContainEqual([[0, 0, 0], [0, 1, 0], 0, 2, 1])
+ expect(bonds).toContainEqual([[0, 0, 0], [0, 0, 1], 0, 3, 1])
+ })
+
+ test(`should handle empty structures`, () => {
+ const structure = make_struct([])
+ const bonds = nearest_neighbor(structure)
+ expect(bonds).toHaveLength(0)
+ })
+
+ test(`should respect the scaling_factor`, () => {
+ const structure = make_struct([
+ [0, 0, 0],
+ [1, 0, 0],
+ [0, 1, 0],
+ [0, 0, 1],
+ [1.5, 0, 0],
+ ])
+ const bonds = nearest_neighbor(structure, {
+ scaling_factor: 1.4,
+ min_bond_dist: 0.5,
+ })
+ expect(bonds).toHaveLength(4)
+ expect(bonds).toContainEqual([[0, 0, 0], [1, 0, 0], 0, 1, 1])
+ expect(bonds).toContainEqual([[0, 0, 0], [0, 1, 0], 0, 2, 1])
+ expect(bonds).toContainEqual([[0, 0, 0], [0, 0, 1], 0, 3, 1])
+ expect(bonds).toContainEqual([[1, 0, 0], [1.5, 0, 0], 1, 4, 0.5])
+ })
+})