#!/usr/bin/python3

from samba.samdb import SamDB
from optparse import OptionParser
import samba
from samba import getopt
from samba.auth import system_session
from samba.join import DCJoinContext
from samba.dcerpc.security import dom_sid
from samba.netcmd import domain_backup
from samba.netcmd.dns import dns_connect, dns_type_flag, dns_record_match, data_to_dns_record, dns_client_version
from samba.ndr import ndr_unpack
from samba.remove_dc import remove_dc
from samba import Ldb
from samba.provision import secretsdb_self_join
from samba.dcerpc import misc, dnsserver
from configobj import ConfigObj
import ldb
import logging
from os.path import exists
import os
from sys import exit
from time import sleep
from shutil import copytree, rmtree, move
from glob import glob

targetdir = '/var/lib/samba'
config = ConfigObj('/etc/eole/samba4-vars.conf')
netbios = config['AD_HOST_NAME'].upper()
ad_realm = config['AD_REALM'].upper()
ad_realm_lower = config['AD_REALM']
ad_host_keytab_file = config['AD_HOST_KEYTAB_FILE']
fake_netbios = netbios + 'RESTORE'
addc_ip = config['AD_HOST_IP']
site = None

def fake_exists(path):
    if path == 'fake':
        print('fake exists for tarball')
        return True
    if path == targetdir:
        print('fake not exists for {}'.format(path))
        return False
    return exists(path)
true_exists = os.path.exists
os.path.exists = fake_exists
domain_backup.os = os

class fake_open:
    def extractall(self, a):
        pass

    def close(self):
        pass

class fake_tarfile:
    # files are not in tar file
    # do not extract anything
    def open(self, path):
        print('fake tarfile for {}'.format(path))
        return fake_open()
domain_backup.tarfile = fake_tarfile()

# first restore ADDC with a different name
cmd = 'systemctl stop samba-ad-dc'
os.system(cmd)
backup = domain_backup.cmd_domain_backup_restore()
parse = OptionParser()
options = getopt.SambaOptions(parse)
cred_options = getopt.CredentialsOptions(parse)
rmtree(targetdir)
copytree('/home/backup/samba/bareos/', targetdir)
os.system('sed -i "s/full_audit:success = connect disconnect//g" {}'.format(os.path.join(targetdir, 'etc/smb.conf')))
backup.run(sambaopts=options,
           credopts=cred_options,
           backup_file='fake',
           targetdir='/var/lib/samba',
           newservername=fake_netbios)

# rejoin a new server with old name
lp = options.get_loadparm()
creds = cred_options.get_credentials(lp)
samdb = SamDB(session_info=system_session(), credentials=creds, lp=lp)
logger = logging.getLogger('restore')
ctx = DCJoinContext(logger, creds=creds, lp=lp, site=site,
                    forced_local_samdb=samdb,
                    netbios_name=netbios)
ctx.userAccountControl = (samba.dsdb.UF_SERVER_TRUST_ACCOUNT |
                          samba.dsdb.UF_TRUSTED_FOR_DELEGATION)
res = samdb.search(base="", scope=ldb.SCOPE_BASE,
                   attrs=['namingContexts'])
ncs = [str(r) for r in res[0].get('namingContexts')]
ctx.nc_list = ncs
ctx.full_nc_list = ncs
ctx.join_add_objects()
m = ldb.Message()
m.dn = ldb.Dn(samdb, '@ROOTDSE')
ntds_guid = str(ctx.ntds_guid)
m["dsServiceName"] = ldb.MessageElement("<GUID=%s>" % ntds_guid,
                                        ldb.FLAG_MOD_REPLACE,
                                        "dsServiceName")
samdb.modify(m)
private_dir = os.path.join(targetdir, 'private')
secrets_path = os.path.join(private_dir, 'secrets.ldb')
secrets_ldb = Ldb(secrets_path, session_info=system_session(), lp=lp)
secretsdb_self_join(secrets_ldb, domain=ctx.domain_name,
                    realm=ctx.realm, dnsdomain=ctx.dnsdomain,
                    netbiosname=ctx.myname, domainsid=ctx.domsid,
                    machinepass=ctx.acct_pass,
                    key_version_number=ctx.key_version_number,
                    secure_channel_type=misc.SEC_CHAN_BDC)

# Seize DNS roles
domain_dn = samdb.domain_dn()
forest_dn = samba.dn_from_dns_name(samdb.forest_dns_name())
domaindns_dn = ("CN=Infrastructure,DC=DomainDnsZones,", domain_dn)
forestdns_dn = ("CN=Infrastructure,DC=ForestDnsZones,", forest_dn)
for dn_prefix, dns_dn in [forestdns_dn, domaindns_dn]:
    if dns_dn not in ncs:
        continue
    full_dn = dn_prefix + dns_dn
    m = ldb.Message()
    m.dn = ldb.Dn(samdb, full_dn)
    m["fSMORoleOwner"] = ldb.MessageElement(samdb.get_dsServiceName(),
                                            ldb.FLAG_MOD_REPLACE,
                                            "fSMORoleOwner")
    samdb.modify(m)

# Seize other roles
for role in ['rid', 'pdc', 'naming', 'infrastructure', 'schema']:
    backup.seize_role(role, samdb, force=True)
search_expr = "(&(objectClass=Server)(serverReference=*)(cn={}))".format(fake_netbios)
res = samdb.search(samdb.get_config_basedn(), scope=ldb.SCOPE_SUBTREE,
                   expression=search_expr)


# remove restored server
cn = str(res[0].get('cn')[0])
remove_dc(samdb, logger, cn)

# Remove the repsFrom and repsTo from each NC to ensure we do
# not try (and fail) to talk to the old DCs
for nc in ncs:
    msg = ldb.Message()
    msg.dn = ldb.Dn(samdb, nc)

    msg["repsFrom"] = ldb.MessageElement([],
                                         ldb.FLAG_MOD_REPLACE,
                                         "repsFrom")
    msg["repsTo"] = ldb.MessageElement([],
                                       ldb.FLAG_MOD_REPLACE,
                                       "repsTo")
    samdb.modify(msg)

os.path.exists = true_exists

# remove extra data in targetdir
rmtree(targetdir + '/etc')
os.unlink(targetdir + '/backup.txt')

# restore sysvol
sysvol_path = f'/home/sysvol/{ad_realm_lower}/'
pol_path = f'{sysvol_path}/Policies/'
for pol in glob(sysvol_path + '*'):
    rmtree(pol)
state_path = targetdir + '/state/'
res_sysvol_path = state_path + f'sysvol/{ad_realm_lower}/'
res_pol_path = res_sysvol_path + 'Policies/'
for pol in glob(res_pol_path + '*'):
    move(pol, pol_path)
rmtree(res_pol_path)
for pol in glob(res_sysvol_path + '*'):
    move(pol, sysvol_path)
rmtree(state_path)

cmd = 'systemctl start samba-ad-dc'
if os.system(cmd):
    print('la commande {} est en erreur'.format(cmd))
    exit(1)

cmd = 'samba-tool domain exportkeytab "{0}" --principal="{1}@{2}" -s /etc/samba/smb.conf'.format(ad_host_keytab_file, netbios, ad_realm)
if os.system(cmd):
    print('la commande {} est en erreur'.format(cmd))
    exit(1)

# touch $AD_INSTANCE_LOCK_FILE
open('/var/lib/samba/.instance_ok', 'a').close()
cmd = 'kinit "{1}@{2}" -k -t "{0}"'.format(ad_host_keytab_file, netbios, ad_realm)
if os.system(cmd):
    print('la commande {} est en erreur'.format(cmd))
    exit(1)

# Update zones info (code from netcmd/dns.py)
record_type = dns_type_flag('SOA')
old_server = '{}.{}'.format(fake_netbios.lower(), ad_realm.lower())
old_servers = [old_server, '{}.'.format(old_server)]
server = '{}.{}'.format(netbios.lower(), ad_realm.lower())
creds = cred_options.get_credentials(lp)
dns_conn = dns_connect(server, lp, creds)
select_flags = 1

# get zones list
client_version = dns_client_version('longhorn')
request_filter = 1
typeid, zones = dns_conn.DnssrvComplexOperation2(client_version,
                                                 0,
                                                 server,
                                                 None,
                                                 'EnumZones',
                                                 dnsserver.DNSSRV_TYPEID_DWORD,
                                                 request_filter,
                                                 )
zones = [zone.pszZoneName for zone in zones.ZoneArray]

for zone in zones:
    # retrieve SOA informations
    name = zone
    buflen, records = dns_conn.DnssrvEnumRecords2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
                                                  0,
                                                  server,
                                                  zone,
                                                  name,
                                                  None,
                                                  record_type,
                                                  select_flags,
                                                  None,
                                                  None,
                                                  )
    old_name = None
    for rec in records.rec:
        for dns_rec in rec.records:
            try:
                old_name = dns_rec.data.NamePrimaryServer.str.lower()
                old_data = " ".join([dns_rec.data.NamePrimaryServer.str,
                                     dns_rec.data.ZoneAdministratorEmail.str,
                                     str(dns_rec.data.dwSerialNo),
                                     str(dns_rec.data.dwRefresh),
                                     str(dns_rec.data.dwRetry),
                                     str(dns_rec.data.dwExpire),
                                     str(dns_rec.data.dwMinimumTtl),
                                     ])
                new_data = " ".join([server,
                                     dns_rec.data.ZoneAdministratorEmail.str,
                                     str(dns_rec.data.dwSerialNo + 1),
                                     str(dns_rec.data.dwRefresh),
                                     str(dns_rec.data.dwRetry),
                                     str(dns_rec.data.dwExpire),
                                     str(dns_rec.data.dwMinimumTtl),
                                     ])
                break
            except:
                pass
    if old_name in old_servers:
        # update SOA informations
        rec_match = dns_record_match(dns_conn,
                                     server,
                                     zone,
                                     name,
                                     record_type,
                                     old_data,
                                     )
        rec = data_to_dns_record(record_type,
                                 new_data,
                                 )
        rec.dwFlags = rec_match.dwFlags
        rec.dwSerial = rec_match.dwSerial
        rec.dwTtlSeconds = rec_match.dwTtlSeconds
        rec.dwTimeStamp = rec_match.dwTimeStamp

        add_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
        add_rec_buf.rec = rec
        del_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
        del_rec_buf.rec = rec_match

        dns_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
                                     0,
                                     server,
                                     zone,
                                     name,
                                     add_rec_buf,
                                     del_rec_buf,
                                     )
    # Add NS record
    try:
        record_type = dns_type_flag('NS')
        rec = data_to_dns_record(record_type,
                                 '{}.{}.'.format(netbios, ad_realm),
                                 )

        add_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
        add_rec_buf.rec = rec
        dns_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
                                     0,
                                     server,
                                     zone,
                                     name,
                                     add_rec_buf,
                                     None,
                                     )
    except:
        pass
# Add A records
record_type = dns_type_flag('A')
rec = data_to_dns_record(record_type,
                         addc_ip,
                         )

add_rec_buf = dnsserver.DNS_RPC_RECORD_BUF()
add_rec_buf.rec = rec

try:
    dns_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
                                 0,
                                 server,
                                 ad_realm,
                                 ad_realm,
                                 add_rec_buf,
                                 None,
                                 )
except:
    pass

try:
    dns_conn.DnssrvUpdateRecord2(dnsserver.DNS_CLIENT_VERSION_LONGHORN,
                                 0,
                                 server,
                                 ad_realm,
                                 netbios,
                                 add_rec_buf,
                                 None,
                                 )
except:
    pass

cmd = 'kdestroy'
if os.system(cmd):
    print('la commande {} est en erreur'.format(cmd))
    exit(1)
if os.path.isfile('/home/backup/samba/samba.sid'):
    with open('/home/backup/samba/samba.sid', 'r') as fh:
        sid = fh.read().strip()
    os.makedirs('/var/lib/samba/etc')
    os.symlink('/etc/samba/smb.conf', '/var/lib/samba/etc/smb.conf')
    cmd = 'net setlocalsid {}'.format(sid)
    if os.system(cmd):
        print('la commande {} est en erreur'.format(cmd))
        exit(1)
    rmtree('/var/lib/samba/etc')
if os.path.isdir('/var/lib/samba/bind-dns'):
    os.system('chmod 755 /var/lib/samba')
    os.system('chmod 770 /var/lib/samba/bind-dns')
    os.system('chgrp bind /var/lib/samba/bind-dns')
