Files
ImageUtils/.CondaPkg/env/Lib/site-packages/pybase64/tests/test_pybase64.py
2023-04-02 10:04:46 +07:00

456 lines
12 KiB
Python

import base64
from base64 import encodebytes as b64encodebytes
from binascii import Error as BinAsciiError
import pytest
import pybase64
try:
from pybase64._pybase64 import (
_get_simd_flags_compile,
_get_simd_flags_runtime,
_get_simd_path,
_set_simd_path,
)
_has_extension = True
except ImportError:
_has_extension = False
STD = 0
URL = 1
ALT1 = 2
ALT2 = 3
ALT3 = 4
name_lut = ["standard", "urlsafe", "alternative", "alternative2", "alternative3"]
altchars_lut = [b"+/", b"-_", b"@&", b"+,", b";/"]
enc_helper_lut = [
pybase64.standard_b64encode,
pybase64.urlsafe_b64encode,
None,
None,
None,
]
ref_enc_helper_lut = [
pybase64.standard_b64encode,
pybase64.urlsafe_b64encode,
None,
None,
None,
]
dec_helper_lut = [
pybase64.standard_b64decode,
pybase64.urlsafe_b64decode,
None,
None,
None,
]
ref_dec_helper_lut = [
base64.standard_b64decode,
base64.urlsafe_b64decode,
None,
None,
None,
]
std = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/A"
std = std * (32 * 16)
std_len_minus_12 = len(std) - 12
test_vectors_b64_list = [
# rfc4648 test vectors
b"",
b"Zg==",
b"Zm8=",
b"Zm9v",
b"Zm9vYg==",
b"Zm9vYmE=",
b"Zm9vYmFy",
# custom test vectors
std[std_len_minus_12:],
std,
std[:72],
std[:76],
std[:80],
std[:148],
std[:152],
std[:156],
]
test_vectors_b64 = []
for altchars in altchars_lut:
trans = bytes.maketrans(b"+/", altchars)
test_vectors_b64.append(
[vector.translate(trans) for vector in test_vectors_b64_list]
)
test_vectors_bin = []
for altchars in altchars_lut:
test_vectors_bin.append(
[base64.b64decode(vector, altchars) for vector in test_vectors_b64_list]
)
compile_flags = [0]
runtime_flags = 0
if _has_extension:
runtime_flags = _get_simd_flags_runtime()
flags = _get_simd_flags_compile()
for i in range(31):
if flags & (1 << i):
compile_flags += [(1 << i)]
def get_simd_name(simd_id):
simd_name = None
if _has_extension:
simd_flag = compile_flags[simd_id]
if simd_flag == 0:
simd_name = "c"
elif simd_flag == 4:
simd_name = "ssse3"
elif simd_flag == 8:
simd_name = "sse41"
elif simd_flag == 16:
simd_name = "sse42"
elif simd_flag == 32:
simd_name = "avx"
elif simd_flag == 64:
simd_name = "avx2"
else: # pragma: no branch
simd_name = "unk" # pragma: no cover
else:
simd_name = "py"
return simd_name
def simd_setup(simd_id):
if _has_extension:
flag = compile_flags[simd_id]
if flag != 0 and not flag & runtime_flags: # pragma: no branch
pytest.skip("SIMD extension not available") # pragma: no cover
_set_simd_path(flag)
assert _get_simd_path() == flag
else:
assert 0 == simd_id
param_vector = pytest.mark.parametrize("vector_id", range(len(test_vectors_bin[0])))
param_validate = pytest.mark.parametrize(
"validate", [False, True], ids=["novalidate", "validate"]
)
param_altchars = pytest.mark.parametrize(
"altchars_id", [STD, URL, ALT1, ALT2, ALT3], ids=lambda x: name_lut[x]
)
param_altchars_helper = pytest.mark.parametrize(
"altchars_id", [STD, URL], ids=lambda x: name_lut[x]
)
param_simd = pytest.mark.parametrize(
"simd", range(len(compile_flags)), ids=lambda x: get_simd_name(x), indirect=True
)
param_encode_functions = pytest.mark.parametrize(
"efn, ecast",
[
(pybase64.b64encode, lambda x: x),
(pybase64.b64encode_as_string, lambda x: x.encode("ascii")),
],
)
param_decode_functions = pytest.mark.parametrize(
"dfn, dcast",
[
(pybase64.b64decode, lambda x: x),
(pybase64.b64decode_as_bytearray, lambda x: bytes(x)),
],
)
@pytest.fixture
def simd(request):
simd_setup(request.param)
return request.param
@param_simd
def test_version(simd):
assert pybase64.get_version().startswith(pybase64.__version__)
@param_simd
@param_vector
@param_altchars_helper
def test_enc_helper(altchars_id, vector_id, simd):
vector = test_vectors_bin[altchars_id][vector_id]
test = enc_helper_lut[altchars_id](vector)
base = ref_enc_helper_lut[altchars_id](vector)
assert test == base
@param_simd
@param_vector
@param_altchars_helper
def test_dec_helper(altchars_id, vector_id, simd):
vector = test_vectors_b64[altchars_id][vector_id]
test = dec_helper_lut[altchars_id](vector)
base = ref_dec_helper_lut[altchars_id](vector)
assert test == base
@param_simd
@param_vector
@param_altchars_helper
def test_dec_helper_unicode(altchars_id, vector_id, simd):
vector = test_vectors_b64[altchars_id][vector_id]
test = dec_helper_lut[altchars_id](str(vector, "utf-8"))
base = ref_dec_helper_lut[altchars_id](str(vector, "utf-8"))
assert test == base
@param_simd
@param_vector
@param_altchars_helper
def test_rnd_helper(altchars_id, vector_id, simd):
vector = test_vectors_b64[altchars_id][vector_id]
test = dec_helper_lut[altchars_id](vector)
test = enc_helper_lut[altchars_id](test)
assert test == vector
@param_simd
@param_vector
@param_altchars_helper
def test_rnd_helper_unicode(altchars_id, vector_id, simd):
vector = test_vectors_b64[altchars_id][vector_id]
test = dec_helper_lut[altchars_id](str(vector, "utf-8"))
test = enc_helper_lut[altchars_id](test)
assert test == vector
@param_simd
@param_vector
def test_encbytes(vector_id, simd):
vector = test_vectors_bin[STD][vector_id]
test = pybase64.encodebytes(vector)
base = b64encodebytes(vector)
assert test == base
@param_simd
@param_vector
@param_altchars
@param_encode_functions
def test_enc(efn, ecast, altchars_id, vector_id, simd):
vector = test_vectors_bin[altchars_id][vector_id]
altchars = altchars_lut[altchars_id]
test = ecast(efn(vector, altchars))
base = base64.b64encode(vector, altchars)
assert test == base
@param_simd
@param_vector
@param_altchars
@param_validate
@param_decode_functions
def test_dec(dfn, dcast, altchars_id, vector_id, validate, simd):
vector = test_vectors_b64[altchars_id][vector_id]
altchars = altchars_lut[altchars_id]
if validate:
base = base64.b64decode(vector, altchars, validate)
else:
base = base64.b64decode(vector, altchars)
test = dcast(dfn(vector, altchars, validate))
assert test == base
@param_simd
@param_vector
@param_altchars
@param_validate
@param_decode_functions
def test_dec_unicode(dfn, dcast, altchars_id, vector_id, validate, simd):
vector = test_vectors_b64[altchars_id][vector_id]
vector = str(vector, "utf-8")
altchars = altchars_lut[altchars_id]
if altchars_id == STD:
altchars = None
else:
altchars = str(altchars, "utf-8")
if validate:
base = base64.b64decode(vector, altchars, validate)
else:
base = base64.b64decode(vector, altchars)
test = dcast(dfn(vector, altchars, validate))
assert test == base
@param_simd
@param_vector
@param_altchars
@param_validate
@param_encode_functions
@param_decode_functions
def test_rnd(dfn, dcast, efn, ecast, altchars_id, vector_id, validate, simd):
vector = test_vectors_b64[altchars_id][vector_id]
altchars = altchars_lut[altchars_id]
test = dcast(dfn(vector, altchars, validate))
test = ecast(efn(test, altchars))
assert test == vector
@param_simd
@param_vector
@param_altchars
@param_validate
@param_encode_functions
@param_decode_functions
def test_rnd_unicode(dfn, dcast, efn, ecast, altchars_id, vector_id, validate, simd):
vector = test_vectors_b64[altchars_id][vector_id]
altchars = altchars_lut[altchars_id]
test = dcast(dfn(str(vector, "utf-8"), altchars, validate))
test = ecast(efn(test, altchars))
assert test == vector
@param_simd
@param_vector
@param_altchars
@param_validate
@param_decode_functions
def test_invalid_padding_dec(dfn, dcast, altchars_id, vector_id, validate, simd):
vector = test_vectors_b64[altchars_id][vector_id][1:]
if len(vector) > 0:
altchars = altchars_lut[altchars_id]
with pytest.raises(BinAsciiError):
dfn(vector, altchars, validate)
params_invalid_altchars = [
[b"", AssertionError],
[b"-", AssertionError],
[b"-__", AssertionError],
[3.0, TypeError],
["-€", ValueError],
]
params_invalid_altchars = pytest.mark.parametrize(
"altchars,exception",
params_invalid_altchars,
ids=[str(i) for i in range(len(params_invalid_altchars))],
)
@param_simd
@params_invalid_altchars
@param_encode_functions
def test_invalid_altchars_enc(efn, ecast, altchars, exception, simd):
with pytest.raises(exception):
efn(b"ABCD", altchars)
@param_simd
@params_invalid_altchars
@param_decode_functions
def test_invalid_altchars_dec(dfn, dcast, altchars, exception, simd):
with pytest.raises(exception):
dfn(b"ABCD", altchars)
@param_simd
@params_invalid_altchars
@param_decode_functions
def test_invalid_altchars_dec_validate(dfn, dcast, altchars, exception, simd):
with pytest.raises(exception):
dfn(b"ABCD", altchars, True)
params_invalid_data_novalidate = [
[b"A@@@@FG", None, BinAsciiError],
["ABC€", None, ValueError],
[3.0, None, TypeError],
]
params_invalid_data_validate = [
[b"\x00\x00\x00\x00", None, BinAsciiError],
[b"A@@@@FGHIJKLMNOPQRSTUVWXYZabcdef", b"-_", BinAsciiError],
[b"A@@@=FGHIJKLMNOPQRSTUVWXYZabcdef", b"-_", BinAsciiError],
[b"A@@=@FGHIJKLMNOPQRSTUVWXYZabcdef", b"-_", BinAsciiError],
[b"A@@@@FGHIJKLMNOPQRSTUVWXYZabcde@=", b"-_", BinAsciiError],
[b"A@@@@FGHIJKLMNOPQRSTUVWXYZabcd@==", b"-_", BinAsciiError],
[b"A@@@@FGH" * 10000, b"-_", BinAsciiError],
# [std, b'-_', BinAsciiError], TODO does no fail with base64 module
]
params_invalid_data_all = pytest.mark.parametrize(
"vector,altchars,exception",
params_invalid_data_novalidate + params_invalid_data_validate,
ids=[
str(i)
for i in range(
len(params_invalid_data_novalidate) + len(params_invalid_data_validate)
)
],
)
params_invalid_data_novalidate = pytest.mark.parametrize(
"vector,altchars,exception",
params_invalid_data_novalidate,
ids=[str(i) for i in range(len(params_invalid_data_novalidate))],
)
params_invalid_data_validate = pytest.mark.parametrize(
"vector,altchars,exception",
params_invalid_data_validate,
ids=[str(i) for i in range(len(params_invalid_data_validate))],
)
@param_simd
@params_invalid_data_novalidate
@param_decode_functions
def test_invalid_data_dec(dfn, dcast, vector, altchars, exception, simd):
with pytest.raises(exception):
dfn(vector, altchars)
@param_simd
@params_invalid_data_validate
@param_decode_functions
def test_invalid_data_dec_skip(dfn, dcast, vector, altchars, exception, simd):
test = dcast(dfn(vector, altchars))
base = base64.b64decode(vector, altchars)
assert test == base
@param_simd
@params_invalid_data_all
@param_decode_functions
def test_invalid_data_dec_validate(dfn, dcast, vector, altchars, exception, simd):
with pytest.raises(exception):
dfn(vector, altchars, True)
@param_encode_functions
def test_invalid_data_enc_0(efn, ecast):
with pytest.raises(TypeError):
efn("this is a test")
@param_encode_functions
def test_invalid_args_enc_0(efn, ecast):
with pytest.raises(TypeError):
efn()
@param_decode_functions
def test_invalid_args_dec_0(dfn, dcast):
with pytest.raises(TypeError):
dfn()