diff --git a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -1311,10 +1312,12 @@ return failure(); LinalgOp linalgOp = cast(op.getOperation()); for (auto operand : llvm::enumerate(linalgOp.getInputOpOperands())) { - ConstantOp constantOp = operand.value().get().getDefiningOp(); - if (!constantOp || - !constantOp.value().cast().isSplat() || - !controlFn(constantOp->getResult(0), operand.value())) + Operation *def = operand.value().get().getDefiningOp(); + DenseElementsAttr constantAttr; + if (!def || + !matchPattern(def, m_Constant(&constantAttr)) || + !constantAttr.isSplat() || + !controlFn(def->getResult(0), operand.value())) continue; // The indexing_maps for the operands of the fused operation are same as @@ -1337,8 +1340,7 @@ // Create a constant scalar value from the splat constant. Value scalarConstant = rewriter.create( - constantOp.getLoc(), - constantOp.value().cast().getSplatValue()); + def->getLoc(), constantAttr.getSplatValue()); LinalgOp fusedOp = createLinalgOpOfSameType( linalgOp, rewriter, rewriter.getUnknownLoc(),