1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | __global__ void matMultiply(float *Ad, float *Bd, float *AxBd, int SIZE, int iter){ //SIZE 是矩陣的維度 int i = threadIdx.y + blockIdx.y*TILE_W; int j = threadIdx.x + blockIdx.x*TILE_W; int k; float AxB_ij; if(iter == 0){ for(k=0; k<SIZE; k++){ AxB_ij += Ad[i*SIZE + k] * Bd[k*SIZE +j]; } AxBd[i*SIZE +j] = AxB_ij; } // iter!=0 是矩陣太大時,每個block要跑多個tile用的 else { while (i<SIZE){ while (j<SIZE){ AxB_ij = 0; for(k=0; k<SIZE; k++){ AxB_ij += Ad[i*SIZE+k] * Bd[k*SIZE+j]; } AxBd[i*SIZE +j] = AxB_ij; // j counter 的進位 j += blockDim.x * gridDim.x; } // i counter 的進位, j counter 回到本來的 threadIdx.x +blockIdx.x*TILE_W i += blockDim.y * gridDim.y; j = threadIdx.x + blockIdx.x*TILE_W; } } } |
Direct link: https://paste.plurk.com/show/407891