@@ -11,7 +11,7 @@ use crate::{
11
11
ted:: { self , Position } ,
12
12
AstNode , AstToken , Direction ,
13
13
SyntaxKind :: { ATTR , COMMENT , WHITESPACE } ,
14
- SyntaxNode ,
14
+ SyntaxNode , SyntaxToken ,
15
15
} ;
16
16
17
17
use super :: HasName ;
@@ -506,19 +506,7 @@ impl ast::RecordExprFieldList {
506
506
507
507
let position = match self . fields ( ) . last ( ) {
508
508
Some ( last_field) => {
509
- let comma = match last_field
510
- . syntax ( )
511
- . siblings_with_tokens ( Direction :: Next )
512
- . filter_map ( |it| it. into_token ( ) )
513
- . find ( |it| it. kind ( ) == T ! [ , ] )
514
- {
515
- Some ( it) => it,
516
- None => {
517
- let comma = ast:: make:: token ( T ! [ , ] ) ;
518
- ted:: insert ( Position :: after ( last_field. syntax ( ) ) , & comma) ;
519
- comma
520
- }
521
- } ;
509
+ let comma = get_or_insert_comma_after ( last_field. syntax ( ) ) ;
522
510
Position :: after ( comma)
523
511
}
524
512
None => match self . l_curly_token ( ) {
@@ -579,19 +567,8 @@ impl ast::RecordPatFieldList {
579
567
580
568
let position = match self . fields ( ) . last ( ) {
581
569
Some ( last_field) => {
582
- let comma = match last_field
583
- . syntax ( )
584
- . siblings_with_tokens ( Direction :: Next )
585
- . filter_map ( |it| it. into_token ( ) )
586
- . find ( |it| it. kind ( ) == T ! [ , ] )
587
- {
588
- Some ( it) => it,
589
- None => {
590
- let comma = ast:: make:: token ( T ! [ , ] ) ;
591
- ted:: insert ( Position :: after ( last_field. syntax ( ) ) , & comma) ;
592
- comma
593
- }
594
- } ;
570
+ let syntax = last_field. syntax ( ) ;
571
+ let comma = get_or_insert_comma_after ( syntax) ;
595
572
Position :: after ( comma)
596
573
}
597
574
None => match self . l_curly_token ( ) {
@@ -606,12 +583,53 @@ impl ast::RecordPatFieldList {
606
583
}
607
584
}
608
585
}
586
+
587
+ fn get_or_insert_comma_after ( syntax : & SyntaxNode ) -> SyntaxToken {
588
+ let comma = match syntax
589
+ . siblings_with_tokens ( Direction :: Next )
590
+ . filter_map ( |it| it. into_token ( ) )
591
+ . find ( |it| it. kind ( ) == T ! [ , ] )
592
+ {
593
+ Some ( it) => it,
594
+ None => {
595
+ let comma = ast:: make:: token ( T ! [ , ] ) ;
596
+ ted:: insert ( Position :: after ( syntax) , & comma) ;
597
+ comma
598
+ }
599
+ } ;
600
+ comma
601
+ }
602
+
609
603
impl ast:: StmtList {
610
604
pub fn push_front ( & self , statement : ast:: Stmt ) {
611
605
ted:: insert ( Position :: after ( self . l_curly_token ( ) . unwrap ( ) ) , statement. syntax ( ) ) ;
612
606
}
613
607
}
614
608
609
+ impl ast:: VariantList {
610
+ pub fn add_variant ( & self , variant : ast:: Variant ) {
611
+ let ( indent, position) = match self . variants ( ) . last ( ) {
612
+ Some ( last_item) => (
613
+ IndentLevel :: from_node ( last_item. syntax ( ) ) ,
614
+ Position :: after ( get_or_insert_comma_after ( last_item. syntax ( ) ) ) ,
615
+ ) ,
616
+ None => match self . l_curly_token ( ) {
617
+ Some ( l_curly) => {
618
+ normalize_ws_between_braces ( self . syntax ( ) ) ;
619
+ ( IndentLevel :: from_token ( & l_curly) + 1 , Position :: after ( & l_curly) )
620
+ }
621
+ None => ( IndentLevel :: single ( ) , Position :: last_child_of ( self . syntax ( ) ) ) ,
622
+ } ,
623
+ } ;
624
+ let elements: Vec < SyntaxElement < _ > > = vec ! [
625
+ make:: tokens:: whitespace( & format!( "{}{}" , "\n " , indent) ) . into( ) ,
626
+ variant. syntax( ) . clone( ) . into( ) ,
627
+ ast:: make:: token( T ![ , ] ) . into( ) ,
628
+ ] ;
629
+ ted:: insert_all ( position, elements) ;
630
+ }
631
+ }
632
+
615
633
fn normalize_ws_between_braces ( node : & SyntaxNode ) -> Option < ( ) > {
616
634
let l = node
617
635
. children_with_tokens ( )
@@ -661,6 +679,9 @@ impl<N: AstNode + Clone> Indent for N {}
661
679
mod tests {
662
680
use std:: fmt;
663
681
682
+ use stdx:: trim_indent;
683
+ use test_utils:: assert_eq_text;
684
+
664
685
use crate :: SourceFile ;
665
686
666
687
use super :: * ;
@@ -714,4 +735,100 @@ mod tests {
714
735
}" ,
715
736
) ;
716
737
}
738
+
739
+ #[ test]
740
+ fn add_variant_to_empty_enum ( ) {
741
+ let variant = make:: variant ( make:: name ( "Bar" ) , None ) . clone_for_update ( ) ;
742
+
743
+ check_add_variant (
744
+ r#"
745
+ enum Foo {}
746
+ "# ,
747
+ r#"
748
+ enum Foo {
749
+ Bar,
750
+ }
751
+ "# ,
752
+ variant,
753
+ ) ;
754
+ }
755
+
756
+ #[ test]
757
+ fn add_variant_to_non_empty_enum ( ) {
758
+ let variant = make:: variant ( make:: name ( "Baz" ) , None ) . clone_for_update ( ) ;
759
+
760
+ check_add_variant (
761
+ r#"
762
+ enum Foo {
763
+ Bar,
764
+ }
765
+ "# ,
766
+ r#"
767
+ enum Foo {
768
+ Bar,
769
+ Baz,
770
+ }
771
+ "# ,
772
+ variant,
773
+ ) ;
774
+ }
775
+
776
+ #[ test]
777
+ fn add_variant_with_tuple_field_list ( ) {
778
+ let variant = make:: variant (
779
+ make:: name ( "Baz" ) ,
780
+ Some ( ast:: FieldList :: TupleFieldList ( make:: tuple_field_list ( std:: iter:: once (
781
+ make:: tuple_field ( None , make:: ty ( "bool" ) ) ,
782
+ ) ) ) ) ,
783
+ )
784
+ . clone_for_update ( ) ;
785
+
786
+ check_add_variant (
787
+ r#"
788
+ enum Foo {
789
+ Bar,
790
+ }
791
+ "# ,
792
+ r#"
793
+ enum Foo {
794
+ Bar,
795
+ Baz(bool),
796
+ }
797
+ "# ,
798
+ variant,
799
+ ) ;
800
+ }
801
+
802
+ #[ test]
803
+ fn add_variant_with_record_field_list ( ) {
804
+ let variant = make:: variant (
805
+ make:: name ( "Baz" ) ,
806
+ Some ( ast:: FieldList :: RecordFieldList ( make:: record_field_list ( std:: iter:: once (
807
+ make:: record_field ( None , make:: name ( "x" ) , make:: ty ( "bool" ) ) ,
808
+ ) ) ) ) ,
809
+ )
810
+ . clone_for_update ( ) ;
811
+
812
+ check_add_variant (
813
+ r#"
814
+ enum Foo {
815
+ Bar,
816
+ }
817
+ "# ,
818
+ r#"
819
+ enum Foo {
820
+ Bar,
821
+ Baz { x: bool },
822
+ }
823
+ "# ,
824
+ variant,
825
+ ) ;
826
+ }
827
+
828
+ fn check_add_variant ( before : & str , expected : & str , variant : ast:: Variant ) {
829
+ let enum_ = ast_mut_from_text :: < ast:: Enum > ( before) ;
830
+ enum_. variant_list ( ) . map ( |it| it. add_variant ( variant) ) ;
831
+ let after = enum_. to_string ( ) ;
832
+ assert_eq_text ! ( & trim_indent( expected. trim( ) ) , & trim_indent( & after. trim( ) ) ) ;
833
+ }
717
834
}
0 commit comments