diff --git a/sqlor/crud.py b/sqlor/crud.py index 671cb1a..ee8cd5a 100755 --- a/sqlor/crud.py +++ b/sqlor/crud.py @@ -41,29 +41,33 @@ class CRUD(object): self.primary_data = None self.oa = ObjectAction() - async def primaryKey(self): + async def primaryKey(self,**kw): if self.primary_data is None: self.primary_data = await self.pool.getTablePrimaryKey(self.dbname, - self.tablename) + self.tablename,**kw) return self.primary_data - async def forignKeys(self): - data = self.pool.getTableForignKeys(self.dbname,self.tablename) + async def forignKeys(self,**kw): + data = self.pool.getTableForignKeys(self.dbname,self.tablename,**kw) return data - async def I(self): + async def I(self,**kw): """ fields information """ - pkdata = await self.primaryKey() - pks = [ i.field_name for i in pkdata ] - data = await self.pool.getTableFields(self.dbname,self.tablename) - for d in data: - if d.name in pks: - d.update({'primarykey':True}) - data = self.oa.execute(self.dbname+'_'+self.tablename,'tableInfo',data) - return data + @self.pool.inSQLor + async def main(dbname,NS,**kw): + pkdata = await self.primaryKey(**kw) + pks = [ i.field_name for i in pkdata ] + data = await self.pool.getTableFields(self.dbname,self.tablename,**kw) + for d in data: + if d.name in pks: + d.update({'primarykey':True}) + data = self.oa.execute(self.dbname+'_'+self.tablename,'tableInfo',data) + return data + + return await main(self.dbname,{},**kw) async def fromStr(self,data): fields = await self.pool.getTableFields(self.dbname,self.tablename) @@ -90,11 +94,11 @@ class CRUD(object): ret[k] = f(data[k]) return ret - async def datagrid(self,request,target): + async def datagrid(self,request,targeti,**kw): fields = await self.I() fs = [ self.defaultIOField(f) for f in fields ] id = self.dbname+':'+ self.tablename - pk = await self.primaryKey() + pk = await self.primaryKey(**kw) idField = pk[0]['field_name'] data = { "tmplname":"widget_js.tmpl", @@ -178,18 +182,12 @@ class CRUD(object): "iotype":"text" } - async def C(self,rec): + async def C(self,rec,**kw): """ create new data """ - fields = await self.pool.getTableFields(self.dbname,self.tablename) - flist = [ f['name'] for f in fields ] - fns = ','.join(flist) - vfs = ','.join([ '${' + f + '}$' for f in flist ]) - data = {} - [ data.update({k.lower():v}) for k,v in rec.items() ] @self.pool.runSQL - async def addSQL(dbname,data,callback=None): + async def addSQL(dbname,data,**kw): sqldesc={ "sql_string" : """ insert into %s (%s) values (%s) @@ -197,31 +195,40 @@ class CRUD(object): } return sqldesc - pk = await self.primaryKey() - k = pk[0]['field_name'] - if not data.get(k): - v = getID() - data[k] = v - data = self.oa.execute(self.dbname+'_'+self.tablename,'beforeAdd',data) - await addSQL(self.dbname,data) - data = self.oa.execute(self.dbname+'_'+self.tablename,'afterAdd',data) - return {k:data[k]} - return data + @self.pool.inSQLor + async def main(dbname,NS,**kw): + fields = await self.pool.getTableFields(self.dbname,self.tablename,**kw) + flist = [ f['name'] for f in fields ] + fns = ','.join(flist) + vfs = ','.join([ '${' + f + '}$' for f in flist ]) + data = {} + [ data.update({k.lower():v}) for k,v in NS.items() ] + pk = await self.primaryKey(**kw) + k = pk[0]['field_name'] + if not data.get(k): + v = getID() + data[k] = v + data = self.oa.execute(self.dbname+'_'+self.tablename,'beforeAdd',data) + await addSQL(self.dbname,data,**kw) + data = self.oa.execute(self.dbname+'_'+self.tablename,'afterAdd',data) + return {k:data[k]} + + return await main(self.dbname,rec,**kw) - async def defaultFilter(self,NS): - fields = await self.pool.getTableFields(self.dbname,self.tablename) + async def defaultFilter(self,NS,**kw): + fields = await self.pool.getTableFields(self.dbname,self.tablename,**kw) d = [ '%s = ${%s}$' % (f['name'],f['name']) for f in fields if f['name'] in NS.keys() ] if len(d) == 0: return '' ret = ' and ' + ' and '.join(d) return ret - async def R(self,filters=None,NS={}): + async def R(self,filters=None,NS={},**kw): """ retrieve data """ @self.pool.runSQL - async def retrieve(dbname,data,callback=None): + async def retrieve(dbname,data,**kw): fstr = '' if filters is not None: fstr = ' and ' @@ -235,7 +242,7 @@ class CRUD(object): return sqldesc @self.pool.runSQLPaging - async def pagingdata(dbname,data,filters=None): + async def pagingdata(dbname,data,filters=None,**kw): fstr = "" if filters is not None: fstr = ' and ' @@ -250,31 +257,35 @@ class CRUD(object): } return sqldesc - p = await self.primaryKey() - if NS.get('__id') is not None: - NS[p[0]['field_name']] = NS['__id'] - del NS['__id'] + @self.pool.inSQLor + async def main(dbname,NS,**kw): + p = await self.primaryKey(**kw) + if NS.get('__id') is not None: + NS[p[0]['field_name']] = NS['__id'] + del NS['__id'] + if NS.get('page'): + del NS['page'] + if NS.get('page'): - del NS['page'] + if NS.get('sort',None) is None: + NS['sort'] = p[0]['field_name'] - if NS.get('page'): - if NS.get('sort',None) is None: - NS['sort'] = p[0]['field_name'] + data = self.oa.execute(self.dbname+'_'+self.tablename,'beforeRetrieve',NS) + if NS.get('page'): + data = await pagingdata(self.dbname,data,**kw) + else: + data = await retrieve(self.dbname,data,**kw) + data = self.oa.execute(self.dbname+'_'+self.tablename,'afterRetrieve',data) + return data - data = self.oa.execute(self.dbname+'_'+self.tablename,'beforeRetrieve',NS) - if NS.get('page'): - data = await pagingdata(self.dbname,data) - else: - data = await retrieve(self.dbname,data) - data = self.oa.execute(self.dbname+'_'+self.tablename,'afterRetrieve',data) - return data + return await main(self.dbname,NS,**kw) - async def U(self,data): + async def U(self,data, **kw): """ update data """ @self.pool.runSQL - async def update(dbname,NS,callback=None): + async def update(dbname,NS,**kw): condi = [ i['field_name'] for i in self.primary_data ] newData = [ i for i in NS.keys() if i not in condi ] c = [ '%s = ${%s}$' % (i,i) for i in condi ] @@ -286,20 +297,23 @@ class CRUD(object): } return sqldesc - pk = await self.primaryKey() - pkfields = [k.field_name for k in pk ] - newData = [ k for k in data if k not in pkfields ] - data = self.oa.execute(self.dbname+'_'+self.tablename,'beforeUpdate',data) - await update(self.dbname,data) - data = self.oa.execute(self.dbname+'_'+self.tablename,'afterUpdate',data) - return data + @self.pool.inSQLor + async def main(dbname,NS,**kw): + pk = await self.primaryKey(**kw) + pkfields = [k.field_name for k in pk ] + newData = [ k for k in data if k not in pkfields ] + data = self.oa.execute(self.dbname+'_'+self.tablename,'beforeUpdate',data) + await update(self.dbname,data,**kw) + data = self.oa.execute(self.dbname+'_'+self.tablename,'afterUpdate',data) + return data + return await main(self.dbname,data,**kw) - async def D(self,data): + async def D(self,data,**kw): """ delete data """ @self.pool.runSQL - def delete(dbname,data): + def delete(dbname,data,**kw): pnames = [ i['field_name'] for i in self.primary_data ] c = [ '%s = ${%s}$' % (i,i) for i in pnames ] cs = ' and '.join(c) @@ -308,10 +322,13 @@ class CRUD(object): } return sqldesc - data = self.oa.execute(self.dbname+'_'+self.tablename,'beforeDelete',data) - await delete(self.dbname,data,pkfields) - data = self.oa.execute(self.dbname+'_'+self.tablename,'afterDelete',data) - return data + @self.pool.inSQLor + async def main(dbname,NS,**kw): + data = self.oa.execute(self.dbname+'_'+self.tablename,'beforeDelete',data) + await delete(self.dbname,data,pkfields,**kw) + data = self.oa.execute(self.dbname+'_'+self.tablename,'afterDelete',data) + return data + return await main(self.dbname,data,**kw) if __name__ == '__main__': DBPools({ diff --git a/sqlor/dbpools.py b/sqlor/dbpools.py index 28e28d1..8ef140a 100644 --- a/sqlor/dbpools.py +++ b/sqlor/dbpools.py @@ -210,12 +210,24 @@ class DBPools: raise Exception('database (%s) not connected'%dbname) await p.release(conn) + async def useOrGetSor(self,dbname,**kw): + commit = False + if kw.get('sor'): + sor = kw['sor'] + else: + sor = await self.getSqlor(dbname) + commit = True + return sor, commit + def inSqlor(self,func): @wraps(func) - async def wrap_func(sor,dbname,*args,**kw): - sor = await self.getSqlor(dbname) + async def wrap_func(dbname,NS,*args,**kw): + sor, commit = self.useOrGetSor(dbname, **kw) + kw['sor'] = sor try: - ret = await func(sor,dbname,*args,**kw) + ret = await func(dbname,NS,*args,**kw) + if not commit: + return ret try: await sor.conn.commit() except: @@ -223,49 +235,61 @@ class DBPools: return ret except Exception as e: print('error',sor) + if not commit: + raise e try: await sor.conn.rollback() except: pass raise e finally: - await self.freeSqlor(sor) + if commit: + await self.freeSqlor(sor) return wrap_func def runSQL(self,func): @wraps(func) - async def wrap_func(dbname,NS,*,callback=None,**kw): - sor = await self.getSqlor(dbname) + async def wrap_func(dbname,NS,*args,**kw): + sor, commit = self.useOrGetSor(dbname,**kw) + kw['sor'] = sor ret = None try: - desc = await func(dbname,NS,callback=callback,**kw) - ret = await sor.runSQL(desc,NS,callback,**kw) - try: - await sor.conn.commit() - except: - pass + desc = await func(dbname,NS,*args,**kw) + callback = kw.get('callback',None) + kw1 = {} + [ kw1.update({k:v}) for k,v in kw.items if k!='callback' ] + ret = await sor.runSQL(desc,NS,callback,**kw1) + if commit: + try: + await sor.conn.commit() + except: + pass if NS.get('dummy'): return NS['dummy'] else: return [] except Exception as e: print('error:',e) + if not commit: + raise e try: await sor.conn.rollback() except: pass raise e finally: - await self.freeSqlor(sor) + if commit: + await self.freeSqlor(sor) return wrap_func def runSQLPaging(self,func): @wraps(func) - async def wrap_func(dbname,NS,**kw): - sor = await self.getSqlor(dbname) + async def wrap_func(dbname,NS,*args,**kw): + sor, commit = self.useOrGetSor(dbname,**kw) + kw['sor'] = sor try: - desc = await func(dbname,NS,**kw) + desc = await func(dbname,NS,*args,**kw) total = await sor.record_count(desc,NS) recs = await sor.pagingdata(desc,NS) data = { @@ -277,56 +301,64 @@ class DBPools: print('error',e) raise e finally: - await self.freeSqlor(sor) + if commit: + await self.freeSqlor(sor) return wrap_func - def runSQLResultFields(self, func): + async def runSQLResultFields(self, func): @wraps(func) - async def wrap_func(dbname,NS,**kw): - sor = await self.getSqlor(dbname) + async def wrap_func(dbname,NS,*args,**kw): + sor, commit = self.useOrGetSor(dbname,**kw) + kw['sor'] = sor try: - desc = await func(dbname,NS,**kw) + desc = await func(dbname,NS,*args,**kw) ret = await sor.resultFields(desc,NS) return ret except Exception as e: print('error=',e) raise e finally: - await self.freeSqlor(sor) + if commit: + await self.freeSqlor(sor) return wrap_func - async def getTables(self,dbname): + async def getTables(self,dbname,**kw): @self.inSqlor - async def _getTables(sor,dbname): + async def _getTables(dbname,NS,**kw): + sor = kw['sor'] ret = await sor.tables() return ret - return await _getTables(None,dbname) + return await _getTables(dbname,{},**kw) - async def getTableFields(self,dbname,tblname): + async def getTableFields(self,dbname,tblname,**kw): @self.inSqlor - async def _getTableFields(sor,dbname,tblname): + async def _getTableFields(dbname,NS,tblname,**kw): + sor = kw['sor'] ret = await sor.fields(tblname) return ret - return await _getTableFields(None,dbname,tblname) + return await _getTableFields(dbname,{},tblname,**kw) async def getTablePrimaryKey(self,dbname,tblname): @self.inSqlor - async def _getTablePrimaryKey(sor,dbname,tblname): + async def _getTablePrimaryKey(dbname,NS,tblname,**kw): + sor = kw['sor'] ret = await sor.primary(tblname) return ret - return await _getTablePrimaryKey(None,dbname,tblname) + return await _getTablePrimaryKey(dbname,{},tblname,**kw) async def getTableIndexes(self,dbname,tblname): @self.inSqlor - async def _getTablePrimaryKey(sor,dbname,tblname): + async def _getTablePrimaryKey(dbname,NS,tblname,**kw): + sor = kw['sor'] ret = await sor.indexes(tblname) return ret - return await _getTablePrimaryKey(None,dbname,tblname) + return await _getTablePrimaryKey(dbname,{},tblname,**kw) async def getTableForignKeys(self,dbname,tblname): @self.inSqlor - async def _getTableForignKeys(sor,dbname,tblname): + async def _getTableForignKeys(dbname,NS,tblname,**kw): + sor = kw['sor'] ret = await sor.fkeys(tblname) return ret - return await _getTableForignKeys(None,dbname,tblname) + return await _getTableForignKeys(dbname,{},tblname,**kw)