456 lines
12 KiB
Python
456 lines
12 KiB
Python
import base64
|
|
from base64 import encodebytes as b64encodebytes
|
|
from binascii import Error as BinAsciiError
|
|
|
|
import pytest
|
|
|
|
import pybase64
|
|
|
|
try:
|
|
from pybase64._pybase64 import (
|
|
_get_simd_flags_compile,
|
|
_get_simd_flags_runtime,
|
|
_get_simd_path,
|
|
_set_simd_path,
|
|
)
|
|
|
|
_has_extension = True
|
|
except ImportError:
|
|
_has_extension = False
|
|
|
|
|
|
STD = 0
|
|
URL = 1
|
|
ALT1 = 2
|
|
ALT2 = 3
|
|
ALT3 = 4
|
|
name_lut = ["standard", "urlsafe", "alternative", "alternative2", "alternative3"]
|
|
altchars_lut = [b"+/", b"-_", b"@&", b"+,", b";/"]
|
|
enc_helper_lut = [
|
|
pybase64.standard_b64encode,
|
|
pybase64.urlsafe_b64encode,
|
|
None,
|
|
None,
|
|
None,
|
|
]
|
|
ref_enc_helper_lut = [
|
|
pybase64.standard_b64encode,
|
|
pybase64.urlsafe_b64encode,
|
|
None,
|
|
None,
|
|
None,
|
|
]
|
|
dec_helper_lut = [
|
|
pybase64.standard_b64decode,
|
|
pybase64.urlsafe_b64decode,
|
|
None,
|
|
None,
|
|
None,
|
|
]
|
|
ref_dec_helper_lut = [
|
|
base64.standard_b64decode,
|
|
base64.urlsafe_b64decode,
|
|
None,
|
|
None,
|
|
None,
|
|
]
|
|
|
|
std = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/A"
|
|
std = std * (32 * 16)
|
|
std_len_minus_12 = len(std) - 12
|
|
|
|
test_vectors_b64_list = [
|
|
# rfc4648 test vectors
|
|
b"",
|
|
b"Zg==",
|
|
b"Zm8=",
|
|
b"Zm9v",
|
|
b"Zm9vYg==",
|
|
b"Zm9vYmE=",
|
|
b"Zm9vYmFy",
|
|
# custom test vectors
|
|
std[std_len_minus_12:],
|
|
std,
|
|
std[:72],
|
|
std[:76],
|
|
std[:80],
|
|
std[:148],
|
|
std[:152],
|
|
std[:156],
|
|
]
|
|
|
|
test_vectors_b64 = []
|
|
for altchars in altchars_lut:
|
|
trans = bytes.maketrans(b"+/", altchars)
|
|
test_vectors_b64.append(
|
|
[vector.translate(trans) for vector in test_vectors_b64_list]
|
|
)
|
|
|
|
test_vectors_bin = []
|
|
for altchars in altchars_lut:
|
|
test_vectors_bin.append(
|
|
[base64.b64decode(vector, altchars) for vector in test_vectors_b64_list]
|
|
)
|
|
|
|
compile_flags = [0]
|
|
runtime_flags = 0
|
|
if _has_extension:
|
|
runtime_flags = _get_simd_flags_runtime()
|
|
flags = _get_simd_flags_compile()
|
|
for i in range(31):
|
|
if flags & (1 << i):
|
|
compile_flags += [(1 << i)]
|
|
|
|
|
|
def get_simd_name(simd_id):
|
|
simd_name = None
|
|
if _has_extension:
|
|
simd_flag = compile_flags[simd_id]
|
|
if simd_flag == 0:
|
|
simd_name = "c"
|
|
elif simd_flag == 4:
|
|
simd_name = "ssse3"
|
|
elif simd_flag == 8:
|
|
simd_name = "sse41"
|
|
elif simd_flag == 16:
|
|
simd_name = "sse42"
|
|
elif simd_flag == 32:
|
|
simd_name = "avx"
|
|
elif simd_flag == 64:
|
|
simd_name = "avx2"
|
|
else: # pragma: no branch
|
|
simd_name = "unk" # pragma: no cover
|
|
else:
|
|
simd_name = "py"
|
|
return simd_name
|
|
|
|
|
|
def simd_setup(simd_id):
|
|
if _has_extension:
|
|
flag = compile_flags[simd_id]
|
|
if flag != 0 and not flag & runtime_flags: # pragma: no branch
|
|
pytest.skip("SIMD extension not available") # pragma: no cover
|
|
_set_simd_path(flag)
|
|
assert _get_simd_path() == flag
|
|
else:
|
|
assert 0 == simd_id
|
|
|
|
|
|
param_vector = pytest.mark.parametrize("vector_id", range(len(test_vectors_bin[0])))
|
|
|
|
|
|
param_validate = pytest.mark.parametrize(
|
|
"validate", [False, True], ids=["novalidate", "validate"]
|
|
)
|
|
|
|
|
|
param_altchars = pytest.mark.parametrize(
|
|
"altchars_id", [STD, URL, ALT1, ALT2, ALT3], ids=lambda x: name_lut[x]
|
|
)
|
|
|
|
|
|
param_altchars_helper = pytest.mark.parametrize(
|
|
"altchars_id", [STD, URL], ids=lambda x: name_lut[x]
|
|
)
|
|
|
|
|
|
param_simd = pytest.mark.parametrize(
|
|
"simd", range(len(compile_flags)), ids=lambda x: get_simd_name(x), indirect=True
|
|
)
|
|
|
|
|
|
param_encode_functions = pytest.mark.parametrize(
|
|
"efn, ecast",
|
|
[
|
|
(pybase64.b64encode, lambda x: x),
|
|
(pybase64.b64encode_as_string, lambda x: x.encode("ascii")),
|
|
],
|
|
)
|
|
|
|
|
|
param_decode_functions = pytest.mark.parametrize(
|
|
"dfn, dcast",
|
|
[
|
|
(pybase64.b64decode, lambda x: x),
|
|
(pybase64.b64decode_as_bytearray, lambda x: bytes(x)),
|
|
],
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def simd(request):
|
|
simd_setup(request.param)
|
|
return request.param
|
|
|
|
|
|
@param_simd
|
|
def test_version(simd):
|
|
assert pybase64.get_version().startswith(pybase64.__version__)
|
|
|
|
|
|
@param_simd
|
|
@param_vector
|
|
@param_altchars_helper
|
|
def test_enc_helper(altchars_id, vector_id, simd):
|
|
vector = test_vectors_bin[altchars_id][vector_id]
|
|
test = enc_helper_lut[altchars_id](vector)
|
|
base = ref_enc_helper_lut[altchars_id](vector)
|
|
assert test == base
|
|
|
|
|
|
@param_simd
|
|
@param_vector
|
|
@param_altchars_helper
|
|
def test_dec_helper(altchars_id, vector_id, simd):
|
|
vector = test_vectors_b64[altchars_id][vector_id]
|
|
test = dec_helper_lut[altchars_id](vector)
|
|
base = ref_dec_helper_lut[altchars_id](vector)
|
|
assert test == base
|
|
|
|
|
|
@param_simd
|
|
@param_vector
|
|
@param_altchars_helper
|
|
def test_dec_helper_unicode(altchars_id, vector_id, simd):
|
|
vector = test_vectors_b64[altchars_id][vector_id]
|
|
test = dec_helper_lut[altchars_id](str(vector, "utf-8"))
|
|
base = ref_dec_helper_lut[altchars_id](str(vector, "utf-8"))
|
|
assert test == base
|
|
|
|
|
|
@param_simd
|
|
@param_vector
|
|
@param_altchars_helper
|
|
def test_rnd_helper(altchars_id, vector_id, simd):
|
|
vector = test_vectors_b64[altchars_id][vector_id]
|
|
test = dec_helper_lut[altchars_id](vector)
|
|
test = enc_helper_lut[altchars_id](test)
|
|
assert test == vector
|
|
|
|
|
|
@param_simd
|
|
@param_vector
|
|
@param_altchars_helper
|
|
def test_rnd_helper_unicode(altchars_id, vector_id, simd):
|
|
vector = test_vectors_b64[altchars_id][vector_id]
|
|
test = dec_helper_lut[altchars_id](str(vector, "utf-8"))
|
|
test = enc_helper_lut[altchars_id](test)
|
|
assert test == vector
|
|
|
|
|
|
@param_simd
|
|
@param_vector
|
|
def test_encbytes(vector_id, simd):
|
|
vector = test_vectors_bin[STD][vector_id]
|
|
test = pybase64.encodebytes(vector)
|
|
base = b64encodebytes(vector)
|
|
assert test == base
|
|
|
|
|
|
@param_simd
|
|
@param_vector
|
|
@param_altchars
|
|
@param_encode_functions
|
|
def test_enc(efn, ecast, altchars_id, vector_id, simd):
|
|
vector = test_vectors_bin[altchars_id][vector_id]
|
|
altchars = altchars_lut[altchars_id]
|
|
test = ecast(efn(vector, altchars))
|
|
base = base64.b64encode(vector, altchars)
|
|
assert test == base
|
|
|
|
|
|
@param_simd
|
|
@param_vector
|
|
@param_altchars
|
|
@param_validate
|
|
@param_decode_functions
|
|
def test_dec(dfn, dcast, altchars_id, vector_id, validate, simd):
|
|
vector = test_vectors_b64[altchars_id][vector_id]
|
|
altchars = altchars_lut[altchars_id]
|
|
if validate:
|
|
base = base64.b64decode(vector, altchars, validate)
|
|
else:
|
|
base = base64.b64decode(vector, altchars)
|
|
test = dcast(dfn(vector, altchars, validate))
|
|
assert test == base
|
|
|
|
|
|
@param_simd
|
|
@param_vector
|
|
@param_altchars
|
|
@param_validate
|
|
@param_decode_functions
|
|
def test_dec_unicode(dfn, dcast, altchars_id, vector_id, validate, simd):
|
|
vector = test_vectors_b64[altchars_id][vector_id]
|
|
vector = str(vector, "utf-8")
|
|
altchars = altchars_lut[altchars_id]
|
|
if altchars_id == STD:
|
|
altchars = None
|
|
else:
|
|
altchars = str(altchars, "utf-8")
|
|
if validate:
|
|
base = base64.b64decode(vector, altchars, validate)
|
|
else:
|
|
base = base64.b64decode(vector, altchars)
|
|
test = dcast(dfn(vector, altchars, validate))
|
|
assert test == base
|
|
|
|
|
|
@param_simd
|
|
@param_vector
|
|
@param_altchars
|
|
@param_validate
|
|
@param_encode_functions
|
|
@param_decode_functions
|
|
def test_rnd(dfn, dcast, efn, ecast, altchars_id, vector_id, validate, simd):
|
|
vector = test_vectors_b64[altchars_id][vector_id]
|
|
altchars = altchars_lut[altchars_id]
|
|
test = dcast(dfn(vector, altchars, validate))
|
|
test = ecast(efn(test, altchars))
|
|
assert test == vector
|
|
|
|
|
|
@param_simd
|
|
@param_vector
|
|
@param_altchars
|
|
@param_validate
|
|
@param_encode_functions
|
|
@param_decode_functions
|
|
def test_rnd_unicode(dfn, dcast, efn, ecast, altchars_id, vector_id, validate, simd):
|
|
vector = test_vectors_b64[altchars_id][vector_id]
|
|
altchars = altchars_lut[altchars_id]
|
|
test = dcast(dfn(str(vector, "utf-8"), altchars, validate))
|
|
test = ecast(efn(test, altchars))
|
|
assert test == vector
|
|
|
|
|
|
@param_simd
|
|
@param_vector
|
|
@param_altchars
|
|
@param_validate
|
|
@param_decode_functions
|
|
def test_invalid_padding_dec(dfn, dcast, altchars_id, vector_id, validate, simd):
|
|
vector = test_vectors_b64[altchars_id][vector_id][1:]
|
|
if len(vector) > 0:
|
|
altchars = altchars_lut[altchars_id]
|
|
with pytest.raises(BinAsciiError):
|
|
dfn(vector, altchars, validate)
|
|
|
|
|
|
params_invalid_altchars = [
|
|
[b"", AssertionError],
|
|
[b"-", AssertionError],
|
|
[b"-__", AssertionError],
|
|
[3.0, TypeError],
|
|
["-€", ValueError],
|
|
]
|
|
params_invalid_altchars = pytest.mark.parametrize(
|
|
"altchars,exception",
|
|
params_invalid_altchars,
|
|
ids=[str(i) for i in range(len(params_invalid_altchars))],
|
|
)
|
|
|
|
|
|
@param_simd
|
|
@params_invalid_altchars
|
|
@param_encode_functions
|
|
def test_invalid_altchars_enc(efn, ecast, altchars, exception, simd):
|
|
with pytest.raises(exception):
|
|
efn(b"ABCD", altchars)
|
|
|
|
|
|
@param_simd
|
|
@params_invalid_altchars
|
|
@param_decode_functions
|
|
def test_invalid_altchars_dec(dfn, dcast, altchars, exception, simd):
|
|
with pytest.raises(exception):
|
|
dfn(b"ABCD", altchars)
|
|
|
|
|
|
@param_simd
|
|
@params_invalid_altchars
|
|
@param_decode_functions
|
|
def test_invalid_altchars_dec_validate(dfn, dcast, altchars, exception, simd):
|
|
with pytest.raises(exception):
|
|
dfn(b"ABCD", altchars, True)
|
|
|
|
|
|
params_invalid_data_novalidate = [
|
|
[b"A@@@@FG", None, BinAsciiError],
|
|
["ABC€", None, ValueError],
|
|
[3.0, None, TypeError],
|
|
]
|
|
params_invalid_data_validate = [
|
|
[b"\x00\x00\x00\x00", None, BinAsciiError],
|
|
[b"A@@@@FGHIJKLMNOPQRSTUVWXYZabcdef", b"-_", BinAsciiError],
|
|
[b"A@@@=FGHIJKLMNOPQRSTUVWXYZabcdef", b"-_", BinAsciiError],
|
|
[b"A@@=@FGHIJKLMNOPQRSTUVWXYZabcdef", b"-_", BinAsciiError],
|
|
[b"A@@@@FGHIJKLMNOPQRSTUVWXYZabcde@=", b"-_", BinAsciiError],
|
|
[b"A@@@@FGHIJKLMNOPQRSTUVWXYZabcd@==", b"-_", BinAsciiError],
|
|
[b"A@@@@FGH" * 10000, b"-_", BinAsciiError],
|
|
# [std, b'-_', BinAsciiError], TODO does no fail with base64 module
|
|
]
|
|
params_invalid_data_all = pytest.mark.parametrize(
|
|
"vector,altchars,exception",
|
|
params_invalid_data_novalidate + params_invalid_data_validate,
|
|
ids=[
|
|
str(i)
|
|
for i in range(
|
|
len(params_invalid_data_novalidate) + len(params_invalid_data_validate)
|
|
)
|
|
],
|
|
)
|
|
params_invalid_data_novalidate = pytest.mark.parametrize(
|
|
"vector,altchars,exception",
|
|
params_invalid_data_novalidate,
|
|
ids=[str(i) for i in range(len(params_invalid_data_novalidate))],
|
|
)
|
|
params_invalid_data_validate = pytest.mark.parametrize(
|
|
"vector,altchars,exception",
|
|
params_invalid_data_validate,
|
|
ids=[str(i) for i in range(len(params_invalid_data_validate))],
|
|
)
|
|
|
|
|
|
@param_simd
|
|
@params_invalid_data_novalidate
|
|
@param_decode_functions
|
|
def test_invalid_data_dec(dfn, dcast, vector, altchars, exception, simd):
|
|
with pytest.raises(exception):
|
|
dfn(vector, altchars)
|
|
|
|
|
|
@param_simd
|
|
@params_invalid_data_validate
|
|
@param_decode_functions
|
|
def test_invalid_data_dec_skip(dfn, dcast, vector, altchars, exception, simd):
|
|
test = dcast(dfn(vector, altchars))
|
|
base = base64.b64decode(vector, altchars)
|
|
assert test == base
|
|
|
|
|
|
@param_simd
|
|
@params_invalid_data_all
|
|
@param_decode_functions
|
|
def test_invalid_data_dec_validate(dfn, dcast, vector, altchars, exception, simd):
|
|
with pytest.raises(exception):
|
|
dfn(vector, altchars, True)
|
|
|
|
|
|
@param_encode_functions
|
|
def test_invalid_data_enc_0(efn, ecast):
|
|
with pytest.raises(TypeError):
|
|
efn("this is a test")
|
|
|
|
|
|
@param_encode_functions
|
|
def test_invalid_args_enc_0(efn, ecast):
|
|
with pytest.raises(TypeError):
|
|
efn()
|
|
|
|
|
|
@param_decode_functions
|
|
def test_invalid_args_dec_0(dfn, dcast):
|
|
with pytest.raises(TypeError):
|
|
dfn()
|