Skip to content

Improve getting command ast and parse parameters #13864

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ public void VerifyParameterValues()
Action actual = () => this._service.GetSuggestion(null, 1, 1, CancellationToken.None);
Assert.Throws<ArgumentNullException>(actual);

actual = () => this._service.GetSuggestion(predictionContext.InputAst, 0, 1, CancellationToken.None);
actual = () => this._service.GetSuggestion(predictionContext, 0, 1, CancellationToken.None);
Assert.Throws<ArgumentOutOfRangeException>(actual);

actual = () => this._service.GetSuggestion(predictionContext.InputAst, 1, 0, CancellationToken.None);
actual = () => this._service.GetSuggestion(predictionContext, 1, 0, CancellationToken.None);
Assert.Throws<ArgumentOutOfRangeException>(actual);
}

Expand All @@ -110,8 +110,8 @@ public void VerifyParameterValues()
public void VerifyUsingCommandBasedPredictor(string userInput)
{
var predictionContext = PredictionContext.Create(userInput);
var commandAst = predictionContext.InputAst.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst;
var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value;
var commandAst = predictionContext.RelatedAsts.OfType<CommandAst>().LastOrDefault();
var commandName = commandAst?.GetCommandName();
var inputParameterSet = new ParameterSet(commandAst);
var rawUserInput = predictionContext.InputAst.Extent.Text;
var presentCommands = new Dictionary<string, int>();
Expand All @@ -123,7 +123,7 @@ public void VerifyUsingCommandBasedPredictor(string userInput)
1,
CancellationToken.None);

var actual = this._service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
var actual = this._service.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.NotNull(actual);
Assert.True(actual.Count > 0);
Assert.NotNull(actual.PredictiveSuggestions.First());
Expand All @@ -133,7 +133,7 @@ public void VerifyUsingCommandBasedPredictor(string userInput)
Assert.Equal<string>(expected.SourceTexts, actual.SourceTexts);
Assert.All<SuggestionSource>(actual.SuggestionSources, (source) => Assert.Equal(SuggestionSource.CurrentCommand, source));

actual = this._noFallbackPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
actual = this._noFallbackPredictorService.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.NotNull(actual);
Assert.True(actual.Count > 0);
Assert.NotNull(actual.PredictiveSuggestions.First());
Expand All @@ -153,7 +153,7 @@ public void VerifyUsingCommandBasedPredictor(string userInput)
public void VerifyUsingFallbackPredictor(string userInput)
{
var predictionContext = PredictionContext.Create(userInput);
var commandAst = predictionContext.InputAst.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst;
var commandAst = predictionContext.RelatedAsts.OfType<CommandAst>().LastOrDefault();
var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value;
var inputParameterSet = new ParameterSet(commandAst);
var rawUserInput = predictionContext.InputAst.Extent.Text;
Expand All @@ -166,7 +166,7 @@ public void VerifyUsingFallbackPredictor(string userInput)
1,
CancellationToken.None);

var actual = this._service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
var actual = this._service.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.NotNull(actual);
Assert.True(actual.Count > 0);
Assert.NotNull(actual.PredictiveSuggestions.First());
Expand All @@ -176,7 +176,7 @@ public void VerifyUsingFallbackPredictor(string userInput)
Assert.Equal<string>(expected.SourceTexts, actual.SourceTexts);
Assert.All<SuggestionSource>(actual.SuggestionSources, (source) => Assert.Equal(SuggestionSource.StaticCommands, source));

actual = this._noCommandBasedPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
actual = this._noCommandBasedPredictorService.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.NotNull(actual);
Assert.True(actual.Count > 0);
Assert.NotNull(actual.PredictiveSuggestions.First());
Expand All @@ -199,33 +199,43 @@ public void VerifyUsingFallbackPredictor(string userInput)
public void VerifyNoPrediction(string userInput)
{
var predictionContext = PredictionContext.Create(userInput);
var actual = this._service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
var actual = this._service.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.Equal(0, actual.Count);

actual = this._noFallbackPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
actual = this._noFallbackPredictorService.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.Equal(0, actual.Count);

actual = this._noCommandBasedPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
actual = this._noCommandBasedPredictorService.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.Equal(0, actual.Count);

actual = this._noPredictorService.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
actual = this._noPredictorService.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.Null(actual);
}

/// <summary>
/// Verify that it returns null when we cannot parse the user input.
/// </summary>
[Theory]
[InlineData("git status")]
public void VerifyFailToParseUserInput(string userInput)
{
var predictionContext = PredictionContext.Create(userInput);
var actual = this._service.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
Assert.Null(actual);
}

/// <summary>
/// Verify when we cannot parse the user input correctly.
/// </summary>
/// <remarks>
/// When we can parse them correctly, please move the InlineData to the corresponding test methods, for example, "git status"
/// doesn't have any prediction so it should move to <see cref="VerifyNoPrediction"/>.
/// When we can parse them correctly, please move the InlineData to the corresponding test methods.
/// </remarks>
[Theory]
[InlineData("git status")]
[InlineData("Get-AzContext Name")]
public void VerifyMalFormattedCommandLine(string userInput)
{
var predictionContext = PredictionContext.Create(userInput);
Action actual = () => this._service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
Action actual = () => this._service.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
_ = Assert.Throws<InvalidOperationException>(actual);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ public void VerifySupportedAndUnsupportedCommands()
public void VerifySuggestion(string userInput)
{
var predictionContext = PredictionContext.Create(userInput);
var expected = _service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
var actual = _azPredictor.GetSuggestion(predictionContext, CancellationToken.None);
var expected = this._service.GetSuggestion(predictionContext, 1, 1, CancellationToken.None);
var actual = this._azPredictor.GetSuggestion(predictionContext, CancellationToken.None);

Assert.Equal(expected.Count, actual.Count);
Assert.Equal(expected.PredictiveSuggestions.First().SuggestionText, actual.First().SuggestionText);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
// limitations under the License.
// ----------------------------------------------------------------------------------

using System;
using Microsoft.Azure.PowerShell.Tools.AzPredictor.Telemetry;
using System;

namespace Microsoft.Azure.PowerShell.Tools.AzPredictor.Test.Mocks
{
Expand Down Expand Up @@ -58,5 +58,10 @@ public void OnGetSuggestion(GetSuggestionTelemetryData telemetryData)
{
}

/// <inheritdoc/>
public void OnLoadParameterMap(ParameterMapTelemetryData telemetryData)
{
}

}
}
5 changes: 4 additions & 1 deletion tools/Az.Tools.Predictor/Az.Tools.Predictor.sln
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Az.Tools.Predictor", "Az.To
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Az.Tools.Predictor.Test", "Az.Tools.Predictor.Test\Az.Tools.Predictor.Test.csproj", "{C7A3ED31-8F41-4643-ADCF-85DF032BD8AC}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "MockPSConsole", "MockPSConsole\MockPSConsole.csproj", "{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}"
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "MockPSConsole", "MockPSConsole\MockPSConsole.csproj", "{80AFAFC7-9542-4CEB-AB5B-D1385A28CCE5}"
ProjectSection(ProjectDependencies) = postProject
{E4A5F697-086C-4908-B90E-A31EE47ECF13} = {E4A5F697-086C-4908-B90E-A31EE47ECF13}
EndProjectSection
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Expand Down
4 changes: 2 additions & 2 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
using System.Runtime.CompilerServices;
using System.Threading;

[assembly:InternalsVisibleTo("Microsoft.Azure.PowerShell.Tools.AzPredictor.Test")]
[assembly: InternalsVisibleTo("Microsoft.Azure.PowerShell.Tools.AzPredictor.Test")]

namespace Microsoft.Azure.PowerShell.Tools.AzPredictor
{
Expand Down Expand Up @@ -195,7 +195,7 @@ public List<PredictiveSuggestion> GetSuggestion(PredictionContext context, Cance
{
var localCancellationToken = Settings.ContinueOnTimeout ? CancellationToken.None : cancellationToken;

suggestions = _service.GetSuggestion(context.InputAst, _settings.SuggestionCount.Value, _settings.MaxAllowedCommandDuplicate.Value, localCancellationToken);
suggestions = _service.GetSuggestion(context, _settings.SuggestionCount.Value, _settings.MaxAllowedCommandDuplicate.Value, localCancellationToken);

var returnedValue = suggestions?.PredictiveSuggestions?.ToList();
return returnedValue ?? new List<PredictiveSuggestion>();
Expand Down
41 changes: 34 additions & 7 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/AzPredictorService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Management.Automation.Language;
using System.Management.Automation.Subsystem;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Text;
Expand Down Expand Up @@ -81,7 +82,7 @@ private sealed class CommandRequestContext
/// </summary>
private HashSet<string> _allPredictiveCommands;
private CancellationTokenSource _predictionRequestCancellationSource;
private readonly ParameterValuePredictor _parameterValuePredictor = new ParameterValuePredictor();
private readonly ParameterValuePredictor _parameterValuePredictor;

private readonly ITelemetryClient _telemetryClient;
private readonly IAzContext _azContext;
Expand All @@ -98,6 +99,8 @@ public AzPredictorService(string serviceUri, ITelemetryClient telemetryClient, I
Validation.CheckArgument(telemetryClient, $"{nameof(telemetryClient)} cannot be null.");
Validation.CheckArgument(azContext, $"{nameof(azContext)} cannot be null.");

_parameterValuePredictor = new ParameterValuePredictor(telemetryClient);

_commandsEndpoint = $"{serviceUri}{AzPredictorConstants.CommandsEndpoint}?clientType={AzPredictorService.ClientType}&context.versionNumber={azContext.AzVersion}";
_predictionsEndpoint = serviceUri + AzPredictorConstants.PredictionsEndpoint;
_telemetryClient = telemetryClient;
Expand Down Expand Up @@ -143,22 +146,46 @@ protected virtual void Dispose(bool disposing)
/// Tries to get the suggestions for the user input from the command history. If that doesn't find
/// <paramref name="suggestionCount"/> suggestions, it'll fallback to find the suggestion regardless of command history.
/// </remarks>
public CommandLineSuggestion GetSuggestion(Ast input, int suggestionCount, int maxAllowedCommandDuplicate, CancellationToken cancellationToken)
public CommandLineSuggestion GetSuggestion(PredictionContext context, int suggestionCount, int maxAllowedCommandDuplicate, CancellationToken cancellationToken)
{
Validation.CheckArgument(input, $"{nameof(input)} cannot be null");
Validation.CheckArgument(context, $"{nameof(context)} cannot be null");
Validation.CheckArgument<ArgumentOutOfRangeException>(suggestionCount > 0, $"{nameof(suggestionCount)} must be larger than 0.");
Validation.CheckArgument<ArgumentOutOfRangeException>(maxAllowedCommandDuplicate > 0, $"{nameof(maxAllowedCommandDuplicate)} must be larger than 0.");

var commandAst = input.FindAll(p => p is CommandAst, true).LastOrDefault() as CommandAst;
var commandName = (commandAst?.CommandElements?.FirstOrDefault() as StringConstantExpressionAst)?.Value;
var relatedAsts = context.RelatedAsts;
CommandAst commandAst = null;

for (var i = relatedAsts.Count - 1; i >= 0; --i)
{
if (relatedAsts[i] is CommandAst c)
{
commandAst = c;
break;
}
}

var commandName = commandAst?.GetCommandName();

if (string.IsNullOrWhiteSpace(commandName))
{
return null;
}

var inputParameterSet = new ParameterSet(commandAst);
var rawUserInput = input.Extent.Text;
ParameterSet inputParameterSet = null;

try
{
inputParameterSet = new ParameterSet(commandAst);
}
catch when (!IsSupportedCommand(commandName))
{
// We only ignore the exception when the command name is not supported.
// For the supported ones, this most likely happens when positional parameters are used.
// We want to collect the telemetry about the exception how common a positional parameter is used.
return null;
}

var rawUserInput = context.InputAst.ToString();
var presentCommands = new Dictionary<string, int>();
var commandBasedPredictor = _commandBasedPredictor;
var commandToRequestPrediction = _commandToRequestPrediction;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,13 @@ public CommandLinePredictor(IList<PredictiveCommand> modelPredictions, Parameter
{
var predictionText = CommandLineUtilities.EscapePredictionText(predictiveCommand.Command);
Ast ast = Parser.ParseInput(predictionText, out Token[] tokens, out _);
var commandAst = (ast.Find((ast) => ast is CommandAst, searchNestedScriptBlocks: false) as CommandAst);
var commandAst = ast.Find((ast) => ast is CommandAst, searchNestedScriptBlocks: false) as CommandAst;
var commandName = commandAst?.GetCommandName();

if (commandAst?.CommandElements[0] is StringConstantExpressionAst commandName)
if (!string.IsNullOrWhiteSpace(commandName))
{
var parameterSet = new ParameterSet(commandAst);
this._commandLinePredictions.Add(new CommandLine(commandName.Value, predictiveCommand.Description, parameterSet));
this._commandLinePredictions.Add(new CommandLine(commandName, predictiveCommand.Description, parameterSet));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
// limitations under the License.
// ----------------------------------------------------------------------------------

using System;
using System.Collections.Generic;
using System.Management.Automation.Language;
using System.Management.Automation.Subsystem;
using System.Threading;

namespace Microsoft.Azure.PowerShell.Tools.AzPredictor
Expand All @@ -27,12 +27,12 @@ public interface IAzPredictorService
/// <summary>
/// Gest the suggestions for the user input.
/// </summary>
/// <param name="input">User input from PSReadLine.</param>
/// <param name="context">User input context from PSReadLine.</param>
/// <param name="suggestionCount">The number of suggestion to return.</param>
/// <param name="cancellationToken">The cancellation token</param>
/// <param name="maxAllowedCommandDuplicate">The maximum amount of the same commnds in the list of predictions.</param>
/// <returns>The suggestions for <paramref name="input"/>. The maximum number of suggestions is <paramref name="suggestionCount"/>.</returns>
public CommandLineSuggestion GetSuggestion(Ast input, int suggestionCount, int maxAllowedCommandDuplicate, CancellationToken cancellationToken);
/// <returns>The suggestions for <paramref name="context"/>. The maximum number of suggestions is <paramref name="suggestionCount"/>. A null will be returned if there the user input context isn't valid/supported at all.</returns>
public CommandLineSuggestion GetSuggestion(PredictionContext context, int suggestionCount, int maxAllowedCommandDuplicate, CancellationToken cancellationToken);

/// <summary>
/// Requests predictions, given a command string.
Expand Down
27 changes: 22 additions & 5 deletions tools/Az.Tools.Predictor/Az.Tools.Predictor/ParameterSet.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ namespace Microsoft.Azure.PowerShell.Tools.AzPredictor
/// does not matter to resulting prediction - the prediction should adapt to the
/// order of the parameters typed by the user.
/// </summary>
/// <remarks>
/// This doesn't handle the positional parameters yet.
/// </remarks>
sealed class ParameterSet
{
/// <summary>
Expand All @@ -36,11 +39,14 @@ public ParameterSet(CommandAst commandAst)
Validation.CheckArgument(commandAst, $"{nameof(commandAst)} cannot be null.");

var parameters = new List<Parameter>();
var elements = commandAst.CommandElements.Skip(1);
CommandParameterAst param = null;
Ast arg = null;
foreach (Ast elem in elements)

// Loop through all the parameters. The first element of CommandElements is the command name, so skip it.
for (var i = 1; i < commandAst.CommandElements.Count(); ++i)
{
var elem = commandAst.CommandElements[i];

if (elem is CommandParameterAst p)
{
AddParameter(param, arg);
Expand Down Expand Up @@ -68,11 +74,22 @@ public ParameterSet(CommandAst commandAst)

Parameters = parameters;

void AddParameter(CommandParameterAst parameterName, Ast parameterValue)
void AddParameter(CommandParameterAst parameter, Ast parameterValue)
{
if (parameterName != null)
if (parameter != null)
{
parameters.Add(new Parameter(parameterName.ParameterName, (parameterValue == null) ? null : CommandLineUtilities.UnescapePredictionText(parameterValue.ToString())));
var value = parameterValue?.ToString();
if (value == null)
{
value = parameter.Argument?.ToString();
}

if (value != null)
{
value = CommandLineUtilities.UnescapePredictionText(value);
}

parameters.Add(new Parameter(parameter.ParameterName, value));
}
}
}
Expand Down
Loading