Compare commits

...

7 Commits

3 changed files with 83 additions and 19 deletions

View File

@ -1,12 +1,12 @@
databases: databases:
source_db: source_db:
host: '192.168.31.204' host: '82.156.30.246'
port: 3306 port: 3307
user: 'root' user: 'root'
password: '123456' password: '123456'
database: 'allocative/test_master' database: 'allocative/test_master'
target_db: target_db:
host: '192.168.31.203' host: '82.156.30.246'
port: 3306 port: 3306
user: 'root' user: 'root'
password: '123456' password: '123456'

View File

@ -6,6 +6,7 @@ import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List from typing import List
import logging import logging
import re
# 默认配置文件路径 # 默认配置文件路径
DEFAULT_CONFIG_PATH = "config.yaml" DEFAULT_CONFIG_PATH = "config.yaml"
@ -17,32 +18,27 @@ def load_config(config_path="config.yaml"):
config = yaml.safe_load(file) config = yaml.safe_load(file)
return config return config
# 配置日志 # 配置日志
def setup_logging(config_level, log_file='migration.log', file_mode='a'): def setup_logging(config_level, log_file='migration.log', file_mode='a'):
numeric_level = getattr(logging, config_level.upper(), None) numeric_level = getattr(logging, config_level.upper(), None)
if not isinstance(numeric_level, int): if not isinstance(numeric_level, int):
raise ValueError(f'Invalid log level: {config_level}') raise ValueError(f'Invalid log level: {config_level}')
# 日志记录器
# 创建一个日志记录器
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(numeric_level) logger.setLevel(numeric_level)
# 控制台日志处理器
# 创建一个控制台日志处理器并设置级别
ch = logging.StreamHandler() ch = logging.StreamHandler()
ch.setLevel(numeric_level) ch.setLevel(numeric_level)
# 文件日志处理器
# 创建一个文件日志处理器并设置级别
fh = logging.FileHandler(log_file, mode=file_mode) fh = logging.FileHandler(log_file, mode=file_mode)
fh.setLevel(numeric_level) fh.setLevel(numeric_level)
# 日志格式器
# 创建一个日志格式器 formatter = logging.Formatter('%(asctime)s [Thread ID:%(thread)d] [%(threadName)s] %(levelname)s: %(message)s')
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') # 添加格式器
# 为处理器添加格式器
ch.setFormatter(formatter) ch.setFormatter(formatter)
fh.setFormatter(formatter) fh.setFormatter(formatter)
# 添加处理器
# 为日志记录器添加处理器
logger.addHandler(ch) logger.addHandler(ch)
logger.addHandler(fh) logger.addHandler(fh)
@ -87,20 +83,45 @@ class DBMigrator:
self.batch_size = batch_size self.batch_size = batch_size
def clear_target_database(self): def clear_target_database(self):
# 清空目标数据库所有表 # 删除目标数据库所有表
try: try:
with DatabaseConnection(self.target_config) as target_db: with DatabaseConnection(self.target_config) as target_db:
with target_db.cursor() as cursor: with target_db.cursor() as cursor:
cursor.execute("SET FOREIGN_KEY_CHECKS = 0;")
cursor.execute("SHOW TABLES;") cursor.execute("SHOW TABLES;")
tables = cursor.fetchall() tables = cursor.fetchall()
for table in tables: for table in tables:
cursor.execute(f"TRUNCATE TABLE {table[0]};") cursor.execute(f"DROP TABLE IF EXISTS {table[0]};")
cursor.execute("SET FOREIGN_KEY_CHECKS = 1;")
target_db.commit() target_db.commit()
return True return True
except Exception as e: except Exception as e:
logging.error(f"Error clearing target database: {e}") logging.error(f"Error clearing target database: {e}")
return False return False
def copy_table_structure(self, table):
# 复制表结构
try:
with DatabaseConnection(self.source_config) as source_db, DatabaseConnection(
self.target_config) as target_db:
with source_db.cursor() as source_cursor, target_db.cursor() as target_cursor:
source_cursor.execute(f"SHOW CREATE TABLE {table};")
create_table_sql = source_cursor.fetchone()[1]
# 检查并修正BIT(1)类型字段的默认值表示, 当前脚本为简单替换,其他方案待讨论
corrected_sql = self.correct_bit_default_value(create_table_sql)
target_cursor.execute(corrected_sql)
target_db.commit()
return True
except Exception as e:
logging.error(f"Error copying table structure for {table}: {e}")
return False
def correct_bit_default_value(self, sql):
corrected_sql = sql.replace("bit(1) DEFAULT '0'", "bit(1) DEFAULT b'0'")
return corrected_sql
def migrate_table_data(self, table): def migrate_table_data(self, table):
# 为每个表迁移数据 # 为每个表迁移数据
try: try:
@ -153,6 +174,12 @@ class DBMigrator:
def migrate(self, concurrency): def migrate(self, concurrency):
self.clear_target_database() self.clear_target_database()
tables = self.get_tables() tables = self.get_tables()
# 复制所有表结构
for table in tables:
if not self.copy_table_structure(table):
logging.error(f"Failed to copy structure for table {table}. Migration aborted.")
return
# 使用线程池并发迁移数据
with ThreadPoolExecutor(max_workers=concurrency) as executor: with ThreadPoolExecutor(max_workers=concurrency) as executor:
futures = {executor.submit(self.migrate_table_data, table): table for table in tables} futures = {executor.submit(self.migrate_table_data, table): table for table in tables}
for future in as_completed(futures): for future in as_completed(futures):
@ -171,7 +198,7 @@ def main():
args = parse_args() args = parse_args()
config = load_config(args.config) config = load_config(args.config)
setup_logging(config["logging"]["level"], log_file='migration.log', file_mode='a') setup_logging(config["logging"]["level"], log_file='migration.log', file_mode='a')
# 更新数据库配置,如果通过命令行提供了具体信息 # 更新数据库配置,如果参数传入,则使用参数中的配置
if args.source_host and args.target_host: if args.source_host and args.target_host:
source_config = { source_config = {
"host": args.source_host, "host": args.source_host,

View File

@ -0,0 +1,37 @@
#!/bin/bash
# 脚本名称
SCRIPT_NAME="db_migrate.py"
# 打包后的输出目录
DIST_DIR="./dist"
BUILD_DIR="./build"
# 打包命令
package() {
echo "开始打包 $SCRIPT_NAME ..."
pyinstaller --onefile "$SCRIPT_NAME"
echo "打包完成."
}
# 清理命令
clean() {
echo "开始清理打包生成的文件..."
rm -rf "$DIST_DIR" "$BUILD_DIR" "${SCRIPT_NAME%.*}.spec"
echo "清理完成."
}
# 根据传入的命令参数执行对应的函数
case "$1" in
package)
package
;;
clean)
clean
;;
*)
echo "未知命令: $1"
echo "支持的命令: package, clean"
exit 1
;;
esac