Skip to content

Commit baeddc8

Browse files
authored
Only use supported commands for prediction request (#13863)
* Only use supported commands for prediction request. - We only request predictions for the last two commands. When the commands are not supported, we replace them with "start_of_snippet". The problem of that is when the user has inputted a few unsupported commands such as assignments, we'll start to request prediction for [ "start_of_snippet", "start_of_snippet" ]. That'll the same as resetting the prediction as the beginning of the session. While in fact there may be Az commands in the history. - The change here is not to use those unsupported commands to request predictions. We'll skip the unsupported commands unless we don't have enough commands from history to use for the prediction. * Incorporate PR feedback
1 parent 6486ca7 commit baeddc8

File tree

3 files changed

+287
-80
lines changed

3 files changed

+287
-80
lines changed

tools/Az.Tools.Predictor/Az.Tools.Predictor.Test/AzPredictorTests.cs

Lines changed: 225 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@ public sealed class AzPredictorTests
3737
/// </summary>
3838
public AzPredictorTests(ModelFixture modelFixture)
3939
{
40-
this._fixture = modelFixture;
40+
_fixture = modelFixture;
4141
var startHistory = $"{AzPredictorConstants.CommandPlaceholder}{AzPredictorConstants.CommandConcatenator}{AzPredictorConstants.CommandPlaceholder}";
4242

43-
this._service = new MockAzPredictorService(startHistory, this._fixture.PredictionCollection[startHistory], this._fixture.CommandCollection);
44-
this._telemetryClient = new MockAzPredictorTelemetryClient();
45-
this._azPredictor = new AzPredictor(this._service, this._telemetryClient, new Settings()
43+
_service = new MockAzPredictorService(startHistory, _fixture.PredictionCollection[startHistory], _fixture.CommandCollection);
44+
_telemetryClient = new MockAzPredictorTelemetryClient();
45+
_azPredictor = new AzPredictor(_service, _telemetryClient, new Settings()
4646
{
4747
SuggestionCount = 1,
4848
MaxAllowedCommandDuplicate = 1,
@@ -51,84 +51,250 @@ public AzPredictorTests(ModelFixture modelFixture)
5151
}
5252

5353
/// <summary>
54-
/// Verifies when the last command in history are not supported.
55-
/// We don't collect the telemetry and only request prediction while StartEarlyProcess is called.
54+
/// Verify we replace unsupported command with <see cref="AzPredictorConstants.CommandPlaceholder"/>.
5655
/// </summary>
57-
[Theory]
58-
[InlineData("start_of_snippet\nstart_of_snippet\nstart_of_snippet")]
59-
[InlineData("start_of_snippet")]
60-
[InlineData("")]
61-
[InlineData("git status")]
62-
[InlineData("git status\nGet-ChildItem")]
63-
[InlineData("^29a9l2")]
64-
[InlineData("'Get-AzResource'")]
65-
[InlineData("Get-AzResource\ngit log")]
66-
[InlineData("Get-ChildItem")]
67-
public void VerifyWithNonSupportedCommand(string historyLine)
56+
[Fact]
57+
public void VerifyRequestPredictionForOneUnsupportedCommandInHistory()
6858
{
69-
IReadOnlyList<string> history = historyLine.Split('\n');
59+
IReadOnlyList<string> history = new List<string>()
60+
{
61+
"git status"
62+
};
7063

71-
this._telemetryClient.RecordedSuggestion = null;
72-
this._service.IsPredictionRequested = false;
64+
_telemetryClient.RecordedSuggestion = null;
65+
_service.Commands = null;
66+
_service.History = null;
7367

74-
this._azPredictor.StartEarlyProcessing(history);
68+
_azPredictor.StartEarlyProcessing(history);
7569

76-
Assert.True(this._service.IsPredictionRequested);
77-
Assert.NotNull(this._telemetryClient.RecordedSuggestion);
70+
Assert.Equal(new List<string>() { AzPredictorConstants.CommandPlaceholder, AzPredictorConstants.CommandPlaceholder }, _service.Commands);
71+
Assert.Equal(AzPredictorConstants.CommandPlaceholder, _telemetryClient.RecordedSuggestion.HistoryLine);
72+
Assert.Null(_service.History);
7873
}
7974

8075
/// <summary>
81-
/// Verifies when the last command in history are not supported.
82-
/// We don't collect the telemetry and only request prediction while StartEarlyProcess is called.
76+
/// Verify that we masked the supported command in requesting prediction and telemetry.
8377
/// </summary>
84-
[Theory]
85-
[InlineData("start_of_snippet\nConnect-AzAccount")]
86-
[InlineData("Get-AzResource")]
87-
[InlineData("git status\nGet-AzContext")]
88-
[InlineData("Get-AzContext\nGet-AzLog")]
89-
public void VerifyWithOneSupportedCommand(string historyLine)
78+
[Fact]
79+
public void VerifyRequestPredictionForOneSupportedCommandInHistory()
80+
{
81+
IReadOnlyList<string> history = new List<string>()
82+
{
83+
"New-AzVM -Name hello -Location WestUS"
84+
};
85+
86+
_telemetryClient.RecordedSuggestion = null;
87+
_service.Commands = null;
88+
_service.History = null;
89+
90+
_azPredictor.StartEarlyProcessing(history);
91+
92+
string maskedCommand = "New-AzVM -Location *** -Name ***";
93+
94+
Assert.Equal(new List<string>() { AzPredictorConstants.CommandPlaceholder, maskedCommand }, _service.Commands);
95+
Assert.Equal(maskedCommand, _telemetryClient.RecordedSuggestion.HistoryLine);
96+
Assert.Equal(history[0], _service.History.ToString());
97+
}
98+
99+
/// <summary>
100+
/// Verify that we can handle the two supported command in sequences.
101+
/// </summary>
102+
[Fact]
103+
public void VerifyRequestPredictionForTwoSupportedCommandInHistory()
104+
{
105+
IReadOnlyList<string> history = new List<string>()
106+
{
107+
"New-AzResourceGroup -Name 'resourceGroup01'",
108+
"New-AzVM -Name:hello -Location:WestUS"
109+
};
110+
111+
_telemetryClient.RecordedSuggestion = null;
112+
_service.Commands = null;
113+
_service.History = null;
114+
115+
_azPredictor.StartEarlyProcessing(history);
116+
117+
var maskedCommands = new List<string>()
118+
{
119+
"New-AzResourceGroup -Name ***",
120+
"New-AzVM -Location:*** -Name:***"
121+
};
122+
123+
Assert.Equal(maskedCommands, _service.Commands);
124+
Assert.Equal(maskedCommands[1], _telemetryClient.RecordedSuggestion.HistoryLine);
125+
Assert.Equal(history[1], _service.History.ToString());
126+
}
127+
128+
/// <summary>
129+
/// Verify that we can handle the two unsupported command in sequences.
130+
/// </summary>
131+
[Fact]
132+
public void VerifyRequestPredictionForTwoUnsupportedCommandInHistory()
133+
{
134+
IReadOnlyList<string> history = new List<string>()
135+
{
136+
"git status",
137+
@"$a='ResourceGroup01'",
138+
};
139+
140+
_telemetryClient.RecordedSuggestion = null;
141+
_service.Commands = null;
142+
_service.History = null;
143+
144+
_azPredictor.StartEarlyProcessing(history);
145+
146+
var maskedCommands = new List<string>()
147+
{
148+
AzPredictorConstants.CommandPlaceholder,
149+
AzPredictorConstants.CommandPlaceholder,
150+
};
151+
152+
Assert.Equal(maskedCommands, _service.Commands);
153+
Assert.Equal(maskedCommands[1], _telemetryClient.RecordedSuggestion.HistoryLine);
154+
Assert.Null(_service.History);
155+
}
156+
157+
/// <summary>
158+
/// Verify that we skip the unsupported commands.
159+
/// </summary>
160+
[Fact]
161+
public void VerifyNotTakeUnsupportedCommands()
162+
{
163+
var history = new List<string>()
164+
{
165+
"New-AzResourceGroup -Name:resourceGroup01",
166+
"New-AzVM -Name hello -Location WestUS"
167+
};
168+
169+
_telemetryClient.RecordedSuggestion = null;
170+
_service.Commands = null;
171+
_service.History = null;
172+
173+
_azPredictor.StartEarlyProcessing(history);
174+
175+
history.Add("git status");
176+
_azPredictor.StartEarlyProcessing(history);
177+
178+
history.Add(@"$a='NewResourceName'");
179+
_azPredictor.StartEarlyProcessing(history);
180+
181+
// We don't take the last two unsupported command to request predictions.
182+
// But we send the masked one in telemetry.
183+
184+
var maskedCommands = new List<string>()
185+
{
186+
"New-AzResourceGroup -Name:***",
187+
"New-AzVM -Location *** -Name ***"
188+
};
189+
190+
Assert.Equal(maskedCommands, _service.Commands);
191+
Assert.Equal(AzPredictorConstants.CommandPlaceholder, _telemetryClient.RecordedSuggestion.HistoryLine);
192+
Assert.Equal(history[1], _service.History.ToString());
193+
194+
// When there is a new supported command, we'll use that for prediction.
195+
196+
history.Add("Get-AzResourceGroup -Name ResourceGroup01");
197+
_azPredictor.StartEarlyProcessing(history);
198+
199+
maskedCommands = new List<string>()
200+
{
201+
"New-AzVM -Location *** -Name ***",
202+
"Get-AzResourceGroup -Name ***",
203+
};
204+
205+
Assert.Equal(maskedCommands, _service.Commands);
206+
Assert.Equal(maskedCommands[1], _telemetryClient.RecordedSuggestion.HistoryLine);
207+
Assert.Equal(history.Last(), _service.History.ToString());
208+
}
209+
210+
/// <summary>
211+
/// Verify that we handle the three supported command in the same order.
212+
/// </summary>
213+
[Fact]
214+
public void VerifyThreeSupportedCommands()
90215
{
91-
IReadOnlyList<string> history = historyLine.Split('\n');
216+
var history = new List<string>()
217+
{
218+
"New-AzResourceGroup -Name resourceGroup01",
219+
"New-AzVM -Name:hello -Location:WestUS"
220+
};
92221

93-
this._telemetryClient.RecordedSuggestion = null;
94-
this._service.IsPredictionRequested = false;
222+
_telemetryClient.RecordedSuggestion = null;
223+
_service.Commands = null;
224+
_service.History = null;
95225

96-
this._azPredictor.StartEarlyProcessing(history);
226+
_azPredictor.StartEarlyProcessing(history);
97227

98-
Assert.True(this._service.IsPredictionRequested);
99-
Assert.NotNull(this._telemetryClient.RecordedSuggestion);
228+
history.Add("Get-AzResourceGroup -Name resourceGroup01");
229+
_azPredictor.StartEarlyProcessing(history);
230+
231+
var maskedCommands = new List<string>()
232+
{
233+
"New-AzVM -Location:*** -Name:***",
234+
"Get-AzResourceGroup -Name ***",
235+
};
236+
237+
Assert.Equal(maskedCommands, _service.Commands);
238+
Assert.Equal(maskedCommands[1], _telemetryClient.RecordedSuggestion.HistoryLine);
239+
Assert.Equal(history.Last(), _service.History.ToString());
100240
}
101241

102242
/// <summary>
103-
/// Verify that the supported commands parameter values are masked.
243+
/// Verify that we handle the sequence of one unsupported command and one supported command.
104244
/// </summary>
105245
[Fact]
106-
public void VerifySupportedCommandMasked()
246+
public void VerifyUnsupportedAndSupportedCommands()
107247
{
108-
var input = "Get-AzVMExtension -ResourceGroupName 'ResourceGroup11' -VMName 'VirtualMachine22'";
109-
var expected = "Get-AzVMExtension -ResourceGroupName *** -VMName ***";
248+
var history = new List<string>()
249+
{
250+
"git status",
251+
"New-AzVM -Name:hello -Location:WestUS"
252+
};
110253

111-
this._telemetryClient.RecordedSuggestion = null;
112-
this._service.IsPredictionRequested = false;
254+
_telemetryClient.RecordedSuggestion = null;
255+
_service.Commands = null;
256+
_service.History = null;
113257

114-
this._azPredictor.StartEarlyProcessing(new List<string> { input } );
258+
_azPredictor.StartEarlyProcessing(history);
115259

116-
Assert.True(this._service.IsPredictionRequested);
117-
Assert.NotNull(this._telemetryClient.RecordedSuggestion);
118-
Assert.Equal(expected, this._telemetryClient.RecordedSuggestion.HistoryLine);
260+
var maskedCommands = new List<string>()
261+
{
262+
AzPredictorConstants.CommandPlaceholder,
263+
"New-AzVM -Location:*** -Name:***"
264+
};
119265

120-
input = "Get-AzStorageAccountKey -Name:'ContosoStorage' -ResourceGroupName:'ContosoGroup02'";
121-
expected = "Get-AzStorageAccountKey -Name:*** -ResourceGroupName:***";
266+
Assert.Equal(maskedCommands, _service.Commands);
267+
Assert.Equal(maskedCommands[1], _telemetryClient.RecordedSuggestion.HistoryLine);
268+
Assert.Equal(history.Last(), _service.History.ToString());
269+
}
122270

271+
/// <summary>
272+
/// Verify that we handle the sequence of one supported command and one unsupported command.
273+
/// </summary>
274+
[Fact]
275+
public void VerifySupportedAndUnsupportedCommands()
276+
{
277+
var history = new List<string>()
278+
{
279+
"New-AzVM -Name hello -Location WestUS",
280+
"git status",
281+
};
282+
283+
_telemetryClient.RecordedSuggestion = null;
284+
_service.Commands = null;
285+
_service.History = null;
123286

124-
this._telemetryClient.RecordedSuggestion = null;
125-
this._service.IsPredictionRequested = false;
287+
_azPredictor.StartEarlyProcessing(history);
126288

127-
this._azPredictor.StartEarlyProcessing(new List<string> { input } );
289+
var maskedCommands = new List<string>()
290+
{
291+
AzPredictorConstants.CommandPlaceholder,
292+
"New-AzVM -Location *** -Name ***",
293+
};
128294

129-
Assert.True(this._service.IsPredictionRequested);
130-
Assert.NotNull(this._telemetryClient.RecordedSuggestion);
131-
Assert.Equal(expected, this._telemetryClient.RecordedSuggestion.HistoryLine);
295+
Assert.Equal(maskedCommands, _service.Commands);
296+
Assert.Equal(AzPredictorConstants.CommandPlaceholder, _telemetryClient.RecordedSuggestion.HistoryLine);
297+
Assert.Equal(history.First(), _service.History.ToString());
132298
}
133299

134300
/// <summary>
@@ -140,8 +306,8 @@ public void VerifySupportedCommandMasked()
140306
public void VerifySuggestion(string userInput)
141307
{
142308
var predictionContext = PredictionContext.Create(userInput);
143-
var expected = this._service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
144-
var actual = this._azPredictor.GetSuggestion(predictionContext, CancellationToken.None);
309+
var expected = _service.GetSuggestion(predictionContext.InputAst, 1, 1, CancellationToken.None);
310+
var actual = _azPredictor.GetSuggestion(predictionContext, CancellationToken.None);
145311

146312
Assert.Equal(expected.Count, actual.Count);
147313
Assert.Equal(expected.PredictiveSuggestions.First().SuggestionText, actual.First().SuggestionText);
@@ -154,7 +320,7 @@ public void VerifySuggestion(string userInput)
154320
public void VerifySuggestionOnIncompleteCommand()
155321
{
156322
// We need to get the suggestions for more than one. So we create a local version az predictor.
157-
var localAzPredictor = new AzPredictor(this._service, this._telemetryClient, new Settings()
323+
var localAzPredictor = new AzPredictor(_service, _telemetryClient, new Settings()
158324
{
159325
SuggestionCount = 7,
160326
MaxAllowedCommandDuplicate = 1,
@@ -170,7 +336,6 @@ public void VerifySuggestionOnIncompleteCommand()
170336
Assert.Equal(expected, actual.First().SuggestionText);
171337
}
172338

173-
174339
/// <summary>
175340
/// Verify when we cannot parse the user input correctly.
176341
/// </summary>
@@ -183,7 +348,7 @@ public void VerifySuggestionOnIncompleteCommand()
183348
public void VerifyMalFormattedCommandLine(string userInput)
184349
{
185350
var predictionContext = PredictionContext.Create(userInput);
186-
var actual = this._azPredictor.GetSuggestion(predictionContext, CancellationToken.None);
351+
var actual = _azPredictor.GetSuggestion(predictionContext, CancellationToken.None);
187352

188353
Assert.Empty(actual);
189354
}

0 commit comments

Comments
 (0)