diff --git a/src/fiber/channel.rs b/src/fiber/channel.rs index eeb602a9b..52e588dfa 100644 --- a/src/fiber/channel.rs +++ b/src/fiber/channel.rs @@ -73,9 +73,9 @@ use super::{ serde_utils::EntityHex, types::{ AcceptChannel, AddTlc, ChannelAnnouncement, ChannelReady, ClosingSigned, CommitmentSigned, - EcdsaSignature, FiberChannelMessage, FiberMessage, Hash256, OpenChannel, PaymentHopData, - PaymentOnionPacket, PeeledOnionPacket, Privkey, Pubkey, ReestablishChannel, RemoveTlc, - RemoveTlcFulfill, RemoveTlcReason, RevokeAndAck, TxCollaborationMsg, TxComplete, TxUpdate, + EcdsaSignature, FiberChannelMessage, FiberMessage, Hash256, OpenChannel, + PaymentOnionPacket, Privkey, Pubkey, ReestablishChannel, RemoveTlc, RemoveTlcFulfill, + RemoveTlcReason, RevokeAndAck, TxCollaborationMsg, TxComplete, TxUpdate, }, NetworkActorCommand, NetworkActorEvent, NetworkActorMessage, ASSUME_NETWORK_ACTOR_ALIVE, }; @@ -760,12 +760,12 @@ where } } - async fn check_add_tlc_info( + async fn apply_add_tlc_operation( &self, state: &mut ChannelActorState, add_tlc: &AddTlcInfo, ) -> Result<(), ProcessingChannelError> { - let mut update_invoice_payment_hash: Option = None; + state.check_tlc_expiry(add_tlc.expiry)?; let tlc = state .get_received_tlc(add_tlc.tlc_id.into()) @@ -807,7 +807,9 @@ where if invoice_status != CkbInvoiceStatus::Open { return Err(ProcessingChannelError::FinalInvoiceInvalid(invoice_status)); } - update_invoice_payment_hash = Some(payment_hash); + self.store + .update_invoice_status(&payment_hash, CkbInvoiceStatus::Received) + .expect("update invoice status failed"); } // if this is the last hop, store the preimage. @@ -827,8 +829,6 @@ where return Err(ProcessingChannelError::FinalIncorrectPaymentHash); } } else { - // if this is not the last hop, store the peeled packet. - state.set_received_tlc_peeled_packet(add_tlc.tlc_id.into(), peeled_onion_packet); assert!(received_amount >= forward_amount); let forward_fee = received_amount.saturating_sub(forward_amount); let fee_rate: u128 = state @@ -845,14 +845,16 @@ where ); return Err(ProcessingChannelError::TlcForwardFeeIsTooLow); } + // if this is not the last hop, forward TLC to next hop + self.handle_forward_onion_packet( + state, + peeled_onion_packet.clone(), + add_tlc.tlc_id.into(), + ) + .await?; } } - if let Some(payment_hash) = update_invoice_payment_hash { - self.store - .update_invoice_status(&payment_hash, CkbInvoiceStatus::Received) - .expect("update invoice status failed"); - } if let Some(ref udt_type_script) = state.funding_udt_type_script { self.subscribers .pending_received_tlcs_subscribers @@ -902,31 +904,6 @@ where Ok(()) } - async fn apply_add_tlc_operation( - &self, - state: &mut ChannelActorState, - add_tlc: &AddTlcInfo, - ) -> Result<(), ProcessingChannelError> { - state.check_tlc_expiry(add_tlc.expiry)?; - self.check_add_tlc_info(state, add_tlc).await?; - - // retrieve the tlc from the state since we updated onion packet and preimage - let tlc = state - .get_received_tlc(add_tlc.tlc_id.into()) - .expect("tlc exists"); - - if tlc.need_forward_next_hop() { - let peeled_onion_packet = tlc - .peeled_onion_packet - .as_ref() - .expect("peeled onion packet exists in tlcs for forwarding") - .clone(); - self.handle_forward_onion_packet(state, peeled_onion_packet, add_tlc.tlc_id.into()) - .await?; - } - Ok(()) - } - async fn apply_remove_tlc_operation( &self, state: &mut ChannelActorState, @@ -2062,14 +2039,6 @@ impl AddTlcInfo { .try_into() .expect("short hash from payment hash") } - - fn need_forward_next_hop(&self) -> bool { - !self.is_last_hop() - && self.peeled_onion_packet.is_some() - && self.previous_tlc.is_none() - && self.relay_status == TlcRelayStatus::WaitingForward - && self.payment_preimage.is_none() - } } #[derive(Clone, Debug, Serialize, Deserialize, Eq, PartialEq)] @@ -4030,20 +3999,6 @@ impl ChannelActorState { } } - pub(crate) fn set_received_tlc_peeled_packet( - &mut self, - tlc_id: u64, - peeled_packet: &PeeledOnionPacket, - ) { - if let Some(tlc) = self.tlc_state.get_mut(&TLCId::Received(tlc_id)) { - if !peeled_packet.is_last() { - tlc.relay_status = TlcRelayStatus::WaitingForward; - } - - tlc.peeled_onion_packet = Some(peeled_packet.clone()); - } - } - pub fn check_insert_tlc(&mut self, tlc: &AddTlcInfo) -> Result<(), ProcessingChannelError> { let payment_hash = tlc.payment_hash; if let Some(tlc) = self