diff --git a/srpt/src/ib_srpt.c b/srpt/src/ib_srpt.c index 7872e7cf8..48ca77917 100644 --- a/srpt/src/ib_srpt.c +++ b/srpt/src/ib_srpt.c @@ -826,7 +826,7 @@ static int srpt_post_recv(struct srpt_device *sdev, struct ib_recv_wr wr, *bad_wr; BUG_ON(!sdev); - wr.wr_id = encode_wr_id(IB_WC_RECV, ioctx->ioctx.index); + wr.wr_id = encode_wr_id(SRPT_RECV, ioctx->ioctx.index); list.addr = ioctx->ioctx.dma; list.length = srp_max_req_size; @@ -869,7 +869,7 @@ static int srpt_post_send(struct srpt_rdma_ch *ch, list.lkey = sdev->mr->lkey; wr.next = NULL; - wr.wr_id = encode_wr_id(IB_WC_SEND, ioctx->ioctx.index); + wr.wr_id = encode_wr_id(SRPT_SEND, ioctx->ioctx.index); wr.sg_list = &list; wr.num_sge = 1; wr.opcode = IB_WR_SEND; @@ -1317,6 +1317,7 @@ static void srpt_handle_send_comp(struct srpt_rdma_ch *ch, */ static void srpt_handle_rdma_comp(struct srpt_rdma_ch *ch, struct srpt_send_ioctx *ioctx, + enum srpt_opcode opcode, enum scst_exec_context context) { enum srpt_command_state state; @@ -1326,7 +1327,7 @@ static void srpt_handle_rdma_comp(struct srpt_rdma_ch *ch, atomic_add(ioctx->n_rdma, &ch->sq_wr_avail); scmnd = ioctx->scmnd; - if (scmnd) { + if (opcode == SRPT_RDMA_READ_LAST && scmnd) { state = srpt_test_and_set_cmd_state(ioctx, SRPT_STATE_NEED_DATA, SRPT_STATE_DATA_IN); if (state == SRPT_STATE_NEED_DATA) @@ -1335,8 +1336,13 @@ static void srpt_handle_rdma_comp(struct srpt_rdma_ch *ch, else PRINT_ERROR("%s[%d]: wrong state = %d", __func__, __LINE__, state); - } else - PRINT_ERROR("%s[%d]: scmnd == NULL", __func__, __LINE__); + } else if (opcode == SRPT_RDMA_ABORT) { + ioctx->rdma_aborted = true; + } else { + WARN_ON(opcode != SRPT_RDMA_READ_LAST); + PRINT_ERROR("%s[%d]: scmnd == NULL (opcode %d)", __func__, + __LINE__, opcode); + } } /** @@ -1344,7 +1350,7 @@ static void srpt_handle_rdma_comp(struct srpt_rdma_ch *ch, */ static void srpt_handle_rdma_err_comp(struct srpt_rdma_ch *ch, struct srpt_send_ioctx *ioctx, - u8 opcode, + enum srpt_opcode opcode, enum scst_exec_context context) { struct scst_cmd *scmnd; @@ -1354,7 +1360,7 @@ static void srpt_handle_rdma_err_comp(struct srpt_rdma_ch *ch, state = srpt_get_cmd_state(ioctx); if (scmnd) { switch (opcode) { - case IB_WC_RDMA_READ: + case SRPT_RDMA_READ_LAST: if (ioctx->n_rdma <= 0) { PRINT_ERROR("Received invalid RDMA read error" " completion with idx %d", @@ -1368,7 +1374,7 @@ static void srpt_handle_rdma_err_comp(struct srpt_rdma_ch *ch, PRINT_ERROR("%s[%d]: wrong state = %d", __func__, __LINE__, state); break; - case IB_WC_RDMA_WRITE: + case SRPT_RDMA_WRITE_LAST: scst_set_delivery_status(scmnd, SCST_CMD_DELIVERY_ABORTED); break; @@ -1791,34 +1797,33 @@ static void srpt_process_send_completion(struct ib_cq *cq, { struct srpt_send_ioctx *send_ioctx; uint32_t index; - u8 opcode; + enum srpt_opcode opcode; index = idx_from_wr_id(wc->wr_id); opcode = opcode_from_wr_id(wc->wr_id); send_ioctx = ch->ioctx_ring[index]; if (wc->status == IB_WC_SUCCESS) { - if (opcode == IB_WC_SEND) + if (opcode == SRPT_SEND) srpt_handle_send_comp(ch, send_ioctx, context); else { - EXTRACHECKS_WARN_ON(wc->opcode != IB_WC_RDMA_READ); - srpt_handle_rdma_comp(ch, send_ioctx, context); + EXTRACHECKS_WARN_ON(opcode != SRPT_RDMA_ABORT && + wc->opcode != IB_WC_RDMA_READ); + srpt_handle_rdma_comp(ch, send_ioctx, opcode, context); } } else { - if (opcode == IB_WC_SEND) { + if (opcode == SRPT_SEND) { PRINT_INFO("sending response for idx %u failed with" " status %d", index, wc->status); srpt_handle_send_err_comp(ch, wc->wr_id, context); - } else { - PRINT_INFO("RDMA %s for idx %u failed with status %d", - opcode == IB_WC_RDMA_READ ? "read" - : opcode == IB_WC_RDMA_WRITE ? "write" - : "???", index, wc->status); + } else if (opcode != SRPT_RDMA_MID) { + PRINT_INFO("RDMA t %d for idx %u failed with status %d", + opcode, index, wc->status); srpt_handle_rdma_err_comp(ch, send_ioctx, opcode, context); } } - while (unlikely(opcode == IB_WC_SEND + while (unlikely(opcode == SRPT_SEND && !list_empty(&ch->cmd_wait_list) && atomic_read(&ch->state) == RDMA_CHANNEL_LIVE && (send_ioctx = srpt_get_send_ioctx(ch)) != NULL)) { @@ -1844,7 +1849,7 @@ static void srpt_process_completion(struct ib_cq *cq, ib_req_notify_cq(cq, IB_CQ_NEXT_COMP); while ((n = ib_poll_cq(cq, ARRAY_SIZE(ch->wc), wc)) > 0) { for (i = 0; i < n; i++) { - if (opcode_from_wr_id(wc[i].wr_id) & IB_WC_RECV) + if (opcode_from_wr_id(wc[i].wr_id) == SRPT_RECV) srpt_process_rcv_completion(cq, ch, context, &wc[i]); else @@ -1904,6 +1909,11 @@ static int srpt_compl_thread(void *arg) } PRINT_INFO("Session %s: kernel thread %s (PID %d) stopped", ch->sess_name, ch->thread->comm, current->pid); + while (!kthread_should_stop()) { + set_current_state(TASK_INTERRUPTIBLE); + schedule(); + } + return 0; } @@ -2895,6 +2905,7 @@ static int srpt_perform_rdmas(struct srpt_rdma_ch *ch, int i; int ret; int sq_wr_avail; + const int n_rdma = ioctx->n_rdma; if (dir == SCST_DATA_WRITE) { ret = -ENOMEM; @@ -2902,23 +2913,28 @@ static int srpt_perform_rdmas(struct srpt_rdma_ch *ch, &ch->sq_wr_avail); if (sq_wr_avail < 0) { PRINT_WARNING("IB send queue full (needed %d)", - ioctx->n_rdma); + n_rdma); goto out; } } + ioctx->rdma_aborted = false; ret = 0; riu = ioctx->rdma_ius; memset(&wr, 0, sizeof wr); - for (i = 0; i < ioctx->n_rdma; ++i, ++riu) { + for (i = 0; i < n_rdma; ++i, ++riu) { if (dir == SCST_DATA_READ) { wr.opcode = IB_WR_RDMA_WRITE; - wr.wr_id = encode_wr_id(IB_WC_RDMA_WRITE, + wr.wr_id = encode_wr_id(i == n_rdma - 1 ? + SRPT_RDMA_WRITE_LAST : + SRPT_RDMA_MID, ioctx->ioctx.index); } else { wr.opcode = IB_WR_RDMA_READ; - wr.wr_id = encode_wr_id(IB_WC_RDMA_READ, + wr.wr_id = encode_wr_id(i == n_rdma - 1 ? + SRPT_RDMA_READ_LAST : + SRPT_RDMA_MID, ioctx->ioctx.index); } wr.next = NULL; @@ -2928,12 +2944,33 @@ static int srpt_perform_rdmas(struct srpt_rdma_ch *ch, wr.sg_list = riu->sge; /* only get completion event for the last rdma wr */ - if (i == (ioctx->n_rdma - 1) && dir == SCST_DATA_WRITE) + if (i == (n_rdma - 1) && dir == SCST_DATA_WRITE) wr.send_flags = IB_SEND_SIGNALED; ret = ib_post_send(ch->qp, &wr, &bad_wr); if (ret) - goto out; + break; + } + + if (ret) + PRINT_ERROR("%s[%d]: ib_post_send() returned %d for %d/%d", + __func__, __LINE__, ret, i, n_rdma); + if (ret && i > 0) { + wr.num_sge = 0; + wr.wr_id = encode_wr_id(SRPT_RDMA_ABORT, ioctx->ioctx.index); + wr.send_flags = IB_SEND_SIGNALED; + while (ch->state == CH_LIVE && + ib_post_send(ch->qp, &wr, &bad_wr) != 0) { + PRINT_INFO("Trying to abort failed RDMA transfer [%d]", + ioctx->ioctx.index); + msleep(1000); + } + while (ch->state != CH_RELEASING && !ioctx->rdma_aborted) { + PRINT_INFO("Waiting until RDMA abort finished [%d]", + ioctx->ioctx.index); + msleep(1000); + } + PRINT_INFO("%s[%d]: done", __func__, __LINE__); } out: diff --git a/srpt/src/ib_srpt.h b/srpt/src/ib_srpt.h index 5240257ea..cb03ecbfd 100644 --- a/srpt/src/ib_srpt.h +++ b/srpt/src/ib_srpt.h @@ -132,9 +132,18 @@ enum { DEFAULT_MAX_RDMA_SIZE = 65536, }; -static inline u64 encode_wr_id(u8 opcode, u32 idx) +enum srpt_opcode { + SRPT_RECV, + SRPT_SEND, + SRPT_RDMA_MID, + SRPT_RDMA_ABORT, + SRPT_RDMA_READ_LAST, + SRPT_RDMA_WRITE_LAST, +}; + +static inline u64 encode_wr_id(enum srpt_opcode opcode, u32 idx) { return ((u64)opcode << 32) | idx; } -static inline u8 opcode_from_wr_id(u64 wr_id) +static inline enum srpt_opcode opcode_from_wr_id(u64 wr_id) { return wr_id >> 32; } static inline u32 idx_from_wr_id(u64 wr_id) { return (u32)wr_id; }