/*
*
*              Reverse MX Filter for Postfix 2
*                    reject_bad_rmx
*           (c) 2004 Elita rozanski@sergiusz.com
*
*/

/*
MXfilter is free software; you can redistribute it and/or modify it under
the terms of the GNU General Public License as published by the Free
Software Foundation; either version 2, or (at your option) any later
version.
*/

#include <stdlib.h>
#include <stdio.h>
#include <stdarg.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <fcntl.h>
#include <unistd.h>
#include <netinet/in.h>
#include <errno.h>
#include <string.h>
#include <netdb.h>
#include <sys/time.h>
#include <sys/resource.h>
#include <ctype.h>
#include <arpa/nameser.h>
#include <resolv.h>
#include <arpa/inet.h>

#include "mxget.h"
#include "utils.h"
#include "spfmacro.h"
#include "spf1.h"
#include "spfptr.h"
#include "gettxt.h"

struct rrecord {
    int16_t r_type;
    int16_t r_class;
    u_int32_t r_ttl;
    int16_t r_length;
};

int check_a(char *hostname,char *client_address, char *net,int tmode, char *prefix, int *querycount, int querymax) {
struct hostent *hh;
int i;

querycount[0]++;
if (querycount[0]>querymax)
        return -1;
hh=gethostbyname(hostname);
if(!hh) {
    if(h_errno==TRY_AGAIN || h_errno==NO_RECOVERY)
	return (-1); //dns error
    }
else {
for(i=0;i<256;i++) {
    if(hh->h_addr_list[i]) {
	char odp[256];
	
	snprintf(odp,sizeof(odp),"%d.%d.%d.%d",(unsigned char)hh->h_addr_list[i][0],
	    (unsigned char)hh->h_addr_list[i][1],
	    (unsigned char)hh->h_addr_list[i][2],
	    (unsigned char)hh->h_addr_list[i][3]);
	if (!tmode) {
	    if (check_ip4(odp, client_address,net)==0) {
		return 0;
		}
	    }
	else {
	    printf("[a%s] %s/%s\n",prefix,odp,net);
	    }
	}
    else {
	i=256;
	}
    }
    }
return 1;
}

/* get_spf1     */
/* -1 dns error */
/* 0 ok */
/* 1 multiple, not found, txt but no spf etc. */

int get_spf1(char *domain,char *answer, int sanswer, int *querycount, int querymax) {
    static char out[1000];
    unsigned char query[1000];
    HEADER *h;
    unsigned char *p;
    int qcount, acount;
    char buf[1000];
    struct rrecord *r;
    char result[1024];
    char trueresult[1024];
    int spf1count=0;
    int kr;

    /* ask libc to do the query:
     * look up 'domain'
     * class C_IN (internet)
     * type  T_TXT
     * record results in query, max len 1000
     */

    querycount[0]++;
    if (querycount[0]>querymax)
        return -1;
    kr=res_search(domain, C_IN, T_TXT, query, 1000);
    if(kr<=0) {
        if(h_errno == HOST_NOT_FOUND ||
	        h_errno == TRY_AGAIN ||
		h_errno == NO_DATA)
	    return (1);
	return -1;
	}

    h = (HEADER *) query;

    /* skip over queries in packet */
    p = query + sizeof(HEADER);
    qcount = ntohs(h->qdcount);	/* # of queries present */
    acount = ntohs(h->ancount);	/* # of answers present */

    /* get past queries */
//    printf("\nqcount %d\n",qcount);
    while (qcount--) {
	p += dn_expand(query, query + 1000, p, buf, 1000);
	p += 4;			/* qclass, qtype */
    }
//    printf("\nacount %d\n",acount);

    /* now we are in the answers section. Just get the best exchanger,
     * we don't want to spend too much time messing around
     */
    out[0] = 0;
    while (acount--) {
	/* name that this answer refers to */
	p += dn_expand(query, query + 1000, p, buf, 1000);
//	printf("\nbuf: %s\n",buf);

	r = (struct rrecord *) p;
	if (htons(r->r_type) == T_TXT) {		/* TXT record */
	    unsigned char *p1 = p + 11;
	    snprintf(result,sizeof(result),"%.*s",ntohs(r->r_length)-1,p1);

//	    printf("result:|%s|",result);

	    if(strncasecmp(result,"v=spf1 ",7)==0) {
		spf1count++;
		snprintf(trueresult,sizeof(trueresult),"%s",result);
		}
	    }
	p = p + 10 + ntohs(r->r_length);	/* skip RR header and data */
    }

if(spf1count==0) {
    return 1; //txt found, no spf1
    }
if(spf1count==1) {
    snprintf(answer,sanswer,"%s",trueresult);
    printf("%s\n",trueresult);
    return 0;
    }
return 1; // multiple spf1 records
}


/* parse_spf1 */
/* return -1 - dns error */
/* return  0 - pass     */
/* return  1 - false    */
/* return  2 - unknown */
/* return  3 - softfail */

int parse_spf1(char *r,char *client_ip,char *domain, int level, int tmode, 
	struct tspfm *spfm, char *exp_res, int sizeof_exp_res, int *querycount, int querymax) {
char *p,*k;
char w[1024];
char prefix[]="+";
int debug=0;
char redirect_domain[1024];
int redirect=0;
char *exp;
char expp[1024];
char res[1024];

if(level==10)return 2; // depth limit

if(tmode)
    printf("[answer level:%d] %s\n",level,r);

//exp
snprintf(expp,sizeof(expp)," %s ",r);
exp=(char *)strstr(expp," exp=");
if(exp) {
    exp+=5;
    if(strlen(exp)>0) {
	char *expk=(char *)index(exp,' ');
	if(expk) {
	    char dmn[1024];

	    expk[0]=0;

	    snprintf(dmn,sizeof(dmn),"%s",exp);
	    macro(dmn,sizeof(dmn),spfm);
	    
	    gettxt(dmn, res, sizeof(res),querycount, querymax);
	    macro(res,sizeof(res),spfm);
	    
	    snprintf(exp_res,sizeof_exp_res,"%s",res);
	    }
	}
    }

snprintf(w,sizeof(w),"%s ",r);
if(debug)printf("spf=`%s` client_ip=`%s` domain=`%s`\n",w,client_ip,domain);
p=w;
while (index(p,' ')) {
    k=(char *)index(p,' ');
    k++[0]=0;
    
    if(debug)printf("> %s\n",p);

    if (strncasecmp(p,"redirect=",9)==0 && !redirect) {
	snprintf(redirect_domain,sizeof(redirect_domain),"%s",p+9);
	redirect=1;
	}

// prefix
    if (p[0]=='+' || p[0]=='-' || p[0]=='?' || p[0]=='~') {
	prefix[0]=p[0];	
	p++;
	}
    else
	prefix[0]='+';

// mechanism
    if (strcasecmp(p,"all")==0) {
	if(prefix[0]=='+') {
	    if(debug)printf(">> match +all\n");
	    if(tmode)printf("[all+] 0/0\n");
	    else return 0;
	    }
	else
	if(prefix[0]=='-') {
	    if(debug)printf(">> match -all\n");
	    if(tmode)printf("[all-] 0/0\n");
	    else return 1;
	    }
	else
	if(prefix[0]=='~') {
	    if(debug)printf(">> match ~all\n");
	    if(tmode)printf("[all~] 0/0\n");
	    else return 3;
	    }
	else {
	    if(debug)printf(">> match ?all\n");
	    if(tmode)printf("[all?] 0/0\n");
	    else return 2;
	    }
	}
    else
    if (strncasecmp(p,"include:",8)==0) {
	char dmn[1024];
	char re[1024];
	int k,w;

	snprintf(dmn,sizeof(dmn),"%s",p+8);
	
	if(strlen(dmn)==0)
	    return 2;

	macro(dmn,sizeof(dmn),spfm);

	k=get_spf1(dmn,re,sizeof(re), querycount, querymax);
//	if(k==1) //unknow, multiple
	if(k==-1) return -1; //dns error
	if(k==0) {
	    char bkp[1024];
	    
	    snprintf(bkp,sizeof(bkp),"%s",spfm->d);
	    snprintf(spfm->d,sizeof(spfm->d),"%s",dmn);
	    w=parse_spf1(re+7,client_ip,dmn,level+1,tmode, spfm, exp_res, sizeof_exp_res, querycount, querymax);
	    snprintf(spfm->d,sizeof(spfm->d),"%s",bkp);
	    
	    if(!tmode) {
		if(w==-1) return -1;
		//if(w==1) return 1; //fail, continue
		//if(w==2) return 2; //neutral, continue
		//if(w==3) return 3; //softfail, continue
		if(w==0) {
		    if(prefix[0]=='+') return 0;
		    else if(prefix[0]=='-') return 1;
		    else if(prefix[0]=='~') return 3;
		    else return 2;
		    }
		}
	    }
	if(!tmode) {
	    //if(k==-1) return -1; //(450) dns error
	    if(k==2) return 2; //multiple or not found "unknown"
	    }
	}
    else
    if (strcasecmp(p,"a")==0 || strncasecmp(p,"a:",2)==0 ||
	    strncmp(p,"a/",2)==0) {
	char dmn[1024];
	char net[4]="32";
	int mr;
	char *n;
    
	if (strncasecmp(p,"a/",2)==0) {
	    snprintf(dmn,sizeof(dmn),"%s",domain);
	    snprintf(net,sizeof(net),"%s",p+2);
	    }
	else {
	    snprintf(dmn,sizeof(dmn),"%s",(index(p,':'))?p+2:domain);
	    n=(char *)index(dmn,'/');
	    if(n) {
		n++[0]=0;
		snprintf(net,sizeof(net),"%s",n);
		}
	    }

	macro(dmn,sizeof(dmn),spfm);
	
	if(debug)printf(">> Test a: %s %s /%s\n",dmn,client_ip,net);
	mr=check_a(dmn,client_ip,net,tmode,prefix,querycount,querymax);

	if(!tmode) {	
	    if(mr==0) {
		if(debug)printf(">> match %sa %s %s\n",prefix,dmn,client_ip);
		if(prefix[0]=='+')return 0; //pass
		else if(prefix[0]=='-')return 1; //false
		else if(prefix[0]=='~')return 3; //softfail
		else return 2;
		}
	    if(mr==-1) return -1; //dns error
	    }
	}
    else
    if (strcasecmp(p,"mx")==0 || strncasecmp(p,"mx:",3)==0 ||
		strncasecmp(p,"mx/",3)==0) {
	char dmn[1024];
	char net[4]="32";
	int mr;
	char *n;

	if (strncasecmp(p,"mx/",3)==0) {
	    snprintf(dmn,sizeof(dmn),"%s",domain);
	    snprintf(net,sizeof(net),"%s",p+3);
	    }
	else {
	    snprintf(dmn,sizeof(dmn),"%s",(index(p,':'))?p+3:domain);
	    n=(char *)index(dmn,'/');
	    if(n) {
		n++[0]=0;
		snprintf(net,sizeof(net),"%s",n);
		}
	    }

	macro(dmn,sizeof(dmn),spfm);

	if(debug)printf(">> Test mx: %s %s\n",dmn,client_ip);
	mr=find_mailhost(dmn, client_ip, net, tmode, prefix, querycount, querymax);

	if(!tmode) {
	    if(mr==0) {
		if(debug)printf(">> match %smx %s %s\n",prefix,dmn,client_ip);
		if(prefix[0]=='+')return 0; //pass
		else if(prefix[0]=='-')return 1; //false
                else if(prefix[0]=='~')return 3; //softfail
		else return 2;
		}
//	    if(mr==1) //no match
	    if(mr==-1) return -1; //dns error
//	    if(mr==-2) //not found
	    }
	}
    else
    if (strcasecmp(p,"ptr")==0 || strncasecmp(p,"ptr:",4)==0) {
	char dmn[1024];
	int w;

	if(strncasecmp(p,"ptr:",4)==0) {
	    snprintf(dmn,sizeof(dmn),"%s",p+4);
	    }
	else {
	    snprintf(dmn,sizeof(dmn),"%s",domain);
	    }
	w=find_ptr(dmn, client_ip, tmode, prefix, querycount, querymax);

	if(!tmode) {
	    if(w==-1) return -1; //dns error
	    // if(w==1) ; //no match, continue
	    if(w==0) {
		if(prefix[0]=='+')return 0; //pass
		else if(prefix[0]=='-')return 1; //false
                else if(prefix[0]=='~')return 3; //softfail
		else return 2;
		}
	    }
	}
    else
    if (strncasecmp(p,"ip4:",4)==0) {
	char dmn[1024];
	char net[4]="32";
	char *n;
	int mr;

	snprintf(dmn,sizeof(dmn),"%s",(index(p,':'))?p+4:domain);
	n=(char *)index(dmn,'/');
	if(n) {
	    n++[0]=0;
	    snprintf(net,sizeof(net),"%s",n);
	    }
	if(debug)printf(">> Test ip4: %s %s\n",dmn,client_ip);
	mr=check_ip4(dmn,client_ip,net);

	if(!tmode) {
	    if(mr==0) {
		if(debug)printf(">> match %sip4 %s %s\n",prefix,dmn,client_ip);
		if(prefix[0]=='+')return 0; //pass
		else if(prefix[0]=='-')return 1; //false
                else if(prefix[0]=='~')return 3; //softfail
		else return 2;
		}
	    if(mr==2) return 2; //unknown
	    }
	else {
	    printf("[ip4%s] %s/%s\n",prefix,dmn,net);
	    }
	}
    else
    if (strncasecmp(p,"exists:",7)==0) {
	char dmn[1024];
	int mr;
    
	snprintf(dmn,sizeof(dmn),"%s",p+7);
	macro(dmn,sizeof(dmn),spfm);
	
	if(debug)printf(">> Test exists: %s\n",dmn);
	mr=check_a(dmn,client_ip,"0",tmode,prefix,querycount,querymax);

	if(!tmode) {
	    // if(mr==1) //no, continue
	    if(mr==0) {
		if(debug)printf(">> match exists %s\n",dmn);
		if(prefix[0]=='+')return 0; //pass
		else if(prefix[0]=='-')return 1; //false
                else if(prefix[0]=='~')return 3; //softfail
		else return 2;
		}
	    if(mr==-1) return -1; //dns error
	    }
	}
    
    p=k;
    }

if(redirect) {
    char dmn[1024];
    char re[1024];
    int k,w;

    snprintf(dmn,sizeof(dmn),"%s",redirect_domain);
	
    if(strlen(dmn)==0)
        return 2;

    macro(dmn,sizeof(dmn),spfm);

    k=get_spf1(dmn,re,sizeof(re), querycount, querymax);
    if(k==0) {
        char bkp[1024];	
	    
        snprintf(bkp,sizeof(bkp),"%s",spfm->d);
        snprintf(spfm->d,sizeof(spfm->d),"%s",dmn);
        w=parse_spf1(re+7,client_ip,dmn,level+1,tmode,spfm,exp_res,sizeof_exp_res, querycount, querymax);
        snprintf(spfm->d,sizeof(spfm->d),"%s",bkp);

        return w;
        }
    if(k==-1)return -1;	// -1 (450) dns error
    return 2; 	// unknow, not found, multiple
    }

return 2;
}

/* find txt spf1 domain */
/* return -1 temp dns error (450) */
/* return 0 OK */
/* return 1 false */
/* return 2 unknown */
/* return 3 softfail */

int check_spf1(char *domain, char *client_ip, int tmode, struct tspfm *spfm, 
	char *exp_res, int sizeof_exp_res, int *querycount, int querymax) {
    char r[1024];
    int k,w;

    k=get_spf1(domain,r,sizeof(r), querycount, querymax);
    if (k==0) {
	w=parse_spf1(r+7,client_ip,domain,1,tmode, spfm, exp_res, sizeof_exp_res, querycount, querymax);
	return w;
	}
    if(k==-1) return -1; //(450) dns error
    return 2; // unknown
    }
