Skip to content

Commit

Permalink
restore the previous newline behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
Toddelismyname committed Feb 1, 2025
1 parent 501cf3e commit 198392d
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 20 deletions.
150 changes: 130 additions & 20 deletions format_precice_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,6 @@ def printElement(self, element, level):
self.printTagEnd(element, level=level)

def printChildren(self, element, level):
"""
Print all child elements in a specific order with enhanced formatting.
"""
if level > self.maxgrouplevel:
for child in element.getchildren():
self.printElement(child, level=level)
Expand Down Expand Up @@ -211,16 +208,66 @@ def custom_sort_key(elem):
)
)

# Print participant with reordered children
self.printTagStart(group, level=level)
# Separate different types of elements
mesh_elements = []
data_elements = []
mapping_elements = []

for child in sorted_participant_children:
self.printElement(child, level=level + 1)
# Add a newline after read/write mesh elements
if str(child.tag) in ['write-mesh', 'read-mesh']:
self.print()
self.printTagEnd(group, level=level)
if str(child.tag) in ['provide-mesh', 'receive-mesh']:
mesh_elements.append(child)
elif str(child.tag) in ['write-data', 'read-data']:
data_elements.append(child)
elif str(child.tag).startswith('mapping:'):
mapping_elements.append(child)

# Construct participant tag with attributes
participant_tag = "<{}".format(group.tag)
for attr, value in group.items():
participant_tag += ' {}="{}"'.format(attr, value)
participant_tag += ">"

# Print participant opening tag
self.print(self.indent * level + participant_tag)

# Print mesh elements
for child in mesh_elements:
self.printElement(child, level + 1)

# Add newline between mesh and data
if mesh_elements and data_elements:
self.print()

# Print data elements
for child in data_elements:
self.printElement(child, level + 1)

# Add newline before mapping
if data_elements and mapping_elements:
self.print()

# Print mapping elements with multi-line formatting
for mapping_elem in mapping_elements:
# Check if the mapping element has multiple attributes
if len(mapping_elem.items()) > 2:
self.print("{}<{}".format(self.indent * (level + 1), mapping_elem.tag))
for k, v in mapping_elem.items():
self.print("{}{}=\"{}\"".format(self.indent * (level + 2), k, v))
self.print("{} />".format(self.indent * (level + 1)))
else:
# Single-line formatting for simple mappings
self.printElement(mapping_elem, level + 1)

# Close participant tag
self.print("{}</participant>".format(self.indent * level))

# Add newline after participant if not the last element
if i < last:
self.print()

continue

# Special handling for coupling-scheme to pair relative-convergence-measure and exchange
# Special handling for coupling-scheme elements
elif 'coupling-scheme' in str(group.tag):
# Sort children of coupling-scheme
sorted_scheme_children = sorted(
Expand All @@ -229,18 +276,81 @@ def custom_sort_key(elem):
1 if str(child.tag) == 'exchange' else 2
)

# Print coupling-scheme with reordered children
self.printTagStart(group, level=level)
# Separate different types of elements
other_elements = []
exchange_elements = []
convergence_elements = []
acceleration_elements = []

for child in sorted_scheme_children:
self.printElement(child, level=level + 1)
self.printTagEnd(group, level=level)
tag = str(child.tag)
if tag == 'exchange':
exchange_elements.append(child)
elif tag == 'relative-convergence-measure':
convergence_elements.append(child)
elif tag == 'acceleration:IQN-ILS':
acceleration_elements.append(child)
else:
other_elements.append(child)

# Print coupling-scheme opening tag
self.print(self.indent * level + "<{}>".format(group.tag))

# Print initial elements
initial_elements = [
elem for elem in other_elements
if str(elem.tag) in ['participants', 'max-time', 'time-window-size']
]
for child in initial_elements:
self.printElement(child, level + 1)

# Print convergence measures first
if convergence_elements:
if initial_elements:
self.print()
for conv in convergence_elements:
self.printElement(conv, level + 1)

# Print exchanges
if exchange_elements:
if initial_elements or convergence_elements:
self.print()
for exchange in exchange_elements:
self.printElement(exchange, level + 1)

# Print max-iterations if present
max_iterations = [
elem for elem in other_elements
if str(elem.tag) == 'max-iterations'
]
if max_iterations:
if exchange_elements or convergence_elements or initial_elements:
self.print()
for child in max_iterations:
self.printElement(child, level + 1)

# Print acceleration elements
if acceleration_elements:
if exchange_elements or convergence_elements or max_iterations or initial_elements:
self.print()
for child in acceleration_elements:
self.printElement(child, level + 1)

# Close coupling-scheme tag
self.print("{}</{}>"
.format(self.indent * level, group.tag))

# Add newline after coupling-scheme if not the last element
if i < last:
self.print()

continue

# Default handling for other elements
else:
self.printElement(group, level=level)
# Print the element normally
self.printElement(group, level=level)

# Add a newline between groups, except for the last group or comments
if not (isComment(group) or (i == last)):
# Add an extra newline between top-level groups
if i < last:
self.print()

@staticmethod
Expand Down
3 changes: 3 additions & 0 deletions format_precice_config_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def custom_sort_key(elem):
if str(child.tag) in ['write-mesh', 'read-mesh']:
self.print()
self.printTagEnd(group, level=level)
self.print()

# Special handling for coupling-scheme to pair relative-convergence-measure and exchange
elif 'coupling-scheme' in str(group.tag):
Expand All @@ -183,10 +184,12 @@ def custom_sort_key(elem):
for child in sorted_scheme_children:
self.printElement(child, level=level + 1)
self.printTagEnd(group, level=level)
self.print()

# Default handling for other elements
else:
self.printElement(group, level=level)
self.print()

# Add a newline between groups, except for the last group or comments
if not (isComment(group) or (i == last)):
Expand Down

0 comments on commit 198392d

Please sign in to comment.