diff --git a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp @@ -33,6 +33,8 @@ // the given `namedOp` does not have a region builder. static GenericOp createGenericOpFromNamedOp(LinalgOp namedOp, PatternRewriter &rewriter) { + SmallVector inputOperands = namedOp.getInputOperands(); + SmallVector outputOperands = namedOp.getOutputOperands(); SmallVector indexingMaps = namedOp.getIndexingMaps(); SmallVector iterators = llvm::to_vector<4>( namedOp.iterator_types().getAsValueRange()); @@ -41,9 +43,9 @@ // Inline the existing region if the named operation has a region attached. if (namedOp->getNumRegions() == 1) { - GenericOp genericOp = rewriter.create( - namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(), - indexingMaps, iterators); + GenericOp genericOp = + rewriter.create(namedOp.getLoc(), types, inputOperands, + outputOperands, indexingMaps, iterators); rewriter.inlineRegionBefore(namedOp->getRegion(0), genericOp.region(), genericOp.region().begin()); return genericOp; @@ -57,8 +59,8 @@ return nullptr; } return rewriter.create( - namedOp.getLoc(), types, namedOp.getInputs(), namedOp.getOutputs(), - indexingMaps, iterators, + namedOp.getLoc(), types, inputOperands, outputOperands, indexingMaps, + iterators, [®ionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) { ImplicitLocOpBuilder b(loc, bodyBuilder); regionBuilder(b, *bodyBuilder.getBlock(), @@ -163,13 +165,14 @@ GenericOp GeneralizeConvOp::createGenericOp(ConvOp convOp, OpBuilder &builder) const { - SmallVector indexingMaps = convOp.getIndexingMaps(); + SmallVector indexingMaps = convOp.getIndexingMaps(); auto iterators = llvm::to_vector<4>(convOp.iterator_types().getAsValueRange()); + SmallVector inputBuffers = convOp.getInputBufferOperands(); + SmallVector outputBuffers = convOp.getOutputBufferOperands(); return builder.create( - convOp.getLoc(), /*resultTensorTypes=*/ArrayRef(), - convOp.getInputBuffers(), convOp.getOutputBuffers(), indexingMaps, - iterators, + convOp.getLoc(), /*resultTensorTypes=*/ArrayRef(), inputBuffers, + outputBuffers, indexingMaps, iterators, [](OpBuilder &bodyBuilder, Location bodyLoc, ValueRange bodyArgs) { Value mul = bodyBuilder.create(bodyLoc, bodyArgs[0], bodyArgs[1]);