codec.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from .core import encode, decode, alabel, ulabel, IDNAError
  2. import codecs
  3. import re
  4. from typing import Any, Tuple, Optional
  5. _unicode_dots_re = re.compile('[\u002e\u3002\uff0e\uff61]')
  6. class Codec(codecs.Codec):
  7. def encode(self, data: str, errors: str = 'strict') -> Tuple[bytes, int]:
  8. if errors != 'strict':
  9. raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
  10. if not data:
  11. return b"", 0
  12. return encode(data), len(data)
  13. def decode(self, data: bytes, errors: str = 'strict') -> Tuple[str, int]:
  14. if errors != 'strict':
  15. raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
  16. if not data:
  17. return '', 0
  18. return decode(data), len(data)
  19. class IncrementalEncoder(codecs.BufferedIncrementalEncoder):
  20. def _buffer_encode(self, data: str, errors: str, final: bool) -> Tuple[bytes, int]:
  21. if errors != 'strict':
  22. raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
  23. if not data:
  24. return b'', 0
  25. labels = _unicode_dots_re.split(data)
  26. trailing_dot = b''
  27. if labels:
  28. if not labels[-1]:
  29. trailing_dot = b'.'
  30. del labels[-1]
  31. elif not final:
  32. # Keep potentially unfinished label until the next call
  33. del labels[-1]
  34. if labels:
  35. trailing_dot = b'.'
  36. result = []
  37. size = 0
  38. for label in labels:
  39. result.append(alabel(label))
  40. if size:
  41. size += 1
  42. size += len(label)
  43. # Join with U+002E
  44. result_bytes = b'.'.join(result) + trailing_dot
  45. size += len(trailing_dot)
  46. return result_bytes, size
  47. class IncrementalDecoder(codecs.BufferedIncrementalDecoder):
  48. def _buffer_decode(self, data: Any, errors: str, final: bool) -> Tuple[str, int]:
  49. if errors != 'strict':
  50. raise IDNAError('Unsupported error handling \"{}\"'.format(errors))
  51. if not data:
  52. return ('', 0)
  53. if not isinstance(data, str):
  54. data = str(data, 'ascii')
  55. labels = _unicode_dots_re.split(data)
  56. trailing_dot = ''
  57. if labels:
  58. if not labels[-1]:
  59. trailing_dot = '.'
  60. del labels[-1]
  61. elif not final:
  62. # Keep potentially unfinished label until the next call
  63. del labels[-1]
  64. if labels:
  65. trailing_dot = '.'
  66. result = []
  67. size = 0
  68. for label in labels:
  69. result.append(ulabel(label))
  70. if size:
  71. size += 1
  72. size += len(label)
  73. result_str = '.'.join(result) + trailing_dot
  74. size += len(trailing_dot)
  75. return (result_str, size)
  76. class StreamWriter(Codec, codecs.StreamWriter):
  77. pass
  78. class StreamReader(Codec, codecs.StreamReader):
  79. pass
  80. def search_function(name: str) -> Optional[codecs.CodecInfo]:
  81. if name != 'idna2008':
  82. return None
  83. return codecs.CodecInfo(
  84. name=name,
  85. encode=Codec().encode,
  86. decode=Codec().decode,
  87. incrementalencoder=IncrementalEncoder,
  88. incrementaldecoder=IncrementalDecoder,
  89. streamwriter=StreamWriter,
  90. streamreader=StreamReader,
  91. )
  92. codecs.register(search_function)