Skip to content

Predict parameter values in the suggestion #12984

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ public AzPredictorServiceTests(ModelFixture fixture)
{
this._fixture = fixture;
var startHistory = $"{AzPredictorConstants.CommandHistoryPlaceholder}{AzPredictorConstants.CommandConcatenator}{AzPredictorConstants.CommandHistoryPlaceholder}";
this._suggestionsPredictor = new Predictor(this._fixture.PredictionCollection[startHistory]);
this._commandsPredictor = new Predictor(this._fixture.CommandCollection);
this._suggestionsPredictor = new Predictor(this._fixture.PredictionCollection[startHistory], null);
this._commandsPredictor = new Predictor(this._fixture.CommandCollection, null);

this._service = new MockAzPredictorService(startHistory, this._fixture.PredictionCollection[startHistory], this._fixture.CommandCollection);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public MockAzPredictorService(string history, IList<string> suggestions, IList<s
}

/// <inheritdoc/>
public override void RequestPredictions(string history)
public override void RequestPredictions(IEnumerable<string> history)
{
this.IsPredictionRequested = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public PredictorTests(ModelFixture fixture)
{
this._fixture = fixture;
var startHistory = $"{AzPredictorConstants.CommandHistoryPlaceholder}{AzPredictorConstants.CommandConcatenator}{AzPredictorConstants.CommandHistoryPlaceholder}";
this._predictor = new Predictor(this._fixture.PredictionCollection[startHistory]);
this._predictor = new Predictor(this._fixture.PredictionCollection[startHistory], null);
}

/// <summary>
Expand Down
8 changes: 7 additions & 1 deletion tools/Az.Tools.Predictor/Az.Tools.Predictor.sln
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ VisualStudioVersion = 16.0.30426.262
MinimumVisualStudioVersion = 10.0.40219.1
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Az.Tools.Predictor", "Az.Tools.Predictor\Az.Tools.Predictor.csproj", "{E4A5F697-086C-4908-B90E-A31EE47ECF13}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Az.Tools.Predictor.Test", "Az.Tools.Predictor.Test\Az.Tools.Predictor.Test.csproj", "{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Az.Tools.Predictor.Test", "Az.Tools.Predictor.Test\Az.Tools.Predictor.Test.csproj", "{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MockPSConsole", "MockPSConsole\MockPSConsole.csproj", "{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Expand All @@ -21,6 +23,10 @@ Global
{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}.Debug|Any CPU.Build.0 = Debug|Any CPU
{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}.Release|Any CPU.ActiveCfg = Release|Any CPU
{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}.Release|Any CPU.Build.0 = Release|Any CPU
{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}.Debug|Any CPU.Build.0 = Debug|Any CPU
{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}.Release|Any CPU.ActiveCfg = Release|Any CPU
{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down
70 changes: 37 additions & 33 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ public sealed class AzPredictor : ICommandPredictor
private const int SuggestionCountForTelemetry = 5;
private const string ParameterValueMask = "***";
private const char ParameterValueSeperator = ':';
private const char ParameterIndicator = '-';

private static readonly string[] CommonParameters = new string[] { "location" };

Expand All @@ -74,45 +73,45 @@ public void StartEarlyProcessing(IReadOnlyList<string> history)
{
if (history.Count > 0)
{
var historyLines = history.TakeLast(AzPredictorConstants.CommandHistoryCountToProcess).ToList();
var historyLines = history.TakeLast(AzPredictorConstants.CommandHistoryCountToProcess);

while (historyLines.Count < AzPredictorConstants.CommandHistoryCountToProcess)
while (historyLines.Count() < AzPredictorConstants.CommandHistoryCountToProcess)
{
historyLines.Insert(0, AzPredictorConstants.CommandHistoryPlaceholder);
historyLines = historyLines.Prepend(AzPredictorConstants.CommandHistoryPlaceholder);
}

for (int i = historyLines.Count - 1; i >= 0; --i)
{
var ast = Parser.ParseInput(historyLines[i], out Token[] tokens, out _);
var commandAsts = ast.FindAll((ast) => ast is CommandAst, true);
var commandAsts = historyLines.Select((h) =>
{
var ast = Parser.ParseInput(h, out Token[] tokens, out _);
var allAsts = ast?.FindAll((ast) => ast is CommandAst, true);
return allAsts?.LastOrDefault() as CommandAst;
}).ToArray();

if (!commandAsts.Any())
{
historyLines[i] = AzPredictorConstants.CommandHistoryPlaceholder;
continue;
}
var maskedHistoryLines = commandAsts.Select((c) =>
{
var commandName = c?.CommandElements?.FirstOrDefault().ToString();

var lastCommandAst = commandAsts.Last() as CommandAst;
var lastCommand = lastCommandAst?.CommandElements?.FirstOrDefault()?.ToString();
if (!_service.IsSupportedCommand(commandName))
{
return AzPredictorConstants.CommandHistoryPlaceholder;
}

if (string.IsNullOrWhiteSpace(lastCommand) || !_service.IsSupportedCommand(lastCommand))
{
historyLines[i] = AzPredictorConstants.CommandHistoryPlaceholder;
continue;
}
return AzPredictor.MaskCommandLine(c);
});

historyLines[i] = MaskCommandLine(lastCommandAst);
var lastMaskedHistoryLines = maskedHistoryLines.Last();

if (i == historyLines.Count - 1)
{
var suggestionIndex = _service.GetRankOfSuggestion(lastCommandAst, ast);
var fallbackIndex = _service.GetRankOfFallback(lastCommandAst, ast);
var topFiveSuggestion = _service.GetTopNSuggestions(AzPredictor.SuggestionCountForTelemetry);
_telemetryClient.OnSuggestionForHistory(historyLines[i], suggestionIndex, fallbackIndex, topFiveSuggestion);
}
if (lastMaskedHistoryLines != AzPredictorConstants.CommandHistoryPlaceholder)
{
var commandName = (commandAsts.LastOrDefault()?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value;
var suggestionIndex = _service.GetRankOfSuggestion(commandName);
var fallbackIndex = _service.GetRankOfFallback(commandName);
var topFiveSuggestion = _service.GetTopNSuggestions(AzPredictor.SuggestionCountForTelemetry);
_telemetryClient.OnSuggestionForHistory(maskedHistoryLines.LastOrDefault(), suggestionIndex, fallbackIndex, topFiveSuggestion);
}

_service.RequestPredictions(String.Join(AzPredictorConstants.CommandConcatenator, historyLines));
_service.RecordHistory(commandAsts);
_service.RequestPredictions(maskedHistoryLines);
}
}

Expand Down Expand Up @@ -175,7 +174,12 @@ private static string MergeStrings(string a, string b)
/// <param name="cmdAst">The last user input command</param>
private static string MaskCommandLine(CommandAst cmdAst)
{
var commandElements = cmdAst.CommandElements;
var commandElements = cmdAst?.CommandElements;

if (commandElements == null)
{
return null;
}

if (commandElements.Count == 1)
{
Expand All @@ -196,15 +200,15 @@ private static string MaskCommandLine(CommandAst cmdAst)
if (param.Argument != null)
{
// Parameter is in the form of `-Name:name`
_ = sb.Append(AzPredictor.ParameterIndicator)
_ = sb.Append(AzPredictorConstants.ParameterIndicator)
.Append(param.ParameterName)
.Append(AzPredictor.ParameterValueSeperator)
.Append(AzPredictor.ParameterValueMask);
}
else
{
// Parameter is in the form of `-Name`
_ = sb.Append(AzPredictor.ParameterIndicator)
_ = sb.Append(AzPredictorConstants.ParameterIndicator)
.Append(param.ParameterName)
.Append(AzPredictorConstants.CommandParameterSeperator)
.Append(AzPredictor.ParameterValueMask);
Expand All @@ -223,8 +227,8 @@ public class PredictorInitializer : IModuleAssemblyInitializer
public void OnImport()
{
var settings = Settings.GetSettings();
var azPredictorService = new AzPredictorService(settings.ServiceUri);
var telemetryClient = new AzPredictorTelemetryClient();
var azPredictorService = new AzPredictorService(settings.ServiceUri);
var predictor = new AzPredictor(azPredictorService, telemetryClient);
SubsystemManager.RegisterSubsystem<ICommandPredictor, AzPredictor>(predictor);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ internal static class AzPredictorConstants
/// <summary>
/// The value to check to determine if it's an Az command.
/// </summary>
public const string AzCommandMoniktor = "az";
public const string AzCommandMoniker = "-Az";

/// <summary>
/// The character to use when we join the commands together.
Expand All @@ -54,6 +54,11 @@ internal static class AzPredictorConstants
/// </summary>
public const char CommandParameterSeperator = ' ';

/// <summary>
/// The character that begins a parameter.
/// </summary>
public const char ParameterIndicator = '-';

/// <summary>
/// The setting file name.
/// </summary>
Expand Down
51 changes: 30 additions & 21 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
// limitations under the License.
// ----------------------------------------------------------------------------------

using Microsoft.WindowsAzure.Commands.Utilities.Common;
using Newtonsoft.Json;
using Newtonsoft.Json.Serialization;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Management.Automation.Language;
Expand Down Expand Up @@ -48,16 +50,15 @@ public sealed class RequestContext
public PredictionRequestBody(string history) => this.History = history;
};

private const int PredictionRequestInProgress = 1;
private const int PredictionRequestNotInProgress = 0;
private static readonly HttpClient _client = new HttpClient();
private readonly string _commandsEndpoint;
private readonly string _predictionsEndpoint;
private volatile Tuple<string, Predictor> _historySuggestions; // The history and the prediction for that.
private volatile Predictor _commands;
private volatile string _history;
private HashSet<string> _commandSet = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
private HashSet<string> _commandSet;
private CancellationTokenSource _predictionRequestCancellationSource;
private ParameterValuePredictor _parameterValuePredictor = new ParameterValuePredictor();

/// <summary>
/// The AzPredictor service interacts with the Aladdin service specified in serviceUri.
Expand Down Expand Up @@ -144,53 +145,58 @@ public Tuple<string, PredictionSource> GetSuggestion(Ast input, CancellationToke
}

/// <inheritdoc/>
public virtual void RequestPredictions(string history)
public virtual void RequestPredictions(IEnumerable<string> history)
{
// Even if it's called multiple times, we only need to keep the one for the latest history.

this._predictionRequestCancellationSource?.Cancel();
this._predictionRequestCancellationSource = new CancellationTokenSource();
var cancellationToken = this._predictionRequestCancellationSource.Token;
this._history = history;
var localHistory = string.Join(AzPredictorConstants.CommandConcatenator, history);
this._history = localHistory;

// We don't need to block on the task. We send the HTTP request and update prediction list at the background.
Task.Run(async () => {
var requestBody = JsonConvert.SerializeObject(new PredictionRequestBody(history));
var requestBody = JsonConvert.SerializeObject(new PredictionRequestBody(localHistory));
var httpResponseMessage = await _client.PostAsync(this._predictionsEndpoint, new StringContent(requestBody, Encoding.UTF8, "application/json"), cancellationToken);

var reply = await httpResponseMessage.Content.ReadAsStringAsync(cancellationToken);
var suggestionsList = JsonConvert.DeserializeObject<List<string>>(reply);

this.SetSuggestionPredictor(history, suggestionsList);
this.SetSuggestionPredictor(localHistory, suggestionsList);
},
cancellationToken);
}

/// <summary>
/// For logging purposes, get the rank of the user input in the model suggestions list.
/// </summary>
public int? GetRankOfSuggestion(CommandAst command, Ast input)
/// <inheritdoc/>
public virtual void RecordHistory(IEnumerable<CommandAst> history)
{
history.ForEach((h) => this._parameterValuePredictor.ProcessHistoryCommand(h));
}

/// <inhericdoc/>
public int? GetRankOfSuggestion(string commandName)
{
var historySuggestions = this._historySuggestions;
return historySuggestions?.Item2?.GetCommandPrediction(command, input, CancellationToken.None).Item2;
return historySuggestions?.Item2?.GetCommandPrediction(commandName, isCommandNameComplete: true, cancellationToken:CancellationToken.None).Item2;
}

/// <inheritdoc/>
public int? GetRankOfFallback(CommandAst command, Ast input)
/// <inhericdoc/>
public int? GetRankOfFallback(string commandName)
{
var commands = this._commands;
return commands?.GetCommandPrediction(command, input, CancellationToken.None).Item2;
return commands?.GetCommandPrediction(commandName, isCommandNameComplete:true, cancellationToken:CancellationToken.None).Item2;
}

/// <inheritdoc/>
/// <inhericdoc/>
public IEnumerable<string> GetTopNSuggestions(int n)
{
var historySuggestions = this._historySuggestions;
return historySuggestions?.Item2?.GetTopNPrediction(n);
}

/// <inheritdoc/>
public bool IsSupportedCommand(string cmd) => !string.IsNullOrWhiteSpace(cmd) && _commandSet.Contains(cmd);
public bool IsSupportedCommand(string cmd) => !string.IsNullOrWhiteSpace(cmd) && (_commandSet?.Contains(cmd) == true);

/// <summary>
/// Requests a list of popular commands from service. These commands are used as fallback suggestion
Expand All @@ -209,7 +215,10 @@ protected virtual void RequestCommands()

// Initialize predictions
var startHistory = $"{AzPredictorConstants.CommandHistoryPlaceholder}{AzPredictorConstants.CommandConcatenator}{AzPredictorConstants.CommandHistoryPlaceholder}";
RequestPredictions(startHistory);
RequestPredictions(new string[] {
AzPredictorConstants.CommandHistoryPlaceholder,
AzPredictorConstants.CommandHistoryPlaceholder});

});
}

Expand All @@ -219,8 +228,8 @@ protected virtual void RequestCommands()
/// <param name="commands">The command collection to set the predictor</param>
protected void SetCommandsPredictor(IList<string> commands)
{
this._commands = new Predictor(commands);
this._commandSet = new HashSet<string>(commands.Select(x => AzPredictorService.GetCommandName(x))); // this could be slow
this._commands = new Predictor(commands, this._parameterValuePredictor);
this._commandSet = commands.Select(x => AzPredictorService.GetCommandName(x)).ToHashSet<string>(StringComparer.OrdinalIgnoreCase); // this could be slow

}

Expand All @@ -231,7 +240,7 @@ protected void SetCommandsPredictor(IList<string> commands)
/// <param name="suggestions">The suggestion collection to set the predictor</param>
protected void SetSuggestionPredictor(string history, IList<string> suggestions)
{
this._historySuggestions = Tuple.Create(history, new Predictor(suggestions));
this._historySuggestions = Tuple.Create(history, new Predictor(suggestions, this._parameterValuePredictor));
}

/// <summary>
Expand Down
16 changes: 11 additions & 5 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/IAzPredictorService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,29 @@ public interface IAzPredictorService
/// <summary>
/// Requests predictions, given a history string.
/// </summary>
/// <param name="history">A history string could look like: "Get-AzContext -Name NAME\nSet-AzContext"</param>
public void RequestPredictions(string history);
/// <param name="history">A list of history commands</param>
public void RequestPredictions(IEnumerable<string> history);

/// <summary>
/// For logging purposes, get the rank of the user input in the model suggestions list.
/// Record the history from PSReadLine.
/// </summary>
public int? GetRankOfSuggestion(CommandAst command, Ast input);
/// <param name="history">A list of history commands</param>
public void RecordHistory(IEnumerable<CommandAst> history);

/// <summary>
/// Return true if command is part of known set of Az cmdlets, false otherwise.
/// </summary>
public bool IsSupportedCommand(string cmd);

/// <summary>
/// For logging purposes, get the rank of the user input in the model suggestions list.
/// </summary>
public int? GetRankOfSuggestion(string commandName);

/// <summary>
/// For logging purposes, get the rank of the user input in the fallback commands cache.
/// </summary>
public int? GetRankOfFallback(CommandAst command, Ast input);
public int? GetRankOfFallback(string commandName);

/// <summary>
/// For logging purposes, get the top N suggestions from the model suggestions list.
Expand Down
Loading