Skip to content

Instantly share code, notes, and snippets.

import torch
import glob
import numpy as np
import imageio
import cv2
import torch.utils.data as data
import torch.nn.functional as F
from torchvision.transforms import Compose
@ranftlr
ranftlr / loss.py
Last active December 26, 2023 08:45
def trimmed_mae_loss(prediction, target, mask, trim=0.2):
M = torch.sum(mask, (1, 2))
res = prediction - target
res = res[mask.bool()].abs()
trimmed, _ = torch.sort(res.view(-1), descending=False)[
: int(len(res) * (1.0 - trim))
]
import torch
import cv2
import h5py
import numpy as np
from scipy.io import loadmat
import torch.utils.data as data
import torch.nn.functional as F
from torchvision.transforms import Compose