utils.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805
  1. # Copyright 2012-2014 ksyun.com, Inc. or its affiliates. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License"). You
  4. # may not use this file except in compliance with the License. A copy of
  5. # the License is located at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # or in the "license" file accompanying this file. This file is
  10. # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
  11. # ANY KIND, either express or implied. See the License for the specific
  12. # language governing permissions and limitations under the License.
  13. import re
  14. import logging
  15. import datetime
  16. import hashlib
  17. import binascii
  18. import functools
  19. from six import string_types, text_type
  20. import dateutil.parser
  21. from dateutil.tz import tzlocal, tzutc
  22. from kscore.exceptions import InvalidExpressionError, ConfigNotFound
  23. from kscore.exceptions import InvalidDNSNameError
  24. from kscore.compat import json, quote, zip_longest, urlsplit, urlunsplit
  25. from kscore.vendored import requests
  26. from kscore.compat import OrderedDict
  27. logger = logging.getLogger(__name__)
  28. DEFAULT_METADATA_SERVICE_TIMEOUT = 1
  29. METADATA_SECURITY_CREDENTIALS_URL = (
  30. 'http://iam.api.ksyun.com/latest/meta-data/iam/security-credentials/'
  31. )
  32. # These are chars that do not need to be urlencoded.
  33. # Based on rfc2986, section 2.3
  34. SAFE_CHARS = '-._~'
  35. LABEL_RE = re.compile('[a-z0-9][a-z0-9\-]*[a-z0-9]')
  36. RESTRICTED_REGIONS = [
  37. 'us-gov-west-1',
  38. 'fips-us-gov-west-1',
  39. ]
  40. S3_ACCELERATE_ENDPOINT = 's3.ksyun.com'
  41. class _RetriesExceededError(Exception):
  42. """Internal exception used when the number of retries are exceeded."""
  43. pass
  44. def get_service_module_name(service_model):
  45. """Returns the module name for a service
  46. This is the value used in both the documentation and client class name
  47. """
  48. name = service_model.metadata.get(
  49. 'serviceAbbreviation',
  50. service_model.metadata.get(
  51. 'serviceFullName', service_model.service_name))
  52. name = name.replace('AWS', '')
  53. name = re.sub('\W+', '', name)
  54. return name
  55. def normalize_url_path(path):
  56. if not path:
  57. return '/'
  58. return remove_dot_segments(path)
  59. def remove_dot_segments(url):
  60. # RFC 3986, section 5.2.4 "Remove Dot Segments"
  61. # Also, KSYUN services require consecutive slashes to be removed,
  62. # so that's done here as well
  63. if not url:
  64. return ''
  65. input_url = url.split('/')
  66. output_list = []
  67. for x in input_url:
  68. if x and x != '.':
  69. if x == '..':
  70. if output_list:
  71. output_list.pop()
  72. else:
  73. output_list.append(x)
  74. if url[0] == '/':
  75. first = '/'
  76. else:
  77. first = ''
  78. if url[-1] == '/' and output_list:
  79. last = '/'
  80. else:
  81. last = ''
  82. return first + '/'.join(output_list) + last
  83. def validate_jmespath_for_set(expression):
  84. # Validates a limited jmespath expression to determine if we can set a
  85. # value based on it. Only works with dotted paths.
  86. if not expression or expression == '.':
  87. raise InvalidExpressionError(expression=expression)
  88. for invalid in ['[', ']', '*']:
  89. if invalid in expression:
  90. raise InvalidExpressionError(expression=expression)
  91. def set_value_from_jmespath(source, expression, value, is_first=True):
  92. # This takes a (limited) jmespath-like expression & can set a value based
  93. # on it.
  94. # Limitations:
  95. # * Only handles dotted lookups
  96. # * No offsets/wildcards/slices/etc.
  97. if is_first:
  98. validate_jmespath_for_set(expression)
  99. bits = expression.split('.', 1)
  100. current_key, remainder = bits[0], bits[1] if len(bits) > 1 else ''
  101. if not current_key:
  102. raise InvalidExpressionError(expression=expression)
  103. if remainder:
  104. if current_key not in source:
  105. # We've got something in the expression that's not present in the
  106. # source (new key). If there's any more bits, we'll set the key
  107. # with an empty dictionary.
  108. source[current_key] = {}
  109. return set_value_from_jmespath(
  110. source[current_key],
  111. remainder,
  112. value,
  113. is_first=False
  114. )
  115. # If we're down to a single key, set it.
  116. source[current_key] = value
  117. class InstanceMetadataFetcher(object):
  118. def __init__(self, timeout=DEFAULT_METADATA_SERVICE_TIMEOUT,
  119. num_attempts=1, url=METADATA_SECURITY_CREDENTIALS_URL):
  120. self._timeout = timeout
  121. self._num_attempts = num_attempts
  122. self._url = url
  123. def _get_request(self, url, timeout, num_attempts=1):
  124. for i in range(num_attempts):
  125. try:
  126. response = requests.get(url, timeout=timeout)
  127. except (requests.Timeout, requests.ConnectionError) as e:
  128. logger.debug("Caught exception while trying to retrieve "
  129. "credentials: %s", e, exc_info=True)
  130. else:
  131. if response.status_code == 200:
  132. return response
  133. raise _RetriesExceededError()
  134. def retrieve_iam_role_credentials(self):
  135. data = {}
  136. url = self._url
  137. timeout = self._timeout
  138. num_attempts = self._num_attempts
  139. try:
  140. r = self._get_request(url, timeout, num_attempts)
  141. if r.content:
  142. fields = r.content.decode('utf-8').split('\n')
  143. for field in fields:
  144. if field.endswith('/'):
  145. data[field[0:-1]] = self.retrieve_iam_role_credentials(
  146. url + field, timeout, num_attempts)
  147. else:
  148. val = self._get_request(
  149. url + field,
  150. timeout=timeout,
  151. num_attempts=num_attempts).content.decode('utf-8')
  152. if val[0] == '{':
  153. val = json.loads(val)
  154. data[field] = val
  155. else:
  156. logger.debug("Metadata service returned non 200 status code "
  157. "of %s for url: %s, content body: %s",
  158. r.status_code, url, r.content)
  159. except _RetriesExceededError:
  160. logger.debug("Max number of attempts exceeded (%s) when "
  161. "attempting to retrieve data from metadata service.",
  162. num_attempts)
  163. # We sort for stable ordering. In practice, this should only consist
  164. # of one role, but may need revisiting if this expands in the future.
  165. final_data = {}
  166. for role_name in sorted(data):
  167. final_data = {
  168. 'role_name': role_name,
  169. 'access_key': data[role_name]['AccessKeyId'],
  170. 'secret_key': data[role_name]['SecretAccessKey'],
  171. 'token': data[role_name]['Token'],
  172. 'expiry_time': data[role_name]['Expiration'],
  173. }
  174. return final_data
  175. def merge_dicts(dict1, dict2, append_lists=False):
  176. """Given two dict, merge the second dict into the first.
  177. The dicts can have arbitrary nesting.
  178. :param append_lists: If true, instead of clobbering a list with the new
  179. value, append all of the new values onto the original list.
  180. """
  181. for key in dict2:
  182. if isinstance(dict2[key], dict):
  183. if key in dict1 and key in dict2:
  184. merge_dicts(dict1[key], dict2[key])
  185. else:
  186. dict1[key] = dict2[key]
  187. # If the value is a list and the ``append_lists`` flag is set,
  188. # append the new values onto the original list
  189. elif isinstance(dict2[key], list) and append_lists:
  190. # The value in dict1 must be a list in order to append new
  191. # values onto it.
  192. if key in dict1 and isinstance(dict1[key], list):
  193. dict1[key].extend(dict2[key])
  194. else:
  195. dict1[key] = dict2[key]
  196. else:
  197. # At scalar types, we iterate and merge the
  198. # current dict that we're on.
  199. dict1[key] = dict2[key]
  200. def parse_key_val_file(filename, _open=open):
  201. try:
  202. with _open(filename) as f:
  203. contents = f.read()
  204. return parse_key_val_file_contents(contents)
  205. except OSError:
  206. raise ConfigNotFound(path=filename)
  207. def parse_key_val_file_contents(contents):
  208. # This was originally extracted from the EC2 credential provider, which was
  209. # fairly lenient in its parsing. We only try to parse key/val pairs if
  210. # there's a '=' in the line.
  211. final = {}
  212. for line in contents.splitlines():
  213. if '=' not in line:
  214. continue
  215. key, val = line.split('=', 1)
  216. key = key.strip()
  217. val = val.strip()
  218. final[key] = val
  219. return final
  220. def percent_encode_sequence(mapping, safe=SAFE_CHARS):
  221. """Urlencode a dict or list into a string.
  222. This is similar to urllib.urlencode except that:
  223. * It uses quote, and not quote_plus
  224. * It has a default list of safe chars that don't need
  225. to be encoded, which matches what KSYUN services expect.
  226. If any value in the input ``mapping`` is a list type,
  227. then each list element wil be serialized. This is the equivalent
  228. to ``urlencode``'s ``doseq=True`` argument.
  229. This function should be preferred over the stdlib
  230. ``urlencode()`` function.
  231. :param mapping: Either a dict to urlencode or a list of
  232. ``(key, value)`` pairs.
  233. """
  234. encoded_pairs = []
  235. if hasattr(mapping, 'items'):
  236. pairs = mapping.items()
  237. else:
  238. pairs = mapping
  239. for key, value in pairs:
  240. if isinstance(value, list):
  241. for element in value:
  242. encoded_pairs.append('%s=%s' % (percent_encode(key),
  243. percent_encode(element)))
  244. else:
  245. encoded_pairs.append('%s=%s' % (percent_encode(key),
  246. percent_encode(value)))
  247. return '&'.join(encoded_pairs)
  248. def percent_encode(input_str, safe=SAFE_CHARS):
  249. """Urlencodes a string.
  250. Whereas percent_encode_sequence handles taking a dict/sequence and
  251. producing a percent encoded string, this function deals only with
  252. taking a string (not a dict/sequence) and percent encoding it.
  253. """
  254. if not isinstance(input_str, string_types):
  255. input_str = text_type(input_str)
  256. return quote(text_type(input_str).encode('utf-8'), safe=safe)
  257. def parse_timestamp(value):
  258. """Parse a timestamp into a datetime object.
  259. Supported formats:
  260. * iso8601
  261. * rfc822
  262. * epoch (value is an integer)
  263. This will return a ``datetime.datetime`` object.
  264. """
  265. if isinstance(value, (int, float)):
  266. # Possibly an epoch time.
  267. return datetime.datetime.fromtimestamp(value, tzlocal())
  268. else:
  269. try:
  270. return datetime.datetime.fromtimestamp(float(value), tzlocal())
  271. except (TypeError, ValueError):
  272. pass
  273. try:
  274. return dateutil.parser.parse(value)
  275. except (TypeError, ValueError) as e:
  276. raise ValueError('Invalid timestamp "%s": %s' % (value, e))
  277. def parse_to_aware_datetime(value):
  278. """Converted the passed in value to a datetime object with tzinfo.
  279. This function can be used to normalize all timestamp inputs. This
  280. function accepts a number of different types of inputs, but
  281. will always return a datetime.datetime object with time zone
  282. information.
  283. The input param ``value`` can be one of several types:
  284. * A datetime object (both naive and aware)
  285. * An integer representing the epoch time (can also be a string
  286. of the integer, i.e '0', instead of 0). The epoch time is
  287. considered to be UTC.
  288. * An iso8601 formatted timestamp. This does not need to be
  289. a complete timestamp, it can contain just the date portion
  290. without the time component.
  291. The returned value will be a datetime object that will have tzinfo.
  292. If no timezone info was provided in the input value, then UTC is
  293. assumed, not local time.
  294. """
  295. # This is a general purpose method that handles several cases of
  296. # converting the provided value to a string timestamp suitable to be
  297. # serialized to an http request. It can handle:
  298. # 1) A datetime.datetime object.
  299. if isinstance(value, datetime.datetime):
  300. datetime_obj = value
  301. else:
  302. # 2) A string object that's formatted as a timestamp.
  303. # We document this as being an iso8601 timestamp, although
  304. # parse_timestamp is a bit more flexible.
  305. datetime_obj = parse_timestamp(value)
  306. if datetime_obj.tzinfo is None:
  307. # I think a case would be made that if no time zone is provided,
  308. # we should use the local time. However, to restore backwards
  309. # compat, the previous behavior was to assume UTC, which is
  310. # what we're going to do here.
  311. datetime_obj = datetime_obj.replace(tzinfo=tzutc())
  312. else:
  313. datetime_obj = datetime_obj.astimezone(tzutc())
  314. return datetime_obj
  315. def datetime2timestamp(dt, default_timezone=None):
  316. """Calculate the timestamp based on the given datetime instance.
  317. :type dt: datetime
  318. :param dt: A datetime object to be converted into timestamp
  319. :type default_timezone: tzinfo
  320. :param default_timezone: If it is provided as None, we treat it as tzutc().
  321. But it is only used when dt is a naive datetime.
  322. :returns: The timestamp
  323. """
  324. epoch = datetime.datetime(1970, 1, 1)
  325. if dt.tzinfo is None:
  326. if default_timezone is None:
  327. default_timezone = tzutc()
  328. dt = dt.replace(tzinfo=default_timezone)
  329. d = dt.replace(tzinfo=None) - dt.utcoffset() - epoch
  330. if hasattr(d, "total_seconds"):
  331. return d.total_seconds() # Works in Python 2.7+
  332. return (d.microseconds + (d.seconds + d.days * 24 * 3600) * 10 ** 6) / 10 ** 6
  333. def calculate_sha256(body, as_hex=False):
  334. """Calculate a sha256 checksum.
  335. This method will calculate the sha256 checksum of a file like
  336. object. Note that this method will iterate through the entire
  337. file contents. The caller is responsible for ensuring the proper
  338. starting position of the file and ``seek()``'ing the file back
  339. to its starting location if other consumers need to read from
  340. the file like object.
  341. :param body: Any file like object. The file must be opened
  342. in binary mode such that a ``.read()`` call returns bytes.
  343. :param as_hex: If True, then the hex digest is returned.
  344. If False, then the digest (as binary bytes) is returned.
  345. :returns: The sha256 checksum
  346. """
  347. checksum = hashlib.sha256()
  348. for chunk in iter(lambda: body.read(1024 * 1024), b''):
  349. checksum.update(chunk)
  350. if as_hex:
  351. return checksum.hexdigest()
  352. else:
  353. return checksum.digest()
  354. def calculate_tree_hash(body):
  355. """Calculate a tree hash checksum.
  356. For more information see:
  357. https://github.com/liuyichen/
  358. :param body: Any file like object. This has the same constraints as
  359. the ``body`` param in calculate_sha256
  360. :rtype: str
  361. :returns: The hex version of the calculated tree hash
  362. """
  363. chunks = []
  364. required_chunk_size = 1024 * 1024
  365. sha256 = hashlib.sha256
  366. for chunk in iter(lambda: body.read(required_chunk_size), b''):
  367. chunks.append(sha256(chunk).digest())
  368. if not chunks:
  369. return sha256(b'').hexdigest()
  370. while len(chunks) > 1:
  371. new_chunks = []
  372. for first, second in _in_pairs(chunks):
  373. if second is not None:
  374. new_chunks.append(sha256(first + second).digest())
  375. else:
  376. # We're at the end of the list and there's no pair left.
  377. new_chunks.append(first)
  378. chunks = new_chunks
  379. return binascii.hexlify(chunks[0]).decode('ascii')
  380. def _in_pairs(iterable):
  381. # Creates iterator that iterates over the list in pairs:
  382. # for a, b in _in_pairs([0, 1, 2, 3, 4]):
  383. # print(a, b)
  384. #
  385. # will print:
  386. # 0, 1
  387. # 2, 3
  388. # 4, None
  389. shared_iter = iter(iterable)
  390. # Note that zip_longest is a compat import that uses
  391. # the itertools izip_longest. This creates an iterator,
  392. # this call below does _not_ immediately create the list
  393. # of pairs.
  394. return zip_longest(shared_iter, shared_iter)
  395. class CachedProperty(object):
  396. """A read only property that caches the initially computed value.
  397. This descriptor will only call the provided ``fget`` function once.
  398. Subsequent access to this property will return the cached value.
  399. """
  400. def __init__(self, fget):
  401. self._fget = fget
  402. def __get__(self, obj, cls):
  403. if obj is None:
  404. return self
  405. else:
  406. computed_value = self._fget(obj)
  407. obj.__dict__[self._fget.__name__] = computed_value
  408. return computed_value
  409. class ArgumentGenerator(object):
  410. """Generate sample input based on a shape model.
  411. This class contains a ``generate_skeleton`` method that will take
  412. an input shape (created from ``kscore.model``) and generate
  413. a sample dictionary corresponding to the input shape.
  414. The specific values used are place holder values. For strings an
  415. empty string is used, for numbers 0 or 0.0 is used. The intended
  416. usage of this class is to generate the *shape* of the input structure.
  417. This can be useful for operations that have complex input shapes.
  418. This allows a user to just fill in the necessary data instead of
  419. worrying about the specific structure of the input arguments.
  420. Example usage::
  421. s = kscore.session.get_session()
  422. ddb = s.get_service_model('dynamodb')
  423. arg_gen = ArgumentGenerator()
  424. sample_input = arg_gen.generate_skeleton(
  425. ddb.operation_model('CreateTable').input_shape)
  426. print("Sample input for dynamodb.CreateTable: %s" % sample_input)
  427. """
  428. def __init__(self):
  429. pass
  430. def generate_skeleton(self, shape):
  431. """Generate a sample input.
  432. :type shape: ``kscore.model.Shape``
  433. :param shape: The input shape.
  434. :return: The generated skeleton input corresponding to the
  435. provided input shape.
  436. """
  437. stack = []
  438. return self._generate_skeleton(shape, stack)
  439. def _generate_skeleton(self, shape, stack):
  440. stack.append(shape.name)
  441. try:
  442. if shape.type_name == 'structure':
  443. return self._generate_type_structure(shape, stack)
  444. elif shape.type_name == 'list':
  445. return self._generate_type_list(shape, stack)
  446. elif shape.type_name == 'map':
  447. return self._generate_type_map(shape, stack)
  448. elif shape.type_name == 'string':
  449. return ''
  450. elif shape.type_name in ['integer', 'long']:
  451. return 0
  452. elif shape.type_name == 'float':
  453. return 0.0
  454. elif shape.type_name == 'boolean':
  455. return True
  456. finally:
  457. stack.pop()
  458. def _generate_type_structure(self, shape, stack):
  459. if stack.count(shape.name) > 1:
  460. return {}
  461. skeleton = OrderedDict()
  462. for member_name, member_shape in shape.members.items():
  463. skeleton[member_name] = self._generate_skeleton(member_shape,
  464. stack)
  465. return skeleton
  466. def _generate_type_list(self, shape, stack):
  467. # For list elements we've arbitrarily decided to
  468. # return two elements for the skeleton list.
  469. return [
  470. self._generate_skeleton(shape.member, stack),
  471. ]
  472. def _generate_type_map(self, shape, stack):
  473. key_shape = shape.key
  474. value_shape = shape.value
  475. assert key_shape.type_name == 'string'
  476. return OrderedDict([
  477. ('KeyName', self._generate_skeleton(value_shape, stack)),
  478. ])
  479. def is_valid_endpoint_url(endpoint_url):
  480. """Verify the endpoint_url is valid.
  481. :type endpoint_url: string
  482. :param endpoint_url: An endpoint_url. Must have at least a scheme
  483. and a hostname.
  484. :return: True if the endpoint url is valid. False otherwise.
  485. """
  486. parts = urlsplit(endpoint_url)
  487. hostname = parts.hostname
  488. if hostname is None:
  489. return False
  490. if len(hostname) > 255:
  491. return False
  492. if hostname[-1] == ".":
  493. hostname = hostname[:-1]
  494. allowed = re.compile(
  495. "^((?!-)[A-Z\d-]{1,63}(?<!-)\.)*((?!-)[A-Z\d-]{1,63}(?<!-))$",
  496. re.IGNORECASE)
  497. return allowed.match(hostname)
  498. def check_dns_name(bucket_name):
  499. """
  500. Check to see if the ``bucket_name`` complies with the
  501. restricted DNS naming conventions necessary to allow
  502. access via virtual-hosting style.
  503. Even though "." characters are perfectly valid in this DNS
  504. naming scheme, we are going to punt on any name containing a
  505. "." character because these will cause SSL cert validation
  506. problems if we try to use virtual-hosting style addressing.
  507. """
  508. if '.' in bucket_name:
  509. return False
  510. n = len(bucket_name)
  511. if n < 3 or n > 63:
  512. # Wrong length
  513. return False
  514. if n == 1:
  515. if not bucket_name.isalnum():
  516. return False
  517. match = LABEL_RE.match(bucket_name)
  518. if match is None or match.end() != len(bucket_name):
  519. return False
  520. return True
  521. def fix_s3_host(request, signature_version, region_name, **kwargs):
  522. """
  523. This handler looks at S3 requests just before they are signed.
  524. If there is a bucket name on the path (true for everything except
  525. ListAllBuckets) it checks to see if that bucket name conforms to
  526. the DNS naming conventions. If it does, it alters the request to
  527. use ``virtual hosting`` style addressing rather than ``path-style``
  528. addressing. This allows us to avoid 301 redirects for all
  529. bucket names that can be CNAME'd.
  530. """
  531. # By default we do not use virtual hosted style addressing when
  532. # signed with signature version 4.
  533. if signature_version in ['s3v4', 'v4']:
  534. return
  535. elif not _allowed_region(region_name):
  536. return
  537. try:
  538. switch_to_virtual_host_style(
  539. request, signature_version, 's3.ksyun.com')
  540. except InvalidDNSNameError as e:
  541. bucket_name = e.kwargs['bucket_name']
  542. logger.debug('Not changing URI, bucket is not DNS compatible: %s',
  543. bucket_name)
  544. def switch_to_virtual_host_style(request, signature_version,
  545. default_endpoint_url=None, **kwargs):
  546. """
  547. This is a handler to force virtual host style s3 addressing no matter
  548. the signature version (which is taken in consideration for the default
  549. case). If the bucket is not DNS compatible an InvalidDNSName is thrown.
  550. :param request: A KSRequest object that is about to be sent.
  551. :param signature_version: The signature version to sign with
  552. :param default_endpoint_url: The endpoint to use when switching to a
  553. virtual style. If None is supplied, the virtual host will be
  554. constructed from the url of the request.
  555. """
  556. if request.auth_path is not None:
  557. # The auth_path has already been applied (this may be a
  558. # retried request). We don't need to perform this
  559. # customization again.
  560. return
  561. elif _is_get_bucket_location_request(request):
  562. # For the GetBucketLocation response, we should not be using
  563. # the virtual host style addressing so we can avoid any sigv4
  564. # issues.
  565. logger.debug("Request is GetBucketLocation operation, not checking "
  566. "for DNS compatibility.")
  567. return
  568. parts = urlsplit(request.url)
  569. request.auth_path = parts.path
  570. path_parts = parts.path.split('/')
  571. # Retrieve what the endpoint we will be prepending the bucket name to.
  572. if default_endpoint_url is None:
  573. default_endpoint_url = parts.netloc
  574. if len(path_parts) > 1:
  575. bucket_name = path_parts[1]
  576. if not bucket_name:
  577. # If the bucket name is empty we should not be checking for
  578. # dns compatibility.
  579. return
  580. logger.debug('Checking for DNS compatible bucket for: %s',
  581. request.url)
  582. if check_dns_name(bucket_name):
  583. # If the operation is on a bucket, the auth_path must be
  584. # terminated with a '/' character.
  585. if len(path_parts) == 2:
  586. if request.auth_path[-1] != '/':
  587. request.auth_path += '/'
  588. path_parts.remove(bucket_name)
  589. # At the very least the path must be a '/', such as with the
  590. # CreateBucket operation when DNS style is being used. If this
  591. # is not used you will get an empty path which is incorrect.
  592. path = '/'.join(path_parts) or '/'
  593. global_endpoint = default_endpoint_url
  594. host = bucket_name + '.' + global_endpoint
  595. new_tuple = (parts.scheme, host, path,
  596. parts.query, '')
  597. new_uri = urlunsplit(new_tuple)
  598. request.url = new_uri
  599. logger.debug('URI updated to: %s', new_uri)
  600. else:
  601. raise InvalidDNSNameError(bucket_name=bucket_name)
  602. def _is_get_bucket_location_request(request):
  603. return request.url.endswith('?location')
  604. def _allowed_region(region_name):
  605. return region_name not in RESTRICTED_REGIONS
  606. def instance_cache(func):
  607. """Method decorator for caching method calls to a single instance.
  608. **This is not a general purpose caching decorator.**
  609. In order to use this, you *must* provide an ``_instance_cache``
  610. attribute on the instance.
  611. This decorator is used to cache method calls. The cache is only
  612. scoped to a single instance though such that multiple instances
  613. will maintain their own cache. In order to keep things simple,
  614. this decorator requires that you provide an ``_instance_cache``
  615. attribute on your instance.
  616. """
  617. func_name = func.__name__
  618. @functools.wraps(func)
  619. def _cache_guard(self, *args, **kwargs):
  620. cache_key = (func_name, args)
  621. if kwargs:
  622. kwarg_items = tuple(sorted(kwargs.items()))
  623. cache_key = (func_name, args, kwarg_items)
  624. result = self._instance_cache.get(cache_key)
  625. if result is not None:
  626. return result
  627. result = func(self, *args, **kwargs)
  628. self._instance_cache[cache_key] = result
  629. return result
  630. return _cache_guard
  631. def switch_host_s3_accelerate(request, operation_name, **kwargs):
  632. """Switches the current s3 endpoint with an S3 Accelerate endpoint"""
  633. # Note that when registered the switching of the s3 host happens
  634. # before it gets changed to virtual. So we are not concerned with ensuring
  635. # that the bucket name is translated to the virtual style here and we
  636. # can hard code the Accelerate endpoint.
  637. endpoint = 'https://' + S3_ACCELERATE_ENDPOINT
  638. if operation_name in ['ListBuckets', 'CreateBucket', 'DeleteBucket']:
  639. return
  640. _switch_hosts(request, endpoint, use_new_scheme=False)
  641. def switch_host_with_param(request, param_name):
  642. """Switches the host using a parameter value from a JSON request body"""
  643. request_json = json.loads(request.data.decode('utf-8'))
  644. if request_json.get(param_name):
  645. new_endpoint = request_json[param_name]
  646. _switch_hosts(request, new_endpoint)
  647. def _switch_hosts(request, new_endpoint, use_new_scheme=True):
  648. new_endpoint_components = urlsplit(new_endpoint)
  649. original_endpoint = request.url
  650. original_endpoint_components = urlsplit(original_endpoint)
  651. scheme = original_endpoint_components.scheme
  652. if use_new_scheme:
  653. scheme = new_endpoint_components.scheme
  654. final_endpoint_components = (
  655. scheme,
  656. new_endpoint_components.netloc,
  657. original_endpoint_components.path,
  658. original_endpoint_components.query,
  659. ''
  660. )
  661. final_endpoint = urlunsplit(final_endpoint_components)
  662. logger.debug('Updating URI from %s to %s' % (request.url, final_endpoint))
  663. request.url = final_endpoint
  664. def set_logger_level(level=logging.DEBUG):
  665. for name in logging.Logger.manager.loggerDict.keys():
  666. if name.find('kscore') == 0:
  667. logging.getLogger(name).setLevel(level)