Skip to content

Commit 57cab6d

Browse files
authored
Predict parameter values in the suggestion (#12984)
* Get the parameter value from the history. * Add a mock ps console for testing purpose. - The mock ps console will echo back most of the commands. So that we don't need to really execute the Az command on Azure to test the prediction.
1 parent 7893ac5 commit 57cab6d

14 files changed

+465
-106
lines changed

tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/AzPredictorServiceTests.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ public AzPredictorServiceTests(ModelFixture fixture)
3838
{
3939
this._fixture = fixture;
4040
var startHistory = $"{AzPredictorConstants.CommandHistoryPlaceholder}{AzPredictorConstants.CommandConcatenator}{AzPredictorConstants.CommandHistoryPlaceholder}";
41-
this._suggestionsPredictor = new Predictor(this._fixture.PredictionCollection[startHistory]);
42-
this._commandsPredictor = new Predictor(this._fixture.CommandCollection);
41+
this._suggestionsPredictor = new Predictor(this._fixture.PredictionCollection[startHistory], null);
42+
this._commandsPredictor = new Predictor(this._fixture.CommandCollection, null);
4343

4444
this._service = new MockAzPredictorService(startHistory, this._fixture.PredictionCollection[startHistory], this._fixture.CommandCollection);
4545
}

tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/Mocks/MockAzPredictorService.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public MockAzPredictorService(string history, IList<string> suggestions, IList<s
4040
}
4141

4242
/// <inheritdoc/>
43-
public override void RequestPredictions(string history)
43+
public override void RequestPredictions(IEnumerable<string> history)
4444
{
4545
this.IsPredictionRequested = true;
4646
}

tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/PredictorTests.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public PredictorTests(ModelFixture fixture)
3434
{
3535
this._fixture = fixture;
3636
var startHistory = $"{AzPredictorConstants.CommandHistoryPlaceholder}{AzPredictorConstants.CommandConcatenator}{AzPredictorConstants.CommandHistoryPlaceholder}";
37-
this._predictor = new Predictor(this._fixture.PredictionCollection[startHistory]);
37+
this._predictor = new Predictor(this._fixture.PredictionCollection[startHistory], null);
3838
}
3939

4040
/// <summary>

tools/Az.Tools.Predictor/Az.Tools.Predictor.sln

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ VisualStudioVersion = 16.0.30426.262
55
MinimumVisualStudioVersion = 10.0.40219.1
66
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Az.Tools.Predictor", "Az.Tools.Predictor\Az.Tools.Predictor.csproj", "{E4A5F697-086C-4908-B90E-A31EE47ECF13}"
77
EndProject
8-
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Az.Tools.Predictor.Test", "Az.Tools.Predictor.Test\Az.Tools.Predictor.Test.csproj", "{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}"
8+
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Az.Tools.Predictor.Test", "Az.Tools.Predictor.Test\Az.Tools.Predictor.Test.csproj", "{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}"
9+
EndProject
10+
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MockPSConsole", "MockPSConsole\MockPSConsole.csproj", "{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}"
911
EndProject
1012
Global
1113
GlobalSection(SolutionConfigurationPlatforms) = preSolution
@@ -21,6 +23,10 @@ Global
2123
{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}.Debug|Any CPU.Build.0 = Debug|Any CPU
2224
{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}.Release|Any CPU.ActiveCfg = Release|Any CPU
2325
{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}.Release|Any CPU.Build.0 = Release|Any CPU
26+
{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
27+
{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}.Debug|Any CPU.Build.0 = Debug|Any CPU
28+
{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}.Release|Any CPU.ActiveCfg = Release|Any CPU
29+
{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}.Release|Any CPU.Build.0 = Release|Any CPU
2430
EndGlobalSection
2531
GlobalSection(SolutionProperties) = preSolution
2632
HideSolutionNode = FALSE

tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictor.cs

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ public sealed class AzPredictor : ICommandPredictor
5151
private const int SuggestionCountForTelemetry = 5;
5252
private const string ParameterValueMask = "***";
5353
private const char ParameterValueSeperator = ':';
54-
private const char ParameterIndicator = '-';
5554

5655
private static readonly string[] CommonParameters = new string[] { "location" };
5756

@@ -74,45 +73,45 @@ public void StartEarlyProcessing(IReadOnlyList<string> history)
7473
{
7574
if (history.Count > 0)
7675
{
77-
var historyLines = history.TakeLast(AzPredictorConstants.CommandHistoryCountToProcess).ToList();
76+
var historyLines = history.TakeLast(AzPredictorConstants.CommandHistoryCountToProcess);
7877

79-
while (historyLines.Count < AzPredictorConstants.CommandHistoryCountToProcess)
78+
while (historyLines.Count() < AzPredictorConstants.CommandHistoryCountToProcess)
8079
{
81-
historyLines.Insert(0, AzPredictorConstants.CommandHistoryPlaceholder);
80+
historyLines = historyLines.Prepend(AzPredictorConstants.CommandHistoryPlaceholder);
8281
}
8382

84-
for (int i = historyLines.Count - 1; i >= 0; --i)
85-
{
86-
var ast = Parser.ParseInput(historyLines[i], out Token[] tokens, out _);
87-
var commandAsts = ast.FindAll((ast) => ast is CommandAst, true);
83+
var commandAsts = historyLines.Select((h) =>
84+
{
85+
var ast = Parser.ParseInput(h, out Token[] tokens, out _);
86+
var allAsts = ast?.FindAll((ast) => ast is CommandAst, true);
87+
return allAsts?.LastOrDefault() as CommandAst;
88+
}).ToArray();
8889

89-
if (!commandAsts.Any())
90-
{
91-
historyLines[i] = AzPredictorConstants.CommandHistoryPlaceholder;
92-
continue;
93-
}
90+
var maskedHistoryLines = commandAsts.Select((c) =>
91+
{
92+
var commandName = c?.CommandElements?.FirstOrDefault().ToString();
9493

95-
var lastCommandAst = commandAsts.Last() as CommandAst;
96-
var lastCommand = lastCommandAst?.CommandElements?.FirstOrDefault()?.ToString();
94+
if (!_service.IsSupportedCommand(commandName))
95+
{
96+
return AzPredictorConstants.CommandHistoryPlaceholder;
97+
}
9798

98-
if (string.IsNullOrWhiteSpace(lastCommand) || !_service.IsSupportedCommand(lastCommand))
99-
{
100-
historyLines[i] = AzPredictorConstants.CommandHistoryPlaceholder;
101-
continue;
102-
}
99+
return AzPredictor.MaskCommandLine(c);
100+
});
103101

104-
historyLines[i] = MaskCommandLine(lastCommandAst);
102+
var lastMaskedHistoryLines = maskedHistoryLines.Last();
105103

106-
if (i == historyLines.Count - 1)
107-
{
108-
var suggestionIndex = _service.GetRankOfSuggestion(lastCommandAst, ast);
109-
var fallbackIndex = _service.GetRankOfFallback(lastCommandAst, ast);
110-
var topFiveSuggestion = _service.GetTopNSuggestions(AzPredictor.SuggestionCountForTelemetry);
111-
_telemetryClient.OnSuggestionForHistory(historyLines[i], suggestionIndex, fallbackIndex, topFiveSuggestion);
112-
}
104+
if (lastMaskedHistoryLines != AzPredictorConstants.CommandHistoryPlaceholder)
105+
{
106+
var commandName = (commandAsts.LastOrDefault()?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value;
107+
var suggestionIndex = _service.GetRankOfSuggestion(commandName);
108+
var fallbackIndex = _service.GetRankOfFallback(commandName);
109+
var topFiveSuggestion = _service.GetTopNSuggestions(AzPredictor.SuggestionCountForTelemetry);
110+
_telemetryClient.OnSuggestionForHistory(maskedHistoryLines.LastOrDefault(), suggestionIndex, fallbackIndex, topFiveSuggestion);
113111
}
114112

115-
_service.RequestPredictions(String.Join(AzPredictorConstants.CommandConcatenator, historyLines));
113+
_service.RecordHistory(commandAsts);
114+
_service.RequestPredictions(maskedHistoryLines);
116115
}
117116
}
118117

@@ -175,7 +174,12 @@ private static string MergeStrings(string a, string b)
175174
/// <param name="cmdAst">The last user input command</param>
176175
private static string MaskCommandLine(CommandAst cmdAst)
177176
{
178-
var commandElements = cmdAst.CommandElements;
177+
var commandElements = cmdAst?.CommandElements;
178+
179+
if (commandElements == null)
180+
{
181+
return null;
182+
}
179183

180184
if (commandElements.Count == 1)
181185
{
@@ -196,15 +200,15 @@ private static string MaskCommandLine(CommandAst cmdAst)
196200
if (param.Argument != null)
197201
{
198202
// Parameter is in the form of `-Name:name`
199-
_ = sb.Append(AzPredictor.ParameterIndicator)
203+
_ = sb.Append(AzPredictorConstants.ParameterIndicator)
200204
.Append(param.ParameterName)
201205
.Append(AzPredictor.ParameterValueSeperator)
202206
.Append(AzPredictor.ParameterValueMask);
203207
}
204208
else
205209
{
206210
// Parameter is in the form of `-Name`
207-
_ = sb.Append(AzPredictor.ParameterIndicator)
211+
_ = sb.Append(AzPredictorConstants.ParameterIndicator)
208212
.Append(param.ParameterName)
209213
.Append(AzPredictorConstants.CommandParameterSeperator)
210214
.Append(AzPredictor.ParameterValueMask);
@@ -223,8 +227,8 @@ public class PredictorInitializer : IModuleAssemblyInitializer
223227
public void OnImport()
224228
{
225229
var settings = Settings.GetSettings();
226-
var azPredictorService = new AzPredictorService(settings.ServiceUri);
227230
var telemetryClient = new AzPredictorTelemetryClient();
231+
var azPredictorService = new AzPredictorService(settings.ServiceUri);
228232
var predictor = new AzPredictor(azPredictorService, telemetryClient);
229233
SubsystemManager.RegisterSubsystem<ICommandPredictor, AzPredictor>(predictor);
230234
}

tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorConstants.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ internal static class AzPredictorConstants
2727
/// <summary>
2828
/// The value to check to determine if it's an Az command.
2929
/// </summary>
30-
public const string AzCommandMoniktor = "az";
30+
public const string AzCommandMoniker = "-Az";
3131

3232
/// <summary>
3333
/// The character to use when we join the commands together.
@@ -54,6 +54,11 @@ internal static class AzPredictorConstants
5454
/// </summary>
5555
public const char CommandParameterSeperator = ' ';
5656

57+
/// <summary>
58+
/// The character that begins a parameter.
59+
/// </summary>
60+
public const char ParameterIndicator = '-';
61+
5762
/// <summary>
5863
/// The setting file name.
5964
/// </summary>

tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorService.cs

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@
1212
// limitations under the License.
1313
// ----------------------------------------------------------------------------------
1414

15+
using Microsoft.WindowsAzure.Commands.Utilities.Common;
1516
using Newtonsoft.Json;
1617
using Newtonsoft.Json.Serialization;
1718
using System;
19+
using System.Collections.Concurrent;
1820
using System.Collections.Generic;
1921
using System.Linq;
2022
using System.Management.Automation.Language;
@@ -48,16 +50,15 @@ public sealed class RequestContext
4850
public PredictionRequestBody(string history) => this.History = history;
4951
};
5052

51-
private const int PredictionRequestInProgress = 1;
52-
private const int PredictionRequestNotInProgress = 0;
5353
private static readonly HttpClient _client = new HttpClient();
5454
private readonly string _commandsEndpoint;
5555
private readonly string _predictionsEndpoint;
5656
private volatile Tuple<string, Predictor> _historySuggestions; // The history and the prediction for that.
5757
private volatile Predictor _commands;
5858
private volatile string _history;
59-
private HashSet<string> _commandSet = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
59+
private HashSet<string> _commandSet;
6060
private CancellationTokenSource _predictionRequestCancellationSource;
61+
private ParameterValuePredictor _parameterValuePredictor = new ParameterValuePredictor();
6162

6263
/// <summary>
6364
/// The AzPredictor service interacts with the Aladdin service specified in serviceUri.
@@ -144,53 +145,58 @@ public Tuple<string, PredictionSource> GetSuggestion(Ast input, CancellationToke
144145
}
145146

146147
/// <inheritdoc/>
147-
public virtual void RequestPredictions(string history)
148+
public virtual void RequestPredictions(IEnumerable<string> history)
148149
{
149150
// Even if it's called multiple times, we only need to keep the one for the latest history.
150151

151152
this._predictionRequestCancellationSource?.Cancel();
152153
this._predictionRequestCancellationSource = new CancellationTokenSource();
153154
var cancellationToken = this._predictionRequestCancellationSource.Token;
154-
this._history = history;
155+
var localHistory = string.Join(AzPredictorConstants.CommandConcatenator, history);
156+
this._history = localHistory;
155157

156158
// We don't need to block on the task. We send the HTTP request and update prediction list at the background.
157159
Task.Run(async () => {
158-
var requestBody = JsonConvert.SerializeObject(new PredictionRequestBody(history));
160+
var requestBody = JsonConvert.SerializeObject(new PredictionRequestBody(localHistory));
159161
var httpResponseMessage = await _client.PostAsync(this._predictionsEndpoint, new StringContent(requestBody, Encoding.UTF8, "application/json"), cancellationToken);
160162

161163
var reply = await httpResponseMessage.Content.ReadAsStringAsync(cancellationToken);
162164
var suggestionsList = JsonConvert.DeserializeObject<List<string>>(reply);
163165

164-
this.SetSuggestionPredictor(history, suggestionsList);
166+
this.SetSuggestionPredictor(localHistory, suggestionsList);
165167
},
166168
cancellationToken);
167169
}
168170

169-
/// <summary>
170-
/// For logging purposes, get the rank of the user input in the model suggestions list.
171-
/// </summary>
172-
public int? GetRankOfSuggestion(CommandAst command, Ast input)
171+
/// <inheritdoc/>
172+
public virtual void RecordHistory(IEnumerable<CommandAst> history)
173+
{
174+
history.ForEach((h) => this._parameterValuePredictor.ProcessHistoryCommand(h));
175+
}
176+
177+
/// <inhericdoc/>
178+
public int? GetRankOfSuggestion(string commandName)
173179
{
174180
var historySuggestions = this._historySuggestions;
175-
return historySuggestions?.Item2?.GetCommandPrediction(command, input, CancellationToken.None).Item2;
181+
return historySuggestions?.Item2?.GetCommandPrediction(commandName, isCommandNameComplete: true, cancellationToken:CancellationToken.None).Item2;
176182
}
177183

178-
/// <inheritdoc/>
179-
public int? GetRankOfFallback(CommandAst command, Ast input)
184+
/// <inhericdoc/>
185+
public int? GetRankOfFallback(string commandName)
180186
{
181187
var commands = this._commands;
182-
return commands?.GetCommandPrediction(command, input, CancellationToken.None).Item2;
188+
return commands?.GetCommandPrediction(commandName, isCommandNameComplete:true, cancellationToken:CancellationToken.None).Item2;
183189
}
184190

185-
/// <inheritdoc/>
191+
/// <inhericdoc/>
186192
public IEnumerable<string> GetTopNSuggestions(int n)
187193
{
188194
var historySuggestions = this._historySuggestions;
189195
return historySuggestions?.Item2?.GetTopNPrediction(n);
190196
}
191197

192198
/// <inheritdoc/>
193-
public bool IsSupportedCommand(string cmd) => !string.IsNullOrWhiteSpace(cmd) && _commandSet.Contains(cmd);
199+
public bool IsSupportedCommand(string cmd) => !string.IsNullOrWhiteSpace(cmd) && (_commandSet?.Contains(cmd) == true);
194200

195201
/// <summary>
196202
/// Requests a list of popular commands from service. These commands are used as fallback suggestion
@@ -209,7 +215,10 @@ protected virtual void RequestCommands()
209215

210216
// Initialize predictions
211217
var startHistory = $"{AzPredictorConstants.CommandHistoryPlaceholder}{AzPredictorConstants.CommandConcatenator}{AzPredictorConstants.CommandHistoryPlaceholder}";
212-
RequestPredictions(startHistory);
218+
RequestPredictions(new string[] {
219+
AzPredictorConstants.CommandHistoryPlaceholder,
220+
AzPredictorConstants.CommandHistoryPlaceholder});
221+
213222
});
214223
}
215224

@@ -219,8 +228,8 @@ protected virtual void RequestCommands()
219228
/// <param name="commands">The command collection to set the predictor</param>
220229
protected void SetCommandsPredictor(IList<string> commands)
221230
{
222-
this._commands = new Predictor(commands);
223-
this._commandSet = new HashSet<string>(commands.Select(x => AzPredictorService.GetCommandName(x))); // this could be slow
231+
this._commands = new Predictor(commands, this._parameterValuePredictor);
232+
this._commandSet = commands.Select(x => AzPredictorService.GetCommandName(x)).ToHashSet<string>(StringComparer.OrdinalIgnoreCase); // this could be slow
224233

225234
}
226235

@@ -231,7 +240,7 @@ protected void SetCommandsPredictor(IList<string> commands)
231240
/// <param name="suggestions">The suggestion collection to set the predictor</param>
232241
protected void SetSuggestionPredictor(string history, IList<string> suggestions)
233242
{
234-
this._historySuggestions = Tuple.Create(history, new Predictor(suggestions));
243+
this._historySuggestions = Tuple.Create(history, new Predictor(suggestions, this._parameterValuePredictor));
235244
}
236245

237246
/// <summary>

tools/Az.Tools.Predictor/Az.Tools.Predictor/IAzPredictorService.cs

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,29 @@ public interface IAzPredictorService
3434
/// <summary>
3535
/// Requests predictions, given a history string.
3636
/// </summary>
37-
/// <param name="history">A history string could look like: "Get-AzContext -Name NAME\nSet-AzContext"</param>
38-
public void RequestPredictions(string history);
37+
/// <param name="history">A list of history commands</param>
38+
public void RequestPredictions(IEnumerable<string> history);
3939

4040
/// <summary>
41-
/// For logging purposes, get the rank of the user input in the model suggestions list.
41+
/// Record the history from PSReadLine.
4242
/// </summary>
43-
public int? GetRankOfSuggestion(CommandAst command, Ast input);
43+
/// <param name="history">A list of history commands</param>
44+
public void RecordHistory(IEnumerable<CommandAst> history);
4445

4546
/// <summary>
4647
/// Return true if command is part of known set of Az cmdlets, false otherwise.
4748
/// </summary>
4849
public bool IsSupportedCommand(string cmd);
4950

51+
/// <summary>
52+
/// For logging purposes, get the rank of the user input in the model suggestions list.
53+
/// </summary>
54+
public int? GetRankOfSuggestion(string commandName);
55+
5056
/// <summary>
5157
/// For logging purposes, get the rank of the user input in the fallback commands cache.
5258
/// </summary>
53-
public int? GetRankOfFallback(CommandAst command, Ast input);
59+
public int? GetRankOfFallback(string commandName);
5460

5561
/// <summary>
5662
/// For logging purposes, get the top N suggestions from the model suggestions list.

0 commit comments

Comments
 (0)