Index: mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp =================================================================== --- mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -179,7 +179,120 @@ } } +/// Expand the given value based on reassociation. +static Value expandValue(Value result, Value origOutput, + ArrayRef reassociation, + RankReductionStrategy rankReductionStrategy, + Location loc, PatternRewriter &rewriter) { + // There are no results for memref outputs. + auto origResultType = origOutput.getType().cast(); + if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { + unsigned rank = origResultType.getRank(); + SmallVector offsets(rank, rewriter.getIndexAttr(0)); + SmallVector sizes = + tensor::getMixedSizes(rewriter, loc, origOutput); + SmallVector strides(rank, rewriter.getIndexAttr(1)); + return rewriter.createOrFold( + loc, result, origOutput, offsets, sizes, strides); + } + + assert(rankReductionStrategy == RankReductionStrategy::ReassociativeReshape && + "unknown rank reduction strategy"); + return rewriter.create(loc, origResultType, result, + reassociation); +} + +/// Collapse the given value based on reassociation. +static Value collapseValue(Value operand, ArrayRef targetShape, + ArrayRef reassociation, + RankReductionStrategy rankReductionStrategy, + Location loc, PatternRewriter &rewriter) { + if (auto memrefType = operand.getType().dyn_cast()) { + if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { + FailureOr rankReducingExtract = + memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand, + targetShape); + assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); + return *rankReducingExtract; + } + + assert(rankReductionStrategy == + RankReductionStrategy::ReassociativeReshape && + "unknown rank reduction strategy"); + MemRefLayoutAttrInterface layout; + auto targetType = MemRefType::get(targetShape, memrefType.getElementType(), + layout, memrefType.getMemorySpace()); + return rewriter.create(loc, targetType, operand, + reassociation); + } + if (auto tensorType = operand.getType().dyn_cast()) { + if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { + FailureOr rankReducingExtract = + tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand, + targetShape); + assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); + return *rankReducingExtract; + } + + assert(rankReductionStrategy == + RankReductionStrategy::ReassociativeReshape && + "unknown rank reduction strategy"); + auto targetType = + RankedTensorType::get(targetShape, tensorType.getElementType()); + return rewriter.create(loc, targetType, operand, + reassociation); + } + llvm_unreachable("unsupported operand type"); +} + namespace { +/// Pattern to convert batch_matmul op into matmul. +struct FoldUnitBatchMatmulToMatmul : public OpRewritePattern { + FoldUnitBatchMatmulToMatmul(MLIRContext *ctx, + RankReductionStrategy rankReductionStrategy) + : OpRewritePattern(ctx), + rankReductionStrategy(rankReductionStrategy) {} + LogicalResult matchAndRewrite(BatchMatmulOp batchMatmul, + PatternRewriter &rewriter) const override { + SmallVector dims = batchMatmul.getStaticShape(); + if (dims[0] != 1) + return failure(); + SmallVector reassoc = {ReassociationIndices({0, 1}), + ReassociationIndices({2})}; + Location loc = batchMatmul.getLoc(); + Value lhs = batchMatmul.getDpsInputOperand(0)->get(); + Value rhs = batchMatmul.getDpsInputOperand(1)->get(); + Value acc = batchMatmul.getDpsInitOperand(0)->get(); + lhs = collapseValue( + lhs, lhs.getType().cast().getShape().take_back(2), reassoc, + rankReductionStrategy, loc, rewriter); + rhs = collapseValue( + rhs, rhs.getType().cast().getShape().take_back(2), reassoc, + rankReductionStrategy, loc, rewriter); + acc = collapseValue( + acc, acc.getType().cast().getShape().take_back(2), reassoc, + rankReductionStrategy, loc, rewriter); + SmallVector resultType; + bool hasTensorResult = batchMatmul.getNumResults() > 0; + if (hasTensorResult) + resultType.push_back(acc.getType()); + auto matmul = rewriter.create(loc, resultType, + ValueRange{lhs, rhs}, acc); + if (hasTensorResult) { + Value expand = expandValue(matmul.getResult(0), + batchMatmul.getDpsInitOperand(0)->get(), + reassoc, rankReductionStrategy, loc, rewriter); + rewriter.replaceOp(batchMatmul, expand); + } else { + rewriter.eraseOp(batchMatmul); + } + return success(); + } + +private: + RankReductionStrategy rankReductionStrategy; +}; + /// Pattern to fold unit-trip count loops in GenericOps. struct FoldUnitDimLoops : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -432,72 +545,6 @@ : OpRewritePattern(ctx), rankReductionStrategy(rankReductionStrategy) {} - // Expand the given value. - Value expandValue(Value result, Value origOutput, - ArrayRef reassociation, Location loc, - PatternRewriter &rewriter) const { - // There are no results for memref outputs. - auto origResultType = origOutput.getType().cast(); - if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { - unsigned rank = origResultType.getRank(); - SmallVector offsets(rank, rewriter.getIndexAttr(0)); - SmallVector sizes = - tensor::getMixedSizes(rewriter, loc, origOutput); - SmallVector strides(rank, rewriter.getIndexAttr(1)); - return rewriter.createOrFold( - loc, result, origOutput, offsets, sizes, strides); - } - - assert(rankReductionStrategy == - RankReductionStrategy::ReassociativeReshape && - "unknown rank reduction strategy"); - return rewriter.create(loc, origResultType, result, - reassociation); - } - - // Collapse the given value. - Value collapseValue(Value operand, ArrayRef targetShape, - ArrayRef reassociation, - Location loc, PatternRewriter &rewriter) const { - if (auto memrefType = operand.getType().dyn_cast()) { - if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { - FailureOr rankReducingExtract = - memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand, - targetShape); - assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); - return *rankReducingExtract; - } - - assert(rankReductionStrategy == - RankReductionStrategy::ReassociativeReshape && - "unknown rank reduction strategy"); - MemRefLayoutAttrInterface layout; - auto targetType = - MemRefType::get(targetShape, memrefType.getElementType(), layout, - memrefType.getMemorySpace()); - return rewriter.create(loc, targetType, operand, - reassociation); - } - if (auto tensorType = operand.getType().dyn_cast()) { - if (rankReductionStrategy == RankReductionStrategy::ExtractInsertSlice) { - FailureOr rankReducingExtract = - tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand, - targetShape); - assert(succeeded(rankReducingExtract) && "not a unit-extent collapse"); - return *rankReducingExtract; - } - - assert(rankReductionStrategy == - RankReductionStrategy::ReassociativeReshape && - "unknown rank reduction strategy"); - auto targetType = - RankedTensorType::get(targetShape, tensorType.getElementType()); - return rewriter.create(loc, targetType, operand, - reassociation); - } - llvm_unreachable("unsupported operand type"); - } - LogicalResult matchAndRewrite(GenericOp genericOp, PatternRewriter &rewriter) const override { // Skip the pattern if the op has any tensor with special encoding. @@ -547,8 +594,9 @@ newOperands.push_back(opOperand.get()); continue; } - newOperands.push_back(collapseValue(opOperand.get(), targetShapes[idx], - reassociations[idx], loc, rewriter)); + newOperands.push_back( + collapseValue(opOperand.get(), targetShapes[idx], reassociations[idx], + rankReductionStrategy, loc, rewriter)); } // If any result type changes, insert a reshape to convert from the original @@ -578,8 +626,9 @@ resultReplacements.push_back(result.value()); continue; } - resultReplacements.push_back(expandValue( - result.value(), origOutput, reassociations[index], loc, rewriter)); + resultReplacements.push_back( + expandValue(result.value(), origOutput, reassociations[index], + rankReductionStrategy, loc, rewriter)); } rewriter.replaceOp(genericOp, resultReplacements); @@ -664,8 +713,8 @@ void mlir::linalg::populateFoldUnitExtentDimsViaReshapesPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); - patterns.add(context, - RankReductionStrategy::ReassociativeReshape); + patterns.add( + context, RankReductionStrategy::ReassociativeReshape); // TODO: Patterns unrelated to unit dim folding should be factored out. patterns.add, @@ -683,8 +732,8 @@ void mlir::linalg::populateFoldUnitExtentDimsViaSlicesPatterns( RewritePatternSet &patterns) { auto *context = patterns.getContext(); - patterns.add(context, - RankReductionStrategy::ExtractInsertSlice); + patterns.add( + context, RankReductionStrategy::ExtractInsertSlice); patterns.add(context); // TODO: Patterns unrelated to unit dim folding should be factored out. linalg::FillOp::getCanonicalizationPatterns(patterns, context); Index: mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir =================================================================== --- mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir +++ mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir @@ -923,3 +923,47 @@ // CHECK-SLICES: memref.subview %{{.*}}[0, 0] [1, 1] [1, 1] : memref<1x1xf32, 3> to memref, 3> // CHECK-SLICES: linalg.generic{{.*}}memref, 3> +// ----- + +func.func @batch_matmul_memref(%A: memref<1x?x?xf32>, %B: memref<1x?x?xf32>, %C: memref<1x?x?xf32>) { + linalg.batch_matmul ins(%A, %B: memref<1x?x?xf32>, memref<1x?x?xf32>) + outs(%C : memref<1x?x?xf32>) + return +} + +// CHECK-LABEL: func @batch_matmul_memref +// CHECK-SAME: (%[[A:.+]]: memref<1x?x?xf32>, %[[B:.+]]: memref<1x?x?xf32>, %[[C:.+]]: memref<1x?x?xf32>) +// CHECK: %[[RA:.+]] = memref.collapse_shape %[[A]] {{\[}}[0, 1], [2]] : memref<1x?x?xf32> into memref +// CHECK: %[[RB:.+]] = memref.collapse_shape %[[B]] {{\[}}[0, 1], [2]] : memref<1x?x?xf32> into memref +// CHECK: %[[RC:.+]] = memref.collapse_shape %[[C]] {{\[}}[0, 1], [2]] : memref<1x?x?xf32> into memref +// CHECK: linalg.matmul ins(%[[RA]], %[[RB]] : memref, memref) outs(%[[RC]] : memref) + +// CHECK-LABEL-SLICES: func @batch_matmul_memref +// CHECK-SLICES: memref.subview %{{.*}}[0, 0, 0] [1, %{{.*}}, %{{.*}}] [1, 1, 1] +// CHECK-SLICES: memref.subview %{{.*}}[0, 0, 0] [1, %{{.*}}, %{{.*}}] [1, 1, 1] +// CHECK-SLICES: memref.subview %{{.*}}[0, 0, 0] [1, %{{.*}}, %{{.*}}] [1, 1, 1] +// CHECK-SLICES: linalg.matmul ins(%{{.*}}, %{{.*}} : memref>, memref>) outs(%{{.*}} : memref>) + + +// ----- + +func.func @batch_matmul_tensor(%A: tensor<1x?x?xf32>, %B: tensor<1x?x?xf32>, %C: tensor<1x?x?xf32>) -> tensor<1x?x?xf32> { + %0 = linalg.batch_matmul ins(%A, %B : tensor<1x?x?xf32>, tensor<1x?x?xf32>) + outs(%C : tensor<1x?x?xf32>) -> tensor<1x?x?xf32> + return %0 : tensor<1x?x?xf32> +} + +// CHECK-LABEL: func @batch_matmul_tensor +// CHECK-SAME: (%[[A:.+]]: tensor<1x?x?xf32>, %[[B:.+]]: tensor<1x?x?xf32>, %[[C:.+]]: tensor<1x?x?xf32>) +// CHECK: %[[RA:.+]] = tensor.collapse_shape %[[A]] {{\[}}[0, 1], [2]] : tensor<1x?x?xf32> into tensor +// CHECK: %[[RB:.+]] = tensor.collapse_shape %[[B]] {{\[}}[0, 1], [2]] : tensor<1x?x?xf32> into tensor +// CHECK: %[[RC:.+]] = tensor.collapse_shape %[[C]] {{\[}}[0, 1], [2]] : tensor<1x?x?xf32> into tensor +// CHECK: %[[D:.+]] = linalg.matmul ins(%[[RA]], %[[RB]] : tensor, tensor) outs(%[[RC]] : tensor) -> tensor +// CHECK: %[[E:.+]] = tensor.expand_shape %[[D]] {{\[}}[0, 1], [2]] : tensor into tensor<1x?x?xf32> + +// CHECK-LABEL-SLICES: func @batch_matmul_tensor +// CHECK-SLICES: tensor.extract_slice %{{.*}}[0, 0, 0] [1, %{{.*}}, %{{.*}}] [1, 1, 1] : tensor<1x?x?xf32> to tensor +// CHECK-SLICES: tensor.extract_slice %{{.*}}[0, 0, 0] [1, %{{.*}}, %{{.*}}] [1, 1, 1] : tensor<1x?x?xf32> to tensor +// CHECK-SLICES: tensor.extract_slice %{{.*}}[0, 0, 0] [1, %{{.*}}, %{{.*}}] [1, 1, 1] : tensor<1x?x?xf32> to tensor +// CHECK-SLICES: linalg.matmul ins(%{{.*}}, %{{.*}} : tensor, tensor) outs(%{{.*}} : tensor) -> tensor +// CHECK-SLICES: tensor.insert_slice %{{.*}} into %{{.*}}[0, 0, 0] [1, %{{.*}}, %{{.*}}] [1, 1, 1] : tensor into tensor<1x?x?xf32>