Skip to content
Merged
1 change: 1 addition & 0 deletions x-pack/plugin/core/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
exports org.elasticsearch.xpack.core.indexing;
exports org.elasticsearch.xpack.core.inference.action;
exports org.elasticsearch.xpack.core.inference.results;
exports org.elasticsearch.xpack.core.inference.usage;
exports org.elasticsearch.xpack.core.inference;
exports org.elasticsearch.xpack.core.logstash;
exports org.elasticsearch.xpack.core.ml.action;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,10 @@
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.XPackFeatureUsage;
import org.elasticsearch.xpack.core.XPackField;
import org.elasticsearch.xpack.core.inference.usage.ModelStats;

import java.io.IOException;
import java.util.Collection;
Expand All @@ -25,83 +23,6 @@

public class InferenceFeatureSetUsage extends XPackFeatureUsage {

public static class ModelStats implements ToXContentObject, Writeable {

private final String service;
private final TaskType taskType;
private long count;

public ModelStats(String service, TaskType taskType) {
this(service, taskType, 0L);
}

public ModelStats(String service, TaskType taskType, long count) {
this.service = service;
this.taskType = taskType;
this.count = count;
}

public ModelStats(ModelStats stats) {
this(stats.service, stats.taskType, stats.count);
}

public ModelStats(StreamInput in) throws IOException {
this.service = in.readString();
this.taskType = in.readEnum(TaskType.class);
this.count = in.readLong();
}

public void add() {
count++;
}

public String service() {
return service;
}

public TaskType taskType() {
return taskType;
}

public long count() {
return count;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
addXContentFragment(builder, params);
builder.endObject();
return builder;
}

public void addXContentFragment(XContentBuilder builder, Params params) throws IOException {
builder.field("service", service);
builder.field("task_type", taskType.name());
builder.field("count", count);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(service);
out.writeEnum(taskType);
out.writeLong(count);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ModelStats that = (ModelStats) o;
return count == that.count && Objects.equals(service, that.service) && taskType == that.taskType;
}

@Override
public int hashCode() {
return Objects.hash(service, taskType, count);
}
}

public static final InferenceFeatureSetUsage EMPTY = new InferenceFeatureSetUsage(List.of());

private final Collection<ModelStats> modelStats;
Expand Down Expand Up @@ -144,4 +65,8 @@ public boolean equals(Object o) {
public int hashCode() {
return Objects.hashCode(modelStats);
}

Collection<ModelStats> modelStats() {
return modelStats;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.usage;

import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;

import java.io.IOException;
import java.util.Objects;

public class ModelStats implements ToXContentObject, Writeable {

private final String service;
private final TaskType taskType;
private long count;

public ModelStats(String service, TaskType taskType) {
this(service, taskType, 0L);
}

public ModelStats(String service, TaskType taskType, long count) {
this.service = service;
this.taskType = taskType;
this.count = count;
}

public ModelStats(ModelStats stats) {
this(stats.service, stats.taskType, stats.count);
}

public ModelStats(StreamInput in) throws IOException {
this.service = in.readString();
this.taskType = in.readEnum(TaskType.class);
this.count = in.readLong();
}

public void add() {
count++;
}
Comment on lines +46 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a very minor point, but could there be a test of this method added to ModelStatsTests?


public String service() {
return service;
}

public TaskType taskType() {
return taskType;
}

public long count() {
return count;
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
addXContentFragment(builder, params);
builder.endObject();
return builder;
}

public void addXContentFragment(XContentBuilder builder, Params params) throws IOException {
builder.field("service", service);
builder.field("task_type", taskType.name());
builder.field("count", count);
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(service);
out.writeEnum(taskType);
out.writeLong(count);
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ModelStats that = (ModelStats) o;
return count == that.count && Objects.equals(service, that.service) && taskType == that.taskType;
}

@Override
public int hashCode() {
return Objects.hash(service, taskType, count);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,35 +7,35 @@

package org.elasticsearch.xpack.core.inference;

import com.carrotsearch.randomizedtesting.generators.RandomStrings;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xpack.core.inference.usage.ModelStats;
import org.elasticsearch.xpack.core.inference.usage.ModelStatsTests;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

public class InferenceFeatureSetUsageTests extends AbstractWireSerializingTestCase<InferenceFeatureSetUsage.ModelStats> {
public class InferenceFeatureSetUsageTests extends AbstractWireSerializingTestCase<InferenceFeatureSetUsage> {

@Override
protected Writeable.Reader<InferenceFeatureSetUsage.ModelStats> instanceReader() {
return InferenceFeatureSetUsage.ModelStats::new;
protected Writeable.Reader<InferenceFeatureSetUsage> instanceReader() {
return InferenceFeatureSetUsage::new;
}

@Override
protected InferenceFeatureSetUsage.ModelStats createTestInstance() {
RandomStrings.randomAsciiLettersOfLength(random(), 10);
return new InferenceFeatureSetUsage.ModelStats(
randomIdentifier(),
TaskType.values()[randomInt(TaskType.values().length - 1)],
randomInt(10)
);
protected InferenceFeatureSetUsage createTestInstance() {
return new InferenceFeatureSetUsage(randomList(10, ModelStatsTests::createRandomInstance));
}

@Override
protected InferenceFeatureSetUsage.ModelStats mutateInstance(InferenceFeatureSetUsage.ModelStats modelStats) throws IOException {
InferenceFeatureSetUsage.ModelStats newModelStats = new InferenceFeatureSetUsage.ModelStats(modelStats);
newModelStats.add();
return newModelStats;
protected InferenceFeatureSetUsage mutateInstance(InferenceFeatureSetUsage instance) throws IOException {
List<ModelStats> mutatedModelStats = new ArrayList<>(instance.modelStats());
if (mutatedModelStats.isEmpty()) {
mutatedModelStats.add(ModelStatsTests.createRandomInstance());
} else {
mutatedModelStats.remove(randomIntBetween(0, mutatedModelStats.size() - 1));
}
return new InferenceFeatureSetUsage(mutatedModelStats);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.core.inference.usage;

import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.AbstractWireSerializingTestCase;

import java.io.IOException;

public class ModelStatsTests extends AbstractWireSerializingTestCase<ModelStats> {

@Override
protected Writeable.Reader<ModelStats> instanceReader() {
return ModelStats::new;
}

@Override
protected ModelStats createTestInstance() {
return createRandomInstance();
}

@Override
protected ModelStats mutateInstance(ModelStats modelStats) throws IOException {
ModelStats newModelStats = new ModelStats(modelStats);
newModelStats.add();
return newModelStats;
}
Comment on lines +28 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This implementation will only test that changes to count in ModelStats are correctly accounted for in the equals() method. Since the equals() method also considers the service and taskType fields, those should be included here too to give full coverage. Something like this maybe:

protected ModelStats mutateInstance(ModelStats modelStats) throws IOException {
    String service = modelStats.service();
    TaskType taskType = modelStats.taskType();
    long count = modelStats.count();
    return switch (randomInt(2)) {
        case 0 -> new ModelStats(randomValueOtherThan(service, ESTestCase::randomIdentifier), taskType, count);
        case 1 -> new ModelStats(service, randomValueOtherThan(taskType, () -> randomFrom(TaskType.values())), count);
        case 2 -> new ModelStats(service, taskType, randomValueOtherThan(count, ESTestCase::randomLong));
        default -> throw new IllegalArgumentException();
    };
}


public static ModelStats createRandomInstance() {
return new ModelStats(randomIdentifier(), TaskType.values()[randomInt(TaskType.values().length - 1)], randomInt(10));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor nitpick, but the second argument here would be more readable as randomFrom(TaskType.values())

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.elasticsearch.xpack.core.action.XPackUsageFeatureTransportAction;
import org.elasticsearch.xpack.core.inference.InferenceFeatureSetUsage;
import org.elasticsearch.xpack.core.inference.action.GetInferenceModelAction;
import org.elasticsearch.xpack.core.inference.usage.ModelStats;

import java.util.Map;
import java.util.TreeMap;
Expand Down Expand Up @@ -61,13 +62,10 @@ protected void localClusterStateOperation(
) {
GetInferenceModelAction.Request getInferenceModelAction = new GetInferenceModelAction.Request("_all", TaskType.ANY, false);
client.execute(GetInferenceModelAction.INSTANCE, getInferenceModelAction, ActionListener.wrap(response -> {
Map<String, InferenceFeatureSetUsage.ModelStats> stats = new TreeMap<>();
Map<String, ModelStats> stats = new TreeMap<>();
for (ModelConfigurations model : response.getEndpoints()) {
String statKey = model.getService() + ":" + model.getTaskType().name();
InferenceFeatureSetUsage.ModelStats stat = stats.computeIfAbsent(
statKey,
key -> new InferenceFeatureSetUsage.ModelStats(model.getService(), model.getTaskType())
);
ModelStats stat = stats.computeIfAbsent(statKey, key -> new ModelStats(model.getService(), model.getTaskType()));
stat.add();
}
InferenceFeatureSetUsage usage = new InferenceFeatureSetUsage(stats.values());
Expand Down