Skip to content

Commit ec3b49f

Browse files
refactor(billing): Billing simplification entitlement usage (#1312)
Co-authored-by: Claude <[email protected]>
1 parent 52877bb commit ec3b49f

File tree

9 files changed

+8114
-8028
lines changed

9 files changed

+8114
-8028
lines changed

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ TAG := $(shell git rev-list --tags --max-count=1)
44
VERSION := $(shell git describe --tags ${TAG})
55
.PHONY: build check fmt lint test test-race vet test-cover-html help install proto ui compose-up-dev
66
.DEFAULT_GOAL := build
7-
PROTON_COMMIT := "80fc5ba1e538e38d5ca190386af1e69ee64584ee"
7+
PROTON_COMMIT := "4144445eb0f9cbd1a801a3d0aa5cfce4cc0ea551"
88

99
ui:
1010
@echo " > generating ui build"

internal/api/v1beta1connect/billing_check.go

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package v1beta1connect
22

33
import (
44
"context"
5+
"errors"
56

67
"connectrpc.com/connect"
78
"github.com/raystack/frontier/billing/customer"
@@ -12,10 +13,25 @@ import (
1213
func (h *ConnectHandler) CheckFeatureEntitlement(ctx context.Context, request *connect.Request[frontierv1beta1.CheckFeatureEntitlementRequest]) (*connect.Response[frontierv1beta1.CheckFeatureEntitlementResponse], error) {
1314
errorLogger := NewErrorLogger()
1415

15-
checkStatus, err := h.entitlementService.Check(ctx, request.Msg.GetBillingId(), request.Msg.GetFeature())
16+
// Always infer billing_id from org_id
17+
cust, err := h.customerService.GetByOrgID(ctx, request.Msg.GetOrgId())
18+
if err != nil {
19+
if errors.Is(err, customer.ErrNotFound) {
20+
return connect.NewResponse(&frontierv1beta1.CheckFeatureEntitlementResponse{}), nil
21+
}
22+
if errors.Is(err, customer.ErrInvalidUUID) || errors.Is(err, customer.ErrInvalidID) {
23+
return nil, connect.NewError(connect.CodeInvalidArgument, err)
24+
}
25+
errorLogger.LogServiceError(ctx, request, "CheckFeatureEntitlement.GetByOrgID", err,
26+
zap.String("org_id", request.Msg.GetOrgId()))
27+
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
28+
}
29+
30+
checkStatus, err := h.entitlementService.Check(ctx, cust.ID, request.Msg.GetFeature())
1631
if err != nil {
1732
errorLogger.LogServiceError(ctx, request, "CheckFeatureEntitlement", err,
18-
zap.String("billing_id", request.Msg.GetBillingId()),
33+
zap.String("billing_id", cust.ID),
34+
zap.String("org_id", request.Msg.GetOrgId()),
1935
zap.String("feature", request.Msg.GetFeature()))
2036
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
2137
}

internal/api/v1beta1connect/billing_check_test.go

Lines changed: 49 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,23 @@ import (
1515

1616
func TestConnectHandler_CheckFeatureEntitlement(t *testing.T) {
1717
tests := []struct {
18-
name string
19-
setup func(es *mocks.EntitlementService)
20-
request *connect.Request[frontierv1beta1.CheckFeatureEntitlementRequest]
21-
want *connect.Response[frontierv1beta1.CheckFeatureEntitlementResponse]
22-
wantErr error
23-
errCode connect.Code
18+
name string
19+
customerSetup func(cs *mocks.CustomerService)
20+
setup func(es *mocks.EntitlementService)
21+
request *connect.Request[frontierv1beta1.CheckFeatureEntitlementRequest]
22+
want *connect.Response[frontierv1beta1.CheckFeatureEntitlementResponse]
23+
wantErr error
24+
errCode connect.Code
2425
}{
2526
{
2627
name: "should return internal server error when entitlement service returns error",
2728
request: connect.NewRequest(&frontierv1beta1.CheckFeatureEntitlementRequest{
28-
BillingId: "billing-123",
29-
Feature: "feature-abc",
29+
OrgId: "org-123",
30+
Feature: "feature-abc",
3031
}),
32+
customerSetup: func(cs *mocks.CustomerService) {
33+
cs.EXPECT().GetByOrgID(mock.Anything, "org-123").Return(customer.Customer{ID: "billing-123"}, nil)
34+
},
3135
setup: func(es *mocks.EntitlementService) {
3236
es.EXPECT().Check(mock.Anything, "billing-123", "feature-abc").Return(false, errors.New("service error"))
3337
},
@@ -37,13 +41,16 @@ func TestConnectHandler_CheckFeatureEntitlement(t *testing.T) {
3741
},
3842
{
3943
name: "should return false when feature is not entitled",
44+
request: connect.NewRequest(&frontierv1beta1.CheckFeatureEntitlementRequest{
45+
OrgId: "org-123",
46+
Feature: "feature-abc",
47+
}),
48+
customerSetup: func(cs *mocks.CustomerService) {
49+
cs.EXPECT().GetByOrgID(mock.Anything, "org-123").Return(customer.Customer{ID: "billing-123"}, nil)
50+
},
4051
setup: func(es *mocks.EntitlementService) {
4152
es.EXPECT().Check(mock.Anything, "billing-123", "feature-abc").Return(false, nil)
4253
},
43-
request: connect.NewRequest(&frontierv1beta1.CheckFeatureEntitlementRequest{
44-
BillingId: "billing-123",
45-
Feature: "feature-abc",
46-
}),
4754
want: connect.NewResponse(&frontierv1beta1.CheckFeatureEntitlementResponse{
4855
Status: false,
4956
}),
@@ -52,57 +59,63 @@ func TestConnectHandler_CheckFeatureEntitlement(t *testing.T) {
5259
},
5360
{
5461
name: "should return true when feature is entitled",
62+
request: connect.NewRequest(&frontierv1beta1.CheckFeatureEntitlementRequest{
63+
OrgId: "org-123",
64+
Feature: "feature-abc",
65+
}),
66+
customerSetup: func(cs *mocks.CustomerService) {
67+
cs.EXPECT().GetByOrgID(mock.Anything, "org-123").Return(customer.Customer{ID: "billing-123"}, nil)
68+
},
5569
setup: func(es *mocks.EntitlementService) {
5670
es.EXPECT().Check(mock.Anything, "billing-123", "feature-abc").Return(true, nil)
5771
},
58-
request: connect.NewRequest(&frontierv1beta1.CheckFeatureEntitlementRequest{
59-
BillingId: "billing-123",
60-
Feature: "feature-abc",
61-
}),
6272
want: connect.NewResponse(&frontierv1beta1.CheckFeatureEntitlementResponse{
6373
Status: true,
6474
}),
6575
wantErr: nil,
6676
errCode: connect.Code(0),
6777
},
6878
{
69-
name: "should handle empty billing id",
70-
setup: func(es *mocks.EntitlementService) {
71-
es.EXPECT().Check(mock.Anything, "", "feature-abc").Return(false, nil)
72-
},
79+
name: "should return empty response when billing account not found",
7380
request: connect.NewRequest(&frontierv1beta1.CheckFeatureEntitlementRequest{
74-
BillingId: "",
75-
Feature: "feature-abc",
76-
}),
77-
want: connect.NewResponse(&frontierv1beta1.CheckFeatureEntitlementResponse{
78-
Status: false,
81+
OrgId: "org-123",
82+
Feature: "feature-abc",
7983
}),
84+
customerSetup: func(cs *mocks.CustomerService) {
85+
cs.EXPECT().GetByOrgID(mock.Anything, "org-123").Return(customer.Customer{}, customer.ErrNotFound)
86+
},
87+
setup: func(es *mocks.EntitlementService) {},
88+
want: connect.NewResponse(&frontierv1beta1.CheckFeatureEntitlementResponse{}),
8089
wantErr: nil,
8190
errCode: connect.Code(0),
8291
},
8392
{
84-
name: "should handle empty feature",
85-
setup: func(es *mocks.EntitlementService) {
86-
es.EXPECT().Check(mock.Anything, "billing-123", "").Return(false, nil)
87-
},
93+
name: "should return invalid argument when org_id is invalid",
8894
request: connect.NewRequest(&frontierv1beta1.CheckFeatureEntitlementRequest{
89-
BillingId: "billing-123",
90-
Feature: "",
95+
OrgId: "",
96+
Feature: "feature-abc",
9197
}),
92-
want: connect.NewResponse(&frontierv1beta1.CheckFeatureEntitlementResponse{
93-
Status: false,
94-
}),
95-
wantErr: nil,
96-
errCode: connect.Code(0),
98+
customerSetup: func(cs *mocks.CustomerService) {
99+
cs.EXPECT().GetByOrgID(mock.Anything, "").Return(customer.Customer{}, customer.ErrInvalidUUID)
100+
},
101+
setup: func(es *mocks.EntitlementService) {},
102+
want: nil,
103+
wantErr: customer.ErrInvalidUUID,
104+
errCode: connect.CodeInvalidArgument,
97105
},
98106
}
99107
for _, tt := range tests {
100108
t.Run(tt.name, func(t *testing.T) {
109+
mockCustomerSvc := new(mocks.CustomerService)
101110
mockEntitlementSvc := new(mocks.EntitlementService)
111+
if tt.customerSetup != nil {
112+
tt.customerSetup(mockCustomerSvc)
113+
}
102114
if tt.setup != nil {
103115
tt.setup(mockEntitlementSvc)
104116
}
105117
h := &ConnectHandler{
118+
customerService: mockCustomerSvc,
106119
entitlementService: mockEntitlementSvc,
107120
}
108121
got, err := h.CheckFeatureEntitlement(context.Background(), tt.request)

internal/api/v1beta1connect/billing_usage.go

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@ import (
2020
func (h *ConnectHandler) CreateBillingUsage(ctx context.Context, request *connect.Request[frontierv1beta1.CreateBillingUsageRequest]) (*connect.Response[frontierv1beta1.CreateBillingUsageResponse], error) {
2121
errorLogger := NewErrorLogger()
2222

23+
// Always infer billing_id from org_id
24+
cust, err := h.customerService.GetByOrgID(ctx, request.Msg.GetOrgId())
25+
if err != nil {
26+
if errors.Is(err, customer.ErrNotFound) {
27+
return nil, connect.NewError(connect.CodeNotFound, err)
28+
}
29+
if errors.Is(err, customer.ErrInvalidUUID) || errors.Is(err, customer.ErrInvalidID) {
30+
return nil, connect.NewError(connect.CodeInvalidArgument, err)
31+
}
32+
errorLogger.LogServiceError(ctx, request, "CreateBillingUsage.GetByOrgID", err,
33+
zap.String("org_id", request.Msg.GetOrgId()))
34+
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
35+
}
36+
2337
createRequests := make([]usage.Usage, 0, len(request.Msg.GetUsages()))
2438
for _, v := range request.Msg.GetUsages() {
2539
usageType := usage.CreditType
@@ -29,7 +43,7 @@ func (h *ConnectHandler) CreateBillingUsage(ctx context.Context, request *connec
2943

3044
createRequests = append(createRequests, usage.Usage{
3145
ID: v.GetId(),
32-
CustomerID: request.Msg.GetBillingId(),
46+
CustomerID: cust.ID,
3347
Type: usageType,
3448
Amount: v.GetAmount(),
3549
Source: strings.ToLower(v.GetSource()), // source in lower case looks nicer
@@ -47,7 +61,8 @@ func (h *ConnectHandler) CreateBillingUsage(ctx context.Context, request *connec
4761
return nil, connect.NewError(connect.CodeAlreadyExists, ErrAlreadyApplied)
4862
}
4963
errorLogger.LogServiceError(ctx, request, "CreateBillingUsage.Report", err,
50-
zap.String("billing_id", request.Msg.GetBillingId()),
64+
zap.String("billing_id", cust.ID),
65+
zap.String("org_id", request.Msg.GetOrgId()),
5166
zap.Int("usage_count", len(createRequests)))
5267
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
5368
}
@@ -185,7 +200,21 @@ func transformTransactionToPB(t credit.Transaction) (*frontierv1beta1.BillingTra
185200
func (h *ConnectHandler) RevertBillingUsage(ctx context.Context, request *connect.Request[frontierv1beta1.RevertBillingUsageRequest]) (*connect.Response[frontierv1beta1.RevertBillingUsageResponse], error) {
186201
errorLogger := NewErrorLogger()
187202

188-
if err := h.usageService.Revert(ctx, request.Msg.GetBillingId(),
203+
// Always infer billing_id from org_id
204+
cust, err := h.customerService.GetByOrgID(ctx, request.Msg.GetOrgId())
205+
if err != nil {
206+
if errors.Is(err, customer.ErrNotFound) {
207+
return nil, connect.NewError(connect.CodeNotFound, err)
208+
}
209+
if errors.Is(err, customer.ErrInvalidUUID) || errors.Is(err, customer.ErrInvalidID) {
210+
return nil, connect.NewError(connect.CodeInvalidArgument, err)
211+
}
212+
errorLogger.LogServiceError(ctx, request, "RevertBillingUsage.GetByOrgID", err,
213+
zap.String("org_id", request.Msg.GetOrgId()))
214+
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
215+
}
216+
217+
if err := h.usageService.Revert(ctx, cust.ID,
189218
request.Msg.GetUsageId(), request.Msg.GetAmount()); err != nil {
190219
if errors.Is(err, usage.ErrRevertAmountExceeds) {
191220
return nil, connect.NewError(connect.CodeInvalidArgument, err)
@@ -199,7 +228,8 @@ func (h *ConnectHandler) RevertBillingUsage(ctx context.Context, request *connec
199228
return nil, connect.NewError(connect.CodeInvalidArgument, err)
200229
}
201230
errorLogger.LogServiceError(ctx, request, "RevertBillingUsage.Revert", err,
202-
zap.String("billing_id", request.Msg.GetBillingId()),
231+
zap.String("billing_id", cust.ID),
232+
zap.String("org_id", request.Msg.GetOrgId()),
203233
zap.String("usage_id", request.Msg.GetUsageId()),
204234
zap.Int64("amount", request.Msg.GetAmount()))
205235
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)

0 commit comments

Comments
 (0)