1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52
| class BatchNorm1d(Module): def __init__(self, dim, eps=1e-5, momentum=0.1, device=None, dtype="float32"): super().__init__() self.dim = dim self.eps = eps self.momentum = momentum self.weight = Parameter(init.ones(dim, device=device)) self.bias = Parameter(init.zeros(dim, device=device)) self.running_mean = Parameter(init.zeros(dim, device=device)) self.running_var = Parameter(init.ones(dim, device=device))
def forward(self, x: Tensor) -> Tensor: batch_size, features = x.shape[0], x.shape[1] broadcast_weight = ops.broadcast_to(ops.reshape(self.weight, (1, -1)), x.shape) broadcast_bias = ops.broadcast_to(ops.reshape(self.bias, (1, -1)), x.shape)
if self.training: mean_x = ops.divide_scalar(ops.summation(x, axes=0), batch_size) broadcast_mean = ops.broadcast_to(ops.reshape(mean_x, (1,-1)), x.shape)
numerator = x - broadcast_mean
var_x = ops.power_scalar(numerator, 2) var_x = ops.summation(ops.divide_scalar(var_x, batch_size), axes=0) broadcast_var = ops.broadcast_to(ops.reshape(var_x, (1,-1)), x.shape)
denominator = ops.power_scalar(broadcast_var + self.eps, 0.5)
frac = numerator / denominator
y = ops.multiply(broadcast_weight, frac) + broadcast_bias
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean_x self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var_x else: broadcast_rm = ops.broadcast_to(ops.reshape(self.running_mean, (1, -1)), x.shape) broadcast_rv = ops.broadcast_to(ops.reshape(self.running_var, (1, -1)), x.shape)
numerator = x - broadcast_rm
denominator = ops.power_scalar(broadcast_rv + self.eps, 0.5)
frac = numerator / denominator
y = ops.multiply(broadcast_weight, frac) + broadcast_bias
return y
|