from manifest_edit.plugin.mpd import ManifestIteratorPlugin
from manifest_edit.context import Context
from manifest_edit import libfmp4
from schema import Schema


class Plugin(ManifestIteratorPlugin):
    _name = __name__

    _keys = ["set_id"]

    def __init__(self):
        self._periods = {}
        super().__init__()

    def schema(self):
        return Schema({self._keys[0]: str})

    def findFathers(self, root, element):
        """
        This method will recurse into
        manifest->periods->adaptationSets->representations
        until it finds and return the list with the provided element.
        """
        if type(root) == libfmp4.mpd.Manifest:
            next_root = root.periods
            child_type = libfmp4.mpd.Period
        elif type(root) == libfmp4.mpd.Period:
            next_root = root.adaptationSets
            child_type = libfmp4.mpd.AdaptationSet
        elif type(root) == libfmp4.mpd.AdaptationSet:
            next_root = root.representations
            child_type = libfmp4.mpd.Representation
        else:
            # not found
            return

        if type(element) == child_type and element in next_root:
            yield root, next_root
        else:
            for next_element in next_root:
                yield from self.findFathers(next_element, element)

    def _minOrInvalidValue(self, attrib_list, invalid_value):
        try:
            minimum = min(attrib_list)
        except ValueError:
            minimum = invalid_value

        return minimum

    def _maxOrInvalidValue(self, attrib_list, invalid_value):
        try:
            maximum = max(attrib_list)
        except ValueError:
            maximum = invalid_value

        return maximum

    def checkAdaptationSet(self, parent):
        """
        We will check here that max and min for attributes @bandwidth,
        @width, @height and @frameRate are coherent with the representations

        In case of empty representation list, the fields will be removed by
        setting the default "invalid" value.
        """
        # width
        attrib_list = [repres.width for repres in parent.representations]
        # I won't set it if it wasn't set before
        if parent.minWidth:
            parent.minWidth = self._minOrInvalidValue(attrib_list, 0)
        if parent.maxWidth:
            parent.maxWidth = self._maxOrInvalidValue(attrib_list, 0)

        # height
        attrib_list = [repres.height for repres in parent.representations]
        if parent.minHeight:
            parent.minHeight = self._minOrInvalidValue(attrib_list, 0)
        if parent.maxHeight:
            parent.maxHeight = self._maxOrInvalidValue(attrib_list, 0)

        # bandwidth
        attrib_list = [repres.bandwidth for repres in parent.representations]
        if parent.minBandwidth:
            parent.minBandwidth = self._minOrInvalidValue(attrib_list, 0)
        if parent.maxBandwidth:
            parent.maxBandwidth = self._maxOrInvalidValue(attrib_list, 0)

        # frameRate
        attrib_list = [repres.frameRate for repres in parent.representations]
        if parent.minFramerate:
            parent.minFramerate = self._minOrInvalidValue(
                attrib_list, libfmp4.FractionUint32()
            )
        if parent.maxFramerate:
            parent.maxFramerate = self._maxOrInvalidValue(
                attrib_list, libfmp4.FractionUint32()
            )

    def checkCoherence(self, parent):
        """
        This method is supposed to be a final "check" stage to be performed
        after element(s) removal to make sure the manifest is still coherent
        and legal.

        The input "parent" argument is the element to check for coherence. It
        can be a manifest(if a period was removed), a period (if an ad_set was
        removed) or an ad_set (if a representation was removed).

        - that the parent min and max properties now reflect the actual
          content of the holding_list (i.e. if the minimum bandwidth repr
          was removed, the parent adaptation set must modify its
          minbandwidth property accordingly).
        """

        if type(parent) == libfmp4.mpd.Manifest:
            self.checkManifest(parent)
        elif type(parent) == libfmp4.mpd.Period:
            self.checkPeriod(parent)
        elif type(parent) == libfmp4.mpd.AdaptationSet:
            self.checkAdaptationSet(parent)
        else:
            Context.log_error(
                f"Received type {type(parent)} in manifest coherence check!"
            )
            raise Exception("Logic error in element_remove plugin! Aborting..")

    def _addRepresentationToAdaptationSet(self, set_id, period, adaptation_set, representation):
        if self._periods.get(period.id, None) is None:
            self._periods[period.id] = {
                "element": period,
                "adaptation_sets": {},
                "adaptation_sets_to_remove": set(),
                "max_ad_set_id" : max([int(adset.id) for adset in period.adaptationSets])
            }
        period = self._periods[period.id]
        # Use set_id to compute destination id as max_id+set_id
        new_ad_set_id = str(period["max_ad_set_id"]+int(set_id))

        if period["adaptation_sets"].get(new_ad_set_id, None) is None:
            new_adaptation_set = libfmp4.mpd.AdaptationSet(adaptation_set)
            new_adaptation_set.representations.clear()
            new_adaptation_set.id = new_ad_set_id
            period["adaptation_sets"][new_ad_set_id] = new_adaptation_set
            period["adaptation_sets_to_remove"].add(adaptation_set)

        period["adaptation_sets"][new_ad_set_id].representations.append(representation)

        self.checkCoherence(period["adaptation_sets"][new_ad_set_id])

    def splitAdaptationSet(self, manifest, storage):
        #adaptation_sets_to_remove = set()

        for config, element in self.config(manifest, storage):
            if type(element) == libfmp4.mpd.Representation:
                adaptation_set = list(self.findFathers(manifest, element))[0][0]
                period = list(self.findFathers(manifest, adaptation_set))[0][0]
                self._addRepresentationToAdaptationSet(config[self._keys[0]], period, adaptation_set, element)
                #adaptation_sets_to_remove.add(adaptation_set)
            else:
                Context.log_error(
                    "In adaptationset_switching plugin, you "
                    "must configure switching for adaptationsets"
                    ", not periods or representations"
                )

        for period in self._periods.values():
            for adaptation_set in period["adaptation_sets_to_remove"]:
                period["element"].adaptationSets.remove(adaptation_set)
            for adaptation_set in period["adaptation_sets"].values():
                period["element"].adaptationSets.append(adaptation_set)
                self.checkCoherence(adaptation_set)

    def _mustSplit(self, manifest, storage):
        '''
        Check how many adaptation sets the selected representations span.
        If the cardinality is higher than the that of the configured set_ids
        than we must split, otherwise not necessary.

        In absence of this check, a manifest that does not need splitting
        would be modified: that would probably be acceptable because an
        equivalent output manifest would be generated, with different ad set
        ids, however we want to avoid doing that.
        '''

        set_ids = set()
        ad_set_ids = {}
        for config, element in self.config(manifest, storage):
            set_ids.add(config[self._keys[0]])
            adaptation_set = list(self.findFathers(manifest, element))[0][0]
            period = list(self.findFathers(manifest, adaptation_set))[0][0]
            if ad_set_ids.get(period.id, None) is None:
                ad_set_ids[period.id] = set()

            ad_set_ids[period.id].add(adaptation_set.id)

        for period, sets in ad_set_ids.items():
            if len(set_ids) > len(sets):
                return True
            
        return False


    def process(self, manifest, storage):
        if self._mustSplit(manifest, storage):
            self.splitAdaptationSet(manifest, storage)