1 #pragma once
2
6
10
12
14 {
15
16 template <class Field>
17 void inline split_field(Field &collect_field, std::vector<Field *> &v_base_field,
const CommKey &comm_key,
19 {
22
24 int total_rank =
product(comm_grid_dim);
25
39 auto processor_dim = comm_grid_dim / comm_key; // How many processors are there in a processor grid sub-parititon?
40 auto partition_dim
41 = comm_grid_dim / processor_dim; // How many such sub-partitions are there? partition_dim == comm_key
42
43 int n_replicates =
product(comm_key);
44 std::vector<void *> v_send_buffer_h(n_replicates, nullptr);
45 std::vector<MsgHandle *> v_mh_send(n_replicates, nullptr);
46
47 int n_fields = v_base_field.size();
48 if (n_fields == 0) {
errorQuda(
"split_field: input field vec has zero size."); }
49
50 const auto meta = v_base_field[0];
51
52 // Send cycles
53 for (int i = 0; i < n_replicates; i++) {
55 auto processor_idx = comm_grid_idx / partition_dim; // Which processor in that partition to send to?
56
57 auto dst_idx = partition_idx * processor_dim + processor_idx;
58
60 int tag = rank * total_rank + dst_rank; // tag = src_rank * total_rank + dst_rank
61
62 size_t bytes = meta->TotalBytes();
63
65
66 v_base_field[i % n_fields]->copy_to_buffer(v_send_buffer_h[i]);
67
70 }
71
72 using param_type = typename Field::param_type;
73
74 param_type
param(*meta);
75 Field *buffer_field = Field::Create(
param);
76
77 CommKey field_dim = {meta->full_dim(0), meta->full_dim(1), meta->full_dim(2), meta->full_dim(3)};
78
79 // Receive cycles
80 for (int i = 0; i < n_replicates; i++) {
81 auto partition_idx
82 =
coordinate_from_index(i, comm_key);
// Here this means which partition of the field we are working on.
83 auto src_idx
84 = (comm_grid_idx % processor_dim) * partition_dim + partition_idx; // And where does this partition comes from?
85
87 int tag = src_rank * total_rank + rank;
88
89 size_t bytes = buffer_field->TotalBytes();
90
92
94
97
98 buffer_field->copy_from_buffer(recv_buffer_h);
99
102
103 auto offset = partition_idx * field_dim;
104
106 }
107
108 delete buffer_field;
109
111
112 for (auto &p : v_send_buffer_h) {
114 };
115 for (auto &p : v_mh_send) {
117 };
118 }
119
120 template <class Field>
121 void inline join_field(std::vector<Field *> &v_base_field,
const Field &collect_field,
const CommKey &comm_key,
123 {
126
128 int total_rank =
product(comm_grid_dim);
129
130 auto processor_dim = comm_grid_dim / comm_key; // Communicator grid.
131 auto partition_dim
132 = comm_grid_dim / processor_dim; // The full field needs to be partitioned according to the communicator grid.
133
134 int n_replicates =
product(comm_key);
135 std::vector<void *> v_send_buffer_h(n_replicates, nullptr);
136 std::vector<MsgHandle *> v_mh_send(n_replicates, nullptr);
137
138 int n_fields = v_base_field.size();
139 if (n_fields == 0) {
errorQuda(
"join_field: output field vec has zero size."); }
140
141 const auto &meta = *(v_base_field[0]);
142
143 using param_type = typename Field::param_type;
144
145 param_type
param(meta);
146 Field *buffer_field = Field::Create(
param);
147
148 CommKey field_dim = {meta.full_dim(0), meta.full_dim(1), meta.full_dim(2), meta.full_dim(3)};
149
150 // Send cycles
151 for (int i = 0; i < n_replicates; i++) {
152
154 auto dst_idx = (comm_grid_idx % processor_dim) * partition_dim + partition_idx;
155
157 int tag = rank * total_rank + dst_rank;
158
159 size_t bytes = meta.TotalBytes();
160
161 auto offset = partition_idx * field_dim;
163
165 buffer_field->copy_to_buffer(v_send_buffer_h[i]);
166
168
170 }
171
172 // Receive cycles
173 for (int i = 0; i < n_replicates; i++) {
174
176 auto processor_idx = comm_grid_idx / partition_dim;
177
178 auto src_idx = partition_idx * processor_dim + processor_idx;
179
181 int tag = src_rank * total_rank + rank;
182
183 size_t bytes = buffer_field->TotalBytes();
184
186
188
191
192 v_base_field[i % n_fields]->copy_from_buffer(recv_buffer_h);
193
196 }
197
198 delete buffer_field;
199
201
202 for (
auto &p : v_send_buffer_h) {
host_free(p); };
203 for (
auto &p : v_mh_send) {
comm_free(p); };
204 }
205
206 } // namespace quda
void comm_start(MsgHandle *mh)
MsgHandle * comm_declare_recv_rank(void *buffer, int rank, int tag, size_t nbytes)
MsgHandle * comm_declare_send_rank(void *buffer, int rank, int tag, size_t nbytes)
void comm_wait(MsgHandle *mh)
void comm_free(MsgHandle *&mh)
enum QudaPCType_s QudaPCType
#define pinned_malloc(size)
constexpr int product(const CommKey &input)
void join_field(std::vector< Field * > &v_base_field, const Field &collect_field, const CommKey &comm_key, QudaPCType pc_type=QUDA_4D_PC)
void split_field(Field &collect_field, std::vector< Field * > &v_base_field, const CommKey &comm_key, QudaPCType pc_type=QUDA_4D_PC)
constexpr CommKey coordinate_from_index(int index, CommKey dim)
void copyFieldOffset(CloverField &out, const CloverField &in, CommKey offset, QudaPCType pc_type)
This function is used for copying from a source clover field to a destination clover field with an of...
Main header file for the QUDA library.
int comm_rank_from_coords(const int *coords)