From 0bf67bf67ca9594b4539e3faf4f77f7d3afef625 Mon Sep 17 00:00:00 2001
From: travisladuke <travisladuke@gmail.com>
Date: Thu, 1 Feb 2024 13:46:05 -0800
Subject: [PATCH] Fix rules engine quirks

See #2200

Mostly makes Tag based rules work as expected
---
 node/Network.cpp | 71 ++++++++++++++++++++++++++++--------------------
 1 file changed, 42 insertions(+), 29 deletions(-)

diff --git a/node/Network.cpp b/node/Network.cpp
index 233c10641..0d61d2652 100644
--- a/node/Network.cpp
+++ b/node/Network.cpp
@@ -107,6 +107,7 @@ static _doZtFilterResult _doZtFilter(
 	// The default match state for each set of entries starts as 'true' since an
 	// ACTION with no MATCH entries preceding it is always taken.
 	uint8_t thisSetMatches = 1;
+	uint8_t skipDrop = 0;
 
 	rrl.clear();
 
@@ -121,11 +122,16 @@ static _doZtFilterResult _doZtFilter(
 						qosBucket = (rules[rn].v.qosBucket <= 8) ? rules[rn].v.qosBucket : 4; // 4 = default bucket (no priority)
 						return DOZTFILTER_ACCEPT;
 
-					case ZT_NETWORK_RULE_ACTION_DROP:
+					case ZT_NETWORK_RULE_ACTION_DROP: {
+						if (!!skipDrop) {
+							skipDrop = 0; continue;
+						}
 						return DOZTFILTER_DROP;
+					}
 
-					case ZT_NETWORK_RULE_ACTION_ACCEPT:
+					case ZT_NETWORK_RULE_ACTION_ACCEPT: {
 						return (superAccept ? DOZTFILTER_SUPER_ACCEPT : DOZTFILTER_ACCEPT); // match, accept packet
+					}
 
 					// These are initially handled together since preliminary logic is common
 					case ZT_NETWORK_RULE_ACTION_TEE:
@@ -192,6 +198,9 @@ static _doZtFilterResult _doZtFilter(
 		// If this was not an ACTION evaluate next MATCH and update thisSetMatches with (AND [result])
 		uint8_t thisRuleMatches = 0;
 		uint64_t ownershipVerificationMask = 1; // this magic value means it hasn't been computed yet -- this is done lazily the first time it's needed
+		uint8_t hardYes = (rules[rn].t >> 7) ^ 1; // XOR with the NOT bit of the rule
+		uint8_t hardNo = (rules[rn].t >> 7) ^ 0;
+
 		switch(rt) {
 			case ZT_NETWORK_RULE_MATCH_SOURCE_ZEROTIER_ADDRESS:
 				thisRuleMatches = (uint8_t)(rules[rn].v.zt == ztSource.toInt());
@@ -220,28 +229,28 @@ static _doZtFilterResult _doZtFilter(
 				if ((etherType == ZT_ETHERTYPE_IPV4)&&(frameLen >= 20)) {
 					thisRuleMatches = (uint8_t)(InetAddress((const void *)&(rules[rn].v.ipv4.ip),4,rules[rn].v.ipv4.mask).containsAddress(InetAddress((const void *)(frameData + 12),4,0)));
 				} else {
-					thisRuleMatches = 0;
+					thisRuleMatches = hardNo;
 				}
 				break;
 			case ZT_NETWORK_RULE_MATCH_IPV4_DEST:
 				if ((etherType == ZT_ETHERTYPE_IPV4)&&(frameLen >= 20)) {
 					thisRuleMatches = (uint8_t)(InetAddress((const void *)&(rules[rn].v.ipv4.ip),4,rules[rn].v.ipv4.mask).containsAddress(InetAddress((const void *)(frameData + 16),4,0)));
 				} else {
-					thisRuleMatches = 0;
+					thisRuleMatches = hardNo;
 				}
 				break;
 			case ZT_NETWORK_RULE_MATCH_IPV6_SOURCE:
 				if ((etherType == ZT_ETHERTYPE_IPV6)&&(frameLen >= 40)) {
 					thisRuleMatches = (uint8_t)(InetAddress((const void *)rules[rn].v.ipv6.ip,16,rules[rn].v.ipv6.mask).containsAddress(InetAddress((const void *)(frameData + 8),16,0)));
 				} else {
-					thisRuleMatches = 0;
+					thisRuleMatches = hardNo;
 				}
 				break;
 			case ZT_NETWORK_RULE_MATCH_IPV6_DEST:
 				if ((etherType == ZT_ETHERTYPE_IPV6)&&(frameLen >= 40)) {
 					thisRuleMatches = (uint8_t)(InetAddress((const void *)rules[rn].v.ipv6.ip,16,rules[rn].v.ipv6.mask).containsAddress(InetAddress((const void *)(frameData + 24),16,0)));
 				} else {
-					thisRuleMatches = 0;
+					thisRuleMatches = hardNo;
 				}
 				break;
 			case ZT_NETWORK_RULE_MATCH_IP_TOS:
@@ -252,7 +261,7 @@ static _doZtFilterResult _doZtFilter(
 					const uint8_t tosMasked = (((frameData[0] << 4) & 0xf0) | ((frameData[1] >> 4) & 0x0f)) & rules[rn].v.ipTos.mask;
 					thisRuleMatches = (uint8_t)((tosMasked >= rules[rn].v.ipTos.value[0])&&(tosMasked <= rules[rn].v.ipTos.value[1]));
 				} else {
-					thisRuleMatches = 0;
+					thisRuleMatches = hardNo;
 				}
 				break;
 			case ZT_NETWORK_RULE_MATCH_IP_PROTOCOL:
@@ -263,10 +272,10 @@ static _doZtFilterResult _doZtFilter(
 					if (_ipv6GetPayload(frameData,frameLen,pos,proto)) {
 						thisRuleMatches = (uint8_t)(rules[rn].v.ipProtocol == (uint8_t)proto);
 					} else {
-						thisRuleMatches = 0;
+						thisRuleMatches = hardNo;
 					}
 				} else {
-					thisRuleMatches = 0;
+					thisRuleMatches = hardNo;
 				}
 				break;
 			case ZT_NETWORK_RULE_MATCH_ETHERTYPE:
@@ -281,16 +290,16 @@ static _doZtFilterResult _doZtFilter(
 								if ((rules[rn].v.icmp.flags & 0x01) != 0) {
 									thisRuleMatches = (uint8_t)(frameData[ihl+1] == rules[rn].v.icmp.code);
 								} else {
-									thisRuleMatches = 1;
+									thisRuleMatches = hardYes;
 								}
 							} else {
-								thisRuleMatches = 0;
+								thisRuleMatches = hardNo;
 							}
 						} else {
-							thisRuleMatches = 0;
+							thisRuleMatches = hardNo;
 						}
 					} else {
-						thisRuleMatches = 0;
+						thisRuleMatches = hardNo;
 					}
 				} else if (etherType == ZT_ETHERTYPE_IPV6) {
 					unsigned int pos = 0,proto = 0;
@@ -300,19 +309,19 @@ static _doZtFilterResult _doZtFilter(
 								if ((rules[rn].v.icmp.flags & 0x01) != 0) {
 									thisRuleMatches = (uint8_t)(frameData[pos+1] == rules[rn].v.icmp.code);
 								} else {
-									thisRuleMatches = 1;
+									thisRuleMatches = hardYes;
 								}
 							} else {
-								thisRuleMatches = 0;
+								thisRuleMatches = hardNo;
 							}
 						} else {
-							thisRuleMatches = 0;
+							thisRuleMatches = hardNo;
 						}
 					} else {
-						thisRuleMatches = 0;
+						thisRuleMatches = hardNo;
 					}
 				} else {
-					thisRuleMatches = 0;
+					thisRuleMatches = hardNo;
 				}
 				break;
 			case ZT_NETWORK_RULE_MATCH_IP_SOURCE_PORT_RANGE:
@@ -356,10 +365,10 @@ static _doZtFilterResult _doZtFilter(
 						}
 						thisRuleMatches = (p > 0) ? (uint8_t)((p >= (int)rules[rn].v.port[0])&&(p <= (int)rules[rn].v.port[1])) : (uint8_t)0;
 					} else {
-						thisRuleMatches = 0;
+						thisRuleMatches = hardNo;
 					}
 				} else {
-					thisRuleMatches = 0;
+					thisRuleMatches = hardNo;
 				}
 				break;
 			case ZT_NETWORK_RULE_MATCH_CHARACTERISTICS: {
@@ -459,28 +468,32 @@ static _doZtFilterResult _doZtFilter(
 						} else if (rt == ZT_NETWORK_RULE_MATCH_TAGS_EQUAL) {
 							thisRuleMatches = (uint8_t)((ltv == rules[rn].v.tag.value)&&(rtv == rules[rn].v.tag.value));
 						} else { // sanity check, can't really happen
-							thisRuleMatches = 0;
+							thisRuleMatches = hardNo;
 						}
 					} else {
 						if ((inbound)&&(!superAccept)) {
-							thisRuleMatches = 0;
+							thisRuleMatches = hardNo;
 						} else {
 							// Outbound side is not strict since if we have to match both tags and
 							// we are sending a first packet to a recipient, we probably do not know
 							// about their tags yet. They will filter on inbound and we will filter
 							// once we get their tag. If we are a tee/redirect target we are also
 							// not strict since we likely do not have these tags.
-							thisRuleMatches = 1;
+							skipDrop = 1;
+							thisRuleMatches = hardYes;
 						}
 					}
 				} else {
-					thisRuleMatches = 0;
+					thisRuleMatches = hardNo;
 				}
 			}	break;
 			case ZT_NETWORK_RULE_MATCH_TAG_SENDER:
 			case ZT_NETWORK_RULE_MATCH_TAG_RECEIVER: {
+					const Tag *const remoteTag = ((membership) ? membership->getTag(nconf,rules[rn].v.tag.id) : (const Tag *)0);
+					const Tag *const localTag = std::lower_bound(&(nconf.tags[0]),&(nconf.tags[nconf.tagCount]),rules[rn].v.tag.id,Tag::IdComparePredicate());
 				if (superAccept) {
-					thisRuleMatches = 1;
+					skipDrop = 1;
+					thisRuleMatches = hardYes;
 				} else if ( ((rt == ZT_NETWORK_RULE_MATCH_TAG_SENDER)&&(inbound)) || ((rt == ZT_NETWORK_RULE_MATCH_TAG_RECEIVER)&&(!inbound)) ) {
 					const Tag *const remoteTag = ((membership) ? membership->getTag(nconf,rules[rn].v.tag.id) : (const Tag *)0);
 					if (remoteTag) {
@@ -489,17 +502,17 @@ static _doZtFilterResult _doZtFilter(
 						if (rt == ZT_NETWORK_RULE_MATCH_TAG_RECEIVER) {
 							// If we are checking the receiver and this is an outbound packet, we
 							// can't be strict since we may not yet know the receiver's tag.
-							thisRuleMatches = 1;
+							skipDrop = 1;
+							thisRuleMatches = hardYes;
 						} else {
-							thisRuleMatches = 0;
+							thisRuleMatches = hardNo;
 						}
 					}
 				} else { // sender and outbound or receiver and inbound
-					const Tag *const localTag = std::lower_bound(&(nconf.tags[0]),&(nconf.tags[nconf.tagCount]),rules[rn].v.tag.id,Tag::IdComparePredicate());
 					if ((localTag != &(nconf.tags[nconf.tagCount]))&&(localTag->id() == rules[rn].v.tag.id)) {
 						thisRuleMatches = (uint8_t)(localTag->value() == rules[rn].v.tag.value);
 					} else {
-						thisRuleMatches = 0;
+						thisRuleMatches = hardNo;
 					}
 				}
 			}	break;