diff --git a/CHANGELOG_PENDING.md b/CHANGELOG_PENDING.md index ad3579acb..39e3d2f1f 100644 --- a/CHANGELOG_PENDING.md +++ b/CHANGELOG_PENDING.md @@ -81,6 +81,7 @@ Friendly reminder, we have a [bug bounty program](https://hackerone.com/tendermi - `BlockStoreStateJSON` is now `BlockStoreState` and is encoded as binary in the database - [store] \#4778 Migrate store module to Protobuf encoding - [types] [\#4792](https://github.com/tendermint/tendermint/pull/4792) Sort validators by voting power to enable faster commit verification (@melekes) + - [mempool] Add RemoveTxByKey() exported function for custom mempool cleaning (@p4u) ### FEATURES: diff --git a/mempool/clist_mempool.go b/mempool/clist_mempool.go index abc696909..c4b63f8bb 100644 --- a/mempool/clist_mempool.go +++ b/mempool/clist_mempool.go @@ -20,6 +20,9 @@ import ( "github.com/tendermint/tendermint/types" ) +// TxKeySize is the size of the transaction key index +const TxKeySize = sha256.Size + //-------------------------------------------------------------------------------- // CListMempool is an ordered in-memory pool for transactions before they are @@ -255,7 +258,7 @@ func (mem *CListMempool) CheckTx(tx types.Tx, cb func(*abci.Response), txInfo Tx // Note it's possible a tx is still in the cache but no longer in the mempool // (eg. after committing a block, txs are removed from mempool but not cache), // so we only record the sender for txs still in the mempool. - if e, ok := mem.txsMap.Load(txKey(tx)); ok { + if e, ok := mem.txsMap.Load(TxKey(tx)); ok { memTx := e.(*clist.CElement).Value.(*mempoolTx) memTx.senders.LoadOrStore(txInfo.SenderID, true) // TODO: consider punishing peer for dups, @@ -351,7 +354,7 @@ func (mem *CListMempool) reqResCb( // - resCbFirstTime (lock not held) if tx is valid func (mem *CListMempool) addTx(memTx *mempoolTx) { e := mem.txs.PushBack(memTx) - mem.txsMap.Store(txKey(memTx.tx), e) + mem.txsMap.Store(TxKey(memTx.tx), e) atomic.AddInt64(&mem.txsBytes, int64(len(memTx.tx))) mem.metrics.TxSizeBytes.Observe(float64(len(memTx.tx))) } @@ -362,7 +365,7 @@ func (mem *CListMempool) addTx(memTx *mempoolTx) { func (mem *CListMempool) removeTx(tx types.Tx, elem *clist.CElement, removeFromCache bool) { mem.txs.Remove(elem) elem.DetachPrev() - mem.txsMap.Delete(txKey(tx)) + mem.txsMap.Delete(TxKey(tx)) atomic.AddInt64(&mem.txsBytes, int64(-len(tx))) if removeFromCache { @@ -370,6 +373,16 @@ func (mem *CListMempool) removeTx(tx types.Tx, elem *clist.CElement, removeFromC } } +// RemoveTxByKey removes a transaction from the mempool by its TxKey index. +func (mem *CListMempool) RemoveTxByKey(txKey [TxKeySize]byte, removeFromCache bool) { + if e, ok := mem.txsMap.Load(txKey); ok { + memTx := e.(*clist.CElement).Value.(*mempoolTx) + if memTx != nil { + mem.removeTx(memTx.tx, e.(*clist.CElement), removeFromCache) + } + } +} + func (mem *CListMempool) isFull(txSize int) error { var ( memSize = mem.Size() @@ -593,7 +606,7 @@ func (mem *CListMempool) Update( // Mempool after: // 100 // https://github.com/tendermint/tendermint/issues/3322. - if e, ok := mem.txsMap.Load(txKey(tx)); ok { + if e, ok := mem.txsMap.Load(TxKey(tx)); ok { mem.removeTx(tx, e.(*clist.CElement), false) } } @@ -670,7 +683,7 @@ type txCache interface { type mapTxCache struct { mtx sync.Mutex size int - cacheMap map[[sha256.Size]byte]*list.Element + cacheMap map[[TxKeySize]byte]*list.Element list *list.List } @@ -680,7 +693,7 @@ var _ txCache = (*mapTxCache)(nil) func newMapTxCache(cacheSize int) *mapTxCache { return &mapTxCache{ size: cacheSize, - cacheMap: make(map[[sha256.Size]byte]*list.Element, cacheSize), + cacheMap: make(map[[TxKeySize]byte]*list.Element, cacheSize), list: list.New(), } } @@ -688,7 +701,7 @@ func newMapTxCache(cacheSize int) *mapTxCache { // Reset resets the cache to an empty state. func (cache *mapTxCache) Reset() { cache.mtx.Lock() - cache.cacheMap = make(map[[sha256.Size]byte]*list.Element, cache.size) + cache.cacheMap = make(map[[TxKeySize]byte]*list.Element, cache.size) cache.list.Init() cache.mtx.Unlock() } @@ -700,7 +713,7 @@ func (cache *mapTxCache) Push(tx types.Tx) bool { defer cache.mtx.Unlock() // Use the tx hash in the cache - txHash := txKey(tx) + txHash := TxKey(tx) if moved, exists := cache.cacheMap[txHash]; exists { cache.list.MoveToBack(moved) return false @@ -709,7 +722,7 @@ func (cache *mapTxCache) Push(tx types.Tx) bool { if cache.list.Len() >= cache.size { popped := cache.list.Front() if popped != nil { - poppedTxHash := popped.Value.([sha256.Size]byte) + poppedTxHash := popped.Value.([TxKeySize]byte) delete(cache.cacheMap, poppedTxHash) cache.list.Remove(popped) } @@ -722,7 +735,7 @@ func (cache *mapTxCache) Push(tx types.Tx) bool { // Remove removes the given tx from the cache. func (cache *mapTxCache) Remove(tx types.Tx) { cache.mtx.Lock() - txHash := txKey(tx) + txHash := TxKey(tx) popped := cache.cacheMap[txHash] delete(cache.cacheMap, txHash) if popped != nil { @@ -742,8 +755,8 @@ func (nopTxCache) Remove(types.Tx) {} //-------------------------------------------------------------------------------- -// txKey is the fixed length array sha256 hash used as the key in maps. -func txKey(tx types.Tx) [sha256.Size]byte { +// TxKey is the fixed length array hash used as the key in maps. +func TxKey(tx types.Tx) [TxKeySize]byte { return sha256.Sum256(tx) } diff --git a/mempool/clist_mempool_test.go b/mempool/clist_mempool_test.go index 245fda53d..e5af51d7c 100644 --- a/mempool/clist_mempool_test.go +++ b/mempool/clist_mempool_test.go @@ -528,6 +528,16 @@ func TestMempoolTxsBytes(t *testing.T) { // Pretend like we committed nothing so txBytes gets rechecked and removed. mempool.Update(1, []types.Tx{}, abciResponses(0, abci.CodeTypeOK), nil, nil) assert.EqualValues(t, 0, mempool.TxsBytes()) + + // 7. Test RemoveTxByKey function + err = mempool.CheckTx([]byte{0x06}, nil, TxInfo{}) + require.NoError(t, err) + assert.EqualValues(t, 1, mempool.TxsBytes()) + mempool.RemoveTxByKey(TxKey([]byte{0x07}), true) + assert.EqualValues(t, 1, mempool.TxsBytes()) + mempool.RemoveTxByKey(TxKey([]byte{0x06}), true) + assert.EqualValues(t, 0, mempool.TxsBytes()) + } // This will non-deterministically catch some concurrency failures like