feat(db): 修改事务管理

This commit is contained in:
AiKrai 2025-03-11 14:56:57 +08:00
parent 098fa02ca6
commit dff4e2d0a5
7 changed files with 176 additions and 103 deletions

View File

@ -21,10 +21,6 @@ class ResponseHandler: ResponseHandlerInterface {
val requestId = ctx.get<Long>("requestId") ?: -1L
val code: Int
val resStr = when (responseData) {
is Unit -> {
code = HttpStatus.NO_CONTENT
null
}
is RespBean -> {
code = responseData.code
responseData.requestId = requestId
@ -50,10 +46,11 @@ class ResponseHandler: ResponseHandlerInterface {
// 业务异常处理
override suspend fun exception(ctx: RoutingContext, e: Throwable) {
logger.error { "${ctx.request().uri()}: ${e.stackTraceToString()}" }
val resObj = when(e) {
val resObj = when (e) {
is Meta -> {
RespBean.failure("${e.name}:${e.message}", e.data)
}
else -> {
RespBean.failure("${e.javaClass.simpleName}${if (e.message != null) ":${e.message}" else ""}")
}

View File

@ -5,6 +5,7 @@ import app.domain.account.modle.AccountRoleAccessDTO
import com.google.inject.Inject
import io.vertx.sqlclient.SqlClient
import org.aikrai.vertx.db.RepositoryImpl
import org.aikrai.vertx.db.tx.withTransaction
class AccountRepositoryImpl @Inject constructor(
sqlClient: SqlClient

View File

@ -8,6 +8,7 @@ import cn.hutool.core.lang.Snowflake
import cn.hutool.crypto.SecureUtil
import com.google.inject.Inject
import io.vertx.ext.web.RoutingContext
import mu.KotlinLogging
import org.aikrai.vertx.db.tx.withTransaction
import org.aikrai.vertx.utlis.IpUtil
import org.aikrai.vertx.utlis.Meta
@ -18,11 +19,20 @@ class AccountService @Inject constructor(
private val accountRepository: AccountRepository,
private val tokenService: TokenService,
) {
private val logger = KotlinLogging.logger { }
suspend fun testTransaction() {
withTransaction {
accountRepository.update(1L, mapOf("avatar" to "test001"))
// throw Meta.failure("test transaction", "test transaction")
accountRepository.update(1L, mapOf("avatar" to "test0001"))
try {
withTransaction {
accountRepository.update(1L, mapOf("avatar" to "test002"))
throw Meta.error("test transaction", "test transaction")
}
} catch (e: Exception) {
logger.info { "内层事务失败已处理: ${e.message}" }
}
}
}

View File

@ -9,7 +9,7 @@ import io.vertx.sqlclient.*
import io.vertx.sqlclient.templates.SqlTemplate
import mu.KotlinLogging
import org.aikrai.vertx.db.annotation.*
import org.aikrai.vertx.db.tx.TxCtx
import org.aikrai.vertx.db.tx.TxCtxElem
import org.aikrai.vertx.jackson.JsonUtil
import org.aikrai.vertx.utlis.Meta
import java.lang.reflect.Field
@ -171,9 +171,6 @@ open class RepositoryImpl<TId, TEntity : Any>(
"DELETE FROM $tableName WHERE $idFieldName = #{id}"
}
val params = mapOf("id" to id)
if (logger.isDebugEnabled) {
logger.debug { "SQL: $sqlTemplate, PARAMS: $params" }
}
return execute(sqlTemplate, params)
} catch (e: Exception) {
logger.error(e) { "Error deleting entity with id: $id" }
@ -190,7 +187,6 @@ open class RepositoryImpl<TId, TEntity : Any>(
"UPDATE $tableName SET $setClause WHERE $idFieldName = #{id}"
}
val params = getNonNullFields(t) + mapOf("id" to idField.get(t))
logger.debug { "SQL: $sqlTemplate, PARAMS: $params" }
return execute(sqlTemplate, params)
} catch (e: Exception) {
logger.error(e) { "Error updating entity: $t" }
@ -206,7 +202,6 @@ open class RepositoryImpl<TId, TEntity : Any>(
"UPDATE $tableName SET $setClause WHERE $idFieldName = #{id}"
}
val params = parameters + mapOf("id" to id)
logger.debug { "SQL: $sqlTemplate, PARAMS: $params" }
return execute(sqlTemplate, params)
} catch (e: Exception) {
logger.error(e) { "Error updating entity with id: $id" }
@ -236,7 +231,6 @@ open class RepositoryImpl<TId, TEntity : Any>(
"SELECT $columns FROM $tableName WHERE $field = #{value}"
}
val params = mapOf("value" to value)
logger.debug { "SQL: $sqlTemplate, PARAMS: $params" }
return get(sqlTemplate, params, clazz)
} catch (e: Exception) {
logger.error(e) { "Error getting entity by field: $field = $value" }
@ -252,7 +246,6 @@ open class RepositoryImpl<TId, TEntity : Any>(
"SELECT $columns FROM $tableName WHERE ${fieldMappings[field.name]} = #{value}"
}
val params = mapOf("value" to value)
logger.debug { "SQL: $sql, PARAMS: $params" }
return get(sql, params, clazz)
} catch (e: Exception) {
logger.error(e) { "Error getting entity by field: ${field.name} = $value" }
@ -308,7 +301,7 @@ open class RepositoryImpl<TId, TEntity : Any>(
.execute(params)
.coAwait()
.rowCount()
} catch (e: Exception) {
} catch (e: Throwable) {
logger.error(e) { "Error executing SQL: $sql, PARAMS: $params" }
throw Meta.repository(e.javaClass.simpleName, e.message)
}
@ -326,20 +319,14 @@ open class RepositoryImpl<TId, TEntity : Any>(
// 其他工具方法
private suspend fun getConnection(): SqlClient {
return if (TxCtx.isTransactionActive(coroutineContext)) {
TxCtx.currentSqlConnection(coroutineContext) ?: run {
logger.error("TransactionContextElement.sqlConnection is null")
return sqlClient
}
} else {
sqlClient
}
val txElem = coroutineContext[TxCtxElem]
return txElem?.connection ?: sqlClient
}
// 通用获取或创建 SQL 模板的方法
private fun getOrCreateSql(tableName: String, key: String, sqlProvider: () -> String): String {
val tableSqlMap = baseSqlCache.getOrPut(tableName) { ConcurrentHashMap() }
return tableSqlMap.getOrPut(key, sqlProvider)
private fun getOrCreateSql(tableName: String, sqlKey: String, generator: () -> String): String {
return baseSqlCache.computeIfAbsent(tableName) { ConcurrentHashMap() }
.computeIfAbsent(sqlKey) { generator() }
}
// 获取非空字段及其值

View File

@ -1,41 +1,80 @@
package org.aikrai.vertx.db.tx
import io.vertx.kotlin.coroutines.coAwait
import io.vertx.sqlclient.SqlClient
import io.vertx.sqlclient.SqlConnection
import io.vertx.sqlclient.Transaction
import org.aikrai.vertx.utlis.Meta
import java.util.*
import kotlin.coroutines.CoroutineContext
import kotlin.coroutines.coroutineContext
class TxCtxElem(
val sqlConnection: SqlConnection,
val transaction: Transaction,
val isActive: Boolean = true,
val isNested: Boolean = false,
val transactionStack: Stack<TxCtxElem>,
val index: Int = transactionStack.size,
val transactionId: String = UUID.randomUUID().toString()
val connection: SqlConnection,
val transaction: Transaction?, // 外层事务才有实际transaction对象
val savepointName: String? = null, // 内层事务使用savepoint名称
val depth: Int = 0, // 事务嵌套深度
) : CoroutineContext.Element {
companion object Key : CoroutineContext.Key<TxCtxElem>
override val key: CoroutineContext.Key<*> = Key
override fun toString(): String {
return "TransactionContextElement(transactionId=$transactionId, isActive=$isActive, isNested=$isNested)"
}
val isRoot: Boolean = depth == 0
val isNested: Boolean = depth > 0
val transactionId: String = UUID.randomUUID().toString().substring(0, 8)
// 标记是否已回滚或提交
var completed: Boolean = false
}
object TxCtx {
fun getTransactionId(context: CoroutineContext): String? {
return context[TxCtxElem.Key]?.transactionId
}
fun currentTransaction(context: CoroutineContext): Transaction? {
return context[TxCtxElem.Key]?.transaction
}
fun currentSqlConnection(context: CoroutineContext): SqlConnection? {
return context[TxCtxElem.Key]?.sqlConnection
}
/**
* 判断当前是否在事务上下文中
*/
fun isTransactionActive(context: CoroutineContext): Boolean {
return context[TxCtxElem.Key]?.isActive ?: false
return context[TxCtxElem] != null
}
/**
* 获取当前事务的连接
*/
fun currentSqlConnection(context: CoroutineContext): SqlClient? {
return context[TxCtxElem]?.connection
}
/**
* 获取当前事务深度
*/
fun currentTransactionDepth(context: CoroutineContext): Int {
return context[TxCtxElem]?.depth ?: 0
}
/**
* 手动控制设置当前事务回滚点
*/
suspend fun setSavepoint(name: String): String {
val context = coroutineContext
val txElem = context[TxCtxElem] ?: throw Meta.error(
"TransactionError",
"Cannot set savepoint. No active transaction."
)
val connection = txElem.connection
val pointName = "manual_$name"
connection.query("SAVEPOINT $pointName").execute().coAwait()
return pointName
}
/**
* 手动回滚到指定保存点
*/
suspend fun rollbackToSavepoint(name: String) {
val context = coroutineContext
val txElem = context[TxCtxElem] ?: throw Meta.error(
"TransactionError",
"Cannot rollback to savepoint. No active transaction."
)
val connection = txElem.connection
connection.query("ROLLBACK TO SAVEPOINT $name").execute().coAwait()
}
}

View File

@ -40,65 +40,90 @@ object TxMgrHolder {
}
}
class TxMgr(
private val pool: Pool
) {
class TxMgr(private val pool: Pool) {
private val logger = KotlinLogging.logger { }
private val transactionStackMap = ConcurrentHashMap<CoroutineContext, Stack<TxCtxElem>>()
/**
* 在事务上下文中执行一个块
*
* @param block 需要在事务中执行的挂起函数
* @return 块的结果
*/
suspend fun <T> withTransaction(block: suspend CoroutineScope.() -> T): Any? {
suspend fun <T> withTransaction(block: suspend CoroutineScope.() -> T): T {
val currentContext = coroutineContext
val transactionStack = currentContext[TxCtxElem]?.transactionStack ?: Stack<TxCtxElem>()
// 外层事务,嵌套事务,都创建新的连接和事务。实现外层事务回滚时所有嵌套事务回滚,嵌套事务回滚不影响外部事务
val connection: SqlConnection = pool.connection.coAwait()
val transaction: Transaction = connection.begin().coAwait()
val currentTx = currentContext[TxCtxElem]
return try {
val txCtxElem =
TxCtxElem(connection, transaction, true, transactionStack.isNotEmpty(), transactionStack)
transactionStack.push(txCtxElem)
logger.debug { (if (txCtxElem.isNested) "嵌套" else "") + "事务Id:" + txCtxElem.transactionId + "开始" }
// 已在事务中 - 创建SAVEPOINT
if (currentTx != null) {
return withSavepoint(currentTx, block)
}
withContext(currentContext + txCtxElem) {
val result = block()
if (txCtxElem.index == 0) {
while (transactionStack.isNotEmpty()) {
val txCtx = transactionStack.pop()
txCtx.transaction.commit().coAwait()
logger.debug { (if (txCtx.isNested) "嵌套" else "") + "事务Id:" + txCtx.transactionId + "提交" }
// 外层事务 - 创建实际事务
val connection = pool.connection.coAwait()
val transaction = connection.begin().coAwait()
val startTime = System.currentTimeMillis()
try {
// 创建根事务上下文
val txElem = TxCtxElem(connection, transaction, depth = 0)
logger.debug { "Root transaction ${txElem.transactionId} started" }
val result = withContext(currentContext + txElem) {
block()
}
// 提交事务
if (!txElem.completed) {
transaction.commit().coAwait()
txElem.completed = true
logger.debug { "Root transaction ${txElem.transactionId} committed, took ${System.currentTimeMillis() - startTime}ms" }
}
result
}
return result
} catch (e: Exception) {
logger.error(e) { "Transaction failed, rollback" }
if (transactionStack.isNotEmpty() && !transactionStack.peek().isNested) {
// 外层事务失败,回滚所有事务
logger.error { "Rolling back all transactions" }
while (transactionStack.isNotEmpty()) {
val txCtxElem = transactionStack.pop()
txCtxElem.transaction.rollback().coAwait()
logger.debug { (if (txCtxElem.isNested) "嵌套" else "") + "事务Id:" + txCtxElem.transactionId + "回滚" }
}
logger.error(e) { "Root transaction failed, rolling back" }
transaction.rollback().coAwait()
throw e
} else {
// 嵌套事务失败,只回滚当前事务
val txCtxElem = transactionStack.pop()
txCtxElem.transaction.rollback().coAwait()
logger.debug(e) { (if (txCtxElem.isNested) "嵌套" else "") + "事务Id:" + txCtxElem.transactionId + "回滚" }
}
} finally {
if (transactionStack.isEmpty()) {
transactionStackMap.remove(currentContext) // 清理上下文
connection.close() // 仅在外层事务时关闭连接
connection.close()
}
}
private suspend fun <T> withSavepoint(
parentTx: TxCtxElem,
block: suspend CoroutineScope.() -> T
): T {
val connection = parentTx.connection
val savepointName = "sp_${UUID.randomUUID().toString().replace("-", "").substring(0, 10)}"
val startTime = System.currentTimeMillis()
// 创建保存点
connection.query("SAVEPOINT $savepointName").execute().coAwait()
logger.debug { "Nested transaction with savepoint $savepointName started" }
try {
// 创建嵌套事务上下文
val nestedTxElem = TxCtxElem(
connection = connection,
transaction = null, // 嵌套事务没有独立的Transaction对象
savepointName = savepointName,
depth = parentTx.depth + 1,
)
val result = withContext(coroutineContext + nestedTxElem) {
block()
}
// 嵌套事务成功,释放保存点
if (!nestedTxElem.completed) {
connection.query("RELEASE SAVEPOINT $savepointName").execute().coAwait()
nestedTxElem.completed = true
logger.debug { "Savepoint $savepointName released, took ${System.currentTimeMillis() - startTime}ms" }
}
return result
} catch (e: Exception) {
logger.warn(e) { "Nested transaction failed, rolling back to savepoint $savepointName" }
// 回滚到保存点,但不影响外层事务
connection.query("ROLLBACK TO SAVEPOINT $savepointName").execute().coAwait()
throw e
}
}
}

View File

@ -113,15 +113,29 @@ class OpenApiSpecGenerator {
*/
private fun generatePaths(): Paths {
val paths = Paths()
val pathInfoMap = mutableMapOf<String, Pair<String, PathItem>>()
// 获取所有带有 @Controller 注解的类
val packageName = ClassUtil.getMainClass()?.packageName
val packageName = ClassUtil.getMainClass().packageName
val controllerClassSet = Reflections(packageName).getTypesAnnotatedWith(Controller::class.java)
ClassUtil.getPublicMethods(controllerClassSet).forEach { (controllerClass, methods) ->
val controllerInfo = extractControllerInfo(controllerClass)
methods.forEach { method ->
val pathInfo = generatePathInfo(method, controllerInfo)
paths.addPathItem(pathInfo.path, pathInfo.pathItem)
if (!pathInfo.pathItem.post?.tags?.first().isNullOrBlank()) {
pathInfoMap[pathInfo.path] = Pair(pathInfo.pathItem.post.tags.first(), pathInfo.pathItem)
}
if (!pathInfo.pathItem.get?.tags?.first().isNullOrBlank()) {
pathInfoMap[pathInfo.path] = Pair(pathInfo.pathItem.get.tags.first(), pathInfo.pathItem)
}
}
}
val sortedMap = pathInfoMap.toList()
.sortedBy { it.second.second.post?.summary }
.sortedBy { it.second.second.get?.summary }
.sortedBy { it.second.first }
.toMap()
for ((key, value) in sortedMap) {
paths[key] = value.second
}
return paths
}