diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td @@ -714,6 +714,19 @@ return *(indexingMaps.begin() + opOperand->getOperandNumber()); }] >, + InterfaceMethod< + /*desc=*/[{ + Return the input or output indexing map for `opOperand`. + }], + /*retTy=*/"Attribute", + /*methodName=*/"getTiedIndexingMapAttr", + /*args=*/(ins "OpOperand*":$opOperand), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(opOperand->getOwner() == this->getOperation()); + return $_op.indexing_maps()[opOperand->getOperandNumber()]; + }] + >, InterfaceMethod< /*desc=*/[{ Return the indexing map for a `result`. diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -165,6 +165,12 @@ let regions = (region AnyRegion:$region); let builders = [ + OpBuilder<(ins "TypeRange":$resultTensorTypses, "ValueRange":$inputs, + "ValueRange":$outputs, "ArrayAttr":$indexingMaps, + "ArrayAttr":$iteratorTypes, "StringAttr":$doc, + "StringAttr":$libraryCall, + "function_ref", + CArg<"ArrayRef", "{}">:$attributes)>, OpBuilder<(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs, "ValueRange":$outputs, "ArrayRef":$indexingMaps, "ArrayRef":$iteratorTypes, "StringRef":$doc, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -12,12 +12,15 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Arithmetic/Utils/Utils.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/SCF/SCF.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/AffineExprVisitor.h" +#include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" @@ -509,16 +512,12 @@ //===----------------------------------------------------------------------===// void GenericOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, - ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, - ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, + ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps, + ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall, function_ref bodyBuild, ArrayRef attributes) { - build(builder, result, resultTensorTypes, inputs, outputs, - builder.getAffineMapArrayAttr(indexingMaps), - builder.getStrArrayAttr(iteratorTypes), - doc.empty() ? StringAttr() : builder.getStringAttr(doc), - libraryCall.empty() ? StringAttr() - : builder.getStringAttr(libraryCall)); + build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps, + iteratorTypes, doc, libraryCall); result.addAttributes(attributes); if (!bodyBuild) return; @@ -539,6 +538,20 @@ bodyBuild(builder, result.location, bodyBlock->getArguments()); } +void GenericOp::build( + OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes, + ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, + ArrayRef iteratorTypes, StringRef doc, StringRef libraryCall, + function_ref bodyBuild, + ArrayRef attributes) { + build(builder, result, resultTensorTypes, inputs, outputs, + builder.getAffineMapArrayAttr(indexingMaps), + builder.getStrArrayAttr(iteratorTypes), + doc.empty() ? StringAttr() : builder.getStringAttr(doc), + libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall), + bodyBuild, attributes); +} + void GenericOp::build( OpBuilder &builder, OperationState &result, ValueRange inputs, ValueRange outputs, ArrayRef indexingMaps, @@ -690,64 +703,124 @@ LogicalResult GenericOp::verify() { return verifyGenericOp(*this); } namespace { -// Deduplicate redundant args of a linalg generic op. -// An arg is redundant if it has the same Value and indexing map as another. -struct DeduplicateGenericOpInputs : public OpRewritePattern { + +struct DeduplicateAndRemoveDeadOperandsAndResults + : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { - // Associate each input to an equivalent "canonical" input that has the same - // Value and indexing map. - // - // In the non-duplicate case, input `i` will have canonical input `i`. But - // in the case of duplicated inputs, the canonical input could be some other - // input `< i`. That is, a later input will have some earlier input as its - // canonical input. - llvm::SmallDenseMap, unsigned> canonicalInput; - // For later remapping tasks like deduplicating payload block arguments, - // having a simple "inputIndex -> canonicalInputIndex" integer mapping is - // convenient. - SmallVector canonicalInputIndices; - for (OpOperand *opOperand : genericOp.getInputOperands()) { - AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand); - // STL-like maps have a convenient behavior for our use case here. In the - // case of duplicate keys, the insertion is rejected, and the returned - // iterator gives access to the value already in the map. - auto pair = canonicalInput.insert( - {{opOperand->get(), indexingMap}, opOperand->getOperandNumber()}); - canonicalInputIndices.push_back(pair.first->second); + // Create a map from argument position in the original op to the argument + // position in the new op. If the argument is dropped it wont have an entry. + llvm::SmallDenseMap origToModifiedPos; + unsigned numNewArgs = 0; + llvm::SmallDenseSet droppedOutputs; + + // Information needed to build the modified op. + SmallVector modifiedInputOperands, modifiedOutputOperands; + SmallVector modifiedIndexingMapsAttrs; + SmallVector modifiedIndexingMaps; + SmallVector modifiedResultTypes; + + // Function to determine if the provided list of indexing maps can compute + // the op iteration domain. It takes two lists, one the set of indexing maps + // known to be preserved in the new op, and the set of indexing maps from + // the original op that might make it into the canonicalized op. + SmallVector concatList; + auto getShapesToLoopMap = [&concatList](ArrayRef list1, + ArrayRef list2) -> bool { + concatList.assign(list1.begin(), list1.end()); + concatList.append(list2.begin(), list2.end()); + return inversePermutation(concatAffineMaps(concatList)) != AffineMap(); + }; + + // Input argument can be dropped if + // - it has no uses + // - there is a duplicate operand which is is accessed using the same + // indexing map. + llvm::SmallDenseMap, unsigned> dedupedInputs; + auto indexingMaps = genericOp.getIndexingMaps(); + for (OpOperand *inputOpOperand : genericOp.getInputOperands()) { + BlockArgument arg = genericOp.getTiedBlockArgument(inputOpOperand); + unsigned argNum = arg.getArgNumber(); + + // Check if operand is dead and if dropping the indexing map makes the + // loops to shape computation invalid. + // NOTE: This is using the implicit assumption that the first + // n indexing maps in the list are for the inputs. + if (!genericOp.payloadUsesValueFromOperand(inputOpOperand) && + getShapesToLoopMap( + modifiedIndexingMaps, + ArrayRef(indexingMaps).drop_front(argNum + 1))) + continue; + + // Check if this operand is a duplicate. + Attribute indexingMap = genericOp.getTiedIndexingMapAttr(inputOpOperand); + auto it = dedupedInputs.find( + std::make_pair(inputOpOperand->get(), indexingMap)); + if (it != dedupedInputs.end()) { + origToModifiedPos[argNum] = it->second; + continue; + } + + // This is a preserved argument. + origToModifiedPos[argNum] = numNewArgs++; + dedupedInputs[{inputOpOperand->get(), indexingMap}] = argNum; + modifiedInputOperands.push_back(inputOpOperand->get()); + modifiedIndexingMapsAttrs.push_back(indexingMap); + modifiedIndexingMaps.push_back( + indexingMap.cast().getValue()); } - // If there are no duplicate args, then bail out. - if (canonicalInput.size() == genericOp.getNumInputs()) + // If the op doesnt have tensor semantics, keep all the outputs as + // preserved. + if (!genericOp.hasTensorSemantics()) { + for (OpOperand *outputOpOperand : genericOp.getOutputOperands()) { + BlockArgument arg = genericOp.getTiedBlockArgument(outputOpOperand); + origToModifiedPos[arg.getArgNumber()] = numNewArgs++; + modifiedOutputOperands.push_back(outputOpOperand->get()); + modifiedIndexingMapsAttrs.push_back( + genericOp.getTiedIndexingMapAttr(outputOpOperand)); + } + } else { + // Output argument can be dropped if the result has no users, and + // it is not used in the payload. + for (auto outputOpOperand : + llvm::enumerate(genericOp.getOutputOperands())) { + Value result = genericOp.getResult(outputOpOperand.index()); + BlockArgument arg = + genericOp.getTiedBlockArgument(outputOpOperand.value()); + if (result.use_empty() && + !genericOp.payloadUsesValueFromOperand(outputOpOperand.value())) { + // Drop the output. + droppedOutputs.insert(outputOpOperand.index()); + continue; + } + + origToModifiedPos[arg.getArgNumber()] = numNewArgs++; + modifiedOutputOperands.push_back(outputOpOperand.value()->get()); + modifiedIndexingMapsAttrs.push_back( + genericOp.getTiedIndexingMapAttr(outputOpOperand.value())); + modifiedResultTypes.push_back(result.getType()); + } + } + + if (modifiedInputOperands.size() + modifiedOutputOperands.size() == + genericOp.getNumInputsAndOutputs()) { + // Nothing to do. return failure(); + } - // The operands for the newly canonicalized op. - SmallVector newInputOperands; - for (OpOperand *opOperand : genericOp.getInputOperands()) - if (canonicalInputIndices[opOperand->getOperandNumber()] == - opOperand->getOperandNumber()) - newInputOperands.push_back(opOperand->get()); - - // Repair the indexing maps by filtering out the ones that have been - // eliminated. - SmallVector newIndexingMaps; - for (OpOperand *opOperand : genericOp.getInputOperands()) - if (canonicalInputIndices[opOperand->getOperandNumber()] == - opOperand->getOperandNumber()) - newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); - for (OpOperand *opOperand : genericOp.getOutputOperands()) - newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); - - // Clone the old op with new operands. - SmallVector outputOperands = genericOp.getOutputOperands(); + // Create the new op with the body being empty. + Location loc = genericOp.getLoc(); auto newOp = rewriter.create( - genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands, - outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps), + loc, modifiedResultTypes, modifiedInputOperands, modifiedOutputOperands, + rewriter.getArrayAttr(modifiedIndexingMapsAttrs), genericOp.iterator_types(), genericOp.docAttr(), - genericOp.library_callAttr()); - + genericOp.library_callAttr(), + [](OpBuilder & /*builder*/, Location /*loc*/, ValueRange /*args*/) { + return; + }); // Copy over unknown attributes. They might be load bearing for some flow. ArrayRef odsAttrs = genericOp.getAttributeNames(); for (NamedAttribute kv : genericOp->getAttrs()) { @@ -756,27 +829,44 @@ } } - rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), - newOp.region().begin()); - - // Repair the payload entry block by RAUW'ing redundant arguments and - // erasing them. - Block &payload = newOp.region().front(); - SmallVector inputOperands = genericOp.getInputOperands(); - for (OpOperand *opOperand : llvm::reverse(inputOperands)) { - // Iterate in reverse, so that we erase later args first, preventing the - // argument list from shifting unexpectedly and invalidating all our - // indices. - unsigned operandNumber = opOperand->getOperandNumber(); - if (canonicalInputIndices[operandNumber] == operandNumber) - continue; - payload.getArgument(operandNumber) - .replaceAllUsesWith( - payload.getArgument(canonicalInputIndices[operandNumber])); - payload.eraseArgument(operandNumber); + // Merge the body of the original op with the new op. + Block *newOpBlock = &newOp.region().front(); + Block *origOpBlock = &genericOp.region().front(); + SmallVector replacements(origOpBlock->getNumArguments(), nullptr); + for (auto argNum : llvm::seq(0, origOpBlock->getNumArguments())) { + auto it = origToModifiedPos.find(argNum); + if (it != origToModifiedPos.end()) { + replacements[argNum] = newOpBlock->getArgument(it->second); + } + } + rewriter.mergeBlocks(origOpBlock, newOpBlock, replacements); + + // Drop the unused yield args. + Block *block = &newOp.region().front(); + if (!droppedOutputs.empty()) { + OpBuilder::InsertionGuard g(rewriter); + SmallVector newYieldVals; + YieldOp origYieldOp = cast(block->getTerminator()); + rewriter.setInsertionPoint(origYieldOp); + SetVector deadOps; + auto filter = [&block](Operation *op) { return op->getBlock() == block; }; + for (auto yieldOpOperands : llvm::enumerate(origYieldOp.values())) { + if (!droppedOutputs.count(yieldOpOperands.index())) { + newYieldVals.push_back(yieldOpOperands.value()); + continue; + } + } + rewriter.replaceOpWithNewOp(origYieldOp, newYieldVals); } - rewriter.replaceOp(genericOp, newOp->getResults()); + // Replace all live uses of the op. + unsigned newResultNum = 0; + for (auto result : llvm::enumerate(genericOp.getResults())) { + if (droppedOutputs.count(result.index())) + continue; + result.value().replaceAllUsesWith(newOp.getResult(newResultNum++)); + } + rewriter.eraseOp(genericOp); return success(); } }; @@ -853,72 +943,13 @@ return success(); } }; - -/// Drop dead args of a linalg generic op. -/// An arg is dead if it has zero uses in the op region. -struct DeadArgsGenericOpInputs : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(GenericOp genericOp, - PatternRewriter &rewriter) const override { - SmallVector oldIndexingMaps = genericOp.getIndexingMaps(); - // Maps must be projected permutations. - if (llvm::any_of(genericOp.getIndexingMaps(), [](AffineMap map) { - return !map.isProjectedPermutation(); - })) - return failure(); - Block &payload = genericOp.region().front(); - SmallVector newInputOperands; - SmallVector newIndexingMaps; - bool deadArgFound = false; - int inputSize = genericOp.getInputOperands().size(); - for (int i = inputSize - 1; i >= 0; i--) { - OpOperand *opOperand = genericOp.getInputOperand(i); - // Iterate in reverse, so that we erase later args first, preventing the - // argument list from shifting unexpectedly and invalidating all our - // indices. - if (payload.getArgument(i).use_empty() && - !hasaUniqueDim(oldIndexingMaps, i)) { - payload.eraseArgument(i); - deadArgFound = true; - // remove this indexing map out of consideration for hasaUniqueDim check - oldIndexingMaps.erase(oldIndexingMaps.begin() + i); - } else { - newInputOperands.insert(newInputOperands.begin(), opOperand->get()); - newIndexingMaps.insert(newIndexingMaps.begin(), - genericOp.getTiedIndexingMap(opOperand)); - } - } - // Bail out if there are no dead args. - if (!deadArgFound) - return failure(); - for (OpOperand *opOperand : genericOp.getOutputOperands()) - newIndexingMaps.push_back(genericOp.getTiedIndexingMap(opOperand)); - SmallVector outputOperands = genericOp.getOutputOperands(); - - auto newOp = rewriter.create( - genericOp.getLoc(), genericOp->getResultTypes(), newInputOperands, - outputOperands, rewriter.getAffineMapArrayAttr(newIndexingMaps), - genericOp.iterator_types(), genericOp.docAttr(), - genericOp.library_callAttr()); - // Copy over unknown attributes. They might be load bearing for some flow. - ArrayRef odsAttrs = genericOp.getAttributeNames(); - for (NamedAttribute kv : genericOp->getAttrs()) { - if (!llvm::is_contained(odsAttrs, kv.getName().getValue())) { - newOp->setAttr(kv.getName(), kv.getValue()); - } - } - rewriter.inlineRegionBefore(genericOp.region(), newOp.region(), - newOp.region().begin()); - rewriter.replaceOp(genericOp, newOp->getResults()); - return success(); - } -}; } // namespace void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add(context); + results + .add( + context); } LogicalResult GenericOp::fold(ArrayRef, diff --git a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir --- a/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize-duplicate-inputs.mlir @@ -90,3 +90,95 @@ } -> tensor return %0 : tensor } + +// ----- + +// Drop dead result. + +#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +#map2 = affine_map<(d0, d1, d2) -> (d1, d2, d0)> +#map3 = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +#map4 = affine_map<(d0, d1, d2) -> (d2, d0, d1)> +func @drop_dead_results(%arg0 : tensor) -> (tensor, tensor) { + %0:4 = linalg.generic { + indexing_maps = [#map0, #map1, #map2, #map3, #map4], + iterator_types = ["parallel", "parallel", "parallel"]} + ins(%arg0 : tensor) + outs(%arg0, %arg0, %arg0, %arg0 + : tensor, tensor, tensor, tensor) { + ^bb0(%b0 : f32, %b1 : f32, %b2 : f32, %b3 : f32, %b4 : f32) : + %1 = arith.addf %b0, %b0: f32 + linalg.yield %1, %1, %1, %1 : f32, f32, f32, f32 + } -> (tensor, tensor, tensor, tensor) + return %0#0, %0#2 : tensor, tensor +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2) -> (d0, d2, d1)> +// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1, d2) -> (d1, d0, d2)> +// CHECK: func @drop_dead_results( +// CHECK-SAME: %[[ARG0:.+]]: tensor) +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]] +// CHECK-SAME: outs(%[[ARG0]], %[[ARG0]] : +// CHECK: return %[[GENERIC]]#0, %[[GENERIC]]#1 + +// ----- + +// Current argmax lowering to `linalg.generic`. Cannot drop the +// first return even though it isnt used since it has an internal +// use. +#map0 = affine_map<(d0) -> (d0)> +#map1 = affine_map<(d0) -> ()> +func @argmax_lowering(%arg0 : tensor) -> tensor { + %init0 = linalg.init_tensor [] : tensor + %init1 = linalg.init_tensor [] : tensor + %0:2 = linalg.generic { + indexing_maps = [#map0, #map1, #map1], + iterator_types = ["reduction"]} + ins(%arg0 : tensor) + outs(%init0, %init1 : tensor, tensor) { + ^bb0(%b0: f32, %b1: f32, %b2: i32): + %8 = linalg.index 0 : index + %9 = arith.index_cast %8 : index to i32 + %10 = arith.cmpf oge, %b0, %b1 : f32 + %11 = arith.select %10, %b0, %b1 : f32 + %12 = arith.cmpf oeq, %b0, %b1 : f32 + %13 = arith.minsi %9, %b2 : i32 + %14 = arith.select %10, %9, %b2 : i32 + %15 = arith.select %12, %13, %14 : i32 + linalg.yield %11, %15 : f32, i32 + } -> (tensor, tensor) + return %0#1 : tensor +} +// CHECK: func @argmax_lowering( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK-DAG: %[[INIT0:.+]] = linalg.init_tensor [] : tensor +// CHECK-DAG: %[[INIT1:.+]] = linalg.init_tensor [] : tensor +// CHECK: %[[GENERIC:.+]]:2 = linalg.generic +// CHECK-SAME: outs(%[[INIT0]], %[[INIT1]] : +// CHECK: return %[[GENERIC]]#1 + +// ----- + +// Do not remove operand needed for loop dim. +func @loop_dim_operand(%arg0 : tensor) -> tensor { + %cst = arith.constant 0 : i32 + %init = linalg.init_tensor [] : tensor + %fill = linalg.fill ins(%cst : i32) outs(%init : tensor) -> tensor + %0 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> ()>], + iterator_types = ["reduction"]} + ins(%arg0 : tensor) outs(%fill : tensor) { + ^bb0(%b0: f32, %b1: i32): + %1 = linalg.index 0 : index + %2 = arith.index_cast %1 : index to i32 + %3 = arith.addi %b1, %2 : i32 + linalg.yield %3 : i32 + } -> tensor + return %0 : tensor +} +// CHECK: func @loop_dim_operand( +// CHECK-SAME: %[[ARG0:.+]]: tensor +// CHECK: linalg.generic +// CHECK-SAME: ins(%[[ARG0]] :