Compare commits
7 Commits
1fc0edcc50
...
6a20176a3a
Author | SHA1 | Date | |
---|---|---|---|
6a20176a3a | |||
6deb809f3c | |||
320c5d1f22 | |||
ad802057be | |||
83cdd05f3b | |||
0e5c49b138 | |||
4d8fe412d2 |
@ -1,12 +1,12 @@
|
|||||||
databases:
|
databases:
|
||||||
source_db:
|
source_db:
|
||||||
host: '192.168.31.204'
|
host: '82.156.30.246'
|
||||||
port: 3306
|
port: 3307
|
||||||
user: 'root'
|
user: 'root'
|
||||||
password: '123456'
|
password: '123456'
|
||||||
database: 'allocative/test_master'
|
database: 'allocative/test_master'
|
||||||
target_db:
|
target_db:
|
||||||
host: '192.168.31.203'
|
host: '82.156.30.246'
|
||||||
port: 3306
|
port: 3306
|
||||||
user: 'root'
|
user: 'root'
|
||||||
password: '123456'
|
password: '123456'
|
||||||
|
@ -6,6 +6,7 @@ import argparse
|
|||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from typing import List
|
from typing import List
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
# 默认配置文件路径
|
# 默认配置文件路径
|
||||||
DEFAULT_CONFIG_PATH = "config.yaml"
|
DEFAULT_CONFIG_PATH = "config.yaml"
|
||||||
@ -17,32 +18,27 @@ def load_config(config_path="config.yaml"):
|
|||||||
config = yaml.safe_load(file)
|
config = yaml.safe_load(file)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
# 配置日志
|
# 配置日志
|
||||||
def setup_logging(config_level, log_file='migration.log', file_mode='a'):
|
def setup_logging(config_level, log_file='migration.log', file_mode='a'):
|
||||||
numeric_level = getattr(logging, config_level.upper(), None)
|
numeric_level = getattr(logging, config_level.upper(), None)
|
||||||
if not isinstance(numeric_level, int):
|
if not isinstance(numeric_level, int):
|
||||||
raise ValueError(f'Invalid log level: {config_level}')
|
raise ValueError(f'Invalid log level: {config_level}')
|
||||||
|
# 日志记录器
|
||||||
# 创建一个日志记录器
|
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
logger.setLevel(numeric_level)
|
logger.setLevel(numeric_level)
|
||||||
|
# 控制台日志处理器
|
||||||
# 创建一个控制台日志处理器并设置级别
|
|
||||||
ch = logging.StreamHandler()
|
ch = logging.StreamHandler()
|
||||||
ch.setLevel(numeric_level)
|
ch.setLevel(numeric_level)
|
||||||
|
# 文件日志处理器
|
||||||
# 创建一个文件日志处理器并设置级别
|
|
||||||
fh = logging.FileHandler(log_file, mode=file_mode)
|
fh = logging.FileHandler(log_file, mode=file_mode)
|
||||||
fh.setLevel(numeric_level)
|
fh.setLevel(numeric_level)
|
||||||
|
# 日志格式器
|
||||||
# 创建一个日志格式器
|
formatter = logging.Formatter('%(asctime)s [Thread ID:%(thread)d] [%(threadName)s] %(levelname)s: %(message)s')
|
||||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
# 添加格式器
|
||||||
|
|
||||||
# 为处理器添加格式器
|
|
||||||
ch.setFormatter(formatter)
|
ch.setFormatter(formatter)
|
||||||
fh.setFormatter(formatter)
|
fh.setFormatter(formatter)
|
||||||
|
# 添加处理器
|
||||||
# 为日志记录器添加处理器
|
|
||||||
logger.addHandler(ch)
|
logger.addHandler(ch)
|
||||||
logger.addHandler(fh)
|
logger.addHandler(fh)
|
||||||
|
|
||||||
@ -87,20 +83,45 @@ class DBMigrator:
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
def clear_target_database(self):
|
def clear_target_database(self):
|
||||||
# 清空目标数据库所有表
|
# 删除目标数据库所有表
|
||||||
try:
|
try:
|
||||||
with DatabaseConnection(self.target_config) as target_db:
|
with DatabaseConnection(self.target_config) as target_db:
|
||||||
with target_db.cursor() as cursor:
|
with target_db.cursor() as cursor:
|
||||||
|
cursor.execute("SET FOREIGN_KEY_CHECKS = 0;")
|
||||||
cursor.execute("SHOW TABLES;")
|
cursor.execute("SHOW TABLES;")
|
||||||
tables = cursor.fetchall()
|
tables = cursor.fetchall()
|
||||||
for table in tables:
|
for table in tables:
|
||||||
cursor.execute(f"TRUNCATE TABLE {table[0]};")
|
cursor.execute(f"DROP TABLE IF EXISTS {table[0]};")
|
||||||
|
cursor.execute("SET FOREIGN_KEY_CHECKS = 1;")
|
||||||
target_db.commit()
|
target_db.commit()
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Error clearing target database: {e}")
|
logging.error(f"Error clearing target database: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def copy_table_structure(self, table):
|
||||||
|
# 复制表结构
|
||||||
|
try:
|
||||||
|
with DatabaseConnection(self.source_config) as source_db, DatabaseConnection(
|
||||||
|
self.target_config) as target_db:
|
||||||
|
with source_db.cursor() as source_cursor, target_db.cursor() as target_cursor:
|
||||||
|
source_cursor.execute(f"SHOW CREATE TABLE {table};")
|
||||||
|
create_table_sql = source_cursor.fetchone()[1]
|
||||||
|
|
||||||
|
# 检查并修正BIT(1)类型字段的默认值表示, 当前脚本为简单替换,其他方案待讨论
|
||||||
|
corrected_sql = self.correct_bit_default_value(create_table_sql)
|
||||||
|
|
||||||
|
target_cursor.execute(corrected_sql)
|
||||||
|
target_db.commit()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"Error copying table structure for {table}: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
def correct_bit_default_value(self, sql):
|
||||||
|
corrected_sql = sql.replace("bit(1) DEFAULT '0'", "bit(1) DEFAULT b'0'")
|
||||||
|
return corrected_sql
|
||||||
|
|
||||||
def migrate_table_data(self, table):
|
def migrate_table_data(self, table):
|
||||||
# 为每个表迁移数据
|
# 为每个表迁移数据
|
||||||
try:
|
try:
|
||||||
@ -153,6 +174,12 @@ class DBMigrator:
|
|||||||
def migrate(self, concurrency):
|
def migrate(self, concurrency):
|
||||||
self.clear_target_database()
|
self.clear_target_database()
|
||||||
tables = self.get_tables()
|
tables = self.get_tables()
|
||||||
|
# 复制所有表结构
|
||||||
|
for table in tables:
|
||||||
|
if not self.copy_table_structure(table):
|
||||||
|
logging.error(f"Failed to copy structure for table {table}. Migration aborted.")
|
||||||
|
return
|
||||||
|
# 使用线程池并发迁移数据
|
||||||
with ThreadPoolExecutor(max_workers=concurrency) as executor:
|
with ThreadPoolExecutor(max_workers=concurrency) as executor:
|
||||||
futures = {executor.submit(self.migrate_table_data, table): table for table in tables}
|
futures = {executor.submit(self.migrate_table_data, table): table for table in tables}
|
||||||
for future in as_completed(futures):
|
for future in as_completed(futures):
|
||||||
@ -171,7 +198,7 @@ def main():
|
|||||||
args = parse_args()
|
args = parse_args()
|
||||||
config = load_config(args.config)
|
config = load_config(args.config)
|
||||||
setup_logging(config["logging"]["level"], log_file='migration.log', file_mode='a')
|
setup_logging(config["logging"]["level"], log_file='migration.log', file_mode='a')
|
||||||
# 更新数据库配置,如果通过命令行提供了具体信息
|
# 更新数据库配置,如果参数传入,则使用参数中的配置
|
||||||
if args.source_host and args.target_host:
|
if args.source_host and args.target_host:
|
||||||
source_config = {
|
source_config = {
|
||||||
"host": args.source_host,
|
"host": args.source_host,
|
||||||
|
37
database_migrate/pyinstalll.sh
Normal file
37
database_migrate/pyinstalll.sh
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# 脚本名称
|
||||||
|
SCRIPT_NAME="db_migrate.py"
|
||||||
|
|
||||||
|
# 打包后的输出目录
|
||||||
|
DIST_DIR="./dist"
|
||||||
|
BUILD_DIR="./build"
|
||||||
|
|
||||||
|
# 打包命令
|
||||||
|
package() {
|
||||||
|
echo "开始打包 $SCRIPT_NAME ..."
|
||||||
|
pyinstaller --onefile "$SCRIPT_NAME"
|
||||||
|
echo "打包完成."
|
||||||
|
}
|
||||||
|
|
||||||
|
# 清理命令
|
||||||
|
clean() {
|
||||||
|
echo "开始清理打包生成的文件..."
|
||||||
|
rm -rf "$DIST_DIR" "$BUILD_DIR" "${SCRIPT_NAME%.*}.spec"
|
||||||
|
echo "清理完成."
|
||||||
|
}
|
||||||
|
|
||||||
|
# 根据传入的命令参数执行对应的函数
|
||||||
|
case "$1" in
|
||||||
|
package)
|
||||||
|
package
|
||||||
|
;;
|
||||||
|
clean)
|
||||||
|
clean
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
echo "未知命令: $1"
|
||||||
|
echo "支持的命令: package, clean"
|
||||||
|
exit 1
|
||||||
|
;;
|
||||||
|
esac
|
Loading…
x
Reference in New Issue
Block a user