#!/usr/bin/env any-python
from __future__ import print_function, absolute_import, unicode_literals

from argparse import ArgumentParser, FileType, RawTextHelpFormatter
import sys
import re

# PY2 is true when we're running under Python 2.x It is used for appropriate
# return value selection of __str__ and __repr_ methods, which must both
# return str, not unicode (in Python 2) and str (in Python 3). In both cases
# the return type annotation is exactly the same, but due to unicode_literals
# being in effect, and the fact we often use a format string (which is an
# unicode string in Python 2), we must encode the it to byte string when
# running under Python 2.
PY2 = sys.version_info[0] == 2

# Define MYPY as False and use it as a conditional for typing import. Despite
# this declaration mypy will really treat MYPY as True when type-checking.
# This is required so that we can import typing on Python 2.x without the
# typing module installed. For more details see:
# https://mypy.readthedocs.io/en/latest/common_issues.html#import-cycles
MYPY = False
if MYPY:
    from typing import Any, Dict, List, Text, Tuple, Match


class Device(int):
    """
    Device is a device number with major and minor components.

    Note that this class does not attempt to mimic peculiar
    encoding used by the Linux kernel.
    """

    @classmethod
    def pack(cls, major, minor):
        # type: (int, int) -> Device
        return cls((major << 16) | (minor & (1 << 16) - 1))

    def __str__(self):
        # type: () -> str
        result = "{}:{}".format(self.major, self.minor)
        if PY2:
            return result.encode()
        return result

    def __repr__(self):
        # type: () -> str
        result = "Device.pack({}, {})".format(self.major, self.minor)
        if PY2:
            return result.encode()
        return result

    @property
    def major(self):
        # type: () -> int
        """major is the higher 16 bits of the device number."""
        return self >> 16

    @property
    def minor(self):
        # type: () -> int
        """minor is the lower 16 bits of the device number."""
        return self & ((1 << 16) - 1)


class MountInfoEntry(object):
    """Single entry in /proc/pid/mointinfo, see proc(5)"""

    known_attrs = {
        "mount_id": int,
        "parent_id": int,
        "dev": Device,
        "root_dir": str,
        "mount_point": str,
        "mount_opts": str,
        "opt_fields": list,
        "fs_type": str,
        "mount_source": str,
        "sb_opts": str,
    }

    def __init__(self):
        # type: () -> None
        self.mount_id = 0
        self.parent_id = 0
        self.dev = Device.pack(0, 0)
        self.root_dir = ""
        self.mount_point = ""
        self.mount_opts = ""
        self.opt_fields = []  # type: List[Text]
        self.fs_type = ""
        self.mount_source = ""
        self.sb_opts = ""

    @classmethod
    def parse(cls, line):
        # type: (Text) -> MountInfoEntry
        it = iter(line.split())
        self = cls()
        self.mount_id = int(next(it))
        self.parent_id = int(next(it))
        dev_maj, dev_min = map(int, next(it).split(":"))
        self.dev = Device((dev_maj << 16) | dev_min)
        self.root_dir = next(it)
        self.mount_point = next(it)
        self.mount_opts = next(it)
        self.opt_fields = []
        for opt_field in it:
            if opt_field == "-":
                break
            self.opt_fields.append(opt_field)
        self.fs_type = next(it)
        self.mount_source = next(it)
        self.sb_opts = next(it)
        try:
            next(it)
        except StopIteration:
            pass
        else:
            raise ValueError("leftovers after parsing {!r}".format(line))
        return self

    def __str__(self):
        # type: () -> str
        result = (
            "{0.mount_id} {0.parent_id} {0.dev} {0.root_dir}"
            " {0.mount_point} {0.mount_opts} {opt_fields} {0.fs_type}"
            " {0.mount_source} {0.sb_opts}"
        ).format(self, opt_fields=" ".join(self.opt_fields + ["-"]))
        if PY2:
            return result.encode()
        return result

    @property
    def dev_maj(self):
        # type: () -> int
        return self.dev.major

    @property
    def dev_min(self):
        # type: () -> int
        return self.dev.minor


class FilterExpr(object):
    """FilterExpr is the interface for filtering mount entries."""

    def __contains__(self, entry):
        # type: (MountInfoEntry) -> bool
        """__contains__ returns true if a mount entry matches the filter."""


class AttrFilter(FilterExpr):
    """AttrFilter performs equality test against a given attribute."""

    def __init__(self, attr, value):
        # type: (Text, Any) -> None
        self.attr = attr
        self.value = value

    def __contains__(self, entry):
        # type: (MountInfoEntry) -> bool
        value = getattr(entry, self.attr)
        return bool(value == self.value)


class AttrPrefixFilter(FilterExpr):
    """AttrPrefixFilter performs prefix test against a given attribute."""

    def __init__(self, attr, value):
        # type: (Text, Text) -> None
        self.attr = attr
        self.value = value

    def __contains__(self, entry):
        # type: (MountInfoEntry) -> bool
        value = str(getattr(entry, self.attr))
        return value.startswith(self.value)


def parse_filter(expr):
    # type: (Text) -> FilterExpr
    """parse_filter parses one of the known filter expressions."""
    if "=" in expr:
        # Accept both .attr=value and attr=value as exact attribute match.
        if expr.startswith("."):
            expr = expr.lstrip(".")
        attr, value = expr.split("=", 1)
        try:
            typ = MountInfoEntry.known_attrs[attr]
        except KeyError:
            raise ValueError("invalid filter expression {!r}".format(expr))
        else:
            return AttrFilter(attr, typ(value))
    elif expr.endswith("..."):
        # Treat /path/... as prefix match on mount_point.
        return AttrPrefixFilter("mount_point", expr.rstrip("..."))
    else:
        # Treat /path as exact match on mount_point.
        return AttrFilter("mount_point", expr)


def parse_attr(expr):
    # type: (Text) -> Text
    """parse_attr parses attribute references (for display)."""
    known = sorted(MountInfoEntry.known_attrs)
    if expr.lstrip(".") in known:
        return expr.lstrip(".")
    raise ValueError(
        "invalid attribute selector {!r}" " (known: {})".format(expr, known)
    )


def parse_exprs(exprs):
    # type: (List[Text]) -> Tuple[List[FilterExpr], List[Text]]
    """parse_exprs parses filter expressions and attribute references."""
    # Filters are either .attr=value, /path, /path...
    filters = [
        parse_filter(expr) for expr in exprs if "=" in expr or not expr.startswith(".")
    ]
    # Attributes are always .attr
    attrs = [
        parse_attr(expr) for expr in exprs if expr.startswith(".") and "=" not in expr
    ]
    return filters, attrs


def matches(entry, filters):
    # type: (MountInfoEntry, List[FilterExpr]) -> bool
    """
    matches checks if a mount entry matches a list of filter expressions.
    Filter expressions are ANDed together.
    """
    for f in filters:
        if entry not in f:
            return False
    return True


def renumber_snap_revision(entry, seen):
    # type: (MountInfoEntry, Dict[Tuple[Text, Text], int]) -> None
    """renumber_snap_revisions re-numbers snap revision numbers in paths."""

    def compose_preferred(parts, n):
        # type: (List[Text], int) -> List[Text]
        return parts[:3] + ["{}".format(n)] + parts[4:]

    def compose_alternate(parts, n):
        # type: (List[Text], int) -> List[Text]
        return parts[:6] + ["{}".format(n)] + parts[7:]

    def alloc_n(snap_name, snap_rev):
        # type: (Text, Text) -> int
        key = (snap_name, snap_rev)
        try:
            return seen[key]
        except KeyError:
            n = len([name for (name, rev) in seen if name == snap_name]) + 1
            seen[key] = n
            return n

    parts = entry.mount_point.split("/")
    if len(parts) >= 4 and parts[:2] == ["", "snap"]:
        snap_name = parts[2]
        snap_rev = parts[3]
        compose = compose_preferred
    elif len(parts) >= 7 and parts[:5] == ["", "var", "lib", "snapd", "snap"]:
        snap_name = parts[5]
        snap_rev = parts[6]
        compose = compose_alternate
    else:
        return
    n = alloc_n(snap_name, snap_rev)
    entry.mount_point = "/".join(compose(parts, n))


def renumber_opt_fields(entry, seen):
    # type: (MountInfoEntry, Dict[int, int]) -> None
    """renumber_opt_fields re-numbers peer group in optional fields."""

    def alloc_n(peer_group):
        # type: (int) -> int
        key = peer_group
        try:
            return seen[key]
        except KeyError:
            n = len(seen) + 1
            seen[key] = n
            return n

    def fn(m):
        # type: (Match[Text]) -> Text
        return "{}".format(alloc_n(int(m.group(1))))

    entry.opt_fields = [re.sub("(\\d+)", fn, opt) for opt in entry.opt_fields]


def renumber_loop_devices(entry, seen):
    # type: (MountInfoEntry, Dict[int, int]) -> None
    """renumber_loop_devices re-numbers loop device numbers."""

    def alloc_n(loop_nr):
        # type: (int) -> int
        key = loop_nr
        try:
            return seen[key]
        except KeyError:
            n = len(seen)
            seen[key] = n
            return n

    def fn(m):
        # type: (Match[Text]) -> Text
        return "loop{}".format(alloc_n(int(m.group(1))))

    entry.mount_source = re.sub("loop(\\d+)", fn, entry.mount_source)


def renumber_mount_ids(entry, seen):
    # type: (MountInfoEntry, Dict[int, int]) -> None
    """renumber_mount_ids re-numbers mount and parent mount IDs."""

    def alloc_n(mount_id):
        # type: (int) -> int
        key = mount_id
        try:
            return seen[key]
        except KeyError:
            n = len(seen)
            seen[key] = n
            return n

    # NOTE: renumber the parent ahead of the mount to get more
    # expected relationship between them.
    entry.parent_id = alloc_n(entry.parent_id)
    entry.mount_id = alloc_n(entry.mount_id)


def renumber_devices(entry, seen):
    # type: (MountInfoEntry, Dict[Device, Device]) -> None
    """renumber_devices re-numbers major:minor device numbers."""

    def alloc_n(dev):
        # type: (Device) -> Device
        key = dev
        try:
            return seen[key]
        except KeyError:
            # We haven't seen the major:minor pair precisely but perhaps we've
            # seen the major number already? Check if this major is already
            # remapped, if so reuse that value. If not just allocate the next
            # one based on cardinality of the set of major numbers we've seen.
            major = 0
            for orig, remapped in seen.items():
                if orig.major == dev.major:
                    major = remapped.major
                    break
            else:
                major = len({orig.major for orig in seen})
            # Allocate the next minor number based on the cardinality of the
            # set of minor numbers matching the major number.
            minor = len({orig.minor for orig in seen if orig.major == dev.major})
            n = Device.pack(major, minor)
            seen[key] = n
            return n

    entry.dev = alloc_n(entry.dev)


def rewrite_renumber(entries):
    # type: (List[MountInfoEntry]) -> None
    """rewrite_renumber applies all re-numbering helpers."""
    seen_opt_fields = {}  # type: Dict[int, int]
    seen_loops = {}  # type: Dict[int, int]
    seen_snap_revs = {}  # type: Dict[Tuple[Text, Text], int]
    seen_mount_ids = {}  # type: Dict[int, int]
    seen_devices = {}  # type: Dict[Device, Device]
    for entry in entries:
        renumber_mount_ids(entry, seen_mount_ids)
        renumber_devices(entry, seen_devices)
        renumber_snap_revision(entry, seen_snap_revs)
        renumber_opt_fields(entry, seen_opt_fields)
        renumber_loop_devices(entry, seen_loops)


def rewrite_rename(entries):
    # type: (List[MountInfoEntry]) -> None
    """rewrite_renameapplies all re-naming helpers."""
    # TODO: allocate devices like everything else above.
    for entry in entries:
        entry.mount_source = re.sub(
            "/dev/[sv]d([a-z])", "/dev/sd\\1", entry.mount_source
        )


def main():
    # type: () -> None
    parser = ArgumentParser(
        epilog="""
Expressions are ANDed together and have one of the following forms:

    .ATTR=VALUE     mount entry attribute ATTR is equal to VALUE
    PATH            mount point is equal to PATH
    PATH...         mount point starts with PATH

In addition .ATTR syntax can be used to limit display to only certain
attributes. By default the output is identical to raw mountinfo.
Known attributes, applicable for both filtering and display.

    mount_point:    path where mount is attached in the file system
    mount_source:   path of the mounted device or bind-mount origin
    fs_type:        filesystem type
    mount_opts:     options applying to the mount point only
    sb_opts:        options applying to the mounted filesystem
    opt_fields:     optional fields, used for propagation information
    mount_id:       mount point identifier
    parent_id:      identifier of parent mount point
    dev:            major:minor numbers of the mounted device
    root_dir:       subtree of the mounted filesystem exposed at mount_point
    """,
        formatter_class=RawTextHelpFormatter,
    )
    parser.add_argument("-v", "--version", action="version", version="1.0")
    parser.add_argument(
        "-f",
        metavar="MOUNTINFO",
        dest="file",
        type=FileType(),
        default="/proc/self/mountinfo",
        help="parse specified mountinfo file",
    )
    parser.add_argument(
        "--one", default=False, action="store_true", help="expect exactly one match"
    )
    parser.add_argument(
        "exprs",
        metavar="EXPRESSION",
        nargs="*",
        help="filter or display expression (see below)",
    )
    group = parser.add_argument_group("Rewriting rules")
    group.add_argument(
        "--renumber",
        action="store_true",
        help="Reassign mount IDs, device numbers, snap revisions"
        " and loopback devices",
    )
    group.add_argument(
        "--rename", action="store_true", help="Reassign block device names"
    )
    opts = parser.parse_args()
    try:
        filters, attrs = parse_exprs(opts.exprs)
    except ValueError as exc:
        raise SystemExit(exc)
    entries = [MountInfoEntry.parse(line) for line in opts.file]
    entries = [e for e in entries if matches(e, filters)]
    if opts.renumber:
        rewrite_renumber(entries)
    if opts.rename:
        rewrite_rename(entries)
    for e in entries:
        if attrs:
            values = []  # type: List[Any]
            for attr in attrs:
                value = getattr(e, attr)
                if isinstance(value, list):
                    value = " ".join(value)
                values.append(value)
            print(*values)
        else:
            print(e)
    if opts.one and len(entries) != 1:
        raise SystemExit(
            "--one requires exactly one match, found {}".format(len(entries))
        )

if __name__ == "__main__":
    main()
