omp_set_num_threads(NUM_THREADS); |
rows = nj/NUM_THREADS; |
LDA = ni + 2; |
// main iteration loop |
#pragma omp parallel private(istep) |
{ |
float temp1, temp2, temp_tmp; |
int tid = omp_get_thread_num(); |
acc_set_device_num(tid+1, acc_device_not_host); |
temp1 = temp1_h + tidrowsLDA; |
temp2 = temp2_h + tidrowsLDA; |
#pragma acc data copyin(temp1[0:(rows+2)LDA]) |
copyin(temp2[0:(rows+2)LDA]) |
{ |
for(istep=0; istep < nstep; istep++){ |
step_kernel(ni+2, rows+2, tfac, temp1, temp2); |
/ all devices (except the last one) update the lower halo to the host / |
if(tid != NUM_THREADS-1){ |
#pragma acc update host(temp2[rowsLDA:LDA]) |
} |
/ all devices (except the first one) update the upper halo to the host / |
if(tid != 0){ |
#pragma acc update host(temp2[LDA:LDA]) |
} |
/ all host threads wait here to make sure halo data from all devices |
have been updated to the host / |
#pragma omp barrier |
/ update the upper halo to all devices (except the first one) / |
if(tid != 0){ |
#pragma acc update device(temp2[0:LDA]) |
} |
/ update the lower halo to all devices (except the last one) / |
if(tid != NUM_THREADS-1){ |
#pragma acc update device(temp2[(rows+1)LDA:LDA]) |
} |
temp_tmp = temp1; |
temp1 = temp2; |
temp2 = temp_tmp; |
} |
/ update the final result to host / |
#pragma acc update host(temp1[LDA:rowLDA]) |
} |
} |