@@ -12,24 +12,45 @@ import { DataViewTableHead } from '../DataViewTableHead';
12
12
import { DataViewTh , DataViewTrTree , isDataViewTdObject } from '../DataViewTable' ;
13
13
import { DataViewState } from '../DataView/DataView' ;
14
14
15
- const getDescendants = ( node : DataViewTrTree ) : DataViewTrTree [ ] => ( ! node . children || ! node . children . length ) ? [ node ] : node . children . flatMap ( getDescendants ) ;
16
-
17
- const isNodeChecked = ( node : DataViewTrTree , isSelected : ( node : DataViewTrTree ) => boolean ) => {
18
- let allSelected = true ;
19
- let someSelected = false ;
20
-
21
- for ( const descendant of getDescendants ( node ) ) {
22
- const selected = ! ! isSelected ?.( descendant ) ;
23
-
24
- someSelected ||= selected ;
25
- allSelected &&= selected ;
26
-
27
- if ( ! allSelected && someSelected ) { return null }
28
- }
29
-
30
- return allSelected ;
15
+ const getNodesAffectedBySelection = (
16
+ allRows : DataViewTrTree [ ] ,
17
+ node : DataViewTrTree ,
18
+ isChecking : boolean ,
19
+ isSelected ?: ( item : DataViewTrTree ) => boolean
20
+ ) : DataViewTrTree [ ] => {
21
+
22
+ const getDescendants = ( node : DataViewTrTree ) : DataViewTrTree [ ] =>
23
+ node . children ? node . children . flatMap ( getDescendants ) . concat ( node ) : [ node ] ;
24
+
25
+ const findParent = ( child : DataViewTrTree , rows : DataViewTrTree [ ] ) : DataViewTrTree | undefined =>
26
+ rows . find ( row => row . children ?. some ( c => c === child ) ) ??
27
+ rows . flatMap ( row => row . children ?? [ ] ) . map ( c => findParent ( child , [ c ] ) ) . find ( p => p ) ;
28
+
29
+ const getAncestors = ( node : DataViewTrTree ) : DataViewTrTree [ ] => {
30
+ const ancestors : DataViewTrTree [ ] = [ ] ;
31
+ let parent = findParent ( node , allRows ) ;
32
+ while ( parent ) {
33
+ ancestors . push ( parent ) ;
34
+ parent = findParent ( parent , allRows ) ;
35
+ }
36
+ return ancestors ;
37
+ } ;
38
+
39
+ const affectedNodes = new Set ( [ node , ...getDescendants ( node ) ] ) ;
40
+
41
+ getAncestors ( node ) . forEach ( ancestor => {
42
+ const allChildrenSelected = ancestor . children ?. every ( child => isSelected ?.( child ) || affectedNodes . has ( child ) ) ;
43
+ const anyChildAffected = ancestor . children ?. some ( child => affectedNodes . has ( child ) || child . id === node . id ) ;
44
+
45
+ if ( isChecking ? ! isSelected ?.( ancestor ) && allChildrenSelected : isSelected ?.( ancestor ) && anyChildAffected ) {
46
+ affectedNodes . add ( ancestor ) ;
47
+ }
48
+ } ) ;
49
+
50
+ return Array . from ( affectedNodes ) ;
31
51
} ;
32
52
53
+
33
54
/** extends TableProps */
34
55
export interface DataViewTableTreeProps extends Omit < TableProps , 'onSelect' | 'rows' > {
35
56
/** Columns definition */
@@ -83,7 +104,7 @@ export const DataViewTableTree: React.FC<DataViewTableTreeProps> = ({
83
104
}
84
105
const isExpanded = expandedNodeIds . includes ( node . id ) ;
85
106
const isDetailsExpanded = expandedDetailsNodeNames . includes ( node . id ) ;
86
- const isChecked = isSelected && isNodeChecked ( node , isSelected ) ;
107
+ const isChecked = isSelected ?. ( node ) ;
87
108
let icon = leafIcon ;
88
109
if ( node . children ) {
89
110
icon = isExpanded ? expandedIcon : collapsedIcon ;
@@ -100,7 +121,7 @@ export const DataViewTableTree: React.FC<DataViewTableTreeProps> = ({
100
121
const otherDetailsExpandedNodeIds = prevDetailsExpanded . filter ( id => id !== node . id ) ;
101
122
return isDetailsExpanded ? otherDetailsExpandedNodeIds : [ ...otherDetailsExpandedNodeIds , node . id ] ;
102
123
} ) ,
103
- onCheckChange : ( isSelectDisabled ?.( node ) || ! onSelect ) ? undefined : ( _event , isChecking ) => onSelect ?.( isChecking , getDescendants ( node ) ) ,
124
+ onCheckChange : ( isSelectDisabled ?.( node ) || ! onSelect ) ? undefined : ( _event , isChecking ) => onSelect ?.( isChecking , getNodesAffectedBySelection ( rows , node , isChecking , isSelected ) ) ,
104
125
rowIndex,
105
126
props : {
106
127
isExpanded,
0 commit comments