init commit

This commit is contained in:
hxuanyu 2024-03-27 22:49:54 +08:00
commit 0514683ccb
3 changed files with 226 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/.idea/

View File

@ -0,0 +1,17 @@
databases:
source_db:
host: '192.168.31.204'
port: 3306
user: 'root'
password: '123456'
database: 'allocative/test_master'
target_db:
host: '192.168.31.203'
port: 3306
user: 'root'
password: '123456'
database: 'allocative'
concurrency: 5
batch_size: 2000
logging:
level: DEBUG

View File

@ -0,0 +1,208 @@
import time
import yaml
import pymysql
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List
import logging
# 默认配置文件路径
DEFAULT_CONFIG_PATH = "config.yaml"
# 读取配置文件
def load_config(config_path="config.yaml"):
with open(config_path, "r") as file:
config = yaml.safe_load(file)
return config
# 配置日志
def setup_logging(config_level, log_file='migration.log', file_mode='a'):
numeric_level = getattr(logging, config_level.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError(f'Invalid log level: {config_level}')
# 创建一个日志记录器
logger = logging.getLogger()
logger.setLevel(numeric_level)
# 创建一个控制台日志处理器并设置级别
ch = logging.StreamHandler()
ch.setLevel(numeric_level)
# 创建一个文件日志处理器并设置级别
fh = logging.FileHandler(log_file, mode=file_mode)
fh.setLevel(numeric_level)
# 创建一个日志格式器
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
# 为处理器添加格式器
ch.setFormatter(formatter)
fh.setFormatter(formatter)
# 为日志记录器添加处理器
logger.addHandler(ch)
logger.addHandler(fh)
# 解析命令行参数
def parse_args():
parser = argparse.ArgumentParser(description="Database Migration Script")
parser.add_argument('--config', type=str, default="config.yaml", help="Path to the configuration file")
parser.add_argument('--source_host', type=str, help="Source database host")
parser.add_argument('--source_port', type=int, help="Source database port")
parser.add_argument('--source_user', type=str, help="Source database user")
parser.add_argument('--source_password', type=str, help="Source database password")
parser.add_argument('--source_db', type=str, help="Source database name")
parser.add_argument('--target_host', type=str, help="Target database host")
parser.add_argument('--target_port', type=int, help="Target database port")
parser.add_argument('--target_user', type=str, help="Target database user")
parser.add_argument('--target_password', type=str, help="Target database password")
parser.add_argument('--target_db', type=str, help="Target database name")
return parser.parse_args()
# 数据库连接上下文管理器
class DatabaseConnection:
def __init__(self, db_config):
self.db_config = db_config
self.connection = None
def __enter__(self):
self.connection = pymysql.connect(**self.db_config)
return self.connection
def __exit__(self, exc_type, exc_val, exc_tb):
if self.connection:
self.connection.close()
# 数据库操作类
class DBMigrator:
def __init__(self, source_config, target_config, batch_size):
self.source_config = source_config
self.target_config = target_config
self.batch_size = batch_size
def clear_target_database(self):
# 清空目标数据库所有表
try:
with DatabaseConnection(self.target_config) as target_db:
with target_db.cursor() as cursor:
cursor.execute("SHOW TABLES;")
tables = cursor.fetchall()
for table in tables:
cursor.execute(f"TRUNCATE TABLE {table[0]};")
target_db.commit()
return True
except Exception as e:
logging.error(f"Error clearing target database: {e}")
return False
def migrate_table_data(self, table):
# 为每个表迁移数据
try:
with DatabaseConnection(self.source_config) as source_db, DatabaseConnection(
self.target_config) as target_db:
with source_db.cursor() as source_cursor, target_db.cursor() as target_cursor:
source_cursor.execute(f"SELECT COUNT(*) FROM {table};")
source_count = source_cursor.fetchone()[0]
logging.info(f"Table {table}: {source_count} rows to migrate.")
source_cursor.execute(f"SELECT * FROM {table};")
rows_fetched = 0
while True:
rows = source_cursor.fetchmany(self.batch_size) # 分页读取
if not rows:
break
placeholders = ', '.join(['%s'] * len(rows[0]))
sql = f"INSERT INTO {table} VALUES ({placeholders})"
target_cursor.executemany(sql, rows)
rows_fetched += len(rows)
logging.debug(f"Table {table}: {rows_fetched}/{source_count} rows migrated.")
target_db.commit()
# 验证数据量是否一致
target_cursor.execute(f"SELECT COUNT(*) FROM {table};")
target_count = target_cursor.fetchone()[0]
if source_count == target_count:
logging.info(f"Table {table} migrated successfully. Total rows: {target_count}")
return True
else:
logging.error(
f"Table {table} migration failed. Source rows: {source_count}, Target rows: {target_count}")
return False
except Exception as e:
logging.error(f"Error migrating table {table}: {e}")
return False
def get_tables(self) -> List[str]:
# 获取源数据库的所有表名
try:
with DatabaseConnection(self.source_config) as db:
with db.cursor() as cursor:
cursor.execute("SHOW TABLES;")
tables = [table[0] for table in cursor.fetchall()]
return tables
except Exception as e:
logging.error(f"Error fetching table list: {e}")
return []
def migrate(self, concurrency):
self.clear_target_database()
tables = self.get_tables()
with ThreadPoolExecutor(max_workers=concurrency) as executor:
futures = {executor.submit(self.migrate_table_data, table): table for table in tables}
for future in as_completed(futures):
table = futures[future]
try:
result = future.result()
if result:
logging.info(f"Table {table} migration succeeded.")
else:
logging.error(f"Table {table} migration failed.")
except Exception as e:
logging.error(f"Error migrating table {table}: {e}")
def main():
args = parse_args()
config = load_config(args.config)
setup_logging(config["logging"]["level"], log_file='migration.log', file_mode='a')
# 更新数据库配置,如果通过命令行提供了具体信息
if args.source_host and args.target_host:
source_config = {
"host": args.source_host,
"port": args.source_port,
"user": args.source_user,
"password": args.source_password,
"db": args.source_db,
}
target_config = {
"host": args.target_host,
"port": args.target_port,
"user": args.target_user,
"password": args.target_password,
"db": args.target_db,
}
else:
# 使用默认配置
source_config = config["databases"]["source_db"]
target_config = config["databases"]["target_db"]
migrator = DBMigrator(source_config, target_config, config["batch_size"])
start_time = time.time()
migrator.migrate(config["concurrency"])
end_time = time.time() # 脚本完成执行的时间
duration = end_time - start_time # 计算脚本执行耗时
# 输出脚本执行耗时
print(f"Script completed in {duration:.2f} seconds.")
logging.info(f"Script completed in {duration:.2f} seconds.")
print(f"Used {config['concurrency']} threads and batch size of {config['batch_size']}.")
if __name__ == "__main__":
main()