diff --git a/srpt/src/ib_srpt.c b/srpt/src/ib_srpt.c index 7b2db32f9..43a961251 100644 --- a/srpt/src/ib_srpt.c +++ b/srpt/src/ib_srpt.c @@ -696,6 +696,59 @@ static void srpt_free_ioctx_ring(struct srpt_device *sdev) } } +/** Atomically get the state of a command. */ +static enum srpt_command_state srpt_get_cmd_state(struct srpt_ioctx *ioctx) +{ + return atomic_read(&ioctx->state); +} + +/** + * Atomically set the state of a command. + * @new: New state to be set. + * + * Does not modify the state of aborted commands. + * + * Returns the previous command state. + */ +static enum srpt_command_state srpt_set_cmd_state(struct srpt_ioctx *ioctx, + enum srpt_command_state new) +{ + enum srpt_command_state previous; + + do { + previous = atomic_read(&ioctx->state); + } while (previous != SRPT_STATE_ABORTED + && atomic_cmpxchg(&ioctx->state, previous, new) != previous); + + return previous; +} + +/** + * Atomically test and set the state of a command. + * @expected: State to compare against. + * @new: New state to be set if the current state matches 'expected'. + * + * Returns the previous command state. + */ +static enum srpt_command_state +srpt_test_and_set_cmd_state(struct srpt_ioctx *ioctx, + enum srpt_command_state expected, + enum srpt_command_state new) +{ + enum srpt_command_state previous; + + WARN_ON(expected == SRPT_STATE_ABORTED); + WARN_ON(new == SRPT_STATE_NEW); + + do { + previous = atomic_read(&ioctx->state); + } while (previous != SRPT_STATE_ABORTED + && previous == expected + && atomic_cmpxchg(&ioctx->state, previous, new) != previous); + + return previous; +} + /* * Post a receive request on the work queue of InfiniBand device 'sdev'. */ @@ -914,17 +967,20 @@ static void srpt_abort_scst_cmd(struct srpt_device *sdev, struct srpt_ioctx *ioctx; scst_data_direction dir; struct srpt_rdma_ch *ch; + enum srpt_command_state previous_state; ioctx = scst_cmd_get_tgt_priv(scmnd); BUG_ON(!ioctx); dir = scst_cmd_get_data_direction(scmnd); - if (dir != SCST_DATA_NONE) + if (dir != SCST_DATA_NONE && scst_cmd_get_sg(scmnd)) ib_dma_unmap_sg(sdev->device, scst_cmd_get_sg(scmnd), scst_cmd_get_sg_cnt(scmnd), scst_to_tgt_dma_dir(dir)); - if (ioctx->state == SRPT_STATE_NEW) { + previous_state = srpt_set_cmd_state(ioctx, SRPT_STATE_ABORTED); + switch (previous_state) { + case SRPT_STATE_NEW: /* * Do not try to abort the SCST command here but wait until * the SCST core has called srpt_rdy_to_xfer() or @@ -940,23 +996,24 @@ static void srpt_abort_scst_cmd(struct srpt_device *sdev, list_del(&ioctx->scmnd_list); ch->active_scmnd_cnt--; spin_unlock_irq(&ch->spinlock); - } else if (ioctx->state == SRPT_STATE_NEED_DATA) { + break; + case SRPT_STATE_NEED_DATA: WARN_ON(scst_cmd_get_data_direction(ioctx->scmnd) == SCST_DATA_READ); scst_rx_data(scmnd, tell_initiator ? SCST_RX_STATUS_ERROR : SCST_RX_STATUS_ERROR_FATAL, SCST_CONTEXT_THREAD); - } else if (ioctx->state == SRPT_STATE_PROCESSED) { + break; + case SRPT_STATE_PROCESSED: scst_set_delivery_status(scmnd, SCST_CMD_DELIVERY_FAILED); WARN_ON(scmnd->state != SCST_CMD_STATE_XMIT_WAIT); scst_tgt_cmd_done(scmnd, scst_estimate_context()); - } else { - TRACE_DBG("Aborting cmd with state %d", ioctx->state); + break; + default: + TRACE_DBG("Aborting cmd with state %d", previous_state); WARN_ON("ERROR: unexpected command state"); } - - ioctx->state = SRPT_STATE_ABORTED; } static void srpt_handle_err_comp(struct srpt_rdma_ch *ch, struct ib_wc *wc) @@ -985,7 +1042,7 @@ static void srpt_handle_send_comp(struct srpt_rdma_ch *ch, scst_data_direction dir = scst_cmd_get_data_direction(ioctx->scmnd); - if (dir != SCST_DATA_NONE) + if (dir != SCST_DATA_NONE && scst_cmd_get_sg(ioctx->scmnd)) ib_dma_unmap_sg(ch->sport->sdev->device, scst_cmd_get_sg(ioctx->scmnd), scst_cmd_get_sg_cnt(ioctx->scmnd), @@ -1012,10 +1069,10 @@ static void srpt_handle_rdma_comp(struct srpt_rdma_ch *ch, * command, tell SCST that processing can continue by calling * scst_rx_data(). */ - if (ioctx->state == SRPT_STATE_NEED_DATA) { + if (srpt_test_and_set_cmd_state(ioctx, SRPT_STATE_NEED_DATA, + SRPT_STATE_DATA_IN) == SRPT_STATE_NEED_DATA) { WARN_ON(scst_cmd_get_data_direction(ioctx->scmnd) == SCST_DATA_READ); - ioctx->state = SRPT_STATE_DATA_IN; scst_rx_data(ioctx->scmnd, SCST_RX_STATUS_SUCCESS, scst_estimate_context()); } @@ -1365,7 +1422,7 @@ static void srpt_handle_new_iu(struct srpt_rdma_ch *ch, ioctx->rdma_ius = NULL; ioctx->scmnd = NULL; ioctx->ch = ch; - ioctx->state = SRPT_STATE_NEW; + srpt_set_cmd_state(ioctx, SRPT_STATE_NEW); srp_cmd = ioctx->buf; srp_rsp = ioctx->buf; @@ -2086,6 +2143,7 @@ static int srpt_map_sg_to_ib_sge(struct srpt_rdma_ch *ch, scat = scst_cmd_get_sg(scmnd); dir = scst_cmd_get_data_direction(scmnd); + WARN_ON(scat == 0); count = ib_dma_map_sg(ch->sport->sdev->device, scat, scst_cmd_get_sg_cnt(scmnd), scst_to_tgt_dma_dir(dir)); @@ -2101,6 +2159,7 @@ static int srpt_map_sg_to_ib_sge(struct srpt_rdma_ch *ch, scst_cmd_atomic(scmnd) ? GFP_ATOMIC : GFP_KERNEL); if (!ioctx->rdma_ius) { + WARN_ON(scat == 0); ib_dma_unmap_sg(ch->sport->sdev->device, scat, scst_cmd_get_sg_cnt(scmnd), scst_to_tgt_dma_dir(dir)); @@ -2238,6 +2297,7 @@ free_mem: kfree(ioctx->rdma_ius); + WARN_ON(scat == 0); ib_dma_unmap_sg(ch->sport->sdev->device, scat, scst_cmd_get_sg_cnt(scmnd), scst_to_tgt_dma_dir(dir)); @@ -2322,7 +2382,7 @@ static int srpt_rdy_to_xfer(struct scst_cmd *scmnd) ioctx = scst_cmd_get_tgt_priv(scmnd); BUG_ON(!ioctx); - if (ioctx->state == SRPT_STATE_ABORTED) + if (srpt_get_cmd_state(ioctx) == SRPT_STATE_ABORTED) return SCST_TGT_RES_FATAL_ERROR; ch = ioctx->ch; @@ -2334,7 +2394,7 @@ static int srpt_rdy_to_xfer(struct scst_cmd *scmnd) else if (ch->state == RDMA_CHANNEL_CONNECTING) return SCST_TGT_RES_QUEUE_FULL; - ioctx->state = SRPT_STATE_NEED_DATA; + srpt_set_cmd_state(ioctx, SRPT_STATE_NEED_DATA); return srpt_xfer_data(ch, ioctx, scmnd); } @@ -2356,7 +2416,7 @@ static int srpt_xmit_response(struct scst_cmd *scmnd) ioctx = scst_cmd_get_tgt_priv(scmnd); BUG_ON(!ioctx); - if (ioctx->state == SRPT_STATE_ABORTED) { + if (srpt_get_cmd_state(ioctx) == SRPT_STATE_ABORTED) { ret = SCST_TGT_RES_FATAL_ERROR; goto out; } @@ -2366,7 +2426,7 @@ static int srpt_xmit_response(struct scst_cmd *scmnd) tag = scst_cmd_get_tag(scmnd); - ioctx->state = SRPT_STATE_PROCESSED; + srpt_set_cmd_state(ioctx, SRPT_STATE_PROCESSED); if (ch->state != RDMA_CHANNEL_LIVE) { PRINT_ERROR("%s: tag= %lld channel in bad state %d", @@ -2443,7 +2503,7 @@ out: out_aborted: ret = SCST_TGT_RES_SUCCESS; scst_set_delivery_status(scmnd, SCST_CMD_DELIVERY_ABORTED); - ioctx->state = SRPT_STATE_ABORTED; + srpt_set_cmd_state(ioctx, SRPT_STATE_ABORTED); WARN_ON(scmnd->state != SCST_CMD_STATE_XMIT_WAIT); scst_tgt_cmd_done(scmnd, SCST_CONTEXT_SAME); goto out; @@ -2473,7 +2533,7 @@ static void srpt_tsk_mgmt_done(struct scst_mgmt_cmd *mcmnd) __func__, (unsigned long long)mgmt_ioctx->tag, scst_mgmt_cmd_get_status(mcmnd)); - ioctx->state = SRPT_STATE_PROCESSED; + srpt_set_cmd_state(ioctx, SRPT_STATE_PROCESSED); rsp_len = srpt_build_tskmgmt_rsp(ch, ioctx, (scst_mgmt_cmd_get_status(mcmnd) == diff --git a/srpt/src/ib_srpt.h b/srpt/src/ib_srpt.h index cdc4ce013..774a38429 100644 --- a/srpt/src/ib_srpt.h +++ b/srpt/src/ib_srpt.h @@ -142,7 +142,7 @@ struct srpt_ioctx { struct srpt_rdma_ch *ch; struct scst_cmd *scmnd; u64 data_len; - enum srpt_command_state state; + atomic_t state; /* enum srpt_command_state */ }; /* Additional context information for SCST management commands. */