diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -15,6 +15,7 @@ #include "mlir/Interfaces/CastInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/CopyOpInterface.h" +#include "mlir/Interfaces/ShapedOpInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -14,6 +14,7 @@ include "mlir/Interfaces/CastInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/CopyOpInterface.td" +include "mlir/Interfaces/ShapedOpInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/SymbolInterfaces.td" @@ -535,7 +536,8 @@ // DimOp //===----------------------------------------------------------------------===// -def MemRef_DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable]> { +def MemRef_DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable, + ShapedDimOpInterface]> { let summary = "dimension index operation"; let description = [{ The `dim` operation takes a memref and a dimension operand of type `index`. @@ -577,6 +579,12 @@ let extraClassDeclaration = [{ /// Helper function to get the index as a simple integer if it is constant. Optional getConstantIndex(); + + /// Interface method of ShapedDimOpInterface: Return the source memref. + Value getShapedValue() { return getSource(); } + + /// Interface method of ShapedDimOpInterface: Return the dimension. + OpFoldResult getDimension() { return getIndex(); } }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h --- a/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h +++ b/mlir/include/mlir/Dialect/Tensor/IR/Tensor.h @@ -18,6 +18,7 @@ #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/ParallelCombiningOpInterface.h" +#include "mlir/Interfaces/ShapedOpInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/TilingInterface.h" #include "mlir/Interfaces/ViewLikeInterface.h" diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -14,6 +14,7 @@ include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/ParallelCombiningOpInterface.td" +include "mlir/Interfaces/ShapedOpInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/TilingInterface.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -81,7 +82,7 @@ // DimOp //===----------------------------------------------------------------------===// -def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect]> { +def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect, ShapedDimOpInterface]> { let summary = "dimension index operation"; let description = [{ The `tensor.dim` operation takes a tensor and a dimension operand of type @@ -122,6 +123,12 @@ let extraClassDeclaration = [{ /// Helper function to get the index as a simple integer if it is constant. Optional getConstantIndex(); + + /// Interface method of ShapedDimOpInterface: Return the source tensor. + Value getShapedValue() { return getSource(); } + + /// Interface method of ShapedDimOpInterface: Return the dimension. + OpFoldResult getDimension() { return getIndex(); } }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(ParallelCombiningOpInterface) +add_mlir_interface(ShapedOpInterfaces) add_mlir_interface(SideEffectInterfaces) add_mlir_interface(TilingInterface) add_mlir_interface(VectorInterfaces) diff --git a/mlir/include/mlir/Interfaces/ShapedOpInterfaces.h b/mlir/include/mlir/Interfaces/ShapedOpInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/ShapedOpInterfaces.h @@ -0,0 +1,30 @@ +//===- ShapedOpInterfaces.h - Interfaces for Shaped Ops ---------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces for ops that operate on shaped values. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_SHAPEDOPINTERFACES_H_ +#define MLIR_INTERFACES_SHAPEDOPINTERFACES_H_ + +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace detail { + +/// Verify invariants of ops that implement the ShapedDimOpInterface. +LogicalResult verifyShapedDimOpInterface(Operation *op); + +} // namespace detail +} // namespace mlir + +/// Include the generated interface declarations. +#include "mlir/Interfaces/ShapedOpInterfaces.h.inc" + +#endif // MLIR_INTERFACES_SHAPEDOPINTERFACES_H_ diff --git a/mlir/include/mlir/Interfaces/ShapedOpInterfaces.td b/mlir/include/mlir/Interfaces/ShapedOpInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/ShapedOpInterfaces.td @@ -0,0 +1,56 @@ +//===-- ShapedOpInterfaces.td - Interfaces for Shaped Ops --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces for ops that operate on shaped values. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_SHAPEDOPINTERFACES +#define MLIR_INTERFACES_SHAPEDOPINTERFACES + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// ShapedDimOpInterface +//===----------------------------------------------------------------------===// + +// Ops that return the dimension of a shaped value. +def ShapedDimOpInterface : OpInterface<"ShapedDimOpInterface"> { + let description = [{ + An interface for ops that return the dimension of a shaped value (such as a + tensor or a memref). It provides access to the source shaped value and to + the dimension. + }]; + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the shaped value operand. This is the value that the dimension + is taken from. + }], + /*retTy=*/"::mlir::Value", + /*methodName=*/"getShapedValue", + /*args=*/(ins) + >, + InterfaceMethod< + /*desc=*/[{ + Return the dimension operand. This can be a constant or an SSA value. + }], + /*retTy=*/"::mlir::OpFoldResult", + /*methodName=*/"getDimension", + /*args=*/(ins) + > + ]; + + let verify = [{ + return verifyShapedDimOpInterface($_op); + }]; +} + +#endif // MLIR_INTERFACES_SHAPEDOPINTERFACES diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -9,13 +9,13 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ShapedOpInterfaces.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallBitVector.h" @@ -65,9 +65,9 @@ // `dim`, which can appear anywhere and be valid, since the defining // op won't be top-level anymore after inlining. Attribute operandCst; + bool isDimLikeOp = isa(value.getDefiningOp()); return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) || - value.getDefiningOp() || - value.getDefiningOp(); + isDimLikeOp; } /// Checks if all values known to be legal affine dimensions or symbols in `src` @@ -300,10 +300,8 @@ return applyOp.isValidDim(region); // The dim op is okay if its operand memref/tensor is defined at the top // level. - if (auto dimOp = dyn_cast(op)) - return isTopLevelValue(dimOp.getSource()); - if (auto dimOp = dyn_cast(op)) - return isTopLevelValue(dimOp.getSource()); + if (auto dimOp = dyn_cast(op)) + return isTopLevelValue(dimOp.getShapedValue()); return false; } @@ -324,24 +322,23 @@ } /// Returns true if the result of the dim op is a valid symbol for `region`. -template -static bool isDimOpValidSymbol(OpTy dimOp, Region *region) { +static bool isDimOpValidSymbol(ShapedDimOpInterface dimOp, Region *region) { // The dim op is okay if its source is defined at the top level. - if (isTopLevelValue(dimOp.getSource())) + if (isTopLevelValue(dimOp.getShapedValue())) return true; // Conservatively handle remaining BlockArguments as non-valid symbols. // E.g. scf.for iterArgs. - if (dimOp.getSource().template isa()) + if (dimOp.getShapedValue().template isa()) return false; // The dim op is also okay if its operand memref is a view/subview whose // corresponding size is a valid symbol. - Optional index = dimOp.getConstantIndex(); + Optional index = getConstantIntValue(dimOp.getDimension()); assert(index.has_value() && "expect only `dim` operations with a constant index"); int64_t i = index.value(); - return TypeSwitch(dimOp.getSource().getDefiningOp()) + return TypeSwitch(dimOp.getShapedValue().getDefiningOp()) .Case( [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); }) .Default([](Operation *) { return false; }); @@ -414,9 +411,7 @@ return applyOp.isValidSymbol(region); // Dim op results could be valid symbols at any level. - if (auto dimOp = dyn_cast(defOp)) - return isDimOpValidSymbol(dimOp, region); - if (auto dimOp = dyn_cast(defOp)) + if (auto dimOp = dyn_cast(defOp)) return isDimOpValidSymbol(dimOp, region); // Check for values dominating `region`'s parent op. diff --git a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Affine/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Affine/IR/CMakeLists.txt @@ -16,6 +16,6 @@ MLIRIR MLIRLoopLikeInterface MLIRMemRefDialect + MLIRShapedOpInterfaces MLIRSideEffectInterfaces - MLIRTensorDialect ) diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -19,6 +19,7 @@ MLIRDialectUtils MLIRInferTypeOpInterface MLIRIR + MLIRShapedOpInterfaces MLIRSideEffectInterfaces MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -27,6 +27,7 @@ MLIRIR MLIRInferTypeOpInterface MLIRParallelCombiningOpInterface + MLIRShapedOpInterfaces MLIRSideEffectInterfaces MLIRSupport MLIRViewLikeInterface diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -9,6 +9,7 @@ InferTypeOpInterface.cpp LoopLikeInterface.cpp ParallelCombiningOpInterface.cpp + ShapedOpInterfaces.cpp SideEffectInterfaces.cpp TilingInterface.cpp VectorInterfaces.cpp @@ -40,6 +41,7 @@ add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(ParallelCombiningOpInterface) +add_mlir_interface_library(ShapedOpInterfaces) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) add_mlir_interface_library(VectorInterfaces) diff --git a/mlir/lib/Interfaces/ShapedOpInterfaces.cpp b/mlir/lib/Interfaces/ShapedOpInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/ShapedOpInterfaces.cpp @@ -0,0 +1,26 @@ +//===- ShapedOpInterfaces.cpp - Interfaces for Shaped Ops -----------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/ShapedOpInterfaces.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// ShapedDimOpInterface +//===----------------------------------------------------------------------===// + +LogicalResult mlir::detail::verifyShapedDimOpInterface(Operation *op) { + if (op->getNumResults() != 1) + return op->emitError("expected single op result"); + if (!op->getResult(0).getType().isIndex()) + return op->emitError("expect index result type"); + return success(); +} + +/// Include the definitions of the interface. +#include "mlir/Interfaces/ShapedOpInterfaces.cpp.inc" diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -978,6 +978,13 @@ deps = [":OpBaseTdFiles"], ) +td_library( + name = "ShapedOpInterfacesTdFiles", + srcs = ["include/mlir/Interfaces/ShapedOpInterfaces.td"], + includes = ["include"], + deps = [":OpBaseTdFiles"], +) + td_library( name = "ParallelCombiningOpInterfaceTdFiles", srcs = ["include/mlir/Interfaces/ParallelCombiningOpInterface.td"], @@ -2373,8 +2380,8 @@ ":IR", ":LoopLikeInterface", ":MemRefDialect", + ":ShapedOpInterfaces", ":SideEffectInterfaces", - ":TensorDialect", "//llvm:Support", ], ) @@ -2791,6 +2798,18 @@ ], ) +cc_library( + name = "ShapedOpInterfaces", + srcs = ["lib/Interfaces/ShapedOpInterfaces.cpp"], + hdrs = ["include/mlir/Interfaces/ShapedOpInterfaces.h"], + includes = ["include"], + deps = [ + ":IR", + ":ShapedOpInterfacesIncGen", + "//llvm:Support", + ], +) + cc_library( name = "ParallelCombiningOpInterface", srcs = ["lib/Interfaces/ParallelCombiningOpInterface.cpp"], @@ -4964,6 +4983,7 @@ ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", ":ParallelCombiningOpInterfaceTdFiles", + ":ShapedOpInterfacesTdFiles", ":SideEffectInterfacesTdFiles", ":TilingInterfaceTdFiles", ":ViewLikeInterfaceTdFiles", @@ -5021,6 +5041,7 @@ ":IR", ":InferTypeOpInterface", ":ParallelCombiningOpInterface", + ":ShapedOpInterfaces", ":SideEffectInterfaces", ":TensorOpsIncGen", ":TilingInterface", @@ -5289,6 +5310,24 @@ deps = [":LoopLikeInterfaceTdFiles"], ) +gentbl_cc_library( + name = "ShapedOpInterfacesIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Interfaces/ShapedOpInterfaces.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Interfaces/ShapedOpInterfaces.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Interfaces/ShapedOpInterfaces.td", + deps = [":ShapedOpInterfacesTdFiles"], +) + gentbl_cc_library( name = "ParallelCombiningOpInterfaceIncGen", strip_include_prefix = "include", @@ -8871,6 +8910,7 @@ ":ControlFlowInterfacesTdFiles", ":CopyOpInterfaceTdFiles", ":OpBaseTdFiles", + ":ShapedOpInterfacesTdFiles", ":SideEffectInterfacesTdFiles", ":ViewLikeInterfaceTdFiles", ], @@ -8915,7 +8955,9 @@ ], tblgen = ":mlir-tblgen", td_file = "include/mlir/Dialect/MemRef/IR/MemRefOps.td", - deps = [":MemRefOpsTdFiles"], + deps = [ + ":MemRefOpsTdFiles", + ], ) cc_library( @@ -8942,6 +8984,7 @@ ":InferTypeOpInterface", ":MemRefBaseIncGen", ":MemRefOpsIncGen", + ":ShapedOpInterfaces", ":ViewLikeInterface", "//llvm:Support", ],