mirror of
https://github.com/amnezia-vpn/amneziawg-go.git
synced 2025-04-25 10:27:44 +02:00
Compare commits
No commits in common. "master" and "0.0.20180514" have entirely different histories.
master
...
0.0.201805
143 changed files with 7109 additions and 16146 deletions
41
.github/workflows/build-if-tag.yml
vendored
41
.github/workflows/build-if-tag.yml
vendored
|
@ -1,41 +0,0 @@
|
||||||
name: build-if-tag
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- 'v[0-9]+.[0-9]+.[0-9]+'
|
|
||||||
|
|
||||||
env:
|
|
||||||
APP: amneziawg-go
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
name: build
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
with:
|
|
||||||
ref: ${{ github.ref_name }}
|
|
||||||
|
|
||||||
- name: Login to Docker Hub
|
|
||||||
uses: docker/login-action@v3
|
|
||||||
with:
|
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
|
||||||
|
|
||||||
- name: Setup metadata
|
|
||||||
uses: docker/metadata-action@v5
|
|
||||||
id: metadata
|
|
||||||
with:
|
|
||||||
images: amneziavpn/${{ env.APP }}
|
|
||||||
tags: type=semver,pattern={{version}}
|
|
||||||
|
|
||||||
- name: Set up Docker Buildx
|
|
||||||
uses: docker/setup-buildx-action@v3
|
|
||||||
|
|
||||||
- name: Build
|
|
||||||
uses: docker/build-push-action@v5
|
|
||||||
with:
|
|
||||||
push: true
|
|
||||||
tags: ${{ steps.metadata.outputs.tags }}
|
|
3
.gitignore
vendored
3
.gitignore
vendored
|
@ -1 +1,2 @@
|
||||||
amneziawg-go
|
wireguard-go
|
||||||
|
vendor
|
||||||
|
|
338
COPYING
Normal file
338
COPYING
Normal file
|
@ -0,0 +1,338 @@
|
||||||
|
GNU GENERAL PUBLIC LICENSE
|
||||||
|
Version 2, June 1991
|
||||||
|
|
||||||
|
Copyright (C) 1989, 1991 Free Software Foundation, Inc.,
|
||||||
|
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||||
|
Everyone is permitted to copy and distribute verbatim copies
|
||||||
|
of this license document, but changing it is not allowed.
|
||||||
|
|
||||||
|
Preamble
|
||||||
|
|
||||||
|
The licenses for most software are designed to take away your
|
||||||
|
freedom to share and change it. By contrast, the GNU General Public
|
||||||
|
License is intended to guarantee your freedom to share and change free
|
||||||
|
software--to make sure the software is free for all its users. This
|
||||||
|
General Public License applies to most of the Free Software
|
||||||
|
Foundation's software and to any other program whose authors commit to
|
||||||
|
using it. (Some other Free Software Foundation software is covered by
|
||||||
|
the GNU Lesser General Public License instead.) You can apply it to
|
||||||
|
your programs, too.
|
||||||
|
|
||||||
|
When we speak of free software, we are referring to freedom, not
|
||||||
|
price. Our General Public Licenses are designed to make sure that you
|
||||||
|
have the freedom to distribute copies of free software (and charge for
|
||||||
|
this service if you wish), that you receive source code or can get it
|
||||||
|
if you want it, that you can change the software or use pieces of it
|
||||||
|
in new free programs; and that you know you can do these things.
|
||||||
|
|
||||||
|
To protect your rights, we need to make restrictions that forbid
|
||||||
|
anyone to deny you these rights or to ask you to surrender the rights.
|
||||||
|
These restrictions translate to certain responsibilities for you if you
|
||||||
|
distribute copies of the software, or if you modify it.
|
||||||
|
|
||||||
|
For example, if you distribute copies of such a program, whether
|
||||||
|
gratis or for a fee, you must give the recipients all the rights that
|
||||||
|
you have. You must make sure that they, too, receive or can get the
|
||||||
|
source code. And you must show them these terms so they know their
|
||||||
|
rights.
|
||||||
|
|
||||||
|
We protect your rights with two steps: (1) copyright the software, and
|
||||||
|
(2) offer you this license which gives you legal permission to copy,
|
||||||
|
distribute and/or modify the software.
|
||||||
|
|
||||||
|
Also, for each author's protection and ours, we want to make certain
|
||||||
|
that everyone understands that there is no warranty for this free
|
||||||
|
software. If the software is modified by someone else and passed on, we
|
||||||
|
want its recipients to know that what they have is not the original, so
|
||||||
|
that any problems introduced by others will not reflect on the original
|
||||||
|
authors' reputations.
|
||||||
|
|
||||||
|
Finally, any free program is threatened constantly by software
|
||||||
|
patents. We wish to avoid the danger that redistributors of a free
|
||||||
|
program will individually obtain patent licenses, in effect making the
|
||||||
|
program proprietary. To prevent this, we have made it clear that any
|
||||||
|
patent must be licensed for everyone's free use or not licensed at all.
|
||||||
|
|
||||||
|
The precise terms and conditions for copying, distribution and
|
||||||
|
modification follow.
|
||||||
|
|
||||||
|
GNU GENERAL PUBLIC LICENSE
|
||||||
|
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
|
||||||
|
|
||||||
|
0. This License applies to any program or other work which contains
|
||||||
|
a notice placed by the copyright holder saying it may be distributed
|
||||||
|
under the terms of this General Public License. The "Program", below,
|
||||||
|
refers to any such program or work, and a "work based on the Program"
|
||||||
|
means either the Program or any derivative work under copyright law:
|
||||||
|
that is to say, a work containing the Program or a portion of it,
|
||||||
|
either verbatim or with modifications and/or translated into another
|
||||||
|
language. (Hereinafter, translation is included without limitation in
|
||||||
|
the term "modification".) Each licensee is addressed as "you".
|
||||||
|
|
||||||
|
Activities other than copying, distribution and modification are not
|
||||||
|
covered by this License; they are outside its scope. The act of
|
||||||
|
running the Program is not restricted, and the output from the Program
|
||||||
|
is covered only if its contents constitute a work based on the
|
||||||
|
Program (independent of having been made by running the Program).
|
||||||
|
Whether that is true depends on what the Program does.
|
||||||
|
|
||||||
|
1. You may copy and distribute verbatim copies of the Program's
|
||||||
|
source code as you receive it, in any medium, provided that you
|
||||||
|
conspicuously and appropriately publish on each copy an appropriate
|
||||||
|
copyright notice and disclaimer of warranty; keep intact all the
|
||||||
|
notices that refer to this License and to the absence of any warranty;
|
||||||
|
and give any other recipients of the Program a copy of this License
|
||||||
|
along with the Program.
|
||||||
|
|
||||||
|
You may charge a fee for the physical act of transferring a copy, and
|
||||||
|
you may at your option offer warranty protection in exchange for a fee.
|
||||||
|
|
||||||
|
2. You may modify your copy or copies of the Program or any portion
|
||||||
|
of it, thus forming a work based on the Program, and copy and
|
||||||
|
distribute such modifications or work under the terms of Section 1
|
||||||
|
above, provided that you also meet all of these conditions:
|
||||||
|
|
||||||
|
a) You must cause the modified files to carry prominent notices
|
||||||
|
stating that you changed the files and the date of any change.
|
||||||
|
|
||||||
|
b) You must cause any work that you distribute or publish, that in
|
||||||
|
whole or in part contains or is derived from the Program or any
|
||||||
|
part thereof, to be licensed as a whole at no charge to all third
|
||||||
|
parties under the terms of this License.
|
||||||
|
|
||||||
|
c) If the modified program normally reads commands interactively
|
||||||
|
when run, you must cause it, when started running for such
|
||||||
|
interactive use in the most ordinary way, to print or display an
|
||||||
|
announcement including an appropriate copyright notice and a
|
||||||
|
notice that there is no warranty (or else, saying that you provide
|
||||||
|
a warranty) and that users may redistribute the program under
|
||||||
|
these conditions, and telling the user how to view a copy of this
|
||||||
|
License. (Exception: if the Program itself is interactive but
|
||||||
|
does not normally print such an announcement, your work based on
|
||||||
|
the Program is not required to print an announcement.)
|
||||||
|
|
||||||
|
These requirements apply to the modified work as a whole. If
|
||||||
|
identifiable sections of that work are not derived from the Program,
|
||||||
|
and can be reasonably considered independent and separate works in
|
||||||
|
themselves, then this License, and its terms, do not apply to those
|
||||||
|
sections when you distribute them as separate works. But when you
|
||||||
|
distribute the same sections as part of a whole which is a work based
|
||||||
|
on the Program, the distribution of the whole must be on the terms of
|
||||||
|
this License, whose permissions for other licensees extend to the
|
||||||
|
entire whole, and thus to each and every part regardless of who wrote it.
|
||||||
|
|
||||||
|
Thus, it is not the intent of this section to claim rights or contest
|
||||||
|
your rights to work written entirely by you; rather, the intent is to
|
||||||
|
exercise the right to control the distribution of derivative or
|
||||||
|
collective works based on the Program.
|
||||||
|
|
||||||
|
In addition, mere aggregation of another work not based on the Program
|
||||||
|
with the Program (or with a work based on the Program) on a volume of
|
||||||
|
a storage or distribution medium does not bring the other work under
|
||||||
|
the scope of this License.
|
||||||
|
|
||||||
|
3. You may copy and distribute the Program (or a work based on it,
|
||||||
|
under Section 2) in object code or executable form under the terms of
|
||||||
|
Sections 1 and 2 above provided that you also do one of the following:
|
||||||
|
|
||||||
|
a) Accompany it with the complete corresponding machine-readable
|
||||||
|
source code, which must be distributed under the terms of Sections
|
||||||
|
1 and 2 above on a medium customarily used for software interchange; or,
|
||||||
|
|
||||||
|
b) Accompany it with a written offer, valid for at least three
|
||||||
|
years, to give any third party, for a charge no more than your
|
||||||
|
cost of physically performing source distribution, a complete
|
||||||
|
machine-readable copy of the corresponding source code, to be
|
||||||
|
distributed under the terms of Sections 1 and 2 above on a medium
|
||||||
|
customarily used for software interchange; or,
|
||||||
|
|
||||||
|
c) Accompany it with the information you received as to the offer
|
||||||
|
to distribute corresponding source code. (This alternative is
|
||||||
|
allowed only for noncommercial distribution and only if you
|
||||||
|
received the program in object code or executable form with such
|
||||||
|
an offer, in accord with Subsection b above.)
|
||||||
|
|
||||||
|
The source code for a work means the preferred form of the work for
|
||||||
|
making modifications to it. For an executable work, complete source
|
||||||
|
code means all the source code for all modules it contains, plus any
|
||||||
|
associated interface definition files, plus the scripts used to
|
||||||
|
control compilation and installation of the executable. However, as a
|
||||||
|
special exception, the source code distributed need not include
|
||||||
|
anything that is normally distributed (in either source or binary
|
||||||
|
form) with the major components (compiler, kernel, and so on) of the
|
||||||
|
operating system on which the executable runs, unless that component
|
||||||
|
itself accompanies the executable.
|
||||||
|
|
||||||
|
If distribution of executable or object code is made by offering
|
||||||
|
access to copy from a designated place, then offering equivalent
|
||||||
|
access to copy the source code from the same place counts as
|
||||||
|
distribution of the source code, even though third parties are not
|
||||||
|
compelled to copy the source along with the object code.
|
||||||
|
|
||||||
|
4. You may not copy, modify, sublicense, or distribute the Program
|
||||||
|
except as expressly provided under this License. Any attempt
|
||||||
|
otherwise to copy, modify, sublicense or distribute the Program is
|
||||||
|
void, and will automatically terminate your rights under this License.
|
||||||
|
However, parties who have received copies, or rights, from you under
|
||||||
|
this License will not have their licenses terminated so long as such
|
||||||
|
parties remain in full compliance.
|
||||||
|
|
||||||
|
5. You are not required to accept this License, since you have not
|
||||||
|
signed it. However, nothing else grants you permission to modify or
|
||||||
|
distribute the Program or its derivative works. These actions are
|
||||||
|
prohibited by law if you do not accept this License. Therefore, by
|
||||||
|
modifying or distributing the Program (or any work based on the
|
||||||
|
Program), you indicate your acceptance of this License to do so, and
|
||||||
|
all its terms and conditions for copying, distributing or modifying
|
||||||
|
the Program or works based on it.
|
||||||
|
|
||||||
|
6. Each time you redistribute the Program (or any work based on the
|
||||||
|
Program), the recipient automatically receives a license from the
|
||||||
|
original licensor to copy, distribute or modify the Program subject to
|
||||||
|
these terms and conditions. You may not impose any further
|
||||||
|
restrictions on the recipients' exercise of the rights granted herein.
|
||||||
|
You are not responsible for enforcing compliance by third parties to
|
||||||
|
this License.
|
||||||
|
|
||||||
|
7. If, as a consequence of a court judgment or allegation of patent
|
||||||
|
infringement or for any other reason (not limited to patent issues),
|
||||||
|
conditions are imposed on you (whether by court order, agreement or
|
||||||
|
otherwise) that contradict the conditions of this License, they do not
|
||||||
|
excuse you from the conditions of this License. If you cannot
|
||||||
|
distribute so as to satisfy simultaneously your obligations under this
|
||||||
|
License and any other pertinent obligations, then as a consequence you
|
||||||
|
may not distribute the Program at all. For example, if a patent
|
||||||
|
license would not permit royalty-free redistribution of the Program by
|
||||||
|
all those who receive copies directly or indirectly through you, then
|
||||||
|
the only way you could satisfy both it and this License would be to
|
||||||
|
refrain entirely from distribution of the Program.
|
||||||
|
|
||||||
|
If any portion of this section is held invalid or unenforceable under
|
||||||
|
any particular circumstance, the balance of the section is intended to
|
||||||
|
apply and the section as a whole is intended to apply in other
|
||||||
|
circumstances.
|
||||||
|
|
||||||
|
It is not the purpose of this section to induce you to infringe any
|
||||||
|
patents or other property right claims or to contest validity of any
|
||||||
|
such claims; this section has the sole purpose of protecting the
|
||||||
|
integrity of the free software distribution system, which is
|
||||||
|
implemented by public license practices. Many people have made
|
||||||
|
generous contributions to the wide range of software distributed
|
||||||
|
through that system in reliance on consistent application of that
|
||||||
|
system; it is up to the author/donor to decide if he or she is willing
|
||||||
|
to distribute software through any other system and a licensee cannot
|
||||||
|
impose that choice.
|
||||||
|
|
||||||
|
This section is intended to make thoroughly clear what is believed to
|
||||||
|
be a consequence of the rest of this License.
|
||||||
|
|
||||||
|
8. If the distribution and/or use of the Program is restricted in
|
||||||
|
certain countries either by patents or by copyrighted interfaces, the
|
||||||
|
original copyright holder who places the Program under this License
|
||||||
|
may add an explicit geographical distribution limitation excluding
|
||||||
|
those countries, so that distribution is permitted only in or among
|
||||||
|
countries not thus excluded. In such case, this License incorporates
|
||||||
|
the limitation as if written in the body of this License.
|
||||||
|
|
||||||
|
9. The Free Software Foundation may publish revised and/or new versions
|
||||||
|
of the General Public License from time to time. Such new versions will
|
||||||
|
be similar in spirit to the present version, but may differ in detail to
|
||||||
|
address new problems or concerns.
|
||||||
|
|
||||||
|
Each version is given a distinguishing version number. If the Program
|
||||||
|
specifies a version number of this License which applies to it and "any
|
||||||
|
later version", you have the option of following the terms and conditions
|
||||||
|
either of that version or of any later version published by the Free
|
||||||
|
Software Foundation. If the Program does not specify a version number of
|
||||||
|
this License, you may choose any version ever published by the Free Software
|
||||||
|
Foundation.
|
||||||
|
|
||||||
|
10. If you wish to incorporate parts of the Program into other free
|
||||||
|
programs whose distribution conditions are different, write to the author
|
||||||
|
to ask for permission. For software which is copyrighted by the Free
|
||||||
|
Software Foundation, write to the Free Software Foundation; we sometimes
|
||||||
|
make exceptions for this. Our decision will be guided by the two goals
|
||||||
|
of preserving the free status of all derivatives of our free software and
|
||||||
|
of promoting the sharing and reuse of software generally.
|
||||||
|
|
||||||
|
NO WARRANTY
|
||||||
|
|
||||||
|
11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
|
||||||
|
FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
|
||||||
|
OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
|
||||||
|
PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
|
||||||
|
OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||||
|
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
|
||||||
|
TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
|
||||||
|
PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
|
||||||
|
REPAIR OR CORRECTION.
|
||||||
|
|
||||||
|
12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||||
|
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR
|
||||||
|
REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,
|
||||||
|
INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
|
||||||
|
OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED
|
||||||
|
TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY
|
||||||
|
YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
|
||||||
|
PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
|
||||||
|
POSSIBILITY OF SUCH DAMAGES.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
How to Apply These Terms to Your New Programs
|
||||||
|
|
||||||
|
If you develop a new program, and you want it to be of the greatest
|
||||||
|
possible use to the public, the best way to achieve this is to make it
|
||||||
|
free software which everyone can redistribute and change under these terms.
|
||||||
|
|
||||||
|
To do so, attach the following notices to the program. It is safest
|
||||||
|
to attach them to the start of each source file to most effectively
|
||||||
|
convey the exclusion of warranty; and each file should have at least
|
||||||
|
the "copyright" line and a pointer to where the full notice is found.
|
||||||
|
|
||||||
|
<one line to give the program's name and a brief idea of what it does.>
|
||||||
|
Copyright (C) <year> <name of author>
|
||||||
|
|
||||||
|
This program is free software; you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License version 2
|
||||||
|
as published by the Free Software Foundation.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License along
|
||||||
|
with this program; if not, write to the Free Software Foundation, Inc.,
|
||||||
|
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
|
|
||||||
|
Also add information on how to contact you by electronic and paper mail.
|
||||||
|
|
||||||
|
If the program is interactive, make it output a short notice like this
|
||||||
|
when it starts in an interactive mode:
|
||||||
|
|
||||||
|
Gnomovision version 69, Copyright (C) year name of author
|
||||||
|
Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||||
|
This is free software, and you are welcome to redistribute it
|
||||||
|
under certain conditions; type `show c' for details.
|
||||||
|
|
||||||
|
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||||
|
parts of the General Public License. Of course, the commands you use may
|
||||||
|
be called something other than `show w' and `show c'; they could even be
|
||||||
|
mouse-clicks or menu items--whatever suits your program.
|
||||||
|
|
||||||
|
You should also get your employer (if you work as a programmer) or your
|
||||||
|
school, if any, to sign a "copyright disclaimer" for the program, if
|
||||||
|
necessary. Here is a sample; alter the names:
|
||||||
|
|
||||||
|
Yoyodyne, Inc., hereby disclaims all copyright interest in the program
|
||||||
|
`Gnomovision' (which makes passes at compilers) written by James Hacker.
|
||||||
|
|
||||||
|
<signature of Ty Coon>, 1 April 1989
|
||||||
|
Ty Coon, President of Vice
|
||||||
|
|
||||||
|
This General Public License does not permit incorporating your program into
|
||||||
|
proprietary programs. If your program is a subroutine library, you may
|
||||||
|
consider it more useful to permit linking proprietary applications with the
|
||||||
|
library. If this is what you want to do, use the GNU Lesser General
|
||||||
|
Public License instead of this License.
|
17
Dockerfile
17
Dockerfile
|
@ -1,17 +0,0 @@
|
||||||
FROM golang:1.24 as awg
|
|
||||||
COPY . /awg
|
|
||||||
WORKDIR /awg
|
|
||||||
RUN go mod download && \
|
|
||||||
go mod verify && \
|
|
||||||
go build -ldflags '-linkmode external -extldflags "-fno-PIC -static"' -v -o /usr/bin
|
|
||||||
|
|
||||||
FROM alpine:3.19
|
|
||||||
ARG AWGTOOLS_RELEASE="1.0.20241018"
|
|
||||||
RUN apk --no-cache add iproute2 iptables bash && \
|
|
||||||
cd /usr/bin/ && \
|
|
||||||
wget https://github.com/amnezia-vpn/amneziawg-tools/releases/download/v${AWGTOOLS_RELEASE}/alpine-3.19-amneziawg-tools.zip && \
|
|
||||||
unzip -j alpine-3.19-amneziawg-tools.zip && \
|
|
||||||
chmod +x /usr/bin/awg /usr/bin/awg-quick && \
|
|
||||||
ln -s /usr/bin/awg /usr/bin/wg && \
|
|
||||||
ln -s /usr/bin/awg-quick /usr/bin/wg-quick
|
|
||||||
COPY --from=awg /usr/bin/amneziawg-go /usr/bin/amneziawg-go
|
|
16
Gopkg.lock
generated
Normal file
16
Gopkg.lock
generated
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
# This was generated by ./generate-vendor.sh
|
||||||
|
[[projects]]
|
||||||
|
branch = "master"
|
||||||
|
name = "golang.org/x/crypto"
|
||||||
|
revision = "1a580b3eff7814fc9b40602fd35256c63b50f491"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
branch = "master"
|
||||||
|
name = "golang.org/x/net"
|
||||||
|
revision = "2491c5de3490fced2f6cff376127c667efeed857"
|
||||||
|
|
||||||
|
[[projects]]
|
||||||
|
branch = "master"
|
||||||
|
name = "golang.org/x/sys"
|
||||||
|
revision = "7c87d13f8e835d2fb3a70a2912c811ed0c1d241b"
|
||||||
|
|
13
Gopkg.toml
Normal file
13
Gopkg.toml
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
# This was generated by ./generate-vendor.sh
|
||||||
|
[[constraint]]
|
||||||
|
branch = "master"
|
||||||
|
name = "golang.org/x/crypto"
|
||||||
|
|
||||||
|
[[constraint]]
|
||||||
|
branch = "master"
|
||||||
|
name = "golang.org/x/net"
|
||||||
|
|
||||||
|
[[constraint]]
|
||||||
|
branch = "master"
|
||||||
|
name = "golang.org/x/sys"
|
||||||
|
|
17
LICENSE
17
LICENSE
|
@ -1,17 +0,0 @@
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
||||||
this software and associated documentation files (the "Software"), to deal in
|
|
||||||
the Software without restriction, including without limitation the rights to
|
|
||||||
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
|
||||||
of the Software, and to permit persons to whom the Software is furnished to do
|
|
||||||
so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
29
Makefile
29
Makefile
|
@ -1,31 +1,16 @@
|
||||||
PREFIX ?= /usr
|
PREFIX ?= /usr
|
||||||
DESTDIR ?=
|
DESTDIR ?=
|
||||||
BINDIR ?= $(PREFIX)/bin
|
BINDIR ?= $(PREFIX)/bin
|
||||||
export GO111MODULE := on
|
|
||||||
|
|
||||||
all: generate-version-and-build
|
all: wireguard-go
|
||||||
|
|
||||||
MAKEFLAGS += --no-print-directory
|
wireguard-go: $(wildcard *.go) $(wildcard */*.go)
|
||||||
|
go build -v -o $@
|
||||||
|
|
||||||
generate-version-and-build:
|
install: wireguard-go
|
||||||
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
@install -v -d "$(DESTDIR)$(BINDIR)" && install -m 0755 -v wireguard-go "$(DESTDIR)$(BINDIR)/wireguard-go"
|
||||||
tag="$$(git describe --tags --dirty 2>/dev/null)" && \
|
|
||||||
ver="$$(printf 'package main\n\nconst Version = "%s"\n' "$$tag")" && \
|
|
||||||
[ "$$(cat version.go 2>/dev/null)" != "$$ver" ] && \
|
|
||||||
echo "$$ver" > version.go && \
|
|
||||||
git update-index --assume-unchanged version.go || true
|
|
||||||
@$(MAKE) amneziawg-go
|
|
||||||
|
|
||||||
amneziawg-go: $(wildcard *.go) $(wildcard */*.go)
|
|
||||||
go build -v -o "$@"
|
|
||||||
|
|
||||||
install: amneziawg-go
|
|
||||||
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/amneziawg-go"
|
|
||||||
|
|
||||||
test:
|
|
||||||
go test ./...
|
|
||||||
|
|
||||||
clean:
|
clean:
|
||||||
rm -f amneziawg-go
|
rm -f wireguard-go
|
||||||
|
|
||||||
.PHONY: all clean test install generate-version-and-build
|
.PHONY: clean install
|
||||||
|
|
68
README.md
68
README.md
|
@ -1,52 +1,36 @@
|
||||||
# Go Implementation of AmneziaWG
|
### Do not use this Go code.
|
||||||
|
|
||||||
AmneziaWG is a contemporary version of the WireGuard protocol. It's a fork of WireGuard-Go and offers protection against detection by Deep Packet Inspection (DPI) systems. At the same time, it retains the simplified architecture and high performance of the original.
|
This is not a complete implementation of WireGuard. If you're interested in using WireGuard, use the implementation for Linux [found here](https://git.zx2c4.com/WireGuard/) and described on the [main wireguard website](https://www.wireguard.io/). There is no group of users that should be using the code in this repository here under any circumstances at the moment, not even beta testers or dare devils. It simply isn't complete. However, if you're interested in assisting with the Go development of WireGuard and contributing to this repository, by all means dig in and help out. But users: stay far away, at least for now.
|
||||||
|
|
||||||
The precursor, WireGuard, is known for its efficiency but had issues with detection due to its distinctive packet signatures.
|
-------
|
||||||
AmneziaWG addresses this problem by employing advanced obfuscation methods, allowing its traffic to blend seamlessly with regular internet traffic.
|
|
||||||
As a result, AmneziaWG maintains high performance while adding an extra layer of stealth, making it a superb choice for those seeking a fast and discreet VPN connection.
|
|
||||||
|
|
||||||
## Usage
|
# Go Implementation of WireGuard
|
||||||
|
|
||||||
Simply run:
|
This is a work in progress for implementing WireGuard in Go.
|
||||||
|
|
||||||
```
|
## License
|
||||||
$ amneziawg-go wg0
|
|
||||||
```
|
|
||||||
|
|
||||||
This will create an interface and fork into the background. To remove the interface, use the usual `ip link del wg0`, or if your system does not support removing interfaces directly, you may instead remove the control socket via `rm -f /var/run/amneziawg/wg0.sock`, which will result in amneziawg-go shutting down.
|
This program is free software; you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License version 2 as
|
||||||
|
published by the Free Software Foundation.
|
||||||
|
|
||||||
To run amneziawg-go without forking to the background, pass `-f` or `--foreground`:
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
```
|
You should have received a copy of the GNU General Public License along
|
||||||
$ amneziawg-go -f wg0
|
with this program; if not, write to the Free Software Foundation, Inc.,
|
||||||
```
|
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
When an interface is running, you may use [`amneziawg-tools `](https://github.com/amnezia-vpn/amneziawg-tools) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
|
||||||
|
|
||||||
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
---------------------------------------------------------------------------
|
||||||
|
Additional Permissions For Submission to Apple App Store: Provided that you
|
||||||
|
are otherwise in compliance with the GPLv2 for each covered work you convey
|
||||||
|
(including without limitation making the Corresponding Source available in
|
||||||
|
compliance with Section 3 of the GPLv2), you are granted the additional
|
||||||
|
the additional permission to convey through the Apple App Store
|
||||||
|
non-source executable versions of the Program as incorporated into each
|
||||||
|
applicable covered work as Executable Versions only under the Mozilla
|
||||||
|
Public License version 2.0 (https://www.mozilla.org/en-US/MPL/2.0/).
|
||||||
|
|
||||||
|
|
||||||
## Platforms
|
|
||||||
|
|
||||||
### Linux
|
|
||||||
|
|
||||||
This will run on Linux; you should run amnezia-wg instead of using default linux kernel module.
|
|
||||||
|
|
||||||
### macOS
|
|
||||||
|
|
||||||
This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
|
|
||||||
This runs on MacOS, you should use it from [amneziawg-apple](https://github.com/amnezia-vpn/amneziawg-apple)
|
|
||||||
|
|
||||||
### Windows
|
|
||||||
|
|
||||||
This runs on Windows, you should use it from [amneziawg-windows](https://github.com/amnezia-vpn/amneziawg-windows), which uses this as a module.
|
|
||||||
|
|
||||||
|
|
||||||
## Building
|
|
||||||
|
|
||||||
This requires an installation of the latest version of [Go](https://go.dev/).
|
|
||||||
|
|
||||||
```
|
|
||||||
$ git clone https://github.com/amnezia-vpn/amneziawg-go
|
|
||||||
$ cd amneziawg-go
|
|
||||||
$ make
|
|
||||||
```
|
|
||||||
|
|
251
allowedips.go
Normal file
251
allowedips.go
Normal file
|
@ -0,0 +1,251 @@
|
||||||
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"math/bits"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
type trieEntry struct {
|
||||||
|
cidr uint
|
||||||
|
child [2]*trieEntry
|
||||||
|
bits net.IP
|
||||||
|
peer *Peer
|
||||||
|
|
||||||
|
// index of "branching" bit
|
||||||
|
|
||||||
|
bit_at_byte uint
|
||||||
|
bit_at_shift uint
|
||||||
|
}
|
||||||
|
|
||||||
|
func isLittleEndian() bool {
|
||||||
|
one := uint32(1)
|
||||||
|
return *(*byte)(unsafe.Pointer(&one)) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func swapU32(i uint32) uint32 {
|
||||||
|
if !isLittleEndian() {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
return bits.ReverseBytes32(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func swapU64(i uint64) uint64 {
|
||||||
|
if !isLittleEndian() {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
return bits.ReverseBytes64(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
func commonBits(ip1 net.IP, ip2 net.IP) uint {
|
||||||
|
size := len(ip1)
|
||||||
|
if size == net.IPv4len {
|
||||||
|
a := (*uint32)(unsafe.Pointer(&ip1[0]))
|
||||||
|
b := (*uint32)(unsafe.Pointer(&ip2[0]))
|
||||||
|
x := *a ^ *b
|
||||||
|
return uint(bits.LeadingZeros32(swapU32(x)))
|
||||||
|
} else if size == net.IPv6len {
|
||||||
|
a := (*uint64)(unsafe.Pointer(&ip1[0]))
|
||||||
|
b := (*uint64)(unsafe.Pointer(&ip2[0]))
|
||||||
|
x := *a ^ *b
|
||||||
|
if x != 0 {
|
||||||
|
return uint(bits.LeadingZeros64(swapU64(x)))
|
||||||
|
}
|
||||||
|
a = (*uint64)(unsafe.Pointer(&ip1[8]))
|
||||||
|
b = (*uint64)(unsafe.Pointer(&ip2[8]))
|
||||||
|
x = *a ^ *b
|
||||||
|
return 64 + uint(bits.LeadingZeros64(swapU64(x)))
|
||||||
|
} else {
|
||||||
|
panic("Wrong size bit string")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
|
||||||
|
if node == nil {
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// walk recursively
|
||||||
|
|
||||||
|
node.child[0] = node.child[0].removeByPeer(p)
|
||||||
|
node.child[1] = node.child[1].removeByPeer(p)
|
||||||
|
|
||||||
|
if node.peer != p {
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove peer & merge
|
||||||
|
|
||||||
|
node.peer = nil
|
||||||
|
if node.child[0] == nil {
|
||||||
|
return node.child[1]
|
||||||
|
}
|
||||||
|
return node.child[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *trieEntry) choose(ip net.IP) byte {
|
||||||
|
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
|
||||||
|
|
||||||
|
// at leaf
|
||||||
|
|
||||||
|
if node == nil {
|
||||||
|
return &trieEntry{
|
||||||
|
bits: ip,
|
||||||
|
peer: peer,
|
||||||
|
cidr: cidr,
|
||||||
|
bit_at_byte: cidr / 8,
|
||||||
|
bit_at_shift: 7 - (cidr % 8),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// traverse deeper
|
||||||
|
|
||||||
|
common := commonBits(node.bits, ip)
|
||||||
|
if node.cidr <= cidr && common >= node.cidr {
|
||||||
|
if node.cidr == cidr {
|
||||||
|
node.peer = peer
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
bit := node.choose(ip)
|
||||||
|
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
|
||||||
|
// split node
|
||||||
|
|
||||||
|
newNode := &trieEntry{
|
||||||
|
bits: ip,
|
||||||
|
peer: peer,
|
||||||
|
cidr: cidr,
|
||||||
|
bit_at_byte: cidr / 8,
|
||||||
|
bit_at_shift: 7 - (cidr % 8),
|
||||||
|
}
|
||||||
|
|
||||||
|
cidr = min(cidr, common)
|
||||||
|
|
||||||
|
// check for shorter prefix
|
||||||
|
|
||||||
|
if newNode.cidr == cidr {
|
||||||
|
bit := newNode.choose(node.bits)
|
||||||
|
newNode.child[bit] = node
|
||||||
|
return newNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// create new parent for node & newNode
|
||||||
|
|
||||||
|
parent := &trieEntry{
|
||||||
|
bits: ip,
|
||||||
|
peer: nil,
|
||||||
|
cidr: cidr,
|
||||||
|
bit_at_byte: cidr / 8,
|
||||||
|
bit_at_shift: 7 - (cidr % 8),
|
||||||
|
}
|
||||||
|
|
||||||
|
bit := parent.choose(ip)
|
||||||
|
parent.child[bit] = newNode
|
||||||
|
parent.child[bit^1] = node
|
||||||
|
|
||||||
|
return parent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *trieEntry) lookup(ip net.IP) *Peer {
|
||||||
|
var found *Peer
|
||||||
|
size := uint(len(ip))
|
||||||
|
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
||||||
|
if node.peer != nil {
|
||||||
|
found = node.peer
|
||||||
|
}
|
||||||
|
if node.bit_at_byte == size {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
bit := node.choose(ip)
|
||||||
|
node = node.child[bit]
|
||||||
|
}
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
|
func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
|
||||||
|
if node == nil {
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
if node.peer == p {
|
||||||
|
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
|
||||||
|
results = append(results, net.IPNet{
|
||||||
|
Mask: mask,
|
||||||
|
IP: node.bits.Mask(mask),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
results = node.child[0].entriesForPeer(p, results)
|
||||||
|
results = node.child[1].entriesForPeer(p, results)
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
type AllowedIPs struct {
|
||||||
|
IPv4 *trieEntry
|
||||||
|
IPv6 *trieEntry
|
||||||
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
|
||||||
|
table.mutex.RLock()
|
||||||
|
defer table.mutex.RUnlock()
|
||||||
|
|
||||||
|
allowed := make([]net.IPNet, 0, 10)
|
||||||
|
allowed = table.IPv4.entriesForPeer(peer, allowed)
|
||||||
|
allowed = table.IPv6.entriesForPeer(peer, allowed)
|
||||||
|
return allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *AllowedIPs) Reset() {
|
||||||
|
table.mutex.Lock()
|
||||||
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
|
table.IPv4 = nil
|
||||||
|
table.IPv6 = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||||
|
table.mutex.Lock()
|
||||||
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
|
table.IPv4 = table.IPv4.removeByPeer(peer)
|
||||||
|
table.IPv6 = table.IPv6.removeByPeer(peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
|
||||||
|
table.mutex.Lock()
|
||||||
|
defer table.mutex.Unlock()
|
||||||
|
|
||||||
|
switch len(ip) {
|
||||||
|
case net.IPv6len:
|
||||||
|
table.IPv6 = table.IPv6.insert(ip, cidr, peer)
|
||||||
|
case net.IPv4len:
|
||||||
|
table.IPv4 = table.IPv4.insert(ip, cidr, peer)
|
||||||
|
default:
|
||||||
|
panic(errors.New("inserting unknown address type"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
|
||||||
|
table.mutex.RLock()
|
||||||
|
defer table.mutex.RUnlock()
|
||||||
|
return table.IPv4.lookup(address)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
|
||||||
|
table.mutex.RLock()
|
||||||
|
defer table.mutex.RUnlock()
|
||||||
|
return table.IPv6.lookup(address)
|
||||||
|
}
|
131
allowedips_rand_test.go
Normal file
131
allowedips_rand_test.go
Normal file
|
@ -0,0 +1,131 @@
|
||||||
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"sort"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
NumberOfPeers = 100
|
||||||
|
NumberOfAddresses = 250
|
||||||
|
NumberOfTests = 10000
|
||||||
|
)
|
||||||
|
|
||||||
|
type SlowNode struct {
|
||||||
|
peer *Peer
|
||||||
|
cidr uint
|
||||||
|
bits []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type SlowRouter []*SlowNode
|
||||||
|
|
||||||
|
func (r SlowRouter) Len() int {
|
||||||
|
return len(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r SlowRouter) Less(i, j int) bool {
|
||||||
|
return r[i].cidr > r[j].cidr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r SlowRouter) Swap(i, j int) {
|
||||||
|
r[i], r[j] = r[j], r[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter {
|
||||||
|
for _, t := range r {
|
||||||
|
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
|
||||||
|
t.peer = peer
|
||||||
|
t.bits = addr
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r = append(r, &SlowNode{
|
||||||
|
cidr: cidr,
|
||||||
|
bits: addr,
|
||||||
|
peer: peer,
|
||||||
|
})
|
||||||
|
sort.Sort(r)
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r SlowRouter) Lookup(addr []byte) *Peer {
|
||||||
|
for _, t := range r {
|
||||||
|
common := commonBits(t.bits, addr)
|
||||||
|
if common >= t.cidr {
|
||||||
|
return t.peer
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTrieRandomIPv4(t *testing.T) {
|
||||||
|
var trie *trieEntry
|
||||||
|
var slow SlowRouter
|
||||||
|
var peers []*Peer
|
||||||
|
|
||||||
|
rand.Seed(1)
|
||||||
|
|
||||||
|
const AddressLength = 4
|
||||||
|
|
||||||
|
for n := 0; n < NumberOfPeers; n += 1 {
|
||||||
|
peers = append(peers, &Peer{})
|
||||||
|
}
|
||||||
|
|
||||||
|
for n := 0; n < NumberOfAddresses; n += 1 {
|
||||||
|
var addr [AddressLength]byte
|
||||||
|
rand.Read(addr[:])
|
||||||
|
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
||||||
|
index := rand.Int() % NumberOfPeers
|
||||||
|
trie = trie.insert(addr[:], cidr, peers[index])
|
||||||
|
slow = slow.Insert(addr[:], cidr, peers[index])
|
||||||
|
}
|
||||||
|
|
||||||
|
for n := 0; n < NumberOfTests; n += 1 {
|
||||||
|
var addr [AddressLength]byte
|
||||||
|
rand.Read(addr[:])
|
||||||
|
peer1 := slow.Lookup(addr[:])
|
||||||
|
peer2 := trie.lookup(addr[:])
|
||||||
|
if peer1 != peer2 {
|
||||||
|
t.Error("Trie did not match naive implementation, for:", addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTrieRandomIPv6(t *testing.T) {
|
||||||
|
var trie *trieEntry
|
||||||
|
var slow SlowRouter
|
||||||
|
var peers []*Peer
|
||||||
|
|
||||||
|
rand.Seed(1)
|
||||||
|
|
||||||
|
const AddressLength = 16
|
||||||
|
|
||||||
|
for n := 0; n < NumberOfPeers; n += 1 {
|
||||||
|
peers = append(peers, &Peer{})
|
||||||
|
}
|
||||||
|
|
||||||
|
for n := 0; n < NumberOfAddresses; n += 1 {
|
||||||
|
var addr [AddressLength]byte
|
||||||
|
rand.Read(addr[:])
|
||||||
|
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
||||||
|
index := rand.Int() % NumberOfPeers
|
||||||
|
trie = trie.insert(addr[:], cidr, peers[index])
|
||||||
|
slow = slow.Insert(addr[:], cidr, peers[index])
|
||||||
|
}
|
||||||
|
|
||||||
|
for n := 0; n < NumberOfTests; n += 1 {
|
||||||
|
var addr [AddressLength]byte
|
||||||
|
rand.Read(addr[:])
|
||||||
|
peer1 := slow.Lookup(addr[:])
|
||||||
|
peer2 := trie.lookup(addr[:])
|
||||||
|
if peer1 != peer2 {
|
||||||
|
t.Error("Trie did not match naive implementation, for:", addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,24 +1,47 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
/* Todo: More comprehensive
|
||||||
|
*/
|
||||||
|
|
||||||
type testPairCommonBits struct {
|
type testPairCommonBits struct {
|
||||||
s1 []byte
|
s1 []byte
|
||||||
s2 []byte
|
s2 []byte
|
||||||
match uint8
|
match uint
|
||||||
|
}
|
||||||
|
|
||||||
|
type testPairTrieInsert struct {
|
||||||
|
key []byte
|
||||||
|
cidr uint
|
||||||
|
peer *Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
type testPairTrieLookup struct {
|
||||||
|
key []byte
|
||||||
|
peer *Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
func printTrie(t *testing.T, p *trieEntry) {
|
||||||
|
if p == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t.Log(p)
|
||||||
|
printTrie(t, p.child[0])
|
||||||
|
printTrie(t, p.child[1])
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCommonBits(t *testing.T) {
|
func TestCommonBits(t *testing.T) {
|
||||||
|
|
||||||
tests := []testPairCommonBits{
|
tests := []testPairCommonBits{
|
||||||
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
|
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
|
||||||
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
|
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
|
||||||
|
@ -39,28 +62,27 @@ func TestCommonBits(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func benchmarkTrie(peerNumber, addressNumber, addressLength int, b *testing.B) {
|
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
|
||||||
var trie *trieEntry
|
var trie *trieEntry
|
||||||
var peers []*Peer
|
var peers []*Peer
|
||||||
root := parentIndirection{&trie, 2}
|
|
||||||
|
|
||||||
rand.Seed(1)
|
rand.Seed(1)
|
||||||
|
|
||||||
const AddressLength = 4
|
const AddressLength = 4
|
||||||
|
|
||||||
for n := 0; n < peerNumber; n++ {
|
for n := 0; n < peerNumber; n += 1 {
|
||||||
peers = append(peers, &Peer{})
|
peers = append(peers, &Peer{})
|
||||||
}
|
}
|
||||||
|
|
||||||
for n := 0; n < addressNumber; n++ {
|
for n := 0; n < addressNumber; n += 1 {
|
||||||
var addr [AddressLength]byte
|
var addr [AddressLength]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
cidr := uint8(rand.Uint32() % (AddressLength * 8))
|
cidr := uint(rand.Uint32() % (AddressLength * 8))
|
||||||
index := rand.Int() % peerNumber
|
index := rand.Int() % peerNumber
|
||||||
root.insert(addr[:], cidr, peers[index])
|
trie = trie.insert(addr[:], cidr, peers[index])
|
||||||
}
|
}
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n += 1 {
|
||||||
var addr [AddressLength]byte
|
var addr [AddressLength]byte
|
||||||
rand.Read(addr[:])
|
rand.Read(addr[:])
|
||||||
trie.lookup(addr[:])
|
trie.lookup(addr[:])
|
||||||
|
@ -95,21 +117,21 @@ func TestTrieIPv4(t *testing.T) {
|
||||||
g := &Peer{}
|
g := &Peer{}
|
||||||
h := &Peer{}
|
h := &Peer{}
|
||||||
|
|
||||||
var allowedIPs AllowedIPs
|
var trie *trieEntry
|
||||||
|
|
||||||
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
|
insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
|
||||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
||||||
p := allowedIPs.Lookup([]byte{a, b, c, d})
|
p := trie.lookup([]byte{a, b, c, d})
|
||||||
if p != peer {
|
if p != peer {
|
||||||
t.Error("Assert EQ failed")
|
t.Error("Assert EQ failed")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assertNEQ := func(peer *Peer, a, b, c, d byte) {
|
assertNEQ := func(peer *Peer, a, b, c, d byte) {
|
||||||
p := allowedIPs.Lookup([]byte{a, b, c, d})
|
p := trie.lookup([]byte{a, b, c, d})
|
||||||
if p == peer {
|
if p == peer {
|
||||||
t.Error("Assert NEQ failed")
|
t.Error("Assert NEQ failed")
|
||||||
}
|
}
|
||||||
|
@ -151,7 +173,7 @@ func TestTrieIPv4(t *testing.T) {
|
||||||
assertEQ(a, 192, 0, 0, 0)
|
assertEQ(a, 192, 0, 0, 0)
|
||||||
assertEQ(a, 255, 0, 0, 0)
|
assertEQ(a, 255, 0, 0, 0)
|
||||||
|
|
||||||
allowedIPs.RemoveByPeer(a)
|
trie = trie.removeByPeer(a)
|
||||||
|
|
||||||
assertNEQ(a, 1, 0, 0, 0)
|
assertNEQ(a, 1, 0, 0, 0)
|
||||||
assertNEQ(a, 64, 0, 0, 0)
|
assertNEQ(a, 64, 0, 0, 0)
|
||||||
|
@ -159,21 +181,12 @@ func TestTrieIPv4(t *testing.T) {
|
||||||
assertNEQ(a, 192, 0, 0, 0)
|
assertNEQ(a, 192, 0, 0, 0)
|
||||||
assertNEQ(a, 255, 0, 0, 0)
|
assertNEQ(a, 255, 0, 0, 0)
|
||||||
|
|
||||||
allowedIPs.RemoveByPeer(a)
|
trie = nil
|
||||||
allowedIPs.RemoveByPeer(b)
|
|
||||||
allowedIPs.RemoveByPeer(c)
|
|
||||||
allowedIPs.RemoveByPeer(d)
|
|
||||||
allowedIPs.RemoveByPeer(e)
|
|
||||||
allowedIPs.RemoveByPeer(g)
|
|
||||||
allowedIPs.RemoveByPeer(h)
|
|
||||||
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
|
|
||||||
t.Error("Expected removing all the peers to empty trie, but it did not")
|
|
||||||
}
|
|
||||||
|
|
||||||
insert(a, 192, 168, 0, 0, 16)
|
insert(a, 192, 168, 0, 0, 16)
|
||||||
insert(a, 192, 168, 0, 0, 24)
|
insert(a, 192, 168, 0, 0, 24)
|
||||||
|
|
||||||
allowedIPs.RemoveByPeer(a)
|
trie = trie.removeByPeer(a)
|
||||||
|
|
||||||
assertNEQ(a, 192, 168, 0, 1)
|
assertNEQ(a, 192, 168, 0, 1)
|
||||||
}
|
}
|
||||||
|
@ -191,7 +204,7 @@ func TestTrieIPv6(t *testing.T) {
|
||||||
g := &Peer{}
|
g := &Peer{}
|
||||||
h := &Peer{}
|
h := &Peer{}
|
||||||
|
|
||||||
var allowedIPs AllowedIPs
|
var trie *trieEntry
|
||||||
|
|
||||||
expand := func(a uint32) []byte {
|
expand := func(a uint32) []byte {
|
||||||
var out [4]byte
|
var out [4]byte
|
||||||
|
@ -202,13 +215,13 @@ func TestTrieIPv6(t *testing.T) {
|
||||||
return out[:]
|
return out[:]
|
||||||
}
|
}
|
||||||
|
|
||||||
insert := func(peer *Peer, a, b, c, d uint32, cidr uint8) {
|
insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
|
||||||
var addr []byte
|
var addr []byte
|
||||||
addr = append(addr, expand(a)...)
|
addr = append(addr, expand(a)...)
|
||||||
addr = append(addr, expand(b)...)
|
addr = append(addr, expand(b)...)
|
||||||
addr = append(addr, expand(c)...)
|
addr = append(addr, expand(c)...)
|
||||||
addr = append(addr, expand(d)...)
|
addr = append(addr, expand(d)...)
|
||||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
trie = trie.insert(addr, cidr, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||||
|
@ -217,7 +230,7 @@ func TestTrieIPv6(t *testing.T) {
|
||||||
addr = append(addr, expand(b)...)
|
addr = append(addr, expand(b)...)
|
||||||
addr = append(addr, expand(c)...)
|
addr = append(addr, expand(c)...)
|
||||||
addr = append(addr, expand(d)...)
|
addr = append(addr, expand(d)...)
|
||||||
p := allowedIPs.Lookup(addr)
|
p := trie.lookup(addr)
|
||||||
if p != peer {
|
if p != peer {
|
||||||
t.Error("Assert EQ failed")
|
t.Error("Assert EQ failed")
|
||||||
}
|
}
|
|
@ -1,24 +1,23 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package main
|
||||||
|
|
||||||
import (
|
import "errors"
|
||||||
"errors"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
)
|
|
||||||
|
|
||||||
type DummyDatagram struct {
|
type DummyDatagram struct {
|
||||||
msg []byte
|
msg []byte
|
||||||
endpoint conn.Endpoint
|
endpoint Endpoint
|
||||||
|
world bool // better type
|
||||||
}
|
}
|
||||||
|
|
||||||
type DummyBind struct {
|
type DummyBind struct {
|
||||||
in6 chan DummyDatagram
|
in6 chan DummyDatagram
|
||||||
|
ou6 chan DummyDatagram
|
||||||
in4 chan DummyDatagram
|
in4 chan DummyDatagram
|
||||||
|
ou4 chan DummyDatagram
|
||||||
closed bool
|
closed bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -26,21 +25,21 @@ func (b *DummyBind) SetMark(v uint32) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) ReceiveIPv6(buf []byte) (int, conn.Endpoint, error) {
|
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||||
datagram, ok := <-b.in6
|
datagram, ok := <-b.in6
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, nil, errors.New("closed")
|
return 0, nil, errors.New("closed")
|
||||||
}
|
}
|
||||||
copy(buf, datagram.msg)
|
copy(buff, datagram.msg)
|
||||||
return len(datagram.msg), datagram.endpoint, nil
|
return len(datagram.msg), datagram.endpoint, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) ReceiveIPv4(buf []byte) (int, conn.Endpoint, error) {
|
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
||||||
datagram, ok := <-b.in4
|
datagram, ok := <-b.in4
|
||||||
if !ok {
|
if !ok {
|
||||||
return 0, nil, errors.New("closed")
|
return 0, nil, errors.New("closed")
|
||||||
}
|
}
|
||||||
copy(buf, datagram.msg)
|
copy(buff, datagram.msg)
|
||||||
return len(datagram.msg), datagram.endpoint, nil
|
return len(datagram.msg), datagram.endpoint, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,6 +50,6 @@ func (b *DummyBind) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *DummyBind) Send(buf []byte, end conn.Endpoint) error {
|
func (b *DummyBind) Send(buff []byte, end Endpoint) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
6
build.cmd
Executable file
6
build.cmd
Executable file
|
@ -0,0 +1,6 @@
|
||||||
|
@echo off
|
||||||
|
|
||||||
|
REM builds wireguard for windows
|
||||||
|
|
||||||
|
go get
|
||||||
|
go build -o wireguard-go.exe
|
170
conn.go
Normal file
170
conn.go
Normal file
|
@ -0,0 +1,170 @@
|
||||||
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"golang.org/x/net/ipv4"
|
||||||
|
"golang.org/x/net/ipv6"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
|
||||||
|
*/
|
||||||
|
type Bind interface {
|
||||||
|
SetMark(value uint32) error
|
||||||
|
ReceiveIPv6(buff []byte) (int, Endpoint, error)
|
||||||
|
ReceiveIPv4(buff []byte) (int, Endpoint, error)
|
||||||
|
Send(buff []byte, end Endpoint) error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
/* An Endpoint maintains the source/destination caching for a peer
|
||||||
|
*
|
||||||
|
* dst : the remote address of a peer ("endpoint" in uapi terminology)
|
||||||
|
* src : the local address from which datagrams originate going to the peer
|
||||||
|
*/
|
||||||
|
type Endpoint interface {
|
||||||
|
ClearSrc() // clears the source address
|
||||||
|
SrcToString() string // returns the local source address (ip:port)
|
||||||
|
DstToString() string // returns the destination address (ip:port)
|
||||||
|
DstToBytes() []byte // used for mac2 cookie calculations
|
||||||
|
DstIP() net.IP
|
||||||
|
SrcIP() net.IP
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
||||||
|
|
||||||
|
// ensure that the host is an IP address
|
||||||
|
|
||||||
|
host, _, err := net.SplitHostPort(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if ip := net.ParseIP(host); ip == nil {
|
||||||
|
return nil, errors.New("Failed to parse IP address: " + host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse address and port
|
||||||
|
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return addr, err
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Must hold device and net lock
|
||||||
|
*/
|
||||||
|
func unsafeCloseBind(device *Device) error {
|
||||||
|
var err error
|
||||||
|
netc := &device.net
|
||||||
|
if netc.bind != nil {
|
||||||
|
err = netc.bind.Close()
|
||||||
|
netc.bind = nil
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) BindSetMark(mark uint32) error {
|
||||||
|
|
||||||
|
device.net.mutex.Lock()
|
||||||
|
defer device.net.mutex.Unlock()
|
||||||
|
|
||||||
|
// check if modified
|
||||||
|
|
||||||
|
if device.net.fwmark == mark {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// update fwmark on existing bind
|
||||||
|
|
||||||
|
device.net.fwmark = mark
|
||||||
|
if device.isUp.Get() && device.net.bind != nil {
|
||||||
|
if err := device.net.bind.SetMark(mark); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear cached source addresses
|
||||||
|
|
||||||
|
device.peers.mutex.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.mutex.Lock()
|
||||||
|
defer peer.mutex.Unlock()
|
||||||
|
if peer.endpoint != nil {
|
||||||
|
peer.endpoint.ClearSrc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
device.peers.mutex.RUnlock()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) BindUpdate() error {
|
||||||
|
|
||||||
|
device.net.mutex.Lock()
|
||||||
|
defer device.net.mutex.Unlock()
|
||||||
|
|
||||||
|
// close existing sockets
|
||||||
|
|
||||||
|
if err := unsafeCloseBind(device); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// open new sockets
|
||||||
|
|
||||||
|
if device.isUp.Get() {
|
||||||
|
|
||||||
|
// bind to new port
|
||||||
|
|
||||||
|
var err error
|
||||||
|
netc := &device.net
|
||||||
|
netc.bind, netc.port, err = CreateBind(netc.port, device)
|
||||||
|
if err != nil {
|
||||||
|
netc.bind = nil
|
||||||
|
netc.port = 0
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// set fwmark
|
||||||
|
|
||||||
|
if netc.fwmark != 0 {
|
||||||
|
err = netc.bind.SetMark(netc.fwmark)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear cached source addresses
|
||||||
|
|
||||||
|
device.peers.mutex.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.mutex.Lock()
|
||||||
|
defer peer.mutex.Unlock()
|
||||||
|
if peer.endpoint != nil {
|
||||||
|
peer.endpoint.ClearSrc()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
device.peers.mutex.RUnlock()
|
||||||
|
|
||||||
|
// start receiving routines
|
||||||
|
|
||||||
|
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
||||||
|
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
||||||
|
|
||||||
|
device.log.Debug.Println("UDP bind has been updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) BindClose() error {
|
||||||
|
device.net.mutex.Lock()
|
||||||
|
err := unsafeCloseBind(device)
|
||||||
|
device.net.mutex.Unlock()
|
||||||
|
return err
|
||||||
|
}
|
544
conn/bind_std.go
544
conn/bind_std.go
|
@ -1,544 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"runtime"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.org/x/net/ipv6"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
_ Bind = (*StdNetBind)(nil)
|
|
||||||
)
|
|
||||||
|
|
||||||
// StdNetBind implements Bind for all platforms. While Windows has its own Bind
|
|
||||||
// (see bind_windows.go), it may fall back to StdNetBind.
|
|
||||||
// TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
|
|
||||||
// methods for sending and receiving multiple datagrams per-syscall. See the
|
|
||||||
// proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
|
|
||||||
type StdNetBind struct {
|
|
||||||
mu sync.Mutex // protects all fields except as specified
|
|
||||||
ipv4 *net.UDPConn
|
|
||||||
ipv6 *net.UDPConn
|
|
||||||
ipv4PC *ipv4.PacketConn // will be nil on non-Linux
|
|
||||||
ipv6PC *ipv6.PacketConn // will be nil on non-Linux
|
|
||||||
ipv4TxOffload bool
|
|
||||||
ipv4RxOffload bool
|
|
||||||
ipv6TxOffload bool
|
|
||||||
ipv6RxOffload bool
|
|
||||||
|
|
||||||
// these two fields are not guarded by mu
|
|
||||||
udpAddrPool sync.Pool
|
|
||||||
msgsPool sync.Pool
|
|
||||||
|
|
||||||
blackhole4 bool
|
|
||||||
blackhole6 bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewStdNetBind() Bind {
|
|
||||||
return &StdNetBind{
|
|
||||||
udpAddrPool: sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
return &net.UDPAddr{
|
|
||||||
IP: make([]byte, 16),
|
|
||||||
}
|
|
||||||
},
|
|
||||||
},
|
|
||||||
|
|
||||||
msgsPool: sync.Pool{
|
|
||||||
New: func() any {
|
|
||||||
// ipv6.Message and ipv4.Message are interchangeable as they are
|
|
||||||
// both aliases for x/net/internal/socket.Message.
|
|
||||||
msgs := make([]ipv6.Message, IdealBatchSize)
|
|
||||||
for i := range msgs {
|
|
||||||
msgs[i].Buffers = make(net.Buffers, 1)
|
|
||||||
msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
|
|
||||||
}
|
|
||||||
return &msgs
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type StdNetEndpoint struct {
|
|
||||||
// AddrPort is the endpoint destination.
|
|
||||||
netip.AddrPort
|
|
||||||
// src is the current sticky source address and interface index, if
|
|
||||||
// supported. Typically this is a PKTINFO structure from/for control
|
|
||||||
// messages, see unix.PKTINFO for an example.
|
|
||||||
src []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
_ Bind = (*StdNetBind)(nil)
|
|
||||||
_ Endpoint = &StdNetEndpoint{}
|
|
||||||
)
|
|
||||||
|
|
||||||
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
|
|
||||||
e, err := netip.ParseAddrPort(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &StdNetEndpoint{
|
|
||||||
AddrPort: e,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) ClearSrc() {
|
|
||||||
if e.src != nil {
|
|
||||||
// Truncate src, no need to reallocate.
|
|
||||||
e.src = e.src[:0]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) DstIP() netip.Addr {
|
|
||||||
return e.AddrPort.Addr()
|
|
||||||
}
|
|
||||||
|
|
||||||
// See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) DstToBytes() []byte {
|
|
||||||
b, _ := e.AddrPort.MarshalBinary()
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) DstToString() string {
|
|
||||||
return e.AddrPort.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
|
||||||
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve port.
|
|
||||||
laddr := conn.LocalAddr()
|
|
||||||
uaddr, err := net.ResolveUDPAddr(
|
|
||||||
laddr.Network(),
|
|
||||||
laddr.String(),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
return conn.(*net.UDPConn), uaddr.Port, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
var err error
|
|
||||||
var tries int
|
|
||||||
|
|
||||||
if s.ipv4 != nil || s.ipv6 != nil {
|
|
||||||
return nil, 0, ErrBindAlreadyOpen
|
|
||||||
}
|
|
||||||
|
|
||||||
// Attempt to open ipv4 and ipv6 listeners on the same port.
|
|
||||||
// If uport is 0, we can retry on failure.
|
|
||||||
again:
|
|
||||||
port := int(uport)
|
|
||||||
var v4conn, v6conn *net.UDPConn
|
|
||||||
var v4pc *ipv4.PacketConn
|
|
||||||
var v6pc *ipv6.PacketConn
|
|
||||||
|
|
||||||
v4conn, port, err = listenNet("udp4", port)
|
|
||||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Listen on the same port as we're using for ipv4.
|
|
||||||
v6conn, port, err = listenNet("udp6", port)
|
|
||||||
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
|
|
||||||
v4conn.Close()
|
|
||||||
tries++
|
|
||||||
goto again
|
|
||||||
}
|
|
||||||
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
|
|
||||||
v4conn.Close()
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
var fns []ReceiveFunc
|
|
||||||
if v4conn != nil {
|
|
||||||
s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
|
|
||||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
|
||||||
v4pc = ipv4.NewPacketConn(v4conn)
|
|
||||||
s.ipv4PC = v4pc
|
|
||||||
}
|
|
||||||
fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
|
|
||||||
s.ipv4 = v4conn
|
|
||||||
}
|
|
||||||
if v6conn != nil {
|
|
||||||
s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
|
|
||||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
|
||||||
v6pc = ipv6.NewPacketConn(v6conn)
|
|
||||||
s.ipv6PC = v6pc
|
|
||||||
}
|
|
||||||
fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
|
|
||||||
s.ipv6 = v6conn
|
|
||||||
}
|
|
||||||
if len(fns) == 0 {
|
|
||||||
return nil, 0, syscall.EAFNOSUPPORT
|
|
||||||
}
|
|
||||||
|
|
||||||
return fns, uint16(port), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
|
|
||||||
for i := range *msgs {
|
|
||||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
|
|
||||||
(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
|
|
||||||
}
|
|
||||||
s.msgsPool.Put(msgs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) getMessages() *[]ipv6.Message {
|
|
||||||
return s.msgsPool.Get().(*[]ipv6.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
// If compilation fails here these are no longer the same underlying type.
|
|
||||||
_ ipv6.Message = ipv4.Message{}
|
|
||||||
)
|
|
||||||
|
|
||||||
type batchReader interface {
|
|
||||||
ReadBatch([]ipv6.Message, int) (int, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
type batchWriter interface {
|
|
||||||
WriteBatch([]ipv6.Message, int) (int, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) receiveIP(
|
|
||||||
br batchReader,
|
|
||||||
conn *net.UDPConn,
|
|
||||||
rxOffload bool,
|
|
||||||
bufs [][]byte,
|
|
||||||
sizes []int,
|
|
||||||
eps []Endpoint,
|
|
||||||
) (n int, err error) {
|
|
||||||
msgs := s.getMessages()
|
|
||||||
for i := range bufs {
|
|
||||||
(*msgs)[i].Buffers[0] = bufs[i]
|
|
||||||
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
|
|
||||||
}
|
|
||||||
defer s.putMessages(msgs)
|
|
||||||
var numMsgs int
|
|
||||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
|
||||||
if rxOffload {
|
|
||||||
readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
|
|
||||||
numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
numMsgs, err = br.ReadBatch(*msgs, 0)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msg := &(*msgs)[0]
|
|
||||||
msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
numMsgs = 1
|
|
||||||
}
|
|
||||||
for i := 0; i < numMsgs; i++ {
|
|
||||||
msg := &(*msgs)[i]
|
|
||||||
sizes[i] = msg.N
|
|
||||||
if sizes[i] == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
|
||||||
ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
|
||||||
getSrcFromControl(msg.OOB[:msg.NN], ep)
|
|
||||||
eps[i] = ep
|
|
||||||
}
|
|
||||||
return numMsgs, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
|
||||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
|
||||||
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
|
|
||||||
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
|
|
||||||
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
|
||||||
// rename the IdealBatchSize constant to BatchSize.
|
|
||||||
func (s *StdNetBind) BatchSize() int {
|
|
||||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
|
||||||
return IdealBatchSize
|
|
||||||
}
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) Close() error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
var err1, err2 error
|
|
||||||
if s.ipv4 != nil {
|
|
||||||
err1 = s.ipv4.Close()
|
|
||||||
s.ipv4 = nil
|
|
||||||
s.ipv4PC = nil
|
|
||||||
}
|
|
||||||
if s.ipv6 != nil {
|
|
||||||
err2 = s.ipv6.Close()
|
|
||||||
s.ipv6 = nil
|
|
||||||
s.ipv6PC = nil
|
|
||||||
}
|
|
||||||
s.blackhole4 = false
|
|
||||||
s.blackhole6 = false
|
|
||||||
s.ipv4TxOffload = false
|
|
||||||
s.ipv4RxOffload = false
|
|
||||||
s.ipv6TxOffload = false
|
|
||||||
s.ipv6RxOffload = false
|
|
||||||
if err1 != nil {
|
|
||||||
return err1
|
|
||||||
}
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
|
|
||||||
type ErrUDPGSODisabled struct {
|
|
||||||
onLaddr string
|
|
||||||
RetryErr error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e ErrUDPGSODisabled) Error() string {
|
|
||||||
return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload or peer MTU with protocol headers is greater than path MTU", e.onLaddr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e ErrUDPGSODisabled) Unwrap() error {
|
|
||||||
return e.RetryErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
blackhole := s.blackhole4
|
|
||||||
conn := s.ipv4
|
|
||||||
offload := s.ipv4TxOffload
|
|
||||||
br := batchWriter(s.ipv4PC)
|
|
||||||
is6 := false
|
|
||||||
if endpoint.DstIP().Is6() {
|
|
||||||
blackhole = s.blackhole6
|
|
||||||
conn = s.ipv6
|
|
||||||
br = s.ipv6PC
|
|
||||||
is6 = true
|
|
||||||
offload = s.ipv6TxOffload
|
|
||||||
}
|
|
||||||
s.mu.Unlock()
|
|
||||||
|
|
||||||
if blackhole {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if conn == nil {
|
|
||||||
return syscall.EAFNOSUPPORT
|
|
||||||
}
|
|
||||||
|
|
||||||
msgs := s.getMessages()
|
|
||||||
defer s.putMessages(msgs)
|
|
||||||
ua := s.udpAddrPool.Get().(*net.UDPAddr)
|
|
||||||
defer s.udpAddrPool.Put(ua)
|
|
||||||
if is6 {
|
|
||||||
as16 := endpoint.DstIP().As16()
|
|
||||||
copy(ua.IP, as16[:])
|
|
||||||
ua.IP = ua.IP[:16]
|
|
||||||
} else {
|
|
||||||
as4 := endpoint.DstIP().As4()
|
|
||||||
copy(ua.IP, as4[:])
|
|
||||||
ua.IP = ua.IP[:4]
|
|
||||||
}
|
|
||||||
ua.Port = int(endpoint.(*StdNetEndpoint).Port())
|
|
||||||
var (
|
|
||||||
retried bool
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
retry:
|
|
||||||
if offload {
|
|
||||||
n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
|
|
||||||
err = s.send(conn, br, (*msgs)[:n])
|
|
||||||
if err != nil && offload && errShouldDisableUDPGSO(err) {
|
|
||||||
offload = false
|
|
||||||
s.mu.Lock()
|
|
||||||
if is6 {
|
|
||||||
s.ipv6TxOffload = false
|
|
||||||
} else {
|
|
||||||
s.ipv4TxOffload = false
|
|
||||||
}
|
|
||||||
s.mu.Unlock()
|
|
||||||
retried = true
|
|
||||||
goto retry
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for i := range bufs {
|
|
||||||
(*msgs)[i].Addr = ua
|
|
||||||
(*msgs)[i].Buffers[0] = bufs[i]
|
|
||||||
setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
|
|
||||||
}
|
|
||||||
err = s.send(conn, br, (*msgs)[:len(bufs)])
|
|
||||||
}
|
|
||||||
if retried {
|
|
||||||
return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
|
|
||||||
var (
|
|
||||||
n int
|
|
||||||
err error
|
|
||||||
start int
|
|
||||||
)
|
|
||||||
if runtime.GOOS == "linux" || runtime.GOOS == "android" {
|
|
||||||
for {
|
|
||||||
n, err = pc.WriteBatch(msgs[start:], 0)
|
|
||||||
if err != nil || n == len(msgs[start:]) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
start += n
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for _, msg := range msgs {
|
|
||||||
_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Exceeding these values results in EMSGSIZE. They account for layer3 and
|
|
||||||
// layer4 headers. IPv6 does not need to account for itself as the payload
|
|
||||||
// length field is self excluding.
|
|
||||||
maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
|
|
||||||
maxIPv6PayloadLen = 1<<16 - 1 - 8
|
|
||||||
|
|
||||||
// This is a hard limit imposed by the kernel.
|
|
||||||
udpSegmentMaxDatagrams = 64
|
|
||||||
)
|
|
||||||
|
|
||||||
type setGSOFunc func(control *[]byte, gsoSize uint16)
|
|
||||||
|
|
||||||
func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
|
|
||||||
var (
|
|
||||||
base = -1 // index of msg we are currently coalescing into
|
|
||||||
gsoSize int // segmentation size of msgs[base]
|
|
||||||
dgramCnt int // number of dgrams coalesced into msgs[base]
|
|
||||||
endBatch bool // tracking flag to start a new batch on next iteration of bufs
|
|
||||||
)
|
|
||||||
maxPayloadLen := maxIPv4PayloadLen
|
|
||||||
if ep.DstIP().Is6() {
|
|
||||||
maxPayloadLen = maxIPv6PayloadLen
|
|
||||||
}
|
|
||||||
for i, buf := range bufs {
|
|
||||||
if i > 0 {
|
|
||||||
msgLen := len(buf)
|
|
||||||
baseLenBefore := len(msgs[base].Buffers[0])
|
|
||||||
freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
|
|
||||||
if msgLen+baseLenBefore <= maxPayloadLen &&
|
|
||||||
msgLen <= gsoSize &&
|
|
||||||
msgLen <= freeBaseCap &&
|
|
||||||
dgramCnt < udpSegmentMaxDatagrams &&
|
|
||||||
!endBatch {
|
|
||||||
msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
|
|
||||||
if i == len(bufs)-1 {
|
|
||||||
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
|
||||||
}
|
|
||||||
dgramCnt++
|
|
||||||
if msgLen < gsoSize {
|
|
||||||
// A smaller than gsoSize packet on the tail is legal, but
|
|
||||||
// it must end the batch.
|
|
||||||
endBatch = true
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if dgramCnt > 1 {
|
|
||||||
setGSO(&msgs[base].OOB, uint16(gsoSize))
|
|
||||||
}
|
|
||||||
// Reset prior to incrementing base since we are preparing to start a
|
|
||||||
// new potential batch.
|
|
||||||
endBatch = false
|
|
||||||
base++
|
|
||||||
gsoSize = len(buf)
|
|
||||||
setSrcControl(&msgs[base].OOB, ep)
|
|
||||||
msgs[base].Buffers[0] = buf
|
|
||||||
msgs[base].Addr = addr
|
|
||||||
dgramCnt = 1
|
|
||||||
}
|
|
||||||
return base + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
type getGSOFunc func(control []byte) (int, error)
|
|
||||||
|
|
||||||
func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
|
|
||||||
for i := firstMsgAt; i < len(msgs); i++ {
|
|
||||||
msg := &msgs[i]
|
|
||||||
if msg.N == 0 {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
var (
|
|
||||||
gsoSize int
|
|
||||||
start int
|
|
||||||
end = msg.N
|
|
||||||
numToSplit = 1
|
|
||||||
)
|
|
||||||
gsoSize, err = getGSO(msg.OOB[:msg.NN])
|
|
||||||
if err != nil {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
if gsoSize > 0 {
|
|
||||||
numToSplit = (msg.N + gsoSize - 1) / gsoSize
|
|
||||||
end = gsoSize
|
|
||||||
}
|
|
||||||
for j := 0; j < numToSplit; j++ {
|
|
||||||
if n > i {
|
|
||||||
return n, errors.New("splitting coalesced packet resulted in overflow")
|
|
||||||
}
|
|
||||||
copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
|
|
||||||
msgs[n].N = copied
|
|
||||||
msgs[n].Addr = msg.Addr
|
|
||||||
start = end
|
|
||||||
end += gsoSize
|
|
||||||
if end > msg.N {
|
|
||||||
end = msg.N
|
|
||||||
}
|
|
||||||
n++
|
|
||||||
}
|
|
||||||
if i != n-1 {
|
|
||||||
// It is legal for bytes to move within msg.Buffers[0] as a result
|
|
||||||
// of splitting, so we only zero the source msg len when it is not
|
|
||||||
// the destination of the last split operation above.
|
|
||||||
msg.N = 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return n, nil
|
|
||||||
}
|
|
|
@ -1,250 +0,0 @@
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"net"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"golang.org/x/net/ipv6"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
|
|
||||||
bind := NewStdNetBind().(*StdNetBind)
|
|
||||||
fns, _, err := bind.Open(0)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
bind.Close()
|
|
||||||
bufs := make([][]byte, 1)
|
|
||||||
bufs[0] = make([]byte, 1)
|
|
||||||
sizes := make([]int, 1)
|
|
||||||
eps := make([]Endpoint, 1)
|
|
||||||
for _, fn := range fns {
|
|
||||||
// The ReceiveFuncs must not access conn-related fields on StdNetBind
|
|
||||||
// unguarded. Close() nils the conn-related fields resulting in a panic
|
|
||||||
// if they violate the mutex.
|
|
||||||
fn(bufs, sizes, eps)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func mockSetGSOSize(control *[]byte, gsoSize uint16) {
|
|
||||||
*control = (*control)[:cap(*control)]
|
|
||||||
binary.LittleEndian.PutUint16(*control, gsoSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_coalesceMessages(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
name string
|
|
||||||
buffs [][]byte
|
|
||||||
wantLens []int
|
|
||||||
wantGSO []int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "one message no coalesce",
|
|
||||||
buffs: [][]byte{
|
|
||||||
make([]byte, 1, 1),
|
|
||||||
},
|
|
||||||
wantLens: []int{1},
|
|
||||||
wantGSO: []int{0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "two messages equal len coalesce",
|
|
||||||
buffs: [][]byte{
|
|
||||||
make([]byte, 1, 2),
|
|
||||||
make([]byte, 1, 1),
|
|
||||||
},
|
|
||||||
wantLens: []int{2},
|
|
||||||
wantGSO: []int{1},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "two messages unequal len coalesce",
|
|
||||||
buffs: [][]byte{
|
|
||||||
make([]byte, 2, 3),
|
|
||||||
make([]byte, 1, 1),
|
|
||||||
},
|
|
||||||
wantLens: []int{3},
|
|
||||||
wantGSO: []int{2},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "three messages second unequal len coalesce",
|
|
||||||
buffs: [][]byte{
|
|
||||||
make([]byte, 2, 3),
|
|
||||||
make([]byte, 1, 1),
|
|
||||||
make([]byte, 2, 2),
|
|
||||||
},
|
|
||||||
wantLens: []int{3, 2},
|
|
||||||
wantGSO: []int{2, 0},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "three messages limited cap coalesce",
|
|
||||||
buffs: [][]byte{
|
|
||||||
make([]byte, 2, 4),
|
|
||||||
make([]byte, 2, 2),
|
|
||||||
make([]byte, 2, 2),
|
|
||||||
},
|
|
||||||
wantLens: []int{4, 2},
|
|
||||||
wantGSO: []int{2, 0},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range cases {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
addr := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP("127.0.0.1").To4(),
|
|
||||||
Port: 1,
|
|
||||||
}
|
|
||||||
msgs := make([]ipv6.Message, len(tt.buffs))
|
|
||||||
for i := range msgs {
|
|
||||||
msgs[i].Buffers = make([][]byte, 1)
|
|
||||||
msgs[i].OOB = make([]byte, 0, 2)
|
|
||||||
}
|
|
||||||
got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
|
|
||||||
if got != len(tt.wantLens) {
|
|
||||||
t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
|
|
||||||
}
|
|
||||||
for i := 0; i < got; i++ {
|
|
||||||
if msgs[i].Addr != addr {
|
|
||||||
t.Errorf("msgs[%d].Addr != passed addr", i)
|
|
||||||
}
|
|
||||||
gotLen := len(msgs[i].Buffers[0])
|
|
||||||
if gotLen != tt.wantLens[i] {
|
|
||||||
t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
|
|
||||||
}
|
|
||||||
gotGSO, err := mockGetGSOSize(msgs[i].OOB)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
|
|
||||||
}
|
|
||||||
if gotGSO != tt.wantGSO[i] {
|
|
||||||
t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func mockGetGSOSize(control []byte) (int, error) {
|
|
||||||
if len(control) < 2 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
return int(binary.LittleEndian.Uint16(control)), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_splitCoalescedMessages(t *testing.T) {
|
|
||||||
newMsg := func(n, gso int) ipv6.Message {
|
|
||||||
msg := ipv6.Message{
|
|
||||||
Buffers: [][]byte{make([]byte, 1<<16-1)},
|
|
||||||
N: n,
|
|
||||||
OOB: make([]byte, 2),
|
|
||||||
}
|
|
||||||
binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
|
|
||||||
if gso > 0 {
|
|
||||||
msg.NN = 2
|
|
||||||
}
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
name string
|
|
||||||
msgs []ipv6.Message
|
|
||||||
firstMsgAt int
|
|
||||||
wantNumEval int
|
|
||||||
wantMsgLens []int
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "second last split last empty",
|
|
||||||
msgs: []ipv6.Message{
|
|
||||||
newMsg(0, 0),
|
|
||||||
newMsg(0, 0),
|
|
||||||
newMsg(3, 1),
|
|
||||||
newMsg(0, 0),
|
|
||||||
},
|
|
||||||
firstMsgAt: 2,
|
|
||||||
wantNumEval: 3,
|
|
||||||
wantMsgLens: []int{1, 1, 1, 0},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "second last no split last empty",
|
|
||||||
msgs: []ipv6.Message{
|
|
||||||
newMsg(0, 0),
|
|
||||||
newMsg(0, 0),
|
|
||||||
newMsg(1, 0),
|
|
||||||
newMsg(0, 0),
|
|
||||||
},
|
|
||||||
firstMsgAt: 2,
|
|
||||||
wantNumEval: 1,
|
|
||||||
wantMsgLens: []int{1, 0, 0, 0},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "second last no split last no split",
|
|
||||||
msgs: []ipv6.Message{
|
|
||||||
newMsg(0, 0),
|
|
||||||
newMsg(0, 0),
|
|
||||||
newMsg(1, 0),
|
|
||||||
newMsg(1, 0),
|
|
||||||
},
|
|
||||||
firstMsgAt: 2,
|
|
||||||
wantNumEval: 2,
|
|
||||||
wantMsgLens: []int{1, 1, 0, 0},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "second last no split last split",
|
|
||||||
msgs: []ipv6.Message{
|
|
||||||
newMsg(0, 0),
|
|
||||||
newMsg(0, 0),
|
|
||||||
newMsg(1, 0),
|
|
||||||
newMsg(3, 1),
|
|
||||||
},
|
|
||||||
firstMsgAt: 2,
|
|
||||||
wantNumEval: 4,
|
|
||||||
wantMsgLens: []int{1, 1, 1, 1},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "second last split last split",
|
|
||||||
msgs: []ipv6.Message{
|
|
||||||
newMsg(0, 0),
|
|
||||||
newMsg(0, 0),
|
|
||||||
newMsg(2, 1),
|
|
||||||
newMsg(2, 1),
|
|
||||||
},
|
|
||||||
firstMsgAt: 2,
|
|
||||||
wantNumEval: 4,
|
|
||||||
wantMsgLens: []int{1, 1, 1, 1},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "second last no split last split overflow",
|
|
||||||
msgs: []ipv6.Message{
|
|
||||||
newMsg(0, 0),
|
|
||||||
newMsg(0, 0),
|
|
||||||
newMsg(1, 0),
|
|
||||||
newMsg(4, 1),
|
|
||||||
},
|
|
||||||
firstMsgAt: 2,
|
|
||||||
wantNumEval: 4,
|
|
||||||
wantMsgLens: []int{1, 1, 1, 1},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range cases {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
|
|
||||||
if err != nil && !tt.wantErr {
|
|
||||||
t.Fatalf("err: %v", err)
|
|
||||||
}
|
|
||||||
if got != tt.wantNumEval {
|
|
||||||
t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
|
|
||||||
}
|
|
||||||
for i, msg := range tt.msgs {
|
|
||||||
if msg.N != tt.wantMsgLens[i] {
|
|
||||||
t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,601 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn/winrio"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
packetsPerRing = 1024
|
|
||||||
bytesPerPacket = 2048 - 32
|
|
||||||
receiveSpins = 15
|
|
||||||
)
|
|
||||||
|
|
||||||
type ringPacket struct {
|
|
||||||
addr WinRingEndpoint
|
|
||||||
data [bytesPerPacket]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type ringBuffer struct {
|
|
||||||
packets uintptr
|
|
||||||
head, tail uint32
|
|
||||||
id winrio.BufferId
|
|
||||||
iocp windows.Handle
|
|
||||||
isFull bool
|
|
||||||
cq winrio.Cq
|
|
||||||
mu sync.Mutex
|
|
||||||
overlapped windows.Overlapped
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rb *ringBuffer) Push() *ringPacket {
|
|
||||||
for rb.isFull {
|
|
||||||
panic("ring is full")
|
|
||||||
}
|
|
||||||
ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
|
|
||||||
rb.tail += 1
|
|
||||||
if rb.tail%packetsPerRing == rb.head%packetsPerRing {
|
|
||||||
rb.isFull = true
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func (rb *ringBuffer) Return(count uint32) {
|
|
||||||
if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
rb.head += count
|
|
||||||
rb.isFull = false
|
|
||||||
}
|
|
||||||
|
|
||||||
type afWinRingBind struct {
|
|
||||||
sock windows.Handle
|
|
||||||
rx, tx ringBuffer
|
|
||||||
rq winrio.Rq
|
|
||||||
mu sync.Mutex
|
|
||||||
blackhole bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// WinRingBind uses Windows registered I/O for fast ring buffered networking.
|
|
||||||
type WinRingBind struct {
|
|
||||||
v4, v6 afWinRingBind
|
|
||||||
mu sync.RWMutex
|
|
||||||
isOpen atomic.Uint32 // 0, 1, or 2
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDefaultBind() Bind { return NewWinRingBind() }
|
|
||||||
|
|
||||||
func NewWinRingBind() Bind {
|
|
||||||
if !winrio.Initialize() {
|
|
||||||
return NewStdNetBind()
|
|
||||||
}
|
|
||||||
return new(WinRingBind)
|
|
||||||
}
|
|
||||||
|
|
||||||
type WinRingEndpoint struct {
|
|
||||||
family uint16
|
|
||||||
data [30]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
_ Bind = (*WinRingBind)(nil)
|
|
||||||
_ Endpoint = (*WinRingEndpoint)(nil)
|
|
||||||
)
|
|
||||||
|
|
||||||
func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
|
|
||||||
host, port, err := net.SplitHostPort(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
host16, err := windows.UTF16PtrFromString(host)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
port16, err := windows.UTF16PtrFromString(port)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
hints := windows.AddrinfoW{
|
|
||||||
Flags: windows.AI_NUMERICHOST,
|
|
||||||
Family: windows.AF_UNSPEC,
|
|
||||||
Socktype: windows.SOCK_DGRAM,
|
|
||||||
Protocol: windows.IPPROTO_UDP,
|
|
||||||
}
|
|
||||||
var addrinfo *windows.AddrinfoW
|
|
||||||
err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer windows.FreeAddrInfoW(addrinfo)
|
|
||||||
if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
|
|
||||||
return nil, windows.ERROR_INVALID_ADDRESS
|
|
||||||
}
|
|
||||||
var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
|
|
||||||
copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
|
|
||||||
return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*WinRingEndpoint) ClearSrc() {}
|
|
||||||
|
|
||||||
func (e *WinRingEndpoint) DstIP() netip.Addr {
|
|
||||||
switch e.family {
|
|
||||||
case windows.AF_INET:
|
|
||||||
return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
|
|
||||||
case windows.AF_INET6:
|
|
||||||
return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
|
|
||||||
}
|
|
||||||
return netip.Addr{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *WinRingEndpoint) SrcIP() netip.Addr {
|
|
||||||
return netip.Addr{} // not supported
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *WinRingEndpoint) DstToBytes() []byte {
|
|
||||||
switch e.family {
|
|
||||||
case windows.AF_INET:
|
|
||||||
b := make([]byte, 0, 6)
|
|
||||||
b = append(b, e.data[2:6]...)
|
|
||||||
b = append(b, e.data[1], e.data[0])
|
|
||||||
return b
|
|
||||||
case windows.AF_INET6:
|
|
||||||
b := make([]byte, 0, 18)
|
|
||||||
b = append(b, e.data[6:22]...)
|
|
||||||
b = append(b, e.data[1], e.data[0])
|
|
||||||
return b
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *WinRingEndpoint) DstToString() string {
|
|
||||||
switch e.family {
|
|
||||||
case windows.AF_INET:
|
|
||||||
return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
|
|
||||||
case windows.AF_INET6:
|
|
||||||
var zone string
|
|
||||||
if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
|
|
||||||
zone = strconv.FormatUint(uint64(scope), 10)
|
|
||||||
}
|
|
||||||
return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *WinRingEndpoint) SrcToString() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ring *ringBuffer) CloseAndZero() {
|
|
||||||
if ring.cq != 0 {
|
|
||||||
winrio.CloseCompletionQueue(ring.cq)
|
|
||||||
ring.cq = 0
|
|
||||||
}
|
|
||||||
if ring.iocp != 0 {
|
|
||||||
windows.CloseHandle(ring.iocp)
|
|
||||||
ring.iocp = 0
|
|
||||||
}
|
|
||||||
if ring.id != 0 {
|
|
||||||
winrio.DeregisterBuffer(ring.id)
|
|
||||||
ring.id = 0
|
|
||||||
}
|
|
||||||
if ring.packets != 0 {
|
|
||||||
windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
|
|
||||||
ring.packets = 0
|
|
||||||
}
|
|
||||||
ring.head = 0
|
|
||||||
ring.tail = 0
|
|
||||||
ring.isFull = false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *afWinRingBind) CloseAndZero() {
|
|
||||||
bind.rx.CloseAndZero()
|
|
||||||
bind.tx.CloseAndZero()
|
|
||||||
if bind.sock != 0 {
|
|
||||||
windows.CloseHandle(bind.sock)
|
|
||||||
bind.sock = 0
|
|
||||||
}
|
|
||||||
bind.blackhole = false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *WinRingBind) closeAndZero() {
|
|
||||||
bind.isOpen.Store(0)
|
|
||||||
bind.v4.CloseAndZero()
|
|
||||||
bind.v6.CloseAndZero()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ring *ringBuffer) Open() error {
|
|
||||||
var err error
|
|
||||||
packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
|
|
||||||
ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
|
|
||||||
var err error
|
|
||||||
bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = bind.rx.Open()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = bind.tx.Open()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = windows.Bind(bind.sock, sa)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sa, err = windows.Getsockname(bind.sock)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return sa, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
|
|
||||||
bind.mu.Lock()
|
|
||||||
defer bind.mu.Unlock()
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
bind.closeAndZero()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
if bind.isOpen.Load() != 0 {
|
|
||||||
return nil, 0, ErrBindAlreadyOpen
|
|
||||||
}
|
|
||||||
var sa windows.Sockaddr
|
|
||||||
sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
|
|
||||||
for i := 0; i < packetsPerRing; i++ {
|
|
||||||
err = bind.v4.InsertReceiveRequest()
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
err = bind.v6.InsertReceiveRequest()
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bind.isOpen.Store(1)
|
|
||||||
return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *WinRingBind) Close() error {
|
|
||||||
bind.mu.RLock()
|
|
||||||
if bind.isOpen.Load() != 1 {
|
|
||||||
bind.mu.RUnlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
bind.isOpen.Store(2)
|
|
||||||
windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
|
|
||||||
windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
|
|
||||||
windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
|
|
||||||
windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
|
|
||||||
bind.mu.RUnlock()
|
|
||||||
bind.mu.Lock()
|
|
||||||
defer bind.mu.Unlock()
|
|
||||||
bind.closeAndZero()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
|
|
||||||
// rename the IdealBatchSize constant to BatchSize.
|
|
||||||
func (bind *WinRingBind) BatchSize() int {
|
|
||||||
// TODO: implement batching in and out of the ring
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *WinRingBind) SetMark(mark uint32) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *afWinRingBind) InsertReceiveRequest() error {
|
|
||||||
packet := bind.rx.Push()
|
|
||||||
dataBuffer := &winrio.Buffer{
|
|
||||||
Id: bind.rx.id,
|
|
||||||
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
|
|
||||||
Length: uint32(len(packet.data)),
|
|
||||||
}
|
|
||||||
addressBuffer := &winrio.Buffer{
|
|
||||||
Id: bind.rx.id,
|
|
||||||
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
|
|
||||||
Length: uint32(unsafe.Sizeof(packet.addr)),
|
|
||||||
}
|
|
||||||
bind.mu.Lock()
|
|
||||||
defer bind.mu.Unlock()
|
|
||||||
return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
|
|
||||||
}
|
|
||||||
|
|
||||||
//go:linkname procyield runtime.procyield
|
|
||||||
func procyield(cycles uint32)
|
|
||||||
|
|
||||||
func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
|
|
||||||
if isOpen.Load() != 1 {
|
|
||||||
return 0, nil, net.ErrClosed
|
|
||||||
}
|
|
||||||
bind.rx.mu.Lock()
|
|
||||||
defer bind.rx.mu.Unlock()
|
|
||||||
|
|
||||||
var err error
|
|
||||||
var count uint32
|
|
||||||
var results [1]winrio.Result
|
|
||||||
retry:
|
|
||||||
count = 0
|
|
||||||
for tries := 0; count == 0 && tries < receiveSpins; tries++ {
|
|
||||||
if tries > 0 {
|
|
||||||
if isOpen.Load() != 1 {
|
|
||||||
return 0, nil, net.ErrClosed
|
|
||||||
}
|
|
||||||
procyield(1)
|
|
||||||
}
|
|
||||||
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
|
||||||
}
|
|
||||||
if count == 0 {
|
|
||||||
err = winrio.Notify(bind.rx.cq)
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, err
|
|
||||||
}
|
|
||||||
var bytes uint32
|
|
||||||
var key uintptr
|
|
||||||
var overlapped *windows.Overlapped
|
|
||||||
err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, err
|
|
||||||
}
|
|
||||||
if isOpen.Load() != 1 {
|
|
||||||
return 0, nil, net.ErrClosed
|
|
||||||
}
|
|
||||||
count = winrio.DequeueCompletion(bind.rx.cq, results[:])
|
|
||||||
if count == 0 {
|
|
||||||
return 0, nil, io.ErrNoProgress
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bind.rx.Return(1)
|
|
||||||
err = bind.InsertReceiveRequest()
|
|
||||||
if err != nil {
|
|
||||||
return 0, nil, err
|
|
||||||
}
|
|
||||||
// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
|
|
||||||
// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
|
|
||||||
// attacker bandwidth, just like the rest of the receive path.
|
|
||||||
if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
|
|
||||||
if isOpen.Load() != 1 {
|
|
||||||
return 0, nil, net.ErrClosed
|
|
||||||
}
|
|
||||||
goto retry
|
|
||||||
}
|
|
||||||
if results[0].Status != 0 {
|
|
||||||
return 0, nil, windows.Errno(results[0].Status)
|
|
||||||
}
|
|
||||||
packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
|
|
||||||
ep := packet.addr
|
|
||||||
n := copy(buf, packet.data[:results[0].BytesTransferred])
|
|
||||||
return n, &ep, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
|
||||||
bind.mu.RLock()
|
|
||||||
defer bind.mu.RUnlock()
|
|
||||||
n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
|
|
||||||
sizes[0] = n
|
|
||||||
eps[0] = ep
|
|
||||||
return 1, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
|
|
||||||
bind.mu.RLock()
|
|
||||||
defer bind.mu.RUnlock()
|
|
||||||
n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
|
|
||||||
sizes[0] = n
|
|
||||||
eps[0] = ep
|
|
||||||
return 1, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
|
|
||||||
if isOpen.Load() != 1 {
|
|
||||||
return net.ErrClosed
|
|
||||||
}
|
|
||||||
if len(buf) > bytesPerPacket {
|
|
||||||
return io.ErrShortBuffer
|
|
||||||
}
|
|
||||||
bind.tx.mu.Lock()
|
|
||||||
defer bind.tx.mu.Unlock()
|
|
||||||
var results [packetsPerRing]winrio.Result
|
|
||||||
count := winrio.DequeueCompletion(bind.tx.cq, results[:])
|
|
||||||
if count == 0 && bind.tx.isFull {
|
|
||||||
err := winrio.Notify(bind.tx.cq)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var bytes uint32
|
|
||||||
var key uintptr
|
|
||||||
var overlapped *windows.Overlapped
|
|
||||||
err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if isOpen.Load() != 1 {
|
|
||||||
return net.ErrClosed
|
|
||||||
}
|
|
||||||
count = winrio.DequeueCompletion(bind.tx.cq, results[:])
|
|
||||||
if count == 0 {
|
|
||||||
return io.ErrNoProgress
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if count > 0 {
|
|
||||||
bind.tx.Return(count)
|
|
||||||
}
|
|
||||||
packet := bind.tx.Push()
|
|
||||||
packet.addr = *nend
|
|
||||||
copy(packet.data[:], buf)
|
|
||||||
dataBuffer := &winrio.Buffer{
|
|
||||||
Id: bind.tx.id,
|
|
||||||
Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
|
|
||||||
Length: uint32(len(buf)),
|
|
||||||
}
|
|
||||||
addressBuffer := &winrio.Buffer{
|
|
||||||
Id: bind.tx.id,
|
|
||||||
Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
|
|
||||||
Length: uint32(unsafe.Sizeof(packet.addr)),
|
|
||||||
}
|
|
||||||
bind.mu.Lock()
|
|
||||||
defer bind.mu.Unlock()
|
|
||||||
return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
|
|
||||||
nend, ok := endpoint.(*WinRingEndpoint)
|
|
||||||
if !ok {
|
|
||||||
return ErrWrongEndpointType
|
|
||||||
}
|
|
||||||
bind.mu.RLock()
|
|
||||||
defer bind.mu.RUnlock()
|
|
||||||
for _, buf := range bufs {
|
|
||||||
switch nend.family {
|
|
||||||
case windows.AF_INET:
|
|
||||||
if bind.v4.blackhole {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case windows.AF_INET6:
|
|
||||||
if bind.v6.blackhole {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
sysconn, err := s.ipv4.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err2 := sysconn.Control(func(fd uintptr) {
|
|
||||||
err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
|
|
||||||
})
|
|
||||||
if err2 != nil {
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.blackhole4 = blackhole
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
sysconn, err := s.ipv6.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err2 := sysconn.Control(func(fd uintptr) {
|
|
||||||
err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
|
|
||||||
})
|
|
||||||
if err2 != nil {
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.blackhole6 = blackhole
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
|
||||||
bind.mu.RLock()
|
|
||||||
defer bind.mu.RUnlock()
|
|
||||||
if bind.isOpen.Load() != 1 {
|
|
||||||
return net.ErrClosed
|
|
||||||
}
|
|
||||||
err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
bind.v4.blackhole = blackhole
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
|
||||||
bind.mu.RLock()
|
|
||||||
defer bind.mu.RUnlock()
|
|
||||||
if bind.isOpen.Load() != 1 {
|
|
||||||
return net.ErrClosed
|
|
||||||
}
|
|
||||||
err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
bind.v6.blackhole = blackhole
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
|
|
||||||
const IP_UNICAST_IF = 31
|
|
||||||
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
|
|
||||||
var bytes [4]byte
|
|
||||||
binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
|
|
||||||
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
|
|
||||||
err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
|
|
||||||
const IPV6_UNICAST_IF = 31
|
|
||||||
return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
|
|
||||||
}
|
|
|
@ -1,136 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package bindtest
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
)
|
|
||||||
|
|
||||||
type ChannelBind struct {
|
|
||||||
rx4, tx4 *chan []byte
|
|
||||||
rx6, tx6 *chan []byte
|
|
||||||
closeSignal chan bool
|
|
||||||
source4, source6 ChannelEndpoint
|
|
||||||
target4, target6 ChannelEndpoint
|
|
||||||
}
|
|
||||||
|
|
||||||
type ChannelEndpoint uint16
|
|
||||||
|
|
||||||
var (
|
|
||||||
_ conn.Bind = (*ChannelBind)(nil)
|
|
||||||
_ conn.Endpoint = (*ChannelEndpoint)(nil)
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewChannelBinds() [2]conn.Bind {
|
|
||||||
arx4 := make(chan []byte, 8192)
|
|
||||||
brx4 := make(chan []byte, 8192)
|
|
||||||
arx6 := make(chan []byte, 8192)
|
|
||||||
brx6 := make(chan []byte, 8192)
|
|
||||||
var binds [2]ChannelBind
|
|
||||||
binds[0].rx4 = &arx4
|
|
||||||
binds[0].tx4 = &brx4
|
|
||||||
binds[1].rx4 = &brx4
|
|
||||||
binds[1].tx4 = &arx4
|
|
||||||
binds[0].rx6 = &arx6
|
|
||||||
binds[0].tx6 = &brx6
|
|
||||||
binds[1].rx6 = &brx6
|
|
||||||
binds[1].tx6 = &arx6
|
|
||||||
binds[0].target4 = ChannelEndpoint(1)
|
|
||||||
binds[1].target4 = ChannelEndpoint(2)
|
|
||||||
binds[0].target6 = ChannelEndpoint(3)
|
|
||||||
binds[1].target6 = ChannelEndpoint(4)
|
|
||||||
binds[0].source4 = binds[1].target4
|
|
||||||
binds[0].source6 = binds[1].target6
|
|
||||||
binds[1].source4 = binds[0].target4
|
|
||||||
binds[1].source6 = binds[0].target6
|
|
||||||
return [2]conn.Bind{&binds[0], &binds[1]}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c ChannelEndpoint) ClearSrc() {}
|
|
||||||
|
|
||||||
func (c ChannelEndpoint) SrcToString() string { return "" }
|
|
||||||
|
|
||||||
func (c ChannelEndpoint) DstToString() string { return fmt.Sprintf("127.0.0.1:%d", c) }
|
|
||||||
|
|
||||||
func (c ChannelEndpoint) DstToBytes() []byte { return []byte{byte(c)} }
|
|
||||||
|
|
||||||
func (c ChannelEndpoint) DstIP() netip.Addr { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }
|
|
||||||
|
|
||||||
func (c ChannelEndpoint) SrcIP() netip.Addr { return netip.Addr{} }
|
|
||||||
|
|
||||||
func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
|
||||||
c.closeSignal = make(chan bool)
|
|
||||||
fns = append(fns, c.makeReceiveFunc(*c.rx4))
|
|
||||||
fns = append(fns, c.makeReceiveFunc(*c.rx6))
|
|
||||||
if rand.Uint32()&1 == 0 {
|
|
||||||
return fns, uint16(c.source4), nil
|
|
||||||
} else {
|
|
||||||
return fns, uint16(c.source6), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ChannelBind) Close() error {
|
|
||||||
if c.closeSignal != nil {
|
|
||||||
select {
|
|
||||||
case <-c.closeSignal:
|
|
||||||
default:
|
|
||||||
close(c.closeSignal)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ChannelBind) BatchSize() int { return 1 }
|
|
||||||
|
|
||||||
func (c *ChannelBind) SetMark(mark uint32) error { return nil }
|
|
||||||
|
|
||||||
func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
|
|
||||||
return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
|
|
||||||
select {
|
|
||||||
case <-c.closeSignal:
|
|
||||||
return 0, net.ErrClosed
|
|
||||||
case rx := <-ch:
|
|
||||||
copied := copy(bufs[0], rx)
|
|
||||||
sizes[0] = copied
|
|
||||||
eps[0] = c.target6
|
|
||||||
return 1, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ChannelBind) Send(bufs [][]byte, ep conn.Endpoint) error {
|
|
||||||
for _, b := range bufs {
|
|
||||||
select {
|
|
||||||
case <-c.closeSignal:
|
|
||||||
return net.ErrClosed
|
|
||||||
default:
|
|
||||||
bc := make([]byte, len(b))
|
|
||||||
copy(bc, b)
|
|
||||||
if ep.(ChannelEndpoint) == c.target4 {
|
|
||||||
*c.tx4 <- bc
|
|
||||||
} else if ep.(ChannelEndpoint) == c.target6 {
|
|
||||||
*c.tx6 <- bc
|
|
||||||
} else {
|
|
||||||
return os.ErrInvalid
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *ChannelBind) ParseEndpoint(s string) (conn.Endpoint, error) {
|
|
||||||
addr, err := netip.ParseAddrPort(s)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return ChannelEndpoint(addr.Port()), nil
|
|
||||||
}
|
|
|
@ -1,34 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
func (s *StdNetBind) PeekLookAtSocketFd4() (fd int, err error) {
|
|
||||||
sysconn, err := s.ipv4.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
err = sysconn.Control(func(f uintptr) {
|
|
||||||
fd = int(f)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) PeekLookAtSocketFd6() (fd int, err error) {
|
|
||||||
sysconn, err := s.ipv6.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
err = sysconn.Control(func(f uintptr) {
|
|
||||||
fd = int(f)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
133
conn/conn.go
133
conn/conn.go
|
@ -1,133 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Package conn implements WireGuard's network connections.
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/netip"
|
|
||||||
"reflect"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
IdealBatchSize = 128 // maximum number of packets handled per read and write
|
|
||||||
)
|
|
||||||
|
|
||||||
// A ReceiveFunc receives at least one packet from the network and writes them
|
|
||||||
// into packets. On a successful read it returns the number of elements of
|
|
||||||
// sizes, packets, and endpoints that should be evaluated. Some elements of
|
|
||||||
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
|
|
||||||
// and eps slice with a length greater than or equal to the length of packets.
|
|
||||||
// These lengths must not exceed the length of the associated Bind.BatchSize().
|
|
||||||
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
|
|
||||||
|
|
||||||
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
|
||||||
//
|
|
||||||
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
|
|
||||||
// depending on the platform-specific implementation.
|
|
||||||
type Bind interface {
|
|
||||||
// Open puts the Bind into a listening state on a given port and reports the actual
|
|
||||||
// port that it bound to. Passing zero results in a random selection.
|
|
||||||
// fns is the set of functions that will be called to receive packets.
|
|
||||||
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
|
|
||||||
|
|
||||||
// Close closes the Bind listener.
|
|
||||||
// All fns returned by Open must return net.ErrClosed after a call to Close.
|
|
||||||
Close() error
|
|
||||||
|
|
||||||
// SetMark sets the mark for each packet sent through this Bind.
|
|
||||||
// This mark is passed to the kernel as the socket option SO_MARK.
|
|
||||||
SetMark(mark uint32) error
|
|
||||||
|
|
||||||
// Send writes one or more packets in bufs to address ep. The length of
|
|
||||||
// bufs must not exceed BatchSize().
|
|
||||||
Send(bufs [][]byte, ep Endpoint) error
|
|
||||||
|
|
||||||
// ParseEndpoint creates a new endpoint from a string.
|
|
||||||
ParseEndpoint(s string) (Endpoint, error)
|
|
||||||
|
|
||||||
// BatchSize is the number of buffers expected to be passed to
|
|
||||||
// the ReceiveFuncs, and the maximum expected to be passed to SendBatch.
|
|
||||||
BatchSize() int
|
|
||||||
}
|
|
||||||
|
|
||||||
// BindSocketToInterface is implemented by Bind objects that support being
|
|
||||||
// tied to a single network interface. Used by wireguard-windows.
|
|
||||||
type BindSocketToInterface interface {
|
|
||||||
BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
|
|
||||||
BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// PeekLookAtSocketFd is implemented by Bind objects that support having their
|
|
||||||
// file descriptor peeked at. Used by wireguard-android.
|
|
||||||
type PeekLookAtSocketFd interface {
|
|
||||||
PeekLookAtSocketFd4() (fd int, err error)
|
|
||||||
PeekLookAtSocketFd6() (fd int, err error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// An Endpoint maintains the source/destination caching for a peer.
|
|
||||||
//
|
|
||||||
// dst: the remote address of a peer ("endpoint" in uapi terminology)
|
|
||||||
// src: the local address from which datagrams originate going to the peer
|
|
||||||
type Endpoint interface {
|
|
||||||
ClearSrc() // clears the source address
|
|
||||||
SrcToString() string // returns the local source address (ip:port)
|
|
||||||
DstToString() string // returns the destination address (ip:port)
|
|
||||||
DstToBytes() []byte // used for mac2 cookie calculations
|
|
||||||
DstIP() netip.Addr
|
|
||||||
SrcIP() netip.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrBindAlreadyOpen = errors.New("bind is already open")
|
|
||||||
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
|
|
||||||
)
|
|
||||||
|
|
||||||
func (fn ReceiveFunc) PrettyName() string {
|
|
||||||
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
|
|
||||||
// 0. cheese/taco.beansIPv6.func12.func21218-fm
|
|
||||||
name = strings.TrimSuffix(name, "-fm")
|
|
||||||
// 1. cheese/taco.beansIPv6.func12.func21218
|
|
||||||
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
|
|
||||||
name = name[idx+1:]
|
|
||||||
// 2. taco.beansIPv6.func12.func21218
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
var idx int
|
|
||||||
for idx = len(name) - 1; idx >= 0; idx-- {
|
|
||||||
if name[idx] < '0' || name[idx] > '9' {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if idx == len(name)-1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
const dotFunc = ".func"
|
|
||||||
if !strings.HasSuffix(name[:idx+1], dotFunc) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
name = name[:idx+1-len(dotFunc)]
|
|
||||||
// 3. taco.beansIPv6.func12
|
|
||||||
// 4. taco.beansIPv6
|
|
||||||
}
|
|
||||||
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
|
|
||||||
name = name[idx+1:]
|
|
||||||
// 5. beansIPv6
|
|
||||||
}
|
|
||||||
if name == "" {
|
|
||||||
return fmt.Sprintf("%p", fn)
|
|
||||||
}
|
|
||||||
if strings.HasSuffix(name, "IPv4") {
|
|
||||||
return "v4"
|
|
||||||
}
|
|
||||||
if strings.HasSuffix(name, "IPv6") {
|
|
||||||
return "v6"
|
|
||||||
}
|
|
||||||
return name
|
|
||||||
}
|
|
|
@ -1,24 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPrettyName(t *testing.T) {
|
|
||||||
var (
|
|
||||||
recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
|
|
||||||
)
|
|
||||||
|
|
||||||
const want = "TestPrettyName"
|
|
||||||
|
|
||||||
t.Run("ReceiveFunc.PrettyName", func(t *testing.T) {
|
|
||||||
if got := recvFunc.PrettyName(); got != want {
|
|
||||||
t.Errorf("PrettyName() = %v, want %v", got, want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -1,43 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"syscall"
|
|
||||||
)
|
|
||||||
|
|
||||||
// UDP socket read/write buffer size (7MB). The value of 7MB is chosen as it is
|
|
||||||
// the max supported by a default configuration of macOS. Some platforms will
|
|
||||||
// silently clamp the value to other maximums, such as linux clamping to
|
|
||||||
// net.core.{r,w}mem_max (see _linux.go for additional implementation that works
|
|
||||||
// around this limitation)
|
|
||||||
const socketBufferSize = 7 << 20
|
|
||||||
|
|
||||||
// controlFn is the callback function signature from net.ListenConfig.Control.
|
|
||||||
// It is used to apply platform specific configuration to the socket prior to
|
|
||||||
// bind.
|
|
||||||
type controlFn func(network, address string, c syscall.RawConn) error
|
|
||||||
|
|
||||||
// controlFns is a list of functions that are called from the listen config
|
|
||||||
// that can apply socket options.
|
|
||||||
var controlFns = []controlFn{}
|
|
||||||
|
|
||||||
// listenConfig returns a net.ListenConfig that applies the controlFns to the
|
|
||||||
// socket prior to bind. This is used to apply socket buffer sizing and packet
|
|
||||||
// information OOB configuration for sticky sockets.
|
|
||||||
func listenConfig() *net.ListenConfig {
|
|
||||||
return &net.ListenConfig{
|
|
||||||
Control: func(network, address string, c syscall.RawConn) error {
|
|
||||||
for _, fn := range controlFns {
|
|
||||||
if err := fn(network, address, c); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,61 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
controlFns = append(controlFns,
|
|
||||||
|
|
||||||
// Attempt to set the socket buffer size beyond net.core.{r,w}mem_max by
|
|
||||||
// using SO_*BUFFORCE. This requires CAP_NET_ADMIN, and is allowed here to
|
|
||||||
// fail silently - the result of failure is lower performance on very fast
|
|
||||||
// links or high latency links.
|
|
||||||
func(network, address string, c syscall.RawConn) error {
|
|
||||||
return c.Control(func(fd uintptr) {
|
|
||||||
// Set up to *mem_max
|
|
||||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
|
||||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
|
||||||
// Set beyond *mem_max if CAP_NET_ADMIN
|
|
||||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUFFORCE, socketBufferSize)
|
|
||||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUFFORCE, socketBufferSize)
|
|
||||||
})
|
|
||||||
},
|
|
||||||
|
|
||||||
// Enable receiving of the packet information (IP_PKTINFO for IPv4,
|
|
||||||
// IPV6_PKTINFO for IPv6) that is used to implement sticky socket support.
|
|
||||||
func(network, address string, c syscall.RawConn) error {
|
|
||||||
var err error
|
|
||||||
switch network {
|
|
||||||
case "udp4":
|
|
||||||
if runtime.GOOS != "android" {
|
|
||||||
c.Control(func(fd uintptr) {
|
|
||||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO, 1)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
case "udp6":
|
|
||||||
c.Control(func(fd uintptr) {
|
|
||||||
if runtime.GOOS != "android" {
|
|
||||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO, 1)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
|
||||||
})
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("unhandled network: %s: %w", network, unix.EINVAL)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
|
@ -1,35 +0,0 @@
|
||||||
//go:build !windows && !linux && !wasm
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
controlFns = append(controlFns,
|
|
||||||
func(network, address string, c syscall.RawConn) error {
|
|
||||||
return c.Control(func(fd uintptr) {
|
|
||||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_RCVBUF, socketBufferSize)
|
|
||||||
_ = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, unix.SO_SNDBUF, socketBufferSize)
|
|
||||||
})
|
|
||||||
},
|
|
||||||
|
|
||||||
func(network, address string, c syscall.RawConn) error {
|
|
||||||
var err error
|
|
||||||
if network == "udp6" {
|
|
||||||
c.Control(func(fd uintptr) {
|
|
||||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_V6ONLY, 1)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
|
@ -1,23 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
controlFns = append(controlFns,
|
|
||||||
func(network, address string, c syscall.RawConn) error {
|
|
||||||
return c.Control(func(fd uintptr) {
|
|
||||||
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_RCVBUF, socketBufferSize)
|
|
||||||
_ = windows.SetsockoptInt(windows.Handle(fd), windows.SOL_SOCKET, windows.SO_SNDBUF, socketBufferSize)
|
|
||||||
})
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
|
@ -1,10 +0,0 @@
|
||||||
//go:build !windows
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
func NewDefaultBind() Bind { return NewStdNetBind() }
|
|
|
@ -1,12 +0,0 @@
|
||||||
//go:build !linux
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
func errShouldDisableUDPGSO(err error) bool {
|
|
||||||
return false
|
|
||||||
}
|
|
|
@ -1,28 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func errShouldDisableUDPGSO(err error) bool {
|
|
||||||
var serr *os.SyscallError
|
|
||||||
if errors.As(err, &serr) {
|
|
||||||
// EIO is returned by udp_send_skb() if the device driver does not have
|
|
||||||
// tx checksumming enabled, which is a hard requirement of UDP_SEGMENT.
|
|
||||||
// See:
|
|
||||||
// https://git.kernel.org/pub/scm/docs/man-pages/man-pages.git/tree/man7/udp.7?id=806eabd74910447f21005160e90957bde4db0183#n228
|
|
||||||
// https://git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git/tree/net/ipv4/udp.c?h=v6.2&id=c9c3395d5e3dcc6daee66c6908354d47bf98cb0c#n942
|
|
||||||
// If gso_size + udp + ip headers > fragment size EINVAL is returned.
|
|
||||||
// It occurs when the peer mtu + wg headers is greater than path mtu.
|
|
||||||
return serr.Err == unix.EIO || serr.Err == unix.EINVAL
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
|
@ -1,15 +0,0 @@
|
||||||
//go:build !linux
|
|
||||||
// +build !linux
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import "net"
|
|
||||||
|
|
||||||
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
|
||||||
return
|
|
||||||
}
|
|
|
@ -1,31 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func supportsUDPOffload(conn *net.UDPConn) (txOffload, rxOffload bool) {
|
|
||||||
rc, err := conn.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = rc.Control(func(fd uintptr) {
|
|
||||||
_, errSyscall := unix.GetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_SEGMENT)
|
|
||||||
txOffload = errSyscall == nil
|
|
||||||
// getsockopt(IPPROTO_UDP, UDP_GRO) is not supported in android
|
|
||||||
// use setsockopt workaround
|
|
||||||
errSyscall = unix.SetsockoptInt(int(fd), unix.IPPROTO_UDP, unix.UDP_GRO, 1)
|
|
||||||
rxOffload = errSyscall == nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return false, false
|
|
||||||
}
|
|
||||||
return txOffload, rxOffload
|
|
||||||
}
|
|
|
@ -1,21 +0,0 @@
|
||||||
//go:build !linux
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
|
||||||
func getGSOSize(control []byte) (int, error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize.
|
|
||||||
func setGSOSize(control *[]byte, gsoSize uint16) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// gsoControlSize returns the recommended buffer size for pooling sticky and UDP
|
|
||||||
// offloading control data.
|
|
||||||
const gsoControlSize = 0
|
|
|
@ -1,65 +0,0 @@
|
||||||
//go:build linux
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
sizeOfGSOData = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
// getGSOSize parses control for UDP_GRO and if found returns its GSO size data.
|
|
||||||
func getGSOSize(control []byte) (int, error) {
|
|
||||||
var (
|
|
||||||
hdr unix.Cmsghdr
|
|
||||||
data []byte
|
|
||||||
rem = control
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
for len(rem) > unix.SizeofCmsghdr {
|
|
||||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("error parsing socket control message: %w", err)
|
|
||||||
}
|
|
||||||
if hdr.Level == unix.SOL_UDP && hdr.Type == unix.UDP_GRO && len(data) >= sizeOfGSOData {
|
|
||||||
var gso uint16
|
|
||||||
copy(unsafe.Slice((*byte)(unsafe.Pointer(&gso)), sizeOfGSOData), data[:sizeOfGSOData])
|
|
||||||
return int(gso), nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setGSOSize sets a UDP_SEGMENT in control based on gsoSize. It leaves existing
|
|
||||||
// data in control untouched.
|
|
||||||
func setGSOSize(control *[]byte, gsoSize uint16) {
|
|
||||||
existingLen := len(*control)
|
|
||||||
avail := cap(*control) - existingLen
|
|
||||||
space := unix.CmsgSpace(sizeOfGSOData)
|
|
||||||
if avail < space {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
*control = (*control)[:cap(*control)]
|
|
||||||
gsoControl := (*control)[existingLen:]
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&(gsoControl)[0]))
|
|
||||||
hdr.Level = unix.SOL_UDP
|
|
||||||
hdr.Type = unix.UDP_SEGMENT
|
|
||||||
hdr.SetLen(unix.CmsgLen(sizeOfGSOData))
|
|
||||||
copy((gsoControl)[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&gsoSize)), sizeOfGSOData))
|
|
||||||
*control = (*control)[:existingLen+space]
|
|
||||||
}
|
|
||||||
|
|
||||||
// gsoControlSize returns the recommended buffer size for pooling UDP
|
|
||||||
// offloading control data.
|
|
||||||
var gsoControlSize = unix.CmsgSpace(sizeOfGSOData)
|
|
|
@ -1,12 +0,0 @@
|
||||||
//go:build !linux && !openbsd && !freebsd
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
func (s *StdNetBind) SetMark(mark uint32) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,65 +0,0 @@
|
||||||
//go:build linux || openbsd || freebsd
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
var fwmarkIoctl int
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
switch runtime.GOOS {
|
|
||||||
case "linux", "android":
|
|
||||||
fwmarkIoctl = 36 /* unix.SO_MARK */
|
|
||||||
case "freebsd":
|
|
||||||
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
|
|
||||||
case "openbsd":
|
|
||||||
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *StdNetBind) SetMark(mark uint32) error {
|
|
||||||
var operr error
|
|
||||||
if fwmarkIoctl == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if s.ipv4 != nil {
|
|
||||||
fd, err := s.ipv4.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = fd.Control(func(fd uintptr) {
|
|
||||||
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
err = operr
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if s.ipv6 != nil {
|
|
||||||
fd, err := s.ipv6.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
err = fd.Control(func(fd uintptr) {
|
|
||||||
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
|
|
||||||
})
|
|
||||||
if err == nil {
|
|
||||||
err = operr
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,42 +0,0 @@
|
||||||
//go:build !linux || android
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import "net/netip"
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
|
||||||
return netip.Addr{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcToString() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: macOS, FreeBSD and other BSDs likely do support the sticky sockets
|
|
||||||
// {get,set}srcControl feature set, but use alternatively named flags and need
|
|
||||||
// ports and require testing.
|
|
||||||
|
|
||||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
|
||||||
// the source information found.
|
|
||||||
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// setSrcControl parses the control for PKTINFO and if found updates ep with
|
|
||||||
// the source information found.
|
|
||||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
|
||||||
}
|
|
||||||
|
|
||||||
// stickyControlSize returns the recommended buffer size for pooling sticky
|
|
||||||
// offloading control data.
|
|
||||||
const stickyControlSize = 0
|
|
||||||
|
|
||||||
const StdNetSupportsStickySockets = false
|
|
|
@ -1,112 +0,0 @@
|
||||||
//go:build linux && !android
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/netip"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcIP() netip.Addr {
|
|
||||||
switch len(e.src) {
|
|
||||||
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
|
|
||||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
|
||||||
return netip.AddrFrom4(info.Spec_dst)
|
|
||||||
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
|
|
||||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
|
||||||
// TODO: set zone. in order to do so we need to check if the address is
|
|
||||||
// link local, and if it is perform a syscall to turn the ifindex into a
|
|
||||||
// zone string because netip uses string zones.
|
|
||||||
return netip.AddrFrom16(info.Addr)
|
|
||||||
}
|
|
||||||
return netip.Addr{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcIfidx() int32 {
|
|
||||||
switch len(e.src) {
|
|
||||||
case unix.CmsgSpace(unix.SizeofInet4Pktinfo):
|
|
||||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
|
||||||
return info.Ifindex
|
|
||||||
case unix.CmsgSpace(unix.SizeofInet6Pktinfo):
|
|
||||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&e.src[unix.CmsgLen(0)]))
|
|
||||||
return int32(info.Ifindex)
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *StdNetEndpoint) SrcToString() string {
|
|
||||||
return e.SrcIP().String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// getSrcFromControl parses the control for PKTINFO and if found updates ep with
|
|
||||||
// the source information found.
|
|
||||||
func getSrcFromControl(control []byte, ep *StdNetEndpoint) {
|
|
||||||
ep.ClearSrc()
|
|
||||||
|
|
||||||
var (
|
|
||||||
hdr unix.Cmsghdr
|
|
||||||
data []byte
|
|
||||||
rem []byte = control
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
for len(rem) > unix.SizeofCmsghdr {
|
|
||||||
hdr, data, rem, err = unix.ParseOneSocketControlMessage(rem)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if hdr.Level == unix.IPPROTO_IP &&
|
|
||||||
hdr.Type == unix.IP_PKTINFO {
|
|
||||||
|
|
||||||
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet4Pktinfo) {
|
|
||||||
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
|
||||||
}
|
|
||||||
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet4Pktinfo)]
|
|
||||||
|
|
||||||
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
|
|
||||||
copy(ep.src, hdrBuf)
|
|
||||||
copy(ep.src[unix.CmsgLen(0):], data)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if hdr.Level == unix.IPPROTO_IPV6 &&
|
|
||||||
hdr.Type == unix.IPV6_PKTINFO {
|
|
||||||
|
|
||||||
if ep.src == nil || cap(ep.src) < unix.CmsgSpace(unix.SizeofInet6Pktinfo) {
|
|
||||||
ep.src = make([]byte, 0, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
|
|
||||||
}
|
|
||||||
|
|
||||||
ep.src = ep.src[:unix.CmsgSpace(unix.SizeofInet6Pktinfo)]
|
|
||||||
|
|
||||||
hdrBuf := unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), unix.SizeofCmsghdr)
|
|
||||||
copy(ep.src, hdrBuf)
|
|
||||||
copy(ep.src[unix.CmsgLen(0):], data)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// setSrcControl sets an IP{V6}_PKTINFO in control based on the source address
|
|
||||||
// and source ifindex found in ep. control's len will be set to 0 in the event
|
|
||||||
// that ep is a default value.
|
|
||||||
func setSrcControl(control *[]byte, ep *StdNetEndpoint) {
|
|
||||||
if cap(*control) < len(ep.src) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
*control = (*control)[:0]
|
|
||||||
*control = append(*control, ep.src...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// stickyControlSize returns the recommended buffer size for pooling sticky
|
|
||||||
// offloading control data.
|
|
||||||
var stickyControlSize = unix.CmsgSpace(unix.SizeofInet6Pktinfo)
|
|
||||||
|
|
||||||
const StdNetSupportsStickySockets = true
|
|
|
@ -1,266 +0,0 @@
|
||||||
//go:build linux && !android
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package conn
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setSrc(ep *StdNetEndpoint, addr netip.Addr, ifidx int32) {
|
|
||||||
var buf []byte
|
|
||||||
if addr.Is4() {
|
|
||||||
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
|
||||||
hdr := unix.Cmsghdr{
|
|
||||||
Level: unix.IPPROTO_IP,
|
|
||||||
Type: unix.IP_PKTINFO,
|
|
||||||
}
|
|
||||||
hdr.SetLen(unix.CmsgLen(unix.SizeofInet4Pktinfo))
|
|
||||||
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
|
|
||||||
|
|
||||||
info := unix.Inet4Pktinfo{
|
|
||||||
Ifindex: ifidx,
|
|
||||||
Spec_dst: addr.As4(),
|
|
||||||
}
|
|
||||||
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet4Pktinfo))
|
|
||||||
} else {
|
|
||||||
buf = make([]byte, unix.CmsgSpace(unix.SizeofInet6Pktinfo))
|
|
||||||
hdr := unix.Cmsghdr{
|
|
||||||
Level: unix.IPPROTO_IPV6,
|
|
||||||
Type: unix.IPV6_PKTINFO,
|
|
||||||
}
|
|
||||||
hdr.SetLen(unix.CmsgLen(unix.SizeofInet6Pktinfo))
|
|
||||||
copy(buf, unsafe.Slice((*byte)(unsafe.Pointer(&hdr)), int(unsafe.Sizeof(hdr))))
|
|
||||||
|
|
||||||
info := unix.Inet6Pktinfo{
|
|
||||||
Ifindex: uint32(ifidx),
|
|
||||||
Addr: addr.As16(),
|
|
||||||
}
|
|
||||||
copy(buf[unix.CmsgLen(0):], unsafe.Slice((*byte)(unsafe.Pointer(&info)), unix.SizeofInet6Pktinfo))
|
|
||||||
}
|
|
||||||
|
|
||||||
ep.src = buf
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_setSrcControl(t *testing.T) {
|
|
||||||
t.Run("IPv4", func(t *testing.T) {
|
|
||||||
ep := &StdNetEndpoint{
|
|
||||||
AddrPort: netip.MustParseAddrPort("127.0.0.1:1234"),
|
|
||||||
}
|
|
||||||
setSrc(ep, netip.MustParseAddr("127.0.0.1"), 5)
|
|
||||||
|
|
||||||
control := make([]byte, stickyControlSize)
|
|
||||||
|
|
||||||
setSrcControl(&control, ep)
|
|
||||||
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
|
||||||
if hdr.Level != unix.IPPROTO_IP {
|
|
||||||
t.Errorf("unexpected level: %d", hdr.Level)
|
|
||||||
}
|
|
||||||
if hdr.Type != unix.IP_PKTINFO {
|
|
||||||
t.Errorf("unexpected type: %d", hdr.Type)
|
|
||||||
}
|
|
||||||
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{})))) {
|
|
||||||
t.Errorf("unexpected length: %d", hdr.Len)
|
|
||||||
}
|
|
||||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
|
||||||
if info.Spec_dst[0] != 127 || info.Spec_dst[1] != 0 || info.Spec_dst[2] != 0 || info.Spec_dst[3] != 1 {
|
|
||||||
t.Errorf("unexpected address: %v", info.Spec_dst)
|
|
||||||
}
|
|
||||||
if info.Ifindex != 5 {
|
|
||||||
t.Errorf("unexpected ifindex: %d", info.Ifindex)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("IPv6", func(t *testing.T) {
|
|
||||||
ep := &StdNetEndpoint{
|
|
||||||
AddrPort: netip.MustParseAddrPort("[::1]:1234"),
|
|
||||||
}
|
|
||||||
setSrc(ep, netip.MustParseAddr("::1"), 5)
|
|
||||||
|
|
||||||
control := make([]byte, stickyControlSize)
|
|
||||||
|
|
||||||
setSrcControl(&control, ep)
|
|
||||||
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
|
||||||
if hdr.Level != unix.IPPROTO_IPV6 {
|
|
||||||
t.Errorf("unexpected level: %d", hdr.Level)
|
|
||||||
}
|
|
||||||
if hdr.Type != unix.IPV6_PKTINFO {
|
|
||||||
t.Errorf("unexpected type: %d", hdr.Type)
|
|
||||||
}
|
|
||||||
if uint(hdr.Len) != uint(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{})))) {
|
|
||||||
t.Errorf("unexpected length: %d", hdr.Len)
|
|
||||||
}
|
|
||||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
|
||||||
if info.Addr != ep.SrcIP().As16() {
|
|
||||||
t.Errorf("unexpected address: %v", info.Addr)
|
|
||||||
}
|
|
||||||
if info.Ifindex != 5 {
|
|
||||||
t.Errorf("unexpected ifindex: %d", info.Ifindex)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("ClearOnNoSrc", func(t *testing.T) {
|
|
||||||
control := make([]byte, stickyControlSize)
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
|
||||||
hdr.Level = 1
|
|
||||||
hdr.Type = 2
|
|
||||||
hdr.Len = 3
|
|
||||||
|
|
||||||
setSrcControl(&control, &StdNetEndpoint{})
|
|
||||||
|
|
||||||
if len(control) != 0 {
|
|
||||||
t.Errorf("unexpected control: %v", control)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_getSrcFromControl(t *testing.T) {
|
|
||||||
t.Run("IPv4", func(t *testing.T) {
|
|
||||||
control := make([]byte, stickyControlSize)
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
|
||||||
hdr.Level = unix.IPPROTO_IP
|
|
||||||
hdr.Type = unix.IP_PKTINFO
|
|
||||||
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
|
|
||||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
|
||||||
info.Spec_dst = [4]byte{127, 0, 0, 1}
|
|
||||||
info.Ifindex = 5
|
|
||||||
|
|
||||||
ep := &StdNetEndpoint{}
|
|
||||||
getSrcFromControl(control, ep)
|
|
||||||
|
|
||||||
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
|
|
||||||
t.Errorf("unexpected address: %v", ep.SrcIP())
|
|
||||||
}
|
|
||||||
if ep.SrcIfidx() != 5 {
|
|
||||||
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
t.Run("IPv6", func(t *testing.T) {
|
|
||||||
control := make([]byte, stickyControlSize)
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
|
||||||
hdr.Level = unix.IPPROTO_IPV6
|
|
||||||
hdr.Type = unix.IPV6_PKTINFO
|
|
||||||
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet6Pktinfo{}))))
|
|
||||||
info := (*unix.Inet6Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
|
||||||
info.Addr = [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}
|
|
||||||
info.Ifindex = 5
|
|
||||||
|
|
||||||
ep := &StdNetEndpoint{}
|
|
||||||
getSrcFromControl(control, ep)
|
|
||||||
|
|
||||||
if ep.SrcIP() != netip.MustParseAddr("::1") {
|
|
||||||
t.Errorf("unexpected address: %v", ep.SrcIP())
|
|
||||||
}
|
|
||||||
if ep.SrcIfidx() != 5 {
|
|
||||||
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
t.Run("ClearOnEmpty", func(t *testing.T) {
|
|
||||||
var control []byte
|
|
||||||
ep := &StdNetEndpoint{}
|
|
||||||
setSrc(ep, netip.MustParseAddr("::1"), 5)
|
|
||||||
|
|
||||||
getSrcFromControl(control, ep)
|
|
||||||
if ep.SrcIP().IsValid() {
|
|
||||||
t.Errorf("unexpected address: %v", ep.SrcIP())
|
|
||||||
}
|
|
||||||
if ep.SrcIfidx() != 0 {
|
|
||||||
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
t.Run("Multiple", func(t *testing.T) {
|
|
||||||
zeroControl := make([]byte, unix.CmsgSpace(0))
|
|
||||||
zeroHdr := (*unix.Cmsghdr)(unsafe.Pointer(&zeroControl[0]))
|
|
||||||
zeroHdr.SetLen(unix.CmsgLen(0))
|
|
||||||
|
|
||||||
control := make([]byte, unix.CmsgSpace(unix.SizeofInet4Pktinfo))
|
|
||||||
hdr := (*unix.Cmsghdr)(unsafe.Pointer(&control[0]))
|
|
||||||
hdr.Level = unix.IPPROTO_IP
|
|
||||||
hdr.Type = unix.IP_PKTINFO
|
|
||||||
hdr.SetLen(unix.CmsgLen(int(unsafe.Sizeof(unix.Inet4Pktinfo{}))))
|
|
||||||
info := (*unix.Inet4Pktinfo)(unsafe.Pointer(&control[unix.CmsgLen(0)]))
|
|
||||||
info.Spec_dst = [4]byte{127, 0, 0, 1}
|
|
||||||
info.Ifindex = 5
|
|
||||||
|
|
||||||
combined := make([]byte, 0)
|
|
||||||
combined = append(combined, zeroControl...)
|
|
||||||
combined = append(combined, control...)
|
|
||||||
|
|
||||||
ep := &StdNetEndpoint{}
|
|
||||||
getSrcFromControl(combined, ep)
|
|
||||||
|
|
||||||
if ep.SrcIP() != netip.MustParseAddr("127.0.0.1") {
|
|
||||||
t.Errorf("unexpected address: %v", ep.SrcIP())
|
|
||||||
}
|
|
||||||
if ep.SrcIfidx() != 5 {
|
|
||||||
t.Errorf("unexpected ifindex: %d", ep.SrcIfidx())
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_listenConfig(t *testing.T) {
|
|
||||||
t.Run("IPv4", func(t *testing.T) {
|
|
||||||
conn, err := listenConfig().ListenPacket(context.Background(), "udp4", ":0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
sc, err := conn.(*net.UDPConn).SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
var i int
|
|
||||||
sc.Control(func(fd uintptr) {
|
|
||||||
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_PKTINFO)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if i != 1 {
|
|
||||||
t.Error("IP_PKTINFO not set!")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
t.Run("IPv6", func(t *testing.T) {
|
|
||||||
conn, err := listenConfig().ListenPacket(context.Background(), "udp6", ":0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
sc, err := conn.(*net.UDPConn).SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
var i int
|
|
||||||
sc.Control(func(fd uintptr) {
|
|
||||||
i, err = unix.GetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_RECVPKTINFO)
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if i != 1 {
|
|
||||||
t.Error("IPV6_PKTINFO not set!")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
t.Logf("listenConfig() does not set IPV6_RECVPKTINFO on %s", runtime.GOOS)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -1,254 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package winrio
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
MsgDontNotify = 1
|
|
||||||
MsgDefer = 2
|
|
||||||
MsgWaitAll = 4
|
|
||||||
MsgCommitOnly = 8
|
|
||||||
|
|
||||||
MaxCqSize = 0x8000000
|
|
||||||
|
|
||||||
invalidBufferId = 0xFFFFFFFF
|
|
||||||
invalidCq = 0
|
|
||||||
invalidRq = 0
|
|
||||||
corruptCq = 0xFFFFFFFF
|
|
||||||
)
|
|
||||||
|
|
||||||
var extensionFunctionTable struct {
|
|
||||||
cbSize uint32
|
|
||||||
rioReceive uintptr
|
|
||||||
rioReceiveEx uintptr
|
|
||||||
rioSend uintptr
|
|
||||||
rioSendEx uintptr
|
|
||||||
rioCloseCompletionQueue uintptr
|
|
||||||
rioCreateCompletionQueue uintptr
|
|
||||||
rioCreateRequestQueue uintptr
|
|
||||||
rioDequeueCompletion uintptr
|
|
||||||
rioDeregisterBuffer uintptr
|
|
||||||
rioNotify uintptr
|
|
||||||
rioRegisterBuffer uintptr
|
|
||||||
rioResizeCompletionQueue uintptr
|
|
||||||
rioResizeRequestQueue uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
type Cq uintptr
|
|
||||||
|
|
||||||
type Rq uintptr
|
|
||||||
|
|
||||||
type BufferId uintptr
|
|
||||||
|
|
||||||
type Buffer struct {
|
|
||||||
Id BufferId
|
|
||||||
Offset uint32
|
|
||||||
Length uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type Result struct {
|
|
||||||
Status int32
|
|
||||||
BytesTransferred uint32
|
|
||||||
SocketContext uint64
|
|
||||||
RequestContext uint64
|
|
||||||
}
|
|
||||||
|
|
||||||
type notificationCompletionType uint32
|
|
||||||
|
|
||||||
const (
|
|
||||||
eventCompletion notificationCompletionType = 1
|
|
||||||
iocpCompletion notificationCompletionType = 2
|
|
||||||
)
|
|
||||||
|
|
||||||
type eventNotificationCompletion struct {
|
|
||||||
completionType notificationCompletionType
|
|
||||||
event windows.Handle
|
|
||||||
notifyReset uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
type iocpNotificationCompletion struct {
|
|
||||||
completionType notificationCompletionType
|
|
||||||
iocp windows.Handle
|
|
||||||
key uintptr
|
|
||||||
overlapped *windows.Overlapped
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
initialized sync.Once
|
|
||||||
available bool
|
|
||||||
)
|
|
||||||
|
|
||||||
func Initialize() bool {
|
|
||||||
initialized.Do(func() {
|
|
||||||
var (
|
|
||||||
err error
|
|
||||||
socket windows.Handle
|
|
||||||
cq Cq
|
|
||||||
)
|
|
||||||
defer func() {
|
|
||||||
if err == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Printf("Registered I/O is unavailable: %v", err)
|
|
||||||
}()
|
|
||||||
socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer windows.CloseHandle(socket)
|
|
||||||
WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
|
|
||||||
const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
|
|
||||||
ob := uint32(0)
|
|
||||||
err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
|
|
||||||
(*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
|
|
||||||
(*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
|
|
||||||
&ob, nil, 0)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
|
|
||||||
// failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
|
|
||||||
var iocp windows.Handle
|
|
||||||
iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer windows.CloseHandle(iocp)
|
|
||||||
var overlapped windows.Overlapped
|
|
||||||
cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer CloseCompletionQueue(cq)
|
|
||||||
_, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
available = true
|
|
||||||
})
|
|
||||||
return available
|
|
||||||
}
|
|
||||||
|
|
||||||
func Socket(af, typ, proto int32) (windows.Handle, error) {
|
|
||||||
return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
|
|
||||||
}
|
|
||||||
|
|
||||||
func CloseCompletionQueue(cq Cq) {
|
|
||||||
_, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
|
|
||||||
notificationCompletion := &eventNotificationCompletion{
|
|
||||||
completionType: eventCompletion,
|
|
||||||
event: event,
|
|
||||||
}
|
|
||||||
if notifyReset {
|
|
||||||
notificationCompletion.notifyReset = 1
|
|
||||||
}
|
|
||||||
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
|
|
||||||
if ret == invalidCq {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return Cq(ret), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
|
|
||||||
notificationCompletion := &iocpNotificationCompletion{
|
|
||||||
completionType: iocpCompletion,
|
|
||||||
iocp: iocp,
|
|
||||||
key: key,
|
|
||||||
overlapped: overlapped,
|
|
||||||
}
|
|
||||||
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
|
|
||||||
if ret == invalidCq {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return Cq(ret), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
|
|
||||||
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
|
|
||||||
if ret == invalidCq {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return Cq(ret), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
|
|
||||||
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
|
|
||||||
if ret == invalidRq {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return Rq(ret), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DequeueCompletion(cq Cq, results []Result) uint32 {
|
|
||||||
var array uintptr
|
|
||||||
if len(results) > 0 {
|
|
||||||
array = uintptr(unsafe.Pointer(&results[0]))
|
|
||||||
}
|
|
||||||
ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
|
|
||||||
if ret == corruptCq {
|
|
||||||
panic("cq is corrupt")
|
|
||||||
}
|
|
||||||
return uint32(ret)
|
|
||||||
}
|
|
||||||
|
|
||||||
func DeregisterBuffer(id BufferId) {
|
|
||||||
_, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
func RegisterBuffer(buffer []byte) (BufferId, error) {
|
|
||||||
var buf unsafe.Pointer
|
|
||||||
if len(buffer) > 0 {
|
|
||||||
buf = unsafe.Pointer(&buffer[0])
|
|
||||||
}
|
|
||||||
return RegisterPointer(buf, uint32(len(buffer)))
|
|
||||||
}
|
|
||||||
|
|
||||||
func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
|
|
||||||
ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
|
|
||||||
if ret == invalidBufferId {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return BufferId(ret), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
|
|
||||||
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
|
|
||||||
if ret == 0 {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
|
|
||||||
ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
|
|
||||||
if ret == 0 {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Notify(cq Cq) error {
|
|
||||||
ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
|
|
||||||
if ret != 0 {
|
|
||||||
return windows.Errno(ret)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
136
conn_default.go
Normal file
136
conn_default.go
Normal file
|
@ -0,0 +1,136 @@
|
||||||
|
// +build !linux
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* This code is meant to be a temporary solution
|
||||||
|
* on platforms for which the sticky socket / source caching behavior
|
||||||
|
* has not yet been implemented.
|
||||||
|
*
|
||||||
|
* See conn_linux.go for an implementation on the linux platform.
|
||||||
|
*/
|
||||||
|
|
||||||
|
type NativeBind struct {
|
||||||
|
ipv4 *net.UDPConn
|
||||||
|
ipv6 *net.UDPConn
|
||||||
|
}
|
||||||
|
|
||||||
|
type NativeEndpoint net.UDPAddr
|
||||||
|
|
||||||
|
var _ Bind = (*NativeBind)(nil)
|
||||||
|
var _ Endpoint = (*NativeEndpoint)(nil)
|
||||||
|
|
||||||
|
func CreateEndpoint(s string) (Endpoint, error) {
|
||||||
|
addr, err := parseEndpoint(s)
|
||||||
|
return (*NativeEndpoint)(addr), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (_ *NativeEndpoint) ClearSrc() {}
|
||||||
|
|
||||||
|
func (e *NativeEndpoint) DstIP() net.IP {
|
||||||
|
return (*net.UDPAddr)(e).IP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NativeEndpoint) SrcIP() net.IP {
|
||||||
|
return nil // not supported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NativeEndpoint) DstToBytes() []byte {
|
||||||
|
addr := (*net.UDPAddr)(e)
|
||||||
|
out := addr.IP
|
||||||
|
out = append(out, byte(addr.Port&0xff))
|
||||||
|
out = append(out, byte((addr.Port>>8)&0xff))
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NativeEndpoint) DstToString() string {
|
||||||
|
return (*net.UDPAddr)(e).String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *NativeEndpoint) SrcToString() string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||||
|
|
||||||
|
// listen
|
||||||
|
|
||||||
|
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// retrieve port
|
||||||
|
|
||||||
|
laddr := conn.LocalAddr()
|
||||||
|
uaddr, err := net.ResolveUDPAddr(
|
||||||
|
laddr.Network(),
|
||||||
|
laddr.String(),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
return conn, uaddr.Port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
|
||||||
|
var err error
|
||||||
|
var bind NativeBind
|
||||||
|
|
||||||
|
port := int(uport)
|
||||||
|
|
||||||
|
bind.ipv4, port, err = listenNet("udp4", port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
bind.ipv6, port, err = listenNet("udp6", port)
|
||||||
|
if err != nil {
|
||||||
|
bind.ipv4.Close()
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &bind, uint16(port), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) Close() error {
|
||||||
|
err1 := bind.ipv4.Close()
|
||||||
|
err2 := bind.ipv6.Close()
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
||||||
|
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
|
||||||
|
return n, (*NativeEndpoint)(endpoint), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||||
|
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
|
||||||
|
return n, (*NativeEndpoint)(endpoint), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
|
||||||
|
var err error
|
||||||
|
nend := endpoint.(*NativeEndpoint)
|
||||||
|
if nend.IP.To16() != nil {
|
||||||
|
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
||||||
|
} else {
|
||||||
|
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) SetMark(_ uint32) error {
|
||||||
|
return nil
|
||||||
|
}
|
690
conn_linux.go
Normal file
690
conn_linux.go
Normal file
|
@ -0,0 +1,690 @@
|
||||||
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
|
*
|
||||||
|
* This implements userspace semantics of "sticky sockets", modeled after
|
||||||
|
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||||
|
* of the sticky-sockets.c example code:
|
||||||
|
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
|
||||||
|
*
|
||||||
|
* Currently there is no way to achieve this within the net package:
|
||||||
|
* See e.g. https://github.com/golang/go/issues/17930
|
||||||
|
* So this code is remains platform dependent.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"./rwcancel"
|
||||||
|
"errors"
|
||||||
|
"golang.org/x/sys/unix"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
type IPv4Source struct {
|
||||||
|
src [4]byte
|
||||||
|
ifindex int32
|
||||||
|
}
|
||||||
|
|
||||||
|
type IPv6Source struct {
|
||||||
|
src [16]byte
|
||||||
|
//ifindex belongs in dst.ZoneId
|
||||||
|
}
|
||||||
|
|
||||||
|
type NativeEndpoint struct {
|
||||||
|
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
|
||||||
|
src [unsafe.Sizeof(IPv6Source{})]byte
|
||||||
|
isV6 bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (endpoint *NativeEndpoint) src4() *IPv4Source {
|
||||||
|
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (endpoint *NativeEndpoint) src6() *IPv6Source {
|
||||||
|
return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
|
||||||
|
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
|
||||||
|
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
type NativeBind struct {
|
||||||
|
sock4 int
|
||||||
|
sock6 int
|
||||||
|
netlinkSock int
|
||||||
|
netlinkCancel *rwcancel.RWCancel
|
||||||
|
lastMark uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ Endpoint = (*NativeEndpoint)(nil)
|
||||||
|
var _ Bind = (*NativeBind)(nil)
|
||||||
|
|
||||||
|
func CreateEndpoint(s string) (Endpoint, error) {
|
||||||
|
var end NativeEndpoint
|
||||||
|
addr, err := parseEndpoint(s)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ipv4 := addr.IP.To4()
|
||||||
|
if ipv4 != nil {
|
||||||
|
dst := end.dst4()
|
||||||
|
end.isV6 = false
|
||||||
|
dst.Port = addr.Port
|
||||||
|
copy(dst.Addr[:], ipv4)
|
||||||
|
end.ClearSrc()
|
||||||
|
return &end, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ipv6 := addr.IP.To16()
|
||||||
|
if ipv6 != nil {
|
||||||
|
zone, err := zoneToUint32(addr.Zone)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
dst := end.dst6()
|
||||||
|
end.isV6 = true
|
||||||
|
dst.Port = addr.Port
|
||||||
|
dst.ZoneId = zone
|
||||||
|
copy(dst.Addr[:], ipv6[:])
|
||||||
|
end.ClearSrc()
|
||||||
|
return &end, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.New("Invalid IP address")
|
||||||
|
}
|
||||||
|
|
||||||
|
func createNetlinkRouteSocket() (int, error) {
|
||||||
|
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
saddr := &unix.SockaddrNetlink{
|
||||||
|
Family: unix.AF_NETLINK,
|
||||||
|
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
|
||||||
|
}
|
||||||
|
err = unix.Bind(sock, saddr)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(sock)
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
return sock, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
|
||||||
|
var err error
|
||||||
|
var bind NativeBind
|
||||||
|
|
||||||
|
bind.netlinkSock, err = createNetlinkRouteSocket()
|
||||||
|
if err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
|
||||||
|
if err != nil {
|
||||||
|
unix.Close(bind.netlinkSock)
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
go bind.routineRouteListener(device)
|
||||||
|
|
||||||
|
bind.sock6, port, err = create6(port)
|
||||||
|
if err != nil {
|
||||||
|
bind.netlinkCancel.Cancel()
|
||||||
|
return nil, port, err
|
||||||
|
}
|
||||||
|
|
||||||
|
bind.sock4, port, err = create4(port)
|
||||||
|
if err != nil {
|
||||||
|
bind.netlinkCancel.Cancel()
|
||||||
|
unix.Close(bind.sock6)
|
||||||
|
}
|
||||||
|
return &bind, port, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) SetMark(value uint32) error {
|
||||||
|
err := unix.SetsockoptInt(
|
||||||
|
bind.sock6,
|
||||||
|
unix.SOL_SOCKET,
|
||||||
|
unix.SO_MARK,
|
||||||
|
int(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = unix.SetsockoptInt(
|
||||||
|
bind.sock4,
|
||||||
|
unix.SOL_SOCKET,
|
||||||
|
unix.SO_MARK,
|
||||||
|
int(value),
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
bind.lastMark = value
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func closeUnblock(fd int) error {
|
||||||
|
// shutdown to unblock readers and writers
|
||||||
|
unix.Shutdown(fd, unix.SHUT_RDWR)
|
||||||
|
return unix.Close(fd)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) Close() error {
|
||||||
|
err1 := closeUnblock(bind.sock6)
|
||||||
|
err2 := closeUnblock(bind.sock4)
|
||||||
|
err3 := bind.netlinkCancel.Cancel()
|
||||||
|
|
||||||
|
if err1 != nil {
|
||||||
|
return err1
|
||||||
|
}
|
||||||
|
if err2 != nil {
|
||||||
|
return err2
|
||||||
|
}
|
||||||
|
return err3
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||||
|
var end NativeEndpoint
|
||||||
|
n, err := receive6(
|
||||||
|
bind.sock6,
|
||||||
|
buff,
|
||||||
|
&end,
|
||||||
|
)
|
||||||
|
return n, &end, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
||||||
|
var end NativeEndpoint
|
||||||
|
n, err := receive4(
|
||||||
|
bind.sock4,
|
||||||
|
buff,
|
||||||
|
&end,
|
||||||
|
)
|
||||||
|
return n, &end, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) Send(buff []byte, end Endpoint) error {
|
||||||
|
nend := end.(*NativeEndpoint)
|
||||||
|
if !nend.isV6 {
|
||||||
|
return send4(bind.sock4, nend, buff)
|
||||||
|
} else {
|
||||||
|
return send6(bind.sock6, nend, buff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *NativeEndpoint) SrcIP() net.IP {
|
||||||
|
if !end.isV6 {
|
||||||
|
return net.IPv4(
|
||||||
|
end.src4().src[0],
|
||||||
|
end.src4().src[1],
|
||||||
|
end.src4().src[2],
|
||||||
|
end.src4().src[3],
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
return end.src6().src[:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *NativeEndpoint) DstIP() net.IP {
|
||||||
|
if !end.isV6 {
|
||||||
|
return net.IPv4(
|
||||||
|
end.dst4().Addr[0],
|
||||||
|
end.dst4().Addr[1],
|
||||||
|
end.dst4().Addr[2],
|
||||||
|
end.dst4().Addr[3],
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
return end.dst6().Addr[:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *NativeEndpoint) DstToBytes() []byte {
|
||||||
|
if !end.isV6 {
|
||||||
|
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
|
||||||
|
} else {
|
||||||
|
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *NativeEndpoint) SrcToString() string {
|
||||||
|
return end.SrcIP().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *NativeEndpoint) DstToString() string {
|
||||||
|
var udpAddr net.UDPAddr
|
||||||
|
udpAddr.IP = end.DstIP()
|
||||||
|
if !end.isV6 {
|
||||||
|
udpAddr.Port = end.dst4().Port
|
||||||
|
} else {
|
||||||
|
udpAddr.Port = end.dst6().Port
|
||||||
|
}
|
||||||
|
return udpAddr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *NativeEndpoint) ClearDst() {
|
||||||
|
for i := range end.dst {
|
||||||
|
end.dst[i] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (end *NativeEndpoint) ClearSrc() {
|
||||||
|
for i := range end.src {
|
||||||
|
end.src[i] = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func zoneToUint32(zone string) (uint32, error) {
|
||||||
|
if zone == "" {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
if intr, err := net.InterfaceByName(zone); err == nil {
|
||||||
|
return uint32(intr.Index), nil
|
||||||
|
}
|
||||||
|
n, err := strconv.ParseUint(zone, 10, 32)
|
||||||
|
return uint32(n), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func create4(port uint16) (int, uint16, error) {
|
||||||
|
|
||||||
|
// create socket
|
||||||
|
|
||||||
|
fd, err := unix.Socket(
|
||||||
|
unix.AF_INET,
|
||||||
|
unix.SOCK_DGRAM,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return -1, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := unix.SockaddrInet4{
|
||||||
|
Port: int(port),
|
||||||
|
}
|
||||||
|
|
||||||
|
// set sockopts and bind
|
||||||
|
|
||||||
|
if err := func() error {
|
||||||
|
if err := unix.SetsockoptInt(
|
||||||
|
fd,
|
||||||
|
unix.SOL_SOCKET,
|
||||||
|
unix.SO_REUSEADDR,
|
||||||
|
1,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unix.SetsockoptInt(
|
||||||
|
fd,
|
||||||
|
unix.IPPROTO_IP,
|
||||||
|
unix.IP_PKTINFO,
|
||||||
|
1,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return unix.Bind(fd, &addr)
|
||||||
|
}(); err != nil {
|
||||||
|
unix.Close(fd)
|
||||||
|
return -1, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fd, uint16(addr.Port), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func create6(port uint16) (int, uint16, error) {
|
||||||
|
|
||||||
|
// create socket
|
||||||
|
|
||||||
|
fd, err := unix.Socket(
|
||||||
|
unix.AF_INET6,
|
||||||
|
unix.SOCK_DGRAM,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return -1, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// set sockopts and bind
|
||||||
|
|
||||||
|
addr := unix.SockaddrInet6{
|
||||||
|
Port: int(port),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := func() error {
|
||||||
|
|
||||||
|
if err := unix.SetsockoptInt(
|
||||||
|
fd,
|
||||||
|
unix.SOL_SOCKET,
|
||||||
|
unix.SO_REUSEADDR,
|
||||||
|
1,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unix.SetsockoptInt(
|
||||||
|
fd,
|
||||||
|
unix.IPPROTO_IPV6,
|
||||||
|
unix.IPV6_RECVPKTINFO,
|
||||||
|
1,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := unix.SetsockoptInt(
|
||||||
|
fd,
|
||||||
|
unix.IPPROTO_IPV6,
|
||||||
|
unix.IPV6_V6ONLY,
|
||||||
|
1,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return unix.Bind(fd, &addr)
|
||||||
|
|
||||||
|
}(); err != nil {
|
||||||
|
unix.Close(fd)
|
||||||
|
return -1, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return fd, uint16(addr.Port), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func send4(sock int, end *NativeEndpoint, buff []byte) error {
|
||||||
|
|
||||||
|
// construct message header
|
||||||
|
|
||||||
|
cmsg := struct {
|
||||||
|
cmsghdr unix.Cmsghdr
|
||||||
|
pktinfo unix.Inet4Pktinfo
|
||||||
|
}{
|
||||||
|
unix.Cmsghdr{
|
||||||
|
Level: unix.IPPROTO_IP,
|
||||||
|
Type: unix.IP_PKTINFO,
|
||||||
|
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
|
||||||
|
},
|
||||||
|
unix.Inet4Pktinfo{
|
||||||
|
Spec_dst: end.src4().src,
|
||||||
|
Ifindex: end.src4().ifindex,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear src and retry
|
||||||
|
|
||||||
|
if err == unix.EINVAL {
|
||||||
|
end.ClearSrc()
|
||||||
|
cmsg.pktinfo = unix.Inet4Pktinfo{}
|
||||||
|
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func send6(sock int, end *NativeEndpoint, buff []byte) error {
|
||||||
|
|
||||||
|
// construct message header
|
||||||
|
|
||||||
|
cmsg := struct {
|
||||||
|
cmsghdr unix.Cmsghdr
|
||||||
|
pktinfo unix.Inet6Pktinfo
|
||||||
|
}{
|
||||||
|
unix.Cmsghdr{
|
||||||
|
Level: unix.IPPROTO_IPV6,
|
||||||
|
Type: unix.IPV6_PKTINFO,
|
||||||
|
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
|
||||||
|
},
|
||||||
|
unix.Inet6Pktinfo{
|
||||||
|
Addr: end.src6().src,
|
||||||
|
Ifindex: end.dst6().ZoneId,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmsg.pktinfo.Addr == [16]byte{} {
|
||||||
|
cmsg.pktinfo.Ifindex = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// clear src and retry
|
||||||
|
|
||||||
|
if err == unix.EINVAL {
|
||||||
|
end.ClearSrc()
|
||||||
|
cmsg.pktinfo = unix.Inet6Pktinfo{}
|
||||||
|
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||||
|
|
||||||
|
// contruct message header
|
||||||
|
|
||||||
|
var cmsg struct {
|
||||||
|
cmsghdr unix.Cmsghdr
|
||||||
|
pktinfo unix.Inet4Pktinfo
|
||||||
|
}
|
||||||
|
|
||||||
|
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
end.isV6 = false
|
||||||
|
|
||||||
|
if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
|
||||||
|
*end.dst4() = *newDst4
|
||||||
|
}
|
||||||
|
|
||||||
|
// update source cache
|
||||||
|
|
||||||
|
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
||||||
|
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
||||||
|
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
||||||
|
end.src4().src = cmsg.pktinfo.Spec_dst
|
||||||
|
end.src4().ifindex = cmsg.pktinfo.Ifindex
|
||||||
|
}
|
||||||
|
|
||||||
|
return size, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||||
|
|
||||||
|
// contruct message header
|
||||||
|
|
||||||
|
var cmsg struct {
|
||||||
|
cmsghdr unix.Cmsghdr
|
||||||
|
pktinfo unix.Inet6Pktinfo
|
||||||
|
}
|
||||||
|
|
||||||
|
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
end.isV6 = true
|
||||||
|
|
||||||
|
if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
|
||||||
|
*end.dst6() = *newDst6
|
||||||
|
}
|
||||||
|
|
||||||
|
// update source cache
|
||||||
|
|
||||||
|
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
|
||||||
|
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
|
||||||
|
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
|
||||||
|
end.src6().src = cmsg.pktinfo.Addr
|
||||||
|
end.dst6().ZoneId = cmsg.pktinfo.Ifindex
|
||||||
|
}
|
||||||
|
|
||||||
|
return size, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (bind *NativeBind) routineRouteListener(device *Device) {
|
||||||
|
type peerEndpointPtr struct {
|
||||||
|
peer *Peer
|
||||||
|
endpoint *Endpoint
|
||||||
|
}
|
||||||
|
var reqPeer map[uint32]peerEndpointPtr
|
||||||
|
|
||||||
|
defer unix.Close(bind.netlinkSock)
|
||||||
|
|
||||||
|
for msg := make([]byte, 1<<16); ; {
|
||||||
|
var err error
|
||||||
|
var msgn int
|
||||||
|
for {
|
||||||
|
msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
|
||||||
|
if err == nil || !rwcancel.ErrorIsEAGAIN(err) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if !bind.netlinkCancel.ReadyRead() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||||
|
|
||||||
|
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||||
|
|
||||||
|
if uint(hdr.Len) > uint(len(remain)) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
switch hdr.Type {
|
||||||
|
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
|
||||||
|
if hdr.Seq <= MaxPeers {
|
||||||
|
if uint(len(remain)) < uint(hdr.Len) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
|
||||||
|
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
|
||||||
|
for {
|
||||||
|
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
|
||||||
|
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
|
||||||
|
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
|
||||||
|
if reqPeer == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pePtr, ok := reqPeer[hdr.Seq]
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pePtr.peer.mutex.Lock()
|
||||||
|
if &pePtr.peer.endpoint != pePtr.endpoint {
|
||||||
|
pePtr.peer.mutex.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
|
||||||
|
pePtr.peer.mutex.Unlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
|
||||||
|
pePtr.peer.mutex.Unlock()
|
||||||
|
}
|
||||||
|
attr = attr[attrhdr.Len:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
reqPeer = make(map[uint32]peerEndpointPtr)
|
||||||
|
go func() {
|
||||||
|
device.peers.mutex.RLock()
|
||||||
|
i := uint32(1)
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.mutex.RLock()
|
||||||
|
if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
|
||||||
|
peer.mutex.RUnlock()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
|
||||||
|
peer.mutex.RUnlock()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
nlmsg := struct {
|
||||||
|
hdr unix.NlMsghdr
|
||||||
|
msg unix.RtMsg
|
||||||
|
dsthdr unix.RtAttr
|
||||||
|
dst [4]byte
|
||||||
|
srchdr unix.RtAttr
|
||||||
|
src [4]byte
|
||||||
|
markhdr unix.RtAttr
|
||||||
|
mark uint32
|
||||||
|
}{
|
||||||
|
unix.NlMsghdr{
|
||||||
|
Type: uint16(unix.RTM_GETROUTE),
|
||||||
|
Flags: unix.NLM_F_REQUEST,
|
||||||
|
Seq: i,
|
||||||
|
},
|
||||||
|
unix.RtMsg{
|
||||||
|
Family: unix.AF_INET,
|
||||||
|
Dst_len: 32,
|
||||||
|
Src_len: 32,
|
||||||
|
},
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: unix.RTA_DST,
|
||||||
|
},
|
||||||
|
peer.endpoint.(*NativeEndpoint).dst4().Addr,
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: unix.RTA_SRC,
|
||||||
|
},
|
||||||
|
peer.endpoint.(*NativeEndpoint).src4().src,
|
||||||
|
unix.RtAttr{
|
||||||
|
Len: 8,
|
||||||
|
Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix
|
||||||
|
},
|
||||||
|
uint32(bind.lastMark),
|
||||||
|
}
|
||||||
|
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
||||||
|
reqPeer[i] = peerEndpointPtr{
|
||||||
|
peer: peer,
|
||||||
|
endpoint: &peer.endpoint,
|
||||||
|
}
|
||||||
|
peer.mutex.RUnlock()
|
||||||
|
i++
|
||||||
|
_, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
device.peers.mutex.RUnlock()
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
remain = remain[hdr.Len:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
45
constants.go
Normal file
45
constants.go
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Specification constants */
|
||||||
|
|
||||||
|
const (
|
||||||
|
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
|
||||||
|
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
|
||||||
|
RekeyAfterTime = time.Second * 120
|
||||||
|
RekeyAttemptTime = time.Second * 90
|
||||||
|
RekeyTimeout = time.Second * 5
|
||||||
|
MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */
|
||||||
|
RekeyTimeoutJitterMaxMs = 334
|
||||||
|
RejectAfterTime = time.Second * 180
|
||||||
|
KeepaliveTimeout = time.Second * 10
|
||||||
|
CookieRefreshTime = time.Second * 120
|
||||||
|
HandshakeInitationRate = time.Second / 20
|
||||||
|
PaddingMultiple = 16
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Implementation specific constants */
|
||||||
|
|
||||||
|
const (
|
||||||
|
QueueOutboundSize = 1024
|
||||||
|
QueueInboundSize = 1024
|
||||||
|
QueueHandshakeSize = 1024
|
||||||
|
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
|
||||||
|
MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive)
|
||||||
|
MaxMessageSize = MaxSegmentSize // maximum size of transport message
|
||||||
|
MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
UnderLoadQueueSize = QueueHandshakeSize / 8
|
||||||
|
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
|
||||||
|
MaxPeers = 1 << 16 // maximum number of configured peers
|
||||||
|
)
|
|
@ -1,23 +1,23 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"./xchacha20poly1305"
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type CookieChecker struct {
|
type CookieChecker struct {
|
||||||
sync.RWMutex
|
mutex sync.RWMutex
|
||||||
mac1 struct {
|
mac1 struct {
|
||||||
key [blake2s.Size]byte
|
key [blake2s.Size]byte
|
||||||
}
|
}
|
||||||
mac2 struct {
|
mac2 struct {
|
||||||
|
@ -28,8 +28,8 @@ type CookieChecker struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type CookieGenerator struct {
|
type CookieGenerator struct {
|
||||||
sync.RWMutex
|
mutex sync.RWMutex
|
||||||
mac1 struct {
|
mac1 struct {
|
||||||
key [blake2s.Size]byte
|
key [blake2s.Size]byte
|
||||||
}
|
}
|
||||||
mac2 struct {
|
mac2 struct {
|
||||||
|
@ -42,8 +42,8 @@ type CookieGenerator struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *CookieChecker) Init(pk NoisePublicKey) {
|
func (st *CookieChecker) Init(pk NoisePublicKey) {
|
||||||
st.Lock()
|
st.mutex.Lock()
|
||||||
defer st.Unlock()
|
defer st.mutex.Unlock()
|
||||||
|
|
||||||
// mac1 state
|
// mac1 state
|
||||||
|
|
||||||
|
@ -67,8 +67,8 @@ func (st *CookieChecker) Init(pk NoisePublicKey) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *CookieChecker) CheckMAC1(msg []byte) bool {
|
func (st *CookieChecker) CheckMAC1(msg []byte) bool {
|
||||||
st.RLock()
|
st.mutex.RLock()
|
||||||
defer st.RUnlock()
|
defer st.mutex.RUnlock()
|
||||||
|
|
||||||
size := len(msg)
|
size := len(msg)
|
||||||
smac2 := size - blake2s.Size128
|
smac2 := size - blake2s.Size128
|
||||||
|
@ -83,11 +83,11 @@ func (st *CookieChecker) CheckMAC1(msg []byte) bool {
|
||||||
return hmac.Equal(mac1[:], msg[smac1:smac2])
|
return hmac.Equal(mac1[:], msg[smac1:smac2])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *CookieChecker) CheckMAC2(msg, src []byte) bool {
|
func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
|
||||||
st.RLock()
|
st.mutex.RLock()
|
||||||
defer st.RUnlock()
|
defer st.mutex.RUnlock()
|
||||||
|
|
||||||
if time.Since(st.mac2.secretSet) > CookieRefreshTime {
|
if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,21 +119,22 @@ func (st *CookieChecker) CreateReply(
|
||||||
recv uint32,
|
recv uint32,
|
||||||
src []byte,
|
src []byte,
|
||||||
) (*MessageCookieReply, error) {
|
) (*MessageCookieReply, error) {
|
||||||
st.RLock()
|
|
||||||
|
st.mutex.RLock()
|
||||||
|
|
||||||
// refresh cookie secret
|
// refresh cookie secret
|
||||||
|
|
||||||
if time.Since(st.mac2.secretSet) > CookieRefreshTime {
|
if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime {
|
||||||
st.RUnlock()
|
st.mutex.RUnlock()
|
||||||
st.Lock()
|
st.mutex.Lock()
|
||||||
_, err := rand.Read(st.mac2.secret[:])
|
_, err := rand.Read(st.mac2.secret[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
st.Unlock()
|
st.mutex.Unlock()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
st.mac2.secretSet = time.Now()
|
st.mac2.secretSet = time.Now()
|
||||||
st.Unlock()
|
st.mutex.Unlock()
|
||||||
st.RLock()
|
st.mutex.RLock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// derive cookie
|
// derive cookie
|
||||||
|
@ -158,21 +159,26 @@ func (st *CookieChecker) CreateReply(
|
||||||
|
|
||||||
_, err := rand.Read(reply.Nonce[:])
|
_, err := rand.Read(reply.Nonce[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
st.RUnlock()
|
st.mutex.RUnlock()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
|
xchacha20poly1305.Encrypt(
|
||||||
xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2])
|
reply.Cookie[:0],
|
||||||
|
&reply.Nonce,
|
||||||
|
cookie[:],
|
||||||
|
msg[smac1:smac2],
|
||||||
|
&st.mac2.encryptionKey,
|
||||||
|
)
|
||||||
|
|
||||||
st.RUnlock()
|
st.mutex.RUnlock()
|
||||||
|
|
||||||
return reply, nil
|
return reply, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *CookieGenerator) Init(pk NoisePublicKey) {
|
func (st *CookieGenerator) Init(pk NoisePublicKey) {
|
||||||
st.Lock()
|
st.mutex.Lock()
|
||||||
defer st.Unlock()
|
defer st.mutex.Unlock()
|
||||||
|
|
||||||
func() {
|
func() {
|
||||||
hash, _ := blake2s.New256(nil)
|
hash, _ := blake2s.New256(nil)
|
||||||
|
@ -192,8 +198,8 @@ func (st *CookieGenerator) Init(pk NoisePublicKey) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
|
func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
|
||||||
st.Lock()
|
st.mutex.Lock()
|
||||||
defer st.Unlock()
|
defer st.mutex.Unlock()
|
||||||
|
|
||||||
if !st.mac2.hasLastMAC1 {
|
if !st.mac2.hasLastMAC1 {
|
||||||
return false
|
return false
|
||||||
|
@ -201,8 +207,14 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
|
||||||
|
|
||||||
var cookie [blake2s.Size128]byte
|
var cookie [blake2s.Size128]byte
|
||||||
|
|
||||||
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
|
_, err := xchacha20poly1305.Decrypt(
|
||||||
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
|
cookie[:0],
|
||||||
|
&msg.Nonce,
|
||||||
|
msg.Cookie[:],
|
||||||
|
st.mac2.lastMAC1[:],
|
||||||
|
&st.mac2.encryptionKey,
|
||||||
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -213,6 +225,7 @@ func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (st *CookieGenerator) AddMacs(msg []byte) {
|
func (st *CookieGenerator) AddMacs(msg []byte) {
|
||||||
|
|
||||||
size := len(msg)
|
size := len(msg)
|
||||||
|
|
||||||
smac2 := size - blake2s.Size128
|
smac2 := size - blake2s.Size128
|
||||||
|
@ -221,8 +234,8 @@ func (st *CookieGenerator) AddMacs(msg []byte) {
|
||||||
mac1 := msg[smac1:smac2]
|
mac1 := msg[smac1:smac2]
|
||||||
mac2 := msg[smac2:]
|
mac2 := msg[smac2:]
|
||||||
|
|
||||||
st.Lock()
|
st.mutex.Lock()
|
||||||
defer st.Unlock()
|
defer st.mutex.Unlock()
|
||||||
|
|
||||||
// set mac1
|
// set mac1
|
||||||
|
|
||||||
|
@ -236,7 +249,7 @@ func (st *CookieGenerator) AddMacs(msg []byte) {
|
||||||
|
|
||||||
// set mac2
|
// set mac2
|
||||||
|
|
||||||
if time.Since(st.mac2.cookieSet) > CookieRefreshTime {
|
if time.Now().Sub(st.mac2.cookieSet) > CookieRefreshTime {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,16 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCookieMAC1(t *testing.T) {
|
func TestCookieMAC1(t *testing.T) {
|
||||||
|
|
||||||
// setup generator / checker
|
// setup generator / checker
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -131,12 +132,12 @@ func TestCookieMAC1(t *testing.T) {
|
||||||
|
|
||||||
msg[5] ^= 0x20
|
msg[5] ^= 0x20
|
||||||
|
|
||||||
srcBad1 := []byte{192, 168, 13, 37, 40, 1}
|
srcBad1 := []byte{192, 168, 13, 37, 40, 01}
|
||||||
if checker.CheckMAC2(msg, srcBad1) {
|
if checker.CheckMAC2(msg, srcBad1) {
|
||||||
t.Fatal("MAC2 generation/verification failed")
|
t.Fatal("MAC2 generation/verification failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
srcBad2 := []byte{192, 168, 13, 38, 40, 1}
|
srcBad2 := []byte{192, 168, 13, 38, 40, 01}
|
||||||
if checker.CheckMAC2(msg, srcBad2) {
|
if checker.CheckMAC2(msg, srcBad2) {
|
||||||
t.Fatal("MAC2 generation/verification failed")
|
t.Fatal("MAC2 generation/verification failed")
|
||||||
}
|
}
|
390
device.go
Normal file
390
device.go
Normal file
|
@ -0,0 +1,390 @@
|
||||||
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"./ratelimiter"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
DeviceRoutineNumberPerCPU = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
type Device struct {
|
||||||
|
isUp AtomicBool // device is (going) up
|
||||||
|
isClosed AtomicBool // device is closed? (acting as guard)
|
||||||
|
log *Logger
|
||||||
|
|
||||||
|
// synchronized resources (locks acquired in order)
|
||||||
|
|
||||||
|
state struct {
|
||||||
|
stopping sync.WaitGroup
|
||||||
|
mutex sync.Mutex
|
||||||
|
changing AtomicBool
|
||||||
|
current bool
|
||||||
|
}
|
||||||
|
|
||||||
|
net struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
bind Bind // bind interface
|
||||||
|
port uint16 // listening port
|
||||||
|
fwmark uint32 // mark value (0 = disabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
staticIdentity struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
privateKey NoisePrivateKey
|
||||||
|
publicKey NoisePublicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
peers struct {
|
||||||
|
mutex sync.RWMutex
|
||||||
|
keyMap map[NoisePublicKey]*Peer
|
||||||
|
}
|
||||||
|
|
||||||
|
// unprotected / "self-synchronising resources"
|
||||||
|
|
||||||
|
allowedips AllowedIPs
|
||||||
|
indexTable IndexTable
|
||||||
|
cookieChecker CookieChecker
|
||||||
|
|
||||||
|
rate struct {
|
||||||
|
underLoadUntil atomic.Value
|
||||||
|
limiter ratelimiter.Ratelimiter
|
||||||
|
}
|
||||||
|
|
||||||
|
pool struct {
|
||||||
|
messageBuffers sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
queue struct {
|
||||||
|
encryption chan *QueueOutboundElement
|
||||||
|
decryption chan *QueueInboundElement
|
||||||
|
handshake chan QueueHandshakeElement
|
||||||
|
}
|
||||||
|
|
||||||
|
signals struct {
|
||||||
|
stop chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
tun struct {
|
||||||
|
device TUNDevice
|
||||||
|
mtu int32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Converts the peer into a "zombie", which remains in the peer map,
|
||||||
|
* but processes no packets and does not exists in the routing table.
|
||||||
|
*
|
||||||
|
* Must hold device.peers.mutex.
|
||||||
|
*/
|
||||||
|
func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
|
||||||
|
|
||||||
|
// stop routing and processing of packets
|
||||||
|
|
||||||
|
device.allowedips.RemoveByPeer(peer)
|
||||||
|
peer.Stop()
|
||||||
|
|
||||||
|
// remove from peer map
|
||||||
|
|
||||||
|
delete(device.peers.keyMap, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func deviceUpdateState(device *Device) {
|
||||||
|
|
||||||
|
// check if state already being updated (guard)
|
||||||
|
|
||||||
|
if device.state.changing.Swap(true) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// compare to current state of device
|
||||||
|
|
||||||
|
device.state.mutex.Lock()
|
||||||
|
|
||||||
|
newIsUp := device.isUp.Get()
|
||||||
|
|
||||||
|
if newIsUp == device.state.current {
|
||||||
|
device.state.changing.Set(false)
|
||||||
|
device.state.mutex.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// change state of device
|
||||||
|
|
||||||
|
switch newIsUp {
|
||||||
|
case true:
|
||||||
|
if err := device.BindUpdate(); err != nil {
|
||||||
|
device.isUp.Set(false)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
device.peers.mutex.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.Start()
|
||||||
|
}
|
||||||
|
device.peers.mutex.RUnlock()
|
||||||
|
|
||||||
|
case false:
|
||||||
|
device.BindClose()
|
||||||
|
device.peers.mutex.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.Stop()
|
||||||
|
}
|
||||||
|
device.peers.mutex.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// update state variables
|
||||||
|
|
||||||
|
device.state.current = newIsUp
|
||||||
|
device.state.changing.Set(false)
|
||||||
|
device.state.mutex.Unlock()
|
||||||
|
|
||||||
|
// check for state change in the mean time
|
||||||
|
|
||||||
|
deviceUpdateState(device)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) Up() {
|
||||||
|
|
||||||
|
// closed device cannot be brought up
|
||||||
|
|
||||||
|
if device.isClosed.Get() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
device.state.mutex.Lock()
|
||||||
|
device.isUp.Set(true)
|
||||||
|
device.state.mutex.Unlock()
|
||||||
|
deviceUpdateState(device)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) Down() {
|
||||||
|
device.state.mutex.Lock()
|
||||||
|
device.isUp.Set(false)
|
||||||
|
device.state.mutex.Unlock()
|
||||||
|
deviceUpdateState(device)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) IsUnderLoad() bool {
|
||||||
|
|
||||||
|
// check if currently under load
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
|
||||||
|
if underLoad {
|
||||||
|
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if recently under load
|
||||||
|
|
||||||
|
until := device.rate.underLoadUntil.Load().(time.Time)
|
||||||
|
return until.After(now)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||||
|
|
||||||
|
// lock required resources
|
||||||
|
|
||||||
|
device.staticIdentity.mutex.Lock()
|
||||||
|
defer device.staticIdentity.mutex.Unlock()
|
||||||
|
|
||||||
|
device.peers.mutex.Lock()
|
||||||
|
defer device.peers.mutex.Unlock()
|
||||||
|
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.handshake.mutex.RLock()
|
||||||
|
defer peer.handshake.mutex.RUnlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove peers with matching public keys
|
||||||
|
|
||||||
|
publicKey := sk.publicKey()
|
||||||
|
for key, peer := range device.peers.keyMap {
|
||||||
|
if peer.handshake.remoteStatic.Equals(publicKey) {
|
||||||
|
unsafeRemovePeer(device, peer, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// update key material
|
||||||
|
|
||||||
|
device.staticIdentity.privateKey = sk
|
||||||
|
device.staticIdentity.publicKey = publicKey
|
||||||
|
device.cookieChecker.Init(publicKey)
|
||||||
|
|
||||||
|
// do static-static DH pre-computations
|
||||||
|
|
||||||
|
rmKey := device.staticIdentity.privateKey.IsZero()
|
||||||
|
|
||||||
|
for key, peer := range device.peers.keyMap {
|
||||||
|
|
||||||
|
handshake := &peer.handshake
|
||||||
|
|
||||||
|
if rmKey {
|
||||||
|
handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
|
||||||
|
} else {
|
||||||
|
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||||
|
unsafeRemovePeer(device, peer, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
||||||
|
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
|
||||||
|
device.pool.messageBuffers.Put(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDevice(tun TUNDevice, logger *Logger) *Device {
|
||||||
|
device := new(Device)
|
||||||
|
|
||||||
|
device.isUp.Set(false)
|
||||||
|
device.isClosed.Set(false)
|
||||||
|
|
||||||
|
device.log = logger
|
||||||
|
|
||||||
|
device.tun.device = tun
|
||||||
|
mtu, err := device.tun.device.MTU()
|
||||||
|
if err != nil {
|
||||||
|
logger.Error.Println("Trouble determining MTU, assuming default:", err)
|
||||||
|
mtu = DefaultMTU
|
||||||
|
}
|
||||||
|
device.tun.mtu = int32(mtu)
|
||||||
|
|
||||||
|
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
||||||
|
|
||||||
|
device.rate.limiter.Init()
|
||||||
|
device.rate.underLoadUntil.Store(time.Time{})
|
||||||
|
|
||||||
|
device.indexTable.Init()
|
||||||
|
device.allowedips.Reset()
|
||||||
|
|
||||||
|
device.pool.messageBuffers = sync.Pool{
|
||||||
|
New: func() interface{} {
|
||||||
|
return new([MaxMessageSize]byte)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// create queues
|
||||||
|
|
||||||
|
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
|
||||||
|
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
|
||||||
|
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
|
||||||
|
|
||||||
|
// prepare signals
|
||||||
|
|
||||||
|
device.signals.stop = make(chan struct{})
|
||||||
|
|
||||||
|
// prepare net
|
||||||
|
|
||||||
|
device.net.port = 0
|
||||||
|
device.net.bind = nil
|
||||||
|
|
||||||
|
// start workers
|
||||||
|
|
||||||
|
cpus := runtime.NumCPU()
|
||||||
|
device.state.stopping.Add(DeviceRoutineNumberPerCPU * cpus)
|
||||||
|
for i := 0; i < cpus; i += 1 {
|
||||||
|
go device.RoutineEncryption()
|
||||||
|
go device.RoutineDecryption()
|
||||||
|
go device.RoutineHandshake()
|
||||||
|
}
|
||||||
|
|
||||||
|
go device.RoutineReadFromTUN()
|
||||||
|
go device.RoutineTUNEventReader()
|
||||||
|
|
||||||
|
return device
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
||||||
|
device.peers.mutex.RLock()
|
||||||
|
defer device.peers.mutex.RUnlock()
|
||||||
|
|
||||||
|
return device.peers.keyMap[pk]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) RemovePeer(key NoisePublicKey) {
|
||||||
|
device.peers.mutex.Lock()
|
||||||
|
defer device.peers.mutex.Unlock()
|
||||||
|
|
||||||
|
// stop peer and remove from routing
|
||||||
|
|
||||||
|
peer, ok := device.peers.keyMap[key]
|
||||||
|
if ok {
|
||||||
|
unsafeRemovePeer(device, peer, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) RemoveAllPeers() {
|
||||||
|
device.peers.mutex.Lock()
|
||||||
|
defer device.peers.mutex.Unlock()
|
||||||
|
|
||||||
|
for key, peer := range device.peers.keyMap {
|
||||||
|
unsafeRemovePeer(device, peer, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) FlushPacketQueues() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case elem, ok := <-device.queue.decryption:
|
||||||
|
if ok {
|
||||||
|
elem.Drop()
|
||||||
|
}
|
||||||
|
case elem, ok := <-device.queue.encryption:
|
||||||
|
if ok {
|
||||||
|
elem.Drop()
|
||||||
|
}
|
||||||
|
case <-device.queue.handshake:
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) Close() {
|
||||||
|
if device.isClosed.Swap(true) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
device.log.Info.Println("Device closing")
|
||||||
|
device.state.changing.Set(true)
|
||||||
|
device.state.mutex.Lock()
|
||||||
|
defer device.state.mutex.Unlock()
|
||||||
|
|
||||||
|
device.tun.device.Close()
|
||||||
|
device.BindClose()
|
||||||
|
|
||||||
|
device.isUp.Set(false)
|
||||||
|
|
||||||
|
close(device.signals.stop)
|
||||||
|
|
||||||
|
device.state.stopping.Wait()
|
||||||
|
device.FlushPacketQueues()
|
||||||
|
|
||||||
|
device.RemoveAllPeers()
|
||||||
|
device.rate.limiter.Close()
|
||||||
|
|
||||||
|
device.state.changing.Set(false)
|
||||||
|
device.log.Info.Println("Interface closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) Wait() chan struct{} {
|
||||||
|
return device.signals.stop
|
||||||
|
}
|
|
@ -1,294 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"container/list"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"math/bits"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"sync"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
type parentIndirection struct {
|
|
||||||
parentBit **trieEntry
|
|
||||||
parentBitType uint8
|
|
||||||
}
|
|
||||||
|
|
||||||
type trieEntry struct {
|
|
||||||
peer *Peer
|
|
||||||
child [2]*trieEntry
|
|
||||||
parent parentIndirection
|
|
||||||
cidr uint8
|
|
||||||
bitAtByte uint8
|
|
||||||
bitAtShift uint8
|
|
||||||
bits []byte
|
|
||||||
perPeerElem *list.Element
|
|
||||||
}
|
|
||||||
|
|
||||||
func commonBits(ip1, ip2 []byte) uint8 {
|
|
||||||
size := len(ip1)
|
|
||||||
if size == net.IPv4len {
|
|
||||||
a := binary.BigEndian.Uint32(ip1)
|
|
||||||
b := binary.BigEndian.Uint32(ip2)
|
|
||||||
x := a ^ b
|
|
||||||
return uint8(bits.LeadingZeros32(x))
|
|
||||||
} else if size == net.IPv6len {
|
|
||||||
a := binary.BigEndian.Uint64(ip1)
|
|
||||||
b := binary.BigEndian.Uint64(ip2)
|
|
||||||
x := a ^ b
|
|
||||||
if x != 0 {
|
|
||||||
return uint8(bits.LeadingZeros64(x))
|
|
||||||
}
|
|
||||||
a = binary.BigEndian.Uint64(ip1[8:])
|
|
||||||
b = binary.BigEndian.Uint64(ip2[8:])
|
|
||||||
x = a ^ b
|
|
||||||
return 64 + uint8(bits.LeadingZeros64(x))
|
|
||||||
} else {
|
|
||||||
panic("Wrong size bit string")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (node *trieEntry) addToPeerEntries() {
|
|
||||||
node.perPeerElem = node.peer.trieEntries.PushBack(node)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (node *trieEntry) removeFromPeerEntries() {
|
|
||||||
if node.perPeerElem != nil {
|
|
||||||
node.peer.trieEntries.Remove(node.perPeerElem)
|
|
||||||
node.perPeerElem = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (node *trieEntry) choose(ip []byte) byte {
|
|
||||||
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func (node *trieEntry) maskSelf() {
|
|
||||||
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
|
|
||||||
for i := 0; i < len(mask); i++ {
|
|
||||||
node.bits[i] &= mask[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (node *trieEntry) zeroizePointers() {
|
|
||||||
// Make the garbage collector's life slightly easier
|
|
||||||
node.peer = nil
|
|
||||||
node.child[0] = nil
|
|
||||||
node.child[1] = nil
|
|
||||||
node.parent.parentBit = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
|
|
||||||
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
|
|
||||||
parent = node
|
|
||||||
if parent.cidr == cidr {
|
|
||||||
exact = true
|
|
||||||
return
|
|
||||||
}
|
|
||||||
bit := node.choose(ip)
|
|
||||||
node = node.child[bit]
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
|
|
||||||
if *trie.parentBit == nil {
|
|
||||||
node := &trieEntry{
|
|
||||||
peer: peer,
|
|
||||||
parent: trie,
|
|
||||||
bits: ip,
|
|
||||||
cidr: cidr,
|
|
||||||
bitAtByte: cidr / 8,
|
|
||||||
bitAtShift: 7 - (cidr % 8),
|
|
||||||
}
|
|
||||||
node.maskSelf()
|
|
||||||
node.addToPeerEntries()
|
|
||||||
*trie.parentBit = node
|
|
||||||
return
|
|
||||||
}
|
|
||||||
node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
|
|
||||||
if exact {
|
|
||||||
node.removeFromPeerEntries()
|
|
||||||
node.peer = peer
|
|
||||||
node.addToPeerEntries()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
newNode := &trieEntry{
|
|
||||||
peer: peer,
|
|
||||||
bits: ip,
|
|
||||||
cidr: cidr,
|
|
||||||
bitAtByte: cidr / 8,
|
|
||||||
bitAtShift: 7 - (cidr % 8),
|
|
||||||
}
|
|
||||||
newNode.maskSelf()
|
|
||||||
newNode.addToPeerEntries()
|
|
||||||
|
|
||||||
var down *trieEntry
|
|
||||||
if node == nil {
|
|
||||||
down = *trie.parentBit
|
|
||||||
} else {
|
|
||||||
bit := node.choose(ip)
|
|
||||||
down = node.child[bit]
|
|
||||||
if down == nil {
|
|
||||||
newNode.parent = parentIndirection{&node.child[bit], bit}
|
|
||||||
node.child[bit] = newNode
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
common := commonBits(down.bits, ip)
|
|
||||||
if common < cidr {
|
|
||||||
cidr = common
|
|
||||||
}
|
|
||||||
parent := node
|
|
||||||
|
|
||||||
if newNode.cidr == cidr {
|
|
||||||
bit := newNode.choose(down.bits)
|
|
||||||
down.parent = parentIndirection{&newNode.child[bit], bit}
|
|
||||||
newNode.child[bit] = down
|
|
||||||
if parent == nil {
|
|
||||||
newNode.parent = trie
|
|
||||||
*trie.parentBit = newNode
|
|
||||||
} else {
|
|
||||||
bit := parent.choose(newNode.bits)
|
|
||||||
newNode.parent = parentIndirection{&parent.child[bit], bit}
|
|
||||||
parent.child[bit] = newNode
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
node = &trieEntry{
|
|
||||||
bits: append([]byte{}, newNode.bits...),
|
|
||||||
cidr: cidr,
|
|
||||||
bitAtByte: cidr / 8,
|
|
||||||
bitAtShift: 7 - (cidr % 8),
|
|
||||||
}
|
|
||||||
node.maskSelf()
|
|
||||||
|
|
||||||
bit := node.choose(down.bits)
|
|
||||||
down.parent = parentIndirection{&node.child[bit], bit}
|
|
||||||
node.child[bit] = down
|
|
||||||
bit = node.choose(newNode.bits)
|
|
||||||
newNode.parent = parentIndirection{&node.child[bit], bit}
|
|
||||||
node.child[bit] = newNode
|
|
||||||
if parent == nil {
|
|
||||||
node.parent = trie
|
|
||||||
*trie.parentBit = node
|
|
||||||
} else {
|
|
||||||
bit := parent.choose(node.bits)
|
|
||||||
node.parent = parentIndirection{&parent.child[bit], bit}
|
|
||||||
parent.child[bit] = node
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (node *trieEntry) lookup(ip []byte) *Peer {
|
|
||||||
var found *Peer
|
|
||||||
size := uint8(len(ip))
|
|
||||||
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
|
||||||
if node.peer != nil {
|
|
||||||
found = node.peer
|
|
||||||
}
|
|
||||||
if node.bitAtByte == size {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
bit := node.choose(ip)
|
|
||||||
node = node.child[bit]
|
|
||||||
}
|
|
||||||
return found
|
|
||||||
}
|
|
||||||
|
|
||||||
type AllowedIPs struct {
|
|
||||||
IPv4 *trieEntry
|
|
||||||
IPv6 *trieEntry
|
|
||||||
mutex sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
|
|
||||||
table.mutex.RLock()
|
|
||||||
defer table.mutex.RUnlock()
|
|
||||||
|
|
||||||
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
|
|
||||||
node := elem.Value.(*trieEntry)
|
|
||||||
a, _ := netip.AddrFromSlice(node.bits)
|
|
||||||
if !cb(netip.PrefixFrom(a, int(node.cidr))) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
|
||||||
table.mutex.Lock()
|
|
||||||
defer table.mutex.Unlock()
|
|
||||||
|
|
||||||
var next *list.Element
|
|
||||||
for elem := peer.trieEntries.Front(); elem != nil; elem = next {
|
|
||||||
next = elem.Next()
|
|
||||||
node := elem.Value.(*trieEntry)
|
|
||||||
|
|
||||||
node.removeFromPeerEntries()
|
|
||||||
node.peer = nil
|
|
||||||
if node.child[0] != nil && node.child[1] != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
bit := 0
|
|
||||||
if node.child[0] == nil {
|
|
||||||
bit = 1
|
|
||||||
}
|
|
||||||
child := node.child[bit]
|
|
||||||
if child != nil {
|
|
||||||
child.parent = node.parent
|
|
||||||
}
|
|
||||||
*node.parent.parentBit = child
|
|
||||||
if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
|
|
||||||
node.zeroizePointers()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
|
|
||||||
if parent.peer != nil {
|
|
||||||
node.zeroizePointers()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
child = parent.child[node.parent.parentBitType^1]
|
|
||||||
if child != nil {
|
|
||||||
child.parent = parent.parent
|
|
||||||
}
|
|
||||||
*parent.parent.parentBit = child
|
|
||||||
node.zeroizePointers()
|
|
||||||
parent.zeroizePointers()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
|
|
||||||
table.mutex.Lock()
|
|
||||||
defer table.mutex.Unlock()
|
|
||||||
|
|
||||||
if prefix.Addr().Is6() {
|
|
||||||
ip := prefix.Addr().As16()
|
|
||||||
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
|
||||||
} else if prefix.Addr().Is4() {
|
|
||||||
ip := prefix.Addr().As4()
|
|
||||||
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
|
||||||
} else {
|
|
||||||
panic(errors.New("inserting unknown address type"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
|
|
||||||
table.mutex.RLock()
|
|
||||||
defer table.mutex.RUnlock()
|
|
||||||
switch len(ip) {
|
|
||||||
case net.IPv6len:
|
|
||||||
return table.IPv6.lookup(ip)
|
|
||||||
case net.IPv4len:
|
|
||||||
return table.IPv4.lookup(ip)
|
|
||||||
default:
|
|
||||||
panic(errors.New("looking up unknown address type"))
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,141 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"sort"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
NumberOfPeers = 100
|
|
||||||
NumberOfPeerRemovals = 4
|
|
||||||
NumberOfAddresses = 250
|
|
||||||
NumberOfTests = 10000
|
|
||||||
)
|
|
||||||
|
|
||||||
type SlowNode struct {
|
|
||||||
peer *Peer
|
|
||||||
cidr uint8
|
|
||||||
bits []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type SlowRouter []*SlowNode
|
|
||||||
|
|
||||||
func (r SlowRouter) Len() int {
|
|
||||||
return len(r)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r SlowRouter) Less(i, j int) bool {
|
|
||||||
return r[i].cidr > r[j].cidr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r SlowRouter) Swap(i, j int) {
|
|
||||||
r[i], r[j] = r[j], r[i]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r SlowRouter) Insert(addr []byte, cidr uint8, peer *Peer) SlowRouter {
|
|
||||||
for _, t := range r {
|
|
||||||
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
|
|
||||||
t.peer = peer
|
|
||||||
t.bits = addr
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
}
|
|
||||||
r = append(r, &SlowNode{
|
|
||||||
cidr: cidr,
|
|
||||||
bits: addr,
|
|
||||||
peer: peer,
|
|
||||||
})
|
|
||||||
sort.Sort(r)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r SlowRouter) Lookup(addr []byte) *Peer {
|
|
||||||
for _, t := range r {
|
|
||||||
common := commonBits(t.bits, addr)
|
|
||||||
if common >= t.cidr {
|
|
||||||
return t.peer
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r SlowRouter) RemoveByPeer(peer *Peer) SlowRouter {
|
|
||||||
n := 0
|
|
||||||
for _, x := range r {
|
|
||||||
if x.peer != peer {
|
|
||||||
r[n] = x
|
|
||||||
n++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return r[:n]
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTrieRandom(t *testing.T) {
|
|
||||||
var slow4, slow6 SlowRouter
|
|
||||||
var peers []*Peer
|
|
||||||
var allowedIPs AllowedIPs
|
|
||||||
|
|
||||||
rand.Seed(1)
|
|
||||||
|
|
||||||
for n := 0; n < NumberOfPeers; n++ {
|
|
||||||
peers = append(peers, &Peer{})
|
|
||||||
}
|
|
||||||
|
|
||||||
for n := 0; n < NumberOfAddresses; n++ {
|
|
||||||
var addr4 [4]byte
|
|
||||||
rand.Read(addr4[:])
|
|
||||||
cidr := uint8(rand.Intn(32) + 1)
|
|
||||||
index := rand.Intn(NumberOfPeers)
|
|
||||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
|
|
||||||
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
|
|
||||||
|
|
||||||
var addr6 [16]byte
|
|
||||||
rand.Read(addr6[:])
|
|
||||||
cidr = uint8(rand.Intn(128) + 1)
|
|
||||||
index = rand.Intn(NumberOfPeers)
|
|
||||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
|
|
||||||
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
|
|
||||||
}
|
|
||||||
|
|
||||||
var p int
|
|
||||||
for p = 0; ; p++ {
|
|
||||||
for n := 0; n < NumberOfTests; n++ {
|
|
||||||
var addr4 [4]byte
|
|
||||||
rand.Read(addr4[:])
|
|
||||||
peer1 := slow4.Lookup(addr4[:])
|
|
||||||
peer2 := allowedIPs.Lookup(addr4[:])
|
|
||||||
if peer1 != peer2 {
|
|
||||||
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr4[:]), peer1, peer2)
|
|
||||||
}
|
|
||||||
|
|
||||||
var addr6 [16]byte
|
|
||||||
rand.Read(addr6[:])
|
|
||||||
peer1 = slow6.Lookup(addr6[:])
|
|
||||||
peer2 = allowedIPs.Lookup(addr6[:])
|
|
||||||
if peer1 != peer2 {
|
|
||||||
t.Errorf("Trie did not match naive implementation, for %v: want %p, got %p", net.IP(addr6[:]), peer1, peer2)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if p >= len(peers) || p >= NumberOfPeerRemovals {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
allowedIPs.RemoveByPeer(peers[p])
|
|
||||||
slow4 = slow4.RemoveByPeer(peers[p])
|
|
||||||
slow6 = slow6.RemoveByPeer(peers[p])
|
|
||||||
}
|
|
||||||
for ; p < len(peers); p++ {
|
|
||||||
allowedIPs.RemoveByPeer(peers[p])
|
|
||||||
}
|
|
||||||
|
|
||||||
if allowedIPs.IPv4 != nil || allowedIPs.IPv6 != nil {
|
|
||||||
t.Error("Failed to remove all nodes from trie by peer")
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,137 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// An outboundQueue is a channel of QueueOutboundElements awaiting encryption.
|
|
||||||
// An outboundQueue is ref-counted using its wg field.
|
|
||||||
// An outboundQueue created with newOutboundQueue has one reference.
|
|
||||||
// Every additional writer must call wg.Add(1).
|
|
||||||
// Every completed writer must call wg.Done().
|
|
||||||
// When no further writers will be added,
|
|
||||||
// call wg.Done to remove the initial reference.
|
|
||||||
// When the refcount hits 0, the queue's channel is closed.
|
|
||||||
type outboundQueue struct {
|
|
||||||
c chan *QueueOutboundElementsContainer
|
|
||||||
wg sync.WaitGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
func newOutboundQueue() *outboundQueue {
|
|
||||||
q := &outboundQueue{
|
|
||||||
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
|
||||||
}
|
|
||||||
q.wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
q.wg.Wait()
|
|
||||||
close(q.c)
|
|
||||||
}()
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// A inboundQueue is similar to an outboundQueue; see those docs.
|
|
||||||
type inboundQueue struct {
|
|
||||||
c chan *QueueInboundElementsContainer
|
|
||||||
wg sync.WaitGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
func newInboundQueue() *inboundQueue {
|
|
||||||
q := &inboundQueue{
|
|
||||||
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
|
||||||
}
|
|
||||||
q.wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
q.wg.Wait()
|
|
||||||
close(q.c)
|
|
||||||
}()
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// A handshakeQueue is similar to an outboundQueue; see those docs.
|
|
||||||
type handshakeQueue struct {
|
|
||||||
c chan QueueHandshakeElement
|
|
||||||
wg sync.WaitGroup
|
|
||||||
}
|
|
||||||
|
|
||||||
func newHandshakeQueue() *handshakeQueue {
|
|
||||||
q := &handshakeQueue{
|
|
||||||
c: make(chan QueueHandshakeElement, QueueHandshakeSize),
|
|
||||||
}
|
|
||||||
q.wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
q.wg.Wait()
|
|
||||||
close(q.c)
|
|
||||||
}()
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
type autodrainingInboundQueue struct {
|
|
||||||
c chan *QueueInboundElementsContainer
|
|
||||||
}
|
|
||||||
|
|
||||||
// newAutodrainingInboundQueue returns a channel that will be drained when it gets GC'd.
|
|
||||||
// It is useful in cases in which is it hard to manage the lifetime of the channel.
|
|
||||||
// The returned channel must not be closed. Senders should signal shutdown using
|
|
||||||
// some other means, such as sending a sentinel nil values.
|
|
||||||
func newAutodrainingInboundQueue(device *Device) *autodrainingInboundQueue {
|
|
||||||
q := &autodrainingInboundQueue{
|
|
||||||
c: make(chan *QueueInboundElementsContainer, QueueInboundSize),
|
|
||||||
}
|
|
||||||
runtime.SetFinalizer(q, device.flushInboundQueue)
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case elemsContainer := <-q.c:
|
|
||||||
elemsContainer.Lock()
|
|
||||||
for _, elem := range elemsContainer.elems {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
device.PutInboundElement(elem)
|
|
||||||
}
|
|
||||||
device.PutInboundElementsContainer(elemsContainer)
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type autodrainingOutboundQueue struct {
|
|
||||||
c chan *QueueOutboundElementsContainer
|
|
||||||
}
|
|
||||||
|
|
||||||
// newAutodrainingOutboundQueue returns a channel that will be drained when it gets GC'd.
|
|
||||||
// It is useful in cases in which is it hard to manage the lifetime of the channel.
|
|
||||||
// The returned channel must not be closed. Senders should signal shutdown using
|
|
||||||
// some other means, such as sending a sentinel nil values.
|
|
||||||
// All sends to the channel must be best-effort, because there may be no receivers.
|
|
||||||
func newAutodrainingOutboundQueue(device *Device) *autodrainingOutboundQueue {
|
|
||||||
q := &autodrainingOutboundQueue{
|
|
||||||
c: make(chan *QueueOutboundElementsContainer, QueueOutboundSize),
|
|
||||||
}
|
|
||||||
runtime.SetFinalizer(q, device.flushOutboundQueue)
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case elemsContainer := <-q.c:
|
|
||||||
elemsContainer.Lock()
|
|
||||||
for _, elem := range elemsContainer.elems {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
device.PutOutboundElement(elem)
|
|
||||||
}
|
|
||||||
device.PutOutboundElementsContainer(elemsContainer)
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,40 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
/* Specification constants */
|
|
||||||
|
|
||||||
const (
|
|
||||||
RekeyAfterMessages = (1 << 60)
|
|
||||||
RejectAfterMessages = (1 << 64) - (1 << 13) - 1
|
|
||||||
RekeyAfterTime = time.Second * 120
|
|
||||||
RekeyAttemptTime = time.Second * 90
|
|
||||||
RekeyTimeout = time.Second * 5
|
|
||||||
MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */
|
|
||||||
RekeyTimeoutJitterMaxMs = 334
|
|
||||||
RejectAfterTime = time.Second * 180
|
|
||||||
KeepaliveTimeout = time.Second * 10
|
|
||||||
CookieRefreshTime = time.Second * 120
|
|
||||||
HandshakeInitationRate = time.Second / 50
|
|
||||||
PaddingMultiple = 16
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive)
|
|
||||||
MaxMessageSize = MaxSegmentSize // maximum size of transport message
|
|
||||||
MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content
|
|
||||||
)
|
|
||||||
|
|
||||||
/* Implementation constants */
|
|
||||||
|
|
||||||
const (
|
|
||||||
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
|
|
||||||
MaxPeers = 1 << 16 // maximum number of configured peers
|
|
||||||
)
|
|
807
device/device.go
807
device/device.go
|
@ -1,807 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/ratelimiter"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
|
||||||
"github.com/tevino/abool/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Device struct {
|
|
||||||
state struct {
|
|
||||||
// state holds the device's state. It is accessed atomically.
|
|
||||||
// Use the device.deviceState method to read it.
|
|
||||||
// device.deviceState does not acquire the mutex, so it captures only a snapshot.
|
|
||||||
// During state transitions, the state variable is updated before the device itself.
|
|
||||||
// The state is thus either the current state of the device or
|
|
||||||
// the intended future state of the device.
|
|
||||||
// For example, while executing a call to Up, state will be deviceStateUp.
|
|
||||||
// There is no guarantee that that intended future state of the device
|
|
||||||
// will become the actual state; Up can fail.
|
|
||||||
// The device can also change state multiple times between time of check and time of use.
|
|
||||||
// Unsynchronized uses of state must therefore be advisory/best-effort only.
|
|
||||||
state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
|
|
||||||
// stopping blocks until all inputs to Device have been closed.
|
|
||||||
stopping sync.WaitGroup
|
|
||||||
// mu protects state changes.
|
|
||||||
sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
net struct {
|
|
||||||
stopping sync.WaitGroup
|
|
||||||
sync.RWMutex
|
|
||||||
bind conn.Bind // bind interface
|
|
||||||
netlinkCancel *rwcancel.RWCancel
|
|
||||||
port uint16 // listening port
|
|
||||||
fwmark uint32 // mark value (0 = disabled)
|
|
||||||
brokenRoaming bool
|
|
||||||
}
|
|
||||||
|
|
||||||
staticIdentity struct {
|
|
||||||
sync.RWMutex
|
|
||||||
privateKey NoisePrivateKey
|
|
||||||
publicKey NoisePublicKey
|
|
||||||
}
|
|
||||||
|
|
||||||
peers struct {
|
|
||||||
sync.RWMutex // protects keyMap
|
|
||||||
keyMap map[NoisePublicKey]*Peer
|
|
||||||
}
|
|
||||||
|
|
||||||
rate struct {
|
|
||||||
underLoadUntil atomic.Int64
|
|
||||||
limiter ratelimiter.Ratelimiter
|
|
||||||
}
|
|
||||||
|
|
||||||
allowedips AllowedIPs
|
|
||||||
indexTable IndexTable
|
|
||||||
cookieChecker CookieChecker
|
|
||||||
|
|
||||||
pool struct {
|
|
||||||
inboundElementsContainer *WaitPool
|
|
||||||
outboundElementsContainer *WaitPool
|
|
||||||
messageBuffers *WaitPool
|
|
||||||
inboundElements *WaitPool
|
|
||||||
outboundElements *WaitPool
|
|
||||||
}
|
|
||||||
|
|
||||||
queue struct {
|
|
||||||
encryption *outboundQueue
|
|
||||||
decryption *inboundQueue
|
|
||||||
handshake *handshakeQueue
|
|
||||||
}
|
|
||||||
|
|
||||||
tun struct {
|
|
||||||
device tun.Device
|
|
||||||
mtu atomic.Int32
|
|
||||||
}
|
|
||||||
|
|
||||||
ipcMutex sync.RWMutex
|
|
||||||
closed chan struct{}
|
|
||||||
log *Logger
|
|
||||||
|
|
||||||
isASecOn abool.AtomicBool
|
|
||||||
aSecMux sync.RWMutex
|
|
||||||
aSecCfg aSecCfgType
|
|
||||||
junkCreator junkCreator
|
|
||||||
}
|
|
||||||
|
|
||||||
type aSecCfgType struct {
|
|
||||||
isSet bool
|
|
||||||
junkPacketCount int
|
|
||||||
junkPacketMinSize int
|
|
||||||
junkPacketMaxSize int
|
|
||||||
initPacketJunkSize int
|
|
||||||
responsePacketJunkSize int
|
|
||||||
initPacketMagicHeader uint32
|
|
||||||
responsePacketMagicHeader uint32
|
|
||||||
underloadPacketMagicHeader uint32
|
|
||||||
transportPacketMagicHeader uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
// deviceState represents the state of a Device.
|
|
||||||
// There are three states: down, up, closed.
|
|
||||||
// Transitions:
|
|
||||||
//
|
|
||||||
// down -----+
|
|
||||||
// ↑↓ ↓
|
|
||||||
// up -> closed
|
|
||||||
type deviceState uint32
|
|
||||||
|
|
||||||
//go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
|
|
||||||
const (
|
|
||||||
deviceStateDown deviceState = iota
|
|
||||||
deviceStateUp
|
|
||||||
deviceStateClosed
|
|
||||||
)
|
|
||||||
|
|
||||||
// deviceState returns device.state.state as a deviceState
|
|
||||||
// See those docs for how to interpret this value.
|
|
||||||
func (device *Device) deviceState() deviceState {
|
|
||||||
return deviceState(device.state.state.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
// isClosed reports whether the device is closed (or is closing).
|
|
||||||
// See device.state.state comments for how to interpret this value.
|
|
||||||
func (device *Device) isClosed() bool {
|
|
||||||
return device.deviceState() == deviceStateClosed
|
|
||||||
}
|
|
||||||
|
|
||||||
// isUp reports whether the device is up (or is attempting to come up).
|
|
||||||
// See device.state.state comments for how to interpret this value.
|
|
||||||
func (device *Device) isUp() bool {
|
|
||||||
return device.deviceState() == deviceStateUp
|
|
||||||
}
|
|
||||||
|
|
||||||
// Must hold device.peers.Lock()
|
|
||||||
func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
|
|
||||||
// stop routing and processing of packets
|
|
||||||
device.allowedips.RemoveByPeer(peer)
|
|
||||||
peer.Stop()
|
|
||||||
|
|
||||||
// remove from peer map
|
|
||||||
delete(device.peers.keyMap, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
// changeState attempts to change the device state to match want.
|
|
||||||
func (device *Device) changeState(want deviceState) (err error) {
|
|
||||||
device.state.Lock()
|
|
||||||
defer device.state.Unlock()
|
|
||||||
old := device.deviceState()
|
|
||||||
if old == deviceStateClosed {
|
|
||||||
// once closed, always closed
|
|
||||||
device.log.Verbosef("Interface closed, ignored requested state %s", want)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
switch want {
|
|
||||||
case old:
|
|
||||||
return nil
|
|
||||||
case deviceStateUp:
|
|
||||||
device.state.state.Store(uint32(deviceStateUp))
|
|
||||||
err = device.upLocked()
|
|
||||||
if err == nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
fallthrough // up failed; bring the device all the way back down
|
|
||||||
case deviceStateDown:
|
|
||||||
device.state.state.Store(uint32(deviceStateDown))
|
|
||||||
errDown := device.downLocked()
|
|
||||||
if err == nil {
|
|
||||||
err = errDown
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device.log.Verbosef(
|
|
||||||
"Interface state was %s, requested %s, now %s", old, want, device.deviceState())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// upLocked attempts to bring the device up and reports whether it succeeded.
|
|
||||||
// The caller must hold device.state.mu and is responsible for updating device.state.state.
|
|
||||||
func (device *Device) upLocked() error {
|
|
||||||
if err := device.BindUpdate(); err != nil {
|
|
||||||
device.log.Errorf("Unable to update bind: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// The IPC set operation waits for peers to be created before calling Start() on them,
|
|
||||||
// so if there's a concurrent IPC set request happening, we should wait for it to complete.
|
|
||||||
device.ipcMutex.Lock()
|
|
||||||
defer device.ipcMutex.Unlock()
|
|
||||||
|
|
||||||
device.peers.RLock()
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.Start()
|
|
||||||
if peer.persistentKeepaliveInterval.Load() > 0 {
|
|
||||||
peer.SendKeepalive()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// downLocked attempts to bring the device down.
|
|
||||||
// The caller must hold device.state.mu and is responsible for updating device.state.state.
|
|
||||||
func (device *Device) downLocked() error {
|
|
||||||
err := device.BindClose()
|
|
||||||
if err != nil {
|
|
||||||
device.log.Errorf("Bind close failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
device.peers.RLock()
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.Stop()
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) Up() error {
|
|
||||||
return device.changeState(deviceStateUp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) Down() error {
|
|
||||||
return device.changeState(deviceStateDown)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) IsUnderLoad() bool {
|
|
||||||
// check if currently under load
|
|
||||||
now := time.Now()
|
|
||||||
underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
|
|
||||||
if underLoad {
|
|
||||||
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// check if recently under load
|
|
||||||
return device.rate.underLoadUntil.Load() > now.UnixNano()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|
||||||
// lock required resources
|
|
||||||
|
|
||||||
device.staticIdentity.Lock()
|
|
||||||
defer device.staticIdentity.Unlock()
|
|
||||||
|
|
||||||
if sk.Equals(device.staticIdentity.privateKey) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
device.peers.Lock()
|
|
||||||
defer device.peers.Unlock()
|
|
||||||
|
|
||||||
lockedPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.handshake.mutex.RLock()
|
|
||||||
lockedPeers = append(lockedPeers, peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove peers with matching public keys
|
|
||||||
|
|
||||||
publicKey := sk.publicKey()
|
|
||||||
for key, peer := range device.peers.keyMap {
|
|
||||||
if peer.handshake.remoteStatic.Equals(publicKey) {
|
|
||||||
peer.handshake.mutex.RUnlock()
|
|
||||||
removePeerLocked(device, peer, key)
|
|
||||||
peer.handshake.mutex.RLock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// update key material
|
|
||||||
|
|
||||||
device.staticIdentity.privateKey = sk
|
|
||||||
device.staticIdentity.publicKey = publicKey
|
|
||||||
device.cookieChecker.Init(publicKey)
|
|
||||||
|
|
||||||
// do static-static DH pre-computations
|
|
||||||
|
|
||||||
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
handshake := &peer.handshake
|
|
||||||
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
|
||||||
expiredPeers = append(expiredPeers, peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, peer := range lockedPeers {
|
|
||||||
peer.handshake.mutex.RUnlock()
|
|
||||||
}
|
|
||||||
for _, peer := range expiredPeers {
|
|
||||||
peer.ExpireCurrentKeypairs()
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
|
||||||
device := new(Device)
|
|
||||||
device.state.state.Store(uint32(deviceStateDown))
|
|
||||||
device.closed = make(chan struct{})
|
|
||||||
device.log = logger
|
|
||||||
device.net.bind = bind
|
|
||||||
device.tun.device = tunDevice
|
|
||||||
mtu, err := device.tun.device.MTU()
|
|
||||||
if err != nil {
|
|
||||||
device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
|
|
||||||
mtu = DefaultMTU
|
|
||||||
}
|
|
||||||
device.tun.mtu.Store(int32(mtu))
|
|
||||||
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
|
||||||
device.rate.limiter.Init()
|
|
||||||
device.indexTable.Init()
|
|
||||||
|
|
||||||
device.PopulatePools()
|
|
||||||
|
|
||||||
// create queues
|
|
||||||
|
|
||||||
device.queue.handshake = newHandshakeQueue()
|
|
||||||
device.queue.encryption = newOutboundQueue()
|
|
||||||
device.queue.decryption = newInboundQueue()
|
|
||||||
|
|
||||||
// start workers
|
|
||||||
|
|
||||||
cpus := runtime.NumCPU()
|
|
||||||
device.state.stopping.Wait()
|
|
||||||
device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
|
|
||||||
for i := 0; i < cpus; i++ {
|
|
||||||
go device.RoutineEncryption(i + 1)
|
|
||||||
go device.RoutineDecryption(i + 1)
|
|
||||||
go device.RoutineHandshake(i + 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
device.state.stopping.Add(1) // RoutineReadFromTUN
|
|
||||||
device.queue.encryption.wg.Add(1) // RoutineReadFromTUN
|
|
||||||
go device.RoutineReadFromTUN()
|
|
||||||
go device.RoutineTUNEventReader()
|
|
||||||
|
|
||||||
return device
|
|
||||||
}
|
|
||||||
|
|
||||||
// BatchSize returns the BatchSize for the device as a whole which is the max of
|
|
||||||
// the bind batch size and the tun batch size. The batch size reported by device
|
|
||||||
// is the size used to construct memory pools, and is the allowed batch size for
|
|
||||||
// the lifetime of the device.
|
|
||||||
func (device *Device) BatchSize() int {
|
|
||||||
size := device.net.bind.BatchSize()
|
|
||||||
dSize := device.tun.device.BatchSize()
|
|
||||||
if size < dSize {
|
|
||||||
size = dSize
|
|
||||||
}
|
|
||||||
return size
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
|
||||||
device.peers.RLock()
|
|
||||||
defer device.peers.RUnlock()
|
|
||||||
|
|
||||||
return device.peers.keyMap[pk]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) RemovePeer(key NoisePublicKey) {
|
|
||||||
device.peers.Lock()
|
|
||||||
defer device.peers.Unlock()
|
|
||||||
// stop peer and remove from routing
|
|
||||||
|
|
||||||
peer, ok := device.peers.keyMap[key]
|
|
||||||
if ok {
|
|
||||||
removePeerLocked(device, peer, key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) RemoveAllPeers() {
|
|
||||||
device.peers.Lock()
|
|
||||||
defer device.peers.Unlock()
|
|
||||||
|
|
||||||
for key, peer := range device.peers.keyMap {
|
|
||||||
removePeerLocked(device, peer, key)
|
|
||||||
}
|
|
||||||
|
|
||||||
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) Close() {
|
|
||||||
device.state.Lock()
|
|
||||||
defer device.state.Unlock()
|
|
||||||
device.ipcMutex.Lock()
|
|
||||||
defer device.ipcMutex.Unlock()
|
|
||||||
if device.isClosed() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
device.state.state.Store(uint32(deviceStateClosed))
|
|
||||||
device.log.Verbosef("Device closing")
|
|
||||||
|
|
||||||
device.tun.device.Close()
|
|
||||||
device.downLocked()
|
|
||||||
|
|
||||||
// Remove peers before closing queues,
|
|
||||||
// because peers assume that queues are active.
|
|
||||||
device.RemoveAllPeers()
|
|
||||||
|
|
||||||
// We kept a reference to the encryption and decryption queues,
|
|
||||||
// in case we started any new peers that might write to them.
|
|
||||||
// No new peers are coming; we are done with these queues.
|
|
||||||
device.queue.encryption.wg.Done()
|
|
||||||
device.queue.decryption.wg.Done()
|
|
||||||
device.queue.handshake.wg.Done()
|
|
||||||
device.state.stopping.Wait()
|
|
||||||
|
|
||||||
device.rate.limiter.Close()
|
|
||||||
|
|
||||||
device.resetProtocol()
|
|
||||||
|
|
||||||
device.log.Verbosef("Device closed")
|
|
||||||
close(device.closed)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) Wait() chan struct{} {
|
|
||||||
return device.closed
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
|
||||||
if !device.isUp() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
device.peers.RLock()
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.keypairs.RLock()
|
|
||||||
sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
|
|
||||||
peer.keypairs.RUnlock()
|
|
||||||
if sendKeepalive {
|
|
||||||
peer.SendKeepalive()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// closeBindLocked closes the device's net.bind.
|
|
||||||
// The caller must hold the net mutex.
|
|
||||||
func closeBindLocked(device *Device) error {
|
|
||||||
var err error
|
|
||||||
netc := &device.net
|
|
||||||
if netc.netlinkCancel != nil {
|
|
||||||
netc.netlinkCancel.Cancel()
|
|
||||||
}
|
|
||||||
if netc.bind != nil {
|
|
||||||
err = netc.bind.Close()
|
|
||||||
}
|
|
||||||
netc.stopping.Wait()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) Bind() conn.Bind {
|
|
||||||
device.net.Lock()
|
|
||||||
defer device.net.Unlock()
|
|
||||||
return device.net.bind
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) BindSetMark(mark uint32) error {
|
|
||||||
device.net.Lock()
|
|
||||||
defer device.net.Unlock()
|
|
||||||
|
|
||||||
// check if modified
|
|
||||||
if device.net.fwmark == mark {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// update fwmark on existing bind
|
|
||||||
device.net.fwmark = mark
|
|
||||||
if device.isUp() && device.net.bind != nil {
|
|
||||||
if err := device.net.bind.SetMark(mark); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// clear cached source addresses
|
|
||||||
device.peers.RLock()
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.markEndpointSrcForClearing()
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) BindUpdate() error {
|
|
||||||
device.net.Lock()
|
|
||||||
defer device.net.Unlock()
|
|
||||||
|
|
||||||
// close existing sockets
|
|
||||||
if err := closeBindLocked(device); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// open new sockets
|
|
||||||
if !device.isUp() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// bind to new port
|
|
||||||
var err error
|
|
||||||
var recvFns []conn.ReceiveFunc
|
|
||||||
netc := &device.net
|
|
||||||
|
|
||||||
recvFns, netc.port, err = netc.bind.Open(netc.port)
|
|
||||||
if err != nil {
|
|
||||||
netc.port = 0
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
|
||||||
if err != nil {
|
|
||||||
netc.bind.Close()
|
|
||||||
netc.port = 0
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// set fwmark
|
|
||||||
if netc.fwmark != 0 {
|
|
||||||
err = netc.bind.SetMark(netc.fwmark)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// clear cached source addresses
|
|
||||||
device.peers.RLock()
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.markEndpointSrcForClearing()
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
|
|
||||||
// start receiving routines
|
|
||||||
device.net.stopping.Add(len(recvFns))
|
|
||||||
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
|
||||||
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
|
||||||
batchSize := netc.bind.BatchSize()
|
|
||||||
for _, fn := range recvFns {
|
|
||||||
go device.RoutineReceiveIncoming(batchSize, fn)
|
|
||||||
}
|
|
||||||
|
|
||||||
device.log.Verbosef("UDP bind has been updated")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) BindClose() error {
|
|
||||||
device.net.Lock()
|
|
||||||
err := closeBindLocked(device)
|
|
||||||
device.net.Unlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
func (device *Device) isAdvancedSecurityOn() bool {
|
|
||||||
return device.isASecOn.IsSet()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) resetProtocol() {
|
|
||||||
// restore default message type values
|
|
||||||
MessageInitiationType = 1
|
|
||||||
MessageResponseType = 2
|
|
||||||
MessageCookieReplyType = 3
|
|
||||||
MessageTransportType = 4
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) handlePostConfig(tempASecCfg *aSecCfgType) (err error) {
|
|
||||||
|
|
||||||
if !tempASecCfg.isSet {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
isASecOn := false
|
|
||||||
device.aSecMux.Lock()
|
|
||||||
if tempASecCfg.junkPacketCount < 0 {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
"JunkPacketCount should be non negative",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
device.aSecCfg.junkPacketCount = tempASecCfg.junkPacketCount
|
|
||||||
if tempASecCfg.junkPacketCount != 0 {
|
|
||||||
isASecOn = true
|
|
||||||
}
|
|
||||||
|
|
||||||
device.aSecCfg.junkPacketMinSize = tempASecCfg.junkPacketMinSize
|
|
||||||
if tempASecCfg.junkPacketMinSize != 0 {
|
|
||||||
isASecOn = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if device.aSecCfg.junkPacketCount > 0 &&
|
|
||||||
tempASecCfg.junkPacketMaxSize == tempASecCfg.junkPacketMinSize {
|
|
||||||
|
|
||||||
tempASecCfg.junkPacketMaxSize++ // to make rand gen work
|
|
||||||
}
|
|
||||||
|
|
||||||
if tempASecCfg.junkPacketMaxSize >= MaxSegmentSize {
|
|
||||||
device.aSecCfg.junkPacketMinSize = 0
|
|
||||||
device.aSecCfg.junkPacketMaxSize = 1
|
|
||||||
if err != nil {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d; %w",
|
|
||||||
tempASecCfg.junkPacketMaxSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
"JunkPacketMaxSize: %d; should be smaller than maxSegmentSize: %d",
|
|
||||||
tempASecCfg.junkPacketMaxSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else if tempASecCfg.junkPacketMaxSize < tempASecCfg.junkPacketMinSize {
|
|
||||||
if err != nil {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
"maxSize: %d; should be greater than minSize: %d; %w",
|
|
||||||
tempASecCfg.junkPacketMaxSize,
|
|
||||||
tempASecCfg.junkPacketMinSize,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
"maxSize: %d; should be greater than minSize: %d",
|
|
||||||
tempASecCfg.junkPacketMaxSize,
|
|
||||||
tempASecCfg.junkPacketMinSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
device.aSecCfg.junkPacketMaxSize = tempASecCfg.junkPacketMaxSize
|
|
||||||
}
|
|
||||||
|
|
||||||
if tempASecCfg.junkPacketMaxSize != 0 {
|
|
||||||
isASecOn = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if MessageInitiationSize+tempASecCfg.initPacketJunkSize >= MaxSegmentSize {
|
|
||||||
if err != nil {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
|
|
||||||
tempASecCfg.initPacketJunkSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
`init header size(148) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
|
||||||
tempASecCfg.initPacketJunkSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
device.aSecCfg.initPacketJunkSize = tempASecCfg.initPacketJunkSize
|
|
||||||
}
|
|
||||||
|
|
||||||
if tempASecCfg.initPacketJunkSize != 0 {
|
|
||||||
isASecOn = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if MessageResponseSize+tempASecCfg.responsePacketJunkSize >= MaxSegmentSize {
|
|
||||||
if err != nil {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d; %w`,
|
|
||||||
tempASecCfg.responsePacketJunkSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
`response header size(92) + junkSize:%d; should be smaller than maxSegmentSize: %d`,
|
|
||||||
tempASecCfg.responsePacketJunkSize,
|
|
||||||
MaxSegmentSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
device.aSecCfg.responsePacketJunkSize = tempASecCfg.responsePacketJunkSize
|
|
||||||
}
|
|
||||||
|
|
||||||
if tempASecCfg.responsePacketJunkSize != 0 {
|
|
||||||
isASecOn = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if tempASecCfg.initPacketMagicHeader > 4 {
|
|
||||||
isASecOn = true
|
|
||||||
device.log.Verbosef("UAPI: Updating init_packet_magic_header")
|
|
||||||
device.aSecCfg.initPacketMagicHeader = tempASecCfg.initPacketMagicHeader
|
|
||||||
MessageInitiationType = device.aSecCfg.initPacketMagicHeader
|
|
||||||
} else {
|
|
||||||
device.log.Verbosef("UAPI: Using default init type")
|
|
||||||
MessageInitiationType = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if tempASecCfg.responsePacketMagicHeader > 4 {
|
|
||||||
isASecOn = true
|
|
||||||
device.log.Verbosef("UAPI: Updating response_packet_magic_header")
|
|
||||||
device.aSecCfg.responsePacketMagicHeader = tempASecCfg.responsePacketMagicHeader
|
|
||||||
MessageResponseType = device.aSecCfg.responsePacketMagicHeader
|
|
||||||
} else {
|
|
||||||
device.log.Verbosef("UAPI: Using default response type")
|
|
||||||
MessageResponseType = 2
|
|
||||||
}
|
|
||||||
|
|
||||||
if tempASecCfg.underloadPacketMagicHeader > 4 {
|
|
||||||
isASecOn = true
|
|
||||||
device.log.Verbosef("UAPI: Updating underload_packet_magic_header")
|
|
||||||
device.aSecCfg.underloadPacketMagicHeader = tempASecCfg.underloadPacketMagicHeader
|
|
||||||
MessageCookieReplyType = device.aSecCfg.underloadPacketMagicHeader
|
|
||||||
} else {
|
|
||||||
device.log.Verbosef("UAPI: Using default underload type")
|
|
||||||
MessageCookieReplyType = 3
|
|
||||||
}
|
|
||||||
|
|
||||||
if tempASecCfg.transportPacketMagicHeader > 4 {
|
|
||||||
isASecOn = true
|
|
||||||
device.log.Verbosef("UAPI: Updating transport_packet_magic_header")
|
|
||||||
device.aSecCfg.transportPacketMagicHeader = tempASecCfg.transportPacketMagicHeader
|
|
||||||
MessageTransportType = device.aSecCfg.transportPacketMagicHeader
|
|
||||||
} else {
|
|
||||||
device.log.Verbosef("UAPI: Using default transport type")
|
|
||||||
MessageTransportType = 4
|
|
||||||
}
|
|
||||||
|
|
||||||
isSameMap := map[uint32]bool{}
|
|
||||||
isSameMap[MessageInitiationType] = true
|
|
||||||
isSameMap[MessageResponseType] = true
|
|
||||||
isSameMap[MessageCookieReplyType] = true
|
|
||||||
isSameMap[MessageTransportType] = true
|
|
||||||
|
|
||||||
// size will be different if same values
|
|
||||||
if len(isSameMap) != 4 {
|
|
||||||
if err != nil {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d; %w`,
|
|
||||||
MessageInitiationType,
|
|
||||||
MessageResponseType,
|
|
||||||
MessageCookieReplyType,
|
|
||||||
MessageTransportType,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
`magic headers should differ; got: init:%d; recv:%d; unde:%d; tran:%d`,
|
|
||||||
MessageInitiationType,
|
|
||||||
MessageResponseType,
|
|
||||||
MessageCookieReplyType,
|
|
||||||
MessageTransportType,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
newInitSize := MessageInitiationSize + device.aSecCfg.initPacketJunkSize
|
|
||||||
newResponseSize := MessageResponseSize + device.aSecCfg.responsePacketJunkSize
|
|
||||||
|
|
||||||
if newInitSize == newResponseSize {
|
|
||||||
if err != nil {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
`new init size:%d; and new response size:%d; should differ; %w`,
|
|
||||||
newInitSize,
|
|
||||||
newResponseSize,
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
err = ipcErrorf(
|
|
||||||
ipc.IpcErrorInvalid,
|
|
||||||
`new init size:%d; and new response size:%d; should differ`,
|
|
||||||
newInitSize,
|
|
||||||
newResponseSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
packetSizeToMsgType = map[int]uint32{
|
|
||||||
newInitSize: MessageInitiationType,
|
|
||||||
newResponseSize: MessageResponseType,
|
|
||||||
MessageCookieReplySize: MessageCookieReplyType,
|
|
||||||
MessageTransportSize: MessageTransportType,
|
|
||||||
}
|
|
||||||
|
|
||||||
msgTypeToJunkSize = map[uint32]int{
|
|
||||||
MessageInitiationType: device.aSecCfg.initPacketJunkSize,
|
|
||||||
MessageResponseType: device.aSecCfg.responsePacketJunkSize,
|
|
||||||
MessageCookieReplyType: 0,
|
|
||||||
MessageTransportType: 0,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
device.isASecOn.SetTo(isASecOn)
|
|
||||||
device.junkCreator, err = NewJunkCreator(device)
|
|
||||||
device.aSecMux.Unlock()
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
|
@ -1,572 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"math/rand"
|
|
||||||
"net/netip"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"runtime/pprof"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
|
|
||||||
)
|
|
||||||
|
|
||||||
// uapiCfg returns a string that contains cfg formatted use with IpcSet.
|
|
||||||
// cfg is a series of alternating key/value strings.
|
|
||||||
// uapiCfg exists because editors and humans like to insert
|
|
||||||
// whitespace into configs, which can cause failures, some of which are silent.
|
|
||||||
// For example, a leading blank newline causes the remainder
|
|
||||||
// of the config to be silently ignored.
|
|
||||||
func uapiCfg(cfg ...string) string {
|
|
||||||
if len(cfg)%2 != 0 {
|
|
||||||
panic("odd number of args to uapiReader")
|
|
||||||
}
|
|
||||||
buf := new(bytes.Buffer)
|
|
||||||
for i, s := range cfg {
|
|
||||||
buf.WriteString(s)
|
|
||||||
sep := byte('\n')
|
|
||||||
if i%2 == 0 {
|
|
||||||
sep = '='
|
|
||||||
}
|
|
||||||
buf.WriteByte(sep)
|
|
||||||
}
|
|
||||||
return buf.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// genConfigs generates a pair of configs that connect to each other.
|
|
||||||
// The configs use distinct, probably-usable ports.
|
|
||||||
func genConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
|
||||||
var key1, key2 NoisePrivateKey
|
|
||||||
_, err := rand.Read(key1[:])
|
|
||||||
if err != nil {
|
|
||||||
tb.Errorf("unable to generate private key random bytes: %v", err)
|
|
||||||
}
|
|
||||||
_, err = rand.Read(key2[:])
|
|
||||||
if err != nil {
|
|
||||||
tb.Errorf("unable to generate private key random bytes: %v", err)
|
|
||||||
}
|
|
||||||
pub1, pub2 := key1.publicKey(), key2.publicKey()
|
|
||||||
|
|
||||||
cfgs[0] = uapiCfg(
|
|
||||||
"private_key", hex.EncodeToString(key1[:]),
|
|
||||||
"listen_port", "0",
|
|
||||||
"replace_peers", "true",
|
|
||||||
"public_key", hex.EncodeToString(pub2[:]),
|
|
||||||
"protocol_version", "1",
|
|
||||||
"replace_allowed_ips", "true",
|
|
||||||
"allowed_ip", "1.0.0.2/32",
|
|
||||||
)
|
|
||||||
endpointCfgs[0] = uapiCfg(
|
|
||||||
"public_key", hex.EncodeToString(pub2[:]),
|
|
||||||
"endpoint", "127.0.0.1:%d",
|
|
||||||
)
|
|
||||||
cfgs[1] = uapiCfg(
|
|
||||||
"private_key", hex.EncodeToString(key2[:]),
|
|
||||||
"listen_port", "0",
|
|
||||||
"replace_peers", "true",
|
|
||||||
"public_key", hex.EncodeToString(pub1[:]),
|
|
||||||
"protocol_version", "1",
|
|
||||||
"replace_allowed_ips", "true",
|
|
||||||
"allowed_ip", "1.0.0.1/32",
|
|
||||||
)
|
|
||||||
endpointCfgs[1] = uapiCfg(
|
|
||||||
"public_key", hex.EncodeToString(pub1[:]),
|
|
||||||
"endpoint", "127.0.0.1:%d",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func genASecurityConfigs(tb testing.TB) (cfgs, endpointCfgs [2]string) {
|
|
||||||
var key1, key2 NoisePrivateKey
|
|
||||||
_, err := rand.Read(key1[:])
|
|
||||||
if err != nil {
|
|
||||||
tb.Errorf("unable to generate private key random bytes: %v", err)
|
|
||||||
}
|
|
||||||
_, err = rand.Read(key2[:])
|
|
||||||
if err != nil {
|
|
||||||
tb.Errorf("unable to generate private key random bytes: %v", err)
|
|
||||||
}
|
|
||||||
pub1, pub2 := key1.publicKey(), key2.publicKey()
|
|
||||||
|
|
||||||
cfgs[0] = uapiCfg(
|
|
||||||
"private_key", hex.EncodeToString(key1[:]),
|
|
||||||
"listen_port", "0",
|
|
||||||
"replace_peers", "true",
|
|
||||||
"jc", "5",
|
|
||||||
"jmin", "500",
|
|
||||||
"jmax", "1000",
|
|
||||||
"s1", "30",
|
|
||||||
"s2", "40",
|
|
||||||
"h1", "123456",
|
|
||||||
"h2", "67543",
|
|
||||||
"h4", "32345",
|
|
||||||
"h3", "123123",
|
|
||||||
"public_key", hex.EncodeToString(pub2[:]),
|
|
||||||
"protocol_version", "1",
|
|
||||||
"replace_allowed_ips", "true",
|
|
||||||
"allowed_ip", "1.0.0.2/32",
|
|
||||||
)
|
|
||||||
endpointCfgs[0] = uapiCfg(
|
|
||||||
"public_key", hex.EncodeToString(pub2[:]),
|
|
||||||
"endpoint", "127.0.0.1:%d",
|
|
||||||
)
|
|
||||||
cfgs[1] = uapiCfg(
|
|
||||||
"private_key", hex.EncodeToString(key2[:]),
|
|
||||||
"listen_port", "0",
|
|
||||||
"replace_peers", "true",
|
|
||||||
"jc", "5",
|
|
||||||
"jmin", "500",
|
|
||||||
"jmax", "1000",
|
|
||||||
"s1", "30",
|
|
||||||
"s2", "40",
|
|
||||||
"h1", "123456",
|
|
||||||
"h2", "67543",
|
|
||||||
"h4", "32345",
|
|
||||||
"h3", "123123",
|
|
||||||
"public_key", hex.EncodeToString(pub1[:]),
|
|
||||||
"protocol_version", "1",
|
|
||||||
"replace_allowed_ips", "true",
|
|
||||||
"allowed_ip", "1.0.0.1/32",
|
|
||||||
)
|
|
||||||
endpointCfgs[1] = uapiCfg(
|
|
||||||
"public_key", hex.EncodeToString(pub1[:]),
|
|
||||||
"endpoint", "127.0.0.1:%d",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// A testPair is a pair of testPeers.
|
|
||||||
type testPair [2]testPeer
|
|
||||||
|
|
||||||
// A testPeer is a peer used for testing.
|
|
||||||
type testPeer struct {
|
|
||||||
tun *tuntest.ChannelTUN
|
|
||||||
dev *Device
|
|
||||||
ip netip.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
type SendDirection bool
|
|
||||||
|
|
||||||
const (
|
|
||||||
Ping SendDirection = true
|
|
||||||
Pong SendDirection = false
|
|
||||||
)
|
|
||||||
|
|
||||||
func (d SendDirection) String() string {
|
|
||||||
if d == Ping {
|
|
||||||
return "ping"
|
|
||||||
}
|
|
||||||
return "pong"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (pair *testPair) Send(
|
|
||||||
tb testing.TB,
|
|
||||||
ping SendDirection,
|
|
||||||
done chan struct{},
|
|
||||||
) {
|
|
||||||
tb.Helper()
|
|
||||||
p0, p1 := pair[0], pair[1]
|
|
||||||
if !ping {
|
|
||||||
// pong is the new ping
|
|
||||||
p0, p1 = p1, p0
|
|
||||||
}
|
|
||||||
msg := tuntest.Ping(p0.ip, p1.ip)
|
|
||||||
p1.tun.Outbound <- msg
|
|
||||||
timer := time.NewTimer(5 * time.Second)
|
|
||||||
defer timer.Stop()
|
|
||||||
var err error
|
|
||||||
select {
|
|
||||||
case msgRecv := <-p0.tun.Inbound:
|
|
||||||
if !bytes.Equal(msg, msgRecv) {
|
|
||||||
err = fmt.Errorf("%s did not transit correctly", ping)
|
|
||||||
}
|
|
||||||
case <-timer.C:
|
|
||||||
err = fmt.Errorf("%s did not transit", ping)
|
|
||||||
case <-done:
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
// The error may have occurred because the test is done.
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
// Real error.
|
|
||||||
tb.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// genTestPair creates a testPair.
|
|
||||||
func genTestPair(
|
|
||||||
tb testing.TB,
|
|
||||||
realSocket, withASecurity bool,
|
|
||||||
) (pair testPair) {
|
|
||||||
var cfg, endpointCfg [2]string
|
|
||||||
if withASecurity {
|
|
||||||
cfg, endpointCfg = genASecurityConfigs(tb)
|
|
||||||
} else {
|
|
||||||
cfg, endpointCfg = genConfigs(tb)
|
|
||||||
}
|
|
||||||
var binds [2]conn.Bind
|
|
||||||
if realSocket {
|
|
||||||
binds[0], binds[1] = conn.NewDefaultBind(), conn.NewDefaultBind()
|
|
||||||
} else {
|
|
||||||
binds = bindtest.NewChannelBinds()
|
|
||||||
}
|
|
||||||
// Bring up a ChannelTun for each config.
|
|
||||||
for i := range pair {
|
|
||||||
p := &pair[i]
|
|
||||||
p.tun = tuntest.NewChannelTUN()
|
|
||||||
p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
|
|
||||||
level := LogLevelVerbose
|
|
||||||
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
|
||||||
level = LogLevelError
|
|
||||||
}
|
|
||||||
p.dev = NewDevice(p.tun.TUN(), binds[i], NewLogger(level, fmt.Sprintf("dev%d: ", i)))
|
|
||||||
if err := p.dev.IpcSet(cfg[i]); err != nil {
|
|
||||||
tb.Errorf("failed to configure device %d: %v", i, err)
|
|
||||||
p.dev.Close()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if err := p.dev.Up(); err != nil {
|
|
||||||
tb.Errorf("failed to bring up device %d: %v", i, err)
|
|
||||||
p.dev.Close()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
endpointCfg[i^1] = fmt.Sprintf(endpointCfg[i^1], p.dev.net.port)
|
|
||||||
}
|
|
||||||
for i := range pair {
|
|
||||||
p := &pair[i]
|
|
||||||
if err := p.dev.IpcSet(endpointCfg[i]); err != nil {
|
|
||||||
tb.Errorf("failed to configure device endpoint %d: %v", i, err)
|
|
||||||
p.dev.Close()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// The device is ready. Close it when the test completes.
|
|
||||||
tb.Cleanup(p.dev.Close)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTwoDevicePing(t *testing.T) {
|
|
||||||
goroutineLeakCheck(t)
|
|
||||||
pair := genTestPair(t, true, false)
|
|
||||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
|
||||||
pair.Send(t, Ping, nil)
|
|
||||||
})
|
|
||||||
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
|
||||||
pair.Send(t, Pong, nil)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestASecurityTwoDevicePing(t *testing.T) {
|
|
||||||
goroutineLeakCheck(t)
|
|
||||||
pair := genTestPair(t, true, true)
|
|
||||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
|
||||||
pair.Send(t, Ping, nil)
|
|
||||||
})
|
|
||||||
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
|
||||||
pair.Send(t, Pong, nil)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUpDown(t *testing.T) {
|
|
||||||
goroutineLeakCheck(t)
|
|
||||||
const itrials = 50
|
|
||||||
const otrials = 10
|
|
||||||
|
|
||||||
for n := 0; n < otrials; n++ {
|
|
||||||
pair := genTestPair(t, false, false)
|
|
||||||
for i := range pair {
|
|
||||||
for k := range pair[i].dev.peers.keyMap {
|
|
||||||
pair[i].dev.IpcSet(fmt.Sprintf("public_key=%s\npersistent_keepalive_interval=1\n", hex.EncodeToString(k[:])))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(len(pair))
|
|
||||||
for i := range pair {
|
|
||||||
go func(d *Device) {
|
|
||||||
defer wg.Done()
|
|
||||||
for i := 0; i < itrials; i++ {
|
|
||||||
if err := d.Up(); err != nil {
|
|
||||||
t.Errorf("failed up bring up device: %v", err)
|
|
||||||
}
|
|
||||||
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
|
|
||||||
if err := d.Down(); err != nil {
|
|
||||||
t.Errorf("failed to bring down device: %v", err)
|
|
||||||
}
|
|
||||||
time.Sleep(time.Duration(rand.Intn(int(time.Nanosecond * (0x10000 - 1)))))
|
|
||||||
}
|
|
||||||
}(pair[i].dev)
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
for i := range pair {
|
|
||||||
pair[i].dev.Up()
|
|
||||||
pair[i].dev.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestConcurrencySafety does other things concurrently with tunnel use.
|
|
||||||
// It is intended to be used with the race detector to catch data races.
|
|
||||||
func TestConcurrencySafety(t *testing.T) {
|
|
||||||
pair := genTestPair(t, true, false)
|
|
||||||
done := make(chan struct{})
|
|
||||||
|
|
||||||
const warmupIters = 10
|
|
||||||
var warmup sync.WaitGroup
|
|
||||||
warmup.Add(warmupIters)
|
|
||||||
go func() {
|
|
||||||
// Send data continuously back and forth until we're done.
|
|
||||||
// Note that we may continue to attempt to send data
|
|
||||||
// even after done is closed.
|
|
||||||
i := warmupIters
|
|
||||||
for ping := Ping; ; ping = !ping {
|
|
||||||
pair.Send(t, ping, done)
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
if i > 0 {
|
|
||||||
warmup.Done()
|
|
||||||
i--
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
warmup.Wait()
|
|
||||||
|
|
||||||
applyCfg := func(cfg string) {
|
|
||||||
err := pair[0].dev.IpcSet(cfg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Change persistent_keepalive_interval concurrently with tunnel use.
|
|
||||||
t.Run("persistentKeepaliveInterval", func(t *testing.T) {
|
|
||||||
var pub NoisePublicKey
|
|
||||||
for key := range pair[0].dev.peers.keyMap {
|
|
||||||
pub = key
|
|
||||||
break
|
|
||||||
}
|
|
||||||
cfg := uapiCfg(
|
|
||||||
"public_key", hex.EncodeToString(pub[:]),
|
|
||||||
"persistent_keepalive_interval", "1",
|
|
||||||
)
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
applyCfg(cfg)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Change private keys concurrently with tunnel use.
|
|
||||||
t.Run("privateKey", func(t *testing.T) {
|
|
||||||
bad := uapiCfg("private_key", "7777777777777777777777777777777777777777777777777777777777777777")
|
|
||||||
good := uapiCfg("private_key", hex.EncodeToString(pair[0].dev.staticIdentity.privateKey[:]))
|
|
||||||
// Set iters to a large number like 1000 to flush out data races quickly.
|
|
||||||
// Don't leave it large. That can cause logical races
|
|
||||||
// in which the handshake is interleaved with key changes
|
|
||||||
// such that the private key appears to be unchanging but
|
|
||||||
// other state gets reset, which can cause handshake failures like
|
|
||||||
// "Received packet with invalid mac1".
|
|
||||||
const iters = 1
|
|
||||||
for i := 0; i < iters; i++ {
|
|
||||||
applyCfg(bad)
|
|
||||||
applyCfg(good)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
// Perform bind updates and keepalive sends concurrently with tunnel use.
|
|
||||||
t.Run("bindUpdate and keepalive", func(t *testing.T) {
|
|
||||||
const iters = 10
|
|
||||||
for i := 0; i < iters; i++ {
|
|
||||||
for _, peer := range pair {
|
|
||||||
peer.dev.BindUpdate()
|
|
||||||
peer.dev.SendKeepalivesToPeersWithCurrentKeypair()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
close(done)
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkLatency(b *testing.B) {
|
|
||||||
pair := genTestPair(b, true, false)
|
|
||||||
|
|
||||||
// Establish a connection.
|
|
||||||
pair.Send(b, Ping, nil)
|
|
||||||
pair.Send(b, Pong, nil)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
pair.Send(b, Ping, nil)
|
|
||||||
pair.Send(b, Pong, nil)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkThroughput(b *testing.B) {
|
|
||||||
pair := genTestPair(b, true, false)
|
|
||||||
|
|
||||||
// Establish a connection.
|
|
||||||
pair.Send(b, Ping, nil)
|
|
||||||
pair.Send(b, Pong, nil)
|
|
||||||
|
|
||||||
// Measure how long it takes to receive b.N packets,
|
|
||||||
// starting when we receive the first packet.
|
|
||||||
var recv atomic.Uint64
|
|
||||||
var elapsed time.Duration
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
var start time.Time
|
|
||||||
for {
|
|
||||||
<-pair[0].tun.Inbound
|
|
||||||
new := recv.Add(1)
|
|
||||||
if new == 1 {
|
|
||||||
start = time.Now()
|
|
||||||
}
|
|
||||||
// Careful! Don't change this to else if; b.N can be equal to 1.
|
|
||||||
if new == uint64(b.N) {
|
|
||||||
elapsed = time.Since(start)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Send packets as fast as we can until we've received enough.
|
|
||||||
ping := tuntest.Ping(pair[0].ip, pair[1].ip)
|
|
||||||
pingc := pair[1].tun.Outbound
|
|
||||||
var sent uint64
|
|
||||||
for recv.Load() != uint64(b.N) {
|
|
||||||
sent++
|
|
||||||
pingc <- ping
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
b.ReportMetric(float64(elapsed)/float64(b.N), "ns/op")
|
|
||||||
b.ReportMetric(1-float64(b.N)/float64(sent), "packet-loss")
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkUAPIGet(b *testing.B) {
|
|
||||||
pair := genTestPair(b, true, false)
|
|
||||||
pair.Send(b, Ping, nil)
|
|
||||||
pair.Send(b, Pong, nil)
|
|
||||||
b.ReportAllocs()
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
pair[0].dev.IpcGetOperation(io.Discard)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func goroutineLeakCheck(t *testing.T) {
|
|
||||||
goroutines := func() (int, []byte) {
|
|
||||||
p := pprof.Lookup("goroutine")
|
|
||||||
b := new(bytes.Buffer)
|
|
||||||
p.WriteTo(b, 1)
|
|
||||||
return p.Count(), b.Bytes()
|
|
||||||
}
|
|
||||||
|
|
||||||
startGoroutines, startStacks := goroutines()
|
|
||||||
t.Cleanup(func() {
|
|
||||||
if t.Failed() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Give goroutines time to exit, if they need it.
|
|
||||||
for i := 0; i < 10000; i++ {
|
|
||||||
if runtime.NumGoroutine() <= startGoroutines {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(1 * time.Millisecond)
|
|
||||||
}
|
|
||||||
endGoroutines, endStacks := goroutines()
|
|
||||||
t.Logf("starting stacks:\n%s\n", startStacks)
|
|
||||||
t.Logf("ending stacks:\n%s\n", endStacks)
|
|
||||||
t.Fatalf("expected %d goroutines, got %d, leak?", startGoroutines, endGoroutines)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type fakeBindSized struct {
|
|
||||||
size int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *fakeBindSized) Open(
|
|
||||||
port uint16,
|
|
||||||
) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
|
|
||||||
return nil, 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *fakeBindSized) Close() error { return nil }
|
|
||||||
|
|
||||||
func (b *fakeBindSized) SetMark(mark uint32) error { return nil }
|
|
||||||
|
|
||||||
func (b *fakeBindSized) Send(bufs [][]byte, ep conn.Endpoint) error { return nil }
|
|
||||||
|
|
||||||
func (b *fakeBindSized) ParseEndpoint(s string) (conn.Endpoint, error) { return nil, nil }
|
|
||||||
|
|
||||||
func (b *fakeBindSized) BatchSize() int { return b.size }
|
|
||||||
|
|
||||||
type fakeTUNDeviceSized struct {
|
|
||||||
size int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *fakeTUNDeviceSized) File() *os.File { return nil }
|
|
||||||
|
|
||||||
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
|
|
||||||
|
|
||||||
func (t *fakeTUNDeviceSized) MTU() (int, error) { return 0, nil }
|
|
||||||
|
|
||||||
func (t *fakeTUNDeviceSized) Name() (string, error) { return "", nil }
|
|
||||||
|
|
||||||
func (t *fakeTUNDeviceSized) Events() <-chan tun.Event { return nil }
|
|
||||||
|
|
||||||
func (t *fakeTUNDeviceSized) Close() error { return nil }
|
|
||||||
|
|
||||||
func (t *fakeTUNDeviceSized) BatchSize() int { return t.size }
|
|
||||||
|
|
||||||
func TestBatchSize(t *testing.T) {
|
|
||||||
d := Device{}
|
|
||||||
|
|
||||||
d.net.bind = &fakeBindSized{1}
|
|
||||||
d.tun.device = &fakeTUNDeviceSized{1}
|
|
||||||
if want, got := 1, d.BatchSize(); got != want {
|
|
||||||
t.Errorf("expected batch size %d, got %d", want, got)
|
|
||||||
}
|
|
||||||
|
|
||||||
d.net.bind = &fakeBindSized{1}
|
|
||||||
d.tun.device = &fakeTUNDeviceSized{128}
|
|
||||||
if want, got := 128, d.BatchSize(); got != want {
|
|
||||||
t.Errorf("expected batch size %d, got %d", want, got)
|
|
||||||
}
|
|
||||||
|
|
||||||
d.net.bind = &fakeBindSized{128}
|
|
||||||
d.tun.device = &fakeTUNDeviceSized{1}
|
|
||||||
if want, got := 128, d.BatchSize(); got != want {
|
|
||||||
t.Errorf("expected batch size %d, got %d", want, got)
|
|
||||||
}
|
|
||||||
|
|
||||||
d.net.bind = &fakeBindSized{128}
|
|
||||||
d.tun.device = &fakeTUNDeviceSized{128}
|
|
||||||
if want, got := 128, d.BatchSize(); got != want {
|
|
||||||
t.Errorf("expected batch size %d, got %d", want, got)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,16 +0,0 @@
|
||||||
// Code generated by "stringer -type deviceState -trimprefix=deviceState"; DO NOT EDIT.
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import "strconv"
|
|
||||||
|
|
||||||
const _deviceState_name = "DownUpClosed"
|
|
||||||
|
|
||||||
var _deviceState_index = [...]uint8{0, 4, 6, 12}
|
|
||||||
|
|
||||||
func (i deviceState) String() string {
|
|
||||||
if i >= deviceState(len(_deviceState_index)-1) {
|
|
||||||
return "deviceState(" + strconv.FormatInt(int64(i), 10) + ")"
|
|
||||||
}
|
|
||||||
return _deviceState_name[_deviceState_index[i]:_deviceState_index[i+1]]
|
|
||||||
}
|
|
|
@ -1,49 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math/rand"
|
|
||||||
"net/netip"
|
|
||||||
)
|
|
||||||
|
|
||||||
type DummyEndpoint struct {
|
|
||||||
src, dst netip.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateDummyEndpoint() (*DummyEndpoint, error) {
|
|
||||||
var src, dst [16]byte
|
|
||||||
if _, err := rand.Read(src[:]); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
_, err := rand.Read(dst[:])
|
|
||||||
return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *DummyEndpoint) ClearSrc() {}
|
|
||||||
|
|
||||||
func (e *DummyEndpoint) SrcToString() string {
|
|
||||||
return netip.AddrPortFrom(e.SrcIP(), 1000).String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *DummyEndpoint) DstToString() string {
|
|
||||||
return netip.AddrPortFrom(e.DstIP(), 1000).String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *DummyEndpoint) DstToBytes() []byte {
|
|
||||||
out := e.DstIP().AsSlice()
|
|
||||||
out = append(out, byte(1000&0xff))
|
|
||||||
out = append(out, byte((1000>>8)&0xff))
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *DummyEndpoint) DstIP() netip.Addr {
|
|
||||||
return e.dst
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *DummyEndpoint) SrcIP() netip.Addr {
|
|
||||||
return e.src
|
|
||||||
}
|
|
|
@ -1,69 +0,0 @@
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
crand "crypto/rand"
|
|
||||||
"fmt"
|
|
||||||
v2 "math/rand/v2"
|
|
||||||
)
|
|
||||||
|
|
||||||
type junkCreator struct {
|
|
||||||
device *Device
|
|
||||||
cha8Rand *v2.ChaCha8
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewJunkCreator(d *Device) (junkCreator, error) {
|
|
||||||
buf := make([]byte, 32)
|
|
||||||
_, err := crand.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
return junkCreator{}, err
|
|
||||||
}
|
|
||||||
return junkCreator{device: d, cha8Rand: v2.NewChaCha8([32]byte(buf))}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should be called with aSecMux RLocked
|
|
||||||
func (jc *junkCreator) createJunkPackets() ([][]byte, error) {
|
|
||||||
if jc.device.aSecCfg.junkPacketCount == 0 {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
junks := make([][]byte, 0, jc.device.aSecCfg.junkPacketCount)
|
|
||||||
for i := 0; i < jc.device.aSecCfg.junkPacketCount; i++ {
|
|
||||||
packetSize := jc.randomPacketSize()
|
|
||||||
junk, err := jc.randomJunkWithSize(packetSize)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("Failed to create junk packet: %v", err)
|
|
||||||
}
|
|
||||||
junks = append(junks, junk)
|
|
||||||
}
|
|
||||||
return junks, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should be called with aSecMux RLocked
|
|
||||||
func (jc *junkCreator) randomPacketSize() int {
|
|
||||||
return int(
|
|
||||||
jc.cha8Rand.Uint64()%uint64(
|
|
||||||
jc.device.aSecCfg.junkPacketMaxSize-jc.device.aSecCfg.junkPacketMinSize,
|
|
||||||
),
|
|
||||||
) + jc.device.aSecCfg.junkPacketMinSize
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should be called with aSecMux RLocked
|
|
||||||
func (jc *junkCreator) appendJunk(writer *bytes.Buffer, size int) error {
|
|
||||||
headerJunk, err := jc.randomJunkWithSize(size)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create header junk: %v", err)
|
|
||||||
}
|
|
||||||
_, err = writer.Write(headerJunk)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to write header junk: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should be called with aSecMux RLocked
|
|
||||||
func (jc *junkCreator) randomJunkWithSize(size int) ([]byte, error) {
|
|
||||||
junk := make([]byte, size)
|
|
||||||
_, err := jc.cha8Rand.Read(junk)
|
|
||||||
return junk, err
|
|
||||||
}
|
|
|
@ -1,124 +0,0 @@
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn/bindtest"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setUpJunkCreator(t *testing.T) (junkCreator, error) {
|
|
||||||
cfg, _ := genASecurityConfigs(t)
|
|
||||||
tun := tuntest.NewChannelTUN()
|
|
||||||
binds := bindtest.NewChannelBinds()
|
|
||||||
level := LogLevelVerbose
|
|
||||||
dev := NewDevice(
|
|
||||||
tun.TUN(),
|
|
||||||
binds[0],
|
|
||||||
NewLogger(level, ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
if err := dev.IpcSet(cfg[0]); err != nil {
|
|
||||||
t.Errorf("failed to configure device %v", err)
|
|
||||||
dev.Close()
|
|
||||||
return junkCreator{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
jc, err := NewJunkCreator(dev)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to create junk creator %v", err)
|
|
||||||
dev.Close()
|
|
||||||
return junkCreator{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return jc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_junkCreator_createJunkPackets(t *testing.T) {
|
|
||||||
jc, err := setUpJunkCreator(t)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Run("", func(t *testing.T) {
|
|
||||||
got, err := jc.createJunkPackets()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf(
|
|
||||||
"junkCreator.createJunkPackets() = %v; failed",
|
|
||||||
err,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
seen := make(map[string]bool)
|
|
||||||
for _, junk := range got {
|
|
||||||
key := string(junk)
|
|
||||||
if seen[key] {
|
|
||||||
t.Errorf(
|
|
||||||
"junkCreator.createJunkPackets() = %v, duplicate key: %v",
|
|
||||||
got,
|
|
||||||
junk,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
seen[key] = true
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_junkCreator_randomJunkWithSize(t *testing.T) {
|
|
||||||
t.Run("", func(t *testing.T) {
|
|
||||||
jc, err := setUpJunkCreator(t)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
r1, _ := jc.randomJunkWithSize(10)
|
|
||||||
r2, _ := jc.randomJunkWithSize(10)
|
|
||||||
fmt.Printf("%v\n%v\n", r1, r2)
|
|
||||||
if bytes.Equal(r1, r2) {
|
|
||||||
t.Errorf("same junks %v", err)
|
|
||||||
jc.device.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_junkCreator_randomPacketSize(t *testing.T) {
|
|
||||||
jc, err := setUpJunkCreator(t)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for range [30]struct{}{} {
|
|
||||||
t.Run("", func(t *testing.T) {
|
|
||||||
if got := jc.randomPacketSize(); jc.device.aSecCfg.junkPacketMinSize > got ||
|
|
||||||
got > jc.device.aSecCfg.junkPacketMaxSize {
|
|
||||||
t.Errorf(
|
|
||||||
"junkCreator.randomPacketSize() = %v, not between range [%v,%v]",
|
|
||||||
got,
|
|
||||||
jc.device.aSecCfg.junkPacketMinSize,
|
|
||||||
jc.device.aSecCfg.junkPacketMaxSize,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_junkCreator_appendJunk(t *testing.T) {
|
|
||||||
jc, err := setUpJunkCreator(t)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
t.Run("", func(t *testing.T) {
|
|
||||||
s := "apple"
|
|
||||||
buffer := bytes.NewBuffer([]byte(s))
|
|
||||||
err := jc.appendJunk(buffer, 30)
|
|
||||||
if err != nil &&
|
|
||||||
buffer.Len() != len(s)+30 {
|
|
||||||
t.Errorf("appendWithJunk() size don't match")
|
|
||||||
}
|
|
||||||
read := make([]byte, 50)
|
|
||||||
buffer.Read(read)
|
|
||||||
fmt.Println(string(read))
|
|
||||||
})
|
|
||||||
}
|
|
|
@ -1,48 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"os"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A Logger provides logging for a Device.
|
|
||||||
// The functions are Printf-style functions.
|
|
||||||
// They must be safe for concurrent use.
|
|
||||||
// They do not require a trailing newline in the format.
|
|
||||||
// If nil, that level of logging will be silent.
|
|
||||||
type Logger struct {
|
|
||||||
Verbosef func(format string, args ...any)
|
|
||||||
Errorf func(format string, args ...any)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Log levels for use with NewLogger.
|
|
||||||
const (
|
|
||||||
LogLevelSilent = iota
|
|
||||||
LogLevelError
|
|
||||||
LogLevelVerbose
|
|
||||||
)
|
|
||||||
|
|
||||||
// Function for use in Logger for discarding logged lines.
|
|
||||||
func DiscardLogf(format string, args ...any) {}
|
|
||||||
|
|
||||||
// NewLogger constructs a Logger that writes to stdout.
|
|
||||||
// It logs at the specified log level and above.
|
|
||||||
// It decorates log lines with the log level, date, time, and prepend.
|
|
||||||
func NewLogger(level int, prepend string) *Logger {
|
|
||||||
logger := &Logger{DiscardLogf, DiscardLogf}
|
|
||||||
logf := func(prefix string) func(string, ...any) {
|
|
||||||
return log.New(os.Stdout, prefix+": "+prepend, log.Ldate|log.Ltime).Printf
|
|
||||||
}
|
|
||||||
if level >= LogLevelVerbose {
|
|
||||||
logger.Verbosef = logf("DEBUG")
|
|
||||||
}
|
|
||||||
if level >= LogLevelError {
|
|
||||||
logger.Errorf = logf("ERROR")
|
|
||||||
}
|
|
||||||
return logger
|
|
||||||
}
|
|
|
@ -1,19 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
// DisableSomeRoamingForBrokenMobileSemantics should ideally be called before peers are created,
|
|
||||||
// though it will try to deal with it, and race maybe, if called after.
|
|
||||||
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
|
|
||||||
device.net.brokenRoaming = true
|
|
||||||
device.peers.RLock()
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.endpoint.Lock()
|
|
||||||
peer.endpoint.disableRoaming = peer.endpoint.val != nil
|
|
||||||
peer.endpoint.Unlock()
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
}
|
|
296
device/peer.go
296
device/peer.go
|
@ -1,296 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"container/list"
|
|
||||||
"errors"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Peer struct {
|
|
||||||
isRunning atomic.Bool
|
|
||||||
keypairs Keypairs
|
|
||||||
handshake Handshake
|
|
||||||
device *Device
|
|
||||||
stopping sync.WaitGroup // routines pending stop
|
|
||||||
txBytes atomic.Uint64 // bytes send to peer (endpoint)
|
|
||||||
rxBytes atomic.Uint64 // bytes received from peer
|
|
||||||
lastHandshakeNano atomic.Int64 // nano seconds since epoch
|
|
||||||
|
|
||||||
endpoint struct {
|
|
||||||
sync.Mutex
|
|
||||||
val conn.Endpoint
|
|
||||||
clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
|
|
||||||
disableRoaming bool
|
|
||||||
}
|
|
||||||
|
|
||||||
timers struct {
|
|
||||||
retransmitHandshake *Timer
|
|
||||||
sendKeepalive *Timer
|
|
||||||
newHandshake *Timer
|
|
||||||
zeroKeyMaterial *Timer
|
|
||||||
persistentKeepalive *Timer
|
|
||||||
handshakeAttempts atomic.Uint32
|
|
||||||
needAnotherKeepalive atomic.Bool
|
|
||||||
sentLastMinuteHandshake atomic.Bool
|
|
||||||
}
|
|
||||||
|
|
||||||
state struct {
|
|
||||||
sync.Mutex // protects against concurrent Start/Stop
|
|
||||||
}
|
|
||||||
|
|
||||||
queue struct {
|
|
||||||
staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
|
|
||||||
outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
|
|
||||||
inbound *autodrainingInboundQueue // sequential ordering of tun writing
|
|
||||||
}
|
|
||||||
|
|
||||||
cookieGenerator CookieGenerator
|
|
||||||
trieEntries list.List
|
|
||||||
persistentKeepaliveInterval atomic.Uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|
||||||
if device.isClosed() {
|
|
||||||
return nil, errors.New("device closed")
|
|
||||||
}
|
|
||||||
|
|
||||||
// lock resources
|
|
||||||
device.staticIdentity.RLock()
|
|
||||||
defer device.staticIdentity.RUnlock()
|
|
||||||
|
|
||||||
device.peers.Lock()
|
|
||||||
defer device.peers.Unlock()
|
|
||||||
|
|
||||||
// check if over limit
|
|
||||||
if len(device.peers.keyMap) >= MaxPeers {
|
|
||||||
return nil, errors.New("too many peers")
|
|
||||||
}
|
|
||||||
|
|
||||||
// create peer
|
|
||||||
peer := new(Peer)
|
|
||||||
|
|
||||||
peer.cookieGenerator.Init(pk)
|
|
||||||
peer.device = device
|
|
||||||
peer.queue.outbound = newAutodrainingOutboundQueue(device)
|
|
||||||
peer.queue.inbound = newAutodrainingInboundQueue(device)
|
|
||||||
peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
|
|
||||||
|
|
||||||
// map public key
|
|
||||||
_, ok := device.peers.keyMap[pk]
|
|
||||||
if ok {
|
|
||||||
return nil, errors.New("adding existing peer")
|
|
||||||
}
|
|
||||||
|
|
||||||
// pre-compute DH
|
|
||||||
handshake := &peer.handshake
|
|
||||||
handshake.mutex.Lock()
|
|
||||||
handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
|
|
||||||
handshake.remoteStatic = pk
|
|
||||||
handshake.mutex.Unlock()
|
|
||||||
|
|
||||||
// reset endpoint
|
|
||||||
peer.endpoint.Lock()
|
|
||||||
peer.endpoint.val = nil
|
|
||||||
peer.endpoint.disableRoaming = false
|
|
||||||
peer.endpoint.clearSrcOnTx = false
|
|
||||||
peer.endpoint.Unlock()
|
|
||||||
|
|
||||||
// init timers
|
|
||||||
peer.timersInit()
|
|
||||||
|
|
||||||
// add
|
|
||||||
device.peers.keyMap[pk] = peer
|
|
||||||
|
|
||||||
return peer, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) SendBuffers(buffers [][]byte) error {
|
|
||||||
peer.device.net.RLock()
|
|
||||||
defer peer.device.net.RUnlock()
|
|
||||||
|
|
||||||
if peer.device.isClosed() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.endpoint.Lock()
|
|
||||||
endpoint := peer.endpoint.val
|
|
||||||
if endpoint == nil {
|
|
||||||
peer.endpoint.Unlock()
|
|
||||||
return errors.New("no known endpoint for peer")
|
|
||||||
}
|
|
||||||
if peer.endpoint.clearSrcOnTx {
|
|
||||||
endpoint.ClearSrc()
|
|
||||||
peer.endpoint.clearSrcOnTx = false
|
|
||||||
}
|
|
||||||
peer.endpoint.Unlock()
|
|
||||||
|
|
||||||
err := peer.device.net.bind.Send(buffers, endpoint)
|
|
||||||
if err == nil {
|
|
||||||
var totalLen uint64
|
|
||||||
for _, b := range buffers {
|
|
||||||
totalLen += uint64(len(b))
|
|
||||||
}
|
|
||||||
peer.txBytes.Add(totalLen)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) String() string {
|
|
||||||
// The awful goo that follows is identical to:
|
|
||||||
//
|
|
||||||
// base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
|
|
||||||
// abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
|
|
||||||
// return fmt.Sprintf("peer(%s)", abbreviatedKey)
|
|
||||||
//
|
|
||||||
// except that it is considerably more efficient.
|
|
||||||
src := peer.handshake.remoteStatic
|
|
||||||
b64 := func(input byte) byte {
|
|
||||||
return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
|
|
||||||
}
|
|
||||||
b := []byte("peer(____…____)")
|
|
||||||
const first = len("peer(")
|
|
||||||
const second = len("peer(____…")
|
|
||||||
b[first+0] = b64((src[0] >> 2) & 63)
|
|
||||||
b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
|
|
||||||
b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
|
|
||||||
b[first+3] = b64(src[2] & 63)
|
|
||||||
b[second+0] = b64(src[29] & 63)
|
|
||||||
b[second+1] = b64((src[30] >> 2) & 63)
|
|
||||||
b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
|
|
||||||
b[second+3] = b64((src[31] << 2) & 63)
|
|
||||||
return string(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) Start() {
|
|
||||||
// should never start a peer on a closed device
|
|
||||||
if peer.device.isClosed() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// prevent simultaneous start/stop operations
|
|
||||||
peer.state.Lock()
|
|
||||||
defer peer.state.Unlock()
|
|
||||||
|
|
||||||
if peer.isRunning.Load() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
device := peer.device
|
|
||||||
device.log.Verbosef("%v - Starting", peer)
|
|
||||||
|
|
||||||
// reset routine state
|
|
||||||
peer.stopping.Wait()
|
|
||||||
peer.stopping.Add(2)
|
|
||||||
|
|
||||||
peer.handshake.mutex.Lock()
|
|
||||||
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
|
||||||
peer.handshake.mutex.Unlock()
|
|
||||||
|
|
||||||
peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
|
|
||||||
|
|
||||||
peer.timersStart()
|
|
||||||
|
|
||||||
device.flushInboundQueue(peer.queue.inbound)
|
|
||||||
device.flushOutboundQueue(peer.queue.outbound)
|
|
||||||
|
|
||||||
// Use the device batch size, not the bind batch size, as the device size is
|
|
||||||
// the size of the batch pools.
|
|
||||||
batchSize := peer.device.BatchSize()
|
|
||||||
go peer.RoutineSequentialSender(batchSize)
|
|
||||||
go peer.RoutineSequentialReceiver(batchSize)
|
|
||||||
|
|
||||||
peer.isRunning.Store(true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) ZeroAndFlushAll() {
|
|
||||||
device := peer.device
|
|
||||||
|
|
||||||
// clear key pairs
|
|
||||||
|
|
||||||
keypairs := &peer.keypairs
|
|
||||||
keypairs.Lock()
|
|
||||||
device.DeleteKeypair(keypairs.previous)
|
|
||||||
device.DeleteKeypair(keypairs.current)
|
|
||||||
device.DeleteKeypair(keypairs.next.Load())
|
|
||||||
keypairs.previous = nil
|
|
||||||
keypairs.current = nil
|
|
||||||
keypairs.next.Store(nil)
|
|
||||||
keypairs.Unlock()
|
|
||||||
|
|
||||||
// clear handshake state
|
|
||||||
|
|
||||||
handshake := &peer.handshake
|
|
||||||
handshake.mutex.Lock()
|
|
||||||
device.indexTable.Delete(handshake.localIndex)
|
|
||||||
handshake.Clear()
|
|
||||||
handshake.mutex.Unlock()
|
|
||||||
|
|
||||||
peer.FlushStagedPackets()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) ExpireCurrentKeypairs() {
|
|
||||||
handshake := &peer.handshake
|
|
||||||
handshake.mutex.Lock()
|
|
||||||
peer.device.indexTable.Delete(handshake.localIndex)
|
|
||||||
handshake.Clear()
|
|
||||||
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
|
||||||
handshake.mutex.Unlock()
|
|
||||||
|
|
||||||
keypairs := &peer.keypairs
|
|
||||||
keypairs.Lock()
|
|
||||||
if keypairs.current != nil {
|
|
||||||
keypairs.current.sendNonce.Store(RejectAfterMessages)
|
|
||||||
}
|
|
||||||
if next := keypairs.next.Load(); next != nil {
|
|
||||||
next.sendNonce.Store(RejectAfterMessages)
|
|
||||||
}
|
|
||||||
keypairs.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) Stop() {
|
|
||||||
peer.state.Lock()
|
|
||||||
defer peer.state.Unlock()
|
|
||||||
|
|
||||||
if !peer.isRunning.Swap(false) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.device.log.Verbosef("%v - Stopping", peer)
|
|
||||||
|
|
||||||
peer.timersStop()
|
|
||||||
// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
|
|
||||||
peer.queue.inbound.c <- nil
|
|
||||||
peer.queue.outbound.c <- nil
|
|
||||||
peer.stopping.Wait()
|
|
||||||
peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us
|
|
||||||
|
|
||||||
peer.ZeroAndFlushAll()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
|
|
||||||
peer.endpoint.Lock()
|
|
||||||
defer peer.endpoint.Unlock()
|
|
||||||
if peer.endpoint.disableRoaming {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
peer.endpoint.clearSrcOnTx = false
|
|
||||||
peer.endpoint.val = endpoint
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) markEndpointSrcForClearing() {
|
|
||||||
peer.endpoint.Lock()
|
|
||||||
defer peer.endpoint.Unlock()
|
|
||||||
if peer.endpoint.val == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
peer.endpoint.clearSrcOnTx = true
|
|
||||||
}
|
|
120
device/pools.go
120
device/pools.go
|
@ -1,120 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
type WaitPool struct {
|
|
||||||
pool sync.Pool
|
|
||||||
cond sync.Cond
|
|
||||||
lock sync.Mutex
|
|
||||||
count atomic.Uint32
|
|
||||||
max uint32
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWaitPool(max uint32, new func() any) *WaitPool {
|
|
||||||
p := &WaitPool{pool: sync.Pool{New: new}, max: max}
|
|
||||||
p.cond = sync.Cond{L: &p.lock}
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *WaitPool) Get() any {
|
|
||||||
if p.max != 0 {
|
|
||||||
p.lock.Lock()
|
|
||||||
for p.count.Load() >= p.max {
|
|
||||||
p.cond.Wait()
|
|
||||||
}
|
|
||||||
p.count.Add(1)
|
|
||||||
p.lock.Unlock()
|
|
||||||
}
|
|
||||||
return p.pool.Get()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *WaitPool) Put(x any) {
|
|
||||||
p.pool.Put(x)
|
|
||||||
if p.max == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
p.count.Add(^uint32(0))
|
|
||||||
p.cond.Signal()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) PopulatePools() {
|
|
||||||
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
|
||||||
s := make([]*QueueInboundElement, 0, device.BatchSize())
|
|
||||||
return &QueueInboundElementsContainer{elems: s}
|
|
||||||
})
|
|
||||||
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
|
||||||
s := make([]*QueueOutboundElement, 0, device.BatchSize())
|
|
||||||
return &QueueOutboundElementsContainer{elems: s}
|
|
||||||
})
|
|
||||||
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
|
||||||
return new([MaxMessageSize]byte)
|
|
||||||
})
|
|
||||||
device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
|
||||||
return new(QueueInboundElement)
|
|
||||||
})
|
|
||||||
device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
|
|
||||||
return new(QueueOutboundElement)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) GetInboundElementsContainer() *QueueInboundElementsContainer {
|
|
||||||
c := device.pool.inboundElementsContainer.Get().(*QueueInboundElementsContainer)
|
|
||||||
c.Mutex = sync.Mutex{}
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) PutInboundElementsContainer(c *QueueInboundElementsContainer) {
|
|
||||||
for i := range c.elems {
|
|
||||||
c.elems[i] = nil
|
|
||||||
}
|
|
||||||
c.elems = c.elems[:0]
|
|
||||||
device.pool.inboundElementsContainer.Put(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) GetOutboundElementsContainer() *QueueOutboundElementsContainer {
|
|
||||||
c := device.pool.outboundElementsContainer.Get().(*QueueOutboundElementsContainer)
|
|
||||||
c.Mutex = sync.Mutex{}
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsContainer) {
|
|
||||||
for i := range c.elems {
|
|
||||||
c.elems[i] = nil
|
|
||||||
}
|
|
||||||
c.elems = c.elems[:0]
|
|
||||||
device.pool.outboundElementsContainer.Put(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
|
|
||||||
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
|
|
||||||
device.pool.messageBuffers.Put(msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) GetInboundElement() *QueueInboundElement {
|
|
||||||
return device.pool.inboundElements.Get().(*QueueInboundElement)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) PutInboundElement(elem *QueueInboundElement) {
|
|
||||||
elem.clearPointers()
|
|
||||||
device.pool.inboundElements.Put(elem)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) GetOutboundElement() *QueueOutboundElement {
|
|
||||||
return device.pool.outboundElements.Get().(*QueueOutboundElement)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
|
|
||||||
elem.clearPointers()
|
|
||||||
device.pool.outboundElements.Put(elem)
|
|
||||||
}
|
|
|
@ -1,139 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"math/rand"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestWaitPool(t *testing.T) {
|
|
||||||
t.Skip("Currently disabled")
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
var trials atomic.Int32
|
|
||||||
startTrials := int32(100000)
|
|
||||||
if raceEnabled {
|
|
||||||
// This test can be very slow with -race.
|
|
||||||
startTrials /= 10
|
|
||||||
}
|
|
||||||
trials.Store(startTrials)
|
|
||||||
workers := runtime.NumCPU() + 2
|
|
||||||
if workers-4 <= 0 {
|
|
||||||
t.Skip("Not enough cores")
|
|
||||||
}
|
|
||||||
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
|
|
||||||
wg.Add(workers)
|
|
||||||
var max atomic.Uint32
|
|
||||||
updateMax := func() {
|
|
||||||
count := p.count.Load()
|
|
||||||
if count > p.max {
|
|
||||||
t.Errorf("count (%d) > max (%d)", count, p.max)
|
|
||||||
}
|
|
||||||
for {
|
|
||||||
old := max.Load()
|
|
||||||
if count <= old {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if max.CompareAndSwap(old, count) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for i := 0; i < workers; i++ {
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
for trials.Add(-1) > 0 {
|
|
||||||
updateMax()
|
|
||||||
x := p.Get()
|
|
||||||
updateMax()
|
|
||||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
|
||||||
updateMax()
|
|
||||||
p.Put(x)
|
|
||||||
updateMax()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
if max.Load() != p.max {
|
|
||||||
t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max, p.max)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkWaitPool(b *testing.B) {
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
var trials atomic.Int32
|
|
||||||
trials.Store(int32(b.N))
|
|
||||||
workers := runtime.NumCPU() + 2
|
|
||||||
if workers-4 <= 0 {
|
|
||||||
b.Skip("Not enough cores")
|
|
||||||
}
|
|
||||||
p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) })
|
|
||||||
wg.Add(workers)
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < workers; i++ {
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
for trials.Add(-1) > 0 {
|
|
||||||
x := p.Get()
|
|
||||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
|
||||||
p.Put(x)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkWaitPoolEmpty(b *testing.B) {
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
var trials atomic.Int32
|
|
||||||
trials.Store(int32(b.N))
|
|
||||||
workers := runtime.NumCPU() + 2
|
|
||||||
if workers-4 <= 0 {
|
|
||||||
b.Skip("Not enough cores")
|
|
||||||
}
|
|
||||||
p := NewWaitPool(0, func() any { return make([]byte, 16) })
|
|
||||||
wg.Add(workers)
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < workers; i++ {
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
for trials.Add(-1) > 0 {
|
|
||||||
x := p.Get()
|
|
||||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
|
||||||
p.Put(x)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkSyncPool(b *testing.B) {
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
var trials atomic.Int32
|
|
||||||
trials.Store(int32(b.N))
|
|
||||||
workers := runtime.NumCPU() + 2
|
|
||||||
if workers-4 <= 0 {
|
|
||||||
b.Skip("Not enough cores")
|
|
||||||
}
|
|
||||||
p := sync.Pool{New: func() any { return make([]byte, 16) }}
|
|
||||||
wg.Add(workers)
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < workers; i++ {
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
for trials.Add(-1) > 0 {
|
|
||||||
x := p.Get()
|
|
||||||
time.Sleep(time.Duration(rand.Intn(100)) * time.Microsecond)
|
|
||||||
p.Put(x)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
|
@ -1,19 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import "github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
|
|
||||||
/* Reduce memory consumption for Android */
|
|
||||||
|
|
||||||
const (
|
|
||||||
QueueStagedSize = conn.IdealBatchSize
|
|
||||||
QueueOutboundSize = 1024
|
|
||||||
QueueInboundSize = 1024
|
|
||||||
QueueHandshakeSize = 1024
|
|
||||||
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
|
|
||||||
PreallocatedBuffersPerPool = 4096
|
|
||||||
)
|
|
|
@ -1,19 +0,0 @@
|
||||||
//go:build !android && !ios && !windows
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import "github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
|
|
||||||
const (
|
|
||||||
QueueStagedSize = conn.IdealBatchSize
|
|
||||||
QueueOutboundSize = 1024
|
|
||||||
QueueInboundSize = 1024
|
|
||||||
QueueHandshakeSize = 1024
|
|
||||||
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
|
|
||||||
PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
|
|
||||||
)
|
|
|
@ -1,21 +0,0 @@
|
||||||
//go:build ios
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
// Fit within memory limits for iOS's Network Extension API, which has stricter requirements.
|
|
||||||
// These are vars instead of consts, because heavier network extensions might want to reduce
|
|
||||||
// them further.
|
|
||||||
var (
|
|
||||||
QueueStagedSize = 128
|
|
||||||
QueueOutboundSize = 1024
|
|
||||||
QueueInboundSize = 1024
|
|
||||||
QueueHandshakeSize = 1024
|
|
||||||
PreallocatedBuffersPerPool uint32 = 1024
|
|
||||||
)
|
|
||||||
|
|
||||||
const MaxSegmentSize = 1700
|
|
|
@ -1,15 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
const (
|
|
||||||
QueueStagedSize = 128
|
|
||||||
QueueOutboundSize = 1024
|
|
||||||
QueueInboundSize = 1024
|
|
||||||
QueueHandshakeSize = 1024
|
|
||||||
MaxSegmentSize = 2048 - 32 // largest possible UDP datagram
|
|
||||||
PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
|
|
||||||
)
|
|
|
@ -1,10 +0,0 @@
|
||||||
//go:build !race
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
const raceEnabled = false
|
|
|
@ -1,10 +0,0 @@
|
||||||
//go:build race
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
const raceEnabled = true
|
|
|
@ -1,577 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.org/x/net/ipv6"
|
|
||||||
)
|
|
||||||
|
|
||||||
type QueueHandshakeElement struct {
|
|
||||||
msgType uint32
|
|
||||||
packet []byte
|
|
||||||
endpoint conn.Endpoint
|
|
||||||
buffer *[MaxMessageSize]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
type QueueInboundElement struct {
|
|
||||||
buffer *[MaxMessageSize]byte
|
|
||||||
packet []byte
|
|
||||||
counter uint64
|
|
||||||
keypair *Keypair
|
|
||||||
endpoint conn.Endpoint
|
|
||||||
}
|
|
||||||
|
|
||||||
type QueueInboundElementsContainer struct {
|
|
||||||
sync.Mutex
|
|
||||||
elems []*QueueInboundElement
|
|
||||||
}
|
|
||||||
|
|
||||||
// clearPointers clears elem fields that contain pointers.
|
|
||||||
// This makes the garbage collector's life easier and
|
|
||||||
// avoids accidentally keeping other objects around unnecessarily.
|
|
||||||
// It also reduces the possible collateral damage from use-after-free bugs.
|
|
||||||
func (elem *QueueInboundElement) clearPointers() {
|
|
||||||
elem.buffer = nil
|
|
||||||
elem.packet = nil
|
|
||||||
elem.keypair = nil
|
|
||||||
elem.endpoint = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Called when a new authenticated message has been received
|
|
||||||
*
|
|
||||||
* NOTE: Not thread safe, but called by sequential receiver!
|
|
||||||
*/
|
|
||||||
func (peer *Peer) keepKeyFreshReceiving() {
|
|
||||||
if peer.timers.sentLastMinuteHandshake.Load() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
keypair := peer.keypairs.Current()
|
|
||||||
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
|
|
||||||
peer.timers.sentLastMinuteHandshake.Store(true)
|
|
||||||
peer.SendHandshakeInitiation(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Receives incoming datagrams for the device
|
|
||||||
*
|
|
||||||
* Every time the bind is updated a new routine is started for
|
|
||||||
* IPv4 and IPv6 (separately)
|
|
||||||
*/
|
|
||||||
func (device *Device) RoutineReceiveIncoming(
|
|
||||||
maxBatchSize int,
|
|
||||||
recv conn.ReceiveFunc,
|
|
||||||
) {
|
|
||||||
recvName := recv.PrettyName()
|
|
||||||
defer func() {
|
|
||||||
device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
|
|
||||||
device.queue.decryption.wg.Done()
|
|
||||||
device.queue.handshake.wg.Done()
|
|
||||||
device.net.stopping.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
device.log.Verbosef("Routine: receive incoming %s - started", recvName)
|
|
||||||
|
|
||||||
// receive datagrams until conn is closed
|
|
||||||
|
|
||||||
var (
|
|
||||||
bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
|
|
||||||
bufs = make([][]byte, maxBatchSize)
|
|
||||||
err error
|
|
||||||
sizes = make([]int, maxBatchSize)
|
|
||||||
count int
|
|
||||||
endpoints = make([]conn.Endpoint, maxBatchSize)
|
|
||||||
deathSpiral int
|
|
||||||
elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
|
|
||||||
)
|
|
||||||
|
|
||||||
for i := range bufsArrs {
|
|
||||||
bufsArrs[i] = device.GetMessageBuffer()
|
|
||||||
bufs[i] = bufsArrs[i][:]
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
for i := 0; i < maxBatchSize; i++ {
|
|
||||||
if bufsArrs[i] != nil {
|
|
||||||
device.PutMessageBuffer(bufsArrs[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
count, err = recv(bufs, sizes, endpoints)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, net.ErrClosed) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
|
|
||||||
if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if deathSpiral < 10 {
|
|
||||||
deathSpiral++
|
|
||||||
time.Sleep(time.Second / 3)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
deathSpiral = 0
|
|
||||||
|
|
||||||
device.aSecMux.RLock()
|
|
||||||
// handle each packet in the batch
|
|
||||||
for i, size := range sizes[:count] {
|
|
||||||
if size < MinMessageSize {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// check size of packet
|
|
||||||
|
|
||||||
packet := bufsArrs[i][:size]
|
|
||||||
var msgType uint32
|
|
||||||
if device.isAdvancedSecurityOn() {
|
|
||||||
if assumedMsgType, ok := packetSizeToMsgType[size]; ok {
|
|
||||||
junkSize := msgTypeToJunkSize[assumedMsgType]
|
|
||||||
// transport size can align with other header types;
|
|
||||||
// making sure we have the right msgType
|
|
||||||
msgType = binary.LittleEndian.Uint32(packet[junkSize : junkSize+4])
|
|
||||||
if msgType == assumedMsgType {
|
|
||||||
packet = packet[junkSize:]
|
|
||||||
} else {
|
|
||||||
device.log.Verbosef("Transport packet lined up with another msg type")
|
|
||||||
msgType = binary.LittleEndian.Uint32(packet[:4])
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msgType = binary.LittleEndian.Uint32(packet[:4])
|
|
||||||
if msgType != MessageTransportType {
|
|
||||||
device.log.Verbosef("ASec: Received message with unknown type")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
msgType = binary.LittleEndian.Uint32(packet[:4])
|
|
||||||
}
|
|
||||||
switch msgType {
|
|
||||||
|
|
||||||
// check if transport
|
|
||||||
|
|
||||||
case MessageTransportType:
|
|
||||||
|
|
||||||
// check size
|
|
||||||
|
|
||||||
if len(packet) < MessageTransportSize {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookup key pair
|
|
||||||
|
|
||||||
receiver := binary.LittleEndian.Uint32(
|
|
||||||
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
|
|
||||||
)
|
|
||||||
value := device.indexTable.Lookup(receiver)
|
|
||||||
keypair := value.keypair
|
|
||||||
if keypair == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// check keypair expiry
|
|
||||||
|
|
||||||
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// create work element
|
|
||||||
peer := value.peer
|
|
||||||
elem := device.GetInboundElement()
|
|
||||||
elem.packet = packet
|
|
||||||
elem.buffer = bufsArrs[i]
|
|
||||||
elem.keypair = keypair
|
|
||||||
elem.endpoint = endpoints[i]
|
|
||||||
elem.counter = 0
|
|
||||||
|
|
||||||
elemsForPeer, ok := elemsByPeer[peer]
|
|
||||||
if !ok {
|
|
||||||
elemsForPeer = device.GetInboundElementsContainer()
|
|
||||||
elemsForPeer.Lock()
|
|
||||||
elemsByPeer[peer] = elemsForPeer
|
|
||||||
}
|
|
||||||
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
|
||||||
bufsArrs[i] = device.GetMessageBuffer()
|
|
||||||
bufs[i] = bufsArrs[i][:]
|
|
||||||
continue
|
|
||||||
|
|
||||||
// otherwise it is a fixed size & handshake related packet
|
|
||||||
|
|
||||||
case MessageInitiationType:
|
|
||||||
if len(packet) != MessageInitiationSize {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
case MessageResponseType:
|
|
||||||
if len(packet) != MessageResponseSize {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
case MessageCookieReplyType:
|
|
||||||
if len(packet) != MessageCookieReplySize {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
device.log.Verbosef("Received message with unknown type")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case device.queue.handshake.c <- QueueHandshakeElement{
|
|
||||||
msgType: msgType,
|
|
||||||
buffer: bufsArrs[i],
|
|
||||||
packet: packet,
|
|
||||||
endpoint: endpoints[i],
|
|
||||||
}:
|
|
||||||
bufsArrs[i] = device.GetMessageBuffer()
|
|
||||||
bufs[i] = bufsArrs[i][:]
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device.aSecMux.RUnlock()
|
|
||||||
for peer, elemsContainer := range elemsByPeer {
|
|
||||||
if peer.isRunning.Load() {
|
|
||||||
peer.queue.inbound.c <- elemsContainer
|
|
||||||
device.queue.decryption.c <- elemsContainer
|
|
||||||
} else {
|
|
||||||
for _, elem := range elemsContainer.elems {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
device.PutInboundElement(elem)
|
|
||||||
}
|
|
||||||
device.PutInboundElementsContainer(elemsContainer)
|
|
||||||
}
|
|
||||||
delete(elemsByPeer, peer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) RoutineDecryption(id int) {
|
|
||||||
var nonce [chacha20poly1305.NonceSize]byte
|
|
||||||
|
|
||||||
defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
|
|
||||||
device.log.Verbosef("Routine: decryption worker %d - started", id)
|
|
||||||
|
|
||||||
for elemsContainer := range device.queue.decryption.c {
|
|
||||||
for _, elem := range elemsContainer.elems {
|
|
||||||
// split message into fields
|
|
||||||
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
|
|
||||||
content := elem.packet[MessageTransportOffsetContent:]
|
|
||||||
|
|
||||||
// decrypt and release to consumer
|
|
||||||
var err error
|
|
||||||
elem.counter = binary.LittleEndian.Uint64(counter)
|
|
||||||
// copy counter to nonce
|
|
||||||
binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
|
|
||||||
elem.packet, err = elem.keypair.receive.Open(
|
|
||||||
content[:0],
|
|
||||||
nonce[:],
|
|
||||||
content,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
elem.packet = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
elemsContainer.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Handles incoming packets related to handshake
|
|
||||||
*/
|
|
||||||
func (device *Device) RoutineHandshake(id int) {
|
|
||||||
defer func() {
|
|
||||||
device.log.Verbosef("Routine: handshake worker %d - stopped", id)
|
|
||||||
device.queue.encryption.wg.Done()
|
|
||||||
}()
|
|
||||||
device.log.Verbosef("Routine: handshake worker %d - started", id)
|
|
||||||
|
|
||||||
for elem := range device.queue.handshake.c {
|
|
||||||
|
|
||||||
device.aSecMux.RLock()
|
|
||||||
|
|
||||||
// handle cookie fields and ratelimiting
|
|
||||||
|
|
||||||
switch elem.msgType {
|
|
||||||
|
|
||||||
case MessageCookieReplyType:
|
|
||||||
|
|
||||||
// unmarshal packet
|
|
||||||
|
|
||||||
var reply MessageCookieReply
|
|
||||||
reader := bytes.NewReader(elem.packet)
|
|
||||||
err := binary.Read(reader, binary.LittleEndian, &reply)
|
|
||||||
if err != nil {
|
|
||||||
device.log.Verbosef("Failed to decode cookie reply")
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookup peer from index
|
|
||||||
|
|
||||||
entry := device.indexTable.Lookup(reply.Receiver)
|
|
||||||
|
|
||||||
if entry.peer == nil {
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
|
|
||||||
// consume reply
|
|
||||||
|
|
||||||
if peer := entry.peer; peer.isRunning.Load() {
|
|
||||||
device.log.Verbosef(
|
|
||||||
"Receiving cookie response from %s",
|
|
||||||
elem.endpoint.DstToString(),
|
|
||||||
)
|
|
||||||
if !peer.cookieGenerator.ConsumeReply(&reply) {
|
|
||||||
device.log.Verbosef(
|
|
||||||
"Could not decrypt invalid cookie response",
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
goto skip
|
|
||||||
|
|
||||||
case MessageInitiationType, MessageResponseType:
|
|
||||||
|
|
||||||
// check mac fields and maybe ratelimit
|
|
||||||
|
|
||||||
if !device.cookieChecker.CheckMAC1(elem.packet) {
|
|
||||||
device.log.Verbosef("Received packet with invalid mac1")
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
|
|
||||||
// endpoints destination address is the source of the datagram
|
|
||||||
|
|
||||||
if device.IsUnderLoad() {
|
|
||||||
|
|
||||||
// verify MAC2 field
|
|
||||||
|
|
||||||
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
|
|
||||||
device.SendHandshakeCookie(&elem)
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
|
|
||||||
// check ratelimiter
|
|
||||||
|
|
||||||
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
device.log.Errorf("Invalid packet ended up in the handshake queue")
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
|
|
||||||
// handle handshake initiation/response content
|
|
||||||
|
|
||||||
switch elem.msgType {
|
|
||||||
case MessageInitiationType:
|
|
||||||
// unmarshal
|
|
||||||
var msg MessageInitiation
|
|
||||||
reader := bytes.NewReader(elem.packet)
|
|
||||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
|
||||||
if err != nil {
|
|
||||||
device.log.Errorf("Failed to decode initiation message")
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
|
|
||||||
// consume initiation
|
|
||||||
peer := device.ConsumeMessageInitiation(&msg)
|
|
||||||
if peer == nil {
|
|
||||||
device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
|
|
||||||
// update timers
|
|
||||||
|
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
|
||||||
peer.timersAnyAuthenticatedPacketReceived()
|
|
||||||
|
|
||||||
// update endpoint
|
|
||||||
peer.SetEndpointFromPacket(elem.endpoint)
|
|
||||||
|
|
||||||
device.log.Verbosef("%v - Received handshake initiation", peer)
|
|
||||||
peer.rxBytes.Add(uint64(len(elem.packet)))
|
|
||||||
|
|
||||||
peer.SendHandshakeResponse()
|
|
||||||
|
|
||||||
case MessageResponseType:
|
|
||||||
|
|
||||||
// unmarshal
|
|
||||||
|
|
||||||
var msg MessageResponse
|
|
||||||
reader := bytes.NewReader(elem.packet)
|
|
||||||
err := binary.Read(reader, binary.LittleEndian, &msg)
|
|
||||||
if err != nil {
|
|
||||||
device.log.Errorf("Failed to decode response message")
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
|
|
||||||
// consume response
|
|
||||||
|
|
||||||
peer := device.ConsumeMessageResponse(&msg)
|
|
||||||
if peer == nil {
|
|
||||||
device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
|
|
||||||
// update endpoint
|
|
||||||
peer.SetEndpointFromPacket(elem.endpoint)
|
|
||||||
|
|
||||||
device.log.Verbosef("%v - Received handshake response", peer)
|
|
||||||
peer.rxBytes.Add(uint64(len(elem.packet)))
|
|
||||||
|
|
||||||
// update timers
|
|
||||||
|
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
|
||||||
peer.timersAnyAuthenticatedPacketReceived()
|
|
||||||
|
|
||||||
// derive keypair
|
|
||||||
|
|
||||||
err = peer.BeginSymmetricSession()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
|
|
||||||
goto skip
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.timersSessionDerived()
|
|
||||||
peer.timersHandshakeComplete()
|
|
||||||
peer.SendKeepalive()
|
|
||||||
}
|
|
||||||
skip:
|
|
||||||
device.aSecMux.RUnlock()
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
|
|
||||||
device := peer.device
|
|
||||||
defer func() {
|
|
||||||
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
|
|
||||||
peer.stopping.Done()
|
|
||||||
}()
|
|
||||||
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
|
|
||||||
|
|
||||||
bufs := make([][]byte, 0, maxBatchSize)
|
|
||||||
|
|
||||||
for elemsContainer := range peer.queue.inbound.c {
|
|
||||||
if elemsContainer == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
elemsContainer.Lock()
|
|
||||||
validTailPacket := -1
|
|
||||||
dataPacketReceived := false
|
|
||||||
rxBytesLen := uint64(0)
|
|
||||||
for i, elem := range elemsContainer.elems {
|
|
||||||
if elem.packet == nil {
|
|
||||||
// decryption failed
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
validTailPacket = i
|
|
||||||
if peer.ReceivedWithKeypair(elem.keypair) {
|
|
||||||
peer.SetEndpointFromPacket(elem.endpoint)
|
|
||||||
peer.timersHandshakeComplete()
|
|
||||||
peer.SendStagedPackets()
|
|
||||||
}
|
|
||||||
rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
|
|
||||||
|
|
||||||
if len(elem.packet) == 0 {
|
|
||||||
device.log.Verbosef("%v - Receiving keepalive packet", peer)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dataPacketReceived = true
|
|
||||||
|
|
||||||
switch elem.packet[0] >> 4 {
|
|
||||||
case 4:
|
|
||||||
if len(elem.packet) < ipv4.HeaderLen {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
|
|
||||||
length := binary.BigEndian.Uint16(field)
|
|
||||||
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
elem.packet = elem.packet[:length]
|
|
||||||
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
|
|
||||||
if device.allowedips.Lookup(src) != peer {
|
|
||||||
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
case 6:
|
|
||||||
if len(elem.packet) < ipv6.HeaderLen {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
|
|
||||||
length := binary.BigEndian.Uint16(field)
|
|
||||||
length += ipv6.HeaderLen
|
|
||||||
if int(length) > len(elem.packet) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
elem.packet = elem.packet[:length]
|
|
||||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
|
||||||
if device.allowedips.Lookup(src) != peer {
|
|
||||||
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
device.log.Verbosef(
|
|
||||||
"Packet with invalid IP version from %v",
|
|
||||||
peer,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
bufs = append(
|
|
||||||
bufs,
|
|
||||||
elem.buffer[:MessageTransportOffsetContent+len(elem.packet)],
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.rxBytes.Add(rxBytesLen)
|
|
||||||
if validTailPacket >= 0 {
|
|
||||||
peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
|
|
||||||
peer.keepKeyFreshReceiving()
|
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
|
||||||
peer.timersAnyAuthenticatedPacketReceived()
|
|
||||||
}
|
|
||||||
if dataPacketReceived {
|
|
||||||
peer.timersDataReceived()
|
|
||||||
}
|
|
||||||
if len(bufs) > 0 {
|
|
||||||
_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
|
|
||||||
if err != nil && !device.isClosed() {
|
|
||||||
device.log.Errorf("Failed to write packets to TUN device: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, elem := range elemsContainer.elems {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
device.PutInboundElement(elem)
|
|
||||||
}
|
|
||||||
bufs = bufs[:0]
|
|
||||||
device.PutInboundElementsContainer(elemsContainer)
|
|
||||||
}
|
|
||||||
}
|
|
608
device/send.go
608
device/send.go
|
@ -1,608 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
|
||||||
"golang.org/x/net/ipv4"
|
|
||||||
"golang.org/x/net/ipv6"
|
|
||||||
)
|
|
||||||
|
|
||||||
/* Outbound flow
|
|
||||||
*
|
|
||||||
* 1. TUN queue
|
|
||||||
* 2. Routing (sequential)
|
|
||||||
* 3. Nonce assignment (sequential)
|
|
||||||
* 4. Encryption (parallel)
|
|
||||||
* 5. Transmission (sequential)
|
|
||||||
*
|
|
||||||
* The functions in this file occur (roughly) in the order in
|
|
||||||
* which the packets are processed.
|
|
||||||
*
|
|
||||||
* Locking, Producers and Consumers
|
|
||||||
*
|
|
||||||
* The order of packets (per peer) must be maintained,
|
|
||||||
* but encryption of packets happen out-of-order:
|
|
||||||
*
|
|
||||||
* The sequential consumers will attempt to take the lock,
|
|
||||||
* workers release lock when they have completed work (encryption) on the packet.
|
|
||||||
*
|
|
||||||
* If the element is inserted into the "encryption queue",
|
|
||||||
* the content is preceded by enough "junk" to contain the transport header
|
|
||||||
* (to allow the construction of transport messages in-place)
|
|
||||||
*/
|
|
||||||
|
|
||||||
type QueueOutboundElement struct {
|
|
||||||
buffer *[MaxMessageSize]byte // slice holding the packet data
|
|
||||||
packet []byte // slice of "buffer" (always!)
|
|
||||||
nonce uint64 // nonce for encryption
|
|
||||||
keypair *Keypair // keypair for encryption
|
|
||||||
peer *Peer // related peer
|
|
||||||
}
|
|
||||||
|
|
||||||
type QueueOutboundElementsContainer struct {
|
|
||||||
sync.Mutex
|
|
||||||
elems []*QueueOutboundElement
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) NewOutboundElement() *QueueOutboundElement {
|
|
||||||
elem := device.GetOutboundElement()
|
|
||||||
elem.buffer = device.GetMessageBuffer()
|
|
||||||
elem.nonce = 0
|
|
||||||
// keypair and peer were cleared (if necessary) by clearPointers.
|
|
||||||
return elem
|
|
||||||
}
|
|
||||||
|
|
||||||
// clearPointers clears elem fields that contain pointers.
|
|
||||||
// This makes the garbage collector's life easier and
|
|
||||||
// avoids accidentally keeping other objects around unnecessarily.
|
|
||||||
// It also reduces the possible collateral damage from use-after-free bugs.
|
|
||||||
func (elem *QueueOutboundElement) clearPointers() {
|
|
||||||
elem.buffer = nil
|
|
||||||
elem.packet = nil
|
|
||||||
elem.keypair = nil
|
|
||||||
elem.peer = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Queues a keepalive if no packets are queued for peer
|
|
||||||
*/
|
|
||||||
func (peer *Peer) SendKeepalive() {
|
|
||||||
if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
|
|
||||||
elem := peer.device.NewOutboundElement()
|
|
||||||
elemsContainer := peer.device.GetOutboundElementsContainer()
|
|
||||||
elemsContainer.elems = append(elemsContainer.elems, elem)
|
|
||||||
select {
|
|
||||||
case peer.queue.staged <- elemsContainer:
|
|
||||||
peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
|
|
||||||
default:
|
|
||||||
peer.device.PutMessageBuffer(elem.buffer)
|
|
||||||
peer.device.PutOutboundElement(elem)
|
|
||||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
peer.SendStagedPackets()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|
||||||
if !isRetry {
|
|
||||||
peer.timers.handshakeAttempts.Store(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.handshake.mutex.RLock()
|
|
||||||
if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
|
|
||||||
peer.handshake.mutex.RUnlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
peer.handshake.mutex.RUnlock()
|
|
||||||
|
|
||||||
peer.handshake.mutex.Lock()
|
|
||||||
if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
|
|
||||||
peer.handshake.mutex.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
peer.handshake.lastSentHandshake = time.Now()
|
|
||||||
peer.handshake.mutex.Unlock()
|
|
||||||
|
|
||||||
peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
|
|
||||||
|
|
||||||
msg, err := peer.device.CreateMessageInitiation(peer)
|
|
||||||
if err != nil {
|
|
||||||
peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var sendBuffer [][]byte
|
|
||||||
// so only packet processed for cookie generation
|
|
||||||
var junkedHeader []byte
|
|
||||||
if peer.device.isAdvancedSecurityOn() {
|
|
||||||
peer.device.aSecMux.RLock()
|
|
||||||
junks, err := peer.device.junkCreator.createJunkPackets()
|
|
||||||
peer.device.aSecMux.RUnlock()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
peer.device.log.Errorf("%v - %v", peer, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(junks) > 0 {
|
|
||||||
err = peer.SendBuffers(junks)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
peer.device.log.Errorf("%v - Failed to send junk packets: %v", peer, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.device.aSecMux.RLock()
|
|
||||||
if peer.device.aSecCfg.initPacketJunkSize != 0 {
|
|
||||||
buf := make([]byte, 0, peer.device.aSecCfg.initPacketJunkSize)
|
|
||||||
writer := bytes.NewBuffer(buf[:0])
|
|
||||||
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.initPacketJunkSize)
|
|
||||||
if err != nil {
|
|
||||||
peer.device.log.Errorf("%v - %v", peer, err)
|
|
||||||
peer.device.aSecMux.RUnlock()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
junkedHeader = writer.Bytes()
|
|
||||||
}
|
|
||||||
peer.device.aSecMux.RUnlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf [MessageInitiationSize]byte
|
|
||||||
writer := bytes.NewBuffer(buf[:0])
|
|
||||||
binary.Write(writer, binary.LittleEndian, msg)
|
|
||||||
packet := writer.Bytes()
|
|
||||||
peer.cookieGenerator.AddMacs(packet)
|
|
||||||
junkedHeader = append(junkedHeader, packet...)
|
|
||||||
|
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
|
||||||
peer.timersAnyAuthenticatedPacketSent()
|
|
||||||
|
|
||||||
sendBuffer = append(sendBuffer, junkedHeader)
|
|
||||||
|
|
||||||
err = peer.SendBuffers(sendBuffer)
|
|
||||||
if err != nil {
|
|
||||||
peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
|
|
||||||
}
|
|
||||||
peer.timersHandshakeInitiated()
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) SendHandshakeResponse() error {
|
|
||||||
peer.handshake.mutex.Lock()
|
|
||||||
peer.handshake.lastSentHandshake = time.Now()
|
|
||||||
peer.handshake.mutex.Unlock()
|
|
||||||
|
|
||||||
peer.device.log.Verbosef("%v - Sending handshake response", peer)
|
|
||||||
|
|
||||||
response, err := peer.device.CreateMessageResponse(peer)
|
|
||||||
if err != nil {
|
|
||||||
peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
var junkedHeader []byte
|
|
||||||
if peer.device.isAdvancedSecurityOn() {
|
|
||||||
peer.device.aSecMux.RLock()
|
|
||||||
if peer.device.aSecCfg.responsePacketJunkSize != 0 {
|
|
||||||
buf := make([]byte, 0, peer.device.aSecCfg.responsePacketJunkSize)
|
|
||||||
writer := bytes.NewBuffer(buf[:0])
|
|
||||||
err = peer.device.junkCreator.appendJunk(writer, peer.device.aSecCfg.responsePacketJunkSize)
|
|
||||||
if err != nil {
|
|
||||||
peer.device.aSecMux.RUnlock()
|
|
||||||
peer.device.log.Errorf("%v - %v", peer, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
junkedHeader = writer.Bytes()
|
|
||||||
}
|
|
||||||
peer.device.aSecMux.RUnlock()
|
|
||||||
}
|
|
||||||
var buf [MessageResponseSize]byte
|
|
||||||
writer := bytes.NewBuffer(buf[:0])
|
|
||||||
|
|
||||||
binary.Write(writer, binary.LittleEndian, response)
|
|
||||||
packet := writer.Bytes()
|
|
||||||
peer.cookieGenerator.AddMacs(packet)
|
|
||||||
junkedHeader = append(junkedHeader, packet...)
|
|
||||||
|
|
||||||
err = peer.BeginSymmetricSession()
|
|
||||||
if err != nil {
|
|
||||||
peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.timersSessionDerived()
|
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
|
||||||
peer.timersAnyAuthenticatedPacketSent()
|
|
||||||
|
|
||||||
// TODO: allocation could be avoided
|
|
||||||
err = peer.SendBuffers([][]byte{junkedHeader})
|
|
||||||
if err != nil {
|
|
||||||
peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) SendHandshakeCookie(
|
|
||||||
initiatingElem *QueueHandshakeElement,
|
|
||||||
) error {
|
|
||||||
device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
|
|
||||||
|
|
||||||
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
|
|
||||||
reply, err := device.cookieChecker.CreateReply(
|
|
||||||
initiatingElem.packet,
|
|
||||||
sender,
|
|
||||||
initiatingElem.endpoint.DstToBytes(),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
device.log.Errorf("Failed to create cookie reply: %v", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
var buf [MessageCookieReplySize]byte
|
|
||||||
writer := bytes.NewBuffer(buf[:0])
|
|
||||||
binary.Write(writer, binary.LittleEndian, reply)
|
|
||||||
// TODO: allocation could be avoided
|
|
||||||
device.net.bind.Send([][]byte{writer.Bytes()}, initiatingElem.endpoint)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) keepKeyFreshSending() {
|
|
||||||
keypair := peer.keypairs.Current()
|
|
||||||
if keypair == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
nonce := keypair.sendNonce.Load()
|
|
||||||
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
|
|
||||||
peer.SendHandshakeInitiation(false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) RoutineReadFromTUN() {
|
|
||||||
defer func() {
|
|
||||||
device.log.Verbosef("Routine: TUN reader - stopped")
|
|
||||||
device.state.stopping.Done()
|
|
||||||
device.queue.encryption.wg.Done()
|
|
||||||
}()
|
|
||||||
|
|
||||||
device.log.Verbosef("Routine: TUN reader - started")
|
|
||||||
|
|
||||||
var (
|
|
||||||
batchSize = device.BatchSize()
|
|
||||||
readErr error
|
|
||||||
elems = make([]*QueueOutboundElement, batchSize)
|
|
||||||
bufs = make([][]byte, batchSize)
|
|
||||||
elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
|
|
||||||
count = 0
|
|
||||||
sizes = make([]int, batchSize)
|
|
||||||
offset = MessageTransportHeaderSize
|
|
||||||
)
|
|
||||||
|
|
||||||
for i := range elems {
|
|
||||||
elems[i] = device.NewOutboundElement()
|
|
||||||
bufs[i] = elems[i].buffer[:]
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
for _, elem := range elems {
|
|
||||||
if elem != nil {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
device.PutOutboundElement(elem)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
// read packets
|
|
||||||
count, readErr = device.tun.device.Read(bufs, sizes, offset)
|
|
||||||
for i := 0; i < count; i++ {
|
|
||||||
if sizes[i] < 1 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
elem := elems[i]
|
|
||||||
elem.packet = bufs[i][offset : offset+sizes[i]]
|
|
||||||
|
|
||||||
// lookup peer
|
|
||||||
var peer *Peer
|
|
||||||
switch elem.packet[0] >> 4 {
|
|
||||||
case 4:
|
|
||||||
if len(elem.packet) < ipv4.HeaderLen {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
|
|
||||||
peer = device.allowedips.Lookup(dst)
|
|
||||||
|
|
||||||
case 6:
|
|
||||||
if len(elem.packet) < ipv6.HeaderLen {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
|
|
||||||
peer = device.allowedips.Lookup(dst)
|
|
||||||
|
|
||||||
default:
|
|
||||||
device.log.Verbosef("Received packet with unknown IP version")
|
|
||||||
}
|
|
||||||
|
|
||||||
if peer == nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
elemsForPeer, ok := elemsByPeer[peer]
|
|
||||||
if !ok {
|
|
||||||
elemsForPeer = device.GetOutboundElementsContainer()
|
|
||||||
elemsByPeer[peer] = elemsForPeer
|
|
||||||
}
|
|
||||||
elemsForPeer.elems = append(elemsForPeer.elems, elem)
|
|
||||||
elems[i] = device.NewOutboundElement()
|
|
||||||
bufs[i] = elems[i].buffer[:]
|
|
||||||
}
|
|
||||||
|
|
||||||
for peer, elemsForPeer := range elemsByPeer {
|
|
||||||
if peer.isRunning.Load() {
|
|
||||||
peer.StagePackets(elemsForPeer)
|
|
||||||
peer.SendStagedPackets()
|
|
||||||
} else {
|
|
||||||
for _, elem := range elemsForPeer.elems {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
device.PutOutboundElement(elem)
|
|
||||||
}
|
|
||||||
device.PutOutboundElementsContainer(elemsForPeer)
|
|
||||||
}
|
|
||||||
delete(elemsByPeer, peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
if readErr != nil {
|
|
||||||
if errors.Is(readErr, tun.ErrTooManySegments) {
|
|
||||||
// TODO: record stat for this
|
|
||||||
// This will happen if MSS is surprisingly small (< 576)
|
|
||||||
// coincident with reasonably high throughput.
|
|
||||||
device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !device.isClosed() {
|
|
||||||
if !errors.Is(readErr, os.ErrClosed) {
|
|
||||||
device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
|
|
||||||
}
|
|
||||||
go device.Close()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case peer.queue.staged <- elems:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case tooOld := <-peer.queue.staged:
|
|
||||||
for _, elem := range tooOld.elems {
|
|
||||||
peer.device.PutMessageBuffer(elem.buffer)
|
|
||||||
peer.device.PutOutboundElement(elem)
|
|
||||||
}
|
|
||||||
peer.device.PutOutboundElementsContainer(tooOld)
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) SendStagedPackets() {
|
|
||||||
top:
|
|
||||||
if len(peer.queue.staged) == 0 || !peer.device.isUp() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
keypair := peer.keypairs.Current()
|
|
||||||
if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
|
|
||||||
peer.SendHandshakeInitiation(false)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
var elemsContainerOOO *QueueOutboundElementsContainer
|
|
||||||
select {
|
|
||||||
case elemsContainer := <-peer.queue.staged:
|
|
||||||
i := 0
|
|
||||||
for _, elem := range elemsContainer.elems {
|
|
||||||
elem.peer = peer
|
|
||||||
elem.nonce = keypair.sendNonce.Add(1) - 1
|
|
||||||
if elem.nonce >= RejectAfterMessages {
|
|
||||||
keypair.sendNonce.Store(RejectAfterMessages)
|
|
||||||
if elemsContainerOOO == nil {
|
|
||||||
elemsContainerOOO = peer.device.GetOutboundElementsContainer()
|
|
||||||
}
|
|
||||||
elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
elemsContainer.elems[i] = elem
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
|
|
||||||
elem.keypair = keypair
|
|
||||||
}
|
|
||||||
elemsContainer.Lock()
|
|
||||||
elemsContainer.elems = elemsContainer.elems[:i]
|
|
||||||
|
|
||||||
if elemsContainerOOO != nil {
|
|
||||||
peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(elemsContainer.elems) == 0 {
|
|
||||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
|
||||||
goto top
|
|
||||||
}
|
|
||||||
|
|
||||||
// add to parallel and sequential queue
|
|
||||||
if peer.isRunning.Load() {
|
|
||||||
peer.queue.outbound.c <- elemsContainer
|
|
||||||
peer.device.queue.encryption.c <- elemsContainer
|
|
||||||
} else {
|
|
||||||
for _, elem := range elemsContainer.elems {
|
|
||||||
peer.device.PutMessageBuffer(elem.buffer)
|
|
||||||
peer.device.PutOutboundElement(elem)
|
|
||||||
}
|
|
||||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
|
||||||
}
|
|
||||||
|
|
||||||
if elemsContainerOOO != nil {
|
|
||||||
goto top
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) FlushStagedPackets() {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case elemsContainer := <-peer.queue.staged:
|
|
||||||
for _, elem := range elemsContainer.elems {
|
|
||||||
peer.device.PutMessageBuffer(elem.buffer)
|
|
||||||
peer.device.PutOutboundElement(elem)
|
|
||||||
}
|
|
||||||
peer.device.PutOutboundElementsContainer(elemsContainer)
|
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func calculatePaddingSize(packetSize, mtu int) int {
|
|
||||||
lastUnit := packetSize
|
|
||||||
if mtu == 0 {
|
|
||||||
return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
|
|
||||||
}
|
|
||||||
if lastUnit > mtu {
|
|
||||||
lastUnit %= mtu
|
|
||||||
}
|
|
||||||
paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
|
|
||||||
if paddedSize > mtu {
|
|
||||||
paddedSize = mtu
|
|
||||||
}
|
|
||||||
return paddedSize - lastUnit
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Encrypts the elements in the queue
|
|
||||||
* and marks them for sequential consumption (by releasing the mutex)
|
|
||||||
*
|
|
||||||
* Obs. One instance per core
|
|
||||||
*/
|
|
||||||
func (device *Device) RoutineEncryption(id int) {
|
|
||||||
var paddingZeros [PaddingMultiple]byte
|
|
||||||
var nonce [chacha20poly1305.NonceSize]byte
|
|
||||||
|
|
||||||
defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
|
|
||||||
device.log.Verbosef("Routine: encryption worker %d - started", id)
|
|
||||||
|
|
||||||
for elemsContainer := range device.queue.encryption.c {
|
|
||||||
for _, elem := range elemsContainer.elems {
|
|
||||||
// populate header fields
|
|
||||||
header := elem.buffer[:MessageTransportHeaderSize]
|
|
||||||
|
|
||||||
fieldType := header[0:4]
|
|
||||||
fieldReceiver := header[4:8]
|
|
||||||
fieldNonce := header[8:16]
|
|
||||||
|
|
||||||
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
|
|
||||||
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
|
|
||||||
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
|
|
||||||
|
|
||||||
// pad content to multiple of 16
|
|
||||||
paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
|
|
||||||
elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
|
|
||||||
|
|
||||||
// encrypt content and release to consumer
|
|
||||||
|
|
||||||
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
|
|
||||||
elem.packet = elem.keypair.send.Seal(
|
|
||||||
header,
|
|
||||||
nonce[:],
|
|
||||||
elem.packet,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
elemsContainer.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
|
|
||||||
device := peer.device
|
|
||||||
defer func() {
|
|
||||||
defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
|
|
||||||
peer.stopping.Done()
|
|
||||||
}()
|
|
||||||
device.log.Verbosef("%v - Routine: sequential sender - started", peer)
|
|
||||||
|
|
||||||
bufs := make([][]byte, 0, maxBatchSize)
|
|
||||||
|
|
||||||
for elemsContainer := range peer.queue.outbound.c {
|
|
||||||
bufs = bufs[:0]
|
|
||||||
if elemsContainer == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !peer.isRunning.Load() {
|
|
||||||
// peer has been stopped; return re-usable elems to the shared pool.
|
|
||||||
// This is an optimization only. It is possible for the peer to be stopped
|
|
||||||
// immediately after this check, in which case, elem will get processed.
|
|
||||||
// The timers and SendBuffers code are resilient to a few stragglers.
|
|
||||||
// TODO: rework peer shutdown order to ensure
|
|
||||||
// that we never accidentally keep timers alive longer than necessary.
|
|
||||||
elemsContainer.Lock()
|
|
||||||
for _, elem := range elemsContainer.elems {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
device.PutOutboundElement(elem)
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
dataSent := false
|
|
||||||
elemsContainer.Lock()
|
|
||||||
for _, elem := range elemsContainer.elems {
|
|
||||||
if len(elem.packet) != MessageKeepaliveSize {
|
|
||||||
dataSent = true
|
|
||||||
}
|
|
||||||
bufs = append(bufs, elem.packet)
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
|
||||||
peer.timersAnyAuthenticatedPacketSent()
|
|
||||||
|
|
||||||
err := peer.SendBuffers(bufs)
|
|
||||||
if dataSent {
|
|
||||||
peer.timersDataSent()
|
|
||||||
}
|
|
||||||
for _, elem := range elemsContainer.elems {
|
|
||||||
device.PutMessageBuffer(elem.buffer)
|
|
||||||
device.PutOutboundElement(elem)
|
|
||||||
}
|
|
||||||
device.PutOutboundElementsContainer(elemsContainer)
|
|
||||||
if err != nil {
|
|
||||||
var errGSO conn.ErrUDPGSODisabled
|
|
||||||
if errors.As(err, &errGSO) {
|
|
||||||
device.log.Verbosef(err.Error())
|
|
||||||
err = errGSO.RetryErr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.keepKeyFreshSending()
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,12 +0,0 @@
|
||||||
//go:build !linux
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
|
@ -1,224 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*
|
|
||||||
* This implements userspace semantics of "sticky sockets", modeled after
|
|
||||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
|
||||||
* of the sticky-sockets.c example code:
|
|
||||||
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
|
|
||||||
*
|
|
||||||
* Currently there is no way to achieve this within the net package:
|
|
||||||
* See e.g. https://github.com/golang/go/issues/17930
|
|
||||||
* So this code is remains platform dependent.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/rwcancel"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
|
||||||
if !conn.StdNetSupportsStickySockets {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if _, ok := bind.(*conn.StdNetBind); !ok {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
netlinkSock, err := createNetlinkRouteSocket()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
|
|
||||||
if err != nil {
|
|
||||||
unix.Close(netlinkSock)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
|
|
||||||
|
|
||||||
return netlinkCancel, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
|
||||||
type peerEndpointPtr struct {
|
|
||||||
peer *Peer
|
|
||||||
endpoint *conn.Endpoint
|
|
||||||
}
|
|
||||||
var reqPeer map[uint32]peerEndpointPtr
|
|
||||||
var reqPeerLock sync.Mutex
|
|
||||||
|
|
||||||
defer netlinkCancel.Close()
|
|
||||||
defer unix.Close(netlinkSock)
|
|
||||||
|
|
||||||
for msg := make([]byte, 1<<16); ; {
|
|
||||||
var err error
|
|
||||||
var msgn int
|
|
||||||
for {
|
|
||||||
msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
|
|
||||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if !netlinkCancel.ReadyRead() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
|
||||||
|
|
||||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
|
||||||
|
|
||||||
if uint(hdr.Len) > uint(len(remain)) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
switch hdr.Type {
|
|
||||||
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
|
|
||||||
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
|
|
||||||
if uint(len(remain)) < uint(hdr.Len) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
|
|
||||||
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
|
|
||||||
for {
|
|
||||||
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
|
|
||||||
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
|
|
||||||
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
|
|
||||||
reqPeerLock.Lock()
|
|
||||||
if reqPeer == nil {
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pePtr, ok := reqPeer[hdr.Seq]
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
if !ok {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pePtr.peer.endpoint.Lock()
|
|
||||||
if &pePtr.peer.endpoint.val != pePtr.endpoint {
|
|
||||||
pePtr.peer.endpoint.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
|
|
||||||
pePtr.peer.endpoint.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
pePtr.peer.endpoint.clearSrcOnTx = true
|
|
||||||
pePtr.peer.endpoint.Unlock()
|
|
||||||
}
|
|
||||||
attr = attr[attrhdr.Len:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
reqPeerLock.Lock()
|
|
||||||
reqPeer = make(map[uint32]peerEndpointPtr)
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
go func() {
|
|
||||||
device.peers.RLock()
|
|
||||||
i := uint32(1)
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
peer.endpoint.Lock()
|
|
||||||
if peer.endpoint.val == nil {
|
|
||||||
peer.endpoint.Unlock()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
|
|
||||||
if nativeEP == nil {
|
|
||||||
peer.endpoint.Unlock()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
|
|
||||||
peer.endpoint.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
nlmsg := struct {
|
|
||||||
hdr unix.NlMsghdr
|
|
||||||
msg unix.RtMsg
|
|
||||||
dsthdr unix.RtAttr
|
|
||||||
dst [4]byte
|
|
||||||
srchdr unix.RtAttr
|
|
||||||
src [4]byte
|
|
||||||
markhdr unix.RtAttr
|
|
||||||
mark uint32
|
|
||||||
}{
|
|
||||||
unix.NlMsghdr{
|
|
||||||
Type: uint16(unix.RTM_GETROUTE),
|
|
||||||
Flags: unix.NLM_F_REQUEST,
|
|
||||||
Seq: i,
|
|
||||||
},
|
|
||||||
unix.RtMsg{
|
|
||||||
Family: unix.AF_INET,
|
|
||||||
Dst_len: 32,
|
|
||||||
Src_len: 32,
|
|
||||||
},
|
|
||||||
unix.RtAttr{
|
|
||||||
Len: 8,
|
|
||||||
Type: unix.RTA_DST,
|
|
||||||
},
|
|
||||||
nativeEP.DstIP().As4(),
|
|
||||||
unix.RtAttr{
|
|
||||||
Len: 8,
|
|
||||||
Type: unix.RTA_SRC,
|
|
||||||
},
|
|
||||||
nativeEP.SrcIP().As4(),
|
|
||||||
unix.RtAttr{
|
|
||||||
Len: 8,
|
|
||||||
Type: unix.RTA_MARK,
|
|
||||||
},
|
|
||||||
device.net.fwmark,
|
|
||||||
}
|
|
||||||
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
|
||||||
reqPeerLock.Lock()
|
|
||||||
reqPeer[i] = peerEndpointPtr{
|
|
||||||
peer: peer,
|
|
||||||
endpoint: &peer.endpoint.val,
|
|
||||||
}
|
|
||||||
reqPeerLock.Unlock()
|
|
||||||
peer.endpoint.Unlock()
|
|
||||||
i++
|
|
||||||
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
|
||||||
if err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
device.peers.RUnlock()
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
remain = remain[hdr.Len:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func createNetlinkRouteSocket() (int, error) {
|
|
||||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
saddr := &unix.SockaddrNetlink{
|
|
||||||
Family: unix.AF_NETLINK,
|
|
||||||
Groups: unix.RTMGRP_IPV4_ROUTE,
|
|
||||||
}
|
|
||||||
err = unix.Bind(sock, saddr)
|
|
||||||
if err != nil {
|
|
||||||
unix.Close(sock)
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
return sock, nil
|
|
||||||
}
|
|
|
@ -1,53 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
|
||||||
)
|
|
||||||
|
|
||||||
const DefaultMTU = 1420
|
|
||||||
|
|
||||||
func (device *Device) RoutineTUNEventReader() {
|
|
||||||
device.log.Verbosef("Routine: event worker - started")
|
|
||||||
|
|
||||||
for event := range device.tun.device.Events() {
|
|
||||||
if event&tun.EventMTUUpdate != 0 {
|
|
||||||
mtu, err := device.tun.device.MTU()
|
|
||||||
if err != nil {
|
|
||||||
device.log.Errorf("Failed to load updated MTU of device: %v", err)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if mtu < 0 {
|
|
||||||
device.log.Errorf("MTU not updated to negative value: %v", mtu)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
var tooLarge string
|
|
||||||
if mtu > MaxContentSize {
|
|
||||||
tooLarge = fmt.Sprintf(" (too large, capped at %v)", MaxContentSize)
|
|
||||||
mtu = MaxContentSize
|
|
||||||
}
|
|
||||||
old := device.tun.mtu.Swap(int32(mtu))
|
|
||||||
if int(old) != mtu {
|
|
||||||
device.log.Verbosef("MTU updated: %v%s", mtu, tooLarge)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if event&tun.EventUp != 0 {
|
|
||||||
device.log.Verbosef("Interface up requested")
|
|
||||||
device.Up()
|
|
||||||
}
|
|
||||||
|
|
||||||
if event&tun.EventDown != 0 {
|
|
||||||
device.log.Verbosef("Interface down requested")
|
|
||||||
device.Down()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
device.log.Verbosef("Routine: event worker - stopped")
|
|
||||||
}
|
|
583
device/uapi.go
583
device/uapi.go
|
@ -1,583 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
|
||||||
)
|
|
||||||
|
|
||||||
type IPCError struct {
|
|
||||||
code int64 // error code
|
|
||||||
err error // underlying/wrapped error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s IPCError) Error() string {
|
|
||||||
return fmt.Sprintf("IPC error %d: %v", s.code, s.err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s IPCError) Unwrap() error {
|
|
||||||
return s.err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s IPCError) ErrorCode() int64 {
|
|
||||||
return s.code
|
|
||||||
}
|
|
||||||
|
|
||||||
func ipcErrorf(code int64, msg string, args ...any) *IPCError {
|
|
||||||
return &IPCError{code: code, err: fmt.Errorf(msg, args...)}
|
|
||||||
}
|
|
||||||
|
|
||||||
var byteBufferPool = &sync.Pool{
|
|
||||||
New: func() any { return new(bytes.Buffer) },
|
|
||||||
}
|
|
||||||
|
|
||||||
// IpcGetOperation implements the WireGuard configuration protocol "get" operation.
|
|
||||||
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
|
|
||||||
func (device *Device) IpcGetOperation(w io.Writer) error {
|
|
||||||
device.ipcMutex.RLock()
|
|
||||||
defer device.ipcMutex.RUnlock()
|
|
||||||
|
|
||||||
buf := byteBufferPool.Get().(*bytes.Buffer)
|
|
||||||
buf.Reset()
|
|
||||||
defer byteBufferPool.Put(buf)
|
|
||||||
sendf := func(format string, args ...any) {
|
|
||||||
fmt.Fprintf(buf, format, args...)
|
|
||||||
buf.WriteByte('\n')
|
|
||||||
}
|
|
||||||
keyf := func(prefix string, key *[32]byte) {
|
|
||||||
buf.Grow(len(key)*2 + 2 + len(prefix))
|
|
||||||
buf.WriteString(prefix)
|
|
||||||
buf.WriteByte('=')
|
|
||||||
const hex = "0123456789abcdef"
|
|
||||||
for i := 0; i < len(key); i++ {
|
|
||||||
buf.WriteByte(hex[key[i]>>4])
|
|
||||||
buf.WriteByte(hex[key[i]&0xf])
|
|
||||||
}
|
|
||||||
buf.WriteByte('\n')
|
|
||||||
}
|
|
||||||
|
|
||||||
func() {
|
|
||||||
// lock required resources
|
|
||||||
|
|
||||||
device.net.RLock()
|
|
||||||
defer device.net.RUnlock()
|
|
||||||
|
|
||||||
device.staticIdentity.RLock()
|
|
||||||
defer device.staticIdentity.RUnlock()
|
|
||||||
|
|
||||||
device.peers.RLock()
|
|
||||||
defer device.peers.RUnlock()
|
|
||||||
|
|
||||||
// serialize device related values
|
|
||||||
|
|
||||||
if !device.staticIdentity.privateKey.IsZero() {
|
|
||||||
keyf("private_key", (*[32]byte)(&device.staticIdentity.privateKey))
|
|
||||||
}
|
|
||||||
|
|
||||||
if device.net.port != 0 {
|
|
||||||
sendf("listen_port=%d", device.net.port)
|
|
||||||
}
|
|
||||||
|
|
||||||
if device.net.fwmark != 0 {
|
|
||||||
sendf("fwmark=%d", device.net.fwmark)
|
|
||||||
}
|
|
||||||
|
|
||||||
if device.isAdvancedSecurityOn() {
|
|
||||||
if device.aSecCfg.junkPacketCount != 0 {
|
|
||||||
sendf("jc=%d", device.aSecCfg.junkPacketCount)
|
|
||||||
}
|
|
||||||
if device.aSecCfg.junkPacketMinSize != 0 {
|
|
||||||
sendf("jmin=%d", device.aSecCfg.junkPacketMinSize)
|
|
||||||
}
|
|
||||||
if device.aSecCfg.junkPacketMaxSize != 0 {
|
|
||||||
sendf("jmax=%d", device.aSecCfg.junkPacketMaxSize)
|
|
||||||
}
|
|
||||||
if device.aSecCfg.initPacketJunkSize != 0 {
|
|
||||||
sendf("s1=%d", device.aSecCfg.initPacketJunkSize)
|
|
||||||
}
|
|
||||||
if device.aSecCfg.responsePacketJunkSize != 0 {
|
|
||||||
sendf("s2=%d", device.aSecCfg.responsePacketJunkSize)
|
|
||||||
}
|
|
||||||
if device.aSecCfg.initPacketMagicHeader != 0 {
|
|
||||||
sendf("h1=%d", device.aSecCfg.initPacketMagicHeader)
|
|
||||||
}
|
|
||||||
if device.aSecCfg.responsePacketMagicHeader != 0 {
|
|
||||||
sendf("h2=%d", device.aSecCfg.responsePacketMagicHeader)
|
|
||||||
}
|
|
||||||
if device.aSecCfg.underloadPacketMagicHeader != 0 {
|
|
||||||
sendf("h3=%d", device.aSecCfg.underloadPacketMagicHeader)
|
|
||||||
}
|
|
||||||
if device.aSecCfg.transportPacketMagicHeader != 0 {
|
|
||||||
sendf("h4=%d", device.aSecCfg.transportPacketMagicHeader)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, peer := range device.peers.keyMap {
|
|
||||||
// Serialize peer state.
|
|
||||||
peer.handshake.mutex.RLock()
|
|
||||||
keyf("public_key", (*[32]byte)(&peer.handshake.remoteStatic))
|
|
||||||
keyf("preshared_key", (*[32]byte)(&peer.handshake.presharedKey))
|
|
||||||
peer.handshake.mutex.RUnlock()
|
|
||||||
sendf("protocol_version=1")
|
|
||||||
peer.endpoint.Lock()
|
|
||||||
if peer.endpoint.val != nil {
|
|
||||||
sendf("endpoint=%s", peer.endpoint.val.DstToString())
|
|
||||||
}
|
|
||||||
peer.endpoint.Unlock()
|
|
||||||
|
|
||||||
nano := peer.lastHandshakeNano.Load()
|
|
||||||
secs := nano / time.Second.Nanoseconds()
|
|
||||||
nano %= time.Second.Nanoseconds()
|
|
||||||
|
|
||||||
sendf("last_handshake_time_sec=%d", secs)
|
|
||||||
sendf("last_handshake_time_nsec=%d", nano)
|
|
||||||
sendf("tx_bytes=%d", peer.txBytes.Load())
|
|
||||||
sendf("rx_bytes=%d", peer.rxBytes.Load())
|
|
||||||
sendf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval.Load())
|
|
||||||
|
|
||||||
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
|
||||||
sendf("allowed_ip=%s", prefix.String())
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// send lines (does not require resource locks)
|
|
||||||
if _, err := w.Write(buf.Bytes()); err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorIO, "failed to write output: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IpcSetOperation implements the WireGuard configuration protocol "set" operation.
|
|
||||||
// See https://www.wireguard.com/xplatform/#configuration-protocol for details.
|
|
||||||
func (device *Device) IpcSetOperation(r io.Reader) (err error) {
|
|
||||||
device.ipcMutex.Lock()
|
|
||||||
defer device.ipcMutex.Unlock()
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if err != nil {
|
|
||||||
device.log.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
peer := new(ipcSetPeer)
|
|
||||||
deviceConfig := true
|
|
||||||
|
|
||||||
tempASecCfg := aSecCfgType{}
|
|
||||||
scanner := bufio.NewScanner(r)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := scanner.Text()
|
|
||||||
if line == "" {
|
|
||||||
// Blank line means terminate operation.
|
|
||||||
err := device.handlePostConfig(&tempASecCfg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
peer.handlePostConfig()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
key, value, ok := strings.Cut(line, "=")
|
|
||||||
if !ok {
|
|
||||||
return ipcErrorf(
|
|
||||||
ipc.IpcErrorProtocol,
|
|
||||||
"failed to parse line %q",
|
|
||||||
line,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if key == "public_key" {
|
|
||||||
if deviceConfig {
|
|
||||||
deviceConfig = false
|
|
||||||
}
|
|
||||||
peer.handlePostConfig()
|
|
||||||
// Load/create the peer we are now configuring.
|
|
||||||
err := device.handlePublicKeyLine(peer, value)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
if deviceConfig {
|
|
||||||
err = device.handleDeviceLine(key, value, &tempASecCfg)
|
|
||||||
} else {
|
|
||||||
err = device.handlePeerLine(peer, key, value)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
err = device.handlePostConfig(&tempASecCfg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
peer.handlePostConfig()
|
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorIO, "failed to read input: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) handleDeviceLine(key, value string, tempASecCfg *aSecCfgType) error {
|
|
||||||
switch key {
|
|
||||||
case "private_key":
|
|
||||||
var sk NoisePrivateKey
|
|
||||||
err := sk.FromMaybeZeroHex(value)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set private_key: %w", err)
|
|
||||||
}
|
|
||||||
device.log.Verbosef("UAPI: Updating private key")
|
|
||||||
device.SetPrivateKey(sk)
|
|
||||||
|
|
||||||
case "listen_port":
|
|
||||||
port, err := strconv.ParseUint(value, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to parse listen_port: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// update port and rebind
|
|
||||||
device.log.Verbosef("UAPI: Updating listen port")
|
|
||||||
|
|
||||||
device.net.Lock()
|
|
||||||
device.net.port = uint16(port)
|
|
||||||
device.net.Unlock()
|
|
||||||
|
|
||||||
if err := device.BindUpdate(); err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to set listen_port: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
case "fwmark":
|
|
||||||
mark, err := strconv.ParseUint(value, 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid fwmark: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
device.log.Verbosef("UAPI: Updating fwmark")
|
|
||||||
if err := device.BindSetMark(uint32(mark)); err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorPortInUse, "failed to update fwmark: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
case "replace_peers":
|
|
||||||
if value != "true" {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set replace_peers, invalid value: %v", value)
|
|
||||||
}
|
|
||||||
device.log.Verbosef("UAPI: Removing all peers")
|
|
||||||
device.RemoveAllPeers()
|
|
||||||
|
|
||||||
case "jc":
|
|
||||||
junkPacketCount, err := strconv.Atoi(value)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_count %w", err)
|
|
||||||
}
|
|
||||||
device.log.Verbosef("UAPI: Updating junk_packet_count")
|
|
||||||
tempASecCfg.junkPacketCount = junkPacketCount
|
|
||||||
tempASecCfg.isSet = true
|
|
||||||
|
|
||||||
case "jmin":
|
|
||||||
junkPacketMinSize, err := strconv.Atoi(value)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_min_size %w", err)
|
|
||||||
}
|
|
||||||
device.log.Verbosef("UAPI: Updating junk_packet_min_size")
|
|
||||||
tempASecCfg.junkPacketMinSize = junkPacketMinSize
|
|
||||||
tempASecCfg.isSet = true
|
|
||||||
|
|
||||||
case "jmax":
|
|
||||||
junkPacketMaxSize, err := strconv.Atoi(value)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse junk_packet_max_size %w", err)
|
|
||||||
}
|
|
||||||
device.log.Verbosef("UAPI: Updating junk_packet_max_size")
|
|
||||||
tempASecCfg.junkPacketMaxSize = junkPacketMaxSize
|
|
||||||
tempASecCfg.isSet = true
|
|
||||||
|
|
||||||
case "s1":
|
|
||||||
initPacketJunkSize, err := strconv.Atoi(value)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_junk_size %w", err)
|
|
||||||
}
|
|
||||||
device.log.Verbosef("UAPI: Updating init_packet_junk_size")
|
|
||||||
tempASecCfg.initPacketJunkSize = initPacketJunkSize
|
|
||||||
tempASecCfg.isSet = true
|
|
||||||
|
|
||||||
case "s2":
|
|
||||||
responsePacketJunkSize, err := strconv.Atoi(value)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_junk_size %w", err)
|
|
||||||
}
|
|
||||||
device.log.Verbosef("UAPI: Updating response_packet_junk_size")
|
|
||||||
tempASecCfg.responsePacketJunkSize = responsePacketJunkSize
|
|
||||||
tempASecCfg.isSet = true
|
|
||||||
|
|
||||||
case "h1":
|
|
||||||
initPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse init_packet_magic_header %w", err)
|
|
||||||
}
|
|
||||||
tempASecCfg.initPacketMagicHeader = uint32(initPacketMagicHeader)
|
|
||||||
tempASecCfg.isSet = true
|
|
||||||
|
|
||||||
case "h2":
|
|
||||||
responsePacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse response_packet_magic_header %w", err)
|
|
||||||
}
|
|
||||||
tempASecCfg.responsePacketMagicHeader = uint32(responsePacketMagicHeader)
|
|
||||||
tempASecCfg.isSet = true
|
|
||||||
|
|
||||||
case "h3":
|
|
||||||
underloadPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse underload_packet_magic_header %w", err)
|
|
||||||
}
|
|
||||||
tempASecCfg.underloadPacketMagicHeader = uint32(underloadPacketMagicHeader)
|
|
||||||
tempASecCfg.isSet = true
|
|
||||||
|
|
||||||
case "h4":
|
|
||||||
transportPacketMagicHeader, err := strconv.ParseUint(value, 10, 32)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "faield to parse transport_packet_magic_header %w", err)
|
|
||||||
}
|
|
||||||
tempASecCfg.transportPacketMagicHeader = uint32(transportPacketMagicHeader)
|
|
||||||
tempASecCfg.isSet = true
|
|
||||||
|
|
||||||
default:
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI device key: %v", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// An ipcSetPeer is the current state of an IPC set operation on a peer.
|
|
||||||
type ipcSetPeer struct {
|
|
||||||
*Peer // Peer is the current peer being operated on
|
|
||||||
dummy bool // dummy reports whether this peer is a temporary, placeholder peer
|
|
||||||
created bool // new reports whether this is a newly created peer
|
|
||||||
pkaOn bool // pkaOn reports whether the peer had the persistent keepalive turn on
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *ipcSetPeer) handlePostConfig() {
|
|
||||||
if peer.Peer == nil || peer.dummy {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if peer.created {
|
|
||||||
peer.endpoint.disableRoaming = peer.device.net.brokenRoaming && peer.endpoint.val != nil
|
|
||||||
}
|
|
||||||
if peer.device.isUp() {
|
|
||||||
peer.Start()
|
|
||||||
if peer.pkaOn {
|
|
||||||
peer.SendKeepalive()
|
|
||||||
}
|
|
||||||
peer.SendStagedPackets()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) handlePublicKeyLine(
|
|
||||||
peer *ipcSetPeer,
|
|
||||||
value string,
|
|
||||||
) error {
|
|
||||||
// Load/create the peer we are configuring.
|
|
||||||
var publicKey NoisePublicKey
|
|
||||||
err := publicKey.FromHex(value)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to get peer by public key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Ignore peer with the same public key as this device.
|
|
||||||
device.staticIdentity.RLock()
|
|
||||||
peer.dummy = device.staticIdentity.publicKey.Equals(publicKey)
|
|
||||||
device.staticIdentity.RUnlock()
|
|
||||||
|
|
||||||
if peer.dummy {
|
|
||||||
peer.Peer = &Peer{}
|
|
||||||
} else {
|
|
||||||
peer.Peer = device.LookupPeer(publicKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.created = peer.Peer == nil
|
|
||||||
if peer.created {
|
|
||||||
peer.Peer, err = device.NewPeer(publicKey)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to create new peer: %w", err)
|
|
||||||
}
|
|
||||||
device.log.Verbosef("%v - UAPI: Created", peer.Peer)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) handlePeerLine(
|
|
||||||
peer *ipcSetPeer,
|
|
||||||
key, value string,
|
|
||||||
) error {
|
|
||||||
switch key {
|
|
||||||
case "update_only":
|
|
||||||
// allow disabling of creation
|
|
||||||
if value != "true" {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set update only, invalid value: %v", value)
|
|
||||||
}
|
|
||||||
if peer.created && !peer.dummy {
|
|
||||||
device.RemovePeer(peer.handshake.remoteStatic)
|
|
||||||
peer.Peer = &Peer{}
|
|
||||||
peer.dummy = true
|
|
||||||
}
|
|
||||||
|
|
||||||
case "remove":
|
|
||||||
// remove currently selected peer from device
|
|
||||||
if value != "true" {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set remove, invalid value: %v", value)
|
|
||||||
}
|
|
||||||
if !peer.dummy {
|
|
||||||
device.log.Verbosef("%v - UAPI: Removing", peer.Peer)
|
|
||||||
device.RemovePeer(peer.handshake.remoteStatic)
|
|
||||||
}
|
|
||||||
peer.Peer = &Peer{}
|
|
||||||
peer.dummy = true
|
|
||||||
|
|
||||||
case "preshared_key":
|
|
||||||
device.log.Verbosef("%v - UAPI: Updating preshared key", peer.Peer)
|
|
||||||
|
|
||||||
peer.handshake.mutex.Lock()
|
|
||||||
err := peer.handshake.presharedKey.FromHex(value)
|
|
||||||
peer.handshake.mutex.Unlock()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set preshared key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
case "endpoint":
|
|
||||||
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
|
|
||||||
endpoint, err := device.net.bind.ParseEndpoint(value)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
|
|
||||||
}
|
|
||||||
peer.endpoint.Lock()
|
|
||||||
defer peer.endpoint.Unlock()
|
|
||||||
peer.endpoint.val = endpoint
|
|
||||||
|
|
||||||
case "persistent_keepalive_interval":
|
|
||||||
device.log.Verbosef("%v - UAPI: Updating persistent keepalive interval", peer.Peer)
|
|
||||||
|
|
||||||
secs, err := strconv.ParseUint(value, 10, 16)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set persistent keepalive interval: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
old := peer.persistentKeepaliveInterval.Swap(uint32(secs))
|
|
||||||
|
|
||||||
// Send immediate keepalive if we're turning it on and before it wasn't on.
|
|
||||||
peer.pkaOn = old == 0 && secs != 0
|
|
||||||
|
|
||||||
case "replace_allowed_ips":
|
|
||||||
device.log.Verbosef("%v - UAPI: Removing all allowedips", peer.Peer)
|
|
||||||
if value != "true" {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to replace allowedips, invalid value: %v", value)
|
|
||||||
}
|
|
||||||
if peer.dummy {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
device.allowedips.RemoveByPeer(peer.Peer)
|
|
||||||
|
|
||||||
case "allowed_ip":
|
|
||||||
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
|
||||||
prefix, err := netip.ParsePrefix(value)
|
|
||||||
if err != nil {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
|
||||||
}
|
|
||||||
if peer.dummy {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
device.allowedips.Insert(prefix, peer.Peer)
|
|
||||||
|
|
||||||
case "protocol_version":
|
|
||||||
if value != "1" {
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid protocol version: %v", value)
|
|
||||||
}
|
|
||||||
|
|
||||||
default:
|
|
||||||
return ipcErrorf(ipc.IpcErrorInvalid, "invalid UAPI peer key: %v", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) IpcGet() (string, error) {
|
|
||||||
buf := new(strings.Builder)
|
|
||||||
if err := device.IpcGetOperation(buf); err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return buf.String(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) IpcSet(uapiConf string) error {
|
|
||||||
return device.IpcSetOperation(strings.NewReader(uapiConf))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) IpcHandle(socket net.Conn) {
|
|
||||||
defer socket.Close()
|
|
||||||
|
|
||||||
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
|
|
||||||
reader := bufio.NewReader(s)
|
|
||||||
writer := bufio.NewWriter(s)
|
|
||||||
return bufio.NewReadWriter(reader, writer)
|
|
||||||
}(socket)
|
|
||||||
|
|
||||||
for {
|
|
||||||
op, err := buffered.ReadString('\n')
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// handle operation
|
|
||||||
switch op {
|
|
||||||
case "set=1\n":
|
|
||||||
err = device.IpcSetOperation(buffered.Reader)
|
|
||||||
case "get=1\n":
|
|
||||||
var nextByte byte
|
|
||||||
nextByte, err = buffered.ReadByte()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if nextByte != '\n' {
|
|
||||||
err = ipcErrorf(ipc.IpcErrorInvalid, "trailing character in UAPI get: %q", nextByte)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
err = device.IpcGetOperation(buffered.Writer)
|
|
||||||
default:
|
|
||||||
device.log.Errorf("invalid UAPI operation: %v", op)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// write status
|
|
||||||
var status *IPCError
|
|
||||||
if err != nil && !errors.As(err, &status) {
|
|
||||||
// shouldn't happen
|
|
||||||
status = ipcErrorf(ipc.IpcErrorUnknown, "other UAPI error: %w", err)
|
|
||||||
}
|
|
||||||
if status != nil {
|
|
||||||
device.log.Errorf("%v", status)
|
|
||||||
fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
|
|
||||||
} else {
|
|
||||||
fmt.Fprintf(buffered, "errno=0\n\n")
|
|
||||||
}
|
|
||||||
buffered.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
48
device_test.go
Normal file
48
device_test.go
Normal file
|
@ -0,0 +1,48 @@
|
||||||
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
/* Create two device instances and simulate full WireGuard interaction
|
||||||
|
* without network dependencies
|
||||||
|
*/
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestDevice(t *testing.T) {
|
||||||
|
|
||||||
|
// prepare tun devices for generating traffic
|
||||||
|
|
||||||
|
tun1, err := CreateDummyTUN("tun1")
|
||||||
|
if err != nil {
|
||||||
|
t.Error("failed to create tun:", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
tun2, err := CreateDummyTUN("tun2")
|
||||||
|
if err != nil {
|
||||||
|
t.Error("failed to create tun:", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
println(tun1)
|
||||||
|
println(tun2)
|
||||||
|
|
||||||
|
// prepare endpoints
|
||||||
|
|
||||||
|
end1, err := CreateDummyEndpoint()
|
||||||
|
if err != nil {
|
||||||
|
t.Error("failed to create endpoint:", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
end2, err := CreateDummyEndpoint()
|
||||||
|
if err != nil {
|
||||||
|
t.Error("failed to create endpoint:", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
println(end1)
|
||||||
|
println(end2)
|
||||||
|
|
||||||
|
// create binds
|
||||||
|
|
||||||
|
}
|
53
endpoint_test.go
Normal file
53
endpoint_test.go
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"math/rand"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
type DummyEndpoint struct {
|
||||||
|
src [16]byte
|
||||||
|
dst [16]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateDummyEndpoint() (*DummyEndpoint, error) {
|
||||||
|
var end DummyEndpoint
|
||||||
|
if _, err := rand.Read(end.src[:]); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, err := rand.Read(end.dst[:])
|
||||||
|
return &end, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *DummyEndpoint) ClearSrc() {}
|
||||||
|
|
||||||
|
func (e *DummyEndpoint) SrcToString() string {
|
||||||
|
var addr net.UDPAddr
|
||||||
|
addr.IP = e.SrcIP()
|
||||||
|
addr.Port = 1000
|
||||||
|
return addr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *DummyEndpoint) DstToString() string {
|
||||||
|
var addr net.UDPAddr
|
||||||
|
addr.IP = e.DstIP()
|
||||||
|
addr.Port = 1000
|
||||||
|
return addr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *DummyEndpoint) SrcToBytes() []byte {
|
||||||
|
return e.src[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *DummyEndpoint) DstIP() net.IP {
|
||||||
|
return e.dst[:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *DummyEndpoint) SrcIP() net.IP {
|
||||||
|
return e.src[:]
|
||||||
|
}
|
|
@ -1,51 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"go/format"
|
|
||||||
"io/fs"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestFormatting(t *testing.T) {
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
filepath.WalkDir(".", func(path string, d fs.DirEntry, err error) error {
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("unable to walk %s: %v", path, err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if d.IsDir() || filepath.Ext(path) != ".go" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
wg.Add(1)
|
|
||||||
go func(path string) {
|
|
||||||
defer wg.Done()
|
|
||||||
src, err := os.ReadFile(path)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("unable to read %s: %v", path, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
src = bytes.ReplaceAll(src, []byte{'\r', '\n'}, []byte{'\n'})
|
|
||||||
}
|
|
||||||
formatted, err := format.Source(src)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("unable to format %s: %v", path, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !bytes.Equal(src, formatted) {
|
|
||||||
t.Errorf("unformatted code: %s", path)
|
|
||||||
}
|
|
||||||
}(path)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
20
generate-vendor.sh
Executable file
20
generate-vendor.sh
Executable file
|
@ -0,0 +1,20 @@
|
||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
echo "# This was generated by ./generate-vendor.sh" > Gopkg.lock
|
||||||
|
echo "# This was generated by ./generate-vendor.sh" > Gopkg.toml
|
||||||
|
|
||||||
|
while read -r package; do
|
||||||
|
cat >> Gopkg.lock <<-_EOF
|
||||||
|
[[projects]]
|
||||||
|
branch = "master"
|
||||||
|
name = "$package"
|
||||||
|
revision = "$(< "$GOPATH/src/$package/.git/refs/heads/master")"
|
||||||
|
|
||||||
|
_EOF
|
||||||
|
cat >> Gopkg.toml <<-_EOF
|
||||||
|
[[constraint]]
|
||||||
|
branch = "master"
|
||||||
|
name = "$package"
|
||||||
|
|
||||||
|
_EOF
|
||||||
|
done < <(sed -n 's/.*"\(golang.org\/x\/[^/]\+\)\/\?.*".*/\1/p' *.go */*.go | sort | uniq)
|
17
go.mod
17
go.mod
|
@ -1,17 +0,0 @@
|
||||||
module github.com/amnezia-vpn/amneziawg-go
|
|
||||||
|
|
||||||
go 1.24
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/tevino/abool/v2 v2.1.0
|
|
||||||
golang.org/x/crypto v0.36.0
|
|
||||||
golang.org/x/net v0.37.0
|
|
||||||
golang.org/x/sys v0.31.0
|
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2
|
|
||||||
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6
|
|
||||||
)
|
|
||||||
|
|
||||||
require (
|
|
||||||
github.com/google/btree v1.1.3 // indirect
|
|
||||||
golang.org/x/time v0.9.0 // indirect
|
|
||||||
)
|
|
20
go.sum
20
go.sum
|
@ -1,20 +0,0 @@
|
||||||
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
|
|
||||||
github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
|
||||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
|
||||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
|
||||||
github.com/tevino/abool/v2 v2.1.0 h1:7w+Vf9f/5gmKT4m4qkayb33/92M+Um45F2BkHOR+L/c=
|
|
||||||
github.com/tevino/abool/v2 v2.1.0/go.mod h1:+Lmlqk6bHDWHqN1cbxqhwEAwMPXgc8I1SDEamtseuXY=
|
|
||||||
golang.org/x/crypto v0.36.0 h1:AnAEvhDddvBdpY+uR+MyHmuZzzNqXSe/GvuDeob5L34=
|
|
||||||
golang.org/x/crypto v0.36.0/go.mod h1:Y4J0ReaxCR1IMaabaSMugxJES1EpwhBHhv2bDHklZvc=
|
|
||||||
golang.org/x/mod v0.21.0 h1:vvrHzRwRfVKSiLrG+d4FMl/Qi4ukBCE6kZlTUkDYRT0=
|
|
||||||
golang.org/x/mod v0.21.0/go.mod h1:6SkKJ3Xj0I0BrPOZoBy3bdMptDDU9oJrpohJ3eWZ1fY=
|
|
||||||
golang.org/x/net v0.37.0 h1:1zLorHbz+LYj7MQlSf1+2tPIIgibq2eL5xkrGk6f+2c=
|
|
||||||
golang.org/x/net v0.37.0/go.mod h1:ivrbrMbzFq5J41QOQh0siUuly180yBYtLp+CKbEaFx8=
|
|
||||||
golang.org/x/sys v0.31.0 h1:ioabZlmFYtWhL+TRYpcnNlLwhyxaM9kWTDEmfnprqik=
|
|
||||||
golang.org/x/sys v0.31.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
|
||||||
golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY=
|
|
||||||
golang.org/x/time v0.9.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM=
|
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg=
|
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI=
|
|
||||||
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6 h1:6B7MdW3OEbJqOMr7cEYU9bkzvCjUBX/JlXk12xcANuQ=
|
|
||||||
gvisor.dev/gvisor v0.0.0-20250130013005-04f9204697c6/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM=
|
|
84
helper_test.go
Normal file
84
helper_test.go
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Helpers for writing unit tests
|
||||||
|
*/
|
||||||
|
|
||||||
|
type DummyTUN struct {
|
||||||
|
name string
|
||||||
|
mtu int
|
||||||
|
packets chan []byte
|
||||||
|
events chan TUNEvent
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *DummyTUN) File() *os.File {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *DummyTUN) Name() (string, error) {
|
||||||
|
return tun.name, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *DummyTUN) MTU() (int, error) {
|
||||||
|
return tun.mtu, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *DummyTUN) Write(d []byte, offset int) (int, error) {
|
||||||
|
tun.packets <- d[offset:]
|
||||||
|
return len(d), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *DummyTUN) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *DummyTUN) Events() chan TUNEvent {
|
||||||
|
return tun.events
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tun *DummyTUN) Read(d []byte, offset int) (int, error) {
|
||||||
|
t := <-tun.packets
|
||||||
|
copy(d[offset:], t)
|
||||||
|
return len(t), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func CreateDummyTUN(name string) (TUNDevice, error) {
|
||||||
|
var dummy DummyTUN
|
||||||
|
dummy.mtu = 0
|
||||||
|
dummy.packets = make(chan []byte, 100)
|
||||||
|
return &dummy, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertNil(t *testing.T, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertEqual(t *testing.T, a []byte, b []byte) {
|
||||||
|
if bytes.Compare(a, b) != 0 {
|
||||||
|
t.Fatal(a, "!=", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func randDevice(t *testing.T) *Device {
|
||||||
|
sk, err := newPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tun, _ := CreateDummyTUN("dummy")
|
||||||
|
logger := NewLogger(LogLevelError, "")
|
||||||
|
device := NewDevice(tun, logger)
|
||||||
|
device.SetPrivateKey(sk)
|
||||||
|
return device
|
||||||
|
}
|
|
@ -1,14 +1,14 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/binary"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
"unsafe"
|
||||||
)
|
)
|
||||||
|
|
||||||
type IndexTableEntry struct {
|
type IndexTableEntry struct {
|
||||||
|
@ -18,32 +18,31 @@ type IndexTableEntry struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type IndexTable struct {
|
type IndexTable struct {
|
||||||
sync.RWMutex
|
mutex sync.RWMutex
|
||||||
table map[uint32]IndexTableEntry
|
table map[uint32]IndexTableEntry
|
||||||
}
|
}
|
||||||
|
|
||||||
func randUint32() (uint32, error) {
|
func randUint32() (uint32, error) {
|
||||||
var integer [4]byte
|
var integer [4]byte
|
||||||
_, err := rand.Read(integer[:])
|
_, err := rand.Read(integer[:])
|
||||||
// Arbitrary endianness; both are intrinsified by the Go compiler.
|
return *(*uint32)(unsafe.Pointer(&integer[0])), err
|
||||||
return binary.LittleEndian.Uint32(integer[:]), err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *IndexTable) Init() {
|
func (table *IndexTable) Init() {
|
||||||
table.Lock()
|
table.mutex.Lock()
|
||||||
defer table.Unlock()
|
defer table.mutex.Unlock()
|
||||||
table.table = make(map[uint32]IndexTableEntry)
|
table.table = make(map[uint32]IndexTableEntry)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *IndexTable) Delete(index uint32) {
|
func (table *IndexTable) Delete(index uint32) {
|
||||||
table.Lock()
|
table.mutex.Lock()
|
||||||
defer table.Unlock()
|
defer table.mutex.Unlock()
|
||||||
delete(table.table, index)
|
delete(table.table, index)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) {
|
func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) {
|
||||||
table.Lock()
|
table.mutex.Lock()
|
||||||
defer table.Unlock()
|
defer table.mutex.Unlock()
|
||||||
entry, ok := table.table[index]
|
entry, ok := table.table[index]
|
||||||
if !ok {
|
if !ok {
|
||||||
return
|
return
|
||||||
|
@ -66,19 +65,19 @@ func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake)
|
||||||
|
|
||||||
// check if index used
|
// check if index used
|
||||||
|
|
||||||
table.RLock()
|
table.mutex.RLock()
|
||||||
_, ok := table.table[index]
|
_, ok := table.table[index]
|
||||||
table.RUnlock()
|
table.mutex.RUnlock()
|
||||||
if ok {
|
if ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// check again while locked
|
// check again while locked
|
||||||
|
|
||||||
table.Lock()
|
table.mutex.Lock()
|
||||||
_, found := table.table[index]
|
_, found := table.table[index]
|
||||||
if found {
|
if found {
|
||||||
table.Unlock()
|
table.mutex.Unlock()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
table.table[index] = IndexTableEntry{
|
table.table[index] = IndexTableEntry{
|
||||||
|
@ -86,13 +85,13 @@ func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake)
|
||||||
handshake: handshake,
|
handshake: handshake,
|
||||||
keypair: nil,
|
keypair: nil,
|
||||||
}
|
}
|
||||||
table.Unlock()
|
table.mutex.Unlock()
|
||||||
return index, nil
|
return index, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (table *IndexTable) Lookup(id uint32) IndexTableEntry {
|
func (table *IndexTable) Lookup(id uint32) IndexTableEntry {
|
||||||
table.RLock()
|
table.mutex.RLock()
|
||||||
defer table.RUnlock()
|
defer table.mutex.RUnlock()
|
||||||
return table.table[id]
|
return table.table[id]
|
||||||
}
|
}
|
|
@ -1,9 +1,9 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
|
@ -1,287 +0,0 @@
|
||||||
// Copyright 2021 The Go Authors. All rights reserved.
|
|
||||||
// Copyright 2015 Microsoft
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
//go:build windows
|
|
||||||
|
|
||||||
package namedpipe
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
type timeoutChan chan struct{}
|
|
||||||
|
|
||||||
var (
|
|
||||||
ioInitOnce sync.Once
|
|
||||||
ioCompletionPort windows.Handle
|
|
||||||
)
|
|
||||||
|
|
||||||
// ioResult contains the result of an asynchronous IO operation
|
|
||||||
type ioResult struct {
|
|
||||||
bytes uint32
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// ioOperation represents an outstanding asynchronous Win32 IO
|
|
||||||
type ioOperation struct {
|
|
||||||
o windows.Overlapped
|
|
||||||
ch chan ioResult
|
|
||||||
}
|
|
||||||
|
|
||||||
func initIo() {
|
|
||||||
h, err := windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
ioCompletionPort = h
|
|
||||||
go ioCompletionProcessor(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// file implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
|
|
||||||
// It takes ownership of this handle and will close it if it is garbage collected.
|
|
||||||
type file struct {
|
|
||||||
handle windows.Handle
|
|
||||||
wg sync.WaitGroup
|
|
||||||
wgLock sync.RWMutex
|
|
||||||
closing atomic.Bool
|
|
||||||
socket bool
|
|
||||||
readDeadline deadlineHandler
|
|
||||||
writeDeadline deadlineHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
type deadlineHandler struct {
|
|
||||||
setLock sync.Mutex
|
|
||||||
channel timeoutChan
|
|
||||||
channelLock sync.RWMutex
|
|
||||||
timer *time.Timer
|
|
||||||
timedout atomic.Bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// makeFile makes a new file from an existing file handle
|
|
||||||
func makeFile(h windows.Handle) (*file, error) {
|
|
||||||
f := &file{handle: h}
|
|
||||||
ioInitOnce.Do(initIo)
|
|
||||||
_, err := windows.CreateIoCompletionPort(h, ioCompletionPort, 0, 0)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
err = windows.SetFileCompletionNotificationModes(h, windows.FILE_SKIP_COMPLETION_PORT_ON_SUCCESS|windows.FILE_SKIP_SET_EVENT_ON_HANDLE)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
f.readDeadline.channel = make(timeoutChan)
|
|
||||||
f.writeDeadline.channel = make(timeoutChan)
|
|
||||||
return f, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// closeHandle closes the resources associated with a Win32 handle
|
|
||||||
func (f *file) closeHandle() {
|
|
||||||
f.wgLock.Lock()
|
|
||||||
// Atomically set that we are closing, releasing the resources only once.
|
|
||||||
if f.closing.Swap(true) == false {
|
|
||||||
f.wgLock.Unlock()
|
|
||||||
// cancel all IO and wait for it to complete
|
|
||||||
windows.CancelIoEx(f.handle, nil)
|
|
||||||
f.wg.Wait()
|
|
||||||
// at this point, no new IO can start
|
|
||||||
windows.Close(f.handle)
|
|
||||||
f.handle = 0
|
|
||||||
} else {
|
|
||||||
f.wgLock.Unlock()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes a file.
|
|
||||||
func (f *file) Close() error {
|
|
||||||
f.closeHandle()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// prepareIo prepares for a new IO operation.
|
|
||||||
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
|
|
||||||
func (f *file) prepareIo() (*ioOperation, error) {
|
|
||||||
f.wgLock.RLock()
|
|
||||||
if f.closing.Load() {
|
|
||||||
f.wgLock.RUnlock()
|
|
||||||
return nil, os.ErrClosed
|
|
||||||
}
|
|
||||||
f.wg.Add(1)
|
|
||||||
f.wgLock.RUnlock()
|
|
||||||
c := &ioOperation{}
|
|
||||||
c.ch = make(chan ioResult)
|
|
||||||
return c, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ioCompletionProcessor processes completed async IOs forever
|
|
||||||
func ioCompletionProcessor(h windows.Handle) {
|
|
||||||
for {
|
|
||||||
var bytes uint32
|
|
||||||
var key uintptr
|
|
||||||
var op *ioOperation
|
|
||||||
err := windows.GetQueuedCompletionStatus(h, &bytes, &key, (**windows.Overlapped)(unsafe.Pointer(&op)), windows.INFINITE)
|
|
||||||
if op == nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
op.ch <- ioResult{bytes, err}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
|
|
||||||
// the operation has actually completed.
|
|
||||||
func (f *file) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
|
|
||||||
if err != windows.ERROR_IO_PENDING {
|
|
||||||
return int(bytes), err
|
|
||||||
}
|
|
||||||
|
|
||||||
if f.closing.Load() {
|
|
||||||
windows.CancelIoEx(f.handle, &c.o)
|
|
||||||
}
|
|
||||||
|
|
||||||
var timeout timeoutChan
|
|
||||||
if d != nil {
|
|
||||||
d.channelLock.Lock()
|
|
||||||
timeout = d.channel
|
|
||||||
d.channelLock.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
var r ioResult
|
|
||||||
select {
|
|
||||||
case r = <-c.ch:
|
|
||||||
err = r.err
|
|
||||||
if err == windows.ERROR_OPERATION_ABORTED {
|
|
||||||
if f.closing.Load() {
|
|
||||||
err = os.ErrClosed
|
|
||||||
}
|
|
||||||
} else if err != nil && f.socket {
|
|
||||||
// err is from Win32. Query the overlapped structure to get the winsock error.
|
|
||||||
var bytes, flags uint32
|
|
||||||
err = windows.WSAGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
|
|
||||||
}
|
|
||||||
case <-timeout:
|
|
||||||
windows.CancelIoEx(f.handle, &c.o)
|
|
||||||
r = <-c.ch
|
|
||||||
err = r.err
|
|
||||||
if err == windows.ERROR_OPERATION_ABORTED {
|
|
||||||
err = os.ErrDeadlineExceeded
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// runtime.KeepAlive is needed, as c is passed via native
|
|
||||||
// code to ioCompletionProcessor, c must remain alive
|
|
||||||
// until the channel read is complete.
|
|
||||||
runtime.KeepAlive(c)
|
|
||||||
return int(r.bytes), err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read reads from a file handle.
|
|
||||||
func (f *file) Read(b []byte) (int, error) {
|
|
||||||
c, err := f.prepareIo()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
defer f.wg.Done()
|
|
||||||
|
|
||||||
if f.readDeadline.timedout.Load() {
|
|
||||||
return 0, os.ErrDeadlineExceeded
|
|
||||||
}
|
|
||||||
|
|
||||||
var bytes uint32
|
|
||||||
err = windows.ReadFile(f.handle, b, &bytes, &c.o)
|
|
||||||
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
|
|
||||||
runtime.KeepAlive(b)
|
|
||||||
|
|
||||||
// Handle EOF conditions.
|
|
||||||
if err == nil && n == 0 && len(b) != 0 {
|
|
||||||
return 0, io.EOF
|
|
||||||
} else if err == windows.ERROR_BROKEN_PIPE {
|
|
||||||
return 0, io.EOF
|
|
||||||
} else {
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write writes to a file handle.
|
|
||||||
func (f *file) Write(b []byte) (int, error) {
|
|
||||||
c, err := f.prepareIo()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
defer f.wg.Done()
|
|
||||||
|
|
||||||
if f.writeDeadline.timedout.Load() {
|
|
||||||
return 0, os.ErrDeadlineExceeded
|
|
||||||
}
|
|
||||||
|
|
||||||
var bytes uint32
|
|
||||||
err = windows.WriteFile(f.handle, b, &bytes, &c.o)
|
|
||||||
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
|
|
||||||
runtime.KeepAlive(b)
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *file) SetReadDeadline(deadline time.Time) error {
|
|
||||||
return f.readDeadline.set(deadline)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *file) SetWriteDeadline(deadline time.Time) error {
|
|
||||||
return f.writeDeadline.set(deadline)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *file) Flush() error {
|
|
||||||
return windows.FlushFileBuffers(f.handle)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *file) Fd() uintptr {
|
|
||||||
return uintptr(f.handle)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *deadlineHandler) set(deadline time.Time) error {
|
|
||||||
d.setLock.Lock()
|
|
||||||
defer d.setLock.Unlock()
|
|
||||||
|
|
||||||
if d.timer != nil {
|
|
||||||
if !d.timer.Stop() {
|
|
||||||
<-d.channel
|
|
||||||
}
|
|
||||||
d.timer = nil
|
|
||||||
}
|
|
||||||
d.timedout.Store(false)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-d.channel:
|
|
||||||
d.channelLock.Lock()
|
|
||||||
d.channel = make(chan struct{})
|
|
||||||
d.channelLock.Unlock()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
if deadline.IsZero() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
timeoutIO := func() {
|
|
||||||
d.timedout.Store(true)
|
|
||||||
close(d.channel)
|
|
||||||
}
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
duration := deadline.Sub(now)
|
|
||||||
if deadline.After(now) {
|
|
||||||
// Deadline is in the future, set a timer to wait
|
|
||||||
d.timer = time.AfterFunc(duration, timeoutIO)
|
|
||||||
} else {
|
|
||||||
// Deadline is in the past. Cancel all pending IO now.
|
|
||||||
timeoutIO()
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,485 +0,0 @@
|
||||||
// Copyright 2021 The Go Authors. All rights reserved.
|
|
||||||
// Copyright 2015 Microsoft
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
//go:build windows
|
|
||||||
|
|
||||||
// Package namedpipe implements a net.Conn and net.Listener around Windows named pipes.
|
|
||||||
package namedpipe
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
type pipe struct {
|
|
||||||
*file
|
|
||||||
path string
|
|
||||||
}
|
|
||||||
|
|
||||||
type messageBytePipe struct {
|
|
||||||
pipe
|
|
||||||
writeClosed atomic.Bool
|
|
||||||
readEOF bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type pipeAddress string
|
|
||||||
|
|
||||||
func (f *pipe) LocalAddr() net.Addr {
|
|
||||||
return pipeAddress(f.path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *pipe) RemoteAddr() net.Addr {
|
|
||||||
return pipeAddress(f.path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *pipe) SetDeadline(t time.Time) error {
|
|
||||||
f.SetReadDeadline(t)
|
|
||||||
f.SetWriteDeadline(t)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CloseWrite closes the write side of a message pipe in byte mode.
|
|
||||||
func (f *messageBytePipe) CloseWrite() error {
|
|
||||||
if !f.writeClosed.CompareAndSwap(false, true) {
|
|
||||||
return io.ErrClosedPipe
|
|
||||||
}
|
|
||||||
err := f.file.Flush()
|
|
||||||
if err != nil {
|
|
||||||
f.writeClosed.Store(false)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = f.file.Write(nil)
|
|
||||||
if err != nil {
|
|
||||||
f.writeClosed.Store(false)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
|
||||||
// they are used to implement CloseWrite.
|
|
||||||
func (f *messageBytePipe) Write(b []byte) (int, error) {
|
|
||||||
if f.writeClosed.Load() {
|
|
||||||
return 0, io.ErrClosedPipe
|
|
||||||
}
|
|
||||||
if len(b) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
return f.file.Write(b)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
|
|
||||||
// mode pipe will return io.EOF, as will all subsequent reads.
|
|
||||||
func (f *messageBytePipe) Read(b []byte) (int, error) {
|
|
||||||
if f.readEOF {
|
|
||||||
return 0, io.EOF
|
|
||||||
}
|
|
||||||
n, err := f.file.Read(b)
|
|
||||||
if err == io.EOF {
|
|
||||||
// If this was the result of a zero-byte read, then
|
|
||||||
// it is possible that the read was due to a zero-size
|
|
||||||
// message. Since we are simulating CloseWrite with a
|
|
||||||
// zero-byte message, ensure that all future Read calls
|
|
||||||
// also return EOF.
|
|
||||||
f.readEOF = true
|
|
||||||
} else if err == windows.ERROR_MORE_DATA {
|
|
||||||
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
|
|
||||||
// and the message still has more bytes. Treat this as a success, since
|
|
||||||
// this package presents all named pipes as byte streams.
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *pipe) Handle() windows.Handle {
|
|
||||||
return f.handle
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s pipeAddress) Network() string {
|
|
||||||
return "pipe"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s pipeAddress) String() string {
|
|
||||||
return string(s)
|
|
||||||
}
|
|
||||||
|
|
||||||
// tryDialPipe attempts to dial the specified pipe until cancellation or timeout.
|
|
||||||
func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return 0, ctx.Err()
|
|
||||||
default:
|
|
||||||
path16, err := windows.UTF16PtrFromString(*path)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
h, err := windows.CreateFile(path16, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
|
|
||||||
if err == nil {
|
|
||||||
return h, nil
|
|
||||||
}
|
|
||||||
if err != windows.ERROR_PIPE_BUSY {
|
|
||||||
return h, &os.PathError{Err: err, Op: "open", Path: *path}
|
|
||||||
}
|
|
||||||
// Wait 10 msec and try again. This is a rather simplistic
|
|
||||||
// view, as we always try each 10 milliseconds.
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialConfig exposes various options for use in Dial and DialContext.
|
|
||||||
type DialConfig struct {
|
|
||||||
ExpectedOwner *windows.SID // If non-nil, the pipe is verified to be owned by this SID.
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialTimeout connects to the specified named pipe by path, timing out if the
|
|
||||||
// connection takes longer than the specified duration. If timeout is zero, then
|
|
||||||
// we use a default timeout of 2 seconds.
|
|
||||||
func (config *DialConfig) DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
|
|
||||||
if timeout == 0 {
|
|
||||||
timeout = time.Second * 2
|
|
||||||
}
|
|
||||||
absTimeout := time.Now().Add(timeout)
|
|
||||||
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
|
|
||||||
conn, err := config.DialContext(ctx, path)
|
|
||||||
if err == context.DeadlineExceeded {
|
|
||||||
return nil, os.ErrDeadlineExceeded
|
|
||||||
}
|
|
||||||
return conn, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialContext attempts to connect to the specified named pipe by path.
|
|
||||||
func (config *DialConfig) DialContext(ctx context.Context, path string) (net.Conn, error) {
|
|
||||||
var err error
|
|
||||||
var h windows.Handle
|
|
||||||
h, err = tryDialPipe(ctx, &path)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.ExpectedOwner != nil {
|
|
||||||
sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
|
|
||||||
if err != nil {
|
|
||||||
windows.Close(h)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
realOwner, _, err := sd.Owner()
|
|
||||||
if err != nil {
|
|
||||||
windows.Close(h)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if !realOwner.Equals(config.ExpectedOwner) {
|
|
||||||
windows.Close(h)
|
|
||||||
return nil, windows.ERROR_ACCESS_DENIED
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var flags uint32
|
|
||||||
err = windows.GetNamedPipeInfo(h, &flags, nil, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
windows.Close(h)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
f, err := makeFile(h)
|
|
||||||
if err != nil {
|
|
||||||
windows.Close(h)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the pipe is in message mode, return a message byte pipe, which
|
|
||||||
// supports CloseWrite.
|
|
||||||
if flags&windows.PIPE_TYPE_MESSAGE != 0 {
|
|
||||||
return &messageBytePipe{
|
|
||||||
pipe: pipe{file: f, path: path},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
return &pipe{file: f, path: path}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var defaultDialer DialConfig
|
|
||||||
|
|
||||||
// DialTimeout calls DialConfig.DialTimeout using an empty configuration.
|
|
||||||
func DialTimeout(path string, timeout time.Duration) (net.Conn, error) {
|
|
||||||
return defaultDialer.DialTimeout(path, timeout)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialContext calls DialConfig.DialContext using an empty configuration.
|
|
||||||
func DialContext(ctx context.Context, path string) (net.Conn, error) {
|
|
||||||
return defaultDialer.DialContext(ctx, path)
|
|
||||||
}
|
|
||||||
|
|
||||||
type acceptResponse struct {
|
|
||||||
f *file
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
type pipeListener struct {
|
|
||||||
firstHandle windows.Handle
|
|
||||||
path string
|
|
||||||
config ListenConfig
|
|
||||||
acceptCh chan chan acceptResponse
|
|
||||||
closeCh chan int
|
|
||||||
doneCh chan int
|
|
||||||
}
|
|
||||||
|
|
||||||
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *ListenConfig, isFirstPipe bool) (windows.Handle, error) {
|
|
||||||
path16, err := windows.UTF16PtrFromString(path)
|
|
||||||
if err != nil {
|
|
||||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
|
||||||
}
|
|
||||||
|
|
||||||
var oa windows.OBJECT_ATTRIBUTES
|
|
||||||
oa.Length = uint32(unsafe.Sizeof(oa))
|
|
||||||
|
|
||||||
var ntPath windows.NTUnicodeString
|
|
||||||
if err := windows.RtlDosPathNameToNtPathName(path16, &ntPath, nil, nil); err != nil {
|
|
||||||
if ntstatus, ok := err.(windows.NTStatus); ok {
|
|
||||||
err = ntstatus.Errno()
|
|
||||||
}
|
|
||||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
|
||||||
}
|
|
||||||
defer windows.LocalFree(windows.Handle(unsafe.Pointer(ntPath.Buffer)))
|
|
||||||
oa.ObjectName = &ntPath
|
|
||||||
|
|
||||||
// The security descriptor is only needed for the first pipe.
|
|
||||||
if isFirstPipe {
|
|
||||||
if sd != nil {
|
|
||||||
oa.SecurityDescriptor = sd
|
|
||||||
} else {
|
|
||||||
// Construct the default named pipe security descriptor.
|
|
||||||
var acl *windows.ACL
|
|
||||||
if err := windows.RtlDefaultNpAcl(&acl); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
defer windows.LocalFree(windows.Handle(unsafe.Pointer(acl)))
|
|
||||||
sd, err = windows.NewSecurityDescriptor()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if err = sd.SetDACL(acl, true, false); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
oa.SecurityDescriptor = sd
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
typ := uint32(windows.FILE_PIPE_REJECT_REMOTE_CLIENTS)
|
|
||||||
if c.MessageMode {
|
|
||||||
typ |= windows.FILE_PIPE_MESSAGE_TYPE
|
|
||||||
}
|
|
||||||
|
|
||||||
disposition := uint32(windows.FILE_OPEN)
|
|
||||||
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
|
|
||||||
if isFirstPipe {
|
|
||||||
disposition = windows.FILE_CREATE
|
|
||||||
// By not asking for read or write access, the named pipe file system
|
|
||||||
// will put this pipe into an initially disconnected state, blocking
|
|
||||||
// client connections until the next call with isFirstPipe == false.
|
|
||||||
access = windows.SYNCHRONIZE
|
|
||||||
}
|
|
||||||
|
|
||||||
timeout := int64(-50 * 10000) // 50ms
|
|
||||||
|
|
||||||
var (
|
|
||||||
h windows.Handle
|
|
||||||
iosb windows.IO_STATUS_BLOCK
|
|
||||||
)
|
|
||||||
err = windows.NtCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout)
|
|
||||||
if err != nil {
|
|
||||||
if ntstatus, ok := err.(windows.NTStatus); ok {
|
|
||||||
err = ntstatus.Errno()
|
|
||||||
}
|
|
||||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
|
||||||
}
|
|
||||||
|
|
||||||
runtime.KeepAlive(ntPath)
|
|
||||||
return h, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *pipeListener) makeServerPipe() (*file, error) {
|
|
||||||
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
f, err := makeFile(h)
|
|
||||||
if err != nil {
|
|
||||||
windows.Close(h)
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return f, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *pipeListener) makeConnectedServerPipe() (*file, error) {
|
|
||||||
p, err := l.makeServerPipe()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for the client to connect.
|
|
||||||
ch := make(chan error)
|
|
||||||
go func(p *file) {
|
|
||||||
ch <- connectPipe(p)
|
|
||||||
}(p)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err = <-ch:
|
|
||||||
if err != nil {
|
|
||||||
p.Close()
|
|
||||||
p = nil
|
|
||||||
}
|
|
||||||
case <-l.closeCh:
|
|
||||||
// Abort the connect request by closing the handle.
|
|
||||||
p.Close()
|
|
||||||
p = nil
|
|
||||||
err = <-ch
|
|
||||||
if err == nil || err == os.ErrClosed {
|
|
||||||
err = net.ErrClosed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return p, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *pipeListener) listenerRoutine() {
|
|
||||||
closed := false
|
|
||||||
for !closed {
|
|
||||||
select {
|
|
||||||
case <-l.closeCh:
|
|
||||||
closed = true
|
|
||||||
case responseCh := <-l.acceptCh:
|
|
||||||
var (
|
|
||||||
p *file
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
for {
|
|
||||||
p, err = l.makeConnectedServerPipe()
|
|
||||||
// If the connection was immediately closed by the client, try
|
|
||||||
// again.
|
|
||||||
if err != windows.ERROR_NO_DATA {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
responseCh <- acceptResponse{p, err}
|
|
||||||
closed = err == net.ErrClosed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
windows.Close(l.firstHandle)
|
|
||||||
l.firstHandle = 0
|
|
||||||
// Notify Close and Accept callers that the handle has been closed.
|
|
||||||
close(l.doneCh)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListenConfig contains configuration for the pipe listener.
|
|
||||||
type ListenConfig struct {
|
|
||||||
// SecurityDescriptor contains a Windows security descriptor. If nil, the default from RtlDefaultNpAcl is used.
|
|
||||||
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
|
||||||
|
|
||||||
// MessageMode determines whether the pipe is in byte or message mode. In either
|
|
||||||
// case the pipe is read in byte mode by default. The only practical difference in
|
|
||||||
// this implementation is that CloseWrite is only supported for message mode pipes;
|
|
||||||
// CloseWrite is implemented as a zero-byte write, but zero-byte writes are only
|
|
||||||
// transferred to the reader (and returned as io.EOF in this implementation)
|
|
||||||
// when the pipe is in message mode.
|
|
||||||
MessageMode bool
|
|
||||||
|
|
||||||
// InputBufferSize specifies the initial size of the input buffer, in bytes, which the OS will grow as needed.
|
|
||||||
InputBufferSize int32
|
|
||||||
|
|
||||||
// OutputBufferSize specifies the initial size of the output buffer, in bytes, which the OS will grow as needed.
|
|
||||||
OutputBufferSize int32
|
|
||||||
}
|
|
||||||
|
|
||||||
// Listen creates a listener on a Windows named pipe path,such as \\.\pipe\mypipe.
|
|
||||||
// The pipe must not already exist.
|
|
||||||
func (c *ListenConfig) Listen(path string) (net.Listener, error) {
|
|
||||||
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
l := &pipeListener{
|
|
||||||
firstHandle: h,
|
|
||||||
path: path,
|
|
||||||
config: *c,
|
|
||||||
acceptCh: make(chan chan acceptResponse),
|
|
||||||
closeCh: make(chan int),
|
|
||||||
doneCh: make(chan int),
|
|
||||||
}
|
|
||||||
// The first connection is swallowed on Windows 7 & 8, so synthesize it.
|
|
||||||
if maj, min, _ := windows.RtlGetNtVersionNumbers(); maj < 6 || (maj == 6 && min < 4) {
|
|
||||||
path16, err := windows.UTF16PtrFromString(path)
|
|
||||||
if err == nil {
|
|
||||||
h, err = windows.CreateFile(path16, 0, 0, nil, windows.OPEN_EXISTING, windows.SECURITY_SQOS_PRESENT|windows.SECURITY_ANONYMOUS, 0)
|
|
||||||
if err == nil {
|
|
||||||
windows.CloseHandle(h)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
go l.listenerRoutine()
|
|
||||||
return l, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var defaultListener ListenConfig
|
|
||||||
|
|
||||||
// Listen calls ListenConfig.Listen using an empty configuration.
|
|
||||||
func Listen(path string) (net.Listener, error) {
|
|
||||||
return defaultListener.Listen(path)
|
|
||||||
}
|
|
||||||
|
|
||||||
func connectPipe(p *file) error {
|
|
||||||
c, err := p.prepareIo()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer p.wg.Done()
|
|
||||||
|
|
||||||
err = windows.ConnectNamedPipe(p.handle, &c.o)
|
|
||||||
_, err = p.asyncIo(c, nil, 0, err)
|
|
||||||
if err != nil && err != windows.ERROR_PIPE_CONNECTED {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *pipeListener) Accept() (net.Conn, error) {
|
|
||||||
ch := make(chan acceptResponse)
|
|
||||||
select {
|
|
||||||
case l.acceptCh <- ch:
|
|
||||||
response := <-ch
|
|
||||||
err := response.err
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if l.config.MessageMode {
|
|
||||||
return &messageBytePipe{
|
|
||||||
pipe: pipe{file: response.f, path: l.path},
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
return &pipe{file: response.f, path: l.path}, nil
|
|
||||||
case <-l.doneCh:
|
|
||||||
return nil, net.ErrClosed
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *pipeListener) Close() error {
|
|
||||||
select {
|
|
||||||
case l.closeCh <- 1:
|
|
||||||
<-l.doneCh
|
|
||||||
case <-l.doneCh:
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *pipeListener) Addr() net.Addr {
|
|
||||||
return pipeAddress(l.path)
|
|
||||||
}
|
|
|
@ -1,674 +0,0 @@
|
||||||
// Copyright 2021 The Go Authors. All rights reserved.
|
|
||||||
// Copyright 2015 Microsoft
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE file.
|
|
||||||
|
|
||||||
//go:build windows
|
|
||||||
|
|
||||||
package namedpipe_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bufio"
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
func randomPipePath() string {
|
|
||||||
guid, err := windows.GenerateGUID()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return `\\.\PIPE\go-namedpipe-test-` + guid.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPingPong(t *testing.T) {
|
|
||||||
const (
|
|
||||||
ping = 42
|
|
||||||
pong = 24
|
|
||||||
)
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
listener, err := namedpipe.Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to listen on pipe: %v", err)
|
|
||||||
}
|
|
||||||
defer listener.Close()
|
|
||||||
go func() {
|
|
||||||
incoming, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to accept pipe connection: %v", err)
|
|
||||||
}
|
|
||||||
defer incoming.Close()
|
|
||||||
var data [1]byte
|
|
||||||
_, err = incoming.Read(data[:])
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to read ping from pipe: %v", err)
|
|
||||||
}
|
|
||||||
if data[0] != ping {
|
|
||||||
t.Fatalf("expected ping, got %d", data[0])
|
|
||||||
}
|
|
||||||
data[0] = pong
|
|
||||||
_, err = incoming.Write(data[:])
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to write pong to pipe: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to dial pipe: %v", err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
client.SetDeadline(time.Now().Add(time.Second * 5))
|
|
||||||
var data [1]byte
|
|
||||||
data[0] = ping
|
|
||||||
_, err = client.Write(data[:])
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to write ping to pipe: %v", err)
|
|
||||||
}
|
|
||||||
_, err = client.Read(data[:])
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to read pong from pipe: %v", err)
|
|
||||||
}
|
|
||||||
if data[0] != pong {
|
|
||||||
t.Fatalf("expected pong, got %d", data[0])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDialUnknownFailsImmediately(t *testing.T) {
|
|
||||||
_, err := namedpipe.DialTimeout(randomPipePath(), time.Duration(0))
|
|
||||||
if !errors.Is(err, syscall.ENOENT) {
|
|
||||||
t.Fatalf("expected ENOENT got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDialListenerTimesOut(t *testing.T) {
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
l, err := namedpipe.Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
pipe, err := namedpipe.DialTimeout(pipePath, 10*time.Millisecond)
|
|
||||||
if err == nil {
|
|
||||||
pipe.Close()
|
|
||||||
}
|
|
||||||
if err != os.ErrDeadlineExceeded {
|
|
||||||
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDialContextListenerTimesOut(t *testing.T) {
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
l, err := namedpipe.Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
d := 10 * time.Millisecond
|
|
||||||
ctx, _ := context.WithTimeout(context.Background(), d)
|
|
||||||
pipe, err := namedpipe.DialContext(ctx, pipePath)
|
|
||||||
if err == nil {
|
|
||||||
pipe.Close()
|
|
||||||
}
|
|
||||||
if err != context.DeadlineExceeded {
|
|
||||||
t.Fatalf("expected context.DeadlineExceeded, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDialListenerGetsCancelled(t *testing.T) {
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
l, err := namedpipe.Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
ch := make(chan error)
|
|
||||||
go func(ctx context.Context, ch chan error) {
|
|
||||||
_, err := namedpipe.DialContext(ctx, pipePath)
|
|
||||||
ch <- err
|
|
||||||
}(ctx, ch)
|
|
||||||
time.Sleep(time.Millisecond * 30)
|
|
||||||
cancel()
|
|
||||||
err = <-ch
|
|
||||||
if err != context.Canceled {
|
|
||||||
t.Fatalf("expected context.Canceled, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDialAccessDeniedWithRestrictedSD(t *testing.T) {
|
|
||||||
if windows.NewLazySystemDLL("ntdll.dll").NewProc("wine_get_version").Find() == nil {
|
|
||||||
t.Skip("dacls on named pipes are broken on wine")
|
|
||||||
}
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
sd, _ := windows.SecurityDescriptorFromString("D:")
|
|
||||||
l, err := (&namedpipe.ListenConfig{
|
|
||||||
SecurityDescriptor: sd,
|
|
||||||
}).Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
|
||||||
if err == nil {
|
|
||||||
pipe.Close()
|
|
||||||
}
|
|
||||||
if !errors.Is(err, windows.ERROR_ACCESS_DENIED) {
|
|
||||||
t.Fatalf("expected ERROR_ACCESS_DENIED, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getConnection(cfg *namedpipe.ListenConfig) (client, server net.Conn, err error) {
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
if cfg == nil {
|
|
||||||
cfg = &namedpipe.ListenConfig{}
|
|
||||||
}
|
|
||||||
l, err := cfg.Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
|
|
||||||
type response struct {
|
|
||||||
c net.Conn
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
ch := make(chan response)
|
|
||||||
go func() {
|
|
||||||
c, err := l.Accept()
|
|
||||||
ch <- response{c, err}
|
|
||||||
}()
|
|
||||||
|
|
||||||
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r := <-ch
|
|
||||||
if err = r.err; err != nil {
|
|
||||||
c.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
client = c
|
|
||||||
server = r.c
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestReadTimeout(t *testing.T) {
|
|
||||||
c, s, err := getConnection(nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer c.Close()
|
|
||||||
defer s.Close()
|
|
||||||
|
|
||||||
c.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
|
|
||||||
|
|
||||||
buf := make([]byte, 10)
|
|
||||||
_, err = c.Read(buf)
|
|
||||||
if err != os.ErrDeadlineExceeded {
|
|
||||||
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func server(l net.Listener, ch chan int) {
|
|
||||||
c, err := l.Accept()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
|
|
||||||
s, err := rw.ReadString('\n')
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
_, err = rw.WriteString("got " + s)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
err = rw.Flush()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
c.Close()
|
|
||||||
ch <- 1
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFullListenDialReadWrite(t *testing.T) {
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
l, err := namedpipe.Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
|
|
||||||
ch := make(chan int)
|
|
||||||
go server(l, ch)
|
|
||||||
|
|
||||||
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer c.Close()
|
|
||||||
|
|
||||||
rw := bufio.NewReadWriter(bufio.NewReader(c), bufio.NewWriter(c))
|
|
||||||
_, err = rw.WriteString("hello world\n")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
err = rw.Flush()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s, err := rw.ReadString('\n')
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
ms := "got hello world\n"
|
|
||||||
if s != ms {
|
|
||||||
t.Errorf("expected '%s', got '%s'", ms, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
<-ch
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCloseAbortsListen(t *testing.T) {
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
l, err := namedpipe.Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ch := make(chan error)
|
|
||||||
go func() {
|
|
||||||
_, err := l.Accept()
|
|
||||||
ch <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.Sleep(30 * time.Millisecond)
|
|
||||||
l.Close()
|
|
||||||
|
|
||||||
err = <-ch
|
|
||||||
if err != net.ErrClosed {
|
|
||||||
t.Fatalf("expected net.ErrClosed, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ensureEOFOnClose(t *testing.T, r io.Reader, w io.Closer) {
|
|
||||||
b := make([]byte, 10)
|
|
||||||
w.Close()
|
|
||||||
n, err := r.Read(b)
|
|
||||||
if n > 0 {
|
|
||||||
t.Errorf("unexpected byte count %d", n)
|
|
||||||
}
|
|
||||||
if err != io.EOF {
|
|
||||||
t.Errorf("expected EOF: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCloseClientEOFServer(t *testing.T) {
|
|
||||||
c, s, err := getConnection(nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer c.Close()
|
|
||||||
defer s.Close()
|
|
||||||
ensureEOFOnClose(t, c, s)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCloseServerEOFClient(t *testing.T) {
|
|
||||||
c, s, err := getConnection(nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer c.Close()
|
|
||||||
defer s.Close()
|
|
||||||
ensureEOFOnClose(t, s, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCloseWriteEOF(t *testing.T) {
|
|
||||||
cfg := &namedpipe.ListenConfig{
|
|
||||||
MessageMode: true,
|
|
||||||
}
|
|
||||||
c, s, err := getConnection(cfg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer c.Close()
|
|
||||||
defer s.Close()
|
|
||||||
|
|
||||||
type closeWriter interface {
|
|
||||||
CloseWrite() error
|
|
||||||
}
|
|
||||||
|
|
||||||
err = c.(closeWriter).CloseWrite()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
b := make([]byte, 10)
|
|
||||||
_, err = s.Read(b)
|
|
||||||
if err != io.EOF {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAcceptAfterCloseFails(t *testing.T) {
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
l, err := namedpipe.Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
l.Close()
|
|
||||||
_, err = l.Accept()
|
|
||||||
if err != net.ErrClosed {
|
|
||||||
t.Fatalf("expected net.ErrClosed, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDialTimesOutByDefault(t *testing.T) {
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
l, err := namedpipe.Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
pipe, err := namedpipe.DialTimeout(pipePath, time.Duration(0)) // Should timeout after 2 seconds.
|
|
||||||
if err == nil {
|
|
||||||
pipe.Close()
|
|
||||||
}
|
|
||||||
if err != os.ErrDeadlineExceeded {
|
|
||||||
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTimeoutPendingRead(t *testing.T) {
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
l, err := namedpipe.Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
|
|
||||||
serverDone := make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
s, err := l.Accept()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
s.Close()
|
|
||||||
close(serverDone)
|
|
||||||
}()
|
|
||||||
|
|
||||||
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
clientErr := make(chan error)
|
|
||||||
go func() {
|
|
||||||
buf := make([]byte, 10)
|
|
||||||
_, err = client.Read(buf)
|
|
||||||
clientErr <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond) // make *sure* the pipe is reading before we set the deadline
|
|
||||||
client.SetReadDeadline(time.Unix(1, 0))
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err = <-clientErr:
|
|
||||||
if err != os.ErrDeadlineExceeded {
|
|
||||||
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
|
||||||
}
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Fatalf("timed out while waiting for read to cancel")
|
|
||||||
<-clientErr
|
|
||||||
}
|
|
||||||
<-serverDone
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTimeoutPendingWrite(t *testing.T) {
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
l, err := namedpipe.Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
|
|
||||||
serverDone := make(chan struct{})
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
s, err := l.Accept()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
time.Sleep(1 * time.Second)
|
|
||||||
s.Close()
|
|
||||||
close(serverDone)
|
|
||||||
}()
|
|
||||||
|
|
||||||
client, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
clientErr := make(chan error)
|
|
||||||
go func() {
|
|
||||||
_, err = client.Write([]byte("this should timeout"))
|
|
||||||
clientErr <- err
|
|
||||||
}()
|
|
||||||
|
|
||||||
time.Sleep(100 * time.Millisecond) // make *sure* the pipe is writing before we set the deadline
|
|
||||||
client.SetWriteDeadline(time.Unix(1, 0))
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err = <-clientErr:
|
|
||||||
if err != os.ErrDeadlineExceeded {
|
|
||||||
t.Fatalf("expected os.ErrDeadlineExceeded, got %v", err)
|
|
||||||
}
|
|
||||||
case <-time.After(100 * time.Millisecond):
|
|
||||||
t.Fatalf("timed out while waiting for write to cancel")
|
|
||||||
<-clientErr
|
|
||||||
}
|
|
||||||
<-serverDone
|
|
||||||
}
|
|
||||||
|
|
||||||
type CloseWriter interface {
|
|
||||||
CloseWrite() error
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEchoWithMessaging(t *testing.T) {
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
l, err := (&namedpipe.ListenConfig{
|
|
||||||
MessageMode: true, // Use message mode so that CloseWrite() is supported
|
|
||||||
InputBufferSize: 65536, // Use 64KB buffers to improve performance
|
|
||||||
OutputBufferSize: 65536,
|
|
||||||
}).Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
|
|
||||||
listenerDone := make(chan bool)
|
|
||||||
clientDone := make(chan bool)
|
|
||||||
go func() {
|
|
||||||
// server echo
|
|
||||||
conn, err := l.Accept()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
time.Sleep(500 * time.Millisecond) // make *sure* we don't begin to read before eof signal is sent
|
|
||||||
_, err = io.Copy(conn, conn)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
conn.(CloseWriter).CloseWrite()
|
|
||||||
close(listenerDone)
|
|
||||||
}()
|
|
||||||
client, err := namedpipe.DialTimeout(pipePath, time.Second)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
// client read back
|
|
||||||
bytes := make([]byte, 2)
|
|
||||||
n, e := client.Read(bytes)
|
|
||||||
if e != nil {
|
|
||||||
t.Fatal(e)
|
|
||||||
}
|
|
||||||
if n != 2 || bytes[0] != 0 || bytes[1] != 1 {
|
|
||||||
t.Fatalf("expected 2 bytes, got %v", n)
|
|
||||||
}
|
|
||||||
close(clientDone)
|
|
||||||
}()
|
|
||||||
|
|
||||||
payload := make([]byte, 2)
|
|
||||||
payload[0] = 0
|
|
||||||
payload[1] = 1
|
|
||||||
|
|
||||||
n, err := client.Write(payload)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if n != 2 {
|
|
||||||
t.Fatalf("expected 2 bytes, got %v", n)
|
|
||||||
}
|
|
||||||
client.(CloseWriter).CloseWrite()
|
|
||||||
<-listenerDone
|
|
||||||
<-clientDone
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectRace(t *testing.T) {
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
l, err := namedpipe.Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
s, err := l.Accept()
|
|
||||||
if err == net.ErrClosed {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
s.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
c.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMessageReadMode(t *testing.T) {
|
|
||||||
if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 8 {
|
|
||||||
t.Skipf("Skipping on Windows %d", maj)
|
|
||||||
}
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
defer wg.Wait()
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
l, err := (&namedpipe.ListenConfig{MessageMode: true}).Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer l.Close()
|
|
||||||
|
|
||||||
msg := ([]byte)("hello world")
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
s, err := l.Accept()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
_, err = s.Write(msg)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
s.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer c.Close()
|
|
||||||
|
|
||||||
mode := uint32(windows.PIPE_READMODE_MESSAGE)
|
|
||||||
err = windows.SetNamedPipeHandleState(c.(interface{ Handle() windows.Handle }).Handle(), &mode, nil, nil)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ch := make([]byte, 1)
|
|
||||||
var vmsg []byte
|
|
||||||
for {
|
|
||||||
n, err := c.Read(ch)
|
|
||||||
if err == io.EOF {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
if n != 1 {
|
|
||||||
t.Fatalf("expected 1, got %d", n)
|
|
||||||
}
|
|
||||||
vmsg = append(vmsg, ch[0])
|
|
||||||
}
|
|
||||||
if !bytes.Equal(msg, vmsg) {
|
|
||||||
t.Fatalf("expected %s, got %s", msg, vmsg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestListenConnectRace(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("Skipping long race test")
|
|
||||||
}
|
|
||||||
pipePath := randomPipePath()
|
|
||||||
for i := 0; i < 50 && !t.Failed(); i++ {
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
c, err := namedpipe.DialTimeout(pipePath, time.Duration(0))
|
|
||||||
if err == nil {
|
|
||||||
c.Close()
|
|
||||||
}
|
|
||||||
wg.Done()
|
|
||||||
}()
|
|
||||||
s, err := namedpipe.Listen(pipePath)
|
|
||||||
if err != nil {
|
|
||||||
t.Error(i, err)
|
|
||||||
} else {
|
|
||||||
s.Close()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,66 +0,0 @@
|
||||||
//go:build linux || darwin || freebsd || openbsd
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ipc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
IpcErrorIO = -int64(unix.EIO)
|
|
||||||
IpcErrorProtocol = -int64(unix.EPROTO)
|
|
||||||
IpcErrorInvalid = -int64(unix.EINVAL)
|
|
||||||
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
|
|
||||||
IpcErrorUnknown = -55 // ENOANO
|
|
||||||
)
|
|
||||||
|
|
||||||
// socketDirectory is variable because it is modified by a linker
|
|
||||||
// flag in wireguard-android.
|
|
||||||
var socketDirectory = "/var/run/amneziawg"
|
|
||||||
|
|
||||||
func sockPath(iface string) string {
|
|
||||||
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
|
|
||||||
}
|
|
||||||
|
|
||||||
func UAPIOpen(name string) (*os.File, error) {
|
|
||||||
if err := os.MkdirAll(socketDirectory, 0o755); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
socketPath := sockPath(name)
|
|
||||||
addr, err := net.ResolveUnixAddr("unix", socketPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
oldUmask := unix.Umask(0o077)
|
|
||||||
defer unix.Umask(oldUmask)
|
|
||||||
|
|
||||||
listener, err := net.ListenUnix("unix", addr)
|
|
||||||
if err == nil {
|
|
||||||
return listener.File()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test socket, if not in use cleanup and try again.
|
|
||||||
if _, err := net.Dial("unix", socketPath); err == nil {
|
|
||||||
return nil, errors.New("unix socket in use")
|
|
||||||
}
|
|
||||||
if err := os.Remove(socketPath); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
listener, err = net.ListenUnix("unix", addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return listener.File()
|
|
||||||
}
|
|
|
@ -1,15 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ipc
|
|
||||||
|
|
||||||
// Made up sentinel error codes for {js,wasip1}/wasm.
|
|
||||||
const (
|
|
||||||
IpcErrorIO = 1
|
|
||||||
IpcErrorInvalid = 2
|
|
||||||
IpcErrorPortInUse = 3
|
|
||||||
IpcErrorUnknown = 4
|
|
||||||
IpcErrorProtocol = 5
|
|
||||||
)
|
|
|
@ -1,88 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package ipc
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/ipc/namedpipe"
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TODO: replace these with actual standard windows error numbers from the win package
|
|
||||||
const (
|
|
||||||
IpcErrorIO = -int64(5)
|
|
||||||
IpcErrorProtocol = -int64(71)
|
|
||||||
IpcErrorInvalid = -int64(22)
|
|
||||||
IpcErrorPortInUse = -int64(98)
|
|
||||||
IpcErrorUnknown = -int64(55)
|
|
||||||
)
|
|
||||||
|
|
||||||
type UAPIListener struct {
|
|
||||||
listener net.Listener // unix socket listener
|
|
||||||
connNew chan net.Conn
|
|
||||||
connErr chan error
|
|
||||||
kqueueFd int
|
|
||||||
keventFd int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *UAPIListener) Accept() (net.Conn, error) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case conn := <-l.connNew:
|
|
||||||
return conn, nil
|
|
||||||
|
|
||||||
case err := <-l.connErr:
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *UAPIListener) Close() error {
|
|
||||||
return l.listener.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *UAPIListener) Addr() net.Addr {
|
|
||||||
return l.listener.Addr()
|
|
||||||
}
|
|
||||||
|
|
||||||
var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
var err error
|
|
||||||
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)(A;;GA;;;BA)S:(ML;;NWNRNX;;;HI)")
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func UAPIListen(name string) (net.Listener, error) {
|
|
||||||
listener, err := (&namedpipe.ListenConfig{
|
|
||||||
SecurityDescriptor: UAPISecurityDescriptor,
|
|
||||||
}).Listen(`\\.\pipe\ProtectedPrefix\Administrators\AmneziaWG\` + name)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
uapi := &UAPIListener{
|
|
||||||
listener: listener,
|
|
||||||
connNew: make(chan net.Conn, 1),
|
|
||||||
connErr: make(chan error, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
go func(l *UAPIListener) {
|
|
||||||
for {
|
|
||||||
conn, err := l.listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
l.connErr <- err
|
|
||||||
break
|
|
||||||
}
|
|
||||||
l.connNew <- conn
|
|
||||||
}
|
|
||||||
}(uapi)
|
|
||||||
|
|
||||||
return uapi, nil
|
|
||||||
}
|
|
|
@ -1,15 +1,14 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"testing"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type KDFTest struct {
|
type KDFTest struct {
|
||||||
|
@ -20,7 +19,7 @@ type KDFTest struct {
|
||||||
t2 string
|
t2 string
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertEquals(t *testing.T, a, b string) {
|
func assertEquals(t *testing.T, a string, b string) {
|
||||||
if a != b {
|
if a != b {
|
||||||
t.Fatal("expected", a, "=", b)
|
t.Fatal("expected", a, "=", b)
|
||||||
}
|
}
|
|
@ -1,17 +1,14 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/replay"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
/* Due to limitations in Go and /x/crypto there is currently
|
/* Due to limitations in Go and /x/crypto there is currently
|
||||||
|
@ -22,10 +19,10 @@ import (
|
||||||
*/
|
*/
|
||||||
|
|
||||||
type Keypair struct {
|
type Keypair struct {
|
||||||
sendNonce atomic.Uint64
|
sendNonce uint64
|
||||||
send cipher.AEAD
|
send cipher.AEAD
|
||||||
receive cipher.AEAD
|
receive cipher.AEAD
|
||||||
replayFilter replay.Filter
|
replayFilter ReplayFilter
|
||||||
isInitiator bool
|
isInitiator bool
|
||||||
created time.Time
|
created time.Time
|
||||||
localIndex uint32
|
localIndex uint32
|
||||||
|
@ -33,15 +30,15 @@ type Keypair struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Keypairs struct {
|
type Keypairs struct {
|
||||||
sync.RWMutex
|
mutex sync.RWMutex
|
||||||
current *Keypair
|
current *Keypair
|
||||||
previous *Keypair
|
previous *Keypair
|
||||||
next atomic.Pointer[Keypair]
|
next *Keypair
|
||||||
}
|
}
|
||||||
|
|
||||||
func (kp *Keypairs) Current() *Keypair {
|
func (kp *Keypairs) Current() *Keypair {
|
||||||
kp.RLock()
|
kp.mutex.RLock()
|
||||||
defer kp.RUnlock()
|
defer kp.mutex.RUnlock()
|
||||||
return kp.current
|
return kp.current
|
||||||
}
|
}
|
||||||
|
|
59
logger.go
Normal file
59
logger.go
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
LogLevelSilent = iota
|
||||||
|
LogLevelError
|
||||||
|
LogLevelInfo
|
||||||
|
LogLevelDebug
|
||||||
|
)
|
||||||
|
|
||||||
|
type Logger struct {
|
||||||
|
Debug *log.Logger
|
||||||
|
Info *log.Logger
|
||||||
|
Error *log.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewLogger(level int, prepend string) *Logger {
|
||||||
|
output := os.Stdout
|
||||||
|
logger := new(Logger)
|
||||||
|
|
||||||
|
logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) {
|
||||||
|
if level >= LogLevelDebug {
|
||||||
|
return output, output, output
|
||||||
|
}
|
||||||
|
if level >= LogLevelInfo {
|
||||||
|
return output, output, ioutil.Discard
|
||||||
|
}
|
||||||
|
if level >= LogLevelError {
|
||||||
|
return output, ioutil.Discard, ioutil.Discard
|
||||||
|
}
|
||||||
|
return ioutil.Discard, ioutil.Discard, ioutil.Discard
|
||||||
|
}()
|
||||||
|
|
||||||
|
logger.Debug = log.New(logDebug,
|
||||||
|
"DEBUG: "+prepend,
|
||||||
|
log.Ldate|log.Ltime,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.Info = log.New(logInfo,
|
||||||
|
"INFO: "+prepend,
|
||||||
|
log.Ldate|log.Ltime,
|
||||||
|
)
|
||||||
|
logger.Error = log.New(logErr,
|
||||||
|
"ERROR: "+prepend,
|
||||||
|
log.Ldate|log.Ltime,
|
||||||
|
)
|
||||||
|
return logger
|
||||||
|
}
|
134
main.go
134
main.go
|
@ -1,8 +1,6 @@
|
||||||
//go:build !windows
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package main
|
package main
|
||||||
|
@ -13,12 +11,6 @@ import (
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/device"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -33,38 +25,56 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
func printUsage() {
|
func printUsage() {
|
||||||
fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
|
fmt.Printf("usage:\n")
|
||||||
|
fmt.Printf("%s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
func warning() {
|
func warning() {
|
||||||
switch runtime.GOOS {
|
if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
|
||||||
case "linux", "freebsd", "openbsd":
|
|
||||||
if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────────────┐")
|
shouldQuit := false
|
||||||
fmt.Fprintln(os.Stderr, "│ │")
|
|
||||||
fmt.Fprintln(os.Stderr, "│ Running amneziawg-go is not required because this │")
|
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
|
||||||
fmt.Fprintln(os.Stderr, "│ kernel has first class support for AmneziaWG. For │")
|
fmt.Fprintln(os.Stderr, "W G")
|
||||||
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
|
fmt.Fprintln(os.Stderr, "W This is alpha software. It will very likely not G")
|
||||||
fmt.Fprintln(os.Stderr, "│ please visit: │")
|
fmt.Fprintln(os.Stderr, "W do what it is supposed to do, and things may go G")
|
||||||
fmt.Fprintln(os.Stderr, "| https://github.com/amnezia-vpn/amneziawg-linux-kernel-module │")
|
fmt.Fprintln(os.Stderr, "W horribly wrong. You have been warned. Proceed G")
|
||||||
fmt.Fprintln(os.Stderr, "│ │")
|
fmt.Fprintln(os.Stderr, "W at your own risk. G")
|
||||||
fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────────────┘")
|
if runtime.GOOS == "linux" {
|
||||||
|
shouldQuit = os.Getenv("WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD") != "1"
|
||||||
|
|
||||||
|
fmt.Fprintln(os.Stderr, "W G")
|
||||||
|
fmt.Fprintln(os.Stderr, "W Furthermore, you are running this software on a G")
|
||||||
|
fmt.Fprintln(os.Stderr, "W Linux kernel, which is probably unnecessary and G")
|
||||||
|
fmt.Fprintln(os.Stderr, "W foolish. This is because the Linux kernel has G")
|
||||||
|
fmt.Fprintln(os.Stderr, "W built-in first class support for WireGuard, and G")
|
||||||
|
fmt.Fprintln(os.Stderr, "W this support is much more refined than this G")
|
||||||
|
fmt.Fprintln(os.Stderr, "W program. For more information on installing the G")
|
||||||
|
fmt.Fprintln(os.Stderr, "W kernel module, please visit: G")
|
||||||
|
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
|
||||||
|
if shouldQuit {
|
||||||
|
fmt.Fprintln(os.Stderr, "W G")
|
||||||
|
fmt.Fprintln(os.Stderr, "W If you still want to use this program, against G")
|
||||||
|
fmt.Fprintln(os.Stderr, "W the sage advice here, please first export this G")
|
||||||
|
fmt.Fprintln(os.Stderr, "W environment variable: G")
|
||||||
|
fmt.Fprintln(os.Stderr, "W WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD=1 G")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Fprintln(os.Stderr, "W G")
|
||||||
|
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
|
||||||
|
|
||||||
|
if shouldQuit {
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
if len(os.Args) == 2 && os.Args[1] == "--version" {
|
|
||||||
fmt.Printf("amneziawg-go %s\n\nUserspace AmneziaWG daemon for %s-%s.\nInformation available at https://amnezia.org\n", Version, runtime.GOOS, runtime.GOARCH)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
warning()
|
warning()
|
||||||
|
|
||||||
|
// parse arguments
|
||||||
|
|
||||||
var foreground bool
|
var foreground bool
|
||||||
var interfaceName string
|
var interfaceName string
|
||||||
if len(os.Args) < 2 || len(os.Args) > 3 {
|
if len(os.Args) < 2 || len(os.Args) > 3 {
|
||||||
|
@ -99,22 +109,24 @@ func main() {
|
||||||
|
|
||||||
logLevel := func() int {
|
logLevel := func() int {
|
||||||
switch os.Getenv("LOG_LEVEL") {
|
switch os.Getenv("LOG_LEVEL") {
|
||||||
case "verbose", "debug":
|
case "debug":
|
||||||
return device.LogLevelVerbose
|
return LogLevelDebug
|
||||||
|
case "info":
|
||||||
|
return LogLevelInfo
|
||||||
case "error":
|
case "error":
|
||||||
return device.LogLevelError
|
return LogLevelError
|
||||||
case "silent":
|
case "silent":
|
||||||
return device.LogLevelSilent
|
return LogLevelSilent
|
||||||
}
|
}
|
||||||
return device.LogLevelError
|
return LogLevelInfo
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// open TUN device (or use supplied fd)
|
// open TUN device (or use supplied fd)
|
||||||
|
|
||||||
tdev, err := func() (tun.Device, error) {
|
tun, err := func() (TUNDevice, error) {
|
||||||
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||||
if tunFdStr == "" {
|
if tunFdStr == "" {
|
||||||
return tun.CreateTUN(interfaceName, device.DefaultMTU)
|
return CreateTUN(interfaceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// construct tun device from supplied fd
|
// construct tun device from supplied fd
|
||||||
|
@ -124,31 +136,26 @@ func main() {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = unix.SetNonblock(int(fd), true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
file := os.NewFile(uintptr(fd), "")
|
file := os.NewFile(uintptr(fd), "")
|
||||||
return tun.CreateTUNFromFile(file, device.DefaultMTU)
|
return CreateTUNFromFile(file)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
realInterfaceName, err2 := tdev.Name()
|
realInterfaceName, err2 := tun.Name()
|
||||||
if err2 == nil {
|
if err2 == nil {
|
||||||
interfaceName = realInterfaceName
|
interfaceName = realInterfaceName
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger := device.NewLogger(
|
logger := NewLogger(
|
||||||
logLevel,
|
logLevel,
|
||||||
fmt.Sprintf("(%s) ", interfaceName),
|
fmt.Sprintf("(%s) ", interfaceName),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.Verbosef("Starting amneziawg-go version %s", Version)
|
logger.Debug.Println("Debug log enabled")
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Failed to create TUN device: %v", err)
|
logger.Error.Println("Failed to create TUN device:", err)
|
||||||
os.Exit(ExitSetupFailed)
|
os.Exit(ExitSetupFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -157,7 +164,7 @@ func main() {
|
||||||
fileUAPI, err := func() (*os.File, error) {
|
fileUAPI, err := func() (*os.File, error) {
|
||||||
uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
|
uapiFdStr := os.Getenv(ENV_WG_UAPI_FD)
|
||||||
if uapiFdStr == "" {
|
if uapiFdStr == "" {
|
||||||
return ipc.UAPIOpen(interfaceName)
|
return UAPIOpen(interfaceName)
|
||||||
}
|
}
|
||||||
|
|
||||||
// use supplied fd
|
// use supplied fd
|
||||||
|
@ -169,8 +176,9 @@ func main() {
|
||||||
|
|
||||||
return os.NewFile(uintptr(fd), ""), nil
|
return os.NewFile(uintptr(fd), ""), nil
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("UAPI listen error: %v", err)
|
logger.Error.Println("UAPI listen error:", err)
|
||||||
os.Exit(ExitSetupFailed)
|
os.Exit(ExitSetupFailed)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -182,7 +190,7 @@ func main() {
|
||||||
env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
|
env = append(env, fmt.Sprintf("%s=4", ENV_WG_UAPI_FD))
|
||||||
env = append(env, fmt.Sprintf("%s=1", ENV_WG_PROCESS_FOREGROUND))
|
env = append(env, fmt.Sprintf("%s=1", ENV_WG_PROCESS_FOREGROUND))
|
||||||
files := [3]*os.File{}
|
files := [3]*os.File{}
|
||||||
if os.Getenv("LOG_LEVEL") != "" && logLevel != device.LogLevelSilent {
|
if os.Getenv("LOG_LEVEL") != "" && logLevel != LogLevelSilent {
|
||||||
files[0], _ = os.Open(os.DevNull)
|
files[0], _ = os.Open(os.DevNull)
|
||||||
files[1] = os.Stdout
|
files[1] = os.Stdout
|
||||||
files[2] = os.Stderr
|
files[2] = os.Stderr
|
||||||
|
@ -196,7 +204,7 @@ func main() {
|
||||||
files[0], // stdin
|
files[0], // stdin
|
||||||
files[1], // stdout
|
files[1], // stdout
|
||||||
files[2], // stderr
|
files[2], // stderr
|
||||||
tdev.File(),
|
tun.File(),
|
||||||
fileUAPI,
|
fileUAPI,
|
||||||
},
|
},
|
||||||
Dir: ".",
|
Dir: ".",
|
||||||
|
@ -205,7 +213,7 @@ func main() {
|
||||||
|
|
||||||
path, err := os.Executable()
|
path, err := os.Executable()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Failed to determine executable: %v", err)
|
logger.Error.Println("Failed to determine executable:", err)
|
||||||
os.Exit(ExitSetupFailed)
|
os.Exit(ExitSetupFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -215,23 +223,23 @@ func main() {
|
||||||
attr,
|
attr,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Failed to daemonize: %v", err)
|
logger.Error.Println("Failed to daemonize:", err)
|
||||||
os.Exit(ExitSetupFailed)
|
os.Exit(ExitSetupFailed)
|
||||||
}
|
}
|
||||||
process.Release()
|
process.Release()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
device := device.NewDevice(tdev, conn.NewDefaultBind(), logger)
|
device := NewDevice(tun, logger)
|
||||||
|
|
||||||
logger.Verbosef("Device started")
|
logger.Info.Println("Device started")
|
||||||
|
|
||||||
errs := make(chan error)
|
errs := make(chan error)
|
||||||
term := make(chan os.Signal, 1)
|
term := make(chan os.Signal)
|
||||||
|
|
||||||
uapi, err := ipc.UAPIListen(interfaceName, fileUAPI)
|
uapi, err := UAPIListen(interfaceName, fileUAPI)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Errorf("Failed to listen on uapi socket: %v", err)
|
logger.Error.Println("Failed to listen on uapi socket:", err)
|
||||||
os.Exit(ExitSetupFailed)
|
os.Exit(ExitSetupFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -242,15 +250,15 @@ func main() {
|
||||||
errs <- err
|
errs <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go device.IpcHandle(conn)
|
go ipcHandle(device, conn)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
logger.Verbosef("UAPI listener started")
|
logger.Info.Println("UAPI listener started")
|
||||||
|
|
||||||
// wait for program to terminate
|
// wait for program to terminate
|
||||||
|
|
||||||
signal.Notify(term, unix.SIGTERM)
|
signal.Notify(term, os.Kill)
|
||||||
signal.Notify(term, os.Interrupt)
|
signal.Notify(term, os.Interrupt)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
@ -264,5 +272,5 @@ func main() {
|
||||||
uapi.Close()
|
uapi.Close()
|
||||||
device.Close()
|
device.Close()
|
||||||
|
|
||||||
logger.Verbosef("Shutting down")
|
logger.Info.Println("Shutting down")
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,99 +0,0 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/device"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/ipc"
|
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/tun"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
ExitSetupSuccess = 0
|
|
||||||
ExitSetupFailed = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
if len(os.Args) != 2 {
|
|
||||||
os.Exit(ExitSetupFailed)
|
|
||||||
}
|
|
||||||
interfaceName := os.Args[1]
|
|
||||||
|
|
||||||
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real AmneziaWG for Windows client, please visit: https://amnezia.org")
|
|
||||||
|
|
||||||
logger := device.NewLogger(
|
|
||||||
device.LogLevelVerbose,
|
|
||||||
fmt.Sprintf("(%s) ", interfaceName),
|
|
||||||
)
|
|
||||||
logger.Verbosef("Starting amneziawg-go version %s", Version)
|
|
||||||
|
|
||||||
tun, err := tun.CreateTUN(interfaceName, 0)
|
|
||||||
if err == nil {
|
|
||||||
realInterfaceName, err2 := tun.Name()
|
|
||||||
if err2 == nil {
|
|
||||||
interfaceName = realInterfaceName
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
logger.Errorf("Failed to create TUN device: %v", err)
|
|
||||||
os.Exit(ExitSetupFailed)
|
|
||||||
}
|
|
||||||
|
|
||||||
device := device.NewDevice(tun, conn.NewDefaultBind(), logger)
|
|
||||||
err = device.Up()
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Failed to bring up device: %v", err)
|
|
||||||
os.Exit(ExitSetupFailed)
|
|
||||||
}
|
|
||||||
logger.Verbosef("Device started")
|
|
||||||
|
|
||||||
uapi, err := ipc.UAPIListen(interfaceName)
|
|
||||||
if err != nil {
|
|
||||||
logger.Errorf("Failed to listen on uapi socket: %v", err)
|
|
||||||
os.Exit(ExitSetupFailed)
|
|
||||||
}
|
|
||||||
|
|
||||||
errs := make(chan error)
|
|
||||||
term := make(chan os.Signal, 1)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
conn, err := uapi.Accept()
|
|
||||||
if err != nil {
|
|
||||||
errs <- err
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go device.IpcHandle(conn)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
logger.Verbosef("UAPI listener started")
|
|
||||||
|
|
||||||
// wait for program to terminate
|
|
||||||
|
|
||||||
signal.Notify(term, os.Interrupt)
|
|
||||||
signal.Notify(term, os.Kill)
|
|
||||||
signal.Notify(term, windows.SIGTERM)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-term:
|
|
||||||
case <-errs:
|
|
||||||
case <-device.Wait():
|
|
||||||
}
|
|
||||||
|
|
||||||
// clean up
|
|
||||||
|
|
||||||
uapi.Close()
|
|
||||||
device.Close()
|
|
||||||
|
|
||||||
logger.Verbosef("Shutting down")
|
|
||||||
}
|
|
62
misc.go
Normal file
62
misc.go
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
/* Atomic Boolean */
|
||||||
|
|
||||||
|
const (
|
||||||
|
AtomicFalse = int32(iota)
|
||||||
|
AtomicTrue
|
||||||
|
)
|
||||||
|
|
||||||
|
type AtomicBool struct {
|
||||||
|
flag int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AtomicBool) Get() bool {
|
||||||
|
return atomic.LoadInt32(&a.flag) == AtomicTrue
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AtomicBool) Swap(val bool) bool {
|
||||||
|
flag := AtomicFalse
|
||||||
|
if val {
|
||||||
|
flag = AtomicTrue
|
||||||
|
}
|
||||||
|
return atomic.SwapInt32(&a.flag, flag) == AtomicTrue
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *AtomicBool) Set(val bool) {
|
||||||
|
flag := AtomicFalse
|
||||||
|
if val {
|
||||||
|
flag = AtomicTrue
|
||||||
|
}
|
||||||
|
atomic.StoreInt32(&a.flag, flag)
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Integer manipulation */
|
||||||
|
|
||||||
|
func toInt32(n uint32) int32 {
|
||||||
|
mask := uint32(1 << 31)
|
||||||
|
return int32(-(n & mask) + (n & ^mask))
|
||||||
|
}
|
||||||
|
|
||||||
|
func min(a, b uint) uint {
|
||||||
|
if a > b {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
func minUint64(a uint64, b uint64) uint64 {
|
||||||
|
if a > b {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
|
@ -1,19 +1,17 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/hmac"
|
"crypto/hmac"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"errors"
|
|
||||||
"hash"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
"golang.org/x/crypto/curve25519"
|
"golang.org/x/crypto/curve25519"
|
||||||
|
"hash"
|
||||||
)
|
)
|
||||||
|
|
||||||
/* KDF related functions.
|
/* KDF related functions.
|
||||||
|
@ -43,6 +41,7 @@ func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) {
|
||||||
func KDF1(t0 *[blake2s.Size]byte, key, input []byte) {
|
func KDF1(t0 *[blake2s.Size]byte, key, input []byte) {
|
||||||
HMAC1(t0, key, input)
|
HMAC1(t0, key, input)
|
||||||
HMAC1(t0, t0[:], []byte{0x1})
|
HMAC1(t0, t0[:], []byte{0x1})
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
|
func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
|
||||||
|
@ -51,6 +50,7 @@ func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
|
||||||
HMAC1(t0, prk[:], []byte{0x1})
|
HMAC1(t0, prk[:], []byte{0x1})
|
||||||
HMAC2(t1, prk[:], t0[:], []byte{0x2})
|
HMAC2(t1, prk[:], t0[:], []byte{0x2})
|
||||||
setZero(prk[:])
|
setZero(prk[:])
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
|
func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
|
||||||
|
@ -60,6 +60,7 @@ func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
|
||||||
HMAC2(t1, prk[:], t0[:], []byte{0x2})
|
HMAC2(t1, prk[:], t0[:], []byte{0x2})
|
||||||
HMAC2(t2, prk[:], t1[:], []byte{0x3})
|
HMAC2(t2, prk[:], t1[:], []byte{0x3})
|
||||||
setZero(prk[:])
|
setZero(prk[:])
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func isZero(val []byte) bool {
|
func isZero(val []byte) bool {
|
||||||
|
@ -77,14 +78,12 @@ func setZero(arr []byte) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sk *NoisePrivateKey) clamp() {
|
|
||||||
sk[0] &= 248
|
|
||||||
sk[31] = (sk[31] & 127) | 64
|
|
||||||
}
|
|
||||||
|
|
||||||
func newPrivateKey() (sk NoisePrivateKey, err error) {
|
func newPrivateKey() (sk NoisePrivateKey, err error) {
|
||||||
|
// clamping: https://cr.yp.to/ecdh.html
|
||||||
_, err = rand.Read(sk[:])
|
_, err = rand.Read(sk[:])
|
||||||
sk.clamp()
|
sk[0] &= 248
|
||||||
|
sk[31] &= 127
|
||||||
|
sk[31] |= 64
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -95,14 +94,9 @@ func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var errInvalidPublicKey = errors.New("invalid public key")
|
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
|
||||||
|
|
||||||
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte, err error) {
|
|
||||||
apk := (*[NoisePublicKeySize]byte)(&pk)
|
apk := (*[NoisePublicKeySize]byte)(&pk)
|
||||||
ask := (*[NoisePrivateKeySize]byte)(sk)
|
ask := (*[NoisePrivateKeySize]byte)(sk)
|
||||||
curve25519.ScalarMult(&ss, ask, apk)
|
curve25519.ScalarMult(&ss, ask, apk)
|
||||||
if isZero(ss[:]) {
|
return ss
|
||||||
return ss, errInvalidPublicKey
|
|
||||||
}
|
|
||||||
return ss, nil
|
|
||||||
}
|
}
|
|
@ -1,50 +1,28 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2015-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"./tai64n"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/blake2s"
|
"golang.org/x/crypto/blake2s"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
"golang.org/x/crypto/poly1305"
|
"golang.org/x/crypto/poly1305"
|
||||||
|
"sync"
|
||||||
"github.com/amnezia-vpn/amneziawg-go/tai64n"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type handshakeState int
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
handshakeZeroed = handshakeState(iota)
|
HandshakeZeroed = iota
|
||||||
handshakeInitiationCreated
|
HandshakeInitiationCreated
|
||||||
handshakeInitiationConsumed
|
HandshakeInitiationConsumed
|
||||||
handshakeResponseCreated
|
HandshakeResponseCreated
|
||||||
handshakeResponseConsumed
|
HandshakeResponseConsumed
|
||||||
)
|
)
|
||||||
|
|
||||||
func (hs handshakeState) String() string {
|
|
||||||
switch hs {
|
|
||||||
case handshakeZeroed:
|
|
||||||
return "handshakeZeroed"
|
|
||||||
case handshakeInitiationCreated:
|
|
||||||
return "handshakeInitiationCreated"
|
|
||||||
case handshakeInitiationConsumed:
|
|
||||||
return "handshakeInitiationConsumed"
|
|
||||||
case handshakeResponseCreated:
|
|
||||||
return "handshakeResponseCreated"
|
|
||||||
case handshakeResponseConsumed:
|
|
||||||
return "handshakeResponseConsumed"
|
|
||||||
default:
|
|
||||||
return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
|
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
|
||||||
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
|
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
|
||||||
|
@ -52,21 +30,21 @@ const (
|
||||||
WGLabelCookie = "cookie--"
|
WGLabelCookie = "cookie--"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
const (
|
||||||
MessageInitiationType uint32 = 1
|
MessageInitiationType = 1
|
||||||
MessageResponseType uint32 = 2
|
MessageResponseType = 2
|
||||||
MessageCookieReplyType uint32 = 3
|
MessageCookieReplyType = 3
|
||||||
MessageTransportType uint32 = 4
|
MessageTransportType = 4
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MessageInitiationSize = 148 // size of handshake initiation message
|
MessageInitiationSize = 148 // size of handshake initation message
|
||||||
MessageResponseSize = 92 // size of response message
|
MessageResponseSize = 92 // size of response message
|
||||||
MessageCookieReplySize = 64 // size of cookie reply message
|
MessageCookieReplySize = 64 // size of cookie reply message
|
||||||
MessageTransportHeaderSize = 16 // size of data preceding content in transport message
|
MessageTransportHeaderSize = 16 // size of data preceeding content in transport message
|
||||||
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
|
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
|
||||||
MessageKeepaliveSize = MessageTransportSize // size of keepalive
|
MessageKeepaliveSize = MessageTransportSize // size of keepalive
|
||||||
MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
|
MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -75,10 +53,6 @@ const (
|
||||||
MessageTransportOffsetContent = 16
|
MessageTransportOffsetContent = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
var packetSizeToMsgType map[int]uint32
|
|
||||||
|
|
||||||
var msgTypeToJunkSize map[uint32]int
|
|
||||||
|
|
||||||
/* Type is an 8-bit field, followed by 3 nul bytes,
|
/* Type is an 8-bit field, followed by 3 nul bytes,
|
||||||
* by marshalling the messages in little-endian byteorder
|
* by marshalling the messages in little-endian byteorder
|
||||||
* we can treat these as a 32-bit unsigned int (for now)
|
* we can treat these as a 32-bit unsigned int (for now)
|
||||||
|
@ -115,16 +89,16 @@ type MessageTransport struct {
|
||||||
type MessageCookieReply struct {
|
type MessageCookieReply struct {
|
||||||
Type uint32
|
Type uint32
|
||||||
Receiver uint32
|
Receiver uint32
|
||||||
Nonce [chacha20poly1305.NonceSizeX]byte
|
Nonce [24]byte
|
||||||
Cookie [blake2s.Size128 + poly1305.TagSize]byte
|
Cookie [blake2s.Size128 + poly1305.TagSize]byte
|
||||||
}
|
}
|
||||||
|
|
||||||
type Handshake struct {
|
type Handshake struct {
|
||||||
state handshakeState
|
state int
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
hash [blake2s.Size]byte // hash value
|
hash [blake2s.Size]byte // hash value
|
||||||
chainKey [blake2s.Size]byte // chain key
|
chainKey [blake2s.Size]byte // chain key
|
||||||
presharedKey NoisePresharedKey // psk
|
presharedKey NoiseSymmetricKey // psk
|
||||||
localEphemeral NoisePrivateKey // ephemeral secret key
|
localEphemeral NoisePrivateKey // ephemeral secret key
|
||||||
localIndex uint32 // used to clear hash-table
|
localIndex uint32 // used to clear hash-table
|
||||||
remoteIndex uint32 // index for sending
|
remoteIndex uint32 // index for sending
|
||||||
|
@ -142,11 +116,11 @@ var (
|
||||||
ZeroNonce [chacha20poly1305.NonceSize]byte
|
ZeroNonce [chacha20poly1305.NonceSize]byte
|
||||||
)
|
)
|
||||||
|
|
||||||
func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
|
func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
|
||||||
KDF1(dst, c[:], data)
|
KDF1(dst, c[:], data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
|
func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
|
||||||
hash, _ := blake2s.New256(nil)
|
hash, _ := blake2s.New256(nil)
|
||||||
hash.Write(h[:])
|
hash.Write(h[:])
|
||||||
hash.Write(data)
|
hash.Write(data)
|
||||||
|
@ -160,7 +134,7 @@ func (h *Handshake) Clear() {
|
||||||
setZero(h.chainKey[:])
|
setZero(h.chainKey[:])
|
||||||
setZero(h.hash[:])
|
setZero(h.hash[:])
|
||||||
h.localIndex = 0
|
h.localIndex = 0
|
||||||
h.state = handshakeZeroed
|
h.state = HandshakeZeroed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handshake) mixHash(data []byte) {
|
func (h *Handshake) mixHash(data []byte) {
|
||||||
|
@ -179,14 +153,20 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
||||||
device.staticIdentity.RLock()
|
|
||||||
defer device.staticIdentity.RUnlock()
|
device.staticIdentity.mutex.RLock()
|
||||||
|
defer device.staticIdentity.mutex.RUnlock()
|
||||||
|
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
defer handshake.mutex.Unlock()
|
defer handshake.mutex.Unlock()
|
||||||
|
|
||||||
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||||
|
return nil, errors.New("static shared secret is zero")
|
||||||
|
}
|
||||||
|
|
||||||
// create ephemeral key
|
// create ephemeral key
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
handshake.hash = InitialHash
|
handshake.hash = InitialHash
|
||||||
handshake.chainKey = InitialChainKey
|
handshake.chainKey = InitialChainKey
|
||||||
|
@ -195,58 +175,59 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// assign index
|
||||||
|
|
||||||
|
device.indexTable.Delete(handshake.localIndex)
|
||||||
|
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
handshake.mixHash(handshake.remoteStatic[:])
|
handshake.mixHash(handshake.remoteStatic[:])
|
||||||
|
|
||||||
device.aSecMux.RLock()
|
|
||||||
msg := MessageInitiation{
|
msg := MessageInitiation{
|
||||||
Type: MessageInitiationType,
|
Type: MessageInitiationType,
|
||||||
Ephemeral: handshake.localEphemeral.publicKey(),
|
Ephemeral: handshake.localEphemeral.publicKey(),
|
||||||
|
Sender: handshake.localIndex,
|
||||||
}
|
}
|
||||||
device.aSecMux.RUnlock()
|
|
||||||
|
|
||||||
handshake.mixKey(msg.Ephemeral[:])
|
handshake.mixKey(msg.Ephemeral[:])
|
||||||
handshake.mixHash(msg.Ephemeral[:])
|
handshake.mixHash(msg.Ephemeral[:])
|
||||||
|
|
||||||
// encrypt static key
|
// encrypt static key
|
||||||
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
|
||||||
if err != nil {
|
func() {
|
||||||
return nil, err
|
var key [chacha20poly1305.KeySize]byte
|
||||||
}
|
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||||
var key [chacha20poly1305.KeySize]byte
|
KDF2(
|
||||||
KDF2(
|
&handshake.chainKey,
|
||||||
&handshake.chainKey,
|
&key,
|
||||||
&key,
|
handshake.chainKey[:],
|
||||||
handshake.chainKey[:],
|
ss[:],
|
||||||
ss[:],
|
)
|
||||||
)
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
|
||||||
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
|
}()
|
||||||
handshake.mixHash(msg.Static[:])
|
handshake.mixHash(msg.Static[:])
|
||||||
|
|
||||||
// encrypt timestamp
|
// encrypt timestamp
|
||||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
|
||||||
return nil, errInvalidPublicKey
|
|
||||||
}
|
|
||||||
KDF2(
|
|
||||||
&handshake.chainKey,
|
|
||||||
&key,
|
|
||||||
handshake.chainKey[:],
|
|
||||||
handshake.precomputedStaticStatic[:],
|
|
||||||
)
|
|
||||||
timestamp := tai64n.Now()
|
|
||||||
aead, _ = chacha20poly1305.New(key[:])
|
|
||||||
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
|
|
||||||
|
|
||||||
// assign index
|
timestamp := tai64n.Now()
|
||||||
device.indexTable.Delete(handshake.localIndex)
|
func() {
|
||||||
msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
var key [chacha20poly1305.KeySize]byte
|
||||||
if err != nil {
|
KDF2(
|
||||||
return nil, err
|
&handshake.chainKey,
|
||||||
}
|
&key,
|
||||||
handshake.localIndex = msg.Sender
|
handshake.chainKey[:],
|
||||||
|
handshake.precomputedStaticStatic[:],
|
||||||
|
)
|
||||||
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
|
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
|
||||||
|
}()
|
||||||
|
|
||||||
handshake.mixHash(msg.Timestamp[:])
|
handshake.mixHash(msg.Timestamp[:])
|
||||||
handshake.state = handshakeInitiationCreated
|
handshake.state = HandshakeInitiationCreated
|
||||||
return &msg, nil
|
return &msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,30 +237,28 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||||
chainKey [blake2s.Size]byte
|
chainKey [blake2s.Size]byte
|
||||||
)
|
)
|
||||||
|
|
||||||
device.aSecMux.RLock()
|
|
||||||
if msg.Type != MessageInitiationType {
|
if msg.Type != MessageInitiationType {
|
||||||
device.aSecMux.RUnlock()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
device.aSecMux.RUnlock()
|
|
||||||
|
|
||||||
device.staticIdentity.RLock()
|
device.staticIdentity.mutex.RLock()
|
||||||
defer device.staticIdentity.RUnlock()
|
defer device.staticIdentity.mutex.RUnlock()
|
||||||
|
|
||||||
mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
|
mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
|
||||||
mixHash(&hash, &hash, msg.Ephemeral[:])
|
mixHash(&hash, &hash, msg.Ephemeral[:])
|
||||||
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
||||||
|
|
||||||
// decrypt static key
|
// decrypt static key
|
||||||
|
|
||||||
|
var err error
|
||||||
var peerPK NoisePublicKey
|
var peerPK NoisePublicKey
|
||||||
var key [chacha20poly1305.KeySize]byte
|
func() {
|
||||||
ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
var key [chacha20poly1305.KeySize]byte
|
||||||
if err != nil {
|
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||||
return nil
|
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
||||||
}
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
}()
|
||||||
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -288,29 +267,28 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||||
// lookup peer
|
// lookup peer
|
||||||
|
|
||||||
peer := device.LookupPeer(peerPK)
|
peer := device.LookupPeer(peerPK)
|
||||||
if peer == nil || !peer.isRunning.Load() {
|
if peer == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// verify identity
|
// verify identity
|
||||||
|
|
||||||
var timestamp tai64n.Timestamp
|
var timestamp tai64n.Timestamp
|
||||||
|
var key [chacha20poly1305.KeySize]byte
|
||||||
|
|
||||||
handshake.mutex.RLock()
|
handshake.mutex.RLock()
|
||||||
|
|
||||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
|
||||||
handshake.mutex.RUnlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
KDF2(
|
KDF2(
|
||||||
&chainKey,
|
&chainKey,
|
||||||
&key,
|
&key,
|
||||||
chainKey[:],
|
chainKey[:],
|
||||||
handshake.precomputedStaticStatic[:],
|
handshake.precomputedStaticStatic[:],
|
||||||
)
|
)
|
||||||
aead, _ = chacha20poly1305.New(key[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
handshake.mutex.RUnlock()
|
handshake.mutex.RUnlock()
|
||||||
|
@ -320,15 +298,11 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||||
|
|
||||||
// protect against replay & flood
|
// protect against replay & flood
|
||||||
|
|
||||||
replay := !timestamp.After(handshake.lastTimestamp)
|
var ok bool
|
||||||
flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
|
ok = timestamp.After(handshake.lastTimestamp)
|
||||||
|
ok = ok && time.Now().Sub(handshake.lastInitiationConsumption) > HandshakeInitationRate
|
||||||
handshake.mutex.RUnlock()
|
handshake.mutex.RUnlock()
|
||||||
if replay {
|
if !ok {
|
||||||
device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if flood {
|
|
||||||
device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -340,14 +314,9 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||||
handshake.chainKey = chainKey
|
handshake.chainKey = chainKey
|
||||||
handshake.remoteIndex = msg.Sender
|
handshake.remoteIndex = msg.Sender
|
||||||
handshake.remoteEphemeral = msg.Ephemeral
|
handshake.remoteEphemeral = msg.Ephemeral
|
||||||
if timestamp.After(handshake.lastTimestamp) {
|
handshake.lastTimestamp = timestamp
|
||||||
handshake.lastTimestamp = timestamp
|
handshake.lastInitiationConsumption = time.Now()
|
||||||
}
|
handshake.state = HandshakeInitiationConsumed
|
||||||
now := time.Now()
|
|
||||||
if now.After(handshake.lastInitiationConsumption) {
|
|
||||||
handshake.lastInitiationConsumption = now
|
|
||||||
}
|
|
||||||
handshake.state = handshakeInitiationConsumed
|
|
||||||
|
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
|
@ -362,7 +331,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
defer handshake.mutex.Unlock()
|
defer handshake.mutex.Unlock()
|
||||||
|
|
||||||
if handshake.state != handshakeInitiationConsumed {
|
if handshake.state != HandshakeInitiationConsumed {
|
||||||
return nil, errors.New("handshake initiation must be consumed first")
|
return nil, errors.New("handshake initiation must be consumed first")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -376,9 +345,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||||
}
|
}
|
||||||
|
|
||||||
var msg MessageResponse
|
var msg MessageResponse
|
||||||
device.aSecMux.RLock()
|
|
||||||
msg.Type = MessageResponseType
|
msg.Type = MessageResponseType
|
||||||
device.aSecMux.RUnlock()
|
|
||||||
msg.Sender = handshake.localIndex
|
msg.Sender = handshake.localIndex
|
||||||
msg.Receiver = handshake.remoteIndex
|
msg.Receiver = handshake.remoteIndex
|
||||||
|
|
||||||
|
@ -392,16 +359,12 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||||
handshake.mixHash(msg.Ephemeral[:])
|
handshake.mixHash(msg.Ephemeral[:])
|
||||||
handshake.mixKey(msg.Ephemeral[:])
|
handshake.mixKey(msg.Ephemeral[:])
|
||||||
|
|
||||||
ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
func() {
|
||||||
if err != nil {
|
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
|
||||||
return nil, err
|
handshake.mixKey(ss[:])
|
||||||
}
|
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||||
handshake.mixKey(ss[:])
|
handshake.mixKey(ss[:])
|
||||||
ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
}()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
handshake.mixKey(ss[:])
|
|
||||||
|
|
||||||
// add preshared key
|
// add preshared key
|
||||||
|
|
||||||
|
@ -418,22 +381,21 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||||
|
|
||||||
handshake.mixHash(tau[:])
|
handshake.mixHash(tau[:])
|
||||||
|
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
func() {
|
||||||
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
handshake.mixHash(msg.Empty[:])
|
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
|
||||||
|
handshake.mixHash(msg.Empty[:])
|
||||||
|
}()
|
||||||
|
|
||||||
handshake.state = handshakeResponseCreated
|
handshake.state = HandshakeResponseCreated
|
||||||
|
|
||||||
return &msg, nil
|
return &msg, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
device.aSecMux.RLock()
|
|
||||||
if msg.Type != MessageResponseType {
|
if msg.Type != MessageResponseType {
|
||||||
device.aSecMux.RUnlock()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
device.aSecMux.RUnlock()
|
|
||||||
|
|
||||||
// lookup handshake by receiver
|
// lookup handshake by receiver
|
||||||
|
|
||||||
|
@ -449,38 +411,37 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
)
|
)
|
||||||
|
|
||||||
ok := func() bool {
|
ok := func() bool {
|
||||||
|
|
||||||
// lock handshake state
|
// lock handshake state
|
||||||
|
|
||||||
handshake.mutex.RLock()
|
handshake.mutex.RLock()
|
||||||
defer handshake.mutex.RUnlock()
|
defer handshake.mutex.RUnlock()
|
||||||
|
|
||||||
if handshake.state != handshakeInitiationCreated {
|
if handshake.state != HandshakeInitiationCreated {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// lock private key for reading
|
// lock private key for reading
|
||||||
|
|
||||||
device.staticIdentity.RLock()
|
device.staticIdentity.mutex.RLock()
|
||||||
defer device.staticIdentity.RUnlock()
|
defer device.staticIdentity.mutex.RUnlock()
|
||||||
|
|
||||||
// finish 3-way DH
|
// finish 3-way DH
|
||||||
|
|
||||||
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
|
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
|
||||||
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
|
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
|
||||||
|
|
||||||
ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
func() {
|
||||||
if err != nil {
|
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
|
||||||
return false
|
mixKey(&chainKey, &chainKey, ss[:])
|
||||||
}
|
setZero(ss[:])
|
||||||
mixKey(&chainKey, &chainKey, ss[:])
|
}()
|
||||||
setZero(ss[:])
|
|
||||||
|
|
||||||
ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
func() {
|
||||||
if err != nil {
|
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||||
return false
|
mixKey(&chainKey, &chainKey, ss[:])
|
||||||
}
|
setZero(ss[:])
|
||||||
mixKey(&chainKey, &chainKey, ss[:])
|
}()
|
||||||
setZero(ss[:])
|
|
||||||
|
|
||||||
// add preshared key (psk)
|
// add preshared key (psk)
|
||||||
|
|
||||||
|
@ -498,7 +459,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
// authenticate transcript
|
// authenticate transcript
|
||||||
|
|
||||||
aead, _ := chacha20poly1305.New(key[:])
|
aead, _ := chacha20poly1305.New(key[:])
|
||||||
_, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -517,7 +478,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||||
handshake.hash = hash
|
handshake.hash = hash
|
||||||
handshake.chainKey = chainKey
|
handshake.chainKey = chainKey
|
||||||
handshake.remoteIndex = msg.Sender
|
handshake.remoteIndex = msg.Sender
|
||||||
handshake.state = handshakeResponseConsumed
|
handshake.state = HandshakeResponseConsumed
|
||||||
|
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
|
@ -542,7 +503,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
var sendKey [chacha20poly1305.KeySize]byte
|
var sendKey [chacha20poly1305.KeySize]byte
|
||||||
var recvKey [chacha20poly1305.KeySize]byte
|
var recvKey [chacha20poly1305.KeySize]byte
|
||||||
|
|
||||||
if handshake.state == handshakeResponseConsumed {
|
if handshake.state == HandshakeResponseConsumed {
|
||||||
KDF2(
|
KDF2(
|
||||||
&sendKey,
|
&sendKey,
|
||||||
&recvKey,
|
&recvKey,
|
||||||
|
@ -550,7 +511,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
nil,
|
nil,
|
||||||
)
|
)
|
||||||
isInitiator = true
|
isInitiator = true
|
||||||
} else if handshake.state == handshakeResponseCreated {
|
} else if handshake.state == HandshakeResponseCreated {
|
||||||
KDF2(
|
KDF2(
|
||||||
&recvKey,
|
&recvKey,
|
||||||
&sendKey,
|
&sendKey,
|
||||||
|
@ -559,7 +520,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
)
|
)
|
||||||
isInitiator = false
|
isInitiator = false
|
||||||
} else {
|
} else {
|
||||||
return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
|
return errors.New("invalid state for keypair derivation")
|
||||||
}
|
}
|
||||||
|
|
||||||
// zero handshake
|
// zero handshake
|
||||||
|
@ -567,7 +528,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
setZero(handshake.chainKey[:])
|
setZero(handshake.chainKey[:])
|
||||||
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
|
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
|
||||||
setZero(handshake.localEphemeral[:])
|
setZero(handshake.localEphemeral[:])
|
||||||
peer.handshake.state = handshakeZeroed
|
peer.handshake.state = HandshakeZeroed
|
||||||
|
|
||||||
// create AEAD instances
|
// create AEAD instances
|
||||||
|
|
||||||
|
@ -579,7 +540,8 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
setZero(recvKey[:])
|
setZero(recvKey[:])
|
||||||
|
|
||||||
keypair.created = time.Now()
|
keypair.created = time.Now()
|
||||||
keypair.replayFilter.Reset()
|
keypair.sendNonce = 0
|
||||||
|
keypair.replayFilter.Init()
|
||||||
keypair.isInitiator = isInitiator
|
keypair.isInitiator = isInitiator
|
||||||
keypair.localIndex = peer.handshake.localIndex
|
keypair.localIndex = peer.handshake.localIndex
|
||||||
keypair.remoteIndex = peer.handshake.remoteIndex
|
keypair.remoteIndex = peer.handshake.remoteIndex
|
||||||
|
@ -592,16 +554,16 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
// rotate key pairs
|
// rotate key pairs
|
||||||
|
|
||||||
keypairs := &peer.keypairs
|
keypairs := &peer.keypairs
|
||||||
keypairs.Lock()
|
keypairs.mutex.Lock()
|
||||||
defer keypairs.Unlock()
|
defer keypairs.mutex.Unlock()
|
||||||
|
|
||||||
previous := keypairs.previous
|
previous := keypairs.previous
|
||||||
next := keypairs.next.Load()
|
next := keypairs.next
|
||||||
current := keypairs.current
|
current := keypairs.current
|
||||||
|
|
||||||
if isInitiator {
|
if isInitiator {
|
||||||
if next != nil {
|
if next != nil {
|
||||||
keypairs.next.Store(nil)
|
keypairs.next = nil
|
||||||
keypairs.previous = next
|
keypairs.previous = next
|
||||||
device.DeleteKeypair(current)
|
device.DeleteKeypair(current)
|
||||||
} else {
|
} else {
|
||||||
|
@ -610,7 +572,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
device.DeleteKeypair(previous)
|
device.DeleteKeypair(previous)
|
||||||
keypairs.current = keypair
|
keypairs.current = keypair
|
||||||
} else {
|
} else {
|
||||||
keypairs.next.Store(keypair)
|
keypairs.next = keypair
|
||||||
device.DeleteKeypair(next)
|
device.DeleteKeypair(next)
|
||||||
keypairs.previous = nil
|
keypairs.previous = nil
|
||||||
device.DeleteKeypair(previous)
|
device.DeleteKeypair(previous)
|
||||||
|
@ -621,19 +583,18 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||||
|
|
||||||
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
||||||
keypairs := &peer.keypairs
|
keypairs := &peer.keypairs
|
||||||
|
if keypairs.next != receivedKeypair {
|
||||||
if keypairs.next.Load() != receivedKeypair {
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
keypairs.Lock()
|
keypairs.mutex.Lock()
|
||||||
defer keypairs.Unlock()
|
defer keypairs.mutex.Unlock()
|
||||||
if keypairs.next.Load() != receivedKeypair {
|
if keypairs.next != receivedKeypair {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
old := keypairs.previous
|
old := keypairs.previous
|
||||||
keypairs.previous = keypairs.current
|
keypairs.previous = keypairs.current
|
||||||
peer.device.DeleteKeypair(old)
|
peer.device.DeleteKeypair(old)
|
||||||
keypairs.current = keypairs.next.Load()
|
keypairs.current = keypairs.next
|
||||||
keypairs.next.Store(nil)
|
keypairs.next = nil
|
||||||
return true
|
return true
|
||||||
}
|
}
|
|
@ -1,26 +1,26 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
NoisePublicKeySize = 32
|
NoisePublicKeySize = 32
|
||||||
NoisePrivateKeySize = 32
|
NoisePrivateKeySize = 32
|
||||||
NoisePresharedKeySize = 32
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
type (
|
||||||
NoisePublicKey [NoisePublicKeySize]byte
|
NoisePublicKey [NoisePublicKeySize]byte
|
||||||
NoisePrivateKey [NoisePrivateKeySize]byte
|
NoisePrivateKey [NoisePrivateKeySize]byte
|
||||||
NoisePresharedKey [NoisePresharedKeySize]byte
|
NoiseSymmetricKey [chacha20poly1305.KeySize]byte
|
||||||
NoiseNonce uint64 // padded to 12-bytes
|
NoiseNonce uint64 // padded to 12-bytes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -45,25 +45,22 @@ func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
|
||||||
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
|
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *NoisePrivateKey) FromHex(src string) (err error) {
|
func (key *NoisePrivateKey) FromHex(src string) error {
|
||||||
err = loadExactHex(key[:], src)
|
return loadExactHex(key[:], src)
|
||||||
key.clamp()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
|
func (key NoisePrivateKey) ToHex() string {
|
||||||
err = loadExactHex(key[:], src)
|
return hex.EncodeToString(key[:])
|
||||||
if key.IsZero() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
key.clamp()
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *NoisePublicKey) FromHex(src string) error {
|
func (key *NoisePublicKey) FromHex(src string) error {
|
||||||
return loadExactHex(key[:], src)
|
return loadExactHex(key[:], src)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (key NoisePublicKey) ToHex() string {
|
||||||
|
return hex.EncodeToString(key[:])
|
||||||
|
}
|
||||||
|
|
||||||
func (key NoisePublicKey) IsZero() bool {
|
func (key NoisePublicKey) IsZero() bool {
|
||||||
var zero NoisePublicKey
|
var zero NoisePublicKey
|
||||||
return key.Equals(zero)
|
return key.Equals(zero)
|
||||||
|
@ -73,6 +70,10 @@ func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
|
||||||
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
|
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
|
||||||
}
|
}
|
||||||
|
|
||||||
func (key *NoisePresharedKey) FromHex(src string) error {
|
func (key *NoiseSymmetricKey) FromHex(src string) error {
|
||||||
return loadExactHex(key[:], src)
|
return loadExactHex(key[:], src)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (key NoiseSymmetricKey) ToHex() string {
|
||||||
|
return hex.EncodeToString(key[:])
|
||||||
|
}
|
|
@ -1,17 +1,14 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package device
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/conn"
|
|
||||||
"github.com/amnezia-vpn/amneziawg-go/tun/tuntest"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCurveWrappers(t *testing.T) {
|
func TestCurveWrappers(t *testing.T) {
|
||||||
|
@ -24,38 +21,14 @@ func TestCurveWrappers(t *testing.T) {
|
||||||
pk1 := sk1.publicKey()
|
pk1 := sk1.publicKey()
|
||||||
pk2 := sk2.publicKey()
|
pk2 := sk2.publicKey()
|
||||||
|
|
||||||
ss1, err1 := sk1.sharedSecret(pk2)
|
ss1 := sk1.sharedSecret(pk2)
|
||||||
ss2, err2 := sk2.sharedSecret(pk1)
|
ss2 := sk2.sharedSecret(pk1)
|
||||||
|
|
||||||
if ss1 != ss2 || err1 != nil || err2 != nil {
|
if ss1 != ss2 {
|
||||||
t.Fatal("Failed to compute shared secet")
|
t.Fatal("Failed to compute shared secet")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func randDevice(t *testing.T) *Device {
|
|
||||||
sk, err := newPrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
tun := tuntest.NewChannelTUN()
|
|
||||||
logger := NewLogger(LogLevelError, "")
|
|
||||||
device := NewDevice(tun.TUN(), conn.NewDefaultBind(), logger)
|
|
||||||
device.SetPrivateKey(sk)
|
|
||||||
return device
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertNil(t *testing.T, err error) {
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertEqual(t *testing.T, a, b []byte) {
|
|
||||||
if !bytes.Equal(a, b) {
|
|
||||||
t.Fatal(a, "!=", b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNoiseHandshake(t *testing.T) {
|
func TestNoiseHandshake(t *testing.T) {
|
||||||
dev1 := randDevice(t)
|
dev1 := randDevice(t)
|
||||||
dev2 := randDevice(t)
|
dev2 := randDevice(t)
|
||||||
|
@ -63,16 +36,8 @@ func TestNoiseHandshake(t *testing.T) {
|
||||||
defer dev1.Close()
|
defer dev1.Close()
|
||||||
defer dev2.Close()
|
defer dev2.Close()
|
||||||
|
|
||||||
peer1, err := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
|
peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
|
||||||
if err != nil {
|
peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
peer2, err := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
peer1.Start()
|
|
||||||
peer2.Start()
|
|
||||||
|
|
||||||
assertEqual(
|
assertEqual(
|
||||||
t,
|
t,
|
||||||
|
@ -92,7 +57,6 @@ func TestNoiseHandshake(t *testing.T) {
|
||||||
packet := make([]byte, 0, 256)
|
packet := make([]byte, 0, 256)
|
||||||
writer := bytes.NewBuffer(packet)
|
writer := bytes.NewBuffer(packet)
|
||||||
err = binary.Write(writer, binary.LittleEndian, msg1)
|
err = binary.Write(writer, binary.LittleEndian, msg1)
|
||||||
assertNil(t, err)
|
|
||||||
peer := dev2.ConsumeMessageInitiation(msg1)
|
peer := dev2.ConsumeMessageInitiation(msg1)
|
||||||
if peer == nil {
|
if peer == nil {
|
||||||
t.Fatal("handshake failed at initiation message")
|
t.Fatal("handshake failed at initiation message")
|
||||||
|
@ -148,7 +112,7 @@ func TestNoiseHandshake(t *testing.T) {
|
||||||
t.Fatal("failed to derive keypair for peer 2", err)
|
t.Fatal("failed to derive keypair for peer 2", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
key1 := peer1.keypairs.next.Load()
|
key1 := peer1.keypairs.next
|
||||||
key2 := peer2.keypairs.current
|
key2 := peer2.keypairs.current
|
||||||
|
|
||||||
// encrypting / decryption test
|
// encrypting / decryption test
|
258
peer.go
Normal file
258
peer.go
Normal file
|
@ -0,0 +1,258 @@
|
||||||
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PeerRoutineNumber = 3
|
||||||
|
)
|
||||||
|
|
||||||
|
type Peer struct {
|
||||||
|
isRunning AtomicBool
|
||||||
|
mutex sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
|
||||||
|
keypairs Keypairs
|
||||||
|
handshake Handshake
|
||||||
|
device *Device
|
||||||
|
endpoint Endpoint
|
||||||
|
persistentKeepaliveInterval uint16
|
||||||
|
|
||||||
|
// This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
|
||||||
|
stats struct {
|
||||||
|
txBytes uint64 // bytes send to peer (endpoint)
|
||||||
|
rxBytes uint64 // bytes received from peer
|
||||||
|
lastHandshakeNano int64 // nano seconds since epoch
|
||||||
|
}
|
||||||
|
|
||||||
|
timers struct {
|
||||||
|
retransmitHandshake *Timer
|
||||||
|
sendKeepalive *Timer
|
||||||
|
newHandshake *Timer
|
||||||
|
zeroKeyMaterial *Timer
|
||||||
|
persistentKeepalive *Timer
|
||||||
|
handshakeAttempts uint
|
||||||
|
needAnotherKeepalive bool
|
||||||
|
sentLastMinuteHandshake bool
|
||||||
|
}
|
||||||
|
|
||||||
|
signals struct {
|
||||||
|
newKeypairArrived chan struct{}
|
||||||
|
flushNonceQueue chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
queue struct {
|
||||||
|
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
|
||||||
|
outbound chan *QueueOutboundElement // sequential ordering of work
|
||||||
|
inbound chan *QueueInboundElement // sequential ordering of work
|
||||||
|
packetInNonceQueueIsAwaitingKey bool
|
||||||
|
}
|
||||||
|
|
||||||
|
routines struct {
|
||||||
|
mutex sync.Mutex // held when stopping / starting routines
|
||||||
|
starting sync.WaitGroup // routines pending start
|
||||||
|
stopping sync.WaitGroup // routines pending stop
|
||||||
|
stop chan struct{} // size 0, stop all go routines in peer
|
||||||
|
}
|
||||||
|
|
||||||
|
cookieGenerator CookieGenerator
|
||||||
|
}
|
||||||
|
|
||||||
|
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||||
|
|
||||||
|
if device.isClosed.Get() {
|
||||||
|
return nil, errors.New("device closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// lock resources
|
||||||
|
|
||||||
|
device.staticIdentity.mutex.RLock()
|
||||||
|
defer device.staticIdentity.mutex.RUnlock()
|
||||||
|
|
||||||
|
device.peers.mutex.Lock()
|
||||||
|
defer device.peers.mutex.Unlock()
|
||||||
|
|
||||||
|
// check if over limit
|
||||||
|
|
||||||
|
if len(device.peers.keyMap) >= MaxPeers {
|
||||||
|
return nil, errors.New("too many peers")
|
||||||
|
}
|
||||||
|
|
||||||
|
// create peer
|
||||||
|
|
||||||
|
peer := new(Peer)
|
||||||
|
peer.mutex.Lock()
|
||||||
|
defer peer.mutex.Unlock()
|
||||||
|
|
||||||
|
peer.cookieGenerator.Init(pk)
|
||||||
|
peer.device = device
|
||||||
|
peer.isRunning.Set(false)
|
||||||
|
|
||||||
|
// map public key
|
||||||
|
|
||||||
|
_, ok := device.peers.keyMap[pk]
|
||||||
|
if ok {
|
||||||
|
return nil, errors.New("adding existing peer")
|
||||||
|
}
|
||||||
|
device.peers.keyMap[pk] = peer
|
||||||
|
|
||||||
|
// pre-compute DH
|
||||||
|
|
||||||
|
handshake := &peer.handshake
|
||||||
|
handshake.mutex.Lock()
|
||||||
|
handshake.remoteStatic = pk
|
||||||
|
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
|
||||||
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
|
// reset endpoint
|
||||||
|
|
||||||
|
peer.endpoint = nil
|
||||||
|
|
||||||
|
// start peer
|
||||||
|
|
||||||
|
if peer.device.isUp.Get() {
|
||||||
|
peer.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
return peer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) SendBuffer(buffer []byte) error {
|
||||||
|
peer.device.net.mutex.RLock()
|
||||||
|
defer peer.device.net.mutex.RUnlock()
|
||||||
|
|
||||||
|
if peer.device.net.bind == nil {
|
||||||
|
return errors.New("no bind")
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.mutex.RLock()
|
||||||
|
defer peer.mutex.RUnlock()
|
||||||
|
|
||||||
|
if peer.endpoint == nil {
|
||||||
|
return errors.New("no known endpoint for peer")
|
||||||
|
}
|
||||||
|
|
||||||
|
return peer.device.net.bind.Send(buffer, peer.endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) String() string {
|
||||||
|
base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
|
||||||
|
abbreviatedKey := "invalid"
|
||||||
|
if len(base64Key) == 44 {
|
||||||
|
abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43]
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("peer(%s)", abbreviatedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) Start() {
|
||||||
|
|
||||||
|
// should never start a peer on a closed device
|
||||||
|
|
||||||
|
if peer.device.isClosed.Get() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// prevent simultaneous start/stop operations
|
||||||
|
|
||||||
|
peer.routines.mutex.Lock()
|
||||||
|
defer peer.routines.mutex.Unlock()
|
||||||
|
|
||||||
|
if peer.isRunning.Get() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
device := peer.device
|
||||||
|
device.log.Debug.Println(peer, ": Starting...")
|
||||||
|
|
||||||
|
// reset routine state
|
||||||
|
|
||||||
|
peer.routines.starting.Wait()
|
||||||
|
peer.routines.stopping.Wait()
|
||||||
|
peer.routines.stop = make(chan struct{})
|
||||||
|
peer.routines.starting.Add(PeerRoutineNumber)
|
||||||
|
peer.routines.stopping.Add(PeerRoutineNumber)
|
||||||
|
|
||||||
|
// prepare queues
|
||||||
|
|
||||||
|
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
|
||||||
|
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
|
||||||
|
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
|
||||||
|
|
||||||
|
peer.timersInit()
|
||||||
|
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
||||||
|
peer.signals.newKeypairArrived = make(chan struct{}, 1)
|
||||||
|
peer.signals.flushNonceQueue = make(chan struct{}, 1)
|
||||||
|
|
||||||
|
// wait for routines to start
|
||||||
|
|
||||||
|
go peer.RoutineNonce()
|
||||||
|
go peer.RoutineSequentialSender()
|
||||||
|
go peer.RoutineSequentialReceiver()
|
||||||
|
|
||||||
|
peer.routines.starting.Wait()
|
||||||
|
peer.isRunning.Set(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) ZeroAndFlushAll() {
|
||||||
|
device := peer.device
|
||||||
|
|
||||||
|
// clear key pairs
|
||||||
|
|
||||||
|
keypairs := &peer.keypairs
|
||||||
|
keypairs.mutex.Lock()
|
||||||
|
device.DeleteKeypair(keypairs.previous)
|
||||||
|
device.DeleteKeypair(keypairs.current)
|
||||||
|
device.DeleteKeypair(keypairs.next)
|
||||||
|
keypairs.previous = nil
|
||||||
|
keypairs.current = nil
|
||||||
|
keypairs.next = nil
|
||||||
|
keypairs.mutex.Unlock()
|
||||||
|
|
||||||
|
// clear handshake state
|
||||||
|
|
||||||
|
handshake := &peer.handshake
|
||||||
|
handshake.mutex.Lock()
|
||||||
|
device.indexTable.Delete(handshake.localIndex)
|
||||||
|
handshake.Clear()
|
||||||
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
|
peer.FlushNonceQueue()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) Stop() {
|
||||||
|
|
||||||
|
// prevent simultaneous start/stop operations
|
||||||
|
|
||||||
|
peer.routines.mutex.Lock()
|
||||||
|
defer peer.routines.mutex.Unlock()
|
||||||
|
|
||||||
|
if !peer.isRunning.Swap(false) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
peer.device.log.Debug.Println(peer, ": Stopping...")
|
||||||
|
|
||||||
|
peer.timersStop()
|
||||||
|
|
||||||
|
// stop & wait for ongoing peer routines
|
||||||
|
|
||||||
|
peer.routines.starting.Wait()
|
||||||
|
close(peer.routines.stop)
|
||||||
|
peer.routines.stopping.Wait()
|
||||||
|
|
||||||
|
// close queues
|
||||||
|
|
||||||
|
close(peer.queue.nonce)
|
||||||
|
close(peer.queue.outbound)
|
||||||
|
close(peer.queue.inbound)
|
||||||
|
|
||||||
|
peer.ZeroAndFlushAll()
|
||||||
|
}
|
|
@ -1,12 +1,12 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ratelimiter
|
package ratelimiter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
@ -20,106 +20,118 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
type RatelimiterEntry struct {
|
type RatelimiterEntry struct {
|
||||||
mu sync.Mutex
|
mutex sync.Mutex
|
||||||
lastTime time.Time
|
lastTime time.Time
|
||||||
tokens int64
|
tokens int64
|
||||||
}
|
}
|
||||||
|
|
||||||
type Ratelimiter struct {
|
type Ratelimiter struct {
|
||||||
mu sync.RWMutex
|
mutex sync.RWMutex
|
||||||
timeNow func() time.Time
|
stop chan struct{}
|
||||||
|
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
|
||||||
stopReset chan struct{} // send to reset, close to stop
|
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
|
||||||
table map[netip.Addr]*RatelimiterEntry
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rate *Ratelimiter) Close() {
|
func (rate *Ratelimiter) Close() {
|
||||||
rate.mu.Lock()
|
rate.mutex.Lock()
|
||||||
defer rate.mu.Unlock()
|
defer rate.mutex.Unlock()
|
||||||
|
|
||||||
if rate.stopReset != nil {
|
if rate.stop != nil {
|
||||||
close(rate.stopReset)
|
close(rate.stop)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rate *Ratelimiter) Init() {
|
func (rate *Ratelimiter) Init() {
|
||||||
rate.mu.Lock()
|
rate.mutex.Lock()
|
||||||
defer rate.mu.Unlock()
|
defer rate.mutex.Unlock()
|
||||||
|
|
||||||
if rate.timeNow == nil {
|
|
||||||
rate.timeNow = time.Now
|
|
||||||
}
|
|
||||||
|
|
||||||
// stop any ongoing garbage collection routine
|
// stop any ongoing garbage collection routine
|
||||||
if rate.stopReset != nil {
|
|
||||||
close(rate.stopReset)
|
if rate.stop != nil {
|
||||||
|
close(rate.stop)
|
||||||
}
|
}
|
||||||
|
|
||||||
rate.stopReset = make(chan struct{})
|
rate.stop = make(chan struct{})
|
||||||
rate.table = make(map[netip.Addr]*RatelimiterEntry)
|
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
|
||||||
|
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
|
||||||
|
|
||||||
stopReset := rate.stopReset // store in case Init is called again.
|
// start garbage collection routine
|
||||||
|
|
||||||
// Start garbage collection routine.
|
|
||||||
go func() {
|
go func() {
|
||||||
ticker := time.NewTicker(time.Second)
|
ticker := time.NewTicker(time.Second)
|
||||||
ticker.Stop()
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case _, ok := <-stopReset:
|
case <-rate.stop:
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
if !ok {
|
return
|
||||||
return
|
|
||||||
}
|
|
||||||
ticker = time.NewTicker(time.Second)
|
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
if rate.cleanup() {
|
func() {
|
||||||
ticker.Stop()
|
rate.mutex.Lock()
|
||||||
}
|
defer rate.mutex.Unlock()
|
||||||
|
|
||||||
|
for key, entry := range rate.tableIPv4 {
|
||||||
|
entry.mutex.Lock()
|
||||||
|
if time.Now().Sub(entry.lastTime) > garbageCollectTime {
|
||||||
|
delete(rate.tableIPv4, key)
|
||||||
|
}
|
||||||
|
entry.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, entry := range rate.tableIPv6 {
|
||||||
|
entry.mutex.Lock()
|
||||||
|
if time.Now().Sub(entry.lastTime) > garbageCollectTime {
|
||||||
|
delete(rate.tableIPv6, key)
|
||||||
|
}
|
||||||
|
entry.mutex.Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (rate *Ratelimiter) cleanup() (empty bool) {
|
func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
||||||
rate.mu.Lock()
|
var entry *RatelimiterEntry
|
||||||
defer rate.mu.Unlock()
|
var keyIPv4 [net.IPv4len]byte
|
||||||
|
var keyIPv6 [net.IPv6len]byte
|
||||||
|
|
||||||
for key, entry := range rate.table {
|
// lookup entry
|
||||||
entry.mu.Lock()
|
|
||||||
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
IPv4 := ip.To4()
|
||||||
delete(rate.table, key)
|
IPv6 := ip.To16()
|
||||||
}
|
|
||||||
entry.mu.Unlock()
|
rate.mutex.RLock()
|
||||||
|
|
||||||
|
if IPv4 != nil {
|
||||||
|
copy(keyIPv4[:], IPv4)
|
||||||
|
entry = rate.tableIPv4[keyIPv4]
|
||||||
|
} else {
|
||||||
|
copy(keyIPv6[:], IPv6)
|
||||||
|
entry = rate.tableIPv6[keyIPv6]
|
||||||
}
|
}
|
||||||
|
|
||||||
return len(rate.table) == 0
|
rate.mutex.RUnlock()
|
||||||
}
|
|
||||||
|
|
||||||
func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
|
|
||||||
var entry *RatelimiterEntry
|
|
||||||
// lookup entry
|
|
||||||
rate.mu.RLock()
|
|
||||||
entry = rate.table[ip]
|
|
||||||
rate.mu.RUnlock()
|
|
||||||
|
|
||||||
// make new entry if not found
|
// make new entry if not found
|
||||||
|
|
||||||
if entry == nil {
|
if entry == nil {
|
||||||
entry = new(RatelimiterEntry)
|
entry = new(RatelimiterEntry)
|
||||||
entry.tokens = maxTokens - packetCost
|
entry.tokens = maxTokens - packetCost
|
||||||
entry.lastTime = rate.timeNow()
|
entry.lastTime = time.Now()
|
||||||
rate.mu.Lock()
|
rate.mutex.Lock()
|
||||||
rate.table[ip] = entry
|
if IPv4 != nil {
|
||||||
if len(rate.table) == 1 {
|
rate.tableIPv4[keyIPv4] = entry
|
||||||
rate.stopReset <- struct{}{}
|
} else {
|
||||||
|
rate.tableIPv6[keyIPv6] = entry
|
||||||
}
|
}
|
||||||
rate.mu.Unlock()
|
rate.mutex.Unlock()
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// add tokens to entry
|
// add tokens to entry
|
||||||
entry.mu.Lock()
|
|
||||||
now := rate.timeNow()
|
entry.mutex.Lock()
|
||||||
|
now := time.Now()
|
||||||
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
||||||
entry.lastTime = now
|
entry.lastTime = now
|
||||||
if entry.tokens > maxTokens {
|
if entry.tokens > maxTokens {
|
||||||
|
@ -127,11 +139,12 @@ func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// subtract cost of packet
|
// subtract cost of packet
|
||||||
|
|
||||||
if entry.tokens > packetCost {
|
if entry.tokens > packetCost {
|
||||||
entry.tokens -= packetCost
|
entry.tokens -= packetCost
|
||||||
entry.mu.Unlock()
|
entry.mutex.Unlock()
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
entry.mu.Unlock()
|
entry.mutex.Unlock()
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,31 +1,32 @@
|
||||||
/* SPDX-License-Identifier: MIT
|
/* SPDX-License-Identifier: GPL-2.0
|
||||||
*
|
*
|
||||||
* Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
|
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
package ratelimiter
|
package ratelimiter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type result struct {
|
type RatelimiterResult struct {
|
||||||
allowed bool
|
allowed bool
|
||||||
text string
|
text string
|
||||||
wait time.Duration
|
wait time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRatelimiter(t *testing.T) {
|
func TestRatelimiter(t *testing.T) {
|
||||||
var rate Ratelimiter
|
|
||||||
var expectedResults []result
|
|
||||||
|
|
||||||
nano := func(nano int64) time.Duration {
|
var ratelimiter Ratelimiter
|
||||||
|
var expectedResults []RatelimiterResult
|
||||||
|
|
||||||
|
Nano := func(nano int64) time.Duration {
|
||||||
return time.Nanosecond * time.Duration(nano)
|
return time.Nanosecond * time.Duration(nano)
|
||||||
}
|
}
|
||||||
|
|
||||||
add := func(res result) {
|
Add := func(res RatelimiterResult) {
|
||||||
expectedResults = append(
|
expectedResults = append(
|
||||||
expectedResults,
|
expectedResults,
|
||||||
res,
|
res,
|
||||||
|
@ -33,86 +34,69 @@ func TestRatelimiter(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < packetsBurstable; i++ {
|
for i := 0; i < packetsBurstable; i++ {
|
||||||
add(result{
|
Add(RatelimiterResult{
|
||||||
allowed: true,
|
allowed: true,
|
||||||
text: "initial burst",
|
text: "inital burst",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
add(result{
|
Add(RatelimiterResult{
|
||||||
allowed: false,
|
allowed: false,
|
||||||
text: "after burst",
|
text: "after burst",
|
||||||
})
|
})
|
||||||
|
|
||||||
add(result{
|
Add(RatelimiterResult{
|
||||||
allowed: true,
|
allowed: true,
|
||||||
wait: nano(time.Second.Nanoseconds() / packetsPerSecond),
|
wait: Nano(time.Second.Nanoseconds() / packetsPerSecond),
|
||||||
text: "filling tokens for single packet",
|
text: "filling tokens for single packet",
|
||||||
})
|
})
|
||||||
|
|
||||||
add(result{
|
Add(RatelimiterResult{
|
||||||
allowed: false,
|
allowed: false,
|
||||||
text: "not having refilled enough",
|
text: "not having refilled enough",
|
||||||
})
|
})
|
||||||
|
|
||||||
add(result{
|
Add(RatelimiterResult{
|
||||||
allowed: true,
|
allowed: true,
|
||||||
wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
|
wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)),
|
||||||
text: "filling tokens for two packet burst",
|
text: "filling tokens for two packet burst",
|
||||||
})
|
})
|
||||||
|
|
||||||
add(result{
|
Add(RatelimiterResult{
|
||||||
allowed: true,
|
allowed: true,
|
||||||
text: "second packet in 2 packet burst",
|
text: "second packet in 2 packet burst",
|
||||||
})
|
})
|
||||||
|
|
||||||
add(result{
|
Add(RatelimiterResult{
|
||||||
allowed: false,
|
allowed: false,
|
||||||
text: "packet following 2 packet burst",
|
text: "packet following 2 packet burst",
|
||||||
})
|
})
|
||||||
|
|
||||||
ips := []netip.Addr{
|
ips := []net.IP{
|
||||||
netip.MustParseAddr("127.0.0.1"),
|
net.ParseIP("127.0.0.1"),
|
||||||
netip.MustParseAddr("192.168.1.1"),
|
net.ParseIP("192.168.1.1"),
|
||||||
netip.MustParseAddr("172.167.2.3"),
|
net.ParseIP("172.167.2.3"),
|
||||||
netip.MustParseAddr("97.231.252.215"),
|
net.ParseIP("97.231.252.215"),
|
||||||
netip.MustParseAddr("248.97.91.167"),
|
net.ParseIP("248.97.91.167"),
|
||||||
netip.MustParseAddr("188.208.233.47"),
|
net.ParseIP("188.208.233.47"),
|
||||||
netip.MustParseAddr("104.2.183.179"),
|
net.ParseIP("104.2.183.179"),
|
||||||
netip.MustParseAddr("72.129.46.120"),
|
net.ParseIP("72.129.46.120"),
|
||||||
netip.MustParseAddr("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
|
net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
|
||||||
netip.MustParseAddr("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
|
net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
|
||||||
netip.MustParseAddr("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
|
net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
|
||||||
netip.MustParseAddr("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
|
net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
|
||||||
netip.MustParseAddr("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
|
net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
|
||||||
netip.MustParseAddr("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
|
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
ratelimiter.Init()
|
||||||
rate.timeNow = func() time.Time {
|
|
||||||
return now
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
// Lock to avoid data race with cleanup goroutine from Init.
|
|
||||||
rate.mu.Lock()
|
|
||||||
defer rate.mu.Unlock()
|
|
||||||
|
|
||||||
rate.timeNow = time.Now
|
|
||||||
}()
|
|
||||||
timeSleep := func(d time.Duration) {
|
|
||||||
now = now.Add(d + 1)
|
|
||||||
rate.cleanup()
|
|
||||||
}
|
|
||||||
|
|
||||||
rate.Init()
|
|
||||||
defer rate.Close()
|
|
||||||
|
|
||||||
for i, res := range expectedResults {
|
for i, res := range expectedResults {
|
||||||
timeSleep(res.wait)
|
time.Sleep(res.wait)
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
allowed := rate.Allow(ip)
|
allowed := ratelimiter.Allow(ip)
|
||||||
if allowed != res.allowed {
|
if allowed != res.allowed {
|
||||||
t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
|
t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Reference in a new issue