#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# firewall.py  -  Firewall control
#
# Copyright (C) 2017 PerFact Innovation GmbH & Co. KG
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
#
#

# For the shell wrapper (stdin, stdout)
import sys
import os
# For interfacing and executing iptables and friends
from .generic import safe_syscall, json_encode, json_decode
from .generic import to_ustring, tokenize_quoted


# Helpers for SSH and local shell execution

def firewall_control(host, fwstruct):
    '''Wrap the ssh call of /usr/bin/perfact-firewall-control

    Use this method from within Zope.
    '''
    control_bin = '/usr/bin/perfact-firewall-control'
    cmd = ['/usr/bin/ssh', host, 'sudo', control_bin]
    input = json_encode(fwstruct)
    retcode, output = safe_syscall(cmd, stdin_data=input,
                                   raisemode=True)
    return json_decode(output)


def binwrapper(dryrun=False):
    '''Wrapper used by /usr/bin/perfact-firewall-control

    This runs on the firewall machine, mostly started via ssh by Zope.
    '''
    # Read args list (if we want to use it)
    input = sys.stdin.read()
    if input:
        fwstruct = json_decode(input)
    else:
        fwstruct = None
    fw = FirewallControl(dryrun=dryrun)
    output = fw.write(fwstruct)
    print(json_encode(output))


# Private helper functions

def _parse_counts(val):
    ''' Read counter values from iptables-save

    >>> _parse_counts('[217547:16533572]') == (
    ...     {u'bytes': 16533572, u'packets': 217547})
    True
    '''
    packets, bytes = list(map(int, val.strip('[]').split(':')))
    return {u'packets': packets, u'bytes': bytes}


class FirewallControl:
    '''Features

    This module constitutes a front end for the following kernel internals
    relating to IP:

    iptables (tables, chains, rules)


    Future support may include:

    ip rule (send marked packets into other routing tables)

    ip route (build alternative routing tables)

    tc (shape traffic depending on whatever)

    ipset (sets usable for rules)

    The exported interface FirewallControl.write() is wrapped in a
    script, which should be made available in "sudoers".

    '''

    def __init__(self, dryrun=False):
        self.strict = False
        self.dryrun = dryrun

        def find_script(name):
            '''
            Return path to script either in /sbin or /usr/sbin for
            compatibility to different Ubuntu versions.
            '''
            for path in ['/sbin/', '/usr/sbin/']:
                result = path + name
                if os.path.exists(result):
                    return result

        self.iptables_bin = [find_script('iptables'), '-w', '60']
        self.iptables_save = find_script('iptables-save')

        self.outformat_allowed = ['python', 'iptables']
        self.outformat = 'python'

        self.iptables_revopts = {
            self.iptables_opts[key]:
            key for key in self.iptables_opts.keys()
        }

    builtin_chains = {
        u'filter': [
            u'INPUT', u'OUTPUT', u'FORWARD',
        ],
        u'nat': [
            u'PREROUTING', u'INPUT', u'OUTPUT', u'POSTROUTING',
        ],
        u'mangle': [
            u'PREROUTING', u'INPUT', u'FORWARD', u'OUTPUT', u'POSTROUTING',
        ],
        u'raw': [
            u'PREROUTING', u'OUTPUT',
        ],
        u'security': [
            u'INPUT', u'OUTPUT', u'FORWARD', u'SECMARK', u'CONNSECMARK',
        ],
    }
    iptables_opts = {
        # This must be extended if needed
        u's': u'source',
        u'd': u'destination',
        u'p': u'protocol',
        u'm': u'match',
        u'j': u'jump',
        u'g': u'goto',
        u'i': u'in-interface',
        u'o': u'out-interface',
        u'f': u'fragment',
        u'c': u'set-counters',
    }

    def read(self):
        '''Read the complete firewall with iptables-save and build a
        dictionary reflecting the current status.

        Enrich the return value with a checksum which can be used by the
        caller to detect changes to the ruleset.

        Alternatively the returned data might be plain iptables console output.
        To enable this, the fwstruct needs to set it via key:
        outformat='iptables'
        '''
        # propagate dryrun into components
        if self.dryrun:
            if self.outformat == 'python':
                retcode, dump = 0, example_iptables_save
            elif self.outformat == 'iptables':
                # retcode, dump = 0, example_iptables_dump
                raise NotImplementedError

        if not self.dryrun:
            if self.outformat == 'python':
                retcode, dump = safe_syscall(
                    [self.iptables_save, '-c'], raisemode=True)
            elif self.outformat == 'iptables':
                res = {u'iptables': {}}
                for table in self.builtin_chains:
                    # only string arguments are allowed in safe_syscall
                    table = str(table)
                    retcode, dump = safe_syscall(
                        self.iptables_bin + ['-L', '-n', '-v', '-t', table],
                        raisemode=True)
                    res['iptables'][table] = dump
                return res

        # ensure the dump material is unicode
        dump = to_ustring(dump)
        lines = []
        line = []
        comment = False
        for word in tokenize_quoted(
                dump, quotes=u'"', separators=u' \n', append_separators=True):
            # TODO: implement backslash_mode in tokenize_quoted
            if len(word) == 0:
                continue
            if word[0] == u'#' and len(line) == 0:
                comment = True
            line.append(word.strip())
            if word.endswith(u'\n'):
                if not comment:
                    lines.append(line)
                comment = False
                line = []

        tables = {}
        table = None
        for line in lines:
            if len(line) == 0:
                continue

            if line[0][0] == u'*':  # switch to new table
                table = line[0][1:]
                chains = {}
                rules = {}
                continue

            if line[0] == u'COMMIT':  # end of rules
                tables[table] = {u'chains': chains, u'rules': rules}
                continue

            if line[0][0] == ':':  # new chain
                chain = {}
                name = line[0][1:]
                if name in self.builtin_chains.keys():
                    chain[u'policy'] = line[1]
                chain[u'counters'] = _parse_counts(line[-1])
                chains[name] = chain
                continue
            # new rule
            chain = line[2]
            words = line[3:]
            options = []
            option = []
            while len(words):
                cur_word = words.pop(0)
                if len(option) and (cur_word == u'!' or
                                    cur_word.startswith(u'-')):
                    # flush collected option, start afresh
                    options.append(option)
                    option = []
                if cur_word == u'!':
                    option.append(u'!')
                    # jump to the next word
                    cur_word = words.pop(0)
                if cur_word.startswith(u'--'):
                    cur_word = cur_word[2:]
                elif cur_word.startswith(u'-'):
                    # map to long options
                    optstring = self.iptables_opts.get(cur_word[1])
                    if optstring is not None:
                        cur_word = optstring
                option.append(cur_word)

            # flush the last collected option
            if len(option):
                options.append(option)

            counters = _parse_counts(line[0])
            options.append(
                [u'counters', counters[u'packets'], counters[u'bytes']])
            if chain in rules.keys():
                rules[chain].append(options)
            else:
                rules[chain] = [options, ]

        if table:
            tables[table] = {u'chains': chains, u'rules': rules}

        return {u'iptables': tables}

    def write(self, fwstruct=None):
        '''Given a dictionary describing firewall manipulation commands, apply
        these commands to the local IP tables.

        Encapsulate the writing process in a save-restore cycle which
        rolls back on any error.

        If the caller gave a checksum, check this before applying any commands.

        As a final op, perform firewall_read() and return the output.

        It is possible to give no commands at all and just read out the current
        firewall configuration. For this use a fwstruct like:
        fwstruct = { u'command': None, u'fwstruct': {'iptables': {}, }
        >>> fw = FirewallControl(dryrun=True)
        >>> a = fw.write(example_fwstruct) # doctest: +ELLIPSIS
        executing: ... -w 60 -t filter -C INPUT --source \
192.168.42.254/32 --protocol tcp --match tcp --dport 80:443 --jump DROP
        executing: ... -w 60 -t filter -A INPUT --source \
192.168.42.254/32 --protocol tcp --match tcp --dport 80:443 --jump DROP
        >>> a == example_iptables_struct
        True
        '''
        if fwstruct:
            if isinstance(fwstruct, dict):
                # wrap in a list
                fwstruct = [fwstruct, ]

            for command_set in fwstruct:
                command = command_set['command']
                fwstruct = command_set['fwstruct']

                outformat = command_set.get('outformat')
                if outformat in self.outformat_allowed:
                    self.outformat = command_set['outformat']

                self.apply_commands(fwstruct, mode=command)

        new_state = self.read()
        return new_state

    def syscall(self, cmd):
        '''Wrapper for safe_syscall which uses self.dryrun'''
        if self.dryrun:
            print(u'executing: '+u' '.join(cmd))
            retcode, output = 0, u''
        else:
            retcode, output = safe_syscall(cmd, raisemode=False)
            if retcode and self.strict:
                raise ValueError(output)
        # ensure encoding neutrality
        return retcode, to_ustring(output)

    def apply_commands(self, fwstruct, mode):
        '''Generate commands from fwstruct
        mode: 'insert', 'add', 'delete'
        '''
        for table, tablestruct in fwstruct[u'iptables'].items():
            # Work on "chains" and "rules"
            if mode in [u'add', u'insert']:
                # Add needs to do chains first, then rules
                self.chain_ops(table, tablestruct[u'chains'],
                               mode=mode)
                self.rule_ops(table, tablestruct[u'rules'],
                              mode=mode)
            if mode == u'delete':
                # Delete needs to do rules first, then chains
                self.rule_ops(table, tablestruct[u'rules'],
                              mode=mode)
                self.chain_ops(table, tablestruct[u'chains'],
                               mode=mode)
        return

    def chain_ops(self, table, chains, mode):
        '''Performs creation or deletion of chains.

        Creation fails if self.strict is set and the chain already
        exists. Deletion will fail if there are references or rules
        still in the chain.
        '''
        option = {
            u'insert': u'-N',
            u'add': u'-N',
            u'delete': u'-X',
        }.get(mode)
        for chain, chainstruct in chains.items():
            # Ignore builtin chains
            if chain in self.builtin_chains[table]:
                continue
            cmd = self.iptables_bin + ['-t', table, option, chain]
            self.syscall(cmd)
        return

    def rule_ops(self, table, rules, mode):
        '''Performs creation and deletion of chains.

        This checks if the rule is already present before inserting or
        deleting.  In self.strict mode, finding the rule on "add" and
        not finding the rule on "delete" results in a failure.
        '''
        option = {
            u'insert': u'-I',
            u'add': u'-A',
            u'delete': u'-D',
        }.get(mode)
        for chain, ruledefs in rules.items():
            # Build iptables command options
            check_cmd = self.iptables_bin + ['-t', table, u'-C', chain]
            cmd = self.iptables_bin + ['-t', table, option, chain]
            for ruledef in ruledefs:
                ruleopts = self.rule_to_options(ruledef)
                # Check for presence
                if self.dryrun:
                    print(u'executing: '+u' '.join(check_cmd+ruleopts))
                    retcode, output = (0 if mode == u'delete' else 1), str(
                        check_cmd+ruleopts)
                else:
                    retcode, output = safe_syscall(
                        check_cmd+ruleopts, raisemode=False)
                # Add expects retcode > 0, delete expects retcode == 0
                if mode in [u'add', u'insert'] and retcode == 0:
                    # Cannot add the rule.
                    if self.strict:
                        raise ValueError(output)
                    else:
                        # Ignore this rule
                        continue
                elif mode == u'delete' and retcode != 0:
                    # Cannot delete the rule.
                    if self.strict:
                        raise ValueError(output)
                    else:
                        # Ignore this rule
                        continue

                # Add or delete the rule
                self.syscall(cmd+ruleopts)

        return

    def rule_to_options(self, ruledef):
        '''Convert a ruledef set to a list of options.

        >>> ruledef = [ [u'source', u'192.168.42.254/32',],
        ...             [u'protocol', u'tcp'],
        ...             [u'match', u'tcp'],
        ...             [u'dport', u'80:443'],
        ...             [u'jump', u'DROP'],
        ...             [u'counters', u'0', u'0'], ]
        >>> fw = FirewallControl()
        >>> (fw.rule_to_options(ruledef) ==
        ...  [u'--source', u'192.168.42.254/32', u'--protocol', u'tcp',
        ...   u'--match', u'tcp', u'--dport', u'80:443', u'--jump', u'DROP'])
        True
        '''
        out = []
        for item in ruledef:
            args = list(item)
            if args[0] == u'counters':
                continue
            if args[0] == u'!':
                out.append(args.pop(0))
            option = args.pop(0)
            if option.startswith(u'-'):
                out.append(option)
            else:
                out.append(u'--'+option)
            out.extend(args)
        return out


# Example iptables-save -c output:
example_iptables_save = '''
# Completed on Wed Jul 12 16:04:17 2017
# Generated by iptables-save v1.4.12 on Wed Jul 12 16:04:17 2017
*filter
:INPUT DROP [1943:190432]
:FORWARD DROP [250:10500]
:OUTPUT DROP [0:0]
:10583-OVPN - [0:0]
:EXT-INPUT - [0:0]
:EXT-OUTPUT - [0:0]
[24330854:2635874734] -A INPUT -i eth0 -j EXT-INPUT
[583461:61223748] -A INPUT -i lo -j ACCEPT
[68316:4110844] -A INPUT -p icmp -j ACCEPT
[4409683:319456808] -A INPUT -j LOG
[403592:481620962] -A FORWARD -d 10.1.1.94/32 -i eth0 -o tun0 -j 14110-OVPN
[339540:56315108] -A FORWARD -s 10.1.1.94/32 -i tun0 -o eth0 -j 14110-OVPN
[21700943:11185930269] -A OUTPUT -o eth0 -j EXT-OUTPUT
[583461:61223748] -A OUTPUT -o lo -j ACCEPT
[68205:4097317] -A OUTPUT -p icmp -j ACCEPT
[21:2931] -A OUTPUT -j LOG
[44559:3533157] -A 10583-OVPN -d 10.2.164.210/32 -p tcp -m tcp --dport 3389 \
-m state --state NEW,ESTABLISHED -m comment --comment öRULE-10941 -j ACCEPT
[32721:13602899] -A 10583-OVPN -s 10.2.164.210/32 -p tcp -m tcp --sport 3389 \
-m state --state ESTABLISHED -m comment --comment RULE-10941 -j ACCEPT
[20162735:10906591383] -A EXT-OUTPUT -s 212.100.43.215/32 -p tcp -m tcp \
--sport 443 ! --tcp-flags FIN,SYN,RST,ACK SYN -m state --state ESTABLISHED \
-j ACCEPT
[0:0] -A EXT-OUTPUT -s 212.100.43.215/32 -p udp -m udp --sport 443 -m state \
--state ESTABLISHED -j ACCEPT
COMMIT
# Completed on Wed Jul 12 16:04:17 2017
'''

example_iptables_struct = {
    u'iptables': {
        u'filter': {
            u'chains': {
                u'10583-OVPN': {
                    u'counters': {
                        u'bytes': 0,
                        u'packets': 0
                    }
                },
                u'EXT-INPUT': {
                    u'counters': {
                        u'bytes': 0,
                        u'packets': 0
                    }
                },
                u'EXT-OUTPUT': {
                    u'counters': {
                        u'bytes': 0,
                        u'packets': 0
                    }
                },
                u'FORWARD': {
                    u'counters': {
                        u'bytes': 10500,
                        u'packets': 250
                    }
                },
                u'INPUT': {
                    u'counters': {
                        u'bytes': 190432,
                        u'packets': 1943
                    }
                },
                u'OUTPUT': {
                    u'counters': {
                        u'bytes': 0,
                        u'packets': 0
                    }
                }
            },
            u'rules': {
                u'10583-OVPN': [
                    [
                        [u'destination', u'10.2.164.210/32'],
                        [u'protocol', u'tcp'],
                        [u'match', u'tcp'],
                        [u'dport', u'3389'],
                        [u'match', u'state'],
                        [u'state', u'NEW,ESTABLISHED'],
                        [u'match', u'comment'],
                        [u'comment', u'\xf6RULE-10941'],
                        [u'jump', u'ACCEPT'],
                        [u'counters', 44559, 3533157]
                    ],
                    [
                        [u'source', u'10.2.164.210/32'],
                        [u'protocol', u'tcp'],
                        [u'match', u'tcp'],
                        [u'sport', u'3389'],
                        [u'match', u'state'],
                        [u'state', u'ESTABLISHED'],
                        [u'match', u'comment'],
                        [u'comment', u'RULE-10941'],
                        [u'jump', u'ACCEPT'],
                        [u'counters', 32721, 13602899]
                    ]
                ],
                u'EXT-OUTPUT': [
                    [
                        [u'source', u'212.100.43.215/32'],
                        [u'protocol', u'tcp'],
                        [u'match', u'tcp'],
                        [u'sport', u'443'],
                        [u'!', u'tcp-flags', u'FIN,SYN,RST,ACK', u'SYN'],
                        [u'match', u'state'],
                        [u'state', u'ESTABLISHED'],
                        [u'jump', u'ACCEPT'],
                        [u'counters', 20162735, 10906591383]
                    ],
                    [
                        [u'source', u'212.100.43.215/32'],
                        [u'protocol', u'udp'],
                        [u'match', u'udp'],
                        [u'sport', u'443'],
                        [u'match', u'state'],
                        [u'state', u'ESTABLISHED'],
                        [u'jump', u'ACCEPT'],
                        [u'counters', 0, 0]
                    ]
                ],
                u'FORWARD': [
                    [
                        [u'destination', u'10.1.1.94/32'],
                        [u'in-interface', u'eth0'],
                        [u'out-interface', u'tun0'],
                        [u'jump', u'14110-OVPN'],
                        [u'counters', 403592, 481620962]
                    ],
                    [
                        [u'source', u'10.1.1.94/32'],
                        [u'in-interface', u'tun0'],
                        [u'out-interface', u'eth0'],
                        [u'jump', u'14110-OVPN'],
                        [u'counters', 339540, 56315108]
                    ]
                ],
                u'INPUT': [
                    [
                        [u'in-interface', u'eth0'],
                        [u'jump', u'EXT-INPUT'],
                        [u'counters', 24330854, 2635874734]
                    ],
                    [
                        [u'in-interface', u'lo'],
                        [u'jump', u'ACCEPT'],
                        [u'counters', 583461, 61223748]
                    ],
                    [
                        [u'protocol', u'icmp'],
                        [u'jump', u'ACCEPT'],
                        [u'counters', 68316, 4110844]
                    ],
                    [
                        [u'jump', u'LOG'],
                        [u'counters', 4409683, 319456808]
                    ]
                ],
                u'OUTPUT': [
                    [
                        [u'out-interface', u'eth0'],
                        [u'jump', u'EXT-OUTPUT'],
                        [u'counters', 21700943, 11185930269]
                    ],
                    [
                        [u'out-interface', u'lo'],
                        [u'jump', u'ACCEPT'],
                        [u'counters', 583461, 61223748]],
                    [
                        [u'protocol', u'icmp'],
                        [u'jump', u'ACCEPT'],
                        [u'counters', 68205, 4097317]
                    ],
                    [
                        [u'jump', u'LOG'],
                        [u'counters', 21, 2931]
                    ]
                ]
            }
        }
    }
}


example_ruledef = [
    # Mnemonic version:
    [u'source', u'192.168.42.254/32', ],
    # ('!', 'tcp-flags', 'SYN,ACK,FIN', 'SYN'),
    # Full-blown dictionary (support this???):
    # {
    #    'name': 'source',
    #    'negate': False,  # may be omitted, defaults to False
    #    'value': '192.168.42.254/32',
    # },
    [u'protocol', u'tcp'],
    [u'match', u'tcp'],
    [u'dport', u'80:443'],
    [u'jump', u'DROP'],
    # Counters are not interpreted in input mode
    [u'counters', 0, 0],
]

example_output = {
    # One entry per table
    u'iptables': {
        # Dump says: *filter
        u'filter': {
            u'chains': {
                # Dump says: :INPUT ACCEPT [72:11452]
                u'INPUT': {
                    u'policy': 'ACCEPT',
                    # Counters are not interpreted in input mode
                    u'counters': {
                        u'packets': 72,
                        u'bytes': 11452,
                    },
                },
                # ...
            },
            u'rules': {
                u'INPUT': [
                    # Dump says: [0:0] -A INPUT -s 192.168.42.254/32 -p tcp
                    # -m tcp --dport 80:443 -j DROP
                    example_ruledef,
                    # more rules...
                ],
                # more chains...
            },
        },
    },
}


example_fwstruct = [
    {
        u'command': u'add',
        u'fwstruct': example_output,
        u'outformat': 'python',  # optional: python is the default
    },
    # ...
]
