[sqlgen] Allow arbitrary SQL expression in statements generated by SQLGenerator. Closes #132590

Introduce a SQLExpression class that can be instantiated and given as argument to the various methods instead of string values: those expressions will then be directly inserted in the SQL rather than kept in the values dictionary.

authorVincent Michel <vincent.michel@logilab.fr>
changeset69bf85e32c2d
branchstable
phasedraft
hiddenyes
parent revision#3c58962f468c add TYPE_CONVERTERS dict on advanced helpers defining conversions callback from Python SQL types (closes #132587)
child revision<not specified>
files modified by this revision
ChangeLog
sqlgen.py
test/unittest_sqlgen.py
# HG changeset patch
# User Vincent Michel <vincent.michel@logilab.fr>
# Date 1366117187 -7200
# Tue Apr 16 14:59:47 2013 +0200
# Branch stable
# Node ID 69bf85e32c2d579e988cbd7bfe8cd51c99d89e2a
# Parent 3c58962f468c7df96152967702ac2946891c5612
[sqlgen] Allow arbitrary SQL expression in statements generated by SQLGenerator. Closes #132590

Introduce a SQLExpression class that can be instantiated and given as argument
to the various methods instead of string values: those expressions will then
be directly inserted in the SQL rather than kept in the values dictionary.

diff --git a/ChangeLog b/ChangeLog
@@ -4,10 +4,14 @@
1 
2  --
3      * #132587: TYPE_CONVERTERS dictionary on db helper containing type 
4        conversion callback
5 
6 +    * #132590 (sqlgen): allow to give eg sql function call as argument to 
7 +      select/insert/delete/set by giving SQLExpression instances instead of
8 +      bare string
9 +
10  2012-02-03  --  1.8.2
11      * don't use CURRENT_DATETIME for postgresql (#83892)
12 
13      * sqlserver: sql_current_[date|timestamp] shouldn't return date[time] objects (#88279)
14 
diff --git a/sqlgen.py b/sqlgen.py
@@ -1,6 +1,6 @@
15 -# copyright 2003-2011 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
16 +# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
17  # contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
18  #
19  # This file is part of logilab-database.
20  #
21  # logilab-database is free software: you can redistribute it and/or modify it
@@ -20,28 +20,76 @@
22  """
23  __docformat__ = "restructuredtext en"
24 
25  # SQLGenerator ################################################################
26 
27 +class SQLExpression(object):
28 +    """Use this class when you need direct SQL expression in statements
29 +    generated by SQLGenerator. Arguments:
30 +
31 +    - a sqlstring that defines the SQL expression to be used, e.g. 'YEARS(%(date)s)'
32 +
33 +    - kwargs that define the values to be substituted in the SQL expression,
34 +      e.g. date='2013/01/01'
35 +
36 +    E.g. the SQL expression SQLExpression('YEARS(%(date)s)', date='2013/01/01')
37 +    will yield:
38 +
39 +    '..., age = YEARS(%(date)s), ...' in a SQL statement
40 +
41 +    and will modify accordingly the parameters:
42 +
43 +    {'date': '2013/01/01', ...}
44 +    """
45 +
46 +    def __init__(self, sqlstring, **kwargs):
47 +        self.sqlstring = sqlstring
48 +        self.kwargs = kwargs
49 +
50 +
51  class SQLGenerator :
52      """
53      Helper class to generate SQL strings to use with python's DB-API.
54      """
55 +    def _iterate_params(self, params):
56 +        """ Iterate a parameters dictionnary and yield the correct column name
57 +        and value (base types or SQL functions) """
58 +        # sort for predictability
59 +        for column, value in sorted(params.items()):
60 +            if isinstance(value, SQLExpression):
61 +                # In this case the value that should be substitued
62 +                # is not anymore the one in params, but is passed as kwargs
63 +                # in the SQLExpression
64 +                # E.g. 'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')
65 +                # will create the statement
66 +                # age = YEARS(%(date)s)
67 +                # and thus, the params dictionnary should have a 'date': '2013/01/01'
68 +                # for correct substitution
69 +                params.update(value.kwargs)
70 +                params.pop(column)
71 +                yield column, value.sqlstring
72 +            else:
73 +                yield column, "%%(%s)s" % column
74 
75      def where(self, keys, addon=None):
76          """
77          :param keys: list of keys
78 +        :param addon: additional sql statement
79 
80          >>> s = SQLGenerator()
81          >>> s.where(['nom'])
82          'nom = %(nom)s'
83          >>> s.where(['nom','prenom'])
84          'nom = %(nom)s AND prenom = %(prenom)s'
85          >>> s.where(['nom','prenom'], 'x.id = y.id')
86          'x.id = y.id AND nom = %(nom)s AND prenom = %(prenom)s'
87          """
88 -        restriction = ["%s = %%(%s)s" % (x, x) for x in keys]
89 +        # Do not need SQLExpression here, as we have the addon argument.
90 +        if isinstance(keys, dict):
91 +            restriction = ["%s = %s" % (col, val) for col, val in self._iterate_params(keys)]
92 +        else:
93 +            restriction = ["%s = %%(%s)s" % (x, x) for x in keys]
94          if addon:
95              restriction.insert(0, addon)
96          return " AND ".join(restriction)
97 
98      def set(self, keys):
@@ -52,26 +100,37 @@
99          >>> s.set(['nom'])
100          'nom = %(nom)s'
101          >>> s.set(['nom','prenom'])
102          'nom = %(nom)s, prenom = %(prenom)s'
103          """
104 -        return ", ".join(["%s = %%(%s)s" % (x, x) for x in keys])
105 +        if isinstance(keys, dict):
106 +            set_parts = ["%s = %s" % (col, val) for col, val in self._iterate_params(keys)]
107 +        else:
108 +            set_parts = ["%s = %%(%s)s" % (x, x) for x in keys]
109 +        return ", ".join(set_parts)
110 
111      def insert(self, table, params):
112          """
113          :param table: name of the table
114          :param params:  dictionary that will be used as in cursor.execute(sql,params)
115 -
116          >>> s = SQLGenerator()
117          >>> s.insert('test',{'nom':'dupont'})
118          'INSERT INTO test ( nom ) VALUES ( %(nom)s )'
119 -        >>> s.insert('test',{'nom':'dupont','prenom':'jean'})
120 -        'INSERT INTO test ( nom, prenom ) VALUES ( %(nom)s, %(prenom)s )'
121 +        >>> params = {'nom':'dupont', 'prenom':'jean',
122 +        ...          'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
123 +        >>> s.insert('test', params)
124 +        'INSERT INTO test ( age, nom, prenom ) VALUES ( YEARS(%(date)s), %(nom)s, %(prenom)s )'
125 +        >>> params['date'] # params has been modified
126 +        '2013/01/01'
127          """
128 -        keys = ', '.join(params.keys())
129 -        values = ', '.join(["%%(%s)s" % x for x in params])
130 -        sql = 'INSERT INTO %s ( %s ) VALUES ( %s )' % (table, keys, values)
131 +        columns = []
132 +        values = []
133 +        # sort for predictability
134 +        for column, value in self._iterate_params(params):
135 +            columns.append(column)
136 +            values.append(value)
137 +        sql = 'INSERT INTO %s ( %s ) VALUES ( %s )' % (table, ', '.join(columns), ', '.join(values))
138          return sql
139 
140      def select(self, table, params=None, selection=None):
141          """
142          :param table: name of the table
@@ -86,14 +145,13 @@
143          'SELECT * FROM test WHERE nom = %(nom)s AND prenom = %(prenom)s'
144          """
145          if selection is None:
146              sql = 'SELECT * FROM %s' % table
147          else:
148 -            sql = 'SELECT %s FROM %s' % (','.join(col for col in selection),
149 -                                         table)
150 +            sql = 'SELECT %s FROM %s' % (','.join(col for col in selection), table)
151          if params is not None:
152 -            where = self.where(params.keys())
153 +            where = self.where(params)
154              if where :
155                  sql = sql + ' WHERE %s' % where
156          return sql
157 
158      def adv_select(self, model, tables, params, joins=None) :
@@ -112,44 +170,41 @@
159          """
160          table_names = ["%s AS %s" % (k, v) for k, v in tables]
161          sql = 'SELECT %s FROM %s' % (', '.join(model), ', '.join(table_names))
162          if joins and type(joins) != type(''):
163              joins = ' AND '.join(joins)
164 -        where = self.where(params.keys(), joins)
165 +        where = self.where(params, joins)
166          if where :
167              sql = sql + ' WHERE %s' % where
168          return sql
169 
170 -    def delete(self, table, params) :
171 +    def delete(self, table, params, addon=None) :
172          """
173          :param table: name of the table
174          :param params: dictionary that will be used as in cursor.execute(sql,params)
175 
176          >>> s = SQLGenerator()
177          >>> s.delete('test',{'nom':'dupont'})
178          'DELETE FROM test WHERE nom = %(nom)s'
179          >>> s.delete('test',{'nom':'dupont','prenom':'jean'})
180          'DELETE FROM test WHERE nom = %(nom)s AND prenom = %(prenom)s'
181          """
182 -        where = self.where(params.keys())
183 +        where = self.where(params, addon=addon)
184          sql = 'DELETE FROM %s WHERE %s' % (table, where)
185          return sql
186 
187      def delete_many(self, table, params):
188 -        restriction = []
189 -        to_pop = []
190 -        for col, value in params.iteritems():
191 -            if value.startswith('('): # we want IN
192 -                restriction.append('%s IN %s' % (col, value))
193 -                to_pop.append(col)
194 -            else:
195 -                restriction.append("%s = %%(%s)s" % (col, col))
196 -        for col in to_pop:
197 -            del params[col]
198 -        where = " AND ".join(restriction)
199 -        sql = 'DELETE FROM %s WHERE %s' % (table, where)
200 -        return sql
201 +        """ Delete many using the IN clause
202 +        """
203 +        addons = []
204 +        for key, value in params.items():
205 +            if not isinstance(value, SQLExpression) and value.startswith('('): # we want IN
206 +                addons.append('%s IN %s' % (key, value))
207 +                # The value is pop as it is not needed for substitution
208 +                # (the value is directly written in the SQL IN statement)
209 +                params.pop(key)
210 +        return self.delete(table, params, addon=' AND '.join(addons))
211 
212      def update(self, table, params, unique) :
213          """
214          :param table: name of the table
215          :param params: dictionary that will be used as in cursor.execute(sql,params)
@@ -159,11 +214,20 @@
216          'UPDATE test SET nom = %(nom)s WHERE id = %(id)s'
217          >>> s.update('test',{'id':'001','nom':'dupont','prenom':'jean'},['id'])
218          'UPDATE test SET nom = %(nom)s, prenom = %(prenom)s WHERE id = %(id)s'
219          """
220          where = self.where(unique)
221 -        set = self.set([key for key in params if key not in unique])
222 +        # Remove the unique keys from the params dictionnary
223 +        unique_params = {}
224 +        for key, value in params.items():
225 +            if key in unique:
226 +                params.pop(key)
227 +                unique_params[key] = value
228 +        set = self.set(params)
229 +        # Add the removed unique params to the (now possibly updated)
230 +        # params dict (if there were some SQLExpressions)
231 +        params.update(unique_params)
232          sql = 'UPDATE %s SET %s WHERE %s' % (table, set, where)
233          return sql
234 
235 
236  class BaseTable:
diff --git a/test/unittest_sqlgen.py b/test/unittest_sqlgen.py
@@ -0,0 +1,160 @@
237 +# copyright 2003-2013 LOGILAB S.A. (Paris, FRANCE), all rights reserved.
238 +# contact http://www.logilab.fr/ -- mailto:contact@logilab.fr
239 +#
240 +# This file is part of logilab-database.
241 +#
242 +# logilab-database is free software: you can redistribute it and/or modify it
243 +# under the terms of the GNU Lesser General Public License as published by the
244 +# Free Software Foundation, either version 2.1 of the License, or (at your
245 +# option) any later version.
246 +#
247 +# logilab-database is distributed in the hope that it will be useful, but
248 +# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
249 +# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public License
250 +# for more details.
251 +#
252 +# You should have received a copy of the GNU Lesser General Public License along
253 +# with logilab-database. If not, see <http://www.gnu.org/licenses/>.
254 +import unittest
255 +from datetime import datetime, time, date
256 +
257 +from logilab.database import get_db_helper
258 +from logilab.database.sqlgen import SQLGenerator, SQLExpression
259 +
260 +
261 +class SQLGenTC(unittest.TestCase):
262 +
263 +    def test_set_values(self):
264 +        s = SQLGenerator()
265 +        self.assertEqual(s.set(['nom']), 'nom = %(nom)s')
266 +        self.assertEqual(s.set(['nom','prenom']), 'nom = %(nom)s, prenom = %(prenom)s')
267 +        params = {'nom': 'dupont', 'prenom': 'jean'}
268 +        self.assertEqual(s.set(params), 'nom = %(nom)s, prenom = %(prenom)s')
269 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean'})
270 +
271 +    def test_set_functions(self):
272 +        s = SQLGenerator()
273 +        params = {'nom': 'dupont', 'prenom': 'jean', 'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
274 +        self.assertEqual(s.set(params), 'age = YEARS(%(date)s), nom = %(nom)s, prenom = %(prenom)s')
275 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01'})
276 +
277 +    def test_where_values(self):
278 +        s = SQLGenerator()
279 +        self.assertEqual(s.where(['nom']), 'nom = %(nom)s')
280 +        self.assertEqual(s.where(['nom','prenom']), 'nom = %(nom)s AND prenom = %(prenom)s')
281 +        self.assertEqual(s.where(['nom','prenom'], 'x.id = y.id'),
282 +                         'x.id = y.id AND nom = %(nom)s AND prenom = %(prenom)s')
283 +        params = {'nom': 'dupont', 'prenom': 'jean'}
284 +        self.assertEqual(s.where(params), 'nom = %(nom)s AND prenom = %(prenom)s')
285 +        self.assertEqual(s.where(params, 'x.id = y.id'),
286 +                         'x.id = y.id AND nom = %(nom)s AND prenom = %(prenom)s')
287 +
288 +    def test_where_functions(self):
289 +        s = SQLGenerator()
290 +        params = {'nom': 'dupont', 'prenom': 'jean', 'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
291 +        self.assertEqual(s.where(params), 'age = YEARS(%(date)s) AND nom = %(nom)s AND prenom = %(prenom)s')
292 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01'})
293 +        params = {'nom': 'dupont', 'prenom': 'jean', 'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
294 +        self.assertEqual(s.where(params, 'x.id = y.id'),
295 +                         'x.id = y.id AND age = YEARS(%(date)s) AND nom = %(nom)s AND prenom = %(prenom)s')
296 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01'})
297 +
298 +    def test_insert_values(self):
299 +        s = SQLGenerator()
300 +        params = {'nom': 'dupont'}
301 +        sqlstr = s.insert('test', params)
302 +        self.assertEqual(sqlstr, 'INSERT INTO test ( nom ) VALUES ( %(nom)s )')
303 +        self.assertEqual(params, {'nom': 'dupont'})
304 +
305 +    def test_insert_functions(self):
306 +        s = SQLGenerator()
307 +        params = {'nom':'dupont', 'prenom':'jean',
308 +                  'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
309 +        sqlstr = s.insert('test', params)
310 +        self.assertEqual(sqlstr,  'INSERT INTO test ( age, nom, prenom ) VALUES '
311 +                         '( YEARS(%(date)s), %(nom)s, %(prenom)s )')
312 +        self.assertEqual(params, {'nom':'dupont', 'prenom':'jean', 'date': '2013/01/01'})
313 +
314 +    def test_select_values(self):
315 +        s = SQLGenerator()
316 +        self.assertEqual(s.select('test',{}), 'SELECT * FROM test')
317 +        self.assertEqual(s.select('test',{'nom':'dupont'}),
318 +                         'SELECT * FROM test WHERE nom = %(nom)s')
319 +        self.assertEqual(s.select('test',{'nom':'dupont','prenom':'jean'}),
320 +                         'SELECT * FROM test WHERE nom = %(nom)s AND prenom = %(prenom)s')
321 +
322 +    def test_select_functions(self):
323 +        s = SQLGenerator()
324 +        params = {'nom':'dupont', 'prenom':'jean',
325 +                  'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
326 +        self.assertEqual(s.select('test', params),
327 +                         'SELECT * FROM test WHERE age = YEARS(%(date)s) '
328 +                         'AND nom = %(nom)s AND prenom = %(prenom)s')
329 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01'})
330 +
331 +    def test_adv_select_values(self):
332 +        s = SQLGenerator()
333 +        self.assertEqual(s.adv_select(['column'],[('test', 't')], {}),
334 +                         'SELECT column FROM test AS t')
335 +        self.assertEqual( s.adv_select(['column'],[('test', 't')], {'nom':'dupont'}),
336 +                          'SELECT column FROM test AS t WHERE nom = %(nom)s')
337 +
338 +    def test_adv_select_functions(self):
339 +        s = SQLGenerator()
340 +        params = {'nom':'dupont', 'prenom':'jean',
341 +                  'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
342 +        self.assertEqual( s.adv_select(['column'],[('test', 't')], params),
343 +                          'SELECT column FROM test AS t WHERE age = YEARS(%(date)s) '
344 +                         'AND nom = %(nom)s AND prenom = %(prenom)s')
345 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01'})
346 +
347 +    def test_delete_values(self):
348 +        s = SQLGenerator()
349 +        self.assertEqual(s.delete('test',{'nom':'dupont'}),
350 +                         'DELETE FROM test WHERE nom = %(nom)s')
351 +        self.assertEqual(s.delete('test',{'nom':'dupont','prenom':'jean'}),
352 +                         'DELETE FROM test WHERE nom = %(nom)s AND prenom = %(prenom)s')
353 +
354 +    def test_delete_functions(self):
355 +        s = SQLGenerator()
356 +        params = {'nom':'dupont', 'prenom':'jean',
357 +                  'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
358 +        self.assertEqual( s.delete('test', params),
359 +                          'DELETE FROM test WHERE age = YEARS(%(date)s) '
360 +                         'AND nom = %(nom)s AND prenom = %(prenom)s')
361 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01'})
362 +
363 +    def test_delete_many_values(self):
364 +        s = SQLGenerator()
365 +        params = {'nom':'dupont', 'eid': '(1, 2, 3)'}
366 +        self.assertEqual(s.delete_many('test', params),
367 +                         'DELETE FROM test WHERE eid IN (1, 2, 3) AND nom = %(nom)s')
368 +        self.assertEqual(params, {'nom':'dupont'})
369 +
370 +    def test_delete_many_functions(self):
371 +        s = SQLGenerator()
372 +        params = {'nom':'dupont', 'prenom':'jean', 'eid': '(1, 2, 3)',
373 +                  'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
374 +        self.assertEqual( s.delete_many('test', params),
375 +                          'DELETE FROM test WHERE eid IN (1, 2, 3) AND age = YEARS(%(date)s) '
376 +                          'AND nom = %(nom)s AND prenom = %(prenom)s')
377 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01'})
378 +
379 +    def test_update_values(self):
380 +        s = SQLGenerator()
381 +        self.assertEqual(s.update('test', {'id':'001','nom':'dupont'}, ['id']),
382 +                         'UPDATE test SET nom = %(nom)s WHERE id = %(id)s')
383 +        self.assertEqual(s.update('test',{'id':'001','nom':'dupont','prenom':'jean'},['id']),
384 +                         'UPDATE test SET nom = %(nom)s, prenom = %(prenom)s WHERE id = %(id)s')
385 +
386 +    def test_update_functions(self):
387 +        s = SQLGenerator()
388 +        params = {'id': '001', 'nom':'dupont', 'prenom':'jean',
389 +                  'age': SQLExpression('YEARS(%(date)s)', date='2013/01/01')}
390 +        self.assertEqual( s.update('test', params, ['id']),
391 +                          'UPDATE test SET age = YEARS(%(date)s), nom = %(nom)s, '
392 +                          'prenom = %(prenom)s WHERE id = %(id)s')
393 +        self.assertEqual(params, {'nom': 'dupont', 'prenom': 'jean', 'date': '2013/01/01', 'id': '001'})
394 +
395 +if __name__ == '__main__':
396 +    unittest.main()