From a1d83440397f52402fd0ff6ca7c47acfe2a2589b Mon Sep 17 00:00:00 2001
From: Zach Teed <zariteed@umich.edu>
Date: Thu, 30 Jul 2020 21:25:36 -0600
Subject: [PATCH] added training code

---
 README.md         |  9 ++++-----
 core/datasets.py  | 10 +++++-----
 train.py          |  7 ++++---
 train_mixed.sh    |  6 ++++++
 train_standard.sh |  6 ++++++
 5 files changed, 25 insertions(+), 13 deletions(-)
 create mode 100755 train_mixed.sh
 create mode 100755 train_standard.sh

diff --git a/README.md b/README.md
index 87dd170..330b256 100644
--- a/README.md
+++ b/README.md
@@ -8,11 +8,11 @@ Zachary Teed and Jia Deng<br/>
 <img src="RAFT.png">
 
 ## Requirements
-The code has been tested with PyTorch 1.5.1 and PyTorch Nightly. If you want to train with mixed precision, you will have to install the nightly build.
+The code has been tested with PyTorch 1.6 and Cuda 10.1.
 ```Shell
 conda create --name raft
 conda activate raft
-conda install pytorch torchvision cudatoolkit=10.1 -c pytorch-nightly
+conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 -c pytorch
 conda install matplotlib
 conda install tensorboard
 conda install scipy
@@ -67,8 +67,7 @@ python evaluate.py --model=models/raft-things.pth --dataset=sintel
 ```
 
 ## Training
-Training code will be made available in the next few days
-<!-- We used the following training schedule in our paper (note: we use 2 GPUs for training). Training logs will be written to the `runs` which can be visualized using tensorboard
+We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard
 ```Shell
 ./train_standard.sh
 ```
@@ -76,4 +75,4 @@ Training code will be made available in the next few days
 If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU)
 ```Shell
 ./train_mixed.sh
-``` -->
+```
diff --git a/core/datasets.py b/core/datasets.py
index c5f0a36..3411fda 100644
--- a/core/datasets.py
+++ b/core/datasets.py
@@ -200,7 +200,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
     """ Create the data loader for the corresponding trainign set """
 
     if args.stage == 'chairs':
-        aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 1.0, 'do_flip': True}
+        aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
         train_dataset = FlyingChairs(aug_params, split='training')
     
     elif args.stage == 'things':
@@ -210,14 +210,14 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
         train_dataset = clean_dataset + final_dataset
 
     elif args.stage == 'sintel':
-        aug_params = {'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True}
+        aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
         things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
         sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
         sintel_final = MpiSintel(aug_params, split='training', dstype='final')        
 
         if TRAIN_DS == 'C+T+K+S+H':
-            kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.7, 'do_flip': True})
-            hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.5, 'do_flip': True})
+            kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
+            hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
             train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
 
         elif TRAIN_DS == 'C+T+K/S':
@@ -225,7 +225,7 @@ def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
 
     elif args.stage == 'kitti':
         aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
-        train_dataset = KITTI(args, image_size=args.image_size, is_val=False)
+        train_dataset = KITTI(aug_params, split='training')
 
     train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 
         pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
diff --git a/train.py b/train.py
index c2c58c5..1314141 100644
--- a/train.py
+++ b/train.py
@@ -39,7 +39,7 @@ except:
 
 
 # exclude extremly large displacements
-MAX_FLOW = 500
+MAX_FLOW = 400
 SUM_FREQ = 100
 VAL_FREQ = 5000
 
@@ -181,13 +181,14 @@ def train(args):
 
             loss, metrics = sequence_loss(flow_predictions, flow, valid)
             scaler.scale(loss).backward()
-
-            scaler.unscale_(optimizer)
+            scaler.unscale_(optimizer)                
             torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
             
             scaler.step(optimizer)
             scheduler.step()
             scaler.update()
+
+
             logger.push(metrics)
 
             if total_steps % VAL_FREQ == VAL_FREQ - 1:
diff --git a/train_mixed.sh b/train_mixed.sh
new file mode 100755
index 0000000..ae92aac
--- /dev/null
+++ b/train_mixed.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+mkdir -p checkpoints
+python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision 
+python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision
+python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --mixed_precision
+python -u train.py --name raft-kitti  --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --mixed_precision
diff --git a/train_standard.sh b/train_standard.sh
new file mode 100755
index 0000000..19b5809
--- /dev/null
+++ b/train_standard.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+mkdir -p checkpoints
+python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 12 --lr 0.0004 --image_size 368 496 --wdecay 0.0001
+python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001
+python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001
+python -u train.py --name raft-kitti  --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001
\ No newline at end of file
-- 
GitLab