[sqlgen] Allow arbitrary SQL expression in statements generated by SQLGenerator. Closes #132590

Introduce a SQLExpression class that can be instantiated and given as argument to the select/insert/delete/set methods instead of string values: those expressions will then be directly inserted in the SQL rather than kept in the values dictionary.

authorVincent Michel <vincent.michel@logilab.fr>
changeset8d9520575c0e
branchdefault
phasepublic
hiddenno
parent revision#ff56f6dbe746 merge
child revision#afa47933b69d [sqlite] drop pysqlite2 support, using plain python sqlite3 module (closes #125083)
files modified by this revision
ChangeLog
sqlgen.py
test/unittest_sqlgen.py
# HG changeset patch
# User Vincent Michel <vincent.michel@logilab.fr>
# Date 1366275435 -7200
# Thu Apr 18 10:57:15 2013 +0200
# Node ID 8d9520575c0e289d90d1beb36af6dd11842eeca7
# Parent ff56f6dbe746624843621566a972981bd534284e
[sqlgen] Allow arbitrary SQL expression in statements generated by SQLGenerator. Closes #132590

Introduce a SQLExpression class that can be instantiated and given as argument
to the select/insert/delete/set methods instead of string values: those expressions
will then be directly inserted in the SQL rather than kept in the values dictionary.

diff --git a/ChangeLog b/ChangeLog
@@ -1,8 +1,15 @@
1  Changelog for logilab database package
2  ======================================
3 
4 +
5 +--
6 +
7 +    * #132590 (sqlgen): allow to give arbitrary SQL expression (e.g. function call)
8 +      as argument to select/insert/delete/set methods by giving SQLExpression instances 
9 +      instead of bare strings
10 +
11  2012-02-03  --  1.8.2
12      * don't use CURRENT_DATETIME for postgresql (#83892)
13 
14      * sqlserver: sql_current_[date|timestamp] shouldn't return date[time] objects (#88279)
15 
diff --git a/sqlgen.py b/sqlgen.py
@@ -1,6 +1,6 @@
16 -# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
17 +# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
18  # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
19  #
20  # This file is part of logilab-database.
21  #
22  # logilab-database is free software: you can redistribute it and/or modify it
@@ -20,28 +20,76 @@
23  """
24  __docformat__ = "restructuredtext en"
25 
26  # SQLGenerator ################################################################
27 
28 +class SQLExpression(object):
29 +    """Use this class when you need direct SQL expression in statements
30 +    generated by SQLGenerator. Arguments:
31 +
32 +    - a sqlstring that defines the SQL expression to be used, e.g. 'YEARS(%(date)s)'
33 +
34 +    - kwargs that define the values to be substituted in the SQL expression,
35 +      e.g. date='2013/01/01'
36 +
37 +    E.g. the SQL expression SQLExpression('YEARS(%(date)s)', date='2013/01/01')
38 +    will yield:
39 +
40 +    '..., age = YEARS(%(date)s), ...' in a SQL statement
41 +
42 +    and will modify accordingly the parameters:
43 +
44 +    {'date': '2013/01/01', ...}
45 +    """
46 +
47 +    def __init__(self, sqlstring, **kwargs):
48 +        self.sqlstring = sqlstring
49 +        self.kwargs = kwargs
50 +
51 +
52  class SQLGenerator :
53      """
54      Helper class to generate SQL strings to use with python's DB-API.
55      """
56 +    def _iterate_params(self, params):
57 +        """ Iterate a parameters dictionnary and yield the correct column name
58 +        and value (base types or SQL functions) """
59 +        # sort for predictability
60 +        for column, value in sorted(params.items()):
61 +            if isinstance(value, SQLExpression):
62 +                # In this case the value that should be substitued
63 +                # is not anymore the one in params, but is passed as kwargs
64 +                # in the SQLExpression
65 +                # E.g. 'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')
66 +                # will create the statement
67 +                # age = YEARS(%(date)s)
68 +                # and thus, the params dictionnary should have a 'date': '2013/01/01'
69 +                # for correct substitution
70 +                params.update(value.kwargs)
71 +                params.pop(column)
72 +                yield column, value.sqlstring
73 +            else:
74 +                yield column, "%%(%s)s" % column
75 
76      def where(self, keys, addon=None):
77          """
78          :param keys: list of keys
79 +        :param addon: additional sql statement
80 
81          >>> s = SQLGenerator()
82          >>> s.where(['nom'])
83          'nom = %(nom)s'
84          >>> s.where(['nom','prenom'])
85          'nom = %(nom)s AND prenom = %(prenom)s'
86          >>> s.where(['nom','prenom'], 'x.id = y.id')
87          'x.id = y.id AND nom = %(nom)s AND prenom = %(prenom)s'
88          """
89 -        restriction = ["%s = %%(%s)s" % (x, x) for x in keys]
90 +        # Do not need SQLExpression here, as we have the addon argument.
91 +        if isinstance(keys, dict):
92 +            restriction = ["%s = %s" % (col, val) for col, val in self._iterate_params(keys)]
93 +        else:
94 +            restriction = ["%s = %%(%s)s" % (x, x) for x in keys]
95          if addon:
96              restriction.insert(0, addon)
97          return " AND ".join(restriction)
98 
99      def set(self, keys):
@@ -52,26 +100,37 @@
100          >>> s.set(['nom'])
101          'nom = %(nom)s'
102          >>> s.set(['nom','prenom'])
103          'nom = %(nom)s, prenom = %(prenom)s'
104          """
105 -        return ", ".join(["%s = %%(%s)s" % (x, x) for x in keys])
106 +        if isinstance(keys, dict):
107 +            set_parts = ["%s = %s" % (col, val) for col, val in self._iterate_params(keys)]
108 +        else:
109 +            set_parts = ["%s = %%(%s)s" % (x, x) for x in keys]
110 +        return ", ".join(set_parts)
111 
112      def insert(self, table, params):
113          """
114          :param table: name of the table
115          :param params:  dictionary that will be used as in cursor.execute(sql,params)
116 -
117          >>> s = SQLGenerator()
118          >>> s.insert('test',{'nom':'dupont'})
119          'INSERT INTO test ( nom ) VALUES ( %(nom)s )'
120 -        >>> s.insert('test',{'nom':'dupont','prenom':'jean'})
121 -        'INSERT INTO test ( nom, prenom ) VALUES ( %(nom)s, %(prenom)s )'
122 +        >>> params = {'nom':'dupont', 'prenom':'jean',
123 +        ...          'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
124 +        >>> s.insert('test', params)
125 +        'INSERT INTO test ( age, nom, prenom ) VALUES ( YEARS(%(date)s), %(nom)s, %(prenom)s )'
126 +        >>> params['date'] # params has been modified
127 +        '2013/01/01'
128          """
129 -        keys = ', '.join(params.keys())
130 -        values = ', '.join(["%%(%s)s" % x for x in params])
131 -        sql = 'INSERT INTO %s ( %s ) VALUES ( %s )' % (table, keys, values)
132 +        columns = []
133 +        values = []
134 +        # sort for predictability
135 +        for column, value in self._iterate_params(params):
136 +            columns.append(column)
137 +            values.append(value)
138 +        sql = 'INSERT INTO %s ( %s ) VALUES ( %s )' % (table, ', '.join(columns), ', '.join(values))
139          return sql
140 
141      def select(self, table, params=None, selection=None):
142          """
143          :param table: name of the table
@@ -86,14 +145,13 @@
144          'SELECT * FROM test WHERE nom = %(nom)s AND prenom = %(prenom)s'
145          """
146          if selection is None:
147              sql = 'SELECT * FROM %s' % table
148          else:
149 -            sql = 'SELECT %s FROM %s' % (','.join(col for col in selection),
150 -                                         table)
151 +            sql = 'SELECT %s FROM %s' % (','.join(col for col in selection), table)
152          if params is not None:
153 -            where = self.where(params.keys())
154 +            where = self.where(params)
155              if where :
156                  sql = sql + ' WHERE %s' % where
157          return sql
158 
159      def adv_select(self, model, tables, params, joins=None) :
@@ -112,44 +170,41 @@
160          """
161          table_names = ["%s AS %s" % (k, v) for k, v in tables]
162          sql = 'SELECT %s FROM %s' % (', '.join(model), ', '.join(table_names))
163          if joins and type(joins) != type(''):
164              joins = ' AND '.join(joins)
165 -        where = self.where(params.keys(), joins)
166 +        where = self.where(params, joins)
167          if where :
168              sql = sql + ' WHERE %s' % where
169          return sql
170 
171 -    def delete(self, table, params) :
172 +    def delete(self, table, params, addon=None) :
173          """
174          :param table: name of the table
175          :param params: dictionary that will be used as in cursor.execute(sql,params)
176 
177          >>> s = SQLGenerator()
178          >>> s.delete('test',{'nom':'dupont'})
179          'DELETE FROM test WHERE nom = %(nom)s'
180          >>> s.delete('test',{'nom':'dupont','prenom':'jean'})
181          'DELETE FROM test WHERE nom = %(nom)s AND prenom = %(prenom)s'
182          """
183 -        where = self.where(params.keys())
184 +        where = self.where(params, addon=addon)
185          sql = 'DELETE FROM %s WHERE %s' % (table, where)
186          return sql
187 
188      def delete_many(self, table, params):
189 -        restriction = []
190 -        to_pop = []
191 -        for col, value in params.iteritems():
192 -            if value.startswith('('): # we want IN
193 -                restriction.append('%s IN %s' % (col, value))
194 -                to_pop.append(col)
195 -            else:
196 -                restriction.append("%s = %%(%s)s" % (col, col))
197 -        for col in to_pop:
198 -            del params[col]
199 -        where = " AND ".join(restriction)
200 -        sql = 'DELETE FROM %s WHERE %s' % (table, where)
201 -        return sql
202 +        """ Delete many using the IN clause
203 +        """
204 +        addons = []
205 +        for key, value in params.items():
206 +            if not isinstance(value, SQLExpression) and value.startswith('('): # we want IN
207 +                addons.append('%s IN %s' % (key, value))
208 +                # The value is pop as it is not needed for substitution
209 +                # (the value is directly written in the SQL IN statement)
210 +                params.pop(key)
211 +        return self.delete(table, params, addon=' AND '.join(addons))
212 
213      def update(self, table, params, unique) :
214          """
215          :param table: name of the table
216          :param params: dictionary that will be used as in cursor.execute(sql,params)
@@ -159,11 +214,20 @@
217          'UPDATE test SET nom = %(nom)s WHERE id = %(id)s'
218          >>> s.update('test',{'id':'001','nom':'dupont','prenom':'jean'},['id'])
219          'UPDATE test SET nom = %(nom)s, prenom = %(prenom)s WHERE id = %(id)s'
220          """
221          where = self.where(unique)
222 -        set = self.set([key for key in params if key not in unique])
223 +        # Remove the unique keys from the params dictionnary
224 +        unique_params = {}
225 +        for key, value in params.items():
226 +            if key in unique:
227 +                params.pop(key)
228 +                unique_params[key] = value
229 +        set = self.set(params)
230 +        # Add the removed unique params to the (now possibly updated)
231 +        # params dict (if there were some SQLExpressions)
232 +        params.update(unique_params)
233          sql = 'UPDATE %s SET %s WHERE %s' % (table, set, where)
234          return sql
235 
236 
237  class BaseTable:
diff --git a/test/unittest_sqlgen.py b/test/unittest_sqlgen.py
@@ -0,0 +1,160 @@
238 +# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
239 +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
240 +#
241 +# This file is part of logilab-database.
242 +#
243 +# logilab-database is free software: you can redistribute it and/or modify it
244 +# under the terms of the GNU Lesser General Public License as published by the
245 +# Free Software Foundation, either version 2.1 of the License, or (at your
246 +# option) any later version.
247 +#
248 +# logilab-database is distributed in the hope that it will be useful, but
249 +# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
250 +# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License
251 +# for more details.
252 +#
253 +# You should have received a copy of the GNU Lesser General Public License along
254 +# with logilab-database. If not, see <http://www.gnu.org/licenses/>.
255 +import unittest
256 +from datetime import datetime, time, date
257 +
258 +from logilab.database import get_db_helper
259 +from logilab.database.sqlgen import SQLGenerator, SQLExpression
260 +
261 +
262 +class SQLGenTC(unittest.TestCase):
263 +
264 +    def test_set_values(self):
265 +        s = SQLGenerator()
266 +        self.assertEqual(s.set(['nom']), 'nom = %(nom)s')
267 +        self.assertEqual(s.set(['nom','prenom']), 'nom = %(nom)s, prenom = %(prenom)s')
268 +        params = {'nom': 'dupont', 'prenom': 'jean'}
269 +        self.assertEqual(s.set(params), 'nom = %(nom)s, prenom = %(prenom)s')
270 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean'})
271 +
272 +    def test_set_functions(self):
273 +        s = SQLGenerator()
274 +        params = {'nom': 'dupont', 'prenom': 'jean', 'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
275 +        self.assertEqual(s.set(params), 'age = YEARS(%(date)s), nom = %(nom)s, prenom = %(prenom)s')
276 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01'})
277 +
278 +    def test_where_values(self):
279 +        s = SQLGenerator()
280 +        self.assertEqual(s.where(['nom']), 'nom = %(nom)s')
281 +        self.assertEqual(s.where(['nom','prenom']), 'nom = %(nom)s AND prenom = %(prenom)s')
282 +        self.assertEqual(s.where(['nom','prenom'], 'x.id = y.id'),
283 +                         'x.id = y.id AND nom = %(nom)s AND prenom = %(prenom)s')
284 +        params = {'nom': 'dupont', 'prenom': 'jean'}
285 +        self.assertEqual(s.where(params), 'nom = %(nom)s AND prenom = %(prenom)s')
286 +        self.assertEqual(s.where(params, 'x.id = y.id'),
287 +                         'x.id = y.id AND nom = %(nom)s AND prenom = %(prenom)s')
288 +
289 +    def test_where_functions(self):
290 +        s = SQLGenerator()
291 +        params = {'nom': 'dupont', 'prenom': 'jean', 'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
292 +        self.assertEqual(s.where(params), 'age = YEARS(%(date)s) AND nom = %(nom)s AND prenom = %(prenom)s')
293 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01'})
294 +        params = {'nom': 'dupont', 'prenom': 'jean', 'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
295 +        self.assertEqual(s.where(params, 'x.id = y.id'),
296 +                         'x.id = y.id AND age = YEARS(%(date)s) AND nom = %(nom)s AND prenom = %(prenom)s')
297 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01'})
298 +
299 +    def test_insert_values(self):
300 +        s = SQLGenerator()
301 +        params = {'nom': 'dupont'}
302 +        sqlstr = s.insert('test', params)
303 +        self.assertEqual(sqlstr, 'INSERT INTO test ( nom ) VALUES ( %(nom)s )')
304 +        self.assertEqual(params, {'nom': 'dupont'})
305 +
306 +    def test_insert_functions(self):
307 +        s = SQLGenerator()
308 +        params = {'nom':'dupont', 'prenom':'jean',
309 +                  'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
310 +        sqlstr = s.insert('test', params)
311 +        self.assertEqual(sqlstr,  'INSERT INTO test ( age, nom, prenom ) VALUES '
312 +                         '( YEARS(%(date)s), %(nom)s, %(prenom)s )')
313 +        self.assertEqual(params, {'nom':'dupont', 'prenom':'jean', 'date': '2013/01/01'})
314 +
315 +    def test_select_values(self):
316 +        s = SQLGenerator()
317 +        self.assertEqual(s.select('test',{}), 'SELECT * FROM test')
318 +        self.assertEqual(s.select('test',{'nom':'dupont'}),
319 +                         'SELECT * FROM test WHERE nom = %(nom)s')
320 +        self.assertEqual(s.select('test',{'nom':'dupont','prenom':'jean'}),
321 +                         'SELECT * FROM test WHERE nom = %(nom)s AND prenom = %(prenom)s')
322 +
323 +    def test_select_functions(self):
324 +        s = SQLGenerator()
325 +        params = {'nom':'dupont', 'prenom':'jean',
326 +                  'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
327 +        self.assertEqual(s.select('test', params),
328 +                         'SELECT * FROM test WHERE age = YEARS(%(date)s) '
329 +                         'AND nom = %(nom)s AND prenom = %(prenom)s')
330 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01'})
331 +
332 +    def test_adv_select_values(self):
333 +        s = SQLGenerator()
334 +        self.assertEqual(s.adv_select(['column'],[('test', 't')], {}),
335 +                         'SELECT column FROM test AS t')
336 +        self.assertEqual( s.adv_select(['column'],[('test', 't')], {'nom':'dupont'}),
337 +                          'SELECT column FROM test AS t WHERE nom = %(nom)s')
338 +
339 +    def test_adv_select_functions(self):
340 +        s = SQLGenerator()
341 +        params = {'nom':'dupont', 'prenom':'jean',
342 +                  'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
343 +        self.assertEqual( s.adv_select(['column'],[('test', 't')], params),
344 +                          'SELECT column FROM test AS t WHERE age = YEARS(%(date)s) '
345 +                         'AND nom = %(nom)s AND prenom = %(prenom)s')
346 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01'})
347 +
348 +    def test_delete_values(self):
349 +        s = SQLGenerator()
350 +        self.assertEqual(s.delete('test',{'nom':'dupont'}),
351 +                         'DELETE FROM test WHERE nom = %(nom)s')
352 +        self.assertEqual(s.delete('test',{'nom':'dupont','prenom':'jean'}),
353 +                         'DELETE FROM test WHERE nom = %(nom)s AND prenom = %(prenom)s')
354 +
355 +    def test_delete_functions(self):
356 +        s = SQLGenerator()
357 +        params = {'nom':'dupont', 'prenom':'jean',
358 +                  'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
359 +        self.assertEqual( s.delete('test', params),
360 +                          'DELETE FROM test WHERE age = YEARS(%(date)s) '
361 +                         'AND nom = %(nom)s AND prenom = %(prenom)s')
362 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01'})
363 +
364 +    def test_delete_many_values(self):
365 +        s = SQLGenerator()
366 +        params = {'nom':'dupont', 'eid': '(1, 2, 3)'}
367 +        self.assertEqual(s.delete_many('test', params),
368 +                         'DELETE FROM test WHERE eid IN (1, 2, 3) AND nom = %(nom)s')
369 +        self.assertEqual(params, {'nom':'dupont'})
370 +
371 +    def test_delete_many_functions(self):
372 +        s = SQLGenerator()
373 +        params = {'nom':'dupont', 'prenom':'jean', 'eid': '(1, 2, 3)',
374 +                  'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
375 +        self.assertEqual( s.delete_many('test', params),
376 +                          'DELETE FROM test WHERE eid IN (1, 2, 3) AND age = YEARS(%(date)s) '
377 +                          'AND nom = %(nom)s AND prenom = %(prenom)s')
378 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01'})
379 +
380 +    def test_update_values(self):
381 +        s = SQLGenerator()
382 +        self.assertEqual(s.update('test', {'id':'001','nom':'dupont'}, ['id']),
383 +                         'UPDATE test SET nom = %(nom)s WHERE id = %(id)s')
384 +        self.assertEqual(s.update('test',{'id':'001','nom':'dupont','prenom':'jean'},['id']),
385 +                         'UPDATE test SET nom = %(nom)s, prenom = %(prenom)s WHERE id = %(id)s')
386 +
387 +    def test_update_functions(self):
388 +        s = SQLGenerator()
389 +        params = {'id': '001', 'nom':'dupont', 'prenom':'jean',
390 +                  'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
391 +        self.assertEqual( s.update('test', params, ['id']),
392 +                          'UPDATE test SET age = YEARS(%(date)s), nom = %(nom)s, '
393 +                          'prenom = %(prenom)s WHERE id = %(id)s')
394 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01', 'id': '001'})
395 +
396 +if __name__ == '__main__':
397 +    unittest.main()