@@ -17,8 +17,7 @@ namespace ts.refactor.inferFunctionReturnType {
17
17
function getEditsForAction ( context : RefactorContext ) : RefactorEditInfo | undefined {
18
18
const info = getInfo ( context ) ;
19
19
if ( info && ! isRefactorErrorInfo ( info ) ) {
20
- const edits = textChanges . ChangeTracker . with ( context , t =>
21
- t . tryInsertTypeAnnotation ( context . file , info . declaration , info . returnTypeNode ) ) ;
20
+ const edits = textChanges . ChangeTracker . with ( context , t => doChange ( context . file , t , info . declaration , info . returnTypeNode ) ) ;
22
21
return { renameFilename : undefined , renameLocation : undefined , edits } ;
23
22
}
24
23
return undefined ;
@@ -55,6 +54,19 @@ namespace ts.refactor.inferFunctionReturnType {
55
54
returnTypeNode : TypeNode ;
56
55
}
57
56
57
+ function doChange ( sourceFile : SourceFile , changes : textChanges . ChangeTracker , declaration : ConvertibleDeclaration , typeNode : TypeNode ) {
58
+ const closeParen = findChildOfKind ( declaration , SyntaxKind . CloseParenToken , sourceFile ) ;
59
+ const needParens = isArrowFunction ( declaration ) && closeParen === undefined ;
60
+ const endNode = needParens ? first ( declaration . parameters ) : closeParen ;
61
+ if ( endNode ) {
62
+ if ( needParens ) {
63
+ changes . insertNodeBefore ( sourceFile , endNode , factory . createToken ( SyntaxKind . OpenParenToken ) ) ;
64
+ changes . insertNodeAfter ( sourceFile , endNode , factory . createToken ( SyntaxKind . CloseParenToken ) ) ;
65
+ }
66
+ changes . insertNodeAt ( sourceFile , endNode . end , typeNode , { prefix : ": " } ) ;
67
+ }
68
+ }
69
+
58
70
function getInfo ( context : RefactorContext ) : FunctionInfo | RefactorErrorInfo | undefined {
59
71
if ( isInJSFile ( context . file ) || ! refactorKindBeginsWith ( inferReturnTypeAction . kind , context . kind ) ) return ;
60
72
0 commit comments