@@ -64,7 +64,7 @@ BruteForce::Train(const DatasetPtr& data) {
6464}
6565
6666std::vector<int64_t >
67- BruteForce::Add (const DatasetPtr& data) {
67+ BruteForce::Add (const DatasetPtr& data, AddMode mode ) {
6868 std::vector<int64_t > failed_ids;
6969 auto base_dim = data->GetDim ();
7070 CHECK_ARGUMENT (base_dim == dim_,
@@ -146,34 +146,44 @@ BruteForce::Add(const DatasetPtr& data) {
146146 return failed_ids;
147147}
148148
149- bool
150- BruteForce::Remove (int64_t label ) {
149+ uint32_t
150+ BruteForce::Remove (const std::vector< int64_t >& ids, RemoveMode mode ) {
151151 CHECK_ARGUMENT (not use_attribute_filter_,
152152 " remove is not supported when use_attribute_filter is true" );
153153
154+ uint32_t delete_count = 0 ;
155+ if (mode == RemoveMode::MARK_REMOVE) {
156+ std::scoped_lock label_lock (this ->label_lookup_mutex_ );
157+ delete_count = this ->label_table_ ->MarkRemove (ids);
158+ delete_count_ += delete_count;
159+ return delete_count;
160+ }
161+
154162 std::scoped_lock lock (this ->add_mutex_ , this ->label_lookup_mutex_ );
155- const auto last_inner_id = static_cast <InnerIdType>(this ->total_count_ - 1 );
156- const auto inner_id = this ->label_table_ ->GetIdByLabel (label);
163+ for (auto label : ids) {
164+ const auto last_inner_id = static_cast <InnerIdType>(this ->total_count_ - 1 );
165+ const auto inner_id = this ->label_table_ ->GetIdByLabel (label);
157166
158- CHECK_ARGUMENT (inner_id <= last_inner_id, " the element to be remove is invalid" );
167+ CHECK_ARGUMENT (inner_id <= last_inner_id, " the element to be remove is invalid" );
159168
160- const auto last_label = this ->label_table_ ->GetLabelById (last_inner_id);
161- this ->label_table_ ->Remove (label);
162- --this ->label_table_ ->total_count_ ;
169+ const auto last_label = this ->label_table_ ->GetLabelById (last_inner_id);
170+ this ->label_table_ ->MarkRemove (label);
171+ --this ->label_table_ ->total_count_ ;
163172
164- if (inner_id < last_inner_id) {
165- Vector<float > data (dim_, allocator_);
166- GetVectorByInnerId (last_inner_id, data.data ());
173+ if (inner_id < last_inner_id) {
174+ Vector<float > data (dim_, allocator_);
175+ GetVectorByInnerId (last_inner_id, data.data ());
167176
168- this ->label_table_ ->Remove (last_label);
169- --this ->label_table_ ->total_count_ ;
177+ this ->label_table_ ->MarkRemove (last_label);
178+ --this ->label_table_ ->total_count_ ;
170179
171- this ->inner_codes_ ->InsertVector (data.data (), inner_id);
172- this ->label_table_ ->Insert (inner_id, last_label);
173- }
180+ this ->inner_codes_ ->InsertVector (data.data (), inner_id);
181+ this ->label_table_ ->Insert (inner_id, last_label);
182+ }
174183
175- this ->total_count_ --;
176- return true ;
184+ this ->total_count_ --;
185+ }
186+ return 1 ;
177187}
178188
179189DatasetPtr
@@ -199,10 +209,18 @@ BruteForce::SearchWithRequest(const SearchRequest& request) const {
199209 DistHeapPtr heap = nullptr ;
200210 ExecutorPtr executor = nullptr ;
201211 Filter* attr_filter = nullptr ;
202- Filter* filter = nullptr ;
212+
213+ auto combined_filter = std::make_shared<CombinedFilter>();
214+ combined_filter->AppendFilter (this ->label_table_ ->GetDeletedIdsFilter ());
203215 if (request.filter_ != nullptr ) {
204- filter = request.filter_ .get ();
216+ combined_filter->AppendFilter (
217+ std::make_shared<InnerIdWrapperFilter>(request.filter_ , *this ->label_table_ ));
205218 }
219+ FilterPtr ft = nullptr ;
220+ if (not combined_filter->IsEmpty ()) {
221+ ft = combined_filter;
222+ }
223+
206224 if (request.enable_attribute_filter_ ) {
207225 auto & schema = this ->attr_filter_index_ ->field_type_map_ ;
208226 auto expr = AstParse (request.attribute_filter_str_ , &schema);
@@ -228,7 +246,7 @@ BruteForce::SearchWithRequest(const SearchRequest& request) const {
228246 if (attr_filter != nullptr and not attr_filter->CheckValid (i)) {
229247 continue ;
230248 }
231- if (filter == nullptr or filter ->CheckValid (this -> label_table_ -> GetLabelById (i) )) {
249+ if (ft == nullptr or ft ->CheckValid (i )) {
232250 inner_codes_->Query (&dist, computer, &i, 1 );
233251 ++dist_cmp_local;
234252 cur_heap->Push (dist, i);
0 commit comments