Skip to content

Commit 97b107b

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 195b75a commit 97b107b

File tree

2 files changed

+5
-7
lines changed

2 files changed

+5
-7
lines changed

generation/maisi/scripts/diff_model_create_training_data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,9 @@ def process_file(
160160

161161

162162
@torch.inference_mode()
163-
def diff_model_create_training_data(env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int) -> None:
163+
def diff_model_create_training_data(
164+
env_config_path: str, model_config_path: str, model_def_path: str, num_gpus: int
165+
) -> None:
164166
"""
165167
Create training data for the diffusion model.
166168
@@ -224,9 +226,7 @@ def diff_model_create_training_data(env_config_path: str, model_config_path: str
224226
parser.add_argument(
225227
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
226228
)
227-
parser.add_argument(
228-
"--num_gpus", type=int, default=1, help="Number of GPUs to use for distributed training"
229-
)
229+
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for distributed training")
230230

231231
args = parser.parse_args()
232232
diff_model_create_training_data(args.env_config, args.model_config, args.model_def, args.num_gpus)

generation/maisi/scripts/diff_model_train.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -430,9 +430,7 @@ def diff_model_train(env_config_path: str, model_config_path: str, model_def_pat
430430
parser.add_argument(
431431
"--model_def", type=str, default="./configs/config_maisi.json", help="Path to model definition file"
432432
)
433-
parser.add_argument(
434-
"--num_gpus", type=int, default=1, help="Number of GPUs to use for training"
435-
)
433+
parser.add_argument("--num_gpus", type=int, default=1, help="Number of GPUs to use for training")
436434

437435
args = parser.parse_args()
438436
diff_model_train(args.env_config, args.model_config, args.model_def, args.num_gpus)

0 commit comments

Comments
 (0)