Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -2000,6 +2000,7 @@ cc_library(
deps = [
":attr_type_builder_util",
":mlir_builder",
":stablehlo_broadcast_lowering",
":stablehlo_builder_inc",
":stablehlo_ops",
":stablehlo_type_inference",
Expand Down
4 changes: 2 additions & 2 deletions WORKSPACE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ workspace(name = "stablehlo")

load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")

LLVM_COMMIT = "125948a058dcd89b7fe377872a5fc1a7f9d34e70"
LLVM_COMMIT = "f6d0a512972a74ef100723b9526a6a0ddb23f894"

LLVM_SHA256 = "8f79c221169dec52116dea2202be64aa653584f98b03a715c30d74ea9141328b"
LLVM_SHA256 = "75dba7f15864c9ddc25dd621dcaf2d325a9ca8f23957ff4eb6b01df5b493b5d5"

http_archive(
name = "llvm-raw",
Expand Down
2 changes: 1 addition & 1 deletion build_tools/llvm_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
125948a058dcd89b7fe377872a5fc1a7f9d34e70
f6d0a512972a74ef100723b9526a6a0ddb23f894
39 changes: 39 additions & 0 deletions stablehlo/integrations/cpp/builder/StablehloBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <cstdint>
#include <optional>

#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/ErrorHandling.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
Expand All @@ -30,6 +31,7 @@ limitations under the License.
#include "stablehlo/dialect/TypeInference.h"
#include "stablehlo/integrations/cpp/builder/AttrTypeBuilderUtil.h"
#include "stablehlo/integrations/cpp/builder/MlirBuilder.h"
#include "stablehlo/transforms/StablehloBroadcastLowering.h"

namespace mlir {
namespace stablehlo {
Expand Down Expand Up @@ -94,6 +96,43 @@ MlirOp Constant(MlirBuilder& builder, std::vector<int64_t> value) {
value));
}


MlirOp IotaLike(MlirOp input, int64_t dim, Type elementType) {
auto inputType = mlir::cast<RankedTensorType>(input.getType());
if (inputType.hasStaticShape()) {
return stablehlo::Iota(input.getBuilder(), inputType.clone(elementType),
dim);
}

// Use input's static shape and slice to the dynamic shape.
auto dims = mlir::stablehlo::getDimensions(input.getValue());
if (mlir::failed(dims)) llvm::report_fatal_error(
"failed to create dynamically shaped iota op, with MLIR error: ");

mlir::SmallVector<int64_t> iotaShape = llvm::map_to_vector(
*dims,
[&](mlir::stablehlo::DimensionInfo dim_size) { return dim_size.size; });
auto iotaType =
mlir::makeTensorType(input.getContext(), iotaShape, elementType);
mlir::MlirOp iota = mlir::stablehlo::Iota(input.getBuilder(), iotaType, dim);

// Slice bounded dimensions to the dynamic shape.
for (const mlir::stablehlo::DimensionInfo& dim : *dims) {
if (!dim.boundOp.has_value()) continue;

auto runtime_dim_size =
mlir::stablehlo::GetDimensionSize(input, dim.boundOpDim);
iota = mlir::stablehlo::SetDimensionSize(iota, runtime_dim_size,
dim.boundOpDim);
}
return iota;
}

MlirOp IotaLike(MlirOp input, int64_t dim, ElementType elementType) {
auto resultElementType = getElementType(input.getContext(), elementType);
return IotaLike(input, dim, resultElementType);
}

namespace {

// Use preferred element type, if not use LHS element type.
Expand Down
6 changes: 6 additions & 0 deletions stablehlo/integrations/cpp/builder/StablehloBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@ MlirOp ConvertElementType(MlirOp input, Type resultElementType);
MlirOp Constant(MlirBuilder& builder, int64_t value);
MlirOp Constant(MlirBuilder& builder, std::vector<int64_t> value);

// IotaLike is a sugar API for iota that accounts for bounded dynamism in the
// input tensor. Eventually this should be a chlo.iota_like op with a StableHLO
// decomposition, but for now it will be housed as a builder API.
MlirOp IotaLike(MlirOp input, int64_t dim, ElementType elementType);
MlirOp IotaLike(MlirOp input, int64_t dim, Type elementType);

// Better Dot / DotGeneral builders.
// These ops don't support full type inference because the result element type
// cannot be inferred from operands, however the result shape can be.
Expand Down
50 changes: 50 additions & 0 deletions stablehlo/integrations/cpp/builder/StablehloBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,56 @@ TEST(MlirBuilderTest, DotGeneralOp) {
EXPECT_EQ(expected, debugString(*module));
}

TEST(MlirBuilderTest, IotaLikeStatic) {
std::string expected = R"mlir(module {
func.func @main(%arg0: tensor<2x3xi64>) -> tensor<2x3xi64> {
%0 = stablehlo.iota dim = 1 : tensor<2x3xi64>
return %0 : tensor<2x3xi64>
}
})mlir";
StablehloModuleBuilder mb;
{ // Build Main Func
func::FunctionBuilder fb(mb.get(), "main");
auto& ctx = fb.getContext();
auto type2x3xi64 = makeTensorType(ctx, {2, 3}, ElementType::I64);
auto arg0 = func::Argument(fb, type2x3xi64);
auto iota = stablehlo::IotaLike(arg0, 1, type2x3xi64.getElementType());
func::Return(fb, iota);
}

OwningOpRef<ModuleOp> module = mb->build();
EXPECT_TRUE(succeeded(mlir::verify(*module)));
EXPECT_EQ(expected, debugString(*module));
}

TEST(MlirBuilderTest, IotaLikeDynamic) {
std::string expected = R"mlir(module {
func.func @main(%arg0: tensor<2x3xi64>, %arg1: tensor<i32>) -> tensor<?x3xi64, #stablehlo.bounds<2, ?>> {
%0 = stablehlo.set_dimension_size %arg0, %arg1, dim = 0 : (tensor<2x3xi64>, tensor<i32>) -> tensor<?x3xi64, #stablehlo.bounds<2, ?>>
%1 = stablehlo.iota dim = 1 : tensor<2x3xi64>
%2 = stablehlo.get_dimension_size %0, dim = 0 : (tensor<?x3xi64, #stablehlo.bounds<2, ?>>) -> tensor<i32>
%3 = stablehlo.set_dimension_size %1, %2, dim = 0 : (tensor<2x3xi64>, tensor<i32>) -> tensor<?x3xi64, #stablehlo.bounds<2, ?>>
return %3 : tensor<?x3xi64, #stablehlo.bounds<2, ?>>
}
})mlir";
StablehloModuleBuilder mb;
{ // Build Main Func
func::FunctionBuilder fb(mb.get(), "main");
auto& ctx = fb.getContext();
auto type2x3xi64 = makeTensorType(ctx, {2, 3}, ElementType::I64);
auto typei32 = makeTensorType(ctx, {}, ElementType::I32);
auto arg0 = func::Argument(fb, type2x3xi64);
auto arg1 = func::Argument(fb, typei32);
auto sds = stablehlo::SetDimensionSize(arg0, arg1, 0);
auto iota = stablehlo::IotaLike(sds, 1, type2x3xi64.getElementType());
func::Return(fb, iota);
}

OwningOpRef<ModuleOp> module = mb->build();
EXPECT_TRUE(succeeded(mlir::verify(*module)));
EXPECT_EQ(expected, debugString(*module));
}

TEST(MlirBuilderTest, ReduceOp) {
std::string expected = R"mlir(module {
func.func @main(%arg0: tensor<2xi64>) -> tensor<i64> {
Expand Down
Loading