# Copyright 2012-2014 ksyun.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You # may not use this file except in compliance with the License. A copy of # the License is located at # # http://www.apache.org/licenses/LICENSE-2.0 # # or in the "license" file accompanying this file. This file is # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. from itertools import tee from six import string_types import jmespath import json import base64 import logging from kscore.exceptions import PaginationError from kscore.compat import zip from kscore.utils import set_value_from_jmespath, merge_dicts log = logging.getLogger(__name__) class PaginatorModel(object): def __init__(self, paginator_config): self._paginator_config = paginator_config['pagination'] def get_paginator(self, operation_name): try: single_paginator_config = self._paginator_config[operation_name] except KeyError: raise ValueError("Paginator for operation does not exist: %s" % operation_name) return single_paginator_config class PageIterator(object): def __init__(self, method, input_token, output_token, more_results, result_keys, non_aggregate_keys, limit_key, max_items, starting_token, page_size, op_kwargs): self._method = method self._op_kwargs = op_kwargs self._input_token = input_token self._output_token = output_token self._more_results = more_results self._result_keys = result_keys self._max_items = max_items self._limit_key = limit_key self._starting_token = starting_token self._page_size = page_size self._op_kwargs = op_kwargs self._resume_token = None self._non_aggregate_key_exprs = non_aggregate_keys self._non_aggregate_part = {} @property def result_keys(self): return self._result_keys @property def resume_token(self): """Token to specify to resume pagination.""" return self._resume_token @resume_token.setter def resume_token(self, value): if not isinstance(value, dict): raise ValueError("Bad starting token: %s" % value) if 'ksc_truncate_amount' in value: token_keys = sorted(self._input_token + ['ksc_truncate_amount']) else: token_keys = sorted(self._input_token) dict_keys = sorted(value.keys()) if token_keys == dict_keys: self._resume_token = base64.b64encode( json.dumps(value).encode('utf-8')).decode('utf-8') else: raise ValueError("Bad starting token: %s" % value) @property def non_aggregate_part(self): return self._non_aggregate_part def __iter__(self): current_kwargs = self._op_kwargs previous_next_token = None next_token = dict((key, None) for key in self._input_token) # The number of items from result_key we've seen so far. total_items = 0 first_request = True primary_result_key = self.result_keys[0] starting_truncation = 0 self._inject_starting_params(current_kwargs) while True: response = self._make_request(current_kwargs) parsed = self._extract_parsed_response(response) if first_request: # The first request is handled differently. We could # possibly have a resume/starting token that tells us where # to index into the retrieved page. if self._starting_token is not None: starting_truncation = self._handle_first_request( parsed, primary_result_key, starting_truncation) first_request = False self._record_non_aggregate_key_values(parsed) current_response = primary_result_key.search(parsed) if current_response is None: current_response = [] num_current_response = len(current_response) truncate_amount = 0 if self._max_items is not None: truncate_amount = (total_items + num_current_response) \ - self._max_items if truncate_amount > 0: self._truncate_response(parsed, primary_result_key, truncate_amount, starting_truncation, next_token) yield response break else: yield response total_items += num_current_response next_token = self._get_next_token(parsed) if all(t is None for t in next_token.values()): break if self._max_items is not None and \ total_items == self._max_items: # We're on a page boundary so we can set the current # next token to be the resume token. self.resume_token = next_token break if previous_next_token is not None and \ previous_next_token == next_token: message = ("The same next token was received " "twice: %s" % next_token) raise PaginationError(message=message) self._inject_token_into_kwargs(current_kwargs, next_token) previous_next_token = next_token def search(self, expression): """Applies a JMESPath expression to a paginator Each page of results is searched using the provided JMESPath expression. If the result is not a list, it is yielded directly. If the result is a list, each element in the result is yielded individually (essentially implementing a flatmap in which the JMESPath search is the mapping function). :type expression: str :param expression: JMESPath expression to apply to each page. :return: Returns an iterator that yields the individual elements of applying a JMESPath expression to each page of results. """ compiled = jmespath.compile(expression) for page in self: results = compiled.search(page) if isinstance(results, list): for element in results: yield element else: # Yield result directly if it is not a list. yield results def _make_request(self, current_kwargs): return self._method(**current_kwargs) def _extract_parsed_response(self, response): return response def _record_non_aggregate_key_values(self, response): non_aggregate_keys = {} for expression in self._non_aggregate_key_exprs: result = expression.search(response) set_value_from_jmespath(non_aggregate_keys, expression.expression, result) self._non_aggregate_part = non_aggregate_keys def _inject_starting_params(self, op_kwargs): # If the user has specified a starting token we need to # inject that into the operation's kwargs. if self._starting_token is not None: # Don't need to do anything special if there is no starting # token specified. next_token = self._parse_starting_token()[0] self._inject_token_into_kwargs(op_kwargs, next_token) if self._page_size is not None: # Pass the page size as the parameter name for limiting # page size, also known as the limit_key. op_kwargs[self._limit_key] = self._page_size def _inject_token_into_kwargs(self, op_kwargs, next_token): for name, token in next_token.items(): if token is None or token == 'None': continue op_kwargs[name] = token def _handle_first_request(self, parsed, primary_result_key, starting_truncation): # If the payload is an array or string, we need to slice into it # and only return the truncated amount. starting_truncation = self._parse_starting_token()[1] all_data = primary_result_key.search(parsed) if isinstance(all_data, (list, string_types)): data = all_data[starting_truncation:] else: data = None set_value_from_jmespath( parsed, primary_result_key.expression, data ) # We also need to truncate any secondary result keys # because they were not truncated in the previous last # response. for token in self.result_keys: if token == primary_result_key: continue sample = token.search(parsed) if isinstance(sample, list): empty_value = [] elif isinstance(sample, string_types): empty_value = '' elif isinstance(sample, (int, float)): empty_value = 0 else: empty_value = None set_value_from_jmespath(parsed, token.expression, empty_value) return starting_truncation def _truncate_response(self, parsed, primary_result_key, truncate_amount, starting_truncation, next_token): original = primary_result_key.search(parsed) if original is None: original = [] amount_to_keep = len(original) - truncate_amount truncated = original[:amount_to_keep] set_value_from_jmespath( parsed, primary_result_key.expression, truncated ) # The issue here is that even though we know how much we've truncated # we need to account for this globally including any starting # left truncation. For example: # Raw response: [0,1,2,3] # Starting index: 1 # Max items: 1 # Starting left truncation: [1, 2, 3] # End right truncation for max items: [1] # However, even though we only kept 1, this is post # left truncation so the next starting index should be 2, not 1 # (left_truncation + amount_to_keep). next_token['ksc_truncate_amount'] = \ amount_to_keep + starting_truncation self.resume_token = next_token def _get_next_token(self, parsed): if self._more_results is not None: if not self._more_results.search(parsed): return {} next_tokens = {} for output_token, input_key in \ zip(self._output_token, self._input_token): next_token = output_token.search(parsed) # We do not want to include any empty strings as actual tokens. # Treat them as None. if next_token: next_tokens[input_key] = next_token else: next_tokens[input_key] = None return next_tokens def result_key_iters(self): teed_results = tee(self, len(self.result_keys)) return [ResultKeyIterator(i, result_key) for i, result_key in zip(teed_results, self.result_keys)] def build_full_result(self): complete_result = {} for response in self: page = response # We want to try to catch operation object pagination # and format correctly for those. They come in the form # of a tuple of two elements: (http_response, parsed_responsed). # We want the parsed_response as that is what the page iterator # uses. We can remove it though once operation objects are removed. if isinstance(response, tuple) and len(response) == 2: page = response[1] # We're incrementally building the full response page # by page. For each page in the response we need to # inject the necessary components from the page # into the complete_result. for result_expression in self.result_keys: # In order to incrementally update a result key # we need to search the existing value from complete_result, # then we need to search the _current_ page for the # current result key value. Then we append the current # value onto the existing value, and re-set that value # as the new value. result_value = result_expression.search(page) if result_value is None: continue existing_value = result_expression.search(complete_result) if existing_value is None: # Set the initial result set_value_from_jmespath( complete_result, result_expression.expression, result_value) continue # Now both result_value and existing_value contain something if isinstance(result_value, list): existing_value.extend(result_value) elif isinstance(result_value, (int, float, string_types)): # Modify the existing result with the sum or concatenation set_value_from_jmespath( complete_result, result_expression.expression, existing_value + result_value) merge_dicts(complete_result, self.non_aggregate_part) if self.resume_token is not None: complete_result['NextToken'] = self.resume_token return complete_result def _parse_starting_token(self): if self._starting_token is None: return None # The starting token is a dict passed as a base64 encoded string. next_token = self._starting_token try: next_token = json.loads( base64.b64decode(next_token).decode('utf-8')) index = 0 if 'ksc_truncate_amount' in next_token: index = next_token.get('ksc_truncate_amount') del next_token['ksc_truncate_amount'] except (ValueError, TypeError): next_token, index = self._parse_starting_token_deprecated() return next_token, index def _parse_starting_token_deprecated(self): """ This handles parsing of old style starting tokens, and attempts to coerce them into the new style. """ log.debug("Attempting to fall back to old starting token parser. For " "token: %s" % self._starting_token) if self._starting_token is None: return None parts = self._starting_token.split('___') next_token = [] index = 0 if len(parts) == len(self._input_token) + 1: try: index = int(parts.pop()) except ValueError: raise ValueError("Bad starting token: %s" % self._starting_token) for part in parts: if part == 'None': next_token.append(None) else: next_token.append(part) return self._convert_deprecated_starting_token(next_token), index def _convert_deprecated_starting_token(self, deprecated_token): """ This attempts to convert a deprecated starting token into the new style. """ len_deprecated_token = len(deprecated_token) len_input_token = len(self._input_token) if len_deprecated_token > len_input_token: raise ValueError("Bad starting token: %s" % self._starting_token) elif len_deprecated_token < len_input_token: log.debug("Old format starting token does not contain all input " "tokens. Setting the rest, in order, as None.") for i in range(len_input_token - len_deprecated_token): deprecated_token.append(None) return dict(zip(self._input_token, deprecated_token)) class Paginator(object): PAGE_ITERATOR_CLS = PageIterator def __init__(self, method, pagination_config): self._method = method self._pagination_cfg = pagination_config self._output_token = self._get_output_tokens(self._pagination_cfg) self._input_token = self._get_input_tokens(self._pagination_cfg) self._more_results = self._get_more_results_token(self._pagination_cfg) self._non_aggregate_keys = self._get_non_aggregate_keys( self._pagination_cfg) self._result_keys = self._get_result_keys(self._pagination_cfg) self._limit_key = self._get_limit_key(self._pagination_cfg) @property def result_keys(self): return self._result_keys def _get_non_aggregate_keys(self, config): keys = [] for key in config.get('non_aggregate_keys', []): keys.append(jmespath.compile(key)) return keys def _get_output_tokens(self, config): output = [] output_token = config['output_token'] if not isinstance(output_token, list): output_token = [output_token] for config in output_token: output.append(jmespath.compile(config)) return output def _get_input_tokens(self, config): input_token = self._pagination_cfg['input_token'] if not isinstance(input_token, list): input_token = [input_token] return input_token def _get_more_results_token(self, config): more_results = config.get('more_results') if more_results is not None: return jmespath.compile(more_results) def _get_result_keys(self, config): result_key = config.get('result_key') if result_key is not None: if not isinstance(result_key, list): result_key = [result_key] result_key = [jmespath.compile(rk) for rk in result_key] return result_key def _get_limit_key(self, config): return config.get('limit_key') def paginate(self, **kwargs): """Create paginator object for an operation. This returns an iterable object. Iterating over this object will yield a single page of a response at a time. """ page_params = self._extract_paging_params(kwargs) return self.PAGE_ITERATOR_CLS( self._method, self._input_token, self._output_token, self._more_results, self._result_keys, self._non_aggregate_keys, self._limit_key, page_params['MaxItems'], page_params['StartingToken'], page_params['PageSize'], kwargs) def _extract_paging_params(self, kwargs): pagination_config = kwargs.pop('PaginationConfig', {}) max_items = pagination_config.get('MaxItems', None) if max_items is not None: max_items = int(max_items) page_size = pagination_config.get('PageSize', None) if page_size is not None: page_size = int(page_size) return { 'MaxItems': max_items, 'StartingToken': pagination_config.get('StartingToken', None), 'PageSize': page_size, } class ResultKeyIterator(object): """Iterates over the results of paginated responses. Each iterator is associated with a single result key. Iterating over this object will give you each element in the result key list. :param pages_iterator: An iterator that will give you pages of results (a ``PageIterator`` class). :param result_key: The JMESPath expression representing the result key. """ def __init__(self, pages_iterator, result_key): self._pages_iterator = pages_iterator self.result_key = result_key def __iter__(self): for page in self._pages_iterator: results = self.result_key.search(page) if results is None: results = [] for result in results: yield result