2
0

validate.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. """User input parameter validation.
  2. This module handles user input parameter validation
  3. against a provided input model.
  4. Note that the objects in this module do *not* mutate any
  5. arguments. No type version happens here. It is up to another
  6. layer to properly convert arguments to any required types.
  7. Validation Errors
  8. -----------------
  9. """
  10. from kscore.compat import six
  11. import decimal
  12. from datetime import datetime
  13. from kscore.utils import parse_to_aware_datetime
  14. from kscore.exceptions import ParamValidationError
  15. def validate_parameters(params, shape):
  16. """Validates input parameters against a schema.
  17. This is a convenience function that validates parameters against a schema.
  18. You can also instantiate and use the ParamValidator class directly if you
  19. want more control.
  20. If there are any validation errors then a ParamValidationError
  21. will be raised. If there are no validation errors than no exception
  22. is raised and a value of None is returned.
  23. :param params: The user provided input parameters.
  24. :type shape: kscore.model.Shape
  25. :param shape: The schema which the input parameters should
  26. adhere to.
  27. :raise: ParamValidationError
  28. """
  29. validator = ParamValidator()
  30. report = validator.validate(params, shape)
  31. if report.has_errors():
  32. raise ParamValidationError(report=report.generate_report())
  33. def type_check(valid_types):
  34. def _create_type_check_guard(func):
  35. def _on_passes_type_check(self, param, shape, errors, name):
  36. if _type_check(param, errors, name):
  37. return func(self, param, shape, errors, name)
  38. def _type_check(param, errors, name):
  39. if not isinstance(param, valid_types):
  40. valid_type_names = [six.text_type(t) for t in valid_types]
  41. errors.report(name, 'invalid type', param=param,
  42. valid_types=valid_type_names)
  43. return False
  44. return True
  45. return _on_passes_type_check
  46. return _create_type_check_guard
  47. def range_check(name, value, shape, error_type, errors):
  48. failed = False
  49. min_allowed = float('-inf')
  50. max_allowed = float('inf')
  51. if 'min' in shape.metadata:
  52. min_allowed = shape.metadata['min']
  53. if value < min_allowed:
  54. failed = True
  55. if failed:
  56. errors.report(name, error_type, param=value,
  57. valid_range=[min_allowed, max_allowed])
  58. class ValidationErrors(object):
  59. def __init__(self):
  60. self._errors = []
  61. def has_errors(self):
  62. if self._errors:
  63. return True
  64. return False
  65. def generate_report(self):
  66. error_messages = []
  67. for error in self._errors:
  68. error_messages.append(self._format_error(error))
  69. return '\n'.join(error_messages)
  70. def _format_error(self, error):
  71. error_type, name, additional = error
  72. name = self._get_name(name)
  73. if error_type == 'missing required field':
  74. return 'Missing required parameter in %s: "%s"' % (
  75. name, additional['required_name'])
  76. elif error_type == 'unknown field':
  77. return 'Unknown parameter in %s: "%s", must be one of: %s' % (
  78. name, additional['unknown_param'], ', '.join(additional['valid_names']))
  79. elif error_type == 'invalid type':
  80. return 'Invalid type for parameter %s, value: %s, type: %s, valid types: %s' % (
  81. name, additional['param'],
  82. str(type(additional['param'])),
  83. ', '.join(additional['valid_types']))
  84. elif error_type == 'invalid range':
  85. min_allowed = additional['valid_range'][0]
  86. max_allowed = additional['valid_range'][1]
  87. return ('Invalid range for parameter %s, value: %s, valid range: '
  88. '%s-%s' % (name, additional['param'],
  89. min_allowed, max_allowed))
  90. elif error_type == 'invalid length':
  91. min_allowed = additional['valid_range'][0]
  92. max_allowed = additional['valid_range'][1]
  93. return ('Invalid length for parameter %s, value: %s, valid range: '
  94. '%s-%s' % (name, additional['param'],
  95. min_allowed, max_allowed))
  96. def _get_name(self, name):
  97. if not name:
  98. return 'input'
  99. elif name.startswith('.'):
  100. return name[1:]
  101. else:
  102. return name
  103. def report(self, name, reason, **kwargs):
  104. self._errors.append((reason, name, kwargs))
  105. class ParamValidator(object):
  106. """Validates parameters against a shape model."""
  107. def validate(self, params, shape):
  108. """Validate parameters against a shape model.
  109. This method will validate the parameters against a provided shape model.
  110. All errors will be collected before returning to the caller. This means
  111. that this method will not stop at the first error, it will return all
  112. possible errors.
  113. :param params: User provided dict of parameters
  114. :param shape: A shape model describing the expected input.
  115. :return: A list of errors.
  116. """
  117. errors = ValidationErrors()
  118. self._validate(params, shape, errors, name='')
  119. return errors
  120. def _validate(self, params, shape, errors, name):
  121. getattr(self, '_validate_%s' % shape.type_name)(params, shape, errors, name)
  122. @type_check(valid_types=(dict,))
  123. def _validate_structure(self, params, shape, errors, name):
  124. # Validate required fields.
  125. for required_member in shape.metadata.get('required', []):
  126. if required_member not in params:
  127. errors.report(name, 'missing required field',
  128. required_name=required_member, user_params=params)
  129. members = shape.members
  130. known_params = []
  131. # Validate known params.
  132. for param in params:
  133. if param not in members:
  134. errors.report(name, 'unknown field', unknown_param=param,
  135. valid_names=list(members))
  136. else:
  137. known_params.append(param)
  138. # Validate structure members.
  139. for param in known_params:
  140. self._validate(params[param], shape.members[param],
  141. errors, '%s.%s' % (name, param))
  142. @type_check(valid_types=six.string_types)
  143. def _validate_string(self, param, shape, errors, name):
  144. # Validate range. For a string, the min/max contraints
  145. # are of the string length.
  146. # Looks like:
  147. # "WorkflowId":{
  148. # "type":"string",
  149. # "min":1,
  150. # "max":256
  151. # }
  152. range_check(name, len(param), shape, 'invalid length', errors)
  153. @type_check(valid_types=(list, tuple))
  154. def _validate_list(self, param, shape, errors, name):
  155. member_shape = shape.member
  156. range_check(name, len(param), shape, 'invalid length', errors)
  157. for i, item in enumerate(param):
  158. self._validate(item, member_shape, errors, '%s[%s]' % (name, i))
  159. @type_check(valid_types=(dict,))
  160. def _validate_map(self, param, shape, errors, name):
  161. key_shape = shape.key
  162. value_shape = shape.value
  163. for key, value in param.items():
  164. self._validate(key, key_shape, errors, "%s (key: %s)"
  165. % (name, key))
  166. self._validate(value, value_shape, errors, '%s.%s' % (name, key))
  167. @type_check(valid_types=six.integer_types)
  168. def _validate_integer(self, param, shape, errors, name):
  169. range_check(name, param, shape, 'invalid range', errors)
  170. def _validate_blob(self, param, shape, errors, name):
  171. if isinstance(param, (bytes, bytearray, six.text_type)):
  172. return
  173. elif hasattr(param, 'read'):
  174. # File like objects are also allowed for blob types.
  175. return
  176. else:
  177. errors.report(name, 'invalid type', param=param,
  178. valid_types=[str(bytes), str(bytearray),
  179. 'file-like object'])
  180. @type_check(valid_types=(bool,))
  181. def _validate_boolean(self, param, shape, errors, name):
  182. pass
  183. @type_check(valid_types=(float, decimal.Decimal) + six.integer_types)
  184. def _validate_double(self, param, shape, errors, name):
  185. range_check(name, param, shape, 'invalid range', errors)
  186. _validate_float = _validate_double
  187. @type_check(valid_types=six.integer_types)
  188. def _validate_long(self, param, shape, errors, name):
  189. range_check(name, param, shape, 'invalid range', errors)
  190. def _validate_timestamp(self, param, shape, errors, name):
  191. # We don't use @type_check because datetimes are a bit
  192. # more flexible. You can either provide a datetime
  193. # object, or a string that parses to a datetime.
  194. is_valid_type = self._type_check_datetime(param)
  195. if not is_valid_type:
  196. valid_type_names = [six.text_type(datetime), 'timestamp-string']
  197. errors.report(name, 'invalid type', param=param,
  198. valid_types=valid_type_names)
  199. def _type_check_datetime(self, value):
  200. try:
  201. parse_to_aware_datetime(value)
  202. return True
  203. except (TypeError, ValueError, AttributeError):
  204. # Yes, dateutil can sometimes raise an AttributeError
  205. # when parsing timestamps.
  206. return False
  207. class ParamValidationDecorator(object):
  208. def __init__(self, param_validator, serializer):
  209. self._param_validator = param_validator
  210. self._serializer = serializer
  211. def serialize_to_request(self, parameters, operation_model):
  212. input_shape = operation_model.input_shape
  213. if input_shape is not None:
  214. report = self._param_validator.validate(parameters,
  215. operation_model.input_shape)
  216. if report.has_errors():
  217. raise ParamValidationError(report=report.generate_report())
  218. return self._serializer.serialize_to_request(parameters,
  219. operation_model)