diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -461,6 +461,10 @@ static ArrayAttr collapseReassociationMaps(ArrayRef mapsProducer, ArrayRef mapsConsumer, MLIRContext *context) { + // Handle the corner case of the result being a rank 0 shaped type. Return an + // emtpy ArrayAttr. + if (mapsConsumer.empty() && !mapsProducer.empty()) + return ArrayAttr::get(ArrayRef(), context); if (mapsProducer.empty() || mapsConsumer.empty() || mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() || mapsProducer.size() != mapsConsumer[0].getNumDims()) @@ -500,8 +504,7 @@ ShapedType intermediateType, ShapedType smallerType) -> bool { return largerType.getRank() > intermediateType.getRank() && - intermediateType.getRank() > smallerType.getRank() && - smallerType.getRank() > 0; + intermediateType.getRank() > smallerType.getRank(); }; // Check if producer and consumer are both expanding dims. if (areReshapeOpsFoldable(reshapeOp.getResultType(), reshapeOp.getSrcType(), diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -26,6 +26,8 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include + #define DEBUG_TYPE "linalg-drop-unit-dims" using namespace mlir; @@ -145,15 +147,42 @@ context); } +/// Modify the region of indexed generic op to drop arguments corresponding to +/// loops that are unit trip count. +template +static LogicalResult +replaceBlockArgForUnitDimLoops(OpTy op, const DenseSet &unitDims, + PatternRewriter &rewriterp) { + return success(); +} + +template <> +LogicalResult replaceBlockArgForUnitDimLoops( + IndexedGenericOp op, const DenseSet &unitDims, + PatternRewriter &rewriter) { + OpBuilder::InsertionGuard guard(rewriter); + Block *entryBlock = &op.getOperation()->getRegion(0).front(); + rewriter.setInsertionPointToStart(entryBlock); + Value zero = rewriter.create(op.getLoc(), 0); + for (unsigned unitDimLoop : unitDims) { + entryBlock->getArgument(unitDimLoop).replaceAllUsesWith(zero); + } + std::set orderedUnitDims(unitDims.begin(), unitDims.end()); + for (unsigned i : llvm::reverse(orderedUnitDims)) + entryBlock->eraseArgument(i); + return success(); +} + namespace { /// Pattern to fold unit-trip count loops in GenericOps. // TODO: Generalize this to indexed-generic as well by modifying the region args // as well. -struct FoldUnitDimLoops : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(GenericOp genericOp, +template +struct FoldUnitDimLoops : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOpTy op, PatternRewriter &rewriter) const override { - SmallVector indexingMaps = genericOp.getIndexingMaps(); + SmallVector indexingMaps = op.getIndexingMaps(); if (indexingMaps.empty()) return failure(); @@ -164,10 +193,10 @@ if (!invertedMap) return failure(); SmallVector dims; - for (ShapedType shapedType : genericOp.getInputOutputShapedTypes()) + for (ShapedType shapedType : op.getInputOutputShapedTypes()) dims.append(shapedType.getShape().begin(), shapedType.getShape().end()); DenseSet unitDims; - ArrayAttr iteratorTypes = genericOp.iterator_types(); + ArrayAttr iteratorTypes = op.iterator_types(); for (auto expr : enumerate(invertedMap.getResults())) { if (AffineDimExpr dimExpr = expr.value().dyn_cast()) if (dims[dimExpr.getPosition()] == 1 && @@ -183,7 +212,7 @@ ArrayAttr newIndexingMapAttr = replaceUnitDims(unitDims, indexingMaps, context); if (!newIndexingMapAttr) - return genericOp.emitError("unable to compute modified indexing_maps"); + return op.emitError("unable to compute modified indexing_maps"); // Compute the iterator types of the modified op by dropping the one-trip // count loops. @@ -193,10 +222,11 @@ newIteratorTypes.push_back(attr.value()); } - rewriter.startRootUpdate(genericOp); - genericOp.indexing_mapsAttr(newIndexingMapAttr); - genericOp.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context)); - rewriter.finalizeRootUpdate(genericOp); + rewriter.startRootUpdate(op); + op.indexing_mapsAttr(newIndexingMapAttr); + op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context)); + replaceBlockArgForUnitDimLoops(op, unitDims, rewriter); + rewriter.finalizeRootUpdate(op); return success(); } }; @@ -263,25 +293,27 @@ namespace { /// Pattern to replace tensors operands/results that are unit extents. -struct ReplaceUnitExtentTensors : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(GenericOp genericOp, +template +struct ReplaceUnitExtentTensors : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GenericOpTy op, PatternRewriter &rewriter) const override { // TODO: support init_tensors and reductions. - if (!genericOp.hasTensorSemantics() || !genericOp.init_tensors().empty()) + if (!op.hasTensorSemantics() || !op.init_tensors().empty()) return failure(); MLIRContext *context = rewriter.getContext(); - Location loc = genericOp.getLoc(); + Location loc = op.getLoc(); SmallVector newIndexingMaps; SmallVector reassociationMaps; SmallVector newInputOutputTypes; bool doCanonicalization = false; - for (auto it : llvm::zip(genericOp.getIndexingMaps(), - genericOp.getInputOutputShapedTypes())) { + for (auto it : + llvm::zip(op.getIndexingMaps(), op.getInputOutputShapedTypes())) { auto replacementInfo = replaceUnitExtents( - std::get<0>(it), std::get<1>(it).cast(), context); + std::get<0>(it), std::get<1>(it).template cast(), + context); reassociationMaps.push_back(replacementInfo.reassociation); newIndexingMaps.push_back(replacementInfo.indexMap); newInputOutputTypes.push_back(replacementInfo.type); @@ -313,24 +345,23 @@ return res; }; - SmallVector newInputs = insertReshapes(genericOp.inputs()); + SmallVector newInputs = insertReshapes(op.inputs()); SmallVector newOutputBuffers = - insertReshapes(genericOp.output_buffers()); - SmallVector newInitTensors = - insertReshapes(genericOp.init_tensors()); + insertReshapes(op.output_buffers()); + SmallVector newInitTensors = insertReshapes(op.init_tensors()); // If any result type change, insert a reshape to convert from the original // type to the new type. SmallVector resultTypes; - resultTypes.reserve(genericOp.getNumResults()); - for (unsigned i : llvm::seq(0, genericOp.getNumResults())) - resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]); - GenericOp replacementOp = rewriter.create( + resultTypes.reserve(op.getNumResults()); + for (unsigned i : llvm::seq(0, op.getNumResults())) + resultTypes.push_back(newInputOutputTypes[i + op.getNumInputs()]); + GenericOpTy replacementOp = rewriter.create( loc, resultTypes, newInputs, newOutputBuffers, newInitTensors, newIndexingMaps, llvm::to_vector<4>( - genericOp.iterator_types().getAsValueRange())); - rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(), + op.iterator_types().template getAsValueRange())); + rewriter.inlineRegionBefore(op.region(), replacementOp.region(), replacementOp.region().begin()); // If any result tensor has a modified shape, then add reshape to recover @@ -338,16 +369,16 @@ SmallVector resultReplacements; for (auto result : llvm::enumerate(replacementOp.getResults())) { unsigned index = result.index() + replacementOp.getNumOperands(); - RankedTensorType origResultType = genericOp.getResult(result.index()) + RankedTensorType origResultType = op.getResult(result.index()) .getType() - .cast(); + .template cast(); if (origResultType != result.value().getType()) resultReplacements.push_back(rewriter.create( loc, origResultType, result.value(), reassociationMaps[index])); else resultReplacements.push_back(result.value()); } - rewriter.replaceOp(genericOp, resultReplacements); + rewriter.replaceOp(op, resultReplacements); return success(); } }; @@ -467,7 +498,10 @@ /// broadcasting. void mlir::populateLinalgFoldUnitExtentDimsPatterns( MLIRContext *context, OwningRewritePatternList &patterns) { - patterns.insert(context); + patterns + .insert, FoldUnitDimLoops, + ReplaceUnitExtentTensors, + ReplaceUnitExtentTensors>(context); TensorReshapeOp::getCanonicalizationPatterns(patterns, context); patterns.insert(context); } @@ -481,7 +515,8 @@ FuncOp funcOp = getFunction(); MLIRContext *context = funcOp.getContext(); if (foldOneTripLoopsOnly) - patterns.insert(context); + patterns.insert, + FoldUnitDimLoops>(context); else populateLinalgFoldUnitExtentDimsPatterns(context, patterns); applyPatternsAndFoldGreedily(funcOp.getBody(), patterns); diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -109,13 +109,19 @@ // consumer's operand. // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a // generic op. In this case, there are no indices in block arguments. - unsigned numProducerIndices = - isa(producer.getOperation()) ? nloops : 0; - unsigned numConsumerIndices = - isa(consumer.getOperation()) ? nloops : 0; + unsigned numProducerIndices = isa(producer.getOperation()) + ? producer.getNumLoops() + : 0; + unsigned numConsumerIndices = isa(consumer.getOperation()) + ? consumer.getNumLoops() + : 0; + unsigned numFusedOpIndices = + (isa(producer.getOperation()) || + isa(consumer.getOperation())) + ? std::max(producer.getNumLoops(), consumer.getNumLoops()) + : 0; // Firstly, add all the indices to the block arguments. - for (unsigned i = 0, e = std::max(numProducerIndices, numConsumerIndices); - i < e; ++i) + for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i) fusedBlock->addArgument(rewriter.getIndexType()); // Map the arguments for the unmodified args from the consumer. for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) { @@ -129,7 +135,7 @@ auto newIndex = rewriter.create( producer.getLoc(), consumerToProducerLoopsMap.getSubMap(producerArg.index()), - fusedBlock->getArguments().take_front(nloops)); + fusedBlock->getArguments().take_front(numFusedOpIndices)); mapper.map(producerArg.value(), newIndex); } else { mapper.map(producerArg.value(), diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -43,6 +43,34 @@ // ----- +// ----- + +func @collapsing_tensor_reshapes_to_zero_dim(%arg0 : tensor<1x1x1xf32>) + -> tensor { + %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : + tensor<1x1x1xf32> into tensor<1xf32> + %1 = linalg.tensor_reshape %0 [] : tensor<1xf32> into tensor + return %1 : tensor +} +// CHECK-LABEL: collapsing_tensor_reshapes_to_zero +// CHECK: linalg.tensor_reshape %{{.*}} [] +// CHECK-SAME: tensor<1x1x1xf32> into tensor + +// ----- + +func @collapsing_memref_reshapes_to_zero_dim(%arg0 : memref<1x1x1xf32>) + -> memref { + %0 = linalg.reshape %arg0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : + memref<1x1x1xf32> into memref<1xf32> + %1 = linalg.reshape %0 [] : memref<1xf32> into memref + return %1 : memref +} +// CHECK-LABEL: collapsing_memref_reshapes_to_zero +// CHECK: linalg.reshape %{{.*}} [] +// CHECK-SAME: memref<1x1x1xf32> into memref + +// ----- + func @expanding_tensor_reshapes(%arg0 : tensor) -> tensor { %0 = linalg.tensor_reshape %arg0 @@ -106,6 +134,33 @@ // ----- +func @expanding_tensor_reshapes_to_zero_dim(%arg0 : tensor) + -> tensor<1x1x1xf32> { + %0 = linalg.tensor_reshape %arg0 [] : tensor into tensor<1xf32> + %1 = linalg.tensor_reshape %0 [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : + tensor<1xf32> into tensor<1x1x1xf32> + return %1 : tensor<1x1x1xf32> +} +// CHECK-LABEL: expanding_tensor_reshapes_to_zero +// CHECK: linalg.tensor_reshape %{{.*}} [] +// CHECK-SAME: tensor into tensor<1x1x1xf32> + +// ----- + +func @expanding_memref_reshapes_to_zero_dim(%arg0 : memref) + -> memref<1x1x1xf32> { + %0 = linalg.reshape %arg0 [] : memref into memref<1xf32> + %1 = linalg.reshape %0 + [affine_map<(d0, d1, d2) -> (d0, d1, d2)>] : + memref<1xf32> into memref<1x1x1xf32> + return %1 : memref<1x1x1xf32> +} +// CHECK-LABEL: expanding_memref_reshapes_to_zero +// CHECK: linalg.reshape %{{.*}} [] +// CHECK-SAME: memref into memref<1x1x1xf32> + +// ----- + func @fold_tensor_reshape(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> { %0 = linalg.tensor_reshape %arg0 diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir --- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -36,6 +36,47 @@ // ----- +#accesses = [ + affine_map<(i, j, k, l, m) -> (i, k, m)>, + affine_map<(i, j, k, l, m) -> (i, k, j, l, m)> +] + +#trait = { + iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"], + indexing_maps = #accesses, + library_call = "some_external_func" +} + +func @drop_one_trip_loops_indexed_generic + (%arg0 : tensor) -> tensor +{ + %0 = linalg.indexed_generic #trait + ins(%arg0 : tensor) { + ^bb0(%arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, + %arg5 : index, %arg6 : i32) : + %1 = addi %arg1, %arg2 : index + %2 = addi %1, %arg3 : index + %3 = addi %2, %arg4 : index + %4 = addi %3, %arg5 : index + %5 = index_cast %4 : index to i32 + %6 = addi %5, %arg6 : i32 + linalg.yield %6 : i32 + } -> tensor + return %0 : tensor +} +// CHECK-LABEL: func @drop_one_trip_loops_indexed_generic +// CHECK: linalg.indexed_generic +// CHECK: ^{{.+}}( +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index, %[[ARG2:[a-zA-Z0-9]+]]: index +// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index, %[[ARG4:[a-zA-Z0-9]+]]: i32) +// CHECK: %[[T3:.+]] = addi %[[ARG1]], %[[ARG2]] +// CHECK: %[[T4:.+]] = addi %[[T3]], %[[ARG3]] +// CHECK: %[[T5:.+]] = index_cast %[[T4]] : index to i32 +// CHECK: %[[T6:.+]] = addi %[[T5]], %[[ARG4]] : i32 +// CHECK: linalg.yield %[[T6]] : i32 + +// ----- + #map0 = affine_map<(i, j) -> (i, j)> #access = [#map0, #map0] #trait = { @@ -62,6 +103,35 @@ // ----- +#map0 = affine_map<(i, j) -> (i, j)> +#access = [#map0, #map0] +#trait = { + iterator_types = ["parallel", "parallel"], + indexing_maps = #access, + library_call = "some_external_func" +} + +func @drop_all_loops_indexed_generic + (%arg0 : tensor<1x1xi32>) -> tensor<1x1xi32> +{ + %0 = linalg.indexed_generic #trait + ins(%arg0 : tensor<1x1xi32>) { + ^bb0(%arg1 : index, %arg2 : index, %arg3: i32) : + %1 = addi %arg1, %arg2 : index + %2 = index_cast %1 : index to i32 + %3 = addi %2, %arg3 : i32 + linalg.yield %3 : i32 + } -> tensor<1x1xi32> + return %0 : tensor<1x1xi32> +} + +// CHECK-LABEL: func @drop_all_loops_indexed_generic +// CHECK: linalg.indexed_generic +// CHECK: ^{{.+}}(%[[ARG1:.+]]: i32) +// CHECK: linalg.yield %[[ARG1]] : i32 + +// ----- + #accesses = [ affine_map<(d0) -> (0, d0)>, affine_map<(d0) -> (d0)> diff --git a/mlir/test/Dialect/Linalg/fusion-tensor.mlir b/mlir/test/Dialect/Linalg/fusion-tensor.mlir --- a/mlir/test/Dialect/Linalg/fusion-tensor.mlir +++ b/mlir/test/Dialect/Linalg/fusion-tensor.mlir @@ -381,3 +381,43 @@ // CHECK: %[[VAL4:.+]] = subi %[[VAL3]], %[[SUB_OPERAND2]] : i32 // CHECK: linalg.yield %[[VAL4]] : i32 // CHECK-NOT: linalg.indexed_generic + +// ----- + +func @scalar_indexed_generic_fusion + (%arg0: tensor<5x1x1xf32>, %arg1 : tensor) -> tensor<10xf32> +{ + %c0 = constant 0 : index + %cst = constant dense<1.000000e+00> : tensor<10xf32> + %0 = linalg.indexed_generic + {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], + iterator_types = []} + ins(%arg1 : tensor) { + ^bb0(%arg2: i32): // no predecessors + %3 = index_cast %arg2 : i32 to index + %4 = extract_element %arg0[%3, %c0, %c0] : tensor<5x1x1xf32> + linalg.yield %4 : f32 + } -> tensor + %1 = linalg.generic + {indexing_maps = [affine_map<(d0) -> ()>, affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (d0)>], + iterator_types = ["parallel"]} + ins(%0, %cst : tensor, tensor<10xf32>) { + ^bb0(%arg2: f32, %arg3: f32): // no predecessors + %3 = mulf %arg2, %arg3 : f32 + linalg.yield %3 : f32 + } -> tensor<10xf32> + return %1 : tensor<10xf32> +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0) -> ()> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0) -> (d0)> +// CHECK: func @scalar_indexed_generic_fusion +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<5x1x1xf32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK: %[[T0:.+]] = linalg.indexed_generic +// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]]] +// CHECK-SAME: iterator_types = ["parallel"] +// CHECK-SAME: ins(%[[ARG1]] : tensor) +// CHECK: extract_element %[[ARG0]] +// CHECK: linalg.yield +// CHECK return %[[T0]] \ No newline at end of file