@@ -216,7 +216,11 @@ impl ClauseHeader<'_> {
216216 . decorator_list
217217 . last ( )
218218 . map_or_else ( || header. start ( ) , Ranged :: end) ;
219- find_keyword ( start_position, SimpleTokenKind :: Class , source)
219+ find_keyword (
220+ StartPosition :: ClauseStart ( start_position) ,
221+ SimpleTokenKind :: Class ,
222+ source,
223+ )
220224 }
221225 ClauseHeader :: Function ( header) => {
222226 let start_position = header
@@ -228,21 +232,39 @@ impl ClauseHeader<'_> {
228232 } else {
229233 SimpleTokenKind :: Def
230234 } ;
231- find_keyword ( start_position, keyword, source)
235+ find_keyword ( StartPosition :: ClauseStart ( start_position) , keyword, source)
232236 }
233- ClauseHeader :: If ( header) => find_keyword ( header. start ( ) , SimpleTokenKind :: If , source) ,
237+ ClauseHeader :: If ( header) => find_keyword (
238+ StartPosition :: clause_start ( header) ,
239+ SimpleTokenKind :: If ,
240+ source,
241+ ) ,
234242 ClauseHeader :: ElifElse ( ElifElseClause {
235243 test : None , range, ..
236- } ) => find_keyword ( range. start ( ) , SimpleTokenKind :: Else , source) ,
244+ } ) => find_keyword (
245+ StartPosition :: clause_start ( range) ,
246+ SimpleTokenKind :: Else ,
247+ source,
248+ ) ,
237249 ClauseHeader :: ElifElse ( ElifElseClause {
238250 test : Some ( _) ,
239251 range,
240252 ..
241- } ) => find_keyword ( range. start ( ) , SimpleTokenKind :: Elif , source) ,
242- ClauseHeader :: Try ( header) => find_keyword ( header. start ( ) , SimpleTokenKind :: Try , source) ,
243- ClauseHeader :: ExceptHandler ( header) => {
244- find_keyword ( header. start ( ) , SimpleTokenKind :: Except , source)
245- }
253+ } ) => find_keyword (
254+ StartPosition :: clause_start ( range) ,
255+ SimpleTokenKind :: Elif ,
256+ source,
257+ ) ,
258+ ClauseHeader :: Try ( header) => find_keyword (
259+ StartPosition :: clause_start ( header) ,
260+ SimpleTokenKind :: Try ,
261+ source,
262+ ) ,
263+ ClauseHeader :: ExceptHandler ( header) => find_keyword (
264+ StartPosition :: clause_start ( header) ,
265+ SimpleTokenKind :: Except ,
266+ source,
267+ ) ,
246268 ClauseHeader :: TryFinally ( header) => {
247269 let last_statement = header
248270 . orelse
@@ -253,36 +275,42 @@ impl ClauseHeader<'_> {
253275 . unwrap ( ) ;
254276
255277 find_keyword (
256- after_optional_semicolon ( last_statement. end ( ) , source ) ,
278+ StartPosition :: LastStatement ( last_statement. end ( ) ) ,
257279 SimpleTokenKind :: Finally ,
258280 source,
259281 )
260282 }
261- ClauseHeader :: Match ( header) => {
262- find_keyword ( header. start ( ) , SimpleTokenKind :: Match , source)
263- }
264- ClauseHeader :: MatchCase ( header) => {
265- find_keyword ( header. start ( ) , SimpleTokenKind :: Case , source)
266- }
283+ ClauseHeader :: Match ( header) => find_keyword (
284+ StartPosition :: clause_start ( header) ,
285+ SimpleTokenKind :: Match ,
286+ source,
287+ ) ,
288+ ClauseHeader :: MatchCase ( header) => find_keyword (
289+ StartPosition :: clause_start ( header) ,
290+ SimpleTokenKind :: Case ,
291+ source,
292+ ) ,
267293 ClauseHeader :: For ( header) => {
268294 let keyword = if header. is_async {
269295 SimpleTokenKind :: Async
270296 } else {
271297 SimpleTokenKind :: For
272298 } ;
273- find_keyword ( header. start ( ) , keyword, source)
274- }
275- ClauseHeader :: While ( header) => {
276- find_keyword ( header. start ( ) , SimpleTokenKind :: While , source)
299+ find_keyword ( StartPosition :: clause_start ( header) , keyword, source)
277300 }
301+ ClauseHeader :: While ( header) => find_keyword (
302+ StartPosition :: clause_start ( header) ,
303+ SimpleTokenKind :: While ,
304+ source,
305+ ) ,
278306 ClauseHeader :: With ( header) => {
279307 let keyword = if header. is_async {
280308 SimpleTokenKind :: Async
281309 } else {
282310 SimpleTokenKind :: With
283311 } ;
284312
285- find_keyword ( header . start ( ) , keyword, source)
313+ find_keyword ( StartPosition :: clause_start ( header ) , keyword, source)
286314 }
287315 ClauseHeader :: OrElse ( header) => match header {
288316 ElseClause :: Try ( try_stmt) => {
@@ -294,14 +322,14 @@ impl ClauseHeader<'_> {
294322 . unwrap ( ) ;
295323
296324 find_keyword (
297- after_optional_semicolon ( last_statement. end ( ) , source ) ,
325+ StartPosition :: LastStatement ( last_statement. end ( ) ) ,
298326 SimpleTokenKind :: Else ,
299327 source,
300328 )
301329 }
302330 ElseClause :: For ( StmtFor { body, .. } )
303331 | ElseClause :: While ( StmtWhile { body, .. } ) => find_keyword (
304- after_optional_semicolon ( body. last ( ) . unwrap ( ) . end ( ) , source ) ,
332+ StartPosition :: LastStatement ( body. last ( ) . unwrap ( ) . end ( ) ) ,
305333 SimpleTokenKind :: Else ,
306334 source,
307335 ) ,
@@ -444,16 +472,41 @@ impl Format<PyFormatContext<'_>> for FormatClauseBody<'_> {
444472 }
445473}
446474
447- /// Finds the range of `keyword` starting the search at `start_position`. Expects only trivia between
448- /// the `start_position` and the `keyword` token.
475+ /// Finds the range of `keyword` starting the search at `start_position`.
476+ ///
477+ /// If the start position is at the end of the previous statement, the
478+ /// search will skip the optional semi-colon at the end of that statement.
479+ /// Other than this, we expect only trivia between the `start_position`
480+ /// and the keyword.
449481fn find_keyword (
450- start_position : TextSize ,
482+ start_position : StartPosition ,
451483 keyword : SimpleTokenKind ,
452484 source : & str ,
453485) -> FormatResult < TextRange > {
454- let mut tokenizer = SimpleTokenizer :: starts_at ( start_position, source) . skip_trivia ( ) ;
486+ let next_token = match start_position {
487+ StartPosition :: ClauseStart ( text_size) => SimpleTokenizer :: starts_at ( text_size, source)
488+ . skip_trivia ( )
489+ . next ( ) ,
490+ StartPosition :: LastStatement ( text_size) => {
491+ let mut tokenizer = SimpleTokenizer :: starts_at ( text_size, source) . skip_trivia ( ) ;
492+
493+ let mut token = tokenizer. next ( ) ;
494+
495+ // If the last statement ends with a semi-colon, skip it.
496+ if matches ! (
497+ token,
498+ Some ( SimpleToken {
499+ kind: SimpleTokenKind :: Semi ,
500+ ..
501+ } )
502+ ) {
503+ token = tokenizer. next ( ) ;
504+ }
505+ token
506+ }
507+ } ;
455508
456- match tokenizer . next ( ) {
509+ match next_token {
457510 Some ( token) if token. kind ( ) == keyword => Ok ( token. range ( ) ) ,
458511 Some ( other) => {
459512 debug_assert ! (
@@ -476,15 +529,32 @@ fn find_keyword(
476529 }
477530}
478531
479- fn after_optional_semicolon ( end_of_statement : TextSize , source : & str ) -> TextSize {
480- let mut tokenizer = SimpleTokenizer :: starts_at ( end_of_statement, source) ;
532+ /// Offset directly before clause header.
533+ ///
534+ /// Can either be the beginning of the clause header
535+ /// or the end of the last statement preceding the clause.
536+ #[ derive( Clone , Copy ) ]
537+ enum StartPosition {
538+ /// The beginning of a clause header
539+ ClauseStart ( TextSize ) ,
540+ /// The end of the last statement in the suite preceding a clause.
541+ ///
542+ /// For example:
543+ /// ```python
544+ /// if cond:
545+ /// a
546+ /// b
547+ /// c;
548+ /// # ...^here
549+ /// else:
550+ /// d
551+ /// ```
552+ LastStatement ( TextSize ) ,
553+ }
481554
482- if let Some ( tok) = tokenizer. next ( )
483- && tok. kind ( ) == SimpleTokenKind :: Semi
484- {
485- tok. end ( )
486- } else {
487- end_of_statement
555+ impl StartPosition {
556+ fn clause_start ( ranged : impl Ranged ) -> Self {
557+ Self :: ClauseStart ( ranged. start ( ) )
488558 }
489559}
490560
0 commit comments