Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix throw in generator comparer #76769

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,54 @@ [Attr] class D { }
Assert.Equal(e, runResults.Results.Single().Exception);
}

[Fact, WorkItem("https://github.com/dotnet/roslyn/issues/76765")]
public void Incremental_Generators_Exception_In_DefaultComparer()
{
var source = """
class C { }
""";
var parseOptions = TestOptions.RegularPreview;
Compilation compilation = CreateCompilation(source, options: TestOptions.DebugDllThrowing, parseOptions: parseOptions);
compilation.VerifyDiagnostics();

var syntaxTree = compilation.SyntaxTrees.Single();

var e = new InvalidOperationException("abc");
var generator = new PipelineCallbackGenerator((ctx) =>
{
var name = ctx.CompilationProvider.Select((c, _) => new ThrowWhenEqualsItem(e));
ctx.RegisterSourceOutput(name, (spc, n) => spc.AddSource("item.cs", "// generated"));
});

GeneratorDriver driver = CSharpGeneratorDriver.Create([generator.AsSourceGenerator()], parseOptions: parseOptions);
driver = driver.RunGenerators(compilation);
var runResults = driver.GetRunResult();

Assert.Empty(runResults.Diagnostics);
Assert.Equal("// generated", runResults.Results.Single().GeneratedSources.Single().SourceText.ToString());

compilation = compilation.ReplaceSyntaxTree(syntaxTree, CSharpSyntaxTree.ParseText("""
class D { }
""", parseOptions));
compilation.VerifyDiagnostics();

driver = driver.RunGenerators(compilation);
runResults = driver.GetRunResult();

VerifyGeneratorExceptionDiagnostic<InvalidOperationException>(runResults.Diagnostics.Single(), nameof(PipelineCallbackGenerator), "abc");
Assert.Empty(runResults.GeneratedTrees);
Assert.Equal(e, runResults.Results.Single().Exception);
}

class ThrowWhenEqualsItem(Exception toThrow)
{
readonly Exception _toThrow = toThrow;

public override bool Equals(object? obj) => throw _toThrow;

public override int GetHashCode() => throw new NotImplementedException();
}

[Fact]
public void Incremental_Generators_Exception_During_Execution_Doesnt_Produce_AnySource()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ public void Node_Builder_Can_Add_Entries_From_Previous_Table()
var previousTable = builder.ToImmutableAndFree();

builder = previousTable.ToBuilder(stepName: null, false);
builder.TryModifyEntries(ImmutableArray.Create(10, 11), EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified);
builder.TryModifyEntries(ImmutableArray.Create(10, 11), TimeSpan.Zero, default, EntryState.Modified);
builder.TryUseCachedEntries(TimeSpan.Zero, default, out var cachedEntries); // ((2, EntryState.Cached), (3, EntryState.Cached))
builder.TryModifyEntries(ImmutableArray.Create(20, 21, 22), EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified);
builder.TryModifyEntries(ImmutableArray.Create(20, 21, 22), TimeSpan.Zero, default, EntryState.Modified);
bool didRemoveEntries = builder.TryRemoveEntries(TimeSpan.Zero, default, out var removedEntries); //((6, EntryState.Removed))
var newTable = builder.ToImmutableAndFree();

Expand Down Expand Up @@ -185,9 +185,9 @@ public void Node_Builder_Handles_Modification_When_Both_Tables_Have_Empty_Entrie
AssertTableEntries(previousTable, expected);

builder = previousTable.ToBuilder(stepName: null, false);
Assert.True(builder.TryModifyEntries(ImmutableArray.Create(3, 2), EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntries(ImmutableArray<int>.Empty, EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntries(ImmutableArray.Create(3, 5), EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntries(ImmutableArray.Create(3, 2), TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntries(ImmutableArray<int>.Empty, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntries(ImmutableArray.Create(3, 5), TimeSpan.Zero, default, EntryState.Modified));

var newTable = builder.ToImmutableAndFree();

Expand All @@ -209,10 +209,10 @@ public void Node_Table_Doesnt_Modify_Single_Item_Multiple_Times_When_Same()
AssertTableEntries(previousTable, expected);

builder = previousTable.ToBuilder(stepName: null, false);
Assert.True(builder.TryModifyEntry(1, EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntry(2, EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntry(5, EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntry(4, EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntry(1, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntry(2, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntry(5, TimeSpan.Zero, default, EntryState.Modified));
Assert.True(builder.TryModifyEntry(4, TimeSpan.Zero, default, EntryState.Modified));

var newTable = builder.ToImmutableAndFree();

Expand All @@ -232,10 +232,10 @@ public void Node_Table_Caches_Previous_Object_When_Modification_Considered_Cache
var expected = ImmutableArray.Create((1, EntryState.Added, 0), (2, EntryState.Added, 0), (3, EntryState.Added, 0));
AssertTableEntries(previousTable, expected);

builder = previousTable.ToBuilder(stepName: null, false);
Assert.True(builder.TryModifyEntry(1, EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified)); // ((1, EntryState.Cached))
Assert.True(builder.TryModifyEntry(4, EqualityComparer<int>.Default, TimeSpan.Zero, default, EntryState.Modified)); // ((4, EntryState.Modified))
Assert.True(builder.TryModifyEntry(5, new LambdaComparer<int>((i, j) => true), TimeSpan.Zero, default, EntryState.Modified)); // ((3, EntryState.Cached))
builder = previousTable.ToBuilder(stepName: null, false, new LambdaComparer<int>((i, j) => i == 3 || i == j));
Assert.True(builder.TryModifyEntry(1, TimeSpan.Zero, default, EntryState.Modified)); // ((1, EntryState.Cached))
Assert.True(builder.TryModifyEntry(4, TimeSpan.Zero, default, EntryState.Modified)); // ((4, EntryState.Modified))
Assert.True(builder.TryModifyEntry(5, TimeSpan.Zero, default, EntryState.Modified)); // ((3, EntryState.Cached))
var newTable = builder.ToImmutableAndFree();

expected = ImmutableArray.Create((1, EntryState.Cached, 0), (4, EntryState.Modified, 0), (3, EntryState.Cached, 0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ internal sealed class BatchNode<TInput> : IIncrementalGeneratorNode<ImmutableArr
private static readonly string? s_tableType = typeof(ImmutableArray<TInput>).FullName;

private readonly IIncrementalGeneratorNode<TInput> _sourceNode;
private readonly IEqualityComparer<ImmutableArray<TInput>> _comparer;
private readonly IEqualityComparer<ImmutableArray<TInput>>? _comparer;
private readonly string? _name;

public BatchNode(IIncrementalGeneratorNode<TInput> sourceNode, IEqualityComparer<ImmutableArray<TInput>>? comparer = null, string? name = null)
{
_sourceNode = sourceNode;
_comparer = comparer ?? EqualityComparer<ImmutableArray<TInput>>.Default;
_comparer = comparer;
_name = name;
}

Expand Down Expand Up @@ -136,7 +136,7 @@ public NodeStateTable<ImmutableArray<TInput>> UpdateStateTable(DriverStateTable.
}
else if (!sourceTable.IsCached || !tableBuilder.TryUseCachedEntries(stopwatch.Elapsed, sourceInputs))
{
if (!tableBuilder.TryModifyEntry(sourceValues, _comparer, stopwatch.Elapsed, sourceInputs, EntryState.Modified))
if (!tableBuilder.TryModifyEntry(sourceValues, stopwatch.Elapsed, sourceInputs, EntryState.Modified))
{
tableBuilder.AddEntry(sourceValues, EntryState.Added, stopwatch.Elapsed, sourceInputs, EntryState.Added);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ public CombineNode(IIncrementalGeneratorNode<TInput1> input1, IIncrementalGenera
};

var entry = (entry1.Item, input2);
if (state != EntryState.Modified || _comparer is null || !tableBuilder.TryModifyEntry(entry, _comparer, stopwatch.Elapsed, stepInputs, state))
if (state != EntryState.Modified || _comparer is null || !tableBuilder.TryModifyEntry(entry, stopwatch.Elapsed, stepInputs, state))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does the check _comparer is null serve here? The null case means it's effectively the default comparer so why are we not trying to use that for TryModifyEntry?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an optimization. If we've already determined that one of the two sides is modified then the default comparer will always find the same result. However if there's a custom comparer it's possible that will override the result and return cached.

{
tableBuilder.AddEntry(entry, state, stopwatch.Elapsed, stepInputs, state);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ internal sealed class InputNode<T> : IIncrementalGeneratorNode<T>
private readonly Func<DriverStateTable.Builder, ImmutableArray<T>> _getInput;
private readonly Action<IIncrementalGeneratorOutputNode> _registerOutput;
private readonly IEqualityComparer<T> _inputComparer;
private readonly IEqualityComparer<T> _comparer;
private readonly IEqualityComparer<T>? _comparer;
private readonly string? _name;

public InputNode(Func<DriverStateTable.Builder, ImmutableArray<T>> getInput, IEqualityComparer<T>? inputComparer = null)
Expand All @@ -35,7 +35,7 @@ public InputNode(Func<DriverStateTable.Builder, ImmutableArray<T>> getInput, IEq
private InputNode(Func<DriverStateTable.Builder, ImmutableArray<T>> getInput, Action<IIncrementalGeneratorOutputNode>? registerOutput, IEqualityComparer<T>? inputComparer = null, IEqualityComparer<T>? comparer = null, string? name = null)
{
_getInput = getInput;
_comparer = comparer ?? EqualityComparer<T>.Default;
_comparer = comparer;
_inputComparer = inputComparer ?? EqualityComparer<T>.Default;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should _inputComparer be handled similarly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we control the input comparer and all the inputs it compares. If that's throwing it's a real product bug and we don't want to falsely attribute it to a generator.

_registerOutput = registerOutput ?? (o => throw ExceptionUtilities.Unreachable());
_name = name;
Expand Down Expand Up @@ -83,7 +83,7 @@ public NodeStateTable<T> UpdateStateTable(DriverStateTable.Builder graphState, N
// This allows us to correctly 'replace' items even when they aren't actually the same. In the case that the
// item really isn't modified, but a new item, we still function correctly as we mostly treat them the same,
// but will perform an extra comparison that is omitted in the pure 'added' case.
var modified = tableBuilder.TryModifyEntry(inputItems[itemIndex], _comparer, elapsedTime, noInputStepsStepInfo, EntryState.Modified);
var modified = tableBuilder.TryModifyEntry(inputItems[itemIndex], elapsedTime, noInputStepsStepInfo, EntryState.Modified);
Debug.Assert(modified);
itemsSet.Remove(inputItems[itemIndex]);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public NodeStateTable<T> AsCached()
public Builder ToBuilder(string? stepName, bool stepTrackingEnabled, IEqualityComparer<T>? equalityComparer = null, int? tableCapacity = null)
=> new(this, stepName, stepTrackingEnabled, equalityComparer, tableCapacity);

public NodeStateTable<T> CreateCachedTableWithUpdatedSteps<TInput>(NodeStateTable<TInput> inputTable, string? stepName, IEqualityComparer<T> equalityComparer)
public NodeStateTable<T> CreateCachedTableWithUpdatedSteps<TInput>(NodeStateTable<TInput> inputTable, string? stepName, IEqualityComparer<T>? equalityComparer)
{
Debug.Assert(inputTable.HasTrackedSteps && inputTable.IsCached);
NodeStateTable<T>.Builder builder = ToBuilder(stepName, stepTrackingEnabled: true, equalityComparer);
Expand Down Expand Up @@ -256,7 +256,7 @@ internal Builder(
_states = ArrayBuilder<TableEntry>.GetInstance(tableCapacity ?? previous.GetTotalEntryItemCount());
_previous = previous;
_name = name;
_equalityComparer = equalityComparer ?? EqualityComparer<T>.Default;
_equalityComparer = equalityComparer ?? WrappedUserComparer<T>.Default;
if (stepTrackingEnabled)
{
_steps = ArrayBuilder<IncrementalGeneratorRunStep>.GetInstance();
Expand Down Expand Up @@ -320,7 +320,7 @@ internal bool TryUseCachedEntries(TimeSpan elapsedTime, ImmutableArray<(Incremen
return true;
}

public bool TryModifyEntry(T value, IEqualityComparer<T> comparer, TimeSpan elapsedTime, ImmutableArray<(IncrementalGeneratorRunStep InputStep, int OutputIndex)> stepInputs, EntryState overallInputState)
public bool TryModifyEntry(T value, TimeSpan elapsedTime, ImmutableArray<(IncrementalGeneratorRunStep InputStep, int OutputIndex)> stepInputs, EntryState overallInputState)
{
if (!TryGetPreviousEntry(out var previousEntry))
{
Expand All @@ -335,13 +335,13 @@ public bool TryModifyEntry(T value, IEqualityComparer<T> comparer, TimeSpan elap
}

Debug.Assert(previousEntry.Count == 1);
var (chosen, state, _) = GetModifiedItemAndState(previousEntry.GetItem(0), value, comparer);
var (chosen, state, _) = GetModifiedItemAndState(previousEntry.GetItem(0), value);
_states.Add(new TableEntry(OneOrMany.Create(chosen), state));
RecordStepInfoForLastEntry(elapsedTime, stepInputs, overallInputState);
return true;
}

public bool TryModifyEntries(ImmutableArray<T> outputs, IEqualityComparer<T> comparer, TimeSpan elapsedTime, ImmutableArray<(IncrementalGeneratorRunStep InputStep, int OutputIndex)> stepInputs, EntryState overallInputState)
public bool TryModifyEntries(ImmutableArray<T> outputs, TimeSpan elapsedTime, ImmutableArray<(IncrementalGeneratorRunStep InputStep, int OutputIndex)> stepInputs, EntryState overallInputState)
{
// Semantics:
// For each item in the row, we compare with the new matching new value.
Expand Down Expand Up @@ -384,7 +384,7 @@ public bool TryModifyEntries(ImmutableArray<T> outputs, IEqualityComparer<T> com
var previousState = previousEntry.GetState(i);
var replacementItem = outputs[i];

var (chosenItem, state, chosePrevious) = GetModifiedItemAndState(previousItem, replacementItem, comparer);
var (chosenItem, state, chosePrevious) = GetModifiedItemAndState(previousItem, replacementItem);

if (builder != null)
{
Expand Down Expand Up @@ -433,9 +433,9 @@ public bool TryModifyEntries(ImmutableArray<T> outputs, IEqualityComparer<T> com
return true;
}

public bool TryModifyEntries(ImmutableArray<T> outputs, IEqualityComparer<T> comparer, TimeSpan elapsedTime, ImmutableArray<(IncrementalGeneratorRunStep InputStep, int OutputIndex)> stepInputs, EntryState overallInputState, out TableEntry entry)
public bool TryModifyEntries(ImmutableArray<T> outputs, TimeSpan elapsedTime, ImmutableArray<(IncrementalGeneratorRunStep InputStep, int OutputIndex)> stepInputs, EntryState overallInputState, out TableEntry entry)
{
if (!TryModifyEntries(outputs, comparer, elapsedTime, stepInputs, overallInputState))
if (!TryModifyEntries(outputs, elapsedTime, stepInputs, overallInputState))
{
entry = default;
return false;
Expand Down Expand Up @@ -554,11 +554,11 @@ public NodeStateTable<T> ToImmutableAndFree()
isCached: finalStates.All(static s => s.IsCached) && _previous.GetTotalEntryItemCount() == finalStates.Sum(static s => s.Count));
}

private static (T chosen, EntryState state, bool chosePrevious) GetModifiedItemAndState(T previous, T replacement, IEqualityComparer<T> comparer)
private (T chosen, EntryState state, bool chosePrevious) GetModifiedItemAndState(T previous, T replacement)
{
// when comparing an item to check if its modified we explicitly cache the *previous* item in the case where its
// considered to be equal. This ensures that subsequent comparisons are stable across future generation passes.
return comparer.Equals(previous, replacement)
return _equalityComparer.Equals(previous, replacement)
? (previous, EntryState.Cached, chosePrevious: true)
: (replacement, EntryState.Modified, chosePrevious: false);
}
Expand Down
Loading