diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1164,10 +1164,11 @@ /// Pattern to fold a generic op with a splat constant/scalar constant. Does not /// handle cases where the constant is not single-valued. -class FoldConstants : public OpRewritePattern { +class FoldScalarOrSplatConstant : public OpRewritePattern { public: - FoldConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, - PatternBenefit benefit = 1) + FoldScalarOrSplatConstant(MLIRContext *context, + ControlElementwiseOpsFusionFn &fun, + PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFn(fun) {} LogicalResult matchAndRewrite(GenericOp genericOp, @@ -1264,6 +1265,138 @@ private: ControlElementwiseOpsFusionFn controlFn; }; + +// Folds linalg.generic ops that are actually transposes on constant values. +class FoldConstantTranspose : public OpRewritePattern { +public: + FoldConstantTranspose(MLIRContext *context, + const ControlElementwiseOpsFusionFn &controlFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), controlFn(controlFn) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (genericOp.hasBufferSemantics()) + return failure(); + + // Transpose should just have one input and one output. + if (genericOp.inputs().size() != 1 || genericOp.outputs().size() != 1) + return failure(); + + // All indexing maps should be permutations. + if (!llvm::all_of(genericOp.getIndexingMaps(), + [](AffineMap map) { return map.isPermutation(); })) + return failure(); + + Value input = genericOp.inputs().front(); + Value output = genericOp.getResult(0); + auto inputType = input.getType().dyn_cast(); + auto outputType = output.getType().dyn_cast(); + auto elementType = inputType.getElementType(); + if (!inputType || !outputType) + return failure(); + if (!elementType.isIntOrFloat()) + return failure(); + if (!outputType.hasStaticShape()) + return failure(); + + // Make sure the region only contains a yield op. + Block &body = genericOp.region().front(); + if (!llvm::hasSingleElement(body)) + return failure(); + auto yieldOp = dyn_cast(body.getTerminator()); + if (!yieldOp) + return failure(); + + // The yield op should return the block argument corresponds to the input. + for (Value yieldVal : yieldOp.values()) { + auto yieldArg = yieldVal.dyn_cast(); + if (!yieldArg || yieldArg.getOwner() != &body) + return failure(); + if (yieldArg.getArgNumber() != 0) + return failure(); + } + + DenseIntOrFPElementsAttr inputValues; + if (!matchPattern(input, m_Constant(&inputValues))) + return failure(); + + // Identified this as a potential candidate for folding Now check the + // policy to see whether we are allowed to proceed. + OpOperand *consumer = genericOp.getInputOperand(0); + OpResult producer = consumer->get().cast(); + if (!controlFn(producer, *consumer)) + return failure(); + + auto linalgOp = cast(genericOp.getOperation()); + SmallVector loopBounds = linalgOp.computeStaticLoopSizes(); + int64_t numElements = inputType.getNumElements(); + + SmallVector intOutputValues; + SmallVector fpOutputValues; + if (elementType.isa()) + fpOutputValues.resize(numElements, APFloat(0.f)); + else + intOutputValues.resize(numElements); + + // Return the constant dim positions from the given permutation map. + auto getDimPositions = [](AffineMap map) { + SmallVector dims; + dims.reserve(map.getNumResults()); + for (AffineExpr result : map.getResults()) { + dims.push_back(result.cast().getPosition()); + } + return dims; + }; + + auto inputDims = getDimPositions(genericOp.getIndexingMaps()[0]); + auto outputDims = getDimPositions(genericOp.getIndexingMaps()[1]); + auto outputShape = outputType.getShape(); + + // Transpose the input constant. Because we don't know its rank in advance, + // we need to loop over the range [0, element count) and delinearize the + // index. + for (int linearIndex0 = 0; linearIndex0 < numElements; ++linearIndex0) { + SmallVector indices(loopBounds.size(), 0); + int totalCount = linearIndex0; + for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { + indices[dim] = totalCount % loopBounds[dim]; + totalCount /= loopBounds[dim]; + } + + SmallVector srcIndices(loopBounds.size(), 0); + SmallVector dstIndices(loopBounds.size(), 0); + for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { + srcIndices[dim] = indices[inputDims[dim]]; + dstIndices[dim] = indices[outputDims[dim]]; + } + + uint64_t linearIndex1 = dstIndices.front(); + for (int dim = 1; dim < outputType.getRank(); ++dim) + linearIndex1 = linearIndex1 * outputShape[dim] + dstIndices[dim]; + + if (elementType.isa()) { + fpOutputValues[linearIndex1] = + inputValues.getValue(srcIndices); + } else { + intOutputValues[linearIndex1] = inputValues.getValue(srcIndices); + } + } + + DenseIntOrFPElementsAttr outputAttr; + if (elementType.isa()) { + outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues); + } else { + outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues); + } + rewriter.replaceOpWithNewOp(genericOp, outputAttr); + return success(); + } + +private: + ControlElementwiseOpsFusionFn controlFn; +}; + } // namespace static Optional> @@ -1438,8 +1571,9 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns( RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { auto *context = patterns.getContext(); - patterns.add( - context, options.controlElementwiseOpsFusionFn); + patterns.add(context, + options.controlElementwiseOpsFusionFn); patterns.add(context); populateFoldReshapeOpsByExpansionPatterns(patterns, options.controlFoldingReshapesFn); diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -755,15 +755,15 @@ %2:2 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> ()>, - affine_map<(d0, d1) -> ()>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], + affine_map<(d0, d1) -> ()>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg0, %cst, %c42 : tensor, f32, i32) outs(%0, %1 : tensor, tensor) { ^bb0(%arg1 : f32, %arg2 : f32, %arg3 : i32, %arg4 : f32, %arg5 : i32) : %3 = addf %arg1, %arg2 : f32 - linalg.yield %3, %arg3 : f32, i32 + linalg.yield %3, %arg3 : f32, i32 } -> (tensor, tensor) return %2#0, %2#1 : tensor, tensor } @@ -774,3 +774,136 @@ // CHECK-SAME: ins(%{{.+}} : tensor) // CHECK: %[[YIELD:.+]] = addf %{{.+}}, %[[CST]] : f32 // CHECK: linalg.yield %[[YIELD]], %[[C42]] : f32, i32 + +// ----- + +// CHECK-LABEL: @transpose_fold_2d_fp32 +func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { + %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + // CHECK: %[[CST:.+]] = constant + // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<3x2xf32> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_fold_2d_fp64 +func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> { + %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64> + // CHECK: %[[CST:.+]] = constant + // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) { + ^bb0(%arg1: f64, %arg2: f64): + linalg.yield %arg1 : f64 + } -> tensor<3x2xf64> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf64> +} + +// ----- + +// CHECK-LABEL: @transpose_fold_4d_i32 +func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> { + %input = constant dense<[[ + [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] + ]]> : tensor<1x2x3x4xi32> + // CHECK: %[[CST:.+]] = constant dense<[ + // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], + // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], + // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] + // CHECK-SAME{LITERAL}: ]> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] + } ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) { + ^bb0(%arg1: i32, %arg2: i32): + linalg.yield %arg1 : i32 + } -> tensor<3x1x4x2xi32> + // CHECK: return %[[CST]] + return %1 : tensor<3x1x4x2xi32> +} + +// ----- + +// CHECK-LABEL: @transpose_fold_4d_i16 +func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> { + %input = constant dense<[[ + [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] + ]]> : tensor<1x2x3x4xi16> + // CHECK: %[[CST:.+]] = constant dense<[ + // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], + // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], + // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] + // CHECK-SAME{LITERAL}: ]> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] + } ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) { + ^bb0(%arg1: i16, %arg2: i16): + linalg.yield %arg1 : i16 + } -> tensor<3x1x4x2xi16> + // CHECK: return %[[CST]] + return %1 : tensor<3x1x4x2xi16> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_non_cst_input +func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> { + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_yield_const +func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { + %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + %cst = constant 8.0 : f32 + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %cst : f32 + } -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_multi_ops_in_region +func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { + %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %add = addf %arg1, %arg1 : f32 + linalg.yield %add : f32 + } -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +}