import time import yaml import pymysql import argparse from concurrent.futures import ThreadPoolExecutor, as_completed from typing import List import logging import re # 默认配置文件路径 DEFAULT_CONFIG_PATH = "config.yaml" # 读取配置文件 def load_config(config_path="config.yaml"): with open(config_path, "r") as file: config = yaml.safe_load(file) return config # 配置日志 def setup_logging(config_level, log_file='migration.log', file_mode='a'): numeric_level = getattr(logging, config_level.upper(), None) if not isinstance(numeric_level, int): raise ValueError(f'Invalid log level: {config_level}') # 日志记录器 logger = logging.getLogger() logger.setLevel(numeric_level) # 控制台日志处理器 ch = logging.StreamHandler() ch.setLevel(numeric_level) # 文件日志处理器 fh = logging.FileHandler(log_file, mode=file_mode) fh.setLevel(numeric_level) # 日志格式器 formatter = logging.Formatter('%(asctime)s [Thread ID:%(thread)d] [%(threadName)s] %(levelname)s: %(message)s') # 添加格式器 ch.setFormatter(formatter) fh.setFormatter(formatter) # 添加处理器 logger.addHandler(ch) logger.addHandler(fh) # 解析命令行参数 def parse_args(): parser = argparse.ArgumentParser(description="Database Migration Script") parser.add_argument('--config', type=str, default="config.yaml", help="Path to the configuration file") parser.add_argument('--source_host', type=str, help="Source database host") parser.add_argument('--source_port', type=int, help="Source database port") parser.add_argument('--source_user', type=str, help="Source database user") parser.add_argument('--source_password', type=str, help="Source database password") parser.add_argument('--source_db', type=str, help="Source database name") parser.add_argument('--target_host', type=str, help="Target database host") parser.add_argument('--target_port', type=int, help="Target database port") parser.add_argument('--target_user', type=str, help="Target database user") parser.add_argument('--target_password', type=str, help="Target database password") parser.add_argument('--target_db', type=str, help="Target database name") return parser.parse_args() # 数据库连接上下文管理器 class DatabaseConnection: def __init__(self, db_config): self.db_config = db_config self.connection = None def __enter__(self): self.connection = pymysql.connect(**self.db_config) return self.connection def __exit__(self, exc_type, exc_val, exc_tb): if self.connection: self.connection.close() # 数据库操作类 class DBMigrator: def __init__(self, source_config, target_config, batch_size): self.source_config = source_config self.target_config = target_config self.batch_size = batch_size def clear_target_database(self): # 删除目标数据库所有表 try: with DatabaseConnection(self.target_config) as target_db: with target_db.cursor() as cursor: cursor.execute("SET FOREIGN_KEY_CHECKS = 0;") cursor.execute("SHOW TABLES;") tables = cursor.fetchall() for table in tables: cursor.execute(f"DROP TABLE IF EXISTS {table[0]};") cursor.execute("SET FOREIGN_KEY_CHECKS = 1;") target_db.commit() return True except Exception as e: logging.error(f"Error clearing target database: {e}") return False def copy_table_structure(self, table): # 复制表结构 try: with DatabaseConnection(self.source_config) as source_db, DatabaseConnection( self.target_config) as target_db: with source_db.cursor() as source_cursor, target_db.cursor() as target_cursor: source_cursor.execute(f"SHOW CREATE TABLE {table};") create_table_sql = source_cursor.fetchone()[1] # 检查并修正BIT(1)类型字段的默认值表示, 当前脚本为简单替换,其他方案待讨论 corrected_sql = self.correct_bit_default_value(create_table_sql) target_cursor.execute(corrected_sql) target_db.commit() return True except Exception as e: logging.error(f"Error copying table structure for {table}: {e}") return False def correct_bit_default_value(self, sql): corrected_sql = sql.replace("bit(1) DEFAULT '0'", "bit(1) DEFAULT b'0'") return corrected_sql def migrate_table_data(self, table): # 为每个表迁移数据 try: with DatabaseConnection(self.source_config) as source_db, DatabaseConnection( self.target_config) as target_db: with source_db.cursor() as source_cursor, target_db.cursor() as target_cursor: source_cursor.execute(f"SELECT COUNT(*) FROM {table};") source_count = source_cursor.fetchone()[0] logging.info(f"Table {table}: {source_count} rows to migrate.") source_cursor.execute(f"SELECT * FROM {table};") rows_fetched = 0 while True: rows = source_cursor.fetchmany(self.batch_size) # 分页读取 if not rows: break placeholders = ', '.join(['%s'] * len(rows[0])) sql = f"INSERT INTO {table} VALUES ({placeholders})" target_cursor.executemany(sql, rows) rows_fetched += len(rows) logging.debug(f"Table {table}: {rows_fetched}/{source_count} rows migrated.") target_db.commit() # 验证数据量是否一致 target_cursor.execute(f"SELECT COUNT(*) FROM {table};") target_count = target_cursor.fetchone()[0] if source_count == target_count: logging.info(f"Table {table} migrated successfully. Total rows: {target_count}") return True else: logging.error( f"Table {table} migration failed. Source rows: {source_count}, Target rows: {target_count}") return False except Exception as e: logging.error(f"Error migrating table {table}: {e}") return False def get_tables(self) -> List[str]: # 获取源数据库的所有表名 try: with DatabaseConnection(self.source_config) as db: with db.cursor() as cursor: cursor.execute("SHOW TABLES;") tables = [table[0] for table in cursor.fetchall()] return tables except Exception as e: logging.error(f"Error fetching table list: {e}") return [] def migrate(self, concurrency): self.clear_target_database() tables = self.get_tables() # 复制所有表结构 for table in tables: if not self.copy_table_structure(table): logging.error(f"Failed to copy structure for table {table}. Migration aborted.") return # 使用线程池并发迁移数据 with ThreadPoolExecutor(max_workers=concurrency) as executor: futures = {executor.submit(self.migrate_table_data, table): table for table in tables} for future in as_completed(futures): table = futures[future] try: result = future.result() if result: logging.info(f"Table {table} migration succeeded.") else: logging.error(f"Table {table} migration failed.") except Exception as e: logging.error(f"Error migrating table {table}: {e}") def main(): args = parse_args() config = load_config(args.config) setup_logging(config["logging"]["level"], log_file='migration.log', file_mode='a') # 更新数据库配置,如果参数传入,则使用参数中的配置 if args.source_host and args.target_host: source_config = { "host": args.source_host, "port": args.source_port, "user": args.source_user, "password": args.source_password, "db": args.source_db, } target_config = { "host": args.target_host, "port": args.target_port, "user": args.target_user, "password": args.target_password, "db": args.target_db, } else: # 使用默认配置 source_config = config["databases"]["source_db"] target_config = config["databases"]["target_db"] migrator = DBMigrator(source_config, target_config, config["batch_size"]) start_time = time.time() migrator.migrate(config["concurrency"]) end_time = time.time() # 脚本完成执行的时间 duration = end_time - start_time # 计算脚本执行耗时 # 输出脚本执行耗时 print(f"Script completed in {duration:.2f} seconds.") logging.info(f"Script completed in {duration:.2f} seconds.") print(f"Used {config['concurrency']} threads and batch size of {config['batch_size']}.") if __name__ == "__main__": main()