from __future__ import division import numpy as np from scipy import stats O_ij = np.array([[1,6],[8,2]]) r_i = np.sum(O_ij,axis=-1); r_i c_j = np.sum(O_ij,axis=0); c_j N = np.sum(r_i); N O11 = O_ij[0,0]; O11 r1 = r_i[0]; r1 c1 = c_j[0]; c1 x = np.arange(max(0,r1+c1-N),min(r1,c1)+1); x px = stats.hypergeom(N,c1,r1).pmf(x); px pxobs = stats.hypergeom(N,c1,r1).pmf(O11); pxobs pval_fisher = np.sum(px[px<=pxobs]); pval_fisher r2 = r_i[1]; r2 c2 = c_j[1]; c2 O11vals = np.arange(r1+1); O11vals O21vals = np.arange(r2+1); O21vals c1vals = O11vals[:,None] + O21vals[None,:] c2vals = N - c1vals statvals = (N*O11vals[:,None]-r1*c1vals)/np.sqrt(c1vals*c2vals+1e-6) mystat = (N*O11-r1*c1)/np.sqrt(c1*c2); mystat np.sum(statvals<=mystat) np.sum(statvals>=-mystat) p1 = np.linspace(0,1,100) pmfvals = stats.binom(r1,p1[None,None,:]).pmf(O11vals[:,None,None]) * stats.binom(r2,p1[None,None,:]).pmf(O21vals[None,:,None]) plower = np.sum((statvals<=mystat)[:,:,None]*pmfvals,axis=(0,1)) pupper = np.sum((statvals>=-mystat)[:,:,None]*pmfvals,axis=(0,1)) plot(p1,plower,'b--',label='lower-tailed'); plot(p1,pupper,'g-.',label='upper-tailed'); plot(p1,plower+pupper,'k-',label='two-tailed'); legend(loc='lower center'); xlabel(r'$p_{\bullet 1}$'); ylabel(r'$\alpha$'); title(r'Significance for $2\times 2$ table'); savefig('notes04_alpha2x2.eps',bbox_inches='tight'); pval_row = max(plower+pupper); pval_row