diff --git a/mlir/include/mlir/Dialect/Linalg/TilingInterface/Tiling.h b/mlir/include/mlir/Dialect/Linalg/TilingInterface/Tiling.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/TilingInterface/Tiling.h @@ -0,0 +1,60 @@ +//===- Tiling.h - Tiling transformsion using TilingInterface ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_LINALG_TILINGINTERFACE_TILING_H_ +#define DIALECT_LINALG_TILINGINTERFACE_TILING_H_ + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" + +namespace mlir { + +/// Structure to represent the result of tiling operation using TilingInterface. +struct TiledOp { + /// Tiled op. + Operation *op; + /// Loops generated during tiling. + SmallVector loops; + /// Values that are replacements for the untiled operations. + SmallVector results; +}; + +/// Main entry point for tiling LinalgExtOps using TilingInterface. If the `op` +/// does not implement the `TilingInterface` returns a `TiledOp{}` value. +FailureOr +tileOpUsingInterface(OpBuilder &b, Operation *op, ValueRange dest, + const linalg::LinalgTilingOptions &options); + +/// Base pattern for tiling TilingInterface. Patterns can inherit from this +/// class and implement the `matchAndRewrite` method to call into the +/// `matchAndRewriteBase` method of this class. Note that this method, does not +/// delete the operation. Instead the `matchAndRewriteBase` method returns +/// failure if an error was encountered. If the tiled implementation of the +/// operation wasnt found, this method returns `TiledOp{}`. +struct TilingInterfaceBasePattern : public RewritePattern { + TilingInterfaceBasePattern(StringRef opName, MLIRContext *context, + linalg::LinalgTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : RewritePattern(opName, benefit, context), filter(filter), + options(options) {} + + LogicalResult matchAndRewriteBase(Operation *op, ValueRange dest, + PatternRewriter &rewriter, + TiledOp &result) const; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + linalg::LinalgTransformationFilter filter; + /// Options to control tiling; + linalg::LinalgTilingOptions options; +}; + +} // namespace mlir + +#endif // DIALECT_LINALT_TILINGINTERFACE_TILING_H_ diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(SideEffectInterfaces) +add_mlir_interface(TilingInterface) add_mlir_interface(VectorInterfaces) add_mlir_interface(ViewLikeInterface) diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/TilingInterface.h @@ -0,0 +1,26 @@ +//===- TilingInterface.h - Interface for tiling operations ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the TilingInterface defined in +// `TilingInterface.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_TILINGINTERFACE_H_ +#define MLIR_INTERFACES_TILINGINTERFACE_H_ + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Support/LLVM.h" + +/// Include the ODS generated interface header files. +#include "mlir/Interfaces/TilingInterface.h.inc" + +#endif // MLIR_INTERFACES_TILINGINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -0,0 +1,68 @@ +//===- TilingInterface.td - Interface for tiling operations *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains an interface to allow operations to generate a tiled +// implementation of themselves. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TILINGINTERFACE +#define MLIR_TILINGINTERFACE + +include "mlir/IR/OpBase.td" + +def TilingInterface : OpInterface<"TilingInterface"> { + let description = [{ + Interface for allowing operations to expose information needed to + tile it (similar to LinalgOp, but without having access to + indexing maps) + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns a list of `StringRef`s that describe the number of + loops and the iterator types of the operation. The list is + expected to use + `getParallelIteratorTypeName()`/`getReductionIteratorTypeName()` + from MLIR Structured Op Utils. + }], + /*retType=*/"SmallVector", + /*methodName=*/"getLoopIteratorTypes" + >, + InterfaceMethod< + /*desc=*/[{ + Returns a list of ranges that describe the loop bounds and + step for the loops of the operation. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getLoopBounds", + /*args=*/(ins "OpBuilder &":$b) + >, + InterfaceMethod< + /*desc=*/[{ + Generates a tiled version of the operation given the tile + size for the loops. + }], + /*retType=*/"Operation *", + /*methodName=*/"getTiledImplementation", + /*args=*/(ins + "OpBuilder &":$b, + "ValueRange ":$outputs, + "ArrayRef ":$offsets, + "ArrayRef ":$sizes, + "SmallVectorImpl> &":$resultOffsets, + "SmallVectorImpl> &":$resultSizes), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return nullptr; + }] + > + ]; +} +#endif // MLIR_TILINGINTERFACE diff --git a/mlir/lib/Dialect/Linalg/CMakeLists.txt b/mlir/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Analysis) add_subdirectory(IR) +add_subdirectory(TilingInterface) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Linalg/TilingInterface/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TilingInterface/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/TilingInterface/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(MLIRTilingTransform + Tiling.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg/ + + LINK_LIBS_PUBLIC + MLIRAffine + MLIRIR + MLIRLinalgTransforms + MLIRMemRef + MLIRSCF + MLIRStandard + MLIRTensor + MLIRTilingInterface +) diff --git a/mlir/lib/Dialect/Linalg/TilingInterface/Tiling.cpp b/mlir/lib/Dialect/Linalg/TilingInterface/Tiling.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/TilingInterface/Tiling.cpp @@ -0,0 +1,290 @@ +//===- Tiling.cpp - Implementation of Tiling using TilingInterface --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Tiling using Tiling Interface. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/TilingInterface/Tiling.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/TilingInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utility methods for tiling a linalg_ext operation that implements a +// TilingInterface +//===----------------------------------------------------------------------===// + +/// Returns failure if the options are unsupported. +static LogicalResult +verifySupportedTilingOptions(PatternRewriter &rewriter, Operation *op, + const linalg::LinalgTilingOptions &options) { + if (!options.interchangeVector.empty()) { + return rewriter.notifyMatchFailure(op, + "unsupported interchange during tiling"); + } + if (options.paddingValueComputationFunction) { + return rewriter.notifyMatchFailure(op, "unsupported tile + pad option"); + } + if (options.loopType != linalg::LinalgTilingLoopType::Loops) { + return rewriter.notifyMatchFailure(op, + "only tiling with scf.for is supported"); + } + if (options.distribution) { + if (llvm::any_of(options.distribution->distributionMethod, + [](linalg::DistributionMethod method) { + return method != linalg::DistributionMethod::Cyclic; + })) { + return rewriter.notifyMatchFailure(op, + "only cyclic distibution is allowed"); + } + } + return success(); +} + +/// Converts a `Value` to an `OpFoldRedult` by extracting the constant value if +/// the value is defined by a constant op. +static OpFoldResult getOpFoldResult(Value value) { + IntegerAttr::ValueType attr; + if (matchPattern(value, m_ConstantInt(&attr))) + return IntegerAttr::get(value.getType(), attr); + return value; +} +static SmallVector getOpFoldResult(ArrayRef values) { + return llvm::to_vector<4>(llvm::map_range( + values, [](Value value) { return getOpFoldResult(value); })); +} + +/// Converts an `OpFoldResult` to a `Value` by building a constant op if +/// if the `OpFoldResult` is an `IntegerAttr`. +static Value getValue(OpBuilder &builder, Location loc, + OpFoldResult valueOrAttr) { + if (auto attr = valueOrAttr.dyn_cast()) { + return builder.create(loc, + attr.cast().getInt()); + } + return valueOrAttr.get(); +} + +/// Checks if `valueOrAttr` represents a constant value `val`. +static bool isValue(OpFoldResult valueOrAttr, int64_t val) { + auto attr = valueOrAttr.dyn_cast(); + return attr && attr.cast().getValue() == val; +} + +/// Returns true if loop is untiled. Only checks if the value is statically +/// zero. It is assumed that a `Value` defined by a constant op is already +/// converted to an `IntegerAttr` of that value. So here just return true if +/// this is an attribute with a zero value. +static bool isUntiledLoop(OpFoldResult valueOrAttr) { + return isValue(valueOrAttr, 0); +} + +/// Generates the tiled loops and the body by invoking the interface methods of +/// TilingInterface. +/// - `outputs` are the operands to use for outputs of the tiled operation. +/// - `tileSizes` are tile sizes specified for all loops of the operation. If a +/// loop is to be untiled it is set to 0. +/// - `iteratorType` is the type of the loop iterator returned by the +/// TilingInterface. +/// - `loopBounds` are the bounds of all the loops of the op returned by the +/// TilingInterface. +/// - `loopDepth` is the current loop depth being processed. +/// - `offsets` are the `Value`s that represent the position of the tile being +/// operated on. The offsets are computed as the tiled loops are being +/// generated. +/// - `distributionInfo` is the proc_id and nprocs `Value`s to be used for +/// distributed loops. It is a stack, and once an entry at the top of the +/// stack is used for distribution it is popped before processing the inner +/// loops. +static FailureOr tileOpUsingInterfaceImpl( + OpBuilder &builder, TilingInterface tilableOp, ValueRange outputs, + MutableArrayRef tileSizes, ArrayRef iteratorTypes, + ArrayRef loopBounds, unsigned loopDepth, + SmallVectorImpl &offsets, + ArrayRef distributionInfo) { + Location loc = tilableOp.getLoc(); + // If this is the innermost loop, then generated the tiled implementation of + // the op by invoking the TilingInterface methods. + if (loopDepth == tileSizes.size()) { + SmallVector> resultOffsets; + SmallVector> resultSizes; + Operation *tiledOp = tilableOp.getTiledImplementation( + builder, outputs, offsets, tileSizes, resultOffsets, resultSizes); + if (!tiledOp) { + return static_cast( + tilableOp.emitOpError("failed to get tiled implementation")); + } + assert(tiledOp->getNumResults() == 0 || + (resultOffsets.size() == tiledOp->getNumResults())); + TiledOp ret; + ret.op = tiledOp; + + // If the operation has results, then the result of the tiled operation is + // to be inserted into the `initValues` and returned. + if (tiledOp->getNumResults()) { + SmallVector results; + auto oneAttr = builder.getI64IntegerAttr(1); + results.reserve(tiledOp->getNumResults()); + for (auto en : llvm::enumerate(tiledOp->getResults())) { + Value result = en.value(); + ArrayRef offsets(resultOffsets[en.index()]); + ArrayRef sizes(resultSizes[en.index()]); + SmallVector strides(offsets.size(), oneAttr); + Value insert = builder.create( + loc, result, outputs[en.index()], offsets, sizes, strides); + results.push_back(insert); + } + std::swap(ret.results, results); + } + return ret; + } + + // If tile size at this depth is empty, do nothing. + if (isUntiledLoop(tileSizes[loopDepth])) { + auto zeroAttr = builder.getI64IntegerAttr(0); + offsets.push_back(zeroAttr); + assert(matchPattern(loopBounds[loopDepth].offset, m_Zero()) && + "expected loop bounds to have lower bound of zero"); + tileSizes[loopDepth] = getOpFoldResult(loopBounds[loopDepth].size); + return tileOpUsingInterfaceImpl(builder, tilableOp, outputs, tileSizes, + iteratorTypes, loopBounds, loopDepth + 1, + offsets, distributionInfo); + } + + // Generate an scf.for for the current loop depth. + Value lb = loopBounds[loopDepth].offset; + Value ub = loopBounds[loopDepth].size; + if (!matchPattern(loopBounds[loopDepth].stride, m_One())) { + return static_cast( + tilableOp.emitOpError("expected stride to be 1")); + } + Value step = getValue(builder, loc, tileSizes[loopDepth]); + + // Update lb, ub and step for cyclic distribution. + if (!distributionInfo.empty() && + iteratorTypes[loopDepth] == getParallelIteratorTypeName()) { + linalg::updateBoundsForCyclicDistribution( + builder, loc, distributionInfo.front().procId, + distributionInfo.front().nprocs, lb, ub, step); + distributionInfo = distributionInfo.drop_front(); + } + FailureOr innerReturnValue; + bool isBufferTiling = tilableOp->getNumResults() == 0; + ValueRange initValues(isBufferTiling ? ValueRange{} : outputs); + auto forOp = builder.create( + loc, lb, ub, step, initValues, + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + offsets.push_back(iv); + auto affineMaps = AffineMap::inferFromExprList({ArrayRef{ + b.getAffineSymbolExpr(0), + b.getAffineSymbolExpr(1) - b.getAffineDimExpr(0)}})[0]; + // Similar to linalg tiling, the tile size is the min(tileSizes, ub - + // iv) to account for cases where tile size does not divide (ub - lb) + // exactly. + Value inBoundsTileSize = b.create( + loc, affineMaps, + ValueRange{iv, getValue(builder, loc, tileSizes[loopDepth]), ub}); + tileSizes[loopDepth] = getOpFoldResult(inBoundsTileSize); + // Recursively proceed to generate the tiled loop for the next level. + innerReturnValue = tileOpUsingInterfaceImpl( + b, tilableOp, (isBufferTiling ? outputs : args), tileSizes, + iteratorTypes, loopBounds, loopDepth + 1, offsets, + distributionInfo); + if (failed(innerReturnValue)) + return; + b.create(loc, innerReturnValue->results); + }); + if (failed(innerReturnValue)) + return innerReturnValue; + innerReturnValue->loops.insert(innerReturnValue->loops.begin(), + forOp.getOperation()); + innerReturnValue->results = forOp.getResults(); + return innerReturnValue; +} + +FailureOr +mlir::tileOpUsingInterface(OpBuilder &b, Operation *op, ValueRange dest, + const linalg::LinalgTilingOptions &options) { + TilingInterface tilableOp = dyn_cast(op); + if (!tilableOp) + return TiledOp{}; + + SmallVector iteratorTypes = tilableOp.getLoopIteratorTypes(); + SmallVector tileSizesVals = + options.tileSizeComputationFunction(b, op); + auto zeroAttr = b.getI64IntegerAttr(0); + + // The actual tile sizes used converts `Value` defined as constant 0, to a + // zero integer attributes. Currently if the iterator type is not "parallel", + // the tile size is forced to zero as well. + auto tileSizes = getOpFoldResult(tileSizesVals); + tileSizes.resize(iteratorTypes.size(), zeroAttr); + for (auto en : llvm::enumerate(iteratorTypes)) { + if (en.value() == getParallelIteratorTypeName()) + continue; + if (!isUntiledLoop(tileSizes[en.index()])) { + return static_cast(tilableOp.emitOpError( + "unimplemented tiling of non-parallel loop iterator type")); + } + } + + // Trivial early exit case of tile sizes being zero for all parallel loops. + if (llvm::all_of(tileSizes, isUntiledLoop)) + return TiledOp{op, {}, {}}; + + SmallVector loopBounds = tilableOp.getLoopBounds(b); + SmallVector distributionInfo; + // If the tiled loops are distributed, get the proc_id and nprocs for the + // distributed loops. First collect the parallel loops by iterating over the + // tileSizes and getting the loops that are distribute, i.e., + // - parallel, i.e. iteratorTypes is "parallel" + // - tiled, i.e. tileSize != 0 + if (options.distribution) { + SmallVector distributedLoopRange; + for (auto i : llvm::seq(0, tileSizes.size())) { + if (isUntiledLoop(tileSizes[i])) + continue; + if (iteratorTypes[i] != getParallelIteratorTypeName()) + continue; + distributedLoopRange.push_back(loopBounds[i]); + } + distributionInfo = options.distribution->procInfo(b, tilableOp.getLoc(), + distributedLoopRange); + } + + SmallVector offsets; + return tileOpUsingInterfaceImpl(b, tilableOp, dest, tileSizes, iteratorTypes, + loopBounds, 0, offsets, distributionInfo); +} + +//===----------------------------------------------------------------------===// +// Definintion of methods for `TilingInterfaceBasePattern`. +//===----------------------------------------------------------------------===// + +LogicalResult mlir::TilingInterfaceBasePattern::matchAndRewriteBase( + Operation *op, ValueRange dest, PatternRewriter &rewriter, + TiledOp &result) const { + if (failed(filter.checkAndNotify(rewriter, op))) + return failure(); + if (failed(verifySupportedTilingOptions(rewriter, op, options))) + return failure(); + + FailureOr res = tileOpUsingInterface(rewriter, op, dest, options); + if (failed(res)) + return res; + result = *res; + if (result.op) + filter.replaceLinalgTransformationFilter(rewriter, result.op); + return success(); +} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -8,6 +8,7 @@ InferTypeOpInterface.cpp LoopLikeInterface.cpp SideEffectInterfaces.cpp + TilingInterface.cpp VectorInterfaces.cpp ViewLikeInterface.cpp ) @@ -37,6 +38,6 @@ add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(LoopLikeInterface) add_mlir_interface_library(SideEffectInterfaces) +add_mlir_interface_library(TilingInterface) add_mlir_interface_library(VectorInterfaces) add_mlir_interface_library(ViewLikeInterface) - diff --git a/mlir/lib/Interfaces/TilingInterface.cpp b/mlir/lib/Interfaces/TilingInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/TilingInterface.cpp @@ -0,0 +1,17 @@ +//===- TilingInterface.cpp - Tiling interface -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the interface in `TilingInterface.td`. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/TilingInterface.h" + +namespace mlir { +#include "mlir/Interfaces/TilingInterface.cpp.inc" +} diff --git a/mlir/test/Interfaces/TilingInterface/tiling.mlir b/mlir/test/Interfaces/TilingInterface/tiling.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tiling.mlir @@ -0,0 +1,434 @@ +// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s + +func @scatter_tiling( + %original: tensor, %indices: tensor, + %update : tensor) -> tensor { + %0 = test.full_size_output_tile + {__internal_linalg_transform__ = "tiling_input"} + inputs(%update, %indices : tensor, tensor) + outputs(%original : tensor) -> tensor + return %0 : tensor +} +// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK: func @scatter_tiling( +// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[TILESIZE:.+]] = constant 10 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[UPDATES]], %[[C0]] +// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]] +// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[ORIGINAL]]) +// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[UPDATES]], %[[C1]] +// CHECK: %[[UPDATE_SLICE:.+]] = tensor.extract_slice %[[UPDATES]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], 1] +// CHECK-DAG: %[[SLICE_D0:.+]] = tensor.dim %[[ORIGINAL]], %[[C0]] +// CHECK-DAG: %[[SLICE_D1:.+]] = tensor.dim %[[ORIGINAL]], %[[C1]] +// CHECK: %[[SCATTER_TILE:.+]] = test.full_size_output_tile +// CHECK-SAME: __internal_linalg_transform__ = "tiling_output" +// CHECK-SAME: inputs(%[[UPDATE_SLICE]], %[[INDEX_SLICE]] +// CHECK-SAME: outputs(%[[INIT]] +// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INIT]][0, 0] +// CHECK-SAME: [%[[SLICE_D0]], %[[SLICE_D1]]] +// CHECK: scf.yield %[[YIELD]] +// CHECK: return %[[RESULT]] + +// ----- + +func @scatter_tiling_memref( + %original: memref, %indices: memref, + %update : memref) { + test.full_size_output_tile + {__internal_linalg_transform__ = "tiling_input"} + inputs(%update, %indices : memref, memref) + outputs(%original : memref) + return +} +// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK: func @scatter_tiling_memref( +// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[TILESIZE:.+]] = constant 10 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[UPDATES]], %[[C0]] +// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]] +// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]] +// CHECK-DAG: %[[D1:.+]] = memref.dim %[[UPDATES]], %[[C1]] +// CHECK: %[[UPDATE_SLICE:.+]] = memref.subview %[[UPDATES]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: %[[INDEX_SLICE:.+]] = memref.subview %[[INDICES]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], 1] +// CHECK: test.full_size_output_tile +// CHECK-SAME: __internal_linalg_transform__ = "tiling_output" +// CHECK-SAME: inputs(%[[UPDATE_SLICE]], %[[INDEX_SLICE]] +// CHECK-SAME: outputs(%[[ORIGINAL]] + +// ----- + +func @scatter_tiling_distribution( + %original: tensor, %indices: tensor, + %update : tensor) -> tensor { + %0 = test.full_size_output_tile + {__internal_linalg_transform__ = "distribute_input"} + inputs(%update, %indices : tensor, tensor) + outputs(%original : tensor) -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK: func @scatter_tiling_distribution( +// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[TILESIZE:.+]] = constant 10 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[UPDATES]], %[[C0]] +// CHECK-DAG: %[[ID:.+]] = "gpu.block_id"() {dimension = "x"} +// CHECK-DAG: %[[COUNT:.+]] = "gpu.grid_dim"() {dimension = "x"} +// CHECK-DAG: %[[OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[ID]]] +// CHECK-DAG: %[[STEP:.+]] = affine.apply #[[MAP0]]()[%[[COUNT]]] +// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[OFFSET]] to %[[D0]] step %[[STEP]] +// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[ORIGINAL]]) +// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP1]](%[[IV]])[%[[TILESIZE]], %[[D0]]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[UPDATES]], %[[C1]] +// CHECK: %[[UPDATE_SLICE:.+]] = tensor.extract_slice %[[UPDATES]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], 1] +// CHECK-DAG: %[[SLICE_D0:.+]] = tensor.dim %[[ORIGINAL]], %[[C0]] +// CHECK-DAG: %[[SLICE_D1:.+]] = tensor.dim %[[ORIGINAL]], %[[C1]] +// CHECK: %[[SCATTER_TILE:.+]] = test.full_size_output_tile +// CHECK-SAME: __internal_linalg_transform__ = "distribute_output" +// CHECK-SAME: inputs(%[[UPDATE_SLICE]], %[[INDEX_SLICE]] +// CHECK-SAME: outputs(%[[INIT]] +// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INIT]][0, 0] +// CHECK-SAME: [%[[SLICE_D0]], %[[SLICE_D1]]] +// CHECK: scf.yield %[[YIELD]] +// CHECK: return %[[RESULT]] + +// ----- + +func @scatter_no_tiling( + %original: tensor, %indices: tensor, + %update : tensor) -> tensor { + %0 = test.full_size_output_tile + {__internal_linalg_transform__ = "no_tiling_input"} + inputs(%update, %indices : tensor, tensor) + outputs(%original : tensor) -> tensor + return %0 : tensor +} +// CHECK: func @scatter_no_tiling +// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[RESULT:.+]] = test.full_size_output_tile +// CHECK-SAME: __internal_linalg_transform__ = "no_tiling_output" +// CHECK-SAME: inputs(%[[UPDATES]], %[[INDICES]] +// CHECK-SAME: outputs(%[[ORIGINAL]] +// CHECK: return %[[RESULT]] + +// ----- + +func @sort_1d(%arg0: tensor) -> tensor { + %0 = test.mixed_parallel_reduce reduce(0) + {__internal_linalg_transform__ = "outer_reduce_input"} + outputs(%arg0 : tensor) -> tensor + return %0 : tensor +} +// CHECK: func @sort_1d( +// CHECK-SAME: %[[OPERAND:.+]]: tensor +// CHECK: %[[RESULT:.+]] = test.mixed_parallel_reduce +// CHECK-SAME: {__internal_linalg_transform__ = "outer_reduce_output"} +// CHECK-SAME: outputs(%[[OPERAND]] : +// CHECK: return %[[RESULT]] + +// ----- + +func @sort_2d(%arg0: tensor) -> tensor { + %0 = test.mixed_parallel_reduce reduce(1) + {__internal_linalg_transform__ = "inner_reduce_input"} + outputs(%arg0 : tensor) -> tensor + return %0 : tensor +} +// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK: func @sort_2d( +// CHECK-SAME: %[[OPERAND:.+]]: tensor +// CHECK-DAG: %[[TILESIZE:.+]] = constant 10 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND]], %[[C1]] +// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]] +// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[OPERAND]]) +// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]] +// CHECK: %[[OPERAND_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: %[[SORT_TILE:.+]] = test.mixed_parallel_reduce +// CHECK-SAME: __internal_linalg_transform__ = "inner_reduce_output" +// CHECK-SAME: outputs(%[[OPERAND_SLICE]] +// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SORT_TILE]] into %[[INIT]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: scf.yield %[[YIELD]] +// CHECK: return %[[RESULT]] + +// ----- + +func @sort_2d_inner_parallel(%arg0: tensor) -> tensor { + %0 = test.mixed_parallel_reduce reduce(0) + {__internal_linalg_transform__ = "outer_reduce_input"} + outputs(%arg0 : tensor) -> tensor + return %0 : tensor +} +// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK: func @sort_2d_inner_parallel( +// CHECK-SAME: %[[OPERAND:.+]]: tensor +// CHECK-DAG: %[[TILESIZE:.+]] = constant 20 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND]], %[[C1]] +// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZE]] +// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[OPERAND]]) +// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D1]]] +// CHECK: %[[OPERAND_SLICE:.+]] = tensor.extract_slice %[[INIT]][0, %[[IV]]] +// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]] +// CHECK: %[[SORT_TILE:.+]] = test.mixed_parallel_reduce +// CHECK-SAME: __internal_linalg_transform__ = "outer_reduce_output" +// CHECK-SAME: outputs(%[[OPERAND_SLICE]] +// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SORT_TILE]] into %[[INIT]][0, %[[IV]]] +// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]] +// CHECK: scf.yield %[[YIELD]] +// CHECK: return %[[RESULT]] + +// ----- + +func @sort_2d_multi_result( + %arg0: tensor, %arg1: tensor) + -> (tensor, tensor) { + %0:2 = test.mixed_parallel_reduce reduce(1) + {__internal_linalg_transform__ = "inner_reduce_input"} + outputs(%arg0, %arg1 : tensor, tensor) + -> tensor, tensor + return %0#0, %0#1 : tensor, tensor +} +// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK: func @sort_2d_multi_result( +// CHECK-SAME: %[[OPERAND1:.+]]: tensor +// CHECK-SAME: %[[OPERAND2:.+]]: tensor +// CHECK-DAG: %[[TILESIZE:.+]] = constant 10 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND1]], %[[C1]] +// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]] +// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[OPERAND1]], %[[INIT2:.+]] = %[[OPERAND2]]) +// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]] +// CHECK: %[[OPERAND1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: %[[OPERAND2_SLICE:.+]] = tensor.extract_slice %[[INIT2]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: %[[SORT_TILE:.+]]:2 = test.mixed_parallel_reduce +// CHECK-SAME: __internal_linalg_transform__ = "inner_reduce_output" +// CHECK-SAME: outputs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]] +// CHECK: %[[YIELD1:.+]] = tensor.insert_slice %[[SORT_TILE]]#0 into %[[INIT1]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: %[[YIELD2:.+]] = tensor.insert_slice %[[SORT_TILE]]#1 into %[[INIT2]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: scf.yield %[[YIELD1]], %[[YIELD2]] +// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 + +// ----- + +func @sort_2d_multi_result_memref( + %arg0: memref, %arg1: memref) { + test.mixed_parallel_reduce reduce(0) + {__internal_linalg_transform__ = "outer_reduce_input"} + outputs(%arg0, %arg1 : memref, memref) + return +} +// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK: func @sort_2d_multi_result_memref( +// CHECK-SAME: %[[OPERAND1:.+]]: memref +// CHECK-SAME: %[[OPERAND2:.+]]: memref +// CHECK-DAG: %[[TILESIZE:.+]] = constant 20 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[OPERAND1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = memref.dim %[[OPERAND1]], %[[C1]] +// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZE]] +// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D1]]] +// CHECK: %[[OPERAND1_SLICE:.+]] = memref.subview %[[OPERAND1]][0, %[[IV]]] +// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]] +// CHECK: %[[OPERAND2_SLICE:.+]] = memref.subview %[[OPERAND2]][0, %[[IV]]] +// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]] +// CHECK: test.mixed_parallel_reduce +// CHECK-SAME: __internal_linalg_transform__ = "outer_reduce_output" +// CHECK-SAME: outputs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]] + +// ----- + +func @sort_3d_multi_result_distribute( + %arg0: tensor, %arg1 : tensor) + -> (tensor, tensor) { + %0, %1 = test.mixed_parallel_reduce reduce(1) + {__internal_linalg_transform__ = "distribute_input"} + outputs(%arg0, %arg1 : tensor, tensor) + -> tensor, tensor + return %0, %1 : tensor, tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 30)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)> +// CHECK: func @sort_3d_multi_result_distribute( +// CHECK-SAME: %[[OPERAND1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[OPERAND2:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[TILESIZE1:.+]] = constant 10 : index +// CHECK-DAG: %[[TILESIZE2:.+]] = constant 30 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND1]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[OPERAND1]], %[[C2]] +// CHECK-DAG: %[[IDX:.+]] = "gpu.block_id"() {dimension = "x"} +// CHECK-DAG: %[[COUNTX:.+]] = "gpu.grid_dim"() {dimension = "x"} +// CHECK-DAG: %[[IDY:.+]] = "gpu.block_id"() {dimension = "y"} +// CHECK-DAG: %[[COUNTY:.+]] = "gpu.grid_dim"() {dimension = "y"} +// CHECK-DAG: %[[OFFSETY:.+]] = affine.apply #[[MAP0]]()[%[[IDY]]] +// CHECK-DAG: %[[STEPY:.+]] = affine.apply #[[MAP0]]()[%[[COUNTY]]] +// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV0:.+]] = %[[OFFSETY]] to %[[D0]] step %[[STEPY]] +// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[OPERAND1]], %[[INIT2:.+]] = %[[OPERAND2]]) +// CHECK-DAG: %[[USED_TILESIZE1:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[TILESIZE1]], %[[D0]]] +// CHECK-DAG: %[[OFFSETX:.+]] = affine.apply #[[MAP2]]()[%[[IDX]]] +// CHECK-DAG: %[[STEPX:.+]] = affine.apply #[[MAP2]]()[%[[COUNTX]]] +// CHECK: %[[RESULT_INNER:.+]]:2 = scf.for %[[IV1:.+]] = %[[OFFSETX]] to %[[D2]] step %[[STEPX]] +// CHECK-SAME: iter_args(%[[INIT3:.+]] = %[[INIT1]], %[[INIT4:.+]] = %[[INIT2]]) +// CHECK-DAG: %[[USED_TILESIZE2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[TILESIZE2]], %[[D2]]] +// CHECK: %[[OPERAND1_SLICE:.+]] = tensor.extract_slice %[[INIT3]][%[[IV0]], 0, %[[IV1]]] +// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]] +// CHECK: %[[OPERAND2_SLICE:.+]] = tensor.extract_slice %[[INIT4]][%[[IV0]], 0, %[[IV1]]] +// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]] +// CHECK: %[[SORT_SLICE:.+]]:2 = test.mixed_parallel_reduce +// CHECK-SAME: __internal_linalg_transform__ = "distribute_output" +// CHECK-SAME: outputs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]] +// CHECK: %[[YIELD1:.+]] = tensor.insert_slice %[[SORT_SLICE]]#0 +// CHECK-SAME: into %[[INIT3]][%[[IV0]], 0, %[[IV1]]] +// CHECK: %[[YIELD2:.+]] = tensor.insert_slice %[[SORT_SLICE]]#1 +// CHECK-SAME: into %[[INIT4]][%[[IV0]], 0, %[[IV1]]] +// CHECK: scf.yield %[[YIELD1]], %[[YIELD2]] +// CHECK: scf.yield %[[RESULT_INNER]]#0, %[[RESULT_INNER]]#1 +// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 + +// ----- + +func @sort_3d_multi_result_distribute_memref( + %arg0: memref, %arg1 : memref) { + test.mixed_parallel_reduce reduce(1) + {__internal_linalg_transform__ = "distribute_input"} + outputs(%arg0, %arg1 : memref, memref) + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 30)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)> +// CHECK: func @sort_3d_multi_result_distribute_memref( +// CHECK-SAME: %[[OPERAND1:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[OPERAND2:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[TILESIZE1:.+]] = constant 10 : index +// CHECK-DAG: %[[TILESIZE2:.+]] = constant 30 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[OPERAND1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = memref.dim %[[OPERAND1]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = memref.dim %[[OPERAND1]], %[[C2]] +// CHECK-DAG: %[[IDX:.+]] = "gpu.block_id"() {dimension = "x"} +// CHECK-DAG: %[[COUNTX:.+]] = "gpu.grid_dim"() {dimension = "x"} +// CHECK-DAG: %[[IDY:.+]] = "gpu.block_id"() {dimension = "y"} +// CHECK-DAG: %[[COUNTY:.+]] = "gpu.grid_dim"() {dimension = "y"} +// CHECK-DAG: %[[OFFSETY:.+]] = affine.apply #[[MAP0]]()[%[[IDY]]] +// CHECK-DAG: %[[STEPY:.+]] = affine.apply #[[MAP0]]()[%[[COUNTY]]] +// CHECK: scf.for %[[IV0:.+]] = %[[OFFSETY]] to %[[D0]] step %[[STEPY]] +// CHECK-DAG: %[[USED_TILESIZE1:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[TILESIZE1]], %[[D0]]] +// CHECK-DAG: %[[OFFSETX:.+]] = affine.apply #[[MAP2]]()[%[[IDX]]] +// CHECK-DAG: %[[STEPX:.+]] = affine.apply #[[MAP2]]()[%[[COUNTX]]] +// CHECK: scf.for %[[IV1:.+]] = %[[OFFSETX]] to %[[D2]] step %[[STEPX]] +// CHECK-DAG: %[[USED_TILESIZE2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[TILESIZE2]], %[[D2]]] +// CHECK: %[[OPERAND1_SLICE:.+]] = memref.subview %[[OPERAND1]][%[[IV0]], 0, %[[IV1]]] +// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]] +// CHECK: %[[OPERAND2_SLICE:.+]] = memref.subview %[[OPERAND2]][%[[IV0]], 0, %[[IV1]]] +// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]] +// CHECK: test.mixed_parallel_reduce +// CHECK-SAME: __internal_linalg_transform__ = "distribute_output" +// CHECK-SAME: outputs(%[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]] + +// ----- + +func @slice_insert(%source :tensor, %dest: tensor, + %idx0 : index, %idx1 : index) -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = tensor.dim %source, %c0 : tensor + %1 = tensor.dim %source, %c1 : tensor + %2 = tensor.insert_slice %source into %dest[%idx0, %idx1] [%0, %1] [1, 1] + {__internal_linalg_transform__ = "tiling_input"} : tensor into tensor + return %2 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK: func @slice_insert( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = +// CHECK: %[[YIELD1:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = +// CHECK: %[[OFFSET0:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG2]]] +// CHECK: %[[OFFSET1:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG3]]] +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]] +// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[SLICE]] +// CHECK-SAME: into %{{.+}}[%[[OFFSET0]], %[[OFFSET1]]] +// CHECK: scf.yield %[[UPDATE]] +// CHECK: scf.yield %[[YIELD1]] +// CHECK: return %[[RESULT]] + +// ----- + +func @slice_insert_rank_reduce(%source :tensor, %dest: tensor, + %idx0 : index, %idx1 : index) -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = tensor.dim %source, %c0 : tensor + %1 = tensor.dim %source, %c1 : tensor + %2 = tensor.insert_slice %source into %dest[%idx0, 0, %idx1] [%0, 1, %1] [1, 1, 1] + {__internal_linalg_transform__ = "tiling_input"} : tensor into tensor + return %2 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK: func @slice_insert_rank_reduce( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = +// CHECK: %[[YIELD1:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = +// CHECK: %[[OFFSET0:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG2]]] +// CHECK: %[[OFFSET1:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG3]]] +// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]] +// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[SLICE]] +// CHECK-SAME: into %{{.+}}[%[[OFFSET0]], 0, %[[OFFSET1]]] +// CHECK: scf.yield %[[UPDATE]] +// CHECK: scf.yield %[[YIELD1]] +// CHECK: return %[[RESULT]] diff --git a/mlir/test/lib/CMakeLists.txt b/mlir/test/lib/CMakeLists.txt --- a/mlir/test/lib/CMakeLists.txt +++ b/mlir/test/lib/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(IR) +add_subdirectory(Interfaces) add_subdirectory(Pass) add_subdirectory(Reducer) add_subdirectory(Rewrite) diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -31,6 +31,7 @@ #include "mlir/Interfaces/DerivedAttributeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/TilingInterface.h" namespace mlir { class DLTIDialect; diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -19,6 +19,7 @@ include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/DataLayoutInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/TilingInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "TestInterfaces.td" @@ -2027,6 +2028,36 @@ }]; } +//===----------------------------------------------------------------------===// +// Test TilingInterface +//===----------------------------------------------------------------------===// + +def TestFullSizeOutputTile : TEST_Op<"full_size_output_tile", + [AttrSizedOperandSegments]> { + let arguments = (ins Variadic:$inputs, Variadic:$outputs); + let results = (outs Variadic:$results); + let assemblyFormat = [{ + attr-dict (`inputs` `(` $inputs^ `:` type($inputs) `)`)? + (`outputs` `(` $outputs^ `:` type($outputs) `)`)? (`->` type($results)^)? + }]; +} + +def TestMixedReduceParallel : TEST_Op<"mixed_parallel_reduce", + [AttrSizedOperandSegments]> { + let arguments = (ins + Variadic:$inputs, Variadic:$outputs, + I64Attr:$reduce_dim); + let results = (outs Variadic:$results); + let assemblyFormat = [{ + `reduce` `(` $reduce_dim `)` attr-dict + (`inputs` `(` $inputs^ `:` type($inputs) `)`)? + (`outputs` `(` $outputs^ `:` type($outputs) `)`)? (`->` type($results)^)? + }]; +} + +//===----------------------------------------------------------------------===// + + //===----------------------------------------------------------------------===// // Test TableGen generated build() methods //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Interfaces/CMakeLists.txt b/mlir/test/lib/Interfaces/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Interfaces/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TilingInterface) diff --git a/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt b/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt @@ -0,0 +1,21 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRTestTilingInterface + TestTilingInterface.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS_PUBLIC + MLIRAffine + MLIRGPUOps + MLIRIR + MLIRLinalgTransforms + MLIRMemRef + MLIRPass + MLIRSCF + MLIRStandard + MLIRTransformUtils + MLIRTensor + MLIRTilingInterface + MLIRTilingTransform + MLIRTransformUtils +) diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -0,0 +1,495 @@ +//===- TestTilingInterface.cpp - Test pass for Tiling Interface patterns --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file tests the tiling transformations implemented using TilingInterface. +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/Linalg/TilingInterface/Tiling.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utility methods to convert from `OpFoldResult` to `Value` +//===----------------------------------------------------------------------===// + +/// Converts an `OpFoldResult` to a `Value` by building a constant op if +/// if the `OpFoldResult` is an `IntegerAttr`. +static Value getValue(OpBuilder &builder, Location loc, + OpFoldResult valueOrAttr) { + if (auto attr = valueOrAttr.dyn_cast()) { + return builder.create(loc, + attr.cast().getInt()); + } + return valueOrAttr.get(); +} + +/// Returns the constant value in `valueOrAttr` if it is not a dynamic `Value`. +static Optional getConstantValue(OpFoldResult valueOrAttr) { + if (auto attr = valueOrAttr.dyn_cast()) + return attr.cast().getInt(); + return {}; +} + +/// Checks if `valueOrAttr` represents a constant value `val`. +static bool isValue(OpFoldResult valueOrAttr, int64_t val) { + auto attr = valueOrAttr.dyn_cast(); + return attr && attr.cast().getValue() == val; +} + +/// Returns a memref.subview or a tensor.extract_slice based on the type of the +/// `source`. +static Value getSlice(OpBuilder &b, Location loc, Value source, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + return TypeSwitch(source.getType()) + .Case([&](RankedTensorType t) -> Value { + return b.create(loc, source, offsets, sizes, + strides); + }) + .Case([&](MemRefType type) -> Value { + return b.create(loc, source, offsets, sizes, + strides); + }) + .Default([&](Type t) { return nullptr; }); +} + +static Value getDimValue(OpBuilder &builder, Location loc, Value v, + int64_t dim) { + return TypeSwitch(v.getType()) + .Case([&](RankedTensorType t) -> Value { + return builder.create(loc, v, dim); + }) + .Case([&](MemRefType t) -> Value { + return builder.create(loc, v, dim); + }) + .Default([&](Type t) { return Value(); }); +} + +static OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, + int64_t dim) { + auto t = v.getType().cast(); + if (t.isDynamicDim(dim)) + return getDimValue(builder, loc, v, dim); + return builder.getI64IntegerAttr(t.getDimSize(dim)); +} + +//===----------------------------------------------------------------------===// +// Tiling Interface for Test Dialect operations. +//===----------------------------------------------------------------------===// + +namespace { +struct TestFullSizeOutputTileTilingInterface + : public TilingInterface::ExternalModel< + TestFullSizeOutputTileTilingInterface, test::TestFullSizeOutputTile> { + SmallVector getLoopIteratorTypes(Operation *) const { + return {getParallelIteratorTypeName()}; + } + + SmallVector getLoopBounds(Operation *op, OpBuilder &builder) const { + auto testOp = cast(op); + Location loc = op->getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + Value ub = getDimValue(builder, loc, testOp.inputs()[0], 0); + return {Range{zero, ub, one}}; + } + + Operation *getTiledImplementation( + Operation *op, OpBuilder &builder, ValueRange dest, + ArrayRef offsets, ArrayRef sizes, + SmallVectorImpl> &resultOffsets, + SmallVectorImpl> &resultSizes) const { + assert(dest.size() == 1 && offsets.size() == 1 && sizes.size() == 1); + auto testOp = cast(op); + Location loc = op->getLoc(); + auto zeroAttr = builder.getI64IntegerAttr(0); + auto oneAttr = builder.getI64IntegerAttr(1); + Value lhs = testOp.inputs()[0]; + Value rhs = testOp.inputs()[1]; + Value source = testOp.outputs()[0]; + + // Slice of the updates. + auto updateRank = lhs.getType().cast().getRank(); + SmallVector updateOffsets(updateRank, zeroAttr); + SmallVector updateSizes(updateRank, zeroAttr); + updateOffsets[0] = offsets[0]; + updateSizes[0] = sizes[0]; + for (auto dim : llvm::seq(1, updateRank)) + updateSizes[dim] = getDim(builder, loc, lhs, dim); + + SmallVector updateStrides(updateRank, oneAttr); + Value tiledUpdate = + getSlice(builder, loc, lhs, updateOffsets, updateSizes, updateStrides); + assert(tiledUpdate && "failed to get slice of update"); + + // Slice of indices. + auto indicesRank = rhs.getType().cast().getRank(); + SmallVector indicesOffsets(indicesRank, zeroAttr); + SmallVector indicesSizes(indicesRank, zeroAttr); + indicesOffsets[0] = offsets[0]; + indicesSizes[0] = sizes[0]; + for (auto dim : llvm::seq(1, indicesRank)) + indicesSizes[dim] = getDim(builder, loc, rhs, dim); + SmallVector indicesStrides(indicesRank, oneAttr); + Value tiledIndices = getSlice(builder, loc, rhs, indicesOffsets, + indicesSizes, indicesStrides); + assert(tiledIndices && "failed to get slice of indices"); + + resultOffsets.resize(1); + resultOffsets[0].resize(updateRank, zeroAttr); + resultSizes.resize(1); + resultSizes[0].resize(updateRank); + for (auto dim : llvm::seq(0, updateRank)) + resultSizes[0][dim] = getDim(builder, loc, source, dim); + SmallVector resultTypes; + if (op->getNumResults()) + resultTypes.push_back(op->getResultTypes()[0]); + return builder.create( + loc, resultTypes, ValueRange{tiledUpdate, tiledIndices}, + ValueRange{dest[0]}); + } +}; + +struct TestMixedReduceParallelTilingInterface + : public TilingInterface::ExternalModel< + TestMixedReduceParallelTilingInterface, + test::TestMixedReduceParallel> { + SmallVector getLoopIteratorTypes(Operation *op) const { + auto testOp = cast(op); + // All loops except the dimension to sort along are parallel. + int64_t operandRank = + testOp.outputs()[0].getType().cast().getRank(); + SmallVector iteratorTypes(operandRank, + getParallelIteratorTypeName()); + iteratorTypes[testOp.reduce_dim()] = getReductionIteratorTypeName(); + return iteratorTypes; + } + + SmallVector getLoopBounds(Operation *op, OpBuilder &builder) const { + auto testOp = cast(op); + int64_t operandRank = + testOp.outputs()[0].getType().cast().getRank(); + SmallVector loopBounds(operandRank); + Location loc = op->getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + Value source = testOp.outputs()[0]; + for (auto dim : llvm::seq(0, operandRank)) { + loopBounds[dim].offset = zero; + loopBounds[dim].size = getDimValue(builder, loc, source, dim); + loopBounds[dim].stride = one; + } + return loopBounds; + } + + Operation *getTiledImplementation( + Operation *op, OpBuilder &builder, ValueRange dest, + ArrayRef offsets, ArrayRef sizes, + SmallVectorImpl> &resultOffsets, + SmallVectorImpl> &resultSizes) const { + auto testOp = cast(op); + assert(dest.size() == testOp.outputs().size()); + int64_t rank = testOp.outputs()[0].getType().cast().getRank(); + assert(offsets.size() == static_cast(rank) && + sizes.size() == static_cast(rank)); + auto oneAttr = builder.getI64IntegerAttr(1); + SmallVector strides(rank, oneAttr); + Location loc = op->getLoc(); + SmallVector tiledOperands(dest.size()); + resultOffsets.resize(dest.size()); + resultSizes.resize(dest.size()); + for (auto en : llvm::enumerate(dest)) { + tiledOperands[en.index()] = + getSlice(builder, loc, en.value(), offsets, sizes, strides); + assert(tiledOperands[en.index()] && "failed to get slice of operand"); + resultOffsets[en.index()].assign(offsets.begin(), offsets.end()); + resultSizes[en.index()].assign(sizes.begin(), sizes.end()); + } + SmallVector resultTypes; + if (op->getNumResults()) { + resultTypes = llvm::to_vector<4>( + llvm::map_range(tiledOperands, [&](Value v) { return v.getType(); })); + } + return builder.create( + loc, resultTypes, ValueRange{}, tiledOperands, testOp.reduce_dim()); + } +}; + +//===----------------------------------------------------------------------===// +// Interface implementations for external operations. +//===----------------------------------------------------------------------===// + +struct InsertSliceTilingInterface + : public TilingInterface::ExternalModel { + SmallVector getLoopIteratorTypes(Operation *op) const { + auto insertSliceOp = cast(op); + return SmallVector(insertSliceOp.getSourceType().getRank(), + getParallelIteratorTypeName()); + } + + SmallVector getLoopBounds(Operation *op, OpBuilder &b) const { + auto insertSliceOp = cast(op); + Value source = insertSliceOp.source(); + RankedTensorType sourceType = insertSliceOp.getSourceType(); + Location loc = op->getLoc(); + Value zero = b.create(loc, 0); + Value one = b.create(loc, 1); + SmallVector loopBounds(sourceType.getRank(), + Range{zero, nullptr, one}); + for (auto dim : + llvm::seq(0, insertSliceOp.getSourceType().getRank())) + loopBounds[dim].size = b.create(loc, source, dim); + return loopBounds; + } + + Operation *getTiledImplementation( + Operation *op, OpBuilder &b, ValueRange dest, + ArrayRef offsets, ArrayRef sizes, + SmallVectorImpl> &resultOffsets, + SmallVectorImpl> &resultSizes) const { + // Compute a subtensor of the source based on the offsets. + auto insertOp = cast(op); + auto opOffsets = insertOp.getMixedOffsets(); + auto opSizes = insertOp.getMixedSizes(); + auto opStrides = insertOp.getMixedStrides(); + if (!llvm::all_of(opStrides, [&](OpFoldResult valueOrAttr) { + return isValue(valueOrAttr, 1); + })) { + op->emitOpError("unable to tile operation with non-unit stride"); + return nullptr; + } + // The operation returned is just a tensor.extract_slice of the source with + // the given offsets, sizes and strides. Setting the correct result offset + // will make the sure the tiling algorithm will insert this slice into the + // correct place in the destination. + // The result offset is just the offset passed in plus the offset specified + // in the op (since all strides are checked to be 1). + unsigned offsetIndex = 0; + ArrayRef sourceShape = insertOp.getSourceType().getShape(); + int64_t destRank = insertOp.getType().getRank(); + resultOffsets.resize(1); + resultOffsets[0].resize(destRank); + resultSizes.resize(1); + resultSizes[0].resize(destRank); + Location loc = insertOp.getLoc(); + auto zeroAttr = b.getI64IntegerAttr(0); + auto oneAttr = b.getI64IntegerAttr(1); + SmallVector strides(offsets.size(), oneAttr); + for (auto opOffset : llvm::enumerate(opOffsets)) { + // Check for rank-reducing by checking that + // 1) The corresponding opSize value is 1 + // 2) The current rank of the source is not 1. + // Then the opOffset is for the rank-reduced dimension. Skip. + unsigned opOffsetIndex = opOffset.index(); + if (isValue(opSizes[opOffsetIndex], 1) && sourceShape[offsetIndex] != 1) { + resultOffsets[0][opOffsetIndex] = zeroAttr; + resultSizes[0][opOffsetIndex] = oneAttr; + continue; + } + OpFoldResult opOffsetVal = opOffset.value(); + OpFoldResult offset = offsets[offsetIndex]; + if (opOffsetVal.is() && offset.is()) { + resultOffsets[0][opOffsetIndex] = b.getI64IntegerAttr( + *getConstantValue(opOffsetVal) + *getConstantValue(offset)); + } else { + AffineMap map = AffineMap::get( + 1, 1, {b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0)}); + resultOffsets[0][opOffsetIndex] = + b.create(loc, map, + ValueRange{getValue(b, loc, offset), + getValue(b, loc, opOffsetVal)}) + .getResult(); + } + resultSizes[0][opOffsetIndex] = sizes[offsetIndex]; + offsetIndex++; + } + return b.create(loc, insertOp.source(), offsets, + sizes, strides); + } +}; + +template +struct TestOpTilingPattern : public TilingInterfaceBasePattern { + TestOpTilingPattern(MLIRContext *context, linalg::LinalgTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : TilingInterfaceBasePattern(OpTy::getOperationName(), context, options, + filter, benefit) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + OpTy testOp = dyn_cast(op); + if (!testOp) + return failure(); + TiledOp tiledOp; + // Check for failure. + if (failed(TilingInterfaceBasePattern::matchAndRewriteBase( + op, testOp.outputs(), rewriter, tiledOp))) + return failure(); + + // Check for do-nothing case. + if (!tiledOp.op) + return failure(); + if (tiledOp.op != op) { + if (tiledOp.results.empty()) + rewriter.eraseOp(op); + else + rewriter.replaceOp(op, tiledOp.results); + } + return success(); + } +}; + +struct InsertSliceTilingPattern : public TilingInterfaceBasePattern { + InsertSliceTilingPattern(MLIRContext *context, + linalg::LinalgTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : TilingInterfaceBasePattern(tensor::InsertSliceOp::getOperationName(), + context, options, filter, benefit) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + tensor::InsertSliceOp insertSliceOp = cast(op); + TiledOp tiledOp; + // Check for failure. + if (failed(TilingInterfaceBasePattern::matchAndRewriteBase( + insertSliceOp, insertSliceOp.dest(), rewriter, tiledOp))) + return failure(); + + // Check for do-nothing case. + if (!tiledOp.op) + return failure(); + if (tiledOp.op != op) { + if (tiledOp.results.empty()) + rewriter.eraseOp(op); + else + rewriter.replaceOp(op, tiledOp.results); + } + return success(); + } +}; + +struct TestTilingInterfacePass + : public PassWrapper { + StringRef getArgument() const final { return "test-tiling-interface"; } + StringRef getDescription() const final { return "Test Tiling Interface."; } + TestTilingInterfacePass() = default; + TestTilingInterfacePass(const TestTilingInterfacePass &pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + LogicalResult initialize(MLIRContext *context) override; + void runOnFunction() override; +}; +} // namespace + +void TestTilingInterfacePass::runOnFunction() { + FuncOp funcOp = getOperation(); + MLIRContext *context = funcOp.getContext(); + + RewritePatternSet patterns(context); + patterns.add>( + context, linalg::LinalgTilingOptions().setTileSizes({10, 20}), + linalg::LinalgTransformationFilter( + Identifier::get("tiling_input", context), + Identifier::get("tiling_output", context))); + patterns.add>( + context, linalg::LinalgTilingOptions().setTileSizes(ArrayRef{0}), + linalg::LinalgTransformationFilter( + Identifier::get("no_tiling_input", context), + Identifier::get("no_tiling_output", context))); + + patterns.add>( + context, linalg::LinalgTilingOptions().setTileSizes({0, 20}), + linalg::LinalgTransformationFilter( + Identifier::get("outer_reduce_input", context), + Identifier::get("outer_reduce_output", context))); + patterns.add>( + context, linalg::LinalgTilingOptions().setTileSizes({10, 0, 0}), + linalg::LinalgTransformationFilter( + Identifier::get("inner_reduce_input", context), + Identifier::get("inner_reduce_output", context))); + + static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = { + [](OpBuilder &builder, Location loc, ArrayRef parallelLoopRanges) { + auto numParallelDims = parallelLoopRanges.size(); + + SmallVector procInfo(numParallelDims); + Type indexType = builder.getIndexType(); + std::string dimStr[3] = {"x", "y", "z"}; + for (size_t dim = 0; dim < numParallelDims; ++dim) { + StringAttr attr = builder.getStringAttr(dimStr[dim]); + procInfo[numParallelDims - dim - 1] = { + builder.create(loc, indexType, attr), + builder.create(loc, indexType, attr)}; + } + return procInfo; + }, + {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic, + linalg::DistributionMethod::Cyclic}, + DenseMap>()}; + + patterns.add, + TestOpTilingPattern>( + context, + linalg::LinalgTilingOptions() + .setTileSizes(ArrayRef{10, 0, 30}) + .setDistributionOptions(workgroupDistributionOptions), + linalg::LinalgTransformationFilter( + Identifier::get("distribute_input", context), + Identifier::get("distribute_output", context))); + + patterns.add( + context, linalg::LinalgTilingOptions().setTileSizes({10, 20}), + linalg::LinalgTransformationFilter( + Identifier::get("tiling_input", context), + Identifier::get("tiling_output", context))); + + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) + return signalPassFailure(); +} + +LogicalResult TestTilingInterfacePass::initialize(MLIRContext *context) { + tensor::InsertSliceOp::attachInterface(*context); + test::TestFullSizeOutputTile::attachInterface< + TestFullSizeOutputTileTilingInterface>(*context); + test::TestMixedReduceParallel::attachInterface< + TestMixedReduceParallelTilingInterface>(*context); + return success(); +} + +namespace mlir { +namespace test { +void registerTestTilingInterfacePass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -102,6 +102,7 @@ void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestRecursiveTypesPass(); void registerTestSCFUtilsPass(); +void registerTestTilingInterfacePass(); void registerTestVectorConversions(); } // namespace test } // namespace mlir @@ -182,6 +183,7 @@ test::registerTestPDLByteCodePass(); test::registerTestRecursiveTypesPass(); test::registerTestSCFUtilsPass(); + test::registerTestTilingInterfacePass(); test::registerTestVectorConversions(); } #endif diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -731,6 +731,13 @@ actual = ":SideEffectInterfacesTdFiles", ) +td_library( + name = "TilingInterfaceTdFiles", + srcs = ["include/mlir/Interfaces/TilingInterface.td"], + includes = ["include"], + deps = [":OpBaseTdFiles"], +) + td_library( name = "VectorInterfacesTdFiles", srcs = ["include/mlir/Interfaces/VectorInterfaces.td"], @@ -4515,6 +4522,24 @@ actual = "SideEffectInterfaces", ) +gentbl_cc_library( + name = "TilingInterfaceIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Interfaces/TilingInterface.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Interfaces/TilingInterface.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Interfaces/TilingInterface.td", + deps = [":TilingInterfaceTdFiles"], +) + cc_library( name = "Analysis", srcs = glob( @@ -4988,6 +5013,7 @@ ":StandardToSPIRV", ":TensorDialect", ":TensorTransforms", + ":TilingTransform", ":TosaDialect", ":TosaToLinalg", ":Transforms", @@ -5046,6 +5072,7 @@ "//mlir/test:TestShapeDialect", "//mlir/test:TestStandardOps", "//mlir/test:TestStandardToLLVM", + "//mlir/test:TestTilingInterface", "//mlir/test:TestTosaDialect", "//mlir/test:TestTransforms", "//mlir/test:TestTypeDialect", @@ -6041,6 +6068,42 @@ ], ) +cc_library( + name = "TilingInterface", + srcs = ["lib/Interfaces/TilingInterface.cpp"], + hdrs = ["include/mlir/Interfaces/TilingInterface.h"], + includes = ["include"], + deps = [ + ":IR", + ":Support", + ":TilingInterfaceIncGen", + ":ViewLikeInterface", + "//llvm:Support", + ], +) + +cc_library( + name = "TilingTransform", + srcs = glob([ + "lib/Dialect/Linalg/TilingInterface/*.cpp", + "lib/Dialect/Linalg/TilingInterface/*.h", + ]), + hdrs = glob([ + "include/mlir/Dialect/Linalg/TilingInterface/*.h", + ]), + includes = ["include"], + deps = [ + ":Affine", + ":IR", + ":LinalgTransforms", + ":MemRefDialect", + ":SCFDialect", + ":StandardOps", + ":TensorDialect", + ":TilingInterface", + ], +) + td_library( name = "VectorOpsTdFiles", srcs = ["include/mlir/Dialect/Vector/VectorOps.td"], diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -45,6 +45,7 @@ "//mlir:include/mlir/Interfaces/CopyOpInterface.td", "//mlir:include/mlir/Interfaces/DataLayoutInterfaces.td", "//mlir:include/mlir/Interfaces/InferTypeOpInterface.td", + "//mlir:include/mlir/Interfaces/TilingInterface.td", ], deps = [ "//mlir:OpBaseTdFiles", @@ -234,6 +235,7 @@ "//mlir:StandardOps", "//mlir:StandardOpsTransforms", "//mlir:TensorDialect", + "//mlir:TilingInterface", "//mlir:TransformUtils", "//mlir:Transforms", ], @@ -479,6 +481,28 @@ ], ) +cc_library( + name = "TestTilingInterface", + srcs = glob(["lib/Interfaces/TilingInterface/*.cpp"]), + includes = ["lib/Dialect/Test"], + deps = [ + ":TestDialect", + "//llvm:Support", + "//mlir:Affine", + "//mlir:GPUDialect", + "//mlir:IR", + "//mlir:LinalgTransforms", + "//mlir:MemRefDialect", + "//mlir:Pass", + "//mlir:SCFDialect", + "//mlir:StandardOps", + "//mlir:TensorDialect", + "//mlir:TilingInterface", + "//mlir:TilingTransform", + "//mlir:TransformUtils", + ], +) + cc_library( name = "TestVector", srcs = glob(["lib/Dialect/Vector/*.cpp"]),