Skip to content

Commit c0e0c1c

Browse files
committed
Revise intersection optimizations to include string, number and symbol
1 parent 7c512fb commit c0e0c1c

File tree

2 files changed

+64
-19
lines changed

2 files changed

+64
-19
lines changed

src/compiler/checker.ts

Lines changed: 62 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8533,6 +8533,15 @@ namespace ts {
85338533
return binarySearch(types, type, getTypeId, compareValues) >= 0;
85348534
}
85358535

8536+
function insertType(types: Type[], type: Type): boolean {
8537+
const index = binarySearch(types, type, getTypeId, compareValues);
8538+
if (index < 0) {
8539+
types.splice(~index, 0, type);
8540+
return true;
8541+
}
8542+
return false;
8543+
}
8544+
85368545
// Return true if the given intersection type contains
85378546
// more than one unit type or,
85388547
// an object type and a nullable type (null or undefined), or
@@ -8700,7 +8709,7 @@ namespace ts {
87008709
includes & TypeFlags.Undefined ? includes & TypeFlags.NonWideningType ? undefinedType : undefinedWideningType :
87018710
neverType;
87028711
}
8703-
return getUnionTypeFromSortedList(typeSet, includes & TypeFlags.NotUnit ? 0 : TypeFlags.UnionOfUnitTypes, aliasSymbol, aliasTypeArguments);
8712+
return getUnionTypeFromSortedList(typeSet, includes & TypeFlags.NotPrimitiveUnion ? 0 : TypeFlags.UnionOfPrimitiveTypes, aliasSymbol, aliasTypeArguments);
87048713
}
87058714

87068715
function getUnionTypePredicate(signatures: ReadonlyArray<Signature>): TypePredicate | undefined {
@@ -8823,26 +8832,62 @@ namespace ts {
88238832
}
88248833
}
88258834

8826-
// When intersecting unions of unit types we can simply intersect based on type identity.
8827-
// Here we remove all unions of unit types from the given list and replace them with a
8828-
// a single union containing an intersection of the unit types.
8829-
function intersectUnionsOfUnitTypes(types: Type[]) {
8830-
const unionIndex = findIndex(types, t => (t.flags & TypeFlags.UnionOfUnitTypes) !== 0);
8831-
const unionType = <UnionType>types[unionIndex];
8832-
let intersection = unionType.types;
8833-
let i = types.length - 1;
8834-
while (i > unionIndex) {
8835+
// Check that the given type has a match in every union. A given type is matched by
8836+
// an identical type, and a literal type is additionally matched by its corresponding
8837+
// primitive type.
8838+
function eachUnionContains(unionTypes: UnionType[], type: Type) {
8839+
for (const u of unionTypes) {
8840+
if (!containsType(u.types, type)) {
8841+
const primitive = type.flags & TypeFlags.StringLiteral ? stringType :
8842+
type.flags & TypeFlags.NumberLiteral ? numberType :
8843+
type.flags & TypeFlags.UniqueESSymbol ? esSymbolType :
8844+
undefined;
8845+
if (!primitive || !containsType(u.types, primitive)) {
8846+
return false;
8847+
}
8848+
}
8849+
}
8850+
return true;
8851+
}
8852+
8853+
// Remove all unions of primitive types from the given list and replace them with a
8854+
// single union containing an intersection of those primitive types.
8855+
function intersectUnionsOfPrimitiveTypes(types: Type[]) {
8856+
let unionTypes: UnionType[] | undefined;
8857+
const index = findIndex(types, t => (t.flags & TypeFlags.UnionOfPrimitiveTypes) !== 0);
8858+
let i = index + 1;
8859+
// Remove all but the first union of primitive types and collect them in
8860+
// the unionTypes array.
8861+
while (i < types.length) {
88358862
const t = types[i];
8836-
if (t.flags & TypeFlags.UnionOfUnitTypes) {
8837-
intersection = filter(intersection, u => containsType((<UnionType>t).types, u));
8863+
if (t.flags & TypeFlags.UnionOfPrimitiveTypes) {
8864+
(unionTypes || (unionTypes = [<UnionType>types[index]])).push(<UnionType>t);
88388865
orderedRemoveItemAt(types, i);
88398866
}
8840-
i--;
8867+
else {
8868+
i++;
8869+
}
88418870
}
8842-
if (intersection === unionType.types) {
8871+
// Return false if there was only one union of primitive types
8872+
if (!unionTypes) {
88438873
return false;
88448874
}
8845-
types[unionIndex] = getUnionTypeFromSortedList(intersection, unionType.flags & TypeFlags.UnionOfUnitTypes);
8875+
// We have more than one union of primitive types, now intersect them. For each
8876+
// type in each union we check if the type is matched in every union and if so
8877+
// we include it in the result.
8878+
const checked: Type[] = [];
8879+
const result: Type[] = [];
8880+
for (const u of unionTypes) {
8881+
for (const t of u.types) {
8882+
if (insertType(checked, t)) {
8883+
if (eachUnionContains(unionTypes, t)) {
8884+
insertType(result, t);
8885+
}
8886+
}
8887+
}
8888+
}
8889+
// Finally replace the first union with the result
8890+
types[index] = getUnionTypeFromSortedList(result, TypeFlags.UnionOfPrimitiveTypes);
88468891
return true;
88478892
}
88488893

@@ -8883,7 +8928,7 @@ namespace ts {
88838928
return typeSet[0];
88848929
}
88858930
if (includes & TypeFlags.Union) {
8886-
if (includes & TypeFlags.UnionOfUnitTypes && intersectUnionsOfUnitTypes(typeSet)) {
8931+
if (includes & TypeFlags.UnionOfPrimitiveTypes && intersectUnionsOfPrimitiveTypes(typeSet)) {
88878932
// When the intersection creates a reduced set (which might mean that *all* union types have
88888933
// disappeared), we restart the operation to get a new set of combined flags. Once we have
88898934
// reduced we'll never reduce again, so this occurs at most once.
@@ -13980,7 +14025,7 @@ namespace ts {
1398014025
if (type.flags & TypeFlags.Union) {
1398114026
const types = (<UnionType>type).types;
1398214027
const filtered = filter(types, f);
13983-
return filtered === types ? type : getUnionTypeFromSortedList(filtered, type.flags & TypeFlags.UnionOfUnitTypes);
14028+
return filtered === types ? type : getUnionTypeFromSortedList(filtered, type.flags & TypeFlags.UnionOfPrimitiveTypes);
1398414029
}
1398514030
return f(type) ? type : neverType;
1398614031
}

src/compiler/types.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3675,7 +3675,7 @@ namespace ts {
36753675
/* @internal */
36763676
FreshLiteral = 1 << 25, // Fresh literal or unique type
36773677
/* @internal */
3678-
UnionOfUnitTypes = 1 << 26, // Type is union of unit types
3678+
UnionOfPrimitiveTypes = 1 << 26, // Type is union of primitive types
36793679
/* @internal */
36803680
ContainsWideningType = 1 << 27, // Type is or contains undefined or null widening type
36813681
/* @internal */
@@ -3720,7 +3720,7 @@ namespace ts {
37203720
Narrowable = Any | Unknown | StructuredOrInstantiable | StringLike | NumberLike | BooleanLike | ESSymbol | UniqueESSymbol | NonPrimitive,
37213721
NotUnionOrUnit = Any | Unknown | ESSymbol | Object | NonPrimitive,
37223722
/* @internal */
3723-
NotUnit = Any | String | Number | Boolean | Enum | ESSymbol | Void | Never | StructuredOrInstantiable,
3723+
NotPrimitiveUnion = Any | Unknown | Enum | Void | Never | StructuredOrInstantiable,
37243724
/* @internal */
37253725
RequiresWidening = ContainsWideningType | ContainsObjectLiteral,
37263726
/* @internal */

0 commit comments

Comments
 (0)