diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.h b/mlir/include/mlir/Dialect/Linalg/Passes.h --- a/mlir/include/mlir/Dialect/Linalg/Passes.h +++ b/mlir/include/mlir/Dialect/Linalg/Passes.h @@ -35,6 +35,7 @@ std::unique_ptr createLinalgElementwiseOpFusionPass(); std::unique_ptr createFoldReshapeOpsByLinearizationPass(); +std::unique_ptr createLinalgDataLayoutPropagationPass(); std::unique_ptr createLinalgNamedOpConversionPass(); diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -46,6 +46,14 @@ ]; } +def LinalgDataLayoutPropagation : Pass<"linalg-data-layout-propagation"> { + let summary = "Propagate data lyaout ops on tensors"; + let constructor = "mlir::createLinalgDataLayoutPropagationPass()"; + let dependentDialects = [ + "AffineDialect", "linalg::LinalgDialect", "tensor::TensorDialect" + ]; +} + def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> { let summary = "Convert from one named linalg op to another."; let constructor = "mlir::createLinalgNamedOpConversionPass()"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -81,6 +81,8 @@ RewritePatternSet &patterns, const ControlFusionFn &controlElementwiseOpFusion); +using ControlPropagationFn = std::function; + /// Pattern to remove dead operands and results of `linalg.generic` operations. /// This is effectively DCE for a linalg op. void populateEraseUnusedOperandsAndResultsPatterns(RewritePatternSet &patterns); diff --git a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h --- a/mlir/include/mlir/Dialect/Utils/IndexingUtils.h +++ b/mlir/include/mlir/Dialect/Utils/IndexingUtils.h @@ -69,12 +69,27 @@ template void applyPermutationToVector(SmallVector &inVec, ArrayRef permutation) { + assert(inVec.size() == permutation.size()); SmallVector auxVec(inVec.size()); for (const auto &en : enumerate(permutation)) auxVec[en.index()] = inVec[en.value()]; inVec = auxVec; } +/// Apply the projected permutation defined by `permutation` to `inVec`. +/// Element `i` in `inVec` is mapped to location `j = permutation[i]`. +/// E.g.: for an input vector `inVec = ['a', 'b', 'c']` and a projected +/// permutation vector `permutation = [2, 0]`, this function leaves `inVec = +/// ['c', 'a', 'b']`. +template +void applyProjectedPermToVector(SmallVector &inVec, + ArrayRef permutation) { + SmallVector auxVec = inVec; + for (const auto &en : enumerate(permutation)) + auxVec[en.index()] = inVec[en.value()]; + inVec = auxVec; +} + /// Helper that returns a subset of `arrayAttr` as a vector of int64_t. SmallVector getI64SubArray(ArrayAttr arrayAttr, unsigned dropFront = 0, unsigned dropBack = 0); diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt --- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ BufferizableOpInterfaceImpl.cpp Bufferize.cpp ConstantFold.cpp + DataLayoutPropagation.cpp DecomposeLinalgOps.cpp Detensorize.cpp DropUnitDims.cpp diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -0,0 +1,207 @@ +//===- DataLayoutPropagation.cpp -----------------------------------------===/// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Linalg/Passes.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +namespace mlir { +#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION +#include "mlir/Dialect/Linalg/Passes.h.inc" +} // namespace mlir + +using namespace mlir; +using namespace mlir::linalg; + +#define DEBUG_TYPE "linalg-data-layout-propagation" + +namespace { + +// TODO: Move it to tensor/Utils +Value getEmptyOpForPackOp(OpBuilder &b, Location loc, Value source, + ArrayRef innerTileSizes, + ArrayRef innerDimsPos, + ArrayRef outerDimsPerm) { + AffineExpr dim0, dim1; + bindDims(b.getContext(), dim0, dim1); + auto ceilDiv = [&](OpFoldResult v1, OpFoldResult v2) -> OpFoldResult { + return makeComposedFoldedAffineApply(b, loc, dim0.ceilDiv(dim1), {v1, v2}); + }; + + SmallVector mixedSizes = + tensor::createDimValues(b, loc, source); + mixedSizes.append(innerTileSizes.begin(), innerTileSizes.end()); + for (auto it : llvm::zip(innerDimsPos, innerTileSizes)) { + int64_t dimPos = std::get<0>(it); + OpFoldResult tileSize = std::get<1>(it); + mixedSizes[dimPos] = ceilDiv(mixedSizes[dimPos], tileSize); + } + applyProjectedPermToVector(mixedSizes, outerDimsPerm); + + auto elemType = source.getType().cast().getElementType(); + return b.create(loc, mixedSizes, elemType); +} + +/// Pattern to bubble up tensor.pack op through elementwise generic op. +class BubbleUpPackOpThroughElemGenericOp + : public OpRewritePattern { +public: + BubbleUpPackOpThroughElemGenericOp(MLIRContext *context, + ControlPropagationFn controlFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + controlPropagationFn(std::move(controlFn)) {} + + LogicalResult matchAndRewrite(tensor::PackOp packOp, + PatternRewriter &rewriter) const override { + auto genericOp = packOp.getSource().getDefiningOp(); + if (!genericOp) + return failure(); + + if (!isElementwise(genericOp)) + return failure(); + + // TODO: Relax the restriction. We are able to bubble up the pack op through + // multi-result generic op. It just needs more work. + if (genericOp.getNumResults() != 1) + return failure(); + + int64_t numLoops = genericOp.getNumLoops(); + int64_t numInnerLoops = packOp.getInnerDimsPos().size(); + int64_t newNumLoops = numLoops + numInnerLoops; + + OpOperand *opOperand = genericOp.getDpsInitOperand(0); + if (!controlPropagationFn(opOperand)) + return failure(); + + SmallVector outExprs( + genericOp.getMatchingIndexingMap(opOperand).getResults()); + // TODO: Add support for all permutation indexing maps. + if (!genericOp.getMatchingIndexingMap(opOperand).isIdentity()) { + return rewriter.notifyMatchFailure( + packOp, + "the result of generic op does not have identity indexing_map"); + } + + Location loc = packOp.getLoc(); + DenseMap dimAndTileMapping = + packOp.getDimAndTileMapping(); + SmallVector inputOperands; + SmallVector indexingMaps; + for (OpOperand *opOperand : genericOp.getDpsInputOperands()) { + Value operand = opOperand->get(); + AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand); + if (genericOp.isScalar(opOperand)) { + inputOperands.push_back(operand); + indexingMaps.push_back( + AffineMap::get(newNumLoops, 0, {}, rewriter.getContext())); + continue; + } + SmallVector exprs(indexingMap.getResults()); + + DenseMap iterMapToDim; + for (auto en : llvm::enumerate(indexingMap.getResults())) { + int64_t dimPos = en.value().cast().getPosition(); + if (!dimAndTileMapping.count(dimPos)) + continue; + iterMapToDim[dimPos] = en.index(); + } + SmallVector innerDimsPos; + SmallVector innerTileSizes; + for (auto it : llvm::enumerate( + llvm::zip(packOp.getInnerDimsPos(), packOp.getMixedTiles()))) { + int64_t dimPos = std::get<0>(it.value()); + if (!iterMapToDim.count(dimPos)) + continue; + innerDimsPos.push_back(iterMapToDim[dimPos]); + innerTileSizes.push_back(std::get<1>(it.value())); + exprs.push_back(rewriter.getAffineDimExpr(numLoops + it.index())); + } + SmallVector outerDimsPerm; + for (auto outDim : packOp.getOuterDimsPerm()) { + if (!iterMapToDim.count(outDim)) + continue; + outerDimsPerm.push_back(iterMapToDim[outDim]); + } + + // The operand does not have dimensions that relates to pack op. + if (innerDimsPos.empty() && outerDimsPerm.empty()) { + inputOperands.push_back(operand); + indexingMaps.push_back( + AffineMap::get(newNumLoops, 0, exprs, rewriter.getContext())); + continue; + } + + auto empty = getEmptyOpForPackOp(rewriter, loc, operand, innerTileSizes, + innerDimsPos, outerDimsPerm); + auto packedOperand = rewriter.create( + loc, operand, empty, innerDimsPos, innerTileSizes, llvm::None, + outerDimsPerm); + inputOperands.push_back(packedOperand); + indexingMaps.push_back( + AffineMap::get(newNumLoops, 0, exprs, rewriter.getContext())); + } + + SmallVector iterTypes = + genericOp.getIteratorTypesArray(); + iterTypes.append(numInnerLoops, utils::IteratorType::parallel); + + for (int i = 0; i < numInnerLoops; ++i) + outExprs.push_back(rewriter.getAffineDimExpr(numLoops + i)); + indexingMaps.push_back( + AffineMap::get(newNumLoops, 0, outExprs, rewriter.getContext())); + + auto newGenericOp = rewriter.create( + loc, packOp.getDestType(), inputOperands, packOp.getDest(), + indexingMaps, iterTypes, /*bodyBuild=*/nullptr, + linalg::getPrunedAttributeList(genericOp)); + BlockAndValueMapping bvm; + genericOp.getRegion().cloneInto(&newGenericOp.getRegion(), bvm); + rewriter.replaceOp(packOp, newGenericOp.getResult(0)); + + return success(); + } + +private: + ControlPropagationFn controlPropagationFn; +}; + +struct LinalgDataLayoutPropagationPass + : public impl::LinalgDataLayoutPropagationBase< + LinalgDataLayoutPropagationPass> { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + Operation *op = getOperation(); + MLIRContext *context = op->getContext(); + RewritePatternSet patterns(context); + + ControlPropagationFn defaultControlFn = [](OpOperand *producer) { + return true; + }; + + patterns.insert(context, + defaultControlFn); + if (failed(applyPatternsAndFoldGreedily(op->getRegions(), + std::move(patterns)))) + return signalPassFailure(); + } +}; +} // namespace + +std::unique_ptr mlir::createLinalgDataLayoutPropagationPass() { + return std::make_unique(); +} diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -0,0 +1,230 @@ +// RUN: mlir-opt %s -linalg-data-layout-propagation -split-input-file | FileCheck %s + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @dynamic_elem_pack(%arg0: tensor, %dest: tensor) -> tensor +{ + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.dim %arg0, %c1 : tensor + %2 = tensor.empty(%0, %1) : tensor + %3 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%2 : tensor) { + ^bb0(%arg3: f32, %arg4: f32): + %4 = arith.addf %arg3, %arg3 : f32 + linalg.yield %4 : f32 + } -> tensor + %4 = tensor.pack %3 + inner_dims_pos = [0, 1] + inner_tiles = [8, 2] + into %dest : tensor -> tensor + return %4 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func.func @dynamic_elem_pack +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]] +// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]], %[[OUTER_D1]]) : tensor +// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [8, 2] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ELEM:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[PACK_ARG0]] +// CHECK-SAME: outs(%[[DEST]] +// CHECK: return %[[ELEM]] : tensor + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_inner_dims(%arg0: tensor<128x256xi32>, %dest: tensor<4x16x16x32xi32>) -> tensor<4x16x16x32xi32>{ + %init = tensor.empty() : tensor<128x256xi32> + %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg3 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %pack = tensor.pack %elem + inner_dims_pos = [1, 0] + inner_tiles = [16, 32] + into %dest : tensor<128x256xi32> -> tensor<4x16x16x32xi32> + return %pack : tensor<4x16x16x32xi32> +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func.func @elem_pack_transpose_inner_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<4x16x16x32xi32> +// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [1, 0] inner_tiles = [16, 32] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ELEM:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[PACK_ARG0]] +// CHECK-SAME: outs(%[[DEST]] +// CHECK: return %[[ELEM]] : tensor<4x16x16x32xi32> + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x32x16xi32>) -> tensor<16x4x32x16xi32>{ + %init = tensor.empty() : tensor<128x256xi32> + %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg3 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %pack = tensor.pack %elem + outer_dims_perm = [1, 0] + inner_dims_pos = [0, 1] + inner_tiles = [32, 16] + into %dest : tensor<128x256xi32> -> tensor<16x4x32x16xi32> + return %pack : tensor<16x4x32x16xi32> +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func.func @elem_pack_transpose_outer_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> +// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ELEM:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[PACK_ARG0]] +// CHECK-SAME: outs(%[[DEST]] +// CHECK: return %[[ELEM]] : tensor<16x4x32x16xi32> + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_inner_and_outer_dims(%arg0: tensor<128x256xi32>, %dest: tensor<16x4x16x32xi32>) -> tensor<16x4x16x32xi32>{ + %init = tensor.empty() : tensor<128x256xi32> + %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg3 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %pack = tensor.pack %elem + outer_dims_perm = [1, 0] + inner_dims_pos = [1, 0] + inner_tiles = [16, 32] + into %dest : tensor<128x256xi32> -> tensor<16x4x16x32xi32> + return %pack : tensor<16x4x16x32xi32> +} +// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func.func @elem_pack_transpose_inner_and_outer_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x16x32xi32> +// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 32] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[ELEM:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[PACK_ARG0]] +// CHECK-SAME: outs(%[[DEST]] +// CHECK: return %[[ELEM]] : tensor<16x4x16x32xi32> + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> +#map2 = affine_map<(d0, d1) -> (d1)> +func.func @dynamic_broadcast_pack(%arg0: tensor, %arg1: tensor, %dest: tensor) -> tensor +{ + %c0 = arith.constant 0 : index + %0 = tensor.dim %arg0, %c0 : tensor + %1 = tensor.dim %arg1, %c0 : tensor + %2 = tensor.empty(%0, %1) : tensor + %3 = linalg.generic {indexing_maps = [#map1, #map2, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) + outs(%2 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %4 = arith.addf %arg3, %arg4 : f32 + linalg.yield %4 : f32 + } -> tensor + %4 = tensor.pack %3 + inner_dims_pos = [0, 1] + inner_tiles = [8, 2] + into %dest : tensor -> tensor + return %4 : tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d2)> +// CHECK-DAG: #[[MAP3:.+]] = affine_map<(d0, d1, d2, d3) -> (d1, d3)> +// CHECK-DAG: #[[MAP4:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK: func.func @dynamic_broadcast_pack +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]] +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[OUTER_D0:.+]] = affine.apply #[[MAP0]]()[%[[D0]]] +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty(%[[OUTER_D0]]) : tensor +// CHECK: %[[PACK_ARG0:.+]] = tensor.pack %[[ARG0]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [8] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG1]], %[[C0]] +// CHECK-DAG: %[[OUTER_D1:.+]] = affine.apply #[[MAP1]]()[%[[D1]]] +// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty(%[[OUTER_D1]]) : tensor +// CHECK: %[[PACK_ARG1:.+]] = tensor.pack %[[ARG1]] +// CHECK-SAME: inner_dims_pos = [0] inner_tiles = [2] +// CHECK-SAME: into %[[ARG1_EMPTY]] +// CHECK: %[[ELEM:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP2]], #[[MAP3]], #[[MAP4]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] +// CHECK-SAME: ins(%[[PACK_ARG0]], %[[PACK_ARG0]] +// CHECK-SAME: outs(%[[DEST]] +// CHECK: return %[[ELEM]] : tensor + +// ----- + +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d0)> +#map2 = affine_map<(d0, d1) -> (d1)> +func.func @transpose_pack(%arg0: tensor<100x128x200x256xi32>, %arg1: tensor<100xi32>, %arg2: tensor<128xi32>, %dest: tensor<100x200x4x16x16x32xi32>) -> tensor<100x200x4x16x16x32xi32> +{ + %init_transpose = tensor.empty() : tensor<100x200x128x256xi32> + %transpose = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, + affine_map<(d0, d1, d2, d3) -> (d0)>, + affine_map<(d0, d1, d2, d3) -> (d1)>, + affine_map<(d0, d1, d2, d3) -> (d0, d2, d1, d3)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"]} + ins(%arg0, %arg1, %arg2 : tensor<100x128x200x256xi32>, tensor<100xi32>, tensor<128xi32>) + outs(%init_transpose : tensor<100x200x128x256xi32>) { + ^bb0(%b0 : i32, %b1 : i32, %b2 : i32, %b3 : i32): + %0 = arith.addi %b0, %b1 : i32 + %1 = arith.addi %0, %b2 : i32 + linalg.yield %1 : i32 + } -> tensor<100x200x128x256xi32> + %4 = tensor.pack %transpose + inner_dims_pos = [3, 2] + inner_tiles = [16, 32] + into %dest : tensor<100x200x128x256xi32> -> tensor<100x200x4x16x16x32xi32> + return %4 : tensor<100x200x4x16x16x32xi32> +} +// CHECK: func.func @transpose_pack +// CHECK: linalg.generic +// CHECK: tensor.pack