diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/CommandLine.h" @@ -251,14 +252,14 @@ }; struct UnitExtentReplacementInfo { - RankedTensorType type; + ShapedType type; AffineMap indexMap; ArrayAttr reassociation; }; } // namespace /// Utility function for replacing operands/results to a linalg generic -/// operation on tensors with unit-extent dimensions. These can be replaced with +/// operation with unit-extent dimensions. These can be replaced with /// an operand/result with the unit-extent dimension removed. This is only done /// if the indexing map used to access that didimensionmension has a /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a @@ -267,8 +268,8 @@ /// - modified index map that can be used to access the replaced result/operand /// - the reassociation that converts from the original tensor type to the /// modified tensor type. -static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap, - RankedTensorType type, +template +static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap, T type, MLIRContext *context) { ArrayRef shape = type.getShape(); ArrayRef exprs = indexMap.getResults(); @@ -301,11 +302,11 @@ reassociations.clear(); ++dim; } - UnitExtentReplacementInfo info = { - RankedTensorType::get(newShape, type.getElementType()), - AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(), - newIndexExprs, context), - ArrayAttr::get(context, reassociationMaps)}; + UnitExtentReplacementInfo info = {T::get(newShape, type.getElementType()), + AffineMap::get(indexMap.getNumDims(), + indexMap.getNumSymbols(), + newIndexExprs, context), + ArrayAttr::get(context, reassociationMaps)}; return info; } @@ -320,14 +321,11 @@ return reassociationExprs; } -/// Pattern to replace tensors operands/results that are unit extents. -struct ReplaceUnitExtentTensors : public OpRewritePattern { +/// Pattern to replace tensor/buffer operands/results that are unit extents. +struct ReplaceUnitExtents : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - if (!genericOp.hasTensorSemantics()) - return failure(); - MLIRContext *context = rewriter.getContext(); Location loc = genericOp.getLoc(); @@ -337,9 +335,20 @@ bool doCanonicalization = false; for (auto it : llvm::zip(genericOp.getIndexingMaps(), genericOp.getShapedOperandTypes())) { - auto replacementInfo = replaceUnitExtents( - std::get<0>(it), std::get<1>(it).template cast(), - context); + auto map = std::get<0>(it); + auto operandType = std::get<1>(it); + UnitExtentReplacementInfo replacementInfo; + if (auto tensorType = operandType.template dyn_cast()) { + replacementInfo = replaceUnitExtents(map, tensorType, context); + } else if (auto memRefType = + operandType.template dyn_cast()) { + if (!memRefType.getAffineMaps().empty()) + return failure(); + replacementInfo = replaceUnitExtents(map, memRefType, context); + } else { + return failure(); + } + reassociationMaps.push_back(replacementInfo.reassociation); newIndexingMaps.push_back(replacementInfo.indexMap); newInputOutputTypes.push_back(replacementInfo.type); @@ -360,12 +369,21 @@ SmallVector res; res.reserve(values.size()); for (auto operand : llvm::enumerate(values)) { - if (operand.value().getType() == newInputOutputTypes[flattenedIdx]) - res.push_back(operand.value()); - else - res.push_back(rewriter.create( + auto operandType = operand.value().getType(); + Value reshapedValue; + if (operandType == newInputOutputTypes[flattenedIdx]) + reshapedValue = operand.value(); + else if (operandType.isa()) + reshapedValue = rewriter.create( loc, newInputOutputTypes[flattenedIdx], operand.value(), - convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx]))); + convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx])); + else if (operandType.isa()) + reshapedValue = rewriter.create( + loc, newInputOutputTypes[flattenedIdx], operand.value(), + convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx])); + assert(reshapedValue && + "expected ranked MemRef or Tensor operand type"); + res.push_back(reshapedValue); ++flattenedIdx; } return res; @@ -392,15 +410,22 @@ SmallVector resultReplacements; for (auto result : llvm::enumerate(replacementOp.getResults())) { unsigned index = result.index() + replacementOp.getNumInputs(); - RankedTensorType origResultType = genericOp.getResult(result.index()) - .getType() - .template cast(); - if (origResultType != result.value().getType()) - resultReplacements.push_back(rewriter.create( + auto origResultType = genericOp.getResult(result.index()).getType(); + Value newResult; + if (origResultType == result.value().getType()) + newResult = result.value(); + else if (origResultType.isa()) + newResult = rewriter.create( loc, origResultType, result.value(), - convertAffineMapArrayToExprs(reassociationMaps[index]))); - else - resultReplacements.push_back(result.value()); + convertAffineMapArrayToExprs(reassociationMaps[index])); + else if (origResultType.isa()) + newResult = rewriter.create( + loc, origResultType, result.value(), + convertAffineMapArrayToExprs(reassociationMaps[index])); + + assert(newResult && + "unexpected output type other than ranked MemRef or Tensor"); + resultReplacements.push_back(newResult); } rewriter.replaceOp(genericOp, resultReplacements); return success(); @@ -497,9 +522,8 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); - patterns.add( - context); + patterns.add(context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); } diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -449,3 +449,40 @@ // CHECK: %[[RESULT:.+]] = subtensor_insert %[[RESHAPE]] // CHECK-SAME: tensor into tensor<1x3xf32> // CHECK: return %[[RESULT]] + +// ----- + +#accesses = [ + affine_map<(d0, d1) -> (0, 0)>, + affine_map<(d0, d1) -> (d0, d1)> +] + +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel"], + library_call = "some_external_fn" +} + +// Confirm the linalg.generic patterns also work on memrefs +func @broadcast_scalar_memref(%arg0 : memref<1x1xf32>, %shape : memref) -> memref +{ + linalg.generic #trait + ins(%arg0 : memref<1x1xf32>) + outs(%shape : memref) { + ^bb0(%arg2 : f32, %arg3 : f32): + linalg.yield %arg2 : f32 + } + return %shape : memref +} +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @broadcast_scalar_memref +// CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf32> +// CHECK: %[[A:.*]] = linalg.reshape %[[ARG0]] [] +// CHECK-SAME: memref<1x1xf32> into memref +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: %[[A]] + +