diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -772,23 +772,80 @@ } }; -/// Remove generic operations (on tensors) that are just copying -/// the values from inputs to the results. Requirements are -/// 1) All iterator types are parallel -/// 2) The body contains just a yield operation with the yielded values being -/// the arguments corresponding to the operands. -struct EraseIdentityGenericOp : public OpRewritePattern { +/// Given the indexing maps and types of an {operand, result} pair check if all +/// expressions used to access the operand and result are the same, except for +/// unit dimensions. +static Optional> +checkForTrivialBroadcast(AffineMap higherDimIndexingMap, Type higherDimType, + AffineMap lowerDimIndexingMap, Type lowerDimType) { + if (!higherDimIndexingMap.isProjectedPermutation() || + !lowerDimIndexingMap.isProjectedPermutation()) + return llvm::None; + unsigned higherDimPos = 0, lowerDimPos = 0; + ArrayRef higherDimShape = {}, lowerDimShape = {}; + if (auto higherDimTensorType = higherDimType.dyn_cast()) + higherDimShape = higherDimTensorType.getShape(); + if (auto lowerDimTensorType = lowerDimType.dyn_cast()) + lowerDimShape = lowerDimTensorType.getShape(); + ArrayRef higherDimExprs = higherDimIndexingMap.getResults(), + lowerDimExprs = lowerDimIndexingMap.getResults(); + SmallVector reassociation; + ReassociationIndices currReassociation; + unsigned foldedDim = 0; + while (higherDimPos < higherDimShape.size() && + lowerDimPos < lowerDimShape.size()) { + // If the access expressions are the same, it could still be a copy. + if (higherDimExprs[higherDimPos] == lowerDimExprs[lowerDimPos]) { + currReassociation.push_back(foldedDim++); + reassociation.emplace_back(ReassociationIndices{}); + std::swap(currReassociation, reassociation.back()); + higherDimPos++, lowerDimPos++; + continue; + } + // Check if the higherDim is unit-dimension, if so, could still be a copy. + if (higherDimShape[higherDimPos] == 1) { + currReassociation.push_back(foldedDim++); + higherDimPos++; + continue; + } + // If none of these hold, not a trivial broadcast or copy. + return llvm::None; + } + // currReassociation should always be empty at this stage. If not just push it + // into the list. + if (!currReassociation.empty()) { + reassociation.emplace_back(std::move(currReassociation)); + } + + // For the case where lowerDimType is rank 0 (or scalar) the + // reassociation will be empty. In other cases, check for trailing dimensions + // in higherDim/lowerDim and append folding to the last ReassociationIndices. + ReassociationIndices trailingDimFolding; + while (higherDimPos < higherDimShape.size()) { + if (higherDimShape[higherDimPos] != 1) + return llvm::None; + higherDimPos++; + trailingDimFolding.push_back(foldedDim++); + } + + if (!reassociation.empty()) + reassociation.back().append(trailingDimFolding); + + return reassociation; +} + +/// Erase generic ops that are either copies or trivial broadcasts (where the +/// broadcast dimension is 1) of the operands. +struct EraseTrivialCopyOrBroadcastOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - // Check all indexing maps are identity. - if (llvm::any_of(genericOp.getIndexingMaps(), - [](AffineMap map) { return !map.isIdentity(); })) + LinalgOp linalgOp = cast(genericOp.getOperation()); + if (!genericOp.hasTensorSemantics()) return failure(); - // Check that the body of the linalg operation is just a linalg.yield - // operation. + // The body must be just a yield operation. Block &body = genericOp.region().front(); if (!llvm::hasSingleElement(body)) return failure(); @@ -796,38 +853,115 @@ if (!yieldOp) return failure(); - // In the buffer case, we need to check exact buffer equality. - if (genericOp.hasBufferSemantics()) { - if (genericOp.getNumInputs() == 1 && genericOp.getNumOutputs() == 1 && - genericOp.getInputOperand(0)->get() == - genericOp.getOutputOperand(0)->get()) { - rewriter.eraseOp(genericOp); - return success(); - } - return failure(); - } - - // Get the argument number of the returned values. That is the operand - // number to use for replacing uses of this operation. - SmallVector returnedArgs; + struct ReplacementInfo { + unsigned operandNumber; + SmallVector reassociation; + }; + SmallVector replacementInfo; + // Check if all operands are copy or trivial-broadcasts. Collect information + // needed, while not creating any ops (to avoid dead code generation on + // failure). + // TODO: It is possible to keep track of the non-copy/non-trivial-broadcast + // operands and create a new generic op. Avoiding it to reduce complexity + // here. for (const auto &yieldVal : llvm::enumerate(yieldOp.values())) { auto yieldArg = yieldVal.value().dyn_cast(); if (!yieldArg || yieldArg.getOwner() != &body) return failure(); unsigned argumentNumber = yieldArg.getArgNumber(); - Value returnedArg = genericOp->getOperand(argumentNumber); - Type resultType = genericOp->getResult(yieldVal.index()).getType(); - // The input can have a different type than the result, e.g. a dynamic - // input dimension can be turned into a static output dimension. - if (returnedArg.getType() != resultType) - returnedArg = rewriter.create(genericOp.getLoc(), - resultType, returnedArg); - returnedArgs.push_back(returnedArg); + OpOperand &operand = genericOp->getOpOperand(argumentNumber); + auto operandType = operand.get().getType(); + auto operandRank = operandType.isa() + ? operandType.cast().getRank() + : 0; + AffineMap operandIndexingMap = linalgOp.getTiedIndexingMap(&operand); + + OpResult result = genericOp->getResult(yieldVal.index()); + auto resultType = result.getType().cast(); + auto resultRank = resultType.isa() + ? resultType.cast().getRank() + : 0; + AffineMap resultIndexingMap = + linalgOp.getTiedIndexingMapForResult(result); + + Optional> reassociation; + if (operandRank > resultRank) { + reassociation = checkForTrivialBroadcast( + operandIndexingMap, operandType, resultIndexingMap, resultType); + } else { + reassociation = checkForTrivialBroadcast( + resultIndexingMap, resultType, operandIndexingMap, operandType); + } + + if (!reassociation) + return failure(); + ReplacementInfo info{operand.getOperandNumber(), + std::move(reassociation.getValue())}; + replacementInfo.emplace_back(std::move(info)); + } + + // All operands are copy or trivial-broadcasts, find replacement values. + // They are either the operand value, or operand followed by a reshape to + // match the result shape. + SmallVector replacements; + Location loc = genericOp->getLoc(); + replacements.reserve(genericOp->getNumResults()); + for (const auto &info : llvm::enumerate(replacementInfo)) { + Value operand = genericOp->getOperand(info.value().operandNumber); + Value result = genericOp->getResult(info.index()); + Value replacement = operand; + int64_t operandRank = 0, resultRank = 0; + if (auto operandTensorType = + operand.getType().dyn_cast()) + operandRank = operandTensorType.getRank(); + if (auto resultTensorType = result.getType().dyn_cast()) + resultRank = resultTensorType.getRank(); + if (operandRank > resultRank) { + replacement = rewriter.create( + loc, replacement, info.value().reassociation); + } else if (operandRank < resultRank) { + replacement = rewriter.create( + loc, result.getType(), replacement, info.value().reassociation); + } + if (replacement.getType() != result.getType()) { + replacement = + rewriter.create(loc, result.getType(), replacement); + } + replacements.push_back(replacement); } - if (returnedArgs.size() != genericOp->getNumResults()) + rewriter.replaceOp(genericOp, replacements); + return success(); + } +}; + +/// Remove generic operations (on buffers) that are just copying +/// the values from inputs to the results. Requirements are +/// 1) All iterator types are parallel +/// 2) The body contains just a yield operation with the yielded values being +/// the arguments corresponding to the operands. +struct EraseBufferCopyOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + // In the buffer case, we need to check exact buffer equality. + if (!genericOp.hasBufferSemantics()) + return failure(); + if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1 || + genericOp.getInputOperand(0)->get() != + genericOp.getOutputOperand(0)->get()) + return failure(); + + // The body must be just a yield operation. + Block &body = genericOp.region().front(); + if (!llvm::hasSingleElement(body)) + return failure(); + auto yieldOp = dyn_cast(body.getTerminator()); + if (!yieldOp) return failure(); - rewriter.replaceOp(genericOp, returnedArgs); + + rewriter.eraseOp(genericOp); return success(); } }; @@ -835,7 +969,8 @@ void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -650,3 +650,136 @@ } : tensor<400x273xf32> to tensor<412x276xf32> return %pad : tensor<412x276xf32> } + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d0, d2, d3, d4, d1, d6, d5, d7, d8, d9, d10)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d8, d1, d6)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d3, d1, d6, d8)> +func @trivial_broadcast(%arg0 : tensor<1x1x3x1x4x5x1x1x6x1x1xf32>, %arg1 : tensor<6x4x5xf32>) -> + (tensor<6x4x5xf32>, tensor<3x4x5x6xf32>) { + %init = linalg.init_tensor [3, 4, 5, 6] : tensor<3x4x5x6xf32> + %0:2 = linalg.generic { + indexing_maps = [#map0, #map1, #map1, #map2], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", + "parallel", "parallel", "parallel", "parallel", "parallel"]} + ins(%arg0, %arg1 : tensor<1x1x3x1x4x5x1x1x6x1x1xf32>, tensor<6x4x5xf32>) + outs(%arg1, %init : tensor<6x4x5xf32>, tensor<3x4x5x6xf32>) { + ^bb0(%b0: f32, %b1 : f32, %b2 : f32, %b3 : f32): + linalg.yield %b1, %b0 : f32, f32 + } -> (tensor<6x4x5xf32>, tensor<3x4x5x6xf32>) + return %0#0, %0#1 : tensor<6x4x5xf32>, tensor<3x4x5x6xf32> +} +// CHECK: func @trivial_broadcast( +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x3x1x4x5x1x1x6x1x1xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<6x4x5xf32> +// CHECK: %[[RESHAPE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3, 4], [5], [6, 7, 8, 9, 10]{{\]}} +// CHECK: return %[[ARG1]], %[[RESHAPE]] + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d0, d2, d3, d4, d1, d6, d5, d7, d8, d9, d10)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d8, d1, d6)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8, d9, d10) -> (d3, d1, d6, d8)> +func @trivial_broadcast_result(%arg0 : tensor<3x4x5x6xf32>, %arg1 : tensor<6x4x5xf32>) -> + (tensor<6x4x5xf32>, tensor<1x1x3x1x4x5x1x1x6x1x1xf32>) { + %init = linalg.init_tensor [1, 1, 3, 1, 4, 5, 1, 1, 6, 1, 1] : tensor<1x1x3x1x4x5x1x1x6x1x1xf32> + %0:2 = linalg.generic { + indexing_maps = [#map2, #map1, #map1, #map0], + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", + "parallel", "parallel", "parallel", "parallel", "parallel"]} + ins(%arg0, %arg1 : tensor<3x4x5x6xf32>, tensor<6x4x5xf32>) + outs(%arg1, %init : tensor<6x4x5xf32>, tensor<1x1x3x1x4x5x1x1x6x1x1xf32>) { + ^bb0(%b0: f32, %b1 : f32, %b2 : f32, %b3 : f32): + linalg.yield %b1, %b0 : f32, f32 + } -> (tensor<6x4x5xf32>, tensor<1x1x3x1x4x5x1x1x6x1x1xf32>) + return %0#0, %0#1 : tensor<6x4x5xf32>, tensor<1x1x3x1x4x5x1x1x6x1x1xf32> +} +// CHECK: func @trivial_broadcast_result( +// CHECK-SAME: %[[ARG0:.+]]: tensor<3x4x5x6xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<6x4x5xf32> +// CHECK: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0, 1, 2], [3, 4], [5], [6, 7, 8, 9, 10]{{\]}} +// CHECK: return %[[ARG1]], %[[RESHAPE]] + +// ----- + +func @scalar_copy(%arg0 : tensor) -> tensor<1x1xf32> { + %init = linalg.init_tensor [1, 1] : tensor<1x1xf32> + %0 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> ()>, affine_map<(d0, d1) -> (d1, d0)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor) outs(%init : tensor<1x1xf32>) { + ^bb0(%b0 : f32, %b1 : f32): + linalg.yield %b0 : f32 + } -> tensor<1x1xf32> + return %0 : tensor<1x1xf32> +} +// CHECK: func @scalar_copy( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: %[[RESHAPE:.+]] = tensor.expand_shape %[[ARG0]] [] +// CHECK: return %[[RESHAPE]] + +// ----- + +// Op cannot be removed if it is not just a copy. +func @no_remove_op(%arg0 : tensor) -> tensor { + %0 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%arg0 : tensor) outs(%arg0 : tensor) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.addf %b0, %b0 : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} +// CHECK: func @no_remove_op( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor) +// CHECK-NEXT: %[[RETURN:.+]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : +// CHECK-SAME: outs(%[[ARG0]] : +// CHECK: return %[[RETURN]] + +// ----- + +func @no_remove_op_memref(%arg0 : memref, %arg1 : memref) { + linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%arg0 : memref) outs(%arg1 : memref) { + ^bb0(%b0 : f32, %b1 : f32): + %0 = arith.addf %b0, %b0 : f32 + linalg.yield %0 : f32 + } + return +} +// CHECK: func @no_remove_op_memref( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref +// CHECK-NEXT: linalg.generic +// CHECK-SAME: ins(%[[ARG0]] : +// CHECK-SAME: outs(%[[ARG1]] : + +// ----- + +// Check that op is not canonicalized as a copy when one of the operands is not +// a trivial copy or broadcast. +#map0 = affine_map<(d0, d1) -> (d0, d1)> +#map1 = affine_map<(d0, d1) -> (d1)> +func @no_remove_op(%arg0 : tensor, %arg1 : tensor) -> (tensor, tensor) { + %0:2 = linalg.generic { + indexing_maps = [#map0, #map1, #map0, #map0], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %arg1 : tensor, tensor) outs(%arg0, %arg0 : tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32): + linalg.yield %b0, %b1 : f32, f32 + } -> (tensor, tensor) + return %0#0, %0#1 : tensor, tensor +} +// CHECK: func @no_remove_op( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor, +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) +// CHECK-NEXT: %[[RETURN:.+]]:2 = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CHECK-SAME: outs(%[[ARG0]], %[[ARG0]] : +// CHECK: return %[[RETURN]]#0, %[[RETURN]]#1