diff options
| author | Jørgen P. Tjernø <[email protected]> | 2013-12-02 19:31:46 -0800 |
|---|---|---|
| committer | Jørgen P. Tjernø <[email protected]> | 2013-12-02 19:46:31 -0800 |
| commit | f56bb35301836e56582a575a75864392a0177875 (patch) | |
| tree | de61ddd39de3e7df52759711950b4c288592f0dc /mp/src/thirdparty/protobuf-2.3.0/python | |
| parent | Mark some more files as text. (diff) | |
| download | source-sdk-2013-f56bb35301836e56582a575a75864392a0177875.tar.xz source-sdk-2013-f56bb35301836e56582a575a75864392a0177875.zip | |
Fix line endings. WHAMMY.
Diffstat (limited to 'mp/src/thirdparty/protobuf-2.3.0/python')
25 files changed, 11659 insertions, 11659 deletions
diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/ez_setup.py b/mp/src/thirdparty/protobuf-2.3.0/python/ez_setup.py index 0ce9920c..b7a9849e 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/ez_setup.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/ez_setup.py @@ -1,281 +1,281 @@ -#!python
-
-# This file was obtained from:
-# http://peak.telecommunity.com/dist/ez_setup.py
-# on 2009/4/17.
-
-"""Bootstrap setuptools installation
-
-If you want to use setuptools in your package's setup.py, just include this
-file in the same directory with it, and add this to the top of your setup.py::
-
- from ez_setup import use_setuptools
- use_setuptools()
-
-If you want to require a specific version of setuptools, set a download
-mirror, or use an alternate download directory, you can do so by supplying
-the appropriate options to ``use_setuptools()``.
-
-This file can also be run as a script to install or upgrade setuptools.
-"""
-import sys
-DEFAULT_VERSION = "0.6c9"
-DEFAULT_URL = "http://pypi.python.org/packages/%s/s/setuptools/" % sys.version[:3]
-
-md5_data = {
- 'setuptools-0.6b1-py2.3.egg': '8822caf901250d848b996b7f25c6e6ca',
- 'setuptools-0.6b1-py2.4.egg': 'b79a8a403e4502fbb85ee3f1941735cb',
- 'setuptools-0.6b2-py2.3.egg': '5657759d8a6d8fc44070a9d07272d99b',
- 'setuptools-0.6b2-py2.4.egg': '4996a8d169d2be661fa32a6e52e4f82a',
- 'setuptools-0.6b3-py2.3.egg': 'bb31c0fc7399a63579975cad9f5a0618',
- 'setuptools-0.6b3-py2.4.egg': '38a8c6b3d6ecd22247f179f7da669fac',
- 'setuptools-0.6b4-py2.3.egg': '62045a24ed4e1ebc77fe039aa4e6f7e5',
- 'setuptools-0.6b4-py2.4.egg': '4cb2a185d228dacffb2d17f103b3b1c4',
- 'setuptools-0.6c1-py2.3.egg': 'b3f2b5539d65cb7f74ad79127f1a908c',
- 'setuptools-0.6c1-py2.4.egg': 'b45adeda0667d2d2ffe14009364f2a4b',
- 'setuptools-0.6c2-py2.3.egg': 'f0064bf6aa2b7d0f3ba0b43f20817c27',
- 'setuptools-0.6c2-py2.4.egg': '616192eec35f47e8ea16cd6a122b7277',
- 'setuptools-0.6c3-py2.3.egg': 'f181fa125dfe85a259c9cd6f1d7b78fa',
- 'setuptools-0.6c3-py2.4.egg': 'e0ed74682c998bfb73bf803a50e7b71e',
- 'setuptools-0.6c3-py2.5.egg': 'abef16fdd61955514841c7c6bd98965e',
- 'setuptools-0.6c4-py2.3.egg': 'b0b9131acab32022bfac7f44c5d7971f',
- 'setuptools-0.6c4-py2.4.egg': '2a1f9656d4fbf3c97bf946c0a124e6e2',
- 'setuptools-0.6c4-py2.5.egg': '8f5a052e32cdb9c72bcf4b5526f28afc',
- 'setuptools-0.6c5-py2.3.egg': 'ee9fd80965da04f2f3e6b3576e9d8167',
- 'setuptools-0.6c5-py2.4.egg': 'afe2adf1c01701ee841761f5bcd8aa64',
- 'setuptools-0.6c5-py2.5.egg': 'a8d3f61494ccaa8714dfed37bccd3d5d',
- 'setuptools-0.6c6-py2.3.egg': '35686b78116a668847237b69d549ec20',
- 'setuptools-0.6c6-py2.4.egg': '3c56af57be3225019260a644430065ab',
- 'setuptools-0.6c6-py2.5.egg': 'b2f8a7520709a5b34f80946de5f02f53',
- 'setuptools-0.6c7-py2.3.egg': '209fdf9adc3a615e5115b725658e13e2',
- 'setuptools-0.6c7-py2.4.egg': '5a8f954807d46a0fb67cf1f26c55a82e',
- 'setuptools-0.6c7-py2.5.egg': '45d2ad28f9750e7434111fde831e8372',
- 'setuptools-0.6c8-py2.3.egg': '50759d29b349db8cfd807ba8303f1902',
- 'setuptools-0.6c8-py2.4.egg': 'cba38d74f7d483c06e9daa6070cce6de',
- 'setuptools-0.6c8-py2.5.egg': '1721747ee329dc150590a58b3e1ac95b',
- 'setuptools-0.6c9-py2.3.egg': 'a83c4020414807b496e4cfbe08507c03',
- 'setuptools-0.6c9-py2.4.egg': '260a2be2e5388d66bdaee06abec6342a',
- 'setuptools-0.6c9-py2.5.egg': 'fe67c3e5a17b12c0e7c541b7ea43a8e6',
- 'setuptools-0.6c9-py2.6.egg': 'ca37b1ff16fa2ede6e19383e7b59245a',
-}
-
-import sys, os
-try: from hashlib import md5
-except ImportError: from md5 import md5
-
-def _validate_md5(egg_name, data):
- if egg_name in md5_data:
- digest = md5(data).hexdigest()
- if digest != md5_data[egg_name]:
- print >>sys.stderr, (
- "md5 validation of %s failed! (Possible download problem?)"
- % egg_name
- )
- sys.exit(2)
- return data
-
-def use_setuptools(
- version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir,
- download_delay=15
-):
- """Automatically find/download setuptools and make it available on sys.path
-
- `version` should be a valid setuptools version number that is available
- as an egg for download under the `download_base` URL (which should end with
- a '/'). `to_dir` is the directory where setuptools will be downloaded, if
- it is not already available. If `download_delay` is specified, it should
- be the number of seconds that will be paused before initiating a download,
- should one be required. If an older version of setuptools is installed,
- this routine will print a message to ``sys.stderr`` and raise SystemExit in
- an attempt to abort the calling script.
- """
- was_imported = 'pkg_resources' in sys.modules or 'setuptools' in sys.modules
- def do_download():
- egg = download_setuptools(version, download_base, to_dir, download_delay)
- sys.path.insert(0, egg)
- import setuptools; setuptools.bootstrap_install_from = egg
- try:
- import pkg_resources
- except ImportError:
- return do_download()
- try:
- pkg_resources.require("setuptools>="+version); return
- except pkg_resources.VersionConflict, e:
- if was_imported:
- print >>sys.stderr, (
- "The required version of setuptools (>=%s) is not available, and\n"
- "can't be installed while this script is running. Please install\n"
- " a more recent version first, using 'easy_install -U setuptools'."
- "\n\n(Currently using %r)"
- ) % (version, e.args[0])
- sys.exit(2)
- else:
- del pkg_resources, sys.modules['pkg_resources'] # reload ok
- return do_download()
- except pkg_resources.DistributionNotFound:
- return do_download()
-
-def download_setuptools(
- version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir,
- delay = 15
-):
- """Download setuptools from a specified location and return its filename
-
- `version` should be a valid setuptools version number that is available
- as an egg for download under the `download_base` URL (which should end
- with a '/'). `to_dir` is the directory where the egg will be downloaded.
- `delay` is the number of seconds to pause before an actual download attempt.
- """
- import urllib2, shutil
- egg_name = "setuptools-%s-py%s.egg" % (version,sys.version[:3])
- url = download_base + egg_name
- saveto = os.path.join(to_dir, egg_name)
- src = dst = None
- if not os.path.exists(saveto): # Avoid repeated downloads
- try:
- from distutils import log
- if delay:
- log.warn("""
----------------------------------------------------------------------------
-This script requires setuptools version %s to run (even to display
-help). I will attempt to download it for you (from
-%s), but
-you may need to enable firewall access for this script first.
-I will start the download in %d seconds.
-
-(Note: if this machine does not have network access, please obtain the file
-
- %s
-
-and place it in this directory before rerunning this script.)
----------------------------------------------------------------------------""",
- version, download_base, delay, url
- ); from time import sleep; sleep(delay)
- log.warn("Downloading %s", url)
- src = urllib2.urlopen(url)
- # Read/write all in one block, so we don't create a corrupt file
- # if the download is interrupted.
- data = _validate_md5(egg_name, src.read())
- dst = open(saveto,"wb"); dst.write(data)
- finally:
- if src: src.close()
- if dst: dst.close()
- return os.path.realpath(saveto)
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-def main(argv, version=DEFAULT_VERSION):
- """Install or upgrade setuptools and EasyInstall"""
- try:
- import setuptools
- except ImportError:
- egg = None
- try:
- egg = download_setuptools(version, delay=0)
- sys.path.insert(0,egg)
- from setuptools.command.easy_install import main
- return main(list(argv)+[egg]) # we're done here
- finally:
- if egg and os.path.exists(egg):
- os.unlink(egg)
- else:
- if setuptools.__version__ == '0.0.1':
- print >>sys.stderr, (
- "You have an obsolete version of setuptools installed. Please\n"
- "remove it from your system entirely before rerunning this script."
- )
- sys.exit(2)
-
- req = "setuptools>="+version
- import pkg_resources
- try:
- pkg_resources.require(req)
- except pkg_resources.VersionConflict:
- try:
- from setuptools.command.easy_install import main
- except ImportError:
- from easy_install import main
- main(list(argv)+[download_setuptools(delay=0)])
- sys.exit(0) # try to force an exit
- else:
- if argv:
- from setuptools.command.easy_install import main
- main(argv)
- else:
- print "Setuptools version",version,"or greater has been installed."
- print '(Run "ez_setup.py -U setuptools" to reinstall or upgrade.)'
-
-def update_md5(filenames):
- """Update our built-in md5 registry"""
-
- import re
-
- for name in filenames:
- base = os.path.basename(name)
- f = open(name,'rb')
- md5_data[base] = md5(f.read()).hexdigest()
- f.close()
-
- data = [" %r: %r,\n" % it for it in md5_data.items()]
- data.sort()
- repl = "".join(data)
-
- import inspect
- srcfile = inspect.getsourcefile(sys.modules[__name__])
- f = open(srcfile, 'rb'); src = f.read(); f.close()
-
- match = re.search("\nmd5_data = {\n([^}]+)}", src)
- if not match:
- print >>sys.stderr, "Internal error!"
- sys.exit(2)
-
- src = src[:match.start(1)] + repl + src[match.end(1):]
- f = open(srcfile,'w')
- f.write(src)
- f.close()
-
-
-if __name__=='__main__':
- if len(sys.argv)>2 and sys.argv[1]=='--md5update':
- update_md5(sys.argv[2:])
- else:
- main(sys.argv[1:])
-
-
-
-
-
-
+#!python + +# This file was obtained from: +# http://peak.telecommunity.com/dist/ez_setup.py +# on 2009/4/17. + +"""Bootstrap setuptools installation + +If you want to use setuptools in your package's setup.py, just include this +file in the same directory with it, and add this to the top of your setup.py:: + + from ez_setup import use_setuptools + use_setuptools() + +If you want to require a specific version of setuptools, set a download +mirror, or use an alternate download directory, you can do so by supplying +the appropriate options to ``use_setuptools()``. + +This file can also be run as a script to install or upgrade setuptools. +""" +import sys +DEFAULT_VERSION = "0.6c9" +DEFAULT_URL = "http://pypi.python.org/packages/%s/s/setuptools/" % sys.version[:3] + +md5_data = { + 'setuptools-0.6b1-py2.3.egg': '8822caf901250d848b996b7f25c6e6ca', + 'setuptools-0.6b1-py2.4.egg': 'b79a8a403e4502fbb85ee3f1941735cb', + 'setuptools-0.6b2-py2.3.egg': '5657759d8a6d8fc44070a9d07272d99b', + 'setuptools-0.6b2-py2.4.egg': '4996a8d169d2be661fa32a6e52e4f82a', + 'setuptools-0.6b3-py2.3.egg': 'bb31c0fc7399a63579975cad9f5a0618', + 'setuptools-0.6b3-py2.4.egg': '38a8c6b3d6ecd22247f179f7da669fac', + 'setuptools-0.6b4-py2.3.egg': '62045a24ed4e1ebc77fe039aa4e6f7e5', + 'setuptools-0.6b4-py2.4.egg': '4cb2a185d228dacffb2d17f103b3b1c4', + 'setuptools-0.6c1-py2.3.egg': 'b3f2b5539d65cb7f74ad79127f1a908c', + 'setuptools-0.6c1-py2.4.egg': 'b45adeda0667d2d2ffe14009364f2a4b', + 'setuptools-0.6c2-py2.3.egg': 'f0064bf6aa2b7d0f3ba0b43f20817c27', + 'setuptools-0.6c2-py2.4.egg': '616192eec35f47e8ea16cd6a122b7277', + 'setuptools-0.6c3-py2.3.egg': 'f181fa125dfe85a259c9cd6f1d7b78fa', + 'setuptools-0.6c3-py2.4.egg': 'e0ed74682c998bfb73bf803a50e7b71e', + 'setuptools-0.6c3-py2.5.egg': 'abef16fdd61955514841c7c6bd98965e', + 'setuptools-0.6c4-py2.3.egg': 'b0b9131acab32022bfac7f44c5d7971f', + 'setuptools-0.6c4-py2.4.egg': '2a1f9656d4fbf3c97bf946c0a124e6e2', + 'setuptools-0.6c4-py2.5.egg': '8f5a052e32cdb9c72bcf4b5526f28afc', + 'setuptools-0.6c5-py2.3.egg': 'ee9fd80965da04f2f3e6b3576e9d8167', + 'setuptools-0.6c5-py2.4.egg': 'afe2adf1c01701ee841761f5bcd8aa64', + 'setuptools-0.6c5-py2.5.egg': 'a8d3f61494ccaa8714dfed37bccd3d5d', + 'setuptools-0.6c6-py2.3.egg': '35686b78116a668847237b69d549ec20', + 'setuptools-0.6c6-py2.4.egg': '3c56af57be3225019260a644430065ab', + 'setuptools-0.6c6-py2.5.egg': 'b2f8a7520709a5b34f80946de5f02f53', + 'setuptools-0.6c7-py2.3.egg': '209fdf9adc3a615e5115b725658e13e2', + 'setuptools-0.6c7-py2.4.egg': '5a8f954807d46a0fb67cf1f26c55a82e', + 'setuptools-0.6c7-py2.5.egg': '45d2ad28f9750e7434111fde831e8372', + 'setuptools-0.6c8-py2.3.egg': '50759d29b349db8cfd807ba8303f1902', + 'setuptools-0.6c8-py2.4.egg': 'cba38d74f7d483c06e9daa6070cce6de', + 'setuptools-0.6c8-py2.5.egg': '1721747ee329dc150590a58b3e1ac95b', + 'setuptools-0.6c9-py2.3.egg': 'a83c4020414807b496e4cfbe08507c03', + 'setuptools-0.6c9-py2.4.egg': '260a2be2e5388d66bdaee06abec6342a', + 'setuptools-0.6c9-py2.5.egg': 'fe67c3e5a17b12c0e7c541b7ea43a8e6', + 'setuptools-0.6c9-py2.6.egg': 'ca37b1ff16fa2ede6e19383e7b59245a', +} + +import sys, os +try: from hashlib import md5 +except ImportError: from md5 import md5 + +def _validate_md5(egg_name, data): + if egg_name in md5_data: + digest = md5(data).hexdigest() + if digest != md5_data[egg_name]: + print >>sys.stderr, ( + "md5 validation of %s failed! (Possible download problem?)" + % egg_name + ) + sys.exit(2) + return data + +def use_setuptools( + version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, + download_delay=15 +): + """Automatically find/download setuptools and make it available on sys.path + + `version` should be a valid setuptools version number that is available + as an egg for download under the `download_base` URL (which should end with + a '/'). `to_dir` is the directory where setuptools will be downloaded, if + it is not already available. If `download_delay` is specified, it should + be the number of seconds that will be paused before initiating a download, + should one be required. If an older version of setuptools is installed, + this routine will print a message to ``sys.stderr`` and raise SystemExit in + an attempt to abort the calling script. + """ + was_imported = 'pkg_resources' in sys.modules or 'setuptools' in sys.modules + def do_download(): + egg = download_setuptools(version, download_base, to_dir, download_delay) + sys.path.insert(0, egg) + import setuptools; setuptools.bootstrap_install_from = egg + try: + import pkg_resources + except ImportError: + return do_download() + try: + pkg_resources.require("setuptools>="+version); return + except pkg_resources.VersionConflict, e: + if was_imported: + print >>sys.stderr, ( + "The required version of setuptools (>=%s) is not available, and\n" + "can't be installed while this script is running. Please install\n" + " a more recent version first, using 'easy_install -U setuptools'." + "\n\n(Currently using %r)" + ) % (version, e.args[0]) + sys.exit(2) + else: + del pkg_resources, sys.modules['pkg_resources'] # reload ok + return do_download() + except pkg_resources.DistributionNotFound: + return do_download() + +def download_setuptools( + version=DEFAULT_VERSION, download_base=DEFAULT_URL, to_dir=os.curdir, + delay = 15 +): + """Download setuptools from a specified location and return its filename + + `version` should be a valid setuptools version number that is available + as an egg for download under the `download_base` URL (which should end + with a '/'). `to_dir` is the directory where the egg will be downloaded. + `delay` is the number of seconds to pause before an actual download attempt. + """ + import urllib2, shutil + egg_name = "setuptools-%s-py%s.egg" % (version,sys.version[:3]) + url = download_base + egg_name + saveto = os.path.join(to_dir, egg_name) + src = dst = None + if not os.path.exists(saveto): # Avoid repeated downloads + try: + from distutils import log + if delay: + log.warn(""" +--------------------------------------------------------------------------- +This script requires setuptools version %s to run (even to display +help). I will attempt to download it for you (from +%s), but +you may need to enable firewall access for this script first. +I will start the download in %d seconds. + +(Note: if this machine does not have network access, please obtain the file + + %s + +and place it in this directory before rerunning this script.) +---------------------------------------------------------------------------""", + version, download_base, delay, url + ); from time import sleep; sleep(delay) + log.warn("Downloading %s", url) + src = urllib2.urlopen(url) + # Read/write all in one block, so we don't create a corrupt file + # if the download is interrupted. + data = _validate_md5(egg_name, src.read()) + dst = open(saveto,"wb"); dst.write(data) + finally: + if src: src.close() + if dst: dst.close() + return os.path.realpath(saveto) + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +def main(argv, version=DEFAULT_VERSION): + """Install or upgrade setuptools and EasyInstall""" + try: + import setuptools + except ImportError: + egg = None + try: + egg = download_setuptools(version, delay=0) + sys.path.insert(0,egg) + from setuptools.command.easy_install import main + return main(list(argv)+[egg]) # we're done here + finally: + if egg and os.path.exists(egg): + os.unlink(egg) + else: + if setuptools.__version__ == '0.0.1': + print >>sys.stderr, ( + "You have an obsolete version of setuptools installed. Please\n" + "remove it from your system entirely before rerunning this script." + ) + sys.exit(2) + + req = "setuptools>="+version + import pkg_resources + try: + pkg_resources.require(req) + except pkg_resources.VersionConflict: + try: + from setuptools.command.easy_install import main + except ImportError: + from easy_install import main + main(list(argv)+[download_setuptools(delay=0)]) + sys.exit(0) # try to force an exit + else: + if argv: + from setuptools.command.easy_install import main + main(argv) + else: + print "Setuptools version",version,"or greater has been installed." + print '(Run "ez_setup.py -U setuptools" to reinstall or upgrade.)' + +def update_md5(filenames): + """Update our built-in md5 registry""" + + import re + + for name in filenames: + base = os.path.basename(name) + f = open(name,'rb') + md5_data[base] = md5(f.read()).hexdigest() + f.close() + + data = [" %r: %r,\n" % it for it in md5_data.items()] + data.sort() + repl = "".join(data) + + import inspect + srcfile = inspect.getsourcefile(sys.modules[__name__]) + f = open(srcfile, 'rb'); src = f.read(); f.close() + + match = re.search("\nmd5_data = {\n([^}]+)}", src) + if not match: + print >>sys.stderr, "Internal error!" + sys.exit(2) + + src = src[:match.start(1)] + repl + src[match.end(1):] + f = open(srcfile,'w') + f.write(src) + f.close() + + +if __name__=='__main__': + if len(sys.argv)>2 and sys.argv[1]=='--md5update': + update_md5(sys.argv[2:]) + else: + main(sys.argv[1:]) + + + + + + diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/__init__.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/__init__.py index 4d0b94e2..de40ea7c 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/__init__.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/__init__.py @@ -1 +1 @@ -__import__('pkg_resources').declare_namespace(__name__)
+__import__('pkg_resources').declare_namespace(__name__) diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/descriptor.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/descriptor.py index 212455b9..aa4ab969 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/descriptor.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/descriptor.py @@ -1,590 +1,590 @@ -# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-# TODO(robinson): We probably need to provide deep-copy methods for
-# descriptor types. When a FieldDescriptor is passed into
-# Descriptor.__init__(), we should make a deep copy and then set
-# containing_type on it. Alternatively, we could just get
-# rid of containing_type (iit's not needed for reflection.py, at least).
-#
-# TODO(robinson): Print method?
-#
-# TODO(robinson): Useful __repr__?
-
-"""Descriptors essentially contain exactly the information found in a .proto
-file, in types that make this information accessible in Python.
-"""
-
-__author__ = '[email protected] (Will Robinson)'
-
-
-class Error(Exception):
- """Base error for this module."""
-
-
-class DescriptorBase(object):
-
- """Descriptors base class.
-
- This class is the base of all descriptor classes. It provides common options
- related functionaility.
-
- Attributes:
- has_options: True if the descriptor has non-default options. Usually it
- is not necessary to read this -- just call GetOptions() which will
- happily return the default instance. However, it's sometimes useful
- for efficiency, and also useful inside the protobuf implementation to
- avoid some bootstrapping issues.
- """
-
- def __init__(self, options, options_class_name):
- """Initialize the descriptor given its options message and the name of the
- class of the options message. The name of the class is required in case
- the options message is None and has to be created.
- """
- self._options = options
- self._options_class_name = options_class_name
-
- # Does this descriptor have non-default options?
- self.has_options = options is not None
-
- def GetOptions(self):
- """Retrieves descriptor options.
-
- This method returns the options set or creates the default options for the
- descriptor.
- """
- if self._options:
- return self._options
- from google.protobuf import descriptor_pb2
- try:
- options_class = getattr(descriptor_pb2, self._options_class_name)
- except AttributeError:
- raise RuntimeError('Unknown options class name %s!' %
- (self._options_class_name))
- self._options = options_class()
- return self._options
-
-
-class _NestedDescriptorBase(DescriptorBase):
- """Common class for descriptors that can be nested."""
-
- def __init__(self, options, options_class_name, name, full_name,
- file, containing_type, serialized_start=None,
- serialized_end=None):
- """Constructor.
-
- Args:
- options: Protocol message options or None
- to use default message options.
- options_class_name: (str) The class name of the above options.
-
- name: (str) Name of this protocol message type.
- full_name: (str) Fully-qualified name of this protocol message type,
- which will include protocol "package" name and the name of any
- enclosing types.
- file: (FileDescriptor) Reference to file info.
- containing_type: if provided, this is a nested descriptor, with this
- descriptor as parent, otherwise None.
- serialized_start: The start index (inclusive) in block in the
- file.serialized_pb that describes this descriptor.
- serialized_end: The end index (exclusive) in block in the
- file.serialized_pb that describes this descriptor.
- """
- super(_NestedDescriptorBase, self).__init__(
- options, options_class_name)
-
- self.name = name
- # TODO(falk): Add function to calculate full_name instead of having it in
- # memory?
- self.full_name = full_name
- self.file = file
- self.containing_type = containing_type
-
- self._serialized_start = serialized_start
- self._serialized_end = serialized_end
-
- def GetTopLevelContainingType(self):
- """Returns the root if this is a nested type, or itself if its the root."""
- desc = self
- while desc.containing_type is not None:
- desc = desc.containing_type
- return desc
-
- def CopyToProto(self, proto):
- """Copies this to the matching proto in descriptor_pb2.
-
- Args:
- proto: An empty proto instance from descriptor_pb2.
-
- Raises:
- Error: If self couldnt be serialized, due to to few constructor arguments.
- """
- if (self.file is not None and
- self._serialized_start is not None and
- self._serialized_end is not None):
- proto.ParseFromString(self.file.serialized_pb[
- self._serialized_start:self._serialized_end])
- else:
- raise Error('Descriptor does not contain serialization.')
-
-
-class Descriptor(_NestedDescriptorBase):
-
- """Descriptor for a protocol message type.
-
- A Descriptor instance has the following attributes:
-
- name: (str) Name of this protocol message type.
- full_name: (str) Fully-qualified name of this protocol message type,
- which will include protocol "package" name and the name of any
- enclosing types.
-
- containing_type: (Descriptor) Reference to the descriptor of the
- type containing us, or None if this is top-level.
-
- fields: (list of FieldDescriptors) Field descriptors for all
- fields in this type.
- fields_by_number: (dict int -> FieldDescriptor) Same FieldDescriptor
- objects as in |fields|, but indexed by "number" attribute in each
- FieldDescriptor.
- fields_by_name: (dict str -> FieldDescriptor) Same FieldDescriptor
- objects as in |fields|, but indexed by "name" attribute in each
- FieldDescriptor.
-
- nested_types: (list of Descriptors) Descriptor references
- for all protocol message types nested within this one.
- nested_types_by_name: (dict str -> Descriptor) Same Descriptor
- objects as in |nested_types|, but indexed by "name" attribute
- in each Descriptor.
-
- enum_types: (list of EnumDescriptors) EnumDescriptor references
- for all enums contained within this type.
- enum_types_by_name: (dict str ->EnumDescriptor) Same EnumDescriptor
- objects as in |enum_types|, but indexed by "name" attribute
- in each EnumDescriptor.
- enum_values_by_name: (dict str -> EnumValueDescriptor) Dict mapping
- from enum value name to EnumValueDescriptor for that value.
-
- extensions: (list of FieldDescriptor) All extensions defined directly
- within this message type (NOT within a nested type).
- extensions_by_name: (dict, string -> FieldDescriptor) Same FieldDescriptor
- objects as |extensions|, but indexed by "name" attribute of each
- FieldDescriptor.
-
- is_extendable: Does this type define any extension ranges?
-
- options: (descriptor_pb2.MessageOptions) Protocol message options or None
- to use default message options.
-
- file: (FileDescriptor) Reference to file descriptor.
- """
-
- def __init__(self, name, full_name, filename, containing_type, fields,
- nested_types, enum_types, extensions, options=None,
- is_extendable=True, extension_ranges=None, file=None,
- serialized_start=None, serialized_end=None):
- """Arguments to __init__() are as described in the description
- of Descriptor fields above.
-
- Note that filename is an obsolete argument, that is not used anymore.
- Please use file.name to access this as an attribute.
- """
- super(Descriptor, self).__init__(
- options, 'MessageOptions', name, full_name, file,
- containing_type, serialized_start=serialized_start,
- serialized_end=serialized_start)
-
- # We have fields in addition to fields_by_name and fields_by_number,
- # so that:
- # 1. Clients can index fields by "order in which they're listed."
- # 2. Clients can easily iterate over all fields with the terse
- # syntax: for f in descriptor.fields: ...
- self.fields = fields
- for field in self.fields:
- field.containing_type = self
- self.fields_by_number = dict((f.number, f) for f in fields)
- self.fields_by_name = dict((f.name, f) for f in fields)
-
- self.nested_types = nested_types
- self.nested_types_by_name = dict((t.name, t) for t in nested_types)
-
- self.enum_types = enum_types
- for enum_type in self.enum_types:
- enum_type.containing_type = self
- self.enum_types_by_name = dict((t.name, t) for t in enum_types)
- self.enum_values_by_name = dict(
- (v.name, v) for t in enum_types for v in t.values)
-
- self.extensions = extensions
- for extension in self.extensions:
- extension.extension_scope = self
- self.extensions_by_name = dict((f.name, f) for f in extensions)
- self.is_extendable = is_extendable
- self.extension_ranges = extension_ranges
-
- self._serialized_start = serialized_start
- self._serialized_end = serialized_end
-
- def CopyToProto(self, proto):
- """Copies this to a descriptor_pb2.DescriptorProto.
-
- Args:
- proto: An empty descriptor_pb2.DescriptorProto.
- """
- # This function is overriden to give a better doc comment.
- super(Descriptor, self).CopyToProto(proto)
-
-
-# TODO(robinson): We should have aggressive checking here,
-# for example:
-# * If you specify a repeated field, you should not be allowed
-# to specify a default value.
-# * [Other examples here as needed].
-#
-# TODO(robinson): for this and other *Descriptor classes, we
-# might also want to lock things down aggressively (e.g.,
-# prevent clients from setting the attributes). Having
-# stronger invariants here in general will reduce the number
-# of runtime checks we must do in reflection.py...
-class FieldDescriptor(DescriptorBase):
-
- """Descriptor for a single field in a .proto file.
-
- A FieldDescriptor instance has the following attriubtes:
-
- name: (str) Name of this field, exactly as it appears in .proto.
- full_name: (str) Name of this field, including containing scope. This is
- particularly relevant for extensions.
- index: (int) Dense, 0-indexed index giving the order that this
- field textually appears within its message in the .proto file.
- number: (int) Tag number declared for this field in the .proto file.
-
- type: (One of the TYPE_* constants below) Declared type.
- cpp_type: (One of the CPPTYPE_* constants below) C++ type used to
- represent this field.
-
- label: (One of the LABEL_* constants below) Tells whether this
- field is optional, required, or repeated.
- has_default_value: (bool) True if this field has a default value defined,
- otherwise false.
- default_value: (Varies) Default value of this field. Only
- meaningful for non-repeated scalar fields. Repeated fields
- should always set this to [], and non-repeated composite
- fields should always set this to None.
-
- containing_type: (Descriptor) Descriptor of the protocol message
- type that contains this field. Set by the Descriptor constructor
- if we're passed into one.
- Somewhat confusingly, for extension fields, this is the
- descriptor of the EXTENDED message, not the descriptor
- of the message containing this field. (See is_extension and
- extension_scope below).
- message_type: (Descriptor) If a composite field, a descriptor
- of the message type contained in this field. Otherwise, this is None.
- enum_type: (EnumDescriptor) If this field contains an enum, a
- descriptor of that enum. Otherwise, this is None.
-
- is_extension: True iff this describes an extension field.
- extension_scope: (Descriptor) Only meaningful if is_extension is True.
- Gives the message that immediately contains this extension field.
- Will be None iff we're a top-level (file-level) extension field.
-
- options: (descriptor_pb2.FieldOptions) Protocol message field options or
- None to use default field options.
- """
-
- # Must be consistent with C++ FieldDescriptor::Type enum in
- # descriptor.h.
- #
- # TODO(robinson): Find a way to eliminate this repetition.
- TYPE_DOUBLE = 1
- TYPE_FLOAT = 2
- TYPE_INT64 = 3
- TYPE_UINT64 = 4
- TYPE_INT32 = 5
- TYPE_FIXED64 = 6
- TYPE_FIXED32 = 7
- TYPE_BOOL = 8
- TYPE_STRING = 9
- TYPE_GROUP = 10
- TYPE_MESSAGE = 11
- TYPE_BYTES = 12
- TYPE_UINT32 = 13
- TYPE_ENUM = 14
- TYPE_SFIXED32 = 15
- TYPE_SFIXED64 = 16
- TYPE_SINT32 = 17
- TYPE_SINT64 = 18
- MAX_TYPE = 18
-
- # Must be consistent with C++ FieldDescriptor::CppType enum in
- # descriptor.h.
- #
- # TODO(robinson): Find a way to eliminate this repetition.
- CPPTYPE_INT32 = 1
- CPPTYPE_INT64 = 2
- CPPTYPE_UINT32 = 3
- CPPTYPE_UINT64 = 4
- CPPTYPE_DOUBLE = 5
- CPPTYPE_FLOAT = 6
- CPPTYPE_BOOL = 7
- CPPTYPE_ENUM = 8
- CPPTYPE_STRING = 9
- CPPTYPE_MESSAGE = 10
- MAX_CPPTYPE = 10
-
- # Must be consistent with C++ FieldDescriptor::Label enum in
- # descriptor.h.
- #
- # TODO(robinson): Find a way to eliminate this repetition.
- LABEL_OPTIONAL = 1
- LABEL_REQUIRED = 2
- LABEL_REPEATED = 3
- MAX_LABEL = 3
-
- def __init__(self, name, full_name, index, number, type, cpp_type, label,
- default_value, message_type, enum_type, containing_type,
- is_extension, extension_scope, options=None,
- has_default_value=True):
- """The arguments are as described in the description of FieldDescriptor
- attributes above.
-
- Note that containing_type may be None, and may be set later if necessary
- (to deal with circular references between message types, for example).
- Likewise for extension_scope.
- """
- super(FieldDescriptor, self).__init__(options, 'FieldOptions')
- self.name = name
- self.full_name = full_name
- self.index = index
- self.number = number
- self.type = type
- self.cpp_type = cpp_type
- self.label = label
- self.has_default_value = has_default_value
- self.default_value = default_value
- self.containing_type = containing_type
- self.message_type = message_type
- self.enum_type = enum_type
- self.is_extension = is_extension
- self.extension_scope = extension_scope
-
-
-class EnumDescriptor(_NestedDescriptorBase):
-
- """Descriptor for an enum defined in a .proto file.
-
- An EnumDescriptor instance has the following attributes:
-
- name: (str) Name of the enum type.
- full_name: (str) Full name of the type, including package name
- and any enclosing type(s).
-
- values: (list of EnumValueDescriptors) List of the values
- in this enum.
- values_by_name: (dict str -> EnumValueDescriptor) Same as |values|,
- but indexed by the "name" field of each EnumValueDescriptor.
- values_by_number: (dict int -> EnumValueDescriptor) Same as |values|,
- but indexed by the "number" field of each EnumValueDescriptor.
- containing_type: (Descriptor) Descriptor of the immediate containing
- type of this enum, or None if this is an enum defined at the
- top level in a .proto file. Set by Descriptor's constructor
- if we're passed into one.
- file: (FileDescriptor) Reference to file descriptor.
- options: (descriptor_pb2.EnumOptions) Enum options message or
- None to use default enum options.
- """
-
- def __init__(self, name, full_name, filename, values,
- containing_type=None, options=None, file=None,
- serialized_start=None, serialized_end=None):
- """Arguments are as described in the attribute description above.
-
- Note that filename is an obsolete argument, that is not used anymore.
- Please use file.name to access this as an attribute.
- """
- super(EnumDescriptor, self).__init__(
- options, 'EnumOptions', name, full_name, file,
- containing_type, serialized_start=serialized_start,
- serialized_end=serialized_start)
-
- self.values = values
- for value in self.values:
- value.type = self
- self.values_by_name = dict((v.name, v) for v in values)
- self.values_by_number = dict((v.number, v) for v in values)
-
- self._serialized_start = serialized_start
- self._serialized_end = serialized_end
-
- def CopyToProto(self, proto):
- """Copies this to a descriptor_pb2.EnumDescriptorProto.
-
- Args:
- proto: An empty descriptor_pb2.EnumDescriptorProto.
- """
- # This function is overriden to give a better doc comment.
- super(EnumDescriptor, self).CopyToProto(proto)
-
-
-class EnumValueDescriptor(DescriptorBase):
-
- """Descriptor for a single value within an enum.
-
- name: (str) Name of this value.
- index: (int) Dense, 0-indexed index giving the order that this
- value appears textually within its enum in the .proto file.
- number: (int) Actual number assigned to this enum value.
- type: (EnumDescriptor) EnumDescriptor to which this value
- belongs. Set by EnumDescriptor's constructor if we're
- passed into one.
- options: (descriptor_pb2.EnumValueOptions) Enum value options message or
- None to use default enum value options options.
- """
-
- def __init__(self, name, index, number, type=None, options=None):
- """Arguments are as described in the attribute description above."""
- super(EnumValueDescriptor, self).__init__(options, 'EnumValueOptions')
- self.name = name
- self.index = index
- self.number = number
- self.type = type
-
-
-class ServiceDescriptor(_NestedDescriptorBase):
-
- """Descriptor for a service.
-
- name: (str) Name of the service.
- full_name: (str) Full name of the service, including package name.
- index: (int) 0-indexed index giving the order that this services
- definition appears withing the .proto file.
- methods: (list of MethodDescriptor) List of methods provided by this
- service.
- options: (descriptor_pb2.ServiceOptions) Service options message or
- None to use default service options.
- file: (FileDescriptor) Reference to file info.
- """
-
- def __init__(self, name, full_name, index, methods, options=None, file=None,
- serialized_start=None, serialized_end=None):
- super(ServiceDescriptor, self).__init__(
- options, 'ServiceOptions', name, full_name, file,
- None, serialized_start=serialized_start,
- serialized_end=serialized_end)
- self.index = index
- self.methods = methods
- # Set the containing service for each method in this service.
- for method in self.methods:
- method.containing_service = self
-
- def FindMethodByName(self, name):
- """Searches for the specified method, and returns its descriptor."""
- for method in self.methods:
- if name == method.name:
- return method
- return None
-
- def CopyToProto(self, proto):
- """Copies this to a descriptor_pb2.ServiceDescriptorProto.
-
- Args:
- proto: An empty descriptor_pb2.ServiceDescriptorProto.
- """
- # This function is overriden to give a better doc comment.
- super(ServiceDescriptor, self).CopyToProto(proto)
-
-
-class MethodDescriptor(DescriptorBase):
-
- """Descriptor for a method in a service.
-
- name: (str) Name of the method within the service.
- full_name: (str) Full name of method.
- index: (int) 0-indexed index of the method inside the service.
- containing_service: (ServiceDescriptor) The service that contains this
- method.
- input_type: The descriptor of the message that this method accepts.
- output_type: The descriptor of the message that this method returns.
- options: (descriptor_pb2.MethodOptions) Method options message or
- None to use default method options.
- """
-
- def __init__(self, name, full_name, index, containing_service,
- input_type, output_type, options=None):
- """The arguments are as described in the description of MethodDescriptor
- attributes above.
-
- Note that containing_service may be None, and may be set later if necessary.
- """
- super(MethodDescriptor, self).__init__(options, 'MethodOptions')
- self.name = name
- self.full_name = full_name
- self.index = index
- self.containing_service = containing_service
- self.input_type = input_type
- self.output_type = output_type
-
-
-class FileDescriptor(DescriptorBase):
- """Descriptor for a file. Mimics the descriptor_pb2.FileDescriptorProto.
-
- name: name of file, relative to root of source tree.
- package: name of the package
- serialized_pb: (str) Byte string of serialized
- descriptor_pb2.FileDescriptorProto.
- """
-
- def __init__(self, name, package, options=None, serialized_pb=None):
- """Constructor."""
- super(FileDescriptor, self).__init__(options, 'FileOptions')
-
- self.name = name
- self.package = package
- self.serialized_pb = serialized_pb
-
- def CopyToProto(self, proto):
- """Copies this to a descriptor_pb2.FileDescriptorProto.
-
- Args:
- proto: An empty descriptor_pb2.FileDescriptorProto.
- """
- proto.ParseFromString(self.serialized_pb)
-
-
-def _ParseOptions(message, string):
- """Parses serialized options.
-
- This helper function is used to parse serialized options in generated
- proto2 files. It must not be used outside proto2.
- """
- message.ParseFromString(string)
- return message
+# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# TODO(robinson): We probably need to provide deep-copy methods for +# descriptor types. When a FieldDescriptor is passed into +# Descriptor.__init__(), we should make a deep copy and then set +# containing_type on it. Alternatively, we could just get +# rid of containing_type (iit's not needed for reflection.py, at least). +# +# TODO(robinson): Print method? +# +# TODO(robinson): Useful __repr__? + +"""Descriptors essentially contain exactly the information found in a .proto +file, in types that make this information accessible in Python. +""" + +__author__ = '[email protected] (Will Robinson)' + + +class Error(Exception): + """Base error for this module.""" + + +class DescriptorBase(object): + + """Descriptors base class. + + This class is the base of all descriptor classes. It provides common options + related functionaility. + + Attributes: + has_options: True if the descriptor has non-default options. Usually it + is not necessary to read this -- just call GetOptions() which will + happily return the default instance. However, it's sometimes useful + for efficiency, and also useful inside the protobuf implementation to + avoid some bootstrapping issues. + """ + + def __init__(self, options, options_class_name): + """Initialize the descriptor given its options message and the name of the + class of the options message. The name of the class is required in case + the options message is None and has to be created. + """ + self._options = options + self._options_class_name = options_class_name + + # Does this descriptor have non-default options? + self.has_options = options is not None + + def GetOptions(self): + """Retrieves descriptor options. + + This method returns the options set or creates the default options for the + descriptor. + """ + if self._options: + return self._options + from google.protobuf import descriptor_pb2 + try: + options_class = getattr(descriptor_pb2, self._options_class_name) + except AttributeError: + raise RuntimeError('Unknown options class name %s!' % + (self._options_class_name)) + self._options = options_class() + return self._options + + +class _NestedDescriptorBase(DescriptorBase): + """Common class for descriptors that can be nested.""" + + def __init__(self, options, options_class_name, name, full_name, + file, containing_type, serialized_start=None, + serialized_end=None): + """Constructor. + + Args: + options: Protocol message options or None + to use default message options. + options_class_name: (str) The class name of the above options. + + name: (str) Name of this protocol message type. + full_name: (str) Fully-qualified name of this protocol message type, + which will include protocol "package" name and the name of any + enclosing types. + file: (FileDescriptor) Reference to file info. + containing_type: if provided, this is a nested descriptor, with this + descriptor as parent, otherwise None. + serialized_start: The start index (inclusive) in block in the + file.serialized_pb that describes this descriptor. + serialized_end: The end index (exclusive) in block in the + file.serialized_pb that describes this descriptor. + """ + super(_NestedDescriptorBase, self).__init__( + options, options_class_name) + + self.name = name + # TODO(falk): Add function to calculate full_name instead of having it in + # memory? + self.full_name = full_name + self.file = file + self.containing_type = containing_type + + self._serialized_start = serialized_start + self._serialized_end = serialized_end + + def GetTopLevelContainingType(self): + """Returns the root if this is a nested type, or itself if its the root.""" + desc = self + while desc.containing_type is not None: + desc = desc.containing_type + return desc + + def CopyToProto(self, proto): + """Copies this to the matching proto in descriptor_pb2. + + Args: + proto: An empty proto instance from descriptor_pb2. + + Raises: + Error: If self couldnt be serialized, due to to few constructor arguments. + """ + if (self.file is not None and + self._serialized_start is not None and + self._serialized_end is not None): + proto.ParseFromString(self.file.serialized_pb[ + self._serialized_start:self._serialized_end]) + else: + raise Error('Descriptor does not contain serialization.') + + +class Descriptor(_NestedDescriptorBase): + + """Descriptor for a protocol message type. + + A Descriptor instance has the following attributes: + + name: (str) Name of this protocol message type. + full_name: (str) Fully-qualified name of this protocol message type, + which will include protocol "package" name and the name of any + enclosing types. + + containing_type: (Descriptor) Reference to the descriptor of the + type containing us, or None if this is top-level. + + fields: (list of FieldDescriptors) Field descriptors for all + fields in this type. + fields_by_number: (dict int -> FieldDescriptor) Same FieldDescriptor + objects as in |fields|, but indexed by "number" attribute in each + FieldDescriptor. + fields_by_name: (dict str -> FieldDescriptor) Same FieldDescriptor + objects as in |fields|, but indexed by "name" attribute in each + FieldDescriptor. + + nested_types: (list of Descriptors) Descriptor references + for all protocol message types nested within this one. + nested_types_by_name: (dict str -> Descriptor) Same Descriptor + objects as in |nested_types|, but indexed by "name" attribute + in each Descriptor. + + enum_types: (list of EnumDescriptors) EnumDescriptor references + for all enums contained within this type. + enum_types_by_name: (dict str ->EnumDescriptor) Same EnumDescriptor + objects as in |enum_types|, but indexed by "name" attribute + in each EnumDescriptor. + enum_values_by_name: (dict str -> EnumValueDescriptor) Dict mapping + from enum value name to EnumValueDescriptor for that value. + + extensions: (list of FieldDescriptor) All extensions defined directly + within this message type (NOT within a nested type). + extensions_by_name: (dict, string -> FieldDescriptor) Same FieldDescriptor + objects as |extensions|, but indexed by "name" attribute of each + FieldDescriptor. + + is_extendable: Does this type define any extension ranges? + + options: (descriptor_pb2.MessageOptions) Protocol message options or None + to use default message options. + + file: (FileDescriptor) Reference to file descriptor. + """ + + def __init__(self, name, full_name, filename, containing_type, fields, + nested_types, enum_types, extensions, options=None, + is_extendable=True, extension_ranges=None, file=None, + serialized_start=None, serialized_end=None): + """Arguments to __init__() are as described in the description + of Descriptor fields above. + + Note that filename is an obsolete argument, that is not used anymore. + Please use file.name to access this as an attribute. + """ + super(Descriptor, self).__init__( + options, 'MessageOptions', name, full_name, file, + containing_type, serialized_start=serialized_start, + serialized_end=serialized_start) + + # We have fields in addition to fields_by_name and fields_by_number, + # so that: + # 1. Clients can index fields by "order in which they're listed." + # 2. Clients can easily iterate over all fields with the terse + # syntax: for f in descriptor.fields: ... + self.fields = fields + for field in self.fields: + field.containing_type = self + self.fields_by_number = dict((f.number, f) for f in fields) + self.fields_by_name = dict((f.name, f) for f in fields) + + self.nested_types = nested_types + self.nested_types_by_name = dict((t.name, t) for t in nested_types) + + self.enum_types = enum_types + for enum_type in self.enum_types: + enum_type.containing_type = self + self.enum_types_by_name = dict((t.name, t) for t in enum_types) + self.enum_values_by_name = dict( + (v.name, v) for t in enum_types for v in t.values) + + self.extensions = extensions + for extension in self.extensions: + extension.extension_scope = self + self.extensions_by_name = dict((f.name, f) for f in extensions) + self.is_extendable = is_extendable + self.extension_ranges = extension_ranges + + self._serialized_start = serialized_start + self._serialized_end = serialized_end + + def CopyToProto(self, proto): + """Copies this to a descriptor_pb2.DescriptorProto. + + Args: + proto: An empty descriptor_pb2.DescriptorProto. + """ + # This function is overriden to give a better doc comment. + super(Descriptor, self).CopyToProto(proto) + + +# TODO(robinson): We should have aggressive checking here, +# for example: +# * If you specify a repeated field, you should not be allowed +# to specify a default value. +# * [Other examples here as needed]. +# +# TODO(robinson): for this and other *Descriptor classes, we +# might also want to lock things down aggressively (e.g., +# prevent clients from setting the attributes). Having +# stronger invariants here in general will reduce the number +# of runtime checks we must do in reflection.py... +class FieldDescriptor(DescriptorBase): + + """Descriptor for a single field in a .proto file. + + A FieldDescriptor instance has the following attriubtes: + + name: (str) Name of this field, exactly as it appears in .proto. + full_name: (str) Name of this field, including containing scope. This is + particularly relevant for extensions. + index: (int) Dense, 0-indexed index giving the order that this + field textually appears within its message in the .proto file. + number: (int) Tag number declared for this field in the .proto file. + + type: (One of the TYPE_* constants below) Declared type. + cpp_type: (One of the CPPTYPE_* constants below) C++ type used to + represent this field. + + label: (One of the LABEL_* constants below) Tells whether this + field is optional, required, or repeated. + has_default_value: (bool) True if this field has a default value defined, + otherwise false. + default_value: (Varies) Default value of this field. Only + meaningful for non-repeated scalar fields. Repeated fields + should always set this to [], and non-repeated composite + fields should always set this to None. + + containing_type: (Descriptor) Descriptor of the protocol message + type that contains this field. Set by the Descriptor constructor + if we're passed into one. + Somewhat confusingly, for extension fields, this is the + descriptor of the EXTENDED message, not the descriptor + of the message containing this field. (See is_extension and + extension_scope below). + message_type: (Descriptor) If a composite field, a descriptor + of the message type contained in this field. Otherwise, this is None. + enum_type: (EnumDescriptor) If this field contains an enum, a + descriptor of that enum. Otherwise, this is None. + + is_extension: True iff this describes an extension field. + extension_scope: (Descriptor) Only meaningful if is_extension is True. + Gives the message that immediately contains this extension field. + Will be None iff we're a top-level (file-level) extension field. + + options: (descriptor_pb2.FieldOptions) Protocol message field options or + None to use default field options. + """ + + # Must be consistent with C++ FieldDescriptor::Type enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + TYPE_DOUBLE = 1 + TYPE_FLOAT = 2 + TYPE_INT64 = 3 + TYPE_UINT64 = 4 + TYPE_INT32 = 5 + TYPE_FIXED64 = 6 + TYPE_FIXED32 = 7 + TYPE_BOOL = 8 + TYPE_STRING = 9 + TYPE_GROUP = 10 + TYPE_MESSAGE = 11 + TYPE_BYTES = 12 + TYPE_UINT32 = 13 + TYPE_ENUM = 14 + TYPE_SFIXED32 = 15 + TYPE_SFIXED64 = 16 + TYPE_SINT32 = 17 + TYPE_SINT64 = 18 + MAX_TYPE = 18 + + # Must be consistent with C++ FieldDescriptor::CppType enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + CPPTYPE_INT32 = 1 + CPPTYPE_INT64 = 2 + CPPTYPE_UINT32 = 3 + CPPTYPE_UINT64 = 4 + CPPTYPE_DOUBLE = 5 + CPPTYPE_FLOAT = 6 + CPPTYPE_BOOL = 7 + CPPTYPE_ENUM = 8 + CPPTYPE_STRING = 9 + CPPTYPE_MESSAGE = 10 + MAX_CPPTYPE = 10 + + # Must be consistent with C++ FieldDescriptor::Label enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + LABEL_OPTIONAL = 1 + LABEL_REQUIRED = 2 + LABEL_REPEATED = 3 + MAX_LABEL = 3 + + def __init__(self, name, full_name, index, number, type, cpp_type, label, + default_value, message_type, enum_type, containing_type, + is_extension, extension_scope, options=None, + has_default_value=True): + """The arguments are as described in the description of FieldDescriptor + attributes above. + + Note that containing_type may be None, and may be set later if necessary + (to deal with circular references between message types, for example). + Likewise for extension_scope. + """ + super(FieldDescriptor, self).__init__(options, 'FieldOptions') + self.name = name + self.full_name = full_name + self.index = index + self.number = number + self.type = type + self.cpp_type = cpp_type + self.label = label + self.has_default_value = has_default_value + self.default_value = default_value + self.containing_type = containing_type + self.message_type = message_type + self.enum_type = enum_type + self.is_extension = is_extension + self.extension_scope = extension_scope + + +class EnumDescriptor(_NestedDescriptorBase): + + """Descriptor for an enum defined in a .proto file. + + An EnumDescriptor instance has the following attributes: + + name: (str) Name of the enum type. + full_name: (str) Full name of the type, including package name + and any enclosing type(s). + + values: (list of EnumValueDescriptors) List of the values + in this enum. + values_by_name: (dict str -> EnumValueDescriptor) Same as |values|, + but indexed by the "name" field of each EnumValueDescriptor. + values_by_number: (dict int -> EnumValueDescriptor) Same as |values|, + but indexed by the "number" field of each EnumValueDescriptor. + containing_type: (Descriptor) Descriptor of the immediate containing + type of this enum, or None if this is an enum defined at the + top level in a .proto file. Set by Descriptor's constructor + if we're passed into one. + file: (FileDescriptor) Reference to file descriptor. + options: (descriptor_pb2.EnumOptions) Enum options message or + None to use default enum options. + """ + + def __init__(self, name, full_name, filename, values, + containing_type=None, options=None, file=None, + serialized_start=None, serialized_end=None): + """Arguments are as described in the attribute description above. + + Note that filename is an obsolete argument, that is not used anymore. + Please use file.name to access this as an attribute. + """ + super(EnumDescriptor, self).__init__( + options, 'EnumOptions', name, full_name, file, + containing_type, serialized_start=serialized_start, + serialized_end=serialized_start) + + self.values = values + for value in self.values: + value.type = self + self.values_by_name = dict((v.name, v) for v in values) + self.values_by_number = dict((v.number, v) for v in values) + + self._serialized_start = serialized_start + self._serialized_end = serialized_end + + def CopyToProto(self, proto): + """Copies this to a descriptor_pb2.EnumDescriptorProto. + + Args: + proto: An empty descriptor_pb2.EnumDescriptorProto. + """ + # This function is overriden to give a better doc comment. + super(EnumDescriptor, self).CopyToProto(proto) + + +class EnumValueDescriptor(DescriptorBase): + + """Descriptor for a single value within an enum. + + name: (str) Name of this value. + index: (int) Dense, 0-indexed index giving the order that this + value appears textually within its enum in the .proto file. + number: (int) Actual number assigned to this enum value. + type: (EnumDescriptor) EnumDescriptor to which this value + belongs. Set by EnumDescriptor's constructor if we're + passed into one. + options: (descriptor_pb2.EnumValueOptions) Enum value options message or + None to use default enum value options options. + """ + + def __init__(self, name, index, number, type=None, options=None): + """Arguments are as described in the attribute description above.""" + super(EnumValueDescriptor, self).__init__(options, 'EnumValueOptions') + self.name = name + self.index = index + self.number = number + self.type = type + + +class ServiceDescriptor(_NestedDescriptorBase): + + """Descriptor for a service. + + name: (str) Name of the service. + full_name: (str) Full name of the service, including package name. + index: (int) 0-indexed index giving the order that this services + definition appears withing the .proto file. + methods: (list of MethodDescriptor) List of methods provided by this + service. + options: (descriptor_pb2.ServiceOptions) Service options message or + None to use default service options. + file: (FileDescriptor) Reference to file info. + """ + + def __init__(self, name, full_name, index, methods, options=None, file=None, + serialized_start=None, serialized_end=None): + super(ServiceDescriptor, self).__init__( + options, 'ServiceOptions', name, full_name, file, + None, serialized_start=serialized_start, + serialized_end=serialized_end) + self.index = index + self.methods = methods + # Set the containing service for each method in this service. + for method in self.methods: + method.containing_service = self + + def FindMethodByName(self, name): + """Searches for the specified method, and returns its descriptor.""" + for method in self.methods: + if name == method.name: + return method + return None + + def CopyToProto(self, proto): + """Copies this to a descriptor_pb2.ServiceDescriptorProto. + + Args: + proto: An empty descriptor_pb2.ServiceDescriptorProto. + """ + # This function is overriden to give a better doc comment. + super(ServiceDescriptor, self).CopyToProto(proto) + + +class MethodDescriptor(DescriptorBase): + + """Descriptor for a method in a service. + + name: (str) Name of the method within the service. + full_name: (str) Full name of method. + index: (int) 0-indexed index of the method inside the service. + containing_service: (ServiceDescriptor) The service that contains this + method. + input_type: The descriptor of the message that this method accepts. + output_type: The descriptor of the message that this method returns. + options: (descriptor_pb2.MethodOptions) Method options message or + None to use default method options. + """ + + def __init__(self, name, full_name, index, containing_service, + input_type, output_type, options=None): + """The arguments are as described in the description of MethodDescriptor + attributes above. + + Note that containing_service may be None, and may be set later if necessary. + """ + super(MethodDescriptor, self).__init__(options, 'MethodOptions') + self.name = name + self.full_name = full_name + self.index = index + self.containing_service = containing_service + self.input_type = input_type + self.output_type = output_type + + +class FileDescriptor(DescriptorBase): + """Descriptor for a file. Mimics the descriptor_pb2.FileDescriptorProto. + + name: name of file, relative to root of source tree. + package: name of the package + serialized_pb: (str) Byte string of serialized + descriptor_pb2.FileDescriptorProto. + """ + + def __init__(self, name, package, options=None, serialized_pb=None): + """Constructor.""" + super(FileDescriptor, self).__init__(options, 'FileOptions') + + self.name = name + self.package = package + self.serialized_pb = serialized_pb + + def CopyToProto(self, proto): + """Copies this to a descriptor_pb2.FileDescriptorProto. + + Args: + proto: An empty descriptor_pb2.FileDescriptorProto. + """ + proto.ParseFromString(self.serialized_pb) + + +def _ParseOptions(message, string): + """Parses serialized options. + + This helper function is used to parse serialized options in generated + proto2 files. It must not be used outside proto2. + """ + message.ParseFromString(string) + return message diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/containers.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/containers.py index 79b3b3ea..5cc7d6d0 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/containers.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/containers.py @@ -1,244 +1,244 @@ -# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Contains container classes to represent different protocol buffer types.
-
-This file defines container classes which represent categories of protocol
-buffer field types which need extra maintenance. Currently these categories
-are:
- - Repeated scalar fields - These are all repeated fields which aren't
- composite (e.g. they are of simple types like int32, string, etc).
- - Repeated composite fields - Repeated fields which are composite. This
- includes groups and nested messages.
-"""
-
-__author__ = '[email protected] (Petar Petrov)'
-
-
-class BaseContainer(object):
-
- """Base container class."""
-
- # Minimizes memory usage and disallows assignment to other attributes.
- __slots__ = ['_message_listener', '_values']
-
- def __init__(self, message_listener):
- """
- Args:
- message_listener: A MessageListener implementation.
- The RepeatedScalarFieldContainer will call this object's
- Modified() method when it is modified.
- """
- self._message_listener = message_listener
- self._values = []
-
- def __getitem__(self, key):
- """Retrieves item by the specified key."""
- return self._values[key]
-
- def __len__(self):
- """Returns the number of elements in the container."""
- return len(self._values)
-
- def __ne__(self, other):
- """Checks if another instance isn't equal to this one."""
- # The concrete classes should define __eq__.
- return not self == other
-
- def __repr__(self):
- return repr(self._values)
-
-
-class RepeatedScalarFieldContainer(BaseContainer):
-
- """Simple, type-checked, list-like container for holding repeated scalars."""
-
- # Disallows assignment to other attributes.
- __slots__ = ['_type_checker']
-
- def __init__(self, message_listener, type_checker):
- """
- Args:
- message_listener: A MessageListener implementation.
- The RepeatedScalarFieldContainer will call this object's
- Modified() method when it is modified.
- type_checker: A type_checkers.ValueChecker instance to run on elements
- inserted into this container.
- """
- super(RepeatedScalarFieldContainer, self).__init__(message_listener)
- self._type_checker = type_checker
-
- def append(self, value):
- """Appends an item to the list. Similar to list.append()."""
- self._type_checker.CheckValue(value)
- self._values.append(value)
- if not self._message_listener.dirty:
- self._message_listener.Modified()
-
- def insert(self, key, value):
- """Inserts the item at the specified position. Similar to list.insert()."""
- self._type_checker.CheckValue(value)
- self._values.insert(key, value)
- if not self._message_listener.dirty:
- self._message_listener.Modified()
-
- def extend(self, elem_seq):
- """Extends by appending the given sequence. Similar to list.extend()."""
- if not elem_seq:
- return
-
- new_values = []
- for elem in elem_seq:
- self._type_checker.CheckValue(elem)
- new_values.append(elem)
- self._values.extend(new_values)
- self._message_listener.Modified()
-
- def MergeFrom(self, other):
- """Appends the contents of another repeated field of the same type to this
- one. We do not check the types of the individual fields.
- """
- self._values.extend(other._values)
- self._message_listener.Modified()
-
- def remove(self, elem):
- """Removes an item from the list. Similar to list.remove()."""
- self._values.remove(elem)
- self._message_listener.Modified()
-
- def __setitem__(self, key, value):
- """Sets the item on the specified position."""
- self._type_checker.CheckValue(value)
- self._values[key] = value
- self._message_listener.Modified()
-
- def __getslice__(self, start, stop):
- """Retrieves the subset of items from between the specified indices."""
- return self._values[start:stop]
-
- def __setslice__(self, start, stop, values):
- """Sets the subset of items from between the specified indices."""
- new_values = []
- for value in values:
- self._type_checker.CheckValue(value)
- new_values.append(value)
- self._values[start:stop] = new_values
- self._message_listener.Modified()
-
- def __delitem__(self, key):
- """Deletes the item at the specified position."""
- del self._values[key]
- self._message_listener.Modified()
-
- def __delslice__(self, start, stop):
- """Deletes the subset of items from between the specified indices."""
- del self._values[start:stop]
- self._message_listener.Modified()
-
- def __eq__(self, other):
- """Compares the current instance with another one."""
- if self is other:
- return True
- # Special case for the same type which should be common and fast.
- if isinstance(other, self.__class__):
- return other._values == self._values
- # We are presumably comparing against some other sequence type.
- return other == self._values
-
-
-class RepeatedCompositeFieldContainer(BaseContainer):
-
- """Simple, list-like container for holding repeated composite fields."""
-
- # Disallows assignment to other attributes.
- __slots__ = ['_message_descriptor']
-
- def __init__(self, message_listener, message_descriptor):
- """
- Note that we pass in a descriptor instead of the generated directly,
- since at the time we construct a _RepeatedCompositeFieldContainer we
- haven't yet necessarily initialized the type that will be contained in the
- container.
-
- Args:
- message_listener: A MessageListener implementation.
- The RepeatedCompositeFieldContainer will call this object's
- Modified() method when it is modified.
- message_descriptor: A Descriptor instance describing the protocol type
- that should be present in this container. We'll use the
- _concrete_class field of this descriptor when the client calls add().
- """
- super(RepeatedCompositeFieldContainer, self).__init__(message_listener)
- self._message_descriptor = message_descriptor
-
- def add(self):
- new_element = self._message_descriptor._concrete_class()
- new_element._SetListener(self._message_listener)
- self._values.append(new_element)
- if not self._message_listener.dirty:
- self._message_listener.Modified()
- return new_element
-
- def MergeFrom(self, other):
- """Appends the contents of another repeated field of the same type to this
- one, copying each individual message.
- """
- message_class = self._message_descriptor._concrete_class
- listener = self._message_listener
- values = self._values
- for message in other._values:
- new_element = message_class()
- new_element._SetListener(listener)
- new_element.MergeFrom(message)
- values.append(new_element)
- listener.Modified()
-
- def __getslice__(self, start, stop):
- """Retrieves the subset of items from between the specified indices."""
- return self._values[start:stop]
-
- def __delitem__(self, key):
- """Deletes the item at the specified position."""
- del self._values[key]
- self._message_listener.Modified()
-
- def __delslice__(self, start, stop):
- """Deletes the subset of items from between the specified indices."""
- del self._values[start:stop]
- self._message_listener.Modified()
-
- def __eq__(self, other):
- """Compares the current instance with another one."""
- if self is other:
- return True
- if not isinstance(other, self.__class__):
- raise TypeError('Can only compare repeated composite fields against '
- 'other repeated composite fields.')
- return self._values == other._values
+# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains container classes to represent different protocol buffer types. + +This file defines container classes which represent categories of protocol +buffer field types which need extra maintenance. Currently these categories +are: + - Repeated scalar fields - These are all repeated fields which aren't + composite (e.g. they are of simple types like int32, string, etc). + - Repeated composite fields - Repeated fields which are composite. This + includes groups and nested messages. +""" + +__author__ = '[email protected] (Petar Petrov)' + + +class BaseContainer(object): + + """Base container class.""" + + # Minimizes memory usage and disallows assignment to other attributes. + __slots__ = ['_message_listener', '_values'] + + def __init__(self, message_listener): + """ + Args: + message_listener: A MessageListener implementation. + The RepeatedScalarFieldContainer will call this object's + Modified() method when it is modified. + """ + self._message_listener = message_listener + self._values = [] + + def __getitem__(self, key): + """Retrieves item by the specified key.""" + return self._values[key] + + def __len__(self): + """Returns the number of elements in the container.""" + return len(self._values) + + def __ne__(self, other): + """Checks if another instance isn't equal to this one.""" + # The concrete classes should define __eq__. + return not self == other + + def __repr__(self): + return repr(self._values) + + +class RepeatedScalarFieldContainer(BaseContainer): + + """Simple, type-checked, list-like container for holding repeated scalars.""" + + # Disallows assignment to other attributes. + __slots__ = ['_type_checker'] + + def __init__(self, message_listener, type_checker): + """ + Args: + message_listener: A MessageListener implementation. + The RepeatedScalarFieldContainer will call this object's + Modified() method when it is modified. + type_checker: A type_checkers.ValueChecker instance to run on elements + inserted into this container. + """ + super(RepeatedScalarFieldContainer, self).__init__(message_listener) + self._type_checker = type_checker + + def append(self, value): + """Appends an item to the list. Similar to list.append().""" + self._type_checker.CheckValue(value) + self._values.append(value) + if not self._message_listener.dirty: + self._message_listener.Modified() + + def insert(self, key, value): + """Inserts the item at the specified position. Similar to list.insert().""" + self._type_checker.CheckValue(value) + self._values.insert(key, value) + if not self._message_listener.dirty: + self._message_listener.Modified() + + def extend(self, elem_seq): + """Extends by appending the given sequence. Similar to list.extend().""" + if not elem_seq: + return + + new_values = [] + for elem in elem_seq: + self._type_checker.CheckValue(elem) + new_values.append(elem) + self._values.extend(new_values) + self._message_listener.Modified() + + def MergeFrom(self, other): + """Appends the contents of another repeated field of the same type to this + one. We do not check the types of the individual fields. + """ + self._values.extend(other._values) + self._message_listener.Modified() + + def remove(self, elem): + """Removes an item from the list. Similar to list.remove().""" + self._values.remove(elem) + self._message_listener.Modified() + + def __setitem__(self, key, value): + """Sets the item on the specified position.""" + self._type_checker.CheckValue(value) + self._values[key] = value + self._message_listener.Modified() + + def __getslice__(self, start, stop): + """Retrieves the subset of items from between the specified indices.""" + return self._values[start:stop] + + def __setslice__(self, start, stop, values): + """Sets the subset of items from between the specified indices.""" + new_values = [] + for value in values: + self._type_checker.CheckValue(value) + new_values.append(value) + self._values[start:stop] = new_values + self._message_listener.Modified() + + def __delitem__(self, key): + """Deletes the item at the specified position.""" + del self._values[key] + self._message_listener.Modified() + + def __delslice__(self, start, stop): + """Deletes the subset of items from between the specified indices.""" + del self._values[start:stop] + self._message_listener.Modified() + + def __eq__(self, other): + """Compares the current instance with another one.""" + if self is other: + return True + # Special case for the same type which should be common and fast. + if isinstance(other, self.__class__): + return other._values == self._values + # We are presumably comparing against some other sequence type. + return other == self._values + + +class RepeatedCompositeFieldContainer(BaseContainer): + + """Simple, list-like container for holding repeated composite fields.""" + + # Disallows assignment to other attributes. + __slots__ = ['_message_descriptor'] + + def __init__(self, message_listener, message_descriptor): + """ + Note that we pass in a descriptor instead of the generated directly, + since at the time we construct a _RepeatedCompositeFieldContainer we + haven't yet necessarily initialized the type that will be contained in the + container. + + Args: + message_listener: A MessageListener implementation. + The RepeatedCompositeFieldContainer will call this object's + Modified() method when it is modified. + message_descriptor: A Descriptor instance describing the protocol type + that should be present in this container. We'll use the + _concrete_class field of this descriptor when the client calls add(). + """ + super(RepeatedCompositeFieldContainer, self).__init__(message_listener) + self._message_descriptor = message_descriptor + + def add(self): + new_element = self._message_descriptor._concrete_class() + new_element._SetListener(self._message_listener) + self._values.append(new_element) + if not self._message_listener.dirty: + self._message_listener.Modified() + return new_element + + def MergeFrom(self, other): + """Appends the contents of another repeated field of the same type to this + one, copying each individual message. + """ + message_class = self._message_descriptor._concrete_class + listener = self._message_listener + values = self._values + for message in other._values: + new_element = message_class() + new_element._SetListener(listener) + new_element.MergeFrom(message) + values.append(new_element) + listener.Modified() + + def __getslice__(self, start, stop): + """Retrieves the subset of items from between the specified indices.""" + return self._values[start:stop] + + def __delitem__(self, key): + """Deletes the item at the specified position.""" + del self._values[key] + self._message_listener.Modified() + + def __delslice__(self, start, stop): + """Deletes the subset of items from between the specified indices.""" + del self._values[start:stop] + self._message_listener.Modified() + + def __eq__(self, other): + """Compares the current instance with another one.""" + if self is other: + return True + if not isinstance(other, self.__class__): + raise TypeError('Can only compare repeated composite fields against ' + 'other repeated composite fields.') + return self._values == other._values diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/decoder.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/decoder.py index 03ab7287..461a30c0 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/decoder.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/decoder.py @@ -1,641 +1,641 @@ -# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Code for decoding protocol buffer primitives.
-
-This code is very similar to encoder.py -- read the docs for that module first.
-
-A "decoder" is a function with the signature:
- Decode(buffer, pos, end, message, field_dict)
-The arguments are:
- buffer: The string containing the encoded message.
- pos: The current position in the string.
- end: The position in the string where the current message ends. May be
- less than len(buffer) if we're reading a sub-message.
- message: The message object into which we're parsing.
- field_dict: message._fields (avoids a hashtable lookup).
-The decoder reads the field and stores it into field_dict, returning the new
-buffer position. A decoder for a repeated field may proactively decode all of
-the elements of that field, if they appear consecutively.
-
-Note that decoders may throw any of the following:
- IndexError: Indicates a truncated message.
- struct.error: Unpacking of a fixed-width field failed.
- message.DecodeError: Other errors.
-
-Decoders are expected to raise an exception if they are called with pos > end.
-This allows callers to be lax about bounds checking: it's fineto read past
-"end" as long as you are sure that someone else will notice and throw an
-exception later on.
-
-Something up the call stack is expected to catch IndexError and struct.error
-and convert them to message.DecodeError.
-
-Decoders are constructed using decoder constructors with the signature:
- MakeDecoder(field_number, is_repeated, is_packed, key, new_default)
-The arguments are:
- field_number: The field number of the field we want to decode.
- is_repeated: Is the field a repeated field? (bool)
- is_packed: Is the field a packed field? (bool)
- key: The key to use when looking up the field within field_dict.
- (This is actually the FieldDescriptor but nothing in this
- file should depend on that.)
- new_default: A function which takes a message object as a parameter and
- returns a new instance of the default value for this field.
- (This is called for repeated fields and sub-messages, when an
- instance does not already exist.)
-
-As with encoders, we define a decoder constructor for every type of field.
-Then, for every field of every message class we construct an actual decoder.
-That decoder goes into a dict indexed by tag, so when we decode a message
-we repeatedly read a tag, look up the corresponding decoder, and invoke it.
-"""
-
-__author__ = '[email protected] (Kenton Varda)'
-
-import struct
-from google.protobuf.internal import encoder
-from google.protobuf.internal import wire_format
-from google.protobuf import message
-
-
-# This is not for optimization, but rather to avoid conflicts with local
-# variables named "message".
-_DecodeError = message.DecodeError
-
-
-def _VarintDecoder(mask):
- """Return an encoder for a basic varint value (does not include tag).
-
- Decoded values will be bitwise-anded with the given mask before being
- returned, e.g. to limit them to 32 bits. The returned decoder does not
- take the usual "end" parameter -- the caller is expected to do bounds checking
- after the fact (often the caller can defer such checking until later). The
- decoder returns a (value, new_pos) pair.
- """
-
- local_ord = ord
- def DecodeVarint(buffer, pos):
- result = 0
- shift = 0
- while 1:
- b = local_ord(buffer[pos])
- result |= ((b & 0x7f) << shift)
- pos += 1
- if not (b & 0x80):
- result &= mask
- return (result, pos)
- shift += 7
- if shift >= 64:
- raise _DecodeError('Too many bytes when decoding varint.')
- return DecodeVarint
-
-
-def _SignedVarintDecoder(mask):
- """Like _VarintDecoder() but decodes signed values."""
-
- local_ord = ord
- def DecodeVarint(buffer, pos):
- result = 0
- shift = 0
- while 1:
- b = local_ord(buffer[pos])
- result |= ((b & 0x7f) << shift)
- pos += 1
- if not (b & 0x80):
- if result > 0x7fffffffffffffff:
- result -= (1 << 64)
- result |= ~mask
- else:
- result &= mask
- return (result, pos)
- shift += 7
- if shift >= 64:
- raise _DecodeError('Too many bytes when decoding varint.')
- return DecodeVarint
-
-
-_DecodeVarint = _VarintDecoder((1 << 64) - 1)
-_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1)
-
-# Use these versions for values which must be limited to 32 bits.
-_DecodeVarint32 = _VarintDecoder((1 << 32) - 1)
-_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1)
-
-
-def ReadTag(buffer, pos):
- """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple.
-
- We return the raw bytes of the tag rather than decoding them. The raw
- bytes can then be used to look up the proper decoder. This effectively allows
- us to trade some work that would be done in pure-python (decoding a varint)
- for work that is done in C (searching for a byte string in a hash table).
- In a low-level language it would be much cheaper to decode the varint and
- use that, but not in Python.
- """
-
- start = pos
- while ord(buffer[pos]) & 0x80:
- pos += 1
- pos += 1
- return (buffer[start:pos], pos)
-
-
-# --------------------------------------------------------------------
-
-
-def _SimpleDecoder(wire_type, decode_value):
- """Return a constructor for a decoder for fields of a particular type.
-
- Args:
- wire_type: The field's wire type.
- decode_value: A function which decodes an individual value, e.g.
- _DecodeVarint()
- """
-
- def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default):
- if is_packed:
- local_DecodeVarint = _DecodeVarint
- def DecodePackedField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- (endpoint, pos) = local_DecodeVarint(buffer, pos)
- endpoint += pos
- if endpoint > end:
- raise _DecodeError('Truncated message.')
- while pos < endpoint:
- (element, pos) = decode_value(buffer, pos)
- value.append(element)
- if pos > endpoint:
- del value[-1] # Discard corrupt value.
- raise _DecodeError('Packed element was truncated.')
- return pos
- return DecodePackedField
- elif is_repeated:
- tag_bytes = encoder.TagBytes(field_number, wire_type)
- tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- while 1:
- (element, new_pos) = decode_value(buffer, pos)
- value.append(element)
- # Predict that the next tag is another copy of the same repeated
- # field.
- pos = new_pos + tag_len
- if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
- # Prediction failed. Return.
- if new_pos > end:
- raise _DecodeError('Truncated message.')
- return new_pos
- return DecodeRepeatedField
- else:
- def DecodeField(buffer, pos, end, message, field_dict):
- (field_dict[key], pos) = decode_value(buffer, pos)
- if pos > end:
- del field_dict[key] # Discard corrupt value.
- raise _DecodeError('Truncated message.')
- return pos
- return DecodeField
-
- return SpecificDecoder
-
-
-def _ModifiedDecoder(wire_type, decode_value, modify_value):
- """Like SimpleDecoder but additionally invokes modify_value on every value
- before storing it. Usually modify_value is ZigZagDecode.
- """
-
- # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
- # not enough to make a significant difference.
-
- def InnerDecode(buffer, pos):
- (result, new_pos) = decode_value(buffer, pos)
- return (modify_value(result), new_pos)
- return _SimpleDecoder(wire_type, InnerDecode)
-
-
-def _StructPackDecoder(wire_type, format):
- """Return a constructor for a decoder for a fixed-width field.
-
- Args:
- wire_type: The field's wire type.
- format: The format string to pass to struct.unpack().
- """
-
- value_size = struct.calcsize(format)
- local_unpack = struct.unpack
-
- # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but
- # not enough to make a significant difference.
-
- # Note that we expect someone up-stack to catch struct.error and convert
- # it to _DecodeError -- this way we don't have to set up exception-
- # handling blocks every time we parse one value.
-
- def InnerDecode(buffer, pos):
- new_pos = pos + value_size
- result = local_unpack(format, buffer[pos:new_pos])[0]
- return (result, new_pos)
- return _SimpleDecoder(wire_type, InnerDecode)
-
-
-# --------------------------------------------------------------------
-
-
-Int32Decoder = EnumDecoder = _SimpleDecoder(
- wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
-
-Int64Decoder = _SimpleDecoder(
- wire_format.WIRETYPE_VARINT, _DecodeSignedVarint)
-
-UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32)
-UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint)
-
-SInt32Decoder = _ModifiedDecoder(
- wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode)
-SInt64Decoder = _ModifiedDecoder(
- wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode)
-
-# Note that Python conveniently guarantees that when using the '<' prefix on
-# formats, they will also have the same size across all platforms (as opposed
-# to without the prefix, where their sizes depend on the C compiler's basic
-# type sizes).
-Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
-Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
-SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
-SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
-FloatDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<f')
-DoubleDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<d')
-
-BoolDecoder = _ModifiedDecoder(
- wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
-
-
-def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
- """Returns a decoder for a string field."""
-
- local_DecodeVarint = _DecodeVarint
- local_unicode = unicode
-
- assert not is_packed
- if is_repeated:
- tag_bytes = encoder.TagBytes(field_number,
- wire_format.WIRETYPE_LENGTH_DELIMITED)
- tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- while 1:
- (size, pos) = local_DecodeVarint(buffer, pos)
- new_pos = pos + size
- if new_pos > end:
- raise _DecodeError('Truncated string.')
- value.append(local_unicode(buffer[pos:new_pos], 'utf-8'))
- # Predict that the next tag is another copy of the same repeated field.
- pos = new_pos + tag_len
- if buffer[new_pos:pos] != tag_bytes or new_pos == end:
- # Prediction failed. Return.
- return new_pos
- return DecodeRepeatedField
- else:
- def DecodeField(buffer, pos, end, message, field_dict):
- (size, pos) = local_DecodeVarint(buffer, pos)
- new_pos = pos + size
- if new_pos > end:
- raise _DecodeError('Truncated string.')
- field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8')
- return new_pos
- return DecodeField
-
-
-def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
- """Returns a decoder for a bytes field."""
-
- local_DecodeVarint = _DecodeVarint
-
- assert not is_packed
- if is_repeated:
- tag_bytes = encoder.TagBytes(field_number,
- wire_format.WIRETYPE_LENGTH_DELIMITED)
- tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- while 1:
- (size, pos) = local_DecodeVarint(buffer, pos)
- new_pos = pos + size
- if new_pos > end:
- raise _DecodeError('Truncated string.')
- value.append(buffer[pos:new_pos])
- # Predict that the next tag is another copy of the same repeated field.
- pos = new_pos + tag_len
- if buffer[new_pos:pos] != tag_bytes or new_pos == end:
- # Prediction failed. Return.
- return new_pos
- return DecodeRepeatedField
- else:
- def DecodeField(buffer, pos, end, message, field_dict):
- (size, pos) = local_DecodeVarint(buffer, pos)
- new_pos = pos + size
- if new_pos > end:
- raise _DecodeError('Truncated string.')
- field_dict[key] = buffer[pos:new_pos]
- return new_pos
- return DecodeField
-
-
-def GroupDecoder(field_number, is_repeated, is_packed, key, new_default):
- """Returns a decoder for a group field."""
-
- end_tag_bytes = encoder.TagBytes(field_number,
- wire_format.WIRETYPE_END_GROUP)
- end_tag_len = len(end_tag_bytes)
-
- assert not is_packed
- if is_repeated:
- tag_bytes = encoder.TagBytes(field_number,
- wire_format.WIRETYPE_START_GROUP)
- tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- while 1:
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- # Read sub-message.
- pos = value.add()._InternalParse(buffer, pos, end)
- # Read end tag.
- new_pos = pos+end_tag_len
- if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
- raise _DecodeError('Missing group end tag.')
- # Predict that the next tag is another copy of the same repeated field.
- pos = new_pos + tag_len
- if buffer[new_pos:pos] != tag_bytes or new_pos == end:
- # Prediction failed. Return.
- return new_pos
- return DecodeRepeatedField
- else:
- def DecodeField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- # Read sub-message.
- pos = value._InternalParse(buffer, pos, end)
- # Read end tag.
- new_pos = pos+end_tag_len
- if buffer[pos:new_pos] != end_tag_bytes or new_pos > end:
- raise _DecodeError('Missing group end tag.')
- return new_pos
- return DecodeField
-
-
-def MessageDecoder(field_number, is_repeated, is_packed, key, new_default):
- """Returns a decoder for a message field."""
-
- local_DecodeVarint = _DecodeVarint
-
- assert not is_packed
- if is_repeated:
- tag_bytes = encoder.TagBytes(field_number,
- wire_format.WIRETYPE_LENGTH_DELIMITED)
- tag_len = len(tag_bytes)
- def DecodeRepeatedField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- while 1:
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- # Read length.
- (size, pos) = local_DecodeVarint(buffer, pos)
- new_pos = pos + size
- if new_pos > end:
- raise _DecodeError('Truncated message.')
- # Read sub-message.
- if value.add()._InternalParse(buffer, pos, new_pos) != new_pos:
- # The only reason _InternalParse would return early is if it
- # encountered an end-group tag.
- raise _DecodeError('Unexpected end-group tag.')
- # Predict that the next tag is another copy of the same repeated field.
- pos = new_pos + tag_len
- if buffer[new_pos:pos] != tag_bytes or new_pos == end:
- # Prediction failed. Return.
- return new_pos
- return DecodeRepeatedField
- else:
- def DecodeField(buffer, pos, end, message, field_dict):
- value = field_dict.get(key)
- if value is None:
- value = field_dict.setdefault(key, new_default(message))
- # Read length.
- (size, pos) = local_DecodeVarint(buffer, pos)
- new_pos = pos + size
- if new_pos > end:
- raise _DecodeError('Truncated message.')
- # Read sub-message.
- if value._InternalParse(buffer, pos, new_pos) != new_pos:
- # The only reason _InternalParse would return early is if it encountered
- # an end-group tag.
- raise _DecodeError('Unexpected end-group tag.')
- return new_pos
- return DecodeField
-
-
-# --------------------------------------------------------------------
-
-MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP)
-
-def MessageSetItemDecoder(extensions_by_number):
- """Returns a decoder for a MessageSet item.
-
- The parameter is the _extensions_by_number map for the message class.
-
- The message set message looks like this:
- message MessageSet {
- repeated group Item = 1 {
- required int32 type_id = 2;
- required string message = 3;
- }
- }
- """
-
- type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT)
- message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)
- item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP)
-
- local_ReadTag = ReadTag
- local_DecodeVarint = _DecodeVarint
- local_SkipField = SkipField
-
- def DecodeItem(buffer, pos, end, message, field_dict):
- type_id = -1
- message_start = -1
- message_end = -1
-
- # Technically, type_id and message can appear in any order, so we need
- # a little loop here.
- while 1:
- (tag_bytes, pos) = local_ReadTag(buffer, pos)
- if tag_bytes == type_id_tag_bytes:
- (type_id, pos) = local_DecodeVarint(buffer, pos)
- elif tag_bytes == message_tag_bytes:
- (size, message_start) = local_DecodeVarint(buffer, pos)
- pos = message_end = message_start + size
- elif tag_bytes == item_end_tag_bytes:
- break
- else:
- pos = SkipField(buffer, pos, end, tag_bytes)
- if pos == -1:
- raise _DecodeError('Missing group end tag.')
-
- if pos > end:
- raise _DecodeError('Truncated message.')
-
- if type_id == -1:
- raise _DecodeError('MessageSet item missing type_id.')
- if message_start == -1:
- raise _DecodeError('MessageSet item missing message.')
-
- extension = extensions_by_number.get(type_id)
- if extension is not None:
- value = field_dict.get(extension)
- if value is None:
- value = field_dict.setdefault(
- extension, extension.message_type._concrete_class())
- if value._InternalParse(buffer, message_start,message_end) != message_end:
- # The only reason _InternalParse would return early is if it encountered
- # an end-group tag.
- raise _DecodeError('Unexpected end-group tag.')
-
- return pos
-
- return DecodeItem
-
-# --------------------------------------------------------------------
-# Optimization is not as heavy here because calls to SkipField() are rare,
-# except for handling end-group tags.
-
-def _SkipVarint(buffer, pos, end):
- """Skip a varint value. Returns the new position."""
-
- while ord(buffer[pos]) & 0x80:
- pos += 1
- pos += 1
- if pos > end:
- raise _DecodeError('Truncated message.')
- return pos
-
-def _SkipFixed64(buffer, pos, end):
- """Skip a fixed64 value. Returns the new position."""
-
- pos += 8
- if pos > end:
- raise _DecodeError('Truncated message.')
- return pos
-
-def _SkipLengthDelimited(buffer, pos, end):
- """Skip a length-delimited value. Returns the new position."""
-
- (size, pos) = _DecodeVarint(buffer, pos)
- pos += size
- if pos > end:
- raise _DecodeError('Truncated message.')
- return pos
-
-def _SkipGroup(buffer, pos, end):
- """Skip sub-group. Returns the new position."""
-
- while 1:
- (tag_bytes, pos) = ReadTag(buffer, pos)
- new_pos = SkipField(buffer, pos, end, tag_bytes)
- if new_pos == -1:
- return pos
- pos = new_pos
-
-def _EndGroup(buffer, pos, end):
- """Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
-
- return -1
-
-def _SkipFixed32(buffer, pos, end):
- """Skip a fixed32 value. Returns the new position."""
-
- pos += 4
- if pos > end:
- raise _DecodeError('Truncated message.')
- return pos
-
-def _RaiseInvalidWireType(buffer, pos, end):
- """Skip function for unknown wire types. Raises an exception."""
-
- raise _DecodeError('Tag had invalid wire type.')
-
-def _FieldSkipper():
- """Constructs the SkipField function."""
-
- WIRETYPE_TO_SKIPPER = [
- _SkipVarint,
- _SkipFixed64,
- _SkipLengthDelimited,
- _SkipGroup,
- _EndGroup,
- _SkipFixed32,
- _RaiseInvalidWireType,
- _RaiseInvalidWireType,
- ]
-
- wiretype_mask = wire_format.TAG_TYPE_MASK
- local_ord = ord
-
- def SkipField(buffer, pos, end, tag_bytes):
- """Skips a field with the specified tag.
-
- |pos| should point to the byte immediately after the tag.
-
- Returns:
- The new position (after the tag value), or -1 if the tag is an end-group
- tag (in which case the calling loop should break).
- """
-
- # The wire type is always in the first byte since varints are little-endian.
- wire_type = local_ord(tag_bytes[0]) & wiretype_mask
- return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
-
- return SkipField
-
-SkipField = _FieldSkipper()
+# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Code for decoding protocol buffer primitives. + +This code is very similar to encoder.py -- read the docs for that module first. + +A "decoder" is a function with the signature: + Decode(buffer, pos, end, message, field_dict) +The arguments are: + buffer: The string containing the encoded message. + pos: The current position in the string. + end: The position in the string where the current message ends. May be + less than len(buffer) if we're reading a sub-message. + message: The message object into which we're parsing. + field_dict: message._fields (avoids a hashtable lookup). +The decoder reads the field and stores it into field_dict, returning the new +buffer position. A decoder for a repeated field may proactively decode all of +the elements of that field, if they appear consecutively. + +Note that decoders may throw any of the following: + IndexError: Indicates a truncated message. + struct.error: Unpacking of a fixed-width field failed. + message.DecodeError: Other errors. + +Decoders are expected to raise an exception if they are called with pos > end. +This allows callers to be lax about bounds checking: it's fineto read past +"end" as long as you are sure that someone else will notice and throw an +exception later on. + +Something up the call stack is expected to catch IndexError and struct.error +and convert them to message.DecodeError. + +Decoders are constructed using decoder constructors with the signature: + MakeDecoder(field_number, is_repeated, is_packed, key, new_default) +The arguments are: + field_number: The field number of the field we want to decode. + is_repeated: Is the field a repeated field? (bool) + is_packed: Is the field a packed field? (bool) + key: The key to use when looking up the field within field_dict. + (This is actually the FieldDescriptor but nothing in this + file should depend on that.) + new_default: A function which takes a message object as a parameter and + returns a new instance of the default value for this field. + (This is called for repeated fields and sub-messages, when an + instance does not already exist.) + +As with encoders, we define a decoder constructor for every type of field. +Then, for every field of every message class we construct an actual decoder. +That decoder goes into a dict indexed by tag, so when we decode a message +we repeatedly read a tag, look up the corresponding decoder, and invoke it. +""" + +__author__ = '[email protected] (Kenton Varda)' + +import struct +from google.protobuf.internal import encoder +from google.protobuf.internal import wire_format +from google.protobuf import message + + +# This is not for optimization, but rather to avoid conflicts with local +# variables named "message". +_DecodeError = message.DecodeError + + +def _VarintDecoder(mask): + """Return an encoder for a basic varint value (does not include tag). + + Decoded values will be bitwise-anded with the given mask before being + returned, e.g. to limit them to 32 bits. The returned decoder does not + take the usual "end" parameter -- the caller is expected to do bounds checking + after the fact (often the caller can defer such checking until later). The + decoder returns a (value, new_pos) pair. + """ + + local_ord = ord + def DecodeVarint(buffer, pos): + result = 0 + shift = 0 + while 1: + b = local_ord(buffer[pos]) + result |= ((b & 0x7f) << shift) + pos += 1 + if not (b & 0x80): + result &= mask + return (result, pos) + shift += 7 + if shift >= 64: + raise _DecodeError('Too many bytes when decoding varint.') + return DecodeVarint + + +def _SignedVarintDecoder(mask): + """Like _VarintDecoder() but decodes signed values.""" + + local_ord = ord + def DecodeVarint(buffer, pos): + result = 0 + shift = 0 + while 1: + b = local_ord(buffer[pos]) + result |= ((b & 0x7f) << shift) + pos += 1 + if not (b & 0x80): + if result > 0x7fffffffffffffff: + result -= (1 << 64) + result |= ~mask + else: + result &= mask + return (result, pos) + shift += 7 + if shift >= 64: + raise _DecodeError('Too many bytes when decoding varint.') + return DecodeVarint + + +_DecodeVarint = _VarintDecoder((1 << 64) - 1) +_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1) + +# Use these versions for values which must be limited to 32 bits. +_DecodeVarint32 = _VarintDecoder((1 << 32) - 1) +_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1) + + +def ReadTag(buffer, pos): + """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple. + + We return the raw bytes of the tag rather than decoding them. The raw + bytes can then be used to look up the proper decoder. This effectively allows + us to trade some work that would be done in pure-python (decoding a varint) + for work that is done in C (searching for a byte string in a hash table). + In a low-level language it would be much cheaper to decode the varint and + use that, but not in Python. + """ + + start = pos + while ord(buffer[pos]) & 0x80: + pos += 1 + pos += 1 + return (buffer[start:pos], pos) + + +# -------------------------------------------------------------------- + + +def _SimpleDecoder(wire_type, decode_value): + """Return a constructor for a decoder for fields of a particular type. + + Args: + wire_type: The field's wire type. + decode_value: A function which decodes an individual value, e.g. + _DecodeVarint() + """ + + def SpecificDecoder(field_number, is_repeated, is_packed, key, new_default): + if is_packed: + local_DecodeVarint = _DecodeVarint + def DecodePackedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + (endpoint, pos) = local_DecodeVarint(buffer, pos) + endpoint += pos + if endpoint > end: + raise _DecodeError('Truncated message.') + while pos < endpoint: + (element, pos) = decode_value(buffer, pos) + value.append(element) + if pos > endpoint: + del value[-1] # Discard corrupt value. + raise _DecodeError('Packed element was truncated.') + return pos + return DecodePackedField + elif is_repeated: + tag_bytes = encoder.TagBytes(field_number, wire_type) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (element, new_pos) = decode_value(buffer, pos) + value.append(element) + # Predict that the next tag is another copy of the same repeated + # field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos >= end: + # Prediction failed. Return. + if new_pos > end: + raise _DecodeError('Truncated message.') + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + (field_dict[key], pos) = decode_value(buffer, pos) + if pos > end: + del field_dict[key] # Discard corrupt value. + raise _DecodeError('Truncated message.') + return pos + return DecodeField + + return SpecificDecoder + + +def _ModifiedDecoder(wire_type, decode_value, modify_value): + """Like SimpleDecoder but additionally invokes modify_value on every value + before storing it. Usually modify_value is ZigZagDecode. + """ + + # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but + # not enough to make a significant difference. + + def InnerDecode(buffer, pos): + (result, new_pos) = decode_value(buffer, pos) + return (modify_value(result), new_pos) + return _SimpleDecoder(wire_type, InnerDecode) + + +def _StructPackDecoder(wire_type, format): + """Return a constructor for a decoder for a fixed-width field. + + Args: + wire_type: The field's wire type. + format: The format string to pass to struct.unpack(). + """ + + value_size = struct.calcsize(format) + local_unpack = struct.unpack + + # Reusing _SimpleDecoder is slightly slower than copying a bunch of code, but + # not enough to make a significant difference. + + # Note that we expect someone up-stack to catch struct.error and convert + # it to _DecodeError -- this way we don't have to set up exception- + # handling blocks every time we parse one value. + + def InnerDecode(buffer, pos): + new_pos = pos + value_size + result = local_unpack(format, buffer[pos:new_pos])[0] + return (result, new_pos) + return _SimpleDecoder(wire_type, InnerDecode) + + +# -------------------------------------------------------------------- + + +Int32Decoder = EnumDecoder = _SimpleDecoder( + wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) + +Int64Decoder = _SimpleDecoder( + wire_format.WIRETYPE_VARINT, _DecodeSignedVarint) + +UInt32Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint32) +UInt64Decoder = _SimpleDecoder(wire_format.WIRETYPE_VARINT, _DecodeVarint) + +SInt32Decoder = _ModifiedDecoder( + wire_format.WIRETYPE_VARINT, _DecodeVarint32, wire_format.ZigZagDecode) +SInt64Decoder = _ModifiedDecoder( + wire_format.WIRETYPE_VARINT, _DecodeVarint, wire_format.ZigZagDecode) + +# Note that Python conveniently guarantees that when using the '<' prefix on +# formats, they will also have the same size across all platforms (as opposed +# to without the prefix, where their sizes depend on the C compiler's basic +# type sizes). +Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') +Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') +SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') +SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') +FloatDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<f') +DoubleDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<d') + +BoolDecoder = _ModifiedDecoder( + wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) + + +def StringDecoder(field_number, is_repeated, is_packed, key, new_default): + """Returns a decoder for a string field.""" + + local_DecodeVarint = _DecodeVarint + local_unicode = unicode + + assert not is_packed + if is_repeated: + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_LENGTH_DELIMITED) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated string.') + value.append(local_unicode(buffer[pos:new_pos], 'utf-8')) + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated string.') + field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8') + return new_pos + return DecodeField + + +def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): + """Returns a decoder for a bytes field.""" + + local_DecodeVarint = _DecodeVarint + + assert not is_packed + if is_repeated: + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_LENGTH_DELIMITED) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated string.') + value.append(buffer[pos:new_pos]) + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated string.') + field_dict[key] = buffer[pos:new_pos] + return new_pos + return DecodeField + + +def GroupDecoder(field_number, is_repeated, is_packed, key, new_default): + """Returns a decoder for a group field.""" + + end_tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_END_GROUP) + end_tag_len = len(end_tag_bytes) + + assert not is_packed + if is_repeated: + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_START_GROUP) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + # Read sub-message. + pos = value.add()._InternalParse(buffer, pos, end) + # Read end tag. + new_pos = pos+end_tag_len + if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: + raise _DecodeError('Missing group end tag.') + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + # Read sub-message. + pos = value._InternalParse(buffer, pos, end) + # Read end tag. + new_pos = pos+end_tag_len + if buffer[pos:new_pos] != end_tag_bytes or new_pos > end: + raise _DecodeError('Missing group end tag.') + return new_pos + return DecodeField + + +def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): + """Returns a decoder for a message field.""" + + local_DecodeVarint = _DecodeVarint + + assert not is_packed + if is_repeated: + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_LENGTH_DELIMITED) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + # Read length. + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated message.') + # Read sub-message. + if value.add()._InternalParse(buffer, pos, new_pos) != new_pos: + # The only reason _InternalParse would return early is if it + # encountered an end-group tag. + raise _DecodeError('Unexpected end-group tag.') + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + # Read length. + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated message.') + # Read sub-message. + if value._InternalParse(buffer, pos, new_pos) != new_pos: + # The only reason _InternalParse would return early is if it encountered + # an end-group tag. + raise _DecodeError('Unexpected end-group tag.') + return new_pos + return DecodeField + + +# -------------------------------------------------------------------- + +MESSAGE_SET_ITEM_TAG = encoder.TagBytes(1, wire_format.WIRETYPE_START_GROUP) + +def MessageSetItemDecoder(extensions_by_number): + """Returns a decoder for a MessageSet item. + + The parameter is the _extensions_by_number map for the message class. + + The message set message looks like this: + message MessageSet { + repeated group Item = 1 { + required int32 type_id = 2; + required string message = 3; + } + } + """ + + type_id_tag_bytes = encoder.TagBytes(2, wire_format.WIRETYPE_VARINT) + message_tag_bytes = encoder.TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED) + item_end_tag_bytes = encoder.TagBytes(1, wire_format.WIRETYPE_END_GROUP) + + local_ReadTag = ReadTag + local_DecodeVarint = _DecodeVarint + local_SkipField = SkipField + + def DecodeItem(buffer, pos, end, message, field_dict): + type_id = -1 + message_start = -1 + message_end = -1 + + # Technically, type_id and message can appear in any order, so we need + # a little loop here. + while 1: + (tag_bytes, pos) = local_ReadTag(buffer, pos) + if tag_bytes == type_id_tag_bytes: + (type_id, pos) = local_DecodeVarint(buffer, pos) + elif tag_bytes == message_tag_bytes: + (size, message_start) = local_DecodeVarint(buffer, pos) + pos = message_end = message_start + size + elif tag_bytes == item_end_tag_bytes: + break + else: + pos = SkipField(buffer, pos, end, tag_bytes) + if pos == -1: + raise _DecodeError('Missing group end tag.') + + if pos > end: + raise _DecodeError('Truncated message.') + + if type_id == -1: + raise _DecodeError('MessageSet item missing type_id.') + if message_start == -1: + raise _DecodeError('MessageSet item missing message.') + + extension = extensions_by_number.get(type_id) + if extension is not None: + value = field_dict.get(extension) + if value is None: + value = field_dict.setdefault( + extension, extension.message_type._concrete_class()) + if value._InternalParse(buffer, message_start,message_end) != message_end: + # The only reason _InternalParse would return early is if it encountered + # an end-group tag. + raise _DecodeError('Unexpected end-group tag.') + + return pos + + return DecodeItem + +# -------------------------------------------------------------------- +# Optimization is not as heavy here because calls to SkipField() are rare, +# except for handling end-group tags. + +def _SkipVarint(buffer, pos, end): + """Skip a varint value. Returns the new position.""" + + while ord(buffer[pos]) & 0x80: + pos += 1 + pos += 1 + if pos > end: + raise _DecodeError('Truncated message.') + return pos + +def _SkipFixed64(buffer, pos, end): + """Skip a fixed64 value. Returns the new position.""" + + pos += 8 + if pos > end: + raise _DecodeError('Truncated message.') + return pos + +def _SkipLengthDelimited(buffer, pos, end): + """Skip a length-delimited value. Returns the new position.""" + + (size, pos) = _DecodeVarint(buffer, pos) + pos += size + if pos > end: + raise _DecodeError('Truncated message.') + return pos + +def _SkipGroup(buffer, pos, end): + """Skip sub-group. Returns the new position.""" + + while 1: + (tag_bytes, pos) = ReadTag(buffer, pos) + new_pos = SkipField(buffer, pos, end, tag_bytes) + if new_pos == -1: + return pos + pos = new_pos + +def _EndGroup(buffer, pos, end): + """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" + + return -1 + +def _SkipFixed32(buffer, pos, end): + """Skip a fixed32 value. Returns the new position.""" + + pos += 4 + if pos > end: + raise _DecodeError('Truncated message.') + return pos + +def _RaiseInvalidWireType(buffer, pos, end): + """Skip function for unknown wire types. Raises an exception.""" + + raise _DecodeError('Tag had invalid wire type.') + +def _FieldSkipper(): + """Constructs the SkipField function.""" + + WIRETYPE_TO_SKIPPER = [ + _SkipVarint, + _SkipFixed64, + _SkipLengthDelimited, + _SkipGroup, + _EndGroup, + _SkipFixed32, + _RaiseInvalidWireType, + _RaiseInvalidWireType, + ] + + wiretype_mask = wire_format.TAG_TYPE_MASK + local_ord = ord + + def SkipField(buffer, pos, end, tag_bytes): + """Skips a field with the specified tag. + + |pos| should point to the byte immediately after the tag. + + Returns: + The new position (after the tag value), or -1 if the tag is an end-group + tag (in which case the calling loop should break). + """ + + # The wire type is always in the first byte since varints are little-endian. + wire_type = local_ord(tag_bytes[0]) & wiretype_mask + return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) + + return SkipField + +SkipField = _FieldSkipper() diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/descriptor_test.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/descriptor_test.py index 92447a7a..05c27452 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/descriptor_test.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/descriptor_test.py @@ -1,334 +1,334 @@ -#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Unittest for google.protobuf.internal.descriptor."""
-
-__author__ = '[email protected] (Will Robinson)'
-
-import unittest
-from google.protobuf import unittest_import_pb2
-from google.protobuf import unittest_pb2
-from google.protobuf import descriptor_pb2
-from google.protobuf import descriptor
-from google.protobuf import text_format
-
-
-TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII = """
-name: 'TestEmptyMessage'
-"""
-
-
-class DescriptorTest(unittest.TestCase):
-
- def setUp(self):
- self.my_file = descriptor.FileDescriptor(
- name='some/filename/some.proto',
- package='protobuf_unittest'
- )
- self.my_enum = descriptor.EnumDescriptor(
- name='ForeignEnum',
- full_name='protobuf_unittest.ForeignEnum',
- filename=None,
- file=self.my_file,
- values=[
- descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4),
- descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5),
- descriptor.EnumValueDescriptor(name='FOREIGN_BAZ', index=2, number=6),
- ])
- self.my_message = descriptor.Descriptor(
- name='NestedMessage',
- full_name='protobuf_unittest.TestAllTypes.NestedMessage',
- filename=None,
- file=self.my_file,
- containing_type=None,
- fields=[
- descriptor.FieldDescriptor(
- name='bb',
- full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb',
- index=0, number=1,
- type=5, cpp_type=1, label=1,
- has_default_value=False, default_value=0,
- message_type=None, enum_type=None, containing_type=None,
- is_extension=False, extension_scope=None),
- ],
- nested_types=[],
- enum_types=[
- self.my_enum,
- ],
- extensions=[])
- self.my_method = descriptor.MethodDescriptor(
- name='Bar',
- full_name='protobuf_unittest.TestService.Bar',
- index=0,
- containing_service=None,
- input_type=None,
- output_type=None)
- self.my_service = descriptor.ServiceDescriptor(
- name='TestServiceWithOptions',
- full_name='protobuf_unittest.TestServiceWithOptions',
- file=self.my_file,
- index=0,
- methods=[
- self.my_method
- ])
-
- def testEnumFixups(self):
- self.assertEqual(self.my_enum, self.my_enum.values[0].type)
-
- def testContainingTypeFixups(self):
- self.assertEqual(self.my_message, self.my_message.fields[0].containing_type)
- self.assertEqual(self.my_message, self.my_enum.containing_type)
-
- def testContainingServiceFixups(self):
- self.assertEqual(self.my_service, self.my_method.containing_service)
-
- def testGetOptions(self):
- self.assertEqual(self.my_enum.GetOptions(),
- descriptor_pb2.EnumOptions())
- self.assertEqual(self.my_enum.values[0].GetOptions(),
- descriptor_pb2.EnumValueOptions())
- self.assertEqual(self.my_message.GetOptions(),
- descriptor_pb2.MessageOptions())
- self.assertEqual(self.my_message.fields[0].GetOptions(),
- descriptor_pb2.FieldOptions())
- self.assertEqual(self.my_method.GetOptions(),
- descriptor_pb2.MethodOptions())
- self.assertEqual(self.my_service.GetOptions(),
- descriptor_pb2.ServiceOptions())
-
- def testFileDescriptorReferences(self):
- self.assertEqual(self.my_enum.file, self.my_file)
- self.assertEqual(self.my_message.file, self.my_file)
-
- def testFileDescriptor(self):
- self.assertEqual(self.my_file.name, 'some/filename/some.proto')
- self.assertEqual(self.my_file.package, 'protobuf_unittest')
-
-
-class DescriptorCopyToProtoTest(unittest.TestCase):
- """Tests for CopyTo functions of Descriptor."""
-
- def _AssertProtoEqual(self, actual_proto, expected_class, expected_ascii):
- expected_proto = expected_class()
- text_format.Merge(expected_ascii, expected_proto)
-
- self.assertEqual(
- actual_proto, expected_proto,
- 'Not equal,\nActual:\n%s\nExpected:\n%s\n'
- % (str(actual_proto), str(expected_proto)))
-
- def _InternalTestCopyToProto(self, desc, expected_proto_class,
- expected_proto_ascii):
- actual = expected_proto_class()
- desc.CopyToProto(actual)
- self._AssertProtoEqual(
- actual, expected_proto_class, expected_proto_ascii)
-
- def testCopyToProto_EmptyMessage(self):
- self._InternalTestCopyToProto(
- unittest_pb2.TestEmptyMessage.DESCRIPTOR,
- descriptor_pb2.DescriptorProto,
- TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII)
-
- def testCopyToProto_NestedMessage(self):
- TEST_NESTED_MESSAGE_ASCII = """
- name: 'NestedMessage'
- field: <
- name: 'bb'
- number: 1
- label: 1 # Optional
- type: 5 # TYPE_INT32
- >
- """
-
- self._InternalTestCopyToProto(
- unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR,
- descriptor_pb2.DescriptorProto,
- TEST_NESTED_MESSAGE_ASCII)
-
- def testCopyToProto_ForeignNestedMessage(self):
- TEST_FOREIGN_NESTED_ASCII = """
- name: 'TestForeignNested'
- field: <
- name: 'foreign_nested'
- number: 1
- label: 1 # Optional
- type: 11 # TYPE_MESSAGE
- type_name: '.protobuf_unittest.TestAllTypes.NestedMessage'
- >
- """
-
- self._InternalTestCopyToProto(
- unittest_pb2.TestForeignNested.DESCRIPTOR,
- descriptor_pb2.DescriptorProto,
- TEST_FOREIGN_NESTED_ASCII)
-
- def testCopyToProto_ForeignEnum(self):
- TEST_FOREIGN_ENUM_ASCII = """
- name: 'ForeignEnum'
- value: <
- name: 'FOREIGN_FOO'
- number: 4
- >
- value: <
- name: 'FOREIGN_BAR'
- number: 5
- >
- value: <
- name: 'FOREIGN_BAZ'
- number: 6
- >
- """
-
- self._InternalTestCopyToProto(
- unittest_pb2._FOREIGNENUM,
- descriptor_pb2.EnumDescriptorProto,
- TEST_FOREIGN_ENUM_ASCII)
-
- def testCopyToProto_Options(self):
- TEST_DEPRECATED_FIELDS_ASCII = """
- name: 'TestDeprecatedFields'
- field: <
- name: 'deprecated_int32'
- number: 1
- label: 1 # Optional
- type: 5 # TYPE_INT32
- options: <
- deprecated: true
- >
- >
- """
-
- self._InternalTestCopyToProto(
- unittest_pb2.TestDeprecatedFields.DESCRIPTOR,
- descriptor_pb2.DescriptorProto,
- TEST_DEPRECATED_FIELDS_ASCII)
-
- def testCopyToProto_AllExtensions(self):
- TEST_EMPTY_MESSAGE_WITH_EXTENSIONS_ASCII = """
- name: 'TestEmptyMessageWithExtensions'
- extension_range: <
- start: 1
- end: 536870912
- >
- """
-
- self._InternalTestCopyToProto(
- unittest_pb2.TestEmptyMessageWithExtensions.DESCRIPTOR,
- descriptor_pb2.DescriptorProto,
- TEST_EMPTY_MESSAGE_WITH_EXTENSIONS_ASCII)
-
- def testCopyToProto_SeveralExtensions(self):
- TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII = """
- name: 'TestMultipleExtensionRanges'
- extension_range: <
- start: 42
- end: 43
- >
- extension_range: <
- start: 4143
- end: 4244
- >
- extension_range: <
- start: 65536
- end: 536870912
- >
- """
-
- self._InternalTestCopyToProto(
- unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR,
- descriptor_pb2.DescriptorProto,
- TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII)
-
- def testCopyToProto_FileDescriptor(self):
- UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = ("""
- name: 'google/protobuf/unittest_import.proto'
- package: 'protobuf_unittest_import'
- message_type: <
- name: 'ImportMessage'
- field: <
- name: 'd'
- number: 1
- label: 1 # Optional
- type: 5 # TYPE_INT32
- >
- >
- """ +
- """enum_type: <
- name: 'ImportEnum'
- value: <
- name: 'IMPORT_FOO'
- number: 7
- >
- value: <
- name: 'IMPORT_BAR'
- number: 8
- >
- value: <
- name: 'IMPORT_BAZ'
- number: 9
- >
- >
- options: <
- java_package: 'com.google.protobuf.test'
- optimize_for: 1 # SPEED
- >
- """)
-
- self._InternalTestCopyToProto(
- unittest_import_pb2.DESCRIPTOR,
- descriptor_pb2.FileDescriptorProto,
- UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII)
-
- def testCopyToProto_ServiceDescriptor(self):
- TEST_SERVICE_ASCII = """
- name: 'TestService'
- method: <
- name: 'Foo'
- input_type: '.protobuf_unittest.FooRequest'
- output_type: '.protobuf_unittest.FooResponse'
- >
- method: <
- name: 'Bar'
- input_type: '.protobuf_unittest.BarRequest'
- output_type: '.protobuf_unittest.BarResponse'
- >
- """
-
- self._InternalTestCopyToProto(
- unittest_pb2.TestService.DESCRIPTOR,
- descriptor_pb2.ServiceDescriptorProto,
- TEST_SERVICE_ASCII)
-
-
-if __name__ == '__main__':
- unittest.main()
+#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unittest for google.protobuf.internal.descriptor.""" + +__author__ = '[email protected] (Will Robinson)' + +import unittest +from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf import descriptor_pb2 +from google.protobuf import descriptor +from google.protobuf import text_format + + +TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII = """ +name: 'TestEmptyMessage' +""" + + +class DescriptorTest(unittest.TestCase): + + def setUp(self): + self.my_file = descriptor.FileDescriptor( + name='some/filename/some.proto', + package='protobuf_unittest' + ) + self.my_enum = descriptor.EnumDescriptor( + name='ForeignEnum', + full_name='protobuf_unittest.ForeignEnum', + filename=None, + file=self.my_file, + values=[ + descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4), + descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5), + descriptor.EnumValueDescriptor(name='FOREIGN_BAZ', index=2, number=6), + ]) + self.my_message = descriptor.Descriptor( + name='NestedMessage', + full_name='protobuf_unittest.TestAllTypes.NestedMessage', + filename=None, + file=self.my_file, + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='bb', + full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb', + index=0, number=1, + type=5, cpp_type=1, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None), + ], + nested_types=[], + enum_types=[ + self.my_enum, + ], + extensions=[]) + self.my_method = descriptor.MethodDescriptor( + name='Bar', + full_name='protobuf_unittest.TestService.Bar', + index=0, + containing_service=None, + input_type=None, + output_type=None) + self.my_service = descriptor.ServiceDescriptor( + name='TestServiceWithOptions', + full_name='protobuf_unittest.TestServiceWithOptions', + file=self.my_file, + index=0, + methods=[ + self.my_method + ]) + + def testEnumFixups(self): + self.assertEqual(self.my_enum, self.my_enum.values[0].type) + + def testContainingTypeFixups(self): + self.assertEqual(self.my_message, self.my_message.fields[0].containing_type) + self.assertEqual(self.my_message, self.my_enum.containing_type) + + def testContainingServiceFixups(self): + self.assertEqual(self.my_service, self.my_method.containing_service) + + def testGetOptions(self): + self.assertEqual(self.my_enum.GetOptions(), + descriptor_pb2.EnumOptions()) + self.assertEqual(self.my_enum.values[0].GetOptions(), + descriptor_pb2.EnumValueOptions()) + self.assertEqual(self.my_message.GetOptions(), + descriptor_pb2.MessageOptions()) + self.assertEqual(self.my_message.fields[0].GetOptions(), + descriptor_pb2.FieldOptions()) + self.assertEqual(self.my_method.GetOptions(), + descriptor_pb2.MethodOptions()) + self.assertEqual(self.my_service.GetOptions(), + descriptor_pb2.ServiceOptions()) + + def testFileDescriptorReferences(self): + self.assertEqual(self.my_enum.file, self.my_file) + self.assertEqual(self.my_message.file, self.my_file) + + def testFileDescriptor(self): + self.assertEqual(self.my_file.name, 'some/filename/some.proto') + self.assertEqual(self.my_file.package, 'protobuf_unittest') + + +class DescriptorCopyToProtoTest(unittest.TestCase): + """Tests for CopyTo functions of Descriptor.""" + + def _AssertProtoEqual(self, actual_proto, expected_class, expected_ascii): + expected_proto = expected_class() + text_format.Merge(expected_ascii, expected_proto) + + self.assertEqual( + actual_proto, expected_proto, + 'Not equal,\nActual:\n%s\nExpected:\n%s\n' + % (str(actual_proto), str(expected_proto))) + + def _InternalTestCopyToProto(self, desc, expected_proto_class, + expected_proto_ascii): + actual = expected_proto_class() + desc.CopyToProto(actual) + self._AssertProtoEqual( + actual, expected_proto_class, expected_proto_ascii) + + def testCopyToProto_EmptyMessage(self): + self._InternalTestCopyToProto( + unittest_pb2.TestEmptyMessage.DESCRIPTOR, + descriptor_pb2.DescriptorProto, + TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII) + + def testCopyToProto_NestedMessage(self): + TEST_NESTED_MESSAGE_ASCII = """ + name: 'NestedMessage' + field: < + name: 'bb' + number: 1 + label: 1 # Optional + type: 5 # TYPE_INT32 + > + """ + + self._InternalTestCopyToProto( + unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR, + descriptor_pb2.DescriptorProto, + TEST_NESTED_MESSAGE_ASCII) + + def testCopyToProto_ForeignNestedMessage(self): + TEST_FOREIGN_NESTED_ASCII = """ + name: 'TestForeignNested' + field: < + name: 'foreign_nested' + number: 1 + label: 1 # Optional + type: 11 # TYPE_MESSAGE + type_name: '.protobuf_unittest.TestAllTypes.NestedMessage' + > + """ + + self._InternalTestCopyToProto( + unittest_pb2.TestForeignNested.DESCRIPTOR, + descriptor_pb2.DescriptorProto, + TEST_FOREIGN_NESTED_ASCII) + + def testCopyToProto_ForeignEnum(self): + TEST_FOREIGN_ENUM_ASCII = """ + name: 'ForeignEnum' + value: < + name: 'FOREIGN_FOO' + number: 4 + > + value: < + name: 'FOREIGN_BAR' + number: 5 + > + value: < + name: 'FOREIGN_BAZ' + number: 6 + > + """ + + self._InternalTestCopyToProto( + unittest_pb2._FOREIGNENUM, + descriptor_pb2.EnumDescriptorProto, + TEST_FOREIGN_ENUM_ASCII) + + def testCopyToProto_Options(self): + TEST_DEPRECATED_FIELDS_ASCII = """ + name: 'TestDeprecatedFields' + field: < + name: 'deprecated_int32' + number: 1 + label: 1 # Optional + type: 5 # TYPE_INT32 + options: < + deprecated: true + > + > + """ + + self._InternalTestCopyToProto( + unittest_pb2.TestDeprecatedFields.DESCRIPTOR, + descriptor_pb2.DescriptorProto, + TEST_DEPRECATED_FIELDS_ASCII) + + def testCopyToProto_AllExtensions(self): + TEST_EMPTY_MESSAGE_WITH_EXTENSIONS_ASCII = """ + name: 'TestEmptyMessageWithExtensions' + extension_range: < + start: 1 + end: 536870912 + > + """ + + self._InternalTestCopyToProto( + unittest_pb2.TestEmptyMessageWithExtensions.DESCRIPTOR, + descriptor_pb2.DescriptorProto, + TEST_EMPTY_MESSAGE_WITH_EXTENSIONS_ASCII) + + def testCopyToProto_SeveralExtensions(self): + TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII = """ + name: 'TestMultipleExtensionRanges' + extension_range: < + start: 42 + end: 43 + > + extension_range: < + start: 4143 + end: 4244 + > + extension_range: < + start: 65536 + end: 536870912 + > + """ + + self._InternalTestCopyToProto( + unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR, + descriptor_pb2.DescriptorProto, + TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII) + + def testCopyToProto_FileDescriptor(self): + UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = (""" + name: 'google/protobuf/unittest_import.proto' + package: 'protobuf_unittest_import' + message_type: < + name: 'ImportMessage' + field: < + name: 'd' + number: 1 + label: 1 # Optional + type: 5 # TYPE_INT32 + > + > + """ + + """enum_type: < + name: 'ImportEnum' + value: < + name: 'IMPORT_FOO' + number: 7 + > + value: < + name: 'IMPORT_BAR' + number: 8 + > + value: < + name: 'IMPORT_BAZ' + number: 9 + > + > + options: < + java_package: 'com.google.protobuf.test' + optimize_for: 1 # SPEED + > + """) + + self._InternalTestCopyToProto( + unittest_import_pb2.DESCRIPTOR, + descriptor_pb2.FileDescriptorProto, + UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII) + + def testCopyToProto_ServiceDescriptor(self): + TEST_SERVICE_ASCII = """ + name: 'TestService' + method: < + name: 'Foo' + input_type: '.protobuf_unittest.FooRequest' + output_type: '.protobuf_unittest.FooResponse' + > + method: < + name: 'Bar' + input_type: '.protobuf_unittest.BarRequest' + output_type: '.protobuf_unittest.BarResponse' + > + """ + + self._InternalTestCopyToProto( + unittest_pb2.TestService.DESCRIPTOR, + descriptor_pb2.ServiceDescriptorProto, + TEST_SERVICE_ASCII) + + +if __name__ == '__main__': + unittest.main() diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/encoder.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/encoder.py index 645d14a6..aa05d5b3 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/encoder.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/encoder.py @@ -1,686 +1,686 @@ -# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Code for encoding protocol message primitives.
-
-Contains the logic for encoding every logical protocol field type
-into one of the 5 physical wire types.
-
-This code is designed to push the Python interpreter's performance to the
-limits.
-
-The basic idea is that at startup time, for every field (i.e. every
-FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The
-sizer takes a value of this field's type and computes its byte size. The
-encoder takes a writer function and a value. It encodes the value into byte
-strings and invokes the writer function to write those strings. Typically the
-writer function is the write() method of a cStringIO.
-
-We try to do as much work as possible when constructing the writer and the
-sizer rather than when calling them. In particular:
-* We copy any needed global functions to local variables, so that we do not need
- to do costly global table lookups at runtime.
-* Similarly, we try to do any attribute lookups at startup time if possible.
-* Every field's tag is encoded to bytes at startup, since it can't change at
- runtime.
-* Whatever component of the field size we can compute at startup, we do.
-* We *avoid* sharing code if doing so would make the code slower and not sharing
- does not burden us too much. For example, encoders for repeated fields do
- not just call the encoders for singular fields in a loop because this would
- add an extra function call overhead for every loop iteration; instead, we
- manually inline the single-value encoder into the loop.
-* If a Python function lacks a return statement, Python actually generates
- instructions to pop the result of the last statement off the stack, push
- None onto the stack, and then return that. If we really don't care what
- value is returned, then we can save two instructions by returning the
- result of the last statement. It looks funny but it helps.
-* We assume that type and bounds checking has happened at a higher level.
-"""
-
-__author__ = '[email protected] (Kenton Varda)'
-
-import struct
-from google.protobuf.internal import wire_format
-
-
-def _VarintSize(value):
- """Compute the size of a varint value."""
- if value <= 0x7f: return 1
- if value <= 0x3fff: return 2
- if value <= 0x1fffff: return 3
- if value <= 0xfffffff: return 4
- if value <= 0x7ffffffff: return 5
- if value <= 0x3ffffffffff: return 6
- if value <= 0x1ffffffffffff: return 7
- if value <= 0xffffffffffffff: return 8
- if value <= 0x7fffffffffffffff: return 9
- return 10
-
-
-def _SignedVarintSize(value):
- """Compute the size of a signed varint value."""
- if value < 0: return 10
- if value <= 0x7f: return 1
- if value <= 0x3fff: return 2
- if value <= 0x1fffff: return 3
- if value <= 0xfffffff: return 4
- if value <= 0x7ffffffff: return 5
- if value <= 0x3ffffffffff: return 6
- if value <= 0x1ffffffffffff: return 7
- if value <= 0xffffffffffffff: return 8
- if value <= 0x7fffffffffffffff: return 9
- return 10
-
-
-def _TagSize(field_number):
- """Returns the number of bytes required to serialize a tag with this field
- number."""
- # Just pass in type 0, since the type won't affect the tag+type size.
- return _VarintSize(wire_format.PackTag(field_number, 0))
-
-
-# --------------------------------------------------------------------
-# In this section we define some generic sizers. Each of these functions
-# takes parameters specific to a particular field type, e.g. int32 or fixed64.
-# It returns another function which in turn takes parameters specific to a
-# particular field, e.g. the field number and whether it is repeated or packed.
-# Look at the next section to see how these are used.
-
-
-def _SimpleSizer(compute_value_size):
- """A sizer which uses the function compute_value_size to compute the size of
- each value. Typically compute_value_size is _VarintSize."""
-
- def SpecificSizer(field_number, is_repeated, is_packed):
- tag_size = _TagSize(field_number)
- if is_packed:
- local_VarintSize = _VarintSize
- def PackedFieldSize(value):
- result = 0
- for element in value:
- result += compute_value_size(element)
- return result + local_VarintSize(result) + tag_size
- return PackedFieldSize
- elif is_repeated:
- def RepeatedFieldSize(value):
- result = tag_size * len(value)
- for element in value:
- result += compute_value_size(element)
- return result
- return RepeatedFieldSize
- else:
- def FieldSize(value):
- return tag_size + compute_value_size(value)
- return FieldSize
-
- return SpecificSizer
-
-
-def _ModifiedSizer(compute_value_size, modify_value):
- """Like SimpleSizer, but modify_value is invoked on each value before it is
- passed to compute_value_size. modify_value is typically ZigZagEncode."""
-
- def SpecificSizer(field_number, is_repeated, is_packed):
- tag_size = _TagSize(field_number)
- if is_packed:
- local_VarintSize = _VarintSize
- def PackedFieldSize(value):
- result = 0
- for element in value:
- result += compute_value_size(modify_value(element))
- return result + local_VarintSize(result) + tag_size
- return PackedFieldSize
- elif is_repeated:
- def RepeatedFieldSize(value):
- result = tag_size * len(value)
- for element in value:
- result += compute_value_size(modify_value(element))
- return result
- return RepeatedFieldSize
- else:
- def FieldSize(value):
- return tag_size + compute_value_size(modify_value(value))
- return FieldSize
-
- return SpecificSizer
-
-
-def _FixedSizer(value_size):
- """Like _SimpleSizer except for a fixed-size field. The input is the size
- of one value."""
-
- def SpecificSizer(field_number, is_repeated, is_packed):
- tag_size = _TagSize(field_number)
- if is_packed:
- local_VarintSize = _VarintSize
- def PackedFieldSize(value):
- result = len(value) * value_size
- return result + local_VarintSize(result) + tag_size
- return PackedFieldSize
- elif is_repeated:
- element_size = value_size + tag_size
- def RepeatedFieldSize(value):
- return len(value) * element_size
- return RepeatedFieldSize
- else:
- field_size = value_size + tag_size
- def FieldSize(value):
- return field_size
- return FieldSize
-
- return SpecificSizer
-
-
-# ====================================================================
-# Here we declare a sizer constructor for each field type. Each "sizer
-# constructor" is a function that takes (field_number, is_repeated, is_packed)
-# as parameters and returns a sizer, which in turn takes a field value as
-# a parameter and returns its encoded size.
-
-
-Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize)
-
-UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize)
-
-SInt32Sizer = SInt64Sizer = _ModifiedSizer(
- _SignedVarintSize, wire_format.ZigZagEncode)
-
-Fixed32Sizer = SFixed32Sizer = FloatSizer = _FixedSizer(4)
-Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8)
-
-BoolSizer = _FixedSizer(1)
-
-
-def StringSizer(field_number, is_repeated, is_packed):
- """Returns a sizer for a string field."""
-
- tag_size = _TagSize(field_number)
- local_VarintSize = _VarintSize
- local_len = len
- assert not is_packed
- if is_repeated:
- def RepeatedFieldSize(value):
- result = tag_size * len(value)
- for element in value:
- l = local_len(element.encode('utf-8'))
- result += local_VarintSize(l) + l
- return result
- return RepeatedFieldSize
- else:
- def FieldSize(value):
- l = local_len(value.encode('utf-8'))
- return tag_size + local_VarintSize(l) + l
- return FieldSize
-
-
-def BytesSizer(field_number, is_repeated, is_packed):
- """Returns a sizer for a bytes field."""
-
- tag_size = _TagSize(field_number)
- local_VarintSize = _VarintSize
- local_len = len
- assert not is_packed
- if is_repeated:
- def RepeatedFieldSize(value):
- result = tag_size * len(value)
- for element in value:
- l = local_len(element)
- result += local_VarintSize(l) + l
- return result
- return RepeatedFieldSize
- else:
- def FieldSize(value):
- l = local_len(value)
- return tag_size + local_VarintSize(l) + l
- return FieldSize
-
-
-def GroupSizer(field_number, is_repeated, is_packed):
- """Returns a sizer for a group field."""
-
- tag_size = _TagSize(field_number) * 2
- assert not is_packed
- if is_repeated:
- def RepeatedFieldSize(value):
- result = tag_size * len(value)
- for element in value:
- result += element.ByteSize()
- return result
- return RepeatedFieldSize
- else:
- def FieldSize(value):
- return tag_size + value.ByteSize()
- return FieldSize
-
-
-def MessageSizer(field_number, is_repeated, is_packed):
- """Returns a sizer for a message field."""
-
- tag_size = _TagSize(field_number)
- local_VarintSize = _VarintSize
- assert not is_packed
- if is_repeated:
- def RepeatedFieldSize(value):
- result = tag_size * len(value)
- for element in value:
- l = element.ByteSize()
- result += local_VarintSize(l) + l
- return result
- return RepeatedFieldSize
- else:
- def FieldSize(value):
- l = value.ByteSize()
- return tag_size + local_VarintSize(l) + l
- return FieldSize
-
-
-# --------------------------------------------------------------------
-# MessageSet is special.
-
-
-def MessageSetItemSizer(field_number):
- """Returns a sizer for extensions of MessageSet.
-
- The message set message looks like this:
- message MessageSet {
- repeated group Item = 1 {
- required int32 type_id = 2;
- required string message = 3;
- }
- }
- """
- static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) +
- _TagSize(3))
- local_VarintSize = _VarintSize
-
- def FieldSize(value):
- l = value.ByteSize()
- return static_size + local_VarintSize(l) + l
-
- return FieldSize
-
-
-# ====================================================================
-# Encoders!
-
-
-def _VarintEncoder():
- """Return an encoder for a basic varint value (does not include tag)."""
-
- local_chr = chr
- def EncodeVarint(write, value):
- bits = value & 0x7f
- value >>= 7
- while value:
- write(local_chr(0x80|bits))
- bits = value & 0x7f
- value >>= 7
- return write(local_chr(bits))
-
- return EncodeVarint
-
-
-def _SignedVarintEncoder():
- """Return an encoder for a basic signed varint value (does not include
- tag)."""
-
- local_chr = chr
- def EncodeSignedVarint(write, value):
- if value < 0:
- value += (1 << 64)
- bits = value & 0x7f
- value >>= 7
- while value:
- write(local_chr(0x80|bits))
- bits = value & 0x7f
- value >>= 7
- return write(local_chr(bits))
-
- return EncodeSignedVarint
-
-
-_EncodeVarint = _VarintEncoder()
-_EncodeSignedVarint = _SignedVarintEncoder()
-
-
-def _VarintBytes(value):
- """Encode the given integer as a varint and return the bytes. This is only
- called at startup time so it doesn't need to be fast."""
-
- pieces = []
- _EncodeVarint(pieces.append, value)
- return "".join(pieces)
-
-
-def TagBytes(field_number, wire_type):
- """Encode the given tag and return the bytes. Only called at startup."""
-
- return _VarintBytes(wire_format.PackTag(field_number, wire_type))
-
-# --------------------------------------------------------------------
-# As with sizers (see above), we have a number of common encoder
-# implementations.
-
-
-def _SimpleEncoder(wire_type, encode_value, compute_value_size):
- """Return a constructor for an encoder for fields of a particular type.
-
- Args:
- wire_type: The field's wire type, for encoding tags.
- encode_value: A function which encodes an individual value, e.g.
- _EncodeVarint().
- compute_value_size: A function which computes the size of an individual
- value, e.g. _VarintSize().
- """
-
- def SpecificEncoder(field_number, is_repeated, is_packed):
- if is_packed:
- tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
- local_EncodeVarint = _EncodeVarint
- def EncodePackedField(write, value):
- write(tag_bytes)
- size = 0
- for element in value:
- size += compute_value_size(element)
- local_EncodeVarint(write, size)
- for element in value:
- encode_value(write, element)
- return EncodePackedField
- elif is_repeated:
- tag_bytes = TagBytes(field_number, wire_type)
- def EncodeRepeatedField(write, value):
- for element in value:
- write(tag_bytes)
- encode_value(write, element)
- return EncodeRepeatedField
- else:
- tag_bytes = TagBytes(field_number, wire_type)
- def EncodeField(write, value):
- write(tag_bytes)
- return encode_value(write, value)
- return EncodeField
-
- return SpecificEncoder
-
-
-def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value):
- """Like SimpleEncoder but additionally invokes modify_value on every value
- before passing it to encode_value. Usually modify_value is ZigZagEncode."""
-
- def SpecificEncoder(field_number, is_repeated, is_packed):
- if is_packed:
- tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
- local_EncodeVarint = _EncodeVarint
- def EncodePackedField(write, value):
- write(tag_bytes)
- size = 0
- for element in value:
- size += compute_value_size(modify_value(element))
- local_EncodeVarint(write, size)
- for element in value:
- encode_value(write, modify_value(element))
- return EncodePackedField
- elif is_repeated:
- tag_bytes = TagBytes(field_number, wire_type)
- def EncodeRepeatedField(write, value):
- for element in value:
- write(tag_bytes)
- encode_value(write, modify_value(element))
- return EncodeRepeatedField
- else:
- tag_bytes = TagBytes(field_number, wire_type)
- def EncodeField(write, value):
- write(tag_bytes)
- return encode_value(write, modify_value(value))
- return EncodeField
-
- return SpecificEncoder
-
-
-def _StructPackEncoder(wire_type, format):
- """Return a constructor for an encoder for a fixed-width field.
-
- Args:
- wire_type: The field's wire type, for encoding tags.
- format: The format string to pass to struct.pack().
- """
-
- value_size = struct.calcsize(format)
-
- def SpecificEncoder(field_number, is_repeated, is_packed):
- local_struct_pack = struct.pack
- if is_packed:
- tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
- local_EncodeVarint = _EncodeVarint
- def EncodePackedField(write, value):
- write(tag_bytes)
- local_EncodeVarint(write, len(value) * value_size)
- for element in value:
- write(local_struct_pack(format, element))
- return EncodePackedField
- elif is_repeated:
- tag_bytes = TagBytes(field_number, wire_type)
- def EncodeRepeatedField(write, value):
- for element in value:
- write(tag_bytes)
- write(local_struct_pack(format, element))
- return EncodeRepeatedField
- else:
- tag_bytes = TagBytes(field_number, wire_type)
- def EncodeField(write, value):
- write(tag_bytes)
- return write(local_struct_pack(format, value))
- return EncodeField
-
- return SpecificEncoder
-
-
-# ====================================================================
-# Here we declare an encoder constructor for each field type. These work
-# very similarly to sizer constructors, described earlier.
-
-
-Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder(
- wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize)
-
-UInt32Encoder = UInt64Encoder = _SimpleEncoder(
- wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize)
-
-SInt32Encoder = SInt64Encoder = _ModifiedEncoder(
- wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize,
- wire_format.ZigZagEncode)
-
-# Note that Python conveniently guarantees that when using the '<' prefix on
-# formats, they will also have the same size across all platforms (as opposed
-# to without the prefix, where their sizes depend on the C compiler's basic
-# type sizes).
-Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I')
-Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q')
-SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i')
-SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q')
-FloatEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<f')
-DoubleEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<d')
-
-
-def BoolEncoder(field_number, is_repeated, is_packed):
- """Returns an encoder for a boolean field."""
-
- false_byte = chr(0)
- true_byte = chr(1)
- if is_packed:
- tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
- local_EncodeVarint = _EncodeVarint
- def EncodePackedField(write, value):
- write(tag_bytes)
- local_EncodeVarint(write, len(value))
- for element in value:
- if element:
- write(true_byte)
- else:
- write(false_byte)
- return EncodePackedField
- elif is_repeated:
- tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
- def EncodeRepeatedField(write, value):
- for element in value:
- write(tag_bytes)
- if element:
- write(true_byte)
- else:
- write(false_byte)
- return EncodeRepeatedField
- else:
- tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT)
- def EncodeField(write, value):
- write(tag_bytes)
- if value:
- return write(true_byte)
- return write(false_byte)
- return EncodeField
-
-
-def StringEncoder(field_number, is_repeated, is_packed):
- """Returns an encoder for a string field."""
-
- tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
- local_EncodeVarint = _EncodeVarint
- local_len = len
- assert not is_packed
- if is_repeated:
- def EncodeRepeatedField(write, value):
- for element in value:
- encoded = element.encode('utf-8')
- write(tag)
- local_EncodeVarint(write, local_len(encoded))
- write(encoded)
- return EncodeRepeatedField
- else:
- def EncodeField(write, value):
- encoded = value.encode('utf-8')
- write(tag)
- local_EncodeVarint(write, local_len(encoded))
- return write(encoded)
- return EncodeField
-
-
-def BytesEncoder(field_number, is_repeated, is_packed):
- """Returns an encoder for a bytes field."""
-
- tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
- local_EncodeVarint = _EncodeVarint
- local_len = len
- assert not is_packed
- if is_repeated:
- def EncodeRepeatedField(write, value):
- for element in value:
- write(tag)
- local_EncodeVarint(write, local_len(element))
- write(element)
- return EncodeRepeatedField
- else:
- def EncodeField(write, value):
- write(tag)
- local_EncodeVarint(write, local_len(value))
- return write(value)
- return EncodeField
-
-
-def GroupEncoder(field_number, is_repeated, is_packed):
- """Returns an encoder for a group field."""
-
- start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP)
- end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP)
- assert not is_packed
- if is_repeated:
- def EncodeRepeatedField(write, value):
- for element in value:
- write(start_tag)
- element._InternalSerialize(write)
- write(end_tag)
- return EncodeRepeatedField
- else:
- def EncodeField(write, value):
- write(start_tag)
- value._InternalSerialize(write)
- return write(end_tag)
- return EncodeField
-
-
-def MessageEncoder(field_number, is_repeated, is_packed):
- """Returns an encoder for a message field."""
-
- tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
- local_EncodeVarint = _EncodeVarint
- assert not is_packed
- if is_repeated:
- def EncodeRepeatedField(write, value):
- for element in value:
- write(tag)
- local_EncodeVarint(write, element.ByteSize())
- element._InternalSerialize(write)
- return EncodeRepeatedField
- else:
- def EncodeField(write, value):
- write(tag)
- local_EncodeVarint(write, value.ByteSize())
- return value._InternalSerialize(write)
- return EncodeField
-
-
-# --------------------------------------------------------------------
-# As before, MessageSet is special.
-
-
-def MessageSetItemEncoder(field_number):
- """Encoder for extensions of MessageSet.
-
- The message set message looks like this:
- message MessageSet {
- repeated group Item = 1 {
- required int32 type_id = 2;
- required string message = 3;
- }
- }
- """
- start_bytes = "".join([
- TagBytes(1, wire_format.WIRETYPE_START_GROUP),
- TagBytes(2, wire_format.WIRETYPE_VARINT),
- _VarintBytes(field_number),
- TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)])
- end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP)
- local_EncodeVarint = _EncodeVarint
-
- def EncodeField(write, value):
- write(start_bytes)
- local_EncodeVarint(write, value.ByteSize())
- value._InternalSerialize(write)
- return write(end_bytes)
-
- return EncodeField
+# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Code for encoding protocol message primitives. + +Contains the logic for encoding every logical protocol field type +into one of the 5 physical wire types. + +This code is designed to push the Python interpreter's performance to the +limits. + +The basic idea is that at startup time, for every field (i.e. every +FieldDescriptor) we construct two functions: a "sizer" and an "encoder". The +sizer takes a value of this field's type and computes its byte size. The +encoder takes a writer function and a value. It encodes the value into byte +strings and invokes the writer function to write those strings. Typically the +writer function is the write() method of a cStringIO. + +We try to do as much work as possible when constructing the writer and the +sizer rather than when calling them. In particular: +* We copy any needed global functions to local variables, so that we do not need + to do costly global table lookups at runtime. +* Similarly, we try to do any attribute lookups at startup time if possible. +* Every field's tag is encoded to bytes at startup, since it can't change at + runtime. +* Whatever component of the field size we can compute at startup, we do. +* We *avoid* sharing code if doing so would make the code slower and not sharing + does not burden us too much. For example, encoders for repeated fields do + not just call the encoders for singular fields in a loop because this would + add an extra function call overhead for every loop iteration; instead, we + manually inline the single-value encoder into the loop. +* If a Python function lacks a return statement, Python actually generates + instructions to pop the result of the last statement off the stack, push + None onto the stack, and then return that. If we really don't care what + value is returned, then we can save two instructions by returning the + result of the last statement. It looks funny but it helps. +* We assume that type and bounds checking has happened at a higher level. +""" + +__author__ = '[email protected] (Kenton Varda)' + +import struct +from google.protobuf.internal import wire_format + + +def _VarintSize(value): + """Compute the size of a varint value.""" + if value <= 0x7f: return 1 + if value <= 0x3fff: return 2 + if value <= 0x1fffff: return 3 + if value <= 0xfffffff: return 4 + if value <= 0x7ffffffff: return 5 + if value <= 0x3ffffffffff: return 6 + if value <= 0x1ffffffffffff: return 7 + if value <= 0xffffffffffffff: return 8 + if value <= 0x7fffffffffffffff: return 9 + return 10 + + +def _SignedVarintSize(value): + """Compute the size of a signed varint value.""" + if value < 0: return 10 + if value <= 0x7f: return 1 + if value <= 0x3fff: return 2 + if value <= 0x1fffff: return 3 + if value <= 0xfffffff: return 4 + if value <= 0x7ffffffff: return 5 + if value <= 0x3ffffffffff: return 6 + if value <= 0x1ffffffffffff: return 7 + if value <= 0xffffffffffffff: return 8 + if value <= 0x7fffffffffffffff: return 9 + return 10 + + +def _TagSize(field_number): + """Returns the number of bytes required to serialize a tag with this field + number.""" + # Just pass in type 0, since the type won't affect the tag+type size. + return _VarintSize(wire_format.PackTag(field_number, 0)) + + +# -------------------------------------------------------------------- +# In this section we define some generic sizers. Each of these functions +# takes parameters specific to a particular field type, e.g. int32 or fixed64. +# It returns another function which in turn takes parameters specific to a +# particular field, e.g. the field number and whether it is repeated or packed. +# Look at the next section to see how these are used. + + +def _SimpleSizer(compute_value_size): + """A sizer which uses the function compute_value_size to compute the size of + each value. Typically compute_value_size is _VarintSize.""" + + def SpecificSizer(field_number, is_repeated, is_packed): + tag_size = _TagSize(field_number) + if is_packed: + local_VarintSize = _VarintSize + def PackedFieldSize(value): + result = 0 + for element in value: + result += compute_value_size(element) + return result + local_VarintSize(result) + tag_size + return PackedFieldSize + elif is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + result += compute_value_size(element) + return result + return RepeatedFieldSize + else: + def FieldSize(value): + return tag_size + compute_value_size(value) + return FieldSize + + return SpecificSizer + + +def _ModifiedSizer(compute_value_size, modify_value): + """Like SimpleSizer, but modify_value is invoked on each value before it is + passed to compute_value_size. modify_value is typically ZigZagEncode.""" + + def SpecificSizer(field_number, is_repeated, is_packed): + tag_size = _TagSize(field_number) + if is_packed: + local_VarintSize = _VarintSize + def PackedFieldSize(value): + result = 0 + for element in value: + result += compute_value_size(modify_value(element)) + return result + local_VarintSize(result) + tag_size + return PackedFieldSize + elif is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + result += compute_value_size(modify_value(element)) + return result + return RepeatedFieldSize + else: + def FieldSize(value): + return tag_size + compute_value_size(modify_value(value)) + return FieldSize + + return SpecificSizer + + +def _FixedSizer(value_size): + """Like _SimpleSizer except for a fixed-size field. The input is the size + of one value.""" + + def SpecificSizer(field_number, is_repeated, is_packed): + tag_size = _TagSize(field_number) + if is_packed: + local_VarintSize = _VarintSize + def PackedFieldSize(value): + result = len(value) * value_size + return result + local_VarintSize(result) + tag_size + return PackedFieldSize + elif is_repeated: + element_size = value_size + tag_size + def RepeatedFieldSize(value): + return len(value) * element_size + return RepeatedFieldSize + else: + field_size = value_size + tag_size + def FieldSize(value): + return field_size + return FieldSize + + return SpecificSizer + + +# ==================================================================== +# Here we declare a sizer constructor for each field type. Each "sizer +# constructor" is a function that takes (field_number, is_repeated, is_packed) +# as parameters and returns a sizer, which in turn takes a field value as +# a parameter and returns its encoded size. + + +Int32Sizer = Int64Sizer = EnumSizer = _SimpleSizer(_SignedVarintSize) + +UInt32Sizer = UInt64Sizer = _SimpleSizer(_VarintSize) + +SInt32Sizer = SInt64Sizer = _ModifiedSizer( + _SignedVarintSize, wire_format.ZigZagEncode) + +Fixed32Sizer = SFixed32Sizer = FloatSizer = _FixedSizer(4) +Fixed64Sizer = SFixed64Sizer = DoubleSizer = _FixedSizer(8) + +BoolSizer = _FixedSizer(1) + + +def StringSizer(field_number, is_repeated, is_packed): + """Returns a sizer for a string field.""" + + tag_size = _TagSize(field_number) + local_VarintSize = _VarintSize + local_len = len + assert not is_packed + if is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + l = local_len(element.encode('utf-8')) + result += local_VarintSize(l) + l + return result + return RepeatedFieldSize + else: + def FieldSize(value): + l = local_len(value.encode('utf-8')) + return tag_size + local_VarintSize(l) + l + return FieldSize + + +def BytesSizer(field_number, is_repeated, is_packed): + """Returns a sizer for a bytes field.""" + + tag_size = _TagSize(field_number) + local_VarintSize = _VarintSize + local_len = len + assert not is_packed + if is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + l = local_len(element) + result += local_VarintSize(l) + l + return result + return RepeatedFieldSize + else: + def FieldSize(value): + l = local_len(value) + return tag_size + local_VarintSize(l) + l + return FieldSize + + +def GroupSizer(field_number, is_repeated, is_packed): + """Returns a sizer for a group field.""" + + tag_size = _TagSize(field_number) * 2 + assert not is_packed + if is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + result += element.ByteSize() + return result + return RepeatedFieldSize + else: + def FieldSize(value): + return tag_size + value.ByteSize() + return FieldSize + + +def MessageSizer(field_number, is_repeated, is_packed): + """Returns a sizer for a message field.""" + + tag_size = _TagSize(field_number) + local_VarintSize = _VarintSize + assert not is_packed + if is_repeated: + def RepeatedFieldSize(value): + result = tag_size * len(value) + for element in value: + l = element.ByteSize() + result += local_VarintSize(l) + l + return result + return RepeatedFieldSize + else: + def FieldSize(value): + l = value.ByteSize() + return tag_size + local_VarintSize(l) + l + return FieldSize + + +# -------------------------------------------------------------------- +# MessageSet is special. + + +def MessageSetItemSizer(field_number): + """Returns a sizer for extensions of MessageSet. + + The message set message looks like this: + message MessageSet { + repeated group Item = 1 { + required int32 type_id = 2; + required string message = 3; + } + } + """ + static_size = (_TagSize(1) * 2 + _TagSize(2) + _VarintSize(field_number) + + _TagSize(3)) + local_VarintSize = _VarintSize + + def FieldSize(value): + l = value.ByteSize() + return static_size + local_VarintSize(l) + l + + return FieldSize + + +# ==================================================================== +# Encoders! + + +def _VarintEncoder(): + """Return an encoder for a basic varint value (does not include tag).""" + + local_chr = chr + def EncodeVarint(write, value): + bits = value & 0x7f + value >>= 7 + while value: + write(local_chr(0x80|bits)) + bits = value & 0x7f + value >>= 7 + return write(local_chr(bits)) + + return EncodeVarint + + +def _SignedVarintEncoder(): + """Return an encoder for a basic signed varint value (does not include + tag).""" + + local_chr = chr + def EncodeSignedVarint(write, value): + if value < 0: + value += (1 << 64) + bits = value & 0x7f + value >>= 7 + while value: + write(local_chr(0x80|bits)) + bits = value & 0x7f + value >>= 7 + return write(local_chr(bits)) + + return EncodeSignedVarint + + +_EncodeVarint = _VarintEncoder() +_EncodeSignedVarint = _SignedVarintEncoder() + + +def _VarintBytes(value): + """Encode the given integer as a varint and return the bytes. This is only + called at startup time so it doesn't need to be fast.""" + + pieces = [] + _EncodeVarint(pieces.append, value) + return "".join(pieces) + + +def TagBytes(field_number, wire_type): + """Encode the given tag and return the bytes. Only called at startup.""" + + return _VarintBytes(wire_format.PackTag(field_number, wire_type)) + +# -------------------------------------------------------------------- +# As with sizers (see above), we have a number of common encoder +# implementations. + + +def _SimpleEncoder(wire_type, encode_value, compute_value_size): + """Return a constructor for an encoder for fields of a particular type. + + Args: + wire_type: The field's wire type, for encoding tags. + encode_value: A function which encodes an individual value, e.g. + _EncodeVarint(). + compute_value_size: A function which computes the size of an individual + value, e.g. _VarintSize(). + """ + + def SpecificEncoder(field_number, is_repeated, is_packed): + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value): + write(tag_bytes) + size = 0 + for element in value: + size += compute_value_size(element) + local_EncodeVarint(write, size) + for element in value: + encode_value(write, element) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeRepeatedField(write, value): + for element in value: + write(tag_bytes) + encode_value(write, element) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeField(write, value): + write(tag_bytes) + return encode_value(write, value) + return EncodeField + + return SpecificEncoder + + +def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value): + """Like SimpleEncoder but additionally invokes modify_value on every value + before passing it to encode_value. Usually modify_value is ZigZagEncode.""" + + def SpecificEncoder(field_number, is_repeated, is_packed): + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value): + write(tag_bytes) + size = 0 + for element in value: + size += compute_value_size(modify_value(element)) + local_EncodeVarint(write, size) + for element in value: + encode_value(write, modify_value(element)) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeRepeatedField(write, value): + for element in value: + write(tag_bytes) + encode_value(write, modify_value(element)) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeField(write, value): + write(tag_bytes) + return encode_value(write, modify_value(value)) + return EncodeField + + return SpecificEncoder + + +def _StructPackEncoder(wire_type, format): + """Return a constructor for an encoder for a fixed-width field. + + Args: + wire_type: The field's wire type, for encoding tags. + format: The format string to pass to struct.pack(). + """ + + value_size = struct.calcsize(format) + + def SpecificEncoder(field_number, is_repeated, is_packed): + local_struct_pack = struct.pack + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value): + write(tag_bytes) + local_EncodeVarint(write, len(value) * value_size) + for element in value: + write(local_struct_pack(format, element)) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeRepeatedField(write, value): + for element in value: + write(tag_bytes) + write(local_struct_pack(format, element)) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeField(write, value): + write(tag_bytes) + return write(local_struct_pack(format, value)) + return EncodeField + + return SpecificEncoder + + +# ==================================================================== +# Here we declare an encoder constructor for each field type. These work +# very similarly to sizer constructors, described earlier. + + +Int32Encoder = Int64Encoder = EnumEncoder = _SimpleEncoder( + wire_format.WIRETYPE_VARINT, _EncodeSignedVarint, _SignedVarintSize) + +UInt32Encoder = UInt64Encoder = _SimpleEncoder( + wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize) + +SInt32Encoder = SInt64Encoder = _ModifiedEncoder( + wire_format.WIRETYPE_VARINT, _EncodeVarint, _VarintSize, + wire_format.ZigZagEncode) + +# Note that Python conveniently guarantees that when using the '<' prefix on +# formats, they will also have the same size across all platforms (as opposed +# to without the prefix, where their sizes depend on the C compiler's basic +# type sizes). +Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I') +Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q') +SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i') +SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q') +FloatEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<f') +DoubleEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<d') + + +def BoolEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a boolean field.""" + + false_byte = chr(0) + true_byte = chr(1) + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value): + write(tag_bytes) + local_EncodeVarint(write, len(value)) + for element in value: + if element: + write(true_byte) + else: + write(false_byte) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT) + def EncodeRepeatedField(write, value): + for element in value: + write(tag_bytes) + if element: + write(true_byte) + else: + write(false_byte) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT) + def EncodeField(write, value): + write(tag_bytes) + if value: + return write(true_byte) + return write(false_byte) + return EncodeField + + +def StringEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a string field.""" + + tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + local_len = len + assert not is_packed + if is_repeated: + def EncodeRepeatedField(write, value): + for element in value: + encoded = element.encode('utf-8') + write(tag) + local_EncodeVarint(write, local_len(encoded)) + write(encoded) + return EncodeRepeatedField + else: + def EncodeField(write, value): + encoded = value.encode('utf-8') + write(tag) + local_EncodeVarint(write, local_len(encoded)) + return write(encoded) + return EncodeField + + +def BytesEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a bytes field.""" + + tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + local_len = len + assert not is_packed + if is_repeated: + def EncodeRepeatedField(write, value): + for element in value: + write(tag) + local_EncodeVarint(write, local_len(element)) + write(element) + return EncodeRepeatedField + else: + def EncodeField(write, value): + write(tag) + local_EncodeVarint(write, local_len(value)) + return write(value) + return EncodeField + + +def GroupEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a group field.""" + + start_tag = TagBytes(field_number, wire_format.WIRETYPE_START_GROUP) + end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP) + assert not is_packed + if is_repeated: + def EncodeRepeatedField(write, value): + for element in value: + write(start_tag) + element._InternalSerialize(write) + write(end_tag) + return EncodeRepeatedField + else: + def EncodeField(write, value): + write(start_tag) + value._InternalSerialize(write) + return write(end_tag) + return EncodeField + + +def MessageEncoder(field_number, is_repeated, is_packed): + """Returns an encoder for a message field.""" + + tag = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + assert not is_packed + if is_repeated: + def EncodeRepeatedField(write, value): + for element in value: + write(tag) + local_EncodeVarint(write, element.ByteSize()) + element._InternalSerialize(write) + return EncodeRepeatedField + else: + def EncodeField(write, value): + write(tag) + local_EncodeVarint(write, value.ByteSize()) + return value._InternalSerialize(write) + return EncodeField + + +# -------------------------------------------------------------------- +# As before, MessageSet is special. + + +def MessageSetItemEncoder(field_number): + """Encoder for extensions of MessageSet. + + The message set message looks like this: + message MessageSet { + repeated group Item = 1 { + required int32 type_id = 2; + required string message = 3; + } + } + """ + start_bytes = "".join([ + TagBytes(1, wire_format.WIRETYPE_START_GROUP), + TagBytes(2, wire_format.WIRETYPE_VARINT), + _VarintBytes(field_number), + TagBytes(3, wire_format.WIRETYPE_LENGTH_DELIMITED)]) + end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP) + local_EncodeVarint = _EncodeVarint + + def EncodeField(write, value): + write(start_bytes) + local_EncodeVarint(write, value.ByteSize()) + value._InternalSerialize(write) + return write(end_bytes) + + return EncodeField diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/generator_test.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/generator_test.py index 63a98a54..78360b53 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/generator_test.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/generator_test.py @@ -1,220 +1,220 @@ -#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-# TODO(robinson): Flesh this out considerably. We focused on reflection_test.py
-# first, since it's testing the subtler code, and since it provides decent
-# indirect testing of the protocol compiler output.
-
-"""Unittest that directly tests the output of the pure-Python protocol
-compiler. See //google/protobuf/reflection_test.py for a test which
-further ensures that we can use Python protocol message objects as we expect.
-"""
-
-__author__ = '[email protected] (Will Robinson)'
-
-import unittest
-from google.protobuf import unittest_import_pb2
-from google.protobuf import unittest_mset_pb2
-from google.protobuf import unittest_pb2
-from google.protobuf import unittest_no_generic_services_pb2
-
-
-MAX_EXTENSION = 536870912
-
-
-class GeneratorTest(unittest.TestCase):
-
- def testNestedMessageDescriptor(self):
- field_name = 'optional_nested_message'
- proto_type = unittest_pb2.TestAllTypes
- self.assertEqual(
- proto_type.NestedMessage.DESCRIPTOR,
- proto_type.DESCRIPTOR.fields_by_name[field_name].message_type)
-
- def testEnums(self):
- # We test only module-level enums here.
- # TODO(robinson): Examine descriptors directly to check
- # enum descriptor output.
- self.assertEqual(4, unittest_pb2.FOREIGN_FOO)
- self.assertEqual(5, unittest_pb2.FOREIGN_BAR)
- self.assertEqual(6, unittest_pb2.FOREIGN_BAZ)
-
- proto = unittest_pb2.TestAllTypes()
- self.assertEqual(1, proto.FOO)
- self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
- self.assertEqual(2, proto.BAR)
- self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
- self.assertEqual(3, proto.BAZ)
- self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
-
- def testExtremeDefaultValues(self):
- message = unittest_pb2.TestExtremeDefaultValues()
-
- # Python pre-2.6 does not have isinf() or isnan() functions, so we have
- # to provide our own.
- def isnan(val):
- # NaN is never equal to itself.
- return val != val
- def isinf(val):
- # Infinity times zero equals NaN.
- return not isnan(val) and isnan(val * 0)
-
- self.assertTrue(isinf(message.inf_double))
- self.assertTrue(message.inf_double > 0)
- self.assertTrue(isinf(message.neg_inf_double))
- self.assertTrue(message.neg_inf_double < 0)
- self.assertTrue(isnan(message.nan_double))
-
- self.assertTrue(isinf(message.inf_float))
- self.assertTrue(message.inf_float > 0)
- self.assertTrue(isinf(message.neg_inf_float))
- self.assertTrue(message.neg_inf_float < 0)
- self.assertTrue(isnan(message.nan_float))
-
- def testHasDefaultValues(self):
- desc = unittest_pb2.TestAllTypes.DESCRIPTOR
-
- expected_has_default_by_name = {
- 'optional_int32': False,
- 'repeated_int32': False,
- 'optional_nested_message': False,
- 'default_int32': True,
- }
-
- has_default_by_name = dict(
- [(f.name, f.has_default_value)
- for f in desc.fields
- if f.name in expected_has_default_by_name])
- self.assertEqual(expected_has_default_by_name, has_default_by_name)
-
- def testContainingTypeBehaviorForExtensions(self):
- self.assertEqual(unittest_pb2.optional_int32_extension.containing_type,
- unittest_pb2.TestAllExtensions.DESCRIPTOR)
- self.assertEqual(unittest_pb2.TestRequired.single.containing_type,
- unittest_pb2.TestAllExtensions.DESCRIPTOR)
-
- def testExtensionScope(self):
- self.assertEqual(unittest_pb2.optional_int32_extension.extension_scope,
- None)
- self.assertEqual(unittest_pb2.TestRequired.single.extension_scope,
- unittest_pb2.TestRequired.DESCRIPTOR)
-
- def testIsExtension(self):
- self.assertTrue(unittest_pb2.optional_int32_extension.is_extension)
- self.assertTrue(unittest_pb2.TestRequired.single.is_extension)
-
- message_descriptor = unittest_pb2.TestRequired.DESCRIPTOR
- non_extension_descriptor = message_descriptor.fields_by_name['a']
- self.assertTrue(not non_extension_descriptor.is_extension)
-
- def testOptions(self):
- proto = unittest_mset_pb2.TestMessageSet()
- self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format)
-
- def testNestedTypes(self):
- self.assertEquals(
- set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types),
- set([
- unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR,
- unittest_pb2.TestAllTypes.OptionalGroup.DESCRIPTOR,
- unittest_pb2.TestAllTypes.RepeatedGroup.DESCRIPTOR,
- ]))
- self.assertEqual(unittest_pb2.TestEmptyMessage.DESCRIPTOR.nested_types, [])
- self.assertEqual(
- unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.nested_types, [])
-
- def testContainingType(self):
- self.assertTrue(
- unittest_pb2.TestEmptyMessage.DESCRIPTOR.containing_type is None)
- self.assertTrue(
- unittest_pb2.TestAllTypes.DESCRIPTOR.containing_type is None)
- self.assertEqual(
- unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.containing_type,
- unittest_pb2.TestAllTypes.DESCRIPTOR)
- self.assertEqual(
- unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.containing_type,
- unittest_pb2.TestAllTypes.DESCRIPTOR)
- self.assertEqual(
- unittest_pb2.TestAllTypes.RepeatedGroup.DESCRIPTOR.containing_type,
- unittest_pb2.TestAllTypes.DESCRIPTOR)
-
- def testContainingTypeInEnumDescriptor(self):
- self.assertTrue(unittest_pb2._FOREIGNENUM.containing_type is None)
- self.assertEqual(unittest_pb2._TESTALLTYPES_NESTEDENUM.containing_type,
- unittest_pb2.TestAllTypes.DESCRIPTOR)
-
- def testPackage(self):
- self.assertEqual(
- unittest_pb2.TestAllTypes.DESCRIPTOR.file.package,
- 'protobuf_unittest')
- desc = unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR
- self.assertEqual(desc.file.package, 'protobuf_unittest')
- self.assertEqual(
- unittest_import_pb2.ImportMessage.DESCRIPTOR.file.package,
- 'protobuf_unittest_import')
-
- self.assertEqual(
- unittest_pb2._FOREIGNENUM.file.package, 'protobuf_unittest')
- self.assertEqual(
- unittest_pb2._TESTALLTYPES_NESTEDENUM.file.package,
- 'protobuf_unittest')
- self.assertEqual(
- unittest_import_pb2._IMPORTENUM.file.package,
- 'protobuf_unittest_import')
-
- def testExtensionRange(self):
- self.assertEqual(
- unittest_pb2.TestAllTypes.DESCRIPTOR.extension_ranges, [])
- self.assertEqual(
- unittest_pb2.TestAllExtensions.DESCRIPTOR.extension_ranges,
- [(1, MAX_EXTENSION)])
- self.assertEqual(
- unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR.extension_ranges,
- [(42, 43), (4143, 4244), (65536, MAX_EXTENSION)])
-
- def testFileDescriptor(self):
- self.assertEqual(unittest_pb2.DESCRIPTOR.name,
- 'google/protobuf/unittest.proto')
- self.assertEqual(unittest_pb2.DESCRIPTOR.package, 'protobuf_unittest')
- self.assertFalse(unittest_pb2.DESCRIPTOR.serialized_pb is None)
-
- def testNoGenericServices(self):
- # unittest_no_generic_services.proto should contain defs for everything
- # except services.
- self.assertTrue(hasattr(unittest_no_generic_services_pb2, "TestMessage"))
- self.assertTrue(hasattr(unittest_no_generic_services_pb2, "FOO"))
- self.assertTrue(hasattr(unittest_no_generic_services_pb2, "test_extension"))
- self.assertFalse(hasattr(unittest_no_generic_services_pb2, "TestService"))
-
-
-if __name__ == '__main__':
- unittest.main()
+#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# TODO(robinson): Flesh this out considerably. We focused on reflection_test.py +# first, since it's testing the subtler code, and since it provides decent +# indirect testing of the protocol compiler output. + +"""Unittest that directly tests the output of the pure-Python protocol +compiler. See //google/protobuf/reflection_test.py for a test which +further ensures that we can use Python protocol message objects as we expect. +""" + +__author__ = '[email protected] (Will Robinson)' + +import unittest +from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf import unittest_no_generic_services_pb2 + + +MAX_EXTENSION = 536870912 + + +class GeneratorTest(unittest.TestCase): + + def testNestedMessageDescriptor(self): + field_name = 'optional_nested_message' + proto_type = unittest_pb2.TestAllTypes + self.assertEqual( + proto_type.NestedMessage.DESCRIPTOR, + proto_type.DESCRIPTOR.fields_by_name[field_name].message_type) + + def testEnums(self): + # We test only module-level enums here. + # TODO(robinson): Examine descriptors directly to check + # enum descriptor output. + self.assertEqual(4, unittest_pb2.FOREIGN_FOO) + self.assertEqual(5, unittest_pb2.FOREIGN_BAR) + self.assertEqual(6, unittest_pb2.FOREIGN_BAZ) + + proto = unittest_pb2.TestAllTypes() + self.assertEqual(1, proto.FOO) + self.assertEqual(1, unittest_pb2.TestAllTypes.FOO) + self.assertEqual(2, proto.BAR) + self.assertEqual(2, unittest_pb2.TestAllTypes.BAR) + self.assertEqual(3, proto.BAZ) + self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) + + def testExtremeDefaultValues(self): + message = unittest_pb2.TestExtremeDefaultValues() + + # Python pre-2.6 does not have isinf() or isnan() functions, so we have + # to provide our own. + def isnan(val): + # NaN is never equal to itself. + return val != val + def isinf(val): + # Infinity times zero equals NaN. + return not isnan(val) and isnan(val * 0) + + self.assertTrue(isinf(message.inf_double)) + self.assertTrue(message.inf_double > 0) + self.assertTrue(isinf(message.neg_inf_double)) + self.assertTrue(message.neg_inf_double < 0) + self.assertTrue(isnan(message.nan_double)) + + self.assertTrue(isinf(message.inf_float)) + self.assertTrue(message.inf_float > 0) + self.assertTrue(isinf(message.neg_inf_float)) + self.assertTrue(message.neg_inf_float < 0) + self.assertTrue(isnan(message.nan_float)) + + def testHasDefaultValues(self): + desc = unittest_pb2.TestAllTypes.DESCRIPTOR + + expected_has_default_by_name = { + 'optional_int32': False, + 'repeated_int32': False, + 'optional_nested_message': False, + 'default_int32': True, + } + + has_default_by_name = dict( + [(f.name, f.has_default_value) + for f in desc.fields + if f.name in expected_has_default_by_name]) + self.assertEqual(expected_has_default_by_name, has_default_by_name) + + def testContainingTypeBehaviorForExtensions(self): + self.assertEqual(unittest_pb2.optional_int32_extension.containing_type, + unittest_pb2.TestAllExtensions.DESCRIPTOR) + self.assertEqual(unittest_pb2.TestRequired.single.containing_type, + unittest_pb2.TestAllExtensions.DESCRIPTOR) + + def testExtensionScope(self): + self.assertEqual(unittest_pb2.optional_int32_extension.extension_scope, + None) + self.assertEqual(unittest_pb2.TestRequired.single.extension_scope, + unittest_pb2.TestRequired.DESCRIPTOR) + + def testIsExtension(self): + self.assertTrue(unittest_pb2.optional_int32_extension.is_extension) + self.assertTrue(unittest_pb2.TestRequired.single.is_extension) + + message_descriptor = unittest_pb2.TestRequired.DESCRIPTOR + non_extension_descriptor = message_descriptor.fields_by_name['a'] + self.assertTrue(not non_extension_descriptor.is_extension) + + def testOptions(self): + proto = unittest_mset_pb2.TestMessageSet() + self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format) + + def testNestedTypes(self): + self.assertEquals( + set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types), + set([ + unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR, + unittest_pb2.TestAllTypes.OptionalGroup.DESCRIPTOR, + unittest_pb2.TestAllTypes.RepeatedGroup.DESCRIPTOR, + ])) + self.assertEqual(unittest_pb2.TestEmptyMessage.DESCRIPTOR.nested_types, []) + self.assertEqual( + unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.nested_types, []) + + def testContainingType(self): + self.assertTrue( + unittest_pb2.TestEmptyMessage.DESCRIPTOR.containing_type is None) + self.assertTrue( + unittest_pb2.TestAllTypes.DESCRIPTOR.containing_type is None) + self.assertEqual( + unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.containing_type, + unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertEqual( + unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR.containing_type, + unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertEqual( + unittest_pb2.TestAllTypes.RepeatedGroup.DESCRIPTOR.containing_type, + unittest_pb2.TestAllTypes.DESCRIPTOR) + + def testContainingTypeInEnumDescriptor(self): + self.assertTrue(unittest_pb2._FOREIGNENUM.containing_type is None) + self.assertEqual(unittest_pb2._TESTALLTYPES_NESTEDENUM.containing_type, + unittest_pb2.TestAllTypes.DESCRIPTOR) + + def testPackage(self): + self.assertEqual( + unittest_pb2.TestAllTypes.DESCRIPTOR.file.package, + 'protobuf_unittest') + desc = unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR + self.assertEqual(desc.file.package, 'protobuf_unittest') + self.assertEqual( + unittest_import_pb2.ImportMessage.DESCRIPTOR.file.package, + 'protobuf_unittest_import') + + self.assertEqual( + unittest_pb2._FOREIGNENUM.file.package, 'protobuf_unittest') + self.assertEqual( + unittest_pb2._TESTALLTYPES_NESTEDENUM.file.package, + 'protobuf_unittest') + self.assertEqual( + unittest_import_pb2._IMPORTENUM.file.package, + 'protobuf_unittest_import') + + def testExtensionRange(self): + self.assertEqual( + unittest_pb2.TestAllTypes.DESCRIPTOR.extension_ranges, []) + self.assertEqual( + unittest_pb2.TestAllExtensions.DESCRIPTOR.extension_ranges, + [(1, MAX_EXTENSION)]) + self.assertEqual( + unittest_pb2.TestMultipleExtensionRanges.DESCRIPTOR.extension_ranges, + [(42, 43), (4143, 4244), (65536, MAX_EXTENSION)]) + + def testFileDescriptor(self): + self.assertEqual(unittest_pb2.DESCRIPTOR.name, + 'google/protobuf/unittest.proto') + self.assertEqual(unittest_pb2.DESCRIPTOR.package, 'protobuf_unittest') + self.assertFalse(unittest_pb2.DESCRIPTOR.serialized_pb is None) + + def testNoGenericServices(self): + # unittest_no_generic_services.proto should contain defs for everything + # except services. + self.assertTrue(hasattr(unittest_no_generic_services_pb2, "TestMessage")) + self.assertTrue(hasattr(unittest_no_generic_services_pb2, "FOO")) + self.assertTrue(hasattr(unittest_no_generic_services_pb2, "test_extension")) + self.assertFalse(hasattr(unittest_no_generic_services_pb2, "TestService")) + + +if __name__ == '__main__': + unittest.main() diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/message_listener.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/message_listener.py index ab472e30..1080234d 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/message_listener.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/message_listener.py @@ -1,78 +1,78 @@ -# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Defines a listener interface for observing certain
-state transitions on Message objects.
-
-Also defines a null implementation of this interface.
-"""
-
-__author__ = '[email protected] (Will Robinson)'
-
-
-class MessageListener(object):
-
- """Listens for modifications made to a message. Meant to be registered via
- Message._SetListener().
-
- Attributes:
- dirty: If True, then calling Modified() would be a no-op. This can be
- used to avoid these calls entirely in the common case.
- """
-
- def Modified(self):
- """Called every time the message is modified in such a way that the parent
- message may need to be updated. This currently means either:
- (a) The message was modified for the first time, so the parent message
- should henceforth mark the message as present.
- (b) The message's cached byte size became dirty -- i.e. the message was
- modified for the first time after a previous call to ByteSize().
- Therefore the parent should also mark its byte size as dirty.
- Note that (a) implies (b), since new objects start out with a client cached
- size (zero). However, we document (a) explicitly because it is important.
-
- Modified() will *only* be called in response to one of these two events --
- not every time the sub-message is modified.
-
- Note that if the listener's |dirty| attribute is true, then calling
- Modified at the moment would be a no-op, so it can be skipped. Performance-
- sensitive callers should check this attribute directly before calling since
- it will be true most of the time.
- """
-
- raise NotImplementedError
-
-
-class NullMessageListener(object):
-
- """No-op MessageListener implementation."""
-
- def Modified(self):
- pass
+# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Defines a listener interface for observing certain +state transitions on Message objects. + +Also defines a null implementation of this interface. +""" + +__author__ = '[email protected] (Will Robinson)' + + +class MessageListener(object): + + """Listens for modifications made to a message. Meant to be registered via + Message._SetListener(). + + Attributes: + dirty: If True, then calling Modified() would be a no-op. This can be + used to avoid these calls entirely in the common case. + """ + + def Modified(self): + """Called every time the message is modified in such a way that the parent + message may need to be updated. This currently means either: + (a) The message was modified for the first time, so the parent message + should henceforth mark the message as present. + (b) The message's cached byte size became dirty -- i.e. the message was + modified for the first time after a previous call to ByteSize(). + Therefore the parent should also mark its byte size as dirty. + Note that (a) implies (b), since new objects start out with a client cached + size (zero). However, we document (a) explicitly because it is important. + + Modified() will *only* be called in response to one of these two events -- + not every time the sub-message is modified. + + Note that if the listener's |dirty| attribute is true, then calling + Modified at the moment would be a no-op, so it can be skipped. Performance- + sensitive callers should check this attribute directly before calling since + it will be true most of the time. + """ + + raise NotImplementedError + + +class NullMessageListener(object): + + """No-op MessageListener implementation.""" + + def Modified(self): + pass diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/message_test.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/message_test.py index 02565ac9..73a9a3a3 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/message_test.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/message_test.py @@ -1,89 +1,89 @@ -#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Tests python protocol buffers against the golden message.
-
-Note that the golden messages exercise every known field type, thus this
-test ends up exercising and verifying nearly all of the parsing and
-serialization code in the whole library.
-
-TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of
-sense to call this a test of the "message" module, which only declares an
-abstract interface.
-"""
-
-__author__ = '[email protected] (Gregory P. Smith)'
-
-import unittest
-from google.protobuf import unittest_import_pb2
-from google.protobuf import unittest_pb2
-from google.protobuf.internal import test_util
-
-
-class MessageTest(unittest.TestCase):
-
- def testGoldenMessage(self):
- golden_data = test_util.GoldenFile('golden_message').read()
- golden_message = unittest_pb2.TestAllTypes()
- golden_message.ParseFromString(golden_data)
- test_util.ExpectAllFieldsSet(self, golden_message)
- self.assertTrue(golden_message.SerializeToString() == golden_data)
-
- def testGoldenExtensions(self):
- golden_data = test_util.GoldenFile('golden_message').read()
- golden_message = unittest_pb2.TestAllExtensions()
- golden_message.ParseFromString(golden_data)
- all_set = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(all_set)
- self.assertEquals(all_set, golden_message)
- self.assertTrue(golden_message.SerializeToString() == golden_data)
-
- def testGoldenPackedMessage(self):
- golden_data = test_util.GoldenFile('golden_packed_fields_message').read()
- golden_message = unittest_pb2.TestPackedTypes()
- golden_message.ParseFromString(golden_data)
- all_set = unittest_pb2.TestPackedTypes()
- test_util.SetAllPackedFields(all_set)
- self.assertEquals(all_set, golden_message)
- self.assertTrue(all_set.SerializeToString() == golden_data)
-
- def testGoldenPackedExtensions(self):
- golden_data = test_util.GoldenFile('golden_packed_fields_message').read()
- golden_message = unittest_pb2.TestPackedExtensions()
- golden_message.ParseFromString(golden_data)
- all_set = unittest_pb2.TestPackedExtensions()
- test_util.SetAllPackedExtensions(all_set)
- self.assertEquals(all_set, golden_message)
- self.assertTrue(all_set.SerializeToString() == golden_data)
-
-if __name__ == '__main__':
- unittest.main()
+#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests python protocol buffers against the golden message. + +Note that the golden messages exercise every known field type, thus this +test ends up exercising and verifying nearly all of the parsing and +serialization code in the whole library. + +TODO(kenton): Merge with wire_format_test? It doesn't make a whole lot of +sense to call this a test of the "message" module, which only declares an +abstract interface. +""" + +__author__ = '[email protected] (Gregory P. Smith)' + +import unittest +from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf.internal import test_util + + +class MessageTest(unittest.TestCase): + + def testGoldenMessage(self): + golden_data = test_util.GoldenFile('golden_message').read() + golden_message = unittest_pb2.TestAllTypes() + golden_message.ParseFromString(golden_data) + test_util.ExpectAllFieldsSet(self, golden_message) + self.assertTrue(golden_message.SerializeToString() == golden_data) + + def testGoldenExtensions(self): + golden_data = test_util.GoldenFile('golden_message').read() + golden_message = unittest_pb2.TestAllExtensions() + golden_message.ParseFromString(golden_data) + all_set = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(all_set) + self.assertEquals(all_set, golden_message) + self.assertTrue(golden_message.SerializeToString() == golden_data) + + def testGoldenPackedMessage(self): + golden_data = test_util.GoldenFile('golden_packed_fields_message').read() + golden_message = unittest_pb2.TestPackedTypes() + golden_message.ParseFromString(golden_data) + all_set = unittest_pb2.TestPackedTypes() + test_util.SetAllPackedFields(all_set) + self.assertEquals(all_set, golden_message) + self.assertTrue(all_set.SerializeToString() == golden_data) + + def testGoldenPackedExtensions(self): + golden_data = test_util.GoldenFile('golden_packed_fields_message').read() + golden_message = unittest_pb2.TestPackedExtensions() + golden_message.ParseFromString(golden_data) + all_set = unittest_pb2.TestPackedExtensions() + test_util.SetAllPackedExtensions(all_set) + self.assertEquals(all_set, golden_message) + self.assertTrue(all_set.SerializeToString() == golden_data) + +if __name__ == '__main__': + unittest.main() diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/reflection_test.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/reflection_test.py index 7582e550..2c9fa30b 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/reflection_test.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/reflection_test.py @@ -1,2236 +1,2236 @@ -#! /usr/bin/python
-# -*- coding: utf-8 -*-
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Unittest for reflection.py, which also indirectly tests the output of the
-pure-Python protocol compiler.
-"""
-
-__author__ = '[email protected] (Will Robinson)'
-
-import operator
-import struct
-
-import unittest
-# TODO(robinson): When we split this test in two, only some of these imports
-# will be necessary in each test.
-from google.protobuf import unittest_import_pb2
-from google.protobuf import unittest_mset_pb2
-from google.protobuf import unittest_pb2
-from google.protobuf import descriptor_pb2
-from google.protobuf import descriptor
-from google.protobuf import message
-from google.protobuf import reflection
-from google.protobuf.internal import more_extensions_pb2
-from google.protobuf.internal import more_messages_pb2
-from google.protobuf.internal import wire_format
-from google.protobuf.internal import test_util
-from google.protobuf.internal import decoder
-
-
-class _MiniDecoder(object):
- """Decodes a stream of values from a string.
-
- Once upon a time we actually had a class called decoder.Decoder. Then we
- got rid of it during a redesign that made decoding much, much faster overall.
- But a couple tests in this file used it to check that the serialized form of
- a message was correct. So, this class implements just the methods that were
- used by said tests, so that we don't have to rewrite the tests.
- """
-
- def __init__(self, bytes):
- self._bytes = bytes
- self._pos = 0
-
- def ReadVarint(self):
- result, self._pos = decoder._DecodeVarint(self._bytes, self._pos)
- return result
-
- ReadInt32 = ReadVarint
- ReadInt64 = ReadVarint
- ReadUInt32 = ReadVarint
- ReadUInt64 = ReadVarint
-
- def ReadSInt64(self):
- return wire_format.ZigZagDecode(self.ReadVarint())
-
- ReadSInt32 = ReadSInt64
-
- def ReadFieldNumberAndWireType(self):
- return wire_format.UnpackTag(self.ReadVarint())
-
- def ReadFloat(self):
- result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0]
- self._pos += 4
- return result
-
- def ReadDouble(self):
- result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0]
- self._pos += 8
- return result
-
- def EndOfStream(self):
- return self._pos == len(self._bytes)
-
-
-class ReflectionTest(unittest.TestCase):
-
- def assertIs(self, values, others):
- self.assertEqual(len(values), len(others))
- for i in range(len(values)):
- self.assertTrue(values[i] is others[i])
-
- def testScalarConstructor(self):
- # Constructor with only scalar types should succeed.
- proto = unittest_pb2.TestAllTypes(
- optional_int32=24,
- optional_double=54.321,
- optional_string='optional_string')
-
- self.assertEqual(24, proto.optional_int32)
- self.assertEqual(54.321, proto.optional_double)
- self.assertEqual('optional_string', proto.optional_string)
-
- def testRepeatedScalarConstructor(self):
- # Constructor with only repeated scalar types should succeed.
- proto = unittest_pb2.TestAllTypes(
- repeated_int32=[1, 2, 3, 4],
- repeated_double=[1.23, 54.321],
- repeated_bool=[True, False, False],
- repeated_string=["optional_string"])
-
- self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32))
- self.assertEquals([1.23, 54.321], list(proto.repeated_double))
- self.assertEquals([True, False, False], list(proto.repeated_bool))
- self.assertEquals(["optional_string"], list(proto.repeated_string))
-
- def testRepeatedCompositeConstructor(self):
- # Constructor with only repeated composite types should succeed.
- proto = unittest_pb2.TestAllTypes(
- repeated_nested_message=[
- unittest_pb2.TestAllTypes.NestedMessage(
- bb=unittest_pb2.TestAllTypes.FOO),
- unittest_pb2.TestAllTypes.NestedMessage(
- bb=unittest_pb2.TestAllTypes.BAR)],
- repeated_foreign_message=[
- unittest_pb2.ForeignMessage(c=-43),
- unittest_pb2.ForeignMessage(c=45324),
- unittest_pb2.ForeignMessage(c=12)],
- repeatedgroup=[
- unittest_pb2.TestAllTypes.RepeatedGroup(),
- unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
- unittest_pb2.TestAllTypes.RepeatedGroup(a=2)])
-
- self.assertEquals(
- [unittest_pb2.TestAllTypes.NestedMessage(
- bb=unittest_pb2.TestAllTypes.FOO),
- unittest_pb2.TestAllTypes.NestedMessage(
- bb=unittest_pb2.TestAllTypes.BAR)],
- list(proto.repeated_nested_message))
- self.assertEquals(
- [unittest_pb2.ForeignMessage(c=-43),
- unittest_pb2.ForeignMessage(c=45324),
- unittest_pb2.ForeignMessage(c=12)],
- list(proto.repeated_foreign_message))
- self.assertEquals(
- [unittest_pb2.TestAllTypes.RepeatedGroup(),
- unittest_pb2.TestAllTypes.RepeatedGroup(a=1),
- unittest_pb2.TestAllTypes.RepeatedGroup(a=2)],
- list(proto.repeatedgroup))
-
- def testMixedConstructor(self):
- # Constructor with only mixed types should succeed.
- proto = unittest_pb2.TestAllTypes(
- optional_int32=24,
- optional_string='optional_string',
- repeated_double=[1.23, 54.321],
- repeated_bool=[True, False, False],
- repeated_nested_message=[
- unittest_pb2.TestAllTypes.NestedMessage(
- bb=unittest_pb2.TestAllTypes.FOO),
- unittest_pb2.TestAllTypes.NestedMessage(
- bb=unittest_pb2.TestAllTypes.BAR)],
- repeated_foreign_message=[
- unittest_pb2.ForeignMessage(c=-43),
- unittest_pb2.ForeignMessage(c=45324),
- unittest_pb2.ForeignMessage(c=12)])
-
- self.assertEqual(24, proto.optional_int32)
- self.assertEqual('optional_string', proto.optional_string)
- self.assertEquals([1.23, 54.321], list(proto.repeated_double))
- self.assertEquals([True, False, False], list(proto.repeated_bool))
- self.assertEquals(
- [unittest_pb2.TestAllTypes.NestedMessage(
- bb=unittest_pb2.TestAllTypes.FOO),
- unittest_pb2.TestAllTypes.NestedMessage(
- bb=unittest_pb2.TestAllTypes.BAR)],
- list(proto.repeated_nested_message))
- self.assertEquals(
- [unittest_pb2.ForeignMessage(c=-43),
- unittest_pb2.ForeignMessage(c=45324),
- unittest_pb2.ForeignMessage(c=12)],
- list(proto.repeated_foreign_message))
-
- def testSimpleHasBits(self):
- # Test a scalar.
- proto = unittest_pb2.TestAllTypes()
- self.assertTrue(not proto.HasField('optional_int32'))
- self.assertEqual(0, proto.optional_int32)
- # HasField() shouldn't be true if all we've done is
- # read the default value.
- self.assertTrue(not proto.HasField('optional_int32'))
- proto.optional_int32 = 1
- # Setting a value however *should* set the "has" bit.
- self.assertTrue(proto.HasField('optional_int32'))
- proto.ClearField('optional_int32')
- # And clearing that value should unset the "has" bit.
- self.assertTrue(not proto.HasField('optional_int32'))
-
- def testHasBitsWithSinglyNestedScalar(self):
- # Helper used to test foreign messages and groups.
- #
- # composite_field_name should be the name of a non-repeated
- # composite (i.e., foreign or group) field in TestAllTypes,
- # and scalar_field_name should be the name of an integer-valued
- # scalar field within that composite.
- #
- # I never thought I'd miss C++ macros and templates so much. :(
- # This helper is semantically just:
- #
- # assert proto.composite_field.scalar_field == 0
- # assert not proto.composite_field.HasField('scalar_field')
- # assert not proto.HasField('composite_field')
- #
- # proto.composite_field.scalar_field = 10
- # old_composite_field = proto.composite_field
- #
- # assert proto.composite_field.scalar_field == 10
- # assert proto.composite_field.HasField('scalar_field')
- # assert proto.HasField('composite_field')
- #
- # proto.ClearField('composite_field')
- #
- # assert not proto.composite_field.HasField('scalar_field')
- # assert not proto.HasField('composite_field')
- # assert proto.composite_field.scalar_field == 0
- #
- # # Now ensure that ClearField('composite_field') disconnected
- # # the old field object from the object tree...
- # assert old_composite_field is not proto.composite_field
- # old_composite_field.scalar_field = 20
- # assert not proto.composite_field.HasField('scalar_field')
- # assert not proto.HasField('composite_field')
- def TestCompositeHasBits(composite_field_name, scalar_field_name):
- proto = unittest_pb2.TestAllTypes()
- # First, check that we can get the scalar value, and see that it's the
- # default (0), but that proto.HasField('omposite') and
- # proto.composite.HasField('scalar') will still return False.
- composite_field = getattr(proto, composite_field_name)
- original_scalar_value = getattr(composite_field, scalar_field_name)
- self.assertEqual(0, original_scalar_value)
- # Assert that the composite object does not "have" the scalar.
- self.assertTrue(not composite_field.HasField(scalar_field_name))
- # Assert that proto does not "have" the composite field.
- self.assertTrue(not proto.HasField(composite_field_name))
-
- # Now set the scalar within the composite field. Ensure that the setting
- # is reflected, and that proto.HasField('composite') and
- # proto.composite.HasField('scalar') now both return True.
- new_val = 20
- setattr(composite_field, scalar_field_name, new_val)
- self.assertEqual(new_val, getattr(composite_field, scalar_field_name))
- # Hold on to a reference to the current composite_field object.
- old_composite_field = composite_field
- # Assert that the has methods now return true.
- self.assertTrue(composite_field.HasField(scalar_field_name))
- self.assertTrue(proto.HasField(composite_field_name))
-
- # Now call the clear method...
- proto.ClearField(composite_field_name)
-
- # ...and ensure that the "has" bits are all back to False...
- composite_field = getattr(proto, composite_field_name)
- self.assertTrue(not composite_field.HasField(scalar_field_name))
- self.assertTrue(not proto.HasField(composite_field_name))
- # ...and ensure that the scalar field has returned to its default.
- self.assertEqual(0, getattr(composite_field, scalar_field_name))
-
- # Finally, ensure that modifications to the old composite field object
- # don't have any effect on the parent.
- #
- # (NOTE that when we clear the composite field in the parent, we actually
- # don't recursively clear down the tree. Instead, we just disconnect the
- # cleared composite from the tree.)
- self.assertTrue(old_composite_field is not composite_field)
- setattr(old_composite_field, scalar_field_name, new_val)
- self.assertTrue(not composite_field.HasField(scalar_field_name))
- self.assertTrue(not proto.HasField(composite_field_name))
- self.assertEqual(0, getattr(composite_field, scalar_field_name))
-
- # Test simple, single-level nesting when we set a scalar.
- TestCompositeHasBits('optionalgroup', 'a')
- TestCompositeHasBits('optional_nested_message', 'bb')
- TestCompositeHasBits('optional_foreign_message', 'c')
- TestCompositeHasBits('optional_import_message', 'd')
-
- def testReferencesToNestedMessage(self):
- proto = unittest_pb2.TestAllTypes()
- nested = proto.optional_nested_message
- del proto
- # A previous version had a bug where this would raise an exception when
- # hitting a now-dead weak reference.
- nested.bb = 23
-
- def testDisconnectingNestedMessageBeforeSettingField(self):
- proto = unittest_pb2.TestAllTypes()
- nested = proto.optional_nested_message
- proto.ClearField('optional_nested_message') # Should disconnect from parent
- self.assertTrue(nested is not proto.optional_nested_message)
- nested.bb = 23
- self.assertTrue(not proto.HasField('optional_nested_message'))
- self.assertEqual(0, proto.optional_nested_message.bb)
-
- def testHasBitsWhenModifyingRepeatedFields(self):
- # Test nesting when we add an element to a repeated field in a submessage.
- proto = unittest_pb2.TestNestedMessageHasBits()
- proto.optional_nested_message.nestedmessage_repeated_int32.append(5)
- self.assertEqual(
- [5], proto.optional_nested_message.nestedmessage_repeated_int32)
- self.assertTrue(proto.HasField('optional_nested_message'))
-
- # Do the same test, but with a repeated composite field within the
- # submessage.
- proto.ClearField('optional_nested_message')
- self.assertTrue(not proto.HasField('optional_nested_message'))
- proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add()
- self.assertTrue(proto.HasField('optional_nested_message'))
-
- def testHasBitsForManyLevelsOfNesting(self):
- # Test nesting many levels deep.
- recursive_proto = unittest_pb2.TestMutualRecursionA()
- self.assertTrue(not recursive_proto.HasField('bb'))
- self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32)
- self.assertTrue(not recursive_proto.HasField('bb'))
- recursive_proto.bb.a.bb.a.bb.optional_int32 = 5
- self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32)
- self.assertTrue(recursive_proto.HasField('bb'))
- self.assertTrue(recursive_proto.bb.HasField('a'))
- self.assertTrue(recursive_proto.bb.a.HasField('bb'))
- self.assertTrue(recursive_proto.bb.a.bb.HasField('a'))
- self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb'))
- self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a'))
- self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32'))
-
- def testSingularListFields(self):
- proto = unittest_pb2.TestAllTypes()
- proto.optional_fixed32 = 1
- proto.optional_int32 = 5
- proto.optional_string = 'foo'
- # Access sub-message but don't set it yet.
- nested_message = proto.optional_nested_message
- self.assertEqual(
- [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
- (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
- (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ],
- proto.ListFields())
-
- proto.optional_nested_message.bb = 123
- self.assertEqual(
- [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5),
- (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1),
- (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'),
- (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ],
- nested_message) ],
- proto.ListFields())
-
- def testRepeatedListFields(self):
- proto = unittest_pb2.TestAllTypes()
- proto.repeated_fixed32.append(1)
- proto.repeated_int32.append(5)
- proto.repeated_int32.append(11)
- proto.repeated_string.extend(['foo', 'bar'])
- proto.repeated_string.extend([])
- proto.repeated_string.append('baz')
- proto.repeated_string.extend(str(x) for x in xrange(2))
- proto.optional_int32 = 21
- proto.repeated_bool # Access but don't set anything; should not be listed.
- self.assertEqual(
- [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21),
- (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]),
- (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]),
- (proto.DESCRIPTOR.fields_by_name['repeated_string' ],
- ['foo', 'bar', 'baz', '0', '1']) ],
- proto.ListFields())
-
- def testSingularListExtensions(self):
- proto = unittest_pb2.TestAllExtensions()
- proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1
- proto.Extensions[unittest_pb2.optional_int32_extension ] = 5
- proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo'
- self.assertEqual(
- [ (unittest_pb2.optional_int32_extension , 5),
- (unittest_pb2.optional_fixed32_extension, 1),
- (unittest_pb2.optional_string_extension , 'foo') ],
- proto.ListFields())
-
- def testRepeatedListExtensions(self):
- proto = unittest_pb2.TestAllExtensions()
- proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1)
- proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5)
- proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11)
- proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo')
- proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar')
- proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz')
- proto.Extensions[unittest_pb2.optional_int32_extension ] = 21
- self.assertEqual(
- [ (unittest_pb2.optional_int32_extension , 21),
- (unittest_pb2.repeated_int32_extension , [5, 11]),
- (unittest_pb2.repeated_fixed32_extension, [1]),
- (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ],
- proto.ListFields())
-
- def testListFieldsAndExtensions(self):
- proto = unittest_pb2.TestFieldOrderings()
- test_util.SetAllFieldsAndExtensions(proto)
- unittest_pb2.my_extension_int
- self.assertEqual(
- [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1),
- (unittest_pb2.my_extension_int , 23),
- (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'),
- (unittest_pb2.my_extension_string , 'bar'),
- (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ],
- proto.ListFields())
-
- def testDefaultValues(self):
- proto = unittest_pb2.TestAllTypes()
- self.assertEqual(0, proto.optional_int32)
- self.assertEqual(0, proto.optional_int64)
- self.assertEqual(0, proto.optional_uint32)
- self.assertEqual(0, proto.optional_uint64)
- self.assertEqual(0, proto.optional_sint32)
- self.assertEqual(0, proto.optional_sint64)
- self.assertEqual(0, proto.optional_fixed32)
- self.assertEqual(0, proto.optional_fixed64)
- self.assertEqual(0, proto.optional_sfixed32)
- self.assertEqual(0, proto.optional_sfixed64)
- self.assertEqual(0.0, proto.optional_float)
- self.assertEqual(0.0, proto.optional_double)
- self.assertEqual(False, proto.optional_bool)
- self.assertEqual('', proto.optional_string)
- self.assertEqual('', proto.optional_bytes)
-
- self.assertEqual(41, proto.default_int32)
- self.assertEqual(42, proto.default_int64)
- self.assertEqual(43, proto.default_uint32)
- self.assertEqual(44, proto.default_uint64)
- self.assertEqual(-45, proto.default_sint32)
- self.assertEqual(46, proto.default_sint64)
- self.assertEqual(47, proto.default_fixed32)
- self.assertEqual(48, proto.default_fixed64)
- self.assertEqual(49, proto.default_sfixed32)
- self.assertEqual(-50, proto.default_sfixed64)
- self.assertEqual(51.5, proto.default_float)
- self.assertEqual(52e3, proto.default_double)
- self.assertEqual(True, proto.default_bool)
- self.assertEqual('hello', proto.default_string)
- self.assertEqual('world', proto.default_bytes)
- self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum)
- self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum)
- self.assertEqual(unittest_import_pb2.IMPORT_BAR,
- proto.default_import_enum)
-
- proto = unittest_pb2.TestExtremeDefaultValues()
- self.assertEqual(u'\u1234', proto.utf8_string)
-
- def testHasFieldWithUnknownFieldName(self):
- proto = unittest_pb2.TestAllTypes()
- self.assertRaises(ValueError, proto.HasField, 'nonexistent_field')
-
- def testClearFieldWithUnknownFieldName(self):
- proto = unittest_pb2.TestAllTypes()
- self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field')
-
- def testDisallowedAssignments(self):
- # It's illegal to assign values directly to repeated fields
- # or to nonrepeated composite fields. Ensure that this fails.
- proto = unittest_pb2.TestAllTypes()
- # Repeated fields.
- self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10)
- # Lists shouldn't work, either.
- self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10])
- # Composite fields.
- self.assertRaises(AttributeError, setattr, proto,
- 'optional_nested_message', 23)
- # Assignment to a repeated nested message field without specifying
- # the index in the array of nested messages.
- self.assertRaises(AttributeError, setattr, proto.repeated_nested_message,
- 'bb', 34)
- # Assignment to an attribute of a repeated field.
- self.assertRaises(AttributeError, setattr, proto.repeated_float,
- 'some_attribute', 34)
- # proto.nonexistent_field = 23 should fail as well.
- self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23)
-
- # TODO(robinson): Add type-safety check for enums.
- def testSingleScalarTypeSafety(self):
- proto = unittest_pb2.TestAllTypes()
- self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1)
- self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo')
- self.assertRaises(TypeError, setattr, proto, 'optional_string', 10)
- self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10)
-
- def testSingleScalarBoundsChecking(self):
- def TestMinAndMaxIntegers(field_name, expected_min, expected_max):
- pb = unittest_pb2.TestAllTypes()
- setattr(pb, field_name, expected_min)
- setattr(pb, field_name, expected_max)
- self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1)
- self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1)
-
- TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1)
- TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff)
- TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1)
- TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff)
- TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1)
-
- def testRepeatedScalarTypeSafety(self):
- proto = unittest_pb2.TestAllTypes()
- self.assertRaises(TypeError, proto.repeated_int32.append, 1.1)
- self.assertRaises(TypeError, proto.repeated_int32.append, 'foo')
- self.assertRaises(TypeError, proto.repeated_string, 10)
- self.assertRaises(TypeError, proto.repeated_bytes, 10)
-
- proto.repeated_int32.append(10)
- proto.repeated_int32[0] = 23
- self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23)
- self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc')
-
- def testSingleScalarGettersAndSetters(self):
- proto = unittest_pb2.TestAllTypes()
- self.assertEqual(0, proto.optional_int32)
- proto.optional_int32 = 1
- self.assertEqual(1, proto.optional_int32)
- # TODO(robinson): Test all other scalar field types.
-
- def testSingleScalarClearField(self):
- proto = unittest_pb2.TestAllTypes()
- # Should be allowed to clear something that's not there (a no-op).
- proto.ClearField('optional_int32')
- proto.optional_int32 = 1
- self.assertTrue(proto.HasField('optional_int32'))
- proto.ClearField('optional_int32')
- self.assertEqual(0, proto.optional_int32)
- self.assertTrue(not proto.HasField('optional_int32'))
- # TODO(robinson): Test all other scalar field types.
-
- def testEnums(self):
- proto = unittest_pb2.TestAllTypes()
- self.assertEqual(1, proto.FOO)
- self.assertEqual(1, unittest_pb2.TestAllTypes.FOO)
- self.assertEqual(2, proto.BAR)
- self.assertEqual(2, unittest_pb2.TestAllTypes.BAR)
- self.assertEqual(3, proto.BAZ)
- self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ)
-
- def testRepeatedScalars(self):
- proto = unittest_pb2.TestAllTypes()
-
- self.assertTrue(not proto.repeated_int32)
- self.assertEqual(0, len(proto.repeated_int32))
- proto.repeated_int32.append(5)
- proto.repeated_int32.append(10)
- proto.repeated_int32.append(15)
- self.assertTrue(proto.repeated_int32)
- self.assertEqual(3, len(proto.repeated_int32))
-
- self.assertEqual([5, 10, 15], proto.repeated_int32)
-
- # Test single retrieval.
- self.assertEqual(5, proto.repeated_int32[0])
- self.assertEqual(15, proto.repeated_int32[-1])
- # Test out-of-bounds indices.
- self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234)
- self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234)
- # Test incorrect types passed to __getitem__.
- self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo')
- self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None)
-
- # Test single assignment.
- proto.repeated_int32[1] = 20
- self.assertEqual([5, 20, 15], proto.repeated_int32)
-
- # Test insertion.
- proto.repeated_int32.insert(1, 25)
- self.assertEqual([5, 25, 20, 15], proto.repeated_int32)
-
- # Test slice retrieval.
- proto.repeated_int32.append(30)
- self.assertEqual([25, 20, 15], proto.repeated_int32[1:4])
- self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:])
-
- # Test slice assignment with an iterator
- proto.repeated_int32[1:4] = (i for i in xrange(3))
- self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32)
-
- # Test slice assignment.
- proto.repeated_int32[1:4] = [35, 40, 45]
- self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32)
-
- # Test that we can use the field as an iterator.
- result = []
- for i in proto.repeated_int32:
- result.append(i)
- self.assertEqual([5, 35, 40, 45, 30], result)
-
- # Test single deletion.
- del proto.repeated_int32[2]
- self.assertEqual([5, 35, 45, 30], proto.repeated_int32)
-
- # Test slice deletion.
- del proto.repeated_int32[2:]
- self.assertEqual([5, 35], proto.repeated_int32)
-
- # Test clearing.
- proto.ClearField('repeated_int32')
- self.assertTrue(not proto.repeated_int32)
- self.assertEqual(0, len(proto.repeated_int32))
-
- def testRepeatedScalarsRemove(self):
- proto = unittest_pb2.TestAllTypes()
-
- self.assertTrue(not proto.repeated_int32)
- self.assertEqual(0, len(proto.repeated_int32))
- proto.repeated_int32.append(5)
- proto.repeated_int32.append(10)
- proto.repeated_int32.append(5)
- proto.repeated_int32.append(5)
-
- self.assertEqual(4, len(proto.repeated_int32))
- proto.repeated_int32.remove(5)
- self.assertEqual(3, len(proto.repeated_int32))
- self.assertEqual(10, proto.repeated_int32[0])
- self.assertEqual(5, proto.repeated_int32[1])
- self.assertEqual(5, proto.repeated_int32[2])
-
- proto.repeated_int32.remove(5)
- self.assertEqual(2, len(proto.repeated_int32))
- self.assertEqual(10, proto.repeated_int32[0])
- self.assertEqual(5, proto.repeated_int32[1])
-
- proto.repeated_int32.remove(10)
- self.assertEqual(1, len(proto.repeated_int32))
- self.assertEqual(5, proto.repeated_int32[0])
-
- # Remove a non-existent element.
- self.assertRaises(ValueError, proto.repeated_int32.remove, 123)
-
- def testRepeatedComposites(self):
- proto = unittest_pb2.TestAllTypes()
- self.assertTrue(not proto.repeated_nested_message)
- self.assertEqual(0, len(proto.repeated_nested_message))
- m0 = proto.repeated_nested_message.add()
- m1 = proto.repeated_nested_message.add()
- self.assertTrue(proto.repeated_nested_message)
- self.assertEqual(2, len(proto.repeated_nested_message))
- self.assertIs([m0, m1], proto.repeated_nested_message)
- self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage))
-
- # Test out-of-bounds indices.
- self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
- 1234)
- self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__,
- -1234)
-
- # Test incorrect types passed to __getitem__.
- self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
- 'foo')
- self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__,
- None)
-
- # Test slice retrieval.
- m2 = proto.repeated_nested_message.add()
- m3 = proto.repeated_nested_message.add()
- m4 = proto.repeated_nested_message.add()
- self.assertIs([m1, m2, m3], proto.repeated_nested_message[1:4])
- self.assertIs([m0, m1, m2, m3, m4], proto.repeated_nested_message[:])
-
- # Test that we can use the field as an iterator.
- result = []
- for i in proto.repeated_nested_message:
- result.append(i)
- self.assertIs([m0, m1, m2, m3, m4], result)
-
- # Test single deletion.
- del proto.repeated_nested_message[2]
- self.assertIs([m0, m1, m3, m4], proto.repeated_nested_message)
-
- # Test slice deletion.
- del proto.repeated_nested_message[2:]
- self.assertIs([m0, m1], proto.repeated_nested_message)
-
- # Test clearing.
- proto.ClearField('repeated_nested_message')
- self.assertTrue(not proto.repeated_nested_message)
- self.assertEqual(0, len(proto.repeated_nested_message))
-
- def testHandWrittenReflection(self):
- # TODO(robinson): We probably need a better way to specify
- # protocol types by hand. But then again, this isn't something
- # we expect many people to do. Hmm.
- FieldDescriptor = descriptor.FieldDescriptor
- foo_field_descriptor = FieldDescriptor(
- name='foo_field', full_name='MyProto.foo_field',
- index=0, number=1, type=FieldDescriptor.TYPE_INT64,
- cpp_type=FieldDescriptor.CPPTYPE_INT64,
- label=FieldDescriptor.LABEL_OPTIONAL, default_value=0,
- containing_type=None, message_type=None, enum_type=None,
- is_extension=False, extension_scope=None,
- options=descriptor_pb2.FieldOptions())
- mydescriptor = descriptor.Descriptor(
- name='MyProto', full_name='MyProto', filename='ignored',
- containing_type=None, nested_types=[], enum_types=[],
- fields=[foo_field_descriptor], extensions=[],
- options=descriptor_pb2.MessageOptions())
- class MyProtoClass(message.Message):
- DESCRIPTOR = mydescriptor
- __metaclass__ = reflection.GeneratedProtocolMessageType
- myproto_instance = MyProtoClass()
- self.assertEqual(0, myproto_instance.foo_field)
- self.assertTrue(not myproto_instance.HasField('foo_field'))
- myproto_instance.foo_field = 23
- self.assertEqual(23, myproto_instance.foo_field)
- self.assertTrue(myproto_instance.HasField('foo_field'))
-
- def testTopLevelExtensionsForOptionalScalar(self):
- extendee_proto = unittest_pb2.TestAllExtensions()
- extension = unittest_pb2.optional_int32_extension
- self.assertTrue(not extendee_proto.HasExtension(extension))
- self.assertEqual(0, extendee_proto.Extensions[extension])
- # As with normal scalar fields, just doing a read doesn't actually set the
- # "has" bit.
- self.assertTrue(not extendee_proto.HasExtension(extension))
- # Actually set the thing.
- extendee_proto.Extensions[extension] = 23
- self.assertEqual(23, extendee_proto.Extensions[extension])
- self.assertTrue(extendee_proto.HasExtension(extension))
- # Ensure that clearing works as well.
- extendee_proto.ClearExtension(extension)
- self.assertEqual(0, extendee_proto.Extensions[extension])
- self.assertTrue(not extendee_proto.HasExtension(extension))
-
- def testTopLevelExtensionsForRepeatedScalar(self):
- extendee_proto = unittest_pb2.TestAllExtensions()
- extension = unittest_pb2.repeated_string_extension
- self.assertEqual(0, len(extendee_proto.Extensions[extension]))
- extendee_proto.Extensions[extension].append('foo')
- self.assertEqual(['foo'], extendee_proto.Extensions[extension])
- string_list = extendee_proto.Extensions[extension]
- extendee_proto.ClearExtension(extension)
- self.assertEqual(0, len(extendee_proto.Extensions[extension]))
- self.assertTrue(string_list is not extendee_proto.Extensions[extension])
- # Shouldn't be allowed to do Extensions[extension] = 'a'
- self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
- extension, 'a')
-
- def testTopLevelExtensionsForOptionalMessage(self):
- extendee_proto = unittest_pb2.TestAllExtensions()
- extension = unittest_pb2.optional_foreign_message_extension
- self.assertTrue(not extendee_proto.HasExtension(extension))
- self.assertEqual(0, extendee_proto.Extensions[extension].c)
- # As with normal (non-extension) fields, merely reading from the
- # thing shouldn't set the "has" bit.
- self.assertTrue(not extendee_proto.HasExtension(extension))
- extendee_proto.Extensions[extension].c = 23
- self.assertEqual(23, extendee_proto.Extensions[extension].c)
- self.assertTrue(extendee_proto.HasExtension(extension))
- # Save a reference here.
- foreign_message = extendee_proto.Extensions[extension]
- extendee_proto.ClearExtension(extension)
- self.assertTrue(foreign_message is not extendee_proto.Extensions[extension])
- # Setting a field on foreign_message now shouldn't set
- # any "has" bits on extendee_proto.
- foreign_message.c = 42
- self.assertEqual(42, foreign_message.c)
- self.assertTrue(foreign_message.HasField('c'))
- self.assertTrue(not extendee_proto.HasExtension(extension))
- # Shouldn't be allowed to do Extensions[extension] = 'a'
- self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
- extension, 'a')
-
- def testTopLevelExtensionsForRepeatedMessage(self):
- extendee_proto = unittest_pb2.TestAllExtensions()
- extension = unittest_pb2.repeatedgroup_extension
- self.assertEqual(0, len(extendee_proto.Extensions[extension]))
- group = extendee_proto.Extensions[extension].add()
- group.a = 23
- self.assertEqual(23, extendee_proto.Extensions[extension][0].a)
- group.a = 42
- self.assertEqual(42, extendee_proto.Extensions[extension][0].a)
- group_list = extendee_proto.Extensions[extension]
- extendee_proto.ClearExtension(extension)
- self.assertEqual(0, len(extendee_proto.Extensions[extension]))
- self.assertTrue(group_list is not extendee_proto.Extensions[extension])
- # Shouldn't be allowed to do Extensions[extension] = 'a'
- self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions,
- extension, 'a')
-
- def testNestedExtensions(self):
- extendee_proto = unittest_pb2.TestAllExtensions()
- extension = unittest_pb2.TestRequired.single
-
- # We just test the non-repeated case.
- self.assertTrue(not extendee_proto.HasExtension(extension))
- required = extendee_proto.Extensions[extension]
- self.assertEqual(0, required.a)
- self.assertTrue(not extendee_proto.HasExtension(extension))
- required.a = 23
- self.assertEqual(23, extendee_proto.Extensions[extension].a)
- self.assertTrue(extendee_proto.HasExtension(extension))
- extendee_proto.ClearExtension(extension)
- self.assertTrue(required is not extendee_proto.Extensions[extension])
- self.assertTrue(not extendee_proto.HasExtension(extension))
-
- # If message A directly contains message B, and
- # a.HasField('b') is currently False, then mutating any
- # extension in B should change a.HasField('b') to True
- # (and so on up the object tree).
- def testHasBitsForAncestorsOfExtendedMessage(self):
- # Optional scalar extension.
- toplevel = more_extensions_pb2.TopLevelMessage()
- self.assertTrue(not toplevel.HasField('submessage'))
- self.assertEqual(0, toplevel.submessage.Extensions[
- more_extensions_pb2.optional_int_extension])
- self.assertTrue(not toplevel.HasField('submessage'))
- toplevel.submessage.Extensions[
- more_extensions_pb2.optional_int_extension] = 23
- self.assertEqual(23, toplevel.submessage.Extensions[
- more_extensions_pb2.optional_int_extension])
- self.assertTrue(toplevel.HasField('submessage'))
-
- # Repeated scalar extension.
- toplevel = more_extensions_pb2.TopLevelMessage()
- self.assertTrue(not toplevel.HasField('submessage'))
- self.assertEqual([], toplevel.submessage.Extensions[
- more_extensions_pb2.repeated_int_extension])
- self.assertTrue(not toplevel.HasField('submessage'))
- toplevel.submessage.Extensions[
- more_extensions_pb2.repeated_int_extension].append(23)
- self.assertEqual([23], toplevel.submessage.Extensions[
- more_extensions_pb2.repeated_int_extension])
- self.assertTrue(toplevel.HasField('submessage'))
-
- # Optional message extension.
- toplevel = more_extensions_pb2.TopLevelMessage()
- self.assertTrue(not toplevel.HasField('submessage'))
- self.assertEqual(0, toplevel.submessage.Extensions[
- more_extensions_pb2.optional_message_extension].foreign_message_int)
- self.assertTrue(not toplevel.HasField('submessage'))
- toplevel.submessage.Extensions[
- more_extensions_pb2.optional_message_extension].foreign_message_int = 23
- self.assertEqual(23, toplevel.submessage.Extensions[
- more_extensions_pb2.optional_message_extension].foreign_message_int)
- self.assertTrue(toplevel.HasField('submessage'))
-
- # Repeated message extension.
- toplevel = more_extensions_pb2.TopLevelMessage()
- self.assertTrue(not toplevel.HasField('submessage'))
- self.assertEqual(0, len(toplevel.submessage.Extensions[
- more_extensions_pb2.repeated_message_extension]))
- self.assertTrue(not toplevel.HasField('submessage'))
- foreign = toplevel.submessage.Extensions[
- more_extensions_pb2.repeated_message_extension].add()
- self.assertTrue(foreign is toplevel.submessage.Extensions[
- more_extensions_pb2.repeated_message_extension][0])
- self.assertTrue(toplevel.HasField('submessage'))
-
- def testDisconnectionAfterClearingEmptyMessage(self):
- toplevel = more_extensions_pb2.TopLevelMessage()
- extendee_proto = toplevel.submessage
- extension = more_extensions_pb2.optional_message_extension
- extension_proto = extendee_proto.Extensions[extension]
- extendee_proto.ClearExtension(extension)
- extension_proto.foreign_message_int = 23
-
- self.assertTrue(extension_proto is not extendee_proto.Extensions[extension])
-
- def testExtensionFailureModes(self):
- extendee_proto = unittest_pb2.TestAllExtensions()
-
- # Try non-extension-handle arguments to HasExtension,
- # ClearExtension(), and Extensions[]...
- self.assertRaises(KeyError, extendee_proto.HasExtension, 1234)
- self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234)
- self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234)
- self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5)
-
- # Try something that *is* an extension handle, just not for
- # this message...
- unknown_handle = more_extensions_pb2.optional_int_extension
- self.assertRaises(KeyError, extendee_proto.HasExtension,
- unknown_handle)
- self.assertRaises(KeyError, extendee_proto.ClearExtension,
- unknown_handle)
- self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__,
- unknown_handle)
- self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__,
- unknown_handle, 5)
-
- # Try call HasExtension() with a valid handle, but for a
- # *repeated* field. (Just as with non-extension repeated
- # fields, Has*() isn't supported for extension repeated fields).
- self.assertRaises(KeyError, extendee_proto.HasExtension,
- unittest_pb2.repeated_string_extension)
-
- def testStaticParseFrom(self):
- proto1 = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(proto1)
-
- string1 = proto1.SerializeToString()
- proto2 = unittest_pb2.TestAllTypes.FromString(string1)
-
- # Messages should be equal.
- self.assertEqual(proto2, proto1)
-
- def testMergeFromSingularField(self):
- # Test merge with just a singular field.
- proto1 = unittest_pb2.TestAllTypes()
- proto1.optional_int32 = 1
-
- proto2 = unittest_pb2.TestAllTypes()
- # This shouldn't get overwritten.
- proto2.optional_string = 'value'
-
- proto2.MergeFrom(proto1)
- self.assertEqual(1, proto2.optional_int32)
- self.assertEqual('value', proto2.optional_string)
-
- def testMergeFromRepeatedField(self):
- # Test merge with just a repeated field.
- proto1 = unittest_pb2.TestAllTypes()
- proto1.repeated_int32.append(1)
- proto1.repeated_int32.append(2)
-
- proto2 = unittest_pb2.TestAllTypes()
- proto2.repeated_int32.append(0)
- proto2.MergeFrom(proto1)
-
- self.assertEqual(0, proto2.repeated_int32[0])
- self.assertEqual(1, proto2.repeated_int32[1])
- self.assertEqual(2, proto2.repeated_int32[2])
-
- def testMergeFromOptionalGroup(self):
- # Test merge with an optional group.
- proto1 = unittest_pb2.TestAllTypes()
- proto1.optionalgroup.a = 12
- proto2 = unittest_pb2.TestAllTypes()
- proto2.MergeFrom(proto1)
- self.assertEqual(12, proto2.optionalgroup.a)
-
- def testMergeFromRepeatedNestedMessage(self):
- # Test merge with a repeated nested message.
- proto1 = unittest_pb2.TestAllTypes()
- m = proto1.repeated_nested_message.add()
- m.bb = 123
- m = proto1.repeated_nested_message.add()
- m.bb = 321
-
- proto2 = unittest_pb2.TestAllTypes()
- m = proto2.repeated_nested_message.add()
- m.bb = 999
- proto2.MergeFrom(proto1)
- self.assertEqual(999, proto2.repeated_nested_message[0].bb)
- self.assertEqual(123, proto2.repeated_nested_message[1].bb)
- self.assertEqual(321, proto2.repeated_nested_message[2].bb)
-
- def testMergeFromAllFields(self):
- # With all fields set.
- proto1 = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(proto1)
- proto2 = unittest_pb2.TestAllTypes()
- proto2.MergeFrom(proto1)
-
- # Messages should be equal.
- self.assertEqual(proto2, proto1)
-
- # Serialized string should be equal too.
- string1 = proto1.SerializeToString()
- string2 = proto2.SerializeToString()
- self.assertEqual(string1, string2)
-
- def testMergeFromExtensionsSingular(self):
- proto1 = unittest_pb2.TestAllExtensions()
- proto1.Extensions[unittest_pb2.optional_int32_extension] = 1
-
- proto2 = unittest_pb2.TestAllExtensions()
- proto2.MergeFrom(proto1)
- self.assertEqual(
- 1, proto2.Extensions[unittest_pb2.optional_int32_extension])
-
- def testMergeFromExtensionsRepeated(self):
- proto1 = unittest_pb2.TestAllExtensions()
- proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1)
- proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2)
-
- proto2 = unittest_pb2.TestAllExtensions()
- proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0)
- proto2.MergeFrom(proto1)
- self.assertEqual(
- 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension]))
- self.assertEqual(
- 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0])
- self.assertEqual(
- 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1])
- self.assertEqual(
- 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2])
-
- def testMergeFromExtensionsNestedMessage(self):
- proto1 = unittest_pb2.TestAllExtensions()
- ext1 = proto1.Extensions[
- unittest_pb2.repeated_nested_message_extension]
- m = ext1.add()
- m.bb = 222
- m = ext1.add()
- m.bb = 333
-
- proto2 = unittest_pb2.TestAllExtensions()
- ext2 = proto2.Extensions[
- unittest_pb2.repeated_nested_message_extension]
- m = ext2.add()
- m.bb = 111
-
- proto2.MergeFrom(proto1)
- ext2 = proto2.Extensions[
- unittest_pb2.repeated_nested_message_extension]
- self.assertEqual(3, len(ext2))
- self.assertEqual(111, ext2[0].bb)
- self.assertEqual(222, ext2[1].bb)
- self.assertEqual(333, ext2[2].bb)
-
- def testCopyFromSingularField(self):
- # Test copy with just a singular field.
- proto1 = unittest_pb2.TestAllTypes()
- proto1.optional_int32 = 1
- proto1.optional_string = 'important-text'
-
- proto2 = unittest_pb2.TestAllTypes()
- proto2.optional_string = 'value'
-
- proto2.CopyFrom(proto1)
- self.assertEqual(1, proto2.optional_int32)
- self.assertEqual('important-text', proto2.optional_string)
-
- def testCopyFromRepeatedField(self):
- # Test copy with a repeated field.
- proto1 = unittest_pb2.TestAllTypes()
- proto1.repeated_int32.append(1)
- proto1.repeated_int32.append(2)
-
- proto2 = unittest_pb2.TestAllTypes()
- proto2.repeated_int32.append(0)
- proto2.CopyFrom(proto1)
-
- self.assertEqual(1, proto2.repeated_int32[0])
- self.assertEqual(2, proto2.repeated_int32[1])
-
- def testCopyFromAllFields(self):
- # With all fields set.
- proto1 = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(proto1)
- proto2 = unittest_pb2.TestAllTypes()
- proto2.CopyFrom(proto1)
-
- # Messages should be equal.
- self.assertEqual(proto2, proto1)
-
- # Serialized string should be equal too.
- string1 = proto1.SerializeToString()
- string2 = proto2.SerializeToString()
- self.assertEqual(string1, string2)
-
- def testCopyFromSelf(self):
- proto1 = unittest_pb2.TestAllTypes()
- proto1.repeated_int32.append(1)
- proto1.optional_int32 = 2
- proto1.optional_string = 'important-text'
-
- proto1.CopyFrom(proto1)
- self.assertEqual(1, proto1.repeated_int32[0])
- self.assertEqual(2, proto1.optional_int32)
- self.assertEqual('important-text', proto1.optional_string)
-
- def testClear(self):
- proto = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(proto)
- # Clear the message.
- proto.Clear()
- self.assertEquals(proto.ByteSize(), 0)
- empty_proto = unittest_pb2.TestAllTypes()
- self.assertEquals(proto, empty_proto)
-
- # Test if extensions which were set are cleared.
- proto = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(proto)
- # Clear the message.
- proto.Clear()
- self.assertEquals(proto.ByteSize(), 0)
- empty_proto = unittest_pb2.TestAllExtensions()
- self.assertEquals(proto, empty_proto)
-
- def assertInitialized(self, proto):
- self.assertTrue(proto.IsInitialized())
- # Neither method should raise an exception.
- proto.SerializeToString()
- proto.SerializePartialToString()
-
- def assertNotInitialized(self, proto):
- self.assertFalse(proto.IsInitialized())
- self.assertRaises(message.EncodeError, proto.SerializeToString)
- # "Partial" serialization doesn't care if message is uninitialized.
- proto.SerializePartialToString()
-
- def testIsInitialized(self):
- # Trivial cases - all optional fields and extensions.
- proto = unittest_pb2.TestAllTypes()
- self.assertInitialized(proto)
- proto = unittest_pb2.TestAllExtensions()
- self.assertInitialized(proto)
-
- # The case of uninitialized required fields.
- proto = unittest_pb2.TestRequired()
- self.assertNotInitialized(proto)
- proto.a = proto.b = proto.c = 2
- self.assertInitialized(proto)
-
- # The case of uninitialized submessage.
- proto = unittest_pb2.TestRequiredForeign()
- self.assertInitialized(proto)
- proto.optional_message.a = 1
- self.assertNotInitialized(proto)
- proto.optional_message.b = 0
- proto.optional_message.c = 0
- self.assertInitialized(proto)
-
- # Uninitialized repeated submessage.
- message1 = proto.repeated_message.add()
- self.assertNotInitialized(proto)
- message1.a = message1.b = message1.c = 0
- self.assertInitialized(proto)
-
- # Uninitialized repeated group in an extension.
- proto = unittest_pb2.TestAllExtensions()
- extension = unittest_pb2.TestRequired.multi
- message1 = proto.Extensions[extension].add()
- message2 = proto.Extensions[extension].add()
- self.assertNotInitialized(proto)
- message1.a = 1
- message1.b = 1
- message1.c = 1
- self.assertNotInitialized(proto)
- message2.a = 2
- message2.b = 2
- message2.c = 2
- self.assertInitialized(proto)
-
- # Uninitialized nonrepeated message in an extension.
- proto = unittest_pb2.TestAllExtensions()
- extension = unittest_pb2.TestRequired.single
- proto.Extensions[extension].a = 1
- self.assertNotInitialized(proto)
- proto.Extensions[extension].b = 2
- proto.Extensions[extension].c = 3
- self.assertInitialized(proto)
-
- # Try passing an errors list.
- errors = []
- proto = unittest_pb2.TestRequired()
- self.assertFalse(proto.IsInitialized(errors))
- self.assertEqual(errors, ['a', 'b', 'c'])
-
- def testStringUTF8Encoding(self):
- proto = unittest_pb2.TestAllTypes()
-
- # Assignment of a unicode object to a field of type 'bytes' is not allowed.
- self.assertRaises(TypeError,
- setattr, proto, 'optional_bytes', u'unicode object')
-
- # Check that the default value is of python's 'unicode' type.
- self.assertEqual(type(proto.optional_string), unicode)
-
- proto.optional_string = unicode('Testing')
- self.assertEqual(proto.optional_string, str('Testing'))
-
- # Assign a value of type 'str' which can be encoded in UTF-8.
- proto.optional_string = str('Testing')
- self.assertEqual(proto.optional_string, unicode('Testing'))
-
- # Values of type 'str' are also accepted as long as they can be encoded in
- # UTF-8.
- self.assertEqual(type(proto.optional_string), str)
-
- # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII.
- self.assertRaises(ValueError,
- setattr, proto, 'optional_string', str('a\x80a'))
- # Assign a 'str' object which contains a UTF-8 encoded string.
- self.assertRaises(ValueError,
- setattr, proto, 'optional_string', 'Тест')
- # No exception thrown.
- proto.optional_string = 'abc'
-
- def testStringUTF8Serialization(self):
- proto = unittest_mset_pb2.TestMessageSet()
- extension_message = unittest_mset_pb2.TestMessageSetExtension2
- extension = extension_message.message_set_extension
-
- test_utf8 = u'Тест'
- test_utf8_bytes = test_utf8.encode('utf-8')
-
- # 'Test' in another language, using UTF-8 charset.
- proto.Extensions[extension].str = test_utf8
-
- # Serialize using the MessageSet wire format (this is specified in the
- # .proto file).
- serialized = proto.SerializeToString()
-
- # Check byte size.
- self.assertEqual(proto.ByteSize(), len(serialized))
-
- raw = unittest_mset_pb2.RawMessageSet()
- raw.MergeFromString(serialized)
-
- message2 = unittest_mset_pb2.TestMessageSetExtension2()
-
- self.assertEqual(1, len(raw.item))
- # Check that the type_id is the same as the tag ID in the .proto file.
- self.assertEqual(raw.item[0].type_id, 1547769)
-
- # Check the actually bytes on the wire.
- self.assertTrue(
- raw.item[0].message.endswith(test_utf8_bytes))
- message2.MergeFromString(raw.item[0].message)
-
- self.assertEqual(type(message2.str), unicode)
- self.assertEqual(message2.str, test_utf8)
-
- # How about if the bytes on the wire aren't a valid UTF-8 encoded string.
- bytes = raw.item[0].message.replace(
- test_utf8_bytes, len(test_utf8_bytes) * '\xff')
- self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes)
-
- def testEmptyNestedMessage(self):
- proto = unittest_pb2.TestAllTypes()
- proto.optional_nested_message.MergeFrom(
- unittest_pb2.TestAllTypes.NestedMessage())
- self.assertTrue(proto.HasField('optional_nested_message'))
-
- proto = unittest_pb2.TestAllTypes()
- proto.optional_nested_message.CopyFrom(
- unittest_pb2.TestAllTypes.NestedMessage())
- self.assertTrue(proto.HasField('optional_nested_message'))
-
- proto = unittest_pb2.TestAllTypes()
- proto.optional_nested_message.MergeFromString('')
- self.assertTrue(proto.HasField('optional_nested_message'))
-
- proto = unittest_pb2.TestAllTypes()
- proto.optional_nested_message.ParseFromString('')
- self.assertTrue(proto.HasField('optional_nested_message'))
-
- serialized = proto.SerializeToString()
- proto2 = unittest_pb2.TestAllTypes()
- proto2.MergeFromString(serialized)
- self.assertTrue(proto2.HasField('optional_nested_message'))
-
- def testSetInParent(self):
- proto = unittest_pb2.TestAllTypes()
- self.assertFalse(proto.HasField('optionalgroup'))
- proto.optionalgroup.SetInParent()
- self.assertTrue(proto.HasField('optionalgroup'))
-
-
-# Since we had so many tests for protocol buffer equality, we broke these out
-# into separate TestCase classes.
-
-
-class TestAllTypesEqualityTest(unittest.TestCase):
-
- def setUp(self):
- self.first_proto = unittest_pb2.TestAllTypes()
- self.second_proto = unittest_pb2.TestAllTypes()
-
- def testSelfEquality(self):
- self.assertEqual(self.first_proto, self.first_proto)
-
- def testEmptyProtosEqual(self):
- self.assertEqual(self.first_proto, self.second_proto)
-
-
-class FullProtosEqualityTest(unittest.TestCase):
-
- """Equality tests using completely-full protos as a starting point."""
-
- def setUp(self):
- self.first_proto = unittest_pb2.TestAllTypes()
- self.second_proto = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(self.first_proto)
- test_util.SetAllFields(self.second_proto)
-
- def testNoneNotEqual(self):
- self.assertNotEqual(self.first_proto, None)
- self.assertNotEqual(None, self.second_proto)
-
- def testNotEqualToOtherMessage(self):
- third_proto = unittest_pb2.TestRequired()
- self.assertNotEqual(self.first_proto, third_proto)
- self.assertNotEqual(third_proto, self.second_proto)
-
- def testAllFieldsFilledEquality(self):
- self.assertEqual(self.first_proto, self.second_proto)
-
- def testNonRepeatedScalar(self):
- # Nonrepeated scalar field change should cause inequality.
- self.first_proto.optional_int32 += 1
- self.assertNotEqual(self.first_proto, self.second_proto)
- # ...as should clearing a field.
- self.first_proto.ClearField('optional_int32')
- self.assertNotEqual(self.first_proto, self.second_proto)
-
- def testNonRepeatedComposite(self):
- # Change a nonrepeated composite field.
- self.first_proto.optional_nested_message.bb += 1
- self.assertNotEqual(self.first_proto, self.second_proto)
- self.first_proto.optional_nested_message.bb -= 1
- self.assertEqual(self.first_proto, self.second_proto)
- # Clear a field in the nested message.
- self.first_proto.optional_nested_message.ClearField('bb')
- self.assertNotEqual(self.first_proto, self.second_proto)
- self.first_proto.optional_nested_message.bb = (
- self.second_proto.optional_nested_message.bb)
- self.assertEqual(self.first_proto, self.second_proto)
- # Remove the nested message entirely.
- self.first_proto.ClearField('optional_nested_message')
- self.assertNotEqual(self.first_proto, self.second_proto)
-
- def testRepeatedScalar(self):
- # Change a repeated scalar field.
- self.first_proto.repeated_int32.append(5)
- self.assertNotEqual(self.first_proto, self.second_proto)
- self.first_proto.ClearField('repeated_int32')
- self.assertNotEqual(self.first_proto, self.second_proto)
-
- def testRepeatedComposite(self):
- # Change value within a repeated composite field.
- self.first_proto.repeated_nested_message[0].bb += 1
- self.assertNotEqual(self.first_proto, self.second_proto)
- self.first_proto.repeated_nested_message[0].bb -= 1
- self.assertEqual(self.first_proto, self.second_proto)
- # Add a value to a repeated composite field.
- self.first_proto.repeated_nested_message.add()
- self.assertNotEqual(self.first_proto, self.second_proto)
- self.second_proto.repeated_nested_message.add()
- self.assertEqual(self.first_proto, self.second_proto)
-
- def testNonRepeatedScalarHasBits(self):
- # Ensure that we test "has" bits as well as value for
- # nonrepeated scalar field.
- self.first_proto.ClearField('optional_int32')
- self.second_proto.optional_int32 = 0
- self.assertNotEqual(self.first_proto, self.second_proto)
-
- def testNonRepeatedCompositeHasBits(self):
- # Ensure that we test "has" bits as well as value for
- # nonrepeated composite field.
- self.first_proto.ClearField('optional_nested_message')
- self.second_proto.optional_nested_message.ClearField('bb')
- self.assertNotEqual(self.first_proto, self.second_proto)
- # TODO(robinson): Replace next two lines with method
- # to set the "has" bit without changing the value,
- # if/when such a method exists.
- self.first_proto.optional_nested_message.bb = 0
- self.first_proto.optional_nested_message.ClearField('bb')
- self.assertEqual(self.first_proto, self.second_proto)
-
-
-class ExtensionEqualityTest(unittest.TestCase):
-
- def testExtensionEquality(self):
- first_proto = unittest_pb2.TestAllExtensions()
- second_proto = unittest_pb2.TestAllExtensions()
- self.assertEqual(first_proto, second_proto)
- test_util.SetAllExtensions(first_proto)
- self.assertNotEqual(first_proto, second_proto)
- test_util.SetAllExtensions(second_proto)
- self.assertEqual(first_proto, second_proto)
-
- # Ensure that we check value equality.
- first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1
- self.assertNotEqual(first_proto, second_proto)
- first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1
- self.assertEqual(first_proto, second_proto)
-
- # Ensure that we also look at "has" bits.
- first_proto.ClearExtension(unittest_pb2.optional_int32_extension)
- second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
- self.assertNotEqual(first_proto, second_proto)
- first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0
- self.assertEqual(first_proto, second_proto)
-
- # Ensure that differences in cached values
- # don't matter if "has" bits are both false.
- first_proto = unittest_pb2.TestAllExtensions()
- second_proto = unittest_pb2.TestAllExtensions()
- self.assertEqual(
- 0, first_proto.Extensions[unittest_pb2.optional_int32_extension])
- self.assertEqual(first_proto, second_proto)
-
-
-class MutualRecursionEqualityTest(unittest.TestCase):
-
- def testEqualityWithMutualRecursion(self):
- first_proto = unittest_pb2.TestMutualRecursionA()
- second_proto = unittest_pb2.TestMutualRecursionA()
- self.assertEqual(first_proto, second_proto)
- first_proto.bb.a.bb.optional_int32 = 23
- self.assertNotEqual(first_proto, second_proto)
- second_proto.bb.a.bb.optional_int32 = 23
- self.assertEqual(first_proto, second_proto)
-
-
-class ByteSizeTest(unittest.TestCase):
-
- def setUp(self):
- self.proto = unittest_pb2.TestAllTypes()
- self.extended_proto = more_extensions_pb2.ExtendedMessage()
- self.packed_proto = unittest_pb2.TestPackedTypes()
- self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
-
- def Size(self):
- return self.proto.ByteSize()
-
- def testEmptyMessage(self):
- self.assertEqual(0, self.proto.ByteSize())
-
- def testVarints(self):
- def Test(i, expected_varint_size):
- self.proto.Clear()
- self.proto.optional_int64 = i
- # Add one to the varint size for the tag info
- # for tag 1.
- self.assertEqual(expected_varint_size + 1, self.Size())
- Test(0, 1)
- Test(1, 1)
- for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)):
- Test((1 << i) - 1, num_bytes)
- Test(-1, 10)
- Test(-2, 10)
- Test(-(1 << 63), 10)
-
- def testStrings(self):
- self.proto.optional_string = ''
- # Need one byte for tag info (tag #14), and one byte for length.
- self.assertEqual(2, self.Size())
-
- self.proto.optional_string = 'abc'
- # Need one byte for tag info (tag #14), and one byte for length.
- self.assertEqual(2 + len(self.proto.optional_string), self.Size())
-
- self.proto.optional_string = 'x' * 128
- # Need one byte for tag info (tag #14), and TWO bytes for length.
- self.assertEqual(3 + len(self.proto.optional_string), self.Size())
-
- def testOtherNumerics(self):
- self.proto.optional_fixed32 = 1234
- # One byte for tag and 4 bytes for fixed32.
- self.assertEqual(5, self.Size())
- self.proto = unittest_pb2.TestAllTypes()
-
- self.proto.optional_fixed64 = 1234
- # One byte for tag and 8 bytes for fixed64.
- self.assertEqual(9, self.Size())
- self.proto = unittest_pb2.TestAllTypes()
-
- self.proto.optional_float = 1.234
- # One byte for tag and 4 bytes for float.
- self.assertEqual(5, self.Size())
- self.proto = unittest_pb2.TestAllTypes()
-
- self.proto.optional_double = 1.234
- # One byte for tag and 8 bytes for float.
- self.assertEqual(9, self.Size())
- self.proto = unittest_pb2.TestAllTypes()
-
- self.proto.optional_sint32 = 64
- # One byte for tag and 2 bytes for zig-zag-encoded 64.
- self.assertEqual(3, self.Size())
- self.proto = unittest_pb2.TestAllTypes()
-
- def testComposites(self):
- # 3 bytes.
- self.proto.optional_nested_message.bb = (1 << 14)
- # Plus one byte for bb tag.
- # Plus 1 byte for optional_nested_message serialized size.
- # Plus two bytes for optional_nested_message tag.
- self.assertEqual(3 + 1 + 1 + 2, self.Size())
-
- def testGroups(self):
- # 4 bytes.
- self.proto.optionalgroup.a = (1 << 21)
- # Plus two bytes for |a| tag.
- # Plus 2 * two bytes for START_GROUP and END_GROUP tags.
- self.assertEqual(4 + 2 + 2*2, self.Size())
-
- def testRepeatedScalars(self):
- self.proto.repeated_int32.append(10) # 1 byte.
- self.proto.repeated_int32.append(128) # 2 bytes.
- # Also need 2 bytes for each entry for tag.
- self.assertEqual(1 + 2 + 2*2, self.Size())
-
- def testRepeatedScalarsExtend(self):
- self.proto.repeated_int32.extend([10, 128]) # 3 bytes.
- # Also need 2 bytes for each entry for tag.
- self.assertEqual(1 + 2 + 2*2, self.Size())
-
- def testRepeatedScalarsRemove(self):
- self.proto.repeated_int32.append(10) # 1 byte.
- self.proto.repeated_int32.append(128) # 2 bytes.
- # Also need 2 bytes for each entry for tag.
- self.assertEqual(1 + 2 + 2*2, self.Size())
- self.proto.repeated_int32.remove(128)
- self.assertEqual(1 + 2, self.Size())
-
- def testRepeatedComposites(self):
- # Empty message. 2 bytes tag plus 1 byte length.
- foreign_message_0 = self.proto.repeated_nested_message.add()
- # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
- foreign_message_1 = self.proto.repeated_nested_message.add()
- foreign_message_1.bb = 7
- self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
-
- def testRepeatedCompositesDelete(self):
- # Empty message. 2 bytes tag plus 1 byte length.
- foreign_message_0 = self.proto.repeated_nested_message.add()
- # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
- foreign_message_1 = self.proto.repeated_nested_message.add()
- foreign_message_1.bb = 9
- self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size())
-
- # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
- del self.proto.repeated_nested_message[0]
- self.assertEqual(2 + 1 + 1 + 1, self.Size())
-
- # Now add a new message.
- foreign_message_2 = self.proto.repeated_nested_message.add()
- foreign_message_2.bb = 12
-
- # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
- # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
- self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size())
-
- # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int.
- del self.proto.repeated_nested_message[1]
- self.assertEqual(2 + 1 + 1 + 1, self.Size())
-
- del self.proto.repeated_nested_message[0]
- self.assertEqual(0, self.Size())
-
- def testRepeatedGroups(self):
- # 2-byte START_GROUP plus 2-byte END_GROUP.
- group_0 = self.proto.repeatedgroup.add()
- # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a|
- # plus 2-byte END_GROUP.
- group_1 = self.proto.repeatedgroup.add()
- group_1.a = 7
- self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size())
-
- def testExtensions(self):
- proto = unittest_pb2.TestAllExtensions()
- self.assertEqual(0, proto.ByteSize())
- extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte.
- proto.Extensions[extension] = 23
- # 1 byte for tag, 1 byte for value.
- self.assertEqual(2, proto.ByteSize())
-
- def testCacheInvalidationForNonrepeatedScalar(self):
- # Test non-extension.
- self.proto.optional_int32 = 1
- self.assertEqual(2, self.proto.ByteSize())
- self.proto.optional_int32 = 128
- self.assertEqual(3, self.proto.ByteSize())
- self.proto.ClearField('optional_int32')
- self.assertEqual(0, self.proto.ByteSize())
-
- # Test within extension.
- extension = more_extensions_pb2.optional_int_extension
- self.extended_proto.Extensions[extension] = 1
- self.assertEqual(2, self.extended_proto.ByteSize())
- self.extended_proto.Extensions[extension] = 128
- self.assertEqual(3, self.extended_proto.ByteSize())
- self.extended_proto.ClearExtension(extension)
- self.assertEqual(0, self.extended_proto.ByteSize())
-
- def testCacheInvalidationForRepeatedScalar(self):
- # Test non-extension.
- self.proto.repeated_int32.append(1)
- self.assertEqual(3, self.proto.ByteSize())
- self.proto.repeated_int32.append(1)
- self.assertEqual(6, self.proto.ByteSize())
- self.proto.repeated_int32[1] = 128
- self.assertEqual(7, self.proto.ByteSize())
- self.proto.ClearField('repeated_int32')
- self.assertEqual(0, self.proto.ByteSize())
-
- # Test within extension.
- extension = more_extensions_pb2.repeated_int_extension
- repeated = self.extended_proto.Extensions[extension]
- repeated.append(1)
- self.assertEqual(2, self.extended_proto.ByteSize())
- repeated.append(1)
- self.assertEqual(4, self.extended_proto.ByteSize())
- repeated[1] = 128
- self.assertEqual(5, self.extended_proto.ByteSize())
- self.extended_proto.ClearExtension(extension)
- self.assertEqual(0, self.extended_proto.ByteSize())
-
- def testCacheInvalidationForNonrepeatedMessage(self):
- # Test non-extension.
- self.proto.optional_foreign_message.c = 1
- self.assertEqual(5, self.proto.ByteSize())
- self.proto.optional_foreign_message.c = 128
- self.assertEqual(6, self.proto.ByteSize())
- self.proto.optional_foreign_message.ClearField('c')
- self.assertEqual(3, self.proto.ByteSize())
- self.proto.ClearField('optional_foreign_message')
- self.assertEqual(0, self.proto.ByteSize())
- child = self.proto.optional_foreign_message
- self.proto.ClearField('optional_foreign_message')
- child.c = 128
- self.assertEqual(0, self.proto.ByteSize())
-
- # Test within extension.
- extension = more_extensions_pb2.optional_message_extension
- child = self.extended_proto.Extensions[extension]
- self.assertEqual(0, self.extended_proto.ByteSize())
- child.foreign_message_int = 1
- self.assertEqual(4, self.extended_proto.ByteSize())
- child.foreign_message_int = 128
- self.assertEqual(5, self.extended_proto.ByteSize())
- self.extended_proto.ClearExtension(extension)
- self.assertEqual(0, self.extended_proto.ByteSize())
-
- def testCacheInvalidationForRepeatedMessage(self):
- # Test non-extension.
- child0 = self.proto.repeated_foreign_message.add()
- self.assertEqual(3, self.proto.ByteSize())
- self.proto.repeated_foreign_message.add()
- self.assertEqual(6, self.proto.ByteSize())
- child0.c = 1
- self.assertEqual(8, self.proto.ByteSize())
- self.proto.ClearField('repeated_foreign_message')
- self.assertEqual(0, self.proto.ByteSize())
-
- # Test within extension.
- extension = more_extensions_pb2.repeated_message_extension
- child_list = self.extended_proto.Extensions[extension]
- child0 = child_list.add()
- self.assertEqual(2, self.extended_proto.ByteSize())
- child_list.add()
- self.assertEqual(4, self.extended_proto.ByteSize())
- child0.foreign_message_int = 1
- self.assertEqual(6, self.extended_proto.ByteSize())
- child0.ClearField('foreign_message_int')
- self.assertEqual(4, self.extended_proto.ByteSize())
- self.extended_proto.ClearExtension(extension)
- self.assertEqual(0, self.extended_proto.ByteSize())
-
- def testPackedRepeatedScalars(self):
- self.assertEqual(0, self.packed_proto.ByteSize())
-
- self.packed_proto.packed_int32.append(10) # 1 byte.
- self.packed_proto.packed_int32.append(128) # 2 bytes.
- # The tag is 2 bytes (the field number is 90), and the varint
- # storing the length is 1 byte.
- int_size = 1 + 2 + 3
- self.assertEqual(int_size, self.packed_proto.ByteSize())
-
- self.packed_proto.packed_double.append(4.2) # 8 bytes
- self.packed_proto.packed_double.append(3.25) # 8 bytes
- # 2 more tag bytes, 1 more length byte.
- double_size = 8 + 8 + 3
- self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
-
- self.packed_proto.ClearField('packed_int32')
- self.assertEqual(double_size, self.packed_proto.ByteSize())
-
- def testPackedExtensions(self):
- self.assertEqual(0, self.packed_extended_proto.ByteSize())
- extension = self.packed_extended_proto.Extensions[
- unittest_pb2.packed_fixed32_extension]
- extension.extend([1, 2, 3, 4]) # 16 bytes
- # Tag is 3 bytes.
- self.assertEqual(19, self.packed_extended_proto.ByteSize())
-
-
-# TODO(robinson): We need cross-language serialization consistency tests.
-# Issues to be sure to cover include:
-# * Handling of unrecognized tags ("uninterpreted_bytes").
-# * Handling of MessageSets.
-# * Consistent ordering of tags in the wire format,
-# including ordering between extensions and non-extension
-# fields.
-# * Consistent serialization of negative numbers, especially
-# negative int32s.
-# * Handling of empty submessages (with and without "has"
-# bits set).
-
-class SerializationTest(unittest.TestCase):
-
- def testSerializeEmtpyMessage(self):
- first_proto = unittest_pb2.TestAllTypes()
- second_proto = unittest_pb2.TestAllTypes()
- serialized = first_proto.SerializeToString()
- self.assertEqual(first_proto.ByteSize(), len(serialized))
- second_proto.MergeFromString(serialized)
- self.assertEqual(first_proto, second_proto)
-
- def testSerializeAllFields(self):
- first_proto = unittest_pb2.TestAllTypes()
- second_proto = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(first_proto)
- serialized = first_proto.SerializeToString()
- self.assertEqual(first_proto.ByteSize(), len(serialized))
- second_proto.MergeFromString(serialized)
- self.assertEqual(first_proto, second_proto)
-
- def testSerializeAllExtensions(self):
- first_proto = unittest_pb2.TestAllExtensions()
- second_proto = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(first_proto)
- serialized = first_proto.SerializeToString()
- second_proto.MergeFromString(serialized)
- self.assertEqual(first_proto, second_proto)
-
- def testSerializeNegativeValues(self):
- first_proto = unittest_pb2.TestAllTypes()
-
- first_proto.optional_int32 = -1
- first_proto.optional_int64 = -(2 << 40)
- first_proto.optional_sint32 = -3
- first_proto.optional_sint64 = -(4 << 40)
- first_proto.optional_sfixed32 = -5
- first_proto.optional_sfixed64 = -(6 << 40)
-
- second_proto = unittest_pb2.TestAllTypes.FromString(
- first_proto.SerializeToString())
-
- self.assertEqual(first_proto, second_proto)
-
- def testParseTruncated(self):
- first_proto = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(first_proto)
- serialized = first_proto.SerializeToString()
-
- for truncation_point in xrange(len(serialized) + 1):
- try:
- second_proto = unittest_pb2.TestAllTypes()
- unknown_fields = unittest_pb2.TestEmptyMessage()
- pos = second_proto._InternalParse(serialized, 0, truncation_point)
- # If we didn't raise an error then we read exactly the amount expected.
- self.assertEqual(truncation_point, pos)
-
- # Parsing to unknown fields should not throw if parsing to known fields
- # did not.
- try:
- pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point)
- self.assertEqual(truncation_point, pos2)
- except message.DecodeError:
- self.fail('Parsing unknown fields failed when parsing known fields '
- 'did not.')
- except message.DecodeError:
- # Parsing unknown fields should also fail.
- self.assertRaises(message.DecodeError, unknown_fields._InternalParse,
- serialized, 0, truncation_point)
-
- def testCanonicalSerializationOrder(self):
- proto = more_messages_pb2.OutOfOrderFields()
- # These are also their tag numbers. Even though we're setting these in
- # reverse-tag order AND they're listed in reverse tag-order in the .proto
- # file, they should nonetheless be serialized in tag order.
- proto.optional_sint32 = 5
- proto.Extensions[more_messages_pb2.optional_uint64] = 4
- proto.optional_uint32 = 3
- proto.Extensions[more_messages_pb2.optional_int64] = 2
- proto.optional_int32 = 1
- serialized = proto.SerializeToString()
- self.assertEqual(proto.ByteSize(), len(serialized))
- d = _MiniDecoder(serialized)
- ReadTag = d.ReadFieldNumberAndWireType
- self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag())
- self.assertEqual(1, d.ReadInt32())
- self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag())
- self.assertEqual(2, d.ReadInt64())
- self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag())
- self.assertEqual(3, d.ReadUInt32())
- self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag())
- self.assertEqual(4, d.ReadUInt64())
- self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag())
- self.assertEqual(5, d.ReadSInt32())
-
- def testCanonicalSerializationOrderSameAsCpp(self):
- # Copy of the same test we use for C++.
- proto = unittest_pb2.TestFieldOrderings()
- test_util.SetAllFieldsAndExtensions(proto)
- serialized = proto.SerializeToString()
- test_util.ExpectAllFieldsAndExtensionsInOrder(serialized)
-
- def testMergeFromStringWhenFieldsAlreadySet(self):
- first_proto = unittest_pb2.TestAllTypes()
- first_proto.repeated_string.append('foobar')
- first_proto.optional_int32 = 23
- first_proto.optional_nested_message.bb = 42
- serialized = first_proto.SerializeToString()
-
- second_proto = unittest_pb2.TestAllTypes()
- second_proto.repeated_string.append('baz')
- second_proto.optional_int32 = 100
- second_proto.optional_nested_message.bb = 999
-
- second_proto.MergeFromString(serialized)
- # Ensure that we append to repeated fields.
- self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string))
- # Ensure that we overwrite nonrepeatd scalars.
- self.assertEqual(23, second_proto.optional_int32)
- # Ensure that we recursively call MergeFromString() on
- # submessages.
- self.assertEqual(42, second_proto.optional_nested_message.bb)
-
- def testMessageSetWireFormat(self):
- proto = unittest_mset_pb2.TestMessageSet()
- extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
- extension_message2 = unittest_mset_pb2.TestMessageSetExtension2
- extension1 = extension_message1.message_set_extension
- extension2 = extension_message2.message_set_extension
- proto.Extensions[extension1].i = 123
- proto.Extensions[extension2].str = 'foo'
-
- # Serialize using the MessageSet wire format (this is specified in the
- # .proto file).
- serialized = proto.SerializeToString()
-
- raw = unittest_mset_pb2.RawMessageSet()
- self.assertEqual(False,
- raw.DESCRIPTOR.GetOptions().message_set_wire_format)
- raw.MergeFromString(serialized)
- self.assertEqual(2, len(raw.item))
-
- message1 = unittest_mset_pb2.TestMessageSetExtension1()
- message1.MergeFromString(raw.item[0].message)
- self.assertEqual(123, message1.i)
-
- message2 = unittest_mset_pb2.TestMessageSetExtension2()
- message2.MergeFromString(raw.item[1].message)
- self.assertEqual('foo', message2.str)
-
- # Deserialize using the MessageSet wire format.
- proto2 = unittest_mset_pb2.TestMessageSet()
- proto2.MergeFromString(serialized)
- self.assertEqual(123, proto2.Extensions[extension1].i)
- self.assertEqual('foo', proto2.Extensions[extension2].str)
-
- # Check byte size.
- self.assertEqual(proto2.ByteSize(), len(serialized))
- self.assertEqual(proto.ByteSize(), len(serialized))
-
- def testMessageSetWireFormatUnknownExtension(self):
- # Create a message using the message set wire format with an unknown
- # message.
- raw = unittest_mset_pb2.RawMessageSet()
-
- # Add an item.
- item = raw.item.add()
- item.type_id = 1545008
- extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
- message1 = unittest_mset_pb2.TestMessageSetExtension1()
- message1.i = 12345
- item.message = message1.SerializeToString()
-
- # Add a second, unknown extension.
- item = raw.item.add()
- item.type_id = 1545009
- extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
- message1 = unittest_mset_pb2.TestMessageSetExtension1()
- message1.i = 12346
- item.message = message1.SerializeToString()
-
- # Add another unknown extension.
- item = raw.item.add()
- item.type_id = 1545010
- message1 = unittest_mset_pb2.TestMessageSetExtension2()
- message1.str = 'foo'
- item.message = message1.SerializeToString()
-
- serialized = raw.SerializeToString()
-
- # Parse message using the message set wire format.
- proto = unittest_mset_pb2.TestMessageSet()
- proto.MergeFromString(serialized)
-
- # Check that the message parsed well.
- extension_message1 = unittest_mset_pb2.TestMessageSetExtension1
- extension1 = extension_message1.message_set_extension
- self.assertEquals(12345, proto.Extensions[extension1].i)
-
- def testUnknownFields(self):
- proto = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(proto)
-
- serialized = proto.SerializeToString()
-
- # The empty message should be parsable with all of the fields
- # unknown.
- proto2 = unittest_pb2.TestEmptyMessage()
-
- # Parsing this message should succeed.
- proto2.MergeFromString(serialized)
-
- # Now test with a int64 field set.
- proto = unittest_pb2.TestAllTypes()
- proto.optional_int64 = 0x0fffffffffffffff
- serialized = proto.SerializeToString()
- # The empty message should be parsable with all of the fields
- # unknown.
- proto2 = unittest_pb2.TestEmptyMessage()
- # Parsing this message should succeed.
- proto2.MergeFromString(serialized)
-
- def _CheckRaises(self, exc_class, callable_obj, exception):
- """This method checks if the excpetion type and message are as expected."""
- try:
- callable_obj()
- except exc_class, ex:
- # Check if the exception message is the right one.
- self.assertEqual(exception, str(ex))
- return
- else:
- raise self.failureException('%s not raised' % str(exc_class))
-
- def testSerializeUninitialized(self):
- proto = unittest_pb2.TestRequired()
- self._CheckRaises(
- message.EncodeError,
- proto.SerializeToString,
- 'Message is missing required fields: a,b,c')
- # Shouldn't raise exceptions.
- partial = proto.SerializePartialToString()
-
- proto.a = 1
- self._CheckRaises(
- message.EncodeError,
- proto.SerializeToString,
- 'Message is missing required fields: b,c')
- # Shouldn't raise exceptions.
- partial = proto.SerializePartialToString()
-
- proto.b = 2
- self._CheckRaises(
- message.EncodeError,
- proto.SerializeToString,
- 'Message is missing required fields: c')
- # Shouldn't raise exceptions.
- partial = proto.SerializePartialToString()
-
- proto.c = 3
- serialized = proto.SerializeToString()
- # Shouldn't raise exceptions.
- partial = proto.SerializePartialToString()
-
- proto2 = unittest_pb2.TestRequired()
- proto2.MergeFromString(serialized)
- self.assertEqual(1, proto2.a)
- self.assertEqual(2, proto2.b)
- self.assertEqual(3, proto2.c)
- proto2.ParseFromString(partial)
- self.assertEqual(1, proto2.a)
- self.assertEqual(2, proto2.b)
- self.assertEqual(3, proto2.c)
-
- def testSerializeUninitializedSubMessage(self):
- proto = unittest_pb2.TestRequiredForeign()
-
- # Sub-message doesn't exist yet, so this succeeds.
- proto.SerializeToString()
-
- proto.optional_message.a = 1
- self._CheckRaises(
- message.EncodeError,
- proto.SerializeToString,
- 'Message is missing required fields: '
- 'optional_message.b,optional_message.c')
-
- proto.optional_message.b = 2
- proto.optional_message.c = 3
- proto.SerializeToString()
-
- proto.repeated_message.add().a = 1
- proto.repeated_message.add().b = 2
- self._CheckRaises(
- message.EncodeError,
- proto.SerializeToString,
- 'Message is missing required fields: '
- 'repeated_message[0].b,repeated_message[0].c,'
- 'repeated_message[1].a,repeated_message[1].c')
-
- proto.repeated_message[0].b = 2
- proto.repeated_message[0].c = 3
- proto.repeated_message[1].a = 1
- proto.repeated_message[1].c = 3
- proto.SerializeToString()
-
- def testSerializeAllPackedFields(self):
- first_proto = unittest_pb2.TestPackedTypes()
- second_proto = unittest_pb2.TestPackedTypes()
- test_util.SetAllPackedFields(first_proto)
- serialized = first_proto.SerializeToString()
- self.assertEqual(first_proto.ByteSize(), len(serialized))
- bytes_read = second_proto.MergeFromString(serialized)
- self.assertEqual(second_proto.ByteSize(), bytes_read)
- self.assertEqual(first_proto, second_proto)
-
- def testSerializeAllPackedExtensions(self):
- first_proto = unittest_pb2.TestPackedExtensions()
- second_proto = unittest_pb2.TestPackedExtensions()
- test_util.SetAllPackedExtensions(first_proto)
- serialized = first_proto.SerializeToString()
- bytes_read = second_proto.MergeFromString(serialized)
- self.assertEqual(second_proto.ByteSize(), bytes_read)
- self.assertEqual(first_proto, second_proto)
-
- def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
- first_proto = unittest_pb2.TestPackedTypes()
- first_proto.packed_int32.extend([1, 2])
- first_proto.packed_double.append(3.0)
- serialized = first_proto.SerializeToString()
-
- second_proto = unittest_pb2.TestPackedTypes()
- second_proto.packed_int32.append(3)
- second_proto.packed_double.extend([1.0, 2.0])
- second_proto.packed_sint32.append(4)
-
- second_proto.MergeFromString(serialized)
- self.assertEqual([3, 1, 2], second_proto.packed_int32)
- self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
- self.assertEqual([4], second_proto.packed_sint32)
-
- def testPackedFieldsWireFormat(self):
- proto = unittest_pb2.TestPackedTypes()
- proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes
- proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes
- proto.packed_float.append(2.0) # 4 bytes, will be before double
- serialized = proto.SerializeToString()
- self.assertEqual(proto.ByteSize(), len(serialized))
- d = _MiniDecoder(serialized)
- ReadTag = d.ReadFieldNumberAndWireType
- self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
- self.assertEqual(1+1+1+2, d.ReadInt32())
- self.assertEqual(1, d.ReadInt32())
- self.assertEqual(2, d.ReadInt32())
- self.assertEqual(150, d.ReadInt32())
- self.assertEqual(3, d.ReadInt32())
- self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
- self.assertEqual(4, d.ReadInt32())
- self.assertEqual(2.0, d.ReadFloat())
- self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
- self.assertEqual(8+8, d.ReadInt32())
- self.assertEqual(1.0, d.ReadDouble())
- self.assertEqual(1000.0, d.ReadDouble())
- self.assertTrue(d.EndOfStream())
-
- def testParsePackedFromUnpacked(self):
- unpacked = unittest_pb2.TestUnpackedTypes()
- test_util.SetAllUnpackedFields(unpacked)
- packed = unittest_pb2.TestPackedTypes()
- packed.MergeFromString(unpacked.SerializeToString())
- expected = unittest_pb2.TestPackedTypes()
- test_util.SetAllPackedFields(expected)
- self.assertEqual(expected, packed)
-
- def testParseUnpackedFromPacked(self):
- packed = unittest_pb2.TestPackedTypes()
- test_util.SetAllPackedFields(packed)
- unpacked = unittest_pb2.TestUnpackedTypes()
- unpacked.MergeFromString(packed.SerializeToString())
- expected = unittest_pb2.TestUnpackedTypes()
- test_util.SetAllUnpackedFields(expected)
- self.assertEqual(expected, unpacked)
-
- def testFieldNumbers(self):
- proto = unittest_pb2.TestAllTypes()
- self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1)
- self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1)
- self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16)
- self.assertEqual(
- unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18)
- self.assertEqual(
- unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21)
- self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31)
- self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46)
- self.assertEqual(
- unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48)
- self.assertEqual(
- unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51)
-
- def testExtensionFieldNumbers(self):
- self.assertEqual(unittest_pb2.TestRequired.single.number, 1000)
- self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000)
- self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001)
- self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001)
- self.assertEqual(unittest_pb2.optional_int32_extension.number, 1)
- self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1)
- self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16)
- self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16)
- self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18)
- self.assertEqual(
- unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18)
- self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21)
- self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
- 21)
- self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31)
- self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31)
- self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46)
- self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46)
- self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48)
- self.assertEqual(
- unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48)
- self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51)
- self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
- 51)
-
- def testInitKwargs(self):
- proto = unittest_pb2.TestAllTypes(
- optional_int32=1,
- optional_string='foo',
- optional_bool=True,
- optional_bytes='bar',
- optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1),
- optional_foreign_message=unittest_pb2.ForeignMessage(c=1),
- optional_nested_enum=unittest_pb2.TestAllTypes.FOO,
- optional_foreign_enum=unittest_pb2.FOREIGN_FOO,
- repeated_int32=[1, 2, 3])
- self.assertTrue(proto.IsInitialized())
- self.assertTrue(proto.HasField('optional_int32'))
- self.assertTrue(proto.HasField('optional_string'))
- self.assertTrue(proto.HasField('optional_bool'))
- self.assertTrue(proto.HasField('optional_bytes'))
- self.assertTrue(proto.HasField('optional_nested_message'))
- self.assertTrue(proto.HasField('optional_foreign_message'))
- self.assertTrue(proto.HasField('optional_nested_enum'))
- self.assertTrue(proto.HasField('optional_foreign_enum'))
- self.assertEqual(1, proto.optional_int32)
- self.assertEqual('foo', proto.optional_string)
- self.assertEqual(True, proto.optional_bool)
- self.assertEqual('bar', proto.optional_bytes)
- self.assertEqual(1, proto.optional_nested_message.bb)
- self.assertEqual(1, proto.optional_foreign_message.c)
- self.assertEqual(unittest_pb2.TestAllTypes.FOO,
- proto.optional_nested_enum)
- self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum)
- self.assertEqual([1, 2, 3], proto.repeated_int32)
-
- def testInitArgsUnknownFieldName(self):
- def InitalizeEmptyMessageWithExtraKeywordArg():
- unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown')
- self._CheckRaises(ValueError,
- InitalizeEmptyMessageWithExtraKeywordArg,
- 'Protocol message has no "unknown" field.')
-
- def testInitRequiredKwargs(self):
- proto = unittest_pb2.TestRequired(a=1, b=1, c=1)
- self.assertTrue(proto.IsInitialized())
- self.assertTrue(proto.HasField('a'))
- self.assertTrue(proto.HasField('b'))
- self.assertTrue(proto.HasField('c'))
- self.assertTrue(not proto.HasField('dummy2'))
- self.assertEqual(1, proto.a)
- self.assertEqual(1, proto.b)
- self.assertEqual(1, proto.c)
-
- def testInitRequiredForeignKwargs(self):
- proto = unittest_pb2.TestRequiredForeign(
- optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1))
- self.assertTrue(proto.IsInitialized())
- self.assertTrue(proto.HasField('optional_message'))
- self.assertTrue(proto.optional_message.IsInitialized())
- self.assertTrue(proto.optional_message.HasField('a'))
- self.assertTrue(proto.optional_message.HasField('b'))
- self.assertTrue(proto.optional_message.HasField('c'))
- self.assertTrue(not proto.optional_message.HasField('dummy2'))
- self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1),
- proto.optional_message)
- self.assertEqual(1, proto.optional_message.a)
- self.assertEqual(1, proto.optional_message.b)
- self.assertEqual(1, proto.optional_message.c)
-
- def testInitRepeatedKwargs(self):
- proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3])
- self.assertTrue(proto.IsInitialized())
- self.assertEqual(1, proto.repeated_int32[0])
- self.assertEqual(2, proto.repeated_int32[1])
- self.assertEqual(3, proto.repeated_int32[2])
-
-
-class OptionsTest(unittest.TestCase):
-
- def testMessageOptions(self):
- proto = unittest_mset_pb2.TestMessageSet()
- self.assertEqual(True,
- proto.DESCRIPTOR.GetOptions().message_set_wire_format)
- proto = unittest_pb2.TestAllTypes()
- self.assertEqual(False,
- proto.DESCRIPTOR.GetOptions().message_set_wire_format)
-
- def testPackedOptions(self):
- proto = unittest_pb2.TestAllTypes()
- proto.optional_int32 = 1
- proto.optional_double = 3.0
- for field_descriptor, _ in proto.ListFields():
- self.assertEqual(False, field_descriptor.GetOptions().packed)
-
- proto = unittest_pb2.TestPackedTypes()
- proto.packed_int32.append(1)
- proto.packed_double.append(3.0)
- for field_descriptor, _ in proto.ListFields():
- self.assertEqual(True, field_descriptor.GetOptions().packed)
- self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED,
- field_descriptor.label)
-
-
-
-if __name__ == '__main__':
- unittest.main()
+#! /usr/bin/python +# -*- coding: utf-8 -*- +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unittest for reflection.py, which also indirectly tests the output of the +pure-Python protocol compiler. +""" + +__author__ = '[email protected] (Will Robinson)' + +import operator +import struct + +import unittest +# TODO(robinson): When we split this test in two, only some of these imports +# will be necessary in each test. +from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf import descriptor_pb2 +from google.protobuf import descriptor +from google.protobuf import message +from google.protobuf import reflection +from google.protobuf.internal import more_extensions_pb2 +from google.protobuf.internal import more_messages_pb2 +from google.protobuf.internal import wire_format +from google.protobuf.internal import test_util +from google.protobuf.internal import decoder + + +class _MiniDecoder(object): + """Decodes a stream of values from a string. + + Once upon a time we actually had a class called decoder.Decoder. Then we + got rid of it during a redesign that made decoding much, much faster overall. + But a couple tests in this file used it to check that the serialized form of + a message was correct. So, this class implements just the methods that were + used by said tests, so that we don't have to rewrite the tests. + """ + + def __init__(self, bytes): + self._bytes = bytes + self._pos = 0 + + def ReadVarint(self): + result, self._pos = decoder._DecodeVarint(self._bytes, self._pos) + return result + + ReadInt32 = ReadVarint + ReadInt64 = ReadVarint + ReadUInt32 = ReadVarint + ReadUInt64 = ReadVarint + + def ReadSInt64(self): + return wire_format.ZigZagDecode(self.ReadVarint()) + + ReadSInt32 = ReadSInt64 + + def ReadFieldNumberAndWireType(self): + return wire_format.UnpackTag(self.ReadVarint()) + + def ReadFloat(self): + result = struct.unpack("<f", self._bytes[self._pos:self._pos+4])[0] + self._pos += 4 + return result + + def ReadDouble(self): + result = struct.unpack("<d", self._bytes[self._pos:self._pos+8])[0] + self._pos += 8 + return result + + def EndOfStream(self): + return self._pos == len(self._bytes) + + +class ReflectionTest(unittest.TestCase): + + def assertIs(self, values, others): + self.assertEqual(len(values), len(others)) + for i in range(len(values)): + self.assertTrue(values[i] is others[i]) + + def testScalarConstructor(self): + # Constructor with only scalar types should succeed. + proto = unittest_pb2.TestAllTypes( + optional_int32=24, + optional_double=54.321, + optional_string='optional_string') + + self.assertEqual(24, proto.optional_int32) + self.assertEqual(54.321, proto.optional_double) + self.assertEqual('optional_string', proto.optional_string) + + def testRepeatedScalarConstructor(self): + # Constructor with only repeated scalar types should succeed. + proto = unittest_pb2.TestAllTypes( + repeated_int32=[1, 2, 3, 4], + repeated_double=[1.23, 54.321], + repeated_bool=[True, False, False], + repeated_string=["optional_string"]) + + self.assertEquals([1, 2, 3, 4], list(proto.repeated_int32)) + self.assertEquals([1.23, 54.321], list(proto.repeated_double)) + self.assertEquals([True, False, False], list(proto.repeated_bool)) + self.assertEquals(["optional_string"], list(proto.repeated_string)) + + def testRepeatedCompositeConstructor(self): + # Constructor with only repeated composite types should succeed. + proto = unittest_pb2.TestAllTypes( + repeated_nested_message=[ + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.FOO), + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.BAR)], + repeated_foreign_message=[ + unittest_pb2.ForeignMessage(c=-43), + unittest_pb2.ForeignMessage(c=45324), + unittest_pb2.ForeignMessage(c=12)], + repeatedgroup=[ + unittest_pb2.TestAllTypes.RepeatedGroup(), + unittest_pb2.TestAllTypes.RepeatedGroup(a=1), + unittest_pb2.TestAllTypes.RepeatedGroup(a=2)]) + + self.assertEquals( + [unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.FOO), + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.BAR)], + list(proto.repeated_nested_message)) + self.assertEquals( + [unittest_pb2.ForeignMessage(c=-43), + unittest_pb2.ForeignMessage(c=45324), + unittest_pb2.ForeignMessage(c=12)], + list(proto.repeated_foreign_message)) + self.assertEquals( + [unittest_pb2.TestAllTypes.RepeatedGroup(), + unittest_pb2.TestAllTypes.RepeatedGroup(a=1), + unittest_pb2.TestAllTypes.RepeatedGroup(a=2)], + list(proto.repeatedgroup)) + + def testMixedConstructor(self): + # Constructor with only mixed types should succeed. + proto = unittest_pb2.TestAllTypes( + optional_int32=24, + optional_string='optional_string', + repeated_double=[1.23, 54.321], + repeated_bool=[True, False, False], + repeated_nested_message=[ + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.FOO), + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.BAR)], + repeated_foreign_message=[ + unittest_pb2.ForeignMessage(c=-43), + unittest_pb2.ForeignMessage(c=45324), + unittest_pb2.ForeignMessage(c=12)]) + + self.assertEqual(24, proto.optional_int32) + self.assertEqual('optional_string', proto.optional_string) + self.assertEquals([1.23, 54.321], list(proto.repeated_double)) + self.assertEquals([True, False, False], list(proto.repeated_bool)) + self.assertEquals( + [unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.FOO), + unittest_pb2.TestAllTypes.NestedMessage( + bb=unittest_pb2.TestAllTypes.BAR)], + list(proto.repeated_nested_message)) + self.assertEquals( + [unittest_pb2.ForeignMessage(c=-43), + unittest_pb2.ForeignMessage(c=45324), + unittest_pb2.ForeignMessage(c=12)], + list(proto.repeated_foreign_message)) + + def testSimpleHasBits(self): + # Test a scalar. + proto = unittest_pb2.TestAllTypes() + self.assertTrue(not proto.HasField('optional_int32')) + self.assertEqual(0, proto.optional_int32) + # HasField() shouldn't be true if all we've done is + # read the default value. + self.assertTrue(not proto.HasField('optional_int32')) + proto.optional_int32 = 1 + # Setting a value however *should* set the "has" bit. + self.assertTrue(proto.HasField('optional_int32')) + proto.ClearField('optional_int32') + # And clearing that value should unset the "has" bit. + self.assertTrue(not proto.HasField('optional_int32')) + + def testHasBitsWithSinglyNestedScalar(self): + # Helper used to test foreign messages and groups. + # + # composite_field_name should be the name of a non-repeated + # composite (i.e., foreign or group) field in TestAllTypes, + # and scalar_field_name should be the name of an integer-valued + # scalar field within that composite. + # + # I never thought I'd miss C++ macros and templates so much. :( + # This helper is semantically just: + # + # assert proto.composite_field.scalar_field == 0 + # assert not proto.composite_field.HasField('scalar_field') + # assert not proto.HasField('composite_field') + # + # proto.composite_field.scalar_field = 10 + # old_composite_field = proto.composite_field + # + # assert proto.composite_field.scalar_field == 10 + # assert proto.composite_field.HasField('scalar_field') + # assert proto.HasField('composite_field') + # + # proto.ClearField('composite_field') + # + # assert not proto.composite_field.HasField('scalar_field') + # assert not proto.HasField('composite_field') + # assert proto.composite_field.scalar_field == 0 + # + # # Now ensure that ClearField('composite_field') disconnected + # # the old field object from the object tree... + # assert old_composite_field is not proto.composite_field + # old_composite_field.scalar_field = 20 + # assert not proto.composite_field.HasField('scalar_field') + # assert not proto.HasField('composite_field') + def TestCompositeHasBits(composite_field_name, scalar_field_name): + proto = unittest_pb2.TestAllTypes() + # First, check that we can get the scalar value, and see that it's the + # default (0), but that proto.HasField('omposite') and + # proto.composite.HasField('scalar') will still return False. + composite_field = getattr(proto, composite_field_name) + original_scalar_value = getattr(composite_field, scalar_field_name) + self.assertEqual(0, original_scalar_value) + # Assert that the composite object does not "have" the scalar. + self.assertTrue(not composite_field.HasField(scalar_field_name)) + # Assert that proto does not "have" the composite field. + self.assertTrue(not proto.HasField(composite_field_name)) + + # Now set the scalar within the composite field. Ensure that the setting + # is reflected, and that proto.HasField('composite') and + # proto.composite.HasField('scalar') now both return True. + new_val = 20 + setattr(composite_field, scalar_field_name, new_val) + self.assertEqual(new_val, getattr(composite_field, scalar_field_name)) + # Hold on to a reference to the current composite_field object. + old_composite_field = composite_field + # Assert that the has methods now return true. + self.assertTrue(composite_field.HasField(scalar_field_name)) + self.assertTrue(proto.HasField(composite_field_name)) + + # Now call the clear method... + proto.ClearField(composite_field_name) + + # ...and ensure that the "has" bits are all back to False... + composite_field = getattr(proto, composite_field_name) + self.assertTrue(not composite_field.HasField(scalar_field_name)) + self.assertTrue(not proto.HasField(composite_field_name)) + # ...and ensure that the scalar field has returned to its default. + self.assertEqual(0, getattr(composite_field, scalar_field_name)) + + # Finally, ensure that modifications to the old composite field object + # don't have any effect on the parent. + # + # (NOTE that when we clear the composite field in the parent, we actually + # don't recursively clear down the tree. Instead, we just disconnect the + # cleared composite from the tree.) + self.assertTrue(old_composite_field is not composite_field) + setattr(old_composite_field, scalar_field_name, new_val) + self.assertTrue(not composite_field.HasField(scalar_field_name)) + self.assertTrue(not proto.HasField(composite_field_name)) + self.assertEqual(0, getattr(composite_field, scalar_field_name)) + + # Test simple, single-level nesting when we set a scalar. + TestCompositeHasBits('optionalgroup', 'a') + TestCompositeHasBits('optional_nested_message', 'bb') + TestCompositeHasBits('optional_foreign_message', 'c') + TestCompositeHasBits('optional_import_message', 'd') + + def testReferencesToNestedMessage(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + del proto + # A previous version had a bug where this would raise an exception when + # hitting a now-dead weak reference. + nested.bb = 23 + + def testDisconnectingNestedMessageBeforeSettingField(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + proto.ClearField('optional_nested_message') # Should disconnect from parent + self.assertTrue(nested is not proto.optional_nested_message) + nested.bb = 23 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + + def testHasBitsWhenModifyingRepeatedFields(self): + # Test nesting when we add an element to a repeated field in a submessage. + proto = unittest_pb2.TestNestedMessageHasBits() + proto.optional_nested_message.nestedmessage_repeated_int32.append(5) + self.assertEqual( + [5], proto.optional_nested_message.nestedmessage_repeated_int32) + self.assertTrue(proto.HasField('optional_nested_message')) + + # Do the same test, but with a repeated composite field within the + # submessage. + proto.ClearField('optional_nested_message') + self.assertTrue(not proto.HasField('optional_nested_message')) + proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add() + self.assertTrue(proto.HasField('optional_nested_message')) + + def testHasBitsForManyLevelsOfNesting(self): + # Test nesting many levels deep. + recursive_proto = unittest_pb2.TestMutualRecursionA() + self.assertTrue(not recursive_proto.HasField('bb')) + self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32) + self.assertTrue(not recursive_proto.HasField('bb')) + recursive_proto.bb.a.bb.a.bb.optional_int32 = 5 + self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32) + self.assertTrue(recursive_proto.HasField('bb')) + self.assertTrue(recursive_proto.bb.HasField('a')) + self.assertTrue(recursive_proto.bb.a.HasField('bb')) + self.assertTrue(recursive_proto.bb.a.bb.HasField('a')) + self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb')) + self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a')) + self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32')) + + def testSingularListFields(self): + proto = unittest_pb2.TestAllTypes() + proto.optional_fixed32 = 1 + proto.optional_int32 = 5 + proto.optional_string = 'foo' + # Access sub-message but don't set it yet. + nested_message = proto.optional_nested_message + self.assertEqual( + [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), + (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), + (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ], + proto.ListFields()) + + proto.optional_nested_message.bb = 123 + self.assertEqual( + [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), + (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), + (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo'), + (proto.DESCRIPTOR.fields_by_name['optional_nested_message' ], + nested_message) ], + proto.ListFields()) + + def testRepeatedListFields(self): + proto = unittest_pb2.TestAllTypes() + proto.repeated_fixed32.append(1) + proto.repeated_int32.append(5) + proto.repeated_int32.append(11) + proto.repeated_string.extend(['foo', 'bar']) + proto.repeated_string.extend([]) + proto.repeated_string.append('baz') + proto.repeated_string.extend(str(x) for x in xrange(2)) + proto.optional_int32 = 21 + proto.repeated_bool # Access but don't set anything; should not be listed. + self.assertEqual( + [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21), + (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]), + (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]), + (proto.DESCRIPTOR.fields_by_name['repeated_string' ], + ['foo', 'bar', 'baz', '0', '1']) ], + proto.ListFields()) + + def testSingularListExtensions(self): + proto = unittest_pb2.TestAllExtensions() + proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1 + proto.Extensions[unittest_pb2.optional_int32_extension ] = 5 + proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo' + self.assertEqual( + [ (unittest_pb2.optional_int32_extension , 5), + (unittest_pb2.optional_fixed32_extension, 1), + (unittest_pb2.optional_string_extension , 'foo') ], + proto.ListFields()) + + def testRepeatedListExtensions(self): + proto = unittest_pb2.TestAllExtensions() + proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1) + proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5) + proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11) + proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo') + proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar') + proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz') + proto.Extensions[unittest_pb2.optional_int32_extension ] = 21 + self.assertEqual( + [ (unittest_pb2.optional_int32_extension , 21), + (unittest_pb2.repeated_int32_extension , [5, 11]), + (unittest_pb2.repeated_fixed32_extension, [1]), + (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ], + proto.ListFields()) + + def testListFieldsAndExtensions(self): + proto = unittest_pb2.TestFieldOrderings() + test_util.SetAllFieldsAndExtensions(proto) + unittest_pb2.my_extension_int + self.assertEqual( + [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1), + (unittest_pb2.my_extension_int , 23), + (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'), + (unittest_pb2.my_extension_string , 'bar'), + (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ], + proto.ListFields()) + + def testDefaultValues(self): + proto = unittest_pb2.TestAllTypes() + self.assertEqual(0, proto.optional_int32) + self.assertEqual(0, proto.optional_int64) + self.assertEqual(0, proto.optional_uint32) + self.assertEqual(0, proto.optional_uint64) + self.assertEqual(0, proto.optional_sint32) + self.assertEqual(0, proto.optional_sint64) + self.assertEqual(0, proto.optional_fixed32) + self.assertEqual(0, proto.optional_fixed64) + self.assertEqual(0, proto.optional_sfixed32) + self.assertEqual(0, proto.optional_sfixed64) + self.assertEqual(0.0, proto.optional_float) + self.assertEqual(0.0, proto.optional_double) + self.assertEqual(False, proto.optional_bool) + self.assertEqual('', proto.optional_string) + self.assertEqual('', proto.optional_bytes) + + self.assertEqual(41, proto.default_int32) + self.assertEqual(42, proto.default_int64) + self.assertEqual(43, proto.default_uint32) + self.assertEqual(44, proto.default_uint64) + self.assertEqual(-45, proto.default_sint32) + self.assertEqual(46, proto.default_sint64) + self.assertEqual(47, proto.default_fixed32) + self.assertEqual(48, proto.default_fixed64) + self.assertEqual(49, proto.default_sfixed32) + self.assertEqual(-50, proto.default_sfixed64) + self.assertEqual(51.5, proto.default_float) + self.assertEqual(52e3, proto.default_double) + self.assertEqual(True, proto.default_bool) + self.assertEqual('hello', proto.default_string) + self.assertEqual('world', proto.default_bytes) + self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) + self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) + self.assertEqual(unittest_import_pb2.IMPORT_BAR, + proto.default_import_enum) + + proto = unittest_pb2.TestExtremeDefaultValues() + self.assertEqual(u'\u1234', proto.utf8_string) + + def testHasFieldWithUnknownFieldName(self): + proto = unittest_pb2.TestAllTypes() + self.assertRaises(ValueError, proto.HasField, 'nonexistent_field') + + def testClearFieldWithUnknownFieldName(self): + proto = unittest_pb2.TestAllTypes() + self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') + + def testDisallowedAssignments(self): + # It's illegal to assign values directly to repeated fields + # or to nonrepeated composite fields. Ensure that this fails. + proto = unittest_pb2.TestAllTypes() + # Repeated fields. + self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10) + # Lists shouldn't work, either. + self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10]) + # Composite fields. + self.assertRaises(AttributeError, setattr, proto, + 'optional_nested_message', 23) + # Assignment to a repeated nested message field without specifying + # the index in the array of nested messages. + self.assertRaises(AttributeError, setattr, proto.repeated_nested_message, + 'bb', 34) + # Assignment to an attribute of a repeated field. + self.assertRaises(AttributeError, setattr, proto.repeated_float, + 'some_attribute', 34) + # proto.nonexistent_field = 23 should fail as well. + self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23) + + # TODO(robinson): Add type-safety check for enums. + def testSingleScalarTypeSafety(self): + proto = unittest_pb2.TestAllTypes() + self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1) + self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo') + self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) + self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) + + def testSingleScalarBoundsChecking(self): + def TestMinAndMaxIntegers(field_name, expected_min, expected_max): + pb = unittest_pb2.TestAllTypes() + setattr(pb, field_name, expected_min) + setattr(pb, field_name, expected_max) + self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1) + self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1) + + TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1) + TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) + TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) + TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) + TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1) + + def testRepeatedScalarTypeSafety(self): + proto = unittest_pb2.TestAllTypes() + self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) + self.assertRaises(TypeError, proto.repeated_int32.append, 'foo') + self.assertRaises(TypeError, proto.repeated_string, 10) + self.assertRaises(TypeError, proto.repeated_bytes, 10) + + proto.repeated_int32.append(10) + proto.repeated_int32[0] = 23 + self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) + self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') + + def testSingleScalarGettersAndSetters(self): + proto = unittest_pb2.TestAllTypes() + self.assertEqual(0, proto.optional_int32) + proto.optional_int32 = 1 + self.assertEqual(1, proto.optional_int32) + # TODO(robinson): Test all other scalar field types. + + def testSingleScalarClearField(self): + proto = unittest_pb2.TestAllTypes() + # Should be allowed to clear something that's not there (a no-op). + proto.ClearField('optional_int32') + proto.optional_int32 = 1 + self.assertTrue(proto.HasField('optional_int32')) + proto.ClearField('optional_int32') + self.assertEqual(0, proto.optional_int32) + self.assertTrue(not proto.HasField('optional_int32')) + # TODO(robinson): Test all other scalar field types. + + def testEnums(self): + proto = unittest_pb2.TestAllTypes() + self.assertEqual(1, proto.FOO) + self.assertEqual(1, unittest_pb2.TestAllTypes.FOO) + self.assertEqual(2, proto.BAR) + self.assertEqual(2, unittest_pb2.TestAllTypes.BAR) + self.assertEqual(3, proto.BAZ) + self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) + + def testRepeatedScalars(self): + proto = unittest_pb2.TestAllTypes() + + self.assertTrue(not proto.repeated_int32) + self.assertEqual(0, len(proto.repeated_int32)) + proto.repeated_int32.append(5) + proto.repeated_int32.append(10) + proto.repeated_int32.append(15) + self.assertTrue(proto.repeated_int32) + self.assertEqual(3, len(proto.repeated_int32)) + + self.assertEqual([5, 10, 15], proto.repeated_int32) + + # Test single retrieval. + self.assertEqual(5, proto.repeated_int32[0]) + self.assertEqual(15, proto.repeated_int32[-1]) + # Test out-of-bounds indices. + self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234) + self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234) + # Test incorrect types passed to __getitem__. + self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo') + self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None) + + # Test single assignment. + proto.repeated_int32[1] = 20 + self.assertEqual([5, 20, 15], proto.repeated_int32) + + # Test insertion. + proto.repeated_int32.insert(1, 25) + self.assertEqual([5, 25, 20, 15], proto.repeated_int32) + + # Test slice retrieval. + proto.repeated_int32.append(30) + self.assertEqual([25, 20, 15], proto.repeated_int32[1:4]) + self.assertEqual([5, 25, 20, 15, 30], proto.repeated_int32[:]) + + # Test slice assignment with an iterator + proto.repeated_int32[1:4] = (i for i in xrange(3)) + self.assertEqual([5, 0, 1, 2, 30], proto.repeated_int32) + + # Test slice assignment. + proto.repeated_int32[1:4] = [35, 40, 45] + self.assertEqual([5, 35, 40, 45, 30], proto.repeated_int32) + + # Test that we can use the field as an iterator. + result = [] + for i in proto.repeated_int32: + result.append(i) + self.assertEqual([5, 35, 40, 45, 30], result) + + # Test single deletion. + del proto.repeated_int32[2] + self.assertEqual([5, 35, 45, 30], proto.repeated_int32) + + # Test slice deletion. + del proto.repeated_int32[2:] + self.assertEqual([5, 35], proto.repeated_int32) + + # Test clearing. + proto.ClearField('repeated_int32') + self.assertTrue(not proto.repeated_int32) + self.assertEqual(0, len(proto.repeated_int32)) + + def testRepeatedScalarsRemove(self): + proto = unittest_pb2.TestAllTypes() + + self.assertTrue(not proto.repeated_int32) + self.assertEqual(0, len(proto.repeated_int32)) + proto.repeated_int32.append(5) + proto.repeated_int32.append(10) + proto.repeated_int32.append(5) + proto.repeated_int32.append(5) + + self.assertEqual(4, len(proto.repeated_int32)) + proto.repeated_int32.remove(5) + self.assertEqual(3, len(proto.repeated_int32)) + self.assertEqual(10, proto.repeated_int32[0]) + self.assertEqual(5, proto.repeated_int32[1]) + self.assertEqual(5, proto.repeated_int32[2]) + + proto.repeated_int32.remove(5) + self.assertEqual(2, len(proto.repeated_int32)) + self.assertEqual(10, proto.repeated_int32[0]) + self.assertEqual(5, proto.repeated_int32[1]) + + proto.repeated_int32.remove(10) + self.assertEqual(1, len(proto.repeated_int32)) + self.assertEqual(5, proto.repeated_int32[0]) + + # Remove a non-existent element. + self.assertRaises(ValueError, proto.repeated_int32.remove, 123) + + def testRepeatedComposites(self): + proto = unittest_pb2.TestAllTypes() + self.assertTrue(not proto.repeated_nested_message) + self.assertEqual(0, len(proto.repeated_nested_message)) + m0 = proto.repeated_nested_message.add() + m1 = proto.repeated_nested_message.add() + self.assertTrue(proto.repeated_nested_message) + self.assertEqual(2, len(proto.repeated_nested_message)) + self.assertIs([m0, m1], proto.repeated_nested_message) + self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage)) + + # Test out-of-bounds indices. + self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, + 1234) + self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, + -1234) + + # Test incorrect types passed to __getitem__. + self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, + 'foo') + self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, + None) + + # Test slice retrieval. + m2 = proto.repeated_nested_message.add() + m3 = proto.repeated_nested_message.add() + m4 = proto.repeated_nested_message.add() + self.assertIs([m1, m2, m3], proto.repeated_nested_message[1:4]) + self.assertIs([m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) + + # Test that we can use the field as an iterator. + result = [] + for i in proto.repeated_nested_message: + result.append(i) + self.assertIs([m0, m1, m2, m3, m4], result) + + # Test single deletion. + del proto.repeated_nested_message[2] + self.assertIs([m0, m1, m3, m4], proto.repeated_nested_message) + + # Test slice deletion. + del proto.repeated_nested_message[2:] + self.assertIs([m0, m1], proto.repeated_nested_message) + + # Test clearing. + proto.ClearField('repeated_nested_message') + self.assertTrue(not proto.repeated_nested_message) + self.assertEqual(0, len(proto.repeated_nested_message)) + + def testHandWrittenReflection(self): + # TODO(robinson): We probably need a better way to specify + # protocol types by hand. But then again, this isn't something + # we expect many people to do. Hmm. + FieldDescriptor = descriptor.FieldDescriptor + foo_field_descriptor = FieldDescriptor( + name='foo_field', full_name='MyProto.foo_field', + index=0, number=1, type=FieldDescriptor.TYPE_INT64, + cpp_type=FieldDescriptor.CPPTYPE_INT64, + label=FieldDescriptor.LABEL_OPTIONAL, default_value=0, + containing_type=None, message_type=None, enum_type=None, + is_extension=False, extension_scope=None, + options=descriptor_pb2.FieldOptions()) + mydescriptor = descriptor.Descriptor( + name='MyProto', full_name='MyProto', filename='ignored', + containing_type=None, nested_types=[], enum_types=[], + fields=[foo_field_descriptor], extensions=[], + options=descriptor_pb2.MessageOptions()) + class MyProtoClass(message.Message): + DESCRIPTOR = mydescriptor + __metaclass__ = reflection.GeneratedProtocolMessageType + myproto_instance = MyProtoClass() + self.assertEqual(0, myproto_instance.foo_field) + self.assertTrue(not myproto_instance.HasField('foo_field')) + myproto_instance.foo_field = 23 + self.assertEqual(23, myproto_instance.foo_field) + self.assertTrue(myproto_instance.HasField('foo_field')) + + def testTopLevelExtensionsForOptionalScalar(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.optional_int32_extension + self.assertTrue(not extendee_proto.HasExtension(extension)) + self.assertEqual(0, extendee_proto.Extensions[extension]) + # As with normal scalar fields, just doing a read doesn't actually set the + # "has" bit. + self.assertTrue(not extendee_proto.HasExtension(extension)) + # Actually set the thing. + extendee_proto.Extensions[extension] = 23 + self.assertEqual(23, extendee_proto.Extensions[extension]) + self.assertTrue(extendee_proto.HasExtension(extension)) + # Ensure that clearing works as well. + extendee_proto.ClearExtension(extension) + self.assertEqual(0, extendee_proto.Extensions[extension]) + self.assertTrue(not extendee_proto.HasExtension(extension)) + + def testTopLevelExtensionsForRepeatedScalar(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.repeated_string_extension + self.assertEqual(0, len(extendee_proto.Extensions[extension])) + extendee_proto.Extensions[extension].append('foo') + self.assertEqual(['foo'], extendee_proto.Extensions[extension]) + string_list = extendee_proto.Extensions[extension] + extendee_proto.ClearExtension(extension) + self.assertEqual(0, len(extendee_proto.Extensions[extension])) + self.assertTrue(string_list is not extendee_proto.Extensions[extension]) + # Shouldn't be allowed to do Extensions[extension] = 'a' + self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, + extension, 'a') + + def testTopLevelExtensionsForOptionalMessage(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.optional_foreign_message_extension + self.assertTrue(not extendee_proto.HasExtension(extension)) + self.assertEqual(0, extendee_proto.Extensions[extension].c) + # As with normal (non-extension) fields, merely reading from the + # thing shouldn't set the "has" bit. + self.assertTrue(not extendee_proto.HasExtension(extension)) + extendee_proto.Extensions[extension].c = 23 + self.assertEqual(23, extendee_proto.Extensions[extension].c) + self.assertTrue(extendee_proto.HasExtension(extension)) + # Save a reference here. + foreign_message = extendee_proto.Extensions[extension] + extendee_proto.ClearExtension(extension) + self.assertTrue(foreign_message is not extendee_proto.Extensions[extension]) + # Setting a field on foreign_message now shouldn't set + # any "has" bits on extendee_proto. + foreign_message.c = 42 + self.assertEqual(42, foreign_message.c) + self.assertTrue(foreign_message.HasField('c')) + self.assertTrue(not extendee_proto.HasExtension(extension)) + # Shouldn't be allowed to do Extensions[extension] = 'a' + self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, + extension, 'a') + + def testTopLevelExtensionsForRepeatedMessage(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.repeatedgroup_extension + self.assertEqual(0, len(extendee_proto.Extensions[extension])) + group = extendee_proto.Extensions[extension].add() + group.a = 23 + self.assertEqual(23, extendee_proto.Extensions[extension][0].a) + group.a = 42 + self.assertEqual(42, extendee_proto.Extensions[extension][0].a) + group_list = extendee_proto.Extensions[extension] + extendee_proto.ClearExtension(extension) + self.assertEqual(0, len(extendee_proto.Extensions[extension])) + self.assertTrue(group_list is not extendee_proto.Extensions[extension]) + # Shouldn't be allowed to do Extensions[extension] = 'a' + self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, + extension, 'a') + + def testNestedExtensions(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.TestRequired.single + + # We just test the non-repeated case. + self.assertTrue(not extendee_proto.HasExtension(extension)) + required = extendee_proto.Extensions[extension] + self.assertEqual(0, required.a) + self.assertTrue(not extendee_proto.HasExtension(extension)) + required.a = 23 + self.assertEqual(23, extendee_proto.Extensions[extension].a) + self.assertTrue(extendee_proto.HasExtension(extension)) + extendee_proto.ClearExtension(extension) + self.assertTrue(required is not extendee_proto.Extensions[extension]) + self.assertTrue(not extendee_proto.HasExtension(extension)) + + # If message A directly contains message B, and + # a.HasField('b') is currently False, then mutating any + # extension in B should change a.HasField('b') to True + # (and so on up the object tree). + def testHasBitsForAncestorsOfExtendedMessage(self): + # Optional scalar extension. + toplevel = more_extensions_pb2.TopLevelMessage() + self.assertTrue(not toplevel.HasField('submessage')) + self.assertEqual(0, toplevel.submessage.Extensions[ + more_extensions_pb2.optional_int_extension]) + self.assertTrue(not toplevel.HasField('submessage')) + toplevel.submessage.Extensions[ + more_extensions_pb2.optional_int_extension] = 23 + self.assertEqual(23, toplevel.submessage.Extensions[ + more_extensions_pb2.optional_int_extension]) + self.assertTrue(toplevel.HasField('submessage')) + + # Repeated scalar extension. + toplevel = more_extensions_pb2.TopLevelMessage() + self.assertTrue(not toplevel.HasField('submessage')) + self.assertEqual([], toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_int_extension]) + self.assertTrue(not toplevel.HasField('submessage')) + toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_int_extension].append(23) + self.assertEqual([23], toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_int_extension]) + self.assertTrue(toplevel.HasField('submessage')) + + # Optional message extension. + toplevel = more_extensions_pb2.TopLevelMessage() + self.assertTrue(not toplevel.HasField('submessage')) + self.assertEqual(0, toplevel.submessage.Extensions[ + more_extensions_pb2.optional_message_extension].foreign_message_int) + self.assertTrue(not toplevel.HasField('submessage')) + toplevel.submessage.Extensions[ + more_extensions_pb2.optional_message_extension].foreign_message_int = 23 + self.assertEqual(23, toplevel.submessage.Extensions[ + more_extensions_pb2.optional_message_extension].foreign_message_int) + self.assertTrue(toplevel.HasField('submessage')) + + # Repeated message extension. + toplevel = more_extensions_pb2.TopLevelMessage() + self.assertTrue(not toplevel.HasField('submessage')) + self.assertEqual(0, len(toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_message_extension])) + self.assertTrue(not toplevel.HasField('submessage')) + foreign = toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_message_extension].add() + self.assertTrue(foreign is toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_message_extension][0]) + self.assertTrue(toplevel.HasField('submessage')) + + def testDisconnectionAfterClearingEmptyMessage(self): + toplevel = more_extensions_pb2.TopLevelMessage() + extendee_proto = toplevel.submessage + extension = more_extensions_pb2.optional_message_extension + extension_proto = extendee_proto.Extensions[extension] + extendee_proto.ClearExtension(extension) + extension_proto.foreign_message_int = 23 + + self.assertTrue(extension_proto is not extendee_proto.Extensions[extension]) + + def testExtensionFailureModes(self): + extendee_proto = unittest_pb2.TestAllExtensions() + + # Try non-extension-handle arguments to HasExtension, + # ClearExtension(), and Extensions[]... + self.assertRaises(KeyError, extendee_proto.HasExtension, 1234) + self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234) + self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234) + self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5) + + # Try something that *is* an extension handle, just not for + # this message... + unknown_handle = more_extensions_pb2.optional_int_extension + self.assertRaises(KeyError, extendee_proto.HasExtension, + unknown_handle) + self.assertRaises(KeyError, extendee_proto.ClearExtension, + unknown_handle) + self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, + unknown_handle) + self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, + unknown_handle, 5) + + # Try call HasExtension() with a valid handle, but for a + # *repeated* field. (Just as with non-extension repeated + # fields, Has*() isn't supported for extension repeated fields). + self.assertRaises(KeyError, extendee_proto.HasExtension, + unittest_pb2.repeated_string_extension) + + def testStaticParseFrom(self): + proto1 = unittest_pb2.TestAllTypes() + test_util.SetAllFields(proto1) + + string1 = proto1.SerializeToString() + proto2 = unittest_pb2.TestAllTypes.FromString(string1) + + # Messages should be equal. + self.assertEqual(proto2, proto1) + + def testMergeFromSingularField(self): + # Test merge with just a singular field. + proto1 = unittest_pb2.TestAllTypes() + proto1.optional_int32 = 1 + + proto2 = unittest_pb2.TestAllTypes() + # This shouldn't get overwritten. + proto2.optional_string = 'value' + + proto2.MergeFrom(proto1) + self.assertEqual(1, proto2.optional_int32) + self.assertEqual('value', proto2.optional_string) + + def testMergeFromRepeatedField(self): + # Test merge with just a repeated field. + proto1 = unittest_pb2.TestAllTypes() + proto1.repeated_int32.append(1) + proto1.repeated_int32.append(2) + + proto2 = unittest_pb2.TestAllTypes() + proto2.repeated_int32.append(0) + proto2.MergeFrom(proto1) + + self.assertEqual(0, proto2.repeated_int32[0]) + self.assertEqual(1, proto2.repeated_int32[1]) + self.assertEqual(2, proto2.repeated_int32[2]) + + def testMergeFromOptionalGroup(self): + # Test merge with an optional group. + proto1 = unittest_pb2.TestAllTypes() + proto1.optionalgroup.a = 12 + proto2 = unittest_pb2.TestAllTypes() + proto2.MergeFrom(proto1) + self.assertEqual(12, proto2.optionalgroup.a) + + def testMergeFromRepeatedNestedMessage(self): + # Test merge with a repeated nested message. + proto1 = unittest_pb2.TestAllTypes() + m = proto1.repeated_nested_message.add() + m.bb = 123 + m = proto1.repeated_nested_message.add() + m.bb = 321 + + proto2 = unittest_pb2.TestAllTypes() + m = proto2.repeated_nested_message.add() + m.bb = 999 + proto2.MergeFrom(proto1) + self.assertEqual(999, proto2.repeated_nested_message[0].bb) + self.assertEqual(123, proto2.repeated_nested_message[1].bb) + self.assertEqual(321, proto2.repeated_nested_message[2].bb) + + def testMergeFromAllFields(self): + # With all fields set. + proto1 = unittest_pb2.TestAllTypes() + test_util.SetAllFields(proto1) + proto2 = unittest_pb2.TestAllTypes() + proto2.MergeFrom(proto1) + + # Messages should be equal. + self.assertEqual(proto2, proto1) + + # Serialized string should be equal too. + string1 = proto1.SerializeToString() + string2 = proto2.SerializeToString() + self.assertEqual(string1, string2) + + def testMergeFromExtensionsSingular(self): + proto1 = unittest_pb2.TestAllExtensions() + proto1.Extensions[unittest_pb2.optional_int32_extension] = 1 + + proto2 = unittest_pb2.TestAllExtensions() + proto2.MergeFrom(proto1) + self.assertEqual( + 1, proto2.Extensions[unittest_pb2.optional_int32_extension]) + + def testMergeFromExtensionsRepeated(self): + proto1 = unittest_pb2.TestAllExtensions() + proto1.Extensions[unittest_pb2.repeated_int32_extension].append(1) + proto1.Extensions[unittest_pb2.repeated_int32_extension].append(2) + + proto2 = unittest_pb2.TestAllExtensions() + proto2.Extensions[unittest_pb2.repeated_int32_extension].append(0) + proto2.MergeFrom(proto1) + self.assertEqual( + 3, len(proto2.Extensions[unittest_pb2.repeated_int32_extension])) + self.assertEqual( + 0, proto2.Extensions[unittest_pb2.repeated_int32_extension][0]) + self.assertEqual( + 1, proto2.Extensions[unittest_pb2.repeated_int32_extension][1]) + self.assertEqual( + 2, proto2.Extensions[unittest_pb2.repeated_int32_extension][2]) + + def testMergeFromExtensionsNestedMessage(self): + proto1 = unittest_pb2.TestAllExtensions() + ext1 = proto1.Extensions[ + unittest_pb2.repeated_nested_message_extension] + m = ext1.add() + m.bb = 222 + m = ext1.add() + m.bb = 333 + + proto2 = unittest_pb2.TestAllExtensions() + ext2 = proto2.Extensions[ + unittest_pb2.repeated_nested_message_extension] + m = ext2.add() + m.bb = 111 + + proto2.MergeFrom(proto1) + ext2 = proto2.Extensions[ + unittest_pb2.repeated_nested_message_extension] + self.assertEqual(3, len(ext2)) + self.assertEqual(111, ext2[0].bb) + self.assertEqual(222, ext2[1].bb) + self.assertEqual(333, ext2[2].bb) + + def testCopyFromSingularField(self): + # Test copy with just a singular field. + proto1 = unittest_pb2.TestAllTypes() + proto1.optional_int32 = 1 + proto1.optional_string = 'important-text' + + proto2 = unittest_pb2.TestAllTypes() + proto2.optional_string = 'value' + + proto2.CopyFrom(proto1) + self.assertEqual(1, proto2.optional_int32) + self.assertEqual('important-text', proto2.optional_string) + + def testCopyFromRepeatedField(self): + # Test copy with a repeated field. + proto1 = unittest_pb2.TestAllTypes() + proto1.repeated_int32.append(1) + proto1.repeated_int32.append(2) + + proto2 = unittest_pb2.TestAllTypes() + proto2.repeated_int32.append(0) + proto2.CopyFrom(proto1) + + self.assertEqual(1, proto2.repeated_int32[0]) + self.assertEqual(2, proto2.repeated_int32[1]) + + def testCopyFromAllFields(self): + # With all fields set. + proto1 = unittest_pb2.TestAllTypes() + test_util.SetAllFields(proto1) + proto2 = unittest_pb2.TestAllTypes() + proto2.CopyFrom(proto1) + + # Messages should be equal. + self.assertEqual(proto2, proto1) + + # Serialized string should be equal too. + string1 = proto1.SerializeToString() + string2 = proto2.SerializeToString() + self.assertEqual(string1, string2) + + def testCopyFromSelf(self): + proto1 = unittest_pb2.TestAllTypes() + proto1.repeated_int32.append(1) + proto1.optional_int32 = 2 + proto1.optional_string = 'important-text' + + proto1.CopyFrom(proto1) + self.assertEqual(1, proto1.repeated_int32[0]) + self.assertEqual(2, proto1.optional_int32) + self.assertEqual('important-text', proto1.optional_string) + + def testClear(self): + proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(proto) + # Clear the message. + proto.Clear() + self.assertEquals(proto.ByteSize(), 0) + empty_proto = unittest_pb2.TestAllTypes() + self.assertEquals(proto, empty_proto) + + # Test if extensions which were set are cleared. + proto = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(proto) + # Clear the message. + proto.Clear() + self.assertEquals(proto.ByteSize(), 0) + empty_proto = unittest_pb2.TestAllExtensions() + self.assertEquals(proto, empty_proto) + + def assertInitialized(self, proto): + self.assertTrue(proto.IsInitialized()) + # Neither method should raise an exception. + proto.SerializeToString() + proto.SerializePartialToString() + + def assertNotInitialized(self, proto): + self.assertFalse(proto.IsInitialized()) + self.assertRaises(message.EncodeError, proto.SerializeToString) + # "Partial" serialization doesn't care if message is uninitialized. + proto.SerializePartialToString() + + def testIsInitialized(self): + # Trivial cases - all optional fields and extensions. + proto = unittest_pb2.TestAllTypes() + self.assertInitialized(proto) + proto = unittest_pb2.TestAllExtensions() + self.assertInitialized(proto) + + # The case of uninitialized required fields. + proto = unittest_pb2.TestRequired() + self.assertNotInitialized(proto) + proto.a = proto.b = proto.c = 2 + self.assertInitialized(proto) + + # The case of uninitialized submessage. + proto = unittest_pb2.TestRequiredForeign() + self.assertInitialized(proto) + proto.optional_message.a = 1 + self.assertNotInitialized(proto) + proto.optional_message.b = 0 + proto.optional_message.c = 0 + self.assertInitialized(proto) + + # Uninitialized repeated submessage. + message1 = proto.repeated_message.add() + self.assertNotInitialized(proto) + message1.a = message1.b = message1.c = 0 + self.assertInitialized(proto) + + # Uninitialized repeated group in an extension. + proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.TestRequired.multi + message1 = proto.Extensions[extension].add() + message2 = proto.Extensions[extension].add() + self.assertNotInitialized(proto) + message1.a = 1 + message1.b = 1 + message1.c = 1 + self.assertNotInitialized(proto) + message2.a = 2 + message2.b = 2 + message2.c = 2 + self.assertInitialized(proto) + + # Uninitialized nonrepeated message in an extension. + proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.TestRequired.single + proto.Extensions[extension].a = 1 + self.assertNotInitialized(proto) + proto.Extensions[extension].b = 2 + proto.Extensions[extension].c = 3 + self.assertInitialized(proto) + + # Try passing an errors list. + errors = [] + proto = unittest_pb2.TestRequired() + self.assertFalse(proto.IsInitialized(errors)) + self.assertEqual(errors, ['a', 'b', 'c']) + + def testStringUTF8Encoding(self): + proto = unittest_pb2.TestAllTypes() + + # Assignment of a unicode object to a field of type 'bytes' is not allowed. + self.assertRaises(TypeError, + setattr, proto, 'optional_bytes', u'unicode object') + + # Check that the default value is of python's 'unicode' type. + self.assertEqual(type(proto.optional_string), unicode) + + proto.optional_string = unicode('Testing') + self.assertEqual(proto.optional_string, str('Testing')) + + # Assign a value of type 'str' which can be encoded in UTF-8. + proto.optional_string = str('Testing') + self.assertEqual(proto.optional_string, unicode('Testing')) + + # Values of type 'str' are also accepted as long as they can be encoded in + # UTF-8. + self.assertEqual(type(proto.optional_string), str) + + # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII. + self.assertRaises(ValueError, + setattr, proto, 'optional_string', str('a\x80a')) + # Assign a 'str' object which contains a UTF-8 encoded string. + self.assertRaises(ValueError, + setattr, proto, 'optional_string', 'Тест') + # No exception thrown. + proto.optional_string = 'abc' + + def testStringUTF8Serialization(self): + proto = unittest_mset_pb2.TestMessageSet() + extension_message = unittest_mset_pb2.TestMessageSetExtension2 + extension = extension_message.message_set_extension + + test_utf8 = u'Тест' + test_utf8_bytes = test_utf8.encode('utf-8') + + # 'Test' in another language, using UTF-8 charset. + proto.Extensions[extension].str = test_utf8 + + # Serialize using the MessageSet wire format (this is specified in the + # .proto file). + serialized = proto.SerializeToString() + + # Check byte size. + self.assertEqual(proto.ByteSize(), len(serialized)) + + raw = unittest_mset_pb2.RawMessageSet() + raw.MergeFromString(serialized) + + message2 = unittest_mset_pb2.TestMessageSetExtension2() + + self.assertEqual(1, len(raw.item)) + # Check that the type_id is the same as the tag ID in the .proto file. + self.assertEqual(raw.item[0].type_id, 1547769) + + # Check the actually bytes on the wire. + self.assertTrue( + raw.item[0].message.endswith(test_utf8_bytes)) + message2.MergeFromString(raw.item[0].message) + + self.assertEqual(type(message2.str), unicode) + self.assertEqual(message2.str, test_utf8) + + # How about if the bytes on the wire aren't a valid UTF-8 encoded string. + bytes = raw.item[0].message.replace( + test_utf8_bytes, len(test_utf8_bytes) * '\xff') + self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes) + + def testEmptyNestedMessage(self): + proto = unittest_pb2.TestAllTypes() + proto.optional_nested_message.MergeFrom( + unittest_pb2.TestAllTypes.NestedMessage()) + self.assertTrue(proto.HasField('optional_nested_message')) + + proto = unittest_pb2.TestAllTypes() + proto.optional_nested_message.CopyFrom( + unittest_pb2.TestAllTypes.NestedMessage()) + self.assertTrue(proto.HasField('optional_nested_message')) + + proto = unittest_pb2.TestAllTypes() + proto.optional_nested_message.MergeFromString('') + self.assertTrue(proto.HasField('optional_nested_message')) + + proto = unittest_pb2.TestAllTypes() + proto.optional_nested_message.ParseFromString('') + self.assertTrue(proto.HasField('optional_nested_message')) + + serialized = proto.SerializeToString() + proto2 = unittest_pb2.TestAllTypes() + proto2.MergeFromString(serialized) + self.assertTrue(proto2.HasField('optional_nested_message')) + + def testSetInParent(self): + proto = unittest_pb2.TestAllTypes() + self.assertFalse(proto.HasField('optionalgroup')) + proto.optionalgroup.SetInParent() + self.assertTrue(proto.HasField('optionalgroup')) + + +# Since we had so many tests for protocol buffer equality, we broke these out +# into separate TestCase classes. + + +class TestAllTypesEqualityTest(unittest.TestCase): + + def setUp(self): + self.first_proto = unittest_pb2.TestAllTypes() + self.second_proto = unittest_pb2.TestAllTypes() + + def testSelfEquality(self): + self.assertEqual(self.first_proto, self.first_proto) + + def testEmptyProtosEqual(self): + self.assertEqual(self.first_proto, self.second_proto) + + +class FullProtosEqualityTest(unittest.TestCase): + + """Equality tests using completely-full protos as a starting point.""" + + def setUp(self): + self.first_proto = unittest_pb2.TestAllTypes() + self.second_proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(self.first_proto) + test_util.SetAllFields(self.second_proto) + + def testNoneNotEqual(self): + self.assertNotEqual(self.first_proto, None) + self.assertNotEqual(None, self.second_proto) + + def testNotEqualToOtherMessage(self): + third_proto = unittest_pb2.TestRequired() + self.assertNotEqual(self.first_proto, third_proto) + self.assertNotEqual(third_proto, self.second_proto) + + def testAllFieldsFilledEquality(self): + self.assertEqual(self.first_proto, self.second_proto) + + def testNonRepeatedScalar(self): + # Nonrepeated scalar field change should cause inequality. + self.first_proto.optional_int32 += 1 + self.assertNotEqual(self.first_proto, self.second_proto) + # ...as should clearing a field. + self.first_proto.ClearField('optional_int32') + self.assertNotEqual(self.first_proto, self.second_proto) + + def testNonRepeatedComposite(self): + # Change a nonrepeated composite field. + self.first_proto.optional_nested_message.bb += 1 + self.assertNotEqual(self.first_proto, self.second_proto) + self.first_proto.optional_nested_message.bb -= 1 + self.assertEqual(self.first_proto, self.second_proto) + # Clear a field in the nested message. + self.first_proto.optional_nested_message.ClearField('bb') + self.assertNotEqual(self.first_proto, self.second_proto) + self.first_proto.optional_nested_message.bb = ( + self.second_proto.optional_nested_message.bb) + self.assertEqual(self.first_proto, self.second_proto) + # Remove the nested message entirely. + self.first_proto.ClearField('optional_nested_message') + self.assertNotEqual(self.first_proto, self.second_proto) + + def testRepeatedScalar(self): + # Change a repeated scalar field. + self.first_proto.repeated_int32.append(5) + self.assertNotEqual(self.first_proto, self.second_proto) + self.first_proto.ClearField('repeated_int32') + self.assertNotEqual(self.first_proto, self.second_proto) + + def testRepeatedComposite(self): + # Change value within a repeated composite field. + self.first_proto.repeated_nested_message[0].bb += 1 + self.assertNotEqual(self.first_proto, self.second_proto) + self.first_proto.repeated_nested_message[0].bb -= 1 + self.assertEqual(self.first_proto, self.second_proto) + # Add a value to a repeated composite field. + self.first_proto.repeated_nested_message.add() + self.assertNotEqual(self.first_proto, self.second_proto) + self.second_proto.repeated_nested_message.add() + self.assertEqual(self.first_proto, self.second_proto) + + def testNonRepeatedScalarHasBits(self): + # Ensure that we test "has" bits as well as value for + # nonrepeated scalar field. + self.first_proto.ClearField('optional_int32') + self.second_proto.optional_int32 = 0 + self.assertNotEqual(self.first_proto, self.second_proto) + + def testNonRepeatedCompositeHasBits(self): + # Ensure that we test "has" bits as well as value for + # nonrepeated composite field. + self.first_proto.ClearField('optional_nested_message') + self.second_proto.optional_nested_message.ClearField('bb') + self.assertNotEqual(self.first_proto, self.second_proto) + # TODO(robinson): Replace next two lines with method + # to set the "has" bit without changing the value, + # if/when such a method exists. + self.first_proto.optional_nested_message.bb = 0 + self.first_proto.optional_nested_message.ClearField('bb') + self.assertEqual(self.first_proto, self.second_proto) + + +class ExtensionEqualityTest(unittest.TestCase): + + def testExtensionEquality(self): + first_proto = unittest_pb2.TestAllExtensions() + second_proto = unittest_pb2.TestAllExtensions() + self.assertEqual(first_proto, second_proto) + test_util.SetAllExtensions(first_proto) + self.assertNotEqual(first_proto, second_proto) + test_util.SetAllExtensions(second_proto) + self.assertEqual(first_proto, second_proto) + + # Ensure that we check value equality. + first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1 + self.assertNotEqual(first_proto, second_proto) + first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1 + self.assertEqual(first_proto, second_proto) + + # Ensure that we also look at "has" bits. + first_proto.ClearExtension(unittest_pb2.optional_int32_extension) + second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 + self.assertNotEqual(first_proto, second_proto) + first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 + self.assertEqual(first_proto, second_proto) + + # Ensure that differences in cached values + # don't matter if "has" bits are both false. + first_proto = unittest_pb2.TestAllExtensions() + second_proto = unittest_pb2.TestAllExtensions() + self.assertEqual( + 0, first_proto.Extensions[unittest_pb2.optional_int32_extension]) + self.assertEqual(first_proto, second_proto) + + +class MutualRecursionEqualityTest(unittest.TestCase): + + def testEqualityWithMutualRecursion(self): + first_proto = unittest_pb2.TestMutualRecursionA() + second_proto = unittest_pb2.TestMutualRecursionA() + self.assertEqual(first_proto, second_proto) + first_proto.bb.a.bb.optional_int32 = 23 + self.assertNotEqual(first_proto, second_proto) + second_proto.bb.a.bb.optional_int32 = 23 + self.assertEqual(first_proto, second_proto) + + +class ByteSizeTest(unittest.TestCase): + + def setUp(self): + self.proto = unittest_pb2.TestAllTypes() + self.extended_proto = more_extensions_pb2.ExtendedMessage() + self.packed_proto = unittest_pb2.TestPackedTypes() + self.packed_extended_proto = unittest_pb2.TestPackedExtensions() + + def Size(self): + return self.proto.ByteSize() + + def testEmptyMessage(self): + self.assertEqual(0, self.proto.ByteSize()) + + def testVarints(self): + def Test(i, expected_varint_size): + self.proto.Clear() + self.proto.optional_int64 = i + # Add one to the varint size for the tag info + # for tag 1. + self.assertEqual(expected_varint_size + 1, self.Size()) + Test(0, 1) + Test(1, 1) + for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)): + Test((1 << i) - 1, num_bytes) + Test(-1, 10) + Test(-2, 10) + Test(-(1 << 63), 10) + + def testStrings(self): + self.proto.optional_string = '' + # Need one byte for tag info (tag #14), and one byte for length. + self.assertEqual(2, self.Size()) + + self.proto.optional_string = 'abc' + # Need one byte for tag info (tag #14), and one byte for length. + self.assertEqual(2 + len(self.proto.optional_string), self.Size()) + + self.proto.optional_string = 'x' * 128 + # Need one byte for tag info (tag #14), and TWO bytes for length. + self.assertEqual(3 + len(self.proto.optional_string), self.Size()) + + def testOtherNumerics(self): + self.proto.optional_fixed32 = 1234 + # One byte for tag and 4 bytes for fixed32. + self.assertEqual(5, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + self.proto.optional_fixed64 = 1234 + # One byte for tag and 8 bytes for fixed64. + self.assertEqual(9, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + self.proto.optional_float = 1.234 + # One byte for tag and 4 bytes for float. + self.assertEqual(5, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + self.proto.optional_double = 1.234 + # One byte for tag and 8 bytes for float. + self.assertEqual(9, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + self.proto.optional_sint32 = 64 + # One byte for tag and 2 bytes for zig-zag-encoded 64. + self.assertEqual(3, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + def testComposites(self): + # 3 bytes. + self.proto.optional_nested_message.bb = (1 << 14) + # Plus one byte for bb tag. + # Plus 1 byte for optional_nested_message serialized size. + # Plus two bytes for optional_nested_message tag. + self.assertEqual(3 + 1 + 1 + 2, self.Size()) + + def testGroups(self): + # 4 bytes. + self.proto.optionalgroup.a = (1 << 21) + # Plus two bytes for |a| tag. + # Plus 2 * two bytes for START_GROUP and END_GROUP tags. + self.assertEqual(4 + 2 + 2*2, self.Size()) + + def testRepeatedScalars(self): + self.proto.repeated_int32.append(10) # 1 byte. + self.proto.repeated_int32.append(128) # 2 bytes. + # Also need 2 bytes for each entry for tag. + self.assertEqual(1 + 2 + 2*2, self.Size()) + + def testRepeatedScalarsExtend(self): + self.proto.repeated_int32.extend([10, 128]) # 3 bytes. + # Also need 2 bytes for each entry for tag. + self.assertEqual(1 + 2 + 2*2, self.Size()) + + def testRepeatedScalarsRemove(self): + self.proto.repeated_int32.append(10) # 1 byte. + self.proto.repeated_int32.append(128) # 2 bytes. + # Also need 2 bytes for each entry for tag. + self.assertEqual(1 + 2 + 2*2, self.Size()) + self.proto.repeated_int32.remove(128) + self.assertEqual(1 + 2, self.Size()) + + def testRepeatedComposites(self): + # Empty message. 2 bytes tag plus 1 byte length. + foreign_message_0 = self.proto.repeated_nested_message.add() + # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. + foreign_message_1 = self.proto.repeated_nested_message.add() + foreign_message_1.bb = 7 + self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) + + def testRepeatedCompositesDelete(self): + # Empty message. 2 bytes tag plus 1 byte length. + foreign_message_0 = self.proto.repeated_nested_message.add() + # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. + foreign_message_1 = self.proto.repeated_nested_message.add() + foreign_message_1.bb = 9 + self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) + + # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. + del self.proto.repeated_nested_message[0] + self.assertEqual(2 + 1 + 1 + 1, self.Size()) + + # Now add a new message. + foreign_message_2 = self.proto.repeated_nested_message.add() + foreign_message_2.bb = 12 + + # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. + # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. + self.assertEqual(2 + 1 + 1 + 1 + 2 + 1 + 1 + 1, self.Size()) + + # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. + del self.proto.repeated_nested_message[1] + self.assertEqual(2 + 1 + 1 + 1, self.Size()) + + del self.proto.repeated_nested_message[0] + self.assertEqual(0, self.Size()) + + def testRepeatedGroups(self): + # 2-byte START_GROUP plus 2-byte END_GROUP. + group_0 = self.proto.repeatedgroup.add() + # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a| + # plus 2-byte END_GROUP. + group_1 = self.proto.repeatedgroup.add() + group_1.a = 7 + self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size()) + + def testExtensions(self): + proto = unittest_pb2.TestAllExtensions() + self.assertEqual(0, proto.ByteSize()) + extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte. + proto.Extensions[extension] = 23 + # 1 byte for tag, 1 byte for value. + self.assertEqual(2, proto.ByteSize()) + + def testCacheInvalidationForNonrepeatedScalar(self): + # Test non-extension. + self.proto.optional_int32 = 1 + self.assertEqual(2, self.proto.ByteSize()) + self.proto.optional_int32 = 128 + self.assertEqual(3, self.proto.ByteSize()) + self.proto.ClearField('optional_int32') + self.assertEqual(0, self.proto.ByteSize()) + + # Test within extension. + extension = more_extensions_pb2.optional_int_extension + self.extended_proto.Extensions[extension] = 1 + self.assertEqual(2, self.extended_proto.ByteSize()) + self.extended_proto.Extensions[extension] = 128 + self.assertEqual(3, self.extended_proto.ByteSize()) + self.extended_proto.ClearExtension(extension) + self.assertEqual(0, self.extended_proto.ByteSize()) + + def testCacheInvalidationForRepeatedScalar(self): + # Test non-extension. + self.proto.repeated_int32.append(1) + self.assertEqual(3, self.proto.ByteSize()) + self.proto.repeated_int32.append(1) + self.assertEqual(6, self.proto.ByteSize()) + self.proto.repeated_int32[1] = 128 + self.assertEqual(7, self.proto.ByteSize()) + self.proto.ClearField('repeated_int32') + self.assertEqual(0, self.proto.ByteSize()) + + # Test within extension. + extension = more_extensions_pb2.repeated_int_extension + repeated = self.extended_proto.Extensions[extension] + repeated.append(1) + self.assertEqual(2, self.extended_proto.ByteSize()) + repeated.append(1) + self.assertEqual(4, self.extended_proto.ByteSize()) + repeated[1] = 128 + self.assertEqual(5, self.extended_proto.ByteSize()) + self.extended_proto.ClearExtension(extension) + self.assertEqual(0, self.extended_proto.ByteSize()) + + def testCacheInvalidationForNonrepeatedMessage(self): + # Test non-extension. + self.proto.optional_foreign_message.c = 1 + self.assertEqual(5, self.proto.ByteSize()) + self.proto.optional_foreign_message.c = 128 + self.assertEqual(6, self.proto.ByteSize()) + self.proto.optional_foreign_message.ClearField('c') + self.assertEqual(3, self.proto.ByteSize()) + self.proto.ClearField('optional_foreign_message') + self.assertEqual(0, self.proto.ByteSize()) + child = self.proto.optional_foreign_message + self.proto.ClearField('optional_foreign_message') + child.c = 128 + self.assertEqual(0, self.proto.ByteSize()) + + # Test within extension. + extension = more_extensions_pb2.optional_message_extension + child = self.extended_proto.Extensions[extension] + self.assertEqual(0, self.extended_proto.ByteSize()) + child.foreign_message_int = 1 + self.assertEqual(4, self.extended_proto.ByteSize()) + child.foreign_message_int = 128 + self.assertEqual(5, self.extended_proto.ByteSize()) + self.extended_proto.ClearExtension(extension) + self.assertEqual(0, self.extended_proto.ByteSize()) + + def testCacheInvalidationForRepeatedMessage(self): + # Test non-extension. + child0 = self.proto.repeated_foreign_message.add() + self.assertEqual(3, self.proto.ByteSize()) + self.proto.repeated_foreign_message.add() + self.assertEqual(6, self.proto.ByteSize()) + child0.c = 1 + self.assertEqual(8, self.proto.ByteSize()) + self.proto.ClearField('repeated_foreign_message') + self.assertEqual(0, self.proto.ByteSize()) + + # Test within extension. + extension = more_extensions_pb2.repeated_message_extension + child_list = self.extended_proto.Extensions[extension] + child0 = child_list.add() + self.assertEqual(2, self.extended_proto.ByteSize()) + child_list.add() + self.assertEqual(4, self.extended_proto.ByteSize()) + child0.foreign_message_int = 1 + self.assertEqual(6, self.extended_proto.ByteSize()) + child0.ClearField('foreign_message_int') + self.assertEqual(4, self.extended_proto.ByteSize()) + self.extended_proto.ClearExtension(extension) + self.assertEqual(0, self.extended_proto.ByteSize()) + + def testPackedRepeatedScalars(self): + self.assertEqual(0, self.packed_proto.ByteSize()) + + self.packed_proto.packed_int32.append(10) # 1 byte. + self.packed_proto.packed_int32.append(128) # 2 bytes. + # The tag is 2 bytes (the field number is 90), and the varint + # storing the length is 1 byte. + int_size = 1 + 2 + 3 + self.assertEqual(int_size, self.packed_proto.ByteSize()) + + self.packed_proto.packed_double.append(4.2) # 8 bytes + self.packed_proto.packed_double.append(3.25) # 8 bytes + # 2 more tag bytes, 1 more length byte. + double_size = 8 + 8 + 3 + self.assertEqual(int_size+double_size, self.packed_proto.ByteSize()) + + self.packed_proto.ClearField('packed_int32') + self.assertEqual(double_size, self.packed_proto.ByteSize()) + + def testPackedExtensions(self): + self.assertEqual(0, self.packed_extended_proto.ByteSize()) + extension = self.packed_extended_proto.Extensions[ + unittest_pb2.packed_fixed32_extension] + extension.extend([1, 2, 3, 4]) # 16 bytes + # Tag is 3 bytes. + self.assertEqual(19, self.packed_extended_proto.ByteSize()) + + +# TODO(robinson): We need cross-language serialization consistency tests. +# Issues to be sure to cover include: +# * Handling of unrecognized tags ("uninterpreted_bytes"). +# * Handling of MessageSets. +# * Consistent ordering of tags in the wire format, +# including ordering between extensions and non-extension +# fields. +# * Consistent serialization of negative numbers, especially +# negative int32s. +# * Handling of empty submessages (with and without "has" +# bits set). + +class SerializationTest(unittest.TestCase): + + def testSerializeEmtpyMessage(self): + first_proto = unittest_pb2.TestAllTypes() + second_proto = unittest_pb2.TestAllTypes() + serialized = first_proto.SerializeToString() + self.assertEqual(first_proto.ByteSize(), len(serialized)) + second_proto.MergeFromString(serialized) + self.assertEqual(first_proto, second_proto) + + def testSerializeAllFields(self): + first_proto = unittest_pb2.TestAllTypes() + second_proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(first_proto) + serialized = first_proto.SerializeToString() + self.assertEqual(first_proto.ByteSize(), len(serialized)) + second_proto.MergeFromString(serialized) + self.assertEqual(first_proto, second_proto) + + def testSerializeAllExtensions(self): + first_proto = unittest_pb2.TestAllExtensions() + second_proto = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(first_proto) + serialized = first_proto.SerializeToString() + second_proto.MergeFromString(serialized) + self.assertEqual(first_proto, second_proto) + + def testSerializeNegativeValues(self): + first_proto = unittest_pb2.TestAllTypes() + + first_proto.optional_int32 = -1 + first_proto.optional_int64 = -(2 << 40) + first_proto.optional_sint32 = -3 + first_proto.optional_sint64 = -(4 << 40) + first_proto.optional_sfixed32 = -5 + first_proto.optional_sfixed64 = -(6 << 40) + + second_proto = unittest_pb2.TestAllTypes.FromString( + first_proto.SerializeToString()) + + self.assertEqual(first_proto, second_proto) + + def testParseTruncated(self): + first_proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(first_proto) + serialized = first_proto.SerializeToString() + + for truncation_point in xrange(len(serialized) + 1): + try: + second_proto = unittest_pb2.TestAllTypes() + unknown_fields = unittest_pb2.TestEmptyMessage() + pos = second_proto._InternalParse(serialized, 0, truncation_point) + # If we didn't raise an error then we read exactly the amount expected. + self.assertEqual(truncation_point, pos) + + # Parsing to unknown fields should not throw if parsing to known fields + # did not. + try: + pos2 = unknown_fields._InternalParse(serialized, 0, truncation_point) + self.assertEqual(truncation_point, pos2) + except message.DecodeError: + self.fail('Parsing unknown fields failed when parsing known fields ' + 'did not.') + except message.DecodeError: + # Parsing unknown fields should also fail. + self.assertRaises(message.DecodeError, unknown_fields._InternalParse, + serialized, 0, truncation_point) + + def testCanonicalSerializationOrder(self): + proto = more_messages_pb2.OutOfOrderFields() + # These are also their tag numbers. Even though we're setting these in + # reverse-tag order AND they're listed in reverse tag-order in the .proto + # file, they should nonetheless be serialized in tag order. + proto.optional_sint32 = 5 + proto.Extensions[more_messages_pb2.optional_uint64] = 4 + proto.optional_uint32 = 3 + proto.Extensions[more_messages_pb2.optional_int64] = 2 + proto.optional_int32 = 1 + serialized = proto.SerializeToString() + self.assertEqual(proto.ByteSize(), len(serialized)) + d = _MiniDecoder(serialized) + ReadTag = d.ReadFieldNumberAndWireType + self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(1, d.ReadInt32()) + self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(2, d.ReadInt64()) + self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(3, d.ReadUInt32()) + self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(4, d.ReadUInt64()) + self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(5, d.ReadSInt32()) + + def testCanonicalSerializationOrderSameAsCpp(self): + # Copy of the same test we use for C++. + proto = unittest_pb2.TestFieldOrderings() + test_util.SetAllFieldsAndExtensions(proto) + serialized = proto.SerializeToString() + test_util.ExpectAllFieldsAndExtensionsInOrder(serialized) + + def testMergeFromStringWhenFieldsAlreadySet(self): + first_proto = unittest_pb2.TestAllTypes() + first_proto.repeated_string.append('foobar') + first_proto.optional_int32 = 23 + first_proto.optional_nested_message.bb = 42 + serialized = first_proto.SerializeToString() + + second_proto = unittest_pb2.TestAllTypes() + second_proto.repeated_string.append('baz') + second_proto.optional_int32 = 100 + second_proto.optional_nested_message.bb = 999 + + second_proto.MergeFromString(serialized) + # Ensure that we append to repeated fields. + self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) + # Ensure that we overwrite nonrepeatd scalars. + self.assertEqual(23, second_proto.optional_int32) + # Ensure that we recursively call MergeFromString() on + # submessages. + self.assertEqual(42, second_proto.optional_nested_message.bb) + + def testMessageSetWireFormat(self): + proto = unittest_mset_pb2.TestMessageSet() + extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + extension_message2 = unittest_mset_pb2.TestMessageSetExtension2 + extension1 = extension_message1.message_set_extension + extension2 = extension_message2.message_set_extension + proto.Extensions[extension1].i = 123 + proto.Extensions[extension2].str = 'foo' + + # Serialize using the MessageSet wire format (this is specified in the + # .proto file). + serialized = proto.SerializeToString() + + raw = unittest_mset_pb2.RawMessageSet() + self.assertEqual(False, + raw.DESCRIPTOR.GetOptions().message_set_wire_format) + raw.MergeFromString(serialized) + self.assertEqual(2, len(raw.item)) + + message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1.MergeFromString(raw.item[0].message) + self.assertEqual(123, message1.i) + + message2 = unittest_mset_pb2.TestMessageSetExtension2() + message2.MergeFromString(raw.item[1].message) + self.assertEqual('foo', message2.str) + + # Deserialize using the MessageSet wire format. + proto2 = unittest_mset_pb2.TestMessageSet() + proto2.MergeFromString(serialized) + self.assertEqual(123, proto2.Extensions[extension1].i) + self.assertEqual('foo', proto2.Extensions[extension2].str) + + # Check byte size. + self.assertEqual(proto2.ByteSize(), len(serialized)) + self.assertEqual(proto.ByteSize(), len(serialized)) + + def testMessageSetWireFormatUnknownExtension(self): + # Create a message using the message set wire format with an unknown + # message. + raw = unittest_mset_pb2.RawMessageSet() + + # Add an item. + item = raw.item.add() + item.type_id = 1545008 + extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1.i = 12345 + item.message = message1.SerializeToString() + + # Add a second, unknown extension. + item = raw.item.add() + item.type_id = 1545009 + extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1.i = 12346 + item.message = message1.SerializeToString() + + # Add another unknown extension. + item = raw.item.add() + item.type_id = 1545010 + message1 = unittest_mset_pb2.TestMessageSetExtension2() + message1.str = 'foo' + item.message = message1.SerializeToString() + + serialized = raw.SerializeToString() + + # Parse message using the message set wire format. + proto = unittest_mset_pb2.TestMessageSet() + proto.MergeFromString(serialized) + + # Check that the message parsed well. + extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + extension1 = extension_message1.message_set_extension + self.assertEquals(12345, proto.Extensions[extension1].i) + + def testUnknownFields(self): + proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(proto) + + serialized = proto.SerializeToString() + + # The empty message should be parsable with all of the fields + # unknown. + proto2 = unittest_pb2.TestEmptyMessage() + + # Parsing this message should succeed. + proto2.MergeFromString(serialized) + + # Now test with a int64 field set. + proto = unittest_pb2.TestAllTypes() + proto.optional_int64 = 0x0fffffffffffffff + serialized = proto.SerializeToString() + # The empty message should be parsable with all of the fields + # unknown. + proto2 = unittest_pb2.TestEmptyMessage() + # Parsing this message should succeed. + proto2.MergeFromString(serialized) + + def _CheckRaises(self, exc_class, callable_obj, exception): + """This method checks if the excpetion type and message are as expected.""" + try: + callable_obj() + except exc_class, ex: + # Check if the exception message is the right one. + self.assertEqual(exception, str(ex)) + return + else: + raise self.failureException('%s not raised' % str(exc_class)) + + def testSerializeUninitialized(self): + proto = unittest_pb2.TestRequired() + self._CheckRaises( + message.EncodeError, + proto.SerializeToString, + 'Message is missing required fields: a,b,c') + # Shouldn't raise exceptions. + partial = proto.SerializePartialToString() + + proto.a = 1 + self._CheckRaises( + message.EncodeError, + proto.SerializeToString, + 'Message is missing required fields: b,c') + # Shouldn't raise exceptions. + partial = proto.SerializePartialToString() + + proto.b = 2 + self._CheckRaises( + message.EncodeError, + proto.SerializeToString, + 'Message is missing required fields: c') + # Shouldn't raise exceptions. + partial = proto.SerializePartialToString() + + proto.c = 3 + serialized = proto.SerializeToString() + # Shouldn't raise exceptions. + partial = proto.SerializePartialToString() + + proto2 = unittest_pb2.TestRequired() + proto2.MergeFromString(serialized) + self.assertEqual(1, proto2.a) + self.assertEqual(2, proto2.b) + self.assertEqual(3, proto2.c) + proto2.ParseFromString(partial) + self.assertEqual(1, proto2.a) + self.assertEqual(2, proto2.b) + self.assertEqual(3, proto2.c) + + def testSerializeUninitializedSubMessage(self): + proto = unittest_pb2.TestRequiredForeign() + + # Sub-message doesn't exist yet, so this succeeds. + proto.SerializeToString() + + proto.optional_message.a = 1 + self._CheckRaises( + message.EncodeError, + proto.SerializeToString, + 'Message is missing required fields: ' + 'optional_message.b,optional_message.c') + + proto.optional_message.b = 2 + proto.optional_message.c = 3 + proto.SerializeToString() + + proto.repeated_message.add().a = 1 + proto.repeated_message.add().b = 2 + self._CheckRaises( + message.EncodeError, + proto.SerializeToString, + 'Message is missing required fields: ' + 'repeated_message[0].b,repeated_message[0].c,' + 'repeated_message[1].a,repeated_message[1].c') + + proto.repeated_message[0].b = 2 + proto.repeated_message[0].c = 3 + proto.repeated_message[1].a = 1 + proto.repeated_message[1].c = 3 + proto.SerializeToString() + + def testSerializeAllPackedFields(self): + first_proto = unittest_pb2.TestPackedTypes() + second_proto = unittest_pb2.TestPackedTypes() + test_util.SetAllPackedFields(first_proto) + serialized = first_proto.SerializeToString() + self.assertEqual(first_proto.ByteSize(), len(serialized)) + bytes_read = second_proto.MergeFromString(serialized) + self.assertEqual(second_proto.ByteSize(), bytes_read) + self.assertEqual(first_proto, second_proto) + + def testSerializeAllPackedExtensions(self): + first_proto = unittest_pb2.TestPackedExtensions() + second_proto = unittest_pb2.TestPackedExtensions() + test_util.SetAllPackedExtensions(first_proto) + serialized = first_proto.SerializeToString() + bytes_read = second_proto.MergeFromString(serialized) + self.assertEqual(second_proto.ByteSize(), bytes_read) + self.assertEqual(first_proto, second_proto) + + def testMergePackedFromStringWhenSomeFieldsAlreadySet(self): + first_proto = unittest_pb2.TestPackedTypes() + first_proto.packed_int32.extend([1, 2]) + first_proto.packed_double.append(3.0) + serialized = first_proto.SerializeToString() + + second_proto = unittest_pb2.TestPackedTypes() + second_proto.packed_int32.append(3) + second_proto.packed_double.extend([1.0, 2.0]) + second_proto.packed_sint32.append(4) + + second_proto.MergeFromString(serialized) + self.assertEqual([3, 1, 2], second_proto.packed_int32) + self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) + self.assertEqual([4], second_proto.packed_sint32) + + def testPackedFieldsWireFormat(self): + proto = unittest_pb2.TestPackedTypes() + proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes + proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes + proto.packed_float.append(2.0) # 4 bytes, will be before double + serialized = proto.SerializeToString() + self.assertEqual(proto.ByteSize(), len(serialized)) + d = _MiniDecoder(serialized) + ReadTag = d.ReadFieldNumberAndWireType + self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) + self.assertEqual(1+1+1+2, d.ReadInt32()) + self.assertEqual(1, d.ReadInt32()) + self.assertEqual(2, d.ReadInt32()) + self.assertEqual(150, d.ReadInt32()) + self.assertEqual(3, d.ReadInt32()) + self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) + self.assertEqual(4, d.ReadInt32()) + self.assertEqual(2.0, d.ReadFloat()) + self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) + self.assertEqual(8+8, d.ReadInt32()) + self.assertEqual(1.0, d.ReadDouble()) + self.assertEqual(1000.0, d.ReadDouble()) + self.assertTrue(d.EndOfStream()) + + def testParsePackedFromUnpacked(self): + unpacked = unittest_pb2.TestUnpackedTypes() + test_util.SetAllUnpackedFields(unpacked) + packed = unittest_pb2.TestPackedTypes() + packed.MergeFromString(unpacked.SerializeToString()) + expected = unittest_pb2.TestPackedTypes() + test_util.SetAllPackedFields(expected) + self.assertEqual(expected, packed) + + def testParseUnpackedFromPacked(self): + packed = unittest_pb2.TestPackedTypes() + test_util.SetAllPackedFields(packed) + unpacked = unittest_pb2.TestUnpackedTypes() + unpacked.MergeFromString(packed.SerializeToString()) + expected = unittest_pb2.TestUnpackedTypes() + test_util.SetAllUnpackedFields(expected) + self.assertEqual(expected, unpacked) + + def testFieldNumbers(self): + proto = unittest_pb2.TestAllTypes() + self.assertEqual(unittest_pb2.TestAllTypes.NestedMessage.BB_FIELD_NUMBER, 1) + self.assertEqual(unittest_pb2.TestAllTypes.OPTIONAL_INT32_FIELD_NUMBER, 1) + self.assertEqual(unittest_pb2.TestAllTypes.OPTIONALGROUP_FIELD_NUMBER, 16) + self.assertEqual( + unittest_pb2.TestAllTypes.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, 18) + self.assertEqual( + unittest_pb2.TestAllTypes.OPTIONAL_NESTED_ENUM_FIELD_NUMBER, 21) + self.assertEqual(unittest_pb2.TestAllTypes.REPEATED_INT32_FIELD_NUMBER, 31) + self.assertEqual(unittest_pb2.TestAllTypes.REPEATEDGROUP_FIELD_NUMBER, 46) + self.assertEqual( + unittest_pb2.TestAllTypes.REPEATED_NESTED_MESSAGE_FIELD_NUMBER, 48) + self.assertEqual( + unittest_pb2.TestAllTypes.REPEATED_NESTED_ENUM_FIELD_NUMBER, 51) + + def testExtensionFieldNumbers(self): + self.assertEqual(unittest_pb2.TestRequired.single.number, 1000) + self.assertEqual(unittest_pb2.TestRequired.SINGLE_FIELD_NUMBER, 1000) + self.assertEqual(unittest_pb2.TestRequired.multi.number, 1001) + self.assertEqual(unittest_pb2.TestRequired.MULTI_FIELD_NUMBER, 1001) + self.assertEqual(unittest_pb2.optional_int32_extension.number, 1) + self.assertEqual(unittest_pb2.OPTIONAL_INT32_EXTENSION_FIELD_NUMBER, 1) + self.assertEqual(unittest_pb2.optionalgroup_extension.number, 16) + self.assertEqual(unittest_pb2.OPTIONALGROUP_EXTENSION_FIELD_NUMBER, 16) + self.assertEqual(unittest_pb2.optional_nested_message_extension.number, 18) + self.assertEqual( + unittest_pb2.OPTIONAL_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 18) + self.assertEqual(unittest_pb2.optional_nested_enum_extension.number, 21) + self.assertEqual(unittest_pb2.OPTIONAL_NESTED_ENUM_EXTENSION_FIELD_NUMBER, + 21) + self.assertEqual(unittest_pb2.repeated_int32_extension.number, 31) + self.assertEqual(unittest_pb2.REPEATED_INT32_EXTENSION_FIELD_NUMBER, 31) + self.assertEqual(unittest_pb2.repeatedgroup_extension.number, 46) + self.assertEqual(unittest_pb2.REPEATEDGROUP_EXTENSION_FIELD_NUMBER, 46) + self.assertEqual(unittest_pb2.repeated_nested_message_extension.number, 48) + self.assertEqual( + unittest_pb2.REPEATED_NESTED_MESSAGE_EXTENSION_FIELD_NUMBER, 48) + self.assertEqual(unittest_pb2.repeated_nested_enum_extension.number, 51) + self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER, + 51) + + def testInitKwargs(self): + proto = unittest_pb2.TestAllTypes( + optional_int32=1, + optional_string='foo', + optional_bool=True, + optional_bytes='bar', + optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), + optional_foreign_message=unittest_pb2.ForeignMessage(c=1), + optional_nested_enum=unittest_pb2.TestAllTypes.FOO, + optional_foreign_enum=unittest_pb2.FOREIGN_FOO, + repeated_int32=[1, 2, 3]) + self.assertTrue(proto.IsInitialized()) + self.assertTrue(proto.HasField('optional_int32')) + self.assertTrue(proto.HasField('optional_string')) + self.assertTrue(proto.HasField('optional_bool')) + self.assertTrue(proto.HasField('optional_bytes')) + self.assertTrue(proto.HasField('optional_nested_message')) + self.assertTrue(proto.HasField('optional_foreign_message')) + self.assertTrue(proto.HasField('optional_nested_enum')) + self.assertTrue(proto.HasField('optional_foreign_enum')) + self.assertEqual(1, proto.optional_int32) + self.assertEqual('foo', proto.optional_string) + self.assertEqual(True, proto.optional_bool) + self.assertEqual('bar', proto.optional_bytes) + self.assertEqual(1, proto.optional_nested_message.bb) + self.assertEqual(1, proto.optional_foreign_message.c) + self.assertEqual(unittest_pb2.TestAllTypes.FOO, + proto.optional_nested_enum) + self.assertEqual(unittest_pb2.FOREIGN_FOO, proto.optional_foreign_enum) + self.assertEqual([1, 2, 3], proto.repeated_int32) + + def testInitArgsUnknownFieldName(self): + def InitalizeEmptyMessageWithExtraKeywordArg(): + unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown') + self._CheckRaises(ValueError, + InitalizeEmptyMessageWithExtraKeywordArg, + 'Protocol message has no "unknown" field.') + + def testInitRequiredKwargs(self): + proto = unittest_pb2.TestRequired(a=1, b=1, c=1) + self.assertTrue(proto.IsInitialized()) + self.assertTrue(proto.HasField('a')) + self.assertTrue(proto.HasField('b')) + self.assertTrue(proto.HasField('c')) + self.assertTrue(not proto.HasField('dummy2')) + self.assertEqual(1, proto.a) + self.assertEqual(1, proto.b) + self.assertEqual(1, proto.c) + + def testInitRequiredForeignKwargs(self): + proto = unittest_pb2.TestRequiredForeign( + optional_message=unittest_pb2.TestRequired(a=1, b=1, c=1)) + self.assertTrue(proto.IsInitialized()) + self.assertTrue(proto.HasField('optional_message')) + self.assertTrue(proto.optional_message.IsInitialized()) + self.assertTrue(proto.optional_message.HasField('a')) + self.assertTrue(proto.optional_message.HasField('b')) + self.assertTrue(proto.optional_message.HasField('c')) + self.assertTrue(not proto.optional_message.HasField('dummy2')) + self.assertEqual(unittest_pb2.TestRequired(a=1, b=1, c=1), + proto.optional_message) + self.assertEqual(1, proto.optional_message.a) + self.assertEqual(1, proto.optional_message.b) + self.assertEqual(1, proto.optional_message.c) + + def testInitRepeatedKwargs(self): + proto = unittest_pb2.TestAllTypes(repeated_int32=[1, 2, 3]) + self.assertTrue(proto.IsInitialized()) + self.assertEqual(1, proto.repeated_int32[0]) + self.assertEqual(2, proto.repeated_int32[1]) + self.assertEqual(3, proto.repeated_int32[2]) + + +class OptionsTest(unittest.TestCase): + + def testMessageOptions(self): + proto = unittest_mset_pb2.TestMessageSet() + self.assertEqual(True, + proto.DESCRIPTOR.GetOptions().message_set_wire_format) + proto = unittest_pb2.TestAllTypes() + self.assertEqual(False, + proto.DESCRIPTOR.GetOptions().message_set_wire_format) + + def testPackedOptions(self): + proto = unittest_pb2.TestAllTypes() + proto.optional_int32 = 1 + proto.optional_double = 3.0 + for field_descriptor, _ in proto.ListFields(): + self.assertEqual(False, field_descriptor.GetOptions().packed) + + proto = unittest_pb2.TestPackedTypes() + proto.packed_int32.append(1) + proto.packed_double.append(3.0) + for field_descriptor, _ in proto.ListFields(): + self.assertEqual(True, field_descriptor.GetOptions().packed) + self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED, + field_descriptor.label) + + + +if __name__ == '__main__': + unittest.main() diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/service_reflection_test.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/service_reflection_test.py index d90c1788..e04f8252 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/service_reflection_test.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/service_reflection_test.py @@ -1,136 +1,136 @@ -#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Tests for google.protobuf.internal.service_reflection."""
-
-__author__ = '[email protected] (Petar Petrov)'
-
-import unittest
-from google.protobuf import unittest_pb2
-from google.protobuf import service_reflection
-from google.protobuf import service
-
-
-class FooUnitTest(unittest.TestCase):
-
- def testService(self):
- class MockRpcChannel(service.RpcChannel):
- def CallMethod(self, method, controller, request, response, callback):
- self.method = method
- self.controller = controller
- self.request = request
- callback(response)
-
- class MockRpcController(service.RpcController):
- def SetFailed(self, msg):
- self.failure_message = msg
-
- self.callback_response = None
-
- class MyService(unittest_pb2.TestService):
- pass
-
- self.callback_response = None
-
- def MyCallback(response):
- self.callback_response = response
-
- rpc_controller = MockRpcController()
- channel = MockRpcChannel()
- srvc = MyService()
- srvc.Foo(rpc_controller, unittest_pb2.FooRequest(), MyCallback)
- self.assertEqual('Method Foo not implemented.',
- rpc_controller.failure_message)
- self.assertEqual(None, self.callback_response)
-
- rpc_controller.failure_message = None
-
- service_descriptor = unittest_pb2.TestService.GetDescriptor()
- srvc.CallMethod(service_descriptor.methods[1], rpc_controller,
- unittest_pb2.BarRequest(), MyCallback)
- self.assertEqual('Method Bar not implemented.',
- rpc_controller.failure_message)
- self.assertEqual(None, self.callback_response)
-
- class MyServiceImpl(unittest_pb2.TestService):
- def Foo(self, rpc_controller, request, done):
- self.foo_called = True
- def Bar(self, rpc_controller, request, done):
- self.bar_called = True
-
- srvc = MyServiceImpl()
- rpc_controller.failure_message = None
- srvc.Foo(rpc_controller, unittest_pb2.FooRequest(), MyCallback)
- self.assertEqual(None, rpc_controller.failure_message)
- self.assertEqual(True, srvc.foo_called)
-
- rpc_controller.failure_message = None
- srvc.CallMethod(service_descriptor.methods[1], rpc_controller,
- unittest_pb2.BarRequest(), MyCallback)
- self.assertEqual(None, rpc_controller.failure_message)
- self.assertEqual(True, srvc.bar_called)
-
- def testServiceStub(self):
- class MockRpcChannel(service.RpcChannel):
- def CallMethod(self, method, controller, request,
- response_class, callback):
- self.method = method
- self.controller = controller
- self.request = request
- callback(response_class())
-
- self.callback_response = None
-
- def MyCallback(response):
- self.callback_response = response
-
- channel = MockRpcChannel()
- stub = unittest_pb2.TestService_Stub(channel)
- rpc_controller = 'controller'
- request = 'request'
-
- # GetDescriptor now static, still works as instance method for compatability
- self.assertEqual(unittest_pb2.TestService_Stub.GetDescriptor(),
- stub.GetDescriptor())
-
- # Invoke method.
- stub.Foo(rpc_controller, request, MyCallback)
-
- self.assertTrue(isinstance(self.callback_response,
- unittest_pb2.FooResponse))
- self.assertEqual(request, channel.request)
- self.assertEqual(rpc_controller, channel.controller)
- self.assertEqual(stub.GetDescriptor().methods[0], channel.method)
-
-
-if __name__ == '__main__':
- unittest.main()
+#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.internal.service_reflection.""" + +__author__ = '[email protected] (Petar Petrov)' + +import unittest +from google.protobuf import unittest_pb2 +from google.protobuf import service_reflection +from google.protobuf import service + + +class FooUnitTest(unittest.TestCase): + + def testService(self): + class MockRpcChannel(service.RpcChannel): + def CallMethod(self, method, controller, request, response, callback): + self.method = method + self.controller = controller + self.request = request + callback(response) + + class MockRpcController(service.RpcController): + def SetFailed(self, msg): + self.failure_message = msg + + self.callback_response = None + + class MyService(unittest_pb2.TestService): + pass + + self.callback_response = None + + def MyCallback(response): + self.callback_response = response + + rpc_controller = MockRpcController() + channel = MockRpcChannel() + srvc = MyService() + srvc.Foo(rpc_controller, unittest_pb2.FooRequest(), MyCallback) + self.assertEqual('Method Foo not implemented.', + rpc_controller.failure_message) + self.assertEqual(None, self.callback_response) + + rpc_controller.failure_message = None + + service_descriptor = unittest_pb2.TestService.GetDescriptor() + srvc.CallMethod(service_descriptor.methods[1], rpc_controller, + unittest_pb2.BarRequest(), MyCallback) + self.assertEqual('Method Bar not implemented.', + rpc_controller.failure_message) + self.assertEqual(None, self.callback_response) + + class MyServiceImpl(unittest_pb2.TestService): + def Foo(self, rpc_controller, request, done): + self.foo_called = True + def Bar(self, rpc_controller, request, done): + self.bar_called = True + + srvc = MyServiceImpl() + rpc_controller.failure_message = None + srvc.Foo(rpc_controller, unittest_pb2.FooRequest(), MyCallback) + self.assertEqual(None, rpc_controller.failure_message) + self.assertEqual(True, srvc.foo_called) + + rpc_controller.failure_message = None + srvc.CallMethod(service_descriptor.methods[1], rpc_controller, + unittest_pb2.BarRequest(), MyCallback) + self.assertEqual(None, rpc_controller.failure_message) + self.assertEqual(True, srvc.bar_called) + + def testServiceStub(self): + class MockRpcChannel(service.RpcChannel): + def CallMethod(self, method, controller, request, + response_class, callback): + self.method = method + self.controller = controller + self.request = request + callback(response_class()) + + self.callback_response = None + + def MyCallback(response): + self.callback_response = response + + channel = MockRpcChannel() + stub = unittest_pb2.TestService_Stub(channel) + rpc_controller = 'controller' + request = 'request' + + # GetDescriptor now static, still works as instance method for compatability + self.assertEqual(unittest_pb2.TestService_Stub.GetDescriptor(), + stub.GetDescriptor()) + + # Invoke method. + stub.Foo(rpc_controller, request, MyCallback) + + self.assertTrue(isinstance(self.callback_response, + unittest_pb2.FooResponse)) + self.assertEqual(request, channel.request) + self.assertEqual(rpc_controller, channel.controller) + self.assertEqual(stub.GetDescriptor().methods[0], channel.method) + + +if __name__ == '__main__': + unittest.main() diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/test_util.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/test_util.py index bf47fe5f..1df16194 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/test_util.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/test_util.py @@ -1,635 +1,635 @@ -# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Utilities for Python proto2 tests.
-
-This is intentionally modeled on C++ code in
-//google/protobuf/test_util.*.
-"""
-
-__author__ = '[email protected] (Will Robinson)'
-
-import os.path
-
-from google.protobuf import unittest_import_pb2
-from google.protobuf import unittest_pb2
-
-
-def SetAllFields(message):
- """Sets every field in the message to a unique value.
-
- Args:
- message: A unittest_pb2.TestAllTypes instance.
- """
-
- #
- # Optional fields.
- #
-
- message.optional_int32 = 101
- message.optional_int64 = 102
- message.optional_uint32 = 103
- message.optional_uint64 = 104
- message.optional_sint32 = 105
- message.optional_sint64 = 106
- message.optional_fixed32 = 107
- message.optional_fixed64 = 108
- message.optional_sfixed32 = 109
- message.optional_sfixed64 = 110
- message.optional_float = 111
- message.optional_double = 112
- message.optional_bool = True
- # TODO(robinson): Firmly spec out and test how
- # protos interact with unicode. One specific example:
- # what happens if we change the literal below to
- # u'115'? What *should* happen? Still some discussion
- # to finish with Kenton about bytes vs. strings
- # and forcing everything to be utf8. :-/
- message.optional_string = '115'
- message.optional_bytes = '116'
-
- message.optionalgroup.a = 117
- message.optional_nested_message.bb = 118
- message.optional_foreign_message.c = 119
- message.optional_import_message.d = 120
-
- message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ
- message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ
- message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ
-
- message.optional_string_piece = '124'
- message.optional_cord = '125'
-
- #
- # Repeated fields.
- #
-
- message.repeated_int32.append(201)
- message.repeated_int64.append(202)
- message.repeated_uint32.append(203)
- message.repeated_uint64.append(204)
- message.repeated_sint32.append(205)
- message.repeated_sint64.append(206)
- message.repeated_fixed32.append(207)
- message.repeated_fixed64.append(208)
- message.repeated_sfixed32.append(209)
- message.repeated_sfixed64.append(210)
- message.repeated_float.append(211)
- message.repeated_double.append(212)
- message.repeated_bool.append(True)
- message.repeated_string.append('215')
- message.repeated_bytes.append('216')
-
- message.repeatedgroup.add().a = 217
- message.repeated_nested_message.add().bb = 218
- message.repeated_foreign_message.add().c = 219
- message.repeated_import_message.add().d = 220
-
- message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR)
- message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR)
- message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR)
-
- message.repeated_string_piece.append('224')
- message.repeated_cord.append('225')
-
- # Add a second one of each field.
- message.repeated_int32.append(301)
- message.repeated_int64.append(302)
- message.repeated_uint32.append(303)
- message.repeated_uint64.append(304)
- message.repeated_sint32.append(305)
- message.repeated_sint64.append(306)
- message.repeated_fixed32.append(307)
- message.repeated_fixed64.append(308)
- message.repeated_sfixed32.append(309)
- message.repeated_sfixed64.append(310)
- message.repeated_float.append(311)
- message.repeated_double.append(312)
- message.repeated_bool.append(False)
- message.repeated_string.append('315')
- message.repeated_bytes.append('316')
-
- message.repeatedgroup.add().a = 317
- message.repeated_nested_message.add().bb = 318
- message.repeated_foreign_message.add().c = 319
- message.repeated_import_message.add().d = 320
-
- message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ)
- message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ)
- message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ)
-
- message.repeated_string_piece.append('324')
- message.repeated_cord.append('325')
-
- #
- # Fields that have defaults.
- #
-
- message.default_int32 = 401
- message.default_int64 = 402
- message.default_uint32 = 403
- message.default_uint64 = 404
- message.default_sint32 = 405
- message.default_sint64 = 406
- message.default_fixed32 = 407
- message.default_fixed64 = 408
- message.default_sfixed32 = 409
- message.default_sfixed64 = 410
- message.default_float = 411
- message.default_double = 412
- message.default_bool = False
- message.default_string = '415'
- message.default_bytes = '416'
-
- message.default_nested_enum = unittest_pb2.TestAllTypes.FOO
- message.default_foreign_enum = unittest_pb2.FOREIGN_FOO
- message.default_import_enum = unittest_import_pb2.IMPORT_FOO
-
- message.default_string_piece = '424'
- message.default_cord = '425'
-
-
-def SetAllExtensions(message):
- """Sets every extension in the message to a unique value.
-
- Args:
- message: A unittest_pb2.TestAllExtensions instance.
- """
-
- extensions = message.Extensions
- pb2 = unittest_pb2
- import_pb2 = unittest_import_pb2
-
- #
- # Optional fields.
- #
-
- extensions[pb2.optional_int32_extension] = 101
- extensions[pb2.optional_int64_extension] = 102
- extensions[pb2.optional_uint32_extension] = 103
- extensions[pb2.optional_uint64_extension] = 104
- extensions[pb2.optional_sint32_extension] = 105
- extensions[pb2.optional_sint64_extension] = 106
- extensions[pb2.optional_fixed32_extension] = 107
- extensions[pb2.optional_fixed64_extension] = 108
- extensions[pb2.optional_sfixed32_extension] = 109
- extensions[pb2.optional_sfixed64_extension] = 110
- extensions[pb2.optional_float_extension] = 111
- extensions[pb2.optional_double_extension] = 112
- extensions[pb2.optional_bool_extension] = True
- extensions[pb2.optional_string_extension] = '115'
- extensions[pb2.optional_bytes_extension] = '116'
-
- extensions[pb2.optionalgroup_extension].a = 117
- extensions[pb2.optional_nested_message_extension].bb = 118
- extensions[pb2.optional_foreign_message_extension].c = 119
- extensions[pb2.optional_import_message_extension].d = 120
-
- extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ
- extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ
- extensions[pb2.optional_foreign_enum_extension] = pb2.FOREIGN_BAZ
- extensions[pb2.optional_import_enum_extension] = import_pb2.IMPORT_BAZ
-
- extensions[pb2.optional_string_piece_extension] = '124'
- extensions[pb2.optional_cord_extension] = '125'
-
- #
- # Repeated fields.
- #
-
- extensions[pb2.repeated_int32_extension].append(201)
- extensions[pb2.repeated_int64_extension].append(202)
- extensions[pb2.repeated_uint32_extension].append(203)
- extensions[pb2.repeated_uint64_extension].append(204)
- extensions[pb2.repeated_sint32_extension].append(205)
- extensions[pb2.repeated_sint64_extension].append(206)
- extensions[pb2.repeated_fixed32_extension].append(207)
- extensions[pb2.repeated_fixed64_extension].append(208)
- extensions[pb2.repeated_sfixed32_extension].append(209)
- extensions[pb2.repeated_sfixed64_extension].append(210)
- extensions[pb2.repeated_float_extension].append(211)
- extensions[pb2.repeated_double_extension].append(212)
- extensions[pb2.repeated_bool_extension].append(True)
- extensions[pb2.repeated_string_extension].append('215')
- extensions[pb2.repeated_bytes_extension].append('216')
-
- extensions[pb2.repeatedgroup_extension].add().a = 217
- extensions[pb2.repeated_nested_message_extension].add().bb = 218
- extensions[pb2.repeated_foreign_message_extension].add().c = 219
- extensions[pb2.repeated_import_message_extension].add().d = 220
-
- extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAR)
- extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAR)
- extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAR)
-
- extensions[pb2.repeated_string_piece_extension].append('224')
- extensions[pb2.repeated_cord_extension].append('225')
-
- # Append a second one of each field.
- extensions[pb2.repeated_int32_extension].append(301)
- extensions[pb2.repeated_int64_extension].append(302)
- extensions[pb2.repeated_uint32_extension].append(303)
- extensions[pb2.repeated_uint64_extension].append(304)
- extensions[pb2.repeated_sint32_extension].append(305)
- extensions[pb2.repeated_sint64_extension].append(306)
- extensions[pb2.repeated_fixed32_extension].append(307)
- extensions[pb2.repeated_fixed64_extension].append(308)
- extensions[pb2.repeated_sfixed32_extension].append(309)
- extensions[pb2.repeated_sfixed64_extension].append(310)
- extensions[pb2.repeated_float_extension].append(311)
- extensions[pb2.repeated_double_extension].append(312)
- extensions[pb2.repeated_bool_extension].append(False)
- extensions[pb2.repeated_string_extension].append('315')
- extensions[pb2.repeated_bytes_extension].append('316')
-
- extensions[pb2.repeatedgroup_extension].add().a = 317
- extensions[pb2.repeated_nested_message_extension].add().bb = 318
- extensions[pb2.repeated_foreign_message_extension].add().c = 319
- extensions[pb2.repeated_import_message_extension].add().d = 320
-
- extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAZ)
- extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAZ)
- extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAZ)
-
- extensions[pb2.repeated_string_piece_extension].append('324')
- extensions[pb2.repeated_cord_extension].append('325')
-
- #
- # Fields with defaults.
- #
-
- extensions[pb2.default_int32_extension] = 401
- extensions[pb2.default_int64_extension] = 402
- extensions[pb2.default_uint32_extension] = 403
- extensions[pb2.default_uint64_extension] = 404
- extensions[pb2.default_sint32_extension] = 405
- extensions[pb2.default_sint64_extension] = 406
- extensions[pb2.default_fixed32_extension] = 407
- extensions[pb2.default_fixed64_extension] = 408
- extensions[pb2.default_sfixed32_extension] = 409
- extensions[pb2.default_sfixed64_extension] = 410
- extensions[pb2.default_float_extension] = 411
- extensions[pb2.default_double_extension] = 412
- extensions[pb2.default_bool_extension] = False
- extensions[pb2.default_string_extension] = '415'
- extensions[pb2.default_bytes_extension] = '416'
-
- extensions[pb2.default_nested_enum_extension] = pb2.TestAllTypes.FOO
- extensions[pb2.default_foreign_enum_extension] = pb2.FOREIGN_FOO
- extensions[pb2.default_import_enum_extension] = import_pb2.IMPORT_FOO
-
- extensions[pb2.default_string_piece_extension] = '424'
- extensions[pb2.default_cord_extension] = '425'
-
-
-def SetAllFieldsAndExtensions(message):
- """Sets every field and extension in the message to a unique value.
-
- Args:
- message: A unittest_pb2.TestAllExtensions message.
- """
- message.my_int = 1
- message.my_string = 'foo'
- message.my_float = 1.0
- message.Extensions[unittest_pb2.my_extension_int] = 23
- message.Extensions[unittest_pb2.my_extension_string] = 'bar'
-
-
-def ExpectAllFieldsAndExtensionsInOrder(serialized):
- """Ensures that serialized is the serialization we expect for a message
- filled with SetAllFieldsAndExtensions(). (Specifically, ensures that the
- serialization is in canonical, tag-number order).
- """
- my_extension_int = unittest_pb2.my_extension_int
- my_extension_string = unittest_pb2.my_extension_string
- expected_strings = []
- message = unittest_pb2.TestFieldOrderings()
- message.my_int = 1 # Field 1.
- expected_strings.append(message.SerializeToString())
- message.Clear()
- message.Extensions[my_extension_int] = 23 # Field 5.
- expected_strings.append(message.SerializeToString())
- message.Clear()
- message.my_string = 'foo' # Field 11.
- expected_strings.append(message.SerializeToString())
- message.Clear()
- message.Extensions[my_extension_string] = 'bar' # Field 50.
- expected_strings.append(message.SerializeToString())
- message.Clear()
- message.my_float = 1.0
- expected_strings.append(message.SerializeToString())
- message.Clear()
- expected = ''.join(expected_strings)
-
- if expected != serialized:
- raise ValueError('Expected %r, found %r' % (expected, serialized))
-
-
-def ExpectAllFieldsSet(test_case, message):
- """Check all fields for correct values have after Set*Fields() is called."""
- test_case.assertTrue(message.HasField('optional_int32'))
- test_case.assertTrue(message.HasField('optional_int64'))
- test_case.assertTrue(message.HasField('optional_uint32'))
- test_case.assertTrue(message.HasField('optional_uint64'))
- test_case.assertTrue(message.HasField('optional_sint32'))
- test_case.assertTrue(message.HasField('optional_sint64'))
- test_case.assertTrue(message.HasField('optional_fixed32'))
- test_case.assertTrue(message.HasField('optional_fixed64'))
- test_case.assertTrue(message.HasField('optional_sfixed32'))
- test_case.assertTrue(message.HasField('optional_sfixed64'))
- test_case.assertTrue(message.HasField('optional_float'))
- test_case.assertTrue(message.HasField('optional_double'))
- test_case.assertTrue(message.HasField('optional_bool'))
- test_case.assertTrue(message.HasField('optional_string'))
- test_case.assertTrue(message.HasField('optional_bytes'))
-
- test_case.assertTrue(message.HasField('optionalgroup'))
- test_case.assertTrue(message.HasField('optional_nested_message'))
- test_case.assertTrue(message.HasField('optional_foreign_message'))
- test_case.assertTrue(message.HasField('optional_import_message'))
-
- test_case.assertTrue(message.optionalgroup.HasField('a'))
- test_case.assertTrue(message.optional_nested_message.HasField('bb'))
- test_case.assertTrue(message.optional_foreign_message.HasField('c'))
- test_case.assertTrue(message.optional_import_message.HasField('d'))
-
- test_case.assertTrue(message.HasField('optional_nested_enum'))
- test_case.assertTrue(message.HasField('optional_foreign_enum'))
- test_case.assertTrue(message.HasField('optional_import_enum'))
-
- test_case.assertTrue(message.HasField('optional_string_piece'))
- test_case.assertTrue(message.HasField('optional_cord'))
-
- test_case.assertEqual(101, message.optional_int32)
- test_case.assertEqual(102, message.optional_int64)
- test_case.assertEqual(103, message.optional_uint32)
- test_case.assertEqual(104, message.optional_uint64)
- test_case.assertEqual(105, message.optional_sint32)
- test_case.assertEqual(106, message.optional_sint64)
- test_case.assertEqual(107, message.optional_fixed32)
- test_case.assertEqual(108, message.optional_fixed64)
- test_case.assertEqual(109, message.optional_sfixed32)
- test_case.assertEqual(110, message.optional_sfixed64)
- test_case.assertEqual(111, message.optional_float)
- test_case.assertEqual(112, message.optional_double)
- test_case.assertEqual(True, message.optional_bool)
- test_case.assertEqual('115', message.optional_string)
- test_case.assertEqual('116', message.optional_bytes)
-
- test_case.assertEqual(117, message.optionalgroup.a)
- test_case.assertEqual(118, message.optional_nested_message.bb)
- test_case.assertEqual(119, message.optional_foreign_message.c)
- test_case.assertEqual(120, message.optional_import_message.d)
-
- test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ,
- message.optional_nested_enum)
- test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
- message.optional_foreign_enum)
- test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
- message.optional_import_enum)
-
- # -----------------------------------------------------------------
-
- test_case.assertEqual(2, len(message.repeated_int32))
- test_case.assertEqual(2, len(message.repeated_int64))
- test_case.assertEqual(2, len(message.repeated_uint32))
- test_case.assertEqual(2, len(message.repeated_uint64))
- test_case.assertEqual(2, len(message.repeated_sint32))
- test_case.assertEqual(2, len(message.repeated_sint64))
- test_case.assertEqual(2, len(message.repeated_fixed32))
- test_case.assertEqual(2, len(message.repeated_fixed64))
- test_case.assertEqual(2, len(message.repeated_sfixed32))
- test_case.assertEqual(2, len(message.repeated_sfixed64))
- test_case.assertEqual(2, len(message.repeated_float))
- test_case.assertEqual(2, len(message.repeated_double))
- test_case.assertEqual(2, len(message.repeated_bool))
- test_case.assertEqual(2, len(message.repeated_string))
- test_case.assertEqual(2, len(message.repeated_bytes))
-
- test_case.assertEqual(2, len(message.repeatedgroup))
- test_case.assertEqual(2, len(message.repeated_nested_message))
- test_case.assertEqual(2, len(message.repeated_foreign_message))
- test_case.assertEqual(2, len(message.repeated_import_message))
- test_case.assertEqual(2, len(message.repeated_nested_enum))
- test_case.assertEqual(2, len(message.repeated_foreign_enum))
- test_case.assertEqual(2, len(message.repeated_import_enum))
-
- test_case.assertEqual(2, len(message.repeated_string_piece))
- test_case.assertEqual(2, len(message.repeated_cord))
-
- test_case.assertEqual(201, message.repeated_int32[0])
- test_case.assertEqual(202, message.repeated_int64[0])
- test_case.assertEqual(203, message.repeated_uint32[0])
- test_case.assertEqual(204, message.repeated_uint64[0])
- test_case.assertEqual(205, message.repeated_sint32[0])
- test_case.assertEqual(206, message.repeated_sint64[0])
- test_case.assertEqual(207, message.repeated_fixed32[0])
- test_case.assertEqual(208, message.repeated_fixed64[0])
- test_case.assertEqual(209, message.repeated_sfixed32[0])
- test_case.assertEqual(210, message.repeated_sfixed64[0])
- test_case.assertEqual(211, message.repeated_float[0])
- test_case.assertEqual(212, message.repeated_double[0])
- test_case.assertEqual(True, message.repeated_bool[0])
- test_case.assertEqual('215', message.repeated_string[0])
- test_case.assertEqual('216', message.repeated_bytes[0])
-
- test_case.assertEqual(217, message.repeatedgroup[0].a)
- test_case.assertEqual(218, message.repeated_nested_message[0].bb)
- test_case.assertEqual(219, message.repeated_foreign_message[0].c)
- test_case.assertEqual(220, message.repeated_import_message[0].d)
-
- test_case.assertEqual(unittest_pb2.TestAllTypes.BAR,
- message.repeated_nested_enum[0])
- test_case.assertEqual(unittest_pb2.FOREIGN_BAR,
- message.repeated_foreign_enum[0])
- test_case.assertEqual(unittest_import_pb2.IMPORT_BAR,
- message.repeated_import_enum[0])
-
- test_case.assertEqual(301, message.repeated_int32[1])
- test_case.assertEqual(302, message.repeated_int64[1])
- test_case.assertEqual(303, message.repeated_uint32[1])
- test_case.assertEqual(304, message.repeated_uint64[1])
- test_case.assertEqual(305, message.repeated_sint32[1])
- test_case.assertEqual(306, message.repeated_sint64[1])
- test_case.assertEqual(307, message.repeated_fixed32[1])
- test_case.assertEqual(308, message.repeated_fixed64[1])
- test_case.assertEqual(309, message.repeated_sfixed32[1])
- test_case.assertEqual(310, message.repeated_sfixed64[1])
- test_case.assertEqual(311, message.repeated_float[1])
- test_case.assertEqual(312, message.repeated_double[1])
- test_case.assertEqual(False, message.repeated_bool[1])
- test_case.assertEqual('315', message.repeated_string[1])
- test_case.assertEqual('316', message.repeated_bytes[1])
-
- test_case.assertEqual(317, message.repeatedgroup[1].a)
- test_case.assertEqual(318, message.repeated_nested_message[1].bb)
- test_case.assertEqual(319, message.repeated_foreign_message[1].c)
- test_case.assertEqual(320, message.repeated_import_message[1].d)
-
- test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ,
- message.repeated_nested_enum[1])
- test_case.assertEqual(unittest_pb2.FOREIGN_BAZ,
- message.repeated_foreign_enum[1])
- test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ,
- message.repeated_import_enum[1])
-
- # -----------------------------------------------------------------
-
- test_case.assertTrue(message.HasField('default_int32'))
- test_case.assertTrue(message.HasField('default_int64'))
- test_case.assertTrue(message.HasField('default_uint32'))
- test_case.assertTrue(message.HasField('default_uint64'))
- test_case.assertTrue(message.HasField('default_sint32'))
- test_case.assertTrue(message.HasField('default_sint64'))
- test_case.assertTrue(message.HasField('default_fixed32'))
- test_case.assertTrue(message.HasField('default_fixed64'))
- test_case.assertTrue(message.HasField('default_sfixed32'))
- test_case.assertTrue(message.HasField('default_sfixed64'))
- test_case.assertTrue(message.HasField('default_float'))
- test_case.assertTrue(message.HasField('default_double'))
- test_case.assertTrue(message.HasField('default_bool'))
- test_case.assertTrue(message.HasField('default_string'))
- test_case.assertTrue(message.HasField('default_bytes'))
-
- test_case.assertTrue(message.HasField('default_nested_enum'))
- test_case.assertTrue(message.HasField('default_foreign_enum'))
- test_case.assertTrue(message.HasField('default_import_enum'))
-
- test_case.assertEqual(401, message.default_int32)
- test_case.assertEqual(402, message.default_int64)
- test_case.assertEqual(403, message.default_uint32)
- test_case.assertEqual(404, message.default_uint64)
- test_case.assertEqual(405, message.default_sint32)
- test_case.assertEqual(406, message.default_sint64)
- test_case.assertEqual(407, message.default_fixed32)
- test_case.assertEqual(408, message.default_fixed64)
- test_case.assertEqual(409, message.default_sfixed32)
- test_case.assertEqual(410, message.default_sfixed64)
- test_case.assertEqual(411, message.default_float)
- test_case.assertEqual(412, message.default_double)
- test_case.assertEqual(False, message.default_bool)
- test_case.assertEqual('415', message.default_string)
- test_case.assertEqual('416', message.default_bytes)
-
- test_case.assertEqual(unittest_pb2.TestAllTypes.FOO,
- message.default_nested_enum)
- test_case.assertEqual(unittest_pb2.FOREIGN_FOO,
- message.default_foreign_enum)
- test_case.assertEqual(unittest_import_pb2.IMPORT_FOO,
- message.default_import_enum)
-
-def GoldenFile(filename):
- """Finds the given golden file and returns a file object representing it."""
-
- # Search up the directory tree looking for the C++ protobuf source code.
- path = '.'
- while os.path.exists(path):
- if os.path.exists(os.path.join(path, 'src/google/protobuf')):
- # Found it. Load the golden file from the testdata directory.
- full_path = os.path.join(path, 'src/google/protobuf/testdata', filename)
- return open(full_path, 'rb')
- path = os.path.join(path, '..')
-
- raise RuntimeError(
- 'Could not find golden files. This test must be run from within the '
- 'protobuf source package so that it can read test data files from the '
- 'C++ source tree.')
-
-
-def SetAllPackedFields(message):
- """Sets every field in the message to a unique value.
-
- Args:
- message: A unittest_pb2.TestPackedTypes instance.
- """
- message.packed_int32.extend([601, 701])
- message.packed_int64.extend([602, 702])
- message.packed_uint32.extend([603, 703])
- message.packed_uint64.extend([604, 704])
- message.packed_sint32.extend([605, 705])
- message.packed_sint64.extend([606, 706])
- message.packed_fixed32.extend([607, 707])
- message.packed_fixed64.extend([608, 708])
- message.packed_sfixed32.extend([609, 709])
- message.packed_sfixed64.extend([610, 710])
- message.packed_float.extend([611.0, 711.0])
- message.packed_double.extend([612.0, 712.0])
- message.packed_bool.extend([True, False])
- message.packed_enum.extend([unittest_pb2.FOREIGN_BAR,
- unittest_pb2.FOREIGN_BAZ])
-
-
-def SetAllPackedExtensions(message):
- """Sets every extension in the message to a unique value.
-
- Args:
- message: A unittest_pb2.TestPackedExtensions instance.
- """
- extensions = message.Extensions
- pb2 = unittest_pb2
-
- extensions[pb2.packed_int32_extension].extend([601, 701])
- extensions[pb2.packed_int64_extension].extend([602, 702])
- extensions[pb2.packed_uint32_extension].extend([603, 703])
- extensions[pb2.packed_uint64_extension].extend([604, 704])
- extensions[pb2.packed_sint32_extension].extend([605, 705])
- extensions[pb2.packed_sint64_extension].extend([606, 706])
- extensions[pb2.packed_fixed32_extension].extend([607, 707])
- extensions[pb2.packed_fixed64_extension].extend([608, 708])
- extensions[pb2.packed_sfixed32_extension].extend([609, 709])
- extensions[pb2.packed_sfixed64_extension].extend([610, 710])
- extensions[pb2.packed_float_extension].extend([611.0, 711.0])
- extensions[pb2.packed_double_extension].extend([612.0, 712.0])
- extensions[pb2.packed_bool_extension].extend([True, False])
- extensions[pb2.packed_enum_extension].extend([unittest_pb2.FOREIGN_BAR,
- unittest_pb2.FOREIGN_BAZ])
-
-
-def SetAllUnpackedFields(message):
- """Sets every field in the message to a unique value.
-
- Args:
- message: A unittest_pb2.TestUnpackedTypes instance.
- """
- message.unpacked_int32.extend([601, 701])
- message.unpacked_int64.extend([602, 702])
- message.unpacked_uint32.extend([603, 703])
- message.unpacked_uint64.extend([604, 704])
- message.unpacked_sint32.extend([605, 705])
- message.unpacked_sint64.extend([606, 706])
- message.unpacked_fixed32.extend([607, 707])
- message.unpacked_fixed64.extend([608, 708])
- message.unpacked_sfixed32.extend([609, 709])
- message.unpacked_sfixed64.extend([610, 710])
- message.unpacked_float.extend([611.0, 711.0])
- message.unpacked_double.extend([612.0, 712.0])
- message.unpacked_bool.extend([True, False])
- message.unpacked_enum.extend([unittest_pb2.FOREIGN_BAR,
- unittest_pb2.FOREIGN_BAZ])
+# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Utilities for Python proto2 tests. + +This is intentionally modeled on C++ code in +//google/protobuf/test_util.*. +""" + +__author__ = '[email protected] (Will Robinson)' + +import os.path + +from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_pb2 + + +def SetAllFields(message): + """Sets every field in the message to a unique value. + + Args: + message: A unittest_pb2.TestAllTypes instance. + """ + + # + # Optional fields. + # + + message.optional_int32 = 101 + message.optional_int64 = 102 + message.optional_uint32 = 103 + message.optional_uint64 = 104 + message.optional_sint32 = 105 + message.optional_sint64 = 106 + message.optional_fixed32 = 107 + message.optional_fixed64 = 108 + message.optional_sfixed32 = 109 + message.optional_sfixed64 = 110 + message.optional_float = 111 + message.optional_double = 112 + message.optional_bool = True + # TODO(robinson): Firmly spec out and test how + # protos interact with unicode. One specific example: + # what happens if we change the literal below to + # u'115'? What *should* happen? Still some discussion + # to finish with Kenton about bytes vs. strings + # and forcing everything to be utf8. :-/ + message.optional_string = '115' + message.optional_bytes = '116' + + message.optionalgroup.a = 117 + message.optional_nested_message.bb = 118 + message.optional_foreign_message.c = 119 + message.optional_import_message.d = 120 + + message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ + message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ + message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ + + message.optional_string_piece = '124' + message.optional_cord = '125' + + # + # Repeated fields. + # + + message.repeated_int32.append(201) + message.repeated_int64.append(202) + message.repeated_uint32.append(203) + message.repeated_uint64.append(204) + message.repeated_sint32.append(205) + message.repeated_sint64.append(206) + message.repeated_fixed32.append(207) + message.repeated_fixed64.append(208) + message.repeated_sfixed32.append(209) + message.repeated_sfixed64.append(210) + message.repeated_float.append(211) + message.repeated_double.append(212) + message.repeated_bool.append(True) + message.repeated_string.append('215') + message.repeated_bytes.append('216') + + message.repeatedgroup.add().a = 217 + message.repeated_nested_message.add().bb = 218 + message.repeated_foreign_message.add().c = 219 + message.repeated_import_message.add().d = 220 + + message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR) + message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR) + message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR) + + message.repeated_string_piece.append('224') + message.repeated_cord.append('225') + + # Add a second one of each field. + message.repeated_int32.append(301) + message.repeated_int64.append(302) + message.repeated_uint32.append(303) + message.repeated_uint64.append(304) + message.repeated_sint32.append(305) + message.repeated_sint64.append(306) + message.repeated_fixed32.append(307) + message.repeated_fixed64.append(308) + message.repeated_sfixed32.append(309) + message.repeated_sfixed64.append(310) + message.repeated_float.append(311) + message.repeated_double.append(312) + message.repeated_bool.append(False) + message.repeated_string.append('315') + message.repeated_bytes.append('316') + + message.repeatedgroup.add().a = 317 + message.repeated_nested_message.add().bb = 318 + message.repeated_foreign_message.add().c = 319 + message.repeated_import_message.add().d = 320 + + message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ) + message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ) + message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ) + + message.repeated_string_piece.append('324') + message.repeated_cord.append('325') + + # + # Fields that have defaults. + # + + message.default_int32 = 401 + message.default_int64 = 402 + message.default_uint32 = 403 + message.default_uint64 = 404 + message.default_sint32 = 405 + message.default_sint64 = 406 + message.default_fixed32 = 407 + message.default_fixed64 = 408 + message.default_sfixed32 = 409 + message.default_sfixed64 = 410 + message.default_float = 411 + message.default_double = 412 + message.default_bool = False + message.default_string = '415' + message.default_bytes = '416' + + message.default_nested_enum = unittest_pb2.TestAllTypes.FOO + message.default_foreign_enum = unittest_pb2.FOREIGN_FOO + message.default_import_enum = unittest_import_pb2.IMPORT_FOO + + message.default_string_piece = '424' + message.default_cord = '425' + + +def SetAllExtensions(message): + """Sets every extension in the message to a unique value. + + Args: + message: A unittest_pb2.TestAllExtensions instance. + """ + + extensions = message.Extensions + pb2 = unittest_pb2 + import_pb2 = unittest_import_pb2 + + # + # Optional fields. + # + + extensions[pb2.optional_int32_extension] = 101 + extensions[pb2.optional_int64_extension] = 102 + extensions[pb2.optional_uint32_extension] = 103 + extensions[pb2.optional_uint64_extension] = 104 + extensions[pb2.optional_sint32_extension] = 105 + extensions[pb2.optional_sint64_extension] = 106 + extensions[pb2.optional_fixed32_extension] = 107 + extensions[pb2.optional_fixed64_extension] = 108 + extensions[pb2.optional_sfixed32_extension] = 109 + extensions[pb2.optional_sfixed64_extension] = 110 + extensions[pb2.optional_float_extension] = 111 + extensions[pb2.optional_double_extension] = 112 + extensions[pb2.optional_bool_extension] = True + extensions[pb2.optional_string_extension] = '115' + extensions[pb2.optional_bytes_extension] = '116' + + extensions[pb2.optionalgroup_extension].a = 117 + extensions[pb2.optional_nested_message_extension].bb = 118 + extensions[pb2.optional_foreign_message_extension].c = 119 + extensions[pb2.optional_import_message_extension].d = 120 + + extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ + extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ + extensions[pb2.optional_foreign_enum_extension] = pb2.FOREIGN_BAZ + extensions[pb2.optional_import_enum_extension] = import_pb2.IMPORT_BAZ + + extensions[pb2.optional_string_piece_extension] = '124' + extensions[pb2.optional_cord_extension] = '125' + + # + # Repeated fields. + # + + extensions[pb2.repeated_int32_extension].append(201) + extensions[pb2.repeated_int64_extension].append(202) + extensions[pb2.repeated_uint32_extension].append(203) + extensions[pb2.repeated_uint64_extension].append(204) + extensions[pb2.repeated_sint32_extension].append(205) + extensions[pb2.repeated_sint64_extension].append(206) + extensions[pb2.repeated_fixed32_extension].append(207) + extensions[pb2.repeated_fixed64_extension].append(208) + extensions[pb2.repeated_sfixed32_extension].append(209) + extensions[pb2.repeated_sfixed64_extension].append(210) + extensions[pb2.repeated_float_extension].append(211) + extensions[pb2.repeated_double_extension].append(212) + extensions[pb2.repeated_bool_extension].append(True) + extensions[pb2.repeated_string_extension].append('215') + extensions[pb2.repeated_bytes_extension].append('216') + + extensions[pb2.repeatedgroup_extension].add().a = 217 + extensions[pb2.repeated_nested_message_extension].add().bb = 218 + extensions[pb2.repeated_foreign_message_extension].add().c = 219 + extensions[pb2.repeated_import_message_extension].add().d = 220 + + extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAR) + extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAR) + extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAR) + + extensions[pb2.repeated_string_piece_extension].append('224') + extensions[pb2.repeated_cord_extension].append('225') + + # Append a second one of each field. + extensions[pb2.repeated_int32_extension].append(301) + extensions[pb2.repeated_int64_extension].append(302) + extensions[pb2.repeated_uint32_extension].append(303) + extensions[pb2.repeated_uint64_extension].append(304) + extensions[pb2.repeated_sint32_extension].append(305) + extensions[pb2.repeated_sint64_extension].append(306) + extensions[pb2.repeated_fixed32_extension].append(307) + extensions[pb2.repeated_fixed64_extension].append(308) + extensions[pb2.repeated_sfixed32_extension].append(309) + extensions[pb2.repeated_sfixed64_extension].append(310) + extensions[pb2.repeated_float_extension].append(311) + extensions[pb2.repeated_double_extension].append(312) + extensions[pb2.repeated_bool_extension].append(False) + extensions[pb2.repeated_string_extension].append('315') + extensions[pb2.repeated_bytes_extension].append('316') + + extensions[pb2.repeatedgroup_extension].add().a = 317 + extensions[pb2.repeated_nested_message_extension].add().bb = 318 + extensions[pb2.repeated_foreign_message_extension].add().c = 319 + extensions[pb2.repeated_import_message_extension].add().d = 320 + + extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAZ) + extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAZ) + extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAZ) + + extensions[pb2.repeated_string_piece_extension].append('324') + extensions[pb2.repeated_cord_extension].append('325') + + # + # Fields with defaults. + # + + extensions[pb2.default_int32_extension] = 401 + extensions[pb2.default_int64_extension] = 402 + extensions[pb2.default_uint32_extension] = 403 + extensions[pb2.default_uint64_extension] = 404 + extensions[pb2.default_sint32_extension] = 405 + extensions[pb2.default_sint64_extension] = 406 + extensions[pb2.default_fixed32_extension] = 407 + extensions[pb2.default_fixed64_extension] = 408 + extensions[pb2.default_sfixed32_extension] = 409 + extensions[pb2.default_sfixed64_extension] = 410 + extensions[pb2.default_float_extension] = 411 + extensions[pb2.default_double_extension] = 412 + extensions[pb2.default_bool_extension] = False + extensions[pb2.default_string_extension] = '415' + extensions[pb2.default_bytes_extension] = '416' + + extensions[pb2.default_nested_enum_extension] = pb2.TestAllTypes.FOO + extensions[pb2.default_foreign_enum_extension] = pb2.FOREIGN_FOO + extensions[pb2.default_import_enum_extension] = import_pb2.IMPORT_FOO + + extensions[pb2.default_string_piece_extension] = '424' + extensions[pb2.default_cord_extension] = '425' + + +def SetAllFieldsAndExtensions(message): + """Sets every field and extension in the message to a unique value. + + Args: + message: A unittest_pb2.TestAllExtensions message. + """ + message.my_int = 1 + message.my_string = 'foo' + message.my_float = 1.0 + message.Extensions[unittest_pb2.my_extension_int] = 23 + message.Extensions[unittest_pb2.my_extension_string] = 'bar' + + +def ExpectAllFieldsAndExtensionsInOrder(serialized): + """Ensures that serialized is the serialization we expect for a message + filled with SetAllFieldsAndExtensions(). (Specifically, ensures that the + serialization is in canonical, tag-number order). + """ + my_extension_int = unittest_pb2.my_extension_int + my_extension_string = unittest_pb2.my_extension_string + expected_strings = [] + message = unittest_pb2.TestFieldOrderings() + message.my_int = 1 # Field 1. + expected_strings.append(message.SerializeToString()) + message.Clear() + message.Extensions[my_extension_int] = 23 # Field 5. + expected_strings.append(message.SerializeToString()) + message.Clear() + message.my_string = 'foo' # Field 11. + expected_strings.append(message.SerializeToString()) + message.Clear() + message.Extensions[my_extension_string] = 'bar' # Field 50. + expected_strings.append(message.SerializeToString()) + message.Clear() + message.my_float = 1.0 + expected_strings.append(message.SerializeToString()) + message.Clear() + expected = ''.join(expected_strings) + + if expected != serialized: + raise ValueError('Expected %r, found %r' % (expected, serialized)) + + +def ExpectAllFieldsSet(test_case, message): + """Check all fields for correct values have after Set*Fields() is called.""" + test_case.assertTrue(message.HasField('optional_int32')) + test_case.assertTrue(message.HasField('optional_int64')) + test_case.assertTrue(message.HasField('optional_uint32')) + test_case.assertTrue(message.HasField('optional_uint64')) + test_case.assertTrue(message.HasField('optional_sint32')) + test_case.assertTrue(message.HasField('optional_sint64')) + test_case.assertTrue(message.HasField('optional_fixed32')) + test_case.assertTrue(message.HasField('optional_fixed64')) + test_case.assertTrue(message.HasField('optional_sfixed32')) + test_case.assertTrue(message.HasField('optional_sfixed64')) + test_case.assertTrue(message.HasField('optional_float')) + test_case.assertTrue(message.HasField('optional_double')) + test_case.assertTrue(message.HasField('optional_bool')) + test_case.assertTrue(message.HasField('optional_string')) + test_case.assertTrue(message.HasField('optional_bytes')) + + test_case.assertTrue(message.HasField('optionalgroup')) + test_case.assertTrue(message.HasField('optional_nested_message')) + test_case.assertTrue(message.HasField('optional_foreign_message')) + test_case.assertTrue(message.HasField('optional_import_message')) + + test_case.assertTrue(message.optionalgroup.HasField('a')) + test_case.assertTrue(message.optional_nested_message.HasField('bb')) + test_case.assertTrue(message.optional_foreign_message.HasField('c')) + test_case.assertTrue(message.optional_import_message.HasField('d')) + + test_case.assertTrue(message.HasField('optional_nested_enum')) + test_case.assertTrue(message.HasField('optional_foreign_enum')) + test_case.assertTrue(message.HasField('optional_import_enum')) + + test_case.assertTrue(message.HasField('optional_string_piece')) + test_case.assertTrue(message.HasField('optional_cord')) + + test_case.assertEqual(101, message.optional_int32) + test_case.assertEqual(102, message.optional_int64) + test_case.assertEqual(103, message.optional_uint32) + test_case.assertEqual(104, message.optional_uint64) + test_case.assertEqual(105, message.optional_sint32) + test_case.assertEqual(106, message.optional_sint64) + test_case.assertEqual(107, message.optional_fixed32) + test_case.assertEqual(108, message.optional_fixed64) + test_case.assertEqual(109, message.optional_sfixed32) + test_case.assertEqual(110, message.optional_sfixed64) + test_case.assertEqual(111, message.optional_float) + test_case.assertEqual(112, message.optional_double) + test_case.assertEqual(True, message.optional_bool) + test_case.assertEqual('115', message.optional_string) + test_case.assertEqual('116', message.optional_bytes) + + test_case.assertEqual(117, message.optionalgroup.a) + test_case.assertEqual(118, message.optional_nested_message.bb) + test_case.assertEqual(119, message.optional_foreign_message.c) + test_case.assertEqual(120, message.optional_import_message.d) + + test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ, + message.optional_nested_enum) + test_case.assertEqual(unittest_pb2.FOREIGN_BAZ, + message.optional_foreign_enum) + test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ, + message.optional_import_enum) + + # ----------------------------------------------------------------- + + test_case.assertEqual(2, len(message.repeated_int32)) + test_case.assertEqual(2, len(message.repeated_int64)) + test_case.assertEqual(2, len(message.repeated_uint32)) + test_case.assertEqual(2, len(message.repeated_uint64)) + test_case.assertEqual(2, len(message.repeated_sint32)) + test_case.assertEqual(2, len(message.repeated_sint64)) + test_case.assertEqual(2, len(message.repeated_fixed32)) + test_case.assertEqual(2, len(message.repeated_fixed64)) + test_case.assertEqual(2, len(message.repeated_sfixed32)) + test_case.assertEqual(2, len(message.repeated_sfixed64)) + test_case.assertEqual(2, len(message.repeated_float)) + test_case.assertEqual(2, len(message.repeated_double)) + test_case.assertEqual(2, len(message.repeated_bool)) + test_case.assertEqual(2, len(message.repeated_string)) + test_case.assertEqual(2, len(message.repeated_bytes)) + + test_case.assertEqual(2, len(message.repeatedgroup)) + test_case.assertEqual(2, len(message.repeated_nested_message)) + test_case.assertEqual(2, len(message.repeated_foreign_message)) + test_case.assertEqual(2, len(message.repeated_import_message)) + test_case.assertEqual(2, len(message.repeated_nested_enum)) + test_case.assertEqual(2, len(message.repeated_foreign_enum)) + test_case.assertEqual(2, len(message.repeated_import_enum)) + + test_case.assertEqual(2, len(message.repeated_string_piece)) + test_case.assertEqual(2, len(message.repeated_cord)) + + test_case.assertEqual(201, message.repeated_int32[0]) + test_case.assertEqual(202, message.repeated_int64[0]) + test_case.assertEqual(203, message.repeated_uint32[0]) + test_case.assertEqual(204, message.repeated_uint64[0]) + test_case.assertEqual(205, message.repeated_sint32[0]) + test_case.assertEqual(206, message.repeated_sint64[0]) + test_case.assertEqual(207, message.repeated_fixed32[0]) + test_case.assertEqual(208, message.repeated_fixed64[0]) + test_case.assertEqual(209, message.repeated_sfixed32[0]) + test_case.assertEqual(210, message.repeated_sfixed64[0]) + test_case.assertEqual(211, message.repeated_float[0]) + test_case.assertEqual(212, message.repeated_double[0]) + test_case.assertEqual(True, message.repeated_bool[0]) + test_case.assertEqual('215', message.repeated_string[0]) + test_case.assertEqual('216', message.repeated_bytes[0]) + + test_case.assertEqual(217, message.repeatedgroup[0].a) + test_case.assertEqual(218, message.repeated_nested_message[0].bb) + test_case.assertEqual(219, message.repeated_foreign_message[0].c) + test_case.assertEqual(220, message.repeated_import_message[0].d) + + test_case.assertEqual(unittest_pb2.TestAllTypes.BAR, + message.repeated_nested_enum[0]) + test_case.assertEqual(unittest_pb2.FOREIGN_BAR, + message.repeated_foreign_enum[0]) + test_case.assertEqual(unittest_import_pb2.IMPORT_BAR, + message.repeated_import_enum[0]) + + test_case.assertEqual(301, message.repeated_int32[1]) + test_case.assertEqual(302, message.repeated_int64[1]) + test_case.assertEqual(303, message.repeated_uint32[1]) + test_case.assertEqual(304, message.repeated_uint64[1]) + test_case.assertEqual(305, message.repeated_sint32[1]) + test_case.assertEqual(306, message.repeated_sint64[1]) + test_case.assertEqual(307, message.repeated_fixed32[1]) + test_case.assertEqual(308, message.repeated_fixed64[1]) + test_case.assertEqual(309, message.repeated_sfixed32[1]) + test_case.assertEqual(310, message.repeated_sfixed64[1]) + test_case.assertEqual(311, message.repeated_float[1]) + test_case.assertEqual(312, message.repeated_double[1]) + test_case.assertEqual(False, message.repeated_bool[1]) + test_case.assertEqual('315', message.repeated_string[1]) + test_case.assertEqual('316', message.repeated_bytes[1]) + + test_case.assertEqual(317, message.repeatedgroup[1].a) + test_case.assertEqual(318, message.repeated_nested_message[1].bb) + test_case.assertEqual(319, message.repeated_foreign_message[1].c) + test_case.assertEqual(320, message.repeated_import_message[1].d) + + test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ, + message.repeated_nested_enum[1]) + test_case.assertEqual(unittest_pb2.FOREIGN_BAZ, + message.repeated_foreign_enum[1]) + test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ, + message.repeated_import_enum[1]) + + # ----------------------------------------------------------------- + + test_case.assertTrue(message.HasField('default_int32')) + test_case.assertTrue(message.HasField('default_int64')) + test_case.assertTrue(message.HasField('default_uint32')) + test_case.assertTrue(message.HasField('default_uint64')) + test_case.assertTrue(message.HasField('default_sint32')) + test_case.assertTrue(message.HasField('default_sint64')) + test_case.assertTrue(message.HasField('default_fixed32')) + test_case.assertTrue(message.HasField('default_fixed64')) + test_case.assertTrue(message.HasField('default_sfixed32')) + test_case.assertTrue(message.HasField('default_sfixed64')) + test_case.assertTrue(message.HasField('default_float')) + test_case.assertTrue(message.HasField('default_double')) + test_case.assertTrue(message.HasField('default_bool')) + test_case.assertTrue(message.HasField('default_string')) + test_case.assertTrue(message.HasField('default_bytes')) + + test_case.assertTrue(message.HasField('default_nested_enum')) + test_case.assertTrue(message.HasField('default_foreign_enum')) + test_case.assertTrue(message.HasField('default_import_enum')) + + test_case.assertEqual(401, message.default_int32) + test_case.assertEqual(402, message.default_int64) + test_case.assertEqual(403, message.default_uint32) + test_case.assertEqual(404, message.default_uint64) + test_case.assertEqual(405, message.default_sint32) + test_case.assertEqual(406, message.default_sint64) + test_case.assertEqual(407, message.default_fixed32) + test_case.assertEqual(408, message.default_fixed64) + test_case.assertEqual(409, message.default_sfixed32) + test_case.assertEqual(410, message.default_sfixed64) + test_case.assertEqual(411, message.default_float) + test_case.assertEqual(412, message.default_double) + test_case.assertEqual(False, message.default_bool) + test_case.assertEqual('415', message.default_string) + test_case.assertEqual('416', message.default_bytes) + + test_case.assertEqual(unittest_pb2.TestAllTypes.FOO, + message.default_nested_enum) + test_case.assertEqual(unittest_pb2.FOREIGN_FOO, + message.default_foreign_enum) + test_case.assertEqual(unittest_import_pb2.IMPORT_FOO, + message.default_import_enum) + +def GoldenFile(filename): + """Finds the given golden file and returns a file object representing it.""" + + # Search up the directory tree looking for the C++ protobuf source code. + path = '.' + while os.path.exists(path): + if os.path.exists(os.path.join(path, 'src/google/protobuf')): + # Found it. Load the golden file from the testdata directory. + full_path = os.path.join(path, 'src/google/protobuf/testdata', filename) + return open(full_path, 'rb') + path = os.path.join(path, '..') + + raise RuntimeError( + 'Could not find golden files. This test must be run from within the ' + 'protobuf source package so that it can read test data files from the ' + 'C++ source tree.') + + +def SetAllPackedFields(message): + """Sets every field in the message to a unique value. + + Args: + message: A unittest_pb2.TestPackedTypes instance. + """ + message.packed_int32.extend([601, 701]) + message.packed_int64.extend([602, 702]) + message.packed_uint32.extend([603, 703]) + message.packed_uint64.extend([604, 704]) + message.packed_sint32.extend([605, 705]) + message.packed_sint64.extend([606, 706]) + message.packed_fixed32.extend([607, 707]) + message.packed_fixed64.extend([608, 708]) + message.packed_sfixed32.extend([609, 709]) + message.packed_sfixed64.extend([610, 710]) + message.packed_float.extend([611.0, 711.0]) + message.packed_double.extend([612.0, 712.0]) + message.packed_bool.extend([True, False]) + message.packed_enum.extend([unittest_pb2.FOREIGN_BAR, + unittest_pb2.FOREIGN_BAZ]) + + +def SetAllPackedExtensions(message): + """Sets every extension in the message to a unique value. + + Args: + message: A unittest_pb2.TestPackedExtensions instance. + """ + extensions = message.Extensions + pb2 = unittest_pb2 + + extensions[pb2.packed_int32_extension].extend([601, 701]) + extensions[pb2.packed_int64_extension].extend([602, 702]) + extensions[pb2.packed_uint32_extension].extend([603, 703]) + extensions[pb2.packed_uint64_extension].extend([604, 704]) + extensions[pb2.packed_sint32_extension].extend([605, 705]) + extensions[pb2.packed_sint64_extension].extend([606, 706]) + extensions[pb2.packed_fixed32_extension].extend([607, 707]) + extensions[pb2.packed_fixed64_extension].extend([608, 708]) + extensions[pb2.packed_sfixed32_extension].extend([609, 709]) + extensions[pb2.packed_sfixed64_extension].extend([610, 710]) + extensions[pb2.packed_float_extension].extend([611.0, 711.0]) + extensions[pb2.packed_double_extension].extend([612.0, 712.0]) + extensions[pb2.packed_bool_extension].extend([True, False]) + extensions[pb2.packed_enum_extension].extend([unittest_pb2.FOREIGN_BAR, + unittest_pb2.FOREIGN_BAZ]) + + +def SetAllUnpackedFields(message): + """Sets every field in the message to a unique value. + + Args: + message: A unittest_pb2.TestUnpackedTypes instance. + """ + message.unpacked_int32.extend([601, 701]) + message.unpacked_int64.extend([602, 702]) + message.unpacked_uint32.extend([603, 703]) + message.unpacked_uint64.extend([604, 704]) + message.unpacked_sint32.extend([605, 705]) + message.unpacked_sint64.extend([606, 706]) + message.unpacked_fixed32.extend([607, 707]) + message.unpacked_fixed64.extend([608, 708]) + message.unpacked_sfixed32.extend([609, 709]) + message.unpacked_sfixed64.extend([610, 710]) + message.unpacked_float.extend([611.0, 711.0]) + message.unpacked_double.extend([612.0, 712.0]) + message.unpacked_bool.extend([True, False]) + message.unpacked_enum.extend([unittest_pb2.FOREIGN_BAR, + unittest_pb2.FOREIGN_BAZ]) diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/text_format_test.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/text_format_test.py index 2d62ba0e..e0991cb1 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/text_format_test.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/text_format_test.py @@ -1,428 +1,428 @@ -#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Test for google.protobuf.text_format."""
-
-__author__ = '[email protected] (Kenton Varda)'
-
-import difflib
-
-import unittest
-from google.protobuf import text_format
-from google.protobuf.internal import test_util
-from google.protobuf import unittest_pb2
-from google.protobuf import unittest_mset_pb2
-
-
-class TextFormatTest(unittest.TestCase):
- def ReadGolden(self, golden_filename):
- f = test_util.GoldenFile(golden_filename)
- golden_lines = f.readlines()
- f.close()
- return golden_lines
-
- def CompareToGoldenFile(self, text, golden_filename):
- golden_lines = self.ReadGolden(golden_filename)
- self.CompareToGoldenLines(text, golden_lines)
-
- def CompareToGoldenText(self, text, golden_text):
- self.CompareToGoldenLines(text, golden_text.splitlines(1))
-
- def CompareToGoldenLines(self, text, golden_lines):
- actual_lines = text.splitlines(1)
- self.assertEqual(golden_lines, actual_lines,
- "Text doesn't match golden. Diff:\n" +
- ''.join(difflib.ndiff(golden_lines, actual_lines)))
-
- def testPrintAllFields(self):
- message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(text_format.MessageToString(message)),
- 'text_format_unittest_data.txt')
-
- def testPrintAllExtensions(self):
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- self.CompareToGoldenFile(
- self.RemoveRedundantZeros(text_format.MessageToString(message)),
- 'text_format_unittest_extensions_data.txt')
-
- def testPrintMessageSet(self):
- message = unittest_mset_pb2.TestMessageSetContainer()
- ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
- ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
- message.message_set.Extensions[ext1].i = 23
- message.message_set.Extensions[ext2].str = 'foo'
- self.CompareToGoldenText(text_format.MessageToString(message),
- 'message_set {\n'
- ' [protobuf_unittest.TestMessageSetExtension1] {\n'
- ' i: 23\n'
- ' }\n'
- ' [protobuf_unittest.TestMessageSetExtension2] {\n'
- ' str: \"foo\"\n'
- ' }\n'
- '}\n')
-
- def testPrintExotic(self):
- message = unittest_pb2.TestAllTypes()
- message.repeated_int64.append(-9223372036854775808);
- message.repeated_uint64.append(18446744073709551615);
- message.repeated_double.append(123.456);
- message.repeated_double.append(1.23e22);
- message.repeated_double.append(1.23e-18);
- message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'\"');
- self.CompareToGoldenText(
- self.RemoveRedundantZeros(text_format.MessageToString(message)),
- 'repeated_int64: -9223372036854775808\n'
- 'repeated_uint64: 18446744073709551615\n'
- 'repeated_double: 123.456\n'
- 'repeated_double: 1.23e+22\n'
- 'repeated_double: 1.23e-18\n'
- 'repeated_string: '
- '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n')
-
- def testMessageToString(self):
- message = unittest_pb2.ForeignMessage()
- message.c = 123
- self.assertEqual('c: 123\n', str(message))
-
- def RemoveRedundantZeros(self, text):
- # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove
- # these zeros in order to match the golden file.
- return text.replace('e+0','e+').replace('e+0','e+') \
- .replace('e-0','e-').replace('e-0','e-')
-
- def testMergeGolden(self):
- golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt'))
- parsed_message = unittest_pb2.TestAllTypes()
- text_format.Merge(golden_text, parsed_message)
-
- message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- self.assertEquals(message, parsed_message)
-
- def testMergeGoldenExtensions(self):
- golden_text = '\n'.join(self.ReadGolden(
- 'text_format_unittest_extensions_data.txt'))
- parsed_message = unittest_pb2.TestAllExtensions()
- text_format.Merge(golden_text, parsed_message)
-
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- self.assertEquals(message, parsed_message)
-
- def testMergeAllFields(self):
- message = unittest_pb2.TestAllTypes()
- test_util.SetAllFields(message)
- ascii_text = text_format.MessageToString(message)
-
- parsed_message = unittest_pb2.TestAllTypes()
- text_format.Merge(ascii_text, parsed_message)
- self.assertEqual(message, parsed_message)
- test_util.ExpectAllFieldsSet(self, message)
-
- def testMergeAllExtensions(self):
- message = unittest_pb2.TestAllExtensions()
- test_util.SetAllExtensions(message)
- ascii_text = text_format.MessageToString(message)
-
- parsed_message = unittest_pb2.TestAllExtensions()
- text_format.Merge(ascii_text, parsed_message)
- self.assertEqual(message, parsed_message)
-
- def testMergeMessageSet(self):
- message = unittest_pb2.TestAllTypes()
- text = ('repeated_uint64: 1\n'
- 'repeated_uint64: 2\n')
- text_format.Merge(text, message)
- self.assertEqual(1, message.repeated_uint64[0])
- self.assertEqual(2, message.repeated_uint64[1])
-
- message = unittest_mset_pb2.TestMessageSetContainer()
- text = ('message_set {\n'
- ' [protobuf_unittest.TestMessageSetExtension1] {\n'
- ' i: 23\n'
- ' }\n'
- ' [protobuf_unittest.TestMessageSetExtension2] {\n'
- ' str: \"foo\"\n'
- ' }\n'
- '}\n')
- text_format.Merge(text, message)
- ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension
- ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension
- self.assertEquals(23, message.message_set.Extensions[ext1].i)
- self.assertEquals('foo', message.message_set.Extensions[ext2].str)
-
- def testMergeExotic(self):
- message = unittest_pb2.TestAllTypes()
- text = ('repeated_int64: -9223372036854775808\n'
- 'repeated_uint64: 18446744073709551615\n'
- 'repeated_double: 123.456\n'
- 'repeated_double: 1.23e+22\n'
- 'repeated_double: 1.23e-18\n'
- 'repeated_string: \n'
- '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n'
- 'repeated_string: "foo" \'corge\' "grault"')
- text_format.Merge(text, message)
-
- self.assertEqual(-9223372036854775808, message.repeated_int64[0])
- self.assertEqual(18446744073709551615, message.repeated_uint64[0])
- self.assertEqual(123.456, message.repeated_double[0])
- self.assertEqual(1.23e22, message.repeated_double[1])
- self.assertEqual(1.23e-18, message.repeated_double[2])
- self.assertEqual(
- '\000\001\a\b\f\n\r\t\v\\\'\"', message.repeated_string[0])
- self.assertEqual('foocorgegrault', message.repeated_string[1])
-
- def testMergeUnknownField(self):
- message = unittest_pb2.TestAllTypes()
- text = 'unknown_field: 8\n'
- self.assertRaisesWithMessage(
- text_format.ParseError,
- ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named '
- '"unknown_field".'),
- text_format.Merge, text, message)
-
- def testMergeBadExtension(self):
- message = unittest_pb2.TestAllExtensions()
- text = '[unknown_extension]: 8\n'
- self.assertRaisesWithMessage(
- text_format.ParseError,
- '1:2 : Extension "unknown_extension" not registered.',
- text_format.Merge, text, message)
- message = unittest_pb2.TestAllTypes()
- self.assertRaisesWithMessage(
- text_format.ParseError,
- ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have '
- 'extensions.'),
- text_format.Merge, text, message)
-
- def testMergeGroupNotClosed(self):
- message = unittest_pb2.TestAllTypes()
- text = 'RepeatedGroup: <'
- self.assertRaisesWithMessage(
- text_format.ParseError, '1:16 : Expected ">".',
- text_format.Merge, text, message)
-
- text = 'RepeatedGroup: {'
- self.assertRaisesWithMessage(
- text_format.ParseError, '1:16 : Expected "}".',
- text_format.Merge, text, message)
-
- def testMergeEmptyGroup(self):
- message = unittest_pb2.TestAllTypes()
- text = 'OptionalGroup: {}'
- text_format.Merge(text, message)
- self.assertTrue(message.HasField('optionalgroup'))
-
- message.Clear()
-
- message = unittest_pb2.TestAllTypes()
- text = 'OptionalGroup: <>'
- text_format.Merge(text, message)
- self.assertTrue(message.HasField('optionalgroup'))
-
- def testMergeBadEnumValue(self):
- message = unittest_pb2.TestAllTypes()
- text = 'optional_nested_enum: BARR'
- self.assertRaisesWithMessage(
- text_format.ParseError,
- ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
- 'has no value named BARR.'),
- text_format.Merge, text, message)
-
- message = unittest_pb2.TestAllTypes()
- text = 'optional_nested_enum: 100'
- self.assertRaisesWithMessage(
- text_format.ParseError,
- ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
- 'has no value with number 100.'),
- text_format.Merge, text, message)
-
- def assertRaisesWithMessage(self, e_class, e, func, *args, **kwargs):
- """Same as assertRaises, but also compares the exception message."""
- if hasattr(e_class, '__name__'):
- exc_name = e_class.__name__
- else:
- exc_name = str(e_class)
-
- try:
- func(*args, **kwargs)
- except e_class, expr:
- if str(expr) != e:
- msg = '%s raised, but with wrong message: "%s" instead of "%s"'
- raise self.failureException(msg % (exc_name,
- str(expr).encode('string_escape'),
- e.encode('string_escape')))
- return
- else:
- raise self.failureException('%s not raised' % exc_name)
-
-
-class TokenizerTest(unittest.TestCase):
-
- def testSimpleTokenCases(self):
- text = ('identifier1:"string1"\n \n\n'
- 'identifier2 : \n \n123 \n identifier3 :\'string\'\n'
- 'identifiER_4 : 1.1e+2 ID5:-0.23 ID6:\'aaaa\\\'bbbb\'\n'
- 'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n'
- 'ID9: 22 ID10: -111111111111111111 ID11: -22\n'
- 'ID12: 2222222222222222222')
- tokenizer = text_format._Tokenizer(text)
- methods = [(tokenizer.ConsumeIdentifier, 'identifier1'),
- ':',
- (tokenizer.ConsumeString, 'string1'),
- (tokenizer.ConsumeIdentifier, 'identifier2'),
- ':',
- (tokenizer.ConsumeInt32, 123),
- (tokenizer.ConsumeIdentifier, 'identifier3'),
- ':',
- (tokenizer.ConsumeString, 'string'),
- (tokenizer.ConsumeIdentifier, 'identifiER_4'),
- ':',
- (tokenizer.ConsumeFloat, 1.1e+2),
- (tokenizer.ConsumeIdentifier, 'ID5'),
- ':',
- (tokenizer.ConsumeFloat, -0.23),
- (tokenizer.ConsumeIdentifier, 'ID6'),
- ':',
- (tokenizer.ConsumeString, 'aaaa\'bbbb'),
- (tokenizer.ConsumeIdentifier, 'ID7'),
- ':',
- (tokenizer.ConsumeString, 'aa\"bb'),
- (tokenizer.ConsumeIdentifier, 'ID8'),
- ':',
- '{',
- (tokenizer.ConsumeIdentifier, 'A'),
- ':',
- (tokenizer.ConsumeFloat, text_format._INFINITY),
- (tokenizer.ConsumeIdentifier, 'B'),
- ':',
- (tokenizer.ConsumeFloat, -text_format._INFINITY),
- (tokenizer.ConsumeIdentifier, 'C'),
- ':',
- (tokenizer.ConsumeBool, True),
- (tokenizer.ConsumeIdentifier, 'D'),
- ':',
- (tokenizer.ConsumeBool, False),
- '}',
- (tokenizer.ConsumeIdentifier, 'ID9'),
- ':',
- (tokenizer.ConsumeUint32, 22),
- (tokenizer.ConsumeIdentifier, 'ID10'),
- ':',
- (tokenizer.ConsumeInt64, -111111111111111111),
- (tokenizer.ConsumeIdentifier, 'ID11'),
- ':',
- (tokenizer.ConsumeInt32, -22),
- (tokenizer.ConsumeIdentifier, 'ID12'),
- ':',
- (tokenizer.ConsumeUint64, 2222222222222222222)]
-
- i = 0
- while not tokenizer.AtEnd():
- m = methods[i]
- if type(m) == str:
- token = tokenizer.token
- self.assertEqual(token, m)
- tokenizer.NextToken()
- else:
- self.assertEqual(m[1], m[0]())
- i += 1
-
- def testConsumeIntegers(self):
- # This test only tests the failures in the integer parsing methods as well
- # as the '0' special cases.
- int64_max = (1 << 63) - 1
- uint32_max = (1 << 32) - 1
- text = '-1 %d %d' % (uint32_max + 1, int64_max + 1)
- tokenizer = text_format._Tokenizer(text)
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32)
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint64)
- self.assertEqual(-1, tokenizer.ConsumeInt32())
-
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32)
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt32)
- self.assertEqual(uint32_max + 1, tokenizer.ConsumeInt64())
-
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt64)
- self.assertEqual(int64_max + 1, tokenizer.ConsumeUint64())
- self.assertTrue(tokenizer.AtEnd())
-
- text = '-0 -0 0 0'
- tokenizer = text_format._Tokenizer(text)
- self.assertEqual(0, tokenizer.ConsumeUint32())
- self.assertEqual(0, tokenizer.ConsumeUint64())
- self.assertEqual(0, tokenizer.ConsumeUint32())
- self.assertEqual(0, tokenizer.ConsumeUint64())
- self.assertTrue(tokenizer.AtEnd())
-
- def testConsumeByteString(self):
- text = '"string1\''
- tokenizer = text_format._Tokenizer(text)
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
-
- text = 'string1"'
- tokenizer = text_format._Tokenizer(text)
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
-
- text = '\n"\\xt"'
- tokenizer = text_format._Tokenizer(text)
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
-
- text = '\n"\\"'
- tokenizer = text_format._Tokenizer(text)
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
-
- text = '\n"\\x"'
- tokenizer = text_format._Tokenizer(text)
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString)
-
- def testConsumeBool(self):
- text = 'not-a-bool'
- tokenizer = text_format._Tokenizer(text)
- self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool)
-
- def testInfNan(self):
- # Make sure our infinity and NaN definitions are sound.
- self.assertEquals(float, type(text_format._INFINITY))
- self.assertEquals(float, type(text_format._NAN))
- self.assertTrue(text_format._NAN != text_format._NAN)
-
- inf_times_zero = text_format._INFINITY * 0
- self.assertTrue(inf_times_zero != inf_times_zero)
- self.assertTrue(text_format._INFINITY > 0)
-
-
-if __name__ == '__main__':
- unittest.main()
+#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for google.protobuf.text_format.""" + +__author__ = '[email protected] (Kenton Varda)' + +import difflib + +import unittest +from google.protobuf import text_format +from google.protobuf.internal import test_util +from google.protobuf import unittest_pb2 +from google.protobuf import unittest_mset_pb2 + + +class TextFormatTest(unittest.TestCase): + def ReadGolden(self, golden_filename): + f = test_util.GoldenFile(golden_filename) + golden_lines = f.readlines() + f.close() + return golden_lines + + def CompareToGoldenFile(self, text, golden_filename): + golden_lines = self.ReadGolden(golden_filename) + self.CompareToGoldenLines(text, golden_lines) + + def CompareToGoldenText(self, text, golden_text): + self.CompareToGoldenLines(text, golden_text.splitlines(1)) + + def CompareToGoldenLines(self, text, golden_lines): + actual_lines = text.splitlines(1) + self.assertEqual(golden_lines, actual_lines, + "Text doesn't match golden. Diff:\n" + + ''.join(difflib.ndiff(golden_lines, actual_lines))) + + def testPrintAllFields(self): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_data.txt') + + def testPrintAllExtensions(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_extensions_data.txt') + + def testPrintMessageSet(self): + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + self.CompareToGoldenText(text_format.MessageToString(message), + 'message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') + + def testPrintExotic(self): + message = unittest_pb2.TestAllTypes() + message.repeated_int64.append(-9223372036854775808); + message.repeated_uint64.append(18446744073709551615); + message.repeated_double.append(123.456); + message.repeated_double.append(1.23e22); + message.repeated_double.append(1.23e-18); + message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'\"'); + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'repeated_int64: -9223372036854775808\n' + 'repeated_uint64: 18446744073709551615\n' + 'repeated_double: 123.456\n' + 'repeated_double: 1.23e+22\n' + 'repeated_double: 1.23e-18\n' + 'repeated_string: ' + '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n') + + def testMessageToString(self): + message = unittest_pb2.ForeignMessage() + message.c = 123 + self.assertEqual('c: 123\n', str(message)) + + def RemoveRedundantZeros(self, text): + # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove + # these zeros in order to match the golden file. + return text.replace('e+0','e+').replace('e+0','e+') \ + .replace('e-0','e-').replace('e-0','e-') + + def testMergeGolden(self): + golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt')) + parsed_message = unittest_pb2.TestAllTypes() + text_format.Merge(golden_text, parsed_message) + + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEquals(message, parsed_message) + + def testMergeGoldenExtensions(self): + golden_text = '\n'.join(self.ReadGolden( + 'text_format_unittest_extensions_data.txt')) + parsed_message = unittest_pb2.TestAllExtensions() + text_format.Merge(golden_text, parsed_message) + + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.assertEquals(message, parsed_message) + + def testMergeAllFields(self): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + ascii_text = text_format.MessageToString(message) + + parsed_message = unittest_pb2.TestAllTypes() + text_format.Merge(ascii_text, parsed_message) + self.assertEqual(message, parsed_message) + test_util.ExpectAllFieldsSet(self, message) + + def testMergeAllExtensions(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + ascii_text = text_format.MessageToString(message) + + parsed_message = unittest_pb2.TestAllExtensions() + text_format.Merge(ascii_text, parsed_message) + self.assertEqual(message, parsed_message) + + def testMergeMessageSet(self): + message = unittest_pb2.TestAllTypes() + text = ('repeated_uint64: 1\n' + 'repeated_uint64: 2\n') + text_format.Merge(text, message) + self.assertEqual(1, message.repeated_uint64[0]) + self.assertEqual(2, message.repeated_uint64[1]) + + message = unittest_mset_pb2.TestMessageSetContainer() + text = ('message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') + text_format.Merge(text, message) + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + self.assertEquals(23, message.message_set.Extensions[ext1].i) + self.assertEquals('foo', message.message_set.Extensions[ext2].str) + + def testMergeExotic(self): + message = unittest_pb2.TestAllTypes() + text = ('repeated_int64: -9223372036854775808\n' + 'repeated_uint64: 18446744073709551615\n' + 'repeated_double: 123.456\n' + 'repeated_double: 1.23e+22\n' + 'repeated_double: 1.23e-18\n' + 'repeated_string: \n' + '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n' + 'repeated_string: "foo" \'corge\' "grault"') + text_format.Merge(text, message) + + self.assertEqual(-9223372036854775808, message.repeated_int64[0]) + self.assertEqual(18446744073709551615, message.repeated_uint64[0]) + self.assertEqual(123.456, message.repeated_double[0]) + self.assertEqual(1.23e22, message.repeated_double[1]) + self.assertEqual(1.23e-18, message.repeated_double[2]) + self.assertEqual( + '\000\001\a\b\f\n\r\t\v\\\'\"', message.repeated_string[0]) + self.assertEqual('foocorgegrault', message.repeated_string[1]) + + def testMergeUnknownField(self): + message = unittest_pb2.TestAllTypes() + text = 'unknown_field: 8\n' + self.assertRaisesWithMessage( + text_format.ParseError, + ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named ' + '"unknown_field".'), + text_format.Merge, text, message) + + def testMergeBadExtension(self): + message = unittest_pb2.TestAllExtensions() + text = '[unknown_extension]: 8\n' + self.assertRaisesWithMessage( + text_format.ParseError, + '1:2 : Extension "unknown_extension" not registered.', + text_format.Merge, text, message) + message = unittest_pb2.TestAllTypes() + self.assertRaisesWithMessage( + text_format.ParseError, + ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' + 'extensions.'), + text_format.Merge, text, message) + + def testMergeGroupNotClosed(self): + message = unittest_pb2.TestAllTypes() + text = 'RepeatedGroup: <' + self.assertRaisesWithMessage( + text_format.ParseError, '1:16 : Expected ">".', + text_format.Merge, text, message) + + text = 'RepeatedGroup: {' + self.assertRaisesWithMessage( + text_format.ParseError, '1:16 : Expected "}".', + text_format.Merge, text, message) + + def testMergeEmptyGroup(self): + message = unittest_pb2.TestAllTypes() + text = 'OptionalGroup: {}' + text_format.Merge(text, message) + self.assertTrue(message.HasField('optionalgroup')) + + message.Clear() + + message = unittest_pb2.TestAllTypes() + text = 'OptionalGroup: <>' + text_format.Merge(text, message) + self.assertTrue(message.HasField('optionalgroup')) + + def testMergeBadEnumValue(self): + message = unittest_pb2.TestAllTypes() + text = 'optional_nested_enum: BARR' + self.assertRaisesWithMessage( + text_format.ParseError, + ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' + 'has no value named BARR.'), + text_format.Merge, text, message) + + message = unittest_pb2.TestAllTypes() + text = 'optional_nested_enum: 100' + self.assertRaisesWithMessage( + text_format.ParseError, + ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' + 'has no value with number 100.'), + text_format.Merge, text, message) + + def assertRaisesWithMessage(self, e_class, e, func, *args, **kwargs): + """Same as assertRaises, but also compares the exception message.""" + if hasattr(e_class, '__name__'): + exc_name = e_class.__name__ + else: + exc_name = str(e_class) + + try: + func(*args, **kwargs) + except e_class, expr: + if str(expr) != e: + msg = '%s raised, but with wrong message: "%s" instead of "%s"' + raise self.failureException(msg % (exc_name, + str(expr).encode('string_escape'), + e.encode('string_escape'))) + return + else: + raise self.failureException('%s not raised' % exc_name) + + +class TokenizerTest(unittest.TestCase): + + def testSimpleTokenCases(self): + text = ('identifier1:"string1"\n \n\n' + 'identifier2 : \n \n123 \n identifier3 :\'string\'\n' + 'identifiER_4 : 1.1e+2 ID5:-0.23 ID6:\'aaaa\\\'bbbb\'\n' + 'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n' + 'ID9: 22 ID10: -111111111111111111 ID11: -22\n' + 'ID12: 2222222222222222222') + tokenizer = text_format._Tokenizer(text) + methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), + ':', + (tokenizer.ConsumeString, 'string1'), + (tokenizer.ConsumeIdentifier, 'identifier2'), + ':', + (tokenizer.ConsumeInt32, 123), + (tokenizer.ConsumeIdentifier, 'identifier3'), + ':', + (tokenizer.ConsumeString, 'string'), + (tokenizer.ConsumeIdentifier, 'identifiER_4'), + ':', + (tokenizer.ConsumeFloat, 1.1e+2), + (tokenizer.ConsumeIdentifier, 'ID5'), + ':', + (tokenizer.ConsumeFloat, -0.23), + (tokenizer.ConsumeIdentifier, 'ID6'), + ':', + (tokenizer.ConsumeString, 'aaaa\'bbbb'), + (tokenizer.ConsumeIdentifier, 'ID7'), + ':', + (tokenizer.ConsumeString, 'aa\"bb'), + (tokenizer.ConsumeIdentifier, 'ID8'), + ':', + '{', + (tokenizer.ConsumeIdentifier, 'A'), + ':', + (tokenizer.ConsumeFloat, text_format._INFINITY), + (tokenizer.ConsumeIdentifier, 'B'), + ':', + (tokenizer.ConsumeFloat, -text_format._INFINITY), + (tokenizer.ConsumeIdentifier, 'C'), + ':', + (tokenizer.ConsumeBool, True), + (tokenizer.ConsumeIdentifier, 'D'), + ':', + (tokenizer.ConsumeBool, False), + '}', + (tokenizer.ConsumeIdentifier, 'ID9'), + ':', + (tokenizer.ConsumeUint32, 22), + (tokenizer.ConsumeIdentifier, 'ID10'), + ':', + (tokenizer.ConsumeInt64, -111111111111111111), + (tokenizer.ConsumeIdentifier, 'ID11'), + ':', + (tokenizer.ConsumeInt32, -22), + (tokenizer.ConsumeIdentifier, 'ID12'), + ':', + (tokenizer.ConsumeUint64, 2222222222222222222)] + + i = 0 + while not tokenizer.AtEnd(): + m = methods[i] + if type(m) == str: + token = tokenizer.token + self.assertEqual(token, m) + tokenizer.NextToken() + else: + self.assertEqual(m[1], m[0]()) + i += 1 + + def testConsumeIntegers(self): + # This test only tests the failures in the integer parsing methods as well + # as the '0' special cases. + int64_max = (1 << 63) - 1 + uint32_max = (1 << 32) - 1 + text = '-1 %d %d' % (uint32_max + 1, int64_max + 1) + tokenizer = text_format._Tokenizer(text) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint64) + self.assertEqual(-1, tokenizer.ConsumeInt32()) + + self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt32) + self.assertEqual(uint32_max + 1, tokenizer.ConsumeInt64()) + + self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt64) + self.assertEqual(int64_max + 1, tokenizer.ConsumeUint64()) + self.assertTrue(tokenizer.AtEnd()) + + text = '-0 -0 0 0' + tokenizer = text_format._Tokenizer(text) + self.assertEqual(0, tokenizer.ConsumeUint32()) + self.assertEqual(0, tokenizer.ConsumeUint64()) + self.assertEqual(0, tokenizer.ConsumeUint32()) + self.assertEqual(0, tokenizer.ConsumeUint64()) + self.assertTrue(tokenizer.AtEnd()) + + def testConsumeByteString(self): + text = '"string1\'' + tokenizer = text_format._Tokenizer(text) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) + + text = 'string1"' + tokenizer = text_format._Tokenizer(text) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) + + text = '\n"\\xt"' + tokenizer = text_format._Tokenizer(text) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) + + text = '\n"\\"' + tokenizer = text_format._Tokenizer(text) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) + + text = '\n"\\x"' + tokenizer = text_format._Tokenizer(text) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) + + def testConsumeBool(self): + text = 'not-a-bool' + tokenizer = text_format._Tokenizer(text) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool) + + def testInfNan(self): + # Make sure our infinity and NaN definitions are sound. + self.assertEquals(float, type(text_format._INFINITY)) + self.assertEquals(float, type(text_format._NAN)) + self.assertTrue(text_format._NAN != text_format._NAN) + + inf_times_zero = text_format._INFINITY * 0 + self.assertTrue(inf_times_zero != inf_times_zero) + self.assertTrue(text_format._INFINITY > 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/type_checkers.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/type_checkers.py index bea94d28..2b3cd4de 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/type_checkers.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/type_checkers.py @@ -1,286 +1,286 @@ -# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Provides type checking routines.
-
-This module defines type checking utilities in the forms of dictionaries:
-
-VALUE_CHECKERS: A dictionary of field types and a value validation object.
-TYPE_TO_BYTE_SIZE_FN: A dictionary with field types and a size computing
- function.
-TYPE_TO_SERIALIZE_METHOD: A dictionary with field types and serialization
- function.
-FIELD_TYPE_TO_WIRE_TYPE: A dictionary with field typed and their
- coresponding wire types.
-TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization
- function.
-"""
-
-__author__ = '[email protected] (Will Robinson)'
-
-from google.protobuf.internal import decoder
-from google.protobuf.internal import encoder
-from google.protobuf.internal import wire_format
-from google.protobuf import descriptor
-
-_FieldDescriptor = descriptor.FieldDescriptor
-
-
-def GetTypeChecker(cpp_type, field_type):
- """Returns a type checker for a message field of the specified types.
-
- Args:
- cpp_type: C++ type of the field (see descriptor.py).
- field_type: Protocol message field type (see descriptor.py).
-
- Returns:
- An instance of TypeChecker which can be used to verify the types
- of values assigned to a field of the specified type.
- """
- if (cpp_type == _FieldDescriptor.CPPTYPE_STRING and
- field_type == _FieldDescriptor.TYPE_STRING):
- return UnicodeValueChecker()
- return _VALUE_CHECKERS[cpp_type]
-
-
-# None of the typecheckers below make any attempt to guard against people
-# subclassing builtin types and doing weird things. We're not trying to
-# protect against malicious clients here, just people accidentally shooting
-# themselves in the foot in obvious ways.
-
-class TypeChecker(object):
-
- """Type checker used to catch type errors as early as possible
- when the client is setting scalar fields in protocol messages.
- """
-
- def __init__(self, *acceptable_types):
- self._acceptable_types = acceptable_types
-
- def CheckValue(self, proposed_value):
- if not isinstance(proposed_value, self._acceptable_types):
- message = ('%.1024r has type %s, but expected one of: %s' %
- (proposed_value, type(proposed_value), self._acceptable_types))
- raise TypeError(message)
-
-
-# IntValueChecker and its subclasses perform integer type-checks
-# and bounds-checks.
-class IntValueChecker(object):
-
- """Checker used for integer fields. Performs type-check and range check."""
-
- def CheckValue(self, proposed_value):
- if not isinstance(proposed_value, (int, long)):
- message = ('%.1024r has type %s, but expected one of: %s' %
- (proposed_value, type(proposed_value), (int, long)))
- raise TypeError(message)
- if not self._MIN <= proposed_value <= self._MAX:
- raise ValueError('Value out of range: %d' % proposed_value)
-
-
-class UnicodeValueChecker(object):
-
- """Checker used for string fields."""
-
- def CheckValue(self, proposed_value):
- if not isinstance(proposed_value, (str, unicode)):
- message = ('%.1024r has type %s, but expected one of: %s' %
- (proposed_value, type(proposed_value), (str, unicode)))
- raise TypeError(message)
-
- # If the value is of type 'str' make sure that it is in 7-bit ASCII
- # encoding.
- if isinstance(proposed_value, str):
- try:
- unicode(proposed_value, 'ascii')
- except UnicodeDecodeError:
- raise ValueError('%.1024r has type str, but isn\'t in 7-bit ASCII '
- 'encoding. Non-ASCII strings must be converted to '
- 'unicode objects before being added.' %
- (proposed_value))
-
-
-class Int32ValueChecker(IntValueChecker):
- # We're sure to use ints instead of longs here since comparison may be more
- # efficient.
- _MIN = -2147483648
- _MAX = 2147483647
-
-
-class Uint32ValueChecker(IntValueChecker):
- _MIN = 0
- _MAX = (1 << 32) - 1
-
-
-class Int64ValueChecker(IntValueChecker):
- _MIN = -(1 << 63)
- _MAX = (1 << 63) - 1
-
-
-class Uint64ValueChecker(IntValueChecker):
- _MIN = 0
- _MAX = (1 << 64) - 1
-
-
-# Type-checkers for all scalar CPPTYPEs.
-_VALUE_CHECKERS = {
- _FieldDescriptor.CPPTYPE_INT32: Int32ValueChecker(),
- _FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(),
- _FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(),
- _FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(),
- _FieldDescriptor.CPPTYPE_DOUBLE: TypeChecker(
- float, int, long),
- _FieldDescriptor.CPPTYPE_FLOAT: TypeChecker(
- float, int, long),
- _FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int),
- _FieldDescriptor.CPPTYPE_ENUM: Int32ValueChecker(),
- _FieldDescriptor.CPPTYPE_STRING: TypeChecker(str),
- }
-
-
-# Map from field type to a function F, such that F(field_num, value)
-# gives the total byte size for a value of the given type. This
-# byte size includes tag information and any other additional space
-# associated with serializing "value".
-TYPE_TO_BYTE_SIZE_FN = {
- _FieldDescriptor.TYPE_DOUBLE: wire_format.DoubleByteSize,
- _FieldDescriptor.TYPE_FLOAT: wire_format.FloatByteSize,
- _FieldDescriptor.TYPE_INT64: wire_format.Int64ByteSize,
- _FieldDescriptor.TYPE_UINT64: wire_format.UInt64ByteSize,
- _FieldDescriptor.TYPE_INT32: wire_format.Int32ByteSize,
- _FieldDescriptor.TYPE_FIXED64: wire_format.Fixed64ByteSize,
- _FieldDescriptor.TYPE_FIXED32: wire_format.Fixed32ByteSize,
- _FieldDescriptor.TYPE_BOOL: wire_format.BoolByteSize,
- _FieldDescriptor.TYPE_STRING: wire_format.StringByteSize,
- _FieldDescriptor.TYPE_GROUP: wire_format.GroupByteSize,
- _FieldDescriptor.TYPE_MESSAGE: wire_format.MessageByteSize,
- _FieldDescriptor.TYPE_BYTES: wire_format.BytesByteSize,
- _FieldDescriptor.TYPE_UINT32: wire_format.UInt32ByteSize,
- _FieldDescriptor.TYPE_ENUM: wire_format.EnumByteSize,
- _FieldDescriptor.TYPE_SFIXED32: wire_format.SFixed32ByteSize,
- _FieldDescriptor.TYPE_SFIXED64: wire_format.SFixed64ByteSize,
- _FieldDescriptor.TYPE_SINT32: wire_format.SInt32ByteSize,
- _FieldDescriptor.TYPE_SINT64: wire_format.SInt64ByteSize
- }
-
-
-# Maps from field types to encoder constructors.
-TYPE_TO_ENCODER = {
- _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleEncoder,
- _FieldDescriptor.TYPE_FLOAT: encoder.FloatEncoder,
- _FieldDescriptor.TYPE_INT64: encoder.Int64Encoder,
- _FieldDescriptor.TYPE_UINT64: encoder.UInt64Encoder,
- _FieldDescriptor.TYPE_INT32: encoder.Int32Encoder,
- _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Encoder,
- _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Encoder,
- _FieldDescriptor.TYPE_BOOL: encoder.BoolEncoder,
- _FieldDescriptor.TYPE_STRING: encoder.StringEncoder,
- _FieldDescriptor.TYPE_GROUP: encoder.GroupEncoder,
- _FieldDescriptor.TYPE_MESSAGE: encoder.MessageEncoder,
- _FieldDescriptor.TYPE_BYTES: encoder.BytesEncoder,
- _FieldDescriptor.TYPE_UINT32: encoder.UInt32Encoder,
- _FieldDescriptor.TYPE_ENUM: encoder.EnumEncoder,
- _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Encoder,
- _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Encoder,
- _FieldDescriptor.TYPE_SINT32: encoder.SInt32Encoder,
- _FieldDescriptor.TYPE_SINT64: encoder.SInt64Encoder,
- }
-
-
-# Maps from field types to sizer constructors.
-TYPE_TO_SIZER = {
- _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleSizer,
- _FieldDescriptor.TYPE_FLOAT: encoder.FloatSizer,
- _FieldDescriptor.TYPE_INT64: encoder.Int64Sizer,
- _FieldDescriptor.TYPE_UINT64: encoder.UInt64Sizer,
- _FieldDescriptor.TYPE_INT32: encoder.Int32Sizer,
- _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Sizer,
- _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Sizer,
- _FieldDescriptor.TYPE_BOOL: encoder.BoolSizer,
- _FieldDescriptor.TYPE_STRING: encoder.StringSizer,
- _FieldDescriptor.TYPE_GROUP: encoder.GroupSizer,
- _FieldDescriptor.TYPE_MESSAGE: encoder.MessageSizer,
- _FieldDescriptor.TYPE_BYTES: encoder.BytesSizer,
- _FieldDescriptor.TYPE_UINT32: encoder.UInt32Sizer,
- _FieldDescriptor.TYPE_ENUM: encoder.EnumSizer,
- _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Sizer,
- _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Sizer,
- _FieldDescriptor.TYPE_SINT32: encoder.SInt32Sizer,
- _FieldDescriptor.TYPE_SINT64: encoder.SInt64Sizer,
- }
-
-
-# Maps from field type to a decoder constructor.
-TYPE_TO_DECODER = {
- _FieldDescriptor.TYPE_DOUBLE: decoder.DoubleDecoder,
- _FieldDescriptor.TYPE_FLOAT: decoder.FloatDecoder,
- _FieldDescriptor.TYPE_INT64: decoder.Int64Decoder,
- _FieldDescriptor.TYPE_UINT64: decoder.UInt64Decoder,
- _FieldDescriptor.TYPE_INT32: decoder.Int32Decoder,
- _FieldDescriptor.TYPE_FIXED64: decoder.Fixed64Decoder,
- _FieldDescriptor.TYPE_FIXED32: decoder.Fixed32Decoder,
- _FieldDescriptor.TYPE_BOOL: decoder.BoolDecoder,
- _FieldDescriptor.TYPE_STRING: decoder.StringDecoder,
- _FieldDescriptor.TYPE_GROUP: decoder.GroupDecoder,
- _FieldDescriptor.TYPE_MESSAGE: decoder.MessageDecoder,
- _FieldDescriptor.TYPE_BYTES: decoder.BytesDecoder,
- _FieldDescriptor.TYPE_UINT32: decoder.UInt32Decoder,
- _FieldDescriptor.TYPE_ENUM: decoder.EnumDecoder,
- _FieldDescriptor.TYPE_SFIXED32: decoder.SFixed32Decoder,
- _FieldDescriptor.TYPE_SFIXED64: decoder.SFixed64Decoder,
- _FieldDescriptor.TYPE_SINT32: decoder.SInt32Decoder,
- _FieldDescriptor.TYPE_SINT64: decoder.SInt64Decoder,
- }
-
-# Maps from field type to expected wiretype.
-FIELD_TYPE_TO_WIRE_TYPE = {
- _FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64,
- _FieldDescriptor.TYPE_FLOAT: wire_format.WIRETYPE_FIXED32,
- _FieldDescriptor.TYPE_INT64: wire_format.WIRETYPE_VARINT,
- _FieldDescriptor.TYPE_UINT64: wire_format.WIRETYPE_VARINT,
- _FieldDescriptor.TYPE_INT32: wire_format.WIRETYPE_VARINT,
- _FieldDescriptor.TYPE_FIXED64: wire_format.WIRETYPE_FIXED64,
- _FieldDescriptor.TYPE_FIXED32: wire_format.WIRETYPE_FIXED32,
- _FieldDescriptor.TYPE_BOOL: wire_format.WIRETYPE_VARINT,
- _FieldDescriptor.TYPE_STRING:
- wire_format.WIRETYPE_LENGTH_DELIMITED,
- _FieldDescriptor.TYPE_GROUP: wire_format.WIRETYPE_START_GROUP,
- _FieldDescriptor.TYPE_MESSAGE:
- wire_format.WIRETYPE_LENGTH_DELIMITED,
- _FieldDescriptor.TYPE_BYTES:
- wire_format.WIRETYPE_LENGTH_DELIMITED,
- _FieldDescriptor.TYPE_UINT32: wire_format.WIRETYPE_VARINT,
- _FieldDescriptor.TYPE_ENUM: wire_format.WIRETYPE_VARINT,
- _FieldDescriptor.TYPE_SFIXED32: wire_format.WIRETYPE_FIXED32,
- _FieldDescriptor.TYPE_SFIXED64: wire_format.WIRETYPE_FIXED64,
- _FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT,
- _FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT,
- }
+# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Provides type checking routines. + +This module defines type checking utilities in the forms of dictionaries: + +VALUE_CHECKERS: A dictionary of field types and a value validation object. +TYPE_TO_BYTE_SIZE_FN: A dictionary with field types and a size computing + function. +TYPE_TO_SERIALIZE_METHOD: A dictionary with field types and serialization + function. +FIELD_TYPE_TO_WIRE_TYPE: A dictionary with field typed and their + coresponding wire types. +TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization + function. +""" + +__author__ = '[email protected] (Will Robinson)' + +from google.protobuf.internal import decoder +from google.protobuf.internal import encoder +from google.protobuf.internal import wire_format +from google.protobuf import descriptor + +_FieldDescriptor = descriptor.FieldDescriptor + + +def GetTypeChecker(cpp_type, field_type): + """Returns a type checker for a message field of the specified types. + + Args: + cpp_type: C++ type of the field (see descriptor.py). + field_type: Protocol message field type (see descriptor.py). + + Returns: + An instance of TypeChecker which can be used to verify the types + of values assigned to a field of the specified type. + """ + if (cpp_type == _FieldDescriptor.CPPTYPE_STRING and + field_type == _FieldDescriptor.TYPE_STRING): + return UnicodeValueChecker() + return _VALUE_CHECKERS[cpp_type] + + +# None of the typecheckers below make any attempt to guard against people +# subclassing builtin types and doing weird things. We're not trying to +# protect against malicious clients here, just people accidentally shooting +# themselves in the foot in obvious ways. + +class TypeChecker(object): + + """Type checker used to catch type errors as early as possible + when the client is setting scalar fields in protocol messages. + """ + + def __init__(self, *acceptable_types): + self._acceptable_types = acceptable_types + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, self._acceptable_types): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), self._acceptable_types)) + raise TypeError(message) + + +# IntValueChecker and its subclasses perform integer type-checks +# and bounds-checks. +class IntValueChecker(object): + + """Checker used for integer fields. Performs type-check and range check.""" + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, (int, long)): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), (int, long))) + raise TypeError(message) + if not self._MIN <= proposed_value <= self._MAX: + raise ValueError('Value out of range: %d' % proposed_value) + + +class UnicodeValueChecker(object): + + """Checker used for string fields.""" + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, (str, unicode)): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), (str, unicode))) + raise TypeError(message) + + # If the value is of type 'str' make sure that it is in 7-bit ASCII + # encoding. + if isinstance(proposed_value, str): + try: + unicode(proposed_value, 'ascii') + except UnicodeDecodeError: + raise ValueError('%.1024r has type str, but isn\'t in 7-bit ASCII ' + 'encoding. Non-ASCII strings must be converted to ' + 'unicode objects before being added.' % + (proposed_value)) + + +class Int32ValueChecker(IntValueChecker): + # We're sure to use ints instead of longs here since comparison may be more + # efficient. + _MIN = -2147483648 + _MAX = 2147483647 + + +class Uint32ValueChecker(IntValueChecker): + _MIN = 0 + _MAX = (1 << 32) - 1 + + +class Int64ValueChecker(IntValueChecker): + _MIN = -(1 << 63) + _MAX = (1 << 63) - 1 + + +class Uint64ValueChecker(IntValueChecker): + _MIN = 0 + _MAX = (1 << 64) - 1 + + +# Type-checkers for all scalar CPPTYPEs. +_VALUE_CHECKERS = { + _FieldDescriptor.CPPTYPE_INT32: Int32ValueChecker(), + _FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(), + _FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(), + _FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(), + _FieldDescriptor.CPPTYPE_DOUBLE: TypeChecker( + float, int, long), + _FieldDescriptor.CPPTYPE_FLOAT: TypeChecker( + float, int, long), + _FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int), + _FieldDescriptor.CPPTYPE_ENUM: Int32ValueChecker(), + _FieldDescriptor.CPPTYPE_STRING: TypeChecker(str), + } + + +# Map from field type to a function F, such that F(field_num, value) +# gives the total byte size for a value of the given type. This +# byte size includes tag information and any other additional space +# associated with serializing "value". +TYPE_TO_BYTE_SIZE_FN = { + _FieldDescriptor.TYPE_DOUBLE: wire_format.DoubleByteSize, + _FieldDescriptor.TYPE_FLOAT: wire_format.FloatByteSize, + _FieldDescriptor.TYPE_INT64: wire_format.Int64ByteSize, + _FieldDescriptor.TYPE_UINT64: wire_format.UInt64ByteSize, + _FieldDescriptor.TYPE_INT32: wire_format.Int32ByteSize, + _FieldDescriptor.TYPE_FIXED64: wire_format.Fixed64ByteSize, + _FieldDescriptor.TYPE_FIXED32: wire_format.Fixed32ByteSize, + _FieldDescriptor.TYPE_BOOL: wire_format.BoolByteSize, + _FieldDescriptor.TYPE_STRING: wire_format.StringByteSize, + _FieldDescriptor.TYPE_GROUP: wire_format.GroupByteSize, + _FieldDescriptor.TYPE_MESSAGE: wire_format.MessageByteSize, + _FieldDescriptor.TYPE_BYTES: wire_format.BytesByteSize, + _FieldDescriptor.TYPE_UINT32: wire_format.UInt32ByteSize, + _FieldDescriptor.TYPE_ENUM: wire_format.EnumByteSize, + _FieldDescriptor.TYPE_SFIXED32: wire_format.SFixed32ByteSize, + _FieldDescriptor.TYPE_SFIXED64: wire_format.SFixed64ByteSize, + _FieldDescriptor.TYPE_SINT32: wire_format.SInt32ByteSize, + _FieldDescriptor.TYPE_SINT64: wire_format.SInt64ByteSize + } + + +# Maps from field types to encoder constructors. +TYPE_TO_ENCODER = { + _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleEncoder, + _FieldDescriptor.TYPE_FLOAT: encoder.FloatEncoder, + _FieldDescriptor.TYPE_INT64: encoder.Int64Encoder, + _FieldDescriptor.TYPE_UINT64: encoder.UInt64Encoder, + _FieldDescriptor.TYPE_INT32: encoder.Int32Encoder, + _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Encoder, + _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Encoder, + _FieldDescriptor.TYPE_BOOL: encoder.BoolEncoder, + _FieldDescriptor.TYPE_STRING: encoder.StringEncoder, + _FieldDescriptor.TYPE_GROUP: encoder.GroupEncoder, + _FieldDescriptor.TYPE_MESSAGE: encoder.MessageEncoder, + _FieldDescriptor.TYPE_BYTES: encoder.BytesEncoder, + _FieldDescriptor.TYPE_UINT32: encoder.UInt32Encoder, + _FieldDescriptor.TYPE_ENUM: encoder.EnumEncoder, + _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Encoder, + _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Encoder, + _FieldDescriptor.TYPE_SINT32: encoder.SInt32Encoder, + _FieldDescriptor.TYPE_SINT64: encoder.SInt64Encoder, + } + + +# Maps from field types to sizer constructors. +TYPE_TO_SIZER = { + _FieldDescriptor.TYPE_DOUBLE: encoder.DoubleSizer, + _FieldDescriptor.TYPE_FLOAT: encoder.FloatSizer, + _FieldDescriptor.TYPE_INT64: encoder.Int64Sizer, + _FieldDescriptor.TYPE_UINT64: encoder.UInt64Sizer, + _FieldDescriptor.TYPE_INT32: encoder.Int32Sizer, + _FieldDescriptor.TYPE_FIXED64: encoder.Fixed64Sizer, + _FieldDescriptor.TYPE_FIXED32: encoder.Fixed32Sizer, + _FieldDescriptor.TYPE_BOOL: encoder.BoolSizer, + _FieldDescriptor.TYPE_STRING: encoder.StringSizer, + _FieldDescriptor.TYPE_GROUP: encoder.GroupSizer, + _FieldDescriptor.TYPE_MESSAGE: encoder.MessageSizer, + _FieldDescriptor.TYPE_BYTES: encoder.BytesSizer, + _FieldDescriptor.TYPE_UINT32: encoder.UInt32Sizer, + _FieldDescriptor.TYPE_ENUM: encoder.EnumSizer, + _FieldDescriptor.TYPE_SFIXED32: encoder.SFixed32Sizer, + _FieldDescriptor.TYPE_SFIXED64: encoder.SFixed64Sizer, + _FieldDescriptor.TYPE_SINT32: encoder.SInt32Sizer, + _FieldDescriptor.TYPE_SINT64: encoder.SInt64Sizer, + } + + +# Maps from field type to a decoder constructor. +TYPE_TO_DECODER = { + _FieldDescriptor.TYPE_DOUBLE: decoder.DoubleDecoder, + _FieldDescriptor.TYPE_FLOAT: decoder.FloatDecoder, + _FieldDescriptor.TYPE_INT64: decoder.Int64Decoder, + _FieldDescriptor.TYPE_UINT64: decoder.UInt64Decoder, + _FieldDescriptor.TYPE_INT32: decoder.Int32Decoder, + _FieldDescriptor.TYPE_FIXED64: decoder.Fixed64Decoder, + _FieldDescriptor.TYPE_FIXED32: decoder.Fixed32Decoder, + _FieldDescriptor.TYPE_BOOL: decoder.BoolDecoder, + _FieldDescriptor.TYPE_STRING: decoder.StringDecoder, + _FieldDescriptor.TYPE_GROUP: decoder.GroupDecoder, + _FieldDescriptor.TYPE_MESSAGE: decoder.MessageDecoder, + _FieldDescriptor.TYPE_BYTES: decoder.BytesDecoder, + _FieldDescriptor.TYPE_UINT32: decoder.UInt32Decoder, + _FieldDescriptor.TYPE_ENUM: decoder.EnumDecoder, + _FieldDescriptor.TYPE_SFIXED32: decoder.SFixed32Decoder, + _FieldDescriptor.TYPE_SFIXED64: decoder.SFixed64Decoder, + _FieldDescriptor.TYPE_SINT32: decoder.SInt32Decoder, + _FieldDescriptor.TYPE_SINT64: decoder.SInt64Decoder, + } + +# Maps from field type to expected wiretype. +FIELD_TYPE_TO_WIRE_TYPE = { + _FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_FLOAT: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_INT64: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_UINT64: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_INT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_FIXED64: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_FIXED32: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_BOOL: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_STRING: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_GROUP: wire_format.WIRETYPE_START_GROUP, + _FieldDescriptor.TYPE_MESSAGE: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_BYTES: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_UINT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_ENUM: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_SFIXED32: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_SFIXED64: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT, + } diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/wire_format.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/wire_format.py index 12303138..c941fe1a 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/wire_format.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/wire_format.py @@ -1,268 +1,268 @@ -# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Constants and static functions to support protocol buffer wire format."""
-
-__author__ = '[email protected] (Will Robinson)'
-
-import struct
-from google.protobuf import descriptor
-from google.protobuf import message
-
-
-TAG_TYPE_BITS = 3 # Number of bits used to hold type info in a proto tag.
-TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7
-
-# These numbers identify the wire type of a protocol buffer value.
-# We use the least-significant TAG_TYPE_BITS bits of the varint-encoded
-# tag-and-type to store one of these WIRETYPE_* constants.
-# These values must match WireType enum in google/protobuf/wire_format.h.
-WIRETYPE_VARINT = 0
-WIRETYPE_FIXED64 = 1
-WIRETYPE_LENGTH_DELIMITED = 2
-WIRETYPE_START_GROUP = 3
-WIRETYPE_END_GROUP = 4
-WIRETYPE_FIXED32 = 5
-_WIRETYPE_MAX = 5
-
-
-# Bounds for various integer types.
-INT32_MAX = int((1 << 31) - 1)
-INT32_MIN = int(-(1 << 31))
-UINT32_MAX = (1 << 32) - 1
-
-INT64_MAX = (1 << 63) - 1
-INT64_MIN = -(1 << 63)
-UINT64_MAX = (1 << 64) - 1
-
-# "struct" format strings that will encode/decode the specified formats.
-FORMAT_UINT32_LITTLE_ENDIAN = '<I'
-FORMAT_UINT64_LITTLE_ENDIAN = '<Q'
-FORMAT_FLOAT_LITTLE_ENDIAN = '<f'
-FORMAT_DOUBLE_LITTLE_ENDIAN = '<d'
-
-
-# We'll have to provide alternate implementations of AppendLittleEndian*() on
-# any architectures where these checks fail.
-if struct.calcsize(FORMAT_UINT32_LITTLE_ENDIAN) != 4:
- raise AssertionError('Format "I" is not a 32-bit number.')
-if struct.calcsize(FORMAT_UINT64_LITTLE_ENDIAN) != 8:
- raise AssertionError('Format "Q" is not a 64-bit number.')
-
-
-def PackTag(field_number, wire_type):
- """Returns an unsigned 32-bit integer that encodes the field number and
- wire type information in standard protocol message wire format.
-
- Args:
- field_number: Expected to be an integer in the range [1, 1 << 29)
- wire_type: One of the WIRETYPE_* constants.
- """
- if not 0 <= wire_type <= _WIRETYPE_MAX:
- raise message.EncodeError('Unknown wire type: %d' % wire_type)
- return (field_number << TAG_TYPE_BITS) | wire_type
-
-
-def UnpackTag(tag):
- """The inverse of PackTag(). Given an unsigned 32-bit number,
- returns a (field_number, wire_type) tuple.
- """
- return (tag >> TAG_TYPE_BITS), (tag & TAG_TYPE_MASK)
-
-
-def ZigZagEncode(value):
- """ZigZag Transform: Encodes signed integers so that they can be
- effectively used with varint encoding. See wire_format.h for
- more details.
- """
- if value >= 0:
- return value << 1
- return (value << 1) ^ (~0)
-
-
-def ZigZagDecode(value):
- """Inverse of ZigZagEncode()."""
- if not value & 0x1:
- return value >> 1
- return (value >> 1) ^ (~0)
-
-
-
-# The *ByteSize() functions below return the number of bytes required to
-# serialize "field number + type" information and then serialize the value.
-
-
-def Int32ByteSize(field_number, int32):
- return Int64ByteSize(field_number, int32)
-
-
-def Int32ByteSizeNoTag(int32):
- return _VarUInt64ByteSizeNoTag(0xffffffffffffffff & int32)
-
-
-def Int64ByteSize(field_number, int64):
- # Have to convert to uint before calling UInt64ByteSize().
- return UInt64ByteSize(field_number, 0xffffffffffffffff & int64)
-
-
-def UInt32ByteSize(field_number, uint32):
- return UInt64ByteSize(field_number, uint32)
-
-
-def UInt64ByteSize(field_number, uint64):
- return TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64)
-
-
-def SInt32ByteSize(field_number, int32):
- return UInt32ByteSize(field_number, ZigZagEncode(int32))
-
-
-def SInt64ByteSize(field_number, int64):
- return UInt64ByteSize(field_number, ZigZagEncode(int64))
-
-
-def Fixed32ByteSize(field_number, fixed32):
- return TagByteSize(field_number) + 4
-
-
-def Fixed64ByteSize(field_number, fixed64):
- return TagByteSize(field_number) + 8
-
-
-def SFixed32ByteSize(field_number, sfixed32):
- return TagByteSize(field_number) + 4
-
-
-def SFixed64ByteSize(field_number, sfixed64):
- return TagByteSize(field_number) + 8
-
-
-def FloatByteSize(field_number, flt):
- return TagByteSize(field_number) + 4
-
-
-def DoubleByteSize(field_number, double):
- return TagByteSize(field_number) + 8
-
-
-def BoolByteSize(field_number, b):
- return TagByteSize(field_number) + 1
-
-
-def EnumByteSize(field_number, enum):
- return UInt32ByteSize(field_number, enum)
-
-
-def StringByteSize(field_number, string):
- return BytesByteSize(field_number, string.encode('utf-8'))
-
-
-def BytesByteSize(field_number, b):
- return (TagByteSize(field_number)
- + _VarUInt64ByteSizeNoTag(len(b))
- + len(b))
-
-
-def GroupByteSize(field_number, message):
- return (2 * TagByteSize(field_number) # START and END group.
- + message.ByteSize())
-
-
-def MessageByteSize(field_number, message):
- return (TagByteSize(field_number)
- + _VarUInt64ByteSizeNoTag(message.ByteSize())
- + message.ByteSize())
-
-
-def MessageSetItemByteSize(field_number, msg):
- # First compute the sizes of the tags.
- # There are 2 tags for the beginning and ending of the repeated group, that
- # is field number 1, one with field number 2 (type_id) and one with field
- # number 3 (message).
- total_size = (2 * TagByteSize(1) + TagByteSize(2) + TagByteSize(3))
-
- # Add the number of bytes for type_id.
- total_size += _VarUInt64ByteSizeNoTag(field_number)
-
- message_size = msg.ByteSize()
-
- # The number of bytes for encoding the length of the message.
- total_size += _VarUInt64ByteSizeNoTag(message_size)
-
- # The size of the message.
- total_size += message_size
- return total_size
-
-
-def TagByteSize(field_number):
- """Returns the bytes required to serialize a tag with this field number."""
- # Just pass in type 0, since the type won't affect the tag+type size.
- return _VarUInt64ByteSizeNoTag(PackTag(field_number, 0))
-
-
-# Private helper function for the *ByteSize() functions above.
-
-def _VarUInt64ByteSizeNoTag(uint64):
- """Returns the number of bytes required to serialize a single varint
- using boundary value comparisons. (unrolled loop optimization -WPierce)
- uint64 must be unsigned.
- """
- if uint64 <= 0x7f: return 1
- if uint64 <= 0x3fff: return 2
- if uint64 <= 0x1fffff: return 3
- if uint64 <= 0xfffffff: return 4
- if uint64 <= 0x7ffffffff: return 5
- if uint64 <= 0x3ffffffffff: return 6
- if uint64 <= 0x1ffffffffffff: return 7
- if uint64 <= 0xffffffffffffff: return 8
- if uint64 <= 0x7fffffffffffffff: return 9
- if uint64 > UINT64_MAX:
- raise message.EncodeError('Value out of range: %d' % uint64)
- return 10
-
-
-NON_PACKABLE_TYPES = (
- descriptor.FieldDescriptor.TYPE_STRING,
- descriptor.FieldDescriptor.TYPE_GROUP,
- descriptor.FieldDescriptor.TYPE_MESSAGE,
- descriptor.FieldDescriptor.TYPE_BYTES
-)
-
-
-def IsTypePackable(field_type):
- """Return true iff packable = true is valid for fields of this type.
-
- Args:
- field_type: a FieldDescriptor::Type value.
-
- Returns:
- True iff fields of this type are packable.
- """
- return field_type not in NON_PACKABLE_TYPES
+# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Constants and static functions to support protocol buffer wire format.""" + +__author__ = '[email protected] (Will Robinson)' + +import struct +from google.protobuf import descriptor +from google.protobuf import message + + +TAG_TYPE_BITS = 3 # Number of bits used to hold type info in a proto tag. +TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7 + +# These numbers identify the wire type of a protocol buffer value. +# We use the least-significant TAG_TYPE_BITS bits of the varint-encoded +# tag-and-type to store one of these WIRETYPE_* constants. +# These values must match WireType enum in google/protobuf/wire_format.h. +WIRETYPE_VARINT = 0 +WIRETYPE_FIXED64 = 1 +WIRETYPE_LENGTH_DELIMITED = 2 +WIRETYPE_START_GROUP = 3 +WIRETYPE_END_GROUP = 4 +WIRETYPE_FIXED32 = 5 +_WIRETYPE_MAX = 5 + + +# Bounds for various integer types. +INT32_MAX = int((1 << 31) - 1) +INT32_MIN = int(-(1 << 31)) +UINT32_MAX = (1 << 32) - 1 + +INT64_MAX = (1 << 63) - 1 +INT64_MIN = -(1 << 63) +UINT64_MAX = (1 << 64) - 1 + +# "struct" format strings that will encode/decode the specified formats. +FORMAT_UINT32_LITTLE_ENDIAN = '<I' +FORMAT_UINT64_LITTLE_ENDIAN = '<Q' +FORMAT_FLOAT_LITTLE_ENDIAN = '<f' +FORMAT_DOUBLE_LITTLE_ENDIAN = '<d' + + +# We'll have to provide alternate implementations of AppendLittleEndian*() on +# any architectures where these checks fail. +if struct.calcsize(FORMAT_UINT32_LITTLE_ENDIAN) != 4: + raise AssertionError('Format "I" is not a 32-bit number.') +if struct.calcsize(FORMAT_UINT64_LITTLE_ENDIAN) != 8: + raise AssertionError('Format "Q" is not a 64-bit number.') + + +def PackTag(field_number, wire_type): + """Returns an unsigned 32-bit integer that encodes the field number and + wire type information in standard protocol message wire format. + + Args: + field_number: Expected to be an integer in the range [1, 1 << 29) + wire_type: One of the WIRETYPE_* constants. + """ + if not 0 <= wire_type <= _WIRETYPE_MAX: + raise message.EncodeError('Unknown wire type: %d' % wire_type) + return (field_number << TAG_TYPE_BITS) | wire_type + + +def UnpackTag(tag): + """The inverse of PackTag(). Given an unsigned 32-bit number, + returns a (field_number, wire_type) tuple. + """ + return (tag >> TAG_TYPE_BITS), (tag & TAG_TYPE_MASK) + + +def ZigZagEncode(value): + """ZigZag Transform: Encodes signed integers so that they can be + effectively used with varint encoding. See wire_format.h for + more details. + """ + if value >= 0: + return value << 1 + return (value << 1) ^ (~0) + + +def ZigZagDecode(value): + """Inverse of ZigZagEncode().""" + if not value & 0x1: + return value >> 1 + return (value >> 1) ^ (~0) + + + +# The *ByteSize() functions below return the number of bytes required to +# serialize "field number + type" information and then serialize the value. + + +def Int32ByteSize(field_number, int32): + return Int64ByteSize(field_number, int32) + + +def Int32ByteSizeNoTag(int32): + return _VarUInt64ByteSizeNoTag(0xffffffffffffffff & int32) + + +def Int64ByteSize(field_number, int64): + # Have to convert to uint before calling UInt64ByteSize(). + return UInt64ByteSize(field_number, 0xffffffffffffffff & int64) + + +def UInt32ByteSize(field_number, uint32): + return UInt64ByteSize(field_number, uint32) + + +def UInt64ByteSize(field_number, uint64): + return TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64) + + +def SInt32ByteSize(field_number, int32): + return UInt32ByteSize(field_number, ZigZagEncode(int32)) + + +def SInt64ByteSize(field_number, int64): + return UInt64ByteSize(field_number, ZigZagEncode(int64)) + + +def Fixed32ByteSize(field_number, fixed32): + return TagByteSize(field_number) + 4 + + +def Fixed64ByteSize(field_number, fixed64): + return TagByteSize(field_number) + 8 + + +def SFixed32ByteSize(field_number, sfixed32): + return TagByteSize(field_number) + 4 + + +def SFixed64ByteSize(field_number, sfixed64): + return TagByteSize(field_number) + 8 + + +def FloatByteSize(field_number, flt): + return TagByteSize(field_number) + 4 + + +def DoubleByteSize(field_number, double): + return TagByteSize(field_number) + 8 + + +def BoolByteSize(field_number, b): + return TagByteSize(field_number) + 1 + + +def EnumByteSize(field_number, enum): + return UInt32ByteSize(field_number, enum) + + +def StringByteSize(field_number, string): + return BytesByteSize(field_number, string.encode('utf-8')) + + +def BytesByteSize(field_number, b): + return (TagByteSize(field_number) + + _VarUInt64ByteSizeNoTag(len(b)) + + len(b)) + + +def GroupByteSize(field_number, message): + return (2 * TagByteSize(field_number) # START and END group. + + message.ByteSize()) + + +def MessageByteSize(field_number, message): + return (TagByteSize(field_number) + + _VarUInt64ByteSizeNoTag(message.ByteSize()) + + message.ByteSize()) + + +def MessageSetItemByteSize(field_number, msg): + # First compute the sizes of the tags. + # There are 2 tags for the beginning and ending of the repeated group, that + # is field number 1, one with field number 2 (type_id) and one with field + # number 3 (message). + total_size = (2 * TagByteSize(1) + TagByteSize(2) + TagByteSize(3)) + + # Add the number of bytes for type_id. + total_size += _VarUInt64ByteSizeNoTag(field_number) + + message_size = msg.ByteSize() + + # The number of bytes for encoding the length of the message. + total_size += _VarUInt64ByteSizeNoTag(message_size) + + # The size of the message. + total_size += message_size + return total_size + + +def TagByteSize(field_number): + """Returns the bytes required to serialize a tag with this field number.""" + # Just pass in type 0, since the type won't affect the tag+type size. + return _VarUInt64ByteSizeNoTag(PackTag(field_number, 0)) + + +# Private helper function for the *ByteSize() functions above. + +def _VarUInt64ByteSizeNoTag(uint64): + """Returns the number of bytes required to serialize a single varint + using boundary value comparisons. (unrolled loop optimization -WPierce) + uint64 must be unsigned. + """ + if uint64 <= 0x7f: return 1 + if uint64 <= 0x3fff: return 2 + if uint64 <= 0x1fffff: return 3 + if uint64 <= 0xfffffff: return 4 + if uint64 <= 0x7ffffffff: return 5 + if uint64 <= 0x3ffffffffff: return 6 + if uint64 <= 0x1ffffffffffff: return 7 + if uint64 <= 0xffffffffffffff: return 8 + if uint64 <= 0x7fffffffffffffff: return 9 + if uint64 > UINT64_MAX: + raise message.EncodeError('Value out of range: %d' % uint64) + return 10 + + +NON_PACKABLE_TYPES = ( + descriptor.FieldDescriptor.TYPE_STRING, + descriptor.FieldDescriptor.TYPE_GROUP, + descriptor.FieldDescriptor.TYPE_MESSAGE, + descriptor.FieldDescriptor.TYPE_BYTES +) + + +def IsTypePackable(field_type): + """Return true iff packable = true is valid for fields of this type. + + Args: + field_type: a FieldDescriptor::Type value. + + Returns: + True iff fields of this type are packable. + """ + return field_type not in NON_PACKABLE_TYPES diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/wire_format_test.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/wire_format_test.py index 3469467c..76007786 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/wire_format_test.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/internal/wire_format_test.py @@ -1,253 +1,253 @@ -#! /usr/bin/python
-#
-# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Test for google.protobuf.internal.wire_format."""
-
-__author__ = '[email protected] (Will Robinson)'
-
-import unittest
-from google.protobuf import message
-from google.protobuf.internal import wire_format
-
-
-class WireFormatTest(unittest.TestCase):
-
- def testPackTag(self):
- field_number = 0xabc
- tag_type = 2
- self.assertEqual((field_number << 3) | tag_type,
- wire_format.PackTag(field_number, tag_type))
- PackTag = wire_format.PackTag
- # Number too high.
- self.assertRaises(message.EncodeError, PackTag, field_number, 6)
- # Number too low.
- self.assertRaises(message.EncodeError, PackTag, field_number, -1)
-
- def testUnpackTag(self):
- # Test field numbers that will require various varint sizes.
- for expected_field_number in (1, 15, 16, 2047, 2048):
- for expected_wire_type in range(6): # Highest-numbered wiretype is 5.
- field_number, wire_type = wire_format.UnpackTag(
- wire_format.PackTag(expected_field_number, expected_wire_type))
- self.assertEqual(expected_field_number, field_number)
- self.assertEqual(expected_wire_type, wire_type)
-
- self.assertRaises(TypeError, wire_format.UnpackTag, None)
- self.assertRaises(TypeError, wire_format.UnpackTag, 'abc')
- self.assertRaises(TypeError, wire_format.UnpackTag, 0.0)
- self.assertRaises(TypeError, wire_format.UnpackTag, object())
-
- def testZigZagEncode(self):
- Z = wire_format.ZigZagEncode
- self.assertEqual(0, Z(0))
- self.assertEqual(1, Z(-1))
- self.assertEqual(2, Z(1))
- self.assertEqual(3, Z(-2))
- self.assertEqual(4, Z(2))
- self.assertEqual(0xfffffffe, Z(0x7fffffff))
- self.assertEqual(0xffffffff, Z(-0x80000000))
- self.assertEqual(0xfffffffffffffffe, Z(0x7fffffffffffffff))
- self.assertEqual(0xffffffffffffffff, Z(-0x8000000000000000))
-
- self.assertRaises(TypeError, Z, None)
- self.assertRaises(TypeError, Z, 'abcd')
- self.assertRaises(TypeError, Z, 0.0)
- self.assertRaises(TypeError, Z, object())
-
- def testZigZagDecode(self):
- Z = wire_format.ZigZagDecode
- self.assertEqual(0, Z(0))
- self.assertEqual(-1, Z(1))
- self.assertEqual(1, Z(2))
- self.assertEqual(-2, Z(3))
- self.assertEqual(2, Z(4))
- self.assertEqual(0x7fffffff, Z(0xfffffffe))
- self.assertEqual(-0x80000000, Z(0xffffffff))
- self.assertEqual(0x7fffffffffffffff, Z(0xfffffffffffffffe))
- self.assertEqual(-0x8000000000000000, Z(0xffffffffffffffff))
-
- self.assertRaises(TypeError, Z, None)
- self.assertRaises(TypeError, Z, 'abcd')
- self.assertRaises(TypeError, Z, 0.0)
- self.assertRaises(TypeError, Z, object())
-
- def NumericByteSizeTestHelper(self, byte_size_fn, value, expected_value_size):
- # Use field numbers that cause various byte sizes for the tag information.
- for field_number, tag_bytes in ((15, 1), (16, 2), (2047, 2), (2048, 3)):
- expected_size = expected_value_size + tag_bytes
- actual_size = byte_size_fn(field_number, value)
- self.assertEqual(expected_size, actual_size,
- 'byte_size_fn: %s, field_number: %d, value: %r\n'
- 'Expected: %d, Actual: %d'% (
- byte_size_fn, field_number, value, expected_size, actual_size))
-
- def testByteSizeFunctions(self):
- # Test all numeric *ByteSize() functions.
- NUMERIC_ARGS = [
- # Int32ByteSize().
- [wire_format.Int32ByteSize, 0, 1],
- [wire_format.Int32ByteSize, 127, 1],
- [wire_format.Int32ByteSize, 128, 2],
- [wire_format.Int32ByteSize, -1, 10],
- # Int64ByteSize().
- [wire_format.Int64ByteSize, 0, 1],
- [wire_format.Int64ByteSize, 127, 1],
- [wire_format.Int64ByteSize, 128, 2],
- [wire_format.Int64ByteSize, -1, 10],
- # UInt32ByteSize().
- [wire_format.UInt32ByteSize, 0, 1],
- [wire_format.UInt32ByteSize, 127, 1],
- [wire_format.UInt32ByteSize, 128, 2],
- [wire_format.UInt32ByteSize, wire_format.UINT32_MAX, 5],
- # UInt64ByteSize().
- [wire_format.UInt64ByteSize, 0, 1],
- [wire_format.UInt64ByteSize, 127, 1],
- [wire_format.UInt64ByteSize, 128, 2],
- [wire_format.UInt64ByteSize, wire_format.UINT64_MAX, 10],
- # SInt32ByteSize().
- [wire_format.SInt32ByteSize, 0, 1],
- [wire_format.SInt32ByteSize, -1, 1],
- [wire_format.SInt32ByteSize, 1, 1],
- [wire_format.SInt32ByteSize, -63, 1],
- [wire_format.SInt32ByteSize, 63, 1],
- [wire_format.SInt32ByteSize, -64, 1],
- [wire_format.SInt32ByteSize, 64, 2],
- # SInt64ByteSize().
- [wire_format.SInt64ByteSize, 0, 1],
- [wire_format.SInt64ByteSize, -1, 1],
- [wire_format.SInt64ByteSize, 1, 1],
- [wire_format.SInt64ByteSize, -63, 1],
- [wire_format.SInt64ByteSize, 63, 1],
- [wire_format.SInt64ByteSize, -64, 1],
- [wire_format.SInt64ByteSize, 64, 2],
- # Fixed32ByteSize().
- [wire_format.Fixed32ByteSize, 0, 4],
- [wire_format.Fixed32ByteSize, wire_format.UINT32_MAX, 4],
- # Fixed64ByteSize().
- [wire_format.Fixed64ByteSize, 0, 8],
- [wire_format.Fixed64ByteSize, wire_format.UINT64_MAX, 8],
- # SFixed32ByteSize().
- [wire_format.SFixed32ByteSize, 0, 4],
- [wire_format.SFixed32ByteSize, wire_format.INT32_MIN, 4],
- [wire_format.SFixed32ByteSize, wire_format.INT32_MAX, 4],
- # SFixed64ByteSize().
- [wire_format.SFixed64ByteSize, 0, 8],
- [wire_format.SFixed64ByteSize, wire_format.INT64_MIN, 8],
- [wire_format.SFixed64ByteSize, wire_format.INT64_MAX, 8],
- # FloatByteSize().
- [wire_format.FloatByteSize, 0.0, 4],
- [wire_format.FloatByteSize, 1000000000.0, 4],
- [wire_format.FloatByteSize, -1000000000.0, 4],
- # DoubleByteSize().
- [wire_format.DoubleByteSize, 0.0, 8],
- [wire_format.DoubleByteSize, 1000000000.0, 8],
- [wire_format.DoubleByteSize, -1000000000.0, 8],
- # BoolByteSize().
- [wire_format.BoolByteSize, False, 1],
- [wire_format.BoolByteSize, True, 1],
- # EnumByteSize().
- [wire_format.EnumByteSize, 0, 1],
- [wire_format.EnumByteSize, 127, 1],
- [wire_format.EnumByteSize, 128, 2],
- [wire_format.EnumByteSize, wire_format.UINT32_MAX, 5],
- ]
- for args in NUMERIC_ARGS:
- self.NumericByteSizeTestHelper(*args)
-
- # Test strings and bytes.
- for byte_size_fn in (wire_format.StringByteSize, wire_format.BytesByteSize):
- # 1 byte for tag, 1 byte for length, 3 bytes for contents.
- self.assertEqual(5, byte_size_fn(10, 'abc'))
- # 2 bytes for tag, 1 byte for length, 3 bytes for contents.
- self.assertEqual(6, byte_size_fn(16, 'abc'))
- # 2 bytes for tag, 2 bytes for length, 128 bytes for contents.
- self.assertEqual(132, byte_size_fn(16, 'a' * 128))
-
- # Test UTF-8 string byte size calculation.
- # 1 byte for tag, 1 byte for length, 8 bytes for content.
- self.assertEqual(10, wire_format.StringByteSize(
- 5, unicode('\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82', 'utf-8')))
-
- class MockMessage(object):
- def __init__(self, byte_size):
- self.byte_size = byte_size
- def ByteSize(self):
- return self.byte_size
-
- message_byte_size = 10
- mock_message = MockMessage(byte_size=message_byte_size)
- # Test groups.
- # (2 * 1) bytes for begin and end tags, plus message_byte_size.
- self.assertEqual(2 + message_byte_size,
- wire_format.GroupByteSize(1, mock_message))
- # (2 * 2) bytes for begin and end tags, plus message_byte_size.
- self.assertEqual(4 + message_byte_size,
- wire_format.GroupByteSize(16, mock_message))
-
- # Test messages.
- # 1 byte for tag, plus 1 byte for length, plus contents.
- self.assertEqual(2 + mock_message.byte_size,
- wire_format.MessageByteSize(1, mock_message))
- # 2 bytes for tag, plus 1 byte for length, plus contents.
- self.assertEqual(3 + mock_message.byte_size,
- wire_format.MessageByteSize(16, mock_message))
- # 2 bytes for tag, plus 2 bytes for length, plus contents.
- mock_message.byte_size = 128
- self.assertEqual(4 + mock_message.byte_size,
- wire_format.MessageByteSize(16, mock_message))
-
-
- # Test message set item byte size.
- # 4 bytes for tags, plus 1 byte for length, plus 1 byte for type_id,
- # plus contents.
- mock_message.byte_size = 10
- self.assertEqual(mock_message.byte_size + 6,
- wire_format.MessageSetItemByteSize(1, mock_message))
-
- # 4 bytes for tags, plus 2 bytes for length, plus 1 byte for type_id,
- # plus contents.
- mock_message.byte_size = 128
- self.assertEqual(mock_message.byte_size + 7,
- wire_format.MessageSetItemByteSize(1, mock_message))
-
- # 4 bytes for tags, plus 2 bytes for length, plus 2 byte for type_id,
- # plus contents.
- self.assertEqual(mock_message.byte_size + 8,
- wire_format.MessageSetItemByteSize(128, mock_message))
-
- # Too-long varint.
- self.assertRaises(message.EncodeError,
- wire_format.UInt64ByteSize, 1, 1 << 128)
-
-
-if __name__ == '__main__':
- unittest.main()
+#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for google.protobuf.internal.wire_format.""" + +__author__ = '[email protected] (Will Robinson)' + +import unittest +from google.protobuf import message +from google.protobuf.internal import wire_format + + +class WireFormatTest(unittest.TestCase): + + def testPackTag(self): + field_number = 0xabc + tag_type = 2 + self.assertEqual((field_number << 3) | tag_type, + wire_format.PackTag(field_number, tag_type)) + PackTag = wire_format.PackTag + # Number too high. + self.assertRaises(message.EncodeError, PackTag, field_number, 6) + # Number too low. + self.assertRaises(message.EncodeError, PackTag, field_number, -1) + + def testUnpackTag(self): + # Test field numbers that will require various varint sizes. + for expected_field_number in (1, 15, 16, 2047, 2048): + for expected_wire_type in range(6): # Highest-numbered wiretype is 5. + field_number, wire_type = wire_format.UnpackTag( + wire_format.PackTag(expected_field_number, expected_wire_type)) + self.assertEqual(expected_field_number, field_number) + self.assertEqual(expected_wire_type, wire_type) + + self.assertRaises(TypeError, wire_format.UnpackTag, None) + self.assertRaises(TypeError, wire_format.UnpackTag, 'abc') + self.assertRaises(TypeError, wire_format.UnpackTag, 0.0) + self.assertRaises(TypeError, wire_format.UnpackTag, object()) + + def testZigZagEncode(self): + Z = wire_format.ZigZagEncode + self.assertEqual(0, Z(0)) + self.assertEqual(1, Z(-1)) + self.assertEqual(2, Z(1)) + self.assertEqual(3, Z(-2)) + self.assertEqual(4, Z(2)) + self.assertEqual(0xfffffffe, Z(0x7fffffff)) + self.assertEqual(0xffffffff, Z(-0x80000000)) + self.assertEqual(0xfffffffffffffffe, Z(0x7fffffffffffffff)) + self.assertEqual(0xffffffffffffffff, Z(-0x8000000000000000)) + + self.assertRaises(TypeError, Z, None) + self.assertRaises(TypeError, Z, 'abcd') + self.assertRaises(TypeError, Z, 0.0) + self.assertRaises(TypeError, Z, object()) + + def testZigZagDecode(self): + Z = wire_format.ZigZagDecode + self.assertEqual(0, Z(0)) + self.assertEqual(-1, Z(1)) + self.assertEqual(1, Z(2)) + self.assertEqual(-2, Z(3)) + self.assertEqual(2, Z(4)) + self.assertEqual(0x7fffffff, Z(0xfffffffe)) + self.assertEqual(-0x80000000, Z(0xffffffff)) + self.assertEqual(0x7fffffffffffffff, Z(0xfffffffffffffffe)) + self.assertEqual(-0x8000000000000000, Z(0xffffffffffffffff)) + + self.assertRaises(TypeError, Z, None) + self.assertRaises(TypeError, Z, 'abcd') + self.assertRaises(TypeError, Z, 0.0) + self.assertRaises(TypeError, Z, object()) + + def NumericByteSizeTestHelper(self, byte_size_fn, value, expected_value_size): + # Use field numbers that cause various byte sizes for the tag information. + for field_number, tag_bytes in ((15, 1), (16, 2), (2047, 2), (2048, 3)): + expected_size = expected_value_size + tag_bytes + actual_size = byte_size_fn(field_number, value) + self.assertEqual(expected_size, actual_size, + 'byte_size_fn: %s, field_number: %d, value: %r\n' + 'Expected: %d, Actual: %d'% ( + byte_size_fn, field_number, value, expected_size, actual_size)) + + def testByteSizeFunctions(self): + # Test all numeric *ByteSize() functions. + NUMERIC_ARGS = [ + # Int32ByteSize(). + [wire_format.Int32ByteSize, 0, 1], + [wire_format.Int32ByteSize, 127, 1], + [wire_format.Int32ByteSize, 128, 2], + [wire_format.Int32ByteSize, -1, 10], + # Int64ByteSize(). + [wire_format.Int64ByteSize, 0, 1], + [wire_format.Int64ByteSize, 127, 1], + [wire_format.Int64ByteSize, 128, 2], + [wire_format.Int64ByteSize, -1, 10], + # UInt32ByteSize(). + [wire_format.UInt32ByteSize, 0, 1], + [wire_format.UInt32ByteSize, 127, 1], + [wire_format.UInt32ByteSize, 128, 2], + [wire_format.UInt32ByteSize, wire_format.UINT32_MAX, 5], + # UInt64ByteSize(). + [wire_format.UInt64ByteSize, 0, 1], + [wire_format.UInt64ByteSize, 127, 1], + [wire_format.UInt64ByteSize, 128, 2], + [wire_format.UInt64ByteSize, wire_format.UINT64_MAX, 10], + # SInt32ByteSize(). + [wire_format.SInt32ByteSize, 0, 1], + [wire_format.SInt32ByteSize, -1, 1], + [wire_format.SInt32ByteSize, 1, 1], + [wire_format.SInt32ByteSize, -63, 1], + [wire_format.SInt32ByteSize, 63, 1], + [wire_format.SInt32ByteSize, -64, 1], + [wire_format.SInt32ByteSize, 64, 2], + # SInt64ByteSize(). + [wire_format.SInt64ByteSize, 0, 1], + [wire_format.SInt64ByteSize, -1, 1], + [wire_format.SInt64ByteSize, 1, 1], + [wire_format.SInt64ByteSize, -63, 1], + [wire_format.SInt64ByteSize, 63, 1], + [wire_format.SInt64ByteSize, -64, 1], + [wire_format.SInt64ByteSize, 64, 2], + # Fixed32ByteSize(). + [wire_format.Fixed32ByteSize, 0, 4], + [wire_format.Fixed32ByteSize, wire_format.UINT32_MAX, 4], + # Fixed64ByteSize(). + [wire_format.Fixed64ByteSize, 0, 8], + [wire_format.Fixed64ByteSize, wire_format.UINT64_MAX, 8], + # SFixed32ByteSize(). + [wire_format.SFixed32ByteSize, 0, 4], + [wire_format.SFixed32ByteSize, wire_format.INT32_MIN, 4], + [wire_format.SFixed32ByteSize, wire_format.INT32_MAX, 4], + # SFixed64ByteSize(). + [wire_format.SFixed64ByteSize, 0, 8], + [wire_format.SFixed64ByteSize, wire_format.INT64_MIN, 8], + [wire_format.SFixed64ByteSize, wire_format.INT64_MAX, 8], + # FloatByteSize(). + [wire_format.FloatByteSize, 0.0, 4], + [wire_format.FloatByteSize, 1000000000.0, 4], + [wire_format.FloatByteSize, -1000000000.0, 4], + # DoubleByteSize(). + [wire_format.DoubleByteSize, 0.0, 8], + [wire_format.DoubleByteSize, 1000000000.0, 8], + [wire_format.DoubleByteSize, -1000000000.0, 8], + # BoolByteSize(). + [wire_format.BoolByteSize, False, 1], + [wire_format.BoolByteSize, True, 1], + # EnumByteSize(). + [wire_format.EnumByteSize, 0, 1], + [wire_format.EnumByteSize, 127, 1], + [wire_format.EnumByteSize, 128, 2], + [wire_format.EnumByteSize, wire_format.UINT32_MAX, 5], + ] + for args in NUMERIC_ARGS: + self.NumericByteSizeTestHelper(*args) + + # Test strings and bytes. + for byte_size_fn in (wire_format.StringByteSize, wire_format.BytesByteSize): + # 1 byte for tag, 1 byte for length, 3 bytes for contents. + self.assertEqual(5, byte_size_fn(10, 'abc')) + # 2 bytes for tag, 1 byte for length, 3 bytes for contents. + self.assertEqual(6, byte_size_fn(16, 'abc')) + # 2 bytes for tag, 2 bytes for length, 128 bytes for contents. + self.assertEqual(132, byte_size_fn(16, 'a' * 128)) + + # Test UTF-8 string byte size calculation. + # 1 byte for tag, 1 byte for length, 8 bytes for content. + self.assertEqual(10, wire_format.StringByteSize( + 5, unicode('\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82', 'utf-8'))) + + class MockMessage(object): + def __init__(self, byte_size): + self.byte_size = byte_size + def ByteSize(self): + return self.byte_size + + message_byte_size = 10 + mock_message = MockMessage(byte_size=message_byte_size) + # Test groups. + # (2 * 1) bytes for begin and end tags, plus message_byte_size. + self.assertEqual(2 + message_byte_size, + wire_format.GroupByteSize(1, mock_message)) + # (2 * 2) bytes for begin and end tags, plus message_byte_size. + self.assertEqual(4 + message_byte_size, + wire_format.GroupByteSize(16, mock_message)) + + # Test messages. + # 1 byte for tag, plus 1 byte for length, plus contents. + self.assertEqual(2 + mock_message.byte_size, + wire_format.MessageByteSize(1, mock_message)) + # 2 bytes for tag, plus 1 byte for length, plus contents. + self.assertEqual(3 + mock_message.byte_size, + wire_format.MessageByteSize(16, mock_message)) + # 2 bytes for tag, plus 2 bytes for length, plus contents. + mock_message.byte_size = 128 + self.assertEqual(4 + mock_message.byte_size, + wire_format.MessageByteSize(16, mock_message)) + + + # Test message set item byte size. + # 4 bytes for tags, plus 1 byte for length, plus 1 byte for type_id, + # plus contents. + mock_message.byte_size = 10 + self.assertEqual(mock_message.byte_size + 6, + wire_format.MessageSetItemByteSize(1, mock_message)) + + # 4 bytes for tags, plus 2 bytes for length, plus 1 byte for type_id, + # plus contents. + mock_message.byte_size = 128 + self.assertEqual(mock_message.byte_size + 7, + wire_format.MessageSetItemByteSize(1, mock_message)) + + # 4 bytes for tags, plus 2 bytes for length, plus 2 byte for type_id, + # plus contents. + self.assertEqual(mock_message.byte_size + 8, + wire_format.MessageSetItemByteSize(128, mock_message)) + + # Too-long varint. + self.assertRaises(message.EncodeError, + wire_format.UInt64ByteSize, 1, 1 << 128) + + +if __name__ == '__main__': + unittest.main() diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/message.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/message.py index cb4b2171..f8398474 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/message.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/message.py @@ -1,254 +1,254 @@ -# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-# TODO(robinson): We should just make these methods all "pure-virtual" and move
-# all implementation out, into reflection.py for now.
-
-
-"""Contains an abstract base class for protocol messages."""
-
-__author__ = '[email protected] (Will Robinson)'
-
-
-class Error(Exception): pass
-class DecodeError(Error): pass
-class EncodeError(Error): pass
-
-
-class Message(object):
-
- """Abstract base class for protocol messages.
-
- Protocol message classes are almost always generated by the protocol
- compiler. These generated types subclass Message and implement the methods
- shown below.
-
- TODO(robinson): Link to an HTML document here.
-
- TODO(robinson): Document that instances of this class will also
- have an Extensions attribute with __getitem__ and __setitem__.
- Again, not sure how to best convey this.
-
- TODO(robinson): Document that the class must also have a static
- RegisterExtension(extension_field) method.
- Not sure how to best express at this point.
- """
-
- # TODO(robinson): Document these fields and methods.
-
- __slots__ = []
-
- DESCRIPTOR = None
-
- def __eq__(self, other_msg):
- raise NotImplementedError
-
- def __ne__(self, other_msg):
- # Can't just say self != other_msg, since that would infinitely recurse. :)
- return not self == other_msg
-
- def __str__(self):
- raise NotImplementedError
-
- def MergeFrom(self, other_msg):
- """Merges the contents of the specified message into current message.
-
- This method merges the contents of the specified message into the current
- message. Singular fields that are set in the specified message overwrite
- the corresponding fields in the current message. Repeated fields are
- appended. Singular sub-messages and groups are recursively merged.
-
- Args:
- other_msg: Message to merge into the current message.
- """
- raise NotImplementedError
-
- def CopyFrom(self, other_msg):
- """Copies the content of the specified message into the current message.
-
- The method clears the current message and then merges the specified
- message using MergeFrom.
-
- Args:
- other_msg: Message to copy into the current one.
- """
- if self is other_msg:
- return
- self.Clear()
- self.MergeFrom(other_msg)
-
- def Clear(self):
- """Clears all data that was set in the message."""
- raise NotImplementedError
-
- def SetInParent(self):
- """Mark this as present in the parent.
-
- This normally happens automatically when you assign a field of a
- sub-message, but sometimes you want to make the sub-message
- present while keeping it empty. If you find yourself using this,
- you may want to reconsider your design."""
- raise NotImplementedError
-
- def IsInitialized(self):
- """Checks if the message is initialized.
-
- Returns:
- The method returns True if the message is initialized (i.e. all of its
- required fields are set).
- """
- raise NotImplementedError
-
- # TODO(robinson): MergeFromString() should probably return None and be
- # implemented in terms of a helper that returns the # of bytes read. Our
- # deserialization routines would use the helper when recursively
- # deserializing, but the end user would almost always just want the no-return
- # MergeFromString().
-
- def MergeFromString(self, serialized):
- """Merges serialized protocol buffer data into this message.
-
- When we find a field in |serialized| that is already present
- in this message:
- - If it's a "repeated" field, we append to the end of our list.
- - Else, if it's a scalar, we overwrite our field.
- - Else, (it's a nonrepeated composite), we recursively merge
- into the existing composite.
-
- TODO(robinson): Document handling of unknown fields.
-
- Args:
- serialized: Any object that allows us to call buffer(serialized)
- to access a string of bytes using the buffer interface.
-
- TODO(robinson): When we switch to a helper, this will return None.
-
- Returns:
- The number of bytes read from |serialized|.
- For non-group messages, this will always be len(serialized),
- but for messages which are actually groups, this will
- generally be less than len(serialized), since we must
- stop when we reach an END_GROUP tag. Note that if
- we *do* stop because of an END_GROUP tag, the number
- of bytes returned does not include the bytes
- for the END_GROUP tag information.
- """
- raise NotImplementedError
-
- def ParseFromString(self, serialized):
- """Like MergeFromString(), except we clear the object first."""
- self.Clear()
- self.MergeFromString(serialized)
-
- def SerializeToString(self):
- """Serializes the protocol message to a binary string.
-
- Returns:
- A binary string representation of the message if all of the required
- fields in the message are set (i.e. the message is initialized).
-
- Raises:
- message.EncodeError if the message isn't initialized.
- """
- raise NotImplementedError
-
- def SerializePartialToString(self):
- """Serializes the protocol message to a binary string.
-
- This method is similar to SerializeToString but doesn't check if the
- message is initialized.
-
- Returns:
- A string representation of the partial message.
- """
- raise NotImplementedError
-
- # TODO(robinson): Decide whether we like these better
- # than auto-generated has_foo() and clear_foo() methods
- # on the instances themselves. This way is less consistent
- # with C++, but it makes reflection-type access easier and
- # reduces the number of magically autogenerated things.
- #
- # TODO(robinson): Be sure to document (and test) exactly
- # which field names are accepted here. Are we case-sensitive?
- # What do we do with fields that share names with Python keywords
- # like 'lambda' and 'yield'?
- #
- # nnorwitz says:
- # """
- # Typically (in python), an underscore is appended to names that are
- # keywords. So they would become lambda_ or yield_.
- # """
- def ListFields(self):
- """Returns a list of (FieldDescriptor, value) tuples for all
- fields in the message which are not empty. A singular field is non-empty
- if HasField() would return true, and a repeated field is non-empty if
- it contains at least one element. The fields are ordered by field
- number"""
- raise NotImplementedError
-
- def HasField(self, field_name):
- raise NotImplementedError
-
- def ClearField(self, field_name):
- raise NotImplementedError
-
- def HasExtension(self, extension_handle):
- raise NotImplementedError
-
- def ClearExtension(self, extension_handle):
- raise NotImplementedError
-
- def ByteSize(self):
- """Returns the serialized size of this message.
- Recursively calls ByteSize() on all contained messages.
- """
- raise NotImplementedError
-
- def _SetListener(self, message_listener):
- """Internal method used by the protocol message implementation.
- Clients should not call this directly.
-
- Sets a listener that this message will call on certain state transitions.
-
- The purpose of this method is to register back-edges from children to
- parents at runtime, for the purpose of setting "has" bits and
- byte-size-dirty bits in the parent and ancestor objects whenever a child or
- descendant object is modified.
-
- If the client wants to disconnect this Message from the object tree, she
- explicitly sets callback to None.
-
- If message_listener is None, unregisters any existing listener. Otherwise,
- message_listener must implement the MessageListener interface in
- internal/message_listener.py, and we discard any listener registered
- via a previous _SetListener() call.
- """
- raise NotImplementedError
+# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# TODO(robinson): We should just make these methods all "pure-virtual" and move +# all implementation out, into reflection.py for now. + + +"""Contains an abstract base class for protocol messages.""" + +__author__ = '[email protected] (Will Robinson)' + + +class Error(Exception): pass +class DecodeError(Error): pass +class EncodeError(Error): pass + + +class Message(object): + + """Abstract base class for protocol messages. + + Protocol message classes are almost always generated by the protocol + compiler. These generated types subclass Message and implement the methods + shown below. + + TODO(robinson): Link to an HTML document here. + + TODO(robinson): Document that instances of this class will also + have an Extensions attribute with __getitem__ and __setitem__. + Again, not sure how to best convey this. + + TODO(robinson): Document that the class must also have a static + RegisterExtension(extension_field) method. + Not sure how to best express at this point. + """ + + # TODO(robinson): Document these fields and methods. + + __slots__ = [] + + DESCRIPTOR = None + + def __eq__(self, other_msg): + raise NotImplementedError + + def __ne__(self, other_msg): + # Can't just say self != other_msg, since that would infinitely recurse. :) + return not self == other_msg + + def __str__(self): + raise NotImplementedError + + def MergeFrom(self, other_msg): + """Merges the contents of the specified message into current message. + + This method merges the contents of the specified message into the current + message. Singular fields that are set in the specified message overwrite + the corresponding fields in the current message. Repeated fields are + appended. Singular sub-messages and groups are recursively merged. + + Args: + other_msg: Message to merge into the current message. + """ + raise NotImplementedError + + def CopyFrom(self, other_msg): + """Copies the content of the specified message into the current message. + + The method clears the current message and then merges the specified + message using MergeFrom. + + Args: + other_msg: Message to copy into the current one. + """ + if self is other_msg: + return + self.Clear() + self.MergeFrom(other_msg) + + def Clear(self): + """Clears all data that was set in the message.""" + raise NotImplementedError + + def SetInParent(self): + """Mark this as present in the parent. + + This normally happens automatically when you assign a field of a + sub-message, but sometimes you want to make the sub-message + present while keeping it empty. If you find yourself using this, + you may want to reconsider your design.""" + raise NotImplementedError + + def IsInitialized(self): + """Checks if the message is initialized. + + Returns: + The method returns True if the message is initialized (i.e. all of its + required fields are set). + """ + raise NotImplementedError + + # TODO(robinson): MergeFromString() should probably return None and be + # implemented in terms of a helper that returns the # of bytes read. Our + # deserialization routines would use the helper when recursively + # deserializing, but the end user would almost always just want the no-return + # MergeFromString(). + + def MergeFromString(self, serialized): + """Merges serialized protocol buffer data into this message. + + When we find a field in |serialized| that is already present + in this message: + - If it's a "repeated" field, we append to the end of our list. + - Else, if it's a scalar, we overwrite our field. + - Else, (it's a nonrepeated composite), we recursively merge + into the existing composite. + + TODO(robinson): Document handling of unknown fields. + + Args: + serialized: Any object that allows us to call buffer(serialized) + to access a string of bytes using the buffer interface. + + TODO(robinson): When we switch to a helper, this will return None. + + Returns: + The number of bytes read from |serialized|. + For non-group messages, this will always be len(serialized), + but for messages which are actually groups, this will + generally be less than len(serialized), since we must + stop when we reach an END_GROUP tag. Note that if + we *do* stop because of an END_GROUP tag, the number + of bytes returned does not include the bytes + for the END_GROUP tag information. + """ + raise NotImplementedError + + def ParseFromString(self, serialized): + """Like MergeFromString(), except we clear the object first.""" + self.Clear() + self.MergeFromString(serialized) + + def SerializeToString(self): + """Serializes the protocol message to a binary string. + + Returns: + A binary string representation of the message if all of the required + fields in the message are set (i.e. the message is initialized). + + Raises: + message.EncodeError if the message isn't initialized. + """ + raise NotImplementedError + + def SerializePartialToString(self): + """Serializes the protocol message to a binary string. + + This method is similar to SerializeToString but doesn't check if the + message is initialized. + + Returns: + A string representation of the partial message. + """ + raise NotImplementedError + + # TODO(robinson): Decide whether we like these better + # than auto-generated has_foo() and clear_foo() methods + # on the instances themselves. This way is less consistent + # with C++, but it makes reflection-type access easier and + # reduces the number of magically autogenerated things. + # + # TODO(robinson): Be sure to document (and test) exactly + # which field names are accepted here. Are we case-sensitive? + # What do we do with fields that share names with Python keywords + # like 'lambda' and 'yield'? + # + # nnorwitz says: + # """ + # Typically (in python), an underscore is appended to names that are + # keywords. So they would become lambda_ or yield_. + # """ + def ListFields(self): + """Returns a list of (FieldDescriptor, value) tuples for all + fields in the message which are not empty. A singular field is non-empty + if HasField() would return true, and a repeated field is non-empty if + it contains at least one element. The fields are ordered by field + number""" + raise NotImplementedError + + def HasField(self, field_name): + raise NotImplementedError + + def ClearField(self, field_name): + raise NotImplementedError + + def HasExtension(self, extension_handle): + raise NotImplementedError + + def ClearExtension(self, extension_handle): + raise NotImplementedError + + def ByteSize(self): + """Returns the serialized size of this message. + Recursively calls ByteSize() on all contained messages. + """ + raise NotImplementedError + + def _SetListener(self, message_listener): + """Internal method used by the protocol message implementation. + Clients should not call this directly. + + Sets a listener that this message will call on certain state transitions. + + The purpose of this method is to register back-edges from children to + parents at runtime, for the purpose of setting "has" bits and + byte-size-dirty bits in the parent and ancestor objects whenever a child or + descendant object is modified. + + If the client wants to disconnect this Message from the object tree, she + explicitly sets callback to None. + + If message_listener is None, unregisters any existing listener. Otherwise, + message_listener must implement the MessageListener interface in + internal/message_listener.py, and we discard any listener registered + via a previous _SetListener() call. + """ + raise NotImplementedError diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/reflection.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/reflection.py index eb3866b1..5b238031 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/reflection.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/reflection.py @@ -1,1148 +1,1148 @@ -# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-# This code is meant to work on Python 2.4 and above only.
-#
-# TODO(robinson): Helpers for verbose, common checks like seeing if a
-# descriptor's cpp_type is CPPTYPE_MESSAGE.
-
-"""Contains a metaclass and helper functions used to create
-protocol message classes from Descriptor objects at runtime.
-
-Recall that a metaclass is the "type" of a class.
-(A class is to a metaclass what an instance is to a class.)
-
-In this case, we use the GeneratedProtocolMessageType metaclass
-to inject all the useful functionality into the classes
-output by the protocol compiler at compile-time.
-
-The upshot of all this is that the real implementation
-details for ALL pure-Python protocol buffers are *here in
-this file*.
-"""
-
-__author__ = '[email protected] (Will Robinson)'
-
-try:
- from cStringIO import StringIO
-except ImportError:
- from StringIO import StringIO
-import struct
-import weakref
-
-# We use "as" to avoid name collisions with variables.
-from google.protobuf.internal import containers
-from google.protobuf.internal import decoder
-from google.protobuf.internal import encoder
-from google.protobuf.internal import message_listener as message_listener_mod
-from google.protobuf.internal import type_checkers
-from google.protobuf.internal import wire_format
-from google.protobuf import descriptor as descriptor_mod
-from google.protobuf import message as message_mod
-from google.protobuf import text_format
-
-_FieldDescriptor = descriptor_mod.FieldDescriptor
-
-
-class GeneratedProtocolMessageType(type):
-
- """Metaclass for protocol message classes created at runtime from Descriptors.
-
- We add implementations for all methods described in the Message class. We
- also create properties to allow getting/setting all fields in the protocol
- message. Finally, we create slots to prevent users from accidentally
- "setting" nonexistent fields in the protocol message, which then wouldn't get
- serialized / deserialized properly.
-
- The protocol compiler currently uses this metaclass to create protocol
- message classes at runtime. Clients can also manually create their own
- classes at runtime, as in this example:
-
- mydescriptor = Descriptor(.....)
- class MyProtoClass(Message):
- __metaclass__ = GeneratedProtocolMessageType
- DESCRIPTOR = mydescriptor
- myproto_instance = MyProtoClass()
- myproto.foo_field = 23
- ...
- """
-
- # Must be consistent with the protocol-compiler code in
- # proto2/compiler/internal/generator.*.
- _DESCRIPTOR_KEY = 'DESCRIPTOR'
-
- def __new__(cls, name, bases, dictionary):
- """Custom allocation for runtime-generated class types.
-
- We override __new__ because this is apparently the only place
- where we can meaningfully set __slots__ on the class we're creating(?).
- (The interplay between metaclasses and slots is not very well-documented).
-
- Args:
- name: Name of the class (ignored, but required by the
- metaclass protocol).
- bases: Base classes of the class we're constructing.
- (Should be message.Message). We ignore this field, but
- it's required by the metaclass protocol
- dictionary: The class dictionary of the class we're
- constructing. dictionary[_DESCRIPTOR_KEY] must contain
- a Descriptor object describing this protocol message
- type.
-
- Returns:
- Newly-allocated class.
- """
- descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
- _AddSlots(descriptor, dictionary)
- _AddClassAttributesForNestedExtensions(descriptor, dictionary)
- superclass = super(GeneratedProtocolMessageType, cls)
- return superclass.__new__(cls, name, bases, dictionary)
-
- def __init__(cls, name, bases, dictionary):
- """Here we perform the majority of our work on the class.
- We add enum getters, an __init__ method, implementations
- of all Message methods, and properties for all fields
- in the protocol type.
-
- Args:
- name: Name of the class (ignored, but required by the
- metaclass protocol).
- bases: Base classes of the class we're constructing.
- (Should be message.Message). We ignore this field, but
- it's required by the metaclass protocol
- dictionary: The class dictionary of the class we're
- constructing. dictionary[_DESCRIPTOR_KEY] must contain
- a Descriptor object describing this protocol message
- type.
- """
- descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
-
- cls._decoders_by_tag = {}
- cls._extensions_by_name = {}
- cls._extensions_by_number = {}
- if (descriptor.has_options and
- descriptor.GetOptions().message_set_wire_format):
- cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = (
- decoder.MessageSetItemDecoder(cls._extensions_by_number))
-
- # We act as a "friend" class of the descriptor, setting
- # its _concrete_class attribute the first time we use a
- # given descriptor to initialize a concrete protocol message
- # class. We also attach stuff to each FieldDescriptor for quick
- # lookup later on.
- concrete_class_attr_name = '_concrete_class'
- if not hasattr(descriptor, concrete_class_attr_name):
- setattr(descriptor, concrete_class_attr_name, cls)
- for field in descriptor.fields:
- _AttachFieldHelpers(cls, field)
-
- _AddEnumValues(descriptor, cls)
- _AddInitMethod(descriptor, cls)
- _AddPropertiesForFields(descriptor, cls)
- _AddPropertiesForExtensions(descriptor, cls)
- _AddStaticMethods(cls)
- _AddMessageMethods(descriptor, cls)
- _AddPrivateHelperMethods(cls)
- superclass = super(GeneratedProtocolMessageType, cls)
- superclass.__init__(name, bases, dictionary)
-
-
-# Stateless helpers for GeneratedProtocolMessageType below.
-# Outside clients should not access these directly.
-#
-# I opted not to make any of these methods on the metaclass, to make it more
-# clear that I'm not really using any state there and to keep clients from
-# thinking that they have direct access to these construction helpers.
-
-
-def _PropertyName(proto_field_name):
- """Returns the name of the public property attribute which
- clients can use to get and (in some cases) set the value
- of a protocol message field.
-
- Args:
- proto_field_name: The protocol message field name, exactly
- as it appears (or would appear) in a .proto file.
- """
- # TODO(robinson): Escape Python keywords (e.g., yield), and test this support.
- # nnorwitz makes my day by writing:
- # """
- # FYI. See the keyword module in the stdlib. This could be as simple as:
- #
- # if keyword.iskeyword(proto_field_name):
- # return proto_field_name + "_"
- # return proto_field_name
- # """
- # Kenton says: The above is a BAD IDEA. People rely on being able to use
- # getattr() and setattr() to reflectively manipulate field values. If we
- # rename the properties, then every such user has to also make sure to apply
- # the same transformation. Note that currently if you name a field "yield",
- # you can still access it just fine using getattr/setattr -- it's not even
- # that cumbersome to do so.
- # TODO(kenton): Remove this method entirely if/when everyone agrees with my
- # position.
- return proto_field_name
-
-
-def _VerifyExtensionHandle(message, extension_handle):
- """Verify that the given extension handle is valid."""
-
- if not isinstance(extension_handle, _FieldDescriptor):
- raise KeyError('HasExtension() expects an extension handle, got: %s' %
- extension_handle)
-
- if not extension_handle.is_extension:
- raise KeyError('"%s" is not an extension.' % extension_handle.full_name)
-
- if extension_handle.containing_type is not message.DESCRIPTOR:
- raise KeyError('Extension "%s" extends message type "%s", but this '
- 'message is of type "%s".' %
- (extension_handle.full_name,
- extension_handle.containing_type.full_name,
- message.DESCRIPTOR.full_name))
-
-
-def _AddSlots(message_descriptor, dictionary):
- """Adds a __slots__ entry to dictionary, containing the names of all valid
- attributes for this message type.
-
- Args:
- message_descriptor: A Descriptor instance describing this message type.
- dictionary: Class dictionary to which we'll add a '__slots__' entry.
- """
- dictionary['__slots__'] = ['_cached_byte_size',
- '_cached_byte_size_dirty',
- '_fields',
- '_is_present_in_parent',
- '_listener',
- '_listener_for_children',
- '__weakref__']
-
-
-def _IsMessageSetExtension(field):
- return (field.is_extension and
- field.containing_type.has_options and
- field.containing_type.GetOptions().message_set_wire_format and
- field.type == _FieldDescriptor.TYPE_MESSAGE and
- field.message_type == field.extension_scope and
- field.label == _FieldDescriptor.LABEL_OPTIONAL)
-
-
-def _AttachFieldHelpers(cls, field_descriptor):
- is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
- is_packed = (field_descriptor.has_options and
- field_descriptor.GetOptions().packed)
-
- if _IsMessageSetExtension(field_descriptor):
- field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
- sizer = encoder.MessageSetItemSizer(field_descriptor.number)
- else:
- field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type](
- field_descriptor.number, is_repeated, is_packed)
- sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type](
- field_descriptor.number, is_repeated, is_packed)
-
- field_descriptor._encoder = field_encoder
- field_descriptor._sizer = sizer
- field_descriptor._default_constructor = _DefaultValueConstructorForField(
- field_descriptor)
-
- def AddDecoder(wiretype, is_packed):
- tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype)
- cls._decoders_by_tag[tag_bytes] = (
- type_checkers.TYPE_TO_DECODER[field_descriptor.type](
- field_descriptor.number, is_repeated, is_packed,
- field_descriptor, field_descriptor._default_constructor))
-
- AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type],
- False)
-
- if is_repeated and wire_format.IsTypePackable(field_descriptor.type):
- # To support wire compatibility of adding packed = true, add a decoder for
- # packed values regardless of the field's options.
- AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True)
-
-
-def _AddClassAttributesForNestedExtensions(descriptor, dictionary):
- extension_dict = descriptor.extensions_by_name
- for extension_name, extension_field in extension_dict.iteritems():
- assert extension_name not in dictionary
- dictionary[extension_name] = extension_field
-
-
-def _AddEnumValues(descriptor, cls):
- """Sets class-level attributes for all enum fields defined in this message.
-
- Args:
- descriptor: Descriptor object for this message type.
- cls: Class we're constructing for this message type.
- """
- for enum_type in descriptor.enum_types:
- for enum_value in enum_type.values:
- setattr(cls, enum_value.name, enum_value.number)
-
-
-def _DefaultValueConstructorForField(field):
- """Returns a function which returns a default value for a field.
-
- Args:
- field: FieldDescriptor object for this field.
-
- The returned function has one argument:
- message: Message instance containing this field, or a weakref proxy
- of same.
-
- That function in turn returns a default value for this field. The default
- value may refer back to |message| via a weak reference.
- """
-
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- if field.default_value != []:
- raise ValueError('Repeated field default value not empty list: %s' % (
- field.default_value))
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- # We can't look at _concrete_class yet since it might not have
- # been set. (Depends on order in which we initialize the classes).
- message_type = field.message_type
- def MakeRepeatedMessageDefault(message):
- return containers.RepeatedCompositeFieldContainer(
- message._listener_for_children, field.message_type)
- return MakeRepeatedMessageDefault
- else:
- type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
- def MakeRepeatedScalarDefault(message):
- return containers.RepeatedScalarFieldContainer(
- message._listener_for_children, type_checker)
- return MakeRepeatedScalarDefault
-
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- # _concrete_class may not yet be initialized.
- message_type = field.message_type
- def MakeSubMessageDefault(message):
- result = message_type._concrete_class()
- result._SetListener(message._listener_for_children)
- return result
- return MakeSubMessageDefault
-
- def MakeScalarDefault(message):
- return field.default_value
- return MakeScalarDefault
-
-
-def _AddInitMethod(message_descriptor, cls):
- """Adds an __init__ method to cls."""
- fields = message_descriptor.fields
- def init(self, **kwargs):
- self._cached_byte_size = 0
- self._cached_byte_size_dirty = False
- self._fields = {}
- self._is_present_in_parent = False
- self._listener = message_listener_mod.NullMessageListener()
- self._listener_for_children = _Listener(self)
- for field_name, field_value in kwargs.iteritems():
- field = _GetFieldByName(message_descriptor, field_name)
- if field is None:
- raise TypeError("%s() got an unexpected keyword argument '%s'" %
- (message_descriptor.name, field_name))
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- copy = field._default_constructor(self)
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite
- for val in field_value:
- copy.add().MergeFrom(val)
- else: # Scalar
- copy.extend(field_value)
- self._fields[field] = copy
- elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- copy = field._default_constructor(self)
- copy.MergeFrom(field_value)
- self._fields[field] = copy
- else:
- self._fields[field] = field_value
-
- init.__module__ = None
- init.__doc__ = None
- cls.__init__ = init
-
-
-def _GetFieldByName(message_descriptor, field_name):
- """Returns a field descriptor by field name.
-
- Args:
- message_descriptor: A Descriptor describing all fields in message.
- field_name: The name of the field to retrieve.
- Returns:
- The field descriptor associated with the field name.
- """
- try:
- return message_descriptor.fields_by_name[field_name]
- except KeyError:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
-
-
-def _AddPropertiesForFields(descriptor, cls):
- """Adds properties for all fields in this protocol message type."""
- for field in descriptor.fields:
- _AddPropertiesForField(field, cls)
-
- if descriptor.is_extendable:
- # _ExtensionDict is just an adaptor with no state so we allocate a new one
- # every time it is accessed.
- cls.Extensions = property(lambda self: _ExtensionDict(self))
-
-
-def _AddPropertiesForField(field, cls):
- """Adds a public property for a protocol message field.
- Clients can use this property to get and (in the case
- of non-repeated scalar fields) directly set the value
- of a protocol message field.
-
- Args:
- field: A FieldDescriptor for this field.
- cls: The class we're constructing.
- """
- # Catch it if we add other types that we should
- # handle specially here.
- assert _FieldDescriptor.MAX_CPPTYPE == 10
-
- constant_name = field.name.upper() + "_FIELD_NUMBER"
- setattr(cls, constant_name, field.number)
-
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- _AddPropertiesForRepeatedField(field, cls)
- elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- _AddPropertiesForNonRepeatedCompositeField(field, cls)
- else:
- _AddPropertiesForNonRepeatedScalarField(field, cls)
-
-
-def _AddPropertiesForRepeatedField(field, cls):
- """Adds a public property for a "repeated" protocol message field. Clients
- can use this property to get the value of the field, which will be either a
- _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see
- below).
-
- Note that when clients add values to these containers, we perform
- type-checking in the case of repeated scalar fields, and we also set any
- necessary "has" bits as a side-effect.
-
- Args:
- field: A FieldDescriptor for this field.
- cls: The class we're constructing.
- """
- proto_field_name = field.name
- property_name = _PropertyName(proto_field_name)
-
- def getter(self):
- field_value = self._fields.get(field)
- if field_value is None:
- # Construct a new object to represent this field.
- field_value = field._default_constructor(self)
-
- # Atomically check if another thread has preempted us and, if not, swap
- # in the new object we just created. If someone has preempted us, we
- # take that object and discard ours.
- # WARNING: We are relying on setdefault() being atomic. This is true
- # in CPython but we haven't investigated others. This warning appears
- # in several other locations in this file.
- field_value = self._fields.setdefault(field, field_value)
- return field_value
- getter.__module__ = None
- getter.__doc__ = 'Getter for %s.' % proto_field_name
-
- # We define a setter just so we can throw an exception with a more
- # helpful error message.
- def setter(self, new_value):
- raise AttributeError('Assignment not allowed to repeated field '
- '"%s" in protocol message object.' % proto_field_name)
-
- doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
-
-
-def _AddPropertiesForNonRepeatedScalarField(field, cls):
- """Adds a public property for a nonrepeated, scalar protocol message field.
- Clients can use this property to get and directly set the value of the field.
- Note that when the client sets the value of a field by using this property,
- all necessary "has" bits are set as a side-effect, and we also perform
- type-checking.
-
- Args:
- field: A FieldDescriptor for this field.
- cls: The class we're constructing.
- """
- proto_field_name = field.name
- property_name = _PropertyName(proto_field_name)
- type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type)
- default_value = field.default_value
-
- def getter(self):
- return self._fields.get(field, default_value)
- getter.__module__ = None
- getter.__doc__ = 'Getter for %s.' % proto_field_name
- def setter(self, new_value):
- type_checker.CheckValue(new_value)
- self._fields[field] = new_value
- # Check _cached_byte_size_dirty inline to improve performance, since scalar
- # setters are called frequently.
- if not self._cached_byte_size_dirty:
- self._Modified()
- setter.__module__ = None
- setter.__doc__ = 'Setter for %s.' % proto_field_name
-
- # Add a property to encapsulate the getter/setter.
- doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
-
-
-def _AddPropertiesForNonRepeatedCompositeField(field, cls):
- """Adds a public property for a nonrepeated, composite protocol message field.
- A composite field is a "group" or "message" field.
-
- Clients can use this property to get the value of the field, but cannot
- assign to the property directly.
-
- Args:
- field: A FieldDescriptor for this field.
- cls: The class we're constructing.
- """
- # TODO(robinson): Remove duplication with similar method
- # for non-repeated scalars.
- proto_field_name = field.name
- property_name = _PropertyName(proto_field_name)
- message_type = field.message_type
-
- def getter(self):
- field_value = self._fields.get(field)
- if field_value is None:
- # Construct a new object to represent this field.
- field_value = message_type._concrete_class()
- field_value._SetListener(self._listener_for_children)
-
- # Atomically check if another thread has preempted us and, if not, swap
- # in the new object we just created. If someone has preempted us, we
- # take that object and discard ours.
- # WARNING: We are relying on setdefault() being atomic. This is true
- # in CPython but we haven't investigated others. This warning appears
- # in several other locations in this file.
- field_value = self._fields.setdefault(field, field_value)
- return field_value
- getter.__module__ = None
- getter.__doc__ = 'Getter for %s.' % proto_field_name
-
- # We define a setter just so we can throw an exception with a more
- # helpful error message.
- def setter(self, new_value):
- raise AttributeError('Assignment not allowed to composite field '
- '"%s" in protocol message object.' % proto_field_name)
-
- # Add a property to encapsulate the getter.
- doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
-
-
-def _AddPropertiesForExtensions(descriptor, cls):
- """Adds properties for all fields in this protocol message type."""
- extension_dict = descriptor.extensions_by_name
- for extension_name, extension_field in extension_dict.iteritems():
- constant_name = extension_name.upper() + "_FIELD_NUMBER"
- setattr(cls, constant_name, extension_field.number)
-
-
-def _AddStaticMethods(cls):
- # TODO(robinson): This probably needs to be thread-safe(?)
- def RegisterExtension(extension_handle):
- extension_handle.containing_type = cls.DESCRIPTOR
- _AttachFieldHelpers(cls, extension_handle)
-
- # Try to insert our extension, failing if an extension with the same number
- # already exists.
- actual_handle = cls._extensions_by_number.setdefault(
- extension_handle.number, extension_handle)
- if actual_handle is not extension_handle:
- raise AssertionError(
- 'Extensions "%s" and "%s" both try to extend message type "%s" with '
- 'field number %d.' %
- (extension_handle.full_name, actual_handle.full_name,
- cls.DESCRIPTOR.full_name, extension_handle.number))
-
- cls._extensions_by_name[extension_handle.full_name] = extension_handle
-
- handle = extension_handle # avoid line wrapping
- if _IsMessageSetExtension(handle):
- # MessageSet extension. Also register under type name.
- cls._extensions_by_name[
- extension_handle.message_type.full_name] = extension_handle
-
- cls.RegisterExtension = staticmethod(RegisterExtension)
-
- def FromString(s):
- message = cls()
- message.MergeFromString(s)
- return message
- cls.FromString = staticmethod(FromString)
-
-
-def _IsPresent(item):
- """Given a (FieldDescriptor, value) tuple from _fields, return true if the
- value should be included in the list returned by ListFields()."""
-
- if item[0].label == _FieldDescriptor.LABEL_REPEATED:
- return bool(item[1])
- elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- return item[1]._is_present_in_parent
- else:
- return True
-
-
-def _AddListFieldsMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
-
- def ListFields(self):
- all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)]
- all_fields.sort(key = lambda item: item[0].number)
- return all_fields
-
- cls.ListFields = ListFields
-
-
-def _AddHasFieldMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
-
- singular_fields = {}
- for field in message_descriptor.fields:
- if field.label != _FieldDescriptor.LABEL_REPEATED:
- singular_fields[field.name] = field
-
- def HasField(self, field_name):
- try:
- field = singular_fields[field_name]
- except KeyError:
- raise ValueError(
- 'Protocol message has no singular "%s" field.' % field_name)
-
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- value = self._fields.get(field)
- return value is not None and value._is_present_in_parent
- else:
- return field in self._fields
- cls.HasField = HasField
-
-
-def _AddClearFieldMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
- def ClearField(self, field_name):
- try:
- field = message_descriptor.fields_by_name[field_name]
- except KeyError:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
-
- if field in self._fields:
- # Note: If the field is a sub-message, its listener will still point
- # at us. That's fine, because the worst than can happen is that it
- # will call _Modified() and invalidate our byte size. Big deal.
- del self._fields[field]
-
- # Always call _Modified() -- even if nothing was changed, this is
- # a mutating method, and thus calling it should cause the field to become
- # present in the parent message.
- self._Modified()
-
- cls.ClearField = ClearField
-
-
-def _AddClearExtensionMethod(cls):
- """Helper for _AddMessageMethods()."""
- def ClearExtension(self, extension_handle):
- _VerifyExtensionHandle(self, extension_handle)
-
- # Similar to ClearField(), above.
- if extension_handle in self._fields:
- del self._fields[extension_handle]
- self._Modified()
- cls.ClearExtension = ClearExtension
-
-
-def _AddClearMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
- def Clear(self):
- # Clear fields.
- self._fields = {}
- self._Modified()
- cls.Clear = Clear
-
-
-def _AddHasExtensionMethod(cls):
- """Helper for _AddMessageMethods()."""
- def HasExtension(self, extension_handle):
- _VerifyExtensionHandle(self, extension_handle)
- if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
- raise KeyError('"%s" is repeated.' % extension_handle.full_name)
-
- if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- value = self._fields.get(extension_handle)
- return value is not None and value._is_present_in_parent
- else:
- return extension_handle in self._fields
- cls.HasExtension = HasExtension
-
-
-def _AddEqualsMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
- def __eq__(self, other):
- if (not isinstance(other, message_mod.Message) or
- other.DESCRIPTOR != self.DESCRIPTOR):
- return False
-
- if self is other:
- return True
-
- return self.ListFields() == other.ListFields()
-
- cls.__eq__ = __eq__
-
-
-def _AddStrMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
- def __str__(self):
- return text_format.MessageToString(self)
- cls.__str__ = __str__
-
-
-def _AddSetListenerMethod(cls):
- """Helper for _AddMessageMethods()."""
- def SetListener(self, listener):
- if listener is None:
- self._listener = message_listener_mod.NullMessageListener()
- else:
- self._listener = listener
- cls._SetListener = SetListener
-
-
-def _BytesForNonRepeatedElement(value, field_number, field_type):
- """Returns the number of bytes needed to serialize a non-repeated element.
- The returned byte count includes space for tag information and any
- other additional space associated with serializing value.
-
- Args:
- value: Value we're serializing.
- field_number: Field number of this value. (Since the field number
- is stored as part of a varint-encoded tag, this has an impact
- on the total bytes required to serialize the value).
- field_type: The type of the field. One of the TYPE_* constants
- within FieldDescriptor.
- """
- try:
- fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type]
- return fn(field_number, value)
- except KeyError:
- raise message_mod.EncodeError('Unrecognized field type: %d' % field_type)
-
-
-def _AddByteSizeMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
-
- def ByteSize(self):
- if not self._cached_byte_size_dirty:
- return self._cached_byte_size
-
- size = 0
- for field_descriptor, field_value in self.ListFields():
- size += field_descriptor._sizer(field_value)
-
- self._cached_byte_size = size
- self._cached_byte_size_dirty = False
- self._listener_for_children.dirty = False
- return size
-
- cls.ByteSize = ByteSize
-
-
-def _AddSerializeToStringMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
-
- def SerializeToString(self):
- # Check if the message has all of its required fields set.
- errors = []
- if not self.IsInitialized():
- raise message_mod.EncodeError(
- 'Message is missing required fields: ' +
- ','.join(self.FindInitializationErrors()))
- return self.SerializePartialToString()
- cls.SerializeToString = SerializeToString
-
-
-def _AddSerializePartialToStringMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
-
- def SerializePartialToString(self):
- out = StringIO()
- self._InternalSerialize(out.write)
- return out.getvalue()
- cls.SerializePartialToString = SerializePartialToString
-
- def InternalSerialize(self, write_bytes):
- for field_descriptor, field_value in self.ListFields():
- field_descriptor._encoder(write_bytes, field_value)
- cls._InternalSerialize = InternalSerialize
-
-
-def _AddMergeFromStringMethod(message_descriptor, cls):
- """Helper for _AddMessageMethods()."""
- def MergeFromString(self, serialized):
- length = len(serialized)
- try:
- if self._InternalParse(serialized, 0, length) != length:
- # The only reason _InternalParse would return early is if it
- # encountered an end-group tag.
- raise message_mod.DecodeError('Unexpected end-group tag.')
- except IndexError:
- raise message_mod.DecodeError('Truncated message.')
- except struct.error, e:
- raise message_mod.DecodeError(e)
- return length # Return this for legacy reasons.
- cls.MergeFromString = MergeFromString
-
- local_ReadTag = decoder.ReadTag
- local_SkipField = decoder.SkipField
- decoders_by_tag = cls._decoders_by_tag
-
- def InternalParse(self, buffer, pos, end):
- self._Modified()
- field_dict = self._fields
- while pos != end:
- (tag_bytes, new_pos) = local_ReadTag(buffer, pos)
- field_decoder = decoders_by_tag.get(tag_bytes)
- if field_decoder is None:
- new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
- if new_pos == -1:
- return pos
- pos = new_pos
- else:
- pos = field_decoder(buffer, new_pos, end, self, field_dict)
- return pos
- cls._InternalParse = InternalParse
-
-
-def _AddIsInitializedMethod(message_descriptor, cls):
- """Adds the IsInitialized and FindInitializationError methods to the
- protocol message class."""
-
- required_fields = [field for field in message_descriptor.fields
- if field.label == _FieldDescriptor.LABEL_REQUIRED]
-
- def IsInitialized(self, errors=None):
- """Checks if all required fields of a message are set.
-
- Args:
- errors: A list which, if provided, will be populated with the field
- paths of all missing required fields.
-
- Returns:
- True iff the specified message has all required fields set.
- """
-
- # Performance is critical so we avoid HasField() and ListFields().
-
- for field in required_fields:
- if (field not in self._fields or
- (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and
- not self._fields[field]._is_present_in_parent)):
- if errors is not None:
- errors.extend(self.FindInitializationErrors())
- return False
-
- for field, value in self._fields.iteritems():
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- for element in value:
- if not element.IsInitialized():
- if errors is not None:
- errors.extend(self.FindInitializationErrors())
- return False
- elif value._is_present_in_parent and not value.IsInitialized():
- if errors is not None:
- errors.extend(self.FindInitializationErrors())
- return False
-
- return True
-
- cls.IsInitialized = IsInitialized
-
- def FindInitializationErrors(self):
- """Finds required fields which are not initialized.
-
- Returns:
- A list of strings. Each string is a path to an uninitialized field from
- the top-level message, e.g. "foo.bar[5].baz".
- """
-
- errors = [] # simplify things
-
- for field in required_fields:
- if not self.HasField(field.name):
- errors.append(field.name)
-
- for field, value in self.ListFields():
- if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- if field.is_extension:
- name = "(%s)" % field.full_name
- else:
- name = field.name
-
- if field.label == _FieldDescriptor.LABEL_REPEATED:
- for i in xrange(len(value)):
- element = value[i]
- prefix = "%s[%d]." % (name, i)
- sub_errors = element.FindInitializationErrors()
- errors += [ prefix + error for error in sub_errors ]
- else:
- prefix = name + "."
- sub_errors = value.FindInitializationErrors()
- errors += [ prefix + error for error in sub_errors ]
-
- return errors
-
- cls.FindInitializationErrors = FindInitializationErrors
-
-
-def _AddMergeFromMethod(cls):
- LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED
- CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE
-
- def MergeFrom(self, msg):
- assert msg is not self
- self._Modified()
-
- fields = self._fields
-
- for field, value in msg._fields.iteritems():
- if field.label == LABEL_REPEATED or field.cpp_type == CPPTYPE_MESSAGE:
- field_value = fields.get(field)
- if field_value is None:
- # Construct a new object to represent this field.
- field_value = field._default_constructor(self)
- fields[field] = field_value
- field_value.MergeFrom(value)
- else:
- self._fields[field] = value
- cls.MergeFrom = MergeFrom
-
-
-def _AddMessageMethods(message_descriptor, cls):
- """Adds implementations of all Message methods to cls."""
- _AddListFieldsMethod(message_descriptor, cls)
- _AddHasFieldMethod(message_descriptor, cls)
- _AddClearFieldMethod(message_descriptor, cls)
- if message_descriptor.is_extendable:
- _AddClearExtensionMethod(cls)
- _AddHasExtensionMethod(cls)
- _AddClearMethod(message_descriptor, cls)
- _AddEqualsMethod(message_descriptor, cls)
- _AddStrMethod(message_descriptor, cls)
- _AddSetListenerMethod(cls)
- _AddByteSizeMethod(message_descriptor, cls)
- _AddSerializeToStringMethod(message_descriptor, cls)
- _AddSerializePartialToStringMethod(message_descriptor, cls)
- _AddMergeFromStringMethod(message_descriptor, cls)
- _AddIsInitializedMethod(message_descriptor, cls)
- _AddMergeFromMethod(cls)
-
-
-def _AddPrivateHelperMethods(cls):
- """Adds implementation of private helper methods to cls."""
-
- def Modified(self):
- """Sets the _cached_byte_size_dirty bit to true,
- and propagates this to our listener iff this was a state change.
- """
-
- # Note: Some callers check _cached_byte_size_dirty before calling
- # _Modified() as an extra optimization. So, if this method is ever
- # changed such that it does stuff even when _cached_byte_size_dirty is
- # already true, the callers need to be updated.
- if not self._cached_byte_size_dirty:
- self._cached_byte_size_dirty = True
- self._listener_for_children.dirty = True
- self._is_present_in_parent = True
- self._listener.Modified()
-
- cls._Modified = Modified
- cls.SetInParent = Modified
-
-
-class _Listener(object):
-
- """MessageListener implementation that a parent message registers with its
- child message.
-
- In order to support semantics like:
-
- foo.bar.baz.qux = 23
- assert foo.HasField('bar')
-
- ...child objects must have back references to their parents.
- This helper class is at the heart of this support.
- """
-
- def __init__(self, parent_message):
- """Args:
- parent_message: The message whose _Modified() method we should call when
- we receive Modified() messages.
- """
- # This listener establishes a back reference from a child (contained) object
- # to its parent (containing) object. We make this a weak reference to avoid
- # creating cyclic garbage when the client finishes with the 'parent' object
- # in the tree.
- if isinstance(parent_message, weakref.ProxyType):
- self._parent_message_weakref = parent_message
- else:
- self._parent_message_weakref = weakref.proxy(parent_message)
-
- # As an optimization, we also indicate directly on the listener whether
- # or not the parent message is dirty. This way we can avoid traversing
- # up the tree in the common case.
- self.dirty = False
-
- def Modified(self):
- if self.dirty:
- return
- try:
- # Propagate the signal to our parents iff this is the first field set.
- self._parent_message_weakref._Modified()
- except ReferenceError:
- # We can get here if a client has kept a reference to a child object,
- # and is now setting a field on it, but the child's parent has been
- # garbage-collected. This is not an error.
- pass
-
-
-# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous...
-# TODO(robinson): Unify error handling of "unknown extension" crap.
-# TODO(robinson): Support iteritems()-style iteration over all
-# extensions with the "has" bits turned on?
-class _ExtensionDict(object):
-
- """Dict-like container for supporting an indexable "Extensions"
- field on proto instances.
-
- Note that in all cases we expect extension handles to be
- FieldDescriptors.
- """
-
- def __init__(self, extended_message):
- """extended_message: Message instance for which we are the Extensions dict.
- """
-
- self._extended_message = extended_message
-
- def __getitem__(self, extension_handle):
- """Returns the current value of the given extension handle."""
-
- _VerifyExtensionHandle(self._extended_message, extension_handle)
-
- result = self._extended_message._fields.get(extension_handle)
- if result is not None:
- return result
-
- if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
- result = extension_handle._default_constructor(self._extended_message)
- elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
- result = extension_handle.message_type._concrete_class()
- try:
- result._SetListener(self._extended_message._listener_for_children)
- except ReferenceError:
- pass
- else:
- # Singular scalar -- just return the default without inserting into the
- # dict.
- return extension_handle.default_value
-
- # Atomically check if another thread has preempted us and, if not, swap
- # in the new object we just created. If someone has preempted us, we
- # take that object and discard ours.
- # WARNING: We are relying on setdefault() being atomic. This is true
- # in CPython but we haven't investigated others. This warning appears
- # in several other locations in this file.
- result = self._extended_message._fields.setdefault(
- extension_handle, result)
-
- return result
-
- def __eq__(self, other):
- if not isinstance(other, self.__class__):
- return False
-
- my_fields = self._extended_message.ListFields()
- other_fields = other._extended_message.ListFields()
-
- # Get rid of non-extension fields.
- my_fields = [ field for field in my_fields if field.is_extension ]
- other_fields = [ field for field in other_fields if field.is_extension ]
-
- return my_fields == other_fields
-
- def __ne__(self, other):
- return not self == other
-
- # Note that this is only meaningful for non-repeated, scalar extension
- # fields. Note also that we may have to call _Modified() when we do
- # successfully set a field this way, to set any necssary "has" bits in the
- # ancestors of the extended message.
- def __setitem__(self, extension_handle, value):
- """If extension_handle specifies a non-repeated, scalar extension
- field, sets the value of that field.
- """
-
- _VerifyExtensionHandle(self._extended_message, extension_handle)
-
- if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or
- extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE):
- raise TypeError(
- 'Cannot assign to extension "%s" because it is a repeated or '
- 'composite type.' % extension_handle.full_name)
-
- # It's slightly wasteful to lookup the type checker each time,
- # but we expect this to be a vanishingly uncommon case anyway.
- type_checker = type_checkers.GetTypeChecker(
- extension_handle.cpp_type, extension_handle.type)
- type_checker.CheckValue(value)
- self._extended_message._fields[extension_handle] = value
- self._extended_message._Modified()
-
- def _FindExtensionByName(self, name):
- """Tries to find a known extension with the specified name.
-
- Args:
- name: Extension full name.
-
- Returns:
- Extension field descriptor.
- """
- return self._extended_message._extensions_by_name.get(name, None)
+# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# This code is meant to work on Python 2.4 and above only. +# +# TODO(robinson): Helpers for verbose, common checks like seeing if a +# descriptor's cpp_type is CPPTYPE_MESSAGE. + +"""Contains a metaclass and helper functions used to create +protocol message classes from Descriptor objects at runtime. + +Recall that a metaclass is the "type" of a class. +(A class is to a metaclass what an instance is to a class.) + +In this case, we use the GeneratedProtocolMessageType metaclass +to inject all the useful functionality into the classes +output by the protocol compiler at compile-time. + +The upshot of all this is that the real implementation +details for ALL pure-Python protocol buffers are *here in +this file*. +""" + +__author__ = '[email protected] (Will Robinson)' + +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO +import struct +import weakref + +# We use "as" to avoid name collisions with variables. +from google.protobuf.internal import containers +from google.protobuf.internal import decoder +from google.protobuf.internal import encoder +from google.protobuf.internal import message_listener as message_listener_mod +from google.protobuf.internal import type_checkers +from google.protobuf.internal import wire_format +from google.protobuf import descriptor as descriptor_mod +from google.protobuf import message as message_mod +from google.protobuf import text_format + +_FieldDescriptor = descriptor_mod.FieldDescriptor + + +class GeneratedProtocolMessageType(type): + + """Metaclass for protocol message classes created at runtime from Descriptors. + + We add implementations for all methods described in the Message class. We + also create properties to allow getting/setting all fields in the protocol + message. Finally, we create slots to prevent users from accidentally + "setting" nonexistent fields in the protocol message, which then wouldn't get + serialized / deserialized properly. + + The protocol compiler currently uses this metaclass to create protocol + message classes at runtime. Clients can also manually create their own + classes at runtime, as in this example: + + mydescriptor = Descriptor(.....) + class MyProtoClass(Message): + __metaclass__ = GeneratedProtocolMessageType + DESCRIPTOR = mydescriptor + myproto_instance = MyProtoClass() + myproto.foo_field = 23 + ... + """ + + # Must be consistent with the protocol-compiler code in + # proto2/compiler/internal/generator.*. + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __new__(cls, name, bases, dictionary): + """Custom allocation for runtime-generated class types. + + We override __new__ because this is apparently the only place + where we can meaningfully set __slots__ on the class we're creating(?). + (The interplay between metaclasses and slots is not very well-documented). + + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + + Returns: + Newly-allocated class. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + _AddSlots(descriptor, dictionary) + _AddClassAttributesForNestedExtensions(descriptor, dictionary) + superclass = super(GeneratedProtocolMessageType, cls) + return superclass.__new__(cls, name, bases, dictionary) + + def __init__(cls, name, bases, dictionary): + """Here we perform the majority of our work on the class. + We add enum getters, an __init__ method, implementations + of all Message methods, and properties for all fields + in the protocol type. + + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + + cls._decoders_by_tag = {} + cls._extensions_by_name = {} + cls._extensions_by_number = {} + if (descriptor.has_options and + descriptor.GetOptions().message_set_wire_format): + cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( + decoder.MessageSetItemDecoder(cls._extensions_by_number)) + + # We act as a "friend" class of the descriptor, setting + # its _concrete_class attribute the first time we use a + # given descriptor to initialize a concrete protocol message + # class. We also attach stuff to each FieldDescriptor for quick + # lookup later on. + concrete_class_attr_name = '_concrete_class' + if not hasattr(descriptor, concrete_class_attr_name): + setattr(descriptor, concrete_class_attr_name, cls) + for field in descriptor.fields: + _AttachFieldHelpers(cls, field) + + _AddEnumValues(descriptor, cls) + _AddInitMethod(descriptor, cls) + _AddPropertiesForFields(descriptor, cls) + _AddPropertiesForExtensions(descriptor, cls) + _AddStaticMethods(cls) + _AddMessageMethods(descriptor, cls) + _AddPrivateHelperMethods(cls) + superclass = super(GeneratedProtocolMessageType, cls) + superclass.__init__(name, bases, dictionary) + + +# Stateless helpers for GeneratedProtocolMessageType below. +# Outside clients should not access these directly. +# +# I opted not to make any of these methods on the metaclass, to make it more +# clear that I'm not really using any state there and to keep clients from +# thinking that they have direct access to these construction helpers. + + +def _PropertyName(proto_field_name): + """Returns the name of the public property attribute which + clients can use to get and (in some cases) set the value + of a protocol message field. + + Args: + proto_field_name: The protocol message field name, exactly + as it appears (or would appear) in a .proto file. + """ + # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. + # nnorwitz makes my day by writing: + # """ + # FYI. See the keyword module in the stdlib. This could be as simple as: + # + # if keyword.iskeyword(proto_field_name): + # return proto_field_name + "_" + # return proto_field_name + # """ + # Kenton says: The above is a BAD IDEA. People rely on being able to use + # getattr() and setattr() to reflectively manipulate field values. If we + # rename the properties, then every such user has to also make sure to apply + # the same transformation. Note that currently if you name a field "yield", + # you can still access it just fine using getattr/setattr -- it's not even + # that cumbersome to do so. + # TODO(kenton): Remove this method entirely if/when everyone agrees with my + # position. + return proto_field_name + + +def _VerifyExtensionHandle(message, extension_handle): + """Verify that the given extension handle is valid.""" + + if not isinstance(extension_handle, _FieldDescriptor): + raise KeyError('HasExtension() expects an extension handle, got: %s' % + extension_handle) + + if not extension_handle.is_extension: + raise KeyError('"%s" is not an extension.' % extension_handle.full_name) + + if extension_handle.containing_type is not message.DESCRIPTOR: + raise KeyError('Extension "%s" extends message type "%s", but this ' + 'message is of type "%s".' % + (extension_handle.full_name, + extension_handle.containing_type.full_name, + message.DESCRIPTOR.full_name)) + + +def _AddSlots(message_descriptor, dictionary): + """Adds a __slots__ entry to dictionary, containing the names of all valid + attributes for this message type. + + Args: + message_descriptor: A Descriptor instance describing this message type. + dictionary: Class dictionary to which we'll add a '__slots__' entry. + """ + dictionary['__slots__'] = ['_cached_byte_size', + '_cached_byte_size_dirty', + '_fields', + '_is_present_in_parent', + '_listener', + '_listener_for_children', + '__weakref__'] + + +def _IsMessageSetExtension(field): + return (field.is_extension and + field.containing_type.has_options and + field.containing_type.GetOptions().message_set_wire_format and + field.type == _FieldDescriptor.TYPE_MESSAGE and + field.message_type == field.extension_scope and + field.label == _FieldDescriptor.LABEL_OPTIONAL) + + +def _AttachFieldHelpers(cls, field_descriptor): + is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) + is_packed = (field_descriptor.has_options and + field_descriptor.GetOptions().packed) + + if _IsMessageSetExtension(field_descriptor): + field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) + sizer = encoder.MessageSetItemSizer(field_descriptor.number) + else: + field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed) + sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed) + + field_descriptor._encoder = field_encoder + field_descriptor._sizer = sizer + field_descriptor._default_constructor = _DefaultValueConstructorForField( + field_descriptor) + + def AddDecoder(wiretype, is_packed): + tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) + cls._decoders_by_tag[tag_bytes] = ( + type_checkers.TYPE_TO_DECODER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor)) + + AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], + False) + + if is_repeated and wire_format.IsTypePackable(field_descriptor.type): + # To support wire compatibility of adding packed = true, add a decoder for + # packed values regardless of the field's options. + AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) + + +def _AddClassAttributesForNestedExtensions(descriptor, dictionary): + extension_dict = descriptor.extensions_by_name + for extension_name, extension_field in extension_dict.iteritems(): + assert extension_name not in dictionary + dictionary[extension_name] = extension_field + + +def _AddEnumValues(descriptor, cls): + """Sets class-level attributes for all enum fields defined in this message. + + Args: + descriptor: Descriptor object for this message type. + cls: Class we're constructing for this message type. + """ + for enum_type in descriptor.enum_types: + for enum_value in enum_type.values: + setattr(cls, enum_value.name, enum_value.number) + + +def _DefaultValueConstructorForField(field): + """Returns a function which returns a default value for a field. + + Args: + field: FieldDescriptor object for this field. + + The returned function has one argument: + message: Message instance containing this field, or a weakref proxy + of same. + + That function in turn returns a default value for this field. The default + value may refer back to |message| via a weak reference. + """ + + if field.label == _FieldDescriptor.LABEL_REPEATED: + if field.default_value != []: + raise ValueError('Repeated field default value not empty list: %s' % ( + field.default_value)) + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + # We can't look at _concrete_class yet since it might not have + # been set. (Depends on order in which we initialize the classes). + message_type = field.message_type + def MakeRepeatedMessageDefault(message): + return containers.RepeatedCompositeFieldContainer( + message._listener_for_children, field.message_type) + return MakeRepeatedMessageDefault + else: + type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + def MakeRepeatedScalarDefault(message): + return containers.RepeatedScalarFieldContainer( + message._listener_for_children, type_checker) + return MakeRepeatedScalarDefault + + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + # _concrete_class may not yet be initialized. + message_type = field.message_type + def MakeSubMessageDefault(message): + result = message_type._concrete_class() + result._SetListener(message._listener_for_children) + return result + return MakeSubMessageDefault + + def MakeScalarDefault(message): + return field.default_value + return MakeScalarDefault + + +def _AddInitMethod(message_descriptor, cls): + """Adds an __init__ method to cls.""" + fields = message_descriptor.fields + def init(self, **kwargs): + self._cached_byte_size = 0 + self._cached_byte_size_dirty = False + self._fields = {} + self._is_present_in_parent = False + self._listener = message_listener_mod.NullMessageListener() + self._listener_for_children = _Listener(self) + for field_name, field_value in kwargs.iteritems(): + field = _GetFieldByName(message_descriptor, field_name) + if field is None: + raise TypeError("%s() got an unexpected keyword argument '%s'" % + (message_descriptor.name, field_name)) + if field.label == _FieldDescriptor.LABEL_REPEATED: + copy = field._default_constructor(self) + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite + for val in field_value: + copy.add().MergeFrom(val) + else: # Scalar + copy.extend(field_value) + self._fields[field] = copy + elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + copy = field._default_constructor(self) + copy.MergeFrom(field_value) + self._fields[field] = copy + else: + self._fields[field] = field_value + + init.__module__ = None + init.__doc__ = None + cls.__init__ = init + + +def _GetFieldByName(message_descriptor, field_name): + """Returns a field descriptor by field name. + + Args: + message_descriptor: A Descriptor describing all fields in message. + field_name: The name of the field to retrieve. + Returns: + The field descriptor associated with the field name. + """ + try: + return message_descriptor.fields_by_name[field_name] + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + + +def _AddPropertiesForFields(descriptor, cls): + """Adds properties for all fields in this protocol message type.""" + for field in descriptor.fields: + _AddPropertiesForField(field, cls) + + if descriptor.is_extendable: + # _ExtensionDict is just an adaptor with no state so we allocate a new one + # every time it is accessed. + cls.Extensions = property(lambda self: _ExtensionDict(self)) + + +def _AddPropertiesForField(field, cls): + """Adds a public property for a protocol message field. + Clients can use this property to get and (in the case + of non-repeated scalar fields) directly set the value + of a protocol message field. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + # Catch it if we add other types that we should + # handle specially here. + assert _FieldDescriptor.MAX_CPPTYPE == 10 + + constant_name = field.name.upper() + "_FIELD_NUMBER" + setattr(cls, constant_name, field.number) + + if field.label == _FieldDescriptor.LABEL_REPEATED: + _AddPropertiesForRepeatedField(field, cls) + elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + _AddPropertiesForNonRepeatedCompositeField(field, cls) + else: + _AddPropertiesForNonRepeatedScalarField(field, cls) + + +def _AddPropertiesForRepeatedField(field, cls): + """Adds a public property for a "repeated" protocol message field. Clients + can use this property to get the value of the field, which will be either a + _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see + below). + + Note that when clients add values to these containers, we perform + type-checking in the case of repeated scalar fields, and we also set any + necessary "has" bits as a side-effect. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + proto_field_name = field.name + property_name = _PropertyName(proto_field_name) + + def getter(self): + field_value = self._fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + field_value = self._fields.setdefault(field, field_value) + return field_value + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + # We define a setter just so we can throw an exception with a more + # helpful error message. + def setter(self, new_value): + raise AttributeError('Assignment not allowed to repeated field ' + '"%s" in protocol message object.' % proto_field_name) + + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForNonRepeatedScalarField(field, cls): + """Adds a public property for a nonrepeated, scalar protocol message field. + Clients can use this property to get and directly set the value of the field. + Note that when the client sets the value of a field by using this property, + all necessary "has" bits are set as a side-effect, and we also perform + type-checking. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + proto_field_name = field.name + property_name = _PropertyName(proto_field_name) + type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + default_value = field.default_value + + def getter(self): + return self._fields.get(field, default_value) + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + def setter(self, new_value): + type_checker.CheckValue(new_value) + self._fields[field] = new_value + # Check _cached_byte_size_dirty inline to improve performance, since scalar + # setters are called frequently. + if not self._cached_byte_size_dirty: + self._Modified() + setter.__module__ = None + setter.__doc__ = 'Setter for %s.' % proto_field_name + + # Add a property to encapsulate the getter/setter. + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForNonRepeatedCompositeField(field, cls): + """Adds a public property for a nonrepeated, composite protocol message field. + A composite field is a "group" or "message" field. + + Clients can use this property to get the value of the field, but cannot + assign to the property directly. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + # TODO(robinson): Remove duplication with similar method + # for non-repeated scalars. + proto_field_name = field.name + property_name = _PropertyName(proto_field_name) + message_type = field.message_type + + def getter(self): + field_value = self._fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = message_type._concrete_class() + field_value._SetListener(self._listener_for_children) + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + field_value = self._fields.setdefault(field, field_value) + return field_value + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + # We define a setter just so we can throw an exception with a more + # helpful error message. + def setter(self, new_value): + raise AttributeError('Assignment not allowed to composite field ' + '"%s" in protocol message object.' % proto_field_name) + + # Add a property to encapsulate the getter. + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForExtensions(descriptor, cls): + """Adds properties for all fields in this protocol message type.""" + extension_dict = descriptor.extensions_by_name + for extension_name, extension_field in extension_dict.iteritems(): + constant_name = extension_name.upper() + "_FIELD_NUMBER" + setattr(cls, constant_name, extension_field.number) + + +def _AddStaticMethods(cls): + # TODO(robinson): This probably needs to be thread-safe(?) + def RegisterExtension(extension_handle): + extension_handle.containing_type = cls.DESCRIPTOR + _AttachFieldHelpers(cls, extension_handle) + + # Try to insert our extension, failing if an extension with the same number + # already exists. + actual_handle = cls._extensions_by_number.setdefault( + extension_handle.number, extension_handle) + if actual_handle is not extension_handle: + raise AssertionError( + 'Extensions "%s" and "%s" both try to extend message type "%s" with ' + 'field number %d.' % + (extension_handle.full_name, actual_handle.full_name, + cls.DESCRIPTOR.full_name, extension_handle.number)) + + cls._extensions_by_name[extension_handle.full_name] = extension_handle + + handle = extension_handle # avoid line wrapping + if _IsMessageSetExtension(handle): + # MessageSet extension. Also register under type name. + cls._extensions_by_name[ + extension_handle.message_type.full_name] = extension_handle + + cls.RegisterExtension = staticmethod(RegisterExtension) + + def FromString(s): + message = cls() + message.MergeFromString(s) + return message + cls.FromString = staticmethod(FromString) + + +def _IsPresent(item): + """Given a (FieldDescriptor, value) tuple from _fields, return true if the + value should be included in the list returned by ListFields().""" + + if item[0].label == _FieldDescriptor.LABEL_REPEATED: + return bool(item[1]) + elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + return item[1]._is_present_in_parent + else: + return True + + +def _AddListFieldsMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def ListFields(self): + all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] + all_fields.sort(key = lambda item: item[0].number) + return all_fields + + cls.ListFields = ListFields + + +def _AddHasFieldMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + singular_fields = {} + for field in message_descriptor.fields: + if field.label != _FieldDescriptor.LABEL_REPEATED: + singular_fields[field.name] = field + + def HasField(self, field_name): + try: + field = singular_fields[field_name] + except KeyError: + raise ValueError( + 'Protocol message has no singular "%s" field.' % field_name) + + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(field) + return value is not None and value._is_present_in_parent + else: + return field in self._fields + cls.HasField = HasField + + +def _AddClearFieldMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def ClearField(self, field_name): + try: + field = message_descriptor.fields_by_name[field_name] + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + + if field in self._fields: + # Note: If the field is a sub-message, its listener will still point + # at us. That's fine, because the worst than can happen is that it + # will call _Modified() and invalidate our byte size. Big deal. + del self._fields[field] + + # Always call _Modified() -- even if nothing was changed, this is + # a mutating method, and thus calling it should cause the field to become + # present in the parent message. + self._Modified() + + cls.ClearField = ClearField + + +def _AddClearExtensionMethod(cls): + """Helper for _AddMessageMethods().""" + def ClearExtension(self, extension_handle): + _VerifyExtensionHandle(self, extension_handle) + + # Similar to ClearField(), above. + if extension_handle in self._fields: + del self._fields[extension_handle] + self._Modified() + cls.ClearExtension = ClearExtension + + +def _AddClearMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def Clear(self): + # Clear fields. + self._fields = {} + self._Modified() + cls.Clear = Clear + + +def _AddHasExtensionMethod(cls): + """Helper for _AddMessageMethods().""" + def HasExtension(self, extension_handle): + _VerifyExtensionHandle(self, extension_handle) + if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: + raise KeyError('"%s" is repeated.' % extension_handle.full_name) + + if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(extension_handle) + return value is not None and value._is_present_in_parent + else: + return extension_handle in self._fields + cls.HasExtension = HasExtension + + +def _AddEqualsMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __eq__(self, other): + if (not isinstance(other, message_mod.Message) or + other.DESCRIPTOR != self.DESCRIPTOR): + return False + + if self is other: + return True + + return self.ListFields() == other.ListFields() + + cls.__eq__ = __eq__ + + +def _AddStrMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __str__(self): + return text_format.MessageToString(self) + cls.__str__ = __str__ + + +def _AddSetListenerMethod(cls): + """Helper for _AddMessageMethods().""" + def SetListener(self, listener): + if listener is None: + self._listener = message_listener_mod.NullMessageListener() + else: + self._listener = listener + cls._SetListener = SetListener + + +def _BytesForNonRepeatedElement(value, field_number, field_type): + """Returns the number of bytes needed to serialize a non-repeated element. + The returned byte count includes space for tag information and any + other additional space associated with serializing value. + + Args: + value: Value we're serializing. + field_number: Field number of this value. (Since the field number + is stored as part of a varint-encoded tag, this has an impact + on the total bytes required to serialize the value). + field_type: The type of the field. One of the TYPE_* constants + within FieldDescriptor. + """ + try: + fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] + return fn(field_number, value) + except KeyError: + raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) + + +def _AddByteSizeMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def ByteSize(self): + if not self._cached_byte_size_dirty: + return self._cached_byte_size + + size = 0 + for field_descriptor, field_value in self.ListFields(): + size += field_descriptor._sizer(field_value) + + self._cached_byte_size = size + self._cached_byte_size_dirty = False + self._listener_for_children.dirty = False + return size + + cls.ByteSize = ByteSize + + +def _AddSerializeToStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def SerializeToString(self): + # Check if the message has all of its required fields set. + errors = [] + if not self.IsInitialized(): + raise message_mod.EncodeError( + 'Message is missing required fields: ' + + ','.join(self.FindInitializationErrors())) + return self.SerializePartialToString() + cls.SerializeToString = SerializeToString + + +def _AddSerializePartialToStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def SerializePartialToString(self): + out = StringIO() + self._InternalSerialize(out.write) + return out.getvalue() + cls.SerializePartialToString = SerializePartialToString + + def InternalSerialize(self, write_bytes): + for field_descriptor, field_value in self.ListFields(): + field_descriptor._encoder(write_bytes, field_value) + cls._InternalSerialize = InternalSerialize + + +def _AddMergeFromStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def MergeFromString(self, serialized): + length = len(serialized) + try: + if self._InternalParse(serialized, 0, length) != length: + # The only reason _InternalParse would return early is if it + # encountered an end-group tag. + raise message_mod.DecodeError('Unexpected end-group tag.') + except IndexError: + raise message_mod.DecodeError('Truncated message.') + except struct.error, e: + raise message_mod.DecodeError(e) + return length # Return this for legacy reasons. + cls.MergeFromString = MergeFromString + + local_ReadTag = decoder.ReadTag + local_SkipField = decoder.SkipField + decoders_by_tag = cls._decoders_by_tag + + def InternalParse(self, buffer, pos, end): + self._Modified() + field_dict = self._fields + while pos != end: + (tag_bytes, new_pos) = local_ReadTag(buffer, pos) + field_decoder = decoders_by_tag.get(tag_bytes) + if field_decoder is None: + new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) + if new_pos == -1: + return pos + pos = new_pos + else: + pos = field_decoder(buffer, new_pos, end, self, field_dict) + return pos + cls._InternalParse = InternalParse + + +def _AddIsInitializedMethod(message_descriptor, cls): + """Adds the IsInitialized and FindInitializationError methods to the + protocol message class.""" + + required_fields = [field for field in message_descriptor.fields + if field.label == _FieldDescriptor.LABEL_REQUIRED] + + def IsInitialized(self, errors=None): + """Checks if all required fields of a message are set. + + Args: + errors: A list which, if provided, will be populated with the field + paths of all missing required fields. + + Returns: + True iff the specified message has all required fields set. + """ + + # Performance is critical so we avoid HasField() and ListFields(). + + for field in required_fields: + if (field not in self._fields or + (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and + not self._fields[field]._is_present_in_parent)): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False + + for field, value in self._fields.iteritems(): + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.label == _FieldDescriptor.LABEL_REPEATED: + for element in value: + if not element.IsInitialized(): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False + elif value._is_present_in_parent and not value.IsInitialized(): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False + + return True + + cls.IsInitialized = IsInitialized + + def FindInitializationErrors(self): + """Finds required fields which are not initialized. + + Returns: + A list of strings. Each string is a path to an uninitialized field from + the top-level message, e.g. "foo.bar[5].baz". + """ + + errors = [] # simplify things + + for field in required_fields: + if not self.HasField(field.name): + errors.append(field.name) + + for field, value in self.ListFields(): + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.is_extension: + name = "(%s)" % field.full_name + else: + name = field.name + + if field.label == _FieldDescriptor.LABEL_REPEATED: + for i in xrange(len(value)): + element = value[i] + prefix = "%s[%d]." % (name, i) + sub_errors = element.FindInitializationErrors() + errors += [ prefix + error for error in sub_errors ] + else: + prefix = name + "." + sub_errors = value.FindInitializationErrors() + errors += [ prefix + error for error in sub_errors ] + + return errors + + cls.FindInitializationErrors = FindInitializationErrors + + +def _AddMergeFromMethod(cls): + LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED + CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE + + def MergeFrom(self, msg): + assert msg is not self + self._Modified() + + fields = self._fields + + for field, value in msg._fields.iteritems(): + if field.label == LABEL_REPEATED or field.cpp_type == CPPTYPE_MESSAGE: + field_value = fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + fields[field] = field_value + field_value.MergeFrom(value) + else: + self._fields[field] = value + cls.MergeFrom = MergeFrom + + +def _AddMessageMethods(message_descriptor, cls): + """Adds implementations of all Message methods to cls.""" + _AddListFieldsMethod(message_descriptor, cls) + _AddHasFieldMethod(message_descriptor, cls) + _AddClearFieldMethod(message_descriptor, cls) + if message_descriptor.is_extendable: + _AddClearExtensionMethod(cls) + _AddHasExtensionMethod(cls) + _AddClearMethod(message_descriptor, cls) + _AddEqualsMethod(message_descriptor, cls) + _AddStrMethod(message_descriptor, cls) + _AddSetListenerMethod(cls) + _AddByteSizeMethod(message_descriptor, cls) + _AddSerializeToStringMethod(message_descriptor, cls) + _AddSerializePartialToStringMethod(message_descriptor, cls) + _AddMergeFromStringMethod(message_descriptor, cls) + _AddIsInitializedMethod(message_descriptor, cls) + _AddMergeFromMethod(cls) + + +def _AddPrivateHelperMethods(cls): + """Adds implementation of private helper methods to cls.""" + + def Modified(self): + """Sets the _cached_byte_size_dirty bit to true, + and propagates this to our listener iff this was a state change. + """ + + # Note: Some callers check _cached_byte_size_dirty before calling + # _Modified() as an extra optimization. So, if this method is ever + # changed such that it does stuff even when _cached_byte_size_dirty is + # already true, the callers need to be updated. + if not self._cached_byte_size_dirty: + self._cached_byte_size_dirty = True + self._listener_for_children.dirty = True + self._is_present_in_parent = True + self._listener.Modified() + + cls._Modified = Modified + cls.SetInParent = Modified + + +class _Listener(object): + + """MessageListener implementation that a parent message registers with its + child message. + + In order to support semantics like: + + foo.bar.baz.qux = 23 + assert foo.HasField('bar') + + ...child objects must have back references to their parents. + This helper class is at the heart of this support. + """ + + def __init__(self, parent_message): + """Args: + parent_message: The message whose _Modified() method we should call when + we receive Modified() messages. + """ + # This listener establishes a back reference from a child (contained) object + # to its parent (containing) object. We make this a weak reference to avoid + # creating cyclic garbage when the client finishes with the 'parent' object + # in the tree. + if isinstance(parent_message, weakref.ProxyType): + self._parent_message_weakref = parent_message + else: + self._parent_message_weakref = weakref.proxy(parent_message) + + # As an optimization, we also indicate directly on the listener whether + # or not the parent message is dirty. This way we can avoid traversing + # up the tree in the common case. + self.dirty = False + + def Modified(self): + if self.dirty: + return + try: + # Propagate the signal to our parents iff this is the first field set. + self._parent_message_weakref._Modified() + except ReferenceError: + # We can get here if a client has kept a reference to a child object, + # and is now setting a field on it, but the child's parent has been + # garbage-collected. This is not an error. + pass + + +# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... +# TODO(robinson): Unify error handling of "unknown extension" crap. +# TODO(robinson): Support iteritems()-style iteration over all +# extensions with the "has" bits turned on? +class _ExtensionDict(object): + + """Dict-like container for supporting an indexable "Extensions" + field on proto instances. + + Note that in all cases we expect extension handles to be + FieldDescriptors. + """ + + def __init__(self, extended_message): + """extended_message: Message instance for which we are the Extensions dict. + """ + + self._extended_message = extended_message + + def __getitem__(self, extension_handle): + """Returns the current value of the given extension handle.""" + + _VerifyExtensionHandle(self._extended_message, extension_handle) + + result = self._extended_message._fields.get(extension_handle) + if result is not None: + return result + + if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: + result = extension_handle._default_constructor(self._extended_message) + elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + result = extension_handle.message_type._concrete_class() + try: + result._SetListener(self._extended_message._listener_for_children) + except ReferenceError: + pass + else: + # Singular scalar -- just return the default without inserting into the + # dict. + return extension_handle.default_value + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + result = self._extended_message._fields.setdefault( + extension_handle, result) + + return result + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + + my_fields = self._extended_message.ListFields() + other_fields = other._extended_message.ListFields() + + # Get rid of non-extension fields. + my_fields = [ field for field in my_fields if field.is_extension ] + other_fields = [ field for field in other_fields if field.is_extension ] + + return my_fields == other_fields + + def __ne__(self, other): + return not self == other + + # Note that this is only meaningful for non-repeated, scalar extension + # fields. Note also that we may have to call _Modified() when we do + # successfully set a field this way, to set any necssary "has" bits in the + # ancestors of the extended message. + def __setitem__(self, extension_handle, value): + """If extension_handle specifies a non-repeated, scalar extension + field, sets the value of that field. + """ + + _VerifyExtensionHandle(self._extended_message, extension_handle) + + if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or + extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): + raise TypeError( + 'Cannot assign to extension "%s" because it is a repeated or ' + 'composite type.' % extension_handle.full_name) + + # It's slightly wasteful to lookup the type checker each time, + # but we expect this to be a vanishingly uncommon case anyway. + type_checker = type_checkers.GetTypeChecker( + extension_handle.cpp_type, extension_handle.type) + type_checker.CheckValue(value) + self._extended_message._fields[extension_handle] = value + self._extended_message._Modified() + + def _FindExtensionByName(self, name): + """Tries to find a known extension with the specified name. + + Args: + name: Extension full name. + + Returns: + Extension field descriptor. + """ + return self._extended_message._extensions_by_name.get(name, None) diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/service.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/service.py index 52dec654..180b70e8 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/service.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/service.py @@ -1,226 +1,226 @@ -# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""DEPRECATED: Declares the RPC service interfaces.
-
-This module declares the abstract interfaces underlying proto2 RPC
-services. These are intended to be independent of any particular RPC
-implementation, so that proto2 services can be used on top of a variety
-of implementations. Starting with version 2.3.0, RPC implementations should
-not try to build on these, but should instead provide code generator plugins
-which generate code specific to the particular RPC implementation. This way
-the generated code can be more appropriate for the implementation in use
-and can avoid unnecessary layers of indirection.
-"""
-
-__author__ = '[email protected] (Petar Petrov)'
-
-
-class RpcException(Exception):
- """Exception raised on failed blocking RPC method call."""
- pass
-
-
-class Service(object):
-
- """Abstract base interface for protocol-buffer-based RPC services.
-
- Services themselves are abstract classes (implemented either by servers or as
- stubs), but they subclass this base interface. The methods of this
- interface can be used to call the methods of the service without knowing
- its exact type at compile time (analogous to the Message interface).
- """
-
- def GetDescriptor():
- """Retrieves this service's descriptor."""
- raise NotImplementedError
-
- def CallMethod(self, method_descriptor, rpc_controller,
- request, done):
- """Calls a method of the service specified by method_descriptor.
-
- If "done" is None then the call is blocking and the response
- message will be returned directly. Otherwise the call is asynchronous
- and "done" will later be called with the response value.
-
- In the blocking case, RpcException will be raised on error.
-
- Preconditions:
- * method_descriptor.service == GetDescriptor
- * request is of the exact same classes as returned by
- GetRequestClass(method).
- * After the call has started, the request must not be modified.
- * "rpc_controller" is of the correct type for the RPC implementation being
- used by this Service. For stubs, the "correct type" depends on the
- RpcChannel which the stub is using.
-
- Postconditions:
- * "done" will be called when the method is complete. This may be
- before CallMethod() returns or it may be at some point in the future.
- * If the RPC failed, the response value passed to "done" will be None.
- Further details about the failure can be found by querying the
- RpcController.
- """
- raise NotImplementedError
-
- def GetRequestClass(self, method_descriptor):
- """Returns the class of the request message for the specified method.
-
- CallMethod() requires that the request is of a particular subclass of
- Message. GetRequestClass() gets the default instance of this required
- type.
-
- Example:
- method = service.GetDescriptor().FindMethodByName("Foo")
- request = stub.GetRequestClass(method)()
- request.ParseFromString(input)
- service.CallMethod(method, request, callback)
- """
- raise NotImplementedError
-
- def GetResponseClass(self, method_descriptor):
- """Returns the class of the response message for the specified method.
-
- This method isn't really needed, as the RpcChannel's CallMethod constructs
- the response protocol message. It's provided anyway in case it is useful
- for the caller to know the response type in advance.
- """
- raise NotImplementedError
-
-
-class RpcController(object):
-
- """An RpcController mediates a single method call.
-
- The primary purpose of the controller is to provide a way to manipulate
- settings specific to the RPC implementation and to find out about RPC-level
- errors. The methods provided by the RpcController interface are intended
- to be a "least common denominator" set of features which we expect all
- implementations to support. Specific implementations may provide more
- advanced features (e.g. deadline propagation).
- """
-
- # Client-side methods below
-
- def Reset(self):
- """Resets the RpcController to its initial state.
-
- After the RpcController has been reset, it may be reused in
- a new call. Must not be called while an RPC is in progress.
- """
- raise NotImplementedError
-
- def Failed(self):
- """Returns true if the call failed.
-
- After a call has finished, returns true if the call failed. The possible
- reasons for failure depend on the RPC implementation. Failed() must not
- be called before a call has finished. If Failed() returns true, the
- contents of the response message are undefined.
- """
- raise NotImplementedError
-
- def ErrorText(self):
- """If Failed is true, returns a human-readable description of the error."""
- raise NotImplementedError
-
- def StartCancel(self):
- """Initiate cancellation.
-
- Advises the RPC system that the caller desires that the RPC call be
- canceled. The RPC system may cancel it immediately, may wait awhile and
- then cancel it, or may not even cancel the call at all. If the call is
- canceled, the "done" callback will still be called and the RpcController
- will indicate that the call failed at that time.
- """
- raise NotImplementedError
-
- # Server-side methods below
-
- def SetFailed(self, reason):
- """Sets a failure reason.
-
- Causes Failed() to return true on the client side. "reason" will be
- incorporated into the message returned by ErrorText(). If you find
- you need to return machine-readable information about failures, you
- should incorporate it into your response protocol buffer and should
- NOT call SetFailed().
- """
- raise NotImplementedError
-
- def IsCanceled(self):
- """Checks if the client cancelled the RPC.
-
- If true, indicates that the client canceled the RPC, so the server may
- as well give up on replying to it. The server should still call the
- final "done" callback.
- """
- raise NotImplementedError
-
- def NotifyOnCancel(self, callback):
- """Sets a callback to invoke on cancel.
-
- Asks that the given callback be called when the RPC is canceled. The
- callback will always be called exactly once. If the RPC completes without
- being canceled, the callback will be called after completion. If the RPC
- has already been canceled when NotifyOnCancel() is called, the callback
- will be called immediately.
-
- NotifyOnCancel() must be called no more than once per request.
- """
- raise NotImplementedError
-
-
-class RpcChannel(object):
-
- """Abstract interface for an RPC channel.
-
- An RpcChannel represents a communication line to a service which can be used
- to call that service's methods. The service may be running on another
- machine. Normally, you should not use an RpcChannel directly, but instead
- construct a stub {@link Service} wrapping it. Example:
-
- Example:
- RpcChannel channel = rpcImpl.Channel("remotehost.example.com:1234")
- RpcController controller = rpcImpl.Controller()
- MyService service = MyService_Stub(channel)
- service.MyMethod(controller, request, callback)
- """
-
- def CallMethod(self, method_descriptor, rpc_controller,
- request, response_class, done):
- """Calls the method identified by the descriptor.
-
- Call the given method of the remote service. The signature of this
- procedure looks the same as Service.CallMethod(), but the requirements
- are less strict in one important way: the request object doesn't have to
- be of any specific class as long as its descriptor is method.input_type.
- """
- raise NotImplementedError
+# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""DEPRECATED: Declares the RPC service interfaces. + +This module declares the abstract interfaces underlying proto2 RPC +services. These are intended to be independent of any particular RPC +implementation, so that proto2 services can be used on top of a variety +of implementations. Starting with version 2.3.0, RPC implementations should +not try to build on these, but should instead provide code generator plugins +which generate code specific to the particular RPC implementation. This way +the generated code can be more appropriate for the implementation in use +and can avoid unnecessary layers of indirection. +""" + +__author__ = '[email protected] (Petar Petrov)' + + +class RpcException(Exception): + """Exception raised on failed blocking RPC method call.""" + pass + + +class Service(object): + + """Abstract base interface for protocol-buffer-based RPC services. + + Services themselves are abstract classes (implemented either by servers or as + stubs), but they subclass this base interface. The methods of this + interface can be used to call the methods of the service without knowing + its exact type at compile time (analogous to the Message interface). + """ + + def GetDescriptor(): + """Retrieves this service's descriptor.""" + raise NotImplementedError + + def CallMethod(self, method_descriptor, rpc_controller, + request, done): + """Calls a method of the service specified by method_descriptor. + + If "done" is None then the call is blocking and the response + message will be returned directly. Otherwise the call is asynchronous + and "done" will later be called with the response value. + + In the blocking case, RpcException will be raised on error. + + Preconditions: + * method_descriptor.service == GetDescriptor + * request is of the exact same classes as returned by + GetRequestClass(method). + * After the call has started, the request must not be modified. + * "rpc_controller" is of the correct type for the RPC implementation being + used by this Service. For stubs, the "correct type" depends on the + RpcChannel which the stub is using. + + Postconditions: + * "done" will be called when the method is complete. This may be + before CallMethod() returns or it may be at some point in the future. + * If the RPC failed, the response value passed to "done" will be None. + Further details about the failure can be found by querying the + RpcController. + """ + raise NotImplementedError + + def GetRequestClass(self, method_descriptor): + """Returns the class of the request message for the specified method. + + CallMethod() requires that the request is of a particular subclass of + Message. GetRequestClass() gets the default instance of this required + type. + + Example: + method = service.GetDescriptor().FindMethodByName("Foo") + request = stub.GetRequestClass(method)() + request.ParseFromString(input) + service.CallMethod(method, request, callback) + """ + raise NotImplementedError + + def GetResponseClass(self, method_descriptor): + """Returns the class of the response message for the specified method. + + This method isn't really needed, as the RpcChannel's CallMethod constructs + the response protocol message. It's provided anyway in case it is useful + for the caller to know the response type in advance. + """ + raise NotImplementedError + + +class RpcController(object): + + """An RpcController mediates a single method call. + + The primary purpose of the controller is to provide a way to manipulate + settings specific to the RPC implementation and to find out about RPC-level + errors. The methods provided by the RpcController interface are intended + to be a "least common denominator" set of features which we expect all + implementations to support. Specific implementations may provide more + advanced features (e.g. deadline propagation). + """ + + # Client-side methods below + + def Reset(self): + """Resets the RpcController to its initial state. + + After the RpcController has been reset, it may be reused in + a new call. Must not be called while an RPC is in progress. + """ + raise NotImplementedError + + def Failed(self): + """Returns true if the call failed. + + After a call has finished, returns true if the call failed. The possible + reasons for failure depend on the RPC implementation. Failed() must not + be called before a call has finished. If Failed() returns true, the + contents of the response message are undefined. + """ + raise NotImplementedError + + def ErrorText(self): + """If Failed is true, returns a human-readable description of the error.""" + raise NotImplementedError + + def StartCancel(self): + """Initiate cancellation. + + Advises the RPC system that the caller desires that the RPC call be + canceled. The RPC system may cancel it immediately, may wait awhile and + then cancel it, or may not even cancel the call at all. If the call is + canceled, the "done" callback will still be called and the RpcController + will indicate that the call failed at that time. + """ + raise NotImplementedError + + # Server-side methods below + + def SetFailed(self, reason): + """Sets a failure reason. + + Causes Failed() to return true on the client side. "reason" will be + incorporated into the message returned by ErrorText(). If you find + you need to return machine-readable information about failures, you + should incorporate it into your response protocol buffer and should + NOT call SetFailed(). + """ + raise NotImplementedError + + def IsCanceled(self): + """Checks if the client cancelled the RPC. + + If true, indicates that the client canceled the RPC, so the server may + as well give up on replying to it. The server should still call the + final "done" callback. + """ + raise NotImplementedError + + def NotifyOnCancel(self, callback): + """Sets a callback to invoke on cancel. + + Asks that the given callback be called when the RPC is canceled. The + callback will always be called exactly once. If the RPC completes without + being canceled, the callback will be called after completion. If the RPC + has already been canceled when NotifyOnCancel() is called, the callback + will be called immediately. + + NotifyOnCancel() must be called no more than once per request. + """ + raise NotImplementedError + + +class RpcChannel(object): + + """Abstract interface for an RPC channel. + + An RpcChannel represents a communication line to a service which can be used + to call that service's methods. The service may be running on another + machine. Normally, you should not use an RpcChannel directly, but instead + construct a stub {@link Service} wrapping it. Example: + + Example: + RpcChannel channel = rpcImpl.Channel("remotehost.example.com:1234") + RpcController controller = rpcImpl.Controller() + MyService service = MyService_Stub(channel) + service.MyMethod(controller, request, callback) + """ + + def CallMethod(self, method_descriptor, rpc_controller, + request, response_class, done): + """Calls the method identified by the descriptor. + + Call the given method of the remote service. The signature of this + procedure looks the same as Service.CallMethod(), but the requirements + are less strict in one important way: the request object doesn't have to + be of any specific class as long as its descriptor is method.input_type. + """ + raise NotImplementedError diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/service_reflection.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/service_reflection.py index 4604e5c5..851e83e7 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/service_reflection.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/service_reflection.py @@ -1,284 +1,284 @@ -# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Contains metaclasses used to create protocol service and service stub
-classes from ServiceDescriptor objects at runtime.
-
-The GeneratedServiceType and GeneratedServiceStubType metaclasses are used to
-inject all useful functionality into the classes output by the protocol
-compiler at compile-time.
-"""
-
-__author__ = '[email protected] (Petar Petrov)'
-
-
-class GeneratedServiceType(type):
-
- """Metaclass for service classes created at runtime from ServiceDescriptors.
-
- Implementations for all methods described in the Service class are added here
- by this class. We also create properties to allow getting/setting all fields
- in the protocol message.
-
- The protocol compiler currently uses this metaclass to create protocol service
- classes at runtime. Clients can also manually create their own classes at
- runtime, as in this example:
-
- mydescriptor = ServiceDescriptor(.....)
- class MyProtoService(service.Service):
- __metaclass__ = GeneratedServiceType
- DESCRIPTOR = mydescriptor
- myservice_instance = MyProtoService()
- ...
- """
-
- _DESCRIPTOR_KEY = 'DESCRIPTOR'
-
- def __init__(cls, name, bases, dictionary):
- """Creates a message service class.
-
- Args:
- name: Name of the class (ignored, but required by the metaclass
- protocol).
- bases: Base classes of the class being constructed.
- dictionary: The class dictionary of the class being constructed.
- dictionary[_DESCRIPTOR_KEY] must contain a ServiceDescriptor object
- describing this protocol service type.
- """
- # Don't do anything if this class doesn't have a descriptor. This happens
- # when a service class is subclassed.
- if GeneratedServiceType._DESCRIPTOR_KEY not in dictionary:
- return
- descriptor = dictionary[GeneratedServiceType._DESCRIPTOR_KEY]
- service_builder = _ServiceBuilder(descriptor)
- service_builder.BuildService(cls)
-
-
-class GeneratedServiceStubType(GeneratedServiceType):
-
- """Metaclass for service stubs created at runtime from ServiceDescriptors.
-
- This class has similar responsibilities as GeneratedServiceType, except that
- it creates the service stub classes.
- """
-
- _DESCRIPTOR_KEY = 'DESCRIPTOR'
-
- def __init__(cls, name, bases, dictionary):
- """Creates a message service stub class.
-
- Args:
- name: Name of the class (ignored, here).
- bases: Base classes of the class being constructed.
- dictionary: The class dictionary of the class being constructed.
- dictionary[_DESCRIPTOR_KEY] must contain a ServiceDescriptor object
- describing this protocol service type.
- """
- super(GeneratedServiceStubType, cls).__init__(name, bases, dictionary)
- # Don't do anything if this class doesn't have a descriptor. This happens
- # when a service stub is subclassed.
- if GeneratedServiceStubType._DESCRIPTOR_KEY not in dictionary:
- return
- descriptor = dictionary[GeneratedServiceStubType._DESCRIPTOR_KEY]
- service_stub_builder = _ServiceStubBuilder(descriptor)
- service_stub_builder.BuildServiceStub(cls)
-
-
-class _ServiceBuilder(object):
-
- """This class constructs a protocol service class using a service descriptor.
-
- Given a service descriptor, this class constructs a class that represents
- the specified service descriptor. One service builder instance constructs
- exactly one service class. That means all instances of that class share the
- same builder.
- """
-
- def __init__(self, service_descriptor):
- """Initializes an instance of the service class builder.
-
- Args:
- service_descriptor: ServiceDescriptor to use when constructing the
- service class.
- """
- self.descriptor = service_descriptor
-
- def BuildService(self, cls):
- """Constructs the service class.
-
- Args:
- cls: The class that will be constructed.
- """
-
- # CallMethod needs to operate with an instance of the Service class. This
- # internal wrapper function exists only to be able to pass the service
- # instance to the method that does the real CallMethod work.
- def _WrapCallMethod(srvc, method_descriptor,
- rpc_controller, request, callback):
- return self._CallMethod(srvc, method_descriptor,
- rpc_controller, request, callback)
- self.cls = cls
- cls.CallMethod = _WrapCallMethod
- cls.GetDescriptor = staticmethod(lambda: self.descriptor)
- cls.GetDescriptor.__doc__ = "Returns the service descriptor."
- cls.GetRequestClass = self._GetRequestClass
- cls.GetResponseClass = self._GetResponseClass
- for method in self.descriptor.methods:
- setattr(cls, method.name, self._GenerateNonImplementedMethod(method))
-
- def _CallMethod(self, srvc, method_descriptor,
- rpc_controller, request, callback):
- """Calls the method described by a given method descriptor.
-
- Args:
- srvc: Instance of the service for which this method is called.
- method_descriptor: Descriptor that represent the method to call.
- rpc_controller: RPC controller to use for this method's execution.
- request: Request protocol message.
- callback: A callback to invoke after the method has completed.
- """
- if method_descriptor.containing_service != self.descriptor:
- raise RuntimeError(
- 'CallMethod() given method descriptor for wrong service type.')
- method = getattr(srvc, method_descriptor.name)
- return method(rpc_controller, request, callback)
-
- def _GetRequestClass(self, method_descriptor):
- """Returns the class of the request protocol message.
-
- Args:
- method_descriptor: Descriptor of the method for which to return the
- request protocol message class.
-
- Returns:
- A class that represents the input protocol message of the specified
- method.
- """
- if method_descriptor.containing_service != self.descriptor:
- raise RuntimeError(
- 'GetRequestClass() given method descriptor for wrong service type.')
- return method_descriptor.input_type._concrete_class
-
- def _GetResponseClass(self, method_descriptor):
- """Returns the class of the response protocol message.
-
- Args:
- method_descriptor: Descriptor of the method for which to return the
- response protocol message class.
-
- Returns:
- A class that represents the output protocol message of the specified
- method.
- """
- if method_descriptor.containing_service != self.descriptor:
- raise RuntimeError(
- 'GetResponseClass() given method descriptor for wrong service type.')
- return method_descriptor.output_type._concrete_class
-
- def _GenerateNonImplementedMethod(self, method):
- """Generates and returns a method that can be set for a service methods.
-
- Args:
- method: Descriptor of the service method for which a method is to be
- generated.
-
- Returns:
- A method that can be added to the service class.
- """
- return lambda inst, rpc_controller, request, callback: (
- self._NonImplementedMethod(method.name, rpc_controller, callback))
-
- def _NonImplementedMethod(self, method_name, rpc_controller, callback):
- """The body of all methods in the generated service class.
-
- Args:
- method_name: Name of the method being executed.
- rpc_controller: RPC controller used to execute this method.
- callback: A callback which will be invoked when the method finishes.
- """
- rpc_controller.SetFailed('Method %s not implemented.' % method_name)
- callback(None)
-
-
-class _ServiceStubBuilder(object):
-
- """Constructs a protocol service stub class using a service descriptor.
-
- Given a service descriptor, this class constructs a suitable stub class.
- A stub is just a type-safe wrapper around an RpcChannel which emulates a
- local implementation of the service.
-
- One service stub builder instance constructs exactly one class. It means all
- instances of that class share the same service stub builder.
- """
-
- def __init__(self, service_descriptor):
- """Initializes an instance of the service stub class builder.
-
- Args:
- service_descriptor: ServiceDescriptor to use when constructing the
- stub class.
- """
- self.descriptor = service_descriptor
-
- def BuildServiceStub(self, cls):
- """Constructs the stub class.
-
- Args:
- cls: The class that will be constructed.
- """
-
- def _ServiceStubInit(stub, rpc_channel):
- stub.rpc_channel = rpc_channel
- self.cls = cls
- cls.__init__ = _ServiceStubInit
- for method in self.descriptor.methods:
- setattr(cls, method.name, self._GenerateStubMethod(method))
-
- def _GenerateStubMethod(self, method):
- return (lambda inst, rpc_controller, request, callback=None:
- self._StubMethod(inst, method, rpc_controller, request, callback))
-
- def _StubMethod(self, stub, method_descriptor,
- rpc_controller, request, callback):
- """The body of all service methods in the generated stub class.
-
- Args:
- stub: Stub instance.
- method_descriptor: Descriptor of the invoked method.
- rpc_controller: Rpc controller to execute the method.
- request: Request protocol message.
- callback: A callback to execute when the method finishes.
- Returns:
- Response message (in case of blocking call).
- """
- return stub.rpc_channel.CallMethod(
- method_descriptor, rpc_controller, request,
- method_descriptor.output_type._concrete_class, callback)
+# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains metaclasses used to create protocol service and service stub +classes from ServiceDescriptor objects at runtime. + +The GeneratedServiceType and GeneratedServiceStubType metaclasses are used to +inject all useful functionality into the classes output by the protocol +compiler at compile-time. +""" + +__author__ = '[email protected] (Petar Petrov)' + + +class GeneratedServiceType(type): + + """Metaclass for service classes created at runtime from ServiceDescriptors. + + Implementations for all methods described in the Service class are added here + by this class. We also create properties to allow getting/setting all fields + in the protocol message. + + The protocol compiler currently uses this metaclass to create protocol service + classes at runtime. Clients can also manually create their own classes at + runtime, as in this example: + + mydescriptor = ServiceDescriptor(.....) + class MyProtoService(service.Service): + __metaclass__ = GeneratedServiceType + DESCRIPTOR = mydescriptor + myservice_instance = MyProtoService() + ... + """ + + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __init__(cls, name, bases, dictionary): + """Creates a message service class. + + Args: + name: Name of the class (ignored, but required by the metaclass + protocol). + bases: Base classes of the class being constructed. + dictionary: The class dictionary of the class being constructed. + dictionary[_DESCRIPTOR_KEY] must contain a ServiceDescriptor object + describing this protocol service type. + """ + # Don't do anything if this class doesn't have a descriptor. This happens + # when a service class is subclassed. + if GeneratedServiceType._DESCRIPTOR_KEY not in dictionary: + return + descriptor = dictionary[GeneratedServiceType._DESCRIPTOR_KEY] + service_builder = _ServiceBuilder(descriptor) + service_builder.BuildService(cls) + + +class GeneratedServiceStubType(GeneratedServiceType): + + """Metaclass for service stubs created at runtime from ServiceDescriptors. + + This class has similar responsibilities as GeneratedServiceType, except that + it creates the service stub classes. + """ + + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __init__(cls, name, bases, dictionary): + """Creates a message service stub class. + + Args: + name: Name of the class (ignored, here). + bases: Base classes of the class being constructed. + dictionary: The class dictionary of the class being constructed. + dictionary[_DESCRIPTOR_KEY] must contain a ServiceDescriptor object + describing this protocol service type. + """ + super(GeneratedServiceStubType, cls).__init__(name, bases, dictionary) + # Don't do anything if this class doesn't have a descriptor. This happens + # when a service stub is subclassed. + if GeneratedServiceStubType._DESCRIPTOR_KEY not in dictionary: + return + descriptor = dictionary[GeneratedServiceStubType._DESCRIPTOR_KEY] + service_stub_builder = _ServiceStubBuilder(descriptor) + service_stub_builder.BuildServiceStub(cls) + + +class _ServiceBuilder(object): + + """This class constructs a protocol service class using a service descriptor. + + Given a service descriptor, this class constructs a class that represents + the specified service descriptor. One service builder instance constructs + exactly one service class. That means all instances of that class share the + same builder. + """ + + def __init__(self, service_descriptor): + """Initializes an instance of the service class builder. + + Args: + service_descriptor: ServiceDescriptor to use when constructing the + service class. + """ + self.descriptor = service_descriptor + + def BuildService(self, cls): + """Constructs the service class. + + Args: + cls: The class that will be constructed. + """ + + # CallMethod needs to operate with an instance of the Service class. This + # internal wrapper function exists only to be able to pass the service + # instance to the method that does the real CallMethod work. + def _WrapCallMethod(srvc, method_descriptor, + rpc_controller, request, callback): + return self._CallMethod(srvc, method_descriptor, + rpc_controller, request, callback) + self.cls = cls + cls.CallMethod = _WrapCallMethod + cls.GetDescriptor = staticmethod(lambda: self.descriptor) + cls.GetDescriptor.__doc__ = "Returns the service descriptor." + cls.GetRequestClass = self._GetRequestClass + cls.GetResponseClass = self._GetResponseClass + for method in self.descriptor.methods: + setattr(cls, method.name, self._GenerateNonImplementedMethod(method)) + + def _CallMethod(self, srvc, method_descriptor, + rpc_controller, request, callback): + """Calls the method described by a given method descriptor. + + Args: + srvc: Instance of the service for which this method is called. + method_descriptor: Descriptor that represent the method to call. + rpc_controller: RPC controller to use for this method's execution. + request: Request protocol message. + callback: A callback to invoke after the method has completed. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'CallMethod() given method descriptor for wrong service type.') + method = getattr(srvc, method_descriptor.name) + return method(rpc_controller, request, callback) + + def _GetRequestClass(self, method_descriptor): + """Returns the class of the request protocol message. + + Args: + method_descriptor: Descriptor of the method for which to return the + request protocol message class. + + Returns: + A class that represents the input protocol message of the specified + method. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'GetRequestClass() given method descriptor for wrong service type.') + return method_descriptor.input_type._concrete_class + + def _GetResponseClass(self, method_descriptor): + """Returns the class of the response protocol message. + + Args: + method_descriptor: Descriptor of the method for which to return the + response protocol message class. + + Returns: + A class that represents the output protocol message of the specified + method. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'GetResponseClass() given method descriptor for wrong service type.') + return method_descriptor.output_type._concrete_class + + def _GenerateNonImplementedMethod(self, method): + """Generates and returns a method that can be set for a service methods. + + Args: + method: Descriptor of the service method for which a method is to be + generated. + + Returns: + A method that can be added to the service class. + """ + return lambda inst, rpc_controller, request, callback: ( + self._NonImplementedMethod(method.name, rpc_controller, callback)) + + def _NonImplementedMethod(self, method_name, rpc_controller, callback): + """The body of all methods in the generated service class. + + Args: + method_name: Name of the method being executed. + rpc_controller: RPC controller used to execute this method. + callback: A callback which will be invoked when the method finishes. + """ + rpc_controller.SetFailed('Method %s not implemented.' % method_name) + callback(None) + + +class _ServiceStubBuilder(object): + + """Constructs a protocol service stub class using a service descriptor. + + Given a service descriptor, this class constructs a suitable stub class. + A stub is just a type-safe wrapper around an RpcChannel which emulates a + local implementation of the service. + + One service stub builder instance constructs exactly one class. It means all + instances of that class share the same service stub builder. + """ + + def __init__(self, service_descriptor): + """Initializes an instance of the service stub class builder. + + Args: + service_descriptor: ServiceDescriptor to use when constructing the + stub class. + """ + self.descriptor = service_descriptor + + def BuildServiceStub(self, cls): + """Constructs the stub class. + + Args: + cls: The class that will be constructed. + """ + + def _ServiceStubInit(stub, rpc_channel): + stub.rpc_channel = rpc_channel + self.cls = cls + cls.__init__ = _ServiceStubInit + for method in self.descriptor.methods: + setattr(cls, method.name, self._GenerateStubMethod(method)) + + def _GenerateStubMethod(self, method): + return (lambda inst, rpc_controller, request, callback=None: + self._StubMethod(inst, method, rpc_controller, request, callback)) + + def _StubMethod(self, stub, method_descriptor, + rpc_controller, request, callback): + """The body of all service methods in the generated stub class. + + Args: + stub: Stub instance. + method_descriptor: Descriptor of the invoked method. + rpc_controller: Rpc controller to execute the method. + request: Request protocol message. + callback: A callback to execute when the method finishes. + Returns: + Response message (in case of blocking call). + """ + return stub.rpc_channel.CallMethod( + method_descriptor, rpc_controller, request, + method_descriptor.output_type._concrete_class, callback) diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/text_format.py b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/text_format.py index 5b9fe50e..cc6ac902 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/text_format.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/google/protobuf/text_format.py @@ -1,673 +1,673 @@ -# Protocol Buffers - Google's data interchange format
-# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
-#
-# Redistribution and use in source and binary forms, with or without
-# modification, are permitted provided that the following conditions are
-# met:
-#
-# * Redistributions of source code must retain the above copyright
-# notice, this list of conditions and the following disclaimer.
-# * Redistributions in binary form must reproduce the above
-# copyright notice, this list of conditions and the following disclaimer
-# in the documentation and/or other materials provided with the
-# distribution.
-# * Neither the name of Google Inc. nor the names of its
-# contributors may be used to endorse or promote products derived from
-# this software without specific prior written permission.
-#
-# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
-# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
-# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
-# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
-# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
-# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
-# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
-# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
-# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
-# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
-# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-
-"""Contains routines for printing protocol messages in text format."""
-
-__author__ = '[email protected] (Kenton Varda)'
-
-import cStringIO
-import re
-
-from collections import deque
-from google.protobuf.internal import type_checkers
-from google.protobuf import descriptor
-
-__all__ = [ 'MessageToString', 'PrintMessage', 'PrintField',
- 'PrintFieldValue', 'Merge' ]
-
-
-# Infinity and NaN are not explicitly supported by Python pre-2.6, and
-# float('inf') does not work on Windows (pre-2.6).
-_INFINITY = 1e10000 # overflows, thus will actually be infinity.
-_NAN = _INFINITY * 0
-
-
-class ParseError(Exception):
- """Thrown in case of ASCII parsing error."""
-
-
-def MessageToString(message):
- out = cStringIO.StringIO()
- PrintMessage(message, out)
- result = out.getvalue()
- out.close()
- return result
-
-
-def PrintMessage(message, out, indent = 0):
- for field, value in message.ListFields():
- if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- for element in value:
- PrintField(field, element, out, indent)
- else:
- PrintField(field, value, out, indent)
-
-
-def PrintField(field, value, out, indent = 0):
- """Print a single field name/value pair. For repeated fields, the value
- should be a single element."""
-
- out.write(' ' * indent);
- if field.is_extension:
- out.write('[')
- if (field.containing_type.GetOptions().message_set_wire_format and
- field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
- field.message_type == field.extension_scope and
- field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL):
- out.write(field.message_type.full_name)
- else:
- out.write(field.full_name)
- out.write(']')
- elif field.type == descriptor.FieldDescriptor.TYPE_GROUP:
- # For groups, use the capitalized name.
- out.write(field.message_type.name)
- else:
- out.write(field.name)
-
- if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- # The colon is optional in this case, but our cross-language golden files
- # don't include it.
- out.write(': ')
-
- PrintFieldValue(field, value, out, indent)
- out.write('\n')
-
-
-def PrintFieldValue(field, value, out, indent = 0):
- """Print a single field value (not including name). For repeated fields,
- the value should be a single element."""
-
- if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- out.write(' {\n')
- PrintMessage(value, out, indent + 2)
- out.write(' ' * indent + '}')
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM:
- out.write(field.enum_type.values_by_number[value].name)
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
- out.write('\"')
- out.write(_CEscape(value))
- out.write('\"')
- elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL:
- if value:
- out.write("true")
- else:
- out.write("false")
- else:
- out.write(str(value))
-
-
-def Merge(text, message):
- """Merges an ASCII representation of a protocol message into a message.
-
- Args:
- text: Message ASCII representation.
- message: A protocol buffer message to merge into.
-
- Raises:
- ParseError: On ASCII parsing problems.
- """
- tokenizer = _Tokenizer(text)
- while not tokenizer.AtEnd():
- _MergeField(tokenizer, message)
-
-
-def _MergeField(tokenizer, message):
- """Merges a single protocol message field into a message.
-
- Args:
- tokenizer: A tokenizer to parse the field name and values.
- message: A protocol message to record the data.
-
- Raises:
- ParseError: In case of ASCII parsing problems.
- """
- message_descriptor = message.DESCRIPTOR
- if tokenizer.TryConsume('['):
- name = [tokenizer.ConsumeIdentifier()]
- while tokenizer.TryConsume('.'):
- name.append(tokenizer.ConsumeIdentifier())
- name = '.'.join(name)
-
- if not message_descriptor.is_extendable:
- raise tokenizer.ParseErrorPreviousToken(
- 'Message type "%s" does not have extensions.' %
- message_descriptor.full_name)
- field = message.Extensions._FindExtensionByName(name)
- if not field:
- raise tokenizer.ParseErrorPreviousToken(
- 'Extension "%s" not registered.' % name)
- elif message_descriptor != field.containing_type:
- raise tokenizer.ParseErrorPreviousToken(
- 'Extension "%s" does not extend message type "%s".' % (
- name, message_descriptor.full_name))
- tokenizer.Consume(']')
- else:
- name = tokenizer.ConsumeIdentifier()
- field = message_descriptor.fields_by_name.get(name, None)
-
- # Group names are expected to be capitalized as they appear in the
- # .proto file, which actually matches their type names, not their field
- # names.
- if not field:
- field = message_descriptor.fields_by_name.get(name.lower(), None)
- if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP:
- field = None
-
- if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and
- field.message_type.name != name):
- field = None
-
- if not field:
- raise tokenizer.ParseErrorPreviousToken(
- 'Message type "%s" has no field named "%s".' % (
- message_descriptor.full_name, name))
-
- if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE:
- tokenizer.TryConsume(':')
-
- if tokenizer.TryConsume('<'):
- end_token = '>'
- else:
- tokenizer.Consume('{')
- end_token = '}'
-
- if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- if field.is_extension:
- sub_message = message.Extensions[field].add()
- else:
- sub_message = getattr(message, field.name).add()
- else:
- if field.is_extension:
- sub_message = message.Extensions[field]
- else:
- sub_message = getattr(message, field.name)
- sub_message.SetInParent()
-
- while not tokenizer.TryConsume(end_token):
- if tokenizer.AtEnd():
- raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token))
- _MergeField(tokenizer, sub_message)
- else:
- _MergeScalarField(tokenizer, message, field)
-
-
-def _MergeScalarField(tokenizer, message, field):
- """Merges a single protocol message scalar field into a message.
-
- Args:
- tokenizer: A tokenizer to parse the field value.
- message: A protocol message to record the data.
- field: The descriptor of the field to be merged.
-
- Raises:
- ParseError: In case of ASCII parsing problems.
- RuntimeError: On runtime errors.
- """
- tokenizer.Consume(':')
- value = None
-
- if field.type in (descriptor.FieldDescriptor.TYPE_INT32,
- descriptor.FieldDescriptor.TYPE_SINT32,
- descriptor.FieldDescriptor.TYPE_SFIXED32):
- value = tokenizer.ConsumeInt32()
- elif field.type in (descriptor.FieldDescriptor.TYPE_INT64,
- descriptor.FieldDescriptor.TYPE_SINT64,
- descriptor.FieldDescriptor.TYPE_SFIXED64):
- value = tokenizer.ConsumeInt64()
- elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32,
- descriptor.FieldDescriptor.TYPE_FIXED32):
- value = tokenizer.ConsumeUint32()
- elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64,
- descriptor.FieldDescriptor.TYPE_FIXED64):
- value = tokenizer.ConsumeUint64()
- elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT,
- descriptor.FieldDescriptor.TYPE_DOUBLE):
- value = tokenizer.ConsumeFloat()
- elif field.type == descriptor.FieldDescriptor.TYPE_BOOL:
- value = tokenizer.ConsumeBool()
- elif field.type == descriptor.FieldDescriptor.TYPE_STRING:
- value = tokenizer.ConsumeString()
- elif field.type == descriptor.FieldDescriptor.TYPE_BYTES:
- value = tokenizer.ConsumeByteString()
- elif field.type == descriptor.FieldDescriptor.TYPE_ENUM:
- # Enum can be specified by a number (the enum value), or by
- # a string literal (the enum name).
- enum_descriptor = field.enum_type
- if tokenizer.LookingAtInteger():
- number = tokenizer.ConsumeInt32()
- enum_value = enum_descriptor.values_by_number.get(number, None)
- if enum_value is None:
- raise tokenizer.ParseErrorPreviousToken(
- 'Enum type "%s" has no value with number %d.' % (
- enum_descriptor.full_name, number))
- else:
- identifier = tokenizer.ConsumeIdentifier()
- enum_value = enum_descriptor.values_by_name.get(identifier, None)
- if enum_value is None:
- raise tokenizer.ParseErrorPreviousToken(
- 'Enum type "%s" has no value named %s.' % (
- enum_descriptor.full_name, identifier))
- value = enum_value.number
- else:
- raise RuntimeError('Unknown field type %d' % field.type)
-
- if field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- if field.is_extension:
- message.Extensions[field].append(value)
- else:
- getattr(message, field.name).append(value)
- else:
- if field.is_extension:
- message.Extensions[field] = value
- else:
- setattr(message, field.name, value)
-
-
-class _Tokenizer(object):
- """Protocol buffer ASCII representation tokenizer.
-
- This class handles the lower level string parsing by splitting it into
- meaningful tokens.
-
- It was directly ported from the Java protocol buffer API.
- """
-
- _WHITESPACE = re.compile('(\\s|(#.*$))+', re.MULTILINE)
- _TOKEN = re.compile(
- '[a-zA-Z_][0-9a-zA-Z_+-]*|' # an identifier
- '[0-9+-][0-9a-zA-Z_.+-]*|' # a number
- '\"([^\"\n\\\\]|\\\\.)*(\"|\\\\?$)|' # a double-quoted string
- '\'([^\'\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string
- _IDENTIFIER = re.compile('\w+')
- _INTEGER_CHECKERS = [type_checkers.Uint32ValueChecker(),
- type_checkers.Int32ValueChecker(),
- type_checkers.Uint64ValueChecker(),
- type_checkers.Int64ValueChecker()]
- _FLOAT_INFINITY = re.compile('-?inf(inity)?f?', re.IGNORECASE)
- _FLOAT_NAN = re.compile("nanf?", re.IGNORECASE)
-
- def __init__(self, text_message):
- self._text_message = text_message
-
- self._position = 0
- self._line = -1
- self._column = 0
- self._token_start = None
- self.token = ''
- self._lines = deque(text_message.split('\n'))
- self._current_line = ''
- self._previous_line = 0
- self._previous_column = 0
- self._SkipWhitespace()
- self.NextToken()
-
- def AtEnd(self):
- """Checks the end of the text was reached.
-
- Returns:
- True iff the end was reached.
- """
- return not self._lines and not self._current_line
-
- def _PopLine(self):
- while not self._current_line:
- if not self._lines:
- self._current_line = ''
- return
- self._line += 1
- self._column = 0
- self._current_line = self._lines.popleft()
-
- def _SkipWhitespace(self):
- while True:
- self._PopLine()
- match = re.match(self._WHITESPACE, self._current_line)
- if not match:
- break
- length = len(match.group(0))
- self._current_line = self._current_line[length:]
- self._column += length
-
- def TryConsume(self, token):
- """Tries to consume a given piece of text.
-
- Args:
- token: Text to consume.
-
- Returns:
- True iff the text was consumed.
- """
- if self.token == token:
- self.NextToken()
- return True
- return False
-
- def Consume(self, token):
- """Consumes a piece of text.
-
- Args:
- token: Text to consume.
-
- Raises:
- ParseError: If the text couldn't be consumed.
- """
- if not self.TryConsume(token):
- raise self._ParseError('Expected "%s".' % token)
-
- def LookingAtInteger(self):
- """Checks if the current token is an integer.
-
- Returns:
- True iff the current token is an integer.
- """
- if not self.token:
- return False
- c = self.token[0]
- return (c >= '0' and c <= '9') or c == '-' or c == '+'
-
- def ConsumeIdentifier(self):
- """Consumes protocol message field identifier.
-
- Returns:
- Identifier string.
-
- Raises:
- ParseError: If an identifier couldn't be consumed.
- """
- result = self.token
- if not re.match(self._IDENTIFIER, result):
- raise self._ParseError('Expected identifier.')
- self.NextToken()
- return result
-
- def ConsumeInt32(self):
- """Consumes a signed 32bit integer number.
-
- Returns:
- The integer parsed.
-
- Raises:
- ParseError: If a signed 32bit integer couldn't be consumed.
- """
- try:
- result = self._ParseInteger(self.token, is_signed=True, is_long=False)
- except ValueError, e:
- raise self._IntegerParseError(e)
- self.NextToken()
- return result
-
- def ConsumeUint32(self):
- """Consumes an unsigned 32bit integer number.
-
- Returns:
- The integer parsed.
-
- Raises:
- ParseError: If an unsigned 32bit integer couldn't be consumed.
- """
- try:
- result = self._ParseInteger(self.token, is_signed=False, is_long=False)
- except ValueError, e:
- raise self._IntegerParseError(e)
- self.NextToken()
- return result
-
- def ConsumeInt64(self):
- """Consumes a signed 64bit integer number.
-
- Returns:
- The integer parsed.
-
- Raises:
- ParseError: If a signed 64bit integer couldn't be consumed.
- """
- try:
- result = self._ParseInteger(self.token, is_signed=True, is_long=True)
- except ValueError, e:
- raise self._IntegerParseError(e)
- self.NextToken()
- return result
-
- def ConsumeUint64(self):
- """Consumes an unsigned 64bit integer number.
-
- Returns:
- The integer parsed.
-
- Raises:
- ParseError: If an unsigned 64bit integer couldn't be consumed.
- """
- try:
- result = self._ParseInteger(self.token, is_signed=False, is_long=True)
- except ValueError, e:
- raise self._IntegerParseError(e)
- self.NextToken()
- return result
-
- def ConsumeFloat(self):
- """Consumes an floating point number.
-
- Returns:
- The number parsed.
-
- Raises:
- ParseError: If a floating point number couldn't be consumed.
- """
- text = self.token
- if re.match(self._FLOAT_INFINITY, text):
- self.NextToken()
- if text.startswith('-'):
- return -_INFINITY
- return _INFINITY
-
- if re.match(self._FLOAT_NAN, text):
- self.NextToken()
- return _NAN
-
- try:
- result = float(text)
- except ValueError, e:
- raise self._FloatParseError(e)
- self.NextToken()
- return result
-
- def ConsumeBool(self):
- """Consumes a boolean value.
-
- Returns:
- The bool parsed.
-
- Raises:
- ParseError: If a boolean value couldn't be consumed.
- """
- if self.token == 'true':
- self.NextToken()
- return True
- elif self.token == 'false':
- self.NextToken()
- return False
- else:
- raise self._ParseError('Expected "true" or "false".')
-
- def ConsumeString(self):
- """Consumes a string value.
-
- Returns:
- The string parsed.
-
- Raises:
- ParseError: If a string value couldn't be consumed.
- """
- return unicode(self.ConsumeByteString(), 'utf-8')
-
- def ConsumeByteString(self):
- """Consumes a byte array value.
-
- Returns:
- The array parsed (as a string).
-
- Raises:
- ParseError: If a byte array value couldn't be consumed.
- """
- list = [self._ConsumeSingleByteString()]
- while len(self.token) > 0 and self.token[0] in ('\'', '"'):
- list.append(self._ConsumeSingleByteString())
- return "".join(list)
-
- def _ConsumeSingleByteString(self):
- """Consume one token of a string literal.
-
- String literals (whether bytes or text) can come in multiple adjacent
- tokens which are automatically concatenated, like in C or Python. This
- method only consumes one token.
- """
- text = self.token
- if len(text) < 1 or text[0] not in ('\'', '"'):
- raise self._ParseError('Exptected string.')
-
- if len(text) < 2 or text[-1] != text[0]:
- raise self._ParseError('String missing ending quote.')
-
- try:
- result = _CUnescape(text[1:-1])
- except ValueError, e:
- raise self._ParseError(str(e))
- self.NextToken()
- return result
-
- def _ParseInteger(self, text, is_signed=False, is_long=False):
- """Parses an integer.
-
- Args:
- text: The text to parse.
- is_signed: True if a signed integer must be parsed.
- is_long: True if a long integer must be parsed.
-
- Returns:
- The integer value.
-
- Raises:
- ValueError: Thrown Iff the text is not a valid integer.
- """
- pos = 0
- if text.startswith('-'):
- pos += 1
-
- base = 10
- if text.startswith('0x', pos) or text.startswith('0X', pos):
- base = 16
- elif text.startswith('0', pos):
- base = 8
-
- # Do the actual parsing. Exception handling is propagated to caller.
- result = int(text, base)
-
- # Check if the integer is sane. Exceptions handled by callers.
- checker = self._INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)]
- checker.CheckValue(result)
- return result
-
- def ParseErrorPreviousToken(self, message):
- """Creates and *returns* a ParseError for the previously read token.
-
- Args:
- message: A message to set for the exception.
-
- Returns:
- A ParseError instance.
- """
- return ParseError('%d:%d : %s' % (
- self._previous_line + 1, self._previous_column + 1, message))
-
- def _ParseError(self, message):
- """Creates and *returns* a ParseError for the current token."""
- return ParseError('%d:%d : %s' % (
- self._line + 1, self._column + 1, message))
-
- def _IntegerParseError(self, e):
- return self._ParseError('Couldn\'t parse integer: ' + str(e))
-
- def _FloatParseError(self, e):
- return self._ParseError('Couldn\'t parse number: ' + str(e))
-
- def NextToken(self):
- """Reads the next meaningful token."""
- self._previous_line = self._line
- self._previous_column = self._column
- if self.AtEnd():
- self.token = ''
- return
- self._column += len(self.token)
-
- # Make sure there is data to work on.
- self._PopLine()
-
- match = re.match(self._TOKEN, self._current_line)
- if match:
- token = match.group(0)
- self._current_line = self._current_line[len(token):]
- self.token = token
- else:
- self.token = self._current_line[0]
- self._current_line = self._current_line[1:]
- self._SkipWhitespace()
-
-
-# text.encode('string_escape') does not seem to satisfy our needs as it
-# encodes unprintable characters using two-digit hex escapes whereas our
-# C++ unescaping function allows hex escapes to be any length. So,
-# "\0011".encode('string_escape') ends up being "\\x011", which will be
-# decoded in C++ as a single-character string with char code 0x11.
-def _CEscape(text):
- def escape(c):
- o = ord(c)
- if o == 10: return r"\n" # optional escape
- if o == 13: return r"\r" # optional escape
- if o == 9: return r"\t" # optional escape
- if o == 39: return r"\'" # optional escape
-
- if o == 34: return r'\"' # necessary escape
- if o == 92: return r"\\" # necessary escape
-
- if o >= 127 or o < 32: return "\\%03o" % o # necessary escapes
- return c
- return "".join([escape(c) for c in text])
-
-
-_CUNESCAPE_HEX = re.compile('\\\\x([0-9a-fA-F]{2}|[0-9a-f-A-F])')
-
-
-def _CUnescape(text):
- def ReplaceHex(m):
- return chr(int(m.group(0)[2:], 16))
- # This is required because the 'string_escape' encoding doesn't
- # allow single-digit hex escapes (like '\xf').
- result = _CUNESCAPE_HEX.sub(ReplaceHex, text)
- return result.decode('string_escape')
+# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains routines for printing protocol messages in text format.""" + +__author__ = '[email protected] (Kenton Varda)' + +import cStringIO +import re + +from collections import deque +from google.protobuf.internal import type_checkers +from google.protobuf import descriptor + +__all__ = [ 'MessageToString', 'PrintMessage', 'PrintField', + 'PrintFieldValue', 'Merge' ] + + +# Infinity and NaN are not explicitly supported by Python pre-2.6, and +# float('inf') does not work on Windows (pre-2.6). +_INFINITY = 1e10000 # overflows, thus will actually be infinity. +_NAN = _INFINITY * 0 + + +class ParseError(Exception): + """Thrown in case of ASCII parsing error.""" + + +def MessageToString(message): + out = cStringIO.StringIO() + PrintMessage(message, out) + result = out.getvalue() + out.close() + return result + + +def PrintMessage(message, out, indent = 0): + for field, value in message.ListFields(): + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + for element in value: + PrintField(field, element, out, indent) + else: + PrintField(field, value, out, indent) + + +def PrintField(field, value, out, indent = 0): + """Print a single field name/value pair. For repeated fields, the value + should be a single element.""" + + out.write(' ' * indent); + if field.is_extension: + out.write('[') + if (field.containing_type.GetOptions().message_set_wire_format and + field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.message_type == field.extension_scope and + field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): + out.write(field.message_type.full_name) + else: + out.write(field.full_name) + out.write(']') + elif field.type == descriptor.FieldDescriptor.TYPE_GROUP: + # For groups, use the capitalized name. + out.write(field.message_type.name) + else: + out.write(field.name) + + if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + # The colon is optional in this case, but our cross-language golden files + # don't include it. + out.write(': ') + + PrintFieldValue(field, value, out, indent) + out.write('\n') + + +def PrintFieldValue(field, value, out, indent = 0): + """Print a single field value (not including name). For repeated fields, + the value should be a single element.""" + + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + out.write(' {\n') + PrintMessage(value, out, indent + 2) + out.write(' ' * indent + '}') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + out.write(field.enum_type.values_by_number[value].name) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + out.write('\"') + out.write(_CEscape(value)) + out.write('\"') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + if value: + out.write("true") + else: + out.write("false") + else: + out.write(str(value)) + + +def Merge(text, message): + """Merges an ASCII representation of a protocol message into a message. + + Args: + text: Message ASCII representation. + message: A protocol buffer message to merge into. + + Raises: + ParseError: On ASCII parsing problems. + """ + tokenizer = _Tokenizer(text) + while not tokenizer.AtEnd(): + _MergeField(tokenizer, message) + + +def _MergeField(tokenizer, message): + """Merges a single protocol message field into a message. + + Args: + tokenizer: A tokenizer to parse the field name and values. + message: A protocol message to record the data. + + Raises: + ParseError: In case of ASCII parsing problems. + """ + message_descriptor = message.DESCRIPTOR + if tokenizer.TryConsume('['): + name = [tokenizer.ConsumeIdentifier()] + while tokenizer.TryConsume('.'): + name.append(tokenizer.ConsumeIdentifier()) + name = '.'.join(name) + + if not message_descriptor.is_extendable: + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" does not have extensions.' % + message_descriptor.full_name) + field = message.Extensions._FindExtensionByName(name) + if not field: + raise tokenizer.ParseErrorPreviousToken( + 'Extension "%s" not registered.' % name) + elif message_descriptor != field.containing_type: + raise tokenizer.ParseErrorPreviousToken( + 'Extension "%s" does not extend message type "%s".' % ( + name, message_descriptor.full_name)) + tokenizer.Consume(']') + else: + name = tokenizer.ConsumeIdentifier() + field = message_descriptor.fields_by_name.get(name, None) + + # Group names are expected to be capitalized as they appear in the + # .proto file, which actually matches their type names, not their field + # names. + if not field: + field = message_descriptor.fields_by_name.get(name.lower(), None) + if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP: + field = None + + if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and + field.message_type.name != name): + field = None + + if not field: + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" has no field named "%s".' % ( + message_descriptor.full_name, name)) + + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + tokenizer.TryConsume(':') + + if tokenizer.TryConsume('<'): + end_token = '>' + else: + tokenizer.Consume('{') + end_token = '}' + + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_extension: + sub_message = message.Extensions[field].add() + else: + sub_message = getattr(message, field.name).add() + else: + if field.is_extension: + sub_message = message.Extensions[field] + else: + sub_message = getattr(message, field.name) + sub_message.SetInParent() + + while not tokenizer.TryConsume(end_token): + if tokenizer.AtEnd(): + raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token)) + _MergeField(tokenizer, sub_message) + else: + _MergeScalarField(tokenizer, message, field) + + +def _MergeScalarField(tokenizer, message, field): + """Merges a single protocol message scalar field into a message. + + Args: + tokenizer: A tokenizer to parse the field value. + message: A protocol message to record the data. + field: The descriptor of the field to be merged. + + Raises: + ParseError: In case of ASCII parsing problems. + RuntimeError: On runtime errors. + """ + tokenizer.Consume(':') + value = None + + if field.type in (descriptor.FieldDescriptor.TYPE_INT32, + descriptor.FieldDescriptor.TYPE_SINT32, + descriptor.FieldDescriptor.TYPE_SFIXED32): + value = tokenizer.ConsumeInt32() + elif field.type in (descriptor.FieldDescriptor.TYPE_INT64, + descriptor.FieldDescriptor.TYPE_SINT64, + descriptor.FieldDescriptor.TYPE_SFIXED64): + value = tokenizer.ConsumeInt64() + elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32, + descriptor.FieldDescriptor.TYPE_FIXED32): + value = tokenizer.ConsumeUint32() + elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64, + descriptor.FieldDescriptor.TYPE_FIXED64): + value = tokenizer.ConsumeUint64() + elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT, + descriptor.FieldDescriptor.TYPE_DOUBLE): + value = tokenizer.ConsumeFloat() + elif field.type == descriptor.FieldDescriptor.TYPE_BOOL: + value = tokenizer.ConsumeBool() + elif field.type == descriptor.FieldDescriptor.TYPE_STRING: + value = tokenizer.ConsumeString() + elif field.type == descriptor.FieldDescriptor.TYPE_BYTES: + value = tokenizer.ConsumeByteString() + elif field.type == descriptor.FieldDescriptor.TYPE_ENUM: + # Enum can be specified by a number (the enum value), or by + # a string literal (the enum name). + enum_descriptor = field.enum_type + if tokenizer.LookingAtInteger(): + number = tokenizer.ConsumeInt32() + enum_value = enum_descriptor.values_by_number.get(number, None) + if enum_value is None: + raise tokenizer.ParseErrorPreviousToken( + 'Enum type "%s" has no value with number %d.' % ( + enum_descriptor.full_name, number)) + else: + identifier = tokenizer.ConsumeIdentifier() + enum_value = enum_descriptor.values_by_name.get(identifier, None) + if enum_value is None: + raise tokenizer.ParseErrorPreviousToken( + 'Enum type "%s" has no value named %s.' % ( + enum_descriptor.full_name, identifier)) + value = enum_value.number + else: + raise RuntimeError('Unknown field type %d' % field.type) + + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_extension: + message.Extensions[field].append(value) + else: + getattr(message, field.name).append(value) + else: + if field.is_extension: + message.Extensions[field] = value + else: + setattr(message, field.name, value) + + +class _Tokenizer(object): + """Protocol buffer ASCII representation tokenizer. + + This class handles the lower level string parsing by splitting it into + meaningful tokens. + + It was directly ported from the Java protocol buffer API. + """ + + _WHITESPACE = re.compile('(\\s|(#.*$))+', re.MULTILINE) + _TOKEN = re.compile( + '[a-zA-Z_][0-9a-zA-Z_+-]*|' # an identifier + '[0-9+-][0-9a-zA-Z_.+-]*|' # a number + '\"([^\"\n\\\\]|\\\\.)*(\"|\\\\?$)|' # a double-quoted string + '\'([^\'\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string + _IDENTIFIER = re.compile('\w+') + _INTEGER_CHECKERS = [type_checkers.Uint32ValueChecker(), + type_checkers.Int32ValueChecker(), + type_checkers.Uint64ValueChecker(), + type_checkers.Int64ValueChecker()] + _FLOAT_INFINITY = re.compile('-?inf(inity)?f?', re.IGNORECASE) + _FLOAT_NAN = re.compile("nanf?", re.IGNORECASE) + + def __init__(self, text_message): + self._text_message = text_message + + self._position = 0 + self._line = -1 + self._column = 0 + self._token_start = None + self.token = '' + self._lines = deque(text_message.split('\n')) + self._current_line = '' + self._previous_line = 0 + self._previous_column = 0 + self._SkipWhitespace() + self.NextToken() + + def AtEnd(self): + """Checks the end of the text was reached. + + Returns: + True iff the end was reached. + """ + return not self._lines and not self._current_line + + def _PopLine(self): + while not self._current_line: + if not self._lines: + self._current_line = '' + return + self._line += 1 + self._column = 0 + self._current_line = self._lines.popleft() + + def _SkipWhitespace(self): + while True: + self._PopLine() + match = re.match(self._WHITESPACE, self._current_line) + if not match: + break + length = len(match.group(0)) + self._current_line = self._current_line[length:] + self._column += length + + def TryConsume(self, token): + """Tries to consume a given piece of text. + + Args: + token: Text to consume. + + Returns: + True iff the text was consumed. + """ + if self.token == token: + self.NextToken() + return True + return False + + def Consume(self, token): + """Consumes a piece of text. + + Args: + token: Text to consume. + + Raises: + ParseError: If the text couldn't be consumed. + """ + if not self.TryConsume(token): + raise self._ParseError('Expected "%s".' % token) + + def LookingAtInteger(self): + """Checks if the current token is an integer. + + Returns: + True iff the current token is an integer. + """ + if not self.token: + return False + c = self.token[0] + return (c >= '0' and c <= '9') or c == '-' or c == '+' + + def ConsumeIdentifier(self): + """Consumes protocol message field identifier. + + Returns: + Identifier string. + + Raises: + ParseError: If an identifier couldn't be consumed. + """ + result = self.token + if not re.match(self._IDENTIFIER, result): + raise self._ParseError('Expected identifier.') + self.NextToken() + return result + + def ConsumeInt32(self): + """Consumes a signed 32bit integer number. + + Returns: + The integer parsed. + + Raises: + ParseError: If a signed 32bit integer couldn't be consumed. + """ + try: + result = self._ParseInteger(self.token, is_signed=True, is_long=False) + except ValueError, e: + raise self._IntegerParseError(e) + self.NextToken() + return result + + def ConsumeUint32(self): + """Consumes an unsigned 32bit integer number. + + Returns: + The integer parsed. + + Raises: + ParseError: If an unsigned 32bit integer couldn't be consumed. + """ + try: + result = self._ParseInteger(self.token, is_signed=False, is_long=False) + except ValueError, e: + raise self._IntegerParseError(e) + self.NextToken() + return result + + def ConsumeInt64(self): + """Consumes a signed 64bit integer number. + + Returns: + The integer parsed. + + Raises: + ParseError: If a signed 64bit integer couldn't be consumed. + """ + try: + result = self._ParseInteger(self.token, is_signed=True, is_long=True) + except ValueError, e: + raise self._IntegerParseError(e) + self.NextToken() + return result + + def ConsumeUint64(self): + """Consumes an unsigned 64bit integer number. + + Returns: + The integer parsed. + + Raises: + ParseError: If an unsigned 64bit integer couldn't be consumed. + """ + try: + result = self._ParseInteger(self.token, is_signed=False, is_long=True) + except ValueError, e: + raise self._IntegerParseError(e) + self.NextToken() + return result + + def ConsumeFloat(self): + """Consumes an floating point number. + + Returns: + The number parsed. + + Raises: + ParseError: If a floating point number couldn't be consumed. + """ + text = self.token + if re.match(self._FLOAT_INFINITY, text): + self.NextToken() + if text.startswith('-'): + return -_INFINITY + return _INFINITY + + if re.match(self._FLOAT_NAN, text): + self.NextToken() + return _NAN + + try: + result = float(text) + except ValueError, e: + raise self._FloatParseError(e) + self.NextToken() + return result + + def ConsumeBool(self): + """Consumes a boolean value. + + Returns: + The bool parsed. + + Raises: + ParseError: If a boolean value couldn't be consumed. + """ + if self.token == 'true': + self.NextToken() + return True + elif self.token == 'false': + self.NextToken() + return False + else: + raise self._ParseError('Expected "true" or "false".') + + def ConsumeString(self): + """Consumes a string value. + + Returns: + The string parsed. + + Raises: + ParseError: If a string value couldn't be consumed. + """ + return unicode(self.ConsumeByteString(), 'utf-8') + + def ConsumeByteString(self): + """Consumes a byte array value. + + Returns: + The array parsed (as a string). + + Raises: + ParseError: If a byte array value couldn't be consumed. + """ + list = [self._ConsumeSingleByteString()] + while len(self.token) > 0 and self.token[0] in ('\'', '"'): + list.append(self._ConsumeSingleByteString()) + return "".join(list) + + def _ConsumeSingleByteString(self): + """Consume one token of a string literal. + + String literals (whether bytes or text) can come in multiple adjacent + tokens which are automatically concatenated, like in C or Python. This + method only consumes one token. + """ + text = self.token + if len(text) < 1 or text[0] not in ('\'', '"'): + raise self._ParseError('Exptected string.') + + if len(text) < 2 or text[-1] != text[0]: + raise self._ParseError('String missing ending quote.') + + try: + result = _CUnescape(text[1:-1]) + except ValueError, e: + raise self._ParseError(str(e)) + self.NextToken() + return result + + def _ParseInteger(self, text, is_signed=False, is_long=False): + """Parses an integer. + + Args: + text: The text to parse. + is_signed: True if a signed integer must be parsed. + is_long: True if a long integer must be parsed. + + Returns: + The integer value. + + Raises: + ValueError: Thrown Iff the text is not a valid integer. + """ + pos = 0 + if text.startswith('-'): + pos += 1 + + base = 10 + if text.startswith('0x', pos) or text.startswith('0X', pos): + base = 16 + elif text.startswith('0', pos): + base = 8 + + # Do the actual parsing. Exception handling is propagated to caller. + result = int(text, base) + + # Check if the integer is sane. Exceptions handled by callers. + checker = self._INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)] + checker.CheckValue(result) + return result + + def ParseErrorPreviousToken(self, message): + """Creates and *returns* a ParseError for the previously read token. + + Args: + message: A message to set for the exception. + + Returns: + A ParseError instance. + """ + return ParseError('%d:%d : %s' % ( + self._previous_line + 1, self._previous_column + 1, message)) + + def _ParseError(self, message): + """Creates and *returns* a ParseError for the current token.""" + return ParseError('%d:%d : %s' % ( + self._line + 1, self._column + 1, message)) + + def _IntegerParseError(self, e): + return self._ParseError('Couldn\'t parse integer: ' + str(e)) + + def _FloatParseError(self, e): + return self._ParseError('Couldn\'t parse number: ' + str(e)) + + def NextToken(self): + """Reads the next meaningful token.""" + self._previous_line = self._line + self._previous_column = self._column + if self.AtEnd(): + self.token = '' + return + self._column += len(self.token) + + # Make sure there is data to work on. + self._PopLine() + + match = re.match(self._TOKEN, self._current_line) + if match: + token = match.group(0) + self._current_line = self._current_line[len(token):] + self.token = token + else: + self.token = self._current_line[0] + self._current_line = self._current_line[1:] + self._SkipWhitespace() + + +# text.encode('string_escape') does not seem to satisfy our needs as it +# encodes unprintable characters using two-digit hex escapes whereas our +# C++ unescaping function allows hex escapes to be any length. So, +# "\0011".encode('string_escape') ends up being "\\x011", which will be +# decoded in C++ as a single-character string with char code 0x11. +def _CEscape(text): + def escape(c): + o = ord(c) + if o == 10: return r"\n" # optional escape + if o == 13: return r"\r" # optional escape + if o == 9: return r"\t" # optional escape + if o == 39: return r"\'" # optional escape + + if o == 34: return r'\"' # necessary escape + if o == 92: return r"\\" # necessary escape + + if o >= 127 or o < 32: return "\\%03o" % o # necessary escapes + return c + return "".join([escape(c) for c in text]) + + +_CUNESCAPE_HEX = re.compile('\\\\x([0-9a-fA-F]{2}|[0-9a-f-A-F])') + + +def _CUnescape(text): + def ReplaceHex(m): + return chr(int(m.group(0)[2:], 16)) + # This is required because the 'string_escape' encoding doesn't + # allow single-digit hex escapes (like '\xf'). + result = _CUNESCAPE_HEX.sub(ReplaceHex, text) + return result.decode('string_escape') diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/mox.py b/mp/src/thirdparty/protobuf-2.3.0/python/mox.py index a577647b..ce80ba50 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/mox.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/mox.py @@ -1,1401 +1,1401 @@ -#!/usr/bin/python2.4
-#
-# Copyright 2008 Google Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# This file is used for testing. The original is at:
-# http://code.google.com/p/pymox/
-
-"""Mox, an object-mocking framework for Python.
-
-Mox works in the record-replay-verify paradigm. When you first create
-a mock object, it is in record mode. You then programmatically set
-the expected behavior of the mock object (what methods are to be
-called on it, with what parameters, what they should return, and in
-what order).
-
-Once you have set up the expected mock behavior, you put it in replay
-mode. Now the mock responds to method calls just as you told it to.
-If an unexpected method (or an expected method with unexpected
-parameters) is called, then an exception will be raised.
-
-Once you are done interacting with the mock, you need to verify that
-all the expected interactions occured. (Maybe your code exited
-prematurely without calling some cleanup method!) The verify phase
-ensures that every expected method was called; otherwise, an exception
-will be raised.
-
-Suggested usage / workflow:
-
- # Create Mox factory
- my_mox = Mox()
-
- # Create a mock data access object
- mock_dao = my_mox.CreateMock(DAOClass)
-
- # Set up expected behavior
- mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person)
- mock_dao.DeletePerson(person)
-
- # Put mocks in replay mode
- my_mox.ReplayAll()
-
- # Inject mock object and run test
- controller.SetDao(mock_dao)
- controller.DeletePersonById('1')
-
- # Verify all methods were called as expected
- my_mox.VerifyAll()
-"""
-
-from collections import deque
-import re
-import types
-import unittest
-
-import stubout
-
-class Error(AssertionError):
- """Base exception for this module."""
-
- pass
-
-
-class ExpectedMethodCallsError(Error):
- """Raised when Verify() is called before all expected methods have been called
- """
-
- def __init__(self, expected_methods):
- """Init exception.
-
- Args:
- # expected_methods: A sequence of MockMethod objects that should have been
- # called.
- expected_methods: [MockMethod]
-
- Raises:
- ValueError: if expected_methods contains no methods.
- """
-
- if not expected_methods:
- raise ValueError("There must be at least one expected method")
- Error.__init__(self)
- self._expected_methods = expected_methods
-
- def __str__(self):
- calls = "\n".join(["%3d. %s" % (i, m)
- for i, m in enumerate(self._expected_methods)])
- return "Verify: Expected methods never called:\n%s" % (calls,)
-
-
-class UnexpectedMethodCallError(Error):
- """Raised when an unexpected method is called.
-
- This can occur if a method is called with incorrect parameters, or out of the
- specified order.
- """
-
- def __init__(self, unexpected_method, expected):
- """Init exception.
-
- Args:
- # unexpected_method: MockMethod that was called but was not at the head of
- # the expected_method queue.
- # expected: MockMethod or UnorderedGroup the method should have
- # been in.
- unexpected_method: MockMethod
- expected: MockMethod or UnorderedGroup
- """
-
- Error.__init__(self)
- self._unexpected_method = unexpected_method
- self._expected = expected
-
- def __str__(self):
- return "Unexpected method call: %s. Expecting: %s" % \
- (self._unexpected_method, self._expected)
-
-
-class UnknownMethodCallError(Error):
- """Raised if an unknown method is requested of the mock object."""
-
- def __init__(self, unknown_method_name):
- """Init exception.
-
- Args:
- # unknown_method_name: Method call that is not part of the mocked class's
- # public interface.
- unknown_method_name: str
- """
-
- Error.__init__(self)
- self._unknown_method_name = unknown_method_name
-
- def __str__(self):
- return "Method called is not a member of the object: %s" % \
- self._unknown_method_name
-
-
-class Mox(object):
- """Mox: a factory for creating mock objects."""
-
- # A list of types that should be stubbed out with MockObjects (as
- # opposed to MockAnythings).
- _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType,
- types.ObjectType, types.TypeType]
-
- def __init__(self):
- """Initialize a new Mox."""
-
- self._mock_objects = []
- self.stubs = stubout.StubOutForTesting()
-
- def CreateMock(self, class_to_mock):
- """Create a new mock object.
-
- Args:
- # class_to_mock: the class to be mocked
- class_to_mock: class
-
- Returns:
- MockObject that can be used as the class_to_mock would be.
- """
-
- new_mock = MockObject(class_to_mock)
- self._mock_objects.append(new_mock)
- return new_mock
-
- def CreateMockAnything(self):
- """Create a mock that will accept any method calls.
-
- This does not enforce an interface.
- """
-
- new_mock = MockAnything()
- self._mock_objects.append(new_mock)
- return new_mock
-
- def ReplayAll(self):
- """Set all mock objects to replay mode."""
-
- for mock_obj in self._mock_objects:
- mock_obj._Replay()
-
-
- def VerifyAll(self):
- """Call verify on all mock objects created."""
-
- for mock_obj in self._mock_objects:
- mock_obj._Verify()
-
- def ResetAll(self):
- """Call reset on all mock objects. This does not unset stubs."""
-
- for mock_obj in self._mock_objects:
- mock_obj._Reset()
-
- def StubOutWithMock(self, obj, attr_name, use_mock_anything=False):
- """Replace a method, attribute, etc. with a Mock.
-
- This will replace a class or module with a MockObject, and everything else
- (method, function, etc) with a MockAnything. This can be overridden to
- always use a MockAnything by setting use_mock_anything to True.
-
- Args:
- obj: A Python object (class, module, instance, callable).
- attr_name: str. The name of the attribute to replace with a mock.
- use_mock_anything: bool. True if a MockAnything should be used regardless
- of the type of attribute.
- """
-
- attr_to_replace = getattr(obj, attr_name)
- if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything:
- stub = self.CreateMock(attr_to_replace)
- else:
- stub = self.CreateMockAnything()
-
- self.stubs.Set(obj, attr_name, stub)
-
- def UnsetStubs(self):
- """Restore stubs to their original state."""
-
- self.stubs.UnsetAll()
-
-def Replay(*args):
- """Put mocks into Replay mode.
-
- Args:
- # args is any number of mocks to put into replay mode.
- """
-
- for mock in args:
- mock._Replay()
-
-
-def Verify(*args):
- """Verify mocks.
-
- Args:
- # args is any number of mocks to be verified.
- """
-
- for mock in args:
- mock._Verify()
-
-
-def Reset(*args):
- """Reset mocks.
-
- Args:
- # args is any number of mocks to be reset.
- """
-
- for mock in args:
- mock._Reset()
-
-
-class MockAnything:
- """A mock that can be used to mock anything.
-
- This is helpful for mocking classes that do not provide a public interface.
- """
-
- def __init__(self):
- """ """
- self._Reset()
-
- def __getattr__(self, method_name):
- """Intercept method calls on this object.
-
- A new MockMethod is returned that is aware of the MockAnything's
- state (record or replay). The call will be recorded or replayed
- by the MockMethod's __call__.
-
- Args:
- # method name: the name of the method being called.
- method_name: str
-
- Returns:
- A new MockMethod aware of MockAnything's state (record or replay).
- """
-
- return self._CreateMockMethod(method_name)
-
- def _CreateMockMethod(self, method_name):
- """Create a new mock method call and return it.
-
- Args:
- # method name: the name of the method being called.
- method_name: str
-
- Returns:
- A new MockMethod aware of MockAnything's state (record or replay).
- """
-
- return MockMethod(method_name, self._expected_calls_queue,
- self._replay_mode)
-
- def __nonzero__(self):
- """Return 1 for nonzero so the mock can be used as a conditional."""
-
- return 1
-
- def __eq__(self, rhs):
- """Provide custom logic to compare objects."""
-
- return (isinstance(rhs, MockAnything) and
- self._replay_mode == rhs._replay_mode and
- self._expected_calls_queue == rhs._expected_calls_queue)
-
- def __ne__(self, rhs):
- """Provide custom logic to compare objects."""
-
- return not self == rhs
-
- def _Replay(self):
- """Start replaying expected method calls."""
-
- self._replay_mode = True
-
- def _Verify(self):
- """Verify that all of the expected calls have been made.
-
- Raises:
- ExpectedMethodCallsError: if there are still more method calls in the
- expected queue.
- """
-
- # If the list of expected calls is not empty, raise an exception
- if self._expected_calls_queue:
- # The last MultipleTimesGroup is not popped from the queue.
- if (len(self._expected_calls_queue) == 1 and
- isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and
- self._expected_calls_queue[0].IsSatisfied()):
- pass
- else:
- raise ExpectedMethodCallsError(self._expected_calls_queue)
-
- def _Reset(self):
- """Reset the state of this mock to record mode with an empty queue."""
-
- # Maintain a list of method calls we are expecting
- self._expected_calls_queue = deque()
-
- # Make sure we are in setup mode, not replay mode
- self._replay_mode = False
-
-
-class MockObject(MockAnything, object):
- """A mock object that simulates the public/protected interface of a class."""
-
- def __init__(self, class_to_mock):
- """Initialize a mock object.
-
- This determines the methods and properties of the class and stores them.
-
- Args:
- # class_to_mock: class to be mocked
- class_to_mock: class
- """
-
- # This is used to hack around the mixin/inheritance of MockAnything, which
- # is not a proper object (it can be anything. :-)
- MockAnything.__dict__['__init__'](self)
-
- # Get a list of all the public and special methods we should mock.
- self._known_methods = set()
- self._known_vars = set()
- self._class_to_mock = class_to_mock
- for method in dir(class_to_mock):
- if callable(getattr(class_to_mock, method)):
- self._known_methods.add(method)
- else:
- self._known_vars.add(method)
-
- def __getattr__(self, name):
- """Intercept attribute request on this object.
-
- If the attribute is a public class variable, it will be returned and not
- recorded as a call.
-
- If the attribute is not a variable, it is handled like a method
- call. The method name is checked against the set of mockable
- methods, and a new MockMethod is returned that is aware of the
- MockObject's state (record or replay). The call will be recorded
- or replayed by the MockMethod's __call__.
-
- Args:
- # name: the name of the attribute being requested.
- name: str
-
- Returns:
- Either a class variable or a new MockMethod that is aware of the state
- of the mock (record or replay).
-
- Raises:
- UnknownMethodCallError if the MockObject does not mock the requested
- method.
- """
-
- if name in self._known_vars:
- return getattr(self._class_to_mock, name)
-
- if name in self._known_methods:
- return self._CreateMockMethod(name)
-
- raise UnknownMethodCallError(name)
-
- def __eq__(self, rhs):
- """Provide custom logic to compare objects."""
-
- return (isinstance(rhs, MockObject) and
- self._class_to_mock == rhs._class_to_mock and
- self._replay_mode == rhs._replay_mode and
- self._expected_calls_queue == rhs._expected_calls_queue)
-
- def __setitem__(self, key, value):
- """Provide custom logic for mocking classes that support item assignment.
-
- Args:
- key: Key to set the value for.
- value: Value to set.
-
- Returns:
- Expected return value in replay mode. A MockMethod object for the
- __setitem__ method that has already been called if not in replay mode.
-
- Raises:
- TypeError if the underlying class does not support item assignment.
- UnexpectedMethodCallError if the object does not expect the call to
- __setitem__.
-
- """
- setitem = self._class_to_mock.__dict__.get('__setitem__', None)
-
- # Verify the class supports item assignment.
- if setitem is None:
- raise TypeError('object does not support item assignment')
-
- # If we are in replay mode then simply call the mock __setitem__ method.
- if self._replay_mode:
- return MockMethod('__setitem__', self._expected_calls_queue,
- self._replay_mode)(key, value)
-
-
- # Otherwise, create a mock method __setitem__.
- return self._CreateMockMethod('__setitem__')(key, value)
-
- def __getitem__(self, key):
- """Provide custom logic for mocking classes that are subscriptable.
-
- Args:
- key: Key to return the value for.
-
- Returns:
- Expected return value in replay mode. A MockMethod object for the
- __getitem__ method that has already been called if not in replay mode.
-
- Raises:
- TypeError if the underlying class is not subscriptable.
- UnexpectedMethodCallError if the object does not expect the call to
- __setitem__.
-
- """
- getitem = self._class_to_mock.__dict__.get('__getitem__', None)
-
- # Verify the class supports item assignment.
- if getitem is None:
- raise TypeError('unsubscriptable object')
-
- # If we are in replay mode then simply call the mock __getitem__ method.
- if self._replay_mode:
- return MockMethod('__getitem__', self._expected_calls_queue,
- self._replay_mode)(key)
-
-
- # Otherwise, create a mock method __getitem__.
- return self._CreateMockMethod('__getitem__')(key)
-
- def __call__(self, *params, **named_params):
- """Provide custom logic for mocking classes that are callable."""
-
- # Verify the class we are mocking is callable
- callable = self._class_to_mock.__dict__.get('__call__', None)
- if callable is None:
- raise TypeError('Not callable')
-
- # Because the call is happening directly on this object instead of a method,
- # the call on the mock method is made right here
- mock_method = self._CreateMockMethod('__call__')
- return mock_method(*params, **named_params)
-
- @property
- def __class__(self):
- """Return the class that is being mocked."""
-
- return self._class_to_mock
-
-
-class MockMethod(object):
- """Callable mock method.
-
- A MockMethod should act exactly like the method it mocks, accepting parameters
- and returning a value, or throwing an exception (as specified). When this
- method is called, it can optionally verify whether the called method (name and
- signature) matches the expected method.
- """
-
- def __init__(self, method_name, call_queue, replay_mode):
- """Construct a new mock method.
-
- Args:
- # method_name: the name of the method
- # call_queue: deque of calls, verify this call against the head, or add
- # this call to the queue.
- # replay_mode: False if we are recording, True if we are verifying calls
- # against the call queue.
- method_name: str
- call_queue: list or deque
- replay_mode: bool
- """
-
- self._name = method_name
- self._call_queue = call_queue
- if not isinstance(call_queue, deque):
- self._call_queue = deque(self._call_queue)
- self._replay_mode = replay_mode
-
- self._params = None
- self._named_params = None
- self._return_value = None
- self._exception = None
- self._side_effects = None
-
- def __call__(self, *params, **named_params):
- """Log parameters and return the specified return value.
-
- If the Mock(Anything/Object) associated with this call is in record mode,
- this MockMethod will be pushed onto the expected call queue. If the mock
- is in replay mode, this will pop a MockMethod off the top of the queue and
- verify this call is equal to the expected call.
-
- Raises:
- UnexpectedMethodCall if this call is supposed to match an expected method
- call and it does not.
- """
-
- self._params = params
- self._named_params = named_params
-
- if not self._replay_mode:
- self._call_queue.append(self)
- return self
-
- expected_method = self._VerifyMethodCall()
-
- if expected_method._side_effects:
- expected_method._side_effects(*params, **named_params)
-
- if expected_method._exception:
- raise expected_method._exception
-
- return expected_method._return_value
-
- def __getattr__(self, name):
- """Raise an AttributeError with a helpful message."""
-
- raise AttributeError('MockMethod has no attribute "%s". '
- 'Did you remember to put your mocks in replay mode?' % name)
-
- def _PopNextMethod(self):
- """Pop the next method from our call queue."""
- try:
- return self._call_queue.popleft()
- except IndexError:
- raise UnexpectedMethodCallError(self, None)
-
- def _VerifyMethodCall(self):
- """Verify the called method is expected.
-
- This can be an ordered method, or part of an unordered set.
-
- Returns:
- The expected mock method.
-
- Raises:
- UnexpectedMethodCall if the method called was not expected.
- """
-
- expected = self._PopNextMethod()
-
- # Loop here, because we might have a MethodGroup followed by another
- # group.
- while isinstance(expected, MethodGroup):
- expected, method = expected.MethodCalled(self)
- if method is not None:
- return method
-
- # This is a mock method, so just check equality.
- if expected != self:
- raise UnexpectedMethodCallError(self, expected)
-
- return expected
-
- def __str__(self):
- params = ', '.join(
- [repr(p) for p in self._params or []] +
- ['%s=%r' % x for x in sorted((self._named_params or {}).items())])
- desc = "%s(%s) -> %r" % (self._name, params, self._return_value)
- return desc
-
- def __eq__(self, rhs):
- """Test whether this MockMethod is equivalent to another MockMethod.
-
- Args:
- # rhs: the right hand side of the test
- rhs: MockMethod
- """
-
- return (isinstance(rhs, MockMethod) and
- self._name == rhs._name and
- self._params == rhs._params and
- self._named_params == rhs._named_params)
-
- def __ne__(self, rhs):
- """Test whether this MockMethod is not equivalent to another MockMethod.
-
- Args:
- # rhs: the right hand side of the test
- rhs: MockMethod
- """
-
- return not self == rhs
-
- def GetPossibleGroup(self):
- """Returns a possible group from the end of the call queue or None if no
- other methods are on the stack.
- """
-
- # Remove this method from the tail of the queue so we can add it to a group.
- this_method = self._call_queue.pop()
- assert this_method == self
-
- # Determine if the tail of the queue is a group, or just a regular ordered
- # mock method.
- group = None
- try:
- group = self._call_queue[-1]
- except IndexError:
- pass
-
- return group
-
- def _CheckAndCreateNewGroup(self, group_name, group_class):
- """Checks if the last method (a possible group) is an instance of our
- group_class. Adds the current method to this group or creates a new one.
-
- Args:
-
- group_name: the name of the group.
- group_class: the class used to create instance of this new group
- """
- group = self.GetPossibleGroup()
-
- # If this is a group, and it is the correct group, add the method.
- if isinstance(group, group_class) and group.group_name() == group_name:
- group.AddMethod(self)
- return self
-
- # Create a new group and add the method.
- new_group = group_class(group_name)
- new_group.AddMethod(self)
- self._call_queue.append(new_group)
- return self
-
- def InAnyOrder(self, group_name="default"):
- """Move this method into a group of unordered calls.
-
- A group of unordered calls must be defined together, and must be executed
- in full before the next expected method can be called. There can be
- multiple groups that are expected serially, if they are given
- different group names. The same group name can be reused if there is a
- standard method call, or a group with a different name, spliced between
- usages.
-
- Args:
- group_name: the name of the unordered group.
-
- Returns:
- self
- """
- return self._CheckAndCreateNewGroup(group_name, UnorderedGroup)
-
- def MultipleTimes(self, group_name="default"):
- """Move this method into group of calls which may be called multiple times.
-
- A group of repeating calls must be defined together, and must be executed in
- full before the next expected mehtod can be called.
-
- Args:
- group_name: the name of the unordered group.
-
- Returns:
- self
- """
- return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup)
-
- def AndReturn(self, return_value):
- """Set the value to return when this method is called.
-
- Args:
- # return_value can be anything.
- """
-
- self._return_value = return_value
- return return_value
-
- def AndRaise(self, exception):
- """Set the exception to raise when this method is called.
-
- Args:
- # exception: the exception to raise when this method is called.
- exception: Exception
- """
-
- self._exception = exception
-
- def WithSideEffects(self, side_effects):
- """Set the side effects that are simulated when this method is called.
-
- Args:
- side_effects: A callable which modifies the parameters or other relevant
- state which a given test case depends on.
-
- Returns:
- Self for chaining with AndReturn and AndRaise.
- """
- self._side_effects = side_effects
- return self
-
-class Comparator:
- """Base class for all Mox comparators.
-
- A Comparator can be used as a parameter to a mocked method when the exact
- value is not known. For example, the code you are testing might build up a
- long SQL string that is passed to your mock DAO. You're only interested that
- the IN clause contains the proper primary keys, so you can set your mock
- up as follows:
-
- mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
-
- Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'.
-
- A Comparator may replace one or more parameters, for example:
- # return at most 10 rows
- mock_dao.RunQuery(StrContains('SELECT'), 10)
-
- or
-
- # Return some non-deterministic number of rows
- mock_dao.RunQuery(StrContains('SELECT'), IsA(int))
- """
-
- def equals(self, rhs):
- """Special equals method that all comparators must implement.
-
- Args:
- rhs: any python object
- """
-
- raise NotImplementedError, 'method must be implemented by a subclass.'
-
- def __eq__(self, rhs):
- return self.equals(rhs)
-
- def __ne__(self, rhs):
- return not self.equals(rhs)
-
-
-class IsA(Comparator):
- """This class wraps a basic Python type or class. It is used to verify
- that a parameter is of the given type or class.
-
- Example:
- mock_dao.Connect(IsA(DbConnectInfo))
- """
-
- def __init__(self, class_name):
- """Initialize IsA
-
- Args:
- class_name: basic python type or a class
- """
-
- self._class_name = class_name
-
- def equals(self, rhs):
- """Check to see if the RHS is an instance of class_name.
-
- Args:
- # rhs: the right hand side of the test
- rhs: object
-
- Returns:
- bool
- """
-
- try:
- return isinstance(rhs, self._class_name)
- except TypeError:
- # Check raw types if there was a type error. This is helpful for
- # things like cStringIO.StringIO.
- return type(rhs) == type(self._class_name)
-
- def __repr__(self):
- return str(self._class_name)
-
-class IsAlmost(Comparator):
- """Comparison class used to check whether a parameter is nearly equal
- to a given value. Generally useful for floating point numbers.
-
- Example mock_dao.SetTimeout((IsAlmost(3.9)))
- """
-
- def __init__(self, float_value, places=7):
- """Initialize IsAlmost.
-
- Args:
- float_value: The value for making the comparison.
- places: The number of decimal places to round to.
- """
-
- self._float_value = float_value
- self._places = places
-
- def equals(self, rhs):
- """Check to see if RHS is almost equal to float_value
-
- Args:
- rhs: the value to compare to float_value
-
- Returns:
- bool
- """
-
- try:
- return round(rhs-self._float_value, self._places) == 0
- except TypeError:
- # This is probably because either float_value or rhs is not a number.
- return False
-
- def __repr__(self):
- return str(self._float_value)
-
-class StrContains(Comparator):
- """Comparison class used to check whether a substring exists in a
- string parameter. This can be useful in mocking a database with SQL
- passed in as a string parameter, for example.
-
- Example:
- mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result)
- """
-
- def __init__(self, search_string):
- """Initialize.
-
- Args:
- # search_string: the string you are searching for
- search_string: str
- """
-
- self._search_string = search_string
-
- def equals(self, rhs):
- """Check to see if the search_string is contained in the rhs string.
-
- Args:
- # rhs: the right hand side of the test
- rhs: object
-
- Returns:
- bool
- """
-
- try:
- return rhs.find(self._search_string) > -1
- except Exception:
- return False
-
- def __repr__(self):
- return '<str containing \'%s\'>' % self._search_string
-
-
-class Regex(Comparator):
- """Checks if a string matches a regular expression.
-
- This uses a given regular expression to determine equality.
- """
-
- def __init__(self, pattern, flags=0):
- """Initialize.
-
- Args:
- # pattern is the regular expression to search for
- pattern: str
- # flags passed to re.compile function as the second argument
- flags: int
- """
-
- self.regex = re.compile(pattern, flags=flags)
-
- def equals(self, rhs):
- """Check to see if rhs matches regular expression pattern.
-
- Returns:
- bool
- """
-
- return self.regex.search(rhs) is not None
-
- def __repr__(self):
- s = '<regular expression \'%s\'' % self.regex.pattern
- if self.regex.flags:
- s += ', flags=%d' % self.regex.flags
- s += '>'
- return s
-
-
-class In(Comparator):
- """Checks whether an item (or key) is in a list (or dict) parameter.
-
- Example:
- mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result)
- """
-
- def __init__(self, key):
- """Initialize.
-
- Args:
- # key is any thing that could be in a list or a key in a dict
- """
-
- self._key = key
-
- def equals(self, rhs):
- """Check to see whether key is in rhs.
-
- Args:
- rhs: dict
-
- Returns:
- bool
- """
-
- return self._key in rhs
-
- def __repr__(self):
- return '<sequence or map containing \'%s\'>' % self._key
-
-
-class ContainsKeyValue(Comparator):
- """Checks whether a key/value pair is in a dict parameter.
-
- Example:
- mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info))
- """
-
- def __init__(self, key, value):
- """Initialize.
-
- Args:
- # key: a key in a dict
- # value: the corresponding value
- """
-
- self._key = key
- self._value = value
-
- def equals(self, rhs):
- """Check whether the given key/value pair is in the rhs dict.
-
- Returns:
- bool
- """
-
- try:
- return rhs[self._key] == self._value
- except Exception:
- return False
-
- def __repr__(self):
- return '<map containing the entry \'%s: %s\'>' % (self._key, self._value)
-
-
-class SameElementsAs(Comparator):
- """Checks whether iterables contain the same elements (ignoring order).
-
- Example:
- mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki'))
- """
-
- def __init__(self, expected_seq):
- """Initialize.
-
- Args:
- expected_seq: a sequence
- """
-
- self._expected_seq = expected_seq
-
- def equals(self, actual_seq):
- """Check to see whether actual_seq has same elements as expected_seq.
-
- Args:
- actual_seq: sequence
-
- Returns:
- bool
- """
-
- try:
- expected = dict([(element, None) for element in self._expected_seq])
- actual = dict([(element, None) for element in actual_seq])
- except TypeError:
- # Fall back to slower list-compare if any of the objects are unhashable.
- expected = list(self._expected_seq)
- actual = list(actual_seq)
- expected.sort()
- actual.sort()
- return expected == actual
-
- def __repr__(self):
- return '<sequence with same elements as \'%s\'>' % self._expected_seq
-
-
-class And(Comparator):
- """Evaluates one or more Comparators on RHS and returns an AND of the results.
- """
-
- def __init__(self, *args):
- """Initialize.
-
- Args:
- *args: One or more Comparator
- """
-
- self._comparators = args
-
- def equals(self, rhs):
- """Checks whether all Comparators are equal to rhs.
-
- Args:
- # rhs: can be anything
-
- Returns:
- bool
- """
-
- for comparator in self._comparators:
- if not comparator.equals(rhs):
- return False
-
- return True
-
- def __repr__(self):
- return '<AND %s>' % str(self._comparators)
-
-
-class Or(Comparator):
- """Evaluates one or more Comparators on RHS and returns an OR of the results.
- """
-
- def __init__(self, *args):
- """Initialize.
-
- Args:
- *args: One or more Mox comparators
- """
-
- self._comparators = args
-
- def equals(self, rhs):
- """Checks whether any Comparator is equal to rhs.
-
- Args:
- # rhs: can be anything
-
- Returns:
- bool
- """
-
- for comparator in self._comparators:
- if comparator.equals(rhs):
- return True
-
- return False
-
- def __repr__(self):
- return '<OR %s>' % str(self._comparators)
-
-
-class Func(Comparator):
- """Call a function that should verify the parameter passed in is correct.
-
- You may need the ability to perform more advanced operations on the parameter
- in order to validate it. You can use this to have a callable validate any
- parameter. The callable should return either True or False.
-
-
- Example:
-
- def myParamValidator(param):
- # Advanced logic here
- return True
-
- mock_dao.DoSomething(Func(myParamValidator), true)
- """
-
- def __init__(self, func):
- """Initialize.
-
- Args:
- func: callable that takes one parameter and returns a bool
- """
-
- self._func = func
-
- def equals(self, rhs):
- """Test whether rhs passes the function test.
-
- rhs is passed into func.
-
- Args:
- rhs: any python object
-
- Returns:
- the result of func(rhs)
- """
-
- return self._func(rhs)
-
- def __repr__(self):
- return str(self._func)
-
-
-class IgnoreArg(Comparator):
- """Ignore an argument.
-
- This can be used when we don't care about an argument of a method call.
-
- Example:
- # Check if CastMagic is called with 3 as first arg and 'disappear' as third.
- mymock.CastMagic(3, IgnoreArg(), 'disappear')
- """
-
- def equals(self, unused_rhs):
- """Ignores arguments and returns True.
-
- Args:
- unused_rhs: any python object
-
- Returns:
- always returns True
- """
-
- return True
-
- def __repr__(self):
- return '<IgnoreArg>'
-
-
-class MethodGroup(object):
- """Base class containing common behaviour for MethodGroups."""
-
- def __init__(self, group_name):
- self._group_name = group_name
-
- def group_name(self):
- return self._group_name
-
- def __str__(self):
- return '<%s "%s">' % (self.__class__.__name__, self._group_name)
-
- def AddMethod(self, mock_method):
- raise NotImplementedError
-
- def MethodCalled(self, mock_method):
- raise NotImplementedError
-
- def IsSatisfied(self):
- raise NotImplementedError
-
-class UnorderedGroup(MethodGroup):
- """UnorderedGroup holds a set of method calls that may occur in any order.
-
- This construct is helpful for non-deterministic events, such as iterating
- over the keys of a dict.
- """
-
- def __init__(self, group_name):
- super(UnorderedGroup, self).__init__(group_name)
- self._methods = []
-
- def AddMethod(self, mock_method):
- """Add a method to this group.
-
- Args:
- mock_method: A mock method to be added to this group.
- """
-
- self._methods.append(mock_method)
-
- def MethodCalled(self, mock_method):
- """Remove a method call from the group.
-
- If the method is not in the set, an UnexpectedMethodCallError will be
- raised.
-
- Args:
- mock_method: a mock method that should be equal to a method in the group.
-
- Returns:
- The mock method from the group
-
- Raises:
- UnexpectedMethodCallError if the mock_method was not in the group.
- """
-
- # Check to see if this method exists, and if so, remove it from the set
- # and return it.
- for method in self._methods:
- if method == mock_method:
- # Remove the called mock_method instead of the method in the group.
- # The called method will match any comparators when equality is checked
- # during removal. The method in the group could pass a comparator to
- # another comparator during the equality check.
- self._methods.remove(mock_method)
-
- # If this group is not empty, put it back at the head of the queue.
- if not self.IsSatisfied():
- mock_method._call_queue.appendleft(self)
-
- return self, method
-
- raise UnexpectedMethodCallError(mock_method, self)
-
- def IsSatisfied(self):
- """Return True if there are not any methods in this group."""
-
- return len(self._methods) == 0
-
-
-class MultipleTimesGroup(MethodGroup):
- """MultipleTimesGroup holds methods that may be called any number of times.
-
- Note: Each method must be called at least once.
-
- This is helpful, if you don't know or care how many times a method is called.
- """
-
- def __init__(self, group_name):
- super(MultipleTimesGroup, self).__init__(group_name)
- self._methods = set()
- self._methods_called = set()
-
- def AddMethod(self, mock_method):
- """Add a method to this group.
-
- Args:
- mock_method: A mock method to be added to this group.
- """
-
- self._methods.add(mock_method)
-
- def MethodCalled(self, mock_method):
- """Remove a method call from the group.
-
- If the method is not in the set, an UnexpectedMethodCallError will be
- raised.
-
- Args:
- mock_method: a mock method that should be equal to a method in the group.
-
- Returns:
- The mock method from the group
-
- Raises:
- UnexpectedMethodCallError if the mock_method was not in the group.
- """
-
- # Check to see if this method exists, and if so add it to the set of
- # called methods.
-
- for method in self._methods:
- if method == mock_method:
- self._methods_called.add(mock_method)
- # Always put this group back on top of the queue, because we don't know
- # when we are done.
- mock_method._call_queue.appendleft(self)
- return self, method
-
- if self.IsSatisfied():
- next_method = mock_method._PopNextMethod();
- return next_method, None
- else:
- raise UnexpectedMethodCallError(mock_method, self)
-
- def IsSatisfied(self):
- """Return True if all methods in this group are called at least once."""
- # NOTE(psycho): We can't use the simple set difference here because we want
- # to match different parameters which are considered the same e.g. IsA(str)
- # and some string. This solution is O(n^2) but n should be small.
- tmp = self._methods.copy()
- for called in self._methods_called:
- for expected in tmp:
- if called == expected:
- tmp.remove(expected)
- if not tmp:
- return True
- break
- return False
-
-
-class MoxMetaTestBase(type):
- """Metaclass to add mox cleanup and verification to every test.
-
- As the mox unit testing class is being constructed (MoxTestBase or a
- subclass), this metaclass will modify all test functions to call the
- CleanUpMox method of the test class after they finish. This means that
- unstubbing and verifying will happen for every test with no additional code,
- and any failures will result in test failures as opposed to errors.
- """
-
- def __init__(cls, name, bases, d):
- type.__init__(cls, name, bases, d)
-
- # also get all the attributes from the base classes to account
- # for a case when test class is not the immediate child of MoxTestBase
- for base in bases:
- for attr_name in dir(base):
- d[attr_name] = getattr(base, attr_name)
-
- for func_name, func in d.items():
- if func_name.startswith('test') and callable(func):
- setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func))
-
- @staticmethod
- def CleanUpTest(cls, func):
- """Adds Mox cleanup code to any MoxTestBase method.
-
- Always unsets stubs after a test. Will verify all mocks for tests that
- otherwise pass.
-
- Args:
- cls: MoxTestBase or subclass; the class whose test method we are altering.
- func: method; the method of the MoxTestBase test class we wish to alter.
-
- Returns:
- The modified method.
- """
- def new_method(self, *args, **kwargs):
- mox_obj = getattr(self, 'mox', None)
- cleanup_mox = False
- if mox_obj and isinstance(mox_obj, Mox):
- cleanup_mox = True
- try:
- func(self, *args, **kwargs)
- finally:
- if cleanup_mox:
- mox_obj.UnsetStubs()
- if cleanup_mox:
- mox_obj.VerifyAll()
- new_method.__name__ = func.__name__
- new_method.__doc__ = func.__doc__
- new_method.__module__ = func.__module__
- return new_method
-
-
-class MoxTestBase(unittest.TestCase):
- """Convenience test class to make stubbing easier.
-
- Sets up a "mox" attribute which is an instance of Mox - any mox tests will
- want this. Also automatically unsets any stubs and verifies that all mock
- methods have been called at the end of each test, eliminating boilerplate
- code.
- """
-
- __metaclass__ = MoxMetaTestBase
-
- def setUp(self):
- self.mox = Mox()
+#!/usr/bin/python2.4 +# +# Copyright 2008 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is used for testing. The original is at: +# http://code.google.com/p/pymox/ + +"""Mox, an object-mocking framework for Python. + +Mox works in the record-replay-verify paradigm. When you first create +a mock object, it is in record mode. You then programmatically set +the expected behavior of the mock object (what methods are to be +called on it, with what parameters, what they should return, and in +what order). + +Once you have set up the expected mock behavior, you put it in replay +mode. Now the mock responds to method calls just as you told it to. +If an unexpected method (or an expected method with unexpected +parameters) is called, then an exception will be raised. + +Once you are done interacting with the mock, you need to verify that +all the expected interactions occured. (Maybe your code exited +prematurely without calling some cleanup method!) The verify phase +ensures that every expected method was called; otherwise, an exception +will be raised. + +Suggested usage / workflow: + + # Create Mox factory + my_mox = Mox() + + # Create a mock data access object + mock_dao = my_mox.CreateMock(DAOClass) + + # Set up expected behavior + mock_dao.RetrievePersonWithIdentifier('1').AndReturn(person) + mock_dao.DeletePerson(person) + + # Put mocks in replay mode + my_mox.ReplayAll() + + # Inject mock object and run test + controller.SetDao(mock_dao) + controller.DeletePersonById('1') + + # Verify all methods were called as expected + my_mox.VerifyAll() +""" + +from collections import deque +import re +import types +import unittest + +import stubout + +class Error(AssertionError): + """Base exception for this module.""" + + pass + + +class ExpectedMethodCallsError(Error): + """Raised when Verify() is called before all expected methods have been called + """ + + def __init__(self, expected_methods): + """Init exception. + + Args: + # expected_methods: A sequence of MockMethod objects that should have been + # called. + expected_methods: [MockMethod] + + Raises: + ValueError: if expected_methods contains no methods. + """ + + if not expected_methods: + raise ValueError("There must be at least one expected method") + Error.__init__(self) + self._expected_methods = expected_methods + + def __str__(self): + calls = "\n".join(["%3d. %s" % (i, m) + for i, m in enumerate(self._expected_methods)]) + return "Verify: Expected methods never called:\n%s" % (calls,) + + +class UnexpectedMethodCallError(Error): + """Raised when an unexpected method is called. + + This can occur if a method is called with incorrect parameters, or out of the + specified order. + """ + + def __init__(self, unexpected_method, expected): + """Init exception. + + Args: + # unexpected_method: MockMethod that was called but was not at the head of + # the expected_method queue. + # expected: MockMethod or UnorderedGroup the method should have + # been in. + unexpected_method: MockMethod + expected: MockMethod or UnorderedGroup + """ + + Error.__init__(self) + self._unexpected_method = unexpected_method + self._expected = expected + + def __str__(self): + return "Unexpected method call: %s. Expecting: %s" % \ + (self._unexpected_method, self._expected) + + +class UnknownMethodCallError(Error): + """Raised if an unknown method is requested of the mock object.""" + + def __init__(self, unknown_method_name): + """Init exception. + + Args: + # unknown_method_name: Method call that is not part of the mocked class's + # public interface. + unknown_method_name: str + """ + + Error.__init__(self) + self._unknown_method_name = unknown_method_name + + def __str__(self): + return "Method called is not a member of the object: %s" % \ + self._unknown_method_name + + +class Mox(object): + """Mox: a factory for creating mock objects.""" + + # A list of types that should be stubbed out with MockObjects (as + # opposed to MockAnythings). + _USE_MOCK_OBJECT = [types.ClassType, types.InstanceType, types.ModuleType, + types.ObjectType, types.TypeType] + + def __init__(self): + """Initialize a new Mox.""" + + self._mock_objects = [] + self.stubs = stubout.StubOutForTesting() + + def CreateMock(self, class_to_mock): + """Create a new mock object. + + Args: + # class_to_mock: the class to be mocked + class_to_mock: class + + Returns: + MockObject that can be used as the class_to_mock would be. + """ + + new_mock = MockObject(class_to_mock) + self._mock_objects.append(new_mock) + return new_mock + + def CreateMockAnything(self): + """Create a mock that will accept any method calls. + + This does not enforce an interface. + """ + + new_mock = MockAnything() + self._mock_objects.append(new_mock) + return new_mock + + def ReplayAll(self): + """Set all mock objects to replay mode.""" + + for mock_obj in self._mock_objects: + mock_obj._Replay() + + + def VerifyAll(self): + """Call verify on all mock objects created.""" + + for mock_obj in self._mock_objects: + mock_obj._Verify() + + def ResetAll(self): + """Call reset on all mock objects. This does not unset stubs.""" + + for mock_obj in self._mock_objects: + mock_obj._Reset() + + def StubOutWithMock(self, obj, attr_name, use_mock_anything=False): + """Replace a method, attribute, etc. with a Mock. + + This will replace a class or module with a MockObject, and everything else + (method, function, etc) with a MockAnything. This can be overridden to + always use a MockAnything by setting use_mock_anything to True. + + Args: + obj: A Python object (class, module, instance, callable). + attr_name: str. The name of the attribute to replace with a mock. + use_mock_anything: bool. True if a MockAnything should be used regardless + of the type of attribute. + """ + + attr_to_replace = getattr(obj, attr_name) + if type(attr_to_replace) in self._USE_MOCK_OBJECT and not use_mock_anything: + stub = self.CreateMock(attr_to_replace) + else: + stub = self.CreateMockAnything() + + self.stubs.Set(obj, attr_name, stub) + + def UnsetStubs(self): + """Restore stubs to their original state.""" + + self.stubs.UnsetAll() + +def Replay(*args): + """Put mocks into Replay mode. + + Args: + # args is any number of mocks to put into replay mode. + """ + + for mock in args: + mock._Replay() + + +def Verify(*args): + """Verify mocks. + + Args: + # args is any number of mocks to be verified. + """ + + for mock in args: + mock._Verify() + + +def Reset(*args): + """Reset mocks. + + Args: + # args is any number of mocks to be reset. + """ + + for mock in args: + mock._Reset() + + +class MockAnything: + """A mock that can be used to mock anything. + + This is helpful for mocking classes that do not provide a public interface. + """ + + def __init__(self): + """ """ + self._Reset() + + def __getattr__(self, method_name): + """Intercept method calls on this object. + + A new MockMethod is returned that is aware of the MockAnything's + state (record or replay). The call will be recorded or replayed + by the MockMethod's __call__. + + Args: + # method name: the name of the method being called. + method_name: str + + Returns: + A new MockMethod aware of MockAnything's state (record or replay). + """ + + return self._CreateMockMethod(method_name) + + def _CreateMockMethod(self, method_name): + """Create a new mock method call and return it. + + Args: + # method name: the name of the method being called. + method_name: str + + Returns: + A new MockMethod aware of MockAnything's state (record or replay). + """ + + return MockMethod(method_name, self._expected_calls_queue, + self._replay_mode) + + def __nonzero__(self): + """Return 1 for nonzero so the mock can be used as a conditional.""" + + return 1 + + def __eq__(self, rhs): + """Provide custom logic to compare objects.""" + + return (isinstance(rhs, MockAnything) and + self._replay_mode == rhs._replay_mode and + self._expected_calls_queue == rhs._expected_calls_queue) + + def __ne__(self, rhs): + """Provide custom logic to compare objects.""" + + return not self == rhs + + def _Replay(self): + """Start replaying expected method calls.""" + + self._replay_mode = True + + def _Verify(self): + """Verify that all of the expected calls have been made. + + Raises: + ExpectedMethodCallsError: if there are still more method calls in the + expected queue. + """ + + # If the list of expected calls is not empty, raise an exception + if self._expected_calls_queue: + # The last MultipleTimesGroup is not popped from the queue. + if (len(self._expected_calls_queue) == 1 and + isinstance(self._expected_calls_queue[0], MultipleTimesGroup) and + self._expected_calls_queue[0].IsSatisfied()): + pass + else: + raise ExpectedMethodCallsError(self._expected_calls_queue) + + def _Reset(self): + """Reset the state of this mock to record mode with an empty queue.""" + + # Maintain a list of method calls we are expecting + self._expected_calls_queue = deque() + + # Make sure we are in setup mode, not replay mode + self._replay_mode = False + + +class MockObject(MockAnything, object): + """A mock object that simulates the public/protected interface of a class.""" + + def __init__(self, class_to_mock): + """Initialize a mock object. + + This determines the methods and properties of the class and stores them. + + Args: + # class_to_mock: class to be mocked + class_to_mock: class + """ + + # This is used to hack around the mixin/inheritance of MockAnything, which + # is not a proper object (it can be anything. :-) + MockAnything.__dict__['__init__'](self) + + # Get a list of all the public and special methods we should mock. + self._known_methods = set() + self._known_vars = set() + self._class_to_mock = class_to_mock + for method in dir(class_to_mock): + if callable(getattr(class_to_mock, method)): + self._known_methods.add(method) + else: + self._known_vars.add(method) + + def __getattr__(self, name): + """Intercept attribute request on this object. + + If the attribute is a public class variable, it will be returned and not + recorded as a call. + + If the attribute is not a variable, it is handled like a method + call. The method name is checked against the set of mockable + methods, and a new MockMethod is returned that is aware of the + MockObject's state (record or replay). The call will be recorded + or replayed by the MockMethod's __call__. + + Args: + # name: the name of the attribute being requested. + name: str + + Returns: + Either a class variable or a new MockMethod that is aware of the state + of the mock (record or replay). + + Raises: + UnknownMethodCallError if the MockObject does not mock the requested + method. + """ + + if name in self._known_vars: + return getattr(self._class_to_mock, name) + + if name in self._known_methods: + return self._CreateMockMethod(name) + + raise UnknownMethodCallError(name) + + def __eq__(self, rhs): + """Provide custom logic to compare objects.""" + + return (isinstance(rhs, MockObject) and + self._class_to_mock == rhs._class_to_mock and + self._replay_mode == rhs._replay_mode and + self._expected_calls_queue == rhs._expected_calls_queue) + + def __setitem__(self, key, value): + """Provide custom logic for mocking classes that support item assignment. + + Args: + key: Key to set the value for. + value: Value to set. + + Returns: + Expected return value in replay mode. A MockMethod object for the + __setitem__ method that has already been called if not in replay mode. + + Raises: + TypeError if the underlying class does not support item assignment. + UnexpectedMethodCallError if the object does not expect the call to + __setitem__. + + """ + setitem = self._class_to_mock.__dict__.get('__setitem__', None) + + # Verify the class supports item assignment. + if setitem is None: + raise TypeError('object does not support item assignment') + + # If we are in replay mode then simply call the mock __setitem__ method. + if self._replay_mode: + return MockMethod('__setitem__', self._expected_calls_queue, + self._replay_mode)(key, value) + + + # Otherwise, create a mock method __setitem__. + return self._CreateMockMethod('__setitem__')(key, value) + + def __getitem__(self, key): + """Provide custom logic for mocking classes that are subscriptable. + + Args: + key: Key to return the value for. + + Returns: + Expected return value in replay mode. A MockMethod object for the + __getitem__ method that has already been called if not in replay mode. + + Raises: + TypeError if the underlying class is not subscriptable. + UnexpectedMethodCallError if the object does not expect the call to + __setitem__. + + """ + getitem = self._class_to_mock.__dict__.get('__getitem__', None) + + # Verify the class supports item assignment. + if getitem is None: + raise TypeError('unsubscriptable object') + + # If we are in replay mode then simply call the mock __getitem__ method. + if self._replay_mode: + return MockMethod('__getitem__', self._expected_calls_queue, + self._replay_mode)(key) + + + # Otherwise, create a mock method __getitem__. + return self._CreateMockMethod('__getitem__')(key) + + def __call__(self, *params, **named_params): + """Provide custom logic for mocking classes that are callable.""" + + # Verify the class we are mocking is callable + callable = self._class_to_mock.__dict__.get('__call__', None) + if callable is None: + raise TypeError('Not callable') + + # Because the call is happening directly on this object instead of a method, + # the call on the mock method is made right here + mock_method = self._CreateMockMethod('__call__') + return mock_method(*params, **named_params) + + @property + def __class__(self): + """Return the class that is being mocked.""" + + return self._class_to_mock + + +class MockMethod(object): + """Callable mock method. + + A MockMethod should act exactly like the method it mocks, accepting parameters + and returning a value, or throwing an exception (as specified). When this + method is called, it can optionally verify whether the called method (name and + signature) matches the expected method. + """ + + def __init__(self, method_name, call_queue, replay_mode): + """Construct a new mock method. + + Args: + # method_name: the name of the method + # call_queue: deque of calls, verify this call against the head, or add + # this call to the queue. + # replay_mode: False if we are recording, True if we are verifying calls + # against the call queue. + method_name: str + call_queue: list or deque + replay_mode: bool + """ + + self._name = method_name + self._call_queue = call_queue + if not isinstance(call_queue, deque): + self._call_queue = deque(self._call_queue) + self._replay_mode = replay_mode + + self._params = None + self._named_params = None + self._return_value = None + self._exception = None + self._side_effects = None + + def __call__(self, *params, **named_params): + """Log parameters and return the specified return value. + + If the Mock(Anything/Object) associated with this call is in record mode, + this MockMethod will be pushed onto the expected call queue. If the mock + is in replay mode, this will pop a MockMethod off the top of the queue and + verify this call is equal to the expected call. + + Raises: + UnexpectedMethodCall if this call is supposed to match an expected method + call and it does not. + """ + + self._params = params + self._named_params = named_params + + if not self._replay_mode: + self._call_queue.append(self) + return self + + expected_method = self._VerifyMethodCall() + + if expected_method._side_effects: + expected_method._side_effects(*params, **named_params) + + if expected_method._exception: + raise expected_method._exception + + return expected_method._return_value + + def __getattr__(self, name): + """Raise an AttributeError with a helpful message.""" + + raise AttributeError('MockMethod has no attribute "%s". ' + 'Did you remember to put your mocks in replay mode?' % name) + + def _PopNextMethod(self): + """Pop the next method from our call queue.""" + try: + return self._call_queue.popleft() + except IndexError: + raise UnexpectedMethodCallError(self, None) + + def _VerifyMethodCall(self): + """Verify the called method is expected. + + This can be an ordered method, or part of an unordered set. + + Returns: + The expected mock method. + + Raises: + UnexpectedMethodCall if the method called was not expected. + """ + + expected = self._PopNextMethod() + + # Loop here, because we might have a MethodGroup followed by another + # group. + while isinstance(expected, MethodGroup): + expected, method = expected.MethodCalled(self) + if method is not None: + return method + + # This is a mock method, so just check equality. + if expected != self: + raise UnexpectedMethodCallError(self, expected) + + return expected + + def __str__(self): + params = ', '.join( + [repr(p) for p in self._params or []] + + ['%s=%r' % x for x in sorted((self._named_params or {}).items())]) + desc = "%s(%s) -> %r" % (self._name, params, self._return_value) + return desc + + def __eq__(self, rhs): + """Test whether this MockMethod is equivalent to another MockMethod. + + Args: + # rhs: the right hand side of the test + rhs: MockMethod + """ + + return (isinstance(rhs, MockMethod) and + self._name == rhs._name and + self._params == rhs._params and + self._named_params == rhs._named_params) + + def __ne__(self, rhs): + """Test whether this MockMethod is not equivalent to another MockMethod. + + Args: + # rhs: the right hand side of the test + rhs: MockMethod + """ + + return not self == rhs + + def GetPossibleGroup(self): + """Returns a possible group from the end of the call queue or None if no + other methods are on the stack. + """ + + # Remove this method from the tail of the queue so we can add it to a group. + this_method = self._call_queue.pop() + assert this_method == self + + # Determine if the tail of the queue is a group, or just a regular ordered + # mock method. + group = None + try: + group = self._call_queue[-1] + except IndexError: + pass + + return group + + def _CheckAndCreateNewGroup(self, group_name, group_class): + """Checks if the last method (a possible group) is an instance of our + group_class. Adds the current method to this group or creates a new one. + + Args: + + group_name: the name of the group. + group_class: the class used to create instance of this new group + """ + group = self.GetPossibleGroup() + + # If this is a group, and it is the correct group, add the method. + if isinstance(group, group_class) and group.group_name() == group_name: + group.AddMethod(self) + return self + + # Create a new group and add the method. + new_group = group_class(group_name) + new_group.AddMethod(self) + self._call_queue.append(new_group) + return self + + def InAnyOrder(self, group_name="default"): + """Move this method into a group of unordered calls. + + A group of unordered calls must be defined together, and must be executed + in full before the next expected method can be called. There can be + multiple groups that are expected serially, if they are given + different group names. The same group name can be reused if there is a + standard method call, or a group with a different name, spliced between + usages. + + Args: + group_name: the name of the unordered group. + + Returns: + self + """ + return self._CheckAndCreateNewGroup(group_name, UnorderedGroup) + + def MultipleTimes(self, group_name="default"): + """Move this method into group of calls which may be called multiple times. + + A group of repeating calls must be defined together, and must be executed in + full before the next expected mehtod can be called. + + Args: + group_name: the name of the unordered group. + + Returns: + self + """ + return self._CheckAndCreateNewGroup(group_name, MultipleTimesGroup) + + def AndReturn(self, return_value): + """Set the value to return when this method is called. + + Args: + # return_value can be anything. + """ + + self._return_value = return_value + return return_value + + def AndRaise(self, exception): + """Set the exception to raise when this method is called. + + Args: + # exception: the exception to raise when this method is called. + exception: Exception + """ + + self._exception = exception + + def WithSideEffects(self, side_effects): + """Set the side effects that are simulated when this method is called. + + Args: + side_effects: A callable which modifies the parameters or other relevant + state which a given test case depends on. + + Returns: + Self for chaining with AndReturn and AndRaise. + """ + self._side_effects = side_effects + return self + +class Comparator: + """Base class for all Mox comparators. + + A Comparator can be used as a parameter to a mocked method when the exact + value is not known. For example, the code you are testing might build up a + long SQL string that is passed to your mock DAO. You're only interested that + the IN clause contains the proper primary keys, so you can set your mock + up as follows: + + mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) + + Now whatever query is passed in must contain the string 'IN (1, 2, 4, 5)'. + + A Comparator may replace one or more parameters, for example: + # return at most 10 rows + mock_dao.RunQuery(StrContains('SELECT'), 10) + + or + + # Return some non-deterministic number of rows + mock_dao.RunQuery(StrContains('SELECT'), IsA(int)) + """ + + def equals(self, rhs): + """Special equals method that all comparators must implement. + + Args: + rhs: any python object + """ + + raise NotImplementedError, 'method must be implemented by a subclass.' + + def __eq__(self, rhs): + return self.equals(rhs) + + def __ne__(self, rhs): + return not self.equals(rhs) + + +class IsA(Comparator): + """This class wraps a basic Python type or class. It is used to verify + that a parameter is of the given type or class. + + Example: + mock_dao.Connect(IsA(DbConnectInfo)) + """ + + def __init__(self, class_name): + """Initialize IsA + + Args: + class_name: basic python type or a class + """ + + self._class_name = class_name + + def equals(self, rhs): + """Check to see if the RHS is an instance of class_name. + + Args: + # rhs: the right hand side of the test + rhs: object + + Returns: + bool + """ + + try: + return isinstance(rhs, self._class_name) + except TypeError: + # Check raw types if there was a type error. This is helpful for + # things like cStringIO.StringIO. + return type(rhs) == type(self._class_name) + + def __repr__(self): + return str(self._class_name) + +class IsAlmost(Comparator): + """Comparison class used to check whether a parameter is nearly equal + to a given value. Generally useful for floating point numbers. + + Example mock_dao.SetTimeout((IsAlmost(3.9))) + """ + + def __init__(self, float_value, places=7): + """Initialize IsAlmost. + + Args: + float_value: The value for making the comparison. + places: The number of decimal places to round to. + """ + + self._float_value = float_value + self._places = places + + def equals(self, rhs): + """Check to see if RHS is almost equal to float_value + + Args: + rhs: the value to compare to float_value + + Returns: + bool + """ + + try: + return round(rhs-self._float_value, self._places) == 0 + except TypeError: + # This is probably because either float_value or rhs is not a number. + return False + + def __repr__(self): + return str(self._float_value) + +class StrContains(Comparator): + """Comparison class used to check whether a substring exists in a + string parameter. This can be useful in mocking a database with SQL + passed in as a string parameter, for example. + + Example: + mock_dao.RunQuery(StrContains('IN (1, 2, 4, 5)')).AndReturn(mock_result) + """ + + def __init__(self, search_string): + """Initialize. + + Args: + # search_string: the string you are searching for + search_string: str + """ + + self._search_string = search_string + + def equals(self, rhs): + """Check to see if the search_string is contained in the rhs string. + + Args: + # rhs: the right hand side of the test + rhs: object + + Returns: + bool + """ + + try: + return rhs.find(self._search_string) > -1 + except Exception: + return False + + def __repr__(self): + return '<str containing \'%s\'>' % self._search_string + + +class Regex(Comparator): + """Checks if a string matches a regular expression. + + This uses a given regular expression to determine equality. + """ + + def __init__(self, pattern, flags=0): + """Initialize. + + Args: + # pattern is the regular expression to search for + pattern: str + # flags passed to re.compile function as the second argument + flags: int + """ + + self.regex = re.compile(pattern, flags=flags) + + def equals(self, rhs): + """Check to see if rhs matches regular expression pattern. + + Returns: + bool + """ + + return self.regex.search(rhs) is not None + + def __repr__(self): + s = '<regular expression \'%s\'' % self.regex.pattern + if self.regex.flags: + s += ', flags=%d' % self.regex.flags + s += '>' + return s + + +class In(Comparator): + """Checks whether an item (or key) is in a list (or dict) parameter. + + Example: + mock_dao.GetUsersInfo(In('expectedUserName')).AndReturn(mock_result) + """ + + def __init__(self, key): + """Initialize. + + Args: + # key is any thing that could be in a list or a key in a dict + """ + + self._key = key + + def equals(self, rhs): + """Check to see whether key is in rhs. + + Args: + rhs: dict + + Returns: + bool + """ + + return self._key in rhs + + def __repr__(self): + return '<sequence or map containing \'%s\'>' % self._key + + +class ContainsKeyValue(Comparator): + """Checks whether a key/value pair is in a dict parameter. + + Example: + mock_dao.UpdateUsers(ContainsKeyValue('stevepm', stevepm_user_info)) + """ + + def __init__(self, key, value): + """Initialize. + + Args: + # key: a key in a dict + # value: the corresponding value + """ + + self._key = key + self._value = value + + def equals(self, rhs): + """Check whether the given key/value pair is in the rhs dict. + + Returns: + bool + """ + + try: + return rhs[self._key] == self._value + except Exception: + return False + + def __repr__(self): + return '<map containing the entry \'%s: %s\'>' % (self._key, self._value) + + +class SameElementsAs(Comparator): + """Checks whether iterables contain the same elements (ignoring order). + + Example: + mock_dao.ProcessUsers(SameElementsAs('stevepm', 'salomaki')) + """ + + def __init__(self, expected_seq): + """Initialize. + + Args: + expected_seq: a sequence + """ + + self._expected_seq = expected_seq + + def equals(self, actual_seq): + """Check to see whether actual_seq has same elements as expected_seq. + + Args: + actual_seq: sequence + + Returns: + bool + """ + + try: + expected = dict([(element, None) for element in self._expected_seq]) + actual = dict([(element, None) for element in actual_seq]) + except TypeError: + # Fall back to slower list-compare if any of the objects are unhashable. + expected = list(self._expected_seq) + actual = list(actual_seq) + expected.sort() + actual.sort() + return expected == actual + + def __repr__(self): + return '<sequence with same elements as \'%s\'>' % self._expected_seq + + +class And(Comparator): + """Evaluates one or more Comparators on RHS and returns an AND of the results. + """ + + def __init__(self, *args): + """Initialize. + + Args: + *args: One or more Comparator + """ + + self._comparators = args + + def equals(self, rhs): + """Checks whether all Comparators are equal to rhs. + + Args: + # rhs: can be anything + + Returns: + bool + """ + + for comparator in self._comparators: + if not comparator.equals(rhs): + return False + + return True + + def __repr__(self): + return '<AND %s>' % str(self._comparators) + + +class Or(Comparator): + """Evaluates one or more Comparators on RHS and returns an OR of the results. + """ + + def __init__(self, *args): + """Initialize. + + Args: + *args: One or more Mox comparators + """ + + self._comparators = args + + def equals(self, rhs): + """Checks whether any Comparator is equal to rhs. + + Args: + # rhs: can be anything + + Returns: + bool + """ + + for comparator in self._comparators: + if comparator.equals(rhs): + return True + + return False + + def __repr__(self): + return '<OR %s>' % str(self._comparators) + + +class Func(Comparator): + """Call a function that should verify the parameter passed in is correct. + + You may need the ability to perform more advanced operations on the parameter + in order to validate it. You can use this to have a callable validate any + parameter. The callable should return either True or False. + + + Example: + + def myParamValidator(param): + # Advanced logic here + return True + + mock_dao.DoSomething(Func(myParamValidator), true) + """ + + def __init__(self, func): + """Initialize. + + Args: + func: callable that takes one parameter and returns a bool + """ + + self._func = func + + def equals(self, rhs): + """Test whether rhs passes the function test. + + rhs is passed into func. + + Args: + rhs: any python object + + Returns: + the result of func(rhs) + """ + + return self._func(rhs) + + def __repr__(self): + return str(self._func) + + +class IgnoreArg(Comparator): + """Ignore an argument. + + This can be used when we don't care about an argument of a method call. + + Example: + # Check if CastMagic is called with 3 as first arg and 'disappear' as third. + mymock.CastMagic(3, IgnoreArg(), 'disappear') + """ + + def equals(self, unused_rhs): + """Ignores arguments and returns True. + + Args: + unused_rhs: any python object + + Returns: + always returns True + """ + + return True + + def __repr__(self): + return '<IgnoreArg>' + + +class MethodGroup(object): + """Base class containing common behaviour for MethodGroups.""" + + def __init__(self, group_name): + self._group_name = group_name + + def group_name(self): + return self._group_name + + def __str__(self): + return '<%s "%s">' % (self.__class__.__name__, self._group_name) + + def AddMethod(self, mock_method): + raise NotImplementedError + + def MethodCalled(self, mock_method): + raise NotImplementedError + + def IsSatisfied(self): + raise NotImplementedError + +class UnorderedGroup(MethodGroup): + """UnorderedGroup holds a set of method calls that may occur in any order. + + This construct is helpful for non-deterministic events, such as iterating + over the keys of a dict. + """ + + def __init__(self, group_name): + super(UnorderedGroup, self).__init__(group_name) + self._methods = [] + + def AddMethod(self, mock_method): + """Add a method to this group. + + Args: + mock_method: A mock method to be added to this group. + """ + + self._methods.append(mock_method) + + def MethodCalled(self, mock_method): + """Remove a method call from the group. + + If the method is not in the set, an UnexpectedMethodCallError will be + raised. + + Args: + mock_method: a mock method that should be equal to a method in the group. + + Returns: + The mock method from the group + + Raises: + UnexpectedMethodCallError if the mock_method was not in the group. + """ + + # Check to see if this method exists, and if so, remove it from the set + # and return it. + for method in self._methods: + if method == mock_method: + # Remove the called mock_method instead of the method in the group. + # The called method will match any comparators when equality is checked + # during removal. The method in the group could pass a comparator to + # another comparator during the equality check. + self._methods.remove(mock_method) + + # If this group is not empty, put it back at the head of the queue. + if not self.IsSatisfied(): + mock_method._call_queue.appendleft(self) + + return self, method + + raise UnexpectedMethodCallError(mock_method, self) + + def IsSatisfied(self): + """Return True if there are not any methods in this group.""" + + return len(self._methods) == 0 + + +class MultipleTimesGroup(MethodGroup): + """MultipleTimesGroup holds methods that may be called any number of times. + + Note: Each method must be called at least once. + + This is helpful, if you don't know or care how many times a method is called. + """ + + def __init__(self, group_name): + super(MultipleTimesGroup, self).__init__(group_name) + self._methods = set() + self._methods_called = set() + + def AddMethod(self, mock_method): + """Add a method to this group. + + Args: + mock_method: A mock method to be added to this group. + """ + + self._methods.add(mock_method) + + def MethodCalled(self, mock_method): + """Remove a method call from the group. + + If the method is not in the set, an UnexpectedMethodCallError will be + raised. + + Args: + mock_method: a mock method that should be equal to a method in the group. + + Returns: + The mock method from the group + + Raises: + UnexpectedMethodCallError if the mock_method was not in the group. + """ + + # Check to see if this method exists, and if so add it to the set of + # called methods. + + for method in self._methods: + if method == mock_method: + self._methods_called.add(mock_method) + # Always put this group back on top of the queue, because we don't know + # when we are done. + mock_method._call_queue.appendleft(self) + return self, method + + if self.IsSatisfied(): + next_method = mock_method._PopNextMethod(); + return next_method, None + else: + raise UnexpectedMethodCallError(mock_method, self) + + def IsSatisfied(self): + """Return True if all methods in this group are called at least once.""" + # NOTE(psycho): We can't use the simple set difference here because we want + # to match different parameters which are considered the same e.g. IsA(str) + # and some string. This solution is O(n^2) but n should be small. + tmp = self._methods.copy() + for called in self._methods_called: + for expected in tmp: + if called == expected: + tmp.remove(expected) + if not tmp: + return True + break + return False + + +class MoxMetaTestBase(type): + """Metaclass to add mox cleanup and verification to every test. + + As the mox unit testing class is being constructed (MoxTestBase or a + subclass), this metaclass will modify all test functions to call the + CleanUpMox method of the test class after they finish. This means that + unstubbing and verifying will happen for every test with no additional code, + and any failures will result in test failures as opposed to errors. + """ + + def __init__(cls, name, bases, d): + type.__init__(cls, name, bases, d) + + # also get all the attributes from the base classes to account + # for a case when test class is not the immediate child of MoxTestBase + for base in bases: + for attr_name in dir(base): + d[attr_name] = getattr(base, attr_name) + + for func_name, func in d.items(): + if func_name.startswith('test') and callable(func): + setattr(cls, func_name, MoxMetaTestBase.CleanUpTest(cls, func)) + + @staticmethod + def CleanUpTest(cls, func): + """Adds Mox cleanup code to any MoxTestBase method. + + Always unsets stubs after a test. Will verify all mocks for tests that + otherwise pass. + + Args: + cls: MoxTestBase or subclass; the class whose test method we are altering. + func: method; the method of the MoxTestBase test class we wish to alter. + + Returns: + The modified method. + """ + def new_method(self, *args, **kwargs): + mox_obj = getattr(self, 'mox', None) + cleanup_mox = False + if mox_obj and isinstance(mox_obj, Mox): + cleanup_mox = True + try: + func(self, *args, **kwargs) + finally: + if cleanup_mox: + mox_obj.UnsetStubs() + if cleanup_mox: + mox_obj.VerifyAll() + new_method.__name__ = func.__name__ + new_method.__doc__ = func.__doc__ + new_method.__module__ = func.__module__ + return new_method + + +class MoxTestBase(unittest.TestCase): + """Convenience test class to make stubbing easier. + + Sets up a "mox" attribute which is an instance of Mox - any mox tests will + want this. Also automatically unsets any stubs and verifies that all mock + methods have been called at the end of each test, eliminating boilerplate + code. + """ + + __metaclass__ = MoxMetaTestBase + + def setUp(self): + self.mox = Mox() diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/setup.py b/mp/src/thirdparty/protobuf-2.3.0/python/setup.py index 831c9cc4..7242dae2 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/setup.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/setup.py @@ -1,127 +1,127 @@ -#! /usr/bin/python
-#
-# See README for usage instructions.
-
-# We must use setuptools, not distutils, because we need to use the
-# namespace_packages option for the "google" package.
-from ez_setup import use_setuptools
-use_setuptools()
-
-from setuptools import setup
-from distutils.spawn import find_executable
-import sys
-import os
-import subprocess
-
-maintainer_email = "[email protected]"
-
-# Find the Protocol Compiler.
-if os.path.exists("../src/protoc"):
- protoc = "../src/protoc"
-elif os.path.exists("../src/protoc.exe"):
- protoc = "../src/protoc.exe"
-else:
- protoc = find_executable("protoc")
-
-def generate_proto(source):
- """Invokes the Protocol Compiler to generate a _pb2.py from the given
- .proto file. Does nothing if the output already exists and is newer than
- the input."""
-
- output = source.replace(".proto", "_pb2.py").replace("../src/", "")
-
- if not os.path.exists(source):
- print "Can't find required file: " + source
- sys.exit(-1)
-
- if (not os.path.exists(output) or
- (os.path.exists(source) and
- os.path.getmtime(source) > os.path.getmtime(output))):
- print "Generating %s..." % output
-
- if protoc == None:
- sys.stderr.write(
- "protoc is not installed nor found in ../src. Please compile it "
- "or install the binary package.\n")
- sys.exit(-1)
-
- protoc_command = [ protoc, "-I../src", "-I.", "--python_out=.", source ]
- if subprocess.call(protoc_command) != 0:
- sys.exit(-1)
-
-def MakeTestSuite():
- # This is apparently needed on some systems to make sure that the tests
- # work even if a previous version is already installed.
- if 'google' in sys.modules:
- del sys.modules['google']
-
- generate_proto("../src/google/protobuf/unittest.proto")
- generate_proto("../src/google/protobuf/unittest_import.proto")
- generate_proto("../src/google/protobuf/unittest_mset.proto")
- generate_proto("../src/google/protobuf/unittest_no_generic_services.proto")
- generate_proto("google/protobuf/internal/more_extensions.proto")
- generate_proto("google/protobuf/internal/more_messages.proto")
-
- import unittest
- import google.protobuf.internal.generator_test as generator_test
- import google.protobuf.internal.descriptor_test as descriptor_test
- import google.protobuf.internal.reflection_test as reflection_test
- import google.protobuf.internal.service_reflection_test \
- as service_reflection_test
- import google.protobuf.internal.text_format_test as text_format_test
- import google.protobuf.internal.wire_format_test as wire_format_test
-
- loader = unittest.defaultTestLoader
- suite = unittest.TestSuite()
- for test in [ generator_test,
- descriptor_test,
- reflection_test,
- service_reflection_test,
- text_format_test,
- wire_format_test ]:
- suite.addTest(loader.loadTestsFromModule(test))
-
- return suite
-
-if __name__ == '__main__':
- # TODO(kenton): Integrate this into setuptools somehow?
- if len(sys.argv) >= 2 and sys.argv[1] == "clean":
- # Delete generated _pb2.py files and .pyc files in the code tree.
- for (dirpath, dirnames, filenames) in os.walk("."):
- for filename in filenames:
- filepath = os.path.join(dirpath, filename)
- if filepath.endswith("_pb2.py") or filepath.endswith(".pyc"):
- os.remove(filepath)
- else:
- # Generate necessary .proto file if it doesn't exist.
- # TODO(kenton): Maybe we should hook this into a distutils command?
- generate_proto("../src/google/protobuf/descriptor.proto")
-
- setup(name = 'protobuf',
- version = '2.3.0',
- packages = [ 'google' ],
- namespace_packages = [ 'google' ],
- test_suite = 'setup.MakeTestSuite',
- # Must list modules explicitly so that we don't install tests.
- py_modules = [
- 'google.protobuf.internal.containers',
- 'google.protobuf.internal.decoder',
- 'google.protobuf.internal.encoder',
- 'google.protobuf.internal.message_listener',
- 'google.protobuf.internal.type_checkers',
- 'google.protobuf.internal.wire_format',
- 'google.protobuf.descriptor',
- 'google.protobuf.descriptor_pb2',
- 'google.protobuf.message',
- 'google.protobuf.reflection',
- 'google.protobuf.service',
- 'google.protobuf.service_reflection',
- 'google.protobuf.text_format' ],
- url = 'http://code.google.com/p/protobuf/',
- maintainer = maintainer_email,
- maintainer_email = '[email protected]',
- license = 'New BSD License',
- description = 'Protocol Buffers',
- long_description =
- "Protocol Buffers are Google's data interchange format.",
- )
+#! /usr/bin/python +# +# See README for usage instructions. + +# We must use setuptools, not distutils, because we need to use the +# namespace_packages option for the "google" package. +from ez_setup import use_setuptools +use_setuptools() + +from setuptools import setup +from distutils.spawn import find_executable +import sys +import os +import subprocess + +maintainer_email = "[email protected]" + +# Find the Protocol Compiler. +if os.path.exists("../src/protoc"): + protoc = "../src/protoc" +elif os.path.exists("../src/protoc.exe"): + protoc = "../src/protoc.exe" +else: + protoc = find_executable("protoc") + +def generate_proto(source): + """Invokes the Protocol Compiler to generate a _pb2.py from the given + .proto file. Does nothing if the output already exists and is newer than + the input.""" + + output = source.replace(".proto", "_pb2.py").replace("../src/", "") + + if not os.path.exists(source): + print "Can't find required file: " + source + sys.exit(-1) + + if (not os.path.exists(output) or + (os.path.exists(source) and + os.path.getmtime(source) > os.path.getmtime(output))): + print "Generating %s..." % output + + if protoc == None: + sys.stderr.write( + "protoc is not installed nor found in ../src. Please compile it " + "or install the binary package.\n") + sys.exit(-1) + + protoc_command = [ protoc, "-I../src", "-I.", "--python_out=.", source ] + if subprocess.call(protoc_command) != 0: + sys.exit(-1) + +def MakeTestSuite(): + # This is apparently needed on some systems to make sure that the tests + # work even if a previous version is already installed. + if 'google' in sys.modules: + del sys.modules['google'] + + generate_proto("../src/google/protobuf/unittest.proto") + generate_proto("../src/google/protobuf/unittest_import.proto") + generate_proto("../src/google/protobuf/unittest_mset.proto") + generate_proto("../src/google/protobuf/unittest_no_generic_services.proto") + generate_proto("google/protobuf/internal/more_extensions.proto") + generate_proto("google/protobuf/internal/more_messages.proto") + + import unittest + import google.protobuf.internal.generator_test as generator_test + import google.protobuf.internal.descriptor_test as descriptor_test + import google.protobuf.internal.reflection_test as reflection_test + import google.protobuf.internal.service_reflection_test \ + as service_reflection_test + import google.protobuf.internal.text_format_test as text_format_test + import google.protobuf.internal.wire_format_test as wire_format_test + + loader = unittest.defaultTestLoader + suite = unittest.TestSuite() + for test in [ generator_test, + descriptor_test, + reflection_test, + service_reflection_test, + text_format_test, + wire_format_test ]: + suite.addTest(loader.loadTestsFromModule(test)) + + return suite + +if __name__ == '__main__': + # TODO(kenton): Integrate this into setuptools somehow? + if len(sys.argv) >= 2 and sys.argv[1] == "clean": + # Delete generated _pb2.py files and .pyc files in the code tree. + for (dirpath, dirnames, filenames) in os.walk("."): + for filename in filenames: + filepath = os.path.join(dirpath, filename) + if filepath.endswith("_pb2.py") or filepath.endswith(".pyc"): + os.remove(filepath) + else: + # Generate necessary .proto file if it doesn't exist. + # TODO(kenton): Maybe we should hook this into a distutils command? + generate_proto("../src/google/protobuf/descriptor.proto") + + setup(name = 'protobuf', + version = '2.3.0', + packages = [ 'google' ], + namespace_packages = [ 'google' ], + test_suite = 'setup.MakeTestSuite', + # Must list modules explicitly so that we don't install tests. + py_modules = [ + 'google.protobuf.internal.containers', + 'google.protobuf.internal.decoder', + 'google.protobuf.internal.encoder', + 'google.protobuf.internal.message_listener', + 'google.protobuf.internal.type_checkers', + 'google.protobuf.internal.wire_format', + 'google.protobuf.descriptor', + 'google.protobuf.descriptor_pb2', + 'google.protobuf.message', + 'google.protobuf.reflection', + 'google.protobuf.service', + 'google.protobuf.service_reflection', + 'google.protobuf.text_format' ], + url = 'http://code.google.com/p/protobuf/', + maintainer = maintainer_email, + maintainer_email = '[email protected]', + license = 'New BSD License', + description = 'Protocol Buffers', + long_description = + "Protocol Buffers are Google's data interchange format.", + ) diff --git a/mp/src/thirdparty/protobuf-2.3.0/python/stubout.py b/mp/src/thirdparty/protobuf-2.3.0/python/stubout.py index 44dd2363..aee4f2da 100644 --- a/mp/src/thirdparty/protobuf-2.3.0/python/stubout.py +++ b/mp/src/thirdparty/protobuf-2.3.0/python/stubout.py @@ -1,140 +1,140 @@ -#!/usr/bin/python2.4
-#
-# Copyright 2008 Google Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# This file is used for testing. The original is at:
-# http://code.google.com/p/pymox/
-
-class StubOutForTesting:
- """Sample Usage:
- You want os.path.exists() to always return true during testing.
-
- stubs = StubOutForTesting()
- stubs.Set(os.path, 'exists', lambda x: 1)
- ...
- stubs.UnsetAll()
-
- The above changes os.path.exists into a lambda that returns 1. Once
- the ... part of the code finishes, the UnsetAll() looks up the old value
- of os.path.exists and restores it.
-
- """
- def __init__(self):
- self.cache = []
- self.stubs = []
-
- def __del__(self):
- self.SmartUnsetAll()
- self.UnsetAll()
-
- def SmartSet(self, obj, attr_name, new_attr):
- """Replace obj.attr_name with new_attr. This method is smart and works
- at the module, class, and instance level while preserving proper
- inheritance. It will not stub out C types however unless that has been
- explicitly allowed by the type.
-
- This method supports the case where attr_name is a staticmethod or a
- classmethod of obj.
-
- Notes:
- - If obj is an instance, then it is its class that will actually be
- stubbed. Note that the method Set() does not do that: if obj is
- an instance, it (and not its class) will be stubbed.
- - The stubbing is using the builtin getattr and setattr. So, the __get__
- and __set__ will be called when stubbing (TODO: A better idea would
- probably be to manipulate obj.__dict__ instead of getattr() and
- setattr()).
-
- Raises AttributeError if the attribute cannot be found.
- """
- if (inspect.ismodule(obj) or
- (not inspect.isclass(obj) and obj.__dict__.has_key(attr_name))):
- orig_obj = obj
- orig_attr = getattr(obj, attr_name)
-
- else:
- if not inspect.isclass(obj):
- mro = list(inspect.getmro(obj.__class__))
- else:
- mro = list(inspect.getmro(obj))
-
- mro.reverse()
-
- orig_attr = None
-
- for cls in mro:
- try:
- orig_obj = cls
- orig_attr = getattr(obj, attr_name)
- except AttributeError:
- continue
-
- if orig_attr is None:
- raise AttributeError("Attribute not found.")
-
- # Calling getattr() on a staticmethod transforms it to a 'normal' function.
- # We need to ensure that we put it back as a staticmethod.
- old_attribute = obj.__dict__.get(attr_name)
- if old_attribute is not None and isinstance(old_attribute, staticmethod):
- orig_attr = staticmethod(orig_attr)
-
- self.stubs.append((orig_obj, attr_name, orig_attr))
- setattr(orig_obj, attr_name, new_attr)
-
- def SmartUnsetAll(self):
- """Reverses all the SmartSet() calls, restoring things to their original
- definition. Its okay to call SmartUnsetAll() repeatedly, as later calls
- have no effect if no SmartSet() calls have been made.
-
- """
- self.stubs.reverse()
-
- for args in self.stubs:
- setattr(*args)
-
- self.stubs = []
-
- def Set(self, parent, child_name, new_child):
- """Replace child_name's old definition with new_child, in the context
- of the given parent. The parent could be a module when the child is a
- function at module scope. Or the parent could be a class when a class'
- method is being replaced. The named child is set to new_child, while
- the prior definition is saved away for later, when UnsetAll() is called.
-
- This method supports the case where child_name is a staticmethod or a
- classmethod of parent.
- """
- old_child = getattr(parent, child_name)
-
- old_attribute = parent.__dict__.get(child_name)
- if old_attribute is not None and isinstance(old_attribute, staticmethod):
- old_child = staticmethod(old_child)
-
- self.cache.append((parent, old_child, child_name))
- setattr(parent, child_name, new_child)
-
- def UnsetAll(self):
- """Reverses all the Set() calls, restoring things to their original
- definition. Its okay to call UnsetAll() repeatedly, as later calls have
- no effect if no Set() calls have been made.
-
- """
- # Undo calls to Set() in reverse order, in case Set() was called on the
- # same arguments repeatedly (want the original call to be last one undone)
- self.cache.reverse()
-
- for (parent, old_child, child_name) in self.cache:
- setattr(parent, child_name, old_child)
- self.cache = []
+#!/usr/bin/python2.4 +# +# Copyright 2008 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This file is used for testing. The original is at: +# http://code.google.com/p/pymox/ + +class StubOutForTesting: + """Sample Usage: + You want os.path.exists() to always return true during testing. + + stubs = StubOutForTesting() + stubs.Set(os.path, 'exists', lambda x: 1) + ... + stubs.UnsetAll() + + The above changes os.path.exists into a lambda that returns 1. Once + the ... part of the code finishes, the UnsetAll() looks up the old value + of os.path.exists and restores it. + + """ + def __init__(self): + self.cache = [] + self.stubs = [] + + def __del__(self): + self.SmartUnsetAll() + self.UnsetAll() + + def SmartSet(self, obj, attr_name, new_attr): + """Replace obj.attr_name with new_attr. This method is smart and works + at the module, class, and instance level while preserving proper + inheritance. It will not stub out C types however unless that has been + explicitly allowed by the type. + + This method supports the case where attr_name is a staticmethod or a + classmethod of obj. + + Notes: + - If obj is an instance, then it is its class that will actually be + stubbed. Note that the method Set() does not do that: if obj is + an instance, it (and not its class) will be stubbed. + - The stubbing is using the builtin getattr and setattr. So, the __get__ + and __set__ will be called when stubbing (TODO: A better idea would + probably be to manipulate obj.__dict__ instead of getattr() and + setattr()). + + Raises AttributeError if the attribute cannot be found. + """ + if (inspect.ismodule(obj) or + (not inspect.isclass(obj) and obj.__dict__.has_key(attr_name))): + orig_obj = obj + orig_attr = getattr(obj, attr_name) + + else: + if not inspect.isclass(obj): + mro = list(inspect.getmro(obj.__class__)) + else: + mro = list(inspect.getmro(obj)) + + mro.reverse() + + orig_attr = None + + for cls in mro: + try: + orig_obj = cls + orig_attr = getattr(obj, attr_name) + except AttributeError: + continue + + if orig_attr is None: + raise AttributeError("Attribute not found.") + + # Calling getattr() on a staticmethod transforms it to a 'normal' function. + # We need to ensure that we put it back as a staticmethod. + old_attribute = obj.__dict__.get(attr_name) + if old_attribute is not None and isinstance(old_attribute, staticmethod): + orig_attr = staticmethod(orig_attr) + + self.stubs.append((orig_obj, attr_name, orig_attr)) + setattr(orig_obj, attr_name, new_attr) + + def SmartUnsetAll(self): + """Reverses all the SmartSet() calls, restoring things to their original + definition. Its okay to call SmartUnsetAll() repeatedly, as later calls + have no effect if no SmartSet() calls have been made. + + """ + self.stubs.reverse() + + for args in self.stubs: + setattr(*args) + + self.stubs = [] + + def Set(self, parent, child_name, new_child): + """Replace child_name's old definition with new_child, in the context + of the given parent. The parent could be a module when the child is a + function at module scope. Or the parent could be a class when a class' + method is being replaced. The named child is set to new_child, while + the prior definition is saved away for later, when UnsetAll() is called. + + This method supports the case where child_name is a staticmethod or a + classmethod of parent. + """ + old_child = getattr(parent, child_name) + + old_attribute = parent.__dict__.get(child_name) + if old_attribute is not None and isinstance(old_attribute, staticmethod): + old_child = staticmethod(old_child) + + self.cache.append((parent, old_child, child_name)) + setattr(parent, child_name, new_child) + + def UnsetAll(self): + """Reverses all the Set() calls, restoring things to their original + definition. Its okay to call UnsetAll() repeatedly, as later calls have + no effect if no Set() calls have been made. + + """ + # Undo calls to Set() in reverse order, in case Set() was called on the + # same arguments repeatedly (want the original call to be last one undone) + self.cache.reverse() + + for (parent, old_child, child_name) in self.cache: + setattr(parent, child_name, old_child) + self.cache = [] |