diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -251,7 +251,7 @@ }; struct UnitExtentReplacementInfo { - RankedTensorType type; + ShapedType type; AffineMap indexMap; ArrayAttr reassociation; }; @@ -267,8 +267,8 @@ /// - modified index map that can be used to access the replaced result/operand /// - the reassociation that converts from the original tensor type to the /// modified tensor type. -static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap, - RankedTensorType type, +template +static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap, T type, MLIRContext *context) { ArrayRef shape = type.getShape(); ArrayRef exprs = indexMap.getResults(); @@ -301,11 +301,11 @@ reassociations.clear(); ++dim; } - UnitExtentReplacementInfo info = { - RankedTensorType::get(newShape, type.getElementType()), - AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(), - newIndexExprs, context), - ArrayAttr::get(context, reassociationMaps)}; + UnitExtentReplacementInfo info = {T::get(newShape, type.getElementType()), + AffineMap::get(indexMap.getNumDims(), + indexMap.getNumSymbols(), + newIndexExprs, context), + ArrayAttr::get(context, reassociationMaps)}; return info; } @@ -320,14 +320,12 @@ return reassociationExprs; } -/// Pattern to replace tensors operands/results that are unit extents. -struct ReplaceUnitExtentTensors : public OpRewritePattern { +/// Pattern to replace tensor/buffer operands/results that are unit extents. +template +struct ReplaceUnitExtents : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - if (!genericOp.hasTensorSemantics()) - return failure(); - MLIRContext *context = rewriter.getContext(); Location loc = genericOp.getLoc(); @@ -337,9 +335,12 @@ bool doCanonicalization = false; for (auto it : llvm::zip(genericOp.getIndexingMaps(), genericOp.getShapedOperandTypes())) { - auto replacementInfo = replaceUnitExtents( - std::get<0>(it), std::get<1>(it).template cast(), - context); + auto operand = std::get<1>(it).template dyn_cast(); + if (!operand) + return failure(); + + auto replacementInfo = + replaceUnitExtents(std::get<0>(it), operand, context); reassociationMaps.push_back(replacementInfo.reassociation); newIndexingMaps.push_back(replacementInfo.indexMap); newInputOutputTypes.push_back(replacementInfo.type); @@ -363,7 +364,7 @@ if (operand.value().getType() == newInputOutputTypes[flattenedIdx]) res.push_back(operand.value()); else - res.push_back(rewriter.create( + res.push_back(rewriter.create( loc, newInputOutputTypes[flattenedIdx], operand.value(), convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx]))); ++flattenedIdx; @@ -392,11 +393,11 @@ SmallVector resultReplacements; for (auto result : llvm::enumerate(replacementOp.getResults())) { unsigned index = result.index() + replacementOp.getNumInputs(); - RankedTensorType origResultType = genericOp.getResult(result.index()) - .getType() - .template cast(); + auto origResultType = genericOp.getResult(result.index()) + .getType() + .template cast(); if (origResultType != result.value().getType()) - resultReplacements.push_back(rewriter.create( + resultReplacements.push_back(rewriter.create( loc, origResultType, result.value(), convertAffineMapArrayToExprs(reassociationMaps[index]))); else @@ -497,9 +498,10 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); - patterns.add( - context); + patterns + .add, + ReplaceUnitExtents, + UseRankReducedSubTensorOp, UseRankReducedSubTensorInsertOp>(context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); } diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -449,3 +449,40 @@ // CHECK: %[[RESULT:.+]] = subtensor_insert %[[RESHAPE]] // CHECK-SAME: tensor into tensor<1x3xf32> // CHECK: return %[[RESULT]] + +// ----- + +#accesses = [ + affine_map<(d0, d1) -> (0, 0)>, + affine_map<(d0, d1) -> (d0, d1)> +] + +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel"], + library_call = "some_external_fn" +} + +// Confirm the linalg.generic patterns also work on memrefs +func @broadcast_scalar_memref(%arg0 : memref<1x1xf32>, %shape : memref) -> memref +{ + linalg.generic #trait + ins(%arg0 : memref<1x1xf32>) + outs(%shape : memref) { + ^bb0(%arg2 : f32, %arg3 : f32): + linalg.yield %arg2 : f32 + } + return %shape : memref +} +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @broadcast_scalar_memref +// CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf32> +// CHECK: %[[A:.*]] = linalg.reshape %[[ARG0]] [] +// CHECK-SAME: memref<1x1xf32> into memref +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: %[[A]] + +