You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

595 lines
20 KiB

6 years ago
  1. """MySQLdb Cursors
  2. This module implements Cursors of various types for MySQLdb. By
  3. default, MySQLdb uses the Cursor class.
  4. """
  5. from __future__ import print_function, absolute_import
  6. from functools import partial
  7. import re
  8. import sys
  9. from MySQLdb.compat import unicode
  10. from _mysql_exceptions import (
  11. Warning, Error, InterfaceError, DataError,
  12. DatabaseError, OperationalError, IntegrityError, InternalError,
  13. NotSupportedError, ProgrammingError)
  14. PY2 = sys.version_info[0] == 2
  15. if PY2:
  16. text_type = unicode
  17. else:
  18. text_type = str
  19. #: Regular expression for :meth:`Cursor.executemany`.
  20. #: executemany only supports simple bulk insert.
  21. #: You can use it to load large dataset.
  22. RE_INSERT_VALUES = re.compile(
  23. r"\s*((?:INSERT|REPLACE)\b.+\bVALUES?\s*)" +
  24. r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
  25. r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
  26. re.IGNORECASE | re.DOTALL)
  27. class BaseCursor(object):
  28. """A base for Cursor classes. Useful attributes:
  29. description
  30. A tuple of DB API 7-tuples describing the columns in
  31. the last executed query; see PEP-249 for details.
  32. description_flags
  33. Tuple of column flags for last query, one entry per column
  34. in the result set. Values correspond to those in
  35. MySQLdb.constants.FLAG. See MySQL documentation (C API)
  36. for more information. Non-standard extension.
  37. arraysize
  38. default number of rows fetchmany() will fetch
  39. """
  40. #: Max stetement size which :meth:`executemany` generates.
  41. #:
  42. #: Max size of allowed statement is max_allowed_packet - packet_header_size.
  43. #: Default value of max_allowed_packet is 1048576.
  44. max_stmt_length = 64*1024
  45. from _mysql_exceptions import MySQLError, Warning, Error, InterfaceError, \
  46. DatabaseError, DataError, OperationalError, IntegrityError, \
  47. InternalError, ProgrammingError, NotSupportedError
  48. _defer_warnings = False
  49. connection = None
  50. def __init__(self, connection):
  51. self.connection = connection
  52. self.description = None
  53. self.description_flags = None
  54. self.rowcount = -1
  55. self.arraysize = 1
  56. self._executed = None
  57. self.lastrowid = None
  58. self.messages = []
  59. self.errorhandler = connection.errorhandler
  60. self._result = None
  61. self._warnings = None
  62. self.rownumber = None
  63. def close(self):
  64. """Close the cursor. No further queries will be possible."""
  65. try:
  66. if self.connection is None:
  67. return
  68. while self.nextset():
  69. pass
  70. finally:
  71. self.connection = None
  72. self.errorhandler = None
  73. self._result = None
  74. def __enter__(self):
  75. return self
  76. def __exit__(self, *exc_info):
  77. del exc_info
  78. self.close()
  79. def _ensure_bytes(self, x, encoding=None):
  80. if isinstance(x, text_type):
  81. x = x.encode(encoding)
  82. elif isinstance(x, (tuple, list)):
  83. x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x)
  84. return x
  85. def _escape_args(self, args, conn):
  86. ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding)
  87. if isinstance(args, (tuple, list)):
  88. if PY2:
  89. args = tuple(map(ensure_bytes, args))
  90. return tuple(conn.literal(arg) for arg in args)
  91. elif isinstance(args, dict):
  92. if PY2:
  93. args = dict((ensure_bytes(key), ensure_bytes(val)) for
  94. (key, val) in args.items())
  95. return dict((key, conn.literal(val)) for (key, val) in args.items())
  96. else:
  97. # If it's not a dictionary let's try escaping it anyways.
  98. # Worst case it will throw a Value error
  99. if PY2:
  100. args = ensure_bytes(args)
  101. return conn.literal(args)
  102. def _check_executed(self):
  103. if not self._executed:
  104. self.errorhandler(self, ProgrammingError, "execute() first")
  105. def _warning_check(self):
  106. from warnings import warn
  107. db = self._get_db()
  108. # None => warnings not interrogated for current query yet
  109. # 0 => no warnings exists or have been handled already for this query
  110. if self._warnings is None:
  111. self._warnings = db.warning_count()
  112. if self._warnings:
  113. # Only propagate warnings for current query once
  114. warning_count = self._warnings
  115. self._warnings = 0
  116. # When there is next result, fetching warnings cause "command
  117. # out of sync" error.
  118. if self._result and self._result.has_next:
  119. msg = "There are %d MySQL warnings." % (warning_count,)
  120. self.messages.append(msg)
  121. warn(self.Warning(0, msg), stacklevel=3)
  122. return
  123. warnings = db.show_warnings()
  124. if warnings:
  125. # This is done in two loops in case
  126. # Warnings are set to raise exceptions.
  127. for w in warnings:
  128. self.messages.append((self.Warning, w))
  129. for w in warnings:
  130. warn(self.Warning(*w[1:3]), stacklevel=3)
  131. else:
  132. info = db.info()
  133. if info:
  134. self.messages.append((self.Warning, info))
  135. warn(self.Warning(0, info), stacklevel=3)
  136. def nextset(self):
  137. """Advance to the next result set.
  138. Returns None if there are no more result sets.
  139. """
  140. if self._executed:
  141. self.fetchall()
  142. del self.messages[:]
  143. db = self._get_db()
  144. nr = db.next_result()
  145. if nr == -1:
  146. return None
  147. self._do_get_result()
  148. self._post_get_result()
  149. self._warning_check()
  150. return 1
  151. def _post_get_result(self): pass
  152. def _do_get_result(self):
  153. db = self._get_db()
  154. self._result = self._get_result()
  155. self.rowcount = db.affected_rows()
  156. self.rownumber = 0
  157. self.description = self._result and self._result.describe() or None
  158. self.description_flags = self._result and self._result.field_flags() or None
  159. self.lastrowid = db.insert_id()
  160. self._warnings = None
  161. def setinputsizes(self, *args):
  162. """Does nothing, required by DB API."""
  163. def setoutputsizes(self, *args):
  164. """Does nothing, required by DB API."""
  165. def _get_db(self):
  166. con = self.connection
  167. if con is None:
  168. raise ProgrammingError("cursor closed")
  169. return con
  170. def execute(self, query, args=None):
  171. """Execute a query.
  172. query -- string, query to execute on server
  173. args -- optional sequence or mapping, parameters to use with query.
  174. Note: If args is a sequence, then %s must be used as the
  175. parameter placeholder in the query. If a mapping is used,
  176. %(key)s must be used as the placeholder.
  177. Returns integer represents rows affected, if any
  178. """
  179. while self.nextset():
  180. pass
  181. db = self._get_db()
  182. # NOTE:
  183. # Python 2: query should be bytes when executing %.
  184. # All unicode in args should be encoded to bytes on Python 2.
  185. # Python 3: query should be str (unicode) when executing %.
  186. # All bytes in args should be decoded with ascii and surrogateescape on Python 3.
  187. # db.literal(obj) always returns str.
  188. if PY2 and isinstance(query, unicode):
  189. query = query.encode(db.encoding)
  190. if args is not None:
  191. if isinstance(args, dict):
  192. args = dict((key, db.literal(item)) for key, item in args.items())
  193. else:
  194. args = tuple(map(db.literal, args))
  195. if not PY2 and isinstance(query, (bytes, bytearray)):
  196. query = query.decode(db.encoding)
  197. try:
  198. query = query % args
  199. except TypeError as m:
  200. self.errorhandler(self, ProgrammingError, str(m))
  201. if isinstance(query, unicode):
  202. query = query.encode(db.encoding, 'surrogateescape')
  203. res = None
  204. try:
  205. res = self._query(query)
  206. except Exception:
  207. exc, value = sys.exc_info()[:2]
  208. self.errorhandler(self, exc, value)
  209. self._executed = query
  210. if not self._defer_warnings:
  211. self._warning_check()
  212. return res
  213. def executemany(self, query, args):
  214. # type: (str, list) -> int
  215. """Execute a multi-row query.
  216. :param query: query to execute on server
  217. :param args: Sequence of sequences or mappings. It is used as parameter.
  218. :return: Number of rows affected, if any.
  219. This method improves performance on multiple-row INSERT and
  220. REPLACE. Otherwise it is equivalent to looping over args with
  221. execute().
  222. """
  223. del self.messages[:]
  224. if not args:
  225. return
  226. m = RE_INSERT_VALUES.match(query)
  227. if m:
  228. q_prefix = m.group(1) % ()
  229. q_values = m.group(2).rstrip()
  230. q_postfix = m.group(3) or ''
  231. assert q_values[0] == '(' and q_values[-1] == ')'
  232. return self._do_execute_many(q_prefix, q_values, q_postfix, args,
  233. self.max_stmt_length,
  234. self._get_db().encoding)
  235. self.rowcount = sum(self.execute(query, arg) for arg in args)
  236. return self.rowcount
  237. def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding):
  238. conn = self._get_db()
  239. escape = self._escape_args
  240. if isinstance(prefix, text_type):
  241. prefix = prefix.encode(encoding)
  242. if PY2 and isinstance(values, text_type):
  243. values = values.encode(encoding)
  244. if isinstance(postfix, text_type):
  245. postfix = postfix.encode(encoding)
  246. sql = bytearray(prefix)
  247. args = iter(args)
  248. v = values % escape(next(args), conn)
  249. if isinstance(v, text_type):
  250. if PY2:
  251. v = v.encode(encoding)
  252. else:
  253. v = v.encode(encoding, 'surrogateescape')
  254. sql += v
  255. rows = 0
  256. for arg in args:
  257. v = values % escape(arg, conn)
  258. if isinstance(v, text_type):
  259. if PY2:
  260. v = v.encode(encoding)
  261. else:
  262. v = v.encode(encoding, 'surrogateescape')
  263. if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length:
  264. rows += self.execute(sql + postfix)
  265. sql = bytearray(prefix)
  266. else:
  267. sql += b','
  268. sql += v
  269. rows += self.execute(sql + postfix)
  270. self.rowcount = rows
  271. return rows
  272. def callproc(self, procname, args=()):
  273. """Execute stored procedure procname with args
  274. procname -- string, name of procedure to execute on server
  275. args -- Sequence of parameters to use with procedure
  276. Returns the original args.
  277. Compatibility warning: PEP-249 specifies that any modified
  278. parameters must be returned. This is currently impossible
  279. as they are only available by storing them in a server
  280. variable and then retrieved by a query. Since stored
  281. procedures return zero or more result sets, there is no
  282. reliable way to get at OUT or INOUT parameters via callproc.
  283. The server variables are named @_procname_n, where procname
  284. is the parameter above and n is the position of the parameter
  285. (from zero). Once all result sets generated by the procedure
  286. have been fetched, you can issue a SELECT @_procname_0, ...
  287. query using .execute() to get any OUT or INOUT values.
  288. Compatibility warning: The act of calling a stored procedure
  289. itself creates an empty result set. This appears after any
  290. result sets generated by the procedure. This is non-standard
  291. behavior with respect to the DB-API. Be sure to use nextset()
  292. to advance through all result sets; otherwise you may get
  293. disconnected.
  294. """
  295. db = self._get_db()
  296. if args:
  297. fmt = '@_{0}_%d=%s'.format(procname)
  298. q = 'SET %s' % ','.join(fmt % (index, db.literal(arg))
  299. for index, arg in enumerate(args))
  300. if isinstance(q, unicode):
  301. q = q.encode(db.encoding, 'surrogateescape')
  302. self._query(q)
  303. self.nextset()
  304. q = "CALL %s(%s)" % (procname,
  305. ','.join(['@_%s_%d' % (procname, i)
  306. for i in range(len(args))]))
  307. if isinstance(q, unicode):
  308. q = q.encode(db.encoding, 'surrogateescape')
  309. self._query(q)
  310. self._executed = q
  311. if not self._defer_warnings:
  312. self._warning_check()
  313. return args
  314. def _do_query(self, q):
  315. db = self._get_db()
  316. self._last_executed = q
  317. db.query(q)
  318. self._do_get_result()
  319. return self.rowcount
  320. def _query(self, q):
  321. return self._do_query(q)
  322. def _fetch_row(self, size=1):
  323. if not self._result:
  324. return ()
  325. return self._result.fetch_row(size, self._fetch_type)
  326. def __iter__(self):
  327. return iter(self.fetchone, None)
  328. Warning = Warning
  329. Error = Error
  330. InterfaceError = InterfaceError
  331. DatabaseError = DatabaseError
  332. DataError = DataError
  333. OperationalError = OperationalError
  334. IntegrityError = IntegrityError
  335. InternalError = InternalError
  336. ProgrammingError = ProgrammingError
  337. NotSupportedError = NotSupportedError
  338. class CursorStoreResultMixIn(object):
  339. """This is a MixIn class which causes the entire result set to be
  340. stored on the client side, i.e. it uses mysql_store_result(). If the
  341. result set can be very large, consider adding a LIMIT clause to your
  342. query, or using CursorUseResultMixIn instead."""
  343. def _get_result(self):
  344. return self._get_db().store_result()
  345. def _query(self, q):
  346. rowcount = self._do_query(q)
  347. self._post_get_result()
  348. return rowcount
  349. def _post_get_result(self):
  350. self._rows = self._fetch_row(0)
  351. self._result = None
  352. def fetchone(self):
  353. """Fetches a single row from the cursor. None indicates that
  354. no more rows are available."""
  355. self._check_executed()
  356. if self.rownumber >= len(self._rows):
  357. return None
  358. result = self._rows[self.rownumber]
  359. self.rownumber = self.rownumber + 1
  360. return result
  361. def fetchmany(self, size=None):
  362. """Fetch up to size rows from the cursor. Result set may be smaller
  363. than size. If size is not defined, cursor.arraysize is used."""
  364. self._check_executed()
  365. end = self.rownumber + (size or self.arraysize)
  366. result = self._rows[self.rownumber:end]
  367. self.rownumber = min(end, len(self._rows))
  368. return result
  369. def fetchall(self):
  370. """Fetchs all available rows from the cursor."""
  371. self._check_executed()
  372. if self.rownumber:
  373. result = self._rows[self.rownumber:]
  374. else:
  375. result = self._rows
  376. self.rownumber = len(self._rows)
  377. return result
  378. def scroll(self, value, mode='relative'):
  379. """Scroll the cursor in the result set to a new position according
  380. to mode.
  381. If mode is 'relative' (default), value is taken as offset to
  382. the current position in the result set, if set to 'absolute',
  383. value states an absolute target position."""
  384. self._check_executed()
  385. if mode == 'relative':
  386. r = self.rownumber + value
  387. elif mode == 'absolute':
  388. r = value
  389. else:
  390. self.errorhandler(self, ProgrammingError,
  391. "unknown scroll mode %s" % repr(mode))
  392. if r < 0 or r >= len(self._rows):
  393. self.errorhandler(self, IndexError, "out of range")
  394. self.rownumber = r
  395. def __iter__(self):
  396. self._check_executed()
  397. result = self.rownumber and self._rows[self.rownumber:] or self._rows
  398. return iter(result)
  399. class CursorUseResultMixIn(object):
  400. """This is a MixIn class which causes the result set to be stored
  401. in the server and sent row-by-row to client side, i.e. it uses
  402. mysql_use_result(). You MUST retrieve the entire result set and
  403. close() the cursor before additional queries can be performed on
  404. the connection."""
  405. _defer_warnings = True
  406. def _get_result(self): return self._get_db().use_result()
  407. def fetchone(self):
  408. """Fetches a single row from the cursor."""
  409. self._check_executed()
  410. r = self._fetch_row(1)
  411. if not r:
  412. self._warning_check()
  413. return None
  414. self.rownumber = self.rownumber + 1
  415. return r[0]
  416. def fetchmany(self, size=None):
  417. """Fetch up to size rows from the cursor. Result set may be smaller
  418. than size. If size is not defined, cursor.arraysize is used."""
  419. self._check_executed()
  420. r = self._fetch_row(size or self.arraysize)
  421. self.rownumber = self.rownumber + len(r)
  422. if not r:
  423. self._warning_check()
  424. return r
  425. def fetchall(self):
  426. """Fetchs all available rows from the cursor."""
  427. self._check_executed()
  428. r = self._fetch_row(0)
  429. self.rownumber = self.rownumber + len(r)
  430. self._warning_check()
  431. return r
  432. def __iter__(self):
  433. return self
  434. def next(self):
  435. row = self.fetchone()
  436. if row is None:
  437. raise StopIteration
  438. return row
  439. __next__ = next
  440. class CursorTupleRowsMixIn(object):
  441. """This is a MixIn class that causes all rows to be returned as tuples,
  442. which is the standard form required by DB API."""
  443. _fetch_type = 0
  444. class CursorDictRowsMixIn(object):
  445. """This is a MixIn class that causes all rows to be returned as
  446. dictionaries. This is a non-standard feature."""
  447. _fetch_type = 1
  448. def fetchoneDict(self):
  449. """Fetch a single row as a dictionary. Deprecated:
  450. Use fetchone() instead. Will be removed in 1.3."""
  451. from warnings import warn
  452. warn("fetchoneDict() is non-standard and will be removed in 1.3",
  453. DeprecationWarning, 2)
  454. return self.fetchone()
  455. def fetchmanyDict(self, size=None):
  456. """Fetch several rows as a list of dictionaries. Deprecated:
  457. Use fetchmany() instead. Will be removed in 1.3."""
  458. from warnings import warn
  459. warn("fetchmanyDict() is non-standard and will be removed in 1.3",
  460. DeprecationWarning, 2)
  461. return self.fetchmany(size)
  462. def fetchallDict(self):
  463. """Fetch all available rows as a list of dictionaries. Deprecated:
  464. Use fetchall() instead. Will be removed in 1.3."""
  465. from warnings import warn
  466. warn("fetchallDict() is non-standard and will be removed in 1.3",
  467. DeprecationWarning, 2)
  468. return self.fetchall()
  469. class CursorOldDictRowsMixIn(CursorDictRowsMixIn):
  470. """This is a MixIn class that returns rows as dictionaries with
  471. the same key convention as the old Mysqldb (MySQLmodule). Don't
  472. use this."""
  473. _fetch_type = 2
  474. class Cursor(CursorStoreResultMixIn, CursorTupleRowsMixIn,
  475. BaseCursor):
  476. """This is the standard Cursor class that returns rows as tuples
  477. and stores the result set in the client."""
  478. class DictCursor(CursorStoreResultMixIn, CursorDictRowsMixIn,
  479. BaseCursor):
  480. """This is a Cursor class that returns rows as dictionaries and
  481. stores the result set in the client."""
  482. class SSCursor(CursorUseResultMixIn, CursorTupleRowsMixIn,
  483. BaseCursor):
  484. """This is a Cursor class that returns rows as tuples and stores
  485. the result set in the server."""
  486. class SSDictCursor(CursorUseResultMixIn, CursorDictRowsMixIn,
  487. BaseCursor):
  488. """This is a Cursor class that returns rows as dictionaries and
  489. stores the result set in the server."""

Powered by TurnKey Linux.