diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -45,12 +45,6 @@ /// Check if iterator type has "reduction" semantics. bool isReductionIterator(utils::IteratorType iteratorType); -/// Helper function that creates a memref::DimOp or tensor::DimOp depending on -/// the type of `source`. -Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim); -OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, - int64_t dim); - /// Given an operation, retrieves the value of each dynamic dimension through /// constructing the necessary DimOp operators. SmallVector getDynOperands(Location loc, Value val, OpBuilder &b); diff --git a/mlir/include/mlir/Dialect/MemRef/IR/ReifyShapeDimTypeInterfaceImpl.h b/mlir/include/mlir/Dialect/MemRef/IR/ReifyShapeDimTypeInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/IR/ReifyShapeDimTypeInterfaceImpl.h @@ -0,0 +1,21 @@ +//===- ReifyShapeDimTypeInterfaceImpl.h - Interface Implementation --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_IR_REIFYSHAPEDIMTYPEINTERFACEIMPL_H +#define MLIR_DIALECT_MEMREF_IR_REIFYSHAPEDIMTYPEINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace memref { +void registerReifyShapeDimTypeInterfaceExternalModels( + DialectRegistry ®istry); +} // namespace memref +} // namespace mlir + +#endif // MLIR_DIALECT_MEMREF_IR_REIFYSHAPEDIMTYPEINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/Tensor/IR/ReifyShapeDimTypeInterfaceImpl.h b/mlir/include/mlir/Dialect/Tensor/IR/ReifyShapeDimTypeInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Tensor/IR/ReifyShapeDimTypeInterfaceImpl.h @@ -0,0 +1,21 @@ +//===- ReifyShapeDimTypeInterfaceImpl.h - Interface Implementation --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TENSOR_IR_REIFYSHAPEDIMTYPEINTERFACEIMPL_H +#define MLIR_DIALECT_TENSOR_IR_REIFYSHAPEDIMTYPEINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace tensor { +void registerReifyShapeDimTypeInterfaceExternalModels( + DialectRegistry ®istry); +} // namespace tensor +} // namespace mlir + +#endif // MLIR_DIALECT_TENSOR_IR_REIFYSHAPEDIMTYPEINTERFACEIMPL_H diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -30,12 +30,6 @@ class VectorType; class VectorTransferOpInterface; -namespace vector { -/// Helper function that creates a memref::DimOp or tensor::DimOp depending on -/// the type of `source`. -Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim); -} // namespace vector - /// Constructs a permutation map of invariant memref indices to vector /// dimension. /// diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -44,6 +44,7 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/IR/ReifyShapeDimTypeInterfaceImpl.h" #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.h" #include "mlir/Dialect/MemRef/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/MemRef/Transforms/RuntimeOpVerification.h" @@ -61,6 +62,7 @@ #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h" +#include "mlir/Dialect/Tensor/IR/ReifyShapeDimTypeInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" @@ -133,12 +135,14 @@ linalg::registerBufferizableOpInterfaceExternalModels(registry); linalg::registerTilingInterfaceExternalModels(registry); memref::registerBufferizableOpInterfaceExternalModels(registry); + memref::registerReifyShapeDimTypeInterfaceExternalModels(registry); memref::registerRuntimeVerifiableOpInterfaceExternalModels(registry); scf::registerBufferizableOpInterfaceExternalModels(registry); shape::registerBufferizableOpInterfaceExternalModels(registry); sparse_tensor::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerBufferizableOpInterfaceExternalModels(registry); tensor::registerInferTypeOpInterfaceExternalModels(registry); + tensor::registerReifyShapeDimTypeInterfaceExternalModels(registry); tensor::registerTilingInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); } diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_interface(ParallelCombiningOpInterface) add_mlir_interface(RuntimeVerifiableOpInterface) add_mlir_interface(ShapedOpInterfaces) +add_mlir_interface(ShapedTypeInterfaces) add_mlir_interface(SideEffectInterfaces) add_mlir_interface(TilingInterface) add_mlir_interface(VectorInterfaces) diff --git a/mlir/include/mlir/Interfaces/ShapedTypeInterfaces.h b/mlir/include/mlir/Interfaces/ShapedTypeInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/ShapedTypeInterfaces.h @@ -0,0 +1,37 @@ +//===- ShapedTypeInterfaces.h - Interfaces for Shaped Types -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces for shaped types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_SHAPEDTYPEINTERFACES_H_ +#define MLIR_INTERFACES_SHAPEDTYPEINTERFACES_H_ + +#include "mlir/IR/OpDefinition.h" + +/// Include the generated interface declarations. +#include "mlir/Interfaces/ShapedTypeInterfaces.h.inc" + +namespace mlir { + +/// Reify the specified dimension of the given shaped value. This function +/// returns an IntegerAttr if `dim` is a constant and the respective dimension +/// is static. Otherwise, it returns a Value. +OpFoldResult reifyShapeDim(OpBuilder &builder, Location loc, Value shapedValue, + OpFoldResult dim); + +/// Reify the specified dimension of the given shaped value. This function +/// returns an IntegerAttr if the dimension size is static. Otherwise, it +/// returns a Value. +OpFoldResult reifyShapeDim(OpBuilder &builder, Location loc, Value shapedValue, + int64_t dim); + +} // namespace mlir + +#endif // MLIR_INTERFACES_SHAPEDTYPEINTERFACES_H_ diff --git a/mlir/include/mlir/Interfaces/ShapedTypeInterfaces.td b/mlir/include/mlir/Interfaces/ShapedTypeInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/ShapedTypeInterfaces.td @@ -0,0 +1,46 @@ +//===-- ShapedTypeInterfaces.td - Interf. for Shaped Types -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains a set of interfaces for shaped types. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_SHAPEDTYPEINTERFACES +#define MLIR_INTERFACES_SHAPEDTYPEINTERFACES + +include "mlir/IR/OpBase.td" + +//===----------------------------------------------------------------------===// +// ReifyShapeDimTypeInterface +//===----------------------------------------------------------------------===// + +def ReifyShapeDimTypeInterface : TypeInterface<"ReifyShapeDimTypeInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + This is an interface for shaped types to reify dimension sizes. A reified + dynamic dimension is typically the result of a `DimOp` in the respective + dialect. + }]; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Reify the specified dimension of the given shaped value. If the + dimension is static, an `IntegerAttr` should be returned. Otherwise, a + `Value` should be returned. + }], + /*retTy=*/"::mlir::OpFoldResult", + /*methodName=*/"reifyShapeDim", + /*args=*/(ins "::mlir::OpBuilder &":$builder, + "::mlir::Location":$loc, + "::mlir::Value":$shapedValue, + "::mlir::OpFoldResult":$dim) + >, + ]; +} + +#endif // MLIR_INTERFACES_SHAPEDTYPEINTERFACES diff --git a/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt b/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt --- a/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt +++ b/mlir/lib/Conversion/VectorToSCF/CMakeLists.txt @@ -11,6 +11,7 @@ MLIRArithDialect MLIRLLVMDialect MLIRMemRefDialect + MLIRShapedTypeInterfaces MLIRTransforms MLIRVectorDialect MLIRVectorTransforms diff --git a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp --- a/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp +++ b/mlir/lib/Conversion/VectorToSCF/VectorToSCF.cpp @@ -23,6 +23,7 @@ #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/Builders.h" #include "mlir/IR/ImplicitLocOpBuilder.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" @@ -172,8 +173,8 @@ Location loc = xferOp.getLoc(); ImplicitLocOpBuilder lb(xferOp.getLoc(), b); if (!xferOp.isDimInBounds(0) && !isBroadcast) { - Value memrefDim = - vector::createOrFoldDimOp(b, loc, xferOp.getSource(), *dim); + Value memrefDim = getValueOrCreateConstantIndexOp( + b, loc, reifyShapeDim(b, loc, xferOp.getSource(), *dim)); AffineExpr d0, d1; bindDims(xferOp.getContext(), d0, d1); Value base = xferOp.getIndices()[*dim]; diff --git a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/IR/CMakeLists.txt @@ -23,6 +23,7 @@ MLIRInferTypeOpInterface MLIRIR MLIRParser + MLIRShapedTypeInterfaces MLIRSideEffectInterfaces MLIRSparseTensorDialect MLIRSCFDialect diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -17,6 +17,7 @@ #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "llvm/ADT/SmallBitVector.h" using namespace mlir; @@ -502,30 +503,12 @@ // StructuredOpInterface implementation //===----------------------------------------------------------------------===// -/// Helper function that creates a memref::DimOp or tensor::DimOp depending on -/// the type of `source`. -static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, - int64_t dim) { - if (source.getType().isa()) - return b.createOrFold(loc, source, dim); - if (source.getType().isa()) - return b.createOrFold(loc, source, dim); - llvm_unreachable("Expected MemRefType or TensorType"); -} -static OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, - int64_t dim) { - auto shapedType = source.getType().cast(); - if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) - return createOrFoldDimOp(b, loc, source, dim); - return b.getIndexAttr(shapedType.getDimSize(dim)); -} - SmallVector LinalgOp::createFlatListOfOperandDims(OpBuilder &b, Location loc) { SmallVector res; for (OpOperand &opOperand : getOperation()->getOpOperands()) { for (int64_t i = 0, e = getRank(&opOperand); i < e; ++i) - res.push_back(createFoldedDimOp(b, loc, opOperand.get(), i)); + res.push_back(reifyShapeDim(b, loc, opOperand.get(), i)); } return res; } @@ -652,7 +635,8 @@ // Dynamic dim: Return Value. OpFoldResult ofr = checkDimExpr.visit(shapeExprs[pos]) - ? createOrFoldDimOp(b, loc, opOperand->get(), dim) + ? getValueOrCreateConstantIndexOp( + b, loc, reifyShapeDim(b, loc, opOperand->get(), dim)) : allResultDimValues[pos]; shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr)); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -55,14 +55,15 @@ MLIRFuncTransforms MLIRInferTypeOpInterface MLIRIR - MLIRMemRefDialect - MLIRMemRefTransforms MLIRLinalgDialect MLIRLinalgUtils + MLIRMemRefDialect + MLIRMemRefTransforms + MLIRPass MLIRSCFDialect MLIRSCFTransforms MLIRSCFUtils - MLIRPass + MLIRShapedTypeInterfaces MLIRSparseTensorDialect MLIRTensorDialect MLIRTensorTilingInterfaceImpl diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Dominance.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/RegionUtils.h" @@ -119,7 +120,7 @@ for (unsigned i = 0, e = producer.getNumLoops(); i < e; ++i) { auto shapeDim = getShapeDefiningLoopRange(producer, i); OpFoldResult dim = - createFoldedDimOp(b, loc, shapeDim.shape, shapeDim.dimension); + reifyShapeDim(b, loc, shapeDim.shape, shapeDim.dimension); sizeBounds.push_back(dim); auto it = fusedLoopsAndRanges.find(i); if (it != fusedLoopsAndRanges.end()) { diff --git a/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt @@ -13,6 +13,7 @@ MLIRIR MLIRLinalgDialect MLIRSCFDialect + MLIRShapedTypeInterfaces MLIRPass MLIRTensorUtils MLIRTransformUtils diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -32,6 +32,7 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Pass/Pass.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -184,24 +185,6 @@ return iteratorType == utils::IteratorType::reduction; } -/// Helper function that creates a memref::DimOp or tensor::DimOp depending on -/// the type of `source`. -Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim) { - if (source.getType().isa()) - return b.createOrFold(loc, source, dim); - if (source.getType().isa()) - return b.createOrFold(loc, source, dim); - llvm_unreachable("Expected MemRefType or TensorType"); -} - -OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, - int64_t dim) { - auto shapedType = source.getType().cast(); - if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) - return createOrFoldDimOp(b, loc, source, dim); - return b.getIndexAttr(shapedType.getDimSize(dim)); -} - /// Given an operation, retrieves the value of each dynamic dimension through /// constructing the necessary DimOp operators. SmallVector getDynOperands(Location loc, Value val, OpBuilder &b) { @@ -209,7 +192,8 @@ auto shapedType = val.getType().cast(); for (const auto &dim : llvm::enumerate(shapedType.getShape())) { if (dim.value() == ShapedType::kDynamic) - dynOperands.push_back(createOrFoldDimOp(b, loc, val, dim.index())); + dynOperands.push_back(getValueOrCreateConstantIndexOp( + b, loc, reifyShapeDim(b, loc, val, dim.index()))); } return dynOperands; } @@ -763,7 +747,7 @@ LLVM_DEBUG(llvm::dbgs() << "computeSliceParameters: for dim#" << r); if (!isTiled(map.getSubMap({r}), tileSizes)) { sliceParams.offsets.push_back(builder.getIndexAttr(0)); - OpFoldResult dim = createFoldedDimOp(builder, loc, valueToTile, r); + OpFoldResult dim = reifyShapeDim(builder, loc, valueToTile, r); sliceParams.sizes.push_back(dim); sliceParams.strides.push_back(builder.getIndexAttr(1)); LLVM_DEBUG(llvm::dbgs() << ": not tiled: use size: " << dim << "\n"); diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt --- a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(MLIRMemRefDialect MemRefDialect.cpp MemRefOps.cpp + ReifyShapeDimTypeInterfaceImpl.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect @@ -20,6 +21,7 @@ MLIRInferTypeOpInterface MLIRIR MLIRShapedOpInterfaces + MLIRShapedTypeInterfaces MLIRSideEffectInterfaces MLIRViewLikeInterface ) diff --git a/mlir/lib/Dialect/MemRef/IR/ReifyShapeDimTypeInterfaceImpl.cpp b/mlir/lib/Dialect/MemRef/IR/ReifyShapeDimTypeInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/IR/ReifyShapeDimTypeInterfaceImpl.cpp @@ -0,0 +1,53 @@ +//===- ReifyShapeDimTypeInterfaceImpl.cpp - Interface Implementation ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/ReifyShapeDimTypeInterfaceImpl.h" + +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" + +using namespace mlir; +using namespace mlir::memref; + +namespace mlir { +namespace memref { +namespace { + +template +struct MemRefTypeInterface + : public ReifyShapeDimTypeInterface::ExternalModel, + T> { + OpFoldResult reifyShapeDim(Type type, OpBuilder &builder, Location loc, + Value shapedValue, OpFoldResult dim) const { + BaseMemRefType memrefType = type.cast(); + + // Static dimension size: Return OpFoldResult. + auto constDim = getConstantIntValue(dim); + if (constDim.has_value() && memrefType.hasRank() && + !memrefType.isDynamicDim(*constDim)) + return builder.getIndexAttr(memrefType.getDimSize(*constDim)); + + // Dynamic dimension size: Create tensor.dim op. + return builder.createOrFold( + loc, shapedValue, getValueOrCreateConstantIndexOp(builder, loc, dim)); + } +}; + +} // namespace +} // namespace memref +} // namespace mlir + +void mlir::memref::registerReifyShapeDimTypeInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, memref::MemRefDialect *dialect) { + MemRefType::attachInterface>(*ctx); + UnrankedMemRefType::attachInterface< + MemRefTypeInterface>(*ctx); + }); +} diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CMakeLists.txt @@ -35,6 +35,7 @@ MLIRSCFDialect MLIRSCFTransforms MLIRSCFUtils + MLIRShapedTypeInterfaces MLIRSparseTensorDialect MLIRSparseTensorEnums MLIRSparseTensorUtils diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/CodegenUtils.cpp @@ -18,6 +18,7 @@ #include "mlir/IR/Matchers.h" #include "mlir/IR/Types.h" #include "mlir/IR/Value.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include using namespace mlir; @@ -484,7 +485,8 @@ const Value srcCoordinates = splitSrc->first; const Value srcValues = splitSrc->second; lo.push_back(zero); - hi.push_back(linalg::createOrFoldDimOp(builder, loc, srcValues, 0)); + hi.push_back(getValueOrCreateConstantIndexOp( + builder, loc, reifyShapeDim(builder, loc, srcValues, 0))); st.push_back(one); scf::buildLoopNest(builder, loc, lo, hi, st, {}, [&](OpBuilder &builder, Location loc, ValueRange ivs, @@ -498,7 +500,8 @@ } else { for (unsigned i = 0; i < rank; i++) { lo.push_back(zero); - hi.push_back(linalg::createOrFoldDimOp(builder, loc, src, i)); + hi.push_back(getValueOrCreateConstantIndexOp( + builder, loc, reifyShapeDim(builder, loc, src, i))); st.push_back(one); } scf::buildLoopNest(builder, loc, lo, hi, st, {}, @@ -517,7 +520,8 @@ Location loc, Value src) { const Dimension dimRank = getSparseTensorType(src).getDimRank(); for (Dimension d = 0; d < dimRank; d++) - sizes.push_back(linalg::createOrFoldDimOp(builder, loc, src, d)); + sizes.push_back(getValueOrCreateConstantIndexOp( + builder, loc, reifyShapeDim(builder, loc, src, d))); } Operation *mlir::sparse_tensor::getTop(Operation *op) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/LoopEmitter.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" using namespace mlir; using namespace mlir::sparse_tensor; @@ -320,8 +321,8 @@ // FIXME: `toOrigDim` is deprecated // Since we do not have HigherOrdering now, we can always rely on the 1:1 // mapping from level to dimension to retrieve the level size. - Value lvlSz = mlir::linalg::createOrFoldDimOp(builder, loc, tensor, - toOrigDim(enc, l)); + Value lvlSz = getValueOrCreateConstantIndexOp( + builder, loc, reifyShapeDim(builder, loc, tensor, toOrigDim(enc, l))); // Find upper bound in current dimension. highs[t][l] = lvlSizes[t][l] = lvlSz; if (isSparseSlices[t]) { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp @@ -30,6 +30,7 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Transforms/DialectConversion.h" #include @@ -579,7 +580,8 @@ auto retTp = MemRefType::get(ArrayRef{len}, memTp.getElementType()); Value targetLen = constantIndex(builder, loc, len); - Value bufferLen = linalg::createOrFoldDimOp(builder, loc, buffer, 0); + Value bufferLen = getValueOrCreateConstantIndexOp( + builder, loc, reifyShapeDim(builder, loc, buffer, 0)); // Reallocates if target length is greater than the actual buffer len. Value reallocP = builder.create(loc, arith::CmpIPredicate::ugt, targetLen, bufferLen); @@ -1086,7 +1088,8 @@ // TODO: We can instead use the actual memSize in specifier, that // would require a subViewOp to avoid overflow when copying // values. - Value sz = linalg::createOrFoldDimOp(rewriter, loc, srcMem, 0); + Value sz = getValueOrCreateConstantIndexOp( + rewriter, loc, reifyShapeDim(rewriter, loc, srcMem, 0)); auto dstMem = rewriter.create( loc, fTp.cast(), sz); if (fTp != srcMem.getType()) { @@ -1274,7 +1277,8 @@ }); MutSparseTensorDescriptor desc(rtp, fields); - auto noe = linalg::createOrFoldDimOp(rewriter, loc, op.getValues(), 0); + auto noe = getValueOrCreateConstantIndexOp( + rewriter, loc, reifyShapeDim(rewriter, loc, op.getValues(), 0)); // FIXME: should use `SparseTensorType::getLvlRank` in lieu of // `RankedTensorType::getRank`, because the latter introduces dim/lvl // ambiguity. diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -28,6 +28,7 @@ #include "mlir/Dialect/SparseTensor/IR/SparseTensorType.h" #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -108,8 +109,7 @@ /// Looks up a dimension-size by returning a constant from the shape /// (for static sizes), or by calling `genDimSizeCall` (for dynamic sizes -/// of sparse tensors) or `linalg::createOrFoldDimOp` (for dynamic sizes -/// of dense tensors). +/// of sparse tensors) or `reifyShapeDim` (for dynamic sizes of dense tensors). static Value createOrFoldDimCall(OpBuilder &builder, Location loc, SparseTensorType stt, Value tensor, Dimension dim) { @@ -117,7 +117,8 @@ return constantIndex(builder, loc, *sz); if (stt.hasEncoding()) return genDimSizeCall(builder, loc, tensor, dim); - return linalg::createOrFoldDimOp(builder, loc, tensor, dim); + return getValueOrCreateConstantIndexOp( + builder, loc, reifyShapeDim(builder, loc, tensor, dim)); } /// Populates the array with the dimension-sizes of the given tensor. @@ -614,7 +615,8 @@ // Fill out loop iteration information. for (Dimension d = 0; d < dimRank; d++) { lo.push_back(zero); - hi.push_back(linalg::createOrFoldDimOp(rewriter, loc, t, d)); + hi.push_back(getValueOrCreateConstantIndexOp( + rewriter, loc, reifyShapeDim(rewriter, loc, t, d))); st.push_back(one); } diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorRewriting.cpp @@ -25,6 +25,7 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Support/LLVM.h" using namespace mlir; @@ -195,7 +196,8 @@ } else { // Else, compute the shape dynamically. for (const auto &src : srcs.drop_front()) { - Value srcSz = linalg::createOrFoldDimOp(builder, loc, src, dim); + Value srcSz = getValueOrCreateConstantIndexOp( + builder, loc, reifyShapeDim(builder, loc, src, dim)); // Sum up all the sizes. sizes[dim] = builder.create(loc, sizes[dim], srcSz); } diff --git a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/IR/CMakeLists.txt @@ -6,6 +6,7 @@ ) add_mlir_dialect_library(MLIRTensorDialect + ReifyShapeDimTypeInterfaceImpl.cpp TensorDialect.cpp TensorOps.cpp @@ -30,6 +31,7 @@ MLIRInferTypeOpInterface MLIRParallelCombiningOpInterface MLIRShapedOpInterfaces + MLIRShapedTypeInterfaces MLIRSideEffectInterfaces MLIRSupport MLIRViewLikeInterface diff --git a/mlir/lib/Dialect/Tensor/IR/ReifyShapeDimTypeInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/ReifyShapeDimTypeInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Tensor/IR/ReifyShapeDimTypeInterfaceImpl.cpp @@ -0,0 +1,54 @@ +//===- ReifyShapeDimTypeInterfaceImpl.cpp - Interface Implementation ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Tensor/IR/ReifyShapeDimTypeInterfaceImpl.h" + +#include "mlir/Dialect/Arith/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" + +using namespace mlir; +using namespace mlir::tensor; + +namespace mlir { +namespace tensor { +namespace { + +template +struct TensorTypeInterface + : public ReifyShapeDimTypeInterface::ExternalModel, + T> { + OpFoldResult reifyShapeDim(Type type, OpBuilder &builder, Location loc, + Value shapedValue, OpFoldResult dim) const { + TensorType tensorType = type.cast(); + + // Static dimension size: Return OpFoldResult. + auto constDim = getConstantIntValue(dim); + if (constDim.has_value() && tensorType.hasRank() && + !tensorType.isDynamicDim(*constDim)) + return builder.getIndexAttr(tensorType.getDimSize(*constDim)); + + // Dynamic dimension size: Create tensor.dim op. + return builder.createOrFold( + loc, shapedValue, getValueOrCreateConstantIndexOp(builder, loc, dim)); + } +}; + +} // namespace +} // namespace tensor +} // namespace mlir + +void mlir::tensor::registerReifyShapeDimTypeInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, tensor::TensorDialect *dialect) { + RankedTensorType::attachInterface>( + *ctx); + UnrankedTensorType::attachInterface< + TensorTypeInterface>(*ctx); + }); +} diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt @@ -26,6 +26,7 @@ MLIRMemRefDialect MLIRPass MLIRSCFDialect + MLIRShapedTypeInterfaces MLIRTensorDialect MLIRTilingInterface MLIRTransforms diff --git a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/ExtractSliceFromReshapeUtils.cpp @@ -10,6 +10,7 @@ // aggregated slices of the reshape source. // //===----------------------------------------------------------------------===// + #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" @@ -20,29 +21,19 @@ #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "llvm/ADT/STLExtras.h" using namespace mlir; using namespace mlir::tensor; -/// Get the dimension size of a value of RankedTensor type at the -static OpFoldResult getShapeDimSize(OpBuilder &b, Location loc, - Value rankedTensor, int64_t dimIdx) { - RankedTensorType tensorType = rankedTensor.getType().cast(); - if (!tensorType.isDynamicDim(dimIdx)) { - return b.getIndexAttr(tensorType.getDimSize(dimIdx)); - } - Value idxValue = b.create(loc, dimIdx); - return b.createOrFold(loc, rankedTensor, idxValue); -} - /// Get all the dimension sizes of a value of RankedTensor type. static SmallVector getShapeDimSizes(OpBuilder &b, Location loc, Value rankedTensor) { SmallVector dimSizes; RankedTensorType tensorType = rankedTensor.getType().cast(); for (unsigned i = 0; i < tensorType.getRank(); i++) - dimSizes.push_back(getShapeDimSize(b, loc, rankedTensor, i)); + dimSizes.push_back(reifyShapeDim(b, loc, rankedTensor, i)); return dimSizes; } diff --git a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt @@ -31,6 +31,7 @@ MLIRLinalgDialect MLIRMemRefDialect MLIRSCFDialect + MLIRShapedTypeInterfaces MLIRSideEffectInterfaces MLIRTensorDialect MLIRTransforms diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransferSplitRewritePatterns.cpp @@ -20,10 +20,10 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" - #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "llvm/ADT/DenseSet.h" @@ -79,7 +79,9 @@ Value sum = makeComposedAffineApply(b, loc, d0 + vs, xferOp.indices()[indicesIdx]); Value cond = createFoldedSLE( - b, sum, vector::createOrFoldDimOp(b, loc, xferOp.source(), indicesIdx)); + b, sum, + getValueOrCreateConstantIndexOp( + b, loc, reifyShapeDim(b, loc, xferOp.source(), indicesIdx))); if (!cond) return; // Conjunction over all dims for which we are in-bounds. @@ -201,8 +203,9 @@ auto isaWrite = isa(xferOp); xferOp.zipResultAndIndexing([&](int64_t resultIdx, int64_t indicesIdx) { using MapList = ArrayRef>; - Value dimMemRef = vector::createOrFoldDimOp(b, xferOp.getLoc(), - xferOp.source(), indicesIdx); + Value dimMemRef = getValueOrCreateConstantIndexOp( + b, xferOp.getLoc(), + reifyShapeDim(b, xferOp.getLoc(), xferOp.source(), indicesIdx)); Value dimAlloc = b.create(loc, alloc, resultIdx); Value index = xferOp.indices()[indicesIdx]; AffineExpr i, j, k; diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -32,6 +32,7 @@ #include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/ShapedTypeInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Support/LogicalResult.h" @@ -2706,8 +2707,9 @@ // dimensions here. unsigned lastIndex = llvm::size(xferOp.getIndices()) - 1; Value off = xferOp.getIndices()[lastIndex]; - Value dim = - vector::createOrFoldDimOp(rewriter, loc, xferOp.getSource(), lastIndex); + Value dim = getValueOrCreateConstantIndexOp( + rewriter, loc, + reifyShapeDim(rewriter, loc, xferOp.getSource(), lastIndex)); Value b = rewriter.create(loc, dim.getType(), dim, off); Value mask = rewriter.create( loc, diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -32,17 +32,6 @@ using namespace mlir; -/// Helper function that creates a memref::DimOp or tensor::DimOp depending on -/// the type of `source`. -Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source, - int64_t dim) { - if (source.getType().isa()) - return b.createOrFold(loc, source, dim); - if (source.getType().isa()) - return b.createOrFold(loc, source, dim); - llvm_unreachable("Expected MemRefType or TensorType"); -} - /// Constructs a permutation map from memref indices to vector dimension. /// /// The implementation uses the knowledge of the mapping of enclosing loop to diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -12,6 +12,7 @@ ParallelCombiningOpInterface.cpp RuntimeVerifiableOpInterface.cpp ShapedOpInterfaces.cpp + ShapedTypeInterfaces.cpp SideEffectInterfaces.cpp TilingInterface.cpp VectorInterfaces.cpp @@ -47,6 +48,7 @@ add_mlir_interface_library(ParallelCombiningOpInterface) add_mlir_interface_library(RuntimeVerifiableOpInterface) add_mlir_interface_library(ShapedOpInterfaces) +add_mlir_interface_library(ShapedTypeInterfaces) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) add_mlir_interface_library(VectorInterfaces) diff --git a/mlir/lib/Interfaces/ShapedTypeInterfaces.cpp b/mlir/lib/Interfaces/ShapedTypeInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/ShapedTypeInterfaces.cpp @@ -0,0 +1,27 @@ +//===- ShapedTypeInterfaces.cpp - Interfaces for Shaped Types -------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/ShapedTypeInterfaces.h" + +#include "mlir/IR/Builders.h" + +#include "mlir/Interfaces/ShapedTypeInterfaces.cpp.inc" + +using namespace mlir; + +OpFoldResult mlir::reifyShapeDim(OpBuilder &builder, Location loc, + Value shapedValue, OpFoldResult dim) { + auto shapedTypeInterface = + shapedValue.getType().cast(); + return shapedTypeInterface.reifyShapeDim(builder, loc, shapedValue, dim); +} + +OpFoldResult mlir::reifyShapeDim(OpBuilder &builder, Location loc, + Value shapedValue, int64_t dim) { + return reifyShapeDim(builder, loc, shapedValue, builder.getIndexAttr(dim)); +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -996,6 +996,13 @@ deps = [":OpBaseTdFiles"], ) +td_library( + name = "ShapedTypeInterfacesTdFiles", + srcs = ["include/mlir/Interfaces/ShapedTypeInterfaces.td"], + includes = ["include"], + deps = [":OpBaseTdFiles"], +) + td_library( name = "ParallelCombiningOpInterfaceTdFiles", srcs = ["include/mlir/Interfaces/ParallelCombiningOpInterface.td"], @@ -2241,6 +2248,7 @@ ":Pass", ":SCFDialect", ":SCFTransforms", + ":ShapedTypeInterfaces", ":SparseTensorDialect", ":SparseTensorEnums", ":SparseTensorPassIncGen", @@ -2995,6 +3003,18 @@ ], ) +cc_library( + name = "ShapedTypeInterfaces", + srcs = ["lib/Interfaces/ShapedTypeInterfaces.cpp"], + hdrs = ["include/mlir/Interfaces/ShapedTypeInterfaces.h"], + includes = ["include"], + deps = [ + ":IR", + ":ShapedTypeInterfacesIncGen", + "//llvm:Support", + ], +) + cc_library( name = "ParallelCombiningOpInterface", srcs = ["lib/Interfaces/ParallelCombiningOpInterface.cpp"], @@ -3504,6 +3524,7 @@ ":MemRefDialect", ":Pass", ":SCFDialect", + ":ShapedTypeInterfaces", ":SideEffectInterfaces", ":Support", ":TensorDialect", @@ -5452,10 +5473,14 @@ name = "TensorDialect", srcs = [ "include/mlir/Transforms/InliningUtils.h", + "lib/Dialect/Tensor/IR/ReifyShapeDimTypeInterfaceImpl.cpp", "lib/Dialect/Tensor/IR/TensorDialect.cpp", "lib/Dialect/Tensor/IR/TensorOps.cpp", ], - hdrs = ["include/mlir/Dialect/Tensor/IR/Tensor.h"], + hdrs = [ + "include/mlir/Dialect/Tensor/IR/ReifyShapeDimTypeInterfaceImpl.h", + "include/mlir/Dialect/Tensor/IR/Tensor.h", + ], includes = ["include"], deps = [ ":AffineDialect", @@ -5470,6 +5495,7 @@ ":InferTypeOpInterface", ":ParallelCombiningOpInterface", ":ShapedOpInterfaces", + ":ShapedTypeInterfaces", ":SideEffectInterfaces", ":Support", ":TensorOpsIncGen", @@ -5576,6 +5602,7 @@ ":MemRefDialect", ":Pass", ":SCFDialect", + ":ShapedTypeInterfaces", ":TensorDialect", ":TensorPassIncGen", ":TilingInterface", @@ -5794,6 +5821,24 @@ deps = [":ShapedOpInterfacesTdFiles"], ) +gentbl_cc_library( + name = "ShapedTypeInterfacesIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-type-interface-decls"], + "include/mlir/Interfaces/ShapedTypeInterfaces.h.inc", + ), + ( + ["-gen-type-interface-defs"], + "include/mlir/Interfaces/ShapedTypeInterfaces.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Interfaces/ShapedTypeInterfaces.td", + deps = [":ShapedTypeInterfacesTdFiles"], +) + gentbl_cc_library( name = "ParallelCombiningOpInterfaceIncGen", strip_include_prefix = "include", @@ -8362,6 +8407,7 @@ ":MemRefDialect", ":Parser", ":SCFDialect", + ":ShapedTypeInterfaces", ":SideEffectInterfaces", ":SparseTensorDialect", ":Support", @@ -8463,6 +8509,7 @@ ":MemRefDialect", ":Pass", ":SCFDialect", + ":ShapedTypeInterfaces", ":TensorDialect", ":TensorUtils", "//llvm:Support", @@ -8512,6 +8559,7 @@ ":SCFDialect", ":SCFTransforms", ":SCFUtils", + ":ShapedTypeInterfaces", ":SparseTensorDialect", ":Support", ":TensorDialect", @@ -8827,6 +8875,7 @@ ":MemRefDialect", ":Pass", ":SCFDialect", + ":ShapedTypeInterfaces", ":Support", ":TensorDialect", ":Transforms", @@ -9893,6 +9942,7 @@ ), hdrs = [ "include/mlir/Dialect/MemRef/IR/MemRef.h", + "include/mlir/Dialect/MemRef/IR/ReifyShapeDimTypeInterfaceImpl.h", "include/mlir/Dialect/MemRef/Utils/MemRefUtils.h", ], includes = ["include"], @@ -9907,6 +9957,7 @@ ":MemRefBaseIncGen", ":MemRefOpsIncGen", ":ShapedOpInterfaces", + ":ShapedTypeInterfaces", ":ViewLikeInterface", "//llvm:Support", "//llvm:TargetParser",