Skip to content

Commit 0b465cc

Browse files
committed
Allow a custom port range for EC2 VMs
Set the additional text with a comma-separated list of ports i.e. 22,443,80,8080 and these will be added to the security group. Signed-off-by: Alex Ellis (OpenFaaS Ltd) <alexellis2@gmail.com>
1 parent 8351174 commit 0b465cc

File tree

2 files changed

+103
-7
lines changed

2 files changed

+103
-7
lines changed

provision/ec2.go

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ package provision
22

33
import (
44
"fmt"
5-
"github.com/aws/aws-sdk-go/aws/credentials"
65
"strconv"
76
"strings"
87

8+
"github.com/aws/aws-sdk-go/aws/credentials"
9+
910
"github.com/aws/aws-sdk-go/aws"
1011
"github.com/aws/aws-sdk-go/aws/session"
1112
"github.com/aws/aws-sdk-go/service/ec2"
@@ -40,10 +41,12 @@ func (p *EC2Provisioner) Provision(host BasicHost) (*ProvisionedHost, error) {
4041
}
4142
pro := host.Additional["pro"]
4243

44+
ports := host.Additional["ports"]
45+
4346
var vpcID = host.Additional["vpc-id"]
4447
var subnetID = host.Additional["subnet-id"]
4548

46-
groupID, name, err := p.createEC2SecurityGroup(vpcID, port, pro)
49+
groupID, name, err := p.createEC2SecurityGroup(vpcID, port, pro, ports)
4750
if err != nil {
4851
return nil, err
4952
}
@@ -85,6 +88,7 @@ func (p *EC2Provisioner) Provision(host BasicHost) (*ProvisionedHost, error) {
8588
return nil, fmt.Errorf("could not create host: %s", runResult.String())
8689
}
8790

91+
// AE: not sure why this error isn't handled?
8892
_, err = p.ec2Provisioner.CreateTags(&ec2.CreateTagsInput{
8993
Resources: []*string{runResult.Instances[0].InstanceId},
9094
Tags: []*ec2.Tag{
@@ -247,9 +251,21 @@ func (p *EC2Provisioner) lookupID(request HostDeleteRequest) (string, error) {
247251
}
248252

249253
// createEC2SecurityGroup creates a security group for the exit-node
250-
func (p *EC2Provisioner) createEC2SecurityGroup(vpcID string, controlPort int, pro string) (*string, *string, error) {
251-
ports := []int{80, 443, controlPort}
252-
proPorts := []int{1024, 65535}
254+
func (p *EC2Provisioner) createEC2SecurityGroup(vpcID string, controlPort int, pro, extraPorts string) (*string, *string, error) {
255+
ports := []int{controlPort}
256+
257+
proPortRange := []int{1024, 65535}
258+
259+
if len(extraPorts) > 0 {
260+
extraPorts, err := parsePorts(extraPorts)
261+
if err != nil {
262+
return nil, nil, err
263+
}
264+
ports = append(ports, extraPorts...)
265+
266+
proPortRange = []int{}
267+
}
268+
253269
groupName := "inlets-" + uuid.New().String()
254270
var input = &ec2.CreateSecurityGroupInput{
255271
Description: aws.String("inlets security group"),
@@ -271,8 +287,9 @@ func (p *EC2Provisioner) createEC2SecurityGroup(vpcID string, controlPort int, p
271287
return group.GroupId, &groupName, err
272288
}
273289
}
274-
if pro == "true" {
275-
err = p.createEC2SecurityGroupRule(*group.GroupId, proPorts[0], proPorts[1])
290+
291+
if pro == "true" && len(proPortRange) == 2 {
292+
err = p.createEC2SecurityGroupRule(*group.GroupId, proPortRange[0], proPortRange[1])
276293
if err != nil {
277294
return group.GroupId, &groupName, err
278295
}
@@ -281,6 +298,22 @@ func (p *EC2Provisioner) createEC2SecurityGroup(vpcID string, controlPort int, p
281298
return group.GroupId, &groupName, nil
282299
}
283300

301+
func parsePorts(extraPorts string) ([]int, error) {
302+
var ports []int
303+
parts := strings.Split(extraPorts, ",")
304+
for _, part := range parts {
305+
if trimmed := strings.TrimSpace(part); len(trimmed) > 0 {
306+
port, err := strconv.Atoi(trimmed)
307+
if err != nil {
308+
return nil, err
309+
}
310+
ports = append(ports, port)
311+
}
312+
}
313+
314+
return ports, nil
315+
}
316+
284317
func (p *EC2Provisioner) createEC2SecurityGroupRule(groupID string, fromPort, toPort int) error {
285318
_, err := p.ec2Provisioner.AuthorizeSecurityGroupIngress(&ec2.AuthorizeSecurityGroupIngressInput{
286319
CidrIp: aws.String("0.0.0.0/0"),

provision/ec2_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package provision
2+
3+
import "testing"
4+
5+
func Test_parsePorts_empty(t *testing.T) {
6+
7+
ports, err := parsePorts("")
8+
if err != nil {
9+
t.Fatal(err)
10+
}
11+
12+
if len(ports) != 0 {
13+
t.Fatalf("Expected empty slice, got %d", len(ports))
14+
}
15+
}
16+
17+
func Test_parsePorts_single(t *testing.T) {
18+
19+
wantPort := 80
20+
str := "80"
21+
ports, err := parsePorts(str)
22+
if err != nil {
23+
t.Fatal(err)
24+
}
25+
26+
if len(ports) != 1 {
27+
t.Fatalf("Want single port, got %d", len(ports))
28+
}
29+
30+
if ports[0] != wantPort {
31+
t.Fatalf("Want port %d, got %d", wantPort, ports[0])
32+
}
33+
}
34+
35+
func Test_parsePorts_multiple(t *testing.T) {
36+
37+
wantPorts := []int{27017, 22}
38+
39+
str := "27017,22"
40+
41+
ports, err := parsePorts(str)
42+
if err != nil {
43+
t.Fatal(err)
44+
}
45+
46+
if len(ports) != len(wantPorts) {
47+
t.Fatalf("Want %d ports, got %d", len(wantPorts), len(ports))
48+
}
49+
50+
found := 0
51+
52+
for _, port := range ports {
53+
for _, wantPort := range wantPorts {
54+
if port == wantPort {
55+
found++
56+
}
57+
}
58+
}
59+
60+
if found != len(wantPorts) {
61+
t.Fatalf("Want %v ports, got %v", wantPorts, ports)
62+
}
63+
}

0 commit comments

Comments
 (0)