From 2271bc4e1525153417640a8e35ddc0741f90c8d9 Mon Sep 17 00:00:00 2001 From: yumoqing Date: Wed, 6 May 2020 12:17:05 +0800 Subject: [PATCH] add sqlor context --- sqlor/dbpools.py | 5 +++++ sqlor/sor.py | 37 ++++++++++++++++++++++++++++++------- 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/sqlor/dbpools.py b/sqlor/dbpools.py index ed018d5..de0aee2 100644 --- a/sqlor/dbpools.py +++ b/sqlor/dbpools.py @@ -192,7 +192,12 @@ class DBPools: sqlor = await self.getSqlor(name) try: yield sqlor + except: + if sqlor.dataChanged: + sqlor.rollback() finally: + if sqlor.dataChanged: + sqlor.commit() await self.freeSqlor(sqlor) async def _aquireConn(self,dbname): diff --git a/sqlor/sor.py b/sqlor/sor.py index 7fddb65..cbef5e4 100644 --- a/sqlor/sor.py +++ b/sqlor/sor.py @@ -63,6 +63,7 @@ class SQLor(object): self.writer = None self.convfuncs = {} self.cc = ConditionConvert() + self.dataChanged = False def setCursor(self,async_mode,conn,cur): self.async_mode = async_mode @@ -197,10 +198,28 @@ class SQLor(object): return (m_sql,newdata) + def getSQLType(self,sql): + """ + return one of "qry", "dml" and "ddl" + ddl change the database schema + dml change the database data + qry query data + """ + + a = sql.lstrip(' \t\n\r') + a = a.lower() + al = a.split(' ') + if al[0] == 'select': + return 'qry' + if al[0] in ['update','delete','insert']: + return 'dml' + return 'ddl' + async def execute(self,sql,value,callback,**kwargs): + sqltype = self.getSqlType(sql) cur = self.cursor() await self.runVarSQL(cur,sql,value) - if callback is not None: + if sqltype == 'qry' and callback is not None: fields = [ i[0].lower() for i in cur.description ] rec = None if self.async_mode: @@ -217,7 +236,8 @@ class SQLor(object): rec = await cur.fetchone() else: rec = cur.fetchone() - + if sqltype == 'dml': + self.dataChanged = True async def executemany(self,sql,values): cur = self.cursor() @@ -282,10 +302,7 @@ class SQLor(object): return await self.execute(sql,{},lambda x:ret.append(x)) def isSelectSql(self,sql): - i = 0 - while sql[i] in "\r\n \t": - i = i + 1 - return sql.lower().startswith('select ') + return self.getSqlType(sql) == 'qry' def getSQLfromDesc(self,desc): sql = '' @@ -455,4 +472,10 @@ class SQLor(object): desc['validation'].append(idx) return desc - + def rollback(self): + self.conn.rollback() + self.dataChanged = False + + def commit(self): + self.conn.commit() + self.datachanged = False