From 0514683ccb9f25bcdbbda52f758eeb2a7025f703 Mon Sep 17 00:00:00 2001 From: hxuanyu <2252193204@qq.com> Date: Wed, 27 Mar 2024 22:49:54 +0800 Subject: [PATCH] init commit --- .gitignore | 1 + database_migrate/config.yaml | 17 +++ database_migrate/db_migrate.py | 208 +++++++++++++++++++++++++++++++++ 3 files changed, 226 insertions(+) create mode 100644 .gitignore create mode 100644 database_migrate/config.yaml create mode 100644 database_migrate/db_migrate.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..85e7c1d --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/.idea/ diff --git a/database_migrate/config.yaml b/database_migrate/config.yaml new file mode 100644 index 0000000..b4ba79e --- /dev/null +++ b/database_migrate/config.yaml @@ -0,0 +1,17 @@ +databases: + source_db: + host: '192.168.31.204' + port: 3306 + user: 'root' + password: '123456' + database: 'allocative/test_master' + target_db: + host: '192.168.31.203' + port: 3306 + user: 'root' + password: '123456' + database: 'allocative' +concurrency: 5 +batch_size: 2000 +logging: + level: DEBUG diff --git a/database_migrate/db_migrate.py b/database_migrate/db_migrate.py new file mode 100644 index 0000000..8e3cc27 --- /dev/null +++ b/database_migrate/db_migrate.py @@ -0,0 +1,208 @@ +import time + +import yaml +import pymysql +import argparse +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import List +import logging + +# 默认配置文件路径 +DEFAULT_CONFIG_PATH = "config.yaml" + + +# 读取配置文件 +def load_config(config_path="config.yaml"): + with open(config_path, "r") as file: + config = yaml.safe_load(file) + return config + +# 配置日志 +def setup_logging(config_level, log_file='migration.log', file_mode='a'): + numeric_level = getattr(logging, config_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError(f'Invalid log level: {config_level}') + + # 创建一个日志记录器 + logger = logging.getLogger() + logger.setLevel(numeric_level) + + # 创建一个控制台日志处理器并设置级别 + ch = logging.StreamHandler() + ch.setLevel(numeric_level) + + # 创建一个文件日志处理器并设置级别 + fh = logging.FileHandler(log_file, mode=file_mode) + fh.setLevel(numeric_level) + + # 创建一个日志格式器 + formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + + # 为处理器添加格式器 + ch.setFormatter(formatter) + fh.setFormatter(formatter) + + # 为日志记录器添加处理器 + logger.addHandler(ch) + logger.addHandler(fh) + + +# 解析命令行参数 +def parse_args(): + parser = argparse.ArgumentParser(description="Database Migration Script") + parser.add_argument('--config', type=str, default="config.yaml", help="Path to the configuration file") + parser.add_argument('--source_host', type=str, help="Source database host") + parser.add_argument('--source_port', type=int, help="Source database port") + parser.add_argument('--source_user', type=str, help="Source database user") + parser.add_argument('--source_password', type=str, help="Source database password") + parser.add_argument('--source_db', type=str, help="Source database name") + parser.add_argument('--target_host', type=str, help="Target database host") + parser.add_argument('--target_port', type=int, help="Target database port") + parser.add_argument('--target_user', type=str, help="Target database user") + parser.add_argument('--target_password', type=str, help="Target database password") + parser.add_argument('--target_db', type=str, help="Target database name") + return parser.parse_args() + + +# 数据库连接上下文管理器 +class DatabaseConnection: + def __init__(self, db_config): + self.db_config = db_config + self.connection = None + + def __enter__(self): + self.connection = pymysql.connect(**self.db_config) + return self.connection + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.connection: + self.connection.close() + + +# 数据库操作类 +class DBMigrator: + def __init__(self, source_config, target_config, batch_size): + self.source_config = source_config + self.target_config = target_config + self.batch_size = batch_size + + def clear_target_database(self): + # 清空目标数据库所有表 + try: + with DatabaseConnection(self.target_config) as target_db: + with target_db.cursor() as cursor: + cursor.execute("SHOW TABLES;") + tables = cursor.fetchall() + for table in tables: + cursor.execute(f"TRUNCATE TABLE {table[0]};") + target_db.commit() + return True + except Exception as e: + logging.error(f"Error clearing target database: {e}") + return False + + def migrate_table_data(self, table): + # 为每个表迁移数据 + try: + with DatabaseConnection(self.source_config) as source_db, DatabaseConnection( + self.target_config) as target_db: + with source_db.cursor() as source_cursor, target_db.cursor() as target_cursor: + source_cursor.execute(f"SELECT COUNT(*) FROM {table};") + source_count = source_cursor.fetchone()[0] + logging.info(f"Table {table}: {source_count} rows to migrate.") + + source_cursor.execute(f"SELECT * FROM {table};") + rows_fetched = 0 + while True: + rows = source_cursor.fetchmany(self.batch_size) # 分页读取 + if not rows: + break + placeholders = ', '.join(['%s'] * len(rows[0])) + sql = f"INSERT INTO {table} VALUES ({placeholders})" + target_cursor.executemany(sql, rows) + rows_fetched += len(rows) + logging.debug(f"Table {table}: {rows_fetched}/{source_count} rows migrated.") + target_db.commit() + + # 验证数据量是否一致 + target_cursor.execute(f"SELECT COUNT(*) FROM {table};") + target_count = target_cursor.fetchone()[0] + if source_count == target_count: + logging.info(f"Table {table} migrated successfully. Total rows: {target_count}") + return True + else: + logging.error( + f"Table {table} migration failed. Source rows: {source_count}, Target rows: {target_count}") + return False + except Exception as e: + logging.error(f"Error migrating table {table}: {e}") + return False + + def get_tables(self) -> List[str]: + # 获取源数据库的所有表名 + try: + with DatabaseConnection(self.source_config) as db: + with db.cursor() as cursor: + cursor.execute("SHOW TABLES;") + tables = [table[0] for table in cursor.fetchall()] + return tables + except Exception as e: + logging.error(f"Error fetching table list: {e}") + return [] + + def migrate(self, concurrency): + self.clear_target_database() + tables = self.get_tables() + with ThreadPoolExecutor(max_workers=concurrency) as executor: + futures = {executor.submit(self.migrate_table_data, table): table for table in tables} + for future in as_completed(futures): + table = futures[future] + try: + result = future.result() + if result: + logging.info(f"Table {table} migration succeeded.") + else: + logging.error(f"Table {table} migration failed.") + except Exception as e: + logging.error(f"Error migrating table {table}: {e}") + + +def main(): + args = parse_args() + config = load_config(args.config) + setup_logging(config["logging"]["level"], log_file='migration.log', file_mode='a') + # 更新数据库配置,如果通过命令行提供了具体信息 + if args.source_host and args.target_host: + source_config = { + "host": args.source_host, + "port": args.source_port, + "user": args.source_user, + "password": args.source_password, + "db": args.source_db, + } + target_config = { + "host": args.target_host, + "port": args.target_port, + "user": args.target_user, + "password": args.target_password, + "db": args.target_db, + } + else: + # 使用默认配置 + source_config = config["databases"]["source_db"] + target_config = config["databases"]["target_db"] + + migrator = DBMigrator(source_config, target_config, config["batch_size"]) + start_time = time.time() + migrator.migrate(config["concurrency"]) + end_time = time.time() # 脚本完成执行的时间 + duration = end_time - start_time # 计算脚本执行耗时 + + # 输出脚本执行耗时 + print(f"Script completed in {duration:.2f} seconds.") + logging.info(f"Script completed in {duration:.2f} seconds.") + print(f"Used {config['concurrency']} threads and batch size of {config['batch_size']}.") + + +if __name__ == "__main__": + main()