Skip to content

Bug fixes and performance improvement #13410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ public PredictorTests(ModelFixture fixture)
/// </summary>
[Theory]
[InlineData("NEW-AZCONTEXT")]
[InlineData("Get-AzStorageAccount ")] // A complete command and we have exact the same on in the model.
[InlineData("get-azaccount ")]
[InlineData(AzPredictorConstants.CommandPlaceholder)]
[InlineData("git status")]
Expand Down
98 changes: 26 additions & 72 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictor.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// ----------------------------------------------------------------------------------
// ----------------------------------------------------------------------------------
//
// Copyright Microsoft Corporation
// Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -168,69 +168,27 @@ public void OnSuggestionAccepted(string acceptedSuggestion)
/// <inhericdoc />
public List<PredictiveSuggestion> GetSuggestion(PredictionContext context, CancellationToken cancellationToken)
{
// Eventually, the rendering layer in PSReadLine will show the suggestions in a list,
// with an experience similar to you typing in the text box in google/bing search page.
// Hence, the returned texts don't have to be prefixed with the `userInput`, but the
// text should be in the order of relevance, meaning that if you have a result that
// is prefixed with `userInput`, it should be ordered before result that is not prefixed
// with `userInput`.
var localCancellationToken = Settings.ContinueOnTimeout ? CancellationToken.None : cancellationToken;

IEnumerable<ValueTuple<string, IList<Tuple<string, string>>, PredictionSource>> suggestions = Enumerable.Empty<ValueTuple<string, IList<Tuple<string, string>>, PredictionSource>>();
IEnumerable<ValueTuple<string, string, PredictionSource>> suggestions = Enumerable.Empty<ValueTuple<string, string, PredictionSource>>();
string maskedUserInput = string.Empty;
// This is the list of records of the source suggestion and the prediction source.
var telemetryData = new List<ValueTuple<string, PredictionSource>>();

try
{
suggestions = _service.GetSuggestion(context.InputAst, _settings.SuggestionCount.Value, Settings.ContinueOnTimeout ? CancellationToken.None : cancellationToken);
maskedUserInput = AzPredictor.MaskCommandLine(context.InputAst.FindAll((ast) => ast is CommandAst, true).LastOrDefault() as CommandAst);

if (!Settings.ContinueOnTimeout)
{
cancellationToken.ThrowIfCancellationRequested();
}
suggestions = _service.GetSuggestion(context.InputAst, _settings.SuggestionCount.Value, localCancellationToken);

localCancellationToken.ThrowIfCancellationRequested();

var userInput = context.InputAst.Extent.Text;
return suggestions.Select((r, index) =>
{
return new PredictiveSuggestion(MergeStrings(userInput, r.Item1));
})
.ToList();
}
catch (Exception e) when (!(e is OperationCanceledException))
{
this._telemetryClient.OnGetSuggestionError(e);
}
finally
{
var maskedCommandLine = AzPredictor.MaskCommandLine(context.InputAst.FindAll((ast) => ast is CommandAst, true).LastOrDefault() as CommandAst);
var sb = new StringBuilder();
// This is the list of records of the original suggestion and the prediction source.
var telemetryData = new List<ValueTuple<string, PredictionSource>>();
var userAcceptedAndSuggestion = new Dictionary<string, string>();

foreach (var s in suggestions)
{
sb.Clear();
sb.Append(s.Item1.Split(' ')[0])
.Append(AzPredictorConstants.CommandParameterSeperator);

foreach (var p in s.Item2)
{
sb.Append(p.Item1);
if (p.Item2 != null)
{
sb.Append(AzPredictorConstants.CommandParameterSeperator)
.Append(p.Item2);
}

sb.Append(AzPredictorConstants.CommandParameterSeperator);
}

if (sb[sb.Length - 1] == AzPredictorConstants.CommandParameterSeperator)
{
sb.Remove(sb.Length - 1, 1);
}

var suggestedText = sb.ToString();
telemetryData.Add(ValueTuple.Create(suggestedText, s.Item3));
userAcceptedAndSuggestion[s.Item1] = suggestedText;
telemetryData.Add(ValueTuple.Create(s.Item2, s.Item3));
userAcceptedAndSuggestion[s.Item1] = s.Item2;
}

lock (_userAcceptedAndSuggestion)
Expand All @@ -241,31 +199,27 @@ public List<PredictiveSuggestion> GetSuggestion(PredictionContext context, Cance
}
}

_telemetryClient.OnGetSuggestion(maskedCommandLine,
localCancellationToken.ThrowIfCancellationRequested();

var returnedValue = suggestions.Select((r, index) =>
{
return new PredictiveSuggestion(r.Item1);
})
.ToList();

_telemetryClient.OnGetSuggestion(maskedUserInput,
telemetryData,
cancellationToken.IsCancellationRequested);
}

return new List<PredictiveSuggestion>();
}
return returnedValue;

// Merge strings a and b such that the prefix of b is deleted if it is the suffix of a
private static string MergeStrings(string a, string b)
{
for (int i = 0; i < a.Length; i++)
}
catch (Exception e) when (!(e is OperationCanceledException))
{
var j = i;
while (char.ToLower(a[j]) == char.ToLower(b[j - i]))
{
j++;
if (j == a.Length)
{
return a.Substring(0, i) + b;
}
}
this._telemetryClient.OnGetSuggestionError(e);
}

return a + b;
return new List<PredictiveSuggestion>();
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ protected virtual void Dispose(bool disposing)
/// <remarks>
/// Queries the Predictor with the user input if predictions are available, otherwise uses commands
/// </remarks>
public IEnumerable<ValueTuple<string, IList<Tuple<string, string>>, PredictionSource>> GetSuggestion(Ast input, int suggestionCount, CancellationToken cancellationToken)
public IEnumerable<ValueTuple<string, string, PredictionSource>> GetSuggestion(Ast input, int suggestionCount, CancellationToken cancellationToken)
{
var commandSuggestions = this._commandSuggestions;
var command = this._commandForPrediction;

IList<ValueTuple<string, IList<Tuple<string, string>>, PredictionSource>> results = new List<ValueTuple<string, IList<Tuple<string, string>>, PredictionSource>>();
IList<ValueTuple<string, string, PredictionSource>> results = new List<ValueTuple<string, string, PredictionSource>>();

var resultsFromSuggestion = commandSuggestions?.Item2?.Query(input, suggestionCount, cancellationToken);

Expand Down Expand Up @@ -157,7 +157,7 @@ public IEnumerable<ValueTuple<string, IList<Tuple<string, string>>, PredictionSo
{
foreach (var r in resultsFromCommands)
{
if (resultsFromCommands?.ContainsKey(r.Key) == true)
if (resultsFromSuggestion?.ContainsKey(r.Key) == true)
{
continue;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"serviceUri": "https://app.aladdin.microsoft.com/api/v1",
"suggestionCount": 5
"suggestionCount": 7
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ public interface IAzPredictorService
/// <param name="input">User input from PSReadLine</param>
/// <param name="suggestionCount">The number of suggestion to return.</param>
/// <param name="cancellationToken">The cancellation token</param>
/// <returns>The list of suggestions and the parameter set that construct the suggestion for <paramref name="input"/>. The maximum number of suggestion is <paramref name="suggestionCount"/></returns>
public IEnumerable<ValueTuple<string, IList<Tuple<string, string>>, PredictionSource>> GetSuggestion(Ast input, int suggestionCount, CancellationToken cancellationToken);
/// <returns>The list of suggestions for <paramref name="input"/> and the source that create the suggestion. The maximum number of suggestion is <paramref name="suggestionCount"/></returns>
public IEnumerable<ValueTuple<string, string, PredictionSource>> GetSuggestion(Ast input, int suggestionCount, CancellationToken cancellationToken);

/// <summary>
/// Requests predictions, given a command string.
Expand Down
5 changes: 5 additions & 0 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/ParameterSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,15 @@ public ParameterSet(CommandAst commandAst)
}
}


if (param != null)
{
Parameters.Add(new Tuple<string, string>(param.ToString(), arg?.ToString()));
}
else if (arg != null)
{
throw new InvalidOperationException();
}
}
}
}
129 changes: 68 additions & 61 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/Predictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
using System.Collections.Generic;
using System.Linq;
using System.Management.Automation.Language;
using System.Management.Automation.Subsystem;
using System.Text;
using System.Threading;

Expand Down Expand Up @@ -48,17 +47,8 @@ public Predictor(IList<string> modelPredictions, ParameterValuePredictor paramet

if (commandAst?.CommandElements[0] is StringConstantExpressionAst commandName)
{
var existingPrediction = this._predictions.FirstOrDefault(p => string.Equals(p.Command, commandName.Value, StringComparison.OrdinalIgnoreCase));
var parameterSet = new ParameterSet(commandAst);

if (existingPrediction != null)
{
existingPrediction.ParameterSets.Add(parameterSet);
}
else
{
this._predictions.Add(new Prediction(commandName.Value, parameterSet));
}
this._predictions.Add(new Prediction(commandName.Value, parameterSet));
}
}
}
Expand All @@ -69,8 +59,9 @@ public Predictor(IList<string> modelPredictions, ParameterValuePredictor paramet
/// <param name="input">PowerShell AST input of the user, generated by PSReadLine</param>
/// <param name="suggestionCount">The number of suggestion to return.</param>
/// <param name="cancellationToken">The cancellation token</param>
/// <returns>The collection of predicted texts for the user input and the parameter set that constructs the predicted texts.</returns>
public IDictionary<string, IList<Tuple<string, string>>> Query(Ast input, int suggestionCount, CancellationToken cancellationToken)
/// <returns>The collection of suggestions. The key is the predicted text adjusted based on <paramref name="input"/>. The
/// value is the original text to create the adjusted text. </returns>
public IDictionary<string, string> Query(Ast input, int suggestionCount, CancellationToken cancellationToken)
{
if (suggestionCount <= 0)
{
Expand All @@ -85,70 +76,86 @@ public IDictionary<string, IList<Tuple<string, string>>> Query(Ast input, int su
return null;
}

var inputParameterSet = new ParameterSet(commandAst);

// This stores the neccessary information to build the results.
// Each element in the list is a ValueTuple and each such tuple has this component
// 1. The parameter list from the predictions.
// 2. The index of the element that the user input matches in the previous parameter list.
// 3. The string that's built so far for the final result.
var results = new List<ValueTuple<IList<Tuple<string, string>>, HashSet<int>, StringBuilder>>();
var results = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);

var isCommandNameComplete = (((commandAst?.CommandElements != null) && (commandAst.CommandElements.Count > 1)) || ((input as ScriptBlockAst)?.Extent?.Text?.EndsWith(' ') == true));

Func<string, bool> commandNameQuery = (command) => command.Equals(commandName, StringComparison.OrdinalIgnoreCase);
if (!isCommandNameComplete)
try
{
commandNameQuery = (command) => command.StartsWith(commandName, StringComparison.OrdinalIgnoreCase);
}
var inputParameterSet = new ParameterSet(commandAst);

// Try to find the matching command and arrange the parameters in the order of the input.
//
// Predictions should be flexible, e.g. if "Command -Name N -Location L" is a possibility,
// then "Command -Location L -Name N" should also be possible.
//
// resultBuilder and usedParams are stored in results to help prediction more flexible.

for (var i = 0; i < _predictions.Count && results.Count < suggestionCount; ++i)
{
if (commandNameQuery(_predictions[i].Command))
var isCommandNameComplete = (((commandAst?.CommandElements != null) && (commandAst.CommandElements.Count > 1)) || ((input as ScriptBlockAst)?.Extent?.Text?.EndsWith(' ') == true));

Func<string, bool> commandNameQuery = (command) => command.Equals(commandName, StringComparison.OrdinalIgnoreCase);
if (!isCommandNameComplete)
{
foreach (var parameterSet in _predictions[i].ParameterSets)
{
cancellationToken.ThrowIfCancellationRequested();
commandNameQuery = (command) => command.StartsWith(commandName, StringComparison.OrdinalIgnoreCase);
}

var resultBuilder = new StringBuilder(_predictions[i].Command);
var usedParams = new HashSet<int>();
// Try to find the matching command and arrange the parameters in the order of the input.
//
// Predictions should be flexible, e.g. if "Command -Name N -Location L" is a possibility,
// then "Command -Location L -Name N" should also be possible.
//
// resultBuilder and usedParams are used to store the information to construct the result.
// We want to avoid too much heap allocation for the performance purpose.

if (DoesPredictionParameterSetMatchInput(resultBuilder, inputParameterSet, parameterSet, usedParams))
var resultBuilder = new StringBuilder();
var usedParams = new HashSet<int>();
var sourceBuilder = new StringBuilder();

for (var i = 0; i < _predictions.Count && results.Count < suggestionCount; ++i)
{
if (commandNameQuery(_predictions[i].Command))
{
foreach (var parameterSet in _predictions[i].ParameterSets)
{
results.Add(ValueTuple.Create(parameterSet.Parameters, usedParams, resultBuilder));
cancellationToken.ThrowIfCancellationRequested();

if (results.Count == suggestionCount)
resultBuilder.Clear();
resultBuilder.Append(_predictions[i].Command);
usedParams.Clear();

if (DoesPredictionParameterSetMatchInput(resultBuilder, inputParameterSet, parameterSet, usedParams))
{
break;
PredictRestOfParameters(resultBuilder, parameterSet.Parameters, usedParams);
var prediction = UnescapePredictionText(resultBuilder);

if (prediction.Length <= input.Extent.Text.Length)
{
continue;
}

sourceBuilder.Clear();
sourceBuilder.Append(_predictions[i].Command);

foreach (var p in parameterSet.Parameters)
{
_ = sourceBuilder.Append(AzPredictorConstants.CommandParameterSeperator);
_ = sourceBuilder.Append(p.Item1);

if (!string.IsNullOrWhiteSpace(p.Item2))
{
_ = sourceBuilder.Append(AzPredictorConstants.CommandParameterSeperator);
_ = sourceBuilder.Append(p.Item2);
}
}

results.Add(prediction.ToString(), sourceBuilder.ToString());

if (results.Count == suggestionCount)
{
break;
}
}
}
}
}
}
catch
{
}

cancellationToken.ThrowIfCancellationRequested();

return results.Select((r) =>
{
PredictRestOfParameters(r.Item3, r.Item1, r.Item2);
var prediction = UnescapePredictionText(r.Item3);

if (prediction.Length <= input.Extent.Text.Length)
{
return ValueTuple.Create<string, List<Tuple<string, string>>>(null, null);
}

return ValueTuple.Create(prediction?.ToString(), r.Item1);
})
.Where((t) => t.Item1 != null)
.ToDictionary((t) => t.Item1, (t) => t.Item2);
return results;
}

/// <summary>
Expand Down