242 lines
9.5 KiB
Python
242 lines
9.5 KiB
Python
import time
|
|
|
|
import yaml
|
|
import pymysql
|
|
import argparse
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
from typing import List
|
|
import logging
|
|
import re
|
|
|
|
# 默认配置文件路径
|
|
DEFAULT_CONFIG_PATH = "config.yaml"
|
|
|
|
|
|
# 读取配置文件
|
|
def load_config(config_path="config.yaml"):
|
|
with open(config_path, "r") as file:
|
|
config = yaml.safe_load(file)
|
|
return config
|
|
|
|
|
|
# 配置日志
|
|
def setup_logging(config_level, log_file='migration.log', file_mode='a'):
|
|
numeric_level = getattr(logging, config_level.upper(), None)
|
|
if not isinstance(numeric_level, int):
|
|
raise ValueError(f'Invalid log level: {config_level}')
|
|
|
|
# 创建一个日志记录器
|
|
logger = logging.getLogger()
|
|
logger.setLevel(numeric_level)
|
|
|
|
# 创建一个控制台日志处理器并设置级别
|
|
ch = logging.StreamHandler()
|
|
ch.setLevel(numeric_level)
|
|
|
|
# 创建一个文件日志处理器并设置级别
|
|
fh = logging.FileHandler(log_file, mode=file_mode)
|
|
fh.setLevel(numeric_level)
|
|
|
|
# 创建一个日志格式器
|
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
# 为处理器添加格式器
|
|
ch.setFormatter(formatter)
|
|
fh.setFormatter(formatter)
|
|
|
|
# 为日志记录器添加处理器
|
|
logger.addHandler(ch)
|
|
logger.addHandler(fh)
|
|
|
|
|
|
# 解析命令行参数
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Database Migration Script")
|
|
parser.add_argument('--config', type=str, default="config.yaml", help="Path to the configuration file")
|
|
parser.add_argument('--source_host', type=str, help="Source database host")
|
|
parser.add_argument('--source_port', type=int, help="Source database port")
|
|
parser.add_argument('--source_user', type=str, help="Source database user")
|
|
parser.add_argument('--source_password', type=str, help="Source database password")
|
|
parser.add_argument('--source_db', type=str, help="Source database name")
|
|
parser.add_argument('--target_host', type=str, help="Target database host")
|
|
parser.add_argument('--target_port', type=int, help="Target database port")
|
|
parser.add_argument('--target_user', type=str, help="Target database user")
|
|
parser.add_argument('--target_password', type=str, help="Target database password")
|
|
parser.add_argument('--target_db', type=str, help="Target database name")
|
|
return parser.parse_args()
|
|
|
|
|
|
# 数据库连接上下文管理器
|
|
class DatabaseConnection:
|
|
def __init__(self, db_config):
|
|
self.db_config = db_config
|
|
self.connection = None
|
|
|
|
def __enter__(self):
|
|
self.connection = pymysql.connect(**self.db_config)
|
|
return self.connection
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
if self.connection:
|
|
self.connection.close()
|
|
|
|
|
|
# 数据库操作类
|
|
class DBMigrator:
|
|
def __init__(self, source_config, target_config, batch_size):
|
|
self.source_config = source_config
|
|
self.target_config = target_config
|
|
self.batch_size = batch_size
|
|
|
|
def clear_target_database(self):
|
|
# 彻底删除目标数据库所有表
|
|
try:
|
|
with DatabaseConnection(self.target_config) as target_db:
|
|
with target_db.cursor() as cursor:
|
|
cursor.execute("SET FOREIGN_KEY_CHECKS = 0;")
|
|
cursor.execute("SHOW TABLES;")
|
|
tables = cursor.fetchall()
|
|
for table in tables:
|
|
cursor.execute(f"DROP TABLE IF EXISTS {table[0]};")
|
|
cursor.execute("SET FOREIGN_KEY_CHECKS = 1;")
|
|
target_db.commit()
|
|
return True
|
|
except Exception as e:
|
|
logging.error(f"Error clearing target database: {e}")
|
|
return False
|
|
|
|
def copy_table_structure(self, table):
|
|
# 复制表结构
|
|
try:
|
|
with DatabaseConnection(self.source_config) as source_db, DatabaseConnection(
|
|
self.target_config) as target_db:
|
|
with source_db.cursor() as source_cursor, target_db.cursor() as target_cursor:
|
|
source_cursor.execute(f"SHOW CREATE TABLE {table};")
|
|
create_table_sql = source_cursor.fetchone()[1]
|
|
|
|
# 检查并修正BIT(1)类型字段的默认值表示
|
|
corrected_sql = self.correct_bit_default_value(create_table_sql)
|
|
|
|
target_cursor.execute(corrected_sql)
|
|
target_db.commit()
|
|
return True
|
|
except Exception as e:
|
|
logging.error(f"Error copying table structure for {table}: {e}")
|
|
return False
|
|
|
|
def correct_bit_default_value(self, sql):
|
|
corrected_sql = sql.replace("bit(1) DEFAULT '0'", "bit(1) DEFAULT b'0'")
|
|
return corrected_sql
|
|
|
|
def migrate_table_data(self, table):
|
|
# 为每个表迁移数据
|
|
try:
|
|
with DatabaseConnection(self.source_config) as source_db, DatabaseConnection(
|
|
self.target_config) as target_db:
|
|
with source_db.cursor() as source_cursor, target_db.cursor() as target_cursor:
|
|
source_cursor.execute(f"SELECT COUNT(*) FROM {table};")
|
|
source_count = source_cursor.fetchone()[0]
|
|
logging.info(f"Table {table}: {source_count} rows to migrate.")
|
|
|
|
source_cursor.execute(f"SELECT * FROM {table};")
|
|
rows_fetched = 0
|
|
while True:
|
|
rows = source_cursor.fetchmany(self.batch_size) # 分页读取
|
|
if not rows:
|
|
break
|
|
placeholders = ', '.join(['%s'] * len(rows[0]))
|
|
sql = f"INSERT INTO {table} VALUES ({placeholders})"
|
|
target_cursor.executemany(sql, rows)
|
|
rows_fetched += len(rows)
|
|
logging.debug(f"Table {table}: {rows_fetched}/{source_count} rows migrated.")
|
|
target_db.commit()
|
|
|
|
# 验证数据量是否一致
|
|
target_cursor.execute(f"SELECT COUNT(*) FROM {table};")
|
|
target_count = target_cursor.fetchone()[0]
|
|
if source_count == target_count:
|
|
logging.info(f"Table {table} migrated successfully. Total rows: {target_count}")
|
|
return True
|
|
else:
|
|
logging.error(
|
|
f"Table {table} migration failed. Source rows: {source_count}, Target rows: {target_count}")
|
|
return False
|
|
except Exception as e:
|
|
logging.error(f"Error migrating table {table}: {e}")
|
|
return False
|
|
|
|
def get_tables(self) -> List[str]:
|
|
# 获取源数据库的所有表名
|
|
try:
|
|
with DatabaseConnection(self.source_config) as db:
|
|
with db.cursor() as cursor:
|
|
cursor.execute("SHOW TABLES;")
|
|
tables = [table[0] for table in cursor.fetchall()]
|
|
return tables
|
|
except Exception as e:
|
|
logging.error(f"Error fetching table list: {e}")
|
|
return []
|
|
|
|
def migrate(self, concurrency):
|
|
self.clear_target_database()
|
|
tables = self.get_tables()
|
|
# 复制所有表结构
|
|
for table in tables:
|
|
if not self.copy_table_structure(table):
|
|
logging.error(f"Failed to copy structure for table {table}. Migration aborted.")
|
|
return
|
|
# 使用线程池并发迁移数据
|
|
with ThreadPoolExecutor(max_workers=concurrency) as executor:
|
|
futures = {executor.submit(self.migrate_table_data, table): table for table in tables}
|
|
for future in as_completed(futures):
|
|
table = futures[future]
|
|
try:
|
|
result = future.result()
|
|
if result:
|
|
logging.info(f"Table {table} migration succeeded.")
|
|
else:
|
|
logging.error(f"Table {table} migration failed.")
|
|
except Exception as e:
|
|
logging.error(f"Error migrating table {table}: {e}")
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
config = load_config(args.config)
|
|
setup_logging(config["logging"]["level"], log_file='migration.log', file_mode='a')
|
|
# 更新数据库配置,如果通过命令行提供了具体信息
|
|
if args.source_host and args.target_host:
|
|
source_config = {
|
|
"host": args.source_host,
|
|
"port": args.source_port,
|
|
"user": args.source_user,
|
|
"password": args.source_password,
|
|
"db": args.source_db,
|
|
}
|
|
target_config = {
|
|
"host": args.target_host,
|
|
"port": args.target_port,
|
|
"user": args.target_user,
|
|
"password": args.target_password,
|
|
"db": args.target_db,
|
|
}
|
|
else:
|
|
# 使用默认配置
|
|
source_config = config["databases"]["source_db"]
|
|
target_config = config["databases"]["target_db"]
|
|
|
|
migrator = DBMigrator(source_config, target_config, config["batch_size"])
|
|
start_time = time.time()
|
|
migrator.migrate(config["concurrency"])
|
|
end_time = time.time() # 脚本完成执行的时间
|
|
duration = end_time - start_time # 计算脚本执行耗时
|
|
|
|
# 输出脚本执行耗时
|
|
print(f"Script completed in {duration:.2f} seconds.")
|
|
logging.info(f"Script completed in {duration:.2f} seconds.")
|
|
print(f"Used {config['concurrency']} threads and batch size of {config['batch_size']}.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|