"""Nengo implementation of transposed convolution (aka. deconvolution)
"""
from distutils.version import LooseVersion

import nengo
from nengo.builder import Operator, Signal
from nengo.builder.operator import Reset
from nengo.exceptions import BuildError
from nengo.transforms import ChannelShape
import nengo_dl
from nengo_dl.transform_builders import ConvIncBuilder
import numpy as np
import tensorflow as tf


if tf is None:
    tf_convtranspose = None
elif LooseVersion(tf.__version__) < LooseVersion("1.14.0"):
    tf_convtranspose = tf.nn.conv2d_transpose
else:
    tf_convtranspose = tf.nn.conv2d_transpose


class ConvolutionTranspose(nengo.Convolution):
    def __init__(
        self,
        n_filters,
        input_shape,
        kernel_size=(2, 2),
        strides=(2, 2),
        padding="valid",
        channels_last=True,
        init=nengo.dists.Uniform(1, 1),
    ):
        super().__init__(
            n_filters=n_filters,
            input_shape=input_shape,
            kernel_size=kernel_size,
            strides=strides,
            padding=padding,
            channels_last=channels_last,
            init=init,
        )

    @property
    def output_shape(self):
        """Output shape after applying convolution to input."""
        output_shape = np.array(self.input_shape.spatial_shape, dtype=int)
        if self.padding == "same":
            output_shape = output_shape * self.strides
        elif self.padding == "valid":
            output_shape = (output_shape - 1) * self.strides + self.kernel_size
        else:
            raise ValueError("Unrecognized padding type %r" % self.padding)

        output_shape = tuple(output_shape)
        output_shape = (
            output_shape + (self.n_filters,)
            if self.channels_last
            else (self.n_filters,) + output_shape
        )
        return ChannelShape(output_shape, channels_last=self.channels_last)


class ConvTransposeInc(Operator):
    def __init__(self, W, X, Y, conv, tag=None):
        super().__init__(tag=tag)

        self.conv = conv

        self.sets = []
        self.incs = [Y]
        self.reads = [W, X]
        self.updates = []

    @property
    def W(self):
        return self.reads[0]

    @property
    def X(self):
        return self.reads[1]

    @property
    def Y(self):
        return self.incs[0]

    def _descstr(self):
        return "convtranspose2d(%s, %s) -> %s" % (self.W, self.X, self.Y)

    def make_step(self, signals, dt, rng):
        raise NotImplementedError("Not supported outside NengoDL")


@nengo.builder.Builder.register(ConvolutionTranspose)
def build_convolutiontranspose(
    model, transform, sig_in, decoders=None, encoders=None, rng=np.random
):
    if decoders is not None:
        raise BuildError(
            "Applying a convolution transform to a decoded "
            "connection is not supported"
        )

    # Shouldn't be possible for encoders to be non-None, since that only
    # occurs for a connection solver with weights=True, and those can only
    # be applied to decoded connections (which are disallowed above)
    assert encoders is None

    weights = transform.sample(rng=rng)
    weight_sig = Signal(weights, readonly=True, name="%s.weights" % transform)
    weighted = Signal(shape=transform.size_out, name="%s.weighted" % transform)
    model.add_op(Reset(weighted))

    model.add_op(
        ConvTransposeInc(
            weight_sig, sig_in, weighted, transform, tag="%s.apply_weights" % transform
        )
    )

    return weighted, weight_sig


@nengo_dl.builder.Builder.register(ConvTransposeInc)
class ConvTransposeIncBuilder(ConvIncBuilder):
    def build_step(self, signals):
        W = signals.gather(self.W_data)
        X = signals.gather(self.X_data)

        if self.perm_x is not None:
            # move channels to end
            X = tf.transpose(X, perm=self.perm_x)

        if self.perm_w is not None:
            # concatenate kernels along output channel dimension
            W = tf.transpose(W, perm=self.perm_w)
            W = tf.reshape(W, self.reshape_w)

        channels_last = self.fmt == "NHWC"
        input_shape = tf.shape(X)
        input_size = input_shape[1:3] if channels_last else input_shape[2:4]
        # input_channels = input_shape[3:4] if channels_last else input_shape[1:2]
        input_n = input_shape[0:1]

        kernel_size = self.conv.kernel_size
        strides = self.conv.strides
        if self.conv.padding.upper() == "SAME":
            output_size = input_size * strides
        elif self.conv.padding.upper() == "VALID":
            output_size = (input_size - (1, 1)) * strides + kernel_size

        output_channels = tf.shape(W)[3:4]
        output_shape = tf.concat(
            (input_n, output_size, output_channels)
            if channels_last
            else (input_n, output_channels, output_size),
            axis=0,
        )

        # swap channels, because conv2d_transpose order is for forward weights
        filters = tf.transpose(W, perm=(0, 1, 3, 2))

        Y = tf_convtranspose(
            X,
            filters,  # has shape (height, width, Y.channels, X.channels)
            output_shape=output_shape,
            strides=(1,) + strides + (1,) if channels_last else (1, 1) + strides,
            data_format=self.fmt,
            padding=self.conv.padding.upper(),
        )

        if self.reshape_y is not None:
            Y = tf.reshape(Y, self.reshape_y)
        if self.perm_y is not None:
            Y = tf.transpose(Y, perm=self.perm_y)

        # tensorflow loses track of shape information during transposes for some reason
        if self.reshape_y is None:
            Y.set_shape((signals.minibatch_size,) + self.conv.output_shape.shape)
        else:
            Y.set_shape(
                (signals.minibatch_size, self.n_ops) + self.conv.output_shape.shape
            )

        signals.scatter(self.Y_data, Y, mode="inc")
