diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -15,6 +15,7 @@ #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -12,6 +12,7 @@ include "mlir/Dialect/Tensor/IR/TensorBase.td" include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" +include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -99,8 +100,7 @@ }]; let builders = [ - OpBuilder<(ins "Value":$source, "int64_t":$index)>, - OpBuilder<(ins "Value":$source, "Value":$index)> + OpBuilder<(ins "Value":$source, "int64_t":$index)> ]; let extraClassDeclaration = [{ @@ -432,6 +432,8 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides< Tensor_Dialect, "insert_slice", [NoSideEffect, AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface, + DeclareOpInterfaceMethods, TypesMatchWith<"expected result type to match dest type", "dest", "result", "$_self">]> { let summary = "insert_slice operation"; diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -15,6 +15,7 @@ MLIRCastInterfaces MLIRDialectUtils MLIRIR + MLIRInferTypeOpInterface MLIRSideEffectInterfaces MLIRSupport MLIRStandard diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -203,12 +203,6 @@ build(builder, result, source, indexValue); } -void DimOp::build(OpBuilder &builder, OperationState &result, Value source, - Value index) { - auto indexTy = builder.getIndexType(); - build(builder, result, indexTy, source, index); -} - Optional DimOp::getConstantIndex() { if (auto constantOp = index().getDefiningOp()) return constantOp.getValue().cast().getInt(); @@ -1048,6 +1042,17 @@ return OpFoldResult(); } +LogicalResult InsertSliceOp::reifyReturnTypeShapesPerResultDim( + OpBuilder &builder, + SmallVectorImpl> &reifiedReturnShapes) { + reifiedReturnShapes.resize(1, SmallVector(getType().getRank())); + for (auto dim : llvm::seq(0, getType().getRank())) { + reifiedReturnShapes[0][dim] = + builder.createOrFold(getLoc(), dest(), dim); + } + return success(); +} + namespace { /// Pattern to rewrite a insert_slice op with constant arguments. class InsertSliceOpConstantArgumentFolder final diff --git a/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir b/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/resolve-shaped-type-result-dims.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt -resolve-shaped-type-result-dims -split-input-file %s | FileCheck %s + +func @insert_slice( + %arg0 : tensor, %arg1 : tensor, + %arg2 : index, %arg3 : index, %arg4 : index) -> (index, index, index) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %d2 = tensor.dim %arg0, %c2 : tensor + %0 = tensor.insert_slice %arg0 into %arg1[%arg2, %arg3, %arg4] [%d0, %d1, %d2] [1, 1, 1] : tensor into tensor + %1 = tensor.dim %0, %c0 : tensor + %2 = tensor.dim %0, %c1 : tensor + %3 = tensor.dim %0, %c2 : tensor + return %1, %2, %3 : index, index, index +} +// CHECK-LABEL: func @insert_slice( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]] +// CHECK: return %[[D0]], %[[D1]], %[[D2]] diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3761,6 +3761,7 @@ deps = [ ":CastInterfacesTdFiles", ":ControlFlowInterfacesTdFiles", + ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ":ViewLikeInterfaceTdFiles", @@ -3814,6 +3815,7 @@ ":ControlFlowInterfaces", ":DialectUtils", ":IR", + ":InferTypeOpInterface", ":SideEffectInterfaces", ":StandardOps", ":Support",