diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1164,10 +1164,11 @@ /// Pattern to fold a generic op with a splat constant/scalar constant. Does not /// handle cases where the constant is not single-valued. -class FoldConstants : public OpRewritePattern { +class FoldScalarOrSplatConstant : public OpRewritePattern { public: - FoldConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, - PatternBenefit benefit = 1) + FoldScalarOrSplatConstant(MLIRContext *context, + ControlElementwiseOpsFusionFn &fun, + PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFn(fun) {} LogicalResult matchAndRewrite(GenericOp genericOp, @@ -1268,6 +1269,237 @@ private: ControlElementwiseOpsFusionFn controlFn; }; + +/// Base class for constant folding linalg.generic ops with N inputs, 1 output, +/// and permutation indexing maps. +/// +/// `ConcreteType` should provide methods with signatures +/// +/// ```c++ +/// bool matchIndexingMaps(GenericOp genericOp) const; +/// RegionComputationFn getRegionComputeFn(GenericOp) const; +/// ``` +/// +/// The latter inspects the region and returns the computation inside as a +/// functor. The functor will be invoked with constant elements for all inputs +/// and should return the corresponding computea constant element for output. +template +class FoldConstantBase : public OpRewritePattern { +public: + struct APIntOrFloatArray { + SmallVector apInts; + SmallVector apFloats; + }; + using RegionComputationFn = + std::function; + + FoldConstantBase(MLIRContext *context, + const ControlElementwiseOpsFusionFn &controlFn, + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), controlFn(controlFn) {} + + LogicalResult matchAndRewrite(GenericOp genericOp, + PatternRewriter &rewriter) const override { + if (genericOp.hasBufferSemantics()) + return failure(); + + // Only support ops generating one output for now. + if (genericOp.getNumOutputs() != 1) + return failure(); + + auto outputType = genericOp.getResultTypes().front().dyn_cast(); + // Require the output types to be static give we are generating constants. + if (!outputType || !outputType.hasStaticShape()) + return failure(); + + if (!llvm::all_of(genericOp.getInputOperands(), [](OpOperand *operand) { + return operand->get().getType().isa(); + })) + return failure(); + + // Make sure all element types are the same. + auto getOperandElementType = [](OpOperand *operand) { + return operand->get().getType().cast().getElementType(); + }; + if (!llvm::is_splat(llvm::map_range(genericOp.getInputAndOutputOperands(), + getOperandElementType))) + return failure(); + + // We can only handle the case where we have int/float elements. + auto elementType = outputType.getElementType(); + if (!elementType.isIntOrFloat()) + return failure(); + + // Require all indexing maps to be permutations for now. This is common and + // it simplifies input/output access greatly: we can do the data shuffling + // entirely in the compiler, without needing to turn all indices into + // Values, and then do affine apply on them, and then match back the + // constant again. + if (!llvm::all_of(genericOp.getIndexingMaps(), + [](AffineMap map) { return map.isPermutation(); })) + return failure(); + + for (OpOperand *operand : genericOp.getOutputOperands()) { + if (genericOp.payloadUsesValueFromOperand(operand)) + return failure(); + } + + // Further check the indexing maps are okay for the ConcreteType. + if (!static_cast(this)->matchIndexingMaps(genericOp)) + return failure(); + + // Defer to the concrete type to check the region and discover the + // computation inside. + RegionComputationFn computeFn = + static_cast(this)->getRegionComputeFn(genericOp); + if (!computeFn) + return failure(); + + // All inputs should be constants. + int numInputs = genericOp.getNumInputs(); + SmallVector inputValues(numInputs); + for (auto operand : llvm::enumerate(genericOp.getInputOperands())) { + if (!matchPattern(operand.value()->get(), + m_Constant(&inputValues[operand.index()]))) + return failure(); + } + + // Identified this as a potential candidate for folding. Now check the + // policy to see whether we are allowed to proceed. + for (int i = 0; i < numInputs; ++i) { + OpOperand *consumer = genericOp.getInputOperand(i); + OpResult producer = consumer->get().cast(); + if (!controlFn(producer, *consumer)) + return failure(); + } + + auto linalgOp = cast(genericOp.getOperation()); + SmallVector loopBounds = linalgOp.computeStaticLoopSizes(); + int64_t numElements = outputType.getNumElements(); + + // Use APInt/APFloat instead of Attribute here for constructing the output. + // This helps to avoid blowing up compiler memory usage: Attributes would + // unify the following cases but they have lifetime as the MLIRContext. + SmallVector intOutputValues; + SmallVector fpOutputValues; + if (elementType.template isa()) + fpOutputValues.resize(numElements, APFloat(0.f)); + else + intOutputValues.resize(numElements); + + // Return the constant dim positions from the given permutation map. + auto getDimPositions = [](AffineMap map) { + SmallVector dims; + dims.reserve(map.getNumResults()); + for (AffineExpr result : map.getResults()) { + dims.push_back(result.cast().getPosition()); + } + return dims; + }; + + SmallVector> inputDims; + for (int i = 0; i < numInputs; ++i) + inputDims.push_back(getDimPositions(genericOp.getIndexingMaps()[i])); + auto outputDims = getDimPositions(genericOp.getIndexingMaps().back()); + auto outputShape = outputType.getShape(); + + // Transpose the input constant. Because we don't know its rank in advance, + // we need to loop over the range [0, element count) and delinearize the + // index. + for (int linearIndex0 = 0; linearIndex0 < numElements; ++linearIndex0) { + SmallVector indices(loopBounds.size(), 0); + int totalCount = linearIndex0; + for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { + indices[dim] = totalCount % loopBounds[dim]; + totalCount /= loopBounds[dim]; + } + + SmallVector> srcIndices; + for (int i = 0; i < numInputs; ++i) + srcIndices.emplace_back(loopBounds.size(), 0); + SmallVector dstIndices(loopBounds.size(), 0); + + for (int dim = loopBounds.size() - 1; dim >= 0; --dim) { + for (int i = 0; i < numInputs; ++i) + srcIndices[i][dim] = indices[inputDims[i][dim]]; + dstIndices[dim] = indices[outputDims[dim]]; + } + + uint64_t linearIndex1 = dstIndices.front(); + for (int dim = 1; dim < outputType.getRank(); ++dim) + linearIndex1 = linearIndex1 * outputShape[dim] + dstIndices[dim]; + + // Collect constant elements for all inputs at this loop iteration. + SmallVector intValues; + SmallVector fpValues; + if (elementType.isa()) { + for (int i = 0; i < numInputs; ++i) + fpValues.push_back(inputValues[i].getValue(srcIndices[i])); + } else { + for (int i = 0; i < numInputs; ++i) + intValues.push_back(inputValues[i].getValue(srcIndices[i])); + } + + // Invoke the computation to get the corresponding constant output + // element. + APIntOrFloatArray inputs = {intValues, fpValues}; + APIntOrFloatArray outputs = computeFn(inputs); + + if (elementType.isa()) { + fpOutputValues[linearIndex1] = outputs.apFloats.front(); + } else { + intOutputValues[linearIndex1] = outputs.apInts.front(); + } + } + + DenseIntOrFPElementsAttr outputAttr; + if (elementType.isa()) { + outputAttr = DenseFPElementsAttr::get(outputType, fpOutputValues); + } else { + outputAttr = DenseIntElementsAttr::get(outputType, intOutputValues); + } + rewriter.replaceOpWithNewOp(genericOp, outputAttr); + return success(); + } + +private: + ControlElementwiseOpsFusionFn controlFn; +}; + +// Folds linalg.generic ops that are actually transposes on constant values. +struct FoldConstantTranspose : public FoldConstantBase { + using FoldConstantBase::FoldConstantBase; + + bool matchIndexingMaps(GenericOp genericOp) const { + // We should have one input and one output. + return genericOp.getIndexingMaps().size() == 2; + } + + RegionComputationFn getRegionComputeFn(GenericOp genericOp) const { + // Make sure the region only contains a yield op. + Block &body = genericOp.region().front(); + if (!llvm::hasSingleElement(body)) + return nullptr; + auto yieldOp = dyn_cast(body.getTerminator()); + if (!yieldOp) + return nullptr; + + // The yield op should return the block argument corresponds to the input. + for (Value yieldVal : yieldOp.values()) { + auto yieldArg = yieldVal.dyn_cast(); + if (!yieldArg || yieldArg.getOwner() != &body) + return nullptr; + if (yieldArg.getArgNumber() != 0) + return nullptr; + } + + // No computation; just return the orginal value. + return [](APIntOrFloatArray inputs) { return inputs; }; + } + + ControlElementwiseOpsFusionFn controlFn; +}; + } // namespace static Optional> @@ -1442,8 +1674,9 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns( RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { auto *context = patterns.getContext(); - patterns.add( - context, options.controlElementwiseOpsFusionFn); + patterns.add(context, + options.controlElementwiseOpsFusionFn); patterns.add(context); populateFoldReshapeOpsByExpansionPatterns(patterns, options.controlFoldingReshapesFn); diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -755,15 +755,15 @@ %2:2 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> ()>, - affine_map<(d0, d1) -> ()>, - affine_map<(d0, d1) -> (d0, d1)>, - affine_map<(d0, d1) -> (d0, d1)>], + affine_map<(d0, d1) -> ()>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg0, %cst, %c42 : tensor, f32, i32) outs(%0, %1 : tensor, tensor) { ^bb0(%arg1 : f32, %arg2 : f32, %arg3 : i32, %arg4 : f32, %arg5 : i32) : %3 = addf %arg1, %arg2 : f32 - linalg.yield %3, %arg3 : f32, i32 + linalg.yield %3, %arg3 : f32, i32 } -> (tensor, tensor) return %2#0, %2#1 : tensor, tensor } @@ -774,3 +774,136 @@ // CHECK-SAME: ins(%{{.+}} : tensor) // CHECK: %[[YIELD:.+]] = addf %{{.+}}, %[[CST]] : f32 // CHECK: linalg.yield %[[YIELD]], %[[C42]] : f32, i32 + +// ----- + +// CHECK-LABEL: @transpose_fold_2d_fp32 +func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { + %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + // CHECK: %[[CST:.+]] = constant + // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<3x2xf32> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_fold_2d_fp64 +func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> { + %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64> + // CHECK: %[[CST:.+]] = constant + // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) { + ^bb0(%arg1: f64, %arg2: f64): + linalg.yield %arg1 : f64 + } -> tensor<3x2xf64> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf64> +} + +// ----- + +// CHECK-LABEL: @transpose_fold_4d_i32 +func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> { + %input = constant dense<[[ + [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] + ]]> : tensor<1x2x3x4xi32> + // CHECK: %[[CST:.+]] = constant dense<[ + // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], + // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], + // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] + // CHECK-SAME{LITERAL}: ]> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] + } ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) { + ^bb0(%arg1: i32, %arg2: i32): + linalg.yield %arg1 : i32 + } -> tensor<3x1x4x2xi32> + // CHECK: return %[[CST]] + return %1 : tensor<3x1x4x2xi32> +} + +// ----- + +// CHECK-LABEL: @transpose_fold_4d_i16 +func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> { + %input = constant dense<[[ + [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] + ]]> : tensor<1x2x3x4xi16> + // CHECK: %[[CST:.+]] = constant dense<[ + // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], + // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], + // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] + // CHECK-SAME{LITERAL}: ]> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] + } ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) { + ^bb0(%arg1: i16, %arg2: i16): + linalg.yield %arg1 : i16 + } -> tensor<3x1x4x2xi16> + // CHECK: return %[[CST]] + return %1 : tensor<3x1x4x2xi16> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_non_cst_input +func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> { + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_yield_const +func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { + %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + %cst = constant 8.0 : f32 + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %cst : f32 + } -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_multi_ops_in_region +func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { + %input = constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %add = addf %arg1, %arg1 : f32 + linalg.yield %add : f32 + } -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +}