init commit
This commit is contained in:
commit
0514683ccb
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/.idea/
|
17
database_migrate/config.yaml
Normal file
17
database_migrate/config.yaml
Normal file
@ -0,0 +1,17 @@
|
||||
databases:
|
||||
source_db:
|
||||
host: '192.168.31.204'
|
||||
port: 3306
|
||||
user: 'root'
|
||||
password: '123456'
|
||||
database: 'allocative/test_master'
|
||||
target_db:
|
||||
host: '192.168.31.203'
|
||||
port: 3306
|
||||
user: 'root'
|
||||
password: '123456'
|
||||
database: 'allocative'
|
||||
concurrency: 5
|
||||
batch_size: 2000
|
||||
logging:
|
||||
level: DEBUG
|
208
database_migrate/db_migrate.py
Normal file
208
database_migrate/db_migrate.py
Normal file
@ -0,0 +1,208 @@
|
||||
import time
|
||||
|
||||
import yaml
|
||||
import pymysql
|
||||
import argparse
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List
|
||||
import logging
|
||||
|
||||
# 默认配置文件路径
|
||||
DEFAULT_CONFIG_PATH = "config.yaml"
|
||||
|
||||
|
||||
# 读取配置文件
|
||||
def load_config(config_path="config.yaml"):
|
||||
with open(config_path, "r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
return config
|
||||
|
||||
# 配置日志
|
||||
def setup_logging(config_level, log_file='migration.log', file_mode='a'):
|
||||
numeric_level = getattr(logging, config_level.upper(), None)
|
||||
if not isinstance(numeric_level, int):
|
||||
raise ValueError(f'Invalid log level: {config_level}')
|
||||
|
||||
# 创建一个日志记录器
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(numeric_level)
|
||||
|
||||
# 创建一个控制台日志处理器并设置级别
|
||||
ch = logging.StreamHandler()
|
||||
ch.setLevel(numeric_level)
|
||||
|
||||
# 创建一个文件日志处理器并设置级别
|
||||
fh = logging.FileHandler(log_file, mode=file_mode)
|
||||
fh.setLevel(numeric_level)
|
||||
|
||||
# 创建一个日志格式器
|
||||
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# 为处理器添加格式器
|
||||
ch.setFormatter(formatter)
|
||||
fh.setFormatter(formatter)
|
||||
|
||||
# 为日志记录器添加处理器
|
||||
logger.addHandler(ch)
|
||||
logger.addHandler(fh)
|
||||
|
||||
|
||||
# 解析命令行参数
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Database Migration Script")
|
||||
parser.add_argument('--config', type=str, default="config.yaml", help="Path to the configuration file")
|
||||
parser.add_argument('--source_host', type=str, help="Source database host")
|
||||
parser.add_argument('--source_port', type=int, help="Source database port")
|
||||
parser.add_argument('--source_user', type=str, help="Source database user")
|
||||
parser.add_argument('--source_password', type=str, help="Source database password")
|
||||
parser.add_argument('--source_db', type=str, help="Source database name")
|
||||
parser.add_argument('--target_host', type=str, help="Target database host")
|
||||
parser.add_argument('--target_port', type=int, help="Target database port")
|
||||
parser.add_argument('--target_user', type=str, help="Target database user")
|
||||
parser.add_argument('--target_password', type=str, help="Target database password")
|
||||
parser.add_argument('--target_db', type=str, help="Target database name")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# 数据库连接上下文管理器
|
||||
class DatabaseConnection:
|
||||
def __init__(self, db_config):
|
||||
self.db_config = db_config
|
||||
self.connection = None
|
||||
|
||||
def __enter__(self):
|
||||
self.connection = pymysql.connect(**self.db_config)
|
||||
return self.connection
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
|
||||
|
||||
# 数据库操作类
|
||||
class DBMigrator:
|
||||
def __init__(self, source_config, target_config, batch_size):
|
||||
self.source_config = source_config
|
||||
self.target_config = target_config
|
||||
self.batch_size = batch_size
|
||||
|
||||
def clear_target_database(self):
|
||||
# 清空目标数据库所有表
|
||||
try:
|
||||
with DatabaseConnection(self.target_config) as target_db:
|
||||
with target_db.cursor() as cursor:
|
||||
cursor.execute("SHOW TABLES;")
|
||||
tables = cursor.fetchall()
|
||||
for table in tables:
|
||||
cursor.execute(f"TRUNCATE TABLE {table[0]};")
|
||||
target_db.commit()
|
||||
return True
|
||||
except Exception as e:
|
||||
logging.error(f"Error clearing target database: {e}")
|
||||
return False
|
||||
|
||||
def migrate_table_data(self, table):
|
||||
# 为每个表迁移数据
|
||||
try:
|
||||
with DatabaseConnection(self.source_config) as source_db, DatabaseConnection(
|
||||
self.target_config) as target_db:
|
||||
with source_db.cursor() as source_cursor, target_db.cursor() as target_cursor:
|
||||
source_cursor.execute(f"SELECT COUNT(*) FROM {table};")
|
||||
source_count = source_cursor.fetchone()[0]
|
||||
logging.info(f"Table {table}: {source_count} rows to migrate.")
|
||||
|
||||
source_cursor.execute(f"SELECT * FROM {table};")
|
||||
rows_fetched = 0
|
||||
while True:
|
||||
rows = source_cursor.fetchmany(self.batch_size) # 分页读取
|
||||
if not rows:
|
||||
break
|
||||
placeholders = ', '.join(['%s'] * len(rows[0]))
|
||||
sql = f"INSERT INTO {table} VALUES ({placeholders})"
|
||||
target_cursor.executemany(sql, rows)
|
||||
rows_fetched += len(rows)
|
||||
logging.debug(f"Table {table}: {rows_fetched}/{source_count} rows migrated.")
|
||||
target_db.commit()
|
||||
|
||||
# 验证数据量是否一致
|
||||
target_cursor.execute(f"SELECT COUNT(*) FROM {table};")
|
||||
target_count = target_cursor.fetchone()[0]
|
||||
if source_count == target_count:
|
||||
logging.info(f"Table {table} migrated successfully. Total rows: {target_count}")
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
f"Table {table} migration failed. Source rows: {source_count}, Target rows: {target_count}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logging.error(f"Error migrating table {table}: {e}")
|
||||
return False
|
||||
|
||||
def get_tables(self) -> List[str]:
|
||||
# 获取源数据库的所有表名
|
||||
try:
|
||||
with DatabaseConnection(self.source_config) as db:
|
||||
with db.cursor() as cursor:
|
||||
cursor.execute("SHOW TABLES;")
|
||||
tables = [table[0] for table in cursor.fetchall()]
|
||||
return tables
|
||||
except Exception as e:
|
||||
logging.error(f"Error fetching table list: {e}")
|
||||
return []
|
||||
|
||||
def migrate(self, concurrency):
|
||||
self.clear_target_database()
|
||||
tables = self.get_tables()
|
||||
with ThreadPoolExecutor(max_workers=concurrency) as executor:
|
||||
futures = {executor.submit(self.migrate_table_data, table): table for table in tables}
|
||||
for future in as_completed(futures):
|
||||
table = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
if result:
|
||||
logging.info(f"Table {table} migration succeeded.")
|
||||
else:
|
||||
logging.error(f"Table {table} migration failed.")
|
||||
except Exception as e:
|
||||
logging.error(f"Error migrating table {table}: {e}")
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
config = load_config(args.config)
|
||||
setup_logging(config["logging"]["level"], log_file='migration.log', file_mode='a')
|
||||
# 更新数据库配置,如果通过命令行提供了具体信息
|
||||
if args.source_host and args.target_host:
|
||||
source_config = {
|
||||
"host": args.source_host,
|
||||
"port": args.source_port,
|
||||
"user": args.source_user,
|
||||
"password": args.source_password,
|
||||
"db": args.source_db,
|
||||
}
|
||||
target_config = {
|
||||
"host": args.target_host,
|
||||
"port": args.target_port,
|
||||
"user": args.target_user,
|
||||
"password": args.target_password,
|
||||
"db": args.target_db,
|
||||
}
|
||||
else:
|
||||
# 使用默认配置
|
||||
source_config = config["databases"]["source_db"]
|
||||
target_config = config["databases"]["target_db"]
|
||||
|
||||
migrator = DBMigrator(source_config, target_config, config["batch_size"])
|
||||
start_time = time.time()
|
||||
migrator.migrate(config["concurrency"])
|
||||
end_time = time.time() # 脚本完成执行的时间
|
||||
duration = end_time - start_time # 计算脚本执行耗时
|
||||
|
||||
# 输出脚本执行耗时
|
||||
print(f"Script completed in {duration:.2f} seconds.")
|
||||
logging.info(f"Script completed in {duration:.2f} seconds.")
|
||||
print(f"Used {config['concurrency']} threads and batch size of {config['batch_size']}.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user