diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -81,6 +81,9 @@ RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion); +/// Patterns to bubble up or down data layout ops across other operations. +void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns); + /// Pattern to remove dead operands and results of `linalg.generic` operations. /// This is effectively DCE for a linalg op. void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -1776,6 +1776,10 @@ static ShapedType inferPackedType(ShapedType sourceType, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm = {}); + + static Value createDestinationTensor(OpBuilder &b, Location loc, + Value source, ArrayRef innerTileSizes, + ArrayRef innerDimsPos, ArrayRef outerDimsPerm); }]; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp ConstantFold.cpp + DataLayoutPropagation.cpp DecomposeLinalgOps.cpp Detensorize.cpp DropUnitDims.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -0,0 +1,249 @@ +//===- DataLayoutPropagation.cpp -----------------------------------------===/// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Passes.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION +#include "mlir/Dialect/Linalg/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::linalg; + +#define DEBUG_TYPE "linalg-data-layout-propagation" + +namespace { + +/// Returns a tuple for packed operand and indexing_map with the assumptions: +/// 1) The generic op is the producer of the pack op. +/// 2) The generic op has only one result. +/// 3) The indexing map of the output operand is identity. +/// If the operand is a scalar or packing dimensions are all irrelevant to the +/// operand, the opreand and the updated indexing map will be returned. +/// Otherwise, it returns the packed operand and the updated indexing map. E.g., +/// +/// #map0 = affine_map<(d0, d1) -> (d0, d1)> +/// #map1 = affine_map<(d0, d1) -> (d0)> +/// #map2 = affine_map<(d0, d1) -> (d1)> +/// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0], +/// iterator_types = ["parallel", "parallel"]} +/// ins(%arg0, %arg1 : tensor, tensor) +/// outs(%init : tensor) { +/// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): +/// %4 = arith.addf %arg3, %arg4 : f32 +/// linalg.yield %4 : f32 +/// } -> tensor +/// %1 = tensor.pack %0 +/// inner_dims_pos = [0, 1] +/// inner_tiles = [8, 2] +/// into %dest : tensor -> tensor +/// +/// Taking the first input operand as an example, the inner tile size of d1 is +/// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> -> +/// affine_map<(d1, d3)>` will be returned. +/// +/// %pack = tensor.pack %arg0 +/// inner_dims_pos = [0] +/// inner_tiles = [8] +/// into %init : tensor -> tensor +static std::tuple +getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, + tensor::PackOp packOp, GenericOp genericOp, + OpOperand *opOperand) { + int numOrigLoops = genericOp.getNumLoops(); + int64_t numInnerLoops = packOp.getInnerDimsPos().size(); + int64_t numLoops = numOrigLoops + numInnerLoops; + AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand); + SmallVector exprs(origIndexingMap.getResults()); + + if (genericOp.isScalar(opOperand)) + return std::make_tuple( + opOperand->get(), + AffineMap::get(numLoops, 0, exprs, packOp.getContext())); + + llvm::SetVector innerDimsPosSet(packOp.getInnerDimsPos().begin(), + packOp.getInnerDimsPos().end()); + DenseMap + iterMapToDim; // Mapping from AffinDimExpr of indexing maps to the operand + // shape dimension. + for (auto [index, expr] : llvm::enumerate(origIndexingMap.getResults())) { + int64_t dimPos = expr.cast().getPosition(); + if (!innerDimsPosSet.contains(dimPos)) + continue; + iterMapToDim[dimPos] = index; + } + + // Construct the information of packing data dimensions and new indexing maps + // for the operand. + SmallVector innerDimsPos; + SmallVector innerTileSizes; + for (auto [index, value] : llvm::enumerate( + llvm::zip(packOp.getInnerDimsPos(), packOp.getMixedTiles()))) { + int64_t dimPos = std::get<0>(value); + if (!iterMapToDim.count(dimPos)) + continue; + innerDimsPos.push_back(iterMapToDim[dimPos]); + innerTileSizes.push_back(std::get<1>(value)); + exprs.push_back(b.getAffineDimExpr(numOrigLoops + index)); + } + auto indexingMap = AffineMap::get(numLoops, 0, exprs, packOp.getContext()); + + SmallVector outerDimsPerm; + for (auto outDim : packOp.getOuterDimsPerm()) { + if (!iterMapToDim.count(outDim)) + continue; + outerDimsPerm.push_back(iterMapToDim[outDim]); + } + + // The operand does not have dimensions that relates to pack op. + if (innerDimsPos.empty() && outerDimsPerm.empty()) + return std::make_tuple(opOperand->get(), indexingMap); + + auto empty = tensor::PackOp::createDestinationTensor( + b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); + auto packedOperand = b.create( + loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, + packOp.getPaddingValue(), outerDimsPerm); + return std::make_tuple(packedOperand, indexingMap); +} + +/// Bubbles up tensor.pack op through elementwise generic op. This +/// swap pack(generic) to generic(pack). The new generic op works on packed +/// domain; pack ops are created for input and output operands. E.g., +/// +/// #map0 = affine_map<(d0, d1) -> (d0, d1)> +/// %0 = tensor.dim %arg0, %c0 : tensor +/// %1 = tensor.dim %arg0, %c1 : tensor +/// %2 = tensor.empty(%0, %1) : tensor +/// %3 = linalg.generic {indexing_maps = [#map0, #map0], +/// iterator_types = ["parallel", "parallel"]} +/// ins(%arg0 : tensor) +/// outs(%2 : tensor) { +/// ^bb0(%arg3: f32, %arg4: f32): +/// %4 = arith.addf %arg3, %arg3 : f32 +/// linalg.yield %4 : f32 +/// } -> tensor +/// %4 = tensor.pack %3 +/// inner_dims_pos = [0, 1] +/// inner_tiles = [8, 2] +/// into %dest : tensor -> tensor +/// +/// will be converted to +/// +/// #map = affine_map<()[s0] -> (s0 ceildiv 8)> +/// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)> +/// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +/// %dim = tensor.dim %arg0, %c0 : tensor +/// %dim_0 = tensor.dim %arg0, %c1 : tensor +/// %0 = affine.apply #map()[%dim] +/// %1 = affine.apply #map1()[%dim_0] +/// %2 = tensor.empty(%0, %1) : tensor +/// %pack = tensor.pack %arg0 +/// inner_dims_pos = [0, 1] +/// inner_tiles = [8, 2] +/// into %2 : tensor -> tensor +/// %3 = linalg.generic {indexing_maps = [#map2, #map2], +/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]} +/// ins(%pack : tensor) +/// outs(%arg1 : tensor) { +/// ^bb0(%in: f32, %out: f32): +/// %4 = arith.addf %in, %in : f32 +/// linalg.yield %4 : f32 +/// } -> tensor +static FailureOr +bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter, + tensor::PackOp packOp) { + auto genericOp = packOp.getSource().getDefiningOp(); + if (!genericOp) + return failure(); + + if (!isElementwise(genericOp)) + return failure(); + + // TODO: Relax the restriction. We are able to bubble up the pack op through + // multi-result generic op. It just needs more work. + if (genericOp.getNumResults() != 1) + return failure(); + + // TODO: Add an option for allowing padding values. It could introduce + // undefined behavior if we unconditionally propagate pack op through all + // the ops. E.g., if the padding value is zero and there are division ops in + // a generic op. Some values of padding area could be NaN (0/0). + if (packOp.getPaddingValue()) + return failure(); + + OpOperand *opOperand = genericOp.getDpsInitOperand(0); + // TODO: Add support for all permutation indexing maps. + if (!genericOp.getMatchingIndexingMap(opOperand).isIdentity()) + return rewriter.notifyMatchFailure( + packOp, "the result of generic op does not have identity indexing_map"); + + Location loc = packOp.getLoc(); + SmallVector inputOperands; + SmallVector indexingMaps; + for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) { + auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand( + rewriter, loc, packOp, genericOp, inputOperand); + inputOperands.push_back(packedOperand); + indexingMaps.push_back(packedIndexingMap); + } + + int64_t numLoops = genericOp.getNumLoops(); + int64_t numInnerLoops = packOp.getInnerDimsPos().size(); + int64_t newNumLoops = numLoops + numInnerLoops; + SmallVector iterTypes = + genericOp.getIteratorTypesArray(); + iterTypes.append(numInnerLoops, utils::IteratorType::parallel); + + SmallVector outExprs( + genericOp.getMatchingIndexingMap(opOperand).getResults()); + for (int i = 0; i < numInnerLoops; ++i) + outExprs.push_back(rewriter.getAffineDimExpr(numLoops + i)); + indexingMaps.push_back( + AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext())); + + auto newGenericOp = rewriter.create( + loc, packOp.getDestType(), inputOperands, packOp.getDest(), indexingMaps, + iterTypes, /*bodyBuild=*/nullptr, + linalg::getPrunedAttributeList(genericOp)); + rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(), + newGenericOp.getRegion().begin()); + return newGenericOp; +} + +// Wrapper pattern that applies bubbleUpPackOpThroughElemGenericOp method. +struct BubbleUpPackOpThroughElemGenericOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + auto genericOp = bubbleUpPackOpThroughElemGenericOp(rewriter, packOp); + if (failed(genericOp)) + return failure(); + rewriter.replaceOp(packOp, genericOp.value().getResults()); + return success(); + } +}; +} // namespace + +void mlir::linalg::populateDataLayoutPropagationPatterns( + RewritePatternSet &patterns) { + patterns.insert( + patterns.getContext()); +} diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -3349,6 +3349,37 @@ return RankedTensorType::get(resultShape, sourceType.getElementType()); } +Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source, + ArrayRef innerTileSizes, + ArrayRef innerDimsPos, + ArrayRef outerDimsPerm) { + AffineExpr dim0, dim1; + bindDims(b.getContext(), dim0, dim1); + auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult { + return makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1), {v1, v2}); + }; + + SmallVector mixedSizes; + for (auto [index, value] : + llvm::enumerate(source.getType().cast().getShape())) { + if (ShapedType::isDynamic(value)) + mixedSizes.push_back(b.create(loc, source, index).getResult()); + else + mixedSizes.push_back(b.getIndexAttr(value)); + } + for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) { + int64_t dimPos = std::get<0>(it); + OpFoldResult tileSize = std::get<1>(it); + mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize); + } + if (!outerDimsPerm.empty()) + applyPermutationToVector(mixedSizes, outerDimsPerm); + + mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end()); + auto elemType = source.getType().cast().getElementType(); + return b.create(loc, mixedSizes, elemType); +} + /// Returns true if the tiles and the tiled dims are constant. template bool areTilesAndTiledDimsAllConstant(OpTy op) { diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -0,0 +1,230 @@ +// RUN: mlir-opt %s -test-linalg-data-layout-propagation -split-input-file | FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @dynamic_elem_pack(%arg0: tensor, %dest: tensor) -> tensor +{ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.dim %arg0, %c1 : tensor + %2 = tensor.empty(%0, %1) : tensor + %3 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%2 : tensor) { + ^bb0(%arg3: f32, %arg4: f32): + %4 = arith.addf %arg3, %arg3 : f32 + linalg.yield %4 : f32 + } -> tensor + %4 = tensor.pack %3 + inner_dims_pos = [0, 1] + inner_tiles = [8, 2] + into %dest : tensor -> tensor + return %4 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func.func @dynamic_elem_pack +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]] +// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]], %[[OUTER_D1]]) : tensor +// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 2] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ELEM:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[PACK_ARG0]] +// CHECK-SAME: outs(%[[DEST]] +// CHECK: return %[[ELEM]] : tensor + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_inner_dims(%arg0: tensor<128x256xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{ + %init = tensor.empty() : tensor<128x256xi32> + %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg3 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %pack = tensor.pack %elem + inner_dims_pos = [1, 0] + inner_tiles = [16, 32] + into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32> + return %pack : tensor<4x16x16x32xi32> +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func.func @elem_pack_transpose_inner_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32> +// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ELEM:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[PACK_ARG0]] +// CHECK-SAME: outs(%[[DEST]] +// CHECK: return %[[ELEM]] : tensor<4x16x16x32xi32> + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x32x16xi32>) -> tensor<16x4x32x16xi32>{ + %init = tensor.empty() : tensor<128x256xi32> + %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg3 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %pack = tensor.pack %elem + outer_dims_perm = [1, 0] + inner_dims_pos = [0, 1] + inner_tiles = [32, 16] + into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32> + return %pack : tensor<16x4x32x16xi32> +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func.func @elem_pack_transpose_outer_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> +// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ELEM:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[PACK_ARG0]] +// CHECK-SAME: outs(%[[DEST]] +// CHECK: return %[[ELEM]] : tensor<16x4x32x16xi32> + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x16x32xi32>) -> tensor<16x4x16x32xi32>{ + %init = tensor.empty() : tensor<128x256xi32> + %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg3 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %pack = tensor.pack %elem + outer_dims_perm = [1, 0] + inner_dims_pos = [1, 0] + inner_tiles = [16, 32] + into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32> + return %pack : tensor<16x4x16x32xi32> +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func.func @elem_pack_transpose_inner_and_outer_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32> +// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ELEM:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[PACK_ARG0]] +// CHECK-SAME: outs(%[[DEST]] +// CHECK: return %[[ELEM]] : tensor<16x4x16x32xi32> + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> +#map2 = affine_map<(d0, d1) -> (d1)> +func.func @dynamic_broadcast_pack(%arg0: tensor, %arg1: tensor, %dest: tensor) -> tensor +{ + %c0 = arith.constant 0 : index + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.dim %arg1, %c0 : tensor + %2 = tensor.empty(%0, %1) : tensor + %3 = linalg.generic {indexing_maps = [#map1, #map2, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) + outs(%2 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %4 = arith.addf %arg3, %arg4 : f32 + linalg.yield %4 : f32 + } -> tensor + %4 = tensor.pack %3 + inner_dims_pos = [0, 1] + inner_tiles = [8, 2] + into %dest : tensor -> tensor + return %4 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func.func @dynamic_broadcast_pack +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]]) : tensor +// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C0]] +// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]] +// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty(%[[OUTER_D1]]) : tensor +// CHECK: %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [2] +// CHECK-SAME: into %[[ARG1_EMPTY]] +// CHECK: %[[ELEM:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP4]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[PACK_ARG0]], %[[PACK_ARG0]] +// CHECK-SAME: outs(%[[DEST]] +// CHECK: return %[[ELEM]] : tensor + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> +#map2 = affine_map<(d0, d1) -> (d1)> +func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<100x200x4x16x16x32xi32>) -> tensor<100x200x4x16x16x32xi32> +{ + %init_transpose = tensor.empty() : tensor<100x200x128x256xi32> + %transpose = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0)>, + affine_map<(d0, d1, d2, d3) -> (d1)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>) + outs(%init_transpose : tensor<100x200x128x256xi32>) { + ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): + %0 = arith.addi %b0, %b1 : i32 + %1 = arith.addi %0, %b2 : i32 + linalg.yield %1 : i32 + } -> tensor<100x200x128x256xi32> + %4 = tensor.pack %transpose + inner_dims_pos = [3, 2] + inner_tiles = [16, 32] + into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32> + return %4 : tensor<100x200x4x16x16x32xi32> +} +// CHECK: func.func @transpose_pack +// CHECK: linalg.generic +// CHECK: tensor.pack diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt --- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRLinalgTestPasses + TestDataLayoutPropagation.cpp TestLinalgDecomposeOps.cpp TestLinalgElementwiseFusion.cpp TestLinalgFusionTransforms.cpp diff --git a/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Dialect/Linalg/TestDataLayoutPropagation.cpp @@ -0,0 +1,49 @@ +//===- TestDataLayoutPropagation.cpp --------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; + +namespace { +struct TestDataLayoutPropagationPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestDataLayoutPropagationPass) + + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert(); + } + + StringRef getArgument() const final { + return "test-linalg-data-layout-propagation"; + } + StringRef getDescription() const final { + return "Test data layout propagation"; + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + linalg::populateDataLayoutPropagationPatterns(patterns); + if (failed( + applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +namespace mlir { +namespace test { +void registerTestDataLayoutPropagation() { + PassRegistration(); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -72,6 +72,7 @@ void registerTestControlFlowSink(); void registerTestGpuSerializeToCubinPass(); void registerTestGpuSerializeToHsacoPass(); +void registerTestDataLayoutPropagation(); void registerTestDataLayoutQuery(); void registerTestDeadCodeAnalysisPass(); void registerTestDecomposeCallGraphTypes(); @@ -181,6 +182,7 @@ mlir::test::registerTestGpuSerializeToHsacoPass(); #endif mlir::test::registerTestDecomposeCallGraphTypes(); + mlir::test::registerTestDataLayoutPropagation(); mlir::test::registerTestDataLayoutQuery(); mlir::test::registerTestDeadCodeAnalysisPass(); mlir::test::registerTestDominancePass();