From 680b1a450713941fa8cb4b7c9ce9f03a10761c64 Mon Sep 17 00:00:00 2001 From: hfhbd Date: Tue, 21 Nov 2023 19:07:42 +0100 Subject: [PATCH 1/2] Replace forEach with for loop for better debugging --- .../sqldelight/dialect/api/SelectQueryable.kt | 8 +- .../sqldelight/core/SqlDelightEnvironment.kt | 35 +++-- .../core/compiler/DatabaseGenerator.kt | 122 ++++++++---------- .../core/compiler/ExecuteQueryGenerator.kt | 2 +- .../core/compiler/MutatorQueryGenerator.kt | 18 ++- .../core/compiler/QueriesTypeGenerator.kt | 14 +- .../core/compiler/QueryGenerator.kt | 24 ++-- .../core/compiler/QueryInterfaceGenerator.kt | 2 +- .../core/compiler/SelectQueryGenerator.kt | 8 +- .../core/compiler/SqlDelightCompiler.kt | 62 ++++----- .../core/compiler/TableInterfaceGenerator.kt | 11 +- .../core/compiler/model/BindableQuery.kt | 30 +++-- .../cash/sqldelight/core/lang/ParserUtil.kt | 2 +- .../sqldelight/core/lang/SqlDelightFile.kt | 4 +- .../core/lang/SqlDelightQueriesFile.kt | 4 +- .../core/lang/psi/ColumnTypeMixin.kt | 9 +- .../core/lang/util/InsertStmtUtil.kt | 2 +- .../sqldelight/core/lang/util/TreeUtil.kt | 46 ++++--- .../core/queries/InterfaceGeneration.kt | 2 +- .../core/tables/InterfaceGeneration.kt | 4 +- .../core/views/InterfaceGeneration.kt | 4 +- .../lang/SqlDelightFileViewProviderFactory.kt | 3 +- .../intellij/SqlDelightProjectTestCase.kt | 3 +- .../sqldelight/test/util/FixtureCompiler.kt | 5 +- 24 files changed, 211 insertions(+), 213 deletions(-) diff --git a/sqldelight-compiler/dialect/src/main/kotlin/app/cash/sqldelight/dialect/api/SelectQueryable.kt b/sqldelight-compiler/dialect/src/main/kotlin/app/cash/sqldelight/dialect/api/SelectQueryable.kt index 37dcc8c782e..6a670382cba 100644 --- a/sqldelight-compiler/dialect/src/main/kotlin/app/cash/sqldelight/dialect/api/SelectQueryable.kt +++ b/sqldelight-compiler/dialect/src/main/kotlin/app/cash/sqldelight/dialect/api/SelectQueryable.kt @@ -29,9 +29,9 @@ class SelectQueryable( PsiTreeUtil.getParentOfType(resolvedTable, Queryable::class.java)?.tableExposed() }.orEmpty() } - tablesSelected.forEach { - if (it.query.columns.flattenCompounded() == pureColumns) { - val table = it.query.table + for (tableSelected in tablesSelected) { + if (tableSelected.query.columns.flattenCompounded() == pureColumns) { + val table = tableSelected.query.table if (table is SqlViewName) { // check, if this view uses exactly 1 pure table and use this table, if found. val createViewStmt = table.nameIdentifier?.parentOfType()?.compoundSelectStmt @@ -42,7 +42,7 @@ class SelectQueryable( } } } - return@lazy it.tableName + return@lazy tableSelected.tableName } } diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/SqlDelightEnvironment.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/SqlDelightEnvironment.kt index 7cb0a6c14e1..6a1b4e54507 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/SqlDelightEnvironment.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/SqlDelightEnvironment.kt @@ -163,7 +163,7 @@ class SqlDelightEnvironment( if (it !is SqlDelightQueriesFile) return@forSourceFiles logger("----- START ${it.name} ms -------") val timeTaken = measureTimeMillis { - SqlDelightCompiler.writeInterfaces(module, dialect, it, writer) + SqlDelightCompiler.writeInterfaces(it, writer) sourceFile = it } logger("----- END ${it.name} in $timeTaken ms ------") @@ -198,25 +198,24 @@ class SqlDelightEnvironment( .map { localFileSystem.findFileByPath(it.absolutePath)!! } .map { psiManager.findDirectory(it)!! } .flatMap { directory: PsiDirectory -> directory.migrationFiles() } - migrationFiles.sortedBy { it.version } - .forEach { - val errorElements = ArrayList() - PsiTreeUtil.processElements(it) { element -> - when (element) { - is PsiErrorElement -> errorElements.add(element) - // Uncomment when sqm files understand their state of the world. - // is SqlAnnotatedElement -> element.annotate(annotationHolder) - } - return@processElements true - } - if (errorElements.isNotEmpty()) { - throw SqlDelightException( - "Error Reading ${it.name}:\n\n" + - errorElements.joinToString(separator = "\n") { errorMessage(it, it.errorDescription) }, - ) + for (it in migrationFiles.sortedBy { it.version }) { + val errorElements = ArrayList() + PsiTreeUtil.processElements(it) { element -> + when (element) { + is PsiErrorElement -> errorElements.add(element) + // Uncomment when sqm files understand their state of the world. + // is SqlAnnotatedElement -> element.annotate(annotationHolder) } - body(it) + return@processElements true } + if (errorElements.isNotEmpty()) { + throw SqlDelightException( + "Error Reading ${it.name}:\n\n" + + errorElements.joinToString(separator = "\n") { errorMessage(it, it.errorDescription) }, + ) + } + body(it) + } } private fun errorMessage(element: PsiElement, message: String): String = diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/DatabaseGenerator.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/DatabaseGenerator.kt index 85310734ec1..411d70b5d38 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/DatabaseGenerator.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/DatabaseGenerator.kt @@ -67,7 +67,7 @@ internal class DatabaseGenerator( val typeSpec = TypeSpec.interfaceBuilder(fileIndex.className) .addSuperinterface(if (generateAsync) SUSPENDING_TRANSACTER_TYPE else TRANSACTER_TYPE) - fileIndex.dependencies.forEach { + for (it in fileIndex.dependencies) { typeSpec.addSuperinterface(ClassName(it.packageName, it.className)) } @@ -84,18 +84,17 @@ internal class DatabaseGenerator( invoke.addParameter(dbParameter) invokeReturn.add("%N", dbParameter) - moduleFolders.flatMap { it.queryFiles() } + for (file in moduleFolders.flatMap { it.queryFiles() } .filterNot { it.isEmpty() } - .sortedBy { it.name } - .forEach { file -> - // queries property added to QueryWrapper type: - // val dataQueries = DataQueries(this, driver) - typeSpec.addProperty(file.queriesName, file.queriesType) - } + .sortedBy { it.name }) { + // queries property added to QueryWrapper type: + // val dataQueries = DataQueries(this, driver) + typeSpec.addProperty(file.queriesName, file.queriesType) + } - forAdapters { - invoke.addParameter(it.name, it.type) - invokeReturn.add(", %L", it.name) + for (adapter in adapters()) { + invoke.addParameter(adapter.name, adapter.type) + invokeReturn.add(", %L", adapter.name) } return typeSpec @@ -131,16 +130,11 @@ internal class DatabaseGenerator( .build() } - private fun forAdapters( - block: (PropertySpec) -> Unit, - ) { - sourceFolders - .flatMap { it.queryFiles() } - .flatMap { it.requiredAdapters } - .sortedBy { it.name } - .toSet() - .forEach(block) - } + private fun adapters(): Set = sourceFolders + .flatMap { it.queryFiles() } + .flatMap { it.requiredAdapters } + .sortedBy { it.name } + .toSet() fun type(): TypeSpec { val typeSpec = TypeSpec.classBuilder("${fileIndex.className}Impl") @@ -177,30 +171,29 @@ internal class DatabaseGenerator( .addParameter(DRIVER_NAME, DRIVER_TYPE) .addParameter(oldVersion) .addParameter(newVersion) - forAdapters { - constructor.addParameter(it.name, it.type) + for (adapter in adapters()) { + constructor.addParameter(adapter.name, adapter.type) } - sourceFolders.flatMap { it.queryFiles() } + for (file in sourceFolders.flatMap { it.queryFiles() } .filterNot { it.isEmpty() } - .sortedBy { it.name } - .forEach { file -> - var adapters = "" - if (file.requiredAdapters.isNotEmpty()) { - adapters = file.requiredAdapters.joinToString( - prefix = ", ", - transform = { it.name }, - ) - } - // queries property added to QueryWrapper type: - // val dataQueries = DataQueries(this, driver, transactions) - typeSpec.addProperty( - PropertySpec.builder(file.queriesName, file.queriesType) - .addModifiers(OVERRIDE) - .initializer("%T($DRIVER_NAME$adapters)", file.queriesType) - .build(), + .sortedBy { it.name }) { + var adapters = "" + if (file.requiredAdapters.isNotEmpty()) { + adapters = file.requiredAdapters.joinToString( + prefix = ", ", + transform = { it.name }, ) } + // queries property added to QueryWrapper type: + // val dataQueries = DataQueries(this, driver, transactions) + typeSpec.addProperty( + PropertySpec.builder(file.queriesName, file.queriesType) + .addModifiers(OVERRIDE) + .initializer("%T($DRIVER_NAME$adapters)", file.queriesType) + .build(), + ) + } if (generateAsync) { createFunction.beginControlFlow("return %T", ASYNC_RESULT_TYPE) @@ -221,39 +214,36 @@ internal class DatabaseGenerator( .sortedBy { it.order } // Derive the schema from migration files. - orderedMigrations.flatMap { it.sqlStatements() } - .filter { it.isSchema() } - .forEach { - val statement = - if (generateAsync) "$DRIVER_NAME.execute(null, %L, 0).await()" else "$DRIVER_NAME.execute(null, %L, 0)" - createFunction.addStatement(statement, it.rawSqlText().toCodeLiteral()) - } + for (it in orderedMigrations.flatMap { it.sqlStatements() } + .filter { it.isSchema() }) { + val statement = + if (generateAsync) "$DRIVER_NAME.execute(null, %L, 0).await()" else "$DRIVER_NAME.execute(null, %L, 0)" + createFunction.addStatement(statement, it.rawSqlText().toCodeLiteral()) + } } var maxVersion = 1L val hasMigrations = sourceFolders.flatMap { it.migrationFiles() }.isNotEmpty() - sourceFolders.flatMap { it.migrationFiles() } - .sortedBy { it.version } - .forEach { migrationFile -> - try { - maxVersion = maxOf(maxVersion, migrationFile.version + 1) - } catch (e: Throwable) { - throw SqlDelightException("Migration files can only have versioned names (1.sqm, 2.sqm, etc)") - } - migrateFunction.beginControlFlow( - "if (%N <= ${migrationFile.version} && %N > ${migrationFile.version})", - oldVersion, - newVersion, + for (migrationFile in sourceFolders.flatMap { it.migrationFiles() }.sortedBy { it.version }) { + try { + maxVersion = maxOf(maxVersion, migrationFile.version + 1) + } catch (e: Throwable) { + throw SqlDelightException("Migration files can only have versioned names (1.sqm, 2.sqm, etc)") + } + migrateFunction.beginControlFlow( + "if (%N <= ${migrationFile.version} && %N > ${migrationFile.version})", + oldVersion, + newVersion, + ) + for (sqlStmt in migrationFile.sqlStatements()) { + migrateFunction.addStatement( + if (generateAsync) "$DRIVER_NAME.execute(null, %S, 0).await()" else "$DRIVER_NAME.execute(null, %S, 0)", + sqlStmt.rawSqlText(), ) - migrationFile.sqlStatements().forEach { - migrateFunction.addStatement( - if (generateAsync) "$DRIVER_NAME.execute(null, %S, 0).await()" else "$DRIVER_NAME.execute(null, %S, 0)", - it.rawSqlText(), - ) - } - migrateFunction.endControlFlow() } + migrateFunction.endControlFlow() + } if (!generateAsync) { createFunction.addStatement("return %T", UNIT_RESULT_TYPE) diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/ExecuteQueryGenerator.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/ExecuteQueryGenerator.kt index 391b485c038..21caea65a23 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/ExecuteQueryGenerator.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/ExecuteQueryGenerator.kt @@ -57,7 +57,7 @@ open class ExecuteQueryGenerator( CodeBlock.builder() .beginControlFlow("notifyQueries(%L) { emit ->", query.id) .apply { - tablesUpdated.sortedBy { it.name }.forEach { + for (it in tablesUpdated.sortedBy { it.name }) { addStatement("emit(\"${it.name}\")") } } diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/MutatorQueryGenerator.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/MutatorQueryGenerator.kt index 77d2de40709..dc0e02d6e2c 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/MutatorQueryGenerator.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/MutatorQueryGenerator.kt @@ -23,15 +23,22 @@ class MutatorQueryGenerator( val tablesAffected = query.tablesAffected.toMutableList() if (foreignKeyCascadeCheck != null) { - psiFile.sqlStmtList?.stmtList?.mapNotNull { it.createTableStmt }?.forEach { table -> - val effected = table.findChildrenOfType().any { - (it.foreignTable.name in query.tablesAffected.map { it.name }) && it.node.findChildByType(foreignKeyCascadeCheck) != null + val createTableStmt = psiFile.sqlStmtList?.stmtList?.mapNotNull { it.createTableStmt } + if (createTableStmt != null) { + for (table in createTableStmt) { + val effected = table.findChildrenOfType().any { + (it.foreignTable.name in query.tablesAffected.map { it.name }) && it.node.findChildByType( + foreignKeyCascadeCheck, + ) != null + } + if (effected) { + tablesAffected.add(TableNameElement.CreateTableName(table.tableName)) + } } - if (effected) tablesAffected.add(TableNameElement.CreateTableName(table.tableName)) } } - psiFile.triggers.forEach { trigger -> + for (trigger in psiFile.triggers) { if (trigger.tableName?.name in query.tablesAffected.map { it.name }) { val triggered = when (query) { is NamedMutator.Delete -> trigger.childOfType(SqlTypes.DELETE) != null @@ -39,6 +46,7 @@ class MutatorQueryGenerator( trigger.childOfType(SqlTypes.INSERT) != null || (query.hasUpsertClause && trigger.childOfType(SqlTypes.UPDATE) != null) } + is NamedMutator.Update -> { val columns = trigger.columnNameList.map { it.name } val updateColumns = query.update.updateStmtSubsequentSetterList.map { it.columnName?.name } + diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueriesTypeGenerator.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueriesTypeGenerator.kt index 16d8d7af781..bd85450044e 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueriesTypeGenerator.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueriesTypeGenerator.kt @@ -8,15 +8,11 @@ import app.cash.sqldelight.core.lang.SUSPENDING_TRANSACTER_IMPL_TYPE import app.cash.sqldelight.core.lang.SqlDelightQueriesFile import app.cash.sqldelight.core.lang.TRANSACTER_IMPL_TYPE import app.cash.sqldelight.core.lang.queriesType -import app.cash.sqldelight.dialect.api.SqlDelightDialect -import com.intellij.openapi.module.Module import com.squareup.kotlinpoet.FunSpec import com.squareup.kotlinpoet.TypeSpec class QueriesTypeGenerator( - private val module: Module, private val file: SqlDelightQueriesFile, - private val dialect: SqlDelightDialect, ) { private val generateAsync = file.generateAsync @@ -30,7 +26,7 @@ class QueriesTypeGenerator( * transactions: ThreadLocal * ) : TransacterImpl(driver, transactions) */ - fun generateType(packageName: String): TypeSpec? { + fun generateType(): TypeSpec? { if (file.isEmpty()) { return null } @@ -46,12 +42,12 @@ class QueriesTypeGenerator( // Add any required adapters. // private val tableAdapter: Table.Adapter - file.requiredAdapters.forEach { + for (it in file.requiredAdapters) { type.addProperty(it) constructor.addParameter(it.name, it.type) } - file.namedQueries.forEach { query -> + for (query in file.namedQueries) { tryWithElement(query.select) { val generator = SelectQueryGenerator(query) @@ -67,11 +63,11 @@ class QueriesTypeGenerator( } } - file.namedMutators.forEach { mutator -> + for (mutator in file.namedMutators) { type.addExecute(mutator) } - file.namedExecutes.forEach { execute -> + for (execute in file.namedExecutes) { type.addExecute(execute) } diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueryGenerator.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueryGenerator.kt index 161883124d5..acbcb313209 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueryGenerator.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueryGenerator.kt @@ -88,23 +88,22 @@ abstract class QueryGenerator( handledArrayArgs: Set, id: Int, ): Pair> { - val dialectPreparedStatementType = if (generateAsync) dialect.asyncRuntimeTypes.preparedStatementType else dialect.runtimeTypes.preparedStatementType + val dialectPreparedStatementType = + if (generateAsync) dialect.asyncRuntimeTypes.preparedStatementType else dialect.runtimeTypes.preparedStatementType val result = CodeBlock.builder() val positionToArgument = mutableListOf>() val seenArgs = mutableSetOf() val duplicateTypes = mutableSetOf() - query.arguments.forEach { argument -> + for (argument in query.arguments) { if (argument.bindArgs.isNotEmpty()) { - argument.bindArgs - .filter { PsiTreeUtil.isAncestor(statement, it, true) } - .forEach { bindArg -> - if (!seenArgs.add(argument)) { - duplicateTypes.add(argument.type) - } - positionToArgument.add(Triple(bindArg.node.textRange.startOffset, argument, bindArg)) + for (bindArg in argument.bindArgs.filter { PsiTreeUtil.isAncestor(statement, it, true) }) { + if (!seenArgs.add(argument)) { + duplicateTypes.add(argument.type) } + positionToArgument.add(Triple(bindArg.node.textRange.startOffset, argument, bindArg)) + } } else { positionToArgument.add(Triple(0, argument, null)) } @@ -119,7 +118,9 @@ abstract class QueryGenerator( val seenArrayArguments = mutableSetOf() val argumentNameAllocator = NameAllocator().apply { - query.arguments.forEach { newName(it.type.name) } + for (it in query.arguments) { + newName(it.type.name) + } } // A list of [SqlBindExpr] in order of appearance in the query. @@ -140,8 +141,9 @@ abstract class QueryGenerator( extractedVariables[type] = variableName bindStatements.add("val %N = $encodedJavaType\n", variableName) } + // For each argument in the sql - orderedBindArgs.forEach { (_, argument, bindArg) -> + for ((_, argument, bindArg) in orderedBindArgs) { val type = argument.type // Need to replace the single argument with a group of indexed arguments, calculated at // runtime from the list parameter: diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueryInterfaceGenerator.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueryInterfaceGenerator.kt index b846547ab4d..c9f70f1fe8f 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueryInterfaceGenerator.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueryInterfaceGenerator.kt @@ -29,7 +29,7 @@ class QueryInterfaceGenerator(val query: NamedQuery) { val constructor = FunSpec.constructorBuilder() - query.resultColumns.forEach { + for (it in query.resultColumns) { val javaType = it.javaType val typeWithoutAnnotations = javaType.copy(annotations = emptyList()) typeSpec.addProperty( diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/SelectQueryGenerator.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/SelectQueryGenerator.kt index 2de2101439d..024ba6d82a5 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/SelectQueryGenerator.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/SelectQueryGenerator.kt @@ -107,7 +107,7 @@ class SelectQueryGenerator( private fun defaultResultTypeFunctionInterface(params: List>): FunSpec.Builder { val function = FunSpec.builder(query.name) .also(this::addJavadoc) - params.forEach { (name, type) -> + for ((name, type) in params) { function.addParameter(name, type) } return function @@ -117,7 +117,7 @@ class SelectQueryGenerator( private fun customResultTypeFunctionInterface(): FunSpec.Builder { val function = FunSpec.builder(query.name).also(::addJavadoc) - query.parameters.forEach { + for (it in query.parameters) { // Adds each sqlite parameter to the argument list: // fun selectForId(<>, <>, ...) function.addParameter(it.name, it.argumentType()) @@ -166,7 +166,7 @@ class SelectQueryGenerator( val function = customResultTypeFunctionInterface() val dialectCursorType = if (generateAsync) dialect.asyncRuntimeTypes.cursorType else dialect.runtimeTypes.cursorType - query.resultColumns.forEach { resultColumn -> + for (resultColumn in query.resultColumns) { (listOf(resultColumn) + resultColumn.assumedCompatibleTypes) .takeIf { it.size > 1 } ?.map { assumedCompatibleType -> @@ -316,7 +316,7 @@ class SelectQueryGenerator( .addCode(executeBlock()) // For each bind argument the query has. - query.parameters.forEach { parameter -> + for (parameter in query.parameters) { // Add the argument as a constructor property. (Used later to figure out if query dirtied) // val id: Int queryType.addProperty( diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/SqlDelightCompiler.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/SqlDelightCompiler.kt index c9d6151e7a6..6c3e1c0e376 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/SqlDelightCompiler.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/SqlDelightCompiler.kt @@ -24,7 +24,6 @@ import app.cash.sqldelight.core.lang.SqlDelightQueriesFile import app.cash.sqldelight.core.lang.queriesName import app.cash.sqldelight.core.lang.util.sqFile import app.cash.sqldelight.dialect.api.SelectQueryable -import app.cash.sqldelight.dialect.api.SqlDelightDialect import com.alecstrong.sql.psi.core.psi.InvalidElementDetectedException import com.alecstrong.sql.psi.core.psi.NamedElement import com.alecstrong.sql.psi.core.psi.SqlCreateViewStmt @@ -41,15 +40,13 @@ private typealias FileAppender = (fileName: String) -> Appendable object SqlDelightCompiler { fun writeInterfaces( - module: Module, - dialect: SqlDelightDialect, file: SqlDelightQueriesFile, output: FileAppender, ) { try { writeTableInterfaces(file, output) writeQueryInterfaces(file, output) - writeQueries(module, dialect, file, output) + writeQueries(file, output) } catch (e: InvalidElementDetectedException) { // It's okay if compilation is cut short, we can just quit out. } @@ -89,7 +86,7 @@ object SqlDelightCompiler { .addType(databaseImplementationType) .build() - fileIndex.outputDirectory(sourceFile).forEach { outputDirectory -> + for (outputDirectory in fileIndex.outputDirectory(sourceFile)) { val packageDirectory = "$outputDirectory/${packageName.replace(".", "/")}" fileSpec.writeToAndClose(output("$packageDirectory/${databaseImplementationType.name}.kt")) } @@ -108,15 +105,14 @@ object SqlDelightCompiler { // TODO: Remove these when kotlinpoet supports top level types. .addImport("$packageName.$implementationFolder", "newInstance", "schema") .apply { - var index = 0 - fileIndex.dependencies.forEach { - addAliasedImport(ClassName(it.packageName, it.className), "${it.className}${index++}") + for ((index, it) in fileIndex.dependencies.withIndex()) { + addAliasedImport(ClassName(it.packageName, it.className), "${it.className}$index") } } .addType(queryWrapperType) .build() - fileIndex.outputDirectory(sourceFile).forEach { outputDirectory -> + for (outputDirectory in fileIndex.outputDirectory(sourceFile)) { val packageDirectory = "$outputDirectory/${packageName.replace(".", "/")}" fileSpec.writeToAndClose(output("$packageDirectory/${queryWrapperType.name}.kt")) } @@ -128,16 +124,16 @@ object SqlDelightCompiler { includeAll: Boolean = false, ) { val packageName = file.packageName ?: return - file.tables(includeAll).forEach { query -> + for (query in file.tables(includeAll)) { val statement = query.tableName.parent if (statement is SqlCreateViewStmt && statement.compoundSelectStmt != null) { listOf(NamedQuery(allocateName(statement.viewName), SelectQueryable(statement.compoundSelectStmt!!))) .writeQueryInterfaces(file, output) - return@forEach + continue } - if (statement is SqlCreateVirtualTableStmt) return@forEach + if (statement is SqlCreateVirtualTableStmt) continue val fileSpec = FileSpec.builder(packageName, allocateName(query.tableName)) .apply { @@ -148,8 +144,11 @@ object SqlDelightCompiler { } .build() - statement.sqFile().generatedDirectories?.forEach { directory -> - fileSpec.writeToAndClose(output("$directory/${allocateName(query.tableName).capitalize()}.kt")) + val generatedDirectories = statement.sqFile().generatedDirectories + if (generatedDirectories != null) { + for (directory in generatedDirectories) { + fileSpec.writeToAndClose(output("$directory/${allocateName(query.tableName).capitalize()}.kt")) + } } } } @@ -162,21 +161,22 @@ object SqlDelightCompiler { } internal fun writeQueries( - module: Module, - dialect: SqlDelightDialect, file: SqlDelightQueriesFile, output: FileAppender, ) { val packageName = file.packageName ?: return - val queriesType = QueriesTypeGenerator(module, file, dialect) - .generateType(packageName) ?: return + val queriesType = QueriesTypeGenerator(file) + .generateType() ?: return val fileSpec = FileSpec.builder(packageName, file.queriesName.capitalize()) .addType(queriesType) .build() - file.generatedDirectories?.forEach { directory -> - fileSpec.writeToAndClose(output("$directory/${queriesType.name}.kt")) + val generatedDirectories = file.generatedDirectories + if (generatedDirectories != null) { + for (directory in generatedDirectories) { + fileSpec.writeToAndClose(output("$directory/${queriesType.name}.kt")) + } } } @@ -185,21 +185,23 @@ object SqlDelightCompiler { } private fun List.writeQueryInterfaces(file: SqlDelightFile, output: FileAppender) { - return filter { tryWithElement(it.select) { it.needsInterface() } == true } - .forEach { namedQuery -> - val fileSpec = FileSpec.builder(namedQuery.interfaceType.packageName, namedQuery.name) - .apply { - tryWithElement(namedQuery.select) { - val generator = QueryInterfaceGenerator(namedQuery) - addType(generator.kotlinImplementationSpec()) - } + for (namedQuery in filter { tryWithElement(it.select) { it.needsInterface() } == true }) { + val fileSpec = FileSpec.builder(namedQuery.interfaceType.packageName, namedQuery.name) + .apply { + tryWithElement(namedQuery.select) { + val generator = QueryInterfaceGenerator(namedQuery) + addType(generator.kotlinImplementationSpec()) } - .build() + } + .build() - file.generatedDirectories(namedQuery.interfaceType.packageName)?.forEach { directory -> + val generatedDirectories = file.generatedDirectories(namedQuery.interfaceType.packageName) + if (generatedDirectories != null) { + for (directory in generatedDirectories) { fileSpec.writeToAndClose(output("$directory/${namedQuery.name.capitalize()}.kt")) } } + } } private val NamedElement.normalizedName: String diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/TableInterfaceGenerator.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/TableInterfaceGenerator.kt index daa637a9257..527cdf5f6e5 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/TableInterfaceGenerator.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/TableInterfaceGenerator.kt @@ -51,7 +51,7 @@ internal class TableInterfaceGenerator(private val table: LazyQuery) { val constructor = FunSpec.constructorBuilder() - table.query.columns.map { it.element as NamedElement }.forEach { column -> + for (column in table.query.columns.map { it.element as NamedElement }) { val columnName = allocateName(column) val columnDef = column.columnDefSource()!! val columnType = columnDef.columnType as ColumnTypeMixin @@ -95,15 +95,14 @@ internal class TableInterfaceGenerator(private val table: LazyQuery) { ) } - table.query.columns + for (it in table.query.columns .mapNotNull { column -> val columnDef = (column.element as NamedElement).columnDefSource()!! val columnType = columnDef.columnType as ColumnTypeMixin columnType.valueClass() - } - .forEach { - typeSpec.addType(it) - } + }) { + typeSpec.addType(it) + } return typeSpec .primaryConstructor(constructor.build()) diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/BindableQuery.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/BindableQuery.kt index 8d1dd5b895d..d3bde850c9f 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/BindableQuery.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/BindableQuery.kt @@ -92,28 +92,30 @@ abstract class BindableQuery( val manuallyNamedIndexes = mutableSetOf() val namesSeen = mutableSetOf() var maxIndexSeen = 0 - statement.findChildrenOfType().forEach { bindArg -> + for (bindArg in statement.findChildrenOfType()) { val bindParameter = bindArg.bindParameter if (bindParameter is BindParameterMixin && bindParameter.text != "DEFAULT") { - bindParameter.node.findChildByType(SqlTypes.DIGIT)?.text?.toInt()?.let { index -> - if (!indexesSeen.add(index)) { - result.findAndReplace(bindArg, index) { it.index == index } - return@forEach + val bindIndex = bindParameter.node.findChildByType(SqlTypes.DIGIT)?.text?.toInt() + if (bindIndex != null) { + if (!indexesSeen.add(bindIndex)) { + result.findAndReplace(bindArg, bindIndex) { it.index == bindIndex } + continue } - maxIndexSeen = maxOf(maxIndexSeen, index) - result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg))) - return@forEach + maxIndexSeen = maxOf(maxIndexSeen, bindIndex) + result.add(Argument(bindIndex, typeResolver.argumentType(bindArg), mutableListOf(bindArg))) + continue } - bindParameter.identifier?.let { - if (!namesSeen.add(it.text)) { - result.findAndReplace(bindArg) { (_, type, _) -> type.name == it.text } - return@forEach + val identifier = bindParameter.identifier + if (identifier != null) { + if (!namesSeen.add(identifier.text)) { + result.findAndReplace(bindArg) { (_, type, _) -> type.name == identifier.text } + continue } val index = ++maxIndexSeen indexesSeen.add(index) manuallyNamedIndexes.add(index) - result.add(Argument(index, typeResolver.argumentType(bindArg).copy(name = it.text), mutableListOf(bindArg))) - return@forEach + result.add(Argument(index, typeResolver.argumentType(bindArg).copy(name = identifier.text), mutableListOf(bindArg))) + continue } val index = ++maxIndexSeen indexesSeen.add(index) diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/ParserUtil.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/ParserUtil.kt index e06264d5674..dcf7f6e4db8 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/ParserUtil.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/ParserUtil.kt @@ -21,7 +21,7 @@ internal class ParserUtil { SqldelightParserUtil.reset() newDialect.setup() - ServiceLoader.load(SqlDelightModule::class.java, newDialect::class.java.classLoader).forEach { + for (it in ServiceLoader.load(SqlDelightModule::class.java, newDialect::class.java.classLoader)) { it.setup() } SqldelightParserUtil.overrideSqlParser() diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/SqlDelightFile.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/SqlDelightFile.kt index 6bed64635f6..dceeb80a0fc 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/SqlDelightFile.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/SqlDelightFile.kt @@ -46,8 +46,8 @@ abstract class SqlDelightFile( internal val typeResolver: TypeResolver by lazy { var parentResolver: TypeResolver = AnsiSqlTypeResolver - ServiceLoader.load(SqlDelightModule::class.java, dialect::class.java.classLoader).forEach { - parentResolver = it.typeResolver(parentResolver) + for (sqldelightModule in ServiceLoader.load(SqlDelightModule::class.java, dialect::class.java.classLoader)) { + parentResolver = sqldelightModule.typeResolver(parentResolver) } dialect.typeResolver(parentResolver) } diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/SqlDelightQueriesFile.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/SqlDelightQueriesFile.kt index 6211f845e6a..5884685dbb3 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/SqlDelightQueriesFile.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/SqlDelightQueriesFile.kt @@ -138,13 +138,13 @@ class SqlDelightQueriesFile( val module = module ?: return fun PsiDirectory.iterateSqlFiles() { - children.forEach { + for (it in children) { if (it is PsiDirectory) it.iterateSqlFiles() if (it is SqlDelightQueriesFile) block(it) } } - SqlDelightFileIndex.getInstance(module).sourceFolders(this).forEach { dir -> + for (dir in SqlDelightFileIndex.getInstance(module).sourceFolders(this)) { dir.iterateSqlFiles() } } diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/psi/ColumnTypeMixin.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/psi/ColumnTypeMixin.kt index 95b10431d1a..c91e4493f0c 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/psi/ColumnTypeMixin.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/psi/ColumnTypeMixin.kt @@ -168,7 +168,7 @@ internal abstract class ColumnTypeMixin( } private fun SqlDelightJavaType.type(): ClassName { - parentOfType()!!.importStmtList.importStmtList.forEach { import -> + for (import in parentOfType()!!.importStmtList.importStmtList) { val typePrefix = text.substringBefore('.') if (import.javaType.text.endsWith(".$typePrefix")) { return text.split(".").drop(1).fold(import.javaType.type()) { current, nested -> @@ -232,10 +232,9 @@ internal abstract class ColumnTypeMixin( } } val children = node.getChildren(TokenSet.create(SqlTypes.ID)) - children.filter { (it.text == "as" || it.text == "As") && it.prevVisibleSibling?.psi is SqlTypeName } - .forEach { - annotationHolder.createErrorAnnotation(it.psi, "Expected 'AS', got '${it.text}'") - } + for (it in children.filter { (it.text == "as" || it.text == "As") && it.prevVisibleSibling?.psi is SqlTypeName }) { + annotationHolder.createErrorAnnotation(it.psi, "Expected 'AS', got '${it.text}'") + } } internal inner class ValueTypeDialectType( diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/InsertStmtUtil.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/InsertStmtUtil.kt index c3b907f7ff4..464dc86b76b 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/InsertStmtUtil.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/InsertStmtUtil.kt @@ -13,7 +13,7 @@ internal val SqlInsertStmt.columns: List .map { (it.element as NamedElement) } if (columnNameList.isEmpty()) return columns - val columnMap = linkedMapOf(*columns.map { it.name to it }.toTypedArray()) + val columnMap = columns.associateBy { it.name } return columnNameList.mapNotNull { columnMap[it.name] } } diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/TreeUtil.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/TreeUtil.kt index 4ccbe805dcf..a8f2e970265 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/TreeUtil.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/TreeUtil.kt @@ -21,6 +21,7 @@ import app.cash.sqldelight.core.lang.SqlDelightQueriesFile import app.cash.sqldelight.core.lang.acceptsTableInterface import app.cash.sqldelight.core.lang.psi.ColumnTypeMixin import app.cash.sqldelight.core.lang.psi.InsertStmtValuesMixin +import app.cash.sqldelight.core.lang.shouldInferColumns import app.cash.sqldelight.dialect.api.ExposableType import app.cash.sqldelight.dialect.api.IntermediateType import app.cash.sqldelight.dialect.api.PrimitiveType @@ -133,7 +134,7 @@ fun PsiElement.childOfType(types: TokenSet): PsiElement? { } fun ASTNode.findChildRecursive(type: IElementType): ASTNode? { - getChildren(null).forEach { + for (it in getChildren(null)) { if (it.elementType == type) return it it.findChildByType(type)?.let { return it } } @@ -217,25 +218,28 @@ fun Collection.forInitializationStatements( val creators = ArrayList() val miscellanious = ArrayList() - forEach { file -> - file.sqlStatements() - .filter { (label, _) -> label.name == null } - .forEach { (_, sqlStatement) -> - when { - sqlStatement.createTableStmt != null -> tables.add(sqlStatement.createTableStmt!!) - sqlStatement.createViewStmt != null -> views.add(sqlStatement.createViewStmt!!) - sqlStatement.createTriggerStmt != null -> creators.add(sqlStatement.createTriggerStmt!!) - sqlStatement.createIndexStmt != null -> creators.add(sqlStatement.createIndexStmt!!) - else -> miscellanious.add(sqlStatement) - } + for (file in this) { + for ((_, sqlStatement) in file.sqlStatements() + .filter { (label, _) -> label.name == null }) { + when { + sqlStatement.createTableStmt != null -> tables.add(sqlStatement.createTableStmt!!) + sqlStatement.createViewStmt != null -> views.add(sqlStatement.createViewStmt!!) + sqlStatement.createTriggerStmt != null -> creators.add(sqlStatement.createTriggerStmt!!) + sqlStatement.createIndexStmt != null -> creators.add(sqlStatement.createIndexStmt!!) + else -> miscellanious.add(sqlStatement) } + } } when (allowReferenceCycles) { // If we allow cycles, don't attempt to order the table creation statements. The dialect // is permissive. - true -> tables.forEach { body(it.rawSqlText()) } - false -> tables.buildGraph().topological().forEach { body(it.rawSqlText()) } + true -> for (it in tables) { + body(it.rawSqlText()) + } + false -> for (it in tables.buildGraph().topological()) { + body(it.rawSqlText()) + } } views.orderStatements( @@ -245,18 +249,22 @@ fun Collection.forInitializationStatements( body, ) - creators.forEach { body(it.rawSqlText()) } - miscellanious.forEach { body(it.rawSqlText()) } + for (it in creators) { + body(it.rawSqlText()) + } + for (it in miscellanious) { + body(it.rawSqlText()) + } } private fun ArrayList.buildGraph(): Graph { val graph = DirectedAcyclicGraph(DefaultEdge::class.java) val namedStatements = this.associateBy { it.tableName.name } - this.forEach { table -> + for (table in this) { graph.addVertex(table) - table.columnDefList.forEach { column -> - column.columnConstraintList.mapNotNull { it.foreignKeyClause?.foreignTable }.forEach { fk -> + for (column in table.columnDefList) { + for (fk in column.columnConstraintList.mapNotNull { it.foreignKeyClause?.foreignTable }) { try { val foreignTable = namedStatements[fk.name] graph.apply { diff --git a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/InterfaceGeneration.kt b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/InterfaceGeneration.kt index 8dfafbbf3bf..d224e8bb8a0 100644 --- a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/InterfaceGeneration.kt +++ b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/queries/InterfaceGeneration.kt @@ -1069,7 +1069,7 @@ class InterfaceGeneration { private fun checkFixtureCompiles(fixtureRoot: String) { val result = FixtureCompiler.compileFixture( fixtureRoot = "src/test/query-interface-fixtures/$fixtureRoot", - compilationMethod = { _, _, file, output -> + compilationMethod = { file, output -> SqlDelightCompiler.writeQueryInterfaces(file, output) }, generateDb = false, diff --git a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/tables/InterfaceGeneration.kt b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/tables/InterfaceGeneration.kt index 2081860af69..6cdff8cfbe4 100644 --- a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/tables/InterfaceGeneration.kt +++ b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/tables/InterfaceGeneration.kt @@ -428,9 +428,7 @@ class InterfaceGeneration { private fun checkFixtureCompiles(fixtureRoot: String) { val result = FixtureCompiler.compileFixture( fixtureRoot = "src/test/table-interface-fixtures/$fixtureRoot", - compilationMethod = { _, _, sqlDelightQueriesFile, writer -> - SqlDelightCompiler.writeTableInterfaces(sqlDelightQueriesFile, writer) - }, + compilationMethod = SqlDelightCompiler::writeTableInterfaces, generateDb = false, ) for ((expectedFile, actualOutput) in result.compilerOutput) { diff --git a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/views/InterfaceGeneration.kt b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/views/InterfaceGeneration.kt index 783d4928f06..ea2b6a36613 100644 --- a/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/views/InterfaceGeneration.kt +++ b/sqldelight-compiler/src/test/kotlin/app/cash/sqldelight/core/views/InterfaceGeneration.kt @@ -104,9 +104,7 @@ class InterfaceGeneration { private fun checkFixtureCompiles(fixtureRoot: String) { val result = FixtureCompiler.compileFixture( fixtureRoot = "src/test/view-interface-fixtures/$fixtureRoot", - compilationMethod = { _, _, sqlDelightQueriesFile, writer -> - SqlDelightCompiler.writeTableInterfaces(sqlDelightQueriesFile, writer) - }, + compilationMethod = SqlDelightCompiler::writeTableInterfaces, generateDb = false, ) assertThat(result.errors).isEmpty() diff --git a/sqldelight-idea-plugin/src/main/kotlin/app/cash/sqldelight/intellij/lang/SqlDelightFileViewProviderFactory.kt b/sqldelight-idea-plugin/src/main/kotlin/app/cash/sqldelight/intellij/lang/SqlDelightFileViewProviderFactory.kt index 3baed931254..5c56f8e3a4d 100644 --- a/sqldelight-idea-plugin/src/main/kotlin/app/cash/sqldelight/intellij/lang/SqlDelightFileViewProviderFactory.kt +++ b/sqldelight-idea-plugin/src/main/kotlin/app/cash/sqldelight/intellij/lang/SqlDelightFileViewProviderFactory.kt @@ -190,8 +190,7 @@ private class SqlDelightFileViewProvider( } } if (file is SqlDelightQueriesFile) { - val projectService = SqlDelightProjectService.getInstance(module.project) - SqlDelightCompiler.writeInterfaces(module, projectService.dialect, file, fileAppender) + SqlDelightCompiler.writeInterfaces(file, fileAppender) } else if (file is MigrationFile) { SqlDelightCompiler.writeInterfaces(file, fileAppender) } diff --git a/sqldelight-idea-plugin/src/test/kotlin/app/cash/sqldelight/intellij/SqlDelightProjectTestCase.kt b/sqldelight-idea-plugin/src/test/kotlin/app/cash/sqldelight/intellij/SqlDelightProjectTestCase.kt index 087fd4ccc9f..22b5feb6664 100644 --- a/sqldelight-idea-plugin/src/test/kotlin/app/cash/sqldelight/intellij/SqlDelightProjectTestCase.kt +++ b/sqldelight-idea-plugin/src/test/kotlin/app/cash/sqldelight/intellij/SqlDelightProjectTestCase.kt @@ -119,8 +119,7 @@ abstract class SqlDelightProjectTestCase : LightJavaCodeInsightFixtureTestCase() fileToGenerateDb = sqlFile return@iterateContentUnderDirectory true } - val dialect = SqliteDialect() - SqlDelightCompiler.writeInterfaces(module, dialect, fileToGenerateDb!!, virtualFileWriter) + SqlDelightCompiler.writeInterfaces(fileToGenerateDb!!, virtualFileWriter) SqlDelightCompiler.writeDatabaseInterface(module, fileToGenerateDb!!, module.name, virtualFileWriter) } } diff --git a/test-util/src/main/kotlin/app/cash/sqldelight/test/util/FixtureCompiler.kt b/test-util/src/main/kotlin/app/cash/sqldelight/test/util/FixtureCompiler.kt index f5dde517ddb..6fe33cd27de 100644 --- a/test-util/src/main/kotlin/app/cash/sqldelight/test/util/FixtureCompiler.kt +++ b/test-util/src/main/kotlin/app/cash/sqldelight/test/util/FixtureCompiler.kt @@ -22,14 +22,13 @@ import app.cash.sqldelight.core.lang.SqlDelightQueriesFile import app.cash.sqldelight.dialect.api.SqlDelightDialect import app.cash.sqldelight.dialects.sqlite_3_18.SqliteDialect import com.alecstrong.sql.psi.core.SqlAnnotationHolder -import com.intellij.openapi.module.Module import com.intellij.psi.PsiDocumentManager import com.intellij.psi.PsiElement import com.intellij.psi.PsiFile import java.io.File import org.junit.rules.TemporaryFolder -private typealias CompilationMethod = (Module, SqlDelightDialect, SqlDelightQueriesFile, (String) -> Appendable) -> Unit +private typealias CompilationMethod = (SqlDelightQueriesFile, (String) -> Appendable) -> Unit object FixtureCompiler { @@ -125,7 +124,7 @@ object FixtureCompiler { psiFile.log(sourceFiles) if (psiFile is SqlDelightQueriesFile) { if (errors.isEmpty()) { - compilationMethod(environment.module, environment.dialect, psiFile, fileWriter) + compilationMethod(psiFile, fileWriter) } file = psiFile } else if (psiFile is MigrationFile) { From 7db1093eabd88e9bcc787618299ba263022c8f29 Mon Sep 17 00:00:00 2001 From: hfhbd Date: Tue, 21 Nov 2023 21:55:28 +0100 Subject: [PATCH 2/2] Don't use class for insert with defaults --- gradle.properties | 3 + settings.gradle | 1 + .../core/compiler/QueryGenerator.kt | 151 ++++++++++-------- .../core/compiler/model/BindableQuery.kt | 58 ++++--- .../core/lang/AcceptsTableInterface.kt | 10 +- .../sqldelight/core/lang/util/TreeUtil.kt | 2 +- .../postgresql/integration/Defaults.sq | 11 ++ .../src/test/kotlin/MultithreadedTest.kt | 12 +- 8 files changed, 144 insertions(+), 104 deletions(-) create mode 100644 sqldelight-gradle-plugin/src/test/integration-postgresql/src/main/sqldelight/app/cash/sqldelight/postgresql/integration/Defaults.sq diff --git a/gradle.properties b/gradle.properties index 09c5423dcc4..bde507aeaea 100644 --- a/gradle.properties +++ b/gradle.properties @@ -26,3 +26,6 @@ kotlin.native.ignoreDisabledTargets=true # caches break the linkage of the sqlite amalgamation kotlin.native.cacheKind.linuxX64=none + +org.gradle.configuration-cache=false +org.gradle.configureondemand=true diff --git a/settings.gradle b/settings.gradle index f21e843b4cc..fa49008e7d4 100644 --- a/settings.gradle +++ b/settings.gradle @@ -27,6 +27,7 @@ gradleEnterprise { rootProject.name = 'sqldelight' enableFeaturePreview("TYPESAFE_PROJECT_ACCESSORS") +enableFeaturePreview("STABLE_CONFIGURATION_CACHE") include ':adapters:primitive-adapters' include ':dialects:hsql' diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueryGenerator.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueryGenerator.kt index acbcb313209..38b48d6c5fb 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueryGenerator.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/QueryGenerator.kt @@ -10,6 +10,7 @@ import app.cash.sqldelight.core.lang.MAPPER_NAME import app.cash.sqldelight.core.lang.PREPARED_STATEMENT_TYPE import app.cash.sqldelight.core.lang.encodedJavaType import app.cash.sqldelight.core.lang.preparedStatementBinder +import app.cash.sqldelight.core.lang.shouldInferColumns import app.cash.sqldelight.core.lang.util.childOfType import app.cash.sqldelight.core.lang.util.columnDefSource import app.cash.sqldelight.core.lang.util.findChildrenOfType @@ -22,6 +23,7 @@ import app.cash.sqldelight.dialect.api.IntermediateType import app.cash.sqldelight.dialect.grammar.mixins.BindParameterMixin import com.alecstrong.sql.psi.core.psi.SqlBinaryEqualityExpr import com.alecstrong.sql.psi.core.psi.SqlBindExpr +import com.alecstrong.sql.psi.core.psi.SqlInsertStmt import com.alecstrong.sql.psi.core.psi.SqlStmt import com.alecstrong.sql.psi.core.psi.SqlTypes import com.intellij.psi.PsiElement @@ -141,87 +143,94 @@ abstract class QueryGenerator( extractedVariables[type] = variableName bindStatements.add("val %N = $encodedJavaType\n", variableName) } - - // For each argument in the sql - for ((_, argument, bindArg) in orderedBindArgs) { - val type = argument.type - // Need to replace the single argument with a group of indexed arguments, calculated at - // runtime from the list parameter: - // val idIndexes = id.mapIndexed { index, _ -> "?${previousArray.size + index}" }.joinToString(prefix = "(", postfix = ")") - val offset = (precedingArrays.map { "$it.size" } + "$nonArrayBindArgsCount") - .joinToString(separator = " + ").replace(" + 0", "") - if (bindArg?.isArrayParameter() == true) { - needsFreshStatement = true - - if (!handledArrayArgs.contains(argument) && seenArrayArguments.add(argument)) { - result.addStatement( - """ + if (statement is SqlInsertStmt && statement.shouldInferColumns()) { + for ((index, arg) in orderedBindArgs.withIndex()) { + val type = arg.second.type + bindStatements.add(type.preparedStatementBinder(CodeBlock.of("$index"), extractedVariables[type])) + } + argumentCounts.add("${orderedBindArgs.size}") + } else { + // For each argument in the sql + for ((_, argument, bindArg) in orderedBindArgs) { + val type = argument.type + // Need to replace the single argument with a group of indexed arguments, calculated at + // runtime from the list parameter: + // val idIndexes = id.mapIndexed { index, _ -> "?${previousArray.size + index}" }.joinToString(prefix = "(", postfix = ")") + val offset = (precedingArrays.map { "$it.size" } + "$nonArrayBindArgsCount") + .joinToString(separator = " + ").replace(" + 0", "") + if (bindArg?.isArrayParameter() == true) { + needsFreshStatement = true + + if (!handledArrayArgs.contains(argument) && seenArrayArguments.add(argument)) { + result.addStatement( + """ |val ${type.name}Indexes = createArguments(count = ${type.name}.size) - """.trimMargin(), - ) - } + """.trimMargin(), + ) + } - // Replace the single bind argument with the array of bind arguments: - // WHERE id IN ${idIndexes} - replacements.add(bindArg.range to "\$${type.name}Indexes") - - // Perform the necessary binds: - // id.forEachIndex { index, parameter -> - // statement.bindLong(previousArray.size + index, parameter) - // } - val indexCalculator = CodeBlock.of( - if (offset == "0") { - "index" - } else { - "index + %L" - }, - offset, - ) - val elementName = argumentNameAllocator.newName(type.name) - bindStatements.add( - """ + // Replace the single bind argument with the array of bind arguments: + // WHERE id IN ${idIndexes} + replacements.add(bindArg.range to "\$${type.name}Indexes") + + // Perform the necessary binds: + // id.forEachIndex { index, parameter -> + // statement.bindLong(previousArray.size + index, parameter) + // } + val indexCalculator = CodeBlock.of( + if (offset == "0") { + "index" + } else { + "index + %L" + }, + offset, + ) + val elementName = argumentNameAllocator.newName(type.name) + bindStatements.add( + """ |${type.name}.forEachIndexed { index, $elementName -> | %L} | - """.trimMargin(), - type.copy(name = elementName).preparedStatementBinder(indexCalculator), - ) + """.trimMargin(), + type.copy(name = elementName).preparedStatementBinder(indexCalculator), + ) - precedingArrays.add(type.name) - argumentCounts.add("${type.name}.size") - } else { - val bindParameter = bindArg?.bindParameter as? BindParameterMixin - if (bindParameter == null || bindParameter.text != "DEFAULT") { - nonArrayBindArgsCount += 1 - - if (!treatNullAsUnknownForEquality && type.javaType.isNullable) { - val parent = bindArg?.parent - if (parent is SqlBinaryEqualityExpr) { - needsFreshStatement = true - - var symbol = parent.childOfType(SqlTypes.EQ) ?: parent.childOfType(SqlTypes.EQ2) - val nullableEquality: String - if (symbol != null) { - nullableEquality = "${symbol.leftWhitspace()}IS${symbol.rightWhitespace()}" - } else { - symbol = parent.childOfType(SqlTypes.NEQ) ?: parent.childOfType(SqlTypes.NEQ2)!! - nullableEquality = "${symbol.leftWhitspace()}IS NOT${symbol.rightWhitespace()}" + precedingArrays.add(type.name) + argumentCounts.add("${type.name}.size") + } else { + val bindParameter = bindArg?.bindParameter as? BindParameterMixin + if (bindParameter == null || bindParameter.text != "DEFAULT") { + nonArrayBindArgsCount += 1 + + if (!treatNullAsUnknownForEquality && type.javaType.isNullable) { + val parent = bindArg?.parent + if (parent is SqlBinaryEqualityExpr) { + needsFreshStatement = true + + var symbol = parent.childOfType(SqlTypes.EQ) ?: parent.childOfType(SqlTypes.EQ2) + val nullableEquality: String + if (symbol != null) { + nullableEquality = "${symbol.leftWhitspace()}IS${symbol.rightWhitespace()}" + } else { + symbol = parent.childOfType(SqlTypes.NEQ) ?: parent.childOfType(SqlTypes.NEQ2)!! + nullableEquality = "${symbol.leftWhitspace()}IS NOT${symbol.rightWhitespace()}" + } + + val block = CodeBlock.of("if (${type.name} == null) \"$nullableEquality\" else \"${symbol.text}\"") + replacements.add(symbol.range to "\${ $block }") } - - val block = CodeBlock.of("if (${type.name} == null) \"$nullableEquality\" else \"${symbol.text}\"") - replacements.add(symbol.range to "\${ $block }") } - } - // Binds each parameter to the statement: - // statement.bindLong(0, id) - bindStatements.add(type.preparedStatementBinder(CodeBlock.of(offset), extractedVariables[type])) + // Binds each parameter to the statement: + // statement.bindLong(0, id) + bindStatements.add(type.preparedStatementBinder(CodeBlock.of(offset), extractedVariables[type])) - // Replace the named argument with a non named/indexed argument. - // This allows us to use the same algorithm for non Sqlite dialects - // :name becomes ? - if (bindParameter != null) { - replacements.add(bindArg.range to bindParameter.replaceWith(generateAsync, index = nonArrayBindArgsCount)) + // Replace the named argument with a non named/indexed argument. + // This allows us to use the same algorithm for non Sqlite dialects + // :name becomes ? + if (bindParameter != null) { + replacements.add(bindArg.range to bindParameter.replaceWith(generateAsync, index = nonArrayBindArgsCount)) + } } } } diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/BindableQuery.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/BindableQuery.kt index d3bde850c9f..2a033dd255a 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/BindableQuery.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/compiler/model/BindableQuery.kt @@ -20,6 +20,7 @@ import app.cash.sqldelight.core.compiler.SqlDelightCompiler.allocateName import app.cash.sqldelight.core.lang.acceptsTableInterface import app.cash.sqldelight.core.lang.psi.ColumnTypeMixin.ValueTypeDialectType import app.cash.sqldelight.core.lang.psi.StmtIdentifierMixin +import app.cash.sqldelight.core.lang.shouldInferColumns import app.cash.sqldelight.core.lang.types.typeResolver import app.cash.sqldelight.core.lang.util.argumentType import app.cash.sqldelight.core.lang.util.childOfType @@ -92,34 +93,49 @@ abstract class BindableQuery( val manuallyNamedIndexes = mutableSetOf() val namesSeen = mutableSetOf() var maxIndexSeen = 0 - for (bindArg in statement.findChildrenOfType()) { - val bindParameter = bindArg.bindParameter - if (bindParameter is BindParameterMixin && bindParameter.text != "DEFAULT") { - val bindIndex = bindParameter.node.findChildByType(SqlTypes.DIGIT)?.text?.toInt() - if (bindIndex != null) { - if (!indexesSeen.add(bindIndex)) { - result.findAndReplace(bindArg, bindIndex) { it.index == bindIndex } + if (statement is SqlInsertStmt && statement.shouldInferColumns()) { + for (column in statement.columns) { + val index = ++maxIndexSeen + indexesSeen.add(index) + manuallyNamedIndexes.add(index) + result.add(Argument(index, column.type())) + } + } else { + for (bindArg in statement.findChildrenOfType()) { + val bindParameter = bindArg.bindParameter + if (bindParameter is BindParameterMixin && bindParameter.text != "DEFAULT") { + val bindIndex = bindParameter.node.findChildByType(SqlTypes.DIGIT)?.text?.toInt() + if (bindIndex != null) { + if (!indexesSeen.add(bindIndex)) { + result.findAndReplace(bindArg, bindIndex) { it.index == bindIndex } + continue + } + maxIndexSeen = maxOf(maxIndexSeen, bindIndex) + result.add(Argument(bindIndex, typeResolver.argumentType(bindArg), mutableListOf(bindArg))) continue } - maxIndexSeen = maxOf(maxIndexSeen, bindIndex) - result.add(Argument(bindIndex, typeResolver.argumentType(bindArg), mutableListOf(bindArg))) - continue - } - val identifier = bindParameter.identifier - if (identifier != null) { - if (!namesSeen.add(identifier.text)) { - result.findAndReplace(bindArg) { (_, type, _) -> type.name == identifier.text } + val identifier = bindParameter.identifier + if (identifier != null) { + if (!namesSeen.add(identifier.text)) { + result.findAndReplace(bindArg) { (_, type, _) -> type.name == identifier.text } + continue + } + val index = ++maxIndexSeen + indexesSeen.add(index) + manuallyNamedIndexes.add(index) + result.add( + Argument( + index, + typeResolver.argumentType(bindArg).copy(name = identifier.text), + mutableListOf(bindArg), + ), + ) continue } val index = ++maxIndexSeen indexesSeen.add(index) - manuallyNamedIndexes.add(index) - result.add(Argument(index, typeResolver.argumentType(bindArg).copy(name = identifier.text), mutableListOf(bindArg))) - continue + result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg))) } - val index = ++maxIndexSeen - indexesSeen.add(index) - result.add(Argument(index, typeResolver.argumentType(bindArg), mutableListOf(bindArg))) } } diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/AcceptsTableInterface.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/AcceptsTableInterface.kt index 66afed4207e..4ed86d4d32e 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/AcceptsTableInterface.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/AcceptsTableInterface.kt @@ -1,9 +1,15 @@ package app.cash.sqldelight.core.lang import app.cash.sqldelight.core.lang.util.childOfType +import app.cash.sqldelight.core.lang.util.table import com.alecstrong.sql.psi.core.psi.SqlInsertStmt import com.alecstrong.sql.psi.core.psi.SqlTypes -fun SqlInsertStmt.acceptsTableInterface(): Boolean { - return insertStmtValues?.childOfType(SqlTypes.BIND_EXPR) != null +fun SqlInsertStmt.shouldInferColumns() = insertStmtValues?.childOfType(SqlTypes.BIND_EXPR) != null && columnNameList.isNotEmpty() + +fun SqlInsertStmt.acceptsTableInterface(): Boolean = if (shouldInferColumns()) { + // It is safe to just compare the sizes because sql-psi already did check the references or default checks. + columnNameList.size == table.query.columns.size +} else { + insertStmtValues?.childOfType(SqlTypes.BIND_EXPR) != null } diff --git a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/TreeUtil.kt b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/TreeUtil.kt index a8f2e970265..d4c5e87918f 100644 --- a/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/TreeUtil.kt +++ b/sqldelight-compiler/src/main/kotlin/app/cash/sqldelight/core/lang/util/TreeUtil.kt @@ -173,7 +173,7 @@ private fun PsiElement.rangesToReplace(): List> { second = "", ), ) - } else if (this is InsertStmtValuesMixin && parent?.acceptsTableInterface() == true) { + } else if (this is InsertStmtValuesMixin && (parent?.acceptsTableInterface() == true || parent?.shouldInferColumns() == true)) { listOf( Pair( first = childOfType(SqlTypes.BIND_EXPR)!!.range, diff --git a/sqldelight-gradle-plugin/src/test/integration-postgresql/src/main/sqldelight/app/cash/sqldelight/postgresql/integration/Defaults.sq b/sqldelight-gradle-plugin/src/test/integration-postgresql/src/main/sqldelight/app/cash/sqldelight/postgresql/integration/Defaults.sq new file mode 100644 index 00000000000..8440e5aaa75 --- /dev/null +++ b/sqldelight-gradle-plugin/src/test/integration-postgresql/src/main/sqldelight/app/cash/sqldelight/postgresql/integration/Defaults.sq @@ -0,0 +1,11 @@ +CREATE TABLE foo( +a INT DEFAULT 42 NOT NULL, +b TEXT NOT NULL, +c TEXT NOT NULL +); + +insertFoo: +INSERT INTO foo (b, c) VALUES ?; + +insertFooDefault: +INSERT INTO foo (DEFAULT, b, c) VALUES ?; diff --git a/sqldelight-gradle-plugin/src/test/multithreaded-sqlite/src/test/kotlin/MultithreadedTest.kt b/sqldelight-gradle-plugin/src/test/multithreaded-sqlite/src/test/kotlin/MultithreadedTest.kt index c7fb3a91fe5..04ba59e4aec 100644 --- a/sqldelight-gradle-plugin/src/test/multithreaded-sqlite/src/test/kotlin/MultithreadedTest.kt +++ b/sqldelight-gradle-plugin/src/test/multithreaded-sqlite/src/test/kotlin/MultithreadedTest.kt @@ -6,9 +6,6 @@ import java.util.concurrent.atomic.AtomicInteger import kotlin.concurrent.thread import org.junit.After import org.junit.Test -import tables.TableA -import tables.TableB -import tables.TableC /** * I run two threads. Each thread selects all records from TableA, TableB and TableC @@ -96,13 +93,10 @@ private fun selectWithoutTransactions(dbHelper: DbHelper) { fun insertSomeData(dbHelper: DbHelper) { for (i in 0..100) { - val a = TableA(0, genString(10), genString(10), genLong()) - dbHelper.database.tableAQueries.insert(a) + dbHelper.database.tableAQueries.insert(genString(10), genString(10), genLong()) - val b = TableB(0, genString(10), genString(10), genLong()) - dbHelper.database.tableBQueries.insert(b) + dbHelper.database.tableBQueries.insert(genString(10), genString(10), genLong()) - val c = TableC(0, genString(10), genString(10), genLong()) - dbHelper.database.tableCQueries.insert(c) + dbHelper.database.tableCQueries.insert(genString(10), genString(10), genLong()) } }