Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/parser/cxx/ast_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,7 @@ auto ASTRewriter::instantiateClassTemplate(
ClassSymbol* classSymbol) -> ClassSymbol* {
auto templateDecl = classSymbol->templateDeclaration();

ClassSpecifierAST* classSpecifier =
ast_cast<ClassSpecifierAST>(classSymbol->declaration());

if (!classSpecifier) return nullptr;
if (!classSymbol->declaration()) return nullptr;

auto templateArguments =
make_substitution(unit, templateDecl, templateArgumentList);
Expand Down Expand Up @@ -128,6 +125,9 @@ auto ASTRewriter::instantiateClassTemplate(
return subst;
}

auto classSpecifier = ast_cast<ClassSpecifierAST>(classSymbol->declaration());
if (!classSpecifier) return nullptr;

auto parentScope = classSymbol->enclosingNonTemplateParametersScope();

auto rewriter = ASTRewriter{unit, parentScope, templateArguments};
Expand Down
10 changes: 5 additions & 5 deletions src/parser/cxx/ast_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ class ASTRewriter {
TemplateDeclarationAST* templateHead = nullptr)
-> DeclarationAST*;

[[nodiscard]] static auto make_substitution(
TranslationUnit* unit, TemplateDeclarationAST* templateDecl,
List<TemplateArgumentAST*>* templateArgumentList)
-> std::vector<TemplateArgument>;

private:
[[nodiscard]] auto templateArguments() const
-> const std::vector<TemplateArgument>& {
Expand All @@ -73,11 +78,6 @@ class ASTRewriter {
[[nodiscard]] auto restrictedToDeclarations() const -> bool;
void setRestrictedToDeclarations(bool restrictedToDeclarations);

[[nodiscard]] static auto make_substitution(
TranslationUnit* unit, TemplateDeclarationAST* templateDecl,
List<TemplateArgumentAST*>* templateArgumentList)
-> std::vector<TemplateArgument>;

// run on the base nodes
[[nodiscard]] auto unit(UnitAST* ast) -> UnitAST*;
[[nodiscard]] auto statement(StatementAST* ast) -> StatementAST*;
Expand Down
4 changes: 0 additions & 4 deletions src/parser/cxx/ast_rewriter_specifiers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -787,10 +787,6 @@ auto ASTRewriter::SpecifierVisitor::operator()(ClassSpecifierAST* ast)
classSymbol->setDeclaration(copy);
classSymbol->setTemplateDeclaration(templateHead);

if (templateHead) {
classSymbol->setTemplateParameters(binder()->currentTemplateParameters());
}

if (ast->symbol == rewrite.binder().instantiatingSymbol()) {
ast->symbol->addSpecialization(rewrite.templateArguments(), classSymbol);
} else {
Expand Down
160 changes: 112 additions & 48 deletions src/parser/cxx/binder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ void Binder::setScope(ScopeSymbol* scope) {
inTemplate_ = false;

for (auto current = scope_; current; current = current->parent()) {
if (current->isTemplateParameters()) {
if (auto params = current->templateParameters()) {
inTemplate_ = true;
break;
}
Expand Down Expand Up @@ -212,7 +212,6 @@ void Binder::bind(ElaboratedTypeSpecifierAST* ast, DeclSpecs& declSpecs,

classSymbol->setIsUnion(isUnion);
classSymbol->setName(name);
classSymbol->setTemplateParameters(currentTemplateParameters());
classSymbol->setTemplateDeclaration(declSpecs.templateHead);
declaringScope()->addSymbol(classSymbol);

Expand All @@ -230,42 +229,99 @@ void Binder::bind(ElaboratedTypeSpecifierAST* ast, DeclSpecs& declSpecs,
}

void Binder::bind(ClassSpecifierAST* ast, DeclSpecs& declSpecs) {
auto templateParameters = currentTemplateParameters();
auto check_optional_nested_name_specifier = [&] {
if (!ast->nestedNameSpecifier) return;

if (ast->nestedNameSpecifier) {
auto parent = ast->nestedNameSpecifier->symbol;

if (parent && parent->isClassOrNamespace()) {
setScope(static_cast<ScopeSymbol*>(parent));
if (!parent || !parent->isClassOrNamespace()) {
error(ast->nestedNameSpecifier->firstSourceLocation(),
"nested name specifier must be a class or namespace");
return;
}
}

auto className = get_name(control(), ast->unqualifiedId);
auto templateId = ast_cast<SimpleTemplateIdAST>(ast->unqualifiedId);
if (templateId) {
className = templateId->identifier;
}
setScope(static_cast<ScopeSymbol*>(parent));
};

auto location = ast->classLoc;
if (templateId) {
location = templateId->identifierLoc;
} else if (ast->unqualifiedId) {
location = ast->unqualifiedId->firstSourceLocation();
}
auto check_template_specialization = [&] {
auto templateId = ast_cast<SimpleTemplateIdAST>(ast->unqualifiedId);
if (!templateId) return false;

ClassSymbol* primaryTemplate = nullptr;
const auto location = templateId->identifierLoc;

if (templateId && scope()->isTemplateParameters()) {
for (auto candidate : declaringScope()->find(className) | views::classes) {
primaryTemplate = candidate;
ClassSymbol* primaryTemplateSymbol = nullptr;

for (auto candidate :
declaringScope()->find(templateId->identifier) | views::classes) {
primaryTemplateSymbol = candidate;
break;
}

if (!primaryTemplate) {
if (!primaryTemplateSymbol ||
!primaryTemplateSymbol->templateParameters()) {
error(location, std::format("specialization of undeclared template '{}'",
templateId->identifier->name()));
// return true;
}
}

std::vector<TemplateArgument> templateArguments;
ClassSymbol* specialization = nullptr;

if (primaryTemplateSymbol) {
templateArguments = ASTRewriter::make_substitution(
unit_, primaryTemplateSymbol->templateDeclaration(),
templateId->templateArgumentList);

specialization =
primaryTemplateSymbol
? primaryTemplateSymbol->findSpecialization(templateArguments)
: nullptr;

if (specialization) {
error(location, std::format("redefinition of specialization '{}'",
templateId->identifier->name()));
// return true;
}
}

const auto isUnion = ast->classKey == TokenKind::T_UNION;

auto classSymbol = control()->newClassSymbol(declaringScope(), location);
ast->symbol = classSymbol;

classSymbol->setIsUnion(isUnion);
classSymbol->setName(templateId->identifier);
ast->symbol->setDeclaration(ast);
ast->symbol->setFinal(ast->isFinal);

// if (declSpecs.templateHead) {
// warning(location, "setting template head");
// ast->symbol->setTemplateDeclaration(declSpecs.templateHead);
// }

declSpecs.setTypeSpecifier(ast);
declSpecs.setType(ast->symbol->type());

if (primaryTemplateSymbol) {
primaryTemplateSymbol->addSpecialization(std::move(templateArguments),
classSymbol);
}

return true;
};

check_optional_nested_name_specifier();

if (check_template_specialization()) return;

// get the component anme
const Identifier* className = nullptr;
if (auto nameId = ast_cast<NameIdAST>(ast->unqualifiedId))
className = nameId->identifier;

const auto location = ast->unqualifiedId
? ast->unqualifiedId->firstSourceLocation()
: ast->classLoc;

ClassSymbol* classSymbol = nullptr;

Expand All @@ -277,6 +333,9 @@ void Binder::bind(ClassSpecifierAST* ast, DeclSpecs& declSpecs) {
}

if (classSymbol && classSymbol->isComplete()) {
// not a template-id, but a class with the same name already exists
error(location,
std::format("redefinition of class '{}'", to_string(className)));
classSymbol = nullptr;
}

Expand All @@ -285,29 +344,22 @@ void Binder::bind(ClassSpecifierAST* ast, DeclSpecs& declSpecs) {
classSymbol = control()->newClassSymbol(scope(), location);
classSymbol->setIsUnion(isUnion);
classSymbol->setName(className);
classSymbol->setTemplateParameters(templateParameters);

if (!primaryTemplate) {
declaringScope()->addSymbol(classSymbol);
} else {
std::vector<TemplateArgument> arguments;
// TODO: parse template arguments
primaryTemplate->addSpecialization(arguments, classSymbol);
}
declaringScope()->addSymbol(classSymbol);
}

classSymbol->setDeclaration(ast);
ast->symbol = classSymbol;

ast->symbol->setDeclaration(ast);

if (declSpecs.templateHead) {
classSymbol->setTemplateDeclaration(declSpecs.templateHead);
ast->symbol->setTemplateDeclaration(declSpecs.templateHead);
}

classSymbol->setFinal(ast->isFinal);

ast->symbol = classSymbol;
ast->symbol->setFinal(ast->isFinal);

declSpecs.setTypeSpecifier(ast);
declSpecs.setType(classSymbol->type());
declSpecs.setType(ast->symbol->type());
}

void Binder::complete(ClassSpecifierAST* ast) {
Expand Down Expand Up @@ -417,7 +469,6 @@ auto Binder::declareTypeAlias(SourceLocation identifierLoc, TypeIdAST* typeId,
symbol->setName(name);

if (typeId) symbol->setType(typeId->type);
symbol->setTemplateParameters(currentTemplateParameters());

if (auto classType = type_cast<ClassType>(symbol->type())) {
auto classSymbol = classType->symbol();
Expand Down Expand Up @@ -565,7 +616,6 @@ void Binder::bind(ConceptDefinitionAST* ast) {

auto symbol = control()->newConceptSymbol(scope(), ast->identifierLoc);
symbol->setName(ast->identifier);
symbol->setTemplateParameters(templateParameters);

declaringScope()->addSymbol(symbol);
}
Expand Down Expand Up @@ -708,7 +758,6 @@ auto Binder::declareFunction(DeclaratorAST* declarator, const Decl& decl)
applySpecifiers(functionSymbol, decl.specs);
functionSymbol->setName(name);
functionSymbol->setType(type);
functionSymbol->setTemplateParameters(currentTemplateParameters());

if (isConstructor(functionSymbol)) {
auto enclosingClass = symbol_cast<ClassSymbol>(scope());
Expand Down Expand Up @@ -775,7 +824,6 @@ auto Binder::declareVariable(DeclaratorAST* declarator, const Decl& decl)
applySpecifiers(symbol, decl.specs);
symbol->setName(name);
symbol->setType(type);
symbol->setTemplateParameters(currentTemplateParameters());
declaringScope()->addSymbol(symbol);
return symbol;
}
Expand Down Expand Up @@ -890,13 +938,29 @@ auto Binder::resolve(NestedNameSpecifierAST* nestedNameSpecifier,
}

void Binder::bind(IdExpressionAST* ast) {
if (ast->unqualifiedId) {
auto name = get_name(control(), ast->unqualifiedId);
const Name* componentName = name;
if (auto templateId = name_cast<TemplateId>(name))
componentName = templateId->name();
ast->symbol = Lookup{scope()}(ast->nestedNameSpecifier, componentName);
if (!ast->unqualifiedId) {
error(ast->firstSourceLocation(),
"expected an unqualified identifier in id expression");
return;
}

auto name = get_name(control(), ast->unqualifiedId);

const Name* componentName = name;

if (auto templateId = name_cast<TemplateId>(name)) {
componentName = templateId->name();
}

if (ast->nestedNameSpecifier) {
if (!ast->nestedNameSpecifier->symbol) {
error(ast->nestedNameSpecifier->firstSourceLocation(),
"nested name specifier must be a class or namespace");
return;
}
}

ast->symbol = Lookup{scope()}(ast->nestedNameSpecifier, componentName);
}

auto Binder::getFunction(ScopeSymbol* scope, const Name* name, const Type* type)
Expand Down
26 changes: 26 additions & 0 deletions src/parser/cxx/external_name_encoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,31 @@ struct ExternalNameEncoder::EncodeUnqualifiedName {
ExternalNameEncoder& encoder;
Symbol* symbol = nullptr;

void encodeTemplateArguments(Symbol* symbol) {
if (!symbol) return;

std::span<const TemplateArgument> args;

if (auto classSymbol = symbol_cast<ClassSymbol>(symbol)) {
args = classSymbol->templateArguments();
}

if (args.empty()) return;

encoder.out("I");

for (const auto& arg : args) {
if (auto sym = std::get_if<Symbol*>(&arg)) {
auto type = (*sym)->type();
encoder.encodeType(type);
} else {
cxx_runtime_error("template argument not supported yet");
}
}

encoder.out("E");
}

void operator()(const Identifier* id) {
if (auto function = symbol_cast<FunctionSymbol>(symbol)) {
if (function->isConstructor()) {
Expand All @@ -338,6 +363,7 @@ struct ExternalNameEncoder::EncodeUnqualifiedName {
}

out(std::format("{}{}", id->name().length(), id->name()));
encodeTemplateArguments(symbol);
}

void operator()(const OperatorId* name) {
Expand Down
Loading
Loading