diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -320,6 +320,12 @@ for (Value operand : fusedOp->getOperands()) if (auto sliceOp = operand.getDefiningOp()) candidates.push_back(sliceOp); + + if (fusedOp->getNumRegions() > 0) { + fusedOp->walk([&](tensor::ExtractSliceOp sliceOp) { + candidates.push_back(sliceOp); + }); + } }; std::deque candidates; diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp @@ -8,30 +8,351 @@ #include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" #include "mlir/Dialect/Arithmetic/Utils/Utils.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/Interfaces/TilingInterface.h" +#include "llvm/ADT/SmallBitVector.h" using namespace mlir; using namespace mlir::tensor; +static Value getAsValue(OpBuilder &b, Location loc, OpFoldResult ofr) { + Optional constValue = getConstantIntValue(ofr); + if (constValue.hasValue()) + return b.create(loc, *constValue); + return ofr.dyn_cast(); +} + +/// For an operation that implements the `ReifyRankedShapedTypeOpInterface`, use +/// that interface to construct Range's for the output shape. +static SmallVector +getIterationDomainUsingReifyShapedTypeInterface(Operation *op, OpBuilder &b) { + // Materialize the output shape values. + ReifiedRankedShapedTypeDims reifiedShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + dyn_cast(op); + (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes); + + Location loc = op->getLoc(); + Value zero = b.create(loc, 0); + Value one = b.create(loc, 1); + // Initialize all the ranges to {zero, one, one}. All the `ub`s are + // overwritten. + SmallVector loopRanges(reifiedShapes[0].size(), {zero, one, one}); + for (const auto &ub : enumerate(reifiedShapes[0])) + loopRanges[ub.index()].size = ub.value(); + return loopRanges; +} + +/// For an operation that implements `ReifyRankedShapedTypeOpInterface`, use +/// that interface to create a new destination operand using +/// `linalg.init_tensor`. +static Value getDestinationOperandValue(Operation *op, OpBuilder &b) { + ReifiedRankedShapedTypeDims reifiedShapes; + ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = + dyn_cast(op); + (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes); + + SmallVector mixedSizes = getAsOpFoldResult(reifiedShapes[0]); + Value initTensor = b.create( + op->getLoc(), mixedSizes, + op->getResultTypes()[0].cast().getElementType()); + return initTensor; +} + +// The input parameters `offsets` and `sizes` specify a rectangular slice of +// the collapse_shape output. Try to find which dimensions have been sliced +// and which dimensions are not sliced (offset = 0, size = dim). +static llvm::SmallBitVector getSlicedDimensions(OpBuilder &b, + CollapseShapeOp op, + ArrayRef offsets, + ArrayRef sizes) { + SmallVector ranges = + getIterationDomainUsingReifyShapedTypeInterface(op, b); + llvm::SmallBitVector result(op.getResultType().getRank()); + for (const auto &size : llvm::enumerate(sizes)) { + result[size.index()] = + !isEqualConstantIntOrValue(size.value(), ranges[size.index()].size); + } + return result; +} + +/// Determine which dimensions are linearized. +static llvm::SmallBitVector getLinearizedDimensions(CollapseShapeOp op) { + llvm::SmallBitVector result(op.getResultType().getRank()); + for (const auto &it : llvm::enumerate(op.getReassociationIndices())) { + result[it.index()] = it.value().size() > 1; + } + return result; +} + +static scf::LoopNest getEmptyLoopNest(OpBuilder &b, Location loc, + ArrayRef nestLowerBounds, + ArrayRef nestUpperBounds, + Value iterArgInit) { + // Create the loop nest. + scf::LoopNest nest; + Value iterArg = iterArgInit; + Location currLoc = loc; + Value one = b.create(loc, 1); + for (const auto &it : llvm::zip(nestLowerBounds, nestUpperBounds)) { + auto loop = b.create( + currLoc, std::get<0>(it), std::get<1>(it), one, iterArg, + [&](OpBuilder &nB, Location nLoc, Value nIvs, ValueRange nIterArgs) { + currLoc = nLoc; + iterArg = nIterArgs[0]; + }); + b.setInsertionPointToStart(loop.getBody()); + nest.loops.push_back(loop); + } + for (unsigned i = 0, e = nest.loops.size() - 1; i < e; i++) { + b.setInsertionPointToEnd(nest.loops[i].getBody()); + b.create(nest.loops[i]->getLoc(), + nest.loops[i + 1].getResults()); + } + return nest; +} + +/// Given a `tensor.collapse_shape` op and information regarding which +/// dimensions have been collapsed and which are being tiled, as well as the +/// multi-index elements for each delinearized index, create the +/// tiled form of the matmul. +static FailureOr createCollapseShapeTiledResult( + CollapseShapeOp collapseOp, Value producer, OpBuilder &b, + const llvm::SmallBitVector &slicedOutputDims, + const llvm::SmallBitVector &linearizedOutputDims, + ArrayRef> multiIndices, ArrayRef offsets, + ArrayRef sizes) { + // Construct offsets to extract from the result of the producer op (which is + // the input to the `tensor.collapse_shape`). + SmallVector sliceOffsets; + SmallVector sliceSizes; + int64_t loopIdx = 0; + RankedTensorType srcType = collapseOp.getSrcType(); + Location loc = collapseOp->getLoc(); + SmallVector reassociationIndices = + collapseOp.getReassociationIndices(); + for (const auto &it : llvm::enumerate(reassociationIndices)) { + // Case 1: De-linearized dimensions that have also been sliced. These a + // size of 1 because we are iterating over these dimensions. The offset is + // exactly the de-linearized multi index created from the iv's. + if (slicedOutputDims[it.index()] && linearizedOutputDims[it.index()]) { + sliceSizes.append(it.value().size(), b.getIndexAttr(1)); + sliceOffsets.append(llvm::to_vector(llvm::map_range( + multiIndices[loopIdx++], + [&](Value v) -> OpFoldResult { return getAsOpFoldResult(v); }))); + continue; + } + + // Case 2: One or possible multiple combined input dimensions, but we have + // proven that these are not sliced. In this case we just take the full + // extent of each index in the list. + if (linearizedOutputDims[it.index()]) { + for (int64_t srcIndex : it.value()) { + if (!srcType.isDynamicDim(srcIndex)) + sliceSizes.push_back(b.getIndexAttr(srcType.getDimSize(srcIndex))); + else + sliceSizes.push_back(b.createOrFold( + loc, producer, b.create(loc, srcIndex))); + sliceOffsets.push_back(b.getIndexAttr(0)); + } + continue; + } + + // Case 3: A single index, but it may be sliced. + sliceSizes.push_back(sizes[it.index()]); + sliceOffsets.push_back(offsets[it.index()]); + } + + Value tileResult = b.create( + loc, producer, sliceOffsets, sliceSizes, + SmallVector(sliceSizes.size(), b.getIndexAttr(1))); + + // Collapse the dimensions back down. + Value collapsedResult = + b.create(loc, tileResult, reassociationIndices); + return collapsedResult; +} + namespace { -struct PadOpTiling : public TilingInterface::ExternalModel { +struct CollapseShapeOpTiling + : public TilingInterface::ExternalModel { + /// Materializes the `linalg.init_tensor` destination. SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { - ReifiedRankedShapedTypeDims reifiedShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - dyn_cast(op); - (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes); + return {getDestinationOperandValue(op, b)}; + } - auto padOp = cast(op); - SmallVector mixedSizes = getAsOpFoldResult(reifiedShapes[0]); - Value initTensor = b.create( - op->getLoc(), mixedSizes, padOp.getResultType().getElementType()); - return {initTensor}; + SmallVector getLoopIteratorTypes(Operation *op) const { + auto collapseOp = cast(op); + SmallVector iteratorTypes(collapseOp.getResultType().getRank(), + getParallelIteratorTypeName()); + return iteratorTypes; + } + + SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { + return getIterationDomainUsingReifyShapedTypeInterface(op, b); + } + + struct LinearizedDimensionInfo { + int64_t outputDimension; + SmallVector basis; + + LinearizedDimensionInfo(int64_t outputDimension, + ArrayRef linearizedInputDims, + ArrayRef inputShape) + : outputDimension(outputDimension) { + for (auto x : linearizedInputDims) + basis.push_back(inputShape[x]); + } + }; + + SmallVector + getTiledImplementation(Operation *op, OpBuilder &b, ValueRange dest, + ArrayRef offsets, + ArrayRef sizes, + bool tileDestOperands) const { + FailureOr resultTileValue = generateResultTileValue( + op, b, /*resultNumber=*/0, dest, offsets, sizes, tileDestOperands); + if (failed(resultTileValue)) + return {}; + return {(*resultTileValue).getDefiningOp()}; + } + + // Produces a tile of `TilingInterface -> tensor.collapse_shape` by + // materializing a loop nest that assembles the tile. Any linearized + // dimensions are assembled by looping over the delinearized output + // (basically, tiling by 1) and stitching the tile together. + FailureOr generateResultTileValue(Operation *op, OpBuilder &b, + unsigned resultNumber, + ValueRange dest, + ArrayRef offsets, + ArrayRef sizes, + bool tileDestOperands) const { + auto collapseOp = cast(op); + Location loc = op->getLoc(); + + // Materialize the source shapes. + SmallVector sourceShape; + for (unsigned i = 0; i < collapseOp.getSrcType().getRank(); i++) { + Value dimIdx = b.createOrFold(loc, i); + sourceShape.push_back( + b.createOrFold(loc, collapseOp.getSrc(), dimIdx)); + } + + // Try to find which dimensions are sliced and/or linearized. + llvm::SmallBitVector slicedOutputDims = + getSlicedDimensions(b, collapseOp, offsets, sizes); + llvm::SmallBitVector linearizedOutputDims = + getLinearizedDimensions(collapseOp); + + // If there are no sliced and linearized dimensions, then we cannot proceed. + llvm::SmallBitVector slicedAndLinearized = + slicedOutputDims & linearizedOutputDims; + if (!slicedAndLinearized.any()) + return failure(); + + // Any caller of this function has passed an output + // buffer via `dest`. That output buffer is the size of the entire + // collapse_shape output. We need to slice this buffer to get the tiled + // output buffer. + Value iterArgInit = b.create( + loc, dest[0], offsets, sizes, + SmallVector(offsets.size(), b.getIndexAttr(1))); + + // Create the bounds for the loop nest. + SmallVector nestLowerBounds; + SmallVector nestUpperBounds; + for (int idx = slicedAndLinearized.find_first(); idx != -1; + idx = slicedAndLinearized.find_next(idx)) { + nestLowerBounds.push_back(getAsValue(b, loc, offsets[idx])); + nestUpperBounds.push_back(makeComposedAffineApply( + b, loc, b.getAffineDimExpr(0) + b.getAffineDimExpr(1), + {nestLowerBounds.back(), getAsValue(b, loc, sizes[idx])})); + } + scf::LoopNest nest = + getEmptyLoopNest(b, loc, nestLowerBounds, nestUpperBounds, iterArgInit); + + // Create the de-linearized multi indices at the start of each loop body. + SmallVector> multiIndices; + SmallVector reassociationIndices = + collapseOp.getReassociationIndices(); + for (unsigned i = 0, loopIdx = 0; i < reassociationIndices.size(); i++) { + if (!linearizedOutputDims[i] || !slicedOutputDims[i]) + continue; + assert(loopIdx < nest.loops.size()); + auto loop = nest.loops[loopIdx++]; + b.setInsertionPointToStart(loop.getBody()); + Value iv = loop.getInductionVar(); + + SmallVector basis; + for (auto idx : reassociationIndices[i]) { + basis.push_back(b.createOrFold( + loc, collapseOp.src(), b.create(loc, idx))); + } + auto delinOp = b.create(loop->getLoc(), + /*linear_index=*/iv, + /*basis=*/basis); + multiIndices.push_back(llvm::to_vector(llvm::map_range( + delinOp.getResults(), [](OpResult r) -> Value { return r; }))); + } + + // Fill out the first part of the loop body - sub tile for a single + // iteration. + scf::ForOp innerLoop = nest.loops.back(); + b.setInsertionPointToEnd(innerLoop.getBody()); + Value iterArg = innerLoop.getRegionIterArgs()[0]; + FailureOr tiledResult = createCollapseShapeTiledResult( + collapseOp, collapseOp.src(), b, slicedOutputDims, linearizedOutputDims, + multiIndices, offsets, sizes); + if (failed(tiledResult)) + return failure(); + + // Insert the collapse_shape sub-tile into the iteration argument. + SmallVector insertOffsets; + SmallVector insertSizes; + for (unsigned i = 0, loopIdx = 0; i < reassociationIndices.size(); i++) { + // Case 1: Linearized dimensions that have been sliced. The insert size is + // 1, and the offset is the iv. + if (linearizedOutputDims[i] && slicedOutputDims[i]) { + insertOffsets.push_back(nest.loops[loopIdx++].getInductionVar()); + insertSizes.push_back(b.getIndexAttr(1)); + continue; + } + // Case 2: Otherwise, the insert is the full shape of the iteration + // argument dimension, because this output dimension is not being iterated + // over in the loop nest. + insertOffsets.push_back(b.getIndexAttr(0)); + RankedTensorType iterArgsType = + iterArg.getType().cast(); + if (iterArgsType.isDynamicDim(i)) + insertSizes.push_back(b.createOrFold( + loc, iterArg, b.createOrFold(loc, i))); + else + insertSizes.push_back(b.getIndexAttr(iterArgsType.getDimSize(i))); + } + + Value result = b.create( + loc, *tiledResult, iterArg, insertOffsets, insertSizes, + /*strides=*/ + SmallVector(insertOffsets.size(), b.getIndexAttr(1))); + b.create(loc, result); + b.setInsertionPointAfter(op); + return nest.loops.begin()->getResult(0); + } +}; + +struct PadOpTiling : public TilingInterface::ExternalModel { + + SmallVector getDestinationOperands(Operation *op, OpBuilder &b) const { + return {getDestinationOperandValue(op, b)}; } SmallVector getLoopIteratorTypes(Operation *op) const { @@ -42,20 +363,7 @@ } SmallVector getIterationDomain(Operation *op, OpBuilder &b) const { - ReifiedRankedShapedTypeDims reifiedShapes; - ReifyRankedShapedTypeOpInterface reifyShapedTypeInterface = - dyn_cast(op); - (void)reifyShapedTypeInterface.reifyResultShapes(b, reifiedShapes); - - Location loc = op->getLoc(); - Value zero = b.create(loc, 0); - Value one = b.create(loc, 1); - // Initialize all the ranges to {zero, one, one}. All the `ub`s are - // overwritten. - SmallVector loopRanges(reifiedShapes[0].size(), {zero, one, one}); - for (const auto &ub : enumerate(reifiedShapes[0])) - loopRanges[ub.index()].size = ub.value(); - return loopRanges; + return getIterationDomainUsingReifyShapedTypeInterface(op, b); } SmallVector @@ -285,5 +593,6 @@ DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, TensorDialect *dialect) { tensor::PadOp::attachInterface(*ctx); + tensor::CollapseShapeOp::attachInterface(*ctx); }); } diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -183,3 +183,161 @@ // CHECK-SAME: outs(%[[OUTS_TILE]] : // CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[GENERIC_TILE]] into %[[ITERARG1]][%[[IV1]], %[[IV0]]] // CHECK scf.yield %[[INSERT]] + +// ----- + +func.func @collapsed_shape_generic_contraction(%arg0 : tensor, %arg1 : tensor<40x20x30x4xf32>, + %arg2 : tensor) -> tensor { + %0 = tensor.collapse_shape %arg0 [[0,1], [2,3], [4]] : tensor into tensor + %1 = tensor.collapse_shape %arg1 [[0], [1,2], [3]] : tensor<40x20x30x4xf32> into tensor<40x600x4xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %d = arith.muli %d0, %d1 : index + %2 = linalg.init_tensor [%d, 40] : tensor + %3 = linalg.generic { + __internal_linalg_transform__ = "collapsed_shape_generic_contraction", + indexing_maps = [ + affine_map<(d0,d1,d2,d3)->(d0,d2,d3)>, + affine_map<(d0,d1,d2,d3)->(d1,d2,d3)>, + affine_map<(d0,d1,d2,d3)->(d0,d1)> + ], + iterator_types = [ + "parallel", + "parallel", + "reduction", + "reduction" + ] + } + ins(%0, %1 : tensor, tensor<40x600x4xf32>) + outs(%2 : tensor) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %m = arith.mulf %arg3, %arg4 : f32 + %a = arith.addf %m, %arg5 : f32 + linalg.yield %a : f32 + } -> tensor + + %4 = tensor.expand_shape %3 [[0, 1], [2]] : tensor into tensor + return %4 : tensor +} + +// CHECK-DAG: #[[map0:.+]] = affine_map<(d0)[s0, s1] -> (10, -d0 + s1)> +// CHECK-DAG: #[[map1:.+]] = affine_map<(d0)[s0, s1] -> (20, -d0 + s1)> +// CHECK-DAG: #[[map2:.+]] = affine_map<(d0)[s0, s1] -> (1, -d0 + s1)> +// CHECK-DAG: #[[map3:.+]] = affine_map<()[s0] -> (s0 * 10)> +// CHECK-DAG: #[[map4:.+]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK: func.func @collapsed_shape_generic_contraction( +// CHECK-SAME: %[[arg0:.+]]: tensor, +// CHECK-SAME: %[[arg1:.+]]: tensor<40x20x30x4xf32>, +// CHECK-DAG: %[[c10:.+]] = arith.constant 10 : +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : +// CHECK-DAG: %[[c40:.+]] = arith.constant 40 : +// CHECK-DAG: %[[c600:.+]] = arith.constant 600 : +// CHECK-DAG: %[[c20:.+]] = arith.constant 20 : +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : +// CHECK-DAG: %[[c30:.+]] = arith.constant 30 : +// CHECK: scf.for %[[IV1:.+]] = %[[c0]] to %{{.*}} step %[[c10]] iter_args(%[[IA1:.+]] = +// CHECK: %[[TS1:.+]] = affine.min #[[map0]](%[[IV1]])[%[[c10]], %{{.*}}] +// CHECK: scf.for %[[IV2:.+]] = %[[c0]] to %[[c40]] step %[[c20]] iter_args(%[[IA2:.+]] = %[[IA1]]) +// CHECK: %[[TS2:.+]] = affine.min #[[map1]](%[[IV2]])[%[[c20]], %[[c40]]] +// CHECK: scf.for %[[IV3:.+]] = %[[c0]] to %[[c600]] step %[[c1]] iter_args(%[[IA3:.+]] = %[[IA2]]) +// CHECK: %[[TS3:.+]] = affine.min #[[map2]](%[[IV3]])[%[[c1]], %[[c600]]] +// CHECK: %[[D0:.+]] = tensor.dim %[[arg0]], %[[c0]] : +// CHECK: %[[OFFT1:.+]] = affine.apply #[[map3]]()[%[[D0]]] +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[TA_INIT:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], %[[IV3]], 0] [%[[TS1]], %[[TS3]], 4] [1, 1, 1] : +// CHECK: %[[OFFT1:.+]] = affine.apply #[[map4]](%[[IV1]], %[[TS1]]) +// CHECK: %[[OFFT2:.+]] = affine.apply #[[map4]](%[[IV3]], %[[TS3]]) +// CHECK: %[[TA:.+]] = scf.for %[[IV4:.+]] = %[[IV1]] to %[[OFFT1]] step %[[c1]] iter_args(%[[arg10:.+]] = %[[TA_INIT]]) +// CHECK: %[[D0:.+]] = tensor.dim %[[arg0]], %[[c0]] +// CHECK: %[[MI1:.+]]:2 = tensor.delinearize_index %[[IV4]](%[[D0]], %[[c10]] : +// CHECK: scf.for %[[IV5:.+]] = %[[IV3]] to %[[OFFT2]] step %[[c1]] iter_args(%[[arg12:.+]] = %[[arg10]]) +// CHECK: %[[MI2:.+]]:2 = tensor.delinearize_index %[[IV5]](%[[c20]], %[[c30]] : +// CHECK: %[[S1:.+]] = tensor.extract_slice %[[arg0]][%[[MI1]]#0, %[[MI1]]#1, %[[MI2]]#0, %[[MI2]]#1, 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1] : +// CHECK: %[[S2:.+]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2, 3], [4]{{\]}} : +// CHECK: tensor.insert_slice %[[S2]] into %[[arg12]][%[[IV4]], %[[IV5]], 0] [1, 1, 4] [1, 1, 1] : +// CHECK: %[[INIT:.+]] = linalg.init_tensor +// CHECK: %[[TB_INIT:.+]] = tensor.extract_slice %[[INIT]] +// CHECK: %[[OFFT1:.+]] = affine.apply #[[map4]](%[[IV3]], %[[TS3]]) +// CHECK: %[[TB:.+]] = scf.for %[[IV4:.+]] = %[[IV3]] to %[[OFFT1]] step %[[c1]] iter_args(%[[arg10:.+]] = %[[TB_INIT]]) +// CHECK: %[[MI1:.+]]:2 = tensor.delinearize_index %[[IV4]](%[[c20]], %[[c30]] : +// CHECK: %[[S1:.+]] = tensor.extract_slice %[[arg1]][%[[IV2]], %[[MI1]]#0, %[[MI1]]#1, 0] [%[[TS2]], 1, 1, 4] [1, 1, 1, 1] +// CHECK: %[[S2:.+]] = tensor.collapse_shape %[[S1]] {{\[}}[0], [1, 2], [3]{{\]}} +// CHECK: %[[D0:.+]] = tensor.dim %[[arg10]], %[[c0]] : +// CHECK: tensor.insert_slice %[[S2]] into %[[arg10]][0, %[[IV4]], 0] [%[[D0]], 1, 4] [1, 1, 1] : +// CHECK: %[[TC:.+]] = tensor.extract_slice %[[IA3]][%[[IV1]], %[[IV2]]] [%[[TS1]], %[[TS2]]] [1, 1] : +// CHECK: %[[TD:.+]] = linalg.generic +// CHECK-SAME: ins(%[[TA]], %[[TB]] : +// CHECK-SAME: outs(%[[TC]] : +// CHECK: tensor.insert_slice %[[TD]] into %[[IA3]][%[[IV1]], %[[IV2]]] [%[[TS1]], %[[TS2]]] [1, 1] : + +// ----- + +func.func @simple_collapse_shape_fusion(%arg0: tensor<4x4x2xf32>, %arg1: tensor<4x4x2xf32>) -> tensor<16x2xf32> { + %0 = linalg.init_tensor [4, 4, 2] : tensor<4x4x2xf32> + %1 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1, d2)[]->(d0, d1, d2)>, + affine_map<(d0, d1, d2)[]->(d0, d1, d2)>, + affine_map<(d0, d1, d2)[]->(d0, d1, d2)> + ], + iterator_types = ["parallel", "parallel", "parallel"] + } ins(%arg0, %arg1: tensor<4x4x2xf32>, tensor<4x4x2xf32>) + outs(%0 : tensor <4x4x2xf32>){ + ^bb0(%arg2 : f32, %arg3: f32, %accum: f32): + %5 = arith.addf %arg2, %arg3 : f32 + linalg.yield %5 : f32 + } -> tensor<4x4x2xf32> + + %2 = tensor.collapse_shape %1 [[0, 1], [2]] : tensor<4x4x2xf32> into tensor<16x2xf32> + %3 = linalg.init_tensor [16, 2] : tensor<16x2xf32> + %4 = linalg.generic { + indexing_maps = [ + affine_map<(d0, d1)[]->(d0, d1)>, + affine_map<(d0, d1)[]->(d0, d1)> + ], + iterator_types = ["parallel", "parallel"], + __internal_linalg_transform__ = "simple_collapse_shape_fusion" + } ins(%2: tensor<16x2xf32>) + outs(%3 : tensor <16x2xf32>) { + ^bb0(%arg4 : f32, %arg5: f32): + %c1 = arith.constant 1.0 : f32 + %6 = arith.addf %arg4, %c1 : f32 + linalg.yield %6 : f32 + } -> tensor<16x2xf32> + + return %4 : tensor<16x2xf32> +} + +// CHECK-DAG: #[[map0:.+]] = affine_map<(d0)[s0, s1] -> (8, -d0 + s1)> +// CHECK-DAG: #[[map1:.+]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK-DAG: #[[map2:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[map3:.+]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK: func.func @simple_collapse_shape_fusion(%[[arg0:.+]]: tensor<{{.*}}>, %[[arg1:.+]]: tensor<{{.*}}>) +// CHECK-DAG: %[[c0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[c16:.+]] = arith.constant 16 : index +// CHECK-DAG: %[[c8:.+]] = arith.constant 8 : index +// CHECK-DAG: %[[c1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[c4:.+]] = arith.constant 4 : index +// CHECK-DAG: %[[cst:.+]] = arith.constant 1.0 +// CHECK: scf.for %[[IV1:.+]] = %[[c0]] to %[[c16]] step %[[c8]] iter_args(%[[IA1:.+]] = +// CHECK: %[[TS:.+]] = affine.min #map0(%[[IV1]])[%[[c8]], %[[c16]]] +// CHECK: %[[INIT:.+]] = linalg.init_tensor [16, 2] : +// CHECK: %[[IA_INIT:.+]] = tensor.extract_slice %[[INIT]][%[[IV1]], 0] [%[[TS]], 2] [1, 1] : +// CHECK: %[[UB:.+]] = affine.apply #map1(%[[IV1]], %[[TS]]) +// CHECK: %[[TA:.+]] = scf.for %[[IV2:.+]] = %[[IV1]] to %[[UB]] step %[[c1]] iter_args(%[[IA2:.+]] = +// CHECK: %[[MI:.+]]:2 = tensor.delinearize_index %[[IV2]](%[[c4]], %[[c4]] : +// CHECK: %[[AS:.+]] = tensor.extract_slice %[[arg0]][%[[MI]]#0, %[[MI]]#1, 0] [1, 1, 2] [1, 1, 1] : +// CHECK: %[[BS:.+]] = tensor.extract_slice %[[arg1]][%[[MI]]#0, %[[MI]]#1, 0] [1, 1, 2] [1, 1, 1] : +// CHECK: %[[CS:.+]] = tensor.extract_slice %0[%[[MI]]#0, %[[MI]]#1, 0] [1, 1, 2] [1, 1, 1] : +// CHECK: %[[ST:.+]] = linalg.generic { +// CHECK-SAME: ins(%[[AS]], %[[BS]] : +// CHECK-SAME: outs(%[[CS]] : +// CHECK: %[[TS_FLAT:.+]] = tensor.collapse_shape %[[ST]] {{\[}}[0, 1], [2]{{\]}} : +// CHECK: tensor.insert_slice %[[TS_FLAT]] into %[[IA2]][%[[IV2]], 0] [1, 2] [1, 1] : +// CHECK: %[[TB:.+]] = tensor.extract_slice %[[IA1]][%[[IV1]], 0] [%[[TS]], 2] [1, 1] : +// CHECK: %[[TILE:.+]] = linalg.generic +// CHECK-SAME: ins(%[[TA]] : +// CHECK-SAME: outs(%[[TB]] : +// CHECK: tensor.insert_slice %[[TILE]] into %[[IA1]][%[[IV1]], 0] [%3, 2] [1, 1] : diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp @@ -184,6 +184,15 @@ addPatternForTiling< TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( context, {10}, "gemm_fusion", patterns); + + addPatternForTiling< + TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( + context, {8, 0}, "simple_collapse_shape_fusion", patterns); + addPatternForTiling< + TestTileConsumerAndFuseProducersUsingSCFForOpWithFilter>( + context, {10, 20, 1, 0}, "collapsed_shape_generic_contraction", + patterns); + return; } }