diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/CommandLine.h" @@ -256,7 +257,7 @@ } // namespace /// Utility function for replacing operands/results to a linalg generic -/// operation on tensors with unit-extent dimensions. These can be replaced with +/// operation with unit-extent dimensions. These can be replaced with /// an operand/result with the unit-extent dimension removed. This is only done /// if the indexing map used to access that didimensionmension has a /// AffineConstantExpr of value 0. Given the `type` of an result/operand of a @@ -301,10 +302,19 @@ ++dim; } // Compute the tensor or scalar replacement type. + Type actualType = opOperand->get().getType(); Type elementType = getElementTypeOrSelf(opOperand->get()); - Type replacementType = elementType == opOperand->get().getType() - ? elementType - : RankedTensorType::get(newShape, elementType); + Type replacementType; + if (elementType == opOperand->get().getType()) { + replacementType = elementType; + } else if (actualType.isa()) { + replacementType = RankedTensorType::get(newShape, elementType); + } else if (actualType.isa()) { + assert(actualType.cast().getAffineMaps().empty() && + "unsupported strided memrefs"); + replacementType = MemRefType::get(newShape, elementType); + } + assert(replacementType && "unsupported shaped type"); UnitExtentReplacementInfo info = {replacementType, AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(), @@ -324,14 +334,53 @@ return reassociationExprs; } -/// Pattern to replace tensors operands/results that are unit extents. -struct ReplaceUnitExtentTensors : public OpRewritePattern { +/// Pattern to replace tensor/buffer operands/results that are unit extents. +struct ReplaceUnitExtents : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + + // Return the original value if the type is unchanged, or reshape it. Return a + // nullptr if this is an unsupported type. + Value maybeExpand(Value result, Type origResultType, + ArrayAttr reassociationMap, Location loc, + PatternRewriter &rewriter) const { + if (origResultType == result.getType()) + return result; + if (origResultType.isa()) { + return rewriter.create( + loc, origResultType, result, + convertAffineMapArrayToExprs(reassociationMap)); + } + if (origResultType.isa()) { + return rewriter.create( + loc, origResultType, result, + convertAffineMapArrayToExprs(reassociationMap)); + } + return nullptr; + }; + + // Return the original value if the type is unchanged, or reshape it. Return a + // nullptr if this is an unsupported type. + Value maybeCollapse(Value operand, Type newInputOutputType, + ArrayAttr reassociationMap, Location loc, + PatternRewriter &rewriter) const { + auto operandType = operand.getType(); + if (operandType == newInputOutputType) + return operand; + if (operandType.isa()) { + return rewriter.create( + loc, newInputOutputType, operand, + convertAffineMapArrayToExprs(reassociationMap)); + } + if (operandType.isa()) { + return rewriter.create( + loc, newInputOutputType, operand, + convertAffineMapArrayToExprs(reassociationMap)); + } + return nullptr; + }; + LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - if (!genericOp.hasTensorSemantics()) - return failure(); - MLIRContext *context = rewriter.getContext(); Location loc = genericOp.getLoc(); @@ -339,7 +388,6 @@ SmallVector reassociationMaps; SmallVector newInputOutputTypes; bool doCanonicalization = false; - for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) { UnitExtentReplacementInfo replacementInfo = replaceUnitExtents(genericOp, opOperand, context); @@ -362,14 +410,13 @@ auto insertReshapes = [&](ValueRange values) { SmallVector res; res.reserve(values.size()); - for (auto operand : llvm::enumerate(values)) { - if (operand.value().getType() == newInputOutputTypes[flattenedIdx]) - res.push_back(operand.value()); - else { - res.push_back(rewriter.create( - loc, newInputOutputTypes[flattenedIdx], operand.value(), - convertAffineMapArrayToExprs(reassociationMaps[flattenedIdx]))); - } + for (auto operand : values) { + auto reshapedValue = + maybeCollapse(operand, newInputOutputTypes[flattenedIdx], + reassociationMaps[flattenedIdx], loc, rewriter); + assert(reshapedValue && + "expected ranked MemRef or Tensor operand type"); + res.push_back(reshapedValue); ++flattenedIdx; } return res; @@ -396,15 +443,13 @@ SmallVector resultReplacements; for (auto result : llvm::enumerate(replacementOp.getResults())) { unsigned index = result.index() + replacementOp.getNumInputs(); - RankedTensorType origResultType = genericOp.getResult(result.index()) - .getType() - .template cast(); - if (origResultType != result.value().getType()) { - resultReplacements.push_back(rewriter.create( - loc, origResultType, result.value(), - convertAffineMapArrayToExprs(reassociationMaps[index]))); - } else - resultReplacements.push_back(result.value()); + auto origResultType = genericOp.getResult(result.index()).getType(); + + auto newResult = maybeExpand(result.value(), origResultType, + reassociationMaps[index], loc, rewriter); + assert(newResult && + "unexpected output type other than ranked MemRef or Tensor"); + resultReplacements.push_back(newResult); } rewriter.replaceOp(genericOp, resultReplacements); return success(); @@ -501,9 +546,8 @@ void mlir::linalg::populateFoldUnitExtentDimsPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); - patterns.add( - context); + patterns.add(context); TensorCollapseShapeOp::getCanonicalizationPatterns(patterns, context); TensorExpandShapeOp::getCanonicalizationPatterns(patterns, context); } diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -451,3 +451,303 @@ // CHECK: %[[RESULT:.+]] = subtensor_insert %[[RESHAPE]] // CHECK-SAME: tensor into tensor<1x3xf32> // CHECK: return %[[RESULT]] + +// ----- + +#accesses = [ + affine_map<(i, j, k, l, m) -> (i, k, m)>, + affine_map<(i, j, k, l, m) -> ()>, + affine_map<(i, j, k, l, m) -> (i, k, j, l, m)> +] + +#trait = { + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], + indexing_maps = #accesses, + library_call = "some_external_func" +} + +func @drop_one_trip_loops(%arg0 : memref, %arg1 : f32, %shape: memref) -> memref { + linalg.generic #trait + ins(%arg0, %arg1 : memref, f32) + outs(%shape : memref) { + ^bb0(%arg2 : f32, %arg3 : f32, %arg4 : f32) : + linalg.yield %arg3 : f32 + } + return %shape : memref +} +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()> +// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-LABEL: func @drop_one_trip_loops +// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1], [2]] +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] + +// ----- + +#accesses = [ + affine_map<(i, j, k, l, m) -> (i, k, m)>, + affine_map<(i, j, k, l, m) -> (i, k, j, l, m)> +] + +#trait = { + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], + indexing_maps = #accesses, + library_call = "some_external_func" +} + +func @drop_one_trip_loops_indexed + (%arg0 : memref, %shape: memref) -> memref +{ + linalg.generic #trait + ins(%arg0 : memref) + outs(%shape: memref) { + ^bb0(%arg6 : i32, %arg7 : i32) : + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %idx2 = linalg.index 2 : index + %idx3 = linalg.index 3 : index + %idx4 = linalg.index 4 : index + %1 = addi %idx0, %idx1 : index + %2 = subi %1, %idx2 : index + %3 = subi %2, %idx3 : index + %4 = addi %3, %idx4 : index + %5 = index_cast %4 : index to i32 + %6 = addi %5, %arg6 : i32 + linalg.yield %6 : i32 + } + return %shape : memref +} +// The subtractions disappear the access map of the output memref maps its unit +// dimensions 1 and 3 to the index dimensions 2 and 3. +// CHECK-LABEL: func @drop_one_trip_loops_indexed +// CHECK: linalg.generic +// CHECK: ^{{.+}}( +// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i32, %{{.*}}: i32) +// CHECK: %[[IDX0:.+]] = linalg.index 0 : index +// CHECK: %[[IDX1:.+]] = linalg.index 1 : index +// CHECK: %[[IDX2:.+]] = linalg.index 2 : index +// CHECK: %[[T3:.+]] = addi %[[IDX0]], %[[IDX1]] +// CHECK: %[[T4:.+]] = addi %[[T3]], %[[IDX2]] +// CHECK: %[[T5:.+]] = index_cast %[[T4]] : index to i32 +// CHECK: %[[T6:.+]] = addi %[[T5]], %[[ARG4]] : i32 +// CHECK: linalg.yield %[[T6]] : i32 + +// ----- + +#map0 = affine_map<(i, j) -> (i, j)> +#access = [#map0, #map0] +#trait = { + iterator_types = ["parallel", "parallel"], + indexing_maps = #access, + library_call = "some_external_func" +} + +func @drop_all_loops(%arg0 : memref<1x1xf32>) -> memref<1x1xf32> +{ + linalg.generic #trait + ins(%arg0 : memref<1x1xf32>) + outs(%arg0 : memref<1x1xf32>) { + ^bb0(%arg1: f32, %arg2: f32) : + linalg.yield %arg1 : f32 + } + return %arg0 : memref<1x1xf32> +} +// CHECK: #[[$MAP0:.*]] = affine_map<() -> ()> +// CHECK-LABEL: func @drop_all_loops +// CHECK: linalg.collapse_shape %{{.*}} [] +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP0]]] +// CHECK-SAME: iterator_types = [] + +// ----- + +#map0 = affine_map<(i, j) -> (i, j)> +#access = [#map0, #map0] +#trait = { + iterator_types = ["parallel", "parallel"], + indexing_maps = #access, + library_call = "some_external_func" +} + +func @drop_all_loops_indexed + (%arg0 : memref<1x1xi32>) -> memref<1x1xi32>{ + linalg.generic #trait + ins(%arg0 : memref<1x1xi32>) + outs(%arg0 : memref<1x1xi32>) { + ^bb0(%arg3: i32, %arg4: i32) : + %idx0 = linalg.index 0 : index + %idx1 = linalg.index 1 : index + %1 = addi %idx0, %idx1 : index + %2 = index_cast %1 : index to i32 + %3 = addi %2, %arg3 : i32 + linalg.yield %3 : i32 + } + return %arg0 : memref<1x1xi32> +} + +// CHECK-LABEL: func @drop_all_loops_indexed +// CHECK: linalg.generic +// CHECK: ^{{.+}}(%[[ARG1:.+]]: i32, %[[ARG2:.+]]: i32) +// CHECK: linalg.yield %[[ARG1]] : i32 + +// ----- + +#accesses = [ + affine_map<(d0) -> (0, d0)>, + affine_map<(d0) -> (d0)> +] + +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel"], + library_call = "some_external_fn" +} + +func @leading_dim_1_canonicalization(%arg0: memref<1x5xf32>, %shape: memref<5xf32>) -> memref<5xf32> { + linalg.generic #trait + ins(%arg0 : memref<1x5xf32>) + outs(%shape : memref<5xf32>) { + ^bb0(%arg2: f32, %arg3: f32): // no predecessors + linalg.yield %arg2 : f32 + } + return %shape : memref<5xf32> +} +// CHECK: #[[$MAP1:.*]] = affine_map<(d0) -> (d0)> + +// CHECK-LABEL: func @leading_dim_1_canonicalization +// CHECK: linalg.collapse_shape %{{.*}} {{\[}}[0, 1]] +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]]] +// CHECK-SAME: iterator_types = ["parallel"] + +// ----- + +#accesses = [ + affine_map<(d0, d1) -> (0, d1)>, + affine_map<(d0, d1) -> (d0, 0)>, + affine_map<(d0, d1) -> (d0, d1)> +] + +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel"], + library_call = "some_external_fn" +} + +func @broadcast_test(%arg0 : memref<5xf32>, %arg1 : memref<5xf32>, %shape : memref<5x5xf32>) -> memref<5x5xf32> +{ + %0 = linalg.expand_shape %arg0 [[0, 1]] : memref<5xf32> into memref<1x5xf32> + %1 = linalg.expand_shape %arg1 [[0, 1]] : memref<5xf32> into memref<5x1xf32> + linalg.generic #trait + ins(%0, %1 : memref<1x5xf32>, memref<5x1xf32>) + outs(%shape : memref<5x5xf32>) { + ^bb0(%arg3: f32, %arg4: f32, %arg5: f32): + %3 = addf %arg3, %arg4 : f32 + linalg.yield %3 : f32 + } + return %shape : memref<5x5xf32> +} +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @broadcast_test +// CHECK-NOT: linalg.memref_{{.*}}shape +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]], #[[$MAP2]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-NOT: linalg.memref_{{.*}}shape + +// ----- + +#accesses = [ + affine_map<(d0, d1) -> (0, 0)>, + affine_map<(d0, d1) -> (d0, d1)> +] + +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel", "parallel"], + library_call = "some_external_fn" +} + +func @broadcast_scalar(%arg0 : memref<1x1xf32>, %shape : memref) -> memref +{ + linalg.generic #trait + ins(%arg0 : memref<1x1xf32>) + outs(%shape : memref) { + ^bb0(%arg2 : f32, %arg3 : f32): + linalg.yield %arg2 : f32 + } + return %shape : memref +} +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> ()> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)> +// CHECK-LABEL: func @broadcast_scalar +// CHECK-SAME: %[[ARG0:.*]]: memref<1x1xf32> +// CHECK: %[[A:.*]] = linalg.collapse_shape %[[ARG0]] [] +// CHECK-SAME: memref<1x1xf32> into memref +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP0]], #[[$MAP1]]] +// CHECK-SAME: iterator_types = ["parallel", "parallel"] +// CHECK-SAME: %[[A]] + +// ----- + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d2)> +func @fold_unit_dim_memref_reshape_op(%arg0 : memref<5xf32>) -> memref<2x5xf32> +{ + %1 = memref.alloc() : memref<1x2x5xf32> + linalg.generic {i64, indexing_maps = [#map1, #map0], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : memref<5xf32>) outs(%1 : memref<1x2x5xf32>) { + ^bb0(%arg1: f32, %arg2: f32): // no predecessors + linalg.yield %arg1 : f32 + } + %3 = linalg.collapse_shape %1 [[0, 1], [2]] + : memref<1x2x5xf32> into memref<2x5xf32> + return %3 : memref<2x5xf32> +} +// CHECK-LABEL: func @fold_unit_dim_memref_reshape_op +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<1x2x5xf32> +// CHECK: %[[OUT:.*]] = linalg.collapse_shape %[[ALLOC]] +// CHECK: linalg.generic +// CHECK-SAME: outs(%[[OUT:.*]] : +// CHECK: %[[RESULT:.*]] = linalg.collapse_shape %[[ALLOC]] +// CHECK: return %[[RESULT]] + +// ----- + +func @fold_unit_dim_for_init_memref(%input: memref<1x1000xf32>) -> memref<1xf32> { + %cst = constant 0.0 : f32 + %init = memref.alloc() : memref<1xf32> + linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"]} + ins(%input : memref<1x1000xf32>)outs(%init : memref<1xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %1823 = addf %arg1, %arg2 : f32 + linalg.yield %1823 : f32 + } + return %init : memref<1xf32> +} + + +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0) -> ()> + +// CHECK: func @fold_unit_dim_for_init_memref +// CHECK: %[[INIT:.+]] = memref.alloc() : memref<1xf32> +// CHECK: %[[INPUT_RESHAPE:.+]] = linalg.collapse_shape %{{.+}} {{\[}}[0, 1]] : memref<1x1000xf32> into memref<1000xf32> +// CHECK: %[[INIT_RESHAPE:.+]] = linalg.collapse_shape %[[INIT]] [] : memref<1xf32> into memref +// CHECK: linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP1]], #[[MAP2]]] +// CHECK-SAME: iterator_types = ["reduction"] +// CHECK-SAME: ins(%[[INPUT_RESHAPE]] : memref<1000xf32>) +// CHECK-SAME: outs(%[[INIT_RESHAPE]] : memref) +// CHECK: return %[[INIT:.+]] : memref<1xf32> + + +