From 890e633736263b55c0ddb5fb7b5621d0aa997164 Mon Sep 17 00:00:00 2001 From: cat Date: Wed, 22 Jan 2020 18:36:09 -0500 Subject: [PATCH 01/19] add cuda files --- cuda/mexClustering2.cu | 217 ++++++++++ cuda/mexDistances2.cu | 70 ++++ cuda/mexGetSpikes2.cu | 256 ++++++++++++ cuda/mexMPnu8.cu | 523 +++++++++++++++++++++++ cuda/mexSVDsmall2.cu | 255 ++++++++++++ cuda/mexThSpkPC.cu | 172 ++++++++ cuda/mexWtW2.cu | 54 +++ src/yass/reordering/cluster.py | 570 ++++++++++++++++++++++++++ src/yass/reordering/default_params.py | 70 ++++ src/yass/reordering/preprocess.py | 433 +++++++++++++++++++ src/yass/reordering/utils.py | 311 ++++++++++++++ 11 files changed, 2931 insertions(+) create mode 100644 cuda/mexClustering2.cu create mode 100644 cuda/mexDistances2.cu create mode 100644 cuda/mexGetSpikes2.cu create mode 100644 cuda/mexMPnu8.cu create mode 100644 cuda/mexSVDsmall2.cu create mode 100644 cuda/mexThSpkPC.cu create mode 100644 cuda/mexWtW2.cu create mode 100644 src/yass/reordering/cluster.py create mode 100644 src/yass/reordering/default_params.py create mode 100644 src/yass/reordering/preprocess.py create mode 100644 src/yass/reordering/utils.py diff --git a/cuda/mexClustering2.cu b/cuda/mexClustering2.cu new file mode 100644 index 00000000..0eb78dfd --- /dev/null +++ b/cuda/mexClustering2.cu @@ -0,0 +1,217 @@ +__global__ void computeCost(const double *Params, const float *uproj, const float *mu, const float *W, + const bool *match, const int *iC, const int *call, float *cmax){ + + int NrankPC,j, NchanNear, tid, bid, Nspikes, Nthreads, k, my_chan, this_chan, Nchan; + float xsum = 0.0f, Ci, lam; + + Nspikes = (int) Params[0]; + NrankPC = (int) Params[1]; + Nthreads = blockDim.x; + lam = (float) Params[5]; + NchanNear = (int) Params[6]; + Nchan = (int) Params[7]; + + tid = threadIdx.x; + bid = blockIdx.x; + + while(tid max_running){ + id[tind] = ind; + max_running = cmax[tind + ind*Nspikes]; + } + + + cx[tind] = max_running; + + tind += Nblocks*Nthreads; + } +} +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void average_snips(const double *Params, const int *iC, const int *call, + const int *id, const float *uproj, const float *cmax, float *WU){ + + int my_chan, this_chan, tidx, tidy, bid, ind, Nspikes, NrankPC, NchanNear, Nchan; + float xsum = 0.0f; + + Nspikes = (int) Params[0]; + NrankPC = (int) Params[1]; + Nchan = (int) Params[7]; + NchanNear = (int) Params[6]; + + tidx = threadIdx.x; + tidy = threadIdx.y; + bid = blockIdx.x; + + for(ind=0; ind Cmax){ + Cmax = Cf*Cf /(1+j); + kmax = j + t*Nsum; + } + } + } + datasum[tid0 + NT * i] = Cmax; + kkmax[tid0 + NT * i] = kmax; + } + tid0 += blockDim.x * gridDim.x; + } +} + +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void Conv1D(const double *Params, const float *data, const float *W, float *conv_sig){ + volatile __shared__ float sW[81*NrankMax], sdata[(Nthreads+81)]; + float y; + int tid, tid0, bid, i, nid, Nrank, NT, nt0, Nchan; + + tid = threadIdx.x; + bid = blockIdx.x; + NT = (int) Params[0]; + Nrank = (int) Params[14]; + nt0 = (int) Params[4]; + Nchan = (int) Params[9]; + + if(tid Cbest + 1e-6){ + Cbest = Cf; + ibest = i; + kbest = kkmax[tid0 + NT*i]; + } + } + err[tid0] = Cbest; + ftype[tid0] = ibest; + kall[tid0] = kbest; + + tid0 += blockDim.x * gridDim.x; + } +} + +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void cleanup_spikes(const double *Params, const float *err, + const int *ftype, float *x, int *st, int *id, int *counter){ + + int lockout, indx, tid, bid, NT, tid0, j, t0; + volatile __shared__ float sdata[Nthreads+2*81+1]; + bool flag=0; + float err0, Th; + + lockout = (int) Params[4] - 1; + tid = threadIdx.x; + bid = blockIdx.x; + + NT = (int) Params[0]; + tid0 = bid * blockDim.x ; + Th = (float) Params[2]; + + while(tid0 Th*Th && t0err0){ + flag = 1; + break; + } + if(flag==0){ + indx = atomicAdd(&counter[0], 1); + if (indx=0 & t=0 && tid0 Cbest + 1e-6){ + Cnextbest = Cbest; + Cbest = Cf; + ibest = i; + } + else + if (Cf > Cnextbest + 1e-6) + Cnextbest = Cf; + } + err[tid0] = Cbest; + eloss[tid0] = Cbest - Cnextbest; + ftype[tid0] = ibest; + + tid0 += blockDim.x * gridDim.x; + } +} + +// THIS UPDATE DOES NOT UPDATE ELOSS? +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void bestFilterUpdate(const double *Params, const float *data, + const float *mu, float *err, float *eloss, int *ftype, const int *st, const int *id, const int *counter){ + int tid, ind, i,t, NT, Nfilt, ibest = 0, nt0; + float Cf, Cbest, lam, b, a, Cnextbest; + + tid = threadIdx.x; + NT = (int) Params[0]; + Nfilt = (int) Params[1]; + lam = (float) Params[7]; + nt0 = (int) Params[4]; + + + // we only need to compute this at updated locations + ind = counter[1] + blockIdx.x; + + if (ind=0 && t Cbest + 1e-6){ + Cnextbest = Cbest; + Cbest = Cf; + ibest = i; + } + else + if (Cf > Cnextbest + 1e-6) + Cnextbest = Cf; + } + err[t] = Cbest; + ftype[t] = ibest; + } + } +} + +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void cleanup_spikes(const double *Params, const float *data, + const float *mu, const float *err, const float *eloss, const int *ftype, int *st, + int *id, float *x, float *y, float *z, int *counter){ + + int lockout, indx, tid, bid, NT, tid0, j, id0, t0; + volatile __shared__ float sdata[Nthreads+2*81+1]; + bool flag=0; + float err0, Th; + + lockout = (int) Params[4] - 1; + tid = threadIdx.x; + bid = blockIdx.x; + + NT = (int) Params[0]; + tid0 = bid * blockDim.x ; + Th = (float) Params[2]; + //lam = (float) Params[7]; + + while(tid0Th*Th){ + flag = 0; + for(j=-lockout;j<=lockout;j++) + if(sdata[tid+lockout+j]>err0){ + flag = 1; + break; + } + if(flag==0){ + indx = atomicAdd(&counter[0], 1); + if (indxTh){ + if (id[currInd]==bid){ + if (tidx==0 && threadIdx.y==0) + nsp[bid]++; + + tidy = threadIdx.y; + while (tidyThS){ + + tidy = threadIdx.y; + // only do this if the spike is "BAD" + while (tidy xmax){ + xmax = abs(sW[t]); + imax = t; + } + + tid = threadIdx.x; + // shift by imax - tmax + for (k=0;k xmax){ + xmax = abs(sWup[t]); + imax = t; + sgnmax = copysign(1.0f, sWup[t]); + } + + // interpolate by imax + for (k=0;k Cf){ + flag = false; + break; + } + } + + if (flag){ + iChan = iC[NchanNear * i]; + if (Cf>spkTh){ + d = (double) dataraw[tid0+nt0min-1 + NT*iChan]; // + if (d > Cf-1e-6){ + // this is a hit, atomicAdd and return spikes + indx = atomicAdd(&counter[0], 1); + if (indxspkTh) + conv_sig[tid0 + tid + NT*bid] = y; + + tid0+=Nthreads; + __syncthreads(); + } +} diff --git a/cuda/mexWtW2.cu b/cuda/mexWtW2.cu new file mode 100644 index 00000000..ccd32fb6 --- /dev/null +++ b/cuda/mexWtW2.cu @@ -0,0 +1,54 @@ +const int nblock = 32; +////////////////////////////////////////////////////////////////////////////////////////// + +__global__ void crossFilter(const double *Params, const float *W1, const float *W2, + const float *UtU, float *WtW){ + __shared__ float shW1[nblock*81], shW2[nblock*81]; + + float x; + int nt0, tidx, tidy , bidx, bidy, i, Nfilt, t, tid1, tid2; + + tidx = threadIdx.x; + tidy = threadIdx.y; + bidx = blockIdx.x; + bidy = blockIdx.y; + + Nfilt = (int) Params[1]; + nt0 = (int) Params[9]; + + tid1 = tidx + bidx*nblock; + + tid2 = tidy + bidx*nblock; + if (tid2= nChan] = nChan - 1 # only needed for channels not time (due to time buffer) + + indsT = cp.transpose(cp.atleast_3d(indsT), [0, 2, 1]) + indsC = cp.transpose(cp.atleast_3d(indsC), [2, 0, 1]) + + # believe it or not, these indices grab just the right timesamples forour spikes + ix = indsT + indsC * nT + + # grab the data and reshape it appropriately (time samples by channels by num spikes) + clips = dataRAW.T.ravel()[ix[:, 0, :]].reshape((dt.size, row.size), order='F') # HERE + return clips + + +def extractPCfromSnippets(proc, probe=None, params=None, Nbatch=None): + # extracts principal components for 1D snippets of spikes from all channels + # loads a subset of batches to find these snippets + + NT = params.NT + nPCs = params.nPCs + Nchan = probe.Nchan + + batchstart = np.arange(0, NT * Nbatch + 1, NT) + + # extract the PCA projections + # initialize the covariance of single-channel spike waveforms + CC = cp.zeros(params.nt0, dtype=np.float32) + + # from every 100th batch + for ibatch in range(0, Nbatch, 100): + offset = Nchan * batchstart[ibatch] + dat = proc.flat[offset:offset + NT * Nchan].reshape((-1, Nchan), order='F') + if dat.shape[0] == 0: + continue + + # move data to GPU and scale it back to unit variance + dataRAW = cp.asarray(dat, dtype=np.float32) / params.scaleproc + + # find isolated spikes from each batch + row, col, mu = isolated_peaks_new(dataRAW, params) + + # for each peak, get the voltage snippet from that channel + c = get_SpikeSample(dataRAW, row, col, params) + + # scale covariance down by 1,000 to maintain a good dynamic range + CC = CC + cp.dot(c, c.T) / 1e3 + + # the singular vectors of the covariance matrix are the PCs of the waveforms + U, Sv, V = svdecon(CC) + + wPCA = U[:, :nPCs] # take as many as needed + + # adjust the arbitrary sign of the first PC so its negativity is downward + wPCA[:, 0] = -wPCA[:, 0] * cp.sign(wPCA[20, 0]) + + return wPCA + + +def sortBatches2(ccb0): + # takes as input a matrix of nBatches by nBatches containing + # dissimilarities. + # outputs a matrix of sorted batches, and the sorting order, such that + # ccb1 = ccb0(isort, isort) + + # put this matrix on the GPU + ccb0 = cp.asarray(ccb0, order='F') + + # compute its svd on the GPU (this might also be fast enough on CPU) + u, s, v = svdecon(ccb0) + # HACK: consistency with MATLAB + u = u * cp.sign(u[0, 0]) + v = v * cp.sign(u[0, 0]) + + # initialize the positions xs of the batch embeddings to be very small but proportional to + # the first PC + xs = .01 * u[:, 0] / cp.std(u[:, 0], ddof=1) + + # 200 iterations of gradient descent should be enough + niB = 200 + + # this learning rate should usually work fine, since it scales with the average gradient + # and ccb0 is z-scored + eta = 1 + for k in tqdm(range(niB), desc="Sorting %d batches" % ccb0.shape[0]): + # euclidian distances between 1D embedding positions + ds = (xs - xs[:, np.newaxis]) ** 2 + # the transformed distances go through this function + W = cp.log(1 + ds) + + # the error is the difference between ccb0 and W + err = ccb0 - W + + # ignore the mean value of ccb0 + err = err - cp.mean(err, axis=0) + + # backpropagate the gradients + err = err / (1 + ds) + err2 = err * (xs[:, np.newaxis] - xs) + D = cp.mean(err2, axis=1) # one half of the gradients is along this direction + E = cp.mean(err2, axis=0) # the other half is along this direction + # we don't need to worry about the gradients for the diagonal because those are 0 + + # final gradients for the embedding variable + dx = -D + E.T + + # take a gradient step + xs = xs - eta * dx + + # sort the embedding positions xs + isort = cp.argsort(xs, axis=0) + + # sort the matrix of dissimilarities + ccb1 = ccb0[isort, :][:, isort] + + return ccb1, isort + + +def initializeWdata2(call, uprojDAT, Nchan, nPCs, Nfilt, iC): + # this function initializes cluster means for the fast kmeans per batch + # call are time indices for the spikes + # uprojDAT are features projections (Nfeatures by Nspikes) + # some more parameters need to be passed in from the main workspace + + # pick random spikes from the sample + # WARNING: replace ceil by warning because this is a random index, and 0/1 indexing + # discrepancy between Python and MATLAB. + irand = np.floor(np.random.rand(Nfilt) * uprojDAT.shape[1]).astype(np.int32) + + W = cp.zeros((nPCs, Nchan, Nfilt), dtype=np.float32) + + for t in range(Nfilt): + ich = iC[:, call[irand[t]]] # the channels on which this spike lives + # for each selected spike, get its features + W[:, ich, t] = uprojDAT[:, irand[t]].reshape(W[:, ich, t].shape, order='F') + + W = W.reshape((-1, Nfilt), order='F') # HERE + # add small amount of noise in case we accidentally picked the same spike twice + W = W + .001 * cp.random.normal(size=W.shape).astype(np.float32) + mu = cp.sqrt(cp.sum(W ** 2, axis=0)) # get the mean of the template + W = W / (1e-5 + mu) # and normalize the template + W = W.reshape((nPCs, Nchan, Nfilt), order='F') # HERE + nW = (W[0, ...] ** 2) # squared amplitude of the first PC feture + W = W.reshape((nPCs * Nchan, Nfilt), order='F') # HERE + # determine biggest channel according to the amplitude of the first PC + Wheights = cp.argmax(nW, axis=0) + + return W, mu, Wheights, irand + + +def mexThSpkPC(Params, dataRAW, wPCA, iC): + code, constants = get_cuda('mexThSpkPC') + Nthreads = constants.Nthreads + maxFR = constants.maxFR + + NT, Nchan, NchanNear, nt0, nt0min, spkTh, NrankPC = Params + NT = int(NT) + Nchan = int(Nchan) + + # Input GPU arrays. + d_Params = cp.asarray(Params, dtype=np.float64, order='F') + d_data = cp.asarray(dataRAW, dtype=np.float32, order='F') + d_W = cp.asarray(wPCA, dtype=np.float32, order='F') + d_iC = cp.asarray(iC, dtype=np.int32, order='F') + + # New GPU arrays. + d_dout = cp.zeros((Nchan, NT), dtype=np.float32, order='F') + d_dmax = cp.zeros((Nchan, NT), dtype=np.float32, order='F') + d_st = cp.zeros(maxFR, dtype=np.int32, order='F') + d_id = cp.zeros(maxFR, dtype=np.int32, order='F') + d_counter = cp.zeros(1, dtype=np.int32, order='F') + + # filter the data with the temporal templates + Conv1D = cp.RawKernel(code, 'Conv1D') + Conv1D((Nchan,), (Nthreads,), (d_Params, d_data, d_W, d_dout)) + + # get the max of the data + max1D = cp.RawKernel(code, 'max1D') + max1D((Nchan,), (Nthreads,), (d_Params, d_dout, d_dmax)) + + # take max across nearby channels + maxChannels = cp.RawKernel(code, 'maxChannels') + maxChannels( + (int(NT // Nthreads),), (Nthreads,), + (d_Params, d_dout, d_dmax, d_iC, d_st, d_id, d_counter)) + + # move d_x to the CPU + minSize = 1 + minSize = min(maxFR, int(d_counter[0])) + + d_featPC = cp.zeros((NrankPC * NchanNear, minSize), dtype=np.float32, order='F') + + d_id2 = cp.zeros(minSize, dtype=np.int32, order='F') + + if (minSize > 0): + computeProjections = cp.RawKernel(code, 'computeProjections') + computeProjections( + (minSize,), (NchanNear, NrankPC), (d_Params, d_data, d_iC, d_st, d_id, d_W, d_featPC)) + + # TODO: check that the copy occurs on the GPU only + d_id2[:] = d_id[:minSize] + + # Free memory. + del d_st, d_id, d_counter, d_Params, d_dmax, d_dout + # free_gpu_memory() + + return d_featPC, d_id2 + + +def extractPCbatch2(proc, params, probe, wPCA, ibatch, iC, Nbatch): + # this function finds threshold crossings in the data using + # projections onto the pre-determined principal components + # wPCA is number of time samples by number of PCs + # ibatch is a scalar indicating which batch to analyze + # iC is NchanNear by Nchan, indicating for each channel the nearest + # channels to it + + nt0min = params.nt0min + spkTh = params.ThPre + nt0, NrankPC = wPCA.shape + NT, Nchan = params.NT, probe.Nchan + + # starts with predefined PCA waveforms + wPCA = wPCA[:, :3] + + NchanNear = iC.shape[0] + + batchstart = np.arange(0, NT * Nbatch + 1, NT) # batches start at these timepoints + + offset = Nchan * batchstart[ibatch] + dat = proc.flat[offset:offset + NT * Nchan].reshape((-1, Nchan), order='F') + dataRAW = cp.asarray(dat, dtype=np.float32) / params.scaleproc + + # another Params variable to take all our parameters into the C++ code + Params = [NT, Nchan, NchanNear, nt0, nt0min, spkTh, NrankPC] + + # call a CUDA function to do the hard work + # returns a matrix of features uS, as well as the center channels for each spike + uS, idchan = mexThSpkPC(Params, dataRAW, wPCA, iC) + + return uS, idchan + + +def mexClustering2(Params, uproj, W, mu, call, iMatch, iC): + + code, _ = get_cuda('mexClustering2') + + Nspikes = int(Params[0]) + NrankPC = int(Params[1]) + Nfilters = int(Params[2]) + NchanNear = int(Params[6]) + Nchan = int(Params[7]) + + d_Params = cp.asarray(Params, dtype=np.float64, order='F') + d_uproj = cp.asarray(uproj, dtype=np.float32, order='F') + d_W = cp.asarray(W, dtype=np.float32, order='F') + d_mu = cp.asarray(mu, dtype=np.float32, order='F') + d_call = cp.asarray(call, dtype=np.int32, order='F') + d_iC = cp.asarray(iC, dtype=np.int32, order='F') + d_iMatch = cp.asarray(iMatch, dtype=np.bool, order='F') + + d_dWU = cp.zeros((NrankPC * Nchan, Nfilters), dtype=np.float32, order='F') + d_cmax = cp.zeros((Nspikes, Nfilters), dtype=np.float32, order='F') + d_id = cp.zeros(Nspikes, dtype=np.int32, order='F') + d_x = cp.zeros(Nspikes, dtype=np.float32, order='F') + d_nsp = cp.zeros(Nfilters, dtype=np.int32, order='F') + d_V = cp.zeros(Nfilters, dtype=np.float32, order='F') + + # get list of cmaxes for each combination of neuron and filter + computeCost = cp.RawKernel(code, 'computeCost') + computeCost( + (Nfilters,), (1024,), (d_Params, d_uproj, d_mu, d_W, d_iMatch, d_iC, d_call, d_cmax)) + + # loop through cmax to find best template + bestFilter = cp.RawKernel(code, 'bestFilter') + bestFilter((40,), (256,), (d_Params, d_iMatch, d_iC, d_call, d_cmax, d_id, d_x)) + + # average all spikes for same template -- ORIGINAL + average_snips = cp.RawKernel(code, 'average_snips') + average_snips( + (Nfilters,), (NrankPC, NchanNear), (d_Params, d_iC, d_call, d_id, d_uproj, d_cmax, d_dWU)) + + count_spikes = cp.RawKernel(code, 'count_spikes') + count_spikes((7,), (256,), (d_Params, d_id, d_nsp, d_x, d_V)) + + del d_Params, d_V + + return d_dWU, d_id, d_x, d_nsp, d_cmax + + +def mexDistances2(Params, Ws, W, iMatch, iC, Wh, mus, mu): + code, _ = get_cuda('mexDistances2') + + Nspikes = int(Params[0]) + Nfilters = int(Params[2]) + + d_Params = cp.asarray(Params, dtype=np.float64, order='F') + + d_Ws = cp.asarray(Ws, dtype=np.float32, order='F') + d_W = cp.asarray(W, dtype=np.float32, order='F') + d_iMatch = cp.asarray(iMatch, dtype=np.bool, order='F') + d_iC = cp.asarray(iC, dtype=np.int32, order='F') + d_Wh = cp.asarray(Wh, dtype=np.int32, order='F') + d_mu = cp.asarray(mu, dtype=np.float32, order='F') + d_mus = cp.asarray(mus, dtype=np.float32, order='F') + + d_cmax = cp.zeros(Nspikes * Nfilters, dtype=np.float32, order='F') + d_id = cp.zeros(Nspikes, dtype=np.int32, order='F') + d_x = cp.zeros(Nspikes, dtype=np.float32, order='F') + + # get list of cmaxes for each combination of neuron and filter + computeCost = cp.RawKernel(code, 'computeCost') + computeCost( + (Nfilters,), (1024,), (d_Params, d_Ws, d_mus, d_W, d_mu, d_iMatch, d_iC, d_Wh, d_cmax)) + + # loop through cmax to find best template + bestFilter = cp.RawKernel(code, 'bestFilter') + bestFilter((40,), (256,), (d_Params, d_iMatch, d_Wh, d_cmax, d_mus, d_id, d_x)) + + del d_Params, d_cmax + + return d_id, d_x + + +def clusterSingleBatches(ctx): + """ + outputs an ordering of the batches according to drift + for each batch, it extracts spikes as threshold crossings and clusters them with kmeans + the resulting cluster means are then compared for all pairs of batches, and a dissimilarity + score is assigned to each pair + the matrix of similarity scores is then re-ordered so that low dissimilaity is along + the diagonal + """ + Nbatch = ctx.intermediate.Nbatch + params = ctx.params + probe = ctx.probe + raw_data = ctx.raw_data + ir = ctx.intermediate + proc = ir.proc + + if not params.reorder: + # if reordering is turned off, return consecutive order + iorig = np.arange(Nbatch) + return iorig, None, None + + nPCs = params.nPCs + Nfilt = ceil(probe.Nchan / 2) + + # extract PCA waveforms pooled over channels + wPCA = extractPCfromSnippets(proc, probe=probe, params=params, Nbatch=Nbatch) + + Nchan = probe.Nchan + niter = 10 # iterations for k-means. we won't run it to convergence to save time + + nBatches = Nbatch + NchanNear = min(Nchan, 2 * 8 + 1) + + # initialize big arrays on the GPU to hold the results from each batch + # this holds the unit norm templates + Ws = cp.zeros((nPCs, NchanNear, Nfilt, nBatches), dtype=np.float32, order='F') + # this holds the scalings + mus = cp.zeros((Nfilt, nBatches), dtype=np.float32, order='F') + # this holds the number of spikes for that cluster + ns = cp.zeros((Nfilt, nBatches), dtype=np.float32, order='F') + # this holds the center channel for each template + Whs = ones((Nfilt, nBatches), dtype=np.int32, order='F') + + i0 = 0 + NrankPC = 3 # I am not sure if this gets used, but it goes into the function + + # return an array of closest channels for each channel + iC = getClosestChannels(probe, params.sigmaMask, NchanNear)[0] + + for ibatch in tqdm(range(nBatches), desc="Clustering spikes"): + + # extract spikes using PCA waveforms + uproj, call = extractPCbatch2( + proc, params, probe, wPCA, min(nBatches - 2, ibatch), iC, Nbatch) + print("Number of PCS is: " + str(wPCA.shape)) + if cp.sum(cp.isnan(uproj)) > 0: + break # I am not sure what case this safeguards against.... + print(uproj.shape[1]) + if uproj.shape[1] > Nfilt: + + # this initialize the k-means + W, mu, Wheights, irand = initializeWdata2(call, uproj, Nchan, nPCs, Nfilt, iC) + + # Params is a whole bunch of parameters sent to the C++ scripts inside a float64 vector + Params = [uproj.shape[1], NrankPC, Nfilt, 0, W.shape[0], 0, NchanNear, Nchan] + + for i in range(niter): + + Wheights = Wheights.reshape((1, 1, -1), order='F') + iC = cp.atleast_3d(iC) + + # we only compute distances to clusters on the same channels + # this tells us which spikes and which clusters might match + iMatch = cp.min(cp.abs(iC - Wheights), axis=0) < .1 + + # get iclust and update W + # CUDA script to efficiently compute distances for pairs in which iMatch is 1 + dWU, iclust, dx, nsp, dV = mexClustering2(Params, uproj, W, mu, call, iMatch, iC) + + dWU = dWU / (1e-5 + nsp.T) # divide the cumulative waveform by the number of spike + + mu = cp.sqrt(cp.sum(dWU ** 2, axis=0)) # norm of cluster template + W = dWU / (1e-5 + mu) # unit normalize templates + + W = W.reshape((nPCs, Nchan, Nfilt), order='F') + nW = W[0, ...] ** 2 # compute best channel from the square of the first PC feature + W = W.reshape((Nchan * nPCs, Nfilt), order='F') + + Wheights = cp.argmax(nW, axis=0) # the new best channel of each cluster template + + # carefully keep track of cluster templates in dense format + W = W.reshape((nPCs, Nchan, Nfilt), order='F') + W0 = cp.zeros((nPCs, NchanNear, Nfilt), dtype=np.float32, order='F') + for t in range(Nfilt): + W0[..., t] = W[:, iC[:, Wheights[t]], t].squeeze() + # I don't really know why this needs another normalization + W0 = W0 / (1e-5 + cp.sum(cp.sum(W0 ** 2, axis=0)[np.newaxis, ...], axis=1) ** .5) + + # if a batch doesn't have enough spikes, it gets the cluster templates of the previous batc + if 'W0' in locals(): + Ws[..., ibatch] = W0 + mus[:, ibatch] = mu + ns[:, ibatch] = nsp + Whs[:, ibatch] = Wheights.astype(np.int32) + else: + logger.warning('Data batch #%d only had %d spikes.', ibatch, uproj.shape[1]) + + i0 = i0 + Nfilt + + # anothr one of these Params variables transporting parameters to the C++ code + Params = [1, NrankPC, Nfilt, 0, W.shape[0], 0, NchanNear, Nchan] + # the total number of templates is the number of templates per batch times the number of batch + Params[0] = Ws.shape[2] * Ws.shape[3] + + # initialize dissimilarity matrix + ccb = cp.zeros((nBatches, nBatches), dtype=np.float32, order='F') + + for ibatch in tqdm(range(nBatches), desc="Computing distances"): + # for every batch, compute in parallel its dissimilarity to ALL other batches + Wh0 = Whs[:, ibatch] # this one is the primary batch + W0 = Ws[..., ibatch] + mu = mus[..., ibatch] + + # embed the templates from the primary batch back into a full, sparse representation + W = cp.zeros((nPCs, Nchan, Nfilt), dtype=np.float32, order='F') + for t in range(Nfilt): + W[:, iC[:, Wh0[t]], t] = cp.atleast_3d(Ws[:, :, t, ibatch]) + + # pairs of templates that live on the same channels are potential "matches" + iMatch = cp.min(cp.abs(iC - Wh0.reshape((1, 1, -1), order='F')), axis=0) < .1 + + # compute dissimilarities for iMatch = 1 + iclust, ds = mexDistances2(Params, Ws, W, iMatch, iC, Whs, mus, mu) + + # ds are squared Euclidian distances + ds = ds.reshape((Nfilt, -1), order='F') # this should just be an Nfilt-long vector + ds = cp.maximum(0, ds) + + # weigh the distances according to number of spikes in cluster + ccb[ibatch, :] = cp.mean(cp.sqrt(ds) * ns, axis=0) / cp.mean(ns, axis=0) + + # ccb = cp.asnumpy(ccb) + # some normalization steps are needed: zscoring, and symmetrizing ccb + ccb0 = zscore(ccb, axis=0) + ccb0 = ccb0 + ccb0.T + + # sort by manifold embedding algorithm + # iorig is the sorting of the batches + # ccbsort is the resorted matrix (useful for diagnosing drift) + ccbsort, iorig = sortBatches2(ccb0) + logger.info("Finished clustering.") + + return Bunch(iorig=iorig, ccb0=ccb0, ccbsort=ccbsort) diff --git a/src/yass/reordering/default_params.py b/src/yass/reordering/default_params.py new file mode 100644 index 00000000..5d2cba40 --- /dev/null +++ b/src/yass/reordering/default_params.py @@ -0,0 +1,70 @@ +from math import ceil +from .utils import Bunch + +default_params = Bunch() + +# sample rate +default_params.fs = 30000. + +# frequency for high pass filtering (150) +default_params.fshigh = 150. +default_params.fslow = None + +# minimum firing rate on a "good" channel (0 to skip) +default_params.minfr_goodchannels = 0.1 + +# threshold on projections (like in Kilosort1, can be different for last pass like [10 4]) +default_params.Th = [10, 4] + +# how important is the amplitude penalty (like in Kilosort1, 0 means not used, +# 10 is average, 50 is a lot) +default_params.lam = 10 + +# splitting a cluster at the end requires at least this much isolation for each sub-cluster (max=1) +default_params.AUCsplit = 0.9 + +# minimum spike rate (Hz), if a cluster falls below this for too long it gets removed +default_params.minFR = 1. / 50 + +# number of samples to average over (annealed from first to second value) +default_params.momentum = [20, 400] + +# spatial constant in um for computing residual variance of spike +default_params.sigmaMask = 30 + +# threshold crossings for pre-clustering (in PCA projection space) +default_params.ThPre = 8 + +# danger, changing these settings can lead to fatal errors +# options for determining PCs +default_params.spkTh = -6 # spike threshold in standard deviations (-6) +default_params.reorder = 1 # whether to reorder batches for drift correction. +default_params.nskip = 25 # how many batches to skip for determining spike PCs + +# default_params.GPU = 1 # has to be 1, no CPU version yet, sorry +# default_params.Nfilt = 1024 # max number of clusters +default_params.nfilt_factor = 4 # max number of clusters per good channel (even temporary ones) +default_params.ntbuff = 64 # samples of symmetrical buffer for whitening and spike detection +# must be multiple of 32 + ntbuff. This is the batch size (try decreasing if out of memory). +default_params.whiteningRange = 32 # number of channels to use for whitening each channel +default_params.nSkipCov = 25 # compute whitening matrix from every N-th batch +default_params.scaleproc = 200 # int16 scaling of whitened data +default_params.nPCs = 3 # how many PCs to project the spikes into +# default_params.useRAM = 0 # not yet available + +default_params.nt0 = 61 +default_params.nup = 10 +default_params.sig = 1 +default_params.gain = 1 + +default_params.loc_range = [5, 4] +default_params.long_range = [30, 6] + + +def set_dependent_params(params): + """Add dependent parameters.""" + # we need buffers on both sides for filtering + params.NT = params.get('NT', 64 * 1024 + params.ntbuff) + params.NTbuff = params.get('NTbuff', params.NT + 4 * params.ntbuff) + params.nt0min = params.get('nt0min', ceil(20 * params.nt0 / 61)) + return params diff --git a/src/yass/reordering/preprocess.py b/src/yass/reordering/preprocess.py new file mode 100644 index 00000000..c93af5b6 --- /dev/null +++ b/src/yass/reordering/preprocess.py @@ -0,0 +1,433 @@ +import logging +from math import ceil +from functools import lru_cache + +import numpy as np +from scipy.signal import butter +import cupy as cp +from tqdm import tqdm + +from .cptools import lfilter, _get_lfilter_fun, median, convolve +from .utils import is_fortran + +logger = logging.getLogger(__name__) + + +def get_filter_params(fs, fshigh=None, fslow=None): + if fslow and fslow < fs / 2: + # butterworth filter with only 3 nodes (otherwise it's unstable for float32) + return butter(3, (2 * fshigh / fs, 2 * fslow / fs), 'bandpass') + else: + # butterworth filter with only 3 nodes (otherwise it's unstable for float32) + return butter(3, fshigh / fs * 2, 'high') + + +def gpufilter(buff, chanMap=None, fs=None, fslow=None, fshigh=None, car=True): + # filter this batch of data after common average referencing with the + # median + # buff is timepoints by channels + # chanMap are indices of the channels to be kep + # params.fs and params.fshigh are sampling and high-pass frequencies respectively + # if params.fslow is present, it is used as low-pass frequency (discouraged) + + # set up the parameters of the filter + b1, a1 = get_filter_params(fs, fshigh=fshigh, fslow=fslow) + + dataRAW = buff.T + dataRAW.ndim == 2 + if chanMap is not None: + dataRAW = dataRAW[:, chanMap] # subsample only good channels + assert dataRAW.ndim == 2 + + # subtract the mean from each channel + assert dataRAW.ndim == 2 + dataRAW = dataRAW - cp.mean(dataRAW, axis=0) # subtract mean of each channel + assert dataRAW.ndim == 2 + + # CAR, common average referencing by median + if car: + # subtract median across channels + dataRAW = dataRAW - median(dataRAW, axis=1)[:, np.newaxis] + + # next four lines should be equivalent to filtfilt (which cannot be + # used because it requires float64) + datr = lfilter(b1, a1, dataRAW, axis=0) # causal forward filter + datr = lfilter(b1, a1, datr, axis=0, reverse=True) # backward + return datr + + +def _is_vect(x): + return hasattr(x, '__len__') and len(x) > 1 + + +def _make_vect(x): + if not hasattr(x, '__len__'): + x = np.array([x]) + return x + + +def my_min(S1, sig, varargin=None): + # returns a running minimum applied sequentially across a choice of dimensions and bin sizes + # S1 is the matrix to be filtered + # sig is either a scalar or a sequence of scalars, one for each axis to be filtered. + # it's the plus/minus bin length for the minimum filter + # varargin can be the dimensions to do filtering, if len(sig) != x.shape + # if sig is scalar and no axes are provided, the default axis is 2 + idims = 1 + if varargin is not None: + idims = varargin + idims = _make_vect(idims) + if _is_vect(idims) and _is_vect(sig): + sigall = sig + else: + sigall = np.tile(sig, len(idims)) + + for sig, idim in zip(sigall, idims): + Nd = S1.ndim + S1 = cp.transpose(S1, [idim] + list(range(0, idim)) + list(range(idim + 1, Nd))) + dsnew = S1.shape + S1 = cp.reshape(S1, (S1.shape[0], -1), order='F') + dsnew2 = S1.shape + S1 = cp.concatenate( + (cp.full((sig, dsnew2[1]), np.inf), S1, cp.full((sig, dsnew2[1]), np.inf)), axis=0) + Smax = S1[:dsnew2[0], :] + for j in range(1, 2 * sig + 1): + Smax = cp.minimum(Smax, S1[j:j + dsnew2[0], :]) + S1 = cp.reshape(Smax, dsnew, order='F') + S1 = cp.transpose(S1, list(range(1, idim + 1)) + [0] + list(range(idim + 1, Nd))) + return S1 + + +def my_sum(S1, sig, varargin=None): + # returns a running sum applied sequentially across a choice of dimensions and bin sizes + # S1 is the matrix to be filtered + # sig is either a scalar or a sequence of scalars, one for each axis to be filtered. + # it's the plus/minus bin length for the summing filter + # varargin can be the dimensions to do filtering, if len(sig) != x.shape + # if sig is scalar and no axes are provided, the default axis is 2 + idims = 1 + if varargin is not None: + idims = varargin + idims = _make_vect(idims) + if _is_vect(idims) and _is_vect(sig): + sigall = sig + else: + sigall = np.tile(sig, len(idims)) + + for sig, idim in zip(sigall, idims): + Nd = S1.ndim + S1 = cp.transpose(S1, [idim] + list(range(0, idim)) + list(range(idim + 1, Nd))) + dsnew = S1.shape + S1 = cp.reshape(S1, (S1.shape[0], -1), order='F') + dsnew2 = S1.shape + S1 = cp.concatenate( + (cp.full((sig, dsnew2[1]), 0), S1, cp.full((sig, dsnew2[1]), 0)), axis=0) + Smax = S1[:dsnew2[0], :] + for j in range(1, 2 * sig + 1): + Smax = Smax + S1[j:j + dsnew2[0], :] + S1 = cp.reshape(Smax, dsnew, order='F') + S1 = cp.transpose(S1, list(range(1, idim + 1)) + [0] + list(range(idim + 1, Nd))) + return S1 + + +@lru_cache(128) +def _gaus_lfilter(sig, axis=0, is_fortran=True, reverse=False): + tmax = ceil(4 * sig) + dt = np.arange(-tmax, tmax + 1) + gaus = np.exp(-dt ** 2 / (2 * sig ** 2)) + gaus = gaus[:, np.newaxis] / np.sum(gaus) + return _get_lfilter_fun(gaus, 1, is_fortran=is_fortran, axis=axis, reverse=reverse) + + +def my_conv2(S1, sig, varargin=None): + # S1 is the matrix to be filtered along a choice of axes + # sig is either a scalar or a sequence of scalars, one for each axis to be filtered + # varargin can be the dimensions to do filtering, if len(sig) != x.shape + # if sig is scalar and no axes are provided, the default axis is 2 + if sig <= .25: + return S1 + idims = 1 + if varargin is not None: + idims = varargin + idims = _make_vect(idims) + if _is_vect(idims) and _is_vect(sig): + sigall = sig + else: + sigall = np.tile(sig, len(idims)) + + for sig, idim in zip(sigall, idims): + Nd = S1.ndim + S1 = cp.transpose(S1, [idim] + list(range(0, idim)) + list(range(idim + 1, Nd))) + dsnew = S1.shape + S1 = cp.reshape(S1, (S1.shape[0], -1), order='F') + dsnew2 = S1.shape + + tmax = ceil(4 * sig) + dt = cp.arange(-tmax, tmax + 1) + gaus = cp.exp(-dt ** 2 / (2 * sig ** 2)) + gaus = gaus[:, cp.newaxis] / cp.sum(gaus) + + # This GPU FFT-based convolution leads to a splitting step 3.5x faster than the + # custom GPU lfilter implementation below. + cNorm = convolve(cp.ones((dsnew2[0], 1)), gaus).ravel()[:, cp.newaxis] + S1 = convolve(S1, gaus) + + # Slow Custom GPU lfilter implementation: + # cNorm = _apply_lfilter( + # _gaus_lfilter(sig), + # cp.concatenate((cp.ones(dsnew2[0]), cp.zeros(tmax)))[:, np.newaxis]) + # cNorm = cNorm[tmax:, :] + # S1 = _apply_lfilter(_gaus_lfilter(sig), cp.asfortranarray(cp.concatenate( + # (S1, cp.zeros((tmax, dsnew2[1]), order='F')), axis=0))) + # S1 = S1[tmax:, :] + + S1 = S1.reshape(dsnew, order='F') + S1 = S1 / cNorm + + S1 = cp.transpose(S1, list(range(1, idim + 1)) + [0] + list(range(idim + 1, Nd))) + return S1 + + +def whiteningFromCovariance(CC): + # function Wrot = whiteningFromCovariance(CC) + # takes as input the matrix CC of channel pairwise correlations + # outputs a symmetric rotation matrix (also Nchan by Nchan) that rotates + # the data onto uncorrelated, unit-norm axes + + # covariance eigendecomposition (same as svd for positive-definite matrix) + E, D, _ = cp.linalg.svd(CC) + eps = 1e-6 + Di = cp.diag(1. / (D + eps) ** .5) + Wrot = cp.dot(cp.dot(E, Di), E.T) # this is the symmetric whitening matrix (ZCA transform) + return Wrot + + +def whiteningLocal(CC, yc, xc, nRange): + # function to perform local whitening of channels + # CC is a matrix of Nchan by Nchan correlations + # yc and xc are vector of Y and X positions of each channel + # nRange is the number of nearest channels to consider + Wrot = cp.zeros((CC.shape[0], CC.shape[0])) + + for j in range(CC.shape[0]): + ds = (xc - xc[j]) ** 2 + (yc - yc[j]) ** 2 + ilocal = np.argsort(ds) + # take the closest channels to the primary channel. + # First channel in this list will always be the primary channel. + ilocal = ilocal[:nRange] + + wrot0 = cp.asnumpy(whiteningFromCovariance(CC[np.ix_(ilocal, ilocal)])) + # the first column of wrot0 is the whitening filter for the primary channel + Wrot[ilocal, j] = wrot0[:, 0] + + return Wrot + + +def get_whitening_matrix(raw_data=None, probe=None, params=None): + """ + based on a subset of the data, compute a channel whitening matrix + this requires temporal filtering first (gpufilter) + """ + Nbatch = get_Nbatch(raw_data, params) + ntbuff = params.ntbuff + NTbuff = params.NTbuff + whiteningRange = params.whiteningRange + scaleproc = params.scaleproc + NT = params.NT + fs = params.fs + fshigh = params.fshigh + nSkipCov = params.nSkipCov + + xc = probe.xc + yc = probe.yc + chanMap = probe.chanMap + Nchan = probe.Nchan + chanMap = probe.chanMap + + # Nchan is obtained after the bad channels have been removed + CC = cp.zeros((Nchan, Nchan)) + + for ibatch in tqdm(range(0, Nbatch, nSkipCov), desc="Computing the whitening matrix"): + # WARNING: we use Fortran order, so raw_data is NchanTOT x nsamples + i = max(0, (NT - ntbuff) * ibatch - 2 * ntbuff) + buff = raw_data[:, i:i + NT - ntbuff] + + nsampcurr = buff.shape[1] + if nsampcurr < NTbuff: + buff = np.concatenate( + (buff, np.tile(buff[:, nsampcurr - 1][:, np.newaxis], (1, NTbuff))), axis=1) + + buff_g = cp.asarray(buff, dtype=np.float32) + + # apply filters and median subtraction + datr = gpufilter(buff_g, fs=fs, fshigh=fshigh, chanMap=chanMap) + + CC = CC + cp.dot(datr.T, datr) / NT # sample covariance + + CC = CC / ceil((Nbatch - 1) / nSkipCov) + + if whiteningRange < np.inf: + # if there are too many channels, a finite whiteningRange is more robust to noise + # in the estimation of the covariance + whiteningRange = min(whiteningRange, Nchan) + # this function performs the same matrix inversions as below, just on subsets of + # channels around each channel + Wrot = whiteningLocal(CC, yc, xc, whiteningRange) + else: + Wrot = whiteningFromCovariance(CC) + + Wrot = Wrot * scaleproc + + logger.info("Computed the whitening matrix.") + + return Wrot + + +def get_good_channels(raw_data=None, probe=None, params=None): + """ + of the channels indicated by the user as good (chanMap) + further subset those that have a mean firing rate above a certain value + (default is ops.minfr_goodchannels = 0.1Hz) + needs the same filtering parameters in ops as usual + also needs to know where to start processing batches (twind) + and how many channels there are in total (NchanTOT) + """ + fs = params.fs + fshigh = params.fshigh + fslow = params.fslow + Nbatch = get_Nbatch(raw_data, params) + NT = params.NT + spkTh = params.spkTh + nt0 = params.nt0 + minfr_goodchannels = params.minfr_goodchannels + + chanMap = probe.chanMap + # Nchan = probe.Nchan + NchanTOT = len(chanMap) + + ich = [] + k = 0 + ttime = 0 + + # skip every 100 batches + for ibatch in tqdm(range(0, Nbatch, int(ceil(Nbatch / 100))), desc="Finding good channels"): + i = NT * ibatch + buff = raw_data[:, i:i + NT] + if buff.size == 0: + break + + # Put on GPU. + buff = cp.asarray(buff, dtype=np.float32) + + datr = gpufilter(buff, chanMap=chanMap, fs=fs, fshigh=fshigh, fslow=fslow) + + # very basic threshold crossings calculation + s = cp.std(datr, axis=0) + datr = datr / s # standardize each channel ( but don't whiten) + mdat = my_min(datr, 30, 0) # get local minima as min value in +/- 30-sample range + + # take local minima that cross the negative threshold + xi, xj = cp.nonzero((datr < mdat + 1e-3) & (datr < spkTh)) + + # filtering may create transients at beginning or end. Remove those. + xj = xj[(xi >= nt0) & (xi <= NT - nt0)] + + # collect the channel identities for the detected spikes + ich.append(xj) + k += xj.size + + # keep track of total time where we took spikes from + ttime += datr.shape[0] / fs + + ich = cp.concatenate(ich) + + # count how many spikes each channel got + nc, _ = cp.histogram(ich, cp.arange(NchanTOT + 1)) + + # divide by total time to get firing rate + nc = nc / ttime + + # keep only those channels above the preset mean firing rate + igood = cp.asnumpy(nc >= minfr_goodchannels) + + logger.info('Found %d threshold crossings in %2.2f seconds of data.' % (k, ttime)) + logger.info('Found %d/%d bad channels.' % (np.sum(~igood), len(igood))) + + return igood + + +def get_Nbatch(raw_data, params): + # WARNING: F order for now, so (n_channels, n_samples) + axis = 1 if is_fortran(raw_data) else 0 + n_samples = raw_data.shape[axis] + # we assume raw_data as been already virtually split with the requested trange + return ceil(n_samples / (params.NT - params.ntbuff)) # number of data batches + + +def preprocess(ctx): + # function rez = preprocessDataSub(ops) + # this function takes an ops struct, which contains all the Kilosort2 settings and file paths + # and creates a new binary file of preprocessed data, logging new variables into rez. + # The following steps are applied: + # 1) conversion to float32 + # 2) common median subtraction + # 3) bandpass filtering + # 4) channel whitening + # 5) scaling to int16 values + + params = ctx.params + probe = ctx.probe + raw_data = ctx.raw_data + ir = ctx.intermediate + + fs = params.fs + fshigh = params.fshigh + fslow = params.fslow + Nbatch = ir.Nbatch + NT = params.NT + NTbuff = params.NTbuff + + Wrot = cp.asarray(ir.Wrot) + + logger.info("Loading raw data and applying filters.") + + with open(ir.proc_path, 'wb') as fw: # open for writing processed data + for ibatch in tqdm(range(Nbatch), desc="Preprocessing"): + # we'll create a binary file of batches of NT samples, which overlap consecutively + # on params.ntbuff samples + # in addition to that, we'll read another params.ntbuff samples from before and after, + # to have as buffers for filtering + + # number of samples to start reading at. + i = max(0, (NT - params.ntbuff) * ibatch - 2 * params.ntbuff) + if ibatch == 0: + # The very first batch has no pre-buffer, and has to be treated separately + ioffset = 0 + else: + ioffset = params.ntbuff + + buff = raw_data[:, i:i + NTbuff] + if buff.size == 0: + logger.error("Loaded buffer has an empty size!") + break # this shouldn't really happen, unless we counted data batches wrong + + nsampcurr = buff.shape[1] # how many time samples the current batch has + if nsampcurr < NTbuff: + buff = np.concatenate( + (buff, np.tile(buff[:, nsampcurr - 1][:, np.newaxis], (1, NTbuff))), axis=1) + + # apply filters and median subtraction + buff = cp.asarray(buff, dtype=np.float32) + + datr = gpufilter(buff, chanMap=probe.chanMap, fs=fs, fshigh=fshigh, fslow=fslow) + + datr = datr[ioffset:ioffset + NT, :] # remove timepoints used as buffers + datr = cp.dot(datr, Wrot) # whiten the data and scale by 200 for int16 range + + # convert to int16, and gather on the CPU side + # WARNING: transpose because "tofile" always writes in C order, whereas we want + # to write in F order. + datcpu = cp.asnumpy(datr.T.astype(np.int16)) + + # write this batch to binary file + datcpu.tofile(fw) diff --git a/src/yass/reordering/utils.py b/src/yass/reordering/utils.py new file mode 100644 index 00000000..cf37f8e9 --- /dev/null +++ b/src/yass/reordering/utils.py @@ -0,0 +1,311 @@ +from contextlib import contextmanager +from functools import reduce +import json +import logging +from pathlib import Path +import operator +import os.path as op +import re +from time import perf_counter + +import numpy as np +import cupy as cp + +from .event import emit, connect, unconnect # noqa + +logger = logging.getLogger(__name__) + + +def prod(iterable): + return reduce(operator.mul, iterable, 1) + + +class Bunch(dict): + """A subclass of dictionary with an additional dot syntax.""" + def __init__(self, *args, **kwargs): + super(Bunch, self).__init__(*args, **kwargs) + self.__dict__ = self + + def copy(self): + """Return a new Bunch instance which is a copy of the current Bunch instance.""" + return Bunch(super(Bunch, self).copy()) + + +def p(x): + print("shape", x.shape, "mean", "%5e" % x.mean()) + print(x[:2, :2]) + print() + print(x[-2:, -2:]) + + +def _extend(x, i0, i1, val, axis=0): + """Extend an array along a dimension and fill it with some values.""" + shape = x.shape + if x.shape[axis] < i1: + s = list(x.shape) + s[axis] = i1 - s[axis] + x = cp.concatenate((x, cp.zeros(tuple(s), dtype=x.dtype, order='F')), axis=axis) + assert x.shape[axis] == i1 + s = [slice(None, None, None)] * x.ndim + s[axis] = slice(i0, i1, 1) + x[s] = val + for i in range(x.ndim): + if i != axis: + assert x.shape[i] == shape[i] + return x + + +def is_fortran(x): + if isinstance(x, np.ndarray): + return x.flags.f_contiguous + + +def read_data(dat_path, offset=0, shape=None, dtype=None, axis=0): + count = shape[0] * shape[1] if shape and -1 not in shape else -1 + buff = np.fromfile(dat_path, dtype=dtype, count=count, offset=offset) + if shape and -1 not in shape: + shape = (-1, shape[1]) if axis == 0 else (shape[0], -1) + if shape: + buff = buff.reshape(shape, order='F') + return buff + + +def memmap_binary_file(dat_path, n_channels=None, shape=None, dtype=None, offset=None): + """Memmap a dat file.""" + assert dtype is not None + item_size = np.dtype(dtype).itemsize + offset = offset if offset else 0 + if shape is None: + assert n_channels is not None + n_samples = (op.getsize(str(dat_path)) - offset) // (item_size * n_channels) + shape = (n_channels, n_samples) + assert shape + shape = tuple(shape) + return np.memmap(str(dat_path), dtype=dtype, shape=shape, offset=offset, order='F') + + +def extract_constants_from_cuda(code): + r = re.compile(r'const int\s+\S+\s+=\s+\S+.+') + m = r.search(code) + if m: + constants = m.group(0).replace('const int', '').replace(';', '').split(',') + for const in constants: + a, b = const.strip().split('=') + yield a.strip(), int(b.strip()) + + +def get_cuda(fn): + path = Path(__file__).parent / 'cuda' / (fn + '.cu') + assert path.exists + code = path.read_text() + code = code.replace('__global__ void', 'extern "C" __global__ void') + return code, Bunch(extract_constants_from_cuda(code)) + + +class LargeArrayWriter(object): + """Save a large array chunk by chunk, in a binary file with FORTRAN order.""" + def __init__(self, path, dtype=None, shape=None): + self.path = Path(path) + self.dtype = np.dtype(dtype) + self._shape = shape + assert shape[-1] == -1 # the last axis must be the extendable axis, in FORTRAN order + assert -1 not in shape[:-1] # shape may not contain -1 outside the last dimension + self.fw = open(self.path, 'wb') + self.extendable_axis_size = 0 + self.total_size = 0 + + def append(self, arr): + # We convert to the requested data type. + assert arr.flags.f_contiguous # only FORTRAN order arrays are currently supported + assert arr.shape[:-1] == self._shape[:-1] + arr = arr.astype(self.dtype) + es = arr.shape[-1] + if arr.flags.f_contiguous: + arr = arr.T + # We download the array from the GPU if required. + # We ensure the array is in FORTRAN order now. + assert arr.flags.c_contiguous + if isinstance(arr, cp.ndarray): + arr = cp.asnumpy(arr) + arr.tofile(self.fw) + self.total_size += arr.size + self.extendable_axis_size += es # the last dimension, but + assert prod(self.shape) == self.total_size + + @property + def shape(self): + return self._shape[:-1] + (self.extendable_axis_size,) + + def close(self): + self.fw.close() + # Save JSON metadata file. + with open(self.path.with_suffix('.json'), 'w') as f: + json.dump({'shape': self.shape, 'dtype': str(self.dtype), 'order': 'F'}, f) + + +def memmap_large_array(path): + """Memmap a large array saved by LargeArrayWriter.""" + path = Path(path) + with open(path.with_suffix('.json'), 'r') as f: + metadata = json.load(f) + assert metadata['order'] == 'F' + dtype = np.dtype(metadata['dtype']) + shape = metadata['shape'] + return memmap_binary_file(path, shape=shape, dtype=dtype) + + +class Context(Bunch): + def __init__(self, context_path): + super(Context, self).__init__() + self.context_path = context_path + self.intermediate = Bunch() + self.context_path.mkdir(exist_ok=True, parents=True) + self.timer = {} + + @property + def metadata_path(self): + return self.context_path / 'metadata.json' + + def path(self, name, ext='.npy'): + """Path to an array in the context directory.""" + return self.context_path / (name + ext) + + def read_metadata(self): + """Read the metadata dictionary from the metadata.json file in the context dir.""" + if not self.metadata_path.exists(): + return Bunch() + with open(self.metadata_path, 'r') as f: + return Bunch(json.load(f)) + + def write_metadata(self, metadata): + """Write metadata dictionary in the metadata.json file.""" + with open(self.metadata_path, 'w') as f: + json.dump(metadata, f, indent=2) + + def read(self, name): + """Read an array from memory (intermediate object) or from disk.""" + if name not in self.intermediate: + path = self.path(name) + # Load a NumPy file. + if path.exists(): + logger.debug("Loading %s.npy", name) + # Memmap for large files. + mmap_mode = 'r' if op.getsize(path) > 1e8 else None + self.intermediate[name] = np.load(path, mmap_mode=mmap_mode) + else: + # Load a value from the metadata file. + self.intermediate[name] = self.read_metadata().get(name, None) + return self.intermediate[name] + + def write(self, **kwargs): + """Write several arrays.""" + # Load the metadata. + if self.metadata_path.exists(): + metadata = self.read_metadata() + else: + metadata = Bunch() + # Write all variables. + for k, v in kwargs.items(): + # Transfer GPU arrays to the CPU before saving them. + if isinstance(v, cp.ndarray): + v = cp.asnumpy(v) + if isinstance(v, np.ndarray): + p = self.path(k) + overwrite = ' (overwrite)' if p.exists() else '' + logger.debug("Saving %s.npy%s", k, overwrite) + np.save(p, np.asfortranarray(v)) + elif v is not None: + logger.debug("Save %s in the metadata.json file.", k) + metadata[k] = v + # Write the metadata file. + self.write_metadata(metadata) + + def load(self): + """Load intermediate results from disk.""" + # Load metadata values that are not already loaded in the intermediate dictionary. + self.intermediate.update( + {k: v for k, v in self.read_metadata().items() if k not in self.intermediate}) + # Load NumPy arrays that are not already loaded in the intermediate dictionary. + names = [f.stem for f in self.context_path.glob('*.npy')] + self.intermediate.update( + {name: self.read(name) for name in names if name not in self.intermediate}) + + def save(self, **kwargs): + """Save intermediate results to the ctx.intermediate dictionary, and to disk also. + + This has two effects: + 1. variables are available via ctx.intermediate in the current session + 2. In a future session with ctx.load(), these variables will be readily available in + ctx.intermediate + + """ + for k, v in kwargs.items(): + if v is not None: + self.intermediate[k] = v + kwargs = kwargs or self.intermediate + self.write(**kwargs) + + @contextmanager + def time(self, name): + """Context manager to measure the time of a section of code.""" + t0 = perf_counter() + yield + t1 = perf_counter() + self.timer[name] = t1 - t0 + self.show_timer(name) + + def show_timer(self, name=None): + """Display the results of the timer.""" + if name: + logger.info("Step `{:s}` took {:.2f}s.".format(name, self.timer[name])) + return + for name in self.timer.keys(): + self.show_timer(name) + + +def load_probe(probe_path): + """Load a .mat probe file from Kilosort2, or a PRB file (experimental).""" + + # A bunch with the following attributes: + _required_keys = ('NchanTOT', 'chanMap', 'xc', 'yc', 'kcoords') + probe = Bunch() + probe.NchanTOT = 0 + probe_path = Path(probe_path).resolve() + + if probe_path.suffix == '.prb': + # Support for PRB files. + contents = probe_path.read_text() + metadata = {} + exec(contents, {}, metadata) + probe.chanMap = [] + probe.xc = [] + probe.yc = [] + probe.kcoords = [] + for cg in sorted(metadata['channel_groups']): + d = metadata['channel_groups'][cg] + ch = d['channels'] + pos = d.get('geometry', {}) + probe.chanMap.append(ch) + probe.NchanTOT += len(ch) + probe.xc.append([pos[c][0] for c in ch]) + probe.yc.append([pos[c][1] for c in ch]) + probe.kcoords.append([cg for c in ch]) + probe.chanMap = np.concatenate(probe.chanMap).ravel().astype(np.int32) + probe.xc = np.concatenate(probe.xc) + probe.yc = np.concatenate(probe.yc) + probe.kcoords = np.concatenate(probe.kcoords) + + elif probe_path.suffix == '.mat': + from scipy.io import loadmat + mat = loadmat(probe_path) + probe.xc = mat['xcoords'].ravel().astype(np.float64) + nc = len(probe.xc) + probe.yc = mat['ycoords'].ravel().astype(np.float64) + probe.kcoords = mat.get('kcoords', np.zeros(nc)).ravel().astype(np.float64) + probe.chanMap = (mat['chanMap'] - 1).ravel().astype(np.int32) # NOTE: 0-indexing in Python + probe.NchanTOT = len(probe.chanMap) # NOTE: should match the # of columns in the raw data + + for n in _required_keys: + assert n in probe.keys() + + return probe From dfa584ccbcaebbf527c0f860f435d1ee57a98504 Mon Sep 17 00:00:00 2001 From: cat Date: Tue, 28 Jan 2020 04:21:11 -0500 Subject: [PATCH 02/19] reorder --- src/yass/pipeline.py | 2 +- src/yass/preprocess/run.py | 22 +- src/yass/preprocess/util.py | 4 +- src/yass/reordering/__init__.py | 0 src/yass/reordering/cluster.py | 50 +- src/yass/reordering/cptools.py | 265 +++++++++ src/yass/reordering/cuda/__init__.py | 0 src/yass/reordering/cuda/mexClustering2.cu | 217 ++++++++ src/yass/reordering/cuda/mexDistances2.cu | 70 +++ src/yass/reordering/cuda/mexGetSpikes2.cu | 256 +++++++++ src/yass/reordering/cuda/mexMPnu8.cu | 523 ++++++++++++++++++ src/yass/reordering/cuda/mexSVDsmall2.cu | 255 +++++++++ src/yass/reordering/cuda/mexThSpkPC.cu | 172 ++++++ src/yass/reordering/cuda/mexWtW2.cu | 54 ++ src/yass/reordering/event.py | 153 +++++ src/yass/reordering/preprocess.py | 5 + src/yass/reordering/reorder.py | 63 +++ src/yass/reordering/run.py | 62 +++ .../soft_assignment/template_BACKUP_31194.py | 424 ++++++++++++++ .../soft_assignment/template_BACKUP_31287.py | 424 ++++++++++++++ .../soft_assignment/template_BASE_31194.py | 412 ++++++++++++++ .../soft_assignment/template_BASE_31287.py | 412 ++++++++++++++ .../soft_assignment/template_LOCAL_31194.py | 415 ++++++++++++++ .../soft_assignment/template_LOCAL_31287.py | 415 ++++++++++++++ .../soft_assignment/template_REMOTE_31194.py | 416 ++++++++++++++ .../soft_assignment/template_REMOTE_31287.py | 416 ++++++++++++++ 26 files changed, 5476 insertions(+), 31 deletions(-) create mode 100644 src/yass/reordering/__init__.py create mode 100644 src/yass/reordering/cptools.py create mode 100644 src/yass/reordering/cuda/__init__.py create mode 100644 src/yass/reordering/cuda/mexClustering2.cu create mode 100644 src/yass/reordering/cuda/mexDistances2.cu create mode 100644 src/yass/reordering/cuda/mexGetSpikes2.cu create mode 100644 src/yass/reordering/cuda/mexMPnu8.cu create mode 100644 src/yass/reordering/cuda/mexSVDsmall2.cu create mode 100644 src/yass/reordering/cuda/mexThSpkPC.cu create mode 100644 src/yass/reordering/cuda/mexWtW2.cu create mode 100644 src/yass/reordering/event.py create mode 100644 src/yass/reordering/reorder.py create mode 100644 src/yass/reordering/run.py create mode 100644 src/yass/soft_assignment/template_BACKUP_31194.py create mode 100644 src/yass/soft_assignment/template_BACKUP_31287.py create mode 100644 src/yass/soft_assignment/template_BASE_31194.py create mode 100644 src/yass/soft_assignment/template_BASE_31287.py create mode 100644 src/yass/soft_assignment/template_LOCAL_31194.py create mode 100644 src/yass/soft_assignment/template_LOCAL_31287.py create mode 100644 src/yass/soft_assignment/template_REMOTE_31194.py create mode 100644 src/yass/soft_assignment/template_REMOTE_31287.py diff --git a/src/yass/pipeline.py b/src/yass/pipeline.py index 2ace822f..2fe1f107 100644 --- a/src/yass/pipeline.py +++ b/src/yass/pipeline.py @@ -127,7 +127,7 @@ def run(config, logger_level='INFO', clean=False, output_dir='tmp/', # preprocess start = time.time() (standardized_path, - standardized_dtype) = preprocess.run( + standardized_dtype, reorder_path) = preprocess.run( os.path.join(TMP_FOLDER, 'preprocess')) #### Block 1: Detection, Clustering, Postprocess diff --git a/src/yass/preprocess/run.py b/src/yass/preprocess/run.py index 8f7086ab..e364cbad 100644 --- a/src/yass/preprocess/run.py +++ b/src/yass/preprocess/run.py @@ -10,7 +10,7 @@ from yass import read_config from yass.preprocess.util import * from yass.reader import READER - +from yass.reordering import reorder def run(output_directory): """Preprocess pipeline: filtering, standarization and whitening filter @@ -92,10 +92,19 @@ def run(output_directory): n_channels=n_channels) logger.info('Output dtype for transformed data will be {}' .format(CONFIG.preprocess.dtype)) - + reorder_fname = os.path.join(output_directory, "reorder.npy") # Check if data already saved to disk and skip: if os.path.exists(standardized_path): - return standardized_path, standardized_params['dtype'] + if os.path.exists(reorder_fname): + return standardized_path, standardized_params['dtype'], reorder_fname + reorder.run(save_fname = reorder_fname, + standardized_fname = standardized_path, + CONFIG = CONFIG, + n_sec_chunk = 5, + dtype = CONFIG.preprocess.dtype) + return standardized_path, standardized_params['dtype'], reorder_fname + + # ********************************************** # *********** run filter & stdarize *********** @@ -168,5 +177,10 @@ def run(output_directory): with open(path_to_yaml, 'w') as f: logger.info('Saving params...') yaml.dump(standardized_params, f) + reorder.run(save_fname = reorder_fname, + standardized_fname = standardized_path, + CONFIG = CONFIG, + n_sec_chunk = 5, + dtype = CONFIG.preprocess.dtype) - return standardized_path, standardized_params['dtype'] + return standardized_path, standardized_params['dtype'], reorder_fname diff --git a/src/yass/preprocess/util.py b/src/yass/preprocess/util.py index 3e657858..77d84a9d 100644 --- a/src/yass/preprocess/util.py +++ b/src/yass/preprocess/util.py @@ -106,8 +106,8 @@ def _standardize(rec, sd=None, centers=None): rec[:,idx1] = np.divide(rec[:,idx1] - centers[idx1][None], sd[idx1]) # zero out bad channels - idx2 = np.where(sd<0.1)[0] - rec[:,idx2]=0. + #idx2 = np.where(sd<0.1)[0] + #rec[:,idx2]=0. return rec #return np.divide(rec, sd) diff --git a/src/yass/reordering/__init__.py b/src/yass/reordering/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/yass/reordering/cluster.py b/src/yass/reordering/cluster.py index 457e02dd..e6dfc4f9 100644 --- a/src/yass/reordering/cluster.py +++ b/src/yass/reordering/cluster.py @@ -100,20 +100,20 @@ def get_SpikeSample(dataRAW, row, col, params): # believe it or not, these indices grab just the right timesamples forour spikes ix = indsT + indsC * nT - # grab the data and reshape it appropriately (time samples by channels by num spikes) + clips = dataRAW.T.ravel()[ix[:, 0, :]].reshape((dt.size, row.size), order='F') # HERE return clips -def extractPCfromSnippets(proc, probe=None, params=None, Nbatch=None): +def extractPCfromSnippets(proc, yass_batch,NT, probe=None, params=None, Nbatch=None): # extracts principal components for 1D snippets of spikes from all channels # loads a subset of batches to find these snippets - NT = params.NT + #NT = params.NT nPCs = params.nPCs Nchan = probe.Nchan - + Nbatch = yass_batch batchstart = np.arange(0, NT * Nbatch + 1, NT) # extract the PCA projections @@ -121,14 +121,14 @@ def extractPCfromSnippets(proc, probe=None, params=None, Nbatch=None): CC = cp.zeros(params.nt0, dtype=np.float32) # from every 100th batch - for ibatch in range(0, Nbatch, 100): + for ibatch in range(0, yass_batch, 100): offset = Nchan * batchstart[ibatch] - dat = proc.flat[offset:offset + NT * Nchan].reshape((-1, Nchan), order='F') + dat = proc.flat[offset:offset + NT * Nchan].reshape((-1, Nchan), order='C') if dat.shape[0] == 0: continue # move data to GPU and scale it back to unit variance - dataRAW = cp.asarray(dat, dtype=np.float32) / params.scaleproc + dataRAW = cp.asarray(dat, dtype=np.float32)#/params.scaleproc # find isolated spikes from each batch row, col, mu = isolated_peaks_new(dataRAW, params) @@ -279,6 +279,7 @@ def mexThSpkPC(Params, dataRAW, wPCA, iC): # move d_x to the CPU minSize = 1 + minSize = min(maxFR, int(d_counter[0])) d_featPC = cp.zeros((NrankPC * NchanNear, minSize), dtype=np.float32, order='F') @@ -300,18 +301,19 @@ def mexThSpkPC(Params, dataRAW, wPCA, iC): return d_featPC, d_id2 -def extractPCbatch2(proc, params, probe, wPCA, ibatch, iC, Nbatch): +def extractPCbatch2(proc, params, probe, wPCA, ibatch, iC, yass_batch, NT): # this function finds threshold crossings in the data using # projections onto the pre-determined principal components # wPCA is number of time samples by number of PCs # ibatch is a scalar indicating which batch to analyze # iC is NchanNear by Nchan, indicating for each channel the nearest # channels to it - + Nbatch = yass_batch nt0min = params.nt0min spkTh = params.ThPre + nt0, NrankPC = wPCA.shape - NT, Nchan = params.NT, probe.Nchan + Nchan = probe.Nchan # starts with predefined PCA waveforms wPCA = wPCA[:, :3] @@ -321,8 +323,8 @@ def extractPCbatch2(proc, params, probe, wPCA, ibatch, iC, Nbatch): batchstart = np.arange(0, NT * Nbatch + 1, NT) # batches start at these timepoints offset = Nchan * batchstart[ibatch] - dat = proc.flat[offset:offset + NT * Nchan].reshape((-1, Nchan), order='F') - dataRAW = cp.asarray(dat, dtype=np.float32) / params.scaleproc + dat = proc.flat[offset:offset + NT * Nchan].reshape((-1, Nchan), order='C') + dataRAW = cp.asarray(dat, dtype=np.float32)#/params.scaleproc # another Params variable to take all our parameters into the C++ code Params = [NT, Nchan, NchanNear, nt0, nt0min, spkTh, NrankPC] @@ -415,7 +417,7 @@ def mexDistances2(Params, Ws, W, iMatch, iC, Wh, mus, mu): return d_id, d_x -def clusterSingleBatches(ctx): +def clusterSingleBatches(proc, params, probe, yass_batch, n_chunk_sec, nt0): """ outputs an ordering of the batches according to drift for each batch, it extracts spikes as threshold crossings and clusters them with kmeans @@ -424,12 +426,14 @@ def clusterSingleBatches(ctx): the matrix of similarity scores is then re-ordered so that low dissimilaity is along the diagonal """ - Nbatch = ctx.intermediate.Nbatch - params = ctx.params - probe = ctx.probe - raw_data = ctx.raw_data - ir = ctx.intermediate - proc = ir.proc + + #Nbatch = ctx.intermediate.Nbatch + #params = ctx.paramsƒextr + #probe = ctx.probe + #raw_data = ctx.raw_data + #ir = ctx.intermediate + #proc = ir.proc + params.nt0 = nt0 if not params.reorder: # if reordering is turned off, return consecutive order @@ -440,12 +444,12 @@ def clusterSingleBatches(ctx): Nfilt = ceil(probe.Nchan / 2) # extract PCA waveforms pooled over channels - wPCA = extractPCfromSnippets(proc, probe=probe, params=params, Nbatch=Nbatch) + wPCA = extractPCfromSnippets(proc,yass_batch, n_chunk_sec, probe=probe, params=params) Nchan = probe.Nchan niter = 10 # iterations for k-means. we won't run it to convergence to save time - nBatches = Nbatch + nBatches = yass_batch#Nbatch NchanNear = min(Nchan, 2 * 8 + 1) # initialize big arrays on the GPU to hold the results from each batch @@ -468,11 +472,9 @@ def clusterSingleBatches(ctx): # extract spikes using PCA waveforms uproj, call = extractPCbatch2( - proc, params, probe, wPCA, min(nBatches - 2, ibatch), iC, Nbatch) - print("Number of PCS is: " + str(wPCA.shape)) + proc, params, probe, wPCA, min(nBatches - 2, ibatch), iC, yass_batch, n_chunk_sec) if cp.sum(cp.isnan(uproj)) > 0: break # I am not sure what case this safeguards against.... - print(uproj.shape[1]) if uproj.shape[1] > Nfilt: # this initialize the k-means diff --git a/src/yass/reordering/cptools.py b/src/yass/reordering/cptools.py new file mode 100644 index 00000000..1de37039 --- /dev/null +++ b/src/yass/reordering/cptools.py @@ -0,0 +1,265 @@ +import ctypes +from math import ceil +from textwrap import dedent + +import numpy as np +import cupy as cp + + +# LTI filter on GPU +# ----------------------------------------------------------------------------- + +def make_kernel(kernel, name, **const_arrs): + """Compile a kernel and pass optional constant ararys.""" + mod = cp.core.core.compile_with_cache(kernel, prepend_cupy_headers=False) + b = cp.core.core.memory_module.BaseMemory() + # Pass constant arrays. + for n, arr in const_arrs.items(): + b.ptr = mod.get_global_var(n) + p = cp.core.core.memory_module.MemoryPointer(b, 0) + p.copy_from_host(arr.ctypes.data_as(ctypes.c_void_p), arr.nbytes) + return mod.get_function(name) + + +def get_lfilter_kernel(N, isfortran, reverse=False): + order = 'f' if isfortran else 'c' + return dedent(""" + const int N = %d; + __constant__ float a[N + 1]; + __constant__ float b[N + 1]; + + + __device__ int get_idx_f(int n, int col, int n_samples, int n_channels) { + return n_samples * col + n; // Fortran order. + } + __device__ int get_idx_c(int n, int col, int n_samples, int n_channels) { + return n * n_channels + col; // C order. + } + + // LTI IIR filter implemented using a difference equation. + // see https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.lfilter.html + extern "C" __global__ void lfilter( + const float* x, float* y, const int n_samples, const int n_channels){ + // Initialize the state variables. + float d[N + 1]; + for (int k = 0; k <= N; k++) { + d[k] = 0.0; + } + + float xn = 0.0; + float yn = 0.0; + + int idx = 0; + + // Column index. + int col = blockIdx.x * blockDim.x + threadIdx.x; + + + // IMPORTANT: avoid out of bounds memory accesses, which cause no errors but weird bugs. + if (col >= n_channels) return; + + for (int n = 0; n < n_samples; n++) { + idx = get_idx_%s(%s, col, n_samples, n_channels); + // Load the input element. + xn = x[idx]; + // Compute the output element. + yn = (b[0] * xn + d[0]) / a[0]; + // Update the state variables. + for (int k = 0; k < N; k++) { + d[k] = b[k + 1] * xn - a[k + 1] * yn + d[k + 1]; + } + // Update the output array. + y[idx] = yn; + } + } + """ % (N, order, 'n' if not reverse else 'n_samples - 1 - n')) + + +def _get_lfilter_fun(b, a, is_fortran=True, axis=0, reverse=False): + assert axis == 0, "Only filtering along the first axis is currently supported." + + b = np.atleast_1d(b).astype(np.float32) + a = np.atleast_1d(a).astype(np.float32) + N = max(len(b), len(a)) + if len(b) < N: + b = np.pad(b, (0, (N - len(b))), mode='constant') + if len(a) < N: + a = np.pad(a, (0, (N - len(a))), mode='constant') + assert len(a) == len(b) + kernel = get_lfilter_kernel(N - 1, is_fortran, reverse=reverse) + + lfilter = make_kernel(kernel, 'lfilter', b=b, a=a) + + return lfilter + + +def _apply_lfilter(lfilter_fun, arr): + assert isinstance(arr, cp.ndarray) + if arr.ndim == 1: + arr = arr[:, np.newaxis] + n_samples, n_channels = arr.shape + + block = (min(128, n_channels),) + grid = (int(ceil(n_channels / float(block[0]))),) + + arr = cp.asarray(arr, dtype=np.float32) + y = cp.zeros_like(arr, order='F' if arr.flags.f_contiguous else 'C', dtype=arr.dtype) + + assert arr.dtype == np.float32 + assert y.dtype == np.float32 + assert arr.shape == y.shape + + lfilter_fun(grid, block, (arr, y, int(y.shape[0]), int(y.shape[1]))) + return y + + +def lfilter(b, a, arr, axis=0, reverse=False): + """Perform a linear filter along the first axis on a GPU array.""" + lfilter_fun = _get_lfilter_fun( + b, a, is_fortran=arr.flags.f_contiguous, axis=axis, reverse=reverse) + return _apply_lfilter(lfilter_fun, arr) + + +def convolve(x, b, axis=0): + b = b.ravel() + assert axis == 0 + tmax = len(b) // 2 + xshape = x.shape + x = cp.concatenate((x, cp.zeros((tmax, x.shape[1])))) + n = x.shape[axis] + xf = cp.fft.rfft(x, axis=axis, n=n) + if xf.shape[axis] > b.shape[0]: + b = cp.pad(b, (0, n - b.shape[0]), mode='constant') + bf = cp.fft.rfft(b, n=n) + bf = bf[:, np.newaxis] + y = cp.fft.irfft(xf * bf, axis=axis) + y = y[y.shape[axis] - xshape[axis]:, :] + assert y.shape == xshape + return y + + +def svdecon(X, nPC0=None): + """ + Input: + X : m x n matrix + + Output: + X = U*S*V' + + Description: + + Does equivalent to svd(X,'econ') but faster + + Vipin Vijayan (2014) + + """ + + m, n = X.shape + + nPC = nPC0 or min(m, n) + + if m <= n: + C = cp.dot(X, X.T) + D, U = cp.linalg.eigh(C, 'U') + + ix = cp.argsort(np.abs(D))[::-1] + d = D[ix] + U = U[:, ix] + d = d[:nPC] + U = U[:, :nPC] + + V = cp.dot(X.T, U) + s = cp.sqrt(d) + V = V / s.T + S = cp.diag(s) + else: + C = cp.dot(X.T, X) + D, V = cp.linalg.eigh(C) + + ix = cp.argsort(cp.abs(D))[::-1] + d = D[ix] + V = V[:, ix] + + # convert evecs from X'*X to X*X'. the evals are the same. + U = cp.dot(X, V) + s = cp.sqrt(d) + U = U / s.T + S = cp.diag(s) + + return U, S, V + + +def svdecon_cpu(X): + U, S, V = np.linalg.svd(cp.asnumpy(X)) + return U, np.diag(S), V + + +def free_gpu_memory(): + mempool = cp.get_default_memory_pool() + pinned_mempool = cp.get_default_pinned_memory_pool() + mempool.free_all_blocks() + pinned_mempool.free_all_blocks() + + +# Work around CuPy bugs and limitations +# ----------------------------------------------------------------------------- + +def mean(x, axis=0): + if x.ndim == 1: + return cp.mean(x) if x.size else cp.nan + else: + s = list(x.shape) + del s[axis] + return ( + cp.mean(x, axis=axis) if x.shape[axis] > 0 + else cp.zeros(s, dtype=x.dtype, order='F')) + + +def median(a, axis=0): + """Compute the median of a CuPy array on the GPU.""" + a = cp.asarray(a) + + if axis is None: + sz = a.size + else: + sz = a.shape[axis] + if sz % 2 == 0: + szh = sz // 2 + kth = [szh - 1, szh] + else: + kth = [(sz - 1) // 2] + + part = cp.partition(a, kth, axis=axis) + + if part.shape == (): + # make 0-D arrays work + return part.item() + if axis is None: + axis = 0 + + indexer = [slice(None)] * part.ndim + index = part.shape[axis] // 2 + if part.shape[axis] % 2 == 1: + # index with slice to allow mean (below) to work + indexer[axis] = slice(index, index + 1) + else: + indexer[axis] = slice(index - 1, index + 1) + + return cp.mean(part[indexer], axis=axis) + + +def var(x): + return cp.var(x, ddof=1) if x.size > 0 else cp.nan + + +def ones(shape, dtype=None, order=None): + # HACK: cp.ones() has no order kwarg at the moment ! + x = cp.zeros(shape, dtype=dtype, order=order) + x.fill(1) + return x + + +def zscore(a, axis=0): + mns = a.mean(axis=axis) + sstd = a.std(axis=axis, ddof=0) + return (a - mns) / sstd diff --git a/src/yass/reordering/cuda/__init__.py b/src/yass/reordering/cuda/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/yass/reordering/cuda/mexClustering2.cu b/src/yass/reordering/cuda/mexClustering2.cu new file mode 100644 index 00000000..0eb78dfd --- /dev/null +++ b/src/yass/reordering/cuda/mexClustering2.cu @@ -0,0 +1,217 @@ +__global__ void computeCost(const double *Params, const float *uproj, const float *mu, const float *W, + const bool *match, const int *iC, const int *call, float *cmax){ + + int NrankPC,j, NchanNear, tid, bid, Nspikes, Nthreads, k, my_chan, this_chan, Nchan; + float xsum = 0.0f, Ci, lam; + + Nspikes = (int) Params[0]; + NrankPC = (int) Params[1]; + Nthreads = blockDim.x; + lam = (float) Params[5]; + NchanNear = (int) Params[6]; + Nchan = (int) Params[7]; + + tid = threadIdx.x; + bid = blockIdx.x; + + while(tid max_running){ + id[tind] = ind; + max_running = cmax[tind + ind*Nspikes]; + } + + + cx[tind] = max_running; + + tind += Nblocks*Nthreads; + } +} +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void average_snips(const double *Params, const int *iC, const int *call, + const int *id, const float *uproj, const float *cmax, float *WU){ + + int my_chan, this_chan, tidx, tidy, bid, ind, Nspikes, NrankPC, NchanNear, Nchan; + float xsum = 0.0f; + + Nspikes = (int) Params[0]; + NrankPC = (int) Params[1]; + Nchan = (int) Params[7]; + NchanNear = (int) Params[6]; + + tidx = threadIdx.x; + tidy = threadIdx.y; + bid = blockIdx.x; + + for(ind=0; ind Cmax){ + Cmax = Cf*Cf /(1+j); + kmax = j + t*Nsum; + } + } + } + datasum[tid0 + NT * i] = Cmax; + kkmax[tid0 + NT * i] = kmax; + } + tid0 += blockDim.x * gridDim.x; + } +} + +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void Conv1D(const double *Params, const float *data, const float *W, float *conv_sig){ + volatile __shared__ float sW[81*NrankMax], sdata[(Nthreads+81)]; + float y; + int tid, tid0, bid, i, nid, Nrank, NT, nt0, Nchan; + + tid = threadIdx.x; + bid = blockIdx.x; + NT = (int) Params[0]; + Nrank = (int) Params[14]; + nt0 = (int) Params[4]; + Nchan = (int) Params[9]; + + if(tid Cbest + 1e-6){ + Cbest = Cf; + ibest = i; + kbest = kkmax[tid0 + NT*i]; + } + } + err[tid0] = Cbest; + ftype[tid0] = ibest; + kall[tid0] = kbest; + + tid0 += blockDim.x * gridDim.x; + } +} + +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void cleanup_spikes(const double *Params, const float *err, + const int *ftype, float *x, int *st, int *id, int *counter){ + + int lockout, indx, tid, bid, NT, tid0, j, t0; + volatile __shared__ float sdata[Nthreads+2*81+1]; + bool flag=0; + float err0, Th; + + lockout = (int) Params[4] - 1; + tid = threadIdx.x; + bid = blockIdx.x; + + NT = (int) Params[0]; + tid0 = bid * blockDim.x ; + Th = (float) Params[2]; + + while(tid0 Th*Th && t0err0){ + flag = 1; + break; + } + if(flag==0){ + indx = atomicAdd(&counter[0], 1); + if (indx=0 & t=0 && tid0 Cbest + 1e-6){ + Cnextbest = Cbest; + Cbest = Cf; + ibest = i; + } + else + if (Cf > Cnextbest + 1e-6) + Cnextbest = Cf; + } + err[tid0] = Cbest; + eloss[tid0] = Cbest - Cnextbest; + ftype[tid0] = ibest; + + tid0 += blockDim.x * gridDim.x; + } +} + +// THIS UPDATE DOES NOT UPDATE ELOSS? +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void bestFilterUpdate(const double *Params, const float *data, + const float *mu, float *err, float *eloss, int *ftype, const int *st, const int *id, const int *counter){ + int tid, ind, i,t, NT, Nfilt, ibest = 0, nt0; + float Cf, Cbest, lam, b, a, Cnextbest; + + tid = threadIdx.x; + NT = (int) Params[0]; + Nfilt = (int) Params[1]; + lam = (float) Params[7]; + nt0 = (int) Params[4]; + + + // we only need to compute this at updated locations + ind = counter[1] + blockIdx.x; + + if (ind=0 && t Cbest + 1e-6){ + Cnextbest = Cbest; + Cbest = Cf; + ibest = i; + } + else + if (Cf > Cnextbest + 1e-6) + Cnextbest = Cf; + } + err[t] = Cbest; + ftype[t] = ibest; + } + } +} + +////////////////////////////////////////////////////////////////////////////////////////// +__global__ void cleanup_spikes(const double *Params, const float *data, + const float *mu, const float *err, const float *eloss, const int *ftype, int *st, + int *id, float *x, float *y, float *z, int *counter){ + + int lockout, indx, tid, bid, NT, tid0, j, id0, t0; + volatile __shared__ float sdata[Nthreads+2*81+1]; + bool flag=0; + float err0, Th; + + lockout = (int) Params[4] - 1; + tid = threadIdx.x; + bid = blockIdx.x; + + NT = (int) Params[0]; + tid0 = bid * blockDim.x ; + Th = (float) Params[2]; + //lam = (float) Params[7]; + + while(tid0Th*Th){ + flag = 0; + for(j=-lockout;j<=lockout;j++) + if(sdata[tid+lockout+j]>err0){ + flag = 1; + break; + } + if(flag==0){ + indx = atomicAdd(&counter[0], 1); + if (indxTh){ + if (id[currInd]==bid){ + if (tidx==0 && threadIdx.y==0) + nsp[bid]++; + + tidy = threadIdx.y; + while (tidyThS){ + + tidy = threadIdx.y; + // only do this if the spike is "BAD" + while (tidy xmax){ + xmax = abs(sW[t]); + imax = t; + } + + tid = threadIdx.x; + // shift by imax - tmax + for (k=0;k xmax){ + xmax = abs(sWup[t]); + imax = t; + sgnmax = copysign(1.0f, sWup[t]); + } + + // interpolate by imax + for (k=0;k Cf){ + flag = false; + break; + } + } + + if (flag){ + iChan = iC[NchanNear * i]; + if (Cf>spkTh){ + d = (double) dataraw[tid0+nt0min-1 + NT*iChan]; // + if (d > Cf-1e-6){ + // this is a hit, atomicAdd and return spikes + indx = atomicAdd(&counter[0], 1); + if (indxspkTh) + conv_sig[tid0 + tid + NT*bid] = y; + + tid0+=Nthreads; + __syncthreads(); + } +} diff --git a/src/yass/reordering/cuda/mexWtW2.cu b/src/yass/reordering/cuda/mexWtW2.cu new file mode 100644 index 00000000..ccd32fb6 --- /dev/null +++ b/src/yass/reordering/cuda/mexWtW2.cu @@ -0,0 +1,54 @@ +const int nblock = 32; +////////////////////////////////////////////////////////////////////////////////////////// + +__global__ void crossFilter(const double *Params, const float *W1, const float *W2, + const float *UtU, float *WtW){ + __shared__ float shW1[nblock*81], shW2[nblock*81]; + + float x; + int nt0, tidx, tidy , bidx, bidy, i, Nfilt, t, tid1, tid2; + + tidx = threadIdx.x; + tidy = threadIdx.y; + bidx = blockIdx.x; + bidy = blockIdx.y; + + Nfilt = (int) Params[1]; + nt0 = (int) Params[9]; + + tid1 = tidx + bidx*nblock; + + tid2 = tidy + bidx*nblock; + if (tid2()`.""" + r = re.match("^on_(.+)$", func.__name__) + if r: + event = r.group(1) + else: + raise ValueError("The function name should be " + "`on_`().") + return event + + @contextmanager + def silent(self): + """Prevent all callbacks to be called if events are raised + in the context manager. + """ + self.is_silent = not(self.is_silent) + yield + self.is_silent = not(self.is_silent) + + def connect(self, func=None, event=None, **kwargs): + """Register a callback function to a given event. + + To register a callback function to the `spam` event, where `obj` is + an instance of a class deriving from `EventEmitter`: + + ```python + @obj.connect + def on_spam(arg1, arg2): + pass + ``` + + This is called when `obj.emit('spam', arg1, arg2)` is called. + + Several callback functions can be registered for a given event. + + The registration order is conserved and may matter in applications. + + """ + if func is None: + return partial(self.connect, event=event, **kwargs) + + # Get the event name from the function. + if event is None: + event = self._get_on_name(func) + + # We register the callback function. + self._callbacks.append((event, func, kwargs)) + + return func + + def unconnect(self, *items): + """Unconnect specified callback functions.""" + self._callbacks = [ + (event, f, kwargs) + for (event, f, kwargs) in self._callbacks + if f not in items] + + def emit(self, event, *args, **kwargs): + """Call all callback functions registered with an event. + + Any positional and keyword arguments can be passed here, and they will + be forwarded to the callback functions. + + Return the list of callback return results. + + """ + if self.is_silent: + return + logger.log( + 5, "Emit %s(%s, %s)", event, + ', '.join(map(str, args)), ', '.join('%s=%s' % (k, v) for k, v in kwargs.items())) + # Call the last callback if this is a single event. + single = kwargs.pop('single', None) + res = [] + # Put `last=True` callbacks at the end. + callbacks = [c for c in self._callbacks if not c[-1].get('last', None)] + callbacks += [c for c in self._callbacks if c[-1].get('last', None)] + for e, f, k in callbacks: + if e == event: + f_name = getattr(f, '__qualname__', getattr(f, '__name__', str(f))) + logger.log(5, "Callback %s.", f_name) + res.append(f(*args, **kwargs)) + if single: + return res[-1] + return res + + +#------------------------------------------------------------------------------ +# Global event system +#------------------------------------------------------------------------------ + +_EVENT = EventEmitter() + +emit = _EVENT.emit +connect = _EVENT.connect +unconnect = _EVENT.unconnect +silent = _EVENT.silent +set_silent = _EVENT.set_silent +reset = _EVENT.reset diff --git a/src/yass/reordering/preprocess.py b/src/yass/reordering/preprocess.py index c93af5b6..16180294 100644 --- a/src/yass/reordering/preprocess.py +++ b/src/yass/reordering/preprocess.py @@ -418,8 +418,12 @@ def preprocess(ctx): # apply filters and median subtraction buff = cp.asarray(buff, dtype=np.float32) + print("weeeeeee") + print(buff.shape) datr = gpufilter(buff, chanMap=probe.chanMap, fs=fs, fshigh=fshigh, fslow=fslow) + print("weeeee filtered") + print(datr.shape) datr = datr[ioffset:ioffset + NT, :] # remove timepoints used as buffers datr = cp.dot(datr, Wrot) # whiten the data and scale by 200 for int16 range @@ -431,3 +435,4 @@ def preprocess(ctx): # write this batch to binary file datcpu.tofile(fw) + print(datcpu.shape) \ No newline at end of file diff --git a/src/yass/reordering/reorder.py b/src/yass/reordering/reorder.py new file mode 100644 index 00000000..8365ea24 --- /dev/null +++ b/src/yass/reordering/reorder.py @@ -0,0 +1,63 @@ +import yass.reordering.utils +import yass.reordering.cluster +import yass.reordering.default_params +import yass.reordering +from yass import read_config +import numpy as np +from yass.config import Config +from yass.reordering.preprocess import get_good_channels +import os +import cupy as cp +#initialize object + + +class PARAM: + pass + +class PROBE: + pass + +def run(save_fname, standardized_fname, CONFIG,n_sec_chunk, nPCs = 3, nt0 = 61, reorder = True, dtype = np.float32 ): + + + + params = PARAM() + probe = PROBE() + + params.sigmaMask = 30 + params.Nchan = CONFIG.recordings.n_channels + params.nPCs = nPCs + params.fs = CONFIG.recordings.sampling_rate + + #magic numbers from KS + #params.fshigh = 150. + #params.minfr_goodchannels = 0.1 + params.Th = [10, 4] + + #spkTh is the PCA threshold for detecting a spike + params.spkTh = -6 + params.ThPre = 8 + ## + params.loc_range = [5, 4] + params.long_range = [30, 6] + + probe.chanMap = np.arange(params.Nchan) + probe.xc = CONFIG.geom[:, 0] + probe.yc = CONFIG.geom[:, 1] + probe.kcoords = np.zeros(params.Nchan) + probe.Nchan = params.Nchan + shape = (params.Nchan, CONFIG.rec_len) + standardized_mmemap = np.memmap(standardized_fname, order = "F", dtype = dtype) + params.Nbatch = np.ceil(CONFIG.rec_len/(n_sec_chunk*CONFIG.recordings.sampling_rate)).astype(np.int16) + params.reorder = reorder + params.nt0min = np.ceil(20 * nt0 / 61).astype(np.int16) + + + result = yass.reordering.cluster.clusterSingleBatches(proc = standardized_mmemap, + params = params, + probe = probe, + yass_batch = params.Nbatch, + n_chunk_sec = int(n_sec_chunk*CONFIG.recordings.sampling_rate), + nt0 = nt0) + np.save(save_fname, cp.asnumpy(result['iorig'])) + diff --git a/src/yass/reordering/run.py b/src/yass/reordering/run.py new file mode 100644 index 00000000..266c5a38 --- /dev/null +++ b/src/yass/reordering/run.py @@ -0,0 +1,62 @@ +import yass.reordering.utils +import yass.reordering.cluster +import yass.reordering.default_params +import yass.reordering +from yass import read_config +import numpy as np +from yass.config import Config +from yass.reordering.preprocess import get_good_channels +import os + +#initialize object + + +class PARAM: + pass + +class PROBE: + pass + +def run(save_path, standardized_fname, CONFIG,n_sec_chunk, nPCs = 3, nt0 = 61, reorder = True, dtype = np.float32 ): + + + + params = PARAM() + probe = PROBE() + + params.sigmaMask = 30 + params.Nchan = CONFIG.recordings.n_channels + params.nPCs = nPCs + params.fs = CONFIG.recordings.sampling_rate + + #magic numbers from KS + #params.fshigh = 150. + #params.minfr_goodchannels = 0.1 + params.Th = [10, 4] + + #spkTh is the PCA threshold for detecting a spike + params.spkTh = -6 + params.ThPre = 8 + ## + params.loc_range = [5, 4] + params.long_range = [30, 6] + + probe.chanMap = np.arange(params.Nchan) + probe.xc = CONFIG.geom[:, 0] + probe.yc = CONFIG.geom[:, 1] + probe.kcoords = np.zeros(params.Nchan) + probe.Nchan = params.Nchan + shape = (params.Nchan, CONFIG.rec_len) + standardized_mmemap = np.memmap(standardized_fname, order = "F", dtype = dtype) + params.Nbatch = np.ceil(CONFIG.rec_len/(n_sec_chunk*CONFIG.recordings.sampling_rate)).astype(np.int16) + params.reorder = reorder + params.nt0min = np.ceil(20 * nt0 / 61).astype(np.int16) + + + result = yass.reordering.cluster.clusterSingleBatches(proc = standardized_mmemap, + params = params, + probe = probe, + yass_batch = params.Nbatch, + n_chunk_sec = int(n_sec_chunk*CONFIG.recordings.sampling_rate), + nt0 = nt0) + diff --git a/src/yass/soft_assignment/template_BACKUP_31194.py b/src/yass/soft_assignment/template_BACKUP_31194.py new file mode 100644 index 00000000..62948032 --- /dev/null +++ b/src/yass/soft_assignment/template_BACKUP_31194.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Nov 18 21:39:48 2019 + +@author: kevin +""" + +import numpy as np +from tqdm import tqdm +import scipy.spatial.distance as dist +import torch +import cudaSpline as deconv +from scipy.interpolate import splrep +from numpy.linalg import inv as inv + +def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): + if knots is None: + knots = np.arange(len(curve) + prepad + postpad) + return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) + +def transform_template(template, knots=None, prepad=7, postpad=3, order=3): + + if knots is None: + knots = np.arange(len(template.data[0]) + prepad + postpad) + splines = [ + fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) + for curve in template.data.cpu().numpy() + ] + coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') + return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) + +def get_cov_matrix(spat_cov, geom): + posistion = geom + dist_matrix = dist.squareform(dist.pdist(geom )) + + cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) + + for i in range(posistion.shape[0]): + for j in range(posistion.shape[0]): + if dist_matrix[i, j] > np.max(spat_cov[:, 1]): + cov_matrix[i, j] = 0 + continue + idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] + if len(idx) == 0: + cov_matrix[i, j] = 0 + continue + cov_matrix[i, j] = spat_cov[idx, 0] + return cov_matrix + +#Soft assign object + +class TEMPLATE_ASSIGN_OBJECT(object): + def __init__(self, fname_spike_train, fname_templates, fname_shifts, + reader_residual, spat_cov, temp_cov, channel_idx, geom, + large_unit_threshold = 5, n_chans = 5, rec_chans = 512, + sim_units = 3, temp_thresh= np.inf, lik_window = 50): + + #get the variance of the residual: + self.temp_thresh = temp_thresh + self.rec_chans = rec_chans + self.sim_units = sim_units + self.templates = np.load(fname_templates).astype('float32') + self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) + self.spike_train = np.load(fname_spike_train) + self.spike_train_og = np.load(fname_spike_train) + #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] + self.idx_included = set([]) + self.units_in = set([]) + self.shifts = np.load(fname_shifts) + self.reader_residual = reader_residual + self.spat_cov = get_cov_matrix(spat_cov, geom) + self.temp_cov = temp_cov[:lik_window, :lik_window] + self.channel_index = channel_idx + self.n_neigh_chans = self.channel_index.shape[1] + self.n_chans = n_chans + self.n_units, self.n_times, self.n_channels = self.templates.shape + + self.n_total_spikes = self.spike_train.shape[0] + + #get residual variance + self.get_residual_variance() + + self.get_similar() + self.exclude_large_units(large_unit_threshold) + #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] + self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) + self.idx_included = np.logical_and(np.logical_and(self.spike_train_og[:, 0] < reader_residual.rec_len - self.n_times//2, self.spike_train_og[:, 0] > self.n_times//2), np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) + self.spike_train = self.spike_train_og[self.idx_included] + self.shifts = self.shifts[self.idx_included] + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + + self.aligned_template_list = [] + self.coeff_list = [] + self.preprocess_templates_and_spike_times() + self.chan_list = [] + + for i in range(0, self.sim_units): + diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) + for j in range(0, self.n_units): + diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) + + self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) + self.aligned_template_list.append(diff_array) + + #align orignal templates at end + self.aligned_template_list.append(self.templates_aligned) + self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) + + #get aligned templates + + + self.move_to_torch() + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + self.chans = torch.from_numpy(self.chans) + self.get_kronecker() + def get_residual_variance(self): + num = int(60/self.reader_residual.n_sec_chunk) + var_array = np.zeros(num) + for batch_id in range(num): + var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) + self.resid_var = np.mean(var_array) + + def get_kronecker(self): + self.cov_list = [] + inv_temp = inv(self.temp_cov) + for unit in range(self.n_units): + chans = self.chans[unit] + chans = chans[chans < self.rec_chans] + indices = np.ix_(chans,chans) + covar = np.kron(inv(self.spat_cov[indices]), inv_temp) + self.cov_list.append(covar) + self.cov_list = np.asarray(self.cov_list) + self.cov_list = torch.from_numpy(self.cov_list).half().cuda() + + def get_similar(self): + max_time = self.templates.argmax(1).astype("int16") + padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) + reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) + for unit in range(self.n_units): + for chan in range(self.rec_chans): + reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] + + see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + for i in range(self.n_units): + sorted_see = np.sort(see[i])[0:self.sim_units] + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(reduced[i]))) + #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + # self.units_in.add(i) + if sorted_see[1]/norm < self.temp_thresh: + self.units_in.add(i) + + ''' + def get_similar(self): + see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + units_in = [] + for i in range(self.n_units): + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(self.templates[i]))) + if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + self.units_in.add(i) + #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) + + #units_in= np.asarray(units_in) + #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(in_spikes) + ''' + #shift secondary template + def shift_template(self, template, shift): + if shift == 0: + return template + if shift > 0: + return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] + else: + return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] + + def preprocess_templates_and_spike_times(self): + + # templates on neighboring channels + self.mcs = self.templates.ptp(1).argmax(1) + + #template used for alignment defined on neighboring channels + templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) + + #template returned for likilihood calculation- defined on channels with largest ptp + return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) + for k in range(self.n_units): + neigh_chans = self.channel_index[self.mcs[k]] + neigh_chans = neigh_chans[neigh_chans 0: + buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) + return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) + + # get ailgned templates + t_in = np.arange(buffer_size, buffer_size + self.n_times) + + + templates_aligned = np.zeros((self.n_units, + self.n_times, + self.n_chans), 'float32') + for k in range(self.n_units): + t_in_temp = t_in + self.temp_shifts[k] + templates_aligned[k] = return_templates[k,t_in_temp] + + self.templates_aligned = templates_aligned + self.templates_aligned_numpy= templates_aligned + # shift spike times according to the alignment + self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] + + #shifted neighboring template according to shift in primary template + def subtract_template(self, primary_unit, neighbor_unit): + primary_unit_shift = self.temp_shifts[primary_unit] + shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) + return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] + + def exclude_large_units(self, threshold): + + norms = np.zeros(self.n_units) + for j in range(self.n_units): + temp = self.templates[j] + vis_chan = np.where(temp.ptp(0) > 1)[0] + norms[j] = np.sum(np.square(temp[:, vis_chan])) + + + self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) + + #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(self.idx_included.intersection(idx_in)) + def move_to_torch(self): + self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] + self.spike_train = torch.from_numpy(self.spike_train).long().cuda() + self.shifts = torch.from_numpy(self.shifts).float().cuda() + + self.mcs = torch.from_numpy(self.mcs) + + def get_bspline_coeffs(self, template_aligned): + + n_data, n_times, n_channels = template_aligned.shape + + channels = torch.arange(n_channels).cuda() + temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() + + temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) + coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) + return coeffs + def get_shifted_templates(self, temp_ids, shifts, iteration): + temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() + shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() + + n_sample_run = 1000 + n_times = self.aligned_template_list[iteration].shape[1] + + idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) + + shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() + for j in range(len(idx_run)-1): + ii_start = idx_run[j] + ii_end =idx_run[j+1] + obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() + times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() + deconv.subtract_splines(obj, + times, + shifts[ii_start:ii_end], + temp_ids[ii_start:ii_end], + self.coeff_list[iteration], + torch.full( (ii_end - ii_start, ), 2 ).cuda()) + obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) + shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) + + return shifted_templates + + def get_liklihood(self, unit, snip): + chans = self.chans[unit] + chans = chans < self.rec_chans + log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) + return log_prob + def compute_soft_assignment(self): + + log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() + + # batch offsets + offsets = torch.from_numpy(self.reader_residual.idx_list[:, 0] + - self.reader_residual.buffer).cuda().long() + + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] + + spike_train_batch = self.spike_train[idx_in] + + + +<<<<<<< HEAD + spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) +======= + spike_train_batch = self.spike_train[idx_in] + spike_train_batch[:, 0] -= offsets[batch_id] + +>>>>>>> 3076e420ae3d8ecd83625301c6c3e3ed46f7cca7 + shift_batch = self.shifts[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + + shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] + shifted_templates = [element for element in shifted_templates] + + + clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] + unit_idx = spike_train_batch[:, 1] + logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() + for i in range(len(clean_wfs)): + cov_array = self.cov_list[unit_idx] + + restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] + unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) + temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) + result = torch.matmul( temp, unraveled[:, :, None].half()) + logs[:, i] = result.reshape(-1) + ''' + log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) + logs[:, i] = np.asarray(log_vec) + ''' + ''' + for j, spike in enumerate(clean_wfs[i]): + rel_unit = self.similar_array[unit_idx[j]][i] + logs[j, i]= self.get_liklihood(rel_unit, spike) + ''' + log_probs[idx_in] = logs + + pbar.update() + + + + self.log_probs = log_probs.cpu().numpy() + return log_probs.cpu().numpy() + def clean_wave_forms(self, spike_idx, unit): + return_wfs = torch.zeros((spike_idx.shape[0],self.templates.shape[1], self.n_chans)) + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] + s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] + + idx_in = torch.nonzero((s1 & s2))[:,0] + + spike_train_batch = self.spike_train[spike_idx] + spike_train_batch = spike_train_batch[idx_in] + spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) + + shift_batch = self.shifts[spike_idx] + shift_batch = shift_batch[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + + shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) + clean_wfs = resid_snippets + shifted_og + return_wfs[idx_in] = clean_wfs.cpu() + + return return_wfs.cpu().numpy() + + def get_assign_probs(self, log_lik_array): + fix = log_lik_array*-.5 + fix = fix - fix.max(1)[:, None] + probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] + self.probs = probs + return probs + + def run(self): + + #construct array to identify soft assignment units + unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + for unit in range(self.n_units): + row_idx= self.spike_train_og[:, 1] == unit + unit_assignment[row_idx, :] = self.similar_array[unit, :] + + + log_probs = self.compute_soft_assignment() + self.log_probs = log_probs + probs = self.get_assign_probs(log_probs) + replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) + + replace_probs[:, 0] = 1 + replace_probs[self.idx_included, :] = probs + + replace_log[self.idx_included, :] = log_probs + + return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file diff --git a/src/yass/soft_assignment/template_BACKUP_31287.py b/src/yass/soft_assignment/template_BACKUP_31287.py new file mode 100644 index 00000000..62948032 --- /dev/null +++ b/src/yass/soft_assignment/template_BACKUP_31287.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Nov 18 21:39:48 2019 + +@author: kevin +""" + +import numpy as np +from tqdm import tqdm +import scipy.spatial.distance as dist +import torch +import cudaSpline as deconv +from scipy.interpolate import splrep +from numpy.linalg import inv as inv + +def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): + if knots is None: + knots = np.arange(len(curve) + prepad + postpad) + return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) + +def transform_template(template, knots=None, prepad=7, postpad=3, order=3): + + if knots is None: + knots = np.arange(len(template.data[0]) + prepad + postpad) + splines = [ + fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) + for curve in template.data.cpu().numpy() + ] + coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') + return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) + +def get_cov_matrix(spat_cov, geom): + posistion = geom + dist_matrix = dist.squareform(dist.pdist(geom )) + + cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) + + for i in range(posistion.shape[0]): + for j in range(posistion.shape[0]): + if dist_matrix[i, j] > np.max(spat_cov[:, 1]): + cov_matrix[i, j] = 0 + continue + idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] + if len(idx) == 0: + cov_matrix[i, j] = 0 + continue + cov_matrix[i, j] = spat_cov[idx, 0] + return cov_matrix + +#Soft assign object + +class TEMPLATE_ASSIGN_OBJECT(object): + def __init__(self, fname_spike_train, fname_templates, fname_shifts, + reader_residual, spat_cov, temp_cov, channel_idx, geom, + large_unit_threshold = 5, n_chans = 5, rec_chans = 512, + sim_units = 3, temp_thresh= np.inf, lik_window = 50): + + #get the variance of the residual: + self.temp_thresh = temp_thresh + self.rec_chans = rec_chans + self.sim_units = sim_units + self.templates = np.load(fname_templates).astype('float32') + self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) + self.spike_train = np.load(fname_spike_train) + self.spike_train_og = np.load(fname_spike_train) + #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] + self.idx_included = set([]) + self.units_in = set([]) + self.shifts = np.load(fname_shifts) + self.reader_residual = reader_residual + self.spat_cov = get_cov_matrix(spat_cov, geom) + self.temp_cov = temp_cov[:lik_window, :lik_window] + self.channel_index = channel_idx + self.n_neigh_chans = self.channel_index.shape[1] + self.n_chans = n_chans + self.n_units, self.n_times, self.n_channels = self.templates.shape + + self.n_total_spikes = self.spike_train.shape[0] + + #get residual variance + self.get_residual_variance() + + self.get_similar() + self.exclude_large_units(large_unit_threshold) + #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] + self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) + self.idx_included = np.logical_and(np.logical_and(self.spike_train_og[:, 0] < reader_residual.rec_len - self.n_times//2, self.spike_train_og[:, 0] > self.n_times//2), np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) + self.spike_train = self.spike_train_og[self.idx_included] + self.shifts = self.shifts[self.idx_included] + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + + self.aligned_template_list = [] + self.coeff_list = [] + self.preprocess_templates_and_spike_times() + self.chan_list = [] + + for i in range(0, self.sim_units): + diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) + for j in range(0, self.n_units): + diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) + + self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) + self.aligned_template_list.append(diff_array) + + #align orignal templates at end + self.aligned_template_list.append(self.templates_aligned) + self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) + + #get aligned templates + + + self.move_to_torch() + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + self.chans = torch.from_numpy(self.chans) + self.get_kronecker() + def get_residual_variance(self): + num = int(60/self.reader_residual.n_sec_chunk) + var_array = np.zeros(num) + for batch_id in range(num): + var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) + self.resid_var = np.mean(var_array) + + def get_kronecker(self): + self.cov_list = [] + inv_temp = inv(self.temp_cov) + for unit in range(self.n_units): + chans = self.chans[unit] + chans = chans[chans < self.rec_chans] + indices = np.ix_(chans,chans) + covar = np.kron(inv(self.spat_cov[indices]), inv_temp) + self.cov_list.append(covar) + self.cov_list = np.asarray(self.cov_list) + self.cov_list = torch.from_numpy(self.cov_list).half().cuda() + + def get_similar(self): + max_time = self.templates.argmax(1).astype("int16") + padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) + reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) + for unit in range(self.n_units): + for chan in range(self.rec_chans): + reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] + + see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + for i in range(self.n_units): + sorted_see = np.sort(see[i])[0:self.sim_units] + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(reduced[i]))) + #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + # self.units_in.add(i) + if sorted_see[1]/norm < self.temp_thresh: + self.units_in.add(i) + + ''' + def get_similar(self): + see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + units_in = [] + for i in range(self.n_units): + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(self.templates[i]))) + if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + self.units_in.add(i) + #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) + + #units_in= np.asarray(units_in) + #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(in_spikes) + ''' + #shift secondary template + def shift_template(self, template, shift): + if shift == 0: + return template + if shift > 0: + return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] + else: + return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] + + def preprocess_templates_and_spike_times(self): + + # templates on neighboring channels + self.mcs = self.templates.ptp(1).argmax(1) + + #template used for alignment defined on neighboring channels + templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) + + #template returned for likilihood calculation- defined on channels with largest ptp + return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) + for k in range(self.n_units): + neigh_chans = self.channel_index[self.mcs[k]] + neigh_chans = neigh_chans[neigh_chans 0: + buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) + return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) + + # get ailgned templates + t_in = np.arange(buffer_size, buffer_size + self.n_times) + + + templates_aligned = np.zeros((self.n_units, + self.n_times, + self.n_chans), 'float32') + for k in range(self.n_units): + t_in_temp = t_in + self.temp_shifts[k] + templates_aligned[k] = return_templates[k,t_in_temp] + + self.templates_aligned = templates_aligned + self.templates_aligned_numpy= templates_aligned + # shift spike times according to the alignment + self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] + + #shifted neighboring template according to shift in primary template + def subtract_template(self, primary_unit, neighbor_unit): + primary_unit_shift = self.temp_shifts[primary_unit] + shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) + return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] + + def exclude_large_units(self, threshold): + + norms = np.zeros(self.n_units) + for j in range(self.n_units): + temp = self.templates[j] + vis_chan = np.where(temp.ptp(0) > 1)[0] + norms[j] = np.sum(np.square(temp[:, vis_chan])) + + + self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) + + #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(self.idx_included.intersection(idx_in)) + def move_to_torch(self): + self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] + self.spike_train = torch.from_numpy(self.spike_train).long().cuda() + self.shifts = torch.from_numpy(self.shifts).float().cuda() + + self.mcs = torch.from_numpy(self.mcs) + + def get_bspline_coeffs(self, template_aligned): + + n_data, n_times, n_channels = template_aligned.shape + + channels = torch.arange(n_channels).cuda() + temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() + + temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) + coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) + return coeffs + def get_shifted_templates(self, temp_ids, shifts, iteration): + temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() + shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() + + n_sample_run = 1000 + n_times = self.aligned_template_list[iteration].shape[1] + + idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) + + shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() + for j in range(len(idx_run)-1): + ii_start = idx_run[j] + ii_end =idx_run[j+1] + obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() + times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() + deconv.subtract_splines(obj, + times, + shifts[ii_start:ii_end], + temp_ids[ii_start:ii_end], + self.coeff_list[iteration], + torch.full( (ii_end - ii_start, ), 2 ).cuda()) + obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) + shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) + + return shifted_templates + + def get_liklihood(self, unit, snip): + chans = self.chans[unit] + chans = chans < self.rec_chans + log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) + return log_prob + def compute_soft_assignment(self): + + log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() + + # batch offsets + offsets = torch.from_numpy(self.reader_residual.idx_list[:, 0] + - self.reader_residual.buffer).cuda().long() + + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] + + spike_train_batch = self.spike_train[idx_in] + + + +<<<<<<< HEAD + spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) +======= + spike_train_batch = self.spike_train[idx_in] + spike_train_batch[:, 0] -= offsets[batch_id] + +>>>>>>> 3076e420ae3d8ecd83625301c6c3e3ed46f7cca7 + shift_batch = self.shifts[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + + shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] + shifted_templates = [element for element in shifted_templates] + + + clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] + unit_idx = spike_train_batch[:, 1] + logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() + for i in range(len(clean_wfs)): + cov_array = self.cov_list[unit_idx] + + restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] + unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) + temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) + result = torch.matmul( temp, unraveled[:, :, None].half()) + logs[:, i] = result.reshape(-1) + ''' + log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) + logs[:, i] = np.asarray(log_vec) + ''' + ''' + for j, spike in enumerate(clean_wfs[i]): + rel_unit = self.similar_array[unit_idx[j]][i] + logs[j, i]= self.get_liklihood(rel_unit, spike) + ''' + log_probs[idx_in] = logs + + pbar.update() + + + + self.log_probs = log_probs.cpu().numpy() + return log_probs.cpu().numpy() + def clean_wave_forms(self, spike_idx, unit): + return_wfs = torch.zeros((spike_idx.shape[0],self.templates.shape[1], self.n_chans)) + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] + s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] + + idx_in = torch.nonzero((s1 & s2))[:,0] + + spike_train_batch = self.spike_train[spike_idx] + spike_train_batch = spike_train_batch[idx_in] + spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) + + shift_batch = self.shifts[spike_idx] + shift_batch = shift_batch[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + + shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) + clean_wfs = resid_snippets + shifted_og + return_wfs[idx_in] = clean_wfs.cpu() + + return return_wfs.cpu().numpy() + + def get_assign_probs(self, log_lik_array): + fix = log_lik_array*-.5 + fix = fix - fix.max(1)[:, None] + probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] + self.probs = probs + return probs + + def run(self): + + #construct array to identify soft assignment units + unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + for unit in range(self.n_units): + row_idx= self.spike_train_og[:, 1] == unit + unit_assignment[row_idx, :] = self.similar_array[unit, :] + + + log_probs = self.compute_soft_assignment() + self.log_probs = log_probs + probs = self.get_assign_probs(log_probs) + replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) + + replace_probs[:, 0] = 1 + replace_probs[self.idx_included, :] = probs + + replace_log[self.idx_included, :] = log_probs + + return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file diff --git a/src/yass/soft_assignment/template_BASE_31194.py b/src/yass/soft_assignment/template_BASE_31194.py new file mode 100644 index 00000000..0955c193 --- /dev/null +++ b/src/yass/soft_assignment/template_BASE_31194.py @@ -0,0 +1,412 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Nov 18 21:39:48 2019 + +@author: kevin +""" + +import numpy as np +from tqdm import tqdm +import scipy.spatial.distance as dist +import torch +import cudaSpline as deconv +from scipy.interpolate import splrep +from numpy.linalg import inv as inv + +def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): + if knots is None: + knots = np.arange(len(curve) + prepad + postpad) + return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) + +def transform_template(template, knots=None, prepad=7, postpad=3, order=3): + + if knots is None: + knots = np.arange(len(template.data[0]) + prepad + postpad) + splines = [ + fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) + for curve in template.data.cpu().numpy() + ] + coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') + return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) + +def get_cov_matrix(spat_cov, geom): + posistion = geom + dist_matrix = dist.squareform(dist.pdist(geom )) + + cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) + + for i in range(posistion.shape[0]): + for j in range(posistion.shape[0]): + if dist_matrix[i, j] > np.max(spat_cov[:, 1]): + cov_matrix[i, j] = 0 + continue + idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] + if len(idx) == 0: + cov_matrix[i, j] = 0 + continue + cov_matrix[i, j] = spat_cov[idx, 0] + return cov_matrix + +#Soft assign object + +class TEMPLATE_ASSIGN_OBJECT(object): + def __init__(self, fname_spike_train, fname_templates, fname_shifts, + reader_residual, spat_cov, temp_cov, channel_idx, geom, + large_unit_threshold = 5, n_chans = 5, rec_chans = 512, + sim_units = 3, temp_thresh= np.inf, lik_window = 50): + + #get the variance of the residual: + self.temp_thresh = temp_thresh + self.rec_chans = rec_chans + self.sim_units = sim_units + self.templates = np.load(fname_templates).astype('float32') + self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) + self.spike_train = np.load(fname_spike_train) + self.spike_train_og = np.load(fname_spike_train) + #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] + self.idx_included = set([]) + self.units_in = set([]) + self.shifts = np.load(fname_shifts) + self.reader_residual = reader_residual + self.spat_cov = get_cov_matrix(spat_cov, geom) + self.temp_cov = temp_cov[:lik_window, :lik_window] + self.channel_index = channel_idx + self.n_neigh_chans = self.channel_index.shape[1] + self.n_chans = n_chans + self.n_units, self.n_times, self.n_channels = self.templates.shape + + self.n_total_spikes = self.spike_train.shape[0] + + #get residual variance + self.get_residual_variance() + + self.get_similar() + self.exclude_large_units(large_unit_threshold) + #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] + self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) + self.idx_included = np.logical_and(self.spike_train_og[:, 0] > 40, np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) + self.spike_train = self.spike_train_og[self.idx_included] + self.shifts = self.shifts[self.idx_included] + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + + self.aligned_template_list = [] + self.coeff_list = [] + self.preprocess_templates_and_spike_times() + self.chan_list = [] + + for i in range(0, self.sim_units): + diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) + for j in range(0, self.n_units): + diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) + + self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) + self.aligned_template_list.append(diff_array) + + #align orignal templates at end + self.aligned_template_list.append(self.templates_aligned) + self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) + + #get aligned templates + + + self.move_to_torch() + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + self.chans = torch.from_numpy(self.chans) + self.get_kronecker() + def get_residual_variance(self): + num = int(60/self.reader_residual.n_sec_chunk) + var_array = np.zeros(num) + for batch_id in range(num): + var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) + self.resid_var = np.mean(var_array) + + def get_kronecker(self): + self.cov_list = [] + inv_temp = inv(self.temp_cov) + for unit in range(self.n_units): + chans = self.chans[unit] + chans = chans[chans < self.rec_chans] + indices = np.ix_(chans,chans) + covar = np.kron(inv(self.spat_cov[indices]), inv_temp) + self.cov_list.append(covar) + self.cov_list = np.asarray(self.cov_list) + self.cov_list = torch.from_numpy(self.cov_list).half().cuda() + + def get_similar(self): + max_time = self.templates.argmax(1).astype("int16") + padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) + reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) + for unit in range(self.n_units): + for chan in range(self.rec_chans): + reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] + + see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + for i in range(self.n_units): + sorted_see = np.sort(see[i])[0:self.sim_units] + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(reduced[i]))) + #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + # self.units_in.add(i) + if sorted_see[1]/norm < self.temp_thresh: + self.units_in.add(i) + + ''' + def get_similar(self): + see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + units_in = [] + for i in range(self.n_units): + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(self.templates[i]))) + if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + self.units_in.add(i) + #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) + + #units_in= np.asarray(units_in) + #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(in_spikes) + ''' + #shift secondary template + def shift_template(self, template, shift): + if shift == 0: + return template + if shift > 0: + return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] + else: + return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] + + def preprocess_templates_and_spike_times(self): + + # templates on neighboring channels + self.mcs = self.templates.ptp(1).argmax(1) + + #template used for alignment defined on neighboring channels + templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) + + #template returned for likilihood calculation- defined on channels with largest ptp + return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) + for k in range(self.n_units): + neigh_chans = self.channel_index[self.mcs[k]] + neigh_chans = neigh_chans[neigh_chans 0: + buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) + return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) + + # get ailgned templates + t_in = np.arange(buffer_size, buffer_size + self.n_times) + + + templates_aligned = np.zeros((self.n_units, + self.n_times, + self.n_chans), 'float32') + for k in range(self.n_units): + t_in_temp = t_in + self.temp_shifts[k] + templates_aligned[k] = return_templates[k,t_in_temp] + + self.templates_aligned = templates_aligned + self.templates_aligned_numpy= templates_aligned + # shift spike times according to the alignment + self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] + + #shifted neighboring template according to shift in primary template + def subtract_template(self, primary_unit, neighbor_unit): + primary_unit_shift = self.temp_shifts[primary_unit] + shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) + return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] + + def exclude_large_units(self, threshold): + + norms = np.zeros(self.n_units) + for j in range(self.n_units): + temp = self.templates[j] + vis_chan = np.where(temp.ptp(0) > 1)[0] + norms[j] = np.sum(np.square(temp[:, vis_chan])) + + + self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) + + #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(self.idx_included.intersection(idx_in)) + def move_to_torch(self): + self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] + self.spike_train = torch.from_numpy(self.spike_train).long().cuda() + self.shifts = torch.from_numpy(self.shifts).float().cuda() + + self.mcs = torch.from_numpy(self.mcs) + + def get_bspline_coeffs(self, template_aligned): + + n_data, n_times, n_channels = template_aligned.shape + + channels = torch.arange(n_channels).cuda() + temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() + + temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) + coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) + return coeffs + def get_shifted_templates(self, temp_ids, shifts, iteration): + + temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() + shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() + + n_sample_run = 1000 + n_times = self.aligned_template_list[iteration].shape[1] + + idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) + + shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() + for j in range(len(idx_run)-1): + ii_start = idx_run[j] + ii_end =idx_run[j+1] + obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() + times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() + deconv.subtract_splines(obj, + times, + shifts[ii_start:ii_end], + temp_ids[ii_start:ii_end], + self.coeff_list[iteration], + torch.full( (ii_end - ii_start, ), 2 ).cuda()) + obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) + shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) + + return shifted_templates + + def get_liklihood(self, unit, snip): + chans = self.chans[unit] + chans = chans < self.rec_chans + log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) + return log_prob + + def compute_soft_assignment(self): + + log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() + + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] + + spike_train_batch = self.spike_train[idx_in] + spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) + + shift_batch = self.shifts[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] + shifted_templates = [element for element in shifted_templates] + + + clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] + unit_idx = spike_train_batch[:, 1] + logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() + for i in range(len(clean_wfs)): + cov_array = self.cov_list[unit_idx] + + restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] + unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) + temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) + result = torch.matmul( temp, unraveled[:, :, None].half()) + logs[:, i] = result.reshape(-1) + ''' + log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) + logs[:, i] = np.asarray(log_vec) + ''' + ''' + for j, spike in enumerate(clean_wfs[i]): + rel_unit = self.similar_array[unit_idx[j]][i] + logs[j, i]= self.get_liklihood(rel_unit, spike) + ''' + log_probs[idx_in] = logs + + pbar.update() + + + + self.log_probs = log_probs.cpu().numpy() + return log_probs.cpu().numpy() + def clean_wave_forms(self, spike_idx, unit): + return_wfs = torch.zeros((spike_idx.shape[0], 81, 5)) + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] + s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] + + idx_in = torch.nonzero((s1 & s2))[:,0] + + spike_train_batch = self.spike_train[spike_idx] + spike_train_batch = spike_train_batch[idx_in] + spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) + + shift_batch = self.shifts[spike_idx] + shift_batch = shift_batch[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + + shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) + clean_wfs = resid_snippets + shifted_og + return_wfs[idx_in] = clean_wfs.cpu() + + return return_wfs.cpu().numpy() + + def get_assign_probs(self, log_lik_array): + fix = log_lik_array*-.5 + fix = fix - fix.max(1)[:, None] + probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] + self.probs = probs + return probs + + def run(self): + + #construct array to identify soft assignment units + unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + for unit in range(self.n_units): + row_idx= self.spike_train_og[:, 1] == unit + unit_assignment[row_idx, :] = self.similar_array[unit, :] + + + log_probs = self.compute_soft_assignment() + probs = self.get_assign_probs(log_probs) + replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) + + replace_probs[:, 0] = 1 + replace_probs[self.idx_included, :] = probs + + replace_log[self.idx_included, :] = log_probs + + return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file diff --git a/src/yass/soft_assignment/template_BASE_31287.py b/src/yass/soft_assignment/template_BASE_31287.py new file mode 100644 index 00000000..0955c193 --- /dev/null +++ b/src/yass/soft_assignment/template_BASE_31287.py @@ -0,0 +1,412 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Nov 18 21:39:48 2019 + +@author: kevin +""" + +import numpy as np +from tqdm import tqdm +import scipy.spatial.distance as dist +import torch +import cudaSpline as deconv +from scipy.interpolate import splrep +from numpy.linalg import inv as inv + +def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): + if knots is None: + knots = np.arange(len(curve) + prepad + postpad) + return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) + +def transform_template(template, knots=None, prepad=7, postpad=3, order=3): + + if knots is None: + knots = np.arange(len(template.data[0]) + prepad + postpad) + splines = [ + fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) + for curve in template.data.cpu().numpy() + ] + coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') + return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) + +def get_cov_matrix(spat_cov, geom): + posistion = geom + dist_matrix = dist.squareform(dist.pdist(geom )) + + cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) + + for i in range(posistion.shape[0]): + for j in range(posistion.shape[0]): + if dist_matrix[i, j] > np.max(spat_cov[:, 1]): + cov_matrix[i, j] = 0 + continue + idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] + if len(idx) == 0: + cov_matrix[i, j] = 0 + continue + cov_matrix[i, j] = spat_cov[idx, 0] + return cov_matrix + +#Soft assign object + +class TEMPLATE_ASSIGN_OBJECT(object): + def __init__(self, fname_spike_train, fname_templates, fname_shifts, + reader_residual, spat_cov, temp_cov, channel_idx, geom, + large_unit_threshold = 5, n_chans = 5, rec_chans = 512, + sim_units = 3, temp_thresh= np.inf, lik_window = 50): + + #get the variance of the residual: + self.temp_thresh = temp_thresh + self.rec_chans = rec_chans + self.sim_units = sim_units + self.templates = np.load(fname_templates).astype('float32') + self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) + self.spike_train = np.load(fname_spike_train) + self.spike_train_og = np.load(fname_spike_train) + #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] + self.idx_included = set([]) + self.units_in = set([]) + self.shifts = np.load(fname_shifts) + self.reader_residual = reader_residual + self.spat_cov = get_cov_matrix(spat_cov, geom) + self.temp_cov = temp_cov[:lik_window, :lik_window] + self.channel_index = channel_idx + self.n_neigh_chans = self.channel_index.shape[1] + self.n_chans = n_chans + self.n_units, self.n_times, self.n_channels = self.templates.shape + + self.n_total_spikes = self.spike_train.shape[0] + + #get residual variance + self.get_residual_variance() + + self.get_similar() + self.exclude_large_units(large_unit_threshold) + #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] + self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) + self.idx_included = np.logical_and(self.spike_train_og[:, 0] > 40, np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) + self.spike_train = self.spike_train_og[self.idx_included] + self.shifts = self.shifts[self.idx_included] + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + + self.aligned_template_list = [] + self.coeff_list = [] + self.preprocess_templates_and_spike_times() + self.chan_list = [] + + for i in range(0, self.sim_units): + diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) + for j in range(0, self.n_units): + diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) + + self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) + self.aligned_template_list.append(diff_array) + + #align orignal templates at end + self.aligned_template_list.append(self.templates_aligned) + self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) + + #get aligned templates + + + self.move_to_torch() + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + self.chans = torch.from_numpy(self.chans) + self.get_kronecker() + def get_residual_variance(self): + num = int(60/self.reader_residual.n_sec_chunk) + var_array = np.zeros(num) + for batch_id in range(num): + var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) + self.resid_var = np.mean(var_array) + + def get_kronecker(self): + self.cov_list = [] + inv_temp = inv(self.temp_cov) + for unit in range(self.n_units): + chans = self.chans[unit] + chans = chans[chans < self.rec_chans] + indices = np.ix_(chans,chans) + covar = np.kron(inv(self.spat_cov[indices]), inv_temp) + self.cov_list.append(covar) + self.cov_list = np.asarray(self.cov_list) + self.cov_list = torch.from_numpy(self.cov_list).half().cuda() + + def get_similar(self): + max_time = self.templates.argmax(1).astype("int16") + padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) + reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) + for unit in range(self.n_units): + for chan in range(self.rec_chans): + reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] + + see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + for i in range(self.n_units): + sorted_see = np.sort(see[i])[0:self.sim_units] + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(reduced[i]))) + #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + # self.units_in.add(i) + if sorted_see[1]/norm < self.temp_thresh: + self.units_in.add(i) + + ''' + def get_similar(self): + see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + units_in = [] + for i in range(self.n_units): + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(self.templates[i]))) + if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + self.units_in.add(i) + #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) + + #units_in= np.asarray(units_in) + #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(in_spikes) + ''' + #shift secondary template + def shift_template(self, template, shift): + if shift == 0: + return template + if shift > 0: + return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] + else: + return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] + + def preprocess_templates_and_spike_times(self): + + # templates on neighboring channels + self.mcs = self.templates.ptp(1).argmax(1) + + #template used for alignment defined on neighboring channels + templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) + + #template returned for likilihood calculation- defined on channels with largest ptp + return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) + for k in range(self.n_units): + neigh_chans = self.channel_index[self.mcs[k]] + neigh_chans = neigh_chans[neigh_chans 0: + buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) + return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) + + # get ailgned templates + t_in = np.arange(buffer_size, buffer_size + self.n_times) + + + templates_aligned = np.zeros((self.n_units, + self.n_times, + self.n_chans), 'float32') + for k in range(self.n_units): + t_in_temp = t_in + self.temp_shifts[k] + templates_aligned[k] = return_templates[k,t_in_temp] + + self.templates_aligned = templates_aligned + self.templates_aligned_numpy= templates_aligned + # shift spike times according to the alignment + self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] + + #shifted neighboring template according to shift in primary template + def subtract_template(self, primary_unit, neighbor_unit): + primary_unit_shift = self.temp_shifts[primary_unit] + shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) + return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] + + def exclude_large_units(self, threshold): + + norms = np.zeros(self.n_units) + for j in range(self.n_units): + temp = self.templates[j] + vis_chan = np.where(temp.ptp(0) > 1)[0] + norms[j] = np.sum(np.square(temp[:, vis_chan])) + + + self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) + + #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(self.idx_included.intersection(idx_in)) + def move_to_torch(self): + self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] + self.spike_train = torch.from_numpy(self.spike_train).long().cuda() + self.shifts = torch.from_numpy(self.shifts).float().cuda() + + self.mcs = torch.from_numpy(self.mcs) + + def get_bspline_coeffs(self, template_aligned): + + n_data, n_times, n_channels = template_aligned.shape + + channels = torch.arange(n_channels).cuda() + temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() + + temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) + coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) + return coeffs + def get_shifted_templates(self, temp_ids, shifts, iteration): + + temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() + shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() + + n_sample_run = 1000 + n_times = self.aligned_template_list[iteration].shape[1] + + idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) + + shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() + for j in range(len(idx_run)-1): + ii_start = idx_run[j] + ii_end =idx_run[j+1] + obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() + times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() + deconv.subtract_splines(obj, + times, + shifts[ii_start:ii_end], + temp_ids[ii_start:ii_end], + self.coeff_list[iteration], + torch.full( (ii_end - ii_start, ), 2 ).cuda()) + obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) + shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) + + return shifted_templates + + def get_liklihood(self, unit, snip): + chans = self.chans[unit] + chans = chans < self.rec_chans + log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) + return log_prob + + def compute_soft_assignment(self): + + log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() + + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] + + spike_train_batch = self.spike_train[idx_in] + spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) + + shift_batch = self.shifts[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] + shifted_templates = [element for element in shifted_templates] + + + clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] + unit_idx = spike_train_batch[:, 1] + logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() + for i in range(len(clean_wfs)): + cov_array = self.cov_list[unit_idx] + + restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] + unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) + temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) + result = torch.matmul( temp, unraveled[:, :, None].half()) + logs[:, i] = result.reshape(-1) + ''' + log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) + logs[:, i] = np.asarray(log_vec) + ''' + ''' + for j, spike in enumerate(clean_wfs[i]): + rel_unit = self.similar_array[unit_idx[j]][i] + logs[j, i]= self.get_liklihood(rel_unit, spike) + ''' + log_probs[idx_in] = logs + + pbar.update() + + + + self.log_probs = log_probs.cpu().numpy() + return log_probs.cpu().numpy() + def clean_wave_forms(self, spike_idx, unit): + return_wfs = torch.zeros((spike_idx.shape[0], 81, 5)) + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] + s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] + + idx_in = torch.nonzero((s1 & s2))[:,0] + + spike_train_batch = self.spike_train[spike_idx] + spike_train_batch = spike_train_batch[idx_in] + spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) + + shift_batch = self.shifts[spike_idx] + shift_batch = shift_batch[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + + shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) + clean_wfs = resid_snippets + shifted_og + return_wfs[idx_in] = clean_wfs.cpu() + + return return_wfs.cpu().numpy() + + def get_assign_probs(self, log_lik_array): + fix = log_lik_array*-.5 + fix = fix - fix.max(1)[:, None] + probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] + self.probs = probs + return probs + + def run(self): + + #construct array to identify soft assignment units + unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + for unit in range(self.n_units): + row_idx= self.spike_train_og[:, 1] == unit + unit_assignment[row_idx, :] = self.similar_array[unit, :] + + + log_probs = self.compute_soft_assignment() + probs = self.get_assign_probs(log_probs) + replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) + + replace_probs[:, 0] = 1 + replace_probs[self.idx_included, :] = probs + + replace_log[self.idx_included, :] = log_probs + + return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file diff --git a/src/yass/soft_assignment/template_LOCAL_31194.py b/src/yass/soft_assignment/template_LOCAL_31194.py new file mode 100644 index 00000000..c3d631e6 --- /dev/null +++ b/src/yass/soft_assignment/template_LOCAL_31194.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Nov 18 21:39:48 2019 + +@author: kevin +""" + +import numpy as np +from tqdm import tqdm +import scipy.spatial.distance as dist +import torch +import cudaSpline as deconv +from scipy.interpolate import splrep +from numpy.linalg import inv as inv + +def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): + if knots is None: + knots = np.arange(len(curve) + prepad + postpad) + return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) + +def transform_template(template, knots=None, prepad=7, postpad=3, order=3): + + if knots is None: + knots = np.arange(len(template.data[0]) + prepad + postpad) + splines = [ + fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) + for curve in template.data.cpu().numpy() + ] + coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') + return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) + +def get_cov_matrix(spat_cov, geom): + posistion = geom + dist_matrix = dist.squareform(dist.pdist(geom )) + + cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) + + for i in range(posistion.shape[0]): + for j in range(posistion.shape[0]): + if dist_matrix[i, j] > np.max(spat_cov[:, 1]): + cov_matrix[i, j] = 0 + continue + idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] + if len(idx) == 0: + cov_matrix[i, j] = 0 + continue + cov_matrix[i, j] = spat_cov[idx, 0] + return cov_matrix + +#Soft assign object + +class TEMPLATE_ASSIGN_OBJECT(object): + def __init__(self, fname_spike_train, fname_templates, fname_shifts, + reader_residual, spat_cov, temp_cov, channel_idx, geom, + large_unit_threshold = 5, n_chans = 5, rec_chans = 512, + sim_units = 3, temp_thresh= np.inf, lik_window = 50): + + #get the variance of the residual: + self.temp_thresh = temp_thresh + self.rec_chans = rec_chans + self.sim_units = sim_units + self.templates = np.load(fname_templates).astype('float32') + self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) + self.spike_train = np.load(fname_spike_train) + self.spike_train_og = np.load(fname_spike_train) + #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] + self.idx_included = set([]) + self.units_in = set([]) + self.shifts = np.load(fname_shifts) + self.reader_residual = reader_residual + self.spat_cov = get_cov_matrix(spat_cov, geom) + self.temp_cov = temp_cov[:lik_window, :lik_window] + self.channel_index = channel_idx + self.n_neigh_chans = self.channel_index.shape[1] + self.n_chans = n_chans + self.n_units, self.n_times, self.n_channels = self.templates.shape + + self.n_total_spikes = self.spike_train.shape[0] + + #get residual variance + self.get_residual_variance() + + self.get_similar() + self.exclude_large_units(large_unit_threshold) + #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] + self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) + self.idx_included = np.logical_and(np.logical_and(self.spike_train_og[:, 0] < reader_residual.rec_len - self.n_times//2, self.spike_train_og[:, 0] > self.n_times//2), np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) + self.spike_train = self.spike_train_og[self.idx_included] + self.shifts = self.shifts[self.idx_included] + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + + self.aligned_template_list = [] + self.coeff_list = [] + self.preprocess_templates_and_spike_times() + self.chan_list = [] + + for i in range(0, self.sim_units): + diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) + for j in range(0, self.n_units): + diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) + + self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) + self.aligned_template_list.append(diff_array) + + #align orignal templates at end + self.aligned_template_list.append(self.templates_aligned) + self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) + + #get aligned templates + + + self.move_to_torch() + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + self.chans = torch.from_numpy(self.chans) + self.get_kronecker() + def get_residual_variance(self): + num = int(60/self.reader_residual.n_sec_chunk) + var_array = np.zeros(num) + for batch_id in range(num): + var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) + self.resid_var = np.mean(var_array) + + def get_kronecker(self): + self.cov_list = [] + inv_temp = inv(self.temp_cov) + for unit in range(self.n_units): + chans = self.chans[unit] + chans = chans[chans < self.rec_chans] + indices = np.ix_(chans,chans) + covar = np.kron(inv(self.spat_cov[indices]), inv_temp) + self.cov_list.append(covar) + self.cov_list = np.asarray(self.cov_list) + self.cov_list = torch.from_numpy(self.cov_list).half().cuda() + + def get_similar(self): + max_time = self.templates.argmax(1).astype("int16") + padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) + reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) + for unit in range(self.n_units): + for chan in range(self.rec_chans): + reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] + + see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + for i in range(self.n_units): + sorted_see = np.sort(see[i])[0:self.sim_units] + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(reduced[i]))) + #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + # self.units_in.add(i) + if sorted_see[1]/norm < self.temp_thresh: + self.units_in.add(i) + + ''' + def get_similar(self): + see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + units_in = [] + for i in range(self.n_units): + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(self.templates[i]))) + if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + self.units_in.add(i) + #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) + + #units_in= np.asarray(units_in) + #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(in_spikes) + ''' + #shift secondary template + def shift_template(self, template, shift): + if shift == 0: + return template + if shift > 0: + return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] + else: + return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] + + def preprocess_templates_and_spike_times(self): + + # templates on neighboring channels + self.mcs = self.templates.ptp(1).argmax(1) + + #template used for alignment defined on neighboring channels + templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) + + #template returned for likilihood calculation- defined on channels with largest ptp + return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) + for k in range(self.n_units): + neigh_chans = self.channel_index[self.mcs[k]] + neigh_chans = neigh_chans[neigh_chans 0: + buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) + return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) + + # get ailgned templates + t_in = np.arange(buffer_size, buffer_size + self.n_times) + + + templates_aligned = np.zeros((self.n_units, + self.n_times, + self.n_chans), 'float32') + for k in range(self.n_units): + t_in_temp = t_in + self.temp_shifts[k] + templates_aligned[k] = return_templates[k,t_in_temp] + + self.templates_aligned = templates_aligned + self.templates_aligned_numpy= templates_aligned + # shift spike times according to the alignment + self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] + + #shifted neighboring template according to shift in primary template + def subtract_template(self, primary_unit, neighbor_unit): + primary_unit_shift = self.temp_shifts[primary_unit] + shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) + return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] + + def exclude_large_units(self, threshold): + + norms = np.zeros(self.n_units) + for j in range(self.n_units): + temp = self.templates[j] + vis_chan = np.where(temp.ptp(0) > 1)[0] + norms[j] = np.sum(np.square(temp[:, vis_chan])) + + + self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) + + #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(self.idx_included.intersection(idx_in)) + def move_to_torch(self): + self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] + self.spike_train = torch.from_numpy(self.spike_train).long().cuda() + self.shifts = torch.from_numpy(self.shifts).float().cuda() + + self.mcs = torch.from_numpy(self.mcs) + + def get_bspline_coeffs(self, template_aligned): + + n_data, n_times, n_channels = template_aligned.shape + + channels = torch.arange(n_channels).cuda() + temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() + + temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) + coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) + return coeffs + def get_shifted_templates(self, temp_ids, shifts, iteration): + temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() + shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() + + n_sample_run = 1000 + n_times = self.aligned_template_list[iteration].shape[1] + + idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) + + shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() + for j in range(len(idx_run)-1): + ii_start = idx_run[j] + ii_end =idx_run[j+1] + obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() + times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() + deconv.subtract_splines(obj, + times, + shifts[ii_start:ii_end], + temp_ids[ii_start:ii_end], + self.coeff_list[iteration], + torch.full( (ii_end - ii_start, ), 2 ).cuda()) + obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) + shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) + + return shifted_templates + + def get_liklihood(self, unit, snip): + chans = self.chans[unit] + chans = chans < self.rec_chans + log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) + return log_prob + + def compute_soft_assignment(self): + + log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() + + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] + + spike_train_batch = self.spike_train[idx_in] + + + + spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) + shift_batch = self.shifts[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + + shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] + shifted_templates = [element for element in shifted_templates] + + + clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] + unit_idx = spike_train_batch[:, 1] + logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() + for i in range(len(clean_wfs)): + cov_array = self.cov_list[unit_idx] + + restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] + unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) + temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) + result = torch.matmul( temp, unraveled[:, :, None].half()) + logs[:, i] = result.reshape(-1) + ''' + log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) + logs[:, i] = np.asarray(log_vec) + ''' + ''' + for j, spike in enumerate(clean_wfs[i]): + rel_unit = self.similar_array[unit_idx[j]][i] + logs[j, i]= self.get_liklihood(rel_unit, spike) + ''' + log_probs[idx_in] = logs + + pbar.update() + + + + self.log_probs = log_probs.cpu().numpy() + return log_probs.cpu().numpy() + def clean_wave_forms(self, spike_idx, unit): + return_wfs = torch.zeros((spike_idx.shape[0],self.templates.shape[1], self.n_chans)) + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] + s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] + + idx_in = torch.nonzero((s1 & s2))[:,0] + + spike_train_batch = self.spike_train[spike_idx] + spike_train_batch = spike_train_batch[idx_in] + spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) + + shift_batch = self.shifts[spike_idx] + shift_batch = shift_batch[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + + shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) + clean_wfs = resid_snippets + shifted_og + return_wfs[idx_in] = clean_wfs.cpu() + + return return_wfs.cpu().numpy() + + def get_assign_probs(self, log_lik_array): + fix = log_lik_array*-.5 + fix = fix - fix.max(1)[:, None] + probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] + self.probs = probs + return probs + + def run(self): + + #construct array to identify soft assignment units + unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + for unit in range(self.n_units): + row_idx= self.spike_train_og[:, 1] == unit + unit_assignment[row_idx, :] = self.similar_array[unit, :] + + + log_probs = self.compute_soft_assignment() + self.log_probs = log_probs + probs = self.get_assign_probs(log_probs) + replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) + + replace_probs[:, 0] = 1 + replace_probs[self.idx_included, :] = probs + + replace_log[self.idx_included, :] = log_probs + + return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file diff --git a/src/yass/soft_assignment/template_LOCAL_31287.py b/src/yass/soft_assignment/template_LOCAL_31287.py new file mode 100644 index 00000000..c3d631e6 --- /dev/null +++ b/src/yass/soft_assignment/template_LOCAL_31287.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Nov 18 21:39:48 2019 + +@author: kevin +""" + +import numpy as np +from tqdm import tqdm +import scipy.spatial.distance as dist +import torch +import cudaSpline as deconv +from scipy.interpolate import splrep +from numpy.linalg import inv as inv + +def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): + if knots is None: + knots = np.arange(len(curve) + prepad + postpad) + return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) + +def transform_template(template, knots=None, prepad=7, postpad=3, order=3): + + if knots is None: + knots = np.arange(len(template.data[0]) + prepad + postpad) + splines = [ + fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) + for curve in template.data.cpu().numpy() + ] + coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') + return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) + +def get_cov_matrix(spat_cov, geom): + posistion = geom + dist_matrix = dist.squareform(dist.pdist(geom )) + + cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) + + for i in range(posistion.shape[0]): + for j in range(posistion.shape[0]): + if dist_matrix[i, j] > np.max(spat_cov[:, 1]): + cov_matrix[i, j] = 0 + continue + idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] + if len(idx) == 0: + cov_matrix[i, j] = 0 + continue + cov_matrix[i, j] = spat_cov[idx, 0] + return cov_matrix + +#Soft assign object + +class TEMPLATE_ASSIGN_OBJECT(object): + def __init__(self, fname_spike_train, fname_templates, fname_shifts, + reader_residual, spat_cov, temp_cov, channel_idx, geom, + large_unit_threshold = 5, n_chans = 5, rec_chans = 512, + sim_units = 3, temp_thresh= np.inf, lik_window = 50): + + #get the variance of the residual: + self.temp_thresh = temp_thresh + self.rec_chans = rec_chans + self.sim_units = sim_units + self.templates = np.load(fname_templates).astype('float32') + self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) + self.spike_train = np.load(fname_spike_train) + self.spike_train_og = np.load(fname_spike_train) + #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] + self.idx_included = set([]) + self.units_in = set([]) + self.shifts = np.load(fname_shifts) + self.reader_residual = reader_residual + self.spat_cov = get_cov_matrix(spat_cov, geom) + self.temp_cov = temp_cov[:lik_window, :lik_window] + self.channel_index = channel_idx + self.n_neigh_chans = self.channel_index.shape[1] + self.n_chans = n_chans + self.n_units, self.n_times, self.n_channels = self.templates.shape + + self.n_total_spikes = self.spike_train.shape[0] + + #get residual variance + self.get_residual_variance() + + self.get_similar() + self.exclude_large_units(large_unit_threshold) + #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] + self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) + self.idx_included = np.logical_and(np.logical_and(self.spike_train_og[:, 0] < reader_residual.rec_len - self.n_times//2, self.spike_train_og[:, 0] > self.n_times//2), np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) + self.spike_train = self.spike_train_og[self.idx_included] + self.shifts = self.shifts[self.idx_included] + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + + self.aligned_template_list = [] + self.coeff_list = [] + self.preprocess_templates_and_spike_times() + self.chan_list = [] + + for i in range(0, self.sim_units): + diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) + for j in range(0, self.n_units): + diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) + + self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) + self.aligned_template_list.append(diff_array) + + #align orignal templates at end + self.aligned_template_list.append(self.templates_aligned) + self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) + + #get aligned templates + + + self.move_to_torch() + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + self.chans = torch.from_numpy(self.chans) + self.get_kronecker() + def get_residual_variance(self): + num = int(60/self.reader_residual.n_sec_chunk) + var_array = np.zeros(num) + for batch_id in range(num): + var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) + self.resid_var = np.mean(var_array) + + def get_kronecker(self): + self.cov_list = [] + inv_temp = inv(self.temp_cov) + for unit in range(self.n_units): + chans = self.chans[unit] + chans = chans[chans < self.rec_chans] + indices = np.ix_(chans,chans) + covar = np.kron(inv(self.spat_cov[indices]), inv_temp) + self.cov_list.append(covar) + self.cov_list = np.asarray(self.cov_list) + self.cov_list = torch.from_numpy(self.cov_list).half().cuda() + + def get_similar(self): + max_time = self.templates.argmax(1).astype("int16") + padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) + reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) + for unit in range(self.n_units): + for chan in range(self.rec_chans): + reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] + + see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + for i in range(self.n_units): + sorted_see = np.sort(see[i])[0:self.sim_units] + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(reduced[i]))) + #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + # self.units_in.add(i) + if sorted_see[1]/norm < self.temp_thresh: + self.units_in.add(i) + + ''' + def get_similar(self): + see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + units_in = [] + for i in range(self.n_units): + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(self.templates[i]))) + if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + self.units_in.add(i) + #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) + + #units_in= np.asarray(units_in) + #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(in_spikes) + ''' + #shift secondary template + def shift_template(self, template, shift): + if shift == 0: + return template + if shift > 0: + return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] + else: + return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] + + def preprocess_templates_and_spike_times(self): + + # templates on neighboring channels + self.mcs = self.templates.ptp(1).argmax(1) + + #template used for alignment defined on neighboring channels + templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) + + #template returned for likilihood calculation- defined on channels with largest ptp + return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) + for k in range(self.n_units): + neigh_chans = self.channel_index[self.mcs[k]] + neigh_chans = neigh_chans[neigh_chans 0: + buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) + return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) + + # get ailgned templates + t_in = np.arange(buffer_size, buffer_size + self.n_times) + + + templates_aligned = np.zeros((self.n_units, + self.n_times, + self.n_chans), 'float32') + for k in range(self.n_units): + t_in_temp = t_in + self.temp_shifts[k] + templates_aligned[k] = return_templates[k,t_in_temp] + + self.templates_aligned = templates_aligned + self.templates_aligned_numpy= templates_aligned + # shift spike times according to the alignment + self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] + + #shifted neighboring template according to shift in primary template + def subtract_template(self, primary_unit, neighbor_unit): + primary_unit_shift = self.temp_shifts[primary_unit] + shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) + return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] + + def exclude_large_units(self, threshold): + + norms = np.zeros(self.n_units) + for j in range(self.n_units): + temp = self.templates[j] + vis_chan = np.where(temp.ptp(0) > 1)[0] + norms[j] = np.sum(np.square(temp[:, vis_chan])) + + + self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) + + #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(self.idx_included.intersection(idx_in)) + def move_to_torch(self): + self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] + self.spike_train = torch.from_numpy(self.spike_train).long().cuda() + self.shifts = torch.from_numpy(self.shifts).float().cuda() + + self.mcs = torch.from_numpy(self.mcs) + + def get_bspline_coeffs(self, template_aligned): + + n_data, n_times, n_channels = template_aligned.shape + + channels = torch.arange(n_channels).cuda() + temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() + + temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) + coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) + return coeffs + def get_shifted_templates(self, temp_ids, shifts, iteration): + temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() + shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() + + n_sample_run = 1000 + n_times = self.aligned_template_list[iteration].shape[1] + + idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) + + shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() + for j in range(len(idx_run)-1): + ii_start = idx_run[j] + ii_end =idx_run[j+1] + obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() + times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() + deconv.subtract_splines(obj, + times, + shifts[ii_start:ii_end], + temp_ids[ii_start:ii_end], + self.coeff_list[iteration], + torch.full( (ii_end - ii_start, ), 2 ).cuda()) + obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) + shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) + + return shifted_templates + + def get_liklihood(self, unit, snip): + chans = self.chans[unit] + chans = chans < self.rec_chans + log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) + return log_prob + + def compute_soft_assignment(self): + + log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() + + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] + + spike_train_batch = self.spike_train[idx_in] + + + + spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) + shift_batch = self.shifts[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + + shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] + shifted_templates = [element for element in shifted_templates] + + + clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] + unit_idx = spike_train_batch[:, 1] + logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() + for i in range(len(clean_wfs)): + cov_array = self.cov_list[unit_idx] + + restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] + unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) + temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) + result = torch.matmul( temp, unraveled[:, :, None].half()) + logs[:, i] = result.reshape(-1) + ''' + log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) + logs[:, i] = np.asarray(log_vec) + ''' + ''' + for j, spike in enumerate(clean_wfs[i]): + rel_unit = self.similar_array[unit_idx[j]][i] + logs[j, i]= self.get_liklihood(rel_unit, spike) + ''' + log_probs[idx_in] = logs + + pbar.update() + + + + self.log_probs = log_probs.cpu().numpy() + return log_probs.cpu().numpy() + def clean_wave_forms(self, spike_idx, unit): + return_wfs = torch.zeros((spike_idx.shape[0],self.templates.shape[1], self.n_chans)) + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] + s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] + + idx_in = torch.nonzero((s1 & s2))[:,0] + + spike_train_batch = self.spike_train[spike_idx] + spike_train_batch = spike_train_batch[idx_in] + spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) + + shift_batch = self.shifts[spike_idx] + shift_batch = shift_batch[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + + shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) + clean_wfs = resid_snippets + shifted_og + return_wfs[idx_in] = clean_wfs.cpu() + + return return_wfs.cpu().numpy() + + def get_assign_probs(self, log_lik_array): + fix = log_lik_array*-.5 + fix = fix - fix.max(1)[:, None] + probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] + self.probs = probs + return probs + + def run(self): + + #construct array to identify soft assignment units + unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + for unit in range(self.n_units): + row_idx= self.spike_train_og[:, 1] == unit + unit_assignment[row_idx, :] = self.similar_array[unit, :] + + + log_probs = self.compute_soft_assignment() + self.log_probs = log_probs + probs = self.get_assign_probs(log_probs) + replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) + + replace_probs[:, 0] = 1 + replace_probs[self.idx_included, :] = probs + + replace_log[self.idx_included, :] = log_probs + + return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file diff --git a/src/yass/soft_assignment/template_REMOTE_31194.py b/src/yass/soft_assignment/template_REMOTE_31194.py new file mode 100644 index 00000000..20b92cef --- /dev/null +++ b/src/yass/soft_assignment/template_REMOTE_31194.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Nov 18 21:39:48 2019 + +@author: kevin +""" + +import numpy as np +from tqdm import tqdm +import scipy.spatial.distance as dist +import torch +import cudaSpline as deconv +from scipy.interpolate import splrep +from numpy.linalg import inv as inv + +def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): + if knots is None: + knots = np.arange(len(curve) + prepad + postpad) + return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) + +def transform_template(template, knots=None, prepad=7, postpad=3, order=3): + + if knots is None: + knots = np.arange(len(template.data[0]) + prepad + postpad) + splines = [ + fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) + for curve in template.data.cpu().numpy() + ] + coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') + return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) + +def get_cov_matrix(spat_cov, geom): + posistion = geom + dist_matrix = dist.squareform(dist.pdist(geom )) + + cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) + + for i in range(posistion.shape[0]): + for j in range(posistion.shape[0]): + if dist_matrix[i, j] > np.max(spat_cov[:, 1]): + cov_matrix[i, j] = 0 + continue + idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] + if len(idx) == 0: + cov_matrix[i, j] = 0 + continue + cov_matrix[i, j] = spat_cov[idx, 0] + return cov_matrix + +#Soft assign object + +class TEMPLATE_ASSIGN_OBJECT(object): + def __init__(self, fname_spike_train, fname_templates, fname_shifts, + reader_residual, spat_cov, temp_cov, channel_idx, geom, + large_unit_threshold = 5, n_chans = 5, rec_chans = 512, + sim_units = 3, temp_thresh= np.inf, lik_window = 50): + + #get the variance of the residual: + self.temp_thresh = temp_thresh + self.rec_chans = rec_chans + self.sim_units = sim_units + self.templates = np.load(fname_templates).astype('float32') + self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) + self.spike_train = np.load(fname_spike_train) + self.spike_train_og = np.load(fname_spike_train) + #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] + self.idx_included = set([]) + self.units_in = set([]) + self.shifts = np.load(fname_shifts) + self.reader_residual = reader_residual + self.spat_cov = get_cov_matrix(spat_cov, geom) + self.temp_cov = temp_cov[:lik_window, :lik_window] + self.channel_index = channel_idx + self.n_neigh_chans = self.channel_index.shape[1] + self.n_chans = n_chans + self.n_units, self.n_times, self.n_channels = self.templates.shape + + self.n_total_spikes = self.spike_train.shape[0] + + #get residual variance + self.get_residual_variance() + + self.get_similar() + self.exclude_large_units(large_unit_threshold) + #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] + self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) + self.idx_included = np.logical_and(self.spike_train_og[:, 0] > 40, np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) + self.spike_train = self.spike_train_og[self.idx_included] + self.shifts = self.shifts[self.idx_included] + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + + self.aligned_template_list = [] + self.coeff_list = [] + self.preprocess_templates_and_spike_times() + self.chan_list = [] + + for i in range(0, self.sim_units): + diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) + for j in range(0, self.n_units): + diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) + + self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) + self.aligned_template_list.append(diff_array) + + #align orignal templates at end + self.aligned_template_list.append(self.templates_aligned) + self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) + + #get aligned templates + + + self.move_to_torch() + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + self.chans = torch.from_numpy(self.chans) + self.get_kronecker() + def get_residual_variance(self): + num = int(60/self.reader_residual.n_sec_chunk) + var_array = np.zeros(num) + for batch_id in range(num): + var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) + self.resid_var = np.mean(var_array) + + def get_kronecker(self): + self.cov_list = [] + inv_temp = inv(self.temp_cov) + for unit in range(self.n_units): + chans = self.chans[unit] + chans = chans[chans < self.rec_chans] + indices = np.ix_(chans,chans) + covar = np.kron(inv(self.spat_cov[indices]), inv_temp) + self.cov_list.append(covar) + self.cov_list = np.asarray(self.cov_list) + self.cov_list = torch.from_numpy(self.cov_list).half().cuda() + + def get_similar(self): + max_time = self.templates.argmax(1).astype("int16") + padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) + reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) + for unit in range(self.n_units): + for chan in range(self.rec_chans): + reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] + + see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + for i in range(self.n_units): + sorted_see = np.sort(see[i])[0:self.sim_units] + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(reduced[i]))) + #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + # self.units_in.add(i) + if sorted_see[1]/norm < self.temp_thresh: + self.units_in.add(i) + + ''' + def get_similar(self): + see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + units_in = [] + for i in range(self.n_units): + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(self.templates[i]))) + if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + self.units_in.add(i) + #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) + + #units_in= np.asarray(units_in) + #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(in_spikes) + ''' + #shift secondary template + def shift_template(self, template, shift): + if shift == 0: + return template + if shift > 0: + return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] + else: + return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] + + def preprocess_templates_and_spike_times(self): + + # templates on neighboring channels + self.mcs = self.templates.ptp(1).argmax(1) + + #template used for alignment defined on neighboring channels + templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) + + #template returned for likilihood calculation- defined on channels with largest ptp + return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) + for k in range(self.n_units): + neigh_chans = self.channel_index[self.mcs[k]] + neigh_chans = neigh_chans[neigh_chans 0: + buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) + return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) + + # get ailgned templates + t_in = np.arange(buffer_size, buffer_size + self.n_times) + + + templates_aligned = np.zeros((self.n_units, + self.n_times, + self.n_chans), 'float32') + for k in range(self.n_units): + t_in_temp = t_in + self.temp_shifts[k] + templates_aligned[k] = return_templates[k,t_in_temp] + + self.templates_aligned = templates_aligned + self.templates_aligned_numpy= templates_aligned + # shift spike times according to the alignment + self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] + + #shifted neighboring template according to shift in primary template + def subtract_template(self, primary_unit, neighbor_unit): + primary_unit_shift = self.temp_shifts[primary_unit] + shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) + return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] + + def exclude_large_units(self, threshold): + + norms = np.zeros(self.n_units) + for j in range(self.n_units): + temp = self.templates[j] + vis_chan = np.where(temp.ptp(0) > 1)[0] + norms[j] = np.sum(np.square(temp[:, vis_chan])) + + + self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) + + #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(self.idx_included.intersection(idx_in)) + def move_to_torch(self): + self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] + self.spike_train = torch.from_numpy(self.spike_train).long().cuda() + self.shifts = torch.from_numpy(self.shifts).float().cuda() + + self.mcs = torch.from_numpy(self.mcs) + + def get_bspline_coeffs(self, template_aligned): + + n_data, n_times, n_channels = template_aligned.shape + + channels = torch.arange(n_channels).cuda() + temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() + + temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) + coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) + return coeffs + def get_shifted_templates(self, temp_ids, shifts, iteration): + + temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() + shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() + + n_sample_run = 1000 + n_times = self.aligned_template_list[iteration].shape[1] + + idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) + + shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() + for j in range(len(idx_run)-1): + ii_start = idx_run[j] + ii_end =idx_run[j+1] + obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() + times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() + deconv.subtract_splines(obj, + times, + shifts[ii_start:ii_end], + temp_ids[ii_start:ii_end], + self.coeff_list[iteration], + torch.full( (ii_end - ii_start, ), 2 ).cuda()) + obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) + shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) + + return shifted_templates + + def get_liklihood(self, unit, snip): + chans = self.chans[unit] + chans = chans < self.rec_chans + log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) + return log_prob + + def compute_soft_assignment(self): + + log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() + + # batch offsets + offsets = torch.from_numpy(self.reader_residual.idx_list[:, 0] + - self.reader_residual.buffer).cuda().long() + + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] + + spike_train_batch = self.spike_train[idx_in] + spike_train_batch[:, 0] -= offsets[batch_id] + + shift_batch = self.shifts[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] + shifted_templates = [element for element in shifted_templates] + + + clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] + unit_idx = spike_train_batch[:, 1] + logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() + for i in range(len(clean_wfs)): + cov_array = self.cov_list[unit_idx] + + restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] + unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) + temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) + result = torch.matmul( temp, unraveled[:, :, None].half()) + logs[:, i] = result.reshape(-1) + ''' + log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) + logs[:, i] = np.asarray(log_vec) + ''' + ''' + for j, spike in enumerate(clean_wfs[i]): + rel_unit = self.similar_array[unit_idx[j]][i] + logs[j, i]= self.get_liklihood(rel_unit, spike) + ''' + log_probs[idx_in] = logs + + pbar.update() + + + + self.log_probs = log_probs.cpu().numpy() + return log_probs.cpu().numpy() + def clean_wave_forms(self, spike_idx, unit): + return_wfs = torch.zeros((spike_idx.shape[0], 81, 5)) + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] + s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] + + idx_in = torch.nonzero((s1 & s2))[:,0] + + spike_train_batch = self.spike_train[spike_idx] + spike_train_batch = spike_train_batch[idx_in] + spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) + + shift_batch = self.shifts[spike_idx] + shift_batch = shift_batch[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + + shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) + clean_wfs = resid_snippets + shifted_og + return_wfs[idx_in] = clean_wfs.cpu() + + return return_wfs.cpu().numpy() + + def get_assign_probs(self, log_lik_array): + fix = log_lik_array*-.5 + fix = fix - fix.max(1)[:, None] + probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] + self.probs = probs + return probs + + def run(self): + + #construct array to identify soft assignment units + unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + for unit in range(self.n_units): + row_idx= self.spike_train_og[:, 1] == unit + unit_assignment[row_idx, :] = self.similar_array[unit, :] + + + log_probs = self.compute_soft_assignment() + probs = self.get_assign_probs(log_probs) + replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) + + replace_probs[:, 0] = 1 + replace_probs[self.idx_included, :] = probs + + replace_log[self.idx_included, :] = log_probs + + return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file diff --git a/src/yass/soft_assignment/template_REMOTE_31287.py b/src/yass/soft_assignment/template_REMOTE_31287.py new file mode 100644 index 00000000..20b92cef --- /dev/null +++ b/src/yass/soft_assignment/template_REMOTE_31287.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +Created on Mon Nov 18 21:39:48 2019 + +@author: kevin +""" + +import numpy as np +from tqdm import tqdm +import scipy.spatial.distance as dist +import torch +import cudaSpline as deconv +from scipy.interpolate import splrep +from numpy.linalg import inv as inv + +def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): + if knots is None: + knots = np.arange(len(curve) + prepad + postpad) + return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) + +def transform_template(template, knots=None, prepad=7, postpad=3, order=3): + + if knots is None: + knots = np.arange(len(template.data[0]) + prepad + postpad) + splines = [ + fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) + for curve in template.data.cpu().numpy() + ] + coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') + return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) + +def get_cov_matrix(spat_cov, geom): + posistion = geom + dist_matrix = dist.squareform(dist.pdist(geom )) + + cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) + + for i in range(posistion.shape[0]): + for j in range(posistion.shape[0]): + if dist_matrix[i, j] > np.max(spat_cov[:, 1]): + cov_matrix[i, j] = 0 + continue + idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] + if len(idx) == 0: + cov_matrix[i, j] = 0 + continue + cov_matrix[i, j] = spat_cov[idx, 0] + return cov_matrix + +#Soft assign object + +class TEMPLATE_ASSIGN_OBJECT(object): + def __init__(self, fname_spike_train, fname_templates, fname_shifts, + reader_residual, spat_cov, temp_cov, channel_idx, geom, + large_unit_threshold = 5, n_chans = 5, rec_chans = 512, + sim_units = 3, temp_thresh= np.inf, lik_window = 50): + + #get the variance of the residual: + self.temp_thresh = temp_thresh + self.rec_chans = rec_chans + self.sim_units = sim_units + self.templates = np.load(fname_templates).astype('float32') + self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) + self.spike_train = np.load(fname_spike_train) + self.spike_train_og = np.load(fname_spike_train) + #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] + self.idx_included = set([]) + self.units_in = set([]) + self.shifts = np.load(fname_shifts) + self.reader_residual = reader_residual + self.spat_cov = get_cov_matrix(spat_cov, geom) + self.temp_cov = temp_cov[:lik_window, :lik_window] + self.channel_index = channel_idx + self.n_neigh_chans = self.channel_index.shape[1] + self.n_chans = n_chans + self.n_units, self.n_times, self.n_channels = self.templates.shape + + self.n_total_spikes = self.spike_train.shape[0] + + #get residual variance + self.get_residual_variance() + + self.get_similar() + self.exclude_large_units(large_unit_threshold) + #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] + self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) + self.idx_included = np.logical_and(self.spike_train_og[:, 0] > 40, np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) + self.spike_train = self.spike_train_og[self.idx_included] + self.shifts = self.shifts[self.idx_included] + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + + self.aligned_template_list = [] + self.coeff_list = [] + self.preprocess_templates_and_spike_times() + self.chan_list = [] + + for i in range(0, self.sim_units): + diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) + for j in range(0, self.n_units): + diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) + + self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) + self.aligned_template_list.append(diff_array) + + #align orignal templates at end + self.aligned_template_list.append(self.templates_aligned) + self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) + + #get aligned templates + + + self.move_to_torch() + self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) + self.chans = torch.from_numpy(self.chans) + self.get_kronecker() + def get_residual_variance(self): + num = int(60/self.reader_residual.n_sec_chunk) + var_array = np.zeros(num) + for batch_id in range(num): + var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) + self.resid_var = np.mean(var_array) + + def get_kronecker(self): + self.cov_list = [] + inv_temp = inv(self.temp_cov) + for unit in range(self.n_units): + chans = self.chans[unit] + chans = chans[chans < self.rec_chans] + indices = np.ix_(chans,chans) + covar = np.kron(inv(self.spat_cov[indices]), inv_temp) + self.cov_list.append(covar) + self.cov_list = np.asarray(self.cov_list) + self.cov_list = torch.from_numpy(self.cov_list).half().cuda() + + def get_similar(self): + max_time = self.templates.argmax(1).astype("int16") + padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) + reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) + for unit in range(self.n_units): + for chan in range(self.rec_chans): + reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] + + see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + for i in range(self.n_units): + sorted_see = np.sort(see[i])[0:self.sim_units] + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(reduced[i]))) + #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + # self.units_in.add(i) + if sorted_see[1]/norm < self.temp_thresh: + self.units_in.add(i) + + ''' + def get_similar(self): + see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) + self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") + units_in = [] + for i in range(self.n_units): + self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] + norm = np.sqrt(np.sum(np.square(self.templates[i]))) + if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: + self.units_in.add(i) + #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) + + #units_in= np.asarray(units_in) + #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(in_spikes) + ''' + #shift secondary template + def shift_template(self, template, shift): + if shift == 0: + return template + if shift > 0: + return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] + else: + return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] + + def preprocess_templates_and_spike_times(self): + + # templates on neighboring channels + self.mcs = self.templates.ptp(1).argmax(1) + + #template used for alignment defined on neighboring channels + templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) + + #template returned for likilihood calculation- defined on channels with largest ptp + return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) + for k in range(self.n_units): + neigh_chans = self.channel_index[self.mcs[k]] + neigh_chans = neigh_chans[neigh_chans 0: + buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) + return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) + + # get ailgned templates + t_in = np.arange(buffer_size, buffer_size + self.n_times) + + + templates_aligned = np.zeros((self.n_units, + self.n_times, + self.n_chans), 'float32') + for k in range(self.n_units): + t_in_temp = t_in + self.temp_shifts[k] + templates_aligned[k] = return_templates[k,t_in_temp] + + self.templates_aligned = templates_aligned + self.templates_aligned_numpy= templates_aligned + # shift spike times according to the alignment + self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] + + #shifted neighboring template according to shift in primary template + def subtract_template(self, primary_unit, neighbor_unit): + primary_unit_shift = self.temp_shifts[primary_unit] + shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) + return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] + + def exclude_large_units(self, threshold): + + norms = np.zeros(self.n_units) + for j in range(self.n_units): + temp = self.templates[j] + vis_chan = np.where(temp.ptp(0) > 1)[0] + norms[j] = np.sum(np.square(temp[:, vis_chan])) + + + self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) + + #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] + #self.idx_included.update(self.idx_included.intersection(idx_in)) + def move_to_torch(self): + self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] + self.spike_train = torch.from_numpy(self.spike_train).long().cuda() + self.shifts = torch.from_numpy(self.shifts).float().cuda() + + self.mcs = torch.from_numpy(self.mcs) + + def get_bspline_coeffs(self, template_aligned): + + n_data, n_times, n_channels = template_aligned.shape + + channels = torch.arange(n_channels).cuda() + temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() + + temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) + coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) + return coeffs + def get_shifted_templates(self, temp_ids, shifts, iteration): + + temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() + shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() + + n_sample_run = 1000 + n_times = self.aligned_template_list[iteration].shape[1] + + idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) + + shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() + for j in range(len(idx_run)-1): + ii_start = idx_run[j] + ii_end =idx_run[j+1] + obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() + times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() + deconv.subtract_splines(obj, + times, + shifts[ii_start:ii_end], + temp_ids[ii_start:ii_end], + self.coeff_list[iteration], + torch.full( (ii_end - ii_start, ), 2 ).cuda()) + obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) + shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) + + return shifted_templates + + def get_liklihood(self, unit, snip): + chans = self.chans[unit] + chans = chans < self.rec_chans + log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) + return log_prob + + def compute_soft_assignment(self): + + log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() + + # batch offsets + offsets = torch.from_numpy(self.reader_residual.idx_list[:, 0] + - self.reader_residual.buffer).cuda().long() + + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] + + spike_train_batch = self.spike_train[idx_in] + spike_train_batch[:, 0] -= offsets[batch_id] + + shift_batch = self.shifts[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] + shifted_templates = [element for element in shifted_templates] + + + clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] + unit_idx = spike_train_batch[:, 1] + logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() + for i in range(len(clean_wfs)): + cov_array = self.cov_list[unit_idx] + + restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] + unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) + temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) + result = torch.matmul( temp, unraveled[:, :, None].half()) + logs[:, i] = result.reshape(-1) + ''' + log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) + logs[:, i] = np.asarray(log_vec) + ''' + ''' + for j, spike in enumerate(clean_wfs[i]): + rel_unit = self.similar_array[unit_idx[j]][i] + logs[j, i]= self.get_liklihood(rel_unit, spike) + ''' + log_probs[idx_in] = logs + + pbar.update() + + + + self.log_probs = log_probs.cpu().numpy() + return log_probs.cpu().numpy() + def clean_wave_forms(self, spike_idx, unit): + return_wfs = torch.zeros((spike_idx.shape[0], 81, 5)) + with tqdm(total=self.reader_residual.n_batches) as pbar: + for batch_id in range(self.reader_residual.n_batches): + + # load residual data + resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) + resid_dat = torch.from_numpy(resid_dat).cuda() + + # relevant idx + s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] + s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] + + idx_in = torch.nonzero((s1 & s2))[:,0] + + spike_train_batch = self.spike_train[spike_idx] + spike_train_batch = spike_train_batch[idx_in] + spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) + + shift_batch = self.shifts[spike_idx] + shift_batch = shift_batch[idx_in] + # get residual snippets + + t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() + c_index = self.chans[spike_train_batch[:, 1]].long() + resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) + resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] + # get shifted templates + + shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) + clean_wfs = resid_snippets + shifted_og + return_wfs[idx_in] = clean_wfs.cpu() + + return return_wfs.cpu().numpy() + + def get_assign_probs(self, log_lik_array): + fix = log_lik_array*-.5 + fix = fix - fix.max(1)[:, None] + probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] + self.probs = probs + return probs + + def run(self): + + #construct array to identify soft assignment units + unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + for unit in range(self.n_units): + row_idx= self.spike_train_og[:, 1] == unit + unit_assignment[row_idx, :] = self.similar_array[unit, :] + + + log_probs = self.compute_soft_assignment() + probs = self.get_assign_probs(log_probs) + replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) + replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) + + replace_probs[:, 0] = 1 + replace_probs[self.idx_included, :] = probs + + replace_log[self.idx_included, :] = log_probs + + return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file From 00ba3e3eb1f992cb15a5a50a5e38d83657f3cdf0 Mon Sep 17 00:00:00 2001 From: kevinli1324 Date: Tue, 28 Jan 2020 04:27:33 -0500 Subject: [PATCH 03/19] Update util.py --- src/yass/preprocess/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/yass/preprocess/util.py b/src/yass/preprocess/util.py index 77d84a9d..3e657858 100644 --- a/src/yass/preprocess/util.py +++ b/src/yass/preprocess/util.py @@ -106,8 +106,8 @@ def _standardize(rec, sd=None, centers=None): rec[:,idx1] = np.divide(rec[:,idx1] - centers[idx1][None], sd[idx1]) # zero out bad channels - #idx2 = np.where(sd<0.1)[0] - #rec[:,idx2]=0. + idx2 = np.where(sd<0.1)[0] + rec[:,idx2]=0. return rec #return np.divide(rec, sd) From 25b698e0f1741a2669d5bb33cf894b17b503e4f9 Mon Sep 17 00:00:00 2001 From: kevinli1324 Date: Tue, 28 Jan 2020 04:28:33 -0500 Subject: [PATCH 04/19] Delete run.py --- src/yass/reordering/run.py | 62 -------------------------------------- 1 file changed, 62 deletions(-) delete mode 100644 src/yass/reordering/run.py diff --git a/src/yass/reordering/run.py b/src/yass/reordering/run.py deleted file mode 100644 index 266c5a38..00000000 --- a/src/yass/reordering/run.py +++ /dev/null @@ -1,62 +0,0 @@ -import yass.reordering.utils -import yass.reordering.cluster -import yass.reordering.default_params -import yass.reordering -from yass import read_config -import numpy as np -from yass.config import Config -from yass.reordering.preprocess import get_good_channels -import os - -#initialize object - - -class PARAM: - pass - -class PROBE: - pass - -def run(save_path, standardized_fname, CONFIG,n_sec_chunk, nPCs = 3, nt0 = 61, reorder = True, dtype = np.float32 ): - - - - params = PARAM() - probe = PROBE() - - params.sigmaMask = 30 - params.Nchan = CONFIG.recordings.n_channels - params.nPCs = nPCs - params.fs = CONFIG.recordings.sampling_rate - - #magic numbers from KS - #params.fshigh = 150. - #params.minfr_goodchannels = 0.1 - params.Th = [10, 4] - - #spkTh is the PCA threshold for detecting a spike - params.spkTh = -6 - params.ThPre = 8 - ## - params.loc_range = [5, 4] - params.long_range = [30, 6] - - probe.chanMap = np.arange(params.Nchan) - probe.xc = CONFIG.geom[:, 0] - probe.yc = CONFIG.geom[:, 1] - probe.kcoords = np.zeros(params.Nchan) - probe.Nchan = params.Nchan - shape = (params.Nchan, CONFIG.rec_len) - standardized_mmemap = np.memmap(standardized_fname, order = "F", dtype = dtype) - params.Nbatch = np.ceil(CONFIG.rec_len/(n_sec_chunk*CONFIG.recordings.sampling_rate)).astype(np.int16) - params.reorder = reorder - params.nt0min = np.ceil(20 * nt0 / 61).astype(np.int16) - - - result = yass.reordering.cluster.clusterSingleBatches(proc = standardized_mmemap, - params = params, - probe = probe, - yass_batch = params.Nbatch, - n_chunk_sec = int(n_sec_chunk*CONFIG.recordings.sampling_rate), - nt0 = nt0) - From bc007d981c67a39d19216c52a6d55b0b0da089d3 Mon Sep 17 00:00:00 2001 From: kevinli1324 Date: Tue, 28 Jan 2020 04:28:55 -0500 Subject: [PATCH 05/19] Delete template_BACKUP_31194.py --- .../soft_assignment/template_BACKUP_31194.py | 424 ------------------ 1 file changed, 424 deletions(-) delete mode 100644 src/yass/soft_assignment/template_BACKUP_31194.py diff --git a/src/yass/soft_assignment/template_BACKUP_31194.py b/src/yass/soft_assignment/template_BACKUP_31194.py deleted file mode 100644 index 62948032..00000000 --- a/src/yass/soft_assignment/template_BACKUP_31194.py +++ /dev/null @@ -1,424 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Created on Mon Nov 18 21:39:48 2019 - -@author: kevin -""" - -import numpy as np -from tqdm import tqdm -import scipy.spatial.distance as dist -import torch -import cudaSpline as deconv -from scipy.interpolate import splrep -from numpy.linalg import inv as inv - -def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): - if knots is None: - knots = np.arange(len(curve) + prepad + postpad) - return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) - -def transform_template(template, knots=None, prepad=7, postpad=3, order=3): - - if knots is None: - knots = np.arange(len(template.data[0]) + prepad + postpad) - splines = [ - fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) - for curve in template.data.cpu().numpy() - ] - coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') - return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) - -def get_cov_matrix(spat_cov, geom): - posistion = geom - dist_matrix = dist.squareform(dist.pdist(geom )) - - cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) - - for i in range(posistion.shape[0]): - for j in range(posistion.shape[0]): - if dist_matrix[i, j] > np.max(spat_cov[:, 1]): - cov_matrix[i, j] = 0 - continue - idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] - if len(idx) == 0: - cov_matrix[i, j] = 0 - continue - cov_matrix[i, j] = spat_cov[idx, 0] - return cov_matrix - -#Soft assign object - -class TEMPLATE_ASSIGN_OBJECT(object): - def __init__(self, fname_spike_train, fname_templates, fname_shifts, - reader_residual, spat_cov, temp_cov, channel_idx, geom, - large_unit_threshold = 5, n_chans = 5, rec_chans = 512, - sim_units = 3, temp_thresh= np.inf, lik_window = 50): - - #get the variance of the residual: - self.temp_thresh = temp_thresh - self.rec_chans = rec_chans - self.sim_units = sim_units - self.templates = np.load(fname_templates).astype('float32') - self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) - self.spike_train = np.load(fname_spike_train) - self.spike_train_og = np.load(fname_spike_train) - #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] - self.idx_included = set([]) - self.units_in = set([]) - self.shifts = np.load(fname_shifts) - self.reader_residual = reader_residual - self.spat_cov = get_cov_matrix(spat_cov, geom) - self.temp_cov = temp_cov[:lik_window, :lik_window] - self.channel_index = channel_idx - self.n_neigh_chans = self.channel_index.shape[1] - self.n_chans = n_chans - self.n_units, self.n_times, self.n_channels = self.templates.shape - - self.n_total_spikes = self.spike_train.shape[0] - - #get residual variance - self.get_residual_variance() - - self.get_similar() - self.exclude_large_units(large_unit_threshold) - #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] - self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) - self.idx_included = np.logical_and(np.logical_and(self.spike_train_og[:, 0] < reader_residual.rec_len - self.n_times//2, self.spike_train_og[:, 0] > self.n_times//2), np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) - self.spike_train = self.spike_train_og[self.idx_included] - self.shifts = self.shifts[self.idx_included] - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - - self.aligned_template_list = [] - self.coeff_list = [] - self.preprocess_templates_and_spike_times() - self.chan_list = [] - - for i in range(0, self.sim_units): - diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) - for j in range(0, self.n_units): - diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) - - self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) - self.aligned_template_list.append(diff_array) - - #align orignal templates at end - self.aligned_template_list.append(self.templates_aligned) - self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) - - #get aligned templates - - - self.move_to_torch() - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - self.chans = torch.from_numpy(self.chans) - self.get_kronecker() - def get_residual_variance(self): - num = int(60/self.reader_residual.n_sec_chunk) - var_array = np.zeros(num) - for batch_id in range(num): - var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) - self.resid_var = np.mean(var_array) - - def get_kronecker(self): - self.cov_list = [] - inv_temp = inv(self.temp_cov) - for unit in range(self.n_units): - chans = self.chans[unit] - chans = chans[chans < self.rec_chans] - indices = np.ix_(chans,chans) - covar = np.kron(inv(self.spat_cov[indices]), inv_temp) - self.cov_list.append(covar) - self.cov_list = np.asarray(self.cov_list) - self.cov_list = torch.from_numpy(self.cov_list).half().cuda() - - def get_similar(self): - max_time = self.templates.argmax(1).astype("int16") - padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) - reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) - for unit in range(self.n_units): - for chan in range(self.rec_chans): - reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] - - see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - for i in range(self.n_units): - sorted_see = np.sort(see[i])[0:self.sim_units] - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(reduced[i]))) - #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - # self.units_in.add(i) - if sorted_see[1]/norm < self.temp_thresh: - self.units_in.add(i) - - ''' - def get_similar(self): - see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - units_in = [] - for i in range(self.n_units): - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(self.templates[i]))) - if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - self.units_in.add(i) - #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) - - #units_in= np.asarray(units_in) - #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(in_spikes) - ''' - #shift secondary template - def shift_template(self, template, shift): - if shift == 0: - return template - if shift > 0: - return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] - else: - return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] - - def preprocess_templates_and_spike_times(self): - - # templates on neighboring channels - self.mcs = self.templates.ptp(1).argmax(1) - - #template used for alignment defined on neighboring channels - templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) - - #template returned for likilihood calculation- defined on channels with largest ptp - return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) - for k in range(self.n_units): - neigh_chans = self.channel_index[self.mcs[k]] - neigh_chans = neigh_chans[neigh_chans 0: - buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) - return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) - - # get ailgned templates - t_in = np.arange(buffer_size, buffer_size + self.n_times) - - - templates_aligned = np.zeros((self.n_units, - self.n_times, - self.n_chans), 'float32') - for k in range(self.n_units): - t_in_temp = t_in + self.temp_shifts[k] - templates_aligned[k] = return_templates[k,t_in_temp] - - self.templates_aligned = templates_aligned - self.templates_aligned_numpy= templates_aligned - # shift spike times according to the alignment - self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] - - #shifted neighboring template according to shift in primary template - def subtract_template(self, primary_unit, neighbor_unit): - primary_unit_shift = self.temp_shifts[primary_unit] - shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) - return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] - - def exclude_large_units(self, threshold): - - norms = np.zeros(self.n_units) - for j in range(self.n_units): - temp = self.templates[j] - vis_chan = np.where(temp.ptp(0) > 1)[0] - norms[j] = np.sum(np.square(temp[:, vis_chan])) - - - self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) - - #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(self.idx_included.intersection(idx_in)) - def move_to_torch(self): - self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] - self.spike_train = torch.from_numpy(self.spike_train).long().cuda() - self.shifts = torch.from_numpy(self.shifts).float().cuda() - - self.mcs = torch.from_numpy(self.mcs) - - def get_bspline_coeffs(self, template_aligned): - - n_data, n_times, n_channels = template_aligned.shape - - channels = torch.arange(n_channels).cuda() - temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() - - temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) - coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) - return coeffs - def get_shifted_templates(self, temp_ids, shifts, iteration): - temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() - shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() - - n_sample_run = 1000 - n_times = self.aligned_template_list[iteration].shape[1] - - idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) - - shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() - for j in range(len(idx_run)-1): - ii_start = idx_run[j] - ii_end =idx_run[j+1] - obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() - times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() - deconv.subtract_splines(obj, - times, - shifts[ii_start:ii_end], - temp_ids[ii_start:ii_end], - self.coeff_list[iteration], - torch.full( (ii_end - ii_start, ), 2 ).cuda()) - obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) - shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) - - return shifted_templates - - def get_liklihood(self, unit, snip): - chans = self.chans[unit] - chans = chans < self.rec_chans - log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) - return log_prob - def compute_soft_assignment(self): - - log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() - - # batch offsets - offsets = torch.from_numpy(self.reader_residual.idx_list[:, 0] - - self.reader_residual.buffer).cuda().long() - - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] - - spike_train_batch = self.spike_train[idx_in] - - - -<<<<<<< HEAD - spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) -======= - spike_train_batch = self.spike_train[idx_in] - spike_train_batch[:, 0] -= offsets[batch_id] - ->>>>>>> 3076e420ae3d8ecd83625301c6c3e3ed46f7cca7 - shift_batch = self.shifts[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - - shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] - shifted_templates = [element for element in shifted_templates] - - - clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] - unit_idx = spike_train_batch[:, 1] - logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() - for i in range(len(clean_wfs)): - cov_array = self.cov_list[unit_idx] - - restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] - unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) - temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) - result = torch.matmul( temp, unraveled[:, :, None].half()) - logs[:, i] = result.reshape(-1) - ''' - log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) - logs[:, i] = np.asarray(log_vec) - ''' - ''' - for j, spike in enumerate(clean_wfs[i]): - rel_unit = self.similar_array[unit_idx[j]][i] - logs[j, i]= self.get_liklihood(rel_unit, spike) - ''' - log_probs[idx_in] = logs - - pbar.update() - - - - self.log_probs = log_probs.cpu().numpy() - return log_probs.cpu().numpy() - def clean_wave_forms(self, spike_idx, unit): - return_wfs = torch.zeros((spike_idx.shape[0],self.templates.shape[1], self.n_chans)) - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] - s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] - - idx_in = torch.nonzero((s1 & s2))[:,0] - - spike_train_batch = self.spike_train[spike_idx] - spike_train_batch = spike_train_batch[idx_in] - spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) - - shift_batch = self.shifts[spike_idx] - shift_batch = shift_batch[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - - shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) - clean_wfs = resid_snippets + shifted_og - return_wfs[idx_in] = clean_wfs.cpu() - - return return_wfs.cpu().numpy() - - def get_assign_probs(self, log_lik_array): - fix = log_lik_array*-.5 - fix = fix - fix.max(1)[:, None] - probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] - self.probs = probs - return probs - - def run(self): - - #construct array to identify soft assignment units - unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - for unit in range(self.n_units): - row_idx= self.spike_train_og[:, 1] == unit - unit_assignment[row_idx, :] = self.similar_array[unit, :] - - - log_probs = self.compute_soft_assignment() - self.log_probs = log_probs - probs = self.get_assign_probs(log_probs) - replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) - - replace_probs[:, 0] = 1 - replace_probs[self.idx_included, :] = probs - - replace_log[self.idx_included, :] = log_probs - - return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file From ffbf548ab3921dc182ecea995c6d430a2ace05f2 Mon Sep 17 00:00:00 2001 From: kevinli1324 Date: Tue, 28 Jan 2020 04:29:12 -0500 Subject: [PATCH 06/19] Delete template_REMOTE_31287.py --- .../soft_assignment/template_REMOTE_31287.py | 416 ------------------ 1 file changed, 416 deletions(-) delete mode 100644 src/yass/soft_assignment/template_REMOTE_31287.py diff --git a/src/yass/soft_assignment/template_REMOTE_31287.py b/src/yass/soft_assignment/template_REMOTE_31287.py deleted file mode 100644 index 20b92cef..00000000 --- a/src/yass/soft_assignment/template_REMOTE_31287.py +++ /dev/null @@ -1,416 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Created on Mon Nov 18 21:39:48 2019 - -@author: kevin -""" - -import numpy as np -from tqdm import tqdm -import scipy.spatial.distance as dist -import torch -import cudaSpline as deconv -from scipy.interpolate import splrep -from numpy.linalg import inv as inv - -def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): - if knots is None: - knots = np.arange(len(curve) + prepad + postpad) - return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) - -def transform_template(template, knots=None, prepad=7, postpad=3, order=3): - - if knots is None: - knots = np.arange(len(template.data[0]) + prepad + postpad) - splines = [ - fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) - for curve in template.data.cpu().numpy() - ] - coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') - return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) - -def get_cov_matrix(spat_cov, geom): - posistion = geom - dist_matrix = dist.squareform(dist.pdist(geom )) - - cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) - - for i in range(posistion.shape[0]): - for j in range(posistion.shape[0]): - if dist_matrix[i, j] > np.max(spat_cov[:, 1]): - cov_matrix[i, j] = 0 - continue - idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] - if len(idx) == 0: - cov_matrix[i, j] = 0 - continue - cov_matrix[i, j] = spat_cov[idx, 0] - return cov_matrix - -#Soft assign object - -class TEMPLATE_ASSIGN_OBJECT(object): - def __init__(self, fname_spike_train, fname_templates, fname_shifts, - reader_residual, spat_cov, temp_cov, channel_idx, geom, - large_unit_threshold = 5, n_chans = 5, rec_chans = 512, - sim_units = 3, temp_thresh= np.inf, lik_window = 50): - - #get the variance of the residual: - self.temp_thresh = temp_thresh - self.rec_chans = rec_chans - self.sim_units = sim_units - self.templates = np.load(fname_templates).astype('float32') - self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) - self.spike_train = np.load(fname_spike_train) - self.spike_train_og = np.load(fname_spike_train) - #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] - self.idx_included = set([]) - self.units_in = set([]) - self.shifts = np.load(fname_shifts) - self.reader_residual = reader_residual - self.spat_cov = get_cov_matrix(spat_cov, geom) - self.temp_cov = temp_cov[:lik_window, :lik_window] - self.channel_index = channel_idx - self.n_neigh_chans = self.channel_index.shape[1] - self.n_chans = n_chans - self.n_units, self.n_times, self.n_channels = self.templates.shape - - self.n_total_spikes = self.spike_train.shape[0] - - #get residual variance - self.get_residual_variance() - - self.get_similar() - self.exclude_large_units(large_unit_threshold) - #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] - self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) - self.idx_included = np.logical_and(self.spike_train_og[:, 0] > 40, np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) - self.spike_train = self.spike_train_og[self.idx_included] - self.shifts = self.shifts[self.idx_included] - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - - self.aligned_template_list = [] - self.coeff_list = [] - self.preprocess_templates_and_spike_times() - self.chan_list = [] - - for i in range(0, self.sim_units): - diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) - for j in range(0, self.n_units): - diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) - - self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) - self.aligned_template_list.append(diff_array) - - #align orignal templates at end - self.aligned_template_list.append(self.templates_aligned) - self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) - - #get aligned templates - - - self.move_to_torch() - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - self.chans = torch.from_numpy(self.chans) - self.get_kronecker() - def get_residual_variance(self): - num = int(60/self.reader_residual.n_sec_chunk) - var_array = np.zeros(num) - for batch_id in range(num): - var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) - self.resid_var = np.mean(var_array) - - def get_kronecker(self): - self.cov_list = [] - inv_temp = inv(self.temp_cov) - for unit in range(self.n_units): - chans = self.chans[unit] - chans = chans[chans < self.rec_chans] - indices = np.ix_(chans,chans) - covar = np.kron(inv(self.spat_cov[indices]), inv_temp) - self.cov_list.append(covar) - self.cov_list = np.asarray(self.cov_list) - self.cov_list = torch.from_numpy(self.cov_list).half().cuda() - - def get_similar(self): - max_time = self.templates.argmax(1).astype("int16") - padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) - reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) - for unit in range(self.n_units): - for chan in range(self.rec_chans): - reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] - - see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - for i in range(self.n_units): - sorted_see = np.sort(see[i])[0:self.sim_units] - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(reduced[i]))) - #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - # self.units_in.add(i) - if sorted_see[1]/norm < self.temp_thresh: - self.units_in.add(i) - - ''' - def get_similar(self): - see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - units_in = [] - for i in range(self.n_units): - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(self.templates[i]))) - if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - self.units_in.add(i) - #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) - - #units_in= np.asarray(units_in) - #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(in_spikes) - ''' - #shift secondary template - def shift_template(self, template, shift): - if shift == 0: - return template - if shift > 0: - return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] - else: - return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] - - def preprocess_templates_and_spike_times(self): - - # templates on neighboring channels - self.mcs = self.templates.ptp(1).argmax(1) - - #template used for alignment defined on neighboring channels - templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) - - #template returned for likilihood calculation- defined on channels with largest ptp - return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) - for k in range(self.n_units): - neigh_chans = self.channel_index[self.mcs[k]] - neigh_chans = neigh_chans[neigh_chans 0: - buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) - return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) - - # get ailgned templates - t_in = np.arange(buffer_size, buffer_size + self.n_times) - - - templates_aligned = np.zeros((self.n_units, - self.n_times, - self.n_chans), 'float32') - for k in range(self.n_units): - t_in_temp = t_in + self.temp_shifts[k] - templates_aligned[k] = return_templates[k,t_in_temp] - - self.templates_aligned = templates_aligned - self.templates_aligned_numpy= templates_aligned - # shift spike times according to the alignment - self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] - - #shifted neighboring template according to shift in primary template - def subtract_template(self, primary_unit, neighbor_unit): - primary_unit_shift = self.temp_shifts[primary_unit] - shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) - return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] - - def exclude_large_units(self, threshold): - - norms = np.zeros(self.n_units) - for j in range(self.n_units): - temp = self.templates[j] - vis_chan = np.where(temp.ptp(0) > 1)[0] - norms[j] = np.sum(np.square(temp[:, vis_chan])) - - - self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) - - #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(self.idx_included.intersection(idx_in)) - def move_to_torch(self): - self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] - self.spike_train = torch.from_numpy(self.spike_train).long().cuda() - self.shifts = torch.from_numpy(self.shifts).float().cuda() - - self.mcs = torch.from_numpy(self.mcs) - - def get_bspline_coeffs(self, template_aligned): - - n_data, n_times, n_channels = template_aligned.shape - - channels = torch.arange(n_channels).cuda() - temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() - - temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) - coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) - return coeffs - def get_shifted_templates(self, temp_ids, shifts, iteration): - - temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() - shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() - - n_sample_run = 1000 - n_times = self.aligned_template_list[iteration].shape[1] - - idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) - - shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() - for j in range(len(idx_run)-1): - ii_start = idx_run[j] - ii_end =idx_run[j+1] - obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() - times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() - deconv.subtract_splines(obj, - times, - shifts[ii_start:ii_end], - temp_ids[ii_start:ii_end], - self.coeff_list[iteration], - torch.full( (ii_end - ii_start, ), 2 ).cuda()) - obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) - shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) - - return shifted_templates - - def get_liklihood(self, unit, snip): - chans = self.chans[unit] - chans = chans < self.rec_chans - log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) - return log_prob - - def compute_soft_assignment(self): - - log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() - - # batch offsets - offsets = torch.from_numpy(self.reader_residual.idx_list[:, 0] - - self.reader_residual.buffer).cuda().long() - - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] - - spike_train_batch = self.spike_train[idx_in] - spike_train_batch[:, 0] -= offsets[batch_id] - - shift_batch = self.shifts[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] - shifted_templates = [element for element in shifted_templates] - - - clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] - unit_idx = spike_train_batch[:, 1] - logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() - for i in range(len(clean_wfs)): - cov_array = self.cov_list[unit_idx] - - restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] - unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) - temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) - result = torch.matmul( temp, unraveled[:, :, None].half()) - logs[:, i] = result.reshape(-1) - ''' - log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) - logs[:, i] = np.asarray(log_vec) - ''' - ''' - for j, spike in enumerate(clean_wfs[i]): - rel_unit = self.similar_array[unit_idx[j]][i] - logs[j, i]= self.get_liklihood(rel_unit, spike) - ''' - log_probs[idx_in] = logs - - pbar.update() - - - - self.log_probs = log_probs.cpu().numpy() - return log_probs.cpu().numpy() - def clean_wave_forms(self, spike_idx, unit): - return_wfs = torch.zeros((spike_idx.shape[0], 81, 5)) - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] - s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] - - idx_in = torch.nonzero((s1 & s2))[:,0] - - spike_train_batch = self.spike_train[spike_idx] - spike_train_batch = spike_train_batch[idx_in] - spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) - - shift_batch = self.shifts[spike_idx] - shift_batch = shift_batch[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - - shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) - clean_wfs = resid_snippets + shifted_og - return_wfs[idx_in] = clean_wfs.cpu() - - return return_wfs.cpu().numpy() - - def get_assign_probs(self, log_lik_array): - fix = log_lik_array*-.5 - fix = fix - fix.max(1)[:, None] - probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] - self.probs = probs - return probs - - def run(self): - - #construct array to identify soft assignment units - unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - for unit in range(self.n_units): - row_idx= self.spike_train_og[:, 1] == unit - unit_assignment[row_idx, :] = self.similar_array[unit, :] - - - log_probs = self.compute_soft_assignment() - probs = self.get_assign_probs(log_probs) - replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) - - replace_probs[:, 0] = 1 - replace_probs[self.idx_included, :] = probs - - replace_log[self.idx_included, :] = log_probs - - return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file From 7757d8341544a2e1404fdbdd4749ac92ad2ce2ee Mon Sep 17 00:00:00 2001 From: kevinli1324 Date: Tue, 28 Jan 2020 04:29:28 -0500 Subject: [PATCH 07/19] Delete template_BACKUP_31287.py --- .../soft_assignment/template_BACKUP_31287.py | 424 ------------------ 1 file changed, 424 deletions(-) delete mode 100644 src/yass/soft_assignment/template_BACKUP_31287.py diff --git a/src/yass/soft_assignment/template_BACKUP_31287.py b/src/yass/soft_assignment/template_BACKUP_31287.py deleted file mode 100644 index 62948032..00000000 --- a/src/yass/soft_assignment/template_BACKUP_31287.py +++ /dev/null @@ -1,424 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Created on Mon Nov 18 21:39:48 2019 - -@author: kevin -""" - -import numpy as np -from tqdm import tqdm -import scipy.spatial.distance as dist -import torch -import cudaSpline as deconv -from scipy.interpolate import splrep -from numpy.linalg import inv as inv - -def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): - if knots is None: - knots = np.arange(len(curve) + prepad + postpad) - return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) - -def transform_template(template, knots=None, prepad=7, postpad=3, order=3): - - if knots is None: - knots = np.arange(len(template.data[0]) + prepad + postpad) - splines = [ - fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) - for curve in template.data.cpu().numpy() - ] - coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') - return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) - -def get_cov_matrix(spat_cov, geom): - posistion = geom - dist_matrix = dist.squareform(dist.pdist(geom )) - - cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) - - for i in range(posistion.shape[0]): - for j in range(posistion.shape[0]): - if dist_matrix[i, j] > np.max(spat_cov[:, 1]): - cov_matrix[i, j] = 0 - continue - idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] - if len(idx) == 0: - cov_matrix[i, j] = 0 - continue - cov_matrix[i, j] = spat_cov[idx, 0] - return cov_matrix - -#Soft assign object - -class TEMPLATE_ASSIGN_OBJECT(object): - def __init__(self, fname_spike_train, fname_templates, fname_shifts, - reader_residual, spat_cov, temp_cov, channel_idx, geom, - large_unit_threshold = 5, n_chans = 5, rec_chans = 512, - sim_units = 3, temp_thresh= np.inf, lik_window = 50): - - #get the variance of the residual: - self.temp_thresh = temp_thresh - self.rec_chans = rec_chans - self.sim_units = sim_units - self.templates = np.load(fname_templates).astype('float32') - self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) - self.spike_train = np.load(fname_spike_train) - self.spike_train_og = np.load(fname_spike_train) - #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] - self.idx_included = set([]) - self.units_in = set([]) - self.shifts = np.load(fname_shifts) - self.reader_residual = reader_residual - self.spat_cov = get_cov_matrix(spat_cov, geom) - self.temp_cov = temp_cov[:lik_window, :lik_window] - self.channel_index = channel_idx - self.n_neigh_chans = self.channel_index.shape[1] - self.n_chans = n_chans - self.n_units, self.n_times, self.n_channels = self.templates.shape - - self.n_total_spikes = self.spike_train.shape[0] - - #get residual variance - self.get_residual_variance() - - self.get_similar() - self.exclude_large_units(large_unit_threshold) - #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] - self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) - self.idx_included = np.logical_and(np.logical_and(self.spike_train_og[:, 0] < reader_residual.rec_len - self.n_times//2, self.spike_train_og[:, 0] > self.n_times//2), np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) - self.spike_train = self.spike_train_og[self.idx_included] - self.shifts = self.shifts[self.idx_included] - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - - self.aligned_template_list = [] - self.coeff_list = [] - self.preprocess_templates_and_spike_times() - self.chan_list = [] - - for i in range(0, self.sim_units): - diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) - for j in range(0, self.n_units): - diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) - - self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) - self.aligned_template_list.append(diff_array) - - #align orignal templates at end - self.aligned_template_list.append(self.templates_aligned) - self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) - - #get aligned templates - - - self.move_to_torch() - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - self.chans = torch.from_numpy(self.chans) - self.get_kronecker() - def get_residual_variance(self): - num = int(60/self.reader_residual.n_sec_chunk) - var_array = np.zeros(num) - for batch_id in range(num): - var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) - self.resid_var = np.mean(var_array) - - def get_kronecker(self): - self.cov_list = [] - inv_temp = inv(self.temp_cov) - for unit in range(self.n_units): - chans = self.chans[unit] - chans = chans[chans < self.rec_chans] - indices = np.ix_(chans,chans) - covar = np.kron(inv(self.spat_cov[indices]), inv_temp) - self.cov_list.append(covar) - self.cov_list = np.asarray(self.cov_list) - self.cov_list = torch.from_numpy(self.cov_list).half().cuda() - - def get_similar(self): - max_time = self.templates.argmax(1).astype("int16") - padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) - reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) - for unit in range(self.n_units): - for chan in range(self.rec_chans): - reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] - - see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - for i in range(self.n_units): - sorted_see = np.sort(see[i])[0:self.sim_units] - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(reduced[i]))) - #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - # self.units_in.add(i) - if sorted_see[1]/norm < self.temp_thresh: - self.units_in.add(i) - - ''' - def get_similar(self): - see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - units_in = [] - for i in range(self.n_units): - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(self.templates[i]))) - if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - self.units_in.add(i) - #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) - - #units_in= np.asarray(units_in) - #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(in_spikes) - ''' - #shift secondary template - def shift_template(self, template, shift): - if shift == 0: - return template - if shift > 0: - return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] - else: - return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] - - def preprocess_templates_and_spike_times(self): - - # templates on neighboring channels - self.mcs = self.templates.ptp(1).argmax(1) - - #template used for alignment defined on neighboring channels - templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) - - #template returned for likilihood calculation- defined on channels with largest ptp - return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) - for k in range(self.n_units): - neigh_chans = self.channel_index[self.mcs[k]] - neigh_chans = neigh_chans[neigh_chans 0: - buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) - return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) - - # get ailgned templates - t_in = np.arange(buffer_size, buffer_size + self.n_times) - - - templates_aligned = np.zeros((self.n_units, - self.n_times, - self.n_chans), 'float32') - for k in range(self.n_units): - t_in_temp = t_in + self.temp_shifts[k] - templates_aligned[k] = return_templates[k,t_in_temp] - - self.templates_aligned = templates_aligned - self.templates_aligned_numpy= templates_aligned - # shift spike times according to the alignment - self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] - - #shifted neighboring template according to shift in primary template - def subtract_template(self, primary_unit, neighbor_unit): - primary_unit_shift = self.temp_shifts[primary_unit] - shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) - return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] - - def exclude_large_units(self, threshold): - - norms = np.zeros(self.n_units) - for j in range(self.n_units): - temp = self.templates[j] - vis_chan = np.where(temp.ptp(0) > 1)[0] - norms[j] = np.sum(np.square(temp[:, vis_chan])) - - - self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) - - #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(self.idx_included.intersection(idx_in)) - def move_to_torch(self): - self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] - self.spike_train = torch.from_numpy(self.spike_train).long().cuda() - self.shifts = torch.from_numpy(self.shifts).float().cuda() - - self.mcs = torch.from_numpy(self.mcs) - - def get_bspline_coeffs(self, template_aligned): - - n_data, n_times, n_channels = template_aligned.shape - - channels = torch.arange(n_channels).cuda() - temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() - - temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) - coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) - return coeffs - def get_shifted_templates(self, temp_ids, shifts, iteration): - temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() - shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() - - n_sample_run = 1000 - n_times = self.aligned_template_list[iteration].shape[1] - - idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) - - shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() - for j in range(len(idx_run)-1): - ii_start = idx_run[j] - ii_end =idx_run[j+1] - obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() - times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() - deconv.subtract_splines(obj, - times, - shifts[ii_start:ii_end], - temp_ids[ii_start:ii_end], - self.coeff_list[iteration], - torch.full( (ii_end - ii_start, ), 2 ).cuda()) - obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) - shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) - - return shifted_templates - - def get_liklihood(self, unit, snip): - chans = self.chans[unit] - chans = chans < self.rec_chans - log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) - return log_prob - def compute_soft_assignment(self): - - log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() - - # batch offsets - offsets = torch.from_numpy(self.reader_residual.idx_list[:, 0] - - self.reader_residual.buffer).cuda().long() - - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] - - spike_train_batch = self.spike_train[idx_in] - - - -<<<<<<< HEAD - spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) -======= - spike_train_batch = self.spike_train[idx_in] - spike_train_batch[:, 0] -= offsets[batch_id] - ->>>>>>> 3076e420ae3d8ecd83625301c6c3e3ed46f7cca7 - shift_batch = self.shifts[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - - shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] - shifted_templates = [element for element in shifted_templates] - - - clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] - unit_idx = spike_train_batch[:, 1] - logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() - for i in range(len(clean_wfs)): - cov_array = self.cov_list[unit_idx] - - restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] - unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) - temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) - result = torch.matmul( temp, unraveled[:, :, None].half()) - logs[:, i] = result.reshape(-1) - ''' - log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) - logs[:, i] = np.asarray(log_vec) - ''' - ''' - for j, spike in enumerate(clean_wfs[i]): - rel_unit = self.similar_array[unit_idx[j]][i] - logs[j, i]= self.get_liklihood(rel_unit, spike) - ''' - log_probs[idx_in] = logs - - pbar.update() - - - - self.log_probs = log_probs.cpu().numpy() - return log_probs.cpu().numpy() - def clean_wave_forms(self, spike_idx, unit): - return_wfs = torch.zeros((spike_idx.shape[0],self.templates.shape[1], self.n_chans)) - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] - s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] - - idx_in = torch.nonzero((s1 & s2))[:,0] - - spike_train_batch = self.spike_train[spike_idx] - spike_train_batch = spike_train_batch[idx_in] - spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) - - shift_batch = self.shifts[spike_idx] - shift_batch = shift_batch[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - - shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) - clean_wfs = resid_snippets + shifted_og - return_wfs[idx_in] = clean_wfs.cpu() - - return return_wfs.cpu().numpy() - - def get_assign_probs(self, log_lik_array): - fix = log_lik_array*-.5 - fix = fix - fix.max(1)[:, None] - probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] - self.probs = probs - return probs - - def run(self): - - #construct array to identify soft assignment units - unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - for unit in range(self.n_units): - row_idx= self.spike_train_og[:, 1] == unit - unit_assignment[row_idx, :] = self.similar_array[unit, :] - - - log_probs = self.compute_soft_assignment() - self.log_probs = log_probs - probs = self.get_assign_probs(log_probs) - replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) - - replace_probs[:, 0] = 1 - replace_probs[self.idx_included, :] = probs - - replace_log[self.idx_included, :] = log_probs - - return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file From 15fa6e94859fbd0b412e9292598ca56aa872fe35 Mon Sep 17 00:00:00 2001 From: kevinli1324 Date: Tue, 28 Jan 2020 04:31:55 -0500 Subject: [PATCH 08/19] Delete template_BASE_31194.py --- .../soft_assignment/template_BASE_31194.py | 412 ------------------ 1 file changed, 412 deletions(-) delete mode 100644 src/yass/soft_assignment/template_BASE_31194.py diff --git a/src/yass/soft_assignment/template_BASE_31194.py b/src/yass/soft_assignment/template_BASE_31194.py deleted file mode 100644 index 0955c193..00000000 --- a/src/yass/soft_assignment/template_BASE_31194.py +++ /dev/null @@ -1,412 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Created on Mon Nov 18 21:39:48 2019 - -@author: kevin -""" - -import numpy as np -from tqdm import tqdm -import scipy.spatial.distance as dist -import torch -import cudaSpline as deconv -from scipy.interpolate import splrep -from numpy.linalg import inv as inv - -def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): - if knots is None: - knots = np.arange(len(curve) + prepad + postpad) - return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) - -def transform_template(template, knots=None, prepad=7, postpad=3, order=3): - - if knots is None: - knots = np.arange(len(template.data[0]) + prepad + postpad) - splines = [ - fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) - for curve in template.data.cpu().numpy() - ] - coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') - return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) - -def get_cov_matrix(spat_cov, geom): - posistion = geom - dist_matrix = dist.squareform(dist.pdist(geom )) - - cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) - - for i in range(posistion.shape[0]): - for j in range(posistion.shape[0]): - if dist_matrix[i, j] > np.max(spat_cov[:, 1]): - cov_matrix[i, j] = 0 - continue - idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] - if len(idx) == 0: - cov_matrix[i, j] = 0 - continue - cov_matrix[i, j] = spat_cov[idx, 0] - return cov_matrix - -#Soft assign object - -class TEMPLATE_ASSIGN_OBJECT(object): - def __init__(self, fname_spike_train, fname_templates, fname_shifts, - reader_residual, spat_cov, temp_cov, channel_idx, geom, - large_unit_threshold = 5, n_chans = 5, rec_chans = 512, - sim_units = 3, temp_thresh= np.inf, lik_window = 50): - - #get the variance of the residual: - self.temp_thresh = temp_thresh - self.rec_chans = rec_chans - self.sim_units = sim_units - self.templates = np.load(fname_templates).astype('float32') - self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) - self.spike_train = np.load(fname_spike_train) - self.spike_train_og = np.load(fname_spike_train) - #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] - self.idx_included = set([]) - self.units_in = set([]) - self.shifts = np.load(fname_shifts) - self.reader_residual = reader_residual - self.spat_cov = get_cov_matrix(spat_cov, geom) - self.temp_cov = temp_cov[:lik_window, :lik_window] - self.channel_index = channel_idx - self.n_neigh_chans = self.channel_index.shape[1] - self.n_chans = n_chans - self.n_units, self.n_times, self.n_channels = self.templates.shape - - self.n_total_spikes = self.spike_train.shape[0] - - #get residual variance - self.get_residual_variance() - - self.get_similar() - self.exclude_large_units(large_unit_threshold) - #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] - self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) - self.idx_included = np.logical_and(self.spike_train_og[:, 0] > 40, np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) - self.spike_train = self.spike_train_og[self.idx_included] - self.shifts = self.shifts[self.idx_included] - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - - self.aligned_template_list = [] - self.coeff_list = [] - self.preprocess_templates_and_spike_times() - self.chan_list = [] - - for i in range(0, self.sim_units): - diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) - for j in range(0, self.n_units): - diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) - - self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) - self.aligned_template_list.append(diff_array) - - #align orignal templates at end - self.aligned_template_list.append(self.templates_aligned) - self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) - - #get aligned templates - - - self.move_to_torch() - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - self.chans = torch.from_numpy(self.chans) - self.get_kronecker() - def get_residual_variance(self): - num = int(60/self.reader_residual.n_sec_chunk) - var_array = np.zeros(num) - for batch_id in range(num): - var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) - self.resid_var = np.mean(var_array) - - def get_kronecker(self): - self.cov_list = [] - inv_temp = inv(self.temp_cov) - for unit in range(self.n_units): - chans = self.chans[unit] - chans = chans[chans < self.rec_chans] - indices = np.ix_(chans,chans) - covar = np.kron(inv(self.spat_cov[indices]), inv_temp) - self.cov_list.append(covar) - self.cov_list = np.asarray(self.cov_list) - self.cov_list = torch.from_numpy(self.cov_list).half().cuda() - - def get_similar(self): - max_time = self.templates.argmax(1).astype("int16") - padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) - reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) - for unit in range(self.n_units): - for chan in range(self.rec_chans): - reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] - - see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - for i in range(self.n_units): - sorted_see = np.sort(see[i])[0:self.sim_units] - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(reduced[i]))) - #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - # self.units_in.add(i) - if sorted_see[1]/norm < self.temp_thresh: - self.units_in.add(i) - - ''' - def get_similar(self): - see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - units_in = [] - for i in range(self.n_units): - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(self.templates[i]))) - if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - self.units_in.add(i) - #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) - - #units_in= np.asarray(units_in) - #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(in_spikes) - ''' - #shift secondary template - def shift_template(self, template, shift): - if shift == 0: - return template - if shift > 0: - return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] - else: - return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] - - def preprocess_templates_and_spike_times(self): - - # templates on neighboring channels - self.mcs = self.templates.ptp(1).argmax(1) - - #template used for alignment defined on neighboring channels - templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) - - #template returned for likilihood calculation- defined on channels with largest ptp - return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) - for k in range(self.n_units): - neigh_chans = self.channel_index[self.mcs[k]] - neigh_chans = neigh_chans[neigh_chans 0: - buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) - return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) - - # get ailgned templates - t_in = np.arange(buffer_size, buffer_size + self.n_times) - - - templates_aligned = np.zeros((self.n_units, - self.n_times, - self.n_chans), 'float32') - for k in range(self.n_units): - t_in_temp = t_in + self.temp_shifts[k] - templates_aligned[k] = return_templates[k,t_in_temp] - - self.templates_aligned = templates_aligned - self.templates_aligned_numpy= templates_aligned - # shift spike times according to the alignment - self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] - - #shifted neighboring template according to shift in primary template - def subtract_template(self, primary_unit, neighbor_unit): - primary_unit_shift = self.temp_shifts[primary_unit] - shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) - return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] - - def exclude_large_units(self, threshold): - - norms = np.zeros(self.n_units) - for j in range(self.n_units): - temp = self.templates[j] - vis_chan = np.where(temp.ptp(0) > 1)[0] - norms[j] = np.sum(np.square(temp[:, vis_chan])) - - - self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) - - #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(self.idx_included.intersection(idx_in)) - def move_to_torch(self): - self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] - self.spike_train = torch.from_numpy(self.spike_train).long().cuda() - self.shifts = torch.from_numpy(self.shifts).float().cuda() - - self.mcs = torch.from_numpy(self.mcs) - - def get_bspline_coeffs(self, template_aligned): - - n_data, n_times, n_channels = template_aligned.shape - - channels = torch.arange(n_channels).cuda() - temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() - - temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) - coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) - return coeffs - def get_shifted_templates(self, temp_ids, shifts, iteration): - - temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() - shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() - - n_sample_run = 1000 - n_times = self.aligned_template_list[iteration].shape[1] - - idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) - - shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() - for j in range(len(idx_run)-1): - ii_start = idx_run[j] - ii_end =idx_run[j+1] - obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() - times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() - deconv.subtract_splines(obj, - times, - shifts[ii_start:ii_end], - temp_ids[ii_start:ii_end], - self.coeff_list[iteration], - torch.full( (ii_end - ii_start, ), 2 ).cuda()) - obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) - shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) - - return shifted_templates - - def get_liklihood(self, unit, snip): - chans = self.chans[unit] - chans = chans < self.rec_chans - log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) - return log_prob - - def compute_soft_assignment(self): - - log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() - - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] - - spike_train_batch = self.spike_train[idx_in] - spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) - - shift_batch = self.shifts[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] - shifted_templates = [element for element in shifted_templates] - - - clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] - unit_idx = spike_train_batch[:, 1] - logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() - for i in range(len(clean_wfs)): - cov_array = self.cov_list[unit_idx] - - restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] - unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) - temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) - result = torch.matmul( temp, unraveled[:, :, None].half()) - logs[:, i] = result.reshape(-1) - ''' - log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) - logs[:, i] = np.asarray(log_vec) - ''' - ''' - for j, spike in enumerate(clean_wfs[i]): - rel_unit = self.similar_array[unit_idx[j]][i] - logs[j, i]= self.get_liklihood(rel_unit, spike) - ''' - log_probs[idx_in] = logs - - pbar.update() - - - - self.log_probs = log_probs.cpu().numpy() - return log_probs.cpu().numpy() - def clean_wave_forms(self, spike_idx, unit): - return_wfs = torch.zeros((spike_idx.shape[0], 81, 5)) - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] - s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] - - idx_in = torch.nonzero((s1 & s2))[:,0] - - spike_train_batch = self.spike_train[spike_idx] - spike_train_batch = spike_train_batch[idx_in] - spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) - - shift_batch = self.shifts[spike_idx] - shift_batch = shift_batch[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - - shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) - clean_wfs = resid_snippets + shifted_og - return_wfs[idx_in] = clean_wfs.cpu() - - return return_wfs.cpu().numpy() - - def get_assign_probs(self, log_lik_array): - fix = log_lik_array*-.5 - fix = fix - fix.max(1)[:, None] - probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] - self.probs = probs - return probs - - def run(self): - - #construct array to identify soft assignment units - unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - for unit in range(self.n_units): - row_idx= self.spike_train_og[:, 1] == unit - unit_assignment[row_idx, :] = self.similar_array[unit, :] - - - log_probs = self.compute_soft_assignment() - probs = self.get_assign_probs(log_probs) - replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) - - replace_probs[:, 0] = 1 - replace_probs[self.idx_included, :] = probs - - replace_log[self.idx_included, :] = log_probs - - return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file From 69806f9f1544212e8c89196588ac639f0c9d4163 Mon Sep 17 00:00:00 2001 From: kevinli1324 Date: Tue, 28 Jan 2020 04:32:11 -0500 Subject: [PATCH 09/19] Delete template_BASE_31287.py --- .../soft_assignment/template_BASE_31287.py | 412 ------------------ 1 file changed, 412 deletions(-) delete mode 100644 src/yass/soft_assignment/template_BASE_31287.py diff --git a/src/yass/soft_assignment/template_BASE_31287.py b/src/yass/soft_assignment/template_BASE_31287.py deleted file mode 100644 index 0955c193..00000000 --- a/src/yass/soft_assignment/template_BASE_31287.py +++ /dev/null @@ -1,412 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Created on Mon Nov 18 21:39:48 2019 - -@author: kevin -""" - -import numpy as np -from tqdm import tqdm -import scipy.spatial.distance as dist -import torch -import cudaSpline as deconv -from scipy.interpolate import splrep -from numpy.linalg import inv as inv - -def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): - if knots is None: - knots = np.arange(len(curve) + prepad + postpad) - return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) - -def transform_template(template, knots=None, prepad=7, postpad=3, order=3): - - if knots is None: - knots = np.arange(len(template.data[0]) + prepad + postpad) - splines = [ - fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) - for curve in template.data.cpu().numpy() - ] - coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') - return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) - -def get_cov_matrix(spat_cov, geom): - posistion = geom - dist_matrix = dist.squareform(dist.pdist(geom )) - - cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) - - for i in range(posistion.shape[0]): - for j in range(posistion.shape[0]): - if dist_matrix[i, j] > np.max(spat_cov[:, 1]): - cov_matrix[i, j] = 0 - continue - idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] - if len(idx) == 0: - cov_matrix[i, j] = 0 - continue - cov_matrix[i, j] = spat_cov[idx, 0] - return cov_matrix - -#Soft assign object - -class TEMPLATE_ASSIGN_OBJECT(object): - def __init__(self, fname_spike_train, fname_templates, fname_shifts, - reader_residual, spat_cov, temp_cov, channel_idx, geom, - large_unit_threshold = 5, n_chans = 5, rec_chans = 512, - sim_units = 3, temp_thresh= np.inf, lik_window = 50): - - #get the variance of the residual: - self.temp_thresh = temp_thresh - self.rec_chans = rec_chans - self.sim_units = sim_units - self.templates = np.load(fname_templates).astype('float32') - self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) - self.spike_train = np.load(fname_spike_train) - self.spike_train_og = np.load(fname_spike_train) - #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] - self.idx_included = set([]) - self.units_in = set([]) - self.shifts = np.load(fname_shifts) - self.reader_residual = reader_residual - self.spat_cov = get_cov_matrix(spat_cov, geom) - self.temp_cov = temp_cov[:lik_window, :lik_window] - self.channel_index = channel_idx - self.n_neigh_chans = self.channel_index.shape[1] - self.n_chans = n_chans - self.n_units, self.n_times, self.n_channels = self.templates.shape - - self.n_total_spikes = self.spike_train.shape[0] - - #get residual variance - self.get_residual_variance() - - self.get_similar() - self.exclude_large_units(large_unit_threshold) - #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] - self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) - self.idx_included = np.logical_and(self.spike_train_og[:, 0] > 40, np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) - self.spike_train = self.spike_train_og[self.idx_included] - self.shifts = self.shifts[self.idx_included] - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - - self.aligned_template_list = [] - self.coeff_list = [] - self.preprocess_templates_and_spike_times() - self.chan_list = [] - - for i in range(0, self.sim_units): - diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) - for j in range(0, self.n_units): - diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) - - self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) - self.aligned_template_list.append(diff_array) - - #align orignal templates at end - self.aligned_template_list.append(self.templates_aligned) - self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) - - #get aligned templates - - - self.move_to_torch() - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - self.chans = torch.from_numpy(self.chans) - self.get_kronecker() - def get_residual_variance(self): - num = int(60/self.reader_residual.n_sec_chunk) - var_array = np.zeros(num) - for batch_id in range(num): - var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) - self.resid_var = np.mean(var_array) - - def get_kronecker(self): - self.cov_list = [] - inv_temp = inv(self.temp_cov) - for unit in range(self.n_units): - chans = self.chans[unit] - chans = chans[chans < self.rec_chans] - indices = np.ix_(chans,chans) - covar = np.kron(inv(self.spat_cov[indices]), inv_temp) - self.cov_list.append(covar) - self.cov_list = np.asarray(self.cov_list) - self.cov_list = torch.from_numpy(self.cov_list).half().cuda() - - def get_similar(self): - max_time = self.templates.argmax(1).astype("int16") - padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) - reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) - for unit in range(self.n_units): - for chan in range(self.rec_chans): - reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] - - see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - for i in range(self.n_units): - sorted_see = np.sort(see[i])[0:self.sim_units] - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(reduced[i]))) - #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - # self.units_in.add(i) - if sorted_see[1]/norm < self.temp_thresh: - self.units_in.add(i) - - ''' - def get_similar(self): - see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - units_in = [] - for i in range(self.n_units): - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(self.templates[i]))) - if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - self.units_in.add(i) - #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) - - #units_in= np.asarray(units_in) - #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(in_spikes) - ''' - #shift secondary template - def shift_template(self, template, shift): - if shift == 0: - return template - if shift > 0: - return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] - else: - return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] - - def preprocess_templates_and_spike_times(self): - - # templates on neighboring channels - self.mcs = self.templates.ptp(1).argmax(1) - - #template used for alignment defined on neighboring channels - templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) - - #template returned for likilihood calculation- defined on channels with largest ptp - return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) - for k in range(self.n_units): - neigh_chans = self.channel_index[self.mcs[k]] - neigh_chans = neigh_chans[neigh_chans 0: - buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) - return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) - - # get ailgned templates - t_in = np.arange(buffer_size, buffer_size + self.n_times) - - - templates_aligned = np.zeros((self.n_units, - self.n_times, - self.n_chans), 'float32') - for k in range(self.n_units): - t_in_temp = t_in + self.temp_shifts[k] - templates_aligned[k] = return_templates[k,t_in_temp] - - self.templates_aligned = templates_aligned - self.templates_aligned_numpy= templates_aligned - # shift spike times according to the alignment - self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] - - #shifted neighboring template according to shift in primary template - def subtract_template(self, primary_unit, neighbor_unit): - primary_unit_shift = self.temp_shifts[primary_unit] - shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) - return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] - - def exclude_large_units(self, threshold): - - norms = np.zeros(self.n_units) - for j in range(self.n_units): - temp = self.templates[j] - vis_chan = np.where(temp.ptp(0) > 1)[0] - norms[j] = np.sum(np.square(temp[:, vis_chan])) - - - self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) - - #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(self.idx_included.intersection(idx_in)) - def move_to_torch(self): - self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] - self.spike_train = torch.from_numpy(self.spike_train).long().cuda() - self.shifts = torch.from_numpy(self.shifts).float().cuda() - - self.mcs = torch.from_numpy(self.mcs) - - def get_bspline_coeffs(self, template_aligned): - - n_data, n_times, n_channels = template_aligned.shape - - channels = torch.arange(n_channels).cuda() - temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() - - temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) - coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) - return coeffs - def get_shifted_templates(self, temp_ids, shifts, iteration): - - temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() - shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() - - n_sample_run = 1000 - n_times = self.aligned_template_list[iteration].shape[1] - - idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) - - shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() - for j in range(len(idx_run)-1): - ii_start = idx_run[j] - ii_end =idx_run[j+1] - obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() - times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() - deconv.subtract_splines(obj, - times, - shifts[ii_start:ii_end], - temp_ids[ii_start:ii_end], - self.coeff_list[iteration], - torch.full( (ii_end - ii_start, ), 2 ).cuda()) - obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) - shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) - - return shifted_templates - - def get_liklihood(self, unit, snip): - chans = self.chans[unit] - chans = chans < self.rec_chans - log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) - return log_prob - - def compute_soft_assignment(self): - - log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() - - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] - - spike_train_batch = self.spike_train[idx_in] - spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) - - shift_batch = self.shifts[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] - shifted_templates = [element for element in shifted_templates] - - - clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] - unit_idx = spike_train_batch[:, 1] - logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() - for i in range(len(clean_wfs)): - cov_array = self.cov_list[unit_idx] - - restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] - unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) - temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) - result = torch.matmul( temp, unraveled[:, :, None].half()) - logs[:, i] = result.reshape(-1) - ''' - log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) - logs[:, i] = np.asarray(log_vec) - ''' - ''' - for j, spike in enumerate(clean_wfs[i]): - rel_unit = self.similar_array[unit_idx[j]][i] - logs[j, i]= self.get_liklihood(rel_unit, spike) - ''' - log_probs[idx_in] = logs - - pbar.update() - - - - self.log_probs = log_probs.cpu().numpy() - return log_probs.cpu().numpy() - def clean_wave_forms(self, spike_idx, unit): - return_wfs = torch.zeros((spike_idx.shape[0], 81, 5)) - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] - s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] - - idx_in = torch.nonzero((s1 & s2))[:,0] - - spike_train_batch = self.spike_train[spike_idx] - spike_train_batch = spike_train_batch[idx_in] - spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) - - shift_batch = self.shifts[spike_idx] - shift_batch = shift_batch[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - - shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) - clean_wfs = resid_snippets + shifted_og - return_wfs[idx_in] = clean_wfs.cpu() - - return return_wfs.cpu().numpy() - - def get_assign_probs(self, log_lik_array): - fix = log_lik_array*-.5 - fix = fix - fix.max(1)[:, None] - probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] - self.probs = probs - return probs - - def run(self): - - #construct array to identify soft assignment units - unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - for unit in range(self.n_units): - row_idx= self.spike_train_og[:, 1] == unit - unit_assignment[row_idx, :] = self.similar_array[unit, :] - - - log_probs = self.compute_soft_assignment() - probs = self.get_assign_probs(log_probs) - replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) - - replace_probs[:, 0] = 1 - replace_probs[self.idx_included, :] = probs - - replace_log[self.idx_included, :] = log_probs - - return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file From 516f1f4f81433bc86e25d7204bdcdbbaabcc7ed8 Mon Sep 17 00:00:00 2001 From: kevinli1324 Date: Tue, 28 Jan 2020 04:32:26 -0500 Subject: [PATCH 10/19] Delete template_LOCAL_31194.py --- .../soft_assignment/template_LOCAL_31194.py | 415 ------------------ 1 file changed, 415 deletions(-) delete mode 100644 src/yass/soft_assignment/template_LOCAL_31194.py diff --git a/src/yass/soft_assignment/template_LOCAL_31194.py b/src/yass/soft_assignment/template_LOCAL_31194.py deleted file mode 100644 index c3d631e6..00000000 --- a/src/yass/soft_assignment/template_LOCAL_31194.py +++ /dev/null @@ -1,415 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Created on Mon Nov 18 21:39:48 2019 - -@author: kevin -""" - -import numpy as np -from tqdm import tqdm -import scipy.spatial.distance as dist -import torch -import cudaSpline as deconv -from scipy.interpolate import splrep -from numpy.linalg import inv as inv - -def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): - if knots is None: - knots = np.arange(len(curve) + prepad + postpad) - return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) - -def transform_template(template, knots=None, prepad=7, postpad=3, order=3): - - if knots is None: - knots = np.arange(len(template.data[0]) + prepad + postpad) - splines = [ - fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) - for curve in template.data.cpu().numpy() - ] - coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') - return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) - -def get_cov_matrix(spat_cov, geom): - posistion = geom - dist_matrix = dist.squareform(dist.pdist(geom )) - - cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) - - for i in range(posistion.shape[0]): - for j in range(posistion.shape[0]): - if dist_matrix[i, j] > np.max(spat_cov[:, 1]): - cov_matrix[i, j] = 0 - continue - idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] - if len(idx) == 0: - cov_matrix[i, j] = 0 - continue - cov_matrix[i, j] = spat_cov[idx, 0] - return cov_matrix - -#Soft assign object - -class TEMPLATE_ASSIGN_OBJECT(object): - def __init__(self, fname_spike_train, fname_templates, fname_shifts, - reader_residual, spat_cov, temp_cov, channel_idx, geom, - large_unit_threshold = 5, n_chans = 5, rec_chans = 512, - sim_units = 3, temp_thresh= np.inf, lik_window = 50): - - #get the variance of the residual: - self.temp_thresh = temp_thresh - self.rec_chans = rec_chans - self.sim_units = sim_units - self.templates = np.load(fname_templates).astype('float32') - self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) - self.spike_train = np.load(fname_spike_train) - self.spike_train_og = np.load(fname_spike_train) - #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] - self.idx_included = set([]) - self.units_in = set([]) - self.shifts = np.load(fname_shifts) - self.reader_residual = reader_residual - self.spat_cov = get_cov_matrix(spat_cov, geom) - self.temp_cov = temp_cov[:lik_window, :lik_window] - self.channel_index = channel_idx - self.n_neigh_chans = self.channel_index.shape[1] - self.n_chans = n_chans - self.n_units, self.n_times, self.n_channels = self.templates.shape - - self.n_total_spikes = self.spike_train.shape[0] - - #get residual variance - self.get_residual_variance() - - self.get_similar() - self.exclude_large_units(large_unit_threshold) - #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] - self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) - self.idx_included = np.logical_and(np.logical_and(self.spike_train_og[:, 0] < reader_residual.rec_len - self.n_times//2, self.spike_train_og[:, 0] > self.n_times//2), np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) - self.spike_train = self.spike_train_og[self.idx_included] - self.shifts = self.shifts[self.idx_included] - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - - self.aligned_template_list = [] - self.coeff_list = [] - self.preprocess_templates_and_spike_times() - self.chan_list = [] - - for i in range(0, self.sim_units): - diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) - for j in range(0, self.n_units): - diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) - - self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) - self.aligned_template_list.append(diff_array) - - #align orignal templates at end - self.aligned_template_list.append(self.templates_aligned) - self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) - - #get aligned templates - - - self.move_to_torch() - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - self.chans = torch.from_numpy(self.chans) - self.get_kronecker() - def get_residual_variance(self): - num = int(60/self.reader_residual.n_sec_chunk) - var_array = np.zeros(num) - for batch_id in range(num): - var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) - self.resid_var = np.mean(var_array) - - def get_kronecker(self): - self.cov_list = [] - inv_temp = inv(self.temp_cov) - for unit in range(self.n_units): - chans = self.chans[unit] - chans = chans[chans < self.rec_chans] - indices = np.ix_(chans,chans) - covar = np.kron(inv(self.spat_cov[indices]), inv_temp) - self.cov_list.append(covar) - self.cov_list = np.asarray(self.cov_list) - self.cov_list = torch.from_numpy(self.cov_list).half().cuda() - - def get_similar(self): - max_time = self.templates.argmax(1).astype("int16") - padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) - reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) - for unit in range(self.n_units): - for chan in range(self.rec_chans): - reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] - - see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - for i in range(self.n_units): - sorted_see = np.sort(see[i])[0:self.sim_units] - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(reduced[i]))) - #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - # self.units_in.add(i) - if sorted_see[1]/norm < self.temp_thresh: - self.units_in.add(i) - - ''' - def get_similar(self): - see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - units_in = [] - for i in range(self.n_units): - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(self.templates[i]))) - if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - self.units_in.add(i) - #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) - - #units_in= np.asarray(units_in) - #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(in_spikes) - ''' - #shift secondary template - def shift_template(self, template, shift): - if shift == 0: - return template - if shift > 0: - return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] - else: - return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] - - def preprocess_templates_and_spike_times(self): - - # templates on neighboring channels - self.mcs = self.templates.ptp(1).argmax(1) - - #template used for alignment defined on neighboring channels - templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) - - #template returned for likilihood calculation- defined on channels with largest ptp - return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) - for k in range(self.n_units): - neigh_chans = self.channel_index[self.mcs[k]] - neigh_chans = neigh_chans[neigh_chans 0: - buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) - return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) - - # get ailgned templates - t_in = np.arange(buffer_size, buffer_size + self.n_times) - - - templates_aligned = np.zeros((self.n_units, - self.n_times, - self.n_chans), 'float32') - for k in range(self.n_units): - t_in_temp = t_in + self.temp_shifts[k] - templates_aligned[k] = return_templates[k,t_in_temp] - - self.templates_aligned = templates_aligned - self.templates_aligned_numpy= templates_aligned - # shift spike times according to the alignment - self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] - - #shifted neighboring template according to shift in primary template - def subtract_template(self, primary_unit, neighbor_unit): - primary_unit_shift = self.temp_shifts[primary_unit] - shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) - return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] - - def exclude_large_units(self, threshold): - - norms = np.zeros(self.n_units) - for j in range(self.n_units): - temp = self.templates[j] - vis_chan = np.where(temp.ptp(0) > 1)[0] - norms[j] = np.sum(np.square(temp[:, vis_chan])) - - - self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) - - #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(self.idx_included.intersection(idx_in)) - def move_to_torch(self): - self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] - self.spike_train = torch.from_numpy(self.spike_train).long().cuda() - self.shifts = torch.from_numpy(self.shifts).float().cuda() - - self.mcs = torch.from_numpy(self.mcs) - - def get_bspline_coeffs(self, template_aligned): - - n_data, n_times, n_channels = template_aligned.shape - - channels = torch.arange(n_channels).cuda() - temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() - - temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) - coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) - return coeffs - def get_shifted_templates(self, temp_ids, shifts, iteration): - temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() - shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() - - n_sample_run = 1000 - n_times = self.aligned_template_list[iteration].shape[1] - - idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) - - shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() - for j in range(len(idx_run)-1): - ii_start = idx_run[j] - ii_end =idx_run[j+1] - obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() - times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() - deconv.subtract_splines(obj, - times, - shifts[ii_start:ii_end], - temp_ids[ii_start:ii_end], - self.coeff_list[iteration], - torch.full( (ii_end - ii_start, ), 2 ).cuda()) - obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) - shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) - - return shifted_templates - - def get_liklihood(self, unit, snip): - chans = self.chans[unit] - chans = chans < self.rec_chans - log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) - return log_prob - - def compute_soft_assignment(self): - - log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() - - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] - - spike_train_batch = self.spike_train[idx_in] - - - - spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) - shift_batch = self.shifts[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - - shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] - shifted_templates = [element for element in shifted_templates] - - - clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] - unit_idx = spike_train_batch[:, 1] - logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() - for i in range(len(clean_wfs)): - cov_array = self.cov_list[unit_idx] - - restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] - unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) - temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) - result = torch.matmul( temp, unraveled[:, :, None].half()) - logs[:, i] = result.reshape(-1) - ''' - log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) - logs[:, i] = np.asarray(log_vec) - ''' - ''' - for j, spike in enumerate(clean_wfs[i]): - rel_unit = self.similar_array[unit_idx[j]][i] - logs[j, i]= self.get_liklihood(rel_unit, spike) - ''' - log_probs[idx_in] = logs - - pbar.update() - - - - self.log_probs = log_probs.cpu().numpy() - return log_probs.cpu().numpy() - def clean_wave_forms(self, spike_idx, unit): - return_wfs = torch.zeros((spike_idx.shape[0],self.templates.shape[1], self.n_chans)) - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] - s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] - - idx_in = torch.nonzero((s1 & s2))[:,0] - - spike_train_batch = self.spike_train[spike_idx] - spike_train_batch = spike_train_batch[idx_in] - spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) - - shift_batch = self.shifts[spike_idx] - shift_batch = shift_batch[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - - shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) - clean_wfs = resid_snippets + shifted_og - return_wfs[idx_in] = clean_wfs.cpu() - - return return_wfs.cpu().numpy() - - def get_assign_probs(self, log_lik_array): - fix = log_lik_array*-.5 - fix = fix - fix.max(1)[:, None] - probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] - self.probs = probs - return probs - - def run(self): - - #construct array to identify soft assignment units - unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - for unit in range(self.n_units): - row_idx= self.spike_train_og[:, 1] == unit - unit_assignment[row_idx, :] = self.similar_array[unit, :] - - - log_probs = self.compute_soft_assignment() - self.log_probs = log_probs - probs = self.get_assign_probs(log_probs) - replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) - - replace_probs[:, 0] = 1 - replace_probs[self.idx_included, :] = probs - - replace_log[self.idx_included, :] = log_probs - - return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file From 2b18d5ef23b9d9fa3094fb388f8801f90c3994a2 Mon Sep 17 00:00:00 2001 From: kevinli1324 Date: Tue, 28 Jan 2020 04:32:38 -0500 Subject: [PATCH 11/19] Delete template_LOCAL_31287.py --- .../soft_assignment/template_LOCAL_31287.py | 415 ------------------ 1 file changed, 415 deletions(-) delete mode 100644 src/yass/soft_assignment/template_LOCAL_31287.py diff --git a/src/yass/soft_assignment/template_LOCAL_31287.py b/src/yass/soft_assignment/template_LOCAL_31287.py deleted file mode 100644 index c3d631e6..00000000 --- a/src/yass/soft_assignment/template_LOCAL_31287.py +++ /dev/null @@ -1,415 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Created on Mon Nov 18 21:39:48 2019 - -@author: kevin -""" - -import numpy as np -from tqdm import tqdm -import scipy.spatial.distance as dist -import torch -import cudaSpline as deconv -from scipy.interpolate import splrep -from numpy.linalg import inv as inv - -def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): - if knots is None: - knots = np.arange(len(curve) + prepad + postpad) - return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) - -def transform_template(template, knots=None, prepad=7, postpad=3, order=3): - - if knots is None: - knots = np.arange(len(template.data[0]) + prepad + postpad) - splines = [ - fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) - for curve in template.data.cpu().numpy() - ] - coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') - return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) - -def get_cov_matrix(spat_cov, geom): - posistion = geom - dist_matrix = dist.squareform(dist.pdist(geom )) - - cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) - - for i in range(posistion.shape[0]): - for j in range(posistion.shape[0]): - if dist_matrix[i, j] > np.max(spat_cov[:, 1]): - cov_matrix[i, j] = 0 - continue - idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] - if len(idx) == 0: - cov_matrix[i, j] = 0 - continue - cov_matrix[i, j] = spat_cov[idx, 0] - return cov_matrix - -#Soft assign object - -class TEMPLATE_ASSIGN_OBJECT(object): - def __init__(self, fname_spike_train, fname_templates, fname_shifts, - reader_residual, spat_cov, temp_cov, channel_idx, geom, - large_unit_threshold = 5, n_chans = 5, rec_chans = 512, - sim_units = 3, temp_thresh= np.inf, lik_window = 50): - - #get the variance of the residual: - self.temp_thresh = temp_thresh - self.rec_chans = rec_chans - self.sim_units = sim_units - self.templates = np.load(fname_templates).astype('float32') - self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) - self.spike_train = np.load(fname_spike_train) - self.spike_train_og = np.load(fname_spike_train) - #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] - self.idx_included = set([]) - self.units_in = set([]) - self.shifts = np.load(fname_shifts) - self.reader_residual = reader_residual - self.spat_cov = get_cov_matrix(spat_cov, geom) - self.temp_cov = temp_cov[:lik_window, :lik_window] - self.channel_index = channel_idx - self.n_neigh_chans = self.channel_index.shape[1] - self.n_chans = n_chans - self.n_units, self.n_times, self.n_channels = self.templates.shape - - self.n_total_spikes = self.spike_train.shape[0] - - #get residual variance - self.get_residual_variance() - - self.get_similar() - self.exclude_large_units(large_unit_threshold) - #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] - self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) - self.idx_included = np.logical_and(np.logical_and(self.spike_train_og[:, 0] < reader_residual.rec_len - self.n_times//2, self.spike_train_og[:, 0] > self.n_times//2), np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) - self.spike_train = self.spike_train_og[self.idx_included] - self.shifts = self.shifts[self.idx_included] - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - - self.aligned_template_list = [] - self.coeff_list = [] - self.preprocess_templates_and_spike_times() - self.chan_list = [] - - for i in range(0, self.sim_units): - diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) - for j in range(0, self.n_units): - diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) - - self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) - self.aligned_template_list.append(diff_array) - - #align orignal templates at end - self.aligned_template_list.append(self.templates_aligned) - self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) - - #get aligned templates - - - self.move_to_torch() - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - self.chans = torch.from_numpy(self.chans) - self.get_kronecker() - def get_residual_variance(self): - num = int(60/self.reader_residual.n_sec_chunk) - var_array = np.zeros(num) - for batch_id in range(num): - var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) - self.resid_var = np.mean(var_array) - - def get_kronecker(self): - self.cov_list = [] - inv_temp = inv(self.temp_cov) - for unit in range(self.n_units): - chans = self.chans[unit] - chans = chans[chans < self.rec_chans] - indices = np.ix_(chans,chans) - covar = np.kron(inv(self.spat_cov[indices]), inv_temp) - self.cov_list.append(covar) - self.cov_list = np.asarray(self.cov_list) - self.cov_list = torch.from_numpy(self.cov_list).half().cuda() - - def get_similar(self): - max_time = self.templates.argmax(1).astype("int16") - padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) - reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) - for unit in range(self.n_units): - for chan in range(self.rec_chans): - reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] - - see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - for i in range(self.n_units): - sorted_see = np.sort(see[i])[0:self.sim_units] - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(reduced[i]))) - #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - # self.units_in.add(i) - if sorted_see[1]/norm < self.temp_thresh: - self.units_in.add(i) - - ''' - def get_similar(self): - see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - units_in = [] - for i in range(self.n_units): - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(self.templates[i]))) - if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - self.units_in.add(i) - #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) - - #units_in= np.asarray(units_in) - #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(in_spikes) - ''' - #shift secondary template - def shift_template(self, template, shift): - if shift == 0: - return template - if shift > 0: - return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] - else: - return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] - - def preprocess_templates_and_spike_times(self): - - # templates on neighboring channels - self.mcs = self.templates.ptp(1).argmax(1) - - #template used for alignment defined on neighboring channels - templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) - - #template returned for likilihood calculation- defined on channels with largest ptp - return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) - for k in range(self.n_units): - neigh_chans = self.channel_index[self.mcs[k]] - neigh_chans = neigh_chans[neigh_chans 0: - buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) - return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) - - # get ailgned templates - t_in = np.arange(buffer_size, buffer_size + self.n_times) - - - templates_aligned = np.zeros((self.n_units, - self.n_times, - self.n_chans), 'float32') - for k in range(self.n_units): - t_in_temp = t_in + self.temp_shifts[k] - templates_aligned[k] = return_templates[k,t_in_temp] - - self.templates_aligned = templates_aligned - self.templates_aligned_numpy= templates_aligned - # shift spike times according to the alignment - self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] - - #shifted neighboring template according to shift in primary template - def subtract_template(self, primary_unit, neighbor_unit): - primary_unit_shift = self.temp_shifts[primary_unit] - shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) - return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] - - def exclude_large_units(self, threshold): - - norms = np.zeros(self.n_units) - for j in range(self.n_units): - temp = self.templates[j] - vis_chan = np.where(temp.ptp(0) > 1)[0] - norms[j] = np.sum(np.square(temp[:, vis_chan])) - - - self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) - - #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(self.idx_included.intersection(idx_in)) - def move_to_torch(self): - self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] - self.spike_train = torch.from_numpy(self.spike_train).long().cuda() - self.shifts = torch.from_numpy(self.shifts).float().cuda() - - self.mcs = torch.from_numpy(self.mcs) - - def get_bspline_coeffs(self, template_aligned): - - n_data, n_times, n_channels = template_aligned.shape - - channels = torch.arange(n_channels).cuda() - temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() - - temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) - coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) - return coeffs - def get_shifted_templates(self, temp_ids, shifts, iteration): - temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() - shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() - - n_sample_run = 1000 - n_times = self.aligned_template_list[iteration].shape[1] - - idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) - - shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() - for j in range(len(idx_run)-1): - ii_start = idx_run[j] - ii_end =idx_run[j+1] - obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() - times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() - deconv.subtract_splines(obj, - times, - shifts[ii_start:ii_end], - temp_ids[ii_start:ii_end], - self.coeff_list[iteration], - torch.full( (ii_end - ii_start, ), 2 ).cuda()) - obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) - shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) - - return shifted_templates - - def get_liklihood(self, unit, snip): - chans = self.chans[unit] - chans = chans < self.rec_chans - log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) - return log_prob - - def compute_soft_assignment(self): - - log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() - - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] - - spike_train_batch = self.spike_train[idx_in] - - - - spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) - shift_batch = self.shifts[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - - shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] - shifted_templates = [element for element in shifted_templates] - - - clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] - unit_idx = spike_train_batch[:, 1] - logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() - for i in range(len(clean_wfs)): - cov_array = self.cov_list[unit_idx] - - restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] - unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) - temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) - result = torch.matmul( temp, unraveled[:, :, None].half()) - logs[:, i] = result.reshape(-1) - ''' - log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) - logs[:, i] = np.asarray(log_vec) - ''' - ''' - for j, spike in enumerate(clean_wfs[i]): - rel_unit = self.similar_array[unit_idx[j]][i] - logs[j, i]= self.get_liklihood(rel_unit, spike) - ''' - log_probs[idx_in] = logs - - pbar.update() - - - - self.log_probs = log_probs.cpu().numpy() - return log_probs.cpu().numpy() - def clean_wave_forms(self, spike_idx, unit): - return_wfs = torch.zeros((spike_idx.shape[0],self.templates.shape[1], self.n_chans)) - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] - s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] - - idx_in = torch.nonzero((s1 & s2))[:,0] - - spike_train_batch = self.spike_train[spike_idx] - spike_train_batch = spike_train_batch[idx_in] - spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) - - shift_batch = self.shifts[spike_idx] - shift_batch = shift_batch[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - - shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) - clean_wfs = resid_snippets + shifted_og - return_wfs[idx_in] = clean_wfs.cpu() - - return return_wfs.cpu().numpy() - - def get_assign_probs(self, log_lik_array): - fix = log_lik_array*-.5 - fix = fix - fix.max(1)[:, None] - probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] - self.probs = probs - return probs - - def run(self): - - #construct array to identify soft assignment units - unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - for unit in range(self.n_units): - row_idx= self.spike_train_og[:, 1] == unit - unit_assignment[row_idx, :] = self.similar_array[unit, :] - - - log_probs = self.compute_soft_assignment() - self.log_probs = log_probs - probs = self.get_assign_probs(log_probs) - replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) - - replace_probs[:, 0] = 1 - replace_probs[self.idx_included, :] = probs - - replace_log[self.idx_included, :] = log_probs - - return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file From 9da88872124078077ba6c417c43bd6d8b7e1c25a Mon Sep 17 00:00:00 2001 From: kevinli1324 Date: Tue, 28 Jan 2020 04:34:13 -0500 Subject: [PATCH 12/19] Delete template_REMOTE_31194.py --- .../soft_assignment/template_REMOTE_31194.py | 416 ------------------ 1 file changed, 416 deletions(-) delete mode 100644 src/yass/soft_assignment/template_REMOTE_31194.py diff --git a/src/yass/soft_assignment/template_REMOTE_31194.py b/src/yass/soft_assignment/template_REMOTE_31194.py deleted file mode 100644 index 20b92cef..00000000 --- a/src/yass/soft_assignment/template_REMOTE_31194.py +++ /dev/null @@ -1,416 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -Created on Mon Nov 18 21:39:48 2019 - -@author: kevin -""" - -import numpy as np -from tqdm import tqdm -import scipy.spatial.distance as dist -import torch -import cudaSpline as deconv -from scipy.interpolate import splrep -from numpy.linalg import inv as inv - -def fit_spline(curve, knots=None, prepad=0, postpad=0, order=3): - if knots is None: - knots = np.arange(len(curve) + prepad + postpad) - return splrep(knots, np.pad(curve, (prepad, postpad), mode='symmetric'), k=order) - -def transform_template(template, knots=None, prepad=7, postpad=3, order=3): - - if knots is None: - knots = np.arange(len(template.data[0]) + prepad + postpad) - splines = [ - fit_spline(curve, knots=knots, prepad=prepad, postpad=postpad, order=order) - for curve in template.data.cpu().numpy() - ] - coefficients = np.array([spline[1][prepad-1:-1*(postpad+1)] for spline in splines], dtype='float32') - return deconv.Template(torch.from_numpy(coefficients).cuda(), template.indices) - -def get_cov_matrix(spat_cov, geom): - posistion = geom - dist_matrix = dist.squareform(dist.pdist(geom )) - - cov_matrix = np.zeros((posistion.shape[0], posistion.shape[0])) - - for i in range(posistion.shape[0]): - for j in range(posistion.shape[0]): - if dist_matrix[i, j] > np.max(spat_cov[:, 1]): - cov_matrix[i, j] = 0 - continue - idx = np.where(spat_cov[:, 1] == dist_matrix[i, j])[0] - if len(idx) == 0: - cov_matrix[i, j] = 0 - continue - cov_matrix[i, j] = spat_cov[idx, 0] - return cov_matrix - -#Soft assign object - -class TEMPLATE_ASSIGN_OBJECT(object): - def __init__(self, fname_spike_train, fname_templates, fname_shifts, - reader_residual, spat_cov, temp_cov, channel_idx, geom, - large_unit_threshold = 5, n_chans = 5, rec_chans = 512, - sim_units = 3, temp_thresh= np.inf, lik_window = 50): - - #get the variance of the residual: - self.temp_thresh = temp_thresh - self.rec_chans = rec_chans - self.sim_units = sim_units - self.templates = np.load(fname_templates).astype('float32') - self.offset = int((self.templates.shape[1] - (2*(lik_window//2) +1) )/2) - self.spike_train = np.load(fname_spike_train) - self.spike_train_og = np.load(fname_spike_train) - #self.spike_train = self.spike_train[self.spike_train[:, 0] > 40] - self.idx_included = set([]) - self.units_in = set([]) - self.shifts = np.load(fname_shifts) - self.reader_residual = reader_residual - self.spat_cov = get_cov_matrix(spat_cov, geom) - self.temp_cov = temp_cov[:lik_window, :lik_window] - self.channel_index = channel_idx - self.n_neigh_chans = self.channel_index.shape[1] - self.n_chans = n_chans - self.n_units, self.n_times, self.n_channels = self.templates.shape - - self.n_total_spikes = self.spike_train.shape[0] - - #get residual variance - self.get_residual_variance() - - self.get_similar() - self.exclude_large_units(large_unit_threshold) - #self.spike_train = self.spike_train[np.asarray(list(self.idx_included)).astype("int16"), :] - self.test = np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in))) - self.idx_included = np.logical_and(self.spike_train_og[:, 0] > 40, np.in1d(self.spike_train_og[:, 1], np.asarray(list(self.units_in)))) - self.spike_train = self.spike_train_og[self.idx_included] - self.shifts = self.shifts[self.idx_included] - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - - self.aligned_template_list = [] - self.coeff_list = [] - self.preprocess_templates_and_spike_times() - self.chan_list = [] - - for i in range(0, self.sim_units): - diff_array = np.zeros((self.n_units, self.n_times, self.n_chans)) - for j in range(0, self.n_units): - diff_array[j] = self.subtract_template(j, self.similar_array[j, i]) - - self.coeff_list.append(self.get_bspline_coeffs(diff_array.astype("float32"))) - self.aligned_template_list.append(diff_array) - - #align orignal templates at end - self.aligned_template_list.append(self.templates_aligned) - self.coeff_list.append(self.get_bspline_coeffs(self.templates_aligned)) - - #get aligned templates - - - self.move_to_torch() - self.chans = np.asarray([np.argsort(self.templates[unit].ptp(0))[::-1][:self.n_chans] for unit in range(self.n_units)]) - self.chans = torch.from_numpy(self.chans) - self.get_kronecker() - def get_residual_variance(self): - num = int(60/self.reader_residual.n_sec_chunk) - var_array = np.zeros(num) - for batch_id in range(num): - var_array[batch_id] = np.var(self.reader_residual.read_data_batch(batch_id, add_buffer=True)) - self.resid_var = np.mean(var_array) - - def get_kronecker(self): - self.cov_list = [] - inv_temp = inv(self.temp_cov) - for unit in range(self.n_units): - chans = self.chans[unit] - chans = chans[chans < self.rec_chans] - indices = np.ix_(chans,chans) - covar = np.kron(inv(self.spat_cov[indices]), inv_temp) - self.cov_list.append(covar) - self.cov_list = np.asarray(self.cov_list) - self.cov_list = torch.from_numpy(self.cov_list).half().cuda() - - def get_similar(self): - max_time = self.templates.argmax(1).astype("int16") - padded_templates = np.concatenate((np.zeros((self.n_units, 5, self.rec_chans)), self.templates, np.zeros((self.n_units, 5, self.rec_chans))), axis = 1) - reduced = np.zeros((self.templates.shape[0], 5, self.rec_chans)) - for unit in range(self.n_units): - for chan in range(self.rec_chans): - reduced[unit, :, chan] = padded_templates[unit, (5 +max_time[unit, chan] - 2):(5 + max_time[unit, chan] + 3), chan] - - see = dist.squareform(dist.pdist(reduced.reshape(self.n_units, self.n_channels*5))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - for i in range(self.n_units): - sorted_see = np.sort(see[i])[0:self.sim_units] - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(reduced[i]))) - #if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - # self.units_in.add(i) - if sorted_see[1]/norm < self.temp_thresh: - self.units_in.add(i) - - ''' - def get_similar(self): - see = dist.squareform(dist.pdist(self.templates.reshape(self.n_units, self.n_channels*self.n_times))) - self.similar_array = np.zeros((self.n_units, self.sim_units)).astype("int16") - units_in = [] - for i in range(self.n_units): - self.similar_array[i] = np.argsort(see[i])[0:self.sim_units] - norm = np.sqrt(np.sum(np.square(self.templates[i]))) - if np.min(self.similar_array[i][1:])/norm < self.temp_thresh: - self.units_in.add(i) - #self.idx_included.update(np.where(self.spike_train_og[:, 1] == i)[0]) - - #units_in= np.asarray(units_in) - #in_spikes = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(in_spikes) - ''' - #shift secondary template - def shift_template(self, template, shift): - if shift == 0: - return template - if shift > 0: - return np.concatenate((template, np.zeros((shift, template.shape[1]))), axis = 0)[shift:, :] - else: - return np.concatenate((np.zeros((-shift, template.shape[1])), template), axis = 0)[:(self.n_times), :] - - def preprocess_templates_and_spike_times(self): - - # templates on neighboring channels - self.mcs = self.templates.ptp(1).argmax(1) - - #template used for alignment defined on neighboring channels - templates_neigh = np.zeros((self.n_units, self.n_times, self.n_neigh_chans)) - - #template returned for likilihood calculation- defined on channels with largest ptp - return_templates = np.zeros((self.n_units, self.n_times, self.n_chans)) - for k in range(self.n_units): - neigh_chans = self.channel_index[self.mcs[k]] - neigh_chans = neigh_chans[neigh_chans 0: - buffer = np.zeros((self.n_units, buffer_size, self.n_chans)) - return_templates = np.concatenate((buffer, return_templates, buffer), axis=1) - - # get ailgned templates - t_in = np.arange(buffer_size, buffer_size + self.n_times) - - - templates_aligned = np.zeros((self.n_units, - self.n_times, - self.n_chans), 'float32') - for k in range(self.n_units): - t_in_temp = t_in + self.temp_shifts[k] - templates_aligned[k] = return_templates[k,t_in_temp] - - self.templates_aligned = templates_aligned - self.templates_aligned_numpy= templates_aligned - # shift spike times according to the alignment - self.spike_train[:, 0] += self.temp_shifts[self.spike_train[:, 1]] - - #shifted neighboring template according to shift in primary template - def subtract_template(self, primary_unit, neighbor_unit): - primary_unit_shift = self.temp_shifts[primary_unit] - shifted = self.shift_template(self.templates[neighbor_unit], primary_unit_shift) - return self.templates_aligned[primary_unit] - shifted[:, self.chans[primary_unit]] - - def exclude_large_units(self, threshold): - - norms = np.zeros(self.n_units) - for j in range(self.n_units): - temp = self.templates[j] - vis_chan = np.where(temp.ptp(0) > 1)[0] - norms[j] = np.sum(np.square(temp[:, vis_chan])) - - - self.units_in = self.units_in.intersection(np.where(self.templates.ptp(1).max(1) < threshold)[0]) - - #idx_in = np.where(np.in1d(self.spike_train[:,1], units_in))[0] - #self.idx_included.update(self.idx_included.intersection(idx_in)) - def move_to_torch(self): - self.templates_aligned = [torch.from_numpy(element).float().cuda() for element in self.aligned_template_list] - self.spike_train = torch.from_numpy(self.spike_train).long().cuda() - self.shifts = torch.from_numpy(self.shifts).float().cuda() - - self.mcs = torch.from_numpy(self.mcs) - - def get_bspline_coeffs(self, template_aligned): - - n_data, n_times, n_channels = template_aligned.shape - - channels = torch.arange(n_channels).cuda() - temps_torch = torch.from_numpy(-(template_aligned.transpose(0, 2, 1))/2).cuda() - - temp_cpp = deconv.BatchedTemplates([deconv.Template(temp, channels) for temp in temps_torch]) - coeffs = deconv.BatchedTemplates([transform_template(template) for template in temp_cpp]) - return coeffs - def get_shifted_templates(self, temp_ids, shifts, iteration): - - temp_ids = torch.from_numpy(temp_ids.cpu().numpy()).long().cuda() - shifts = torch.from_numpy(shifts.cpu().numpy()).float().cuda() - - n_sample_run = 1000 - n_times = self.aligned_template_list[iteration].shape[1] - - idx_run = np.hstack((np.arange(0, len(shifts), n_sample_run), len(shifts))) - - shifted_templates = torch.zeros((len(shifts), n_times, self.n_chans)).cuda() - for j in range(len(idx_run)-1): - ii_start = idx_run[j] - ii_end =idx_run[j+1] - obj = torch.zeros(self.n_chans, (ii_end-ii_start)*n_times).cuda() - times = torch.arange(0, (ii_end-ii_start)*n_times, n_times).long().cuda() - deconv.subtract_splines(obj, - times, - shifts[ii_start:ii_end], - temp_ids[ii_start:ii_end], - self.coeff_list[iteration], - torch.full( (ii_end - ii_start, ), 2 ).cuda()) - obj = obj.reshape((self.n_chans, (ii_end-ii_start), n_times)) - shifted_templates[ii_start:ii_end] = obj.transpose(0,1).transpose(1,2) - - return shifted_templates - - def get_liklihood(self, unit, snip): - chans = self.chans[unit] - chans = chans < self.rec_chans - log_prob = np.ravel(snip[:, chans].T) @ self.cov_list[unit] @ np.ravel(snip[:, chans].T) - return log_prob - - def compute_soft_assignment(self): - - log_probs = torch.zeros((len(self.spike_train), self.sim_units)).half()#torch.zeros((len(self.spike_train),3)).cuda() - - # batch offsets - offsets = torch.from_numpy(self.reader_residual.idx_list[:, 0] - - self.reader_residual.buffer).cuda().long() - - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True)/np.sqrt(self.resid_var) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - idx_in = torch.nonzero((self.spike_train[:, 0] >= self.reader_residual.idx_list[batch_id][0]) & (self.spike_train[:, 0] < self.reader_residual.idx_list[batch_id][1]))[:,0] - - spike_train_batch = self.spike_train[idx_in] - spike_train_batch[:, 0] -= offsets[batch_id] - - shift_batch = self.shifts[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - shifted_templates = [self.get_shifted_templates(spike_train_batch[:,1], shift_batch, i) for i in range(self.sim_units)] - shifted_templates = [element for element in shifted_templates] - - - clean_wfs = [(resid_snippets + shifted) for shifted in shifted_templates] - unit_idx = spike_train_batch[:, 1] - logs = torch.zeros((idx_in.shape[0],self.sim_units)).half() - for i in range(len(clean_wfs)): - cov_array = self.cov_list[unit_idx] - - restricted = clean_wfs[i][:, (self.offset):(self.n_times -(self.offset)), :] - unraveled = restricted.permute(0, 2, 1).reshape(restricted.shape[0], restricted.shape[1]*restricted.shape[2]) - temp = torch.matmul(unraveled[:, None, :].half(), cov_array.half()) - result = torch.matmul( temp, unraveled[:, :, None].half()) - logs[:, i] = result.reshape(-1) - ''' - log_vec = parmap.map(get_liklihood, zip(unit_idx, clean_wfs[i]) , pm_processes=6) - logs[:, i] = np.asarray(log_vec) - ''' - ''' - for j, spike in enumerate(clean_wfs[i]): - rel_unit = self.similar_array[unit_idx[j]][i] - logs[j, i]= self.get_liklihood(rel_unit, spike) - ''' - log_probs[idx_in] = logs - - pbar.update() - - - - self.log_probs = log_probs.cpu().numpy() - return log_probs.cpu().numpy() - def clean_wave_forms(self, spike_idx, unit): - return_wfs = torch.zeros((spike_idx.shape[0], 81, 5)) - with tqdm(total=self.reader_residual.n_batches) as pbar: - for batch_id in range(self.reader_residual.n_batches): - - # load residual data - resid_dat = self.reader_residual.read_data_batch(batch_id, add_buffer=True) - resid_dat = torch.from_numpy(resid_dat).cuda() - - # relevant idx - s1 = self.spike_train[spike_idx, 0] >= self.reader_residual.idx_list[batch_id][0] - s2 = self.spike_train[spike_idx, 0] < self.reader_residual.idx_list[batch_id][1] - - idx_in = torch.nonzero((s1 & s2))[:,0] - - spike_train_batch = self.spike_train[spike_idx] - spike_train_batch = spike_train_batch[idx_in] - spike_train_batch[:, 0] -= (self.reader_residual.idx_list[batch_id][0] - self.reader_residual.buffer) - - shift_batch = self.shifts[spike_idx] - shift_batch = shift_batch[idx_in] - # get residual snippets - - t_index = spike_train_batch[:, 0][:, None] + torch.arange(-(self.n_times//2), self.n_times//2+1).cuda() - c_index = self.chans[spike_train_batch[:, 1]].long() - resid_dat = torch.cat((resid_dat, torch.zeros((resid_dat.shape[0], 1)).cuda()), 1) - resid_snippets = resid_dat[t_index[:,:,None], c_index[:,None]] - # get shifted templates - - shifted_og = self.get_shifted_templates(spike_train_batch[:,1], shift_batch, self.sim_units) - clean_wfs = resid_snippets + shifted_og - return_wfs[idx_in] = clean_wfs.cpu() - - return return_wfs.cpu().numpy() - - def get_assign_probs(self, log_lik_array): - fix = log_lik_array*-.5 - fix = fix - fix.max(1)[:, None] - probs = np.exp(fix)/np.exp(fix).sum(1)[:, None] - self.probs = probs - return probs - - def run(self): - - #construct array to identify soft assignment units - unit_assignment = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - for unit in range(self.n_units): - row_idx= self.spike_train_og[:, 1] == unit - unit_assignment[row_idx, :] = self.similar_array[unit, :] - - - log_probs = self.compute_soft_assignment() - probs = self.get_assign_probs(log_probs) - replace_probs = np.zeros((self.spike_train_og.shape[0], self.sim_units)) - replace_log = np.zeros((self.spike_train_og.shape[0],self.sim_units)) - - replace_probs[:, 0] = 1 - replace_probs[self.idx_included, :] = probs - - replace_log[self.idx_included, :] = log_probs - - return replace_probs, probs, replace_log, unit_assignment \ No newline at end of file From 6879ae98d8c722d97f5ea1e1cfdd8b7cdfaec09c Mon Sep 17 00:00:00 2001 From: kevinli1324 Date: Tue, 28 Jan 2020 04:45:20 -0500 Subject: [PATCH 13/19] Delete mexMPnu8.cu --- cuda/mexMPnu8.cu | 523 ----------------------------------------------- 1 file changed, 523 deletions(-) delete mode 100644 cuda/mexMPnu8.cu diff --git a/cuda/mexMPnu8.cu b/cuda/mexMPnu8.cu deleted file mode 100644 index 84a60c6d..00000000 --- a/cuda/mexMPnu8.cu +++ /dev/null @@ -1,523 +0,0 @@ -const int Nthreads = 1024, maxFR = 100000, NrankMax = 3, nmaxiter = 500, NchanMax = 32; -////////////////////////////////////////////////////////////////////////////////////////// -__global__ void spaceFilter(const double *Params, const float *data, const float *U, - const int *iC, const int *iW, float *dprod){ - volatile __shared__ float sU[32*NrankMax]; - volatile __shared__ int iU[32]; - float x; - int tid, bid, i,k, Nrank, Nchan, NT, Nfilt, NchanU; - - tid = threadIdx.x; - bid = blockIdx.x; - NT = (int) Params[0]; - Nfilt = (int) Params[1]; - Nrank = (int) Params[6]; - NchanU = (int) Params[10]; - Nchan = (int) Params[9]; - - if (tid=0 & t=0 && tid0 Cbest + 1e-6){ - Cnextbest = Cbest; - Cbest = Cf; - ibest = i; - } - else - if (Cf > Cnextbest + 1e-6) - Cnextbest = Cf; - } - err[tid0] = Cbest; - eloss[tid0] = Cbest - Cnextbest; - ftype[tid0] = ibest; - - tid0 += blockDim.x * gridDim.x; - } -} - -// THIS UPDATE DOES NOT UPDATE ELOSS? -////////////////////////////////////////////////////////////////////////////////////////// -__global__ void bestFilterUpdate(const double *Params, const float *data, - const float *mu, float *err, float *eloss, int *ftype, const int *st, const int *id, const int *counter){ - int tid, ind, i,t, NT, Nfilt, ibest = 0, nt0; - float Cf, Cbest, lam, b, a, Cnextbest; - - tid = threadIdx.x; - NT = (int) Params[0]; - Nfilt = (int) Params[1]; - lam = (float) Params[7]; - nt0 = (int) Params[4]; - - - // we only need to compute this at updated locations - ind = counter[1] + blockIdx.x; - - if (ind=0 && t Cbest + 1e-6){ - Cnextbest = Cbest; - Cbest = Cf; - ibest = i; - } - else - if (Cf > Cnextbest + 1e-6) - Cnextbest = Cf; - } - err[t] = Cbest; - ftype[t] = ibest; - } - } -} - -////////////////////////////////////////////////////////////////////////////////////////// -__global__ void cleanup_spikes(const double *Params, const float *data, - const float *mu, const float *err, const float *eloss, const int *ftype, int *st, - int *id, float *x, float *y, float *z, int *counter){ - - int lockout, indx, tid, bid, NT, tid0, j, id0, t0; - volatile __shared__ float sdata[Nthreads+2*81+1]; - bool flag=0; - float err0, Th; - - lockout = (int) Params[4] - 1; - tid = threadIdx.x; - bid = blockIdx.x; - - NT = (int) Params[0]; - tid0 = bid * blockDim.x ; - Th = (float) Params[2]; - //lam = (float) Params[7]; - - while(tid0Th*Th){ - flag = 0; - for(j=-lockout;j<=lockout;j++) - if(sdata[tid+lockout+j]>err0){ - flag = 1; - break; - } - if(flag==0){ - indx = atomicAdd(&counter[0], 1); - if (indxTh){ - if (id[currInd]==bid){ - if (tidx==0 && threadIdx.y==0) - nsp[bid]++; - - tidy = threadIdx.y; - while (tidyThS){ - - tidy = threadIdx.y; - // only do this if the spike is "BAD" - while (tidy Date: Tue, 28 Jan 2020 04:45:55 -0500 Subject: [PATCH 14/19] Delete mexSVDsmall2.cu --- cuda/mexSVDsmall2.cu | 255 ------------------------------------------- 1 file changed, 255 deletions(-) delete mode 100644 cuda/mexSVDsmall2.cu diff --git a/cuda/mexSVDsmall2.cu b/cuda/mexSVDsmall2.cu deleted file mode 100644 index 54c929a7..00000000 --- a/cuda/mexSVDsmall2.cu +++ /dev/null @@ -1,255 +0,0 @@ -const int Nthreads = 1024, NrankMax = 3, nt0max = 71, NchanMax = 1024; - -////////////////////////////////////////////////////////////////////////////////////////// -__global__ void blankdWU(const double *Params, const double *dWU, - const int *iC, const int *iW, double *dWUblank){ - - int nt0, tidx, tidy, bid, Nchan, NchanNear, iChan; - - nt0 = (int) Params[4]; - Nchan = (int) Params[9]; - NchanNear = (int) Params[10]; - - tidx = threadIdx.x; - tidy = threadIdx.y; - - bid = blockIdx.x; - - while (tidy xmax){ - xmax = abs(sW[t]); - imax = t; - } - - tid = threadIdx.x; - // shift by imax - tmax - for (k=0;k xmax){ - xmax = abs(sWup[t]); - imax = t; - sgnmax = copysign(1.0f, sWup[t]); - } - - // interpolate by imax - for (k=0;k Date: Tue, 28 Jan 2020 04:46:12 -0500 Subject: [PATCH 15/19] Delete mexWtW2.cu --- src/yass/reordering/cuda/mexWtW2.cu | 54 ----------------------------- 1 file changed, 54 deletions(-) delete mode 100644 src/yass/reordering/cuda/mexWtW2.cu diff --git a/src/yass/reordering/cuda/mexWtW2.cu b/src/yass/reordering/cuda/mexWtW2.cu deleted file mode 100644 index ccd32fb6..00000000 --- a/src/yass/reordering/cuda/mexWtW2.cu +++ /dev/null @@ -1,54 +0,0 @@ -const int nblock = 32; -////////////////////////////////////////////////////////////////////////////////////////// - -__global__ void crossFilter(const double *Params, const float *W1, const float *W2, - const float *UtU, float *WtW){ - __shared__ float shW1[nblock*81], shW2[nblock*81]; - - float x; - int nt0, tidx, tidy , bidx, bidy, i, Nfilt, t, tid1, tid2; - - tidx = threadIdx.x; - tidy = threadIdx.y; - bidx = blockIdx.x; - bidy = blockIdx.y; - - Nfilt = (int) Params[1]; - nt0 = (int) Params[9]; - - tid1 = tidx + bidx*nblock; - - tid2 = tidy + bidx*nblock; - if (tid2 Date: Tue, 28 Jan 2020 04:46:33 -0500 Subject: [PATCH 16/19] Delete mexMPnu8.cu --- src/yass/reordering/cuda/mexMPnu8.cu | 523 --------------------------- 1 file changed, 523 deletions(-) delete mode 100644 src/yass/reordering/cuda/mexMPnu8.cu diff --git a/src/yass/reordering/cuda/mexMPnu8.cu b/src/yass/reordering/cuda/mexMPnu8.cu deleted file mode 100644 index 84a60c6d..00000000 --- a/src/yass/reordering/cuda/mexMPnu8.cu +++ /dev/null @@ -1,523 +0,0 @@ -const int Nthreads = 1024, maxFR = 100000, NrankMax = 3, nmaxiter = 500, NchanMax = 32; -////////////////////////////////////////////////////////////////////////////////////////// -__global__ void spaceFilter(const double *Params, const float *data, const float *U, - const int *iC, const int *iW, float *dprod){ - volatile __shared__ float sU[32*NrankMax]; - volatile __shared__ int iU[32]; - float x; - int tid, bid, i,k, Nrank, Nchan, NT, Nfilt, NchanU; - - tid = threadIdx.x; - bid = blockIdx.x; - NT = (int) Params[0]; - Nfilt = (int) Params[1]; - Nrank = (int) Params[6]; - NchanU = (int) Params[10]; - Nchan = (int) Params[9]; - - if (tid=0 & t=0 && tid0 Cbest + 1e-6){ - Cnextbest = Cbest; - Cbest = Cf; - ibest = i; - } - else - if (Cf > Cnextbest + 1e-6) - Cnextbest = Cf; - } - err[tid0] = Cbest; - eloss[tid0] = Cbest - Cnextbest; - ftype[tid0] = ibest; - - tid0 += blockDim.x * gridDim.x; - } -} - -// THIS UPDATE DOES NOT UPDATE ELOSS? -////////////////////////////////////////////////////////////////////////////////////////// -__global__ void bestFilterUpdate(const double *Params, const float *data, - const float *mu, float *err, float *eloss, int *ftype, const int *st, const int *id, const int *counter){ - int tid, ind, i,t, NT, Nfilt, ibest = 0, nt0; - float Cf, Cbest, lam, b, a, Cnextbest; - - tid = threadIdx.x; - NT = (int) Params[0]; - Nfilt = (int) Params[1]; - lam = (float) Params[7]; - nt0 = (int) Params[4]; - - - // we only need to compute this at updated locations - ind = counter[1] + blockIdx.x; - - if (ind=0 && t Cbest + 1e-6){ - Cnextbest = Cbest; - Cbest = Cf; - ibest = i; - } - else - if (Cf > Cnextbest + 1e-6) - Cnextbest = Cf; - } - err[t] = Cbest; - ftype[t] = ibest; - } - } -} - -////////////////////////////////////////////////////////////////////////////////////////// -__global__ void cleanup_spikes(const double *Params, const float *data, - const float *mu, const float *err, const float *eloss, const int *ftype, int *st, - int *id, float *x, float *y, float *z, int *counter){ - - int lockout, indx, tid, bid, NT, tid0, j, id0, t0; - volatile __shared__ float sdata[Nthreads+2*81+1]; - bool flag=0; - float err0, Th; - - lockout = (int) Params[4] - 1; - tid = threadIdx.x; - bid = blockIdx.x; - - NT = (int) Params[0]; - tid0 = bid * blockDim.x ; - Th = (float) Params[2]; - //lam = (float) Params[7]; - - while(tid0Th*Th){ - flag = 0; - for(j=-lockout;j<=lockout;j++) - if(sdata[tid+lockout+j]>err0){ - flag = 1; - break; - } - if(flag==0){ - indx = atomicAdd(&counter[0], 1); - if (indxTh){ - if (id[currInd]==bid){ - if (tidx==0 && threadIdx.y==0) - nsp[bid]++; - - tidy = threadIdx.y; - while (tidyThS){ - - tidy = threadIdx.y; - // only do this if the spike is "BAD" - while (tidy Date: Wed, 25 Mar 2020 16:28:46 -0400 Subject: [PATCH 17/19] diptest error fixed --- src/diptest/_diptest.c | 42 +++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/src/diptest/_diptest.c b/src/diptest/_diptest.c index 88fe399e..13be57ce 100644 --- a/src/diptest/_diptest.c +++ b/src/diptest/_diptest.c @@ -18208,9 +18208,9 @@ static CYTHON_INLINE PyObject* __Pyx_decode_c_string( static CYTHON_INLINE void __Pyx_ExceptionSave(PyObject **type, PyObject **value, PyObject **tb) { #if CYTHON_COMPILING_IN_CPYTHON PyThreadState *tstate = PyThreadState_GET(); - *type = tstate->exc_type; - *value = tstate->exc_value; - *tb = tstate->exc_traceback; + *type = tstate->curexc_type; + *value = tstate->curexc_value; + *tb = tstate->curexc_traceback; Py_XINCREF(*type); Py_XINCREF(*value); Py_XINCREF(*tb); @@ -18222,12 +18222,12 @@ static void __Pyx_ExceptionReset(PyObject *type, PyObject *value, PyObject *tb) #if CYTHON_COMPILING_IN_CPYTHON PyObject *tmp_type, *tmp_value, *tmp_tb; PyThreadState *tstate = PyThreadState_GET(); - tmp_type = tstate->exc_type; - tmp_value = tstate->exc_value; - tmp_tb = tstate->exc_traceback; - tstate->exc_type = type; - tstate->exc_value = value; - tstate->exc_traceback = tb; + tmp_type = tstate->curexc_type; + tmp_value = tstate->curexc_value; + tmp_tb = tstate->curexc_traceback; + tstate->curexc_type = type; + tstate->curexc_value = value; + tstate->curexc_traceback = tb; Py_XDECREF(tmp_type); Py_XDECREF(tmp_value); Py_XDECREF(tmp_tb); @@ -18270,12 +18270,12 @@ static int __Pyx_GetException(PyObject **type, PyObject **value, PyObject **tb) *value = local_value; *tb = local_tb; #if CYTHON_COMPILING_IN_CPYTHON - tmp_type = tstate->exc_type; - tmp_value = tstate->exc_value; - tmp_tb = tstate->exc_traceback; - tstate->exc_type = local_type; - tstate->exc_value = local_value; - tstate->exc_traceback = local_tb; + tmp_type = tstate->curexc_type; + tmp_value = tstate->curexc_value; + tmp_tb = tstate->curexc_traceback; + tstate->curexc_type = local_type; + tstate->curexc_value = local_value; + tstate->curexc_traceback = local_tb; Py_XDECREF(tmp_type); Py_XDECREF(tmp_value); Py_XDECREF(tmp_tb); @@ -18297,12 +18297,12 @@ static CYTHON_INLINE void __Pyx_ExceptionSwap(PyObject **type, PyObject **value, PyObject *tmp_type, *tmp_value, *tmp_tb; #if CYTHON_COMPILING_IN_CPYTHON PyThreadState *tstate = PyThreadState_GET(); - tmp_type = tstate->exc_type; - tmp_value = tstate->exc_value; - tmp_tb = tstate->exc_traceback; - tstate->exc_type = *type; - tstate->exc_value = *value; - tstate->exc_traceback = *tb; + tmp_type = tstate->curexc_type; + tmp_value = tstate->curexc_value; + tmp_tb = tstate->curexc_traceback; + tstate->curexc_type = *type; + tstate->curexc_value = *value; + tstate->curexc_traceback = *tb; #else PyErr_GetExcInfo(&tmp_type, &tmp_value, &tmp_tb); PyErr_SetExcInfo(*type, *value, *tb); From c33a939d7547c76252f30989ee6cbe65710b13f3 Mon Sep 17 00:00:00 2001 From: Cat Date: Thu, 2 Apr 2020 09:43:30 -0400 Subject: [PATCH 18/19] reorder updates --- src/yass/reordering/reorder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/yass/reordering/reorder.py b/src/yass/reordering/reorder.py index 8365ea24..551c0c0b 100644 --- a/src/yass/reordering/reorder.py +++ b/src/yass/reordering/reorder.py @@ -59,5 +59,6 @@ def run(save_fname, standardized_fname, CONFIG,n_sec_chunk, nPCs = 3, nt0 = 61, yass_batch = params.Nbatch, n_chunk_sec = int(n_sec_chunk*CONFIG.recordings.sampling_rate), nt0 = nt0) + print (" DONE REODERING, SAVING: ", save_fname) np.save(save_fname, cp.asnumpy(result['iorig'])) From 18f61a36dbac28bf2ff60097a1b8358f9b402b7a Mon Sep 17 00:00:00 2001 From: Cat Date: Fri, 3 Apr 2020 12:15:54 -0400 Subject: [PATCH 19/19] Reordering implemented --- samples/10chan/config.yaml | 13 ++-- src/yass/phy/run.py | 29 +++----- src/yass/pipeline.py | 35 ++++++--- src/yass/preprocess/run.py | 12 ++-- src/yass/reordering/reorder.py | 128 ++++++++++++++++++++++++++++++--- 5 files changed, 172 insertions(+), 45 deletions(-) diff --git a/samples/10chan/config.yaml b/samples/10chan/config.yaml index 6c6c7cb5..4d7582f6 100644 --- a/samples/10chan/config.yaml +++ b/samples/10chan/config.yaml @@ -20,20 +20,22 @@ resources: multi_processing: 1 # Number of cores to use n_processors: 2 - # Length of processing chunk in seconds for multi-processing stages - n_sec_chunk: 10 - # number of GPUs to use + # number of GPUs to use; n_gpu_processors: 1 - # n_sec_chunk for cpu + # n_sec_chunk for cpu n_sec_chunk: 5 # n_sec_chunk for gpu detection n_sec_chunk_gpu_detect: 0.5 # n_sec_chunk for gpu deconvolution n_sec_chunk_gpu_deconv: 5 + # drift chunk rearanging: 0 - do not track multi-chan drift; + drift: 1 + # length of chunks to rearange data for drift step + n_sec_drift_chunk: 5 # gpu_id; default is 0, i.e. first gpu; gpu_id: 0 # generate phy visualization files; 0 - do not run; 1: generate phy files - generate_phy: 0 + generate_phy: 1 recordings: # precision of the recording – must be a valid numpy dtype @@ -160,4 +162,5 @@ deconvolution: # minimum # of spikes required to split min_split_spikes: 50 + diff --git a/src/yass/phy/run.py b/src/yass/phy/run.py index cd9b3bd1..1343201c 100644 --- a/src/yass/phy/run.py +++ b/src/yass/phy/run.py @@ -14,7 +14,7 @@ #from yass.deconvolve.soft_assignment import get_soft_assignments -def run(CONFIG): +def run(CONFIG, fname_spike_train): """Generate phy2 visualization files """ @@ -25,7 +25,12 @@ def run(CONFIG): # set root directory for output root_dir = CONFIG.data.root_folder - fname_standardized = os.path.join(os.path.join(os.path.join( + + if CONFIG.resources.drift: + fname_standardized = os.path.join(os.path.join(os.path.join( + root_dir,'tmp'),'preprocess'),'standardized_original.bin') + else: + fname_standardized = os.path.join(os.path.join(os.path.join( root_dir,'tmp'),'preprocess'),'standardized.bin') # @@ -42,7 +47,7 @@ def run(CONFIG): # cluster id for each spike; [n_spikes] #spike_train = np.load(root_dir + '/tmp/spike_train.npy') - spike_train = np.load(root_dir + '/tmp/final_deconv/deconv/spike_train.npy') + spike_train = np.load(fname_spike_train) spike_clusters = spike_train[:,1] np.save(root_dir+'/phy/spike_clusters.npy', spike_clusters) @@ -100,7 +105,8 @@ def run(CONFIG): fname_out = os.path.join(output_directory,'pc_objects.npy') if os.path.exists(fname_out)==False: pc_projections = get_pc_objects(root_dir, pc_feature_ind, n_channels, - n_times, units, n_components, CONFIG, spike_train) + n_times, units, n_components, CONFIG, spike_train, + fname_standardized) np.save(fname_out, pc_projections) else: pc_projections = np.load(fname_out,allow_pickle=True) @@ -215,25 +221,12 @@ def get_pc_objects_parallel(units, n_channels, pc_feature_ind, spike_train, def get_pc_objects(root_dir,pc_feature_ind, n_channels, n_times, units, n_components, CONFIG, - spike_train): + spike_train, fname_standardized): ''' First grab 10% of the spikes on each channel and makes PCA objects for each channel Then generate PCA object for each channel using spikes ''' - # load templates from spike trains - # templates = np.load(root_dir + '/tmp/templates.npy') - # print (templates.shape) - - # standardized filename - fname_standardized = os.path.join(os.path.join(os.path.join(root_dir,'tmp'), - 'preprocess'),'standardized.bin') - - # spike_train - #spike_train = np.load(os.path.join(os.path.join(root_dir, 'tmp'),'spike_train.npy')) - #spike_train = np.load(os.path.join(os.path.join(root_dir, 'tmp'),'spike_train.npy')) - - # ******************************************** # ***** APPROXIMATE PROJ MATRIX EACH CHAN **** # ******************************************** diff --git a/src/yass/pipeline.py b/src/yass/pipeline.py index 2fe1f107..42baf197 100644 --- a/src/yass/pipeline.py +++ b/src/yass/pipeline.py @@ -37,6 +37,7 @@ from yass.util import (load_yaml, save_metadata, load_logging_config_file, human_readable_time) +from yass.reordering import (reorder) def run(config, logger_level='INFO', clean=False, output_dir='tmp/', complete=False, calculate_rf=False, visualize=False, set_zero_seed=False): @@ -86,7 +87,6 @@ def run(config, logger_level='INFO', clean=False, output_dir='tmp/', set_config(config, output_dir) CONFIG = read_config() TMP_FOLDER = CONFIG.path_to_output_directory - generate_phy = CONFIG.resources.generate_phy # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) @@ -172,7 +172,6 @@ def run(config, logger_level='INFO', clean=False, output_dir='tmp/', standardized_path, standardized_dtype, fname_templates, - generate_phy, CONFIG, update_templates = CONFIG.deconvolution.update_templates, run_chunk_sec = CONFIG.final_deconv_chunk) @@ -456,7 +455,6 @@ def final_deconv(TMP_FOLDER, standardized_path, standardized_dtype, fname_templates, - generate_phy, CONFIG, update_templates, run_chunk_sec): @@ -500,15 +498,11 @@ def final_deconv(TMP_FOLDER, update_templates=update_templates, run_chunk_sec=run_chunk_sec) - + ''' ********************************************** - ************** GENERATE PHY FILES ************ + ************** GENERATE SOFT ASSIGNMENT ****** ********************************************** ''' - - if generate_phy: - phy.run(CONFIG) - logger.info('SOFT ASSIGNMENT') fname_noise_soft, fname_template_soft = soft_assignment.run( fname_templates, @@ -520,6 +514,29 @@ def final_deconv(TMP_FOLDER, fname_residual, residual_dtype) + + + ''' ********************************************** + ************** REORDER SPIKE TRAINS ********** + ********************************************** + ''' + + if CONFIG.resources.drift: + logger.info('REORDER SPIKE TRAINS') + reorder.reorder_spike_train(CONFIG, fname_spike_train) + + + ''' ********************************************** + ************** GENERATE PHY FILES ************ + ********************************************** + ''' + + if CONFIG.resources.generate_phy: + logger.info('GENERATE PHY FILES') + phy.run(CONFIG, fname_spike_train) + + + return (fname_templates, fname_spike_train, fname_noise_soft, diff --git a/src/yass/preprocess/run.py b/src/yass/preprocess/run.py index e364cbad..25ca1b97 100644 --- a/src/yass/preprocess/run.py +++ b/src/yass/preprocess/run.py @@ -177,10 +177,12 @@ def run(output_directory): with open(path_to_yaml, 'w') as f: logger.info('Saving params...') yaml.dump(standardized_params, f) - reorder.run(save_fname = reorder_fname, - standardized_fname = standardized_path, - CONFIG = CONFIG, - n_sec_chunk = 5, - dtype = CONFIG.preprocess.dtype) + + # + if CONFIG.resources.drift: + reorder.run(save_fname = reorder_fname, + standardized_fname = standardized_path, + CONFIG = CONFIG, + dtype = CONFIG.preprocess.dtype) return standardized_path, standardized_params['dtype'], reorder_fname diff --git a/src/yass/reordering/reorder.py b/src/yass/reordering/reorder.py index 551c0c0b..02fbf60f 100644 --- a/src/yass/reordering/reorder.py +++ b/src/yass/reordering/reorder.py @@ -8,8 +8,6 @@ from yass.reordering.preprocess import get_good_channels import os import cupy as cp -#initialize object - class PARAM: pass @@ -17,10 +15,73 @@ class PARAM: class PROBE: pass -def run(save_fname, standardized_fname, CONFIG,n_sec_chunk, nPCs = 3, nt0 = 61, reorder = True, dtype = np.float32 ): +class ReadReorder(object): + ''' Class that reorders the raw binary standardized file based on + output of rastermap function. + ''' + + def __init__(self): + pass + + def read(self, data_start, length_of_chunk, channels=None): + with open(self.bin_file, "rb") as fin: + # Seek position and read N bytes + fin.seek(int(data_start*self.dtype.itemsize), os.SEEK_SET) + data = np.fromfile(fin, dtype=self.dtype, + count=length_of_chunk) + fin.close() - + data = data.reshape(-1, self.n_chans) + + return data + + def reorder(self): + + # fixed value to pad raw data; + pad_len = 200 + self.dtype = np.float32([0]).dtype + zero_chunk = np.zeros((pad_len, self.n_chans),dtype=self.dtype) + pad_len_samples = pad_len*self.n_chans + + # set chunking information + #chunk_len= 2 + # Cat: TODO: this may miss small bits of data if irregular ended acquisition + data_size = int(os.path.getsize(self.bin_file)/self.n_chans/4) + n_chunks = data_size//(self.sample_rate*self.chunk_len) + #print ("Data_size: ", data_size, " nchunks: ", n_chunks, " idx: ", chunk_idxs) + + # load data in different order + new_file_name = self.bin_file[:-4]+"_reordered.bin" + new_file = open(new_file_name, 'wb') + for idx_ in self.chunk_idxs: + data_start = idx_*self.sample_rate*self.chunk_len*self.n_chans + length_of_chunk = self.sample_rate*self.chunk_len*self.n_chans + + # if neither last nor first grab chunk + padding + if (idx_ != 0) and (idx_!=(n_chunks-1)): + chunk = self.read(data_start-pad_len_samples, length_of_chunk+pad_len_samples*2) + #print(" start: ", data_start-pad_len_samples, " , length: ", length_of_chunk+pad_len_samples*2) + elif (idx_==0): + chunk = self.read(data_start, length_of_chunk+pad_len_samples) + #print(" start: ", data_start, " , length: ", length_of_chunk+pad_len_samples) + chunk = np.hstack((zero_chunk.T,chunk.T)).T + elif (idx_==(n_chunks-1)): + chunk = self.read(data_start-pad_len_samples, length_of_chunk+pad_len_samples) + #print(" start: ", data_start-pad_len_samples, " , length: ", length_of_chunk+pad_len_samples) + chunk = np.hstack((chunk.T, zero_chunk.T)).T + + new_file.write(chunk) + + new_file.close() + + # delete old standardized file and overwrite with the new one. + os.system('mv '+ self.bin_file + " " + self.bin_file[:-4]+"_original.bin") + os.system('mv '+ new_file_name + " " + self.bin_file) + + +def run(save_fname, standardized_fname, CONFIG, nPCs = 3, nt0 = 61, reorder = True, dtype = np.float32 ): + params = PARAM() probe = PROBE() @@ -48,7 +109,7 @@ def run(save_fname, standardized_fname, CONFIG,n_sec_chunk, nPCs = 3, nt0 = 61, probe.Nchan = params.Nchan shape = (params.Nchan, CONFIG.rec_len) standardized_mmemap = np.memmap(standardized_fname, order = "F", dtype = dtype) - params.Nbatch = np.ceil(CONFIG.rec_len/(n_sec_chunk*CONFIG.recordings.sampling_rate)).astype(np.int16) + params.Nbatch = np.ceil(CONFIG.rec_len/(CONFIG.resources.n_sec_drift_chunk*CONFIG.recordings.sampling_rate)).astype(np.int16) params.reorder = reorder params.nt0min = np.ceil(20 * nt0 / 61).astype(np.int16) @@ -57,8 +118,59 @@ def run(save_fname, standardized_fname, CONFIG,n_sec_chunk, nPCs = 3, nt0 = 61, params = params, probe = probe, yass_batch = params.Nbatch, - n_chunk_sec = int(n_sec_chunk*CONFIG.recordings.sampling_rate), + n_chunk_sec = int(CONFIG.resources.n_sec_drift_chunk*CONFIG.recordings.sampling_rate), nt0 = nt0) - print (" DONE REODERING, SAVING: ", save_fname) - np.save(save_fname, cp.asnumpy(result['iorig'])) + # save chunk order and reorder file + print (" saving chunk order: ", save_fname) + chunk_ids = cp.asnumpy(result['iorig']) + np.save(save_fname, chunk_ids) + + + # initialize READER + RR = ReadReorder() + + # initialize data + RR.sample_rate = CONFIG.recordings.sampling_rate + RR.bin_file = os.path.join(CONFIG.data.root_folder, + 'tmp/preprocess/standardized.bin') + print (RR.bin_file) + RR.chunk_len = CONFIG.resources.n_sec_drift_chunk + RR.n_chans = CONFIG.recordings.n_channels + + RR.chunk_idxs = chunk_ids + RR.reorder() + + +def reorder_spike_train(CONFIG, spike_train_fname): + ''' Re order the spike trains obtained from monotonic drift + version of standardized file back to original temporal order + ''' + + spike_train = np.load(spike_train_fname) + + indexes = np.load(os.path.join(CONFIG.data.root_folder, + 'tmp/preprocess/reorder.npy')) + sample_rate = CONFIG.recordings.sampling_rate + pad_len = 200 + chunk_len = CONFIG.resources.n_sec_drift_chunk + + spike_train_reordered = np.zeros((0,2),'int32') + for ctr, idx in enumerate(indexes): + start = ctr*(sample_rate*chunk_len+pad_len*2)+pad_len + end = (start + chunk_len*sample_rate) + temp_idx = np.where(np.logical_and(spike_train[:,0]>=start,spike_train[:,0]