diff --git a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp @@ -1162,11 +1162,12 @@ ControlElementwiseOpsFusionFn controlFoldingReshapes; }; -/// Pattern to fold a generic op with a splat constant. -class FoldSplatConstants : public OpRewritePattern { +/// Pattern to fold a generic op with a splat constant/scalar constant. Does not +/// handle cases where the constant is not single-valued. +class FoldConstants : public OpRewritePattern { public: - FoldSplatConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, - PatternBenefit benefit = 1) + FoldConstants(MLIRContext *context, ControlElementwiseOpsFusionFn &fun, + PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), controlFn(fun) {} LogicalResult matchAndRewrite(GenericOp genericOp, @@ -1175,10 +1176,36 @@ return failure(); for (OpOperand *opOperand : genericOp.getInputOperands()) { Operation *def = opOperand->get().getDefiningOp(); - DenseElementsAttr constantAttr; - if (!def || - !matchPattern(def, m_Constant(&constantAttr)) || - !constantAttr.isSplat() || !controlFn(def->getResult(0), *opOperand)) + Attribute constantAttr; + auto isScalarOrSplatConstantOp = [&constantAttr](Operation *def) -> bool { + { + DenseElementsAttr splatAttr; + if (matchPattern(def, m_Constant(&splatAttr)) && + splatAttr.isSplat()) { + constantAttr = splatAttr.getSplatValue(); + return true; + } + } + { + IntegerAttr intAttr; + if (matchPattern(def, m_Constant(&intAttr))) { + constantAttr = intAttr; + return true; + } + } + { + FloatAttr floatAttr; + if (matchPattern(def, m_Constant(&floatAttr))) { + constantAttr = floatAttr; + return true; + } + } + return false; + }; + + auto resultValue = opOperand->get().dyn_cast(); + if (!def || !resultValue || !isScalarOrSplatConstantOp(def) || + !controlFn(resultValue, *opOperand)) continue; // The operands and the indexing_maps of the fused operation the same as @@ -1205,8 +1232,7 @@ // Create a constant scalar value from the splat constant. Value scalarConstant = rewriter.create( - def->getLoc(), constantAttr.getSplatValue(), - constantAttr.getType().getElementType()); + def->getLoc(), constantAttr, constantAttr.getType()); SmallVector outputOperands = genericOp.getOutputOperands(); auto fusedOp = rewriter.create( @@ -1411,7 +1437,7 @@ void mlir::linalg::populateElementwiseOpsFusionPatterns( RewritePatternSet &patterns, LinalgElementwiseFusionOptions options) { auto *context = patterns.getContext(); - patterns.add( + patterns.add( context, options.controlElementwiseOpsFusionFn); patterns.add(context); populateFoldReshapeOpsByExpansionPatterns(patterns, diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -740,3 +740,37 @@ // CHECK-DAG: %[[INIT:.+]] = linalg.init_tensor [%[[D0]], %[[D1]]] // CHECK: %[[RESULT:.+]] = linalg.generic // CHECK-SAME: outs(%[[INIT]] : tensor) + +// ----- + +func @fuse_scalar_constant(%arg0 : tensor) -> (tensor, tensor) { + %cst = constant 4.0 : f32 + %c42 = constant 42 : i32 + %c0 = constant 0 : index + %c1 = constant 1 : index + %d0 = tensor.dim %arg0, %c0 : tensor + %d1 = tensor.dim %arg0, %c1 : tensor + %0 = linalg.init_tensor[%d0, %d1] : tensor + %1 = linalg.init_tensor[%d0, %d1] : tensor + %2:2 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> ()>, + affine_map<(d0, d1) -> ()>, + affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"]} + ins(%arg0, %cst, %c42 : tensor, f32, i32) + outs(%0, %1 : tensor, tensor) { + ^bb0(%arg1 : f32, %arg2 : f32, %arg3 : i32, %arg4 : f32, %arg5 : i32) : + %3 = addf %arg1, %arg2 : f32 + linalg.yield %3, %arg3 : f32, i32 + } -> (tensor, tensor) + return %2#0, %2#1 : tensor, tensor +} +// CHECK-LABEL: func @fuse_scalar_constant +// CHECK-DAG: %[[CST:.+]] = constant 4.000000e+00 : f32 +// CHECK-DAG: %[[C42:.+]] = constant 42 : i32 +// CHECK: linalg.generic +// CHECK-SAME: ins(%{{.+}} : tensor) +// CHECK: %[[YIELD:.+]] = addf %{{.+}}, %[[CST]] : f32 +// CHECK: linalg.yield %[[YIELD]], %[[C42]] : f32, i32