diff --git a/scripts/image_migrate.py b/scripts/image_migrate.py index 2e4eaa57..d0f9cd89 100755 --- a/scripts/image_migrate.py +++ b/scripts/image_migrate.py @@ -147,46 +147,6 @@ class ImageMigrator: # 确保paramiko已导入 ensure_paramiko() - # 连接目标服务器(必须为SSH) - if self.target_config.is_local(): - print("错误: 目标服务器必须为SSH服务器,不能是本地目录") - return False - - self.target_client = paramiko.SSHClient() - self.target_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - - if self.verbose: - print( - f"连接到目标服务器: {self.target_config.host}:{self.target_config.port}" - ) - - if self.target_config.key_file: - key = paramiko.RSAKey.from_private_key_file(self.target_config.key_file) - - self.target_client.connect( - hostname=self.target_config.host, - port=self.target_config.port, - username=self.target_config.username, - pkey=key, - ) - else: - self.target_client.connect( - hostname=self.target_config.host, - port=self.target_config.port, - username=self.target_config.username, - password=self.target_config.password, - ) - - self.target_sftp = self.target_client.open_sftp() - - # 检查SFTP服务器的工作目录 - if self.verbose: - try: - cwd = self.target_sftp.getcwd() - print(f"DEBUG: 目标SFTP服务器当前工作目录: {cwd}") - except Exception as e: - print(f"DEBUG: 无法获取目标SFTP服务器工作目录: {e}") - # 连接源服务器(如果是SSH类型) if not self.source_config.is_local(): self.source_client = paramiko.SSHClient() @@ -218,11 +178,61 @@ class ImageMigrator: self.source_sftp = self.source_client.open_sftp() + # 检查SFTP服务器的工作目录 + if self.verbose: + try: + cwd = self.source_sftp.getcwd() + print(f"DEBUG: 源SFTP服务器当前工作目录: {cwd}") + except Exception as e: + print(f"DEBUG: 无法获取源SFTP服务器工作目录: {e}") + + # 连接目标服务器(如果为SSH) + if not self.target_config.is_local(): + self.target_client = paramiko.SSHClient() + self.target_client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + if self.verbose: + print( + f"连接到目标服务器: {self.target_config.host}:{self.target_config.port}" + ) + + if self.target_config.key_file: + key = paramiko.RSAKey.from_private_key_file( + self.target_config.key_file + ) + + self.target_client.connect( + hostname=self.target_config.host, + port=self.target_config.port, + username=self.target_config.username, + pkey=key, + ) + else: + self.target_client.connect( + hostname=self.target_config.host, + port=self.target_config.port, + username=self.target_config.username, + password=self.target_config.password, + ) + + self.target_sftp = self.target_client.open_sftp() + + # 检查SFTP服务器的工作目录 + if self.verbose: + try: + cwd = self.target_sftp.getcwd() + print(f"DEBUG: 目标SFTP服务器当前工作目录: {cwd}") + except Exception as e: + print(f"DEBUG: 无法获取目标SFTP服务器工作目录: {e}") + if self.verbose: source_type = ( "本地目录" if self.source_config.is_local() else "SSH服务器" ) - print(f"连接成功! 源: {source_type}, 目标: SSH服务器") + target_type = ( + "本地目录" if self.target_config.is_local() else "SSH服务器" + ) + print(f"连接成功! 源: {source_type}, 目标: {target_type}") return True except Exception as e: @@ -718,7 +728,7 @@ def read_image_paths(image_list_file: str) -> List[str]: def main(): parser = argparse.ArgumentParser( - description="从本地目录或SSH服务器迁移图片到远程SSH服务器,保持相对路径结构", + description="本地目录之间迁移图片,本地目录或SSH服务器迁移图片到远程SSH服务器,保持相对路径结构", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" 示例: @@ -756,8 +766,7 @@ def main(): parser.add_argument( "--source-type", choices=["local", "ssh"], - default="local", - help="源服务器类型: local(本地目录) 或 ssh(SSH服务器),默认: local", + help="源服务器类型: local(本地目录) 或 ssh(SSH服务器)", ) parser.add_argument("--source-host", help="源服务器地址(仅SSH类型需要)") parser.add_argument( @@ -772,6 +781,11 @@ def main(): parser.add_argument("--source-dir", help="源服务器图片基础目录") # 目标服务器选项 + parser.add_argument( + "--target-type", + choices=["local", "ssh"], + help="目标服务器类型: local(本地目录) 或 ssh(SSH服务器)", + ) parser.add_argument("--target-host", help="目标服务器地址") parser.add_argument( "--target-port", type=int, default=22, help="目标服务器端口(默认: 22)" @@ -859,23 +873,38 @@ def main(): # 构建目标服务器配置 target_config_data = config.get("target", {}) - target_config_data["type"] = "ssh" # 目标必须是SSH - - ssh_config = target_config_data.get("ssh", {}) # 命令行参数覆盖配置文件 - if args.target_host: - ssh_config["host"] = args.target_host - if args.target_port: - ssh_config["port"] = args.target_port - if args.target_user: - ssh_config["username"] = args.target_user - if args.target_password: - ssh_config["password"] = args.target_password - if args.target_key: - ssh_config["key_file"] = args.target_key + if args.target_type: + target_config_data["type"] = args.target_type - target_config_data["ssh"] = ssh_config + # 命令行参数覆盖配置文件 + if args.target_type == "ssh" or target_config_data.get("type") == "ssh": + ssh_config = target_config_data.get("ssh", {}) + + if args.target_host: + ssh_config["host"] = args.target_host + if args.target_port: + ssh_config["port"] = args.target_port + if args.target_user: + ssh_config["username"] = args.target_user + if args.target_password: + ssh_config["password"] = args.target_password + if args.target_key: + ssh_config["key_file"] = args.target_key + + target_config_data["ssh"] = ssh_config + + # 检查必要的目标SSH参数 + if not ssh_config.get("host"): + print("错误: 必须指定目标服务器地址 (--target-host)") + sys.exit(1) + if not ssh_config.get("username"): + print("错误: 必须指定目标服务器用户名 (--target-user)") + sys.exit(1) + else: + # 本地类型,不需要SSH配置 + target_config_data.pop("ssh", None) # 设置基础目录 if args.target_dir: @@ -884,14 +913,6 @@ def main(): print("错误: 必须指定目标服务器图片基础目录 (--target-dir)") sys.exit(1) - # 检查必要的目标SSH参数 - if not ssh_config.get("host"): - print("错误: 必须指定目标服务器地址 (--target-host)") - sys.exit(1) - if not ssh_config.get("username"): - print("错误: 必须指定目标服务器用户名 (--target-user)") - sys.exit(1) - # 创建服务器配置对象 try: source_config = ServerConfig.from_dict(source_config_data)