2
0

__init__.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # Copyright (c) 2012-2013 LiuYC https://github.com/liuyichen/
  2. # Copyright 2012-2014 ksyun.com, Inc. or its affiliates. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License"). You
  5. # may not use this file except in compliance with the License. A copy of
  6. # the License is located at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # or in the "license" file accompanying this file. This file is
  11. # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  12. # ANY KIND, either express or implied. See the License for the specific
  13. # language governing permissions and limitations under the License.
  14. import os
  15. import sys
  16. import mock
  17. import time
  18. import random
  19. import shutil
  20. import contextlib
  21. import tempfile
  22. import binascii
  23. import platform
  24. import select
  25. import datetime
  26. from subprocess import Popen, PIPE
  27. from dateutil.tz import tzlocal
  28. # The unittest module got a significant overhaul
  29. # in 2.7, so if we're in 2.6 we can use the backported
  30. # version unittest2.
  31. if sys.version_info[:2] == (2, 6):
  32. import unittest2 as unittest
  33. else:
  34. import unittest
  35. import kscore.loaders
  36. import kscore.session
  37. from kscore import utils
  38. from kscore import credentials
  39. _LOADER = kscore.loaders.Loader()
  40. def skip_unless_has_memory_collection(cls):
  41. """Class decorator to skip tests that require memory collection.
  42. Any test that uses memory collection (such as the resource leak tests)
  43. can decorate their class with skip_unless_has_memory_collection to
  44. indicate that if the platform does not support memory collection
  45. the tests should be skipped.
  46. """
  47. if platform.system() not in ['Darwin', 'Linux']:
  48. return unittest.skip('Memory tests only supported on mac/linux.')(cls)
  49. return cls
  50. def random_chars(num_chars):
  51. """Returns random hex characters.
  52. Useful for creating resources with random names.
  53. """
  54. return binascii.hexlify(os.urandom(int(num_chars / 2))).decode('ascii')
  55. def create_session(**kwargs):
  56. # Create a Session object. By default,
  57. # the _LOADER object is used as the loader
  58. # so that we reused the same models across tests.
  59. session = kscore.session.Session(**kwargs)
  60. session.register_component('data_loader', _LOADER)
  61. session.set_config_variable('credentials_file', 'noexist/foo/kscore')
  62. return session
  63. @contextlib.contextmanager
  64. def temporary_file(mode):
  65. """This is a cross platform temporary file creation.
  66. tempfile.NamedTemporary file on windows creates a secure temp file
  67. that can't be read by other processes and can't be opened a second time.
  68. For tests, we generally *want* them to be read multiple times.
  69. The test fixture writes the temp file contents, the test reads the
  70. temp file.
  71. """
  72. temporary_directory = tempfile.mkdtemp()
  73. basename = 'tmpfile-%s-%s' % (int(time.time()), random.randint(1, 1000))
  74. full_filename = os.path.join(temporary_directory, basename)
  75. open(full_filename, 'w').close()
  76. try:
  77. with open(full_filename, mode) as f:
  78. yield f
  79. finally:
  80. shutil.rmtree(temporary_directory)
  81. class BaseEnvVar(unittest.TestCase):
  82. def setUp(self):
  83. # Automatically patches out os.environ for you
  84. # and gives you a self.environ attribute that simulates
  85. # the environment. Also will automatically restore state
  86. # for you in tearDown()
  87. self.environ = {}
  88. self.environ_patch = mock.patch('os.environ', self.environ)
  89. self.environ_patch.start()
  90. def tearDown(self):
  91. self.environ_patch.stop()
  92. class BaseSessionTest(BaseEnvVar):
  93. """Base class used to provide credentials.
  94. This class can be used as a base class that want to use a real
  95. session class but want to be completely isolated from the
  96. external environment (including environment variables).
  97. This class will also set credential vars so you can make fake
  98. requests to services.
  99. """
  100. def setUp(self, **environ):
  101. super(BaseSessionTest, self).setUp()
  102. self.environ['AWS_ACCESS_KEY_ID'] = 'access_key'
  103. self.environ['AWS_SECRET_ACCESS_KEY'] = 'secret_key'
  104. self.environ['AWS_CONFIG_FILE'] = 'no-exist-foo'
  105. self.environ.update(environ)
  106. self.session = create_session()
  107. self.session.config_filename = 'no-exist-foo'
  108. @skip_unless_has_memory_collection
  109. class BaseClientDriverTest(unittest.TestCase):
  110. INJECT_DUMMY_CREDS = False
  111. def setUp(self):
  112. self.driver = ClientDriver()
  113. env = None
  114. if self.INJECT_DUMMY_CREDS:
  115. env = {'AWS_ACCESS_KEY_ID': 'foo',
  116. 'AWS_SECRET_ACCESS_KEY': 'bar'}
  117. self.driver.start(env=env)
  118. def cmd(self, *args):
  119. self.driver.cmd(*args)
  120. def send_cmd(self, *args):
  121. self.driver.send_cmd(*args)
  122. def record_memory(self):
  123. self.driver.record_memory()
  124. @property
  125. def memory_samples(self):
  126. return self.driver.memory_samples
  127. def tearDown(self):
  128. self.driver.stop()
  129. class ClientDriver(object):
  130. CLIENT_SERVER = os.path.join(
  131. os.path.dirname(os.path.abspath(__file__)),
  132. 'cmd-runner'
  133. )
  134. def __init__(self):
  135. self._popen = None
  136. self.memory_samples = []
  137. def _get_memory_with_ps(self, pid):
  138. # It would be better to eventually switch to psutil,
  139. # which should allow us to test on windows, but for now
  140. # we'll just use ps and run on POSIX platforms.
  141. command_list = ['ps', '-p', str(pid), '-o', 'rss']
  142. p = Popen(command_list, stdout=PIPE)
  143. stdout = p.communicate()[0]
  144. if not p.returncode == 0:
  145. raise RuntimeError("Could not retrieve memory")
  146. else:
  147. # Get the RSS from output that looks like this:
  148. # RSS
  149. # 4496
  150. return int(stdout.splitlines()[1].split()[0]) * 1024
  151. def record_memory(self):
  152. mem = self._get_memory_with_ps(self._popen.pid)
  153. self.memory_samples.append(mem)
  154. def start(self, env=None):
  155. """Start up the command runner process."""
  156. self._popen = Popen([sys.executable, self.CLIENT_SERVER],
  157. stdout=PIPE, stdin=PIPE, env=env)
  158. def stop(self):
  159. """Shutdown the command runner process."""
  160. self.cmd('exit')
  161. self._popen.wait()
  162. def send_cmd(self, *cmd):
  163. """Send a command and return immediately.
  164. This is a lower level method than cmd().
  165. This method will instruct the cmd-runner process
  166. to execute a command, but this method will
  167. immediately return. You will need to use
  168. ``is_cmd_finished()`` to check that the command
  169. is finished.
  170. This method is useful if you want to record attributes
  171. about the process while an operation is occurring. For
  172. example, if you want to instruct the cmd-runner process
  173. to upload a 1GB file to S3 and you'd like to record
  174. the memory during the upload process, you can use
  175. send_cmd() instead of cmd().
  176. """
  177. cmd_str = ' '.join(cmd) + '\n'
  178. cmd_bytes = cmd_str.encode('utf-8')
  179. self._popen.stdin.write(cmd_bytes)
  180. self._popen.stdin.flush()
  181. def is_cmd_finished(self):
  182. rlist = [self._popen.stdout.fileno()]
  183. result = select.select(rlist, [], [], 0.01)
  184. if result[0]:
  185. return True
  186. return False
  187. def cmd(self, *cmd):
  188. """Send a command and block until it finishes.
  189. This method will send a command to the cmd-runner process
  190. to run. It will block until the cmd-runner process is
  191. finished executing the command and sends back a status
  192. response.
  193. """
  194. self.send_cmd(*cmd)
  195. result = self._popen.stdout.readline().strip()
  196. if result != b'OK':
  197. raise RuntimeError(
  198. "Error from command '%s': %s" % (cmd, result))
  199. # This is added to this file because it's used in both
  200. # the functional and unit tests for cred refresh.
  201. class IntegerRefresher(credentials.RefreshableCredentials):
  202. """Refreshable credentials to help with testing.
  203. This class makes testing refreshable credentials easier.
  204. It has the following functionality:
  205. * A counter, self.refresh_counter, to indicate how many
  206. times refresh was called.
  207. * A way to specify how many seconds to make credentials
  208. valid.
  209. * Configurable advisory/mandatory refresh.
  210. * An easy way to check consistency. Each time creds are
  211. refreshed, all the cred values are set to the next
  212. incrementing integer. Frozen credentials should always
  213. have this value.
  214. """
  215. _advisory_refresh_timeout = 2
  216. _mandatory_refresh_timeout = 1
  217. _credentials_expire = 3
  218. def __init__(self, creds_last_for=_credentials_expire,
  219. advisory_refresh=_advisory_refresh_timeout,
  220. mandatory_refresh=_mandatory_refresh_timeout,
  221. refresh_function=None):
  222. expires_in = (
  223. self._current_datetime() +
  224. datetime.timedelta(seconds=creds_last_for))
  225. if refresh_function is None:
  226. refresh_function = self._do_refresh
  227. super(IntegerRefresher, self).__init__(
  228. '0', '0', '0', expires_in,
  229. refresh_function, 'INTREFRESH')
  230. self.creds_last_for = creds_last_for
  231. self.refresh_counter = 0
  232. self._advisory_refresh_timeout = advisory_refresh
  233. self._mandatory_refresh_timeout = mandatory_refresh
  234. def _do_refresh(self):
  235. self.refresh_counter += 1
  236. current = int(self._access_key)
  237. next_id = str(current + 1)
  238. return {
  239. 'access_key': next_id,
  240. 'secret_key': next_id,
  241. 'token': next_id,
  242. 'expiry_time': self._seconds_later(self.creds_last_for),
  243. }
  244. def _seconds_later(self, num_seconds):
  245. # We need to guarantee at *least* num_seconds.
  246. # Because this doesn't handle subsecond precision
  247. # we'll round up to the next second.
  248. num_seconds += 1
  249. t = self._current_datetime() + datetime.timedelta(seconds=num_seconds)
  250. return self._to_timestamp(t)
  251. def _to_timestamp(self, datetime_obj):
  252. obj = utils.parse_to_aware_datetime(datetime_obj)
  253. return obj.strftime('%Y-%m-%dT%H:%M:%SZ')
  254. def _current_timestamp(self):
  255. return self._to_timestamp(self._current_datetime())
  256. def _current_datetime(self):
  257. return datetime.datetime.now(tzlocal())