From 20bab6f9bc772387bef3672fbb52020448e2cc53 Mon Sep 17 00:00:00 2001 From: yang-chaohao Date: Tue, 9 Dec 2025 19:13:24 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E8=A1=A5=E5=85=85=E5=8D=95=E5=85=83?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- unittest/test_cipher.py | 246 ++++++++++++++++++++++++++++++++++++++++ unittest/test_common.py | 221 ++++++++++++++++++++++++++++++++++++ 2 files changed, 467 insertions(+) create mode 100644 unittest/test_cipher.py create mode 100644 unittest/test_common.py diff --git a/unittest/test_cipher.py b/unittest/test_cipher.py new file mode 100644 index 0000000..fa4b3e7 --- /dev/null +++ b/unittest/test_cipher.py @@ -0,0 +1,246 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import unittest +import os +import json +import sys +from unittest.mock import patch, MagicMock + +# 添加 backend 目录到 Python 路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'backend')) +from utils.cipher import CustomCipher, DecryptError + + +class TestCustomCipher(unittest.TestCase): + + def setUp(self): + """测试前设置""" + self.cipher = CustomCipher() + + def test_generate_random_string_default(self): + """测试生成默认长度的随机字符串""" + result = self.cipher.generate_random_string() + self.assertEqual(len(result), 16) + # 检查字符串只包含字母和数字 + self.assertTrue(all(c.isalnum() for c in result)) + + def test_generate_random_string_custom_length(self): + """测试生成自定义长度的随机字符串""" + result = self.cipher.generate_random_string(length=32) + self.assertEqual(len(result), 32) + + def test_generate_random_string_custom_seed(self): + """测试使用自定义字符集生成随机字符串""" + seed = "ABC123" + result = self.cipher.generate_random_string(length=10, seed=seed) + self.assertEqual(len(result), 10) + # 检查字符串只包含自定义字符集中的字符 + self.assertTrue(all(c in seed for c in result)) + + @patch.dict(os.environ, {'CIPHER_HALF_KEY_2': 'test_half_key_2'}) + def test_generate_root_key_with_env(self): + """测试使用环境变量生成根密钥""" + half_key_1 = "test_half_key_1" + result = CustomCipher._generate_root_key(half_key_1) + self.assertIsInstance(result, bytes) + self.assertEqual(len(result), 32) # 32字节 = 64十六进制字符,但切片后为32字节 + + def test_generate_root_key_without_env(self): + """测试使用默认值生成根密钥""" + # 确保环境变量不存在 + if 'CIPHER_HALF_KEY_2' in os.environ: + del os.environ['CIPHER_HALF_KEY_2'] + + half_key_1 = "test_half_key_1" + result = CustomCipher._generate_root_key(half_key_1) + self.assertIsInstance(result, bytes) + self.assertEqual(len(result), 32) + + def test_encrypt_decrypt_roundtrip_string(self): + """测试字符串加密解密的往返正确性""" + plaintext = "这是一个测试字符串" + + # 加密 + ciphertext_data = self.cipher.encrypt_plaintext(plaintext) + + # 验证加密结果结构 + self.assertIn('half_key', ciphertext_data) + self.assertIn('encrypted_work_key', ciphertext_data) + self.assertIn('work_key_iv', ciphertext_data) + self.assertIn('plaintext_iv', ciphertext_data) + self.assertIn('ciphertext', ciphertext_data) + + # 解密 + decrypted = self.cipher.decrypt_ciphertext_data(ciphertext_data) + + # 验证解密结果 + self.assertEqual(decrypted, plaintext) + + def test_encrypt_decrypt_roundtrip_dict(self): + """测试字典加密解密的往返正确性""" + plaintext = { + "name": "测试", + "value": 123, + "nested": { + "key": "value" + } + } + + # 加密 + ciphertext_data = self.cipher.encrypt_plaintext(plaintext) + + # 解密 + decrypted = self.cipher.decrypt_ciphertext_data(ciphertext_data) + + # 验证解密结果 + self.assertEqual(decrypted, plaintext) + + def test_decrypt_invalid_ciphertext_missing_field(self): + """测试解密缺少字段的密文""" + ciphertext_data = { + 'half_key': 'test', + 'encrypted_work_key': 'test', + 'work_key_iv': 'test', + 'plaintext_iv': 'test', + # 缺少 'ciphertext' 字段 + } + + with self.assertRaises(DecryptError) as context: + self.cipher.decrypt_ciphertext_data(ciphertext_data) + + self.assertIn('Failed to decrypt', str(context.exception)) + + def test_decrypt_invalid_ciphertext_empty_field(self): + """测试解密有空字段的密文""" + ciphertext_data = { + 'half_key': '', + 'encrypted_work_key': 'test', + 'work_key_iv': 'test', + 'plaintext_iv': 'test', + 'ciphertext': 'test' + } + + with self.assertRaises(DecryptError) as context: + self.cipher.decrypt_ciphertext_data(ciphertext_data) + + self.assertIn('Failed to decrypt', str(context.exception)) + + def test_decrypt_invalid_base64(self): + """测试解密无效的Base64编码""" + ciphertext_data = { + 'half_key': 'test', + 'encrypted_work_key': 'not-valid-base64!', + 'work_key_iv': 'test', + 'plaintext_iv': 'test', + 'ciphertext': 'test' + } + + # 这里应该会抛出binascii.Error或类似异常 + with self.assertRaises(Exception): + self.cipher.decrypt_ciphertext_data(ciphertext_data) + + def test_encrypt_different_inputs_produce_different_outputs(self): + """测试相同明文多次加密产生不同的输出(由于随机IV)""" + plaintext = "相同的明文" + + # 第一次加密 + ciphertext_data1 = self.cipher.encrypt_plaintext(plaintext) + + # 第二次加密 + ciphertext_data2 = self.cipher.encrypt_plaintext(plaintext) + + # 验证两次加密的结果不同(由于随机IV和half_key) + self.assertNotEqual(ciphertext_data1['ciphertext'], ciphertext_data2['ciphertext']) + self.assertNotEqual(ciphertext_data1['half_key'], ciphertext_data2['half_key']) + + def test_decrypt_error_class(self): + """测试DecryptError异常类""" + error_message = "测试错误消息" + error = DecryptError(error_message) + + self.assertEqual(str(error), error_message) + self.assertIsInstance(error, Exception) + + @patch('utils.cipher.secrets.token_bytes') + def test_generate_work_key(self, mock_token_bytes): + """测试工作密钥生成""" + # 模拟secrets.token_bytes返回固定值 + mock_token_bytes.side_effect = [ + b'work_key_32_bytes_123456789012', # work_key + b'work_key_iv_16_bytes' # work_key_iv + ] + + half_key = "test_half_key" + encrypted_work_key, work_key_iv, work_key = self.cipher._generate_work_key(half_key) + + # 验证返回类型 + self.assertIsInstance(encrypted_work_key, str) + self.assertIsInstance(work_key_iv, str) + self.assertIsInstance(work_key, bytes) + + # 验证Base64编码 + import base64 + try: + base64.b64decode(encrypted_work_key) + base64.b64decode(work_key_iv) + except Exception: + self.fail("返回的字符串不是有效的Base64编码") + + @patch('utils.cipher.b64decode') + @patch('utils.cipher.CustomCipher._generate_root_key') + @patch('utils.cipher.CustomCipher._decrypt') + def test_decrypt_work_key(self, mock_decrypt, mock_generate_root_key, mock_b64decode): + """测试工作密钥解密""" + # 设置模拟 + mock_generate_root_key.return_value = b'root_key_32_bytes' + mock_b64decode.side_effect = [ + b'encrypted_work_key_bytes', + b'work_key_iv_bytes' + ] + mock_decrypt.return_value = b'decrypted_work_key' + + half_key = "test_half_key" + encrypted_work_key = "encrypted_work_key_base64" + work_key_iv = "work_key_iv_base64" + + result = self.cipher._decrypt_work_key(half_key, encrypted_work_key, work_key_iv) + + # 验证调用 + mock_generate_root_key.assert_called_once_with(half_key) + self.assertEqual(mock_b64decode.call_count, 2) + mock_decrypt.assert_called_once_with( + b'root_key_32_bytes', + b'work_key_iv_bytes', + b'encrypted_work_key_bytes' + ) + + # 验证结果 + self.assertEqual(result, b'decrypted_work_key') + + def test_encrypt_plaintext_handles_json_serialization(self): + """测试encrypt_plaintext正确处理JSON序列化""" + # 测试非字符串/字典类型应该被JSON序列化 + test_cases = [ + "字符串", + {"key": "value"}, + ["列表", 123], + 123, # 数字 + True, # 布尔值 + None # null + ] + + for plaintext in test_cases: + try: + ciphertext_data = self.cipher.encrypt_plaintext(plaintext) + decrypted = self.cipher.decrypt_ciphertext_data(ciphertext_data) + + # 验证解密后的值与原始值相等(经过JSON序列化/反序列化) + self.assertEqual(decrypted, plaintext) + except Exception as e: + self.fail(f"加密解密失败,输入: {plaintext}, 错误: {e}") + + +if __name__ == '__main__': + unittest.main() + diff --git a/unittest/test_common.py b/unittest/test_common.py new file mode 100644 index 0000000..2e91822 --- /dev/null +++ b/unittest/test_common.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import unittest +import sys +import os +import time +from unittest.mock import patch, MagicMock +import tempfile +import stat + +# 添加 backend 目录到 Python 路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'backend')) +from utils.common import is_process_running, validate_executable_file + + +class TestCommonUtils(unittest.TestCase): + + @patch('utils.common.psutil.process_iter') + def test_is_process_running_found(self, mock_process_iter): + """测试找到包含关键字的进程""" + # 模拟进程迭代器 + mock_proc1 = MagicMock() + mock_proc1.info = { + 'pid': 1234, + 'name': 'python.exe', + 'cmdline': ['python', 'test_script.py'], + 'create_time': time.time() - 100 # 100秒前创建 + } + + mock_proc2 = MagicMock() + mock_proc2.info = { + 'pid': 5678, + 'name': 'chrome.exe', + 'cmdline': ['chrome', '--test-mode'], + 'create_time': time.time() - 50 # 50秒前创建 + } + + mock_process_iter.return_value = [mock_proc1, mock_proc2] + + # 测试找到进程 + result = is_process_running('python') + self.assertTrue(result) + + result = is_process_running('chrome') + self.assertTrue(result) + + result = is_process_running('test') + self.assertTrue(result) # 在cmdline中找到'test' + + @patch('utils.common.psutil.process_iter') + def test_is_process_running_not_found(self, mock_process_iter): + """测试未找到包含关键字的进程""" + # 模拟进程迭代器 + mock_proc = MagicMock() + mock_proc.info = { + 'pid': 1234, + 'name': 'python.exe', + 'cmdline': ['python', 'other_script.py'], + 'create_time': time.time() - 100 + } + + mock_process_iter.return_value = [mock_proc] + + # 测试未找到进程 + result = is_process_running('java') + self.assertFalse(result) + + result = is_process_running('chrome') + self.assertFalse(result) + + @patch('utils.common.psutil.process_iter') + def test_is_process_running_timeout(self, mock_process_iter): + """测试进程运行超时被忽略""" + # 模拟进程迭代器,进程创建时间超过timeout + mock_proc = MagicMock() + mock_proc.info = { + 'pid': 1234, + 'name': 'python.exe', + 'cmdline': ['python', 'test_script.py'], + 'create_time': time.time() - 700 # 700秒前创建,超过默认600秒timeout + } + + mock_process_iter.return_value = [mock_proc] + + # 测试超时进程被忽略 + result = is_process_running('python') + self.assertFalse(result) + + # 测试自定义timeout + result = is_process_running('python', timeout=800) + self.assertTrue(result) # 800秒timeout,进程创建700秒,应该找到 + + @patch('utils.common.psutil.process_iter') + def test_is_process_running_exception_handling(self, mock_process_iter): + """测试进程迭代异常处理""" + # 模拟进程迭代器抛出异常 + mock_proc = MagicMock() + mock_proc.info.side_effect = Exception("Access denied") + + mock_process_iter.return_value = [mock_proc] + + # 测试异常被正确处理 + result = is_process_running('python') + self.assertFalse(result) # 异常被捕获,返回False + + @patch('utils.common.psutil.process_iter') + def test_is_process_running_case_insensitive(self, mock_process_iter): + """测试关键字大小写不敏感""" + # 模拟进程迭代器 + mock_proc = MagicMock() + mock_proc.info = { + 'pid': 1234, + 'name': 'PYTHON.EXE', # 大写 + 'cmdline': ['Python', 'TestScript.py'], # 混合大小写 + 'create_time': time.time() - 100 + } + + mock_process_iter.return_value = [mock_proc] + + # 测试大小写不敏感匹配 + result = is_process_running('python') + self.assertTrue(result) + + result = is_process_running('PYTHON') + self.assertTrue(result) + + result = is_process_running('testscript') + self.assertTrue(result) + + def test_validate_executable_file_success(self): + """测试验证有效的可执行文件""" + # 创建临时文件并设置可执行权限 + with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: + file_path = f.name + # 在Windows上设置可执行权限 + if os.name == 'nt': + # Windows没有os.X_OK,我们只需要文件存在 + os.chmod(file_path, stat.S_IREAD | stat.S_IWRITE) + else: + os.chmod(file_path, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC) + + try: + # 测试验证成功 + is_valid, error_msg = validate_executable_file(file_path) + self.assertTrue(is_valid) + self.assertEqual(error_msg, "") + finally: + # 清理临时文件 + if os.path.exists(file_path): + os.unlink(file_path) + + def test_validate_executable_file_empty_path(self): + """测试验证空文件路径""" + is_valid, error_msg = validate_executable_file("") + self.assertFalse(is_valid) + self.assertEqual(error_msg, "文件路径不能为空") + + is_valid, error_msg = validate_executable_file(None) + self.assertFalse(is_valid) + self.assertEqual(error_msg, "文件路径不能为空") + + def test_validate_executable_file_not_exist(self): + """测试验证不存在的文件""" + non_existent_file = "/tmp/nonexistent_file_123456789" + is_valid, error_msg = validate_executable_file(non_existent_file) + self.assertFalse(is_valid) + self.assertIn("文件不存在", error_msg) + + def test_validate_executable_file_not_a_file(self): + """测试验证目录而不是文件""" + # 创建临时目录 + with tempfile.TemporaryDirectory() as temp_dir: + is_valid, error_msg = validate_executable_file(temp_dir) + self.assertFalse(is_valid) + self.assertIn("路径不是文件", error_msg) + + @unittest.skipIf(os.name == 'nt', "跳过Windows上的权限测试") + def test_validate_executable_file_no_execute_permission(self): + """测试验证没有可执行权限的文件(非Windows系统)""" + # 创建临时文件但不设置可执行权限 + with tempfile.NamedTemporaryFile(mode='w', delete=False) as f: + file_path = f.name + # 只设置读写权限,没有执行权限 + os.chmod(file_path, stat.S_IREAD | stat.S_IWRITE) + + try: + # 测试验证失败 + is_valid, error_msg = validate_executable_file(file_path) + self.assertFalse(is_valid) + self.assertIn("文件没有可执行权限", error_msg) + finally: + # 清理临时文件 + if os.path.exists(file_path): + os.unlink(file_path) + + def test_validate_executable_file_relative_path(self): + """测试验证相对路径文件""" + # 创建临时文件 + with tempfile.NamedTemporaryFile(mode='w', dir='.', delete=False) as f: + file_name = os.path.basename(f.name) + # 在Windows上设置权限 + if os.name == 'nt': + os.chmod(f.name, stat.S_IREAD | stat.S_IWRITE) + else: + os.chmod(f.name, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC) + + try: + # 测试相对路径 + is_valid, error_msg = validate_executable_file(file_name) + self.assertTrue(is_valid) + self.assertEqual(error_msg, "") + finally: + # 清理临时文件 + if os.path.exists(file_name): + os.unlink(file_name) + + +if __name__ == '__main__': + unittest.main() + -- Gitee