diff --git a/mlir/include/mlir/Dialect/Linalg/TilingInterface/Tiling.h b/mlir/include/mlir/Dialect/Linalg/TilingInterface/Tiling.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Linalg/TilingInterface/Tiling.h @@ -0,0 +1,61 @@ +//===- Tiling.h - Tiling transformsion using TilingInterface ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef DIALECT_LINALG_TILINGINTERFACE_TILING_H_ +#define DIALECT_LINALG_TILINGINTERFACE_TILING_H_ + +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Interfaces/TilingInterface.h" + +namespace mlir { + +/// Structure to represent the result of tiling operation using TilingInterface. +struct TiledOp { + /// Tiled op. + Operation *op; + /// Loops generated during tiling. + SmallVector loops; + /// Values that are replacements for the untiled operations. + SmallVector results; +}; + +/// Main entry point for tiling LinalgExtOps using TilingInterface. If the `op` +/// does not implement the `TilingInterface` returns a `TiledOp{}` value. +FailureOr +tileOpUsingInterface(OpBuilder &b, TilingInterface tilableOp, + const linalg::LinalgTilingOptions &options); + +/// Base pattern for tiling TilingInterface. Patterns can inherit from this +/// class and implement the `matchAndRewrite` method to call into the +/// `matchAndRewriteBase` method of this class. Note that this method, does not +/// delete the operation. Instead the `matchAndRewriteBase` method returns +/// failure if an error was encountered. If the tiled implementation of the +/// operation wasnt found, this method returns `TiledOp{}`. +struct TilingInterfaceBasePattern : public RewritePattern { + TilingInterfaceBasePattern(StringRef opName, MLIRContext *context, + linalg::LinalgTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : RewritePattern(opName, benefit, context), filter(filter), + options(options) {} + + LogicalResult matchAndRewriteBase(TilingInterface tilableOp, + PatternRewriter &rewriter, + TiledOp &result) const; + +private: + /// LinalgTransformMarker handles special attribute manipulations. + linalg::LinalgTransformationFilter filter; + /// Options to control tiling; + linalg::LinalgTilingOptions options; +}; + +} // namespace mlir + +#endif // DIALECT_LINALT_TILINGINTERFACE_TILING_H_ diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -6,6 +6,7 @@ add_mlir_interface(InferTypeOpInterface) add_mlir_interface(LoopLikeInterface) add_mlir_interface(SideEffectInterfaces) +add_mlir_interface(TilingInterface) add_mlir_interface(VectorInterfaces) add_mlir_interface(ViewLikeInterface) diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/TilingInterface.h @@ -0,0 +1,26 @@ +//===- TilingInterface.h - Interface for tiling operations ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the TilingInterface defined in +// `TilingInterface.td`. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_TILINGINTERFACE_H_ +#define MLIR_INTERFACES_TILINGINTERFACE_H_ + +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Operation.h" +#include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/Support/LLVM.h" + +/// Include the ODS generated interface header files. +#include "mlir/Interfaces/TilingInterface.h.inc" + +#endif // MLIR_INTERFACES_TILINGINTERFACE_H_ diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/TilingInterface.td @@ -0,0 +1,100 @@ +//===- TilingInterface.td - Interface for tiling operations *- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains an interface to allow operations to generate a tiled +// implementation of themselves. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_TILINGINTERFACE +#define MLIR_TILINGINTERFACE + +include "mlir/IR/OpBase.td" + +def TilingInterface : OpInterface<"TilingInterface"> { + let description = [{ + Interface for allowing operations to expose information needed to + tile them (similar to LinalgOp, but without having access to + indexing maps) + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Returns a list of operands into which the result of the + tiled implementation is written into. With `tensor` + operands, this will be used as the initial tensor into which + the tiled results are inserted into. With `memref` operands, + this will be the operand into which the result of the tiled + operation is written into. + }], + /*retType=*/"SmallVector", + /*methodName=*/"getDestinationOperands", + /*args=*/(ins "OpBuilder &":$b), + /*methodBody=*/"", + /*defaultImplementation=*/"return ValueRange{};" + >, + InterfaceMethod< + /*desc=*/[{ + Returns a list of `StringRef`s that describe the number of + loops and the iterator types of the operation. The list is + expected to use + `getParallelIteratorTypeName()`/`getReductionIteratorTypeName()` + from MLIR Structured Op Utils. + }], + /*retType=*/"SmallVector", + /*methodName=*/"getLoopIteratorTypes" + >, + InterfaceMethod< + /*desc=*/[{ + Returns a list of ranges that describe the loop bounds and + step for the loops of the operation. + }], + /*retTy=*/"SmallVector", + /*methodName=*/"getLoopBounds", + /*args=*/(ins "OpBuilder &":$b) + >, + InterfaceMethod< + /*desc=*/[{ + Method to generate the tiled implementation of an operation. + + The iteration space of the operation is returned by + `getLoopBounds`. The caller provides the information of the + tile within this iteration space whose implementation the + caller needs. + - `offsets` provides the offset of the tile within the + iteration space + - `sizes` provides the size of the tile. + - `dest` are the Value into which the result of the tiled + operation is to be inserted into. The type of the `dest` + Values is same as the types returned by + `getDestinationOperands` method. + - When the operands of the operation are `tensor` types, the + result of the tiled operation are inserted into the + corresponding `dest` values, and a new tensor is created + (using destructive-update paradigm). These new values are + to be returned to the caller in the `results` vector. When + the operands of the operation are `memref` types, this + vector can be left empty. + }], + /*retType=*/"Operation *", + /*methodName=*/"getTiledImplementation", + /*args=*/(ins + "OpBuilder &":$b, + "ValueRange ":$dest, + "ArrayRef ":$offsets, + "ArrayRef ":$sizes, + "SmallVector &":$results), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return nullptr; + }] + > + ]; +} +#endif // MLIR_TILINGINTERFACE diff --git a/mlir/lib/Dialect/Linalg/CMakeLists.txt b/mlir/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(Analysis) add_subdirectory(IR) +add_subdirectory(TilingInterface) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Linalg/TilingInterface/CMakeLists.txt b/mlir/lib/Dialect/Linalg/TilingInterface/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/TilingInterface/CMakeLists.txt @@ -0,0 +1,16 @@ +add_mlir_dialect_library(MLIRTilingTransform + Tiling.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Linalg/ + + LINK_LIBS_PUBLIC + MLIRAffine + MLIRIR + MLIRLinalgTransforms + MLIRMemRef + MLIRSCF + MLIRStandard + MLIRTensor + MLIRTilingInterface +) diff --git a/mlir/lib/Dialect/Linalg/TilingInterface/Tiling.cpp b/mlir/lib/Dialect/Linalg/TilingInterface/Tiling.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/TilingInterface/Tiling.cpp @@ -0,0 +1,256 @@ +//===- Tiling.cpp - Implementation of Tiling using TilingInterface --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the Tiling using Tiling Interface. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/TilingInterface/Tiling.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Interfaces/TilingInterface.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utility methods for tiling a linalg_ext operation that implements a +// TilingInterface +//===----------------------------------------------------------------------===// + +/// Returns failure if the options are unsupported. +static LogicalResult +verifySupportedTilingOptions(PatternRewriter &rewriter, Operation *op, + const linalg::LinalgTilingOptions &options) { + if (!options.interchangeVector.empty()) { + return rewriter.notifyMatchFailure(op, + "unsupported interchange during tiling"); + } + if (options.paddingValueComputationFunction) { + return rewriter.notifyMatchFailure(op, "unsupported tile + pad option"); + } + if (options.loopType != linalg::LinalgTilingLoopType::Loops) { + return rewriter.notifyMatchFailure(op, + "only tiling with scf.for is supported"); + } + if (options.distribution) { + if (llvm::any_of(options.distribution->distributionMethod, + [](linalg::DistributionMethod method) { + return method != linalg::DistributionMethod::Cyclic; + })) { + return rewriter.notifyMatchFailure(op, + "only cyclic distibution is allowed"); + } + } + return success(); +} + +/// Converts an `OpFoldResult` to a `Value` by building a constant op if +/// if the `OpFoldResult` is an `IntegerAttr`. +static Value getValue(OpBuilder &builder, Location loc, + OpFoldResult valueOrAttr) { + if (auto attr = valueOrAttr.dyn_cast()) { + return builder.create(loc, + attr.cast().getInt()); + } + return valueOrAttr.get(); +} + +/// Checks if `valueOrAttr` represents a constant value `val`. +static bool isValue(OpFoldResult valueOrAttr, int64_t val) { + auto attr = valueOrAttr.dyn_cast(); + return attr && attr.cast().getValue() == val; +} + +/// Returns true if loop is untiled. Only checks if the value is statically +/// zero. It is assumed that a `Value` defined by a constant op is already +/// converted to an `IntegerAttr` of that value. So here just return true if +/// this is an attribute with a zero value. +static bool isUntiledLoop(OpFoldResult valueOrAttr) { + return isValue(valueOrAttr, 0); +} + +/// Generates the tiled loops and the body by invoking the interface methods of +/// TilingInterface. This method generates a single tiled loop and calls itself +/// recursively to generate all the tiled loops (the recursive call increases +/// the `loopDepth` value by 1). For every invocation the function appends to +/// the `offsets` list with the offset to be used within the body of the loopfor +/// the dimension of the iteration space being tiled. +/// - `outputs` are the operands to use for outputs of the tiled operation. +/// - `tileSizes` are tile sizes specified for all loops of the operation. If a +/// loop is to be untiled it is set to 0. +/// - `iteratorType` is the type of the loop iterator returned by the +/// TilingInterface. +/// - `loopBounds` are the bounds of all the loops of the op returned by the +/// TilingInterface. +/// - `loopDepth` is the current loop depth being processed. +/// - `offsets` are the `Value`s that represent the position of the tile being +/// operated on. The offsets are computed as the tiled loops are being +/// generated. +/// - `distributionInfo` is the proc_id and nprocs `Value`s to be used for +/// distributed loops. It is a stack, and once an entry at the top of the +/// stack is used for distribution it is popped before processing the inner +/// loops. +static FailureOr tileOpUsingInterfaceImpl( + OpBuilder &builder, TilingInterface tilableOp, ValueRange outputs, + MutableArrayRef tileSizes, ArrayRef iteratorTypes, + ArrayRef loopBounds, unsigned loopDepth, + SmallVectorImpl &offsets, + ArrayRef distributionInfo) { + Location loc = tilableOp.getLoc(); + // If this is the innermost loop, then generated the tiled implementation of + // the op by invoking the TilingInterface methods. + if (loopDepth == tileSizes.size()) { + TiledOp ret; + ret.op = tilableOp.getTiledImplementation(builder, outputs, offsets, + tileSizes, ret.results); + if (!ret.op) { + return static_cast( + tilableOp.emitOpError("failed to get tiled implementation")); + } + return ret; + } + + // If tile size at this depth is empty, do nothing. + if (isUntiledLoop(tileSizes[loopDepth])) { + auto zeroAttr = builder.getI64IntegerAttr(0); + offsets.push_back(zeroAttr); + assert(matchPattern(loopBounds[loopDepth].offset, m_Zero()) && + "expected loop bounds to have lower bound of zero"); + tileSizes[loopDepth] = getAsOpFoldResult(loopBounds[loopDepth].size); + return tileOpUsingInterfaceImpl(builder, tilableOp, outputs, tileSizes, + iteratorTypes, loopBounds, loopDepth + 1, + offsets, distributionInfo); + } + + // Generate an scf.for for the current loop depth. + Value lb = loopBounds[loopDepth].offset; + Value ub = loopBounds[loopDepth].size; + if (!matchPattern(loopBounds[loopDepth].stride, m_One())) { + return static_cast( + tilableOp.emitOpError("expected stride to be 1")); + } + Value step = getValue(builder, loc, tileSizes[loopDepth]); + + // Update lb, ub and step for cyclic distribution. + if (!distributionInfo.empty() && + iteratorTypes[loopDepth] == getParallelIteratorTypeName()) { + linalg::updateBoundsForCyclicDistribution( + builder, loc, distributionInfo.front().procId, + distributionInfo.front().nprocs, lb, ub, step); + distributionInfo = distributionInfo.drop_front(); + } + FailureOr innerReturnValue; + bool isBufferTiling = tilableOp->getNumResults() == 0; + ValueRange initValues(isBufferTiling ? ValueRange{} : outputs); + auto forOp = builder.create( + loc, lb, ub, step, initValues, + [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { + offsets.push_back(iv); + auto affineMaps = AffineMap::inferFromExprList({ArrayRef{ + b.getAffineSymbolExpr(0), + b.getAffineSymbolExpr(1) - b.getAffineDimExpr(0)}})[0]; + // Similar to linalg tiling, the tile size is the min(tileSizes, ub - + // iv) to account for cases where tile size does not divide (ub - lb) + // exactly. + Value inBoundsTileSize = b.create( + loc, affineMaps, + ValueRange{iv, getValue(builder, loc, tileSizes[loopDepth]), ub}); + tileSizes[loopDepth] = getAsOpFoldResult(inBoundsTileSize); + // Recursively proceed to generate the tiled loop for the next level. + innerReturnValue = tileOpUsingInterfaceImpl( + b, tilableOp, (isBufferTiling ? outputs : args), tileSizes, + iteratorTypes, loopBounds, loopDepth + 1, offsets, + distributionInfo); + if (failed(innerReturnValue)) + return; + b.create(loc, innerReturnValue->results); + }); + if (failed(innerReturnValue)) + return innerReturnValue; + innerReturnValue->loops.insert(innerReturnValue->loops.begin(), + forOp.getOperation()); + innerReturnValue->results = forOp.getResults(); + return innerReturnValue; +} + +FailureOr +mlir::tileOpUsingInterface(OpBuilder &b, TilingInterface tilableOp, + const linalg::LinalgTilingOptions &options) { + SmallVector iteratorTypes = tilableOp.getLoopIteratorTypes(); + SmallVector tileSizesVals = + options.tileSizeComputationFunction(b, tilableOp); + auto zeroAttr = b.getI64IntegerAttr(0); + + // The actual tile sizes used converts `Value` defined as constant 0, to a + // zero integer attributes. Currently if the iterator type is not "parallel", + // the tile size is forced to zero as well. + auto tileSizes = getAsOpFoldResult(tileSizesVals); + tileSizes.resize(iteratorTypes.size(), zeroAttr); + for (auto en : llvm::enumerate(iteratorTypes)) { + if (en.value() == getParallelIteratorTypeName()) + continue; + if (!isUntiledLoop(tileSizes[en.index()])) { + return static_cast(tilableOp.emitOpError( + "unimplemented tiling of non-parallel loop iterator type")); + } + } + + // Trivial early exit case of tile sizes being zero for all parallel loops. + if (llvm::all_of(tileSizes, isUntiledLoop)) + return TiledOp{tilableOp, {}, {}}; + + SmallVector loopBounds = tilableOp.getLoopBounds(b); + SmallVector distributionInfo; + // If the tiled loops are distributed, get the proc_id and nprocs for the + // distributed loops. First collect the parallel loops by iterating over the + // tileSizes and getting the loops that are distribute, i.e., + // - parallel, i.e. iteratorTypes is "parallel" + // - tiled, i.e. tileSize != 0 + if (options.distribution) { + SmallVector distributedLoopRange; + for (auto i : llvm::seq(0, tileSizes.size())) { + if (isUntiledLoop(tileSizes[i])) + continue; + if (iteratorTypes[i] != getParallelIteratorTypeName()) + continue; + distributedLoopRange.push_back(loopBounds[i]); + } + distributionInfo = options.distribution->procInfo(b, tilableOp.getLoc(), + distributedLoopRange); + } + + SmallVector offsets; + SmallVector dest = tilableOp.getDestinationOperands(b); + return tileOpUsingInterfaceImpl(b, tilableOp, dest, tileSizes, iteratorTypes, + loopBounds, 0, offsets, distributionInfo); +} + +//===----------------------------------------------------------------------===// +// Definintion of methods for `TilingInterfaceBasePattern`. +//===----------------------------------------------------------------------===// + +LogicalResult +mlir::TilingInterfaceBasePattern::matchAndRewriteBase(TilingInterface tilableOp, + PatternRewriter &rewriter, + TiledOp &result) const { + if (failed(filter.checkAndNotify(rewriter, tilableOp))) + return failure(); + if (failed(verifySupportedTilingOptions(rewriter, tilableOp, options))) + return failure(); + + FailureOr res = tileOpUsingInterface(rewriter, tilableOp, options); + if (failed(res)) + return res; + result = *res; + if (result.op) + filter.replaceLinalgTransformationFilter(rewriter, result.op); + return success(); +} diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -8,6 +8,7 @@ InferTypeOpInterface.cpp LoopLikeInterface.cpp SideEffectInterfaces.cpp + TilingInterface.cpp VectorInterfaces.cpp ViewLikeInterface.cpp ) @@ -37,6 +38,6 @@ add_mlir_interface_library(InferTypeOpInterface) add_mlir_interface_library(LoopLikeInterface) add_mlir_interface_library(SideEffectInterfaces) +add_mlir_interface_library(TilingInterface) add_mlir_interface_library(VectorInterfaces) add_mlir_interface_library(ViewLikeInterface) - diff --git a/mlir/lib/Interfaces/TilingInterface.cpp b/mlir/lib/Interfaces/TilingInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/TilingInterface.cpp @@ -0,0 +1,17 @@ +//===- TilingInterface.cpp - Tiling interface -------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the definitions of the interface in `TilingInterface.td`. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/TilingInterface.h" + +namespace mlir { +#include "mlir/Interfaces/TilingInterface.cpp.inc" +} diff --git a/mlir/test/Interfaces/TilingInterface/tiling.mlir b/mlir/test/Interfaces/TilingInterface/tiling.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Interfaces/TilingInterface/tiling.mlir @@ -0,0 +1,427 @@ +// RUN: mlir-opt -test-tiling-interface -split-input-file %s | FileCheck %s + +func @scatter_tiling( + %original: tensor, %indices: tensor, + %update : tensor) -> tensor { + %0 = test.scatter {__internal_linalg_transform__ = "tiling_input"} + %update, %indices, %original : tensor, tensor, tensor -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK: func @scatter_tiling( +// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[TILESIZEY:.+]] = constant 10 : index +// CHECK-DAG: %[[TILESIZEX:.+]] = constant 20 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[UPDATES]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[UPDATES]], %[[C1]] +// CHECK: %[[RESULT0:.+]] = scf.for %[[IV0:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZEY]] +// CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[ORIGINAL]]) +// CHECK-DAG: %[[USED_TILESIZEY:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[TILESIZEY]], %[[D0]]] +// CHECK: %[[RESULT1:.+]] = scf.for %[[IV1:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZEX]] +// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]]) +// CHECK-DAG: %[[USED_TILESIZEX:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[TILESIZEX]], %[[D1]]] +// CHECK: %[[UPDATE_SLICE:.+]] = tensor.extract_slice %[[UPDATES]][%[[IV0]], %[[IV1]]] +// CHECK-SAME: [%[[USED_TILESIZEY]], %[[USED_TILESIZEX]]] +// CHECK: %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV0]], 0] +// CHECK-SAME: [%[[USED_TILESIZEY]], 1] +// CHECK: %[[SLICE_D0:.+]] = tensor.dim %[[ORIGINAL]], %[[C0]] +// CHECK: %[[SOURCE_SLICE:.+]] = tensor.extract_slice %[[INIT1]][0, %[[IV1]]] +// CHECK-SAME: [%[[SLICE_D0]], %[[USED_TILESIZEX]]] +// CHECK: %[[SCATTER_TILE:.+]] = test.scatter +// CHECK-SAME: __internal_linalg_transform__ = "tiling_output" +// CHECK-SAME: %[[UPDATE_SLICE]], %[[INDEX_SLICE]], %[[SOURCE_SLICE]] +// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INIT1]][0, %[[IV1]]] +// CHECK-SAME: [%[[SLICE_D0]], %[[USED_TILESIZEX]]] +// CHECK: scf.yield %[[YIELD]] +// CHECK: scf.yield %[[RESULT1]] +// CHECK: return %[[RESULT0]] + +// ----- + +func @scatter_tiling_memref( + %original: memref, %indices: memref, + %update : memref) { + test.scatter {__internal_linalg_transform__ = "tiling_input"} + %update, %indices, %original : memref, memref, memref + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK: func @scatter_tiling_memref( +// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[TILESIZEY:.+]] = constant 10 : index +// CHECK-DAG: %[[TILESIZEX:.+]] = constant 20 : index +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[UPDATES]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = memref.dim %[[UPDATES]], %[[C1]] +// CHECK: scf.for %[[IV0:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZEY]] +// CHECK-DAG: %[[USED_TILESIZEY:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[TILESIZEY]], %[[D0]]] +// CHECK: scf.for %[[IV1:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZEX]] +// CHECK-DAG: %[[USED_TILESIZEX:.+]] = affine.min #[[MAP1]](%[[IV1]])[%[[TILESIZEX]], %[[D1]]] +// CHECK: %[[UPDATE_SLICE:.+]] = memref.subview %[[UPDATES]][%[[IV0]], %[[IV1]]] +// CHECK-SAME: [%[[USED_TILESIZEY]], %[[USED_TILESIZEX]]] +// CHECK: %[[INDEX_SLICE:.+]] = memref.subview %[[INDICES]][%[[IV0]], 0] +// CHECK-SAME: [%[[USED_TILESIZEY]], 1] +// CHECK: %[[SLICE_D0:.+]] = memref.dim %[[ORIGINAL]], %[[C0]] +// CHECK: %[[SOURCE_SLICE:.+]] = memref.subview %[[ORIGINAL]][0, %[[IV1]]] +// CHECK-SAME: [%[[SLICE_D0]], %[[USED_TILESIZEX]]] +// CHECK: test.scatter +// CHECK-SAME: __internal_linalg_transform__ = "tiling_output" +// CHECK-SAME: %[[UPDATE_SLICE]], %[[INDEX_SLICE]], %[[SOURCE_SLICE]] + +// ----- + +func @scatter_tiling_distribution( + %original: tensor, %indices: tensor, + %update : tensor) -> tensor { + %0 = test.scatter {__internal_linalg_transform__ = "distribute_input"} + %update, %indices, %original : tensor, tensor, tensor -> tensor + return %0 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK: func @scatter_tiling_distribution( +// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[TILESIZE:.+]] = constant 10 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[UPDATES]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[UPDATES]], %[[C1]] +// CHECK-DAG: %[[ID:.+]] = "gpu.block_id"() {dimension = "x"} +// CHECK-DAG: %[[COUNT:.+]] = "gpu.grid_dim"() {dimension = "x"} +// CHECK-DAG: %[[OFFSET:.+]] = affine.apply #[[MAP0]]()[%[[ID]]] +// CHECK-DAG: %[[STEP:.+]] = affine.apply #[[MAP0]]()[%[[COUNT]]] +// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[OFFSET]] to %[[D0]] step %[[STEP]] +// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[ORIGINAL]]) +// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP1]](%[[IV]])[%[[TILESIZE]], %[[D0]]] +// CHECK: %[[UPDATE_SLICE:.+]] = tensor.extract_slice %[[UPDATES]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: %[[INDEX_SLICE:.+]] = tensor.extract_slice %[[INDICES]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], 1] +// CHECK: %[[SLICE_D0:.+]] = tensor.dim %[[ORIGINAL]], %[[C0]] +// CHECK: %[[SOURCE_SLICE:.+]] = tensor.extract_slice %[[INIT]][0, 0] [%[[SLICE_D0]], %[[D1]]] +// CHECK: %[[SCATTER_TILE:.+]] = test.scatter +// CHECK-SAME: __internal_linalg_transform__ = "distribute_output" +// CHECK-SAME: %[[UPDATE_SLICE]], %[[INDEX_SLICE]], %[[SOURCE_SLICE]] +// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SCATTER_TILE]] into %[[INIT]][0, 0] +// CHECK-SAME: [%[[SLICE_D0]], %[[D1]]] +// CHECK: scf.yield %[[YIELD]] +// CHECK: return %[[RESULT]] + +// ----- + +func @scatter_no_tiling( + %original: tensor, %indices: tensor, + %update : tensor) -> tensor { + %0 = test.scatter {__internal_linalg_transform__ = "no_tiling_input"} + %update, %indices, %original : tensor, tensor, tensor -> tensor + return %0 : tensor +} +// CHECK: func @scatter_no_tiling +// CHECK-SAME: %[[ORIGINAL:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[INDICES:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[UPDATES:[a-zA-Z0-9_]+]]: tensor +// CHECK: %[[RESULT:.+]] = test.scatter +// CHECK-SAME: __internal_linalg_transform__ = "no_tiling_output" +// CHECK-SAME: %[[UPDATES]], %[[INDICES]], %[[ORIGINAL]] +// CHECK: return %[[RESULT]] + +// ----- + +func @sort_1d(%arg0: tensor) -> tensor { + %0 = test.sort reduce(0) {__internal_linalg_transform__ = "outer_reduce_input"} + %arg0 : tensor -> tensor + return %0 : tensor +} +// CHECK: func @sort_1d( +// CHECK-SAME: %[[OPERAND:.+]]: tensor +// CHECK: %[[RESULT:.+]] = test.sort reduce(0) +// CHECK-SAME: __internal_linalg_transform__ = "outer_reduce_output" +// CHECK-SAME: %[[OPERAND]] +// CHECK: return %[[RESULT]] + +// ----- + +func @sort_2d(%arg0: tensor) -> tensor { + %0 = test.sort reduce(1) {__internal_linalg_transform__ = "inner_reduce_input"} + %arg0 : tensor -> tensor + return %0 : tensor +} +// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK: func @sort_2d( +// CHECK-SAME: %[[OPERAND:.+]]: tensor +// CHECK-DAG: %[[TILESIZE:.+]] = constant 10 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND]], %[[C1]] +// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]] +// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[OPERAND]]) +// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]] +// CHECK: %[[OPERAND_SLICE:.+]] = tensor.extract_slice %[[INIT]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: %[[SORT_TILE:.+]] = test.sort +// CHECK-SAME: __internal_linalg_transform__ = "inner_reduce_output" +// CHECK-SAME: %[[OPERAND_SLICE]] +// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SORT_TILE]] into %[[INIT]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: scf.yield %[[YIELD]] +// CHECK: return %[[RESULT]] + +// ----- + +func @sort_2d_inner_parallel(%arg0: tensor) -> tensor { + %0 = test.sort reduce(0) {__internal_linalg_transform__ = "outer_reduce_input"} + %arg0 : tensor -> tensor + return %0 : tensor +} +// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK: func @sort_2d_inner_parallel( +// CHECK-SAME: %[[OPERAND:.+]]: tensor +// CHECK-DAG: %[[TILESIZE:.+]] = constant 20 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND]], %[[C1]] +// CHECK: %[[RESULT:.+]] = scf.for %[[IV:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZE]] +// CHECK-SAME: iter_args(%[[INIT:.+]] = %[[OPERAND]]) +// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D1]]] +// CHECK: %[[OPERAND_SLICE:.+]] = tensor.extract_slice %[[INIT]][0, %[[IV]]] +// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]] +// CHECK: %[[SORT_TILE:.+]] = test.sort +// CHECK-SAME: __internal_linalg_transform__ = "outer_reduce_output" +// CHECK-SAME: %[[OPERAND_SLICE]] +// CHECK: %[[YIELD:.+]] = tensor.insert_slice %[[SORT_TILE]] into %[[INIT]][0, %[[IV]]] +// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]] +// CHECK: scf.yield %[[YIELD]] +// CHECK: return %[[RESULT]] + +// ----- + +func @sort_2d_multi_result( + %arg0: tensor, %arg1: tensor) + -> (tensor, tensor) { + %0:2 = test.sort reduce(1) {__internal_linalg_transform__ = "inner_reduce_input"} + %arg0, %arg1 : tensor, tensor -> tensor, tensor + return %0#0, %0#1 : tensor, tensor +} +// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK: func @sort_2d_multi_result( +// CHECK-SAME: %[[OPERAND1:.+]]: tensor +// CHECK-SAME: %[[OPERAND2:.+]]: tensor +// CHECK-DAG: %[[TILESIZE:.+]] = constant 10 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND1]], %[[C1]] +// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV:.+]] = %[[C0]] to %[[D0]] step %[[TILESIZE]] +// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[OPERAND1]], %[[INIT2:.+]] = %[[OPERAND2]]) +// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D0]]] +// CHECK: %[[OPERAND1_SLICE:.+]] = tensor.extract_slice %[[INIT1]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: %[[OPERAND2_SLICE:.+]] = tensor.extract_slice %[[INIT2]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: %[[SORT_TILE:.+]]:2 = test.sort +// CHECK-SAME: __internal_linalg_transform__ = "inner_reduce_output" +// CHECK-SAME: %[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]] +// CHECK: %[[YIELD1:.+]] = tensor.insert_slice %[[SORT_TILE]]#0 into %[[INIT1]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: %[[YIELD2:.+]] = tensor.insert_slice %[[SORT_TILE]]#1 into %[[INIT2]][%[[IV]], 0] +// CHECK-SAME: [%[[USED_TILESIZE]], %[[D1]]] +// CHECK: scf.yield %[[YIELD1]], %[[YIELD2]] +// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 + +// ----- + +func @sort_2d_multi_result_memref( + %arg0: memref, %arg1: memref) { + test.sort reduce(0) {__internal_linalg_transform__ = "outer_reduce_input"} + %arg0, %arg1 : memref, memref + return +} +// CHECK: #[[MAP:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK: func @sort_2d_multi_result_memref( +// CHECK-SAME: %[[OPERAND1:.+]]: memref +// CHECK-SAME: %[[OPERAND2:.+]]: memref +// CHECK-DAG: %[[TILESIZE:.+]] = constant 20 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[OPERAND1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = memref.dim %[[OPERAND1]], %[[C1]] +// CHECK: scf.for %[[IV:.+]] = %[[C0]] to %[[D1]] step %[[TILESIZE]] +// CHECK-DAG: %[[USED_TILESIZE:.+]] = affine.min #[[MAP]](%[[IV]])[%[[TILESIZE]], %[[D1]]] +// CHECK: %[[OPERAND1_SLICE:.+]] = memref.subview %[[OPERAND1]][0, %[[IV]]] +// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]] +// CHECK: %[[OPERAND2_SLICE:.+]] = memref.subview %[[OPERAND2]][0, %[[IV]]] +// CHECK-SAME: [%[[D0]], %[[USED_TILESIZE]]] +// CHECK: test.sort +// CHECK-SAME: __internal_linalg_transform__ = "outer_reduce_output" +// CHECK-SAME: %[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]] + +// ----- + +func @sort_3d_multi_result_distribute( + %arg0: tensor, %arg1 : tensor) + -> (tensor, tensor) { + %0, %1 = test.sort reduce(1) {__internal_linalg_transform__ = "distribute_input"} + %arg0, %arg1 : tensor, tensor -> tensor, tensor + return %0, %1 : tensor, tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 30)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)> +// CHECK: func @sort_3d_multi_result_distribute( +// CHECK-SAME: %[[OPERAND1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[OPERAND2:[a-zA-Z0-9_]+]]: tensor +// CHECK-DAG: %[[TILESIZE1:.+]] = constant 10 : index +// CHECK-DAG: %[[TILESIZE2:.+]] = constant 30 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[OPERAND1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[OPERAND1]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[OPERAND1]], %[[C2]] +// CHECK-DAG: %[[IDX:.+]] = "gpu.block_id"() {dimension = "x"} +// CHECK-DAG: %[[COUNTX:.+]] = "gpu.grid_dim"() {dimension = "x"} +// CHECK-DAG: %[[IDY:.+]] = "gpu.block_id"() {dimension = "y"} +// CHECK-DAG: %[[COUNTY:.+]] = "gpu.grid_dim"() {dimension = "y"} +// CHECK-DAG: %[[OFFSETY:.+]] = affine.apply #[[MAP0]]()[%[[IDY]]] +// CHECK-DAG: %[[STEPY:.+]] = affine.apply #[[MAP0]]()[%[[COUNTY]]] +// CHECK: %[[RESULT:.+]]:2 = scf.for %[[IV0:.+]] = %[[OFFSETY]] to %[[D0]] step %[[STEPY]] +// CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[OPERAND1]], %[[INIT2:.+]] = %[[OPERAND2]]) +// CHECK-DAG: %[[USED_TILESIZE1:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[TILESIZE1]], %[[D0]]] +// CHECK-DAG: %[[OFFSETX:.+]] = affine.apply #[[MAP2]]()[%[[IDX]]] +// CHECK-DAG: %[[STEPX:.+]] = affine.apply #[[MAP2]]()[%[[COUNTX]]] +// CHECK: %[[RESULT_INNER:.+]]:2 = scf.for %[[IV1:.+]] = %[[OFFSETX]] to %[[D2]] step %[[STEPX]] +// CHECK-SAME: iter_args(%[[INIT3:.+]] = %[[INIT1]], %[[INIT4:.+]] = %[[INIT2]]) +// CHECK-DAG: %[[USED_TILESIZE2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[TILESIZE2]], %[[D2]]] +// CHECK: %[[OPERAND1_SLICE:.+]] = tensor.extract_slice %[[INIT3]][%[[IV0]], 0, %[[IV1]]] +// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]] +// CHECK: %[[OPERAND2_SLICE:.+]] = tensor.extract_slice %[[INIT4]][%[[IV0]], 0, %[[IV1]]] +// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]] +// CHECK: %[[SORT_SLICE:.+]]:2 = test.sort +// CHECK-SAME: __internal_linalg_transform__ = "distribute_output" +// CHECK-SAME: %[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]] +// CHECK: %[[YIELD1:.+]] = tensor.insert_slice %[[SORT_SLICE]]#0 +// CHECK-SAME: into %[[INIT3]][%[[IV0]], 0, %[[IV1]]] +// CHECK: %[[YIELD2:.+]] = tensor.insert_slice %[[SORT_SLICE]]#1 +// CHECK-SAME: into %[[INIT4]][%[[IV0]], 0, %[[IV1]]] +// CHECK: scf.yield %[[YIELD1]], %[[YIELD2]] +// CHECK: scf.yield %[[RESULT_INNER]]#0, %[[RESULT_INNER]]#1 +// CHECK: return %[[RESULT]]#0, %[[RESULT]]#1 + +// ----- + +func @sort_3d_multi_result_distribute_memref( + %arg0: memref, %arg1 : memref) { + test.sort reduce(1) {__internal_linalg_transform__ = "distribute_input"} + %arg0, %arg1 : memref, memref + return +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 * 10)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0] -> (s0 * 30)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0)[s0, s1] -> (30, -d0 + s1)> +// CHECK: func @sort_3d_multi_result_distribute_memref( +// CHECK-SAME: %[[OPERAND1:[a-zA-Z0-9_]+]]: memref +// CHECK-SAME: %[[OPERAND2:[a-zA-Z0-9_]+]]: memref +// CHECK-DAG: %[[TILESIZE1:.+]] = constant 10 : index +// CHECK-DAG: %[[TILESIZE2:.+]] = constant 30 : index +// CHECK-DAG: %[[C0:.+]] = constant 0 : index +// CHECK-DAG: %[[C1:.+]] = constant 1 : index +// CHECK-DAG: %[[C2:.+]] = constant 2 : index +// CHECK-DAG: %[[D0:.+]] = memref.dim %[[OPERAND1]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = memref.dim %[[OPERAND1]], %[[C1]] +// CHECK-DAG: %[[D2:.+]] = memref.dim %[[OPERAND1]], %[[C2]] +// CHECK-DAG: %[[IDX:.+]] = "gpu.block_id"() {dimension = "x"} +// CHECK-DAG: %[[COUNTX:.+]] = "gpu.grid_dim"() {dimension = "x"} +// CHECK-DAG: %[[IDY:.+]] = "gpu.block_id"() {dimension = "y"} +// CHECK-DAG: %[[COUNTY:.+]] = "gpu.grid_dim"() {dimension = "y"} +// CHECK-DAG: %[[OFFSETY:.+]] = affine.apply #[[MAP0]]()[%[[IDY]]] +// CHECK-DAG: %[[STEPY:.+]] = affine.apply #[[MAP0]]()[%[[COUNTY]]] +// CHECK: scf.for %[[IV0:.+]] = %[[OFFSETY]] to %[[D0]] step %[[STEPY]] +// CHECK-DAG: %[[USED_TILESIZE1:.+]] = affine.min #[[MAP1]](%[[IV0]])[%[[TILESIZE1]], %[[D0]]] +// CHECK-DAG: %[[OFFSETX:.+]] = affine.apply #[[MAP2]]()[%[[IDX]]] +// CHECK-DAG: %[[STEPX:.+]] = affine.apply #[[MAP2]]()[%[[COUNTX]]] +// CHECK: scf.for %[[IV1:.+]] = %[[OFFSETX]] to %[[D2]] step %[[STEPX]] +// CHECK-DAG: %[[USED_TILESIZE2:.+]] = affine.min #[[MAP3]](%[[IV1]])[%[[TILESIZE2]], %[[D2]]] +// CHECK: %[[OPERAND1_SLICE:.+]] = memref.subview %[[OPERAND1]][%[[IV0]], 0, %[[IV1]]] +// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]] +// CHECK: %[[OPERAND2_SLICE:.+]] = memref.subview %[[OPERAND2]][%[[IV0]], 0, %[[IV1]]] +// CHECK-SAME: [%[[USED_TILESIZE1]], %[[D1]], %[[USED_TILESIZE2]]] +// CHECK: test.sort +// CHECK-SAME: __internal_linalg_transform__ = "distribute_output" +// CHECK-SAME: %[[OPERAND1_SLICE]], %[[OPERAND2_SLICE]] + +// ----- + +func @slice_insert(%source :tensor, %dest: tensor, + %idx0 : index, %idx1 : index) -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = tensor.dim %source, %c0 : tensor + %1 = tensor.dim %source, %c1 : tensor + %2 = tensor.insert_slice %source into %dest[%idx0, %idx1] [%0, %1] [1, 1] + {__internal_linalg_transform__ = "tiling_input"} : tensor into tensor + return %2 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK: func @slice_insert( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = +// CHECK-DAG: %[[YIELD1:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = +// CHECK-DAG: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[OFFSET0:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG2]]] +// CHECK-DAG: %[[OFFSET1:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG3]]] +// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[SLICE]] +// CHECK-SAME: into %{{.+}}[%[[OFFSET0]], %[[OFFSET1]]] +// CHECK: scf.yield %[[UPDATE]] +// CHECK: scf.yield %[[YIELD1]] +// CHECK: return %[[RESULT]] + +// ----- + +func @slice_insert_rank_reduce(%source :tensor, %dest: tensor, + %idx0 : index, %idx1 : index) -> tensor { + %c0 = constant 0 : index + %c1 = constant 1 : index + %0 = tensor.dim %source, %c0 : tensor + %1 = tensor.dim %source, %c1 : tensor + %2 = tensor.insert_slice %source into %dest[%idx0, 0, %idx1] [%0, 1, %1] [1, 1, 1] + {__internal_linalg_transform__ = "tiling_input"} : tensor into tensor + return %2 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0)[s0] -> (d0 + s0)> +// CHECK: func @slice_insert_rank_reduce( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index +// CHECK: %[[RESULT:.+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = +// CHECK: %[[YIELD1:.+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = +// CHECK-DAG: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][%[[IV0]], %[[IV1]]] +// CHECK-DAG: %[[OFFSET0:.+]] = affine.apply #[[MAP2]](%[[IV0]])[%[[ARG2]]] +// CHECK-DAG: %[[OFFSET1:.+]] = affine.apply #[[MAP2]](%[[IV1]])[%[[ARG3]]] +// CHECK: %[[UPDATE:.+]] = tensor.insert_slice %[[SLICE]] +// CHECK-SAME: into %{{.+}}[%[[OFFSET0]], 0, %[[OFFSET1]]] +// CHECK: scf.yield %[[UPDATE]] +// CHECK: scf.yield %[[YIELD1]] +// CHECK: return %[[RESULT]] diff --git a/mlir/test/lib/CMakeLists.txt b/mlir/test/lib/CMakeLists.txt --- a/mlir/test/lib/CMakeLists.txt +++ b/mlir/test/lib/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(IR) +add_subdirectory(Interfaces) add_subdirectory(Pass) add_subdirectory(Reducer) add_subdirectory(Rewrite) diff --git a/mlir/test/lib/Dialect/Test/TestDialect.h b/mlir/test/lib/Dialect/Test/TestDialect.h --- a/mlir/test/lib/Dialect/Test/TestDialect.h +++ b/mlir/test/lib/Dialect/Test/TestDialect.h @@ -31,6 +31,7 @@ #include "mlir/Interfaces/DerivedAttributeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/TilingInterface.h" namespace mlir { class DLTIDialect; diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -19,6 +19,7 @@ include "mlir/Interfaces/CopyOpInterface.td" include "mlir/Interfaces/DataLayoutInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/TilingInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "TestInterfaces.td" @@ -2040,6 +2041,82 @@ }]; } +//===----------------------------------------------------------------------===// +// Test TilingInterface +//===----------------------------------------------------------------------===// + +def TestScatter : TEST_Op<"scatter", []> { + let description = [{ + Operation to test tiling interface. + + This operation represents a "scatter" like pattern. The + computation represented is shown below when all operands are of + `memref` types. The operands could be all `tensor` types as well. + + ```mlir + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %d0 = memref.dim %update, %c0 : memref + %d1 = memref.dim %update, %c1 : memref + %d2 = memref.dim %update, %c2 : memref + scf.for %iv0 = %c0 to %d0 step %c1 + scf.for %iv1 = %c0 to %d1 step %c1 + scf.for %iv2 = %c0 to %d2 step %c1 + %i0 = memref.load %index[%iv0][0] : memref + %indx0 = index_cast %i0 : i32 to index + %i1 = memref.load %index[%iv0][1] : memref + %indx1 = index_cast %i1 : i32 to index + %val = memref.load %update[%iv0][%iv1][iv2] : memref + memref.store %val, %source[%indx0][index1][%iv1][%iv2] : memref + ``` + }]; + let arguments = (ins AnyType:$update, AnyType:$indices, AnyType:$source); + let results = (outs Optional:$result); + let assemblyFormat = [{ + attr-dict $update `,` $indices `,` $source `:` + type($update) `,` type($indices) `,` type($source) (`->` type($result)^)? + }]; +} + +def TestSort : TEST_Op<"sort", []> { + let description = [{ + Operation to test tiling interface. + + This operation represents a "sort-like" operation, where values + along a single dimensions are sorted. The computation represented + is shown below when all operands are of `memref` types. The + operands could be all `tensor` types as well. + + ```mlir + %c0 = constant 0 : index + %c1 = constant 1 : index + %c2 = constant 2 : index + %d0 = memref.dim %operand0, %c0 : memref + %d1 = memref.dim %operand0, %c1 : memref + %d2 = memref.dim %operand0, %c2 : memref + scf.for %iv0 = %c0 to %d0 step %c1 + scf.for %iv1 = %c0 to %d1 step %c1 + %sorted_slice0 = memref.subview %operand0[%iv0, 0, %iv1][1, %d1, 1][1, 1, 1] + : memref into memref + %sorted_slice1 = memref.subview %operand1[%iv0, 0, %iv1][1, %d1, 1][1, 1, 1] + : memref into memref + call sort_in_place(%sorted_slice0, %sorted_slice1) + : (memref, memref) -> () + ``` + }]; + + let arguments = (ins Variadic:$sources, I64Attr:$reduce_dim); + let results = (outs Variadic:$results); + let assemblyFormat = [{ + `reduce` `(` $reduce_dim `)` attr-dict $sources + `:` type($sources) (`->` type($results)^)? + }]; +} + +//===----------------------------------------------------------------------===// + + //===----------------------------------------------------------------------===// // Test TableGen generated build() methods //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Interfaces/CMakeLists.txt b/mlir/test/lib/Interfaces/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Interfaces/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TilingInterface) diff --git a/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt b/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt @@ -0,0 +1,24 @@ +# Exclude tests from libMLIR.so +add_mlir_library(MLIRTestTilingInterface + TestTilingInterface.cpp + + EXCLUDE_FROM_LIBMLIR + + LINK_LIBS PUBLIC + MLIRAffine + MLIRGPUOps + MLIRIR + MLIRLinalgTransforms + MLIRMemRef + MLIRPass + MLIRSCF + MLIRStandard + MLIRTransformUtils + MLIRTensor + MLIRTilingInterface + MLIRTilingTransform + MLIRTransformUtils +) + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../Dialect/Test) +include_directories(${CMAKE_CURRENT_BINARY_DIR}/../../Dialect/Test) diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -0,0 +1,495 @@ +//===- TestTilingInterface.cpp - Test pass for Tiling Interface patterns --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file tests the tiling transformations implemented using TilingInterface. +// +//===----------------------------------------------------------------------===// + +#include "TestDialect.h" +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/Linalg/TilingInterface/Tiling.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/SCF.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/TilingInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "llvm/ADT/TypeSwitch.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Utility methods to convert from `OpFoldResult` to `Value` +//===----------------------------------------------------------------------===// + +/// Converts an `OpFoldResult` to a `Value` by building a constant op if +/// if the `OpFoldResult` is an `IntegerAttr`. +static Value getValue(OpBuilder &builder, Location loc, + OpFoldResult valueOrAttr) { + if (auto attr = valueOrAttr.dyn_cast()) { + return builder.create(loc, + attr.cast().getInt()); + } + return valueOrAttr.get(); +} + +/// Returns the constant value in `valueOrAttr` if it is not a dynamic `Value`. +static Optional getConstantValue(OpFoldResult valueOrAttr) { + if (auto attr = valueOrAttr.dyn_cast()) + return attr.cast().getInt(); + return {}; +} + +/// Checks if `valueOrAttr` represents a constant value `val`. +static bool isValue(OpFoldResult valueOrAttr, int64_t val) { + auto attr = valueOrAttr.dyn_cast(); + return attr && attr.cast().getValue() == val; +} + +/// Returns a memref.subview or a tensor.extract_slice based on the type of the +/// `source`. +static Value getSlice(OpBuilder &b, Location loc, Value source, + ArrayRef offsets, + ArrayRef sizes, + ArrayRef strides) { + return TypeSwitch(source.getType()) + .Case([&](RankedTensorType t) -> Value { + return b.create(loc, source, offsets, sizes, + strides); + }) + .Case([&](MemRefType type) -> Value { + return b.create(loc, source, offsets, sizes, + strides); + }) + .Default([&](Type t) { return nullptr; }); +} + +static Value getDimValue(OpBuilder &builder, Location loc, Value v, + int64_t dim) { + return TypeSwitch(v.getType()) + .Case([&](RankedTensorType t) -> Value { + return builder.create(loc, v, dim); + }) + .Case([&](MemRefType t) -> Value { + return builder.create(loc, v, dim); + }) + .Default([&](Type t) { return Value(); }); +} + +static OpFoldResult getDim(OpBuilder &builder, Location loc, Value v, + int64_t dim) { + auto t = v.getType().cast(); + if (t.isDynamicDim(dim)) + return getDimValue(builder, loc, v, dim); + return builder.getI64IntegerAttr(t.getDimSize(dim)); +} + +//===----------------------------------------------------------------------===// +// Tiling Interface for Test Dialect operations. +//===----------------------------------------------------------------------===// + +namespace { +struct TestScatterTilingInterface + : public TilingInterface::ExternalModel { + SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { + return {cast(op).source()}; + } + + SmallVector getLoopIteratorTypes(Operation *op) const { + auto testOp = cast(op); + SmallVector iteratorTypes( + testOp.update().getType().cast().getRank(), + getParallelIteratorTypeName()); + return iteratorTypes; + } + + SmallVector getLoopBounds(Operation *op, OpBuilder &builder) const { + auto testOp = cast(op); + Location loc = op->getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + SmallVector ranges; + Value updates = testOp.update(); + for (auto dim : llvm::seq( + 0, updates.getType().cast().getRank())) { + Value ub = getDimValue(builder, loc, updates, dim); + ranges.emplace_back(Range{zero, ub, one}); + } + return ranges; + } + + Operation *getTiledImplementation(Operation *op, OpBuilder &builder, + ValueRange dest, + ArrayRef offsets, + ArrayRef sizes, + SmallVectorImpl &results) const { + auto testOp = cast(op); + Location loc = op->getLoc(); + auto zeroAttr = builder.getI64IntegerAttr(0); + auto oneAttr = builder.getI64IntegerAttr(1); + + // Slice of the updates. + Value updates = testOp.update(); + auto updateRank = updates.getType().cast().getRank(); + SmallVector updateStrides(updateRank, oneAttr); + Value tiledUpdate = + getSlice(builder, loc, updates, offsets, sizes, updateStrides); + assert(tiledUpdate && "failed to get slice of update"); + + // Slice of indices. + Value indices = testOp.indices(); + auto indicesRank = indices.getType().cast().getRank(); + SmallVector indicesOffsets(indicesRank, zeroAttr); + SmallVector indicesSizes(indicesRank, zeroAttr); + indicesOffsets[0] = offsets[0]; + indicesSizes[0] = sizes[0]; + for (auto dim : llvm::seq(1, indicesRank)) + indicesSizes[dim] = getDim(builder, loc, indices, dim); + SmallVector indicesStrides(indicesRank, oneAttr); + Value tiledIndices = getSlice(builder, loc, indices, indicesOffsets, + indicesSizes, indicesStrides); + assert(tiledIndices && "failed to get slice of indices"); + + // Slice of the original + Value source = testOp.source(); + auto sourceRank = source.getType().cast().getRank(); + SmallVector sourceOffsets(sourceRank, zeroAttr); + SmallVector sourceSizes(sourceRank); + for (auto dim : llvm::seq(0, sourceRank - updateRank + 1)) + sourceSizes[dim] = getDim(builder, loc, source, dim); + for (auto dim : + llvm::seq(sourceRank - updateRank + 1, sourceRank)) { + sourceOffsets[dim] = offsets[dim - (sourceRank - updateRank)]; + sourceSizes[dim] = sizes[dim - (sourceRank - updateRank)]; + } + SmallVector sourceStrides(sourceRank, oneAttr); + Value tiledSource = getSlice(builder, loc, dest[0], sourceOffsets, + sourceSizes, sourceStrides); + assert(tiledSource && "failed to get slice of source tensor"); + + SmallVector resultTypes; + if (op->getNumResults()) { + resultTypes.push_back(tiledSource.getType()); + } + Operation *tiledOp = builder.create( + loc, resultTypes, tiledUpdate, tiledIndices, tiledSource); + for (auto result : llvm::enumerate(tiledOp->getResults())) { + auto insertSliceOp = builder.create( + loc, result.value(), dest[0], sourceOffsets, sourceSizes, + sourceStrides); + results.push_back(insertSliceOp.getResult()); + } + return tiledOp; + } +}; + +struct TestSortTilingInterface + : public TilingInterface::ExternalModel { + SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { + return cast(op).sources(); + } + + SmallVector getLoopIteratorTypes(Operation *op) const { + auto testOp = cast(op); + // All loops except the dimension to sort along are parallel. + int64_t operandRank = + testOp.sources()[0].getType().cast().getRank(); + SmallVector iteratorTypes(operandRank, + getParallelIteratorTypeName()); + iteratorTypes[testOp.reduce_dim()] = getReductionIteratorTypeName(); + return iteratorTypes; + } + + SmallVector getLoopBounds(Operation *op, OpBuilder &builder) const { + auto testOp = cast(op); + int64_t operandRank = + testOp.sources()[0].getType().cast().getRank(); + SmallVector loopBounds(operandRank); + Location loc = op->getLoc(); + Value zero = builder.create(loc, 0); + Value one = builder.create(loc, 1); + Value source = testOp.sources()[0]; + for (auto dim : llvm::seq(0, operandRank)) { + loopBounds[dim].offset = zero; + loopBounds[dim].size = getDimValue(builder, loc, source, dim); + loopBounds[dim].stride = one; + } + return loopBounds; + } + + Operation *getTiledImplementation(Operation *op, OpBuilder &builder, + ValueRange dest, + ArrayRef offsets, + ArrayRef sizes, + SmallVectorImpl &results) const { + auto testOp = cast(op); + assert(dest.size() == testOp.sources().size()); + int64_t rank = testOp.sources()[0].getType().cast().getRank(); + assert(offsets.size() == static_cast(rank) && + sizes.size() == static_cast(rank)); + auto oneAttr = builder.getI64IntegerAttr(1); + SmallVector strides(rank, oneAttr); + Location loc = op->getLoc(); + SmallVector tiledOperands(dest.size()); + for (auto en : llvm::enumerate(dest)) { + tiledOperands[en.index()] = + getSlice(builder, loc, en.value(), offsets, sizes, strides); + assert(tiledOperands[en.index()] && "failed to get slice of operand"); + } + SmallVector resultTypes; + if (op->getNumResults()) { + resultTypes = llvm::to_vector<4>( + llvm::map_range(tiledOperands, [&](Value v) { return v.getType(); })); + } + Operation *tiledOp = builder.create( + loc, resultTypes, tiledOperands, testOp.reduce_dim()); + for (auto result : llvm::enumerate(tiledOp->getResults())) { + auto insertSliceOp = builder.create( + loc, result.value(), dest[result.index()], offsets, sizes, strides); + results.push_back(insertSliceOp.getResult()); + } + return tiledOp; + } +}; + +//===----------------------------------------------------------------------===// +// Interface implementations for external operations. +//===----------------------------------------------------------------------===// + +struct InsertSliceTilingInterface + : public TilingInterface::ExternalModel { + SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { + return {cast(op).dest()}; + } + + SmallVector getLoopIteratorTypes(Operation *op) const { + auto insertSliceOp = cast(op); + return SmallVector(insertSliceOp.getSourceType().getRank(), + getParallelIteratorTypeName()); + } + + SmallVector getLoopBounds(Operation *op, OpBuilder &b) const { + auto insertSliceOp = cast(op); + Value source = insertSliceOp.source(); + RankedTensorType sourceType = insertSliceOp.getSourceType(); + Location loc = op->getLoc(); + Value zero = b.create(loc, 0); + Value one = b.create(loc, 1); + SmallVector loopBounds(sourceType.getRank(), + Range{zero, nullptr, one}); + for (auto dim : + llvm::seq(0, insertSliceOp.getSourceType().getRank())) + loopBounds[dim].size = b.create(loc, source, dim); + return loopBounds; + } + + Operation *getTiledImplementation(Operation *op, OpBuilder &b, + ValueRange dest, + ArrayRef offsets, + ArrayRef sizes, + SmallVector &results) const { + auto insertOp = cast(op); + // Compute a subtensor of the source based on the offsets. + auto opStrides = insertOp.getMixedStrides(); + if (!llvm::all_of(opStrides, [&](OpFoldResult valueOrAttr) { + return isValue(valueOrAttr, 1); + })) { + op->emitOpError("unable to tile operation with non-unit stride"); + return nullptr; + } + Location loc = insertOp.getLoc(); + auto oneAttr = b.getI64IntegerAttr(1); + SmallVector strides(offsets.size(), oneAttr); + auto extractSliceOp = b.create( + loc, insertOp.source(), offsets, sizes, strides); + + // The offsets for the insert is based on the op offsets plus the offsets of + // the loops passed in. + auto opOffsets = insertOp.getMixedOffsets(); + auto opSizes = insertOp.getMixedSizes(); + unsigned offsetIndex = 0; + ArrayRef sourceShape = insertOp.getSourceType().getShape(); + int64_t destRank = insertOp.getType().getRank(); + SmallVector resultOffsets(destRank); + SmallVector resultSizes(destRank); + auto zeroAttr = b.getI64IntegerAttr(0); + for (auto opOffset : llvm::enumerate(opOffsets)) { + // Check for rank-reducing by checking that + // 1) The corresponding opSize value is 1 + // 2) The current rank of the source is not 1. + // Then the opOffset is for the rank-reduced dimension. Skip. + unsigned opOffsetIndex = opOffset.index(); + if (isValue(opSizes[opOffsetIndex], 1) && sourceShape[offsetIndex] != 1) { + resultOffsets[opOffsetIndex] = zeroAttr; + resultSizes[opOffsetIndex] = oneAttr; + continue; + } + OpFoldResult opOffsetVal = opOffset.value(); + OpFoldResult offset = offsets[offsetIndex]; + if (opOffsetVal.is() && offset.is()) { + resultOffsets[opOffsetIndex] = b.getI64IntegerAttr( + *getConstantValue(opOffsetVal) + *getConstantValue(offset)); + } else { + AffineMap map = AffineMap::get( + 1, 1, {b.getAffineDimExpr(0) + b.getAffineSymbolExpr(0)}); + resultOffsets[opOffsetIndex] = + b.create(loc, map, + ValueRange{getValue(b, loc, offset), + getValue(b, loc, opOffsetVal)}) + .getResult(); + } + resultSizes[opOffsetIndex] = sizes[offsetIndex]; + offsetIndex++; + } + SmallVector resultStrides(destRank, oneAttr); + auto tiledInsertOp = b.create( + loc, extractSliceOp.result(), dest[0], resultOffsets, resultSizes, + resultStrides); + results.push_back(tiledInsertOp.result()); + return extractSliceOp; + } +}; + +template +struct TestOpTilingPattern : public TilingInterfaceBasePattern { + TestOpTilingPattern(MLIRContext *context, linalg::LinalgTilingOptions options, + linalg::LinalgTransformationFilter filter = + linalg::LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : TilingInterfaceBasePattern(OpTy::getOperationName(), context, options, + filter, benefit) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto tilableOp = dyn_cast(op); + if (!tilableOp) + return failure(); + TiledOp tiledOp; + // Check for failure. + if (failed(TilingInterfaceBasePattern::matchAndRewriteBase( + tilableOp, rewriter, tiledOp))) + return failure(); + + // Check for do-nothing case. + if (!tiledOp.op) + return failure(); + if (tiledOp.op != op) { + if (tiledOp.results.empty()) + rewriter.eraseOp(op); + else + rewriter.replaceOp(op, tiledOp.results); + } + return success(); + } +}; + +struct TestTilingInterfacePass + : public PassWrapper { + StringRef getArgument() const final { return "test-tiling-interface"; } + StringRef getDescription() const final { return "Test Tiling Interface."; } + TestTilingInterfacePass() = default; + TestTilingInterfacePass(const TestTilingInterfacePass &pass) {} + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + LogicalResult initialize(MLIRContext *context) override; + void runOnFunction() override; +}; +} // namespace + +void TestTilingInterfacePass::runOnFunction() { + FuncOp funcOp = getOperation(); + MLIRContext *context = funcOp.getContext(); + + RewritePatternSet patterns(context); + patterns.add>( + context, linalg::LinalgTilingOptions().setTileSizes({10, 20}), + linalg::LinalgTransformationFilter( + Identifier::get("tiling_input", context), + Identifier::get("tiling_output", context))); + patterns.add>( + context, linalg::LinalgTilingOptions().setTileSizes(ArrayRef{0}), + linalg::LinalgTransformationFilter( + Identifier::get("no_tiling_input", context), + Identifier::get("no_tiling_output", context))); + + patterns.add>( + context, linalg::LinalgTilingOptions().setTileSizes({0, 20}), + linalg::LinalgTransformationFilter( + Identifier::get("outer_reduce_input", context), + Identifier::get("outer_reduce_output", context))); + patterns.add>( + context, linalg::LinalgTilingOptions().setTileSizes({10, 0, 0}), + linalg::LinalgTransformationFilter( + Identifier::get("inner_reduce_input", context), + Identifier::get("inner_reduce_output", context))); + + static linalg::LinalgLoopDistributionOptions workgroupDistributionOptions = { + [](OpBuilder &builder, Location loc, ArrayRef parallelLoopRanges) { + auto numParallelDims = parallelLoopRanges.size(); + + SmallVector procInfo(numParallelDims); + Type indexType = builder.getIndexType(); + std::string dimStr[3] = {"x", "y", "z"}; + for (size_t dim = 0; dim < numParallelDims; ++dim) { + StringAttr attr = builder.getStringAttr(dimStr[dim]); + procInfo[numParallelDims - dim - 1] = { + builder.create(loc, indexType, attr), + builder.create(loc, indexType, attr)}; + } + return procInfo; + }, + {linalg::DistributionMethod::Cyclic, linalg::DistributionMethod::Cyclic, + linalg::DistributionMethod::Cyclic}, + DenseMap>()}; + + patterns.add, + TestOpTilingPattern>( + context, + linalg::LinalgTilingOptions() + .setTileSizes(ArrayRef{10, 0, 30}) + .setDistributionOptions(workgroupDistributionOptions), + linalg::LinalgTransformationFilter( + Identifier::get("distribute_input", context), + Identifier::get("distribute_output", context))); + + patterns.add>( + context, linalg::LinalgTilingOptions().setTileSizes({10, 20}), + linalg::LinalgTransformationFilter( + Identifier::get("tiling_input", context), + Identifier::get("tiling_output", context))); + + if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) + return signalPassFailure(); +} + +LogicalResult TestTilingInterfacePass::initialize(MLIRContext *context) { + tensor::InsertSliceOp::attachInterface(*context); + test::TestScatter::attachInterface(*context); + test::TestSort::attachInterface(*context); + return success(); +} + +namespace mlir { +namespace test { +void registerTestTilingInterfacePass() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/CMakeLists.txt b/mlir/tools/mlir-opt/CMakeLists.txt --- a/mlir/tools/mlir-opt/CMakeLists.txt +++ b/mlir/tools/mlir-opt/CMakeLists.txt @@ -29,6 +29,7 @@ MLIRTestPass MLIRTestReducer MLIRTestRewrite + MLIRTestTilingInterface MLIRTestTransforms ) endif() diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -103,6 +103,7 @@ void registerTestPreparationPassWithAllowedMemrefResults(); void registerTestRecursiveTypesPass(); void registerTestSCFUtilsPass(); +void registerTestTilingInterfacePass(); void registerTestVectorConversions(); } // namespace test } // namespace mlir @@ -189,6 +190,7 @@ mlir::test::registerTestPDLByteCodePass(); mlir::test::registerTestRecursiveTypesPass(); mlir::test::registerTestSCFUtilsPass(); + mlir::test::registerTestTilingInterfacePass(); mlir::test::registerTestVectorConversions(); } #endif