1
- import React , { FC , ReactElement , cloneElement , useCallback , useEffect , useRef } from 'react'
2
- import { mergeRefs , isTabbable } from './utils'
3
- import { TABBABLE_SELECTOR } from './const'
1
+ import React , { FC , ReactElement , cloneElement , useEffect , useRef } from 'react'
2
+ import { mergeRefs , focusableChildren } from './utils'
4
3
5
4
export interface CFocusTrapProps {
6
5
/**
@@ -12,6 +11,13 @@ export interface CFocusTrapProps {
12
11
*/
13
12
active ?: boolean
14
13
14
+ /**
15
+ * Additional container elements to include in the focus trap.
16
+ * Useful for floating elements like tooltips or popovers that are
17
+ * rendered outside the main container but should be part of the trap.
18
+ */
19
+ additionalContainer ?: React . RefObject < HTMLElement | null >
20
+
15
21
/**
16
22
* Single React element that renders a DOM node and forwards refs properly.
17
23
* The focus trap will be applied to this element and all its focusable descendants.
@@ -61,6 +67,7 @@ export interface CFocusTrapProps {
61
67
62
68
export const CFocusTrap : FC < CFocusTrapProps > = ( {
63
69
active = true ,
70
+ additionalContainer,
64
71
children,
65
72
focusFirstElement = false ,
66
73
onActivate,
@@ -69,141 +76,176 @@ export const CFocusTrap: FC<CFocusTrapProps> = ({
69
76
} ) => {
70
77
const containerRef = useRef < HTMLElement | null > ( null )
71
78
const prevFocusedRef = useRef < HTMLElement | null > ( null )
72
- const addedTabIndexRef = useRef < boolean > ( false )
73
79
const isActiveRef = useRef < boolean > ( false )
74
- const focusingRef = useRef < boolean > ( false )
75
-
76
- const getTabbables = useCallback ( ( ) : HTMLElement [ ] => {
77
- const container = containerRef . current
78
- if ( ! container ) {
79
- return [ ]
80
- }
81
-
82
- // eslint-disable-next-line unicorn/prefer-spread
83
- const candidates = Array . from ( container . querySelectorAll < HTMLElement > ( TABBABLE_SELECTOR ) )
84
- return candidates . filter ( ( el ) => isTabbable ( el ) )
85
- } , [ ] )
86
-
87
- const focusFirst = useCallback ( ( ) => {
88
- const container = containerRef . current
89
- if ( ! container || focusingRef . current ) {
90
- return
91
- }
92
-
93
- focusingRef . current = true
94
-
95
- const tabbables = getTabbables ( )
96
- const target = focusFirstElement ? ( tabbables [ 0 ] ?? container ) : container
97
- // Ensure root can receive focus if there are no tabbables
98
- if ( target === container && container . getAttribute ( 'tabindex' ) == null ) {
99
- container . setAttribute ( 'tabindex' , '-1' )
100
- addedTabIndexRef . current = true
101
- }
102
-
103
- target . focus ( { preventScroll : true } )
104
-
105
- // Reset the flag after a short delay to allow the focus event to complete
106
- setTimeout ( ( ) => {
107
- focusingRef . current = false
108
- } , 0 )
109
- } , [ getTabbables , focusFirstElement ] )
80
+ const lastTabNavDirectionRef = useRef < 'forward' | 'backward' > ( 'forward' )
81
+ const tabEventSourceRef = useRef < HTMLElement | null > ( null )
110
82
111
83
useEffect ( ( ) => {
112
84
const container = containerRef . current
85
+ const _additionalContainer = additionalContainer ?. current || null
86
+
113
87
if ( ! active || ! container ) {
114
88
if ( isActiveRef . current ) {
115
89
// Deactivate cleanup
116
- if ( restoreFocus && prevFocusedRef . current && document . contains ( prevFocusedRef . current ) ) {
90
+ if ( restoreFocus && prevFocusedRef . current ?. isConnected ) {
117
91
prevFocusedRef . current . focus ( { preventScroll : true } )
118
92
}
119
93
120
- if ( addedTabIndexRef . current ) {
121
- container ?. removeAttribute ( 'tabindex' )
122
- addedTabIndexRef . current = false
123
- }
124
-
125
94
onDeactivate ?.( )
126
95
isActiveRef . current = false
96
+ prevFocusedRef . current = null
127
97
}
128
98
129
99
return
130
100
}
131
101
102
+ // Remember focused element BEFORE we move focus into the trap
103
+ prevFocusedRef . current = document . activeElement as HTMLElement | null
104
+
132
105
// Activating...
133
106
isActiveRef . current = true
107
+
108
+ // Set initial focus
109
+ if ( focusFirstElement ) {
110
+ const elements = focusableChildren ( container )
111
+ if ( elements . length > 0 ) {
112
+ elements [ 0 ] . focus ( { preventScroll : true } )
113
+ } else {
114
+ // Fallback to container if no focusable elements
115
+ container . focus ( { preventScroll : true } )
116
+ }
117
+ } else {
118
+ container . focus ( { preventScroll : true } )
119
+ }
120
+
134
121
onActivate ?.( )
135
122
136
- // Remember focused element BEFORE we move focus into the trap
137
- prevFocusedRef . current = ( document . activeElement as HTMLElement ) ?? null
123
+ const handleFocusIn = ( event : FocusEvent ) => {
124
+ // Only handle focus events from tab navigation
125
+ if ( containerRef . current !== tabEventSourceRef . current ) {
126
+ return
127
+ }
138
128
139
- // Move focus inside if focus is outside the container
140
- if ( ! container . contains ( document . activeElement ) ) {
141
- focusFirst ( )
142
- }
129
+ const target = event . target as Node
143
130
144
- const handleKeyDown = ( e : KeyboardEvent ) => {
145
- if ( e . key !== 'Tab' ) {
131
+ // Allow focus within container
132
+ if ( target === document || target === container || container . contains ( target ) ) {
146
133
return
147
134
}
148
135
149
- const tabbables = getTabbables ( )
150
- const current = document . activeElement as HTMLElement | null
136
+ // Allow focus within additional elements
137
+ if (
138
+ _additionalContainer &&
139
+ ( target === _additionalContainer || _additionalContainer . contains ( target ) )
140
+ ) {
141
+ return
142
+ }
151
143
152
- if ( tabbables . length === 0 ) {
144
+ // Focus escaped, bring it back
145
+ const elements = focusableChildren ( container )
146
+
147
+ if ( elements . length === 0 ) {
153
148
container . focus ( { preventScroll : true } )
154
- e . preventDefault ( )
149
+ } else if ( lastTabNavDirectionRef . current === 'backward' ) {
150
+ elements . at ( - 1 ) ?. focus ( { preventScroll : true } )
151
+ } else {
152
+ elements [ 0 ] . focus ( { preventScroll : true } )
153
+ }
154
+ }
155
+
156
+ const handleKeyDown = ( event : KeyboardEvent ) => {
157
+ if ( event . key !== 'Tab' ) {
158
+ return
159
+ }
160
+
161
+ tabEventSourceRef . current = container
162
+ lastTabNavDirectionRef . current = event . shiftKey ? 'backward' : 'forward'
163
+
164
+ if ( ! _additionalContainer ) {
155
165
return
156
166
}
157
167
158
- const first = tabbables [ 0 ]
159
- const last = tabbables . at ( - 1 ) !
168
+ const containerElements = focusableChildren ( container )
169
+ const additionalElements = focusableChildren ( _additionalContainer )
160
170
161
- if ( e . shiftKey ) {
162
- if ( ! current || ! container . contains ( current ) || current === first ) {
163
- last . focus ( { preventScroll : true } )
164
- e . preventDefault ( )
171
+ if ( containerElements . length === 0 && additionalElements . length === 0 ) {
172
+ // No focusable elements, prevent tab
173
+ event . preventDefault ( )
174
+ return
175
+ }
176
+
177
+ const activeElement = document . activeElement as HTMLElement
178
+ const isInContainer = containerElements . includes ( activeElement )
179
+ const isInAdditional = additionalElements . includes ( activeElement )
180
+
181
+ // Handle tab navigation between container and additional elements
182
+ if ( isInContainer ) {
183
+ const index = containerElements . indexOf ( activeElement )
184
+
185
+ if (
186
+ ! event . shiftKey &&
187
+ index === containerElements . length - 1 &&
188
+ additionalElements . length > 0
189
+ ) {
190
+ // Tab forward from last container element to first additional element
191
+ event . preventDefault ( )
192
+ additionalElements [ 0 ] . focus ( { preventScroll : true } )
193
+ } else if ( event . shiftKey && index === 0 && additionalElements . length > 0 ) {
194
+ // Tab backward from first container element to last additional element
195
+ event . preventDefault ( )
196
+ additionalElements . at ( - 1 ) ?. focus ( { preventScroll : true } )
165
197
}
166
- } else {
167
- if ( ! current || ! container . contains ( current ) || current === last ) {
168
- first . focus ( { preventScroll : true } )
169
- e . preventDefault ( )
198
+ } else if ( isInAdditional ) {
199
+ const index = additionalElements . indexOf ( activeElement )
200
+
201
+ if (
202
+ ! event . shiftKey &&
203
+ index === additionalElements . length - 1 &&
204
+ containerElements . length > 0
205
+ ) {
206
+ // Tab forward from last additional element to first container element
207
+ event . preventDefault ( )
208
+ containerElements [ 0 ] . focus ( { preventScroll : true } )
209
+ } else if ( event . shiftKey && index === 0 && containerElements . length > 0 ) {
210
+ // Tab backward from first additional element to last container element
211
+ event . preventDefault ( )
212
+ containerElements . at ( - 1 ) ?. focus ( { preventScroll : true } )
170
213
}
171
214
}
172
215
}
173
216
174
- const handleFocusIn = ( e : FocusEvent ) => {
175
- const target = e . target as Node
176
- if ( ! container . contains ( target ) && ! focusingRef . current ) {
177
- // Redirect stray focus back into the trap
178
- focusFirst ( )
179
- }
217
+ // Add event listeners
218
+ container . addEventListener ( 'keydown' , handleKeyDown , true )
219
+ if ( _additionalContainer ) {
220
+ _additionalContainer . addEventListener ( 'keydown' , handleKeyDown , true )
180
221
}
181
-
182
- document . addEventListener ( 'keydown' , handleKeyDown , true )
183
222
document . addEventListener ( 'focusin' , handleFocusIn , true )
184
223
224
+ // Cleanup function
185
225
return ( ) => {
186
- document . removeEventListener ( 'keydown' , handleKeyDown , true )
226
+ container . removeEventListener ( 'keydown' , handleKeyDown , true )
227
+ if ( _additionalContainer ) {
228
+ _additionalContainer . removeEventListener ( 'keydown' , handleKeyDown , true )
229
+ }
187
230
document . removeEventListener ( 'focusin' , handleFocusIn , true )
188
231
189
232
// On unmount (also considered deactivation)
190
- if ( restoreFocus && prevFocusedRef . current && document . contains ( prevFocusedRef . current ) ) {
233
+ if ( restoreFocus && prevFocusedRef . current ?. isConnected ) {
191
234
prevFocusedRef . current . focus ( { preventScroll : true } )
192
235
}
193
236
194
- if ( addedTabIndexRef . current ) {
195
- container . removeAttribute ( 'tabindex' )
196
- addedTabIndexRef . current = false
237
+ if ( isActiveRef . current ) {
238
+ onDeactivate ?. ( )
239
+ isActiveRef . current = false
197
240
}
198
241
199
- onDeactivate ?.( )
200
- isActiveRef . current = false
242
+ prevFocusedRef . current = null
201
243
}
202
- } , [ active , focusFirst , getTabbables , onActivate , onDeactivate , restoreFocus ] )
244
+ } , [ active , additionalContainer , focusFirstElement , onActivate , onDeactivate , restoreFocus ] )
203
245
204
- // Attach our ref to the ONLY child — no extra wrappers.
246
+ // Attach our ref to the ONLY child — no extra wrappers
205
247
const onlyChild = React . Children . only ( children )
206
- const childRef = ( onlyChild as ReactElement & { ref ?: React . Ref < HTMLElement > } ) . ref
248
+ const childRef = ( onlyChild as React . ReactElement & { ref ?: React . Ref < HTMLElement > } ) . ref
207
249
const mergedRef = mergeRefs ( childRef , ( node : HTMLElement | null ) => {
208
250
containerRef . current = node
209
251
} )
0 commit comments