diff --git a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h --- a/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/Linalg.h @@ -72,6 +72,19 @@ SmallVector concat(ArrayRef a, ArrayRef b); +/// Create one memref::DimOp or tensor::DimOp depending on the type of `val`. +/// This is a polymorphic convenience function to abstract away the rank and +/// concrete type of `val`. +/// Asserts that `val` is a memref or tensor type. +Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim); + +/// Create one memref::DimOp or tensor::DimOp depending on the type of `val`. +/// This is a polymorphic convenience function to abstract away the rank and +/// concrete type of `val`. +/// Asserts that `val` is a memref or tensor type. +OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, + int64_t dim); + } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Linalg/Utils/IndexingUtils.h deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/IndexingUtils.h +++ /dev/null @@ -1,47 +0,0 @@ -//===- IndexingUtils.h - Indexing utilities supporting Linalg ---*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_LINALG_UTILS_INDEXINGUTILS_H -#define MLIR_DIALECT_LINALG_UTILS_INDEXINGUTILS_H - -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "llvm/ADT/StringSet.h" -#include - -namespace mlir { -namespace linalg { - -/// Create one memref::DimOp or tensor::DimOp depending on the type of `val`. -/// This is a polymorphic convenience function to abstract away the rank and -/// concrete type of `val`. -/// Asserts that `val` is a memref or tensor type. -Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim); - -/// Create one memref::DimOp or tensor::DimOp depending on the type of `val`. -/// This is a polymorphic convenience function to abstract away the rank and -/// concrete type of `val`. -/// Asserts that `val` is a memref or tensor type. -OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, - int64_t dim); - -/// Build the list of DimOp for the dynamic dimensions of `val`. -/// Asserts that `val` is a ranked shaped type. -SmallVector createDynamicDimensions(OpBuilder &b, Location loc, - Value val); - -/// Build the list of all dimensions for `val`, mixing static attributes and -/// dynamic values where appropriate. -/// Asserts that `val` is a ranked shaped type. -SmallVector getMixedDimensions(OpBuilder &b, Location loc, - Value val); - -} // namespace linalg -} // namespace mlir -#endif // MLIR_DIALECT_LINALG_UTILS_INDEXINGUTILS_H diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -10,7 +10,6 @@ #define MLIR_DIALECT_LINALG_UTILS_UTILS_H #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Utils/IndexingUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "llvm/ADT/StringSet.h" diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -14,7 +14,6 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Utils/IndexingUtils.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" @@ -2041,7 +2040,7 @@ computeOutputShape(OpBuilder &builder, Location loc, Value input, llvm::SmallVectorImpl &dynamicSizes) { // Get [N, H, W] - auto dims = linalg::getMixedDimensions(builder, loc, input); + auto dims = tensor::getMixedSizes(builder, loc, input); // Set W = (W / 2) + 1 to account for the half-sized W dimension of the // output tensors. diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Complex/IR/Complex.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineExprVisitor.h" @@ -622,24 +623,6 @@ // StructuredOpInterface implementation //===----------------------------------------------------------------------===// -/// Helper function that creates a memref::DimOp or tensor::DimOp depending on -/// the type of `source`. -static Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, - int64_t dim) { - if (llvm::isa(source.getType())) - return b.createOrFold(loc, source, dim); - if (llvm::isa(source.getType())) - return b.createOrFold(loc, source, dim); - llvm_unreachable("Expected MemRefType or TensorType"); -} -static OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value source, - int64_t dim) { - auto shapedType = llvm::cast(source.getType()); - if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) - return createOrFoldDimOp(b, loc, source, dim); - return b.getIndexAttr(shapedType.getDimSize(dim)); -} - SmallVector LinalgOp::createFlatListOfOperandDims(OpBuilder &b, Location loc) { SmallVector res; diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -46,6 +46,27 @@ using namespace mlir; using namespace mlir::linalg; +//===----------------------------------------------------------------------===// +// Helper functions +//===----------------------------------------------------------------------===// + +Value linalg::createOrFoldDimOp(OpBuilder &b, Location loc, Value source, + int64_t dim) { + if (llvm::isa(source.getType())) + return b.createOrFold(loc, source, dim); + if (llvm::isa(source.getType())) + return b.createOrFold(loc, source, dim); + llvm_unreachable("Expected MemRefType or TensorType"); +} + +OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source, + int64_t dim) { + auto shapedType = llvm::cast(source.getType()); + if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) + return createOrFoldDimOp(b, loc, source, dim); + return b.getIndexAttr(shapedType.getDimSize(dim)); +} + //===----------------------------------------------------------------------===// // Support for named Linalg ops defined in ods-gen. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseToLinalg.cpp @@ -66,13 +66,9 @@ continue; // Extract static / dynamic shape mix from the first operand. - Value firstOperand = operands.front(); - auto rankedTensorType = cast(t); - auto staticShape = llvm::to_vector<4>(rankedTensorType.getShape()); - auto dynamicShape = linalg::createDynamicDimensions(b, loc, firstOperand); - res.push_back(b.create( - loc, staticShape, rankedTensorType.getElementType(), dynamicShape)); + loc, tensor::getMixedSizes(b, loc, operands.front()), + cast(t).getElementType())); } return res; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -330,7 +330,7 @@ // Strides. SmallVector ones(packedRank, rewriter.getIndexAttr(1)); SmallVector sizes = - getMixedDimensions(rewriter, loc, packOp.getDest()); + tensor::getMixedSizes(rewriter, loc, packOp.getDest()); auto insertSliceOp = rewriter.create( loc, /*source=*/padOp, /*dest=*/emptyOp, @@ -395,7 +395,7 @@ // The inner dimensions stay the same as the destination tensor, but the // outer ones are additional 1s. SmallVector sizes(packedRank - destShape.size(), one); - sizes.append(getMixedDimensions(rewriter, loc, unPackOp.getDest())); + sizes.append(tensor::getMixedSizes(rewriter, loc, unPackOp.getDest())); auto extractSliceOp = rewriter.create( loc, destTensorType, unPackOp.getSource(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1419,7 +1419,7 @@ auto emptyOp = rewriter.create(loc, reifiedReturnShapes[0], padValue.getType()); SmallVector mixedSourceDims = - getMixedDimensions(rewriter, loc, padOp.getSource()); + tensor::getMixedSizes(rewriter, loc, padOp.getSource()); Value mask = rewriter.create(loc, maskType, mixedSourceDims); auto zero = rewriter.create(loc, 0); diff --git a/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Utils/CMakeLists.txt @@ -1,6 +1,5 @@ add_mlir_dialect_library(MLIRLinalgUtils Utils.cpp - IndexingUtils.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg diff --git a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp b/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Linalg/Utils/IndexingUtils.cpp +++ /dev/null @@ -1,82 +0,0 @@ -//===- IndexingUtils.cpp - Indexing utilities supporting Linalg -----------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements indexing utilities for the Linalg dialect. -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Linalg/Utils/Utils.h" - -#include "mlir/Analysis/SliceAnalysis.h" -#include "mlir/Dialect/Affine/Analysis/AffineStructures.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Affine/IR/AffineValueMap.h" -#include "mlir/Dialect/Affine/LoopUtils.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Arith/Utils/Utils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/Dialect/Utils/StaticValueUtils.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/AffineExprVisitor.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/Pass/Pass.h" -#include "llvm/ADT/SetOperations.h" -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" -#include - -#define DEBUG_TYPE "linalg-utils" - -namespace mlir { -namespace linalg { -Value createOrFoldDimOp(OpBuilder &b, Location loc, Value val, int64_t dim) { - if (isa(val.getType())) - return b.createOrFold(loc, val, dim); - if (isa(val.getType())) - return b.createOrFold(loc, val, dim); - llvm_unreachable("Expected MemRefType or TensorType"); -} - -OpFoldResult createFoldedDimOp(OpBuilder &b, Location loc, Value val, - int64_t dim) { - auto shapedType = cast(val.getType()); - if (!shapedType.hasRank() || shapedType.isDynamicDim(dim)) - return createOrFoldDimOp(b, loc, val, dim); - return b.getIndexAttr(shapedType.getDimSize(dim)); -} - -SmallVector createDynamicDimensions(OpBuilder &b, Location loc, - Value val) { - auto shapedType = cast(val.getType()); - assert(shapedType.hasRank() && "`val` must have a static rank"); - SmallVector res; - res.reserve(shapedType.getRank()); - for (const auto &dim : llvm::enumerate(shapedType.getShape())) { - if (dim.value() == ShapedType::kDynamic) - res.push_back(createOrFoldDimOp(b, loc, val, dim.index())); - } - return res; -} - -SmallVector getMixedDimensions(OpBuilder &b, Location loc, - Value val) { - auto shapedType = cast(val.getType()); - assert(shapedType.hasRank() && "`val` must have a static rank"); - SmallVector dynamicDims = createDynamicDimensions(b, loc, val); - return getMixedValues(shapedType.getShape(), dynamicDims, b); -} -} // namespace linalg -} // namespace mlir