diff --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md --- a/mlir/docs/Bufferization.md +++ b/mlir/docs/Bufferization.md @@ -139,10 +139,10 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::CastOp op, ArrayRef operands, + matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto resultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resultType, operands[0]); + rewriter.replaceOpWithNewOp(op, resultType, adaptor.source()); return success(); } }; diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h --- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h +++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h @@ -131,6 +131,8 @@ template class ConvertOpToLLVMPattern : public ConvertToLLVMPattern { public: + using OpAdaptor = typename SourceOp::Adaptor; + explicit ConvertOpToLLVMPattern(LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) : ConvertToLLVMPattern(SourceOp::getOperationName(), @@ -140,7 +142,8 @@ /// Wrappers around the RewritePattern methods that pass the derived op type. void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), operands, rewriter); + rewrite(cast(op), OpAdaptor(operands, op->getAttrDictionary()), + rewriter); } LogicalResult match(Operation *op) const final { return match(cast(op)); @@ -148,28 +151,53 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - return matchAndRewrite(cast(op), operands, rewriter); + return matchAndRewrite(cast(op), + OpAdaptor(operands, op->getAttrDictionary()), + rewriter); } /// Rewrite and Match methods that operate on the SourceOp type. These must be /// overridden by the derived pattern class. + /// NOTICE: These methods are deprecated and will be removed. All new code + /// should use the adaptor methods below instead. virtual void rewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override rewrite or matchAndRewrite"); } - virtual LogicalResult match(SourceOp op) const { - llvm_unreachable("must override match or matchAndRewrite"); - } virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (succeeded(match(op))) { - rewrite(op, operands, rewriter); + rewrite(op, OpAdaptor(operands, op->getAttrDictionary()), rewriter); return success(); } return failure(); } + /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// overridden by the derived pattern class. + virtual LogicalResult match(SourceOp op) const { + llvm_unreachable("must override match or matchAndRewrite"); + } + virtual void rewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + ValueRange operands = adaptor.getOperands(); + rewrite(op, + ArrayRef(operands.getBase().get(), + operands.size()), + rewriter); + } + virtual LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + ValueRange operands = adaptor.getOperands(); + return matchAndRewrite( + op, + ArrayRef(operands.getBase().get(), + operands.size()), + rewriter); + } + private: using ConvertToLLVMPattern::match; using ConvertToLLVMPattern::matchAndRewrite; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -366,79 +366,121 @@ using RewritePattern::rewrite; }; -namespace detail { -/// OpOrInterfaceConversionPatternBase is a wrapper around ConversionPattern -/// that allows for matching and rewriting against an instance of a derived -/// operation class or an Interface as opposed to a raw Operation. +/// OpConversionPattern is a wrapper around ConversionPattern that allows for +/// matching and rewriting against an instance of a derived operation class as +/// opposed to a raw Operation. template -struct OpOrInterfaceConversionPatternBase : public ConversionPattern { - using ConversionPattern::ConversionPattern; +class OpConversionPattern : public ConversionPattern { +public: + using OpAdaptor = typename SourceOp::Adaptor; + + OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) + : ConversionPattern(SourceOp::getOperationName(), benefit, context) {} + OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, + PatternBenefit benefit = 1) + : ConversionPattern(typeConverter, SourceOp::getOperationName(), benefit, + context) {} /// Wrappers around the ConversionPattern methods that pass the derived op /// type. void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - rewrite(cast(op), operands, rewriter); + rewrite(cast(op), OpAdaptor(operands, op->getAttrDictionary()), + rewriter); } LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - return matchAndRewrite(cast(op), operands, rewriter); + return matchAndRewrite(cast(op), + OpAdaptor(operands, op->getAttrDictionary()), + rewriter); } - // TODO: Use OperandAdaptor when it supports access to unnamed operands. - - /// Rewrite and Match methods that operate on the SourceOp type. These must be - /// overridden by the derived pattern class. + /// Rewrite and Match methods that operate on the SourceOp type and accept the + /// raw operand values. + /// NOTICE: These methods are deprecated and will be removed. All new code + /// should use the adaptor methods below instead. virtual void rewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { llvm_unreachable("must override matchAndRewrite or a rewrite method"); } - virtual LogicalResult matchAndRewrite(SourceOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { if (failed(match(op))) return failure(); - rewrite(op, operands, rewriter); + rewrite(op, OpAdaptor(operands, op->getAttrDictionary()), rewriter); return success(); } + /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// overridden by the derived pattern class. + virtual void rewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + ValueRange operands = adaptor.getOperands(); + rewrite(op, + ArrayRef(operands.getBase().get(), + operands.size()), + rewriter); + } + virtual LogicalResult + matchAndRewrite(SourceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + ValueRange operands = adaptor.getOperands(); + return matchAndRewrite( + op, + ArrayRef(operands.getBase().get(), + operands.size()), + rewriter); + } + private: using ConversionPattern::matchAndRewrite; }; -} // namespace detail - -/// OpConversionPattern is a wrapper around ConversionPattern that allows for -/// matching and rewriting against an instance of a derived operation class as -/// opposed to a raw Operation. -template -struct OpConversionPattern - : public detail::OpOrInterfaceConversionPatternBase { - OpConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) - : detail::OpOrInterfaceConversionPatternBase( - SourceOp::getOperationName(), benefit, context) {} - OpConversionPattern(TypeConverter &typeConverter, MLIRContext *context, - PatternBenefit benefit = 1) - : detail::OpOrInterfaceConversionPatternBase( - typeConverter, SourceOp::getOperationName(), benefit, context) {} -}; /// OpInterfaceConversionPattern is a wrapper around ConversionPattern that /// allows for matching and rewriting against an instance of an OpInterface /// class as opposed to a raw Operation. template -struct OpInterfaceConversionPattern - : public detail::OpOrInterfaceConversionPatternBase { +class OpInterfaceConversionPattern : public ConversionPattern { +public: OpInterfaceConversionPattern(MLIRContext *context, PatternBenefit benefit = 1) - : detail::OpOrInterfaceConversionPatternBase( - Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(), - benefit, context) {} + : ConversionPattern(Pattern::MatchInterfaceOpTypeTag(), + SourceOp::getInterfaceID(), benefit, context) {} OpInterfaceConversionPattern(TypeConverter &typeConverter, MLIRContext *context, PatternBenefit benefit = 1) - : detail::OpOrInterfaceConversionPatternBase( - typeConverter, Pattern::MatchInterfaceOpTypeTag(), - SourceOp::getInterfaceID(), benefit, context) {} + : ConversionPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(), + SourceOp::getInterfaceID(), benefit, context) {} + + /// Wrappers around the ConversionPattern methods that pass the derived op + /// type. + void rewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + rewrite(cast(op), operands, rewriter); + } + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + return matchAndRewrite(cast(op), operands, rewriter); + } + + /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// overridden by the derived pattern class. + virtual void rewrite(SourceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + llvm_unreachable("must override matchAndRewrite or a rewrite method"); + } + virtual LogicalResult + matchAndRewrite(SourceOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + if (failed(match(op))) + return failure(); + rewrite(op, operands, rewriter); + return success(); + } + +private: + using ConversionPattern::matchAndRewrite; }; /// Add a pattern to the given pattern list to convert the signature of a diff --git a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp --- a/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp +++ b/mlir/lib/Conversion/AsyncToLLVM/AsyncToLLVM.cpp @@ -326,7 +326,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CoroIdOp op, ArrayRef operands, + matchAndRewrite(CoroIdOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto token = AsyncAPI::tokenType(op->getContext()); auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); @@ -356,7 +356,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CoroBeginOp op, ArrayRef operands, + matchAndRewrite(CoroBeginOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); auto loc = op->getLoc(); @@ -371,7 +371,7 @@ ValueRange(coroSize.getResult())); // Begin a coroutine: @llvm.coro.begin. - auto coroId = CoroBeginOpAdaptor(operands).id(); + auto coroId = CoroBeginOpAdaptor(adaptor.getOperands()).id(); rewriter.replaceOpWithNewOp( op, i8Ptr, ValueRange({coroId, coroAlloc.getResult(0)})); @@ -390,13 +390,14 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CoroFreeOp op, ArrayRef operands, + matchAndRewrite(CoroFreeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto i8Ptr = AsyncAPI::opaquePointerType(op->getContext()); auto loc = op->getLoc(); // Get a pointer to the coroutine frame memory: @llvm.coro.free. - auto coroMem = rewriter.create(loc, i8Ptr, operands); + auto coroMem = + rewriter.create(loc, i8Ptr, adaptor.getOperands()); // Free the memory. rewriter.replaceOpWithNewOp( @@ -418,14 +419,14 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CoroEndOp op, ArrayRef operands, + matchAndRewrite(CoroEndOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We are not in the block that is part of the unwind sequence. auto constFalse = rewriter.create( op->getLoc(), rewriter.getI1Type(), rewriter.getBoolAttr(false)); // Mark the end of a coroutine: @llvm.coro.end. - auto coroHdl = CoroEndOpAdaptor(operands).handle(); + auto coroHdl = adaptor.handle(); rewriter.create(op->getLoc(), rewriter.getI1Type(), ValueRange({coroHdl, constFalse})); rewriter.eraseOp(op); @@ -445,11 +446,11 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CoroSaveOp op, ArrayRef operands, + matchAndRewrite(CoroSaveOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Save the coroutine state: @llvm.coro.save rewriter.replaceOpWithNewOp( - op, AsyncAPI::tokenType(op->getContext()), operands); + op, AsyncAPI::tokenType(op->getContext()), adaptor.getOperands()); return success(); } @@ -491,7 +492,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CoroSuspendOp op, ArrayRef operands, + matchAndRewrite(CoroSuspendOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto i8 = rewriter.getIntegerType(8); auto i32 = rewriter.getI32Type(); @@ -502,7 +503,7 @@ loc, rewriter.getI1Type(), rewriter.getBoolAttr(false)); // Suspend a coroutine: @llvm.coro.suspend - auto coroState = CoroSuspendOpAdaptor(operands).state(); + auto coroState = adaptor.state(); auto coroSuspend = rewriter.create( loc, i8, ValueRange({coroState, constFalse})); @@ -541,7 +542,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeCreateOp op, ArrayRef operands, + matchAndRewrite(RuntimeCreateOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { TypeConverter *converter = getTypeConverter(); Type resultType = op->getResultTypes()[0]; @@ -595,13 +596,14 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeCreateGroupOp op, ArrayRef operands, + matchAndRewrite(RuntimeCreateGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { TypeConverter *converter = getTypeConverter(); Type resultType = op.getResult().getType(); - rewriter.replaceOpWithNewOp( - op, kCreateGroup, converter->convertType(resultType), operands); + rewriter.replaceOpWithNewOp(op, kCreateGroup, + converter->convertType(resultType), + adaptor.getOperands()); return success(); } }; @@ -618,14 +620,15 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeSetAvailableOp op, ArrayRef operands, + matchAndRewrite(RuntimeSetAvailableOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.operand().getType()) .Case([](Type) { return kEmplaceToken; }) .Case([](Type) { return kEmplaceValue; }); - rewriter.replaceOpWithNewOp(op, apiFuncName, TypeRange(), operands); + rewriter.replaceOpWithNewOp(op, apiFuncName, TypeRange(), + adaptor.getOperands()); return success(); } @@ -643,14 +646,15 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeSetErrorOp op, ArrayRef operands, + matchAndRewrite(RuntimeSetErrorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.operand().getType()) .Case([](Type) { return kSetTokenError; }) .Case([](Type) { return kSetValueError; }); - rewriter.replaceOpWithNewOp(op, apiFuncName, TypeRange(), operands); + rewriter.replaceOpWithNewOp(op, apiFuncName, TypeRange(), + adaptor.getOperands()); return success(); } @@ -667,7 +671,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeIsErrorOp op, ArrayRef operands, + matchAndRewrite(RuntimeIsErrorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.operand().getType()) @@ -676,7 +680,7 @@ .Case([](Type) { return kIsValueError; }); rewriter.replaceOpWithNewOp(op, apiFuncName, rewriter.getI1Type(), - operands); + adaptor.getOperands()); return success(); } }; @@ -692,7 +696,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeAwaitOp op, ArrayRef operands, + matchAndRewrite(RuntimeAwaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.operand().getType()) @@ -700,7 +704,8 @@ .Case([](Type) { return kAwaitValue; }) .Case([](Type) { return kAwaitGroup; }); - rewriter.create(op->getLoc(), apiFuncName, TypeRange(), operands); + rewriter.create(op->getLoc(), apiFuncName, TypeRange(), + adaptor.getOperands()); rewriter.eraseOp(op); return success(); @@ -719,7 +724,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeAwaitAndResumeOp op, ArrayRef operands, + matchAndRewrite(RuntimeAwaitAndResumeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { StringRef apiFuncName = TypeSwitch(op.operand().getType()) @@ -727,8 +732,8 @@ .Case([](Type) { return kAwaitValueAndExecute; }) .Case([](Type) { return kAwaitAllAndExecute; }); - Value operand = RuntimeAwaitAndResumeOpAdaptor(operands).operand(); - Value handle = RuntimeAwaitAndResumeOpAdaptor(operands).handle(); + Value operand = adaptor.operand(); + Value handle = adaptor.handle(); // A pointer to coroutine resume intrinsic wrapper. addResumeFunction(op->getParentOfType()); @@ -755,7 +760,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeResumeOp op, ArrayRef operands, + matchAndRewrite(RuntimeResumeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // A pointer to coroutine resume intrinsic wrapper. addResumeFunction(op->getParentOfType()); @@ -764,7 +769,7 @@ op->getLoc(), LLVM::LLVMPointerType::get(resumeFnTy), kResume); // Call async runtime API to execute a coroutine in the managed thread. - auto coroHdl = RuntimeResumeOpAdaptor(operands).handle(); + auto coroHdl = adaptor.handle(); rewriter.replaceOpWithNewOp(op, TypeRange(), kExecute, ValueRange({coroHdl, resumePtr.res()})); @@ -783,13 +788,13 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeStoreOp op, ArrayRef operands, + matchAndRewrite(RuntimeStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); // Get a pointer to the async value storage from the runtime. auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); - auto storage = RuntimeStoreOpAdaptor(operands).storage(); + auto storage = adaptor.storage(); auto storagePtr = rewriter.create(loc, kGetValueStorage, TypeRange(i8Ptr), storage); @@ -805,7 +810,7 @@ storagePtr.getResult(0)); // Store the yielded value into the async value storage. - auto value = RuntimeStoreOpAdaptor(operands).value(); + auto value = adaptor.value(); rewriter.create(loc, value, castedStoragePtr.getResult()); // Erase the original runtime store operation. @@ -826,13 +831,13 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeLoadOp op, ArrayRef operands, + matchAndRewrite(RuntimeLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); // Get a pointer to the async value storage from the runtime. auto i8Ptr = AsyncAPI::opaquePointerType(rewriter.getContext()); - auto storage = RuntimeLoadOpAdaptor(operands).storage(); + auto storage = adaptor.storage(); auto storagePtr = rewriter.create(loc, kGetValueStorage, TypeRange(i8Ptr), storage); @@ -866,15 +871,15 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(RuntimeAddToGroupOp op, ArrayRef operands, + matchAndRewrite(RuntimeAddToGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Currently we can only add tokens to the group. if (!op.operand().getType().isa()) return rewriter.notifyMatchFailure(op, "only token type is supported"); // Replace with a runtime API function call. - rewriter.replaceOpWithNewOp(op, kAddTokenToGroup, - rewriter.getI64Type(), operands); + rewriter.replaceOpWithNewOp( + op, kAddTokenToGroup, rewriter.getI64Type(), adaptor.getOperands()); return success(); } @@ -896,13 +901,13 @@ apiFunctionName(apiFunctionName) {} LogicalResult - matchAndRewrite(RefCountingOp op, ArrayRef operands, + matchAndRewrite(RefCountingOp op, typename RefCountingOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto count = rewriter.create(op->getLoc(), rewriter.getI32Type(), rewriter.getI32IntegerAttr(op.count())); - auto operand = typename RefCountingOp::Adaptor(operands).operand(); + auto operand = adaptor.operand(); rewriter.replaceOpWithNewOp(op, TypeRange(), apiFunctionName, ValueRange({operand, count})); @@ -937,9 +942,9 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, ArrayRef operands, + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; @@ -1032,7 +1037,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ExecuteOp op, ArrayRef operands, + matchAndRewrite(ExecuteOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { ExecuteOp newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); @@ -1040,7 +1045,7 @@ newOp.getRegion().end()); // Set operands and update block argument and result types. - newOp->setOperands(operands); + newOp->setOperands(adaptor.getOperands()); if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *typeConverter))) return failure(); for (auto result : newOp.getResults()) @@ -1056,9 +1061,9 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AwaitOp op, ArrayRef operands, + matchAndRewrite(AwaitOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands.front()); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands().front()); return success(); } }; @@ -1068,9 +1073,9 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(async::YieldOp op, ArrayRef operands, + matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -26,16 +26,13 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::AbsOp op, ArrayRef operands, + matchAndRewrite(complex::AbsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::AbsOp::Adaptor transformed(operands); auto loc = op.getLoc(); auto type = op.getType(); - Value real = - rewriter.create(loc, type, transformed.complex()); - Value imag = - rewriter.create(loc, type, transformed.complex()); + Value real = rewriter.create(loc, type, adaptor.complex()); + Value imag = rewriter.create(loc, type, adaptor.complex()); Value realSqr = rewriter.create(loc, real, real); Value imagSqr = rewriter.create(loc, imag, imag); Value sqNorm = rewriter.create(loc, realSqr, imagSqr); @@ -53,23 +50,16 @@ AndOp, OrOp>; LogicalResult - matchAndRewrite(ComparisonOp op, ArrayRef operands, + matchAndRewrite(ComparisonOp op, typename ComparisonOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - typename ComparisonOp::Adaptor transformed(operands); auto loc = op.getLoc(); - auto type = transformed.lhs() - .getType() - .template cast() - .getElementType(); - - Value realLhs = - rewriter.create(loc, type, transformed.lhs()); - Value imagLhs = - rewriter.create(loc, type, transformed.lhs()); - Value realRhs = - rewriter.create(loc, type, transformed.rhs()); - Value imagRhs = - rewriter.create(loc, type, transformed.rhs()); + auto type = + adaptor.lhs().getType().template cast().getElementType(); + + Value realLhs = rewriter.create(loc, type, adaptor.lhs()); + Value imagLhs = rewriter.create(loc, type, adaptor.lhs()); + Value realRhs = rewriter.create(loc, type, adaptor.rhs()); + Value imagRhs = rewriter.create(loc, type, adaptor.rhs()); Value realComparison = rewriter.create(loc, p, realLhs, realRhs); Value imagComparison = rewriter.create(loc, p, imagLhs, imagRhs); @@ -87,19 +77,18 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(BinaryComplexOp op, ArrayRef operands, + matchAndRewrite(BinaryComplexOp op, typename BinaryComplexOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - typename BinaryComplexOp::Adaptor transformed(operands); - auto type = transformed.lhs().getType().template cast(); + auto type = adaptor.lhs().getType().template cast(); auto elementType = type.getElementType().template cast(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value realLhs = b.create(elementType, transformed.lhs()); - Value realRhs = b.create(elementType, transformed.rhs()); + Value realLhs = b.create(elementType, adaptor.lhs()); + Value realRhs = b.create(elementType, adaptor.rhs()); Value resultReal = b.create(elementType, realLhs, realRhs); - Value imagLhs = b.create(elementType, transformed.lhs()); - Value imagRhs = b.create(elementType, transformed.rhs()); + Value imagLhs = b.create(elementType, adaptor.lhs()); + Value imagRhs = b.create(elementType, adaptor.rhs()); Value resultImag = b.create(elementType, imagLhs, imagRhs); rewriter.replaceOpWithNewOp(op, type, resultReal, @@ -112,21 +101,20 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::DivOp op, ArrayRef operands, + matchAndRewrite(complex::DivOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::DivOp::Adaptor transformed(operands); auto loc = op.getLoc(); - auto type = transformed.lhs().getType().cast(); + auto type = adaptor.lhs().getType().cast(); auto elementType = type.getElementType().cast(); Value lhsReal = - rewriter.create(loc, elementType, transformed.lhs()); + rewriter.create(loc, elementType, adaptor.lhs()); Value lhsImag = - rewriter.create(loc, elementType, transformed.lhs()); + rewriter.create(loc, elementType, adaptor.lhs()); Value rhsReal = - rewriter.create(loc, elementType, transformed.rhs()); + rewriter.create(loc, elementType, adaptor.rhs()); Value rhsImag = - rewriter.create(loc, elementType, transformed.rhs()); + rewriter.create(loc, elementType, adaptor.rhs()); // Smith's algorithm to divide complex numbers. It is just a bit smarter // way to compute the following formula: @@ -321,17 +309,16 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::ExpOp op, ArrayRef operands, + matchAndRewrite(complex::ExpOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::ExpOp::Adaptor transformed(operands); auto loc = op.getLoc(); - auto type = transformed.complex().getType().cast(); + auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); Value real = - rewriter.create(loc, elementType, transformed.complex()); + rewriter.create(loc, elementType, adaptor.complex()); Value imag = - rewriter.create(loc, elementType, transformed.complex()); + rewriter.create(loc, elementType, adaptor.complex()); Value expReal = rewriter.create(loc, real); Value cosImag = rewriter.create(loc, imag); Value resultReal = rewriter.create(loc, expReal, cosImag); @@ -348,17 +335,16 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::LogOp op, ArrayRef operands, + matchAndRewrite(complex::LogOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::LogOp::Adaptor transformed(operands); - auto type = transformed.complex().getType().cast(); + auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value abs = b.create(elementType, transformed.complex()); + Value abs = b.create(elementType, adaptor.complex()); Value resultReal = b.create(elementType, abs); - Value real = b.create(elementType, transformed.complex()); - Value imag = b.create(elementType, transformed.complex()); + Value real = b.create(elementType, adaptor.complex()); + Value imag = b.create(elementType, adaptor.complex()); Value resultImag = b.create(elementType, imag, real); rewriter.replaceOpWithNewOp(op, type, resultReal, resultImag); @@ -370,15 +356,14 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::Log1pOp op, ArrayRef operands, + matchAndRewrite(complex::Log1pOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::Log1pOp::Adaptor transformed(operands); - auto type = transformed.complex().getType().cast(); + auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value real = b.create(elementType, transformed.complex()); - Value imag = b.create(elementType, transformed.complex()); + Value real = b.create(elementType, adaptor.complex()); + Value imag = b.create(elementType, adaptor.complex()); Value one = b.create(elementType, b.getFloatAttr(elementType, 1)); Value realPlusOne = b.create(real, one); @@ -392,20 +377,19 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::MulOp op, ArrayRef operands, + matchAndRewrite(complex::MulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::MulOp::Adaptor transformed(operands); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - auto type = transformed.lhs().getType().cast(); + auto type = adaptor.lhs().getType().cast(); auto elementType = type.getElementType().cast(); - Value lhsReal = b.create(elementType, transformed.lhs()); + Value lhsReal = b.create(elementType, adaptor.lhs()); Value lhsRealAbs = b.create(lhsReal); - Value lhsImag = b.create(elementType, transformed.lhs()); + Value lhsImag = b.create(elementType, adaptor.lhs()); Value lhsImagAbs = b.create(lhsImag); - Value rhsReal = b.create(elementType, transformed.rhs()); + Value rhsReal = b.create(elementType, adaptor.rhs()); Value rhsRealAbs = b.create(rhsReal); - Value rhsImag = b.create(elementType, transformed.rhs()); + Value rhsImag = b.create(elementType, adaptor.rhs()); Value rhsImagAbs = b.create(rhsImag); Value lhsRealTimesRhsReal = b.create(lhsReal, rhsReal); @@ -530,17 +514,16 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::NegOp op, ArrayRef operands, + matchAndRewrite(complex::NegOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::NegOp::Adaptor transformed(operands); auto loc = op.getLoc(); - auto type = transformed.complex().getType().cast(); + auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); Value real = - rewriter.create(loc, elementType, transformed.complex()); + rewriter.create(loc, elementType, adaptor.complex()); Value imag = - rewriter.create(loc, elementType, transformed.complex()); + rewriter.create(loc, elementType, adaptor.complex()); Value negReal = rewriter.create(loc, real); Value negImag = rewriter.create(loc, imag); rewriter.replaceOpWithNewOp(op, type, negReal, negImag); @@ -552,25 +535,23 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(complex::SignOp op, ArrayRef operands, + matchAndRewrite(complex::SignOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - complex::SignOp::Adaptor transformed(operands); - auto type = transformed.complex().getType().cast(); + auto type = adaptor.complex().getType().cast(); auto elementType = type.getElementType().cast(); mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value real = b.create(elementType, transformed.complex()); - Value imag = b.create(elementType, transformed.complex()); + Value real = b.create(elementType, adaptor.complex()); + Value imag = b.create(elementType, adaptor.complex()); Value zero = b.create(elementType, b.getZeroAttr(elementType)); Value realIsZero = b.create(CmpFPredicate::OEQ, real, zero); Value imagIsZero = b.create(CmpFPredicate::OEQ, imag, zero); Value isZero = b.create(realIsZero, imagIsZero); - auto abs = b.create(elementType, transformed.complex()); + auto abs = b.create(elementType, adaptor.complex()); Value realSign = b.create(real, abs); Value imagSign = b.create(imag, abs); Value sign = b.create(type, realSign, imagSign); - rewriter.replaceOpWithNewOp(op, isZero, transformed.complex(), - sign); + rewriter.replaceOpWithNewOp(op, isZero, adaptor.complex(), sign); return success(); } }; diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp @@ -33,7 +33,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -45,7 +45,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -58,7 +58,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(gpu::BlockDimOp op, ArrayRef operands, + matchAndRewrite(gpu::BlockDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -68,7 +68,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(gpu::GPUFuncOp funcOp, ArrayRef operands, + matchAndRewrite(gpu::GPUFuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; private: @@ -81,7 +81,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(gpu::GPUModuleOp moduleOp, ArrayRef operands, + matchAndRewrite(gpu::GPUModuleOp moduleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -91,7 +91,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(gpu::ModuleEndOp endOp, ArrayRef operands, + matchAndRewrite(gpu::ModuleEndOp endOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.eraseOp(endOp); return success(); @@ -105,7 +105,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(gpu::ReturnOp returnOp, ArrayRef operands, + matchAndRewrite(gpu::ReturnOp returnOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -129,7 +129,7 @@ template LogicalResult LaunchConfigConversion::matchAndRewrite( - SourceOp op, ArrayRef operands, + SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { auto index = getLaunchConfigIndex(op); if (!index) @@ -150,7 +150,7 @@ template LogicalResult SingleDimLaunchConfigConversion::matchAndRewrite( - SourceOp op, ArrayRef operands, + SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const { auto *typeConverter = this->template getTypeConverter(); auto indexType = typeConverter->getIndexType(); @@ -162,7 +162,7 @@ } LogicalResult WorkGroupSizeConversion::matchAndRewrite( - gpu::BlockDimOp op, ArrayRef operands, + gpu::BlockDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto index = getLaunchConfigIndex(op); if (!index) @@ -264,7 +264,7 @@ } LogicalResult GPUFuncOpConversion::matchAndRewrite( - gpu::GPUFuncOp funcOp, ArrayRef operands, + gpu::GPUFuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (!gpu::GPUDialect::isKernel(funcOp)) return failure(); @@ -306,7 +306,7 @@ //===----------------------------------------------------------------------===// LogicalResult GPUModuleConversion::matchAndRewrite( - gpu::GPUModuleOp moduleOp, ArrayRef operands, + gpu::GPUModuleOp moduleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { spirv::TargetEnvAttr targetEnv = spirv::lookupTargetEnvOrDefault(moduleOp); spirv::AddressingModel addressingModel = spirv::getAddressingModel(targetEnv); @@ -336,9 +336,9 @@ //===----------------------------------------------------------------------===// LogicalResult GPUReturnOpConversion::matchAndRewrite( - gpu::ReturnOp returnOp, ArrayRef operands, + gpu::ReturnOp returnOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - if (!operands.empty()) + if (!adaptor.getOperands().empty()) return failure(); rewriter.replaceOpWithNewOp(returnOp); diff --git a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp --- a/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp +++ b/mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp @@ -55,7 +55,7 @@ matchAsPerformingReduction(linalg::GenericOp genericOp); LogicalResult - matchAndRewrite(linalg::GenericOp genericOp, ArrayRef operands, + matchAndRewrite(linalg::GenericOp genericOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -109,7 +109,7 @@ } LogicalResult SingleWorkgroupReduction::matchAndRewrite( - linalg::GenericOp genericOp, ArrayRef operands, + linalg::GenericOp genericOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Operation *op = genericOp.getOperation(); auto originalInputType = op->getOperand(0).getType().cast(); @@ -134,7 +134,8 @@ // TODO: Query the target environment to make sure the current // workload fits in a local workgroup. - Value convertedInput = operands[0], convertedOutput = operands[1]; + Value convertedInput = adaptor.getOperands()[0]; + Value convertedOutput = adaptor.getOperands()[1]; Location loc = genericOp.getLoc(); auto *typeConverter = getTypeConverter(); diff --git a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp --- a/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp +++ b/mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp @@ -37,9 +37,9 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(StdOp operation, ArrayRef operands, + matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(operands.size() <= 2); + assert(adaptor.getOperands().size() <= 2); auto dstType = this->getTypeConverter()->convertType(operation.getType()); if (!dstType) return failure(); @@ -48,7 +48,8 @@ return operation.emitError( "bitwidth emulation is not implemented yet on unsigned op"); } - rewriter.template replaceOpWithNewOp(operation, dstType, operands); + rewriter.template replaceOpWithNewOp(operation, dstType, + adaptor.getOperands()); return success(); } }; @@ -62,14 +63,15 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(math::Log1pOp operation, ArrayRef operands, + matchAndRewrite(math::Log1pOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(operands.size() == 1); + assert(adaptor.getOperands().size() == 1); Location loc = operation.getLoc(); auto type = this->getTypeConverter()->convertType(operation.operand().getType()); auto one = spirv::ConstantOp::getOne(type, operation.getLoc(), rewriter); - auto onePlus = rewriter.create(loc, one, operands[0]); + auto onePlus = + rewriter.create(loc, one, adaptor.getOperands()[0]); rewriter.replaceOpWithNewOp(operation, type, onePlus); return success(); } diff --git a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp --- a/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp +++ b/mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp @@ -158,7 +158,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::AllocOp operation, ArrayRef operands, + matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -169,7 +169,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::DeallocOp operation, ArrayRef operands, + matchAndRewrite(memref::DeallocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -179,7 +179,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, + matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -189,7 +189,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, + matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -199,7 +199,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::StoreOp storeOp, ArrayRef operands, + matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -209,7 +209,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::StoreOp storeOp, ArrayRef operands, + matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -220,8 +220,7 @@ //===----------------------------------------------------------------------===// LogicalResult -AllocOpPattern::matchAndRewrite(memref::AllocOp operation, - ArrayRef operands, +AllocOpPattern::matchAndRewrite(memref::AllocOp operation, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MemRefType allocType = operation.getType(); if (!isAllocationSupported(allocType)) @@ -260,7 +259,7 @@ LogicalResult DeallocOpPattern::matchAndRewrite(memref::DeallocOp operation, - ArrayRef operands, + OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { MemRefType deallocType = operation.memref().getType().cast(); if (!isAllocationSupported(deallocType)) @@ -274,10 +273,8 @@ //===----------------------------------------------------------------------===// LogicalResult -IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, - ArrayRef operands, +IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - memref::LoadOpAdaptor loadOperands(operands); auto loc = loadOp.getLoc(); auto memrefType = loadOp.memref().getType().cast(); if (!memrefType.getElementType().isSignlessInteger()) @@ -285,8 +282,8 @@ auto &typeConverter = *getTypeConverter(); spirv::AccessChainOp accessChainOp = - spirv::getElementPtr(typeConverter, memrefType, loadOperands.memref(), - loadOperands.indices(), loc, rewriter); + spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(), + adaptor.indices(), loc, rewriter); if (!accessChainOp) return failure(); @@ -372,15 +369,14 @@ } LogicalResult -LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, ArrayRef operands, +LoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - memref::LoadOpAdaptor loadOperands(operands); auto memrefType = loadOp.memref().getType().cast(); if (memrefType.getElementType().isSignlessInteger()) return failure(); auto loadPtr = spirv::getElementPtr( - *getTypeConverter(), memrefType, - loadOperands.memref(), loadOperands.indices(), loadOp.getLoc(), rewriter); + *getTypeConverter(), memrefType, adaptor.memref(), + adaptor.indices(), loadOp.getLoc(), rewriter); if (!loadPtr) return failure(); @@ -390,10 +386,8 @@ } LogicalResult -IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, - ArrayRef operands, +IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - memref::StoreOpAdaptor storeOperands(operands); auto memrefType = storeOp.memref().getType().cast(); if (!memrefType.getElementType().isSignlessInteger()) return failure(); @@ -401,8 +395,8 @@ auto loc = storeOp.getLoc(); auto &typeConverter = *getTypeConverter(); spirv::AccessChainOp accessChainOp = - spirv::getElementPtr(typeConverter, memrefType, storeOperands.memref(), - storeOperands.indices(), loc, rewriter); + spirv::getElementPtr(typeConverter, memrefType, adaptor.memref(), + adaptor.indices(), loc, rewriter); if (!accessChainOp) return failure(); @@ -427,7 +421,7 @@ assert(dstBits % srcBits == 0); if (srcBits == dstBits) { - Value storeVal = storeOperands.value(); + Value storeVal = adaptor.value(); if (isBool) storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); rewriter.replaceOpWithNewOp( @@ -458,7 +452,7 @@ rewriter.create(loc, dstType, mask, offset); clearBitsMask = rewriter.create(loc, dstType, clearBitsMask); - Value storeVal = storeOperands.value(); + Value storeVal = adaptor.value(); if (isBool) storeVal = castBoolToIntN(loc, storeVal, dstType, rewriter); storeVal = shiftValue(loc, storeVal, offset, mask, dstBits, rewriter); @@ -487,23 +481,20 @@ } LogicalResult -StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, - ArrayRef operands, +StoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - memref::StoreOpAdaptor storeOperands(operands); auto memrefType = storeOp.memref().getType().cast(); if (memrefType.getElementType().isSignlessInteger()) return failure(); - auto storePtr = - spirv::getElementPtr(*getTypeConverter(), memrefType, - storeOperands.memref(), storeOperands.indices(), - storeOp.getLoc(), rewriter); + auto storePtr = spirv::getElementPtr( + *getTypeConverter(), memrefType, adaptor.memref(), + adaptor.indices(), storeOp.getLoc(), rewriter); if (!storePtr) return failure(); rewriter.replaceOpWithNewOp(storeOp, storePtr, - storeOperands.value()); + adaptor.value()); return success(); } diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -84,7 +84,7 @@ using SCFToSPIRVPattern::SCFToSPIRVPattern; LogicalResult - matchAndRewrite(scf::ForOp forOp, ArrayRef operands, + matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -95,7 +95,7 @@ using SCFToSPIRVPattern::SCFToSPIRVPattern; LogicalResult - matchAndRewrite(scf::IfOp ifOp, ArrayRef operands, + matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -104,7 +104,7 @@ using SCFToSPIRVPattern::SCFToSPIRVPattern; LogicalResult - matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef operands, + matchAndRewrite(scf::YieldOp terminatorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace @@ -146,14 +146,13 @@ //===----------------------------------------------------------------------===// LogicalResult -ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef operands, +ForOpConversion::matchAndRewrite(scf::ForOp forOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // scf::ForOp can be lowered to the structured control flow represented by // spirv::LoopOp by making the continue block of the spirv::LoopOp the loop // latch and the merge block the exit block. The resulting spirv::LoopOp has a // single back edge from the continue to header block, and a single exit from // header to merge. - scf::ForOpAdaptor forOperands(operands); auto loc = forOp.getLoc(); auto loopOp = rewriter.create(loc, spirv::LoopControl::None); loopOp.addEntryAndMergeBlock(); @@ -165,9 +164,8 @@ loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header); // Create the new induction variable to use. - BlockArgument newIndVar = - header->addArgument(forOperands.lowerBound().getType()); - for (Value arg : forOperands.initArgs()) + BlockArgument newIndVar = header->addArgument(adaptor.lowerBound().getType()); + for (Value arg : adaptor.initArgs()) header->addArgument(arg.getType()); Block *body = forOp.getBody(); @@ -187,8 +185,8 @@ rewriter.inlineRegionBefore(forOp->getRegion(0), loopOp.body(), std::next(loopOp.body().begin(), 2)); - SmallVector args(1, forOperands.lowerBound()); - args.append(forOperands.initArgs().begin(), forOperands.initArgs().end()); + SmallVector args(1, adaptor.lowerBound()); + args.append(adaptor.initArgs().begin(), adaptor.initArgs().end()); // Branch into it from the entry. rewriter.setInsertionPointToEnd(&(loopOp.body().front())); rewriter.create(loc, header, args); @@ -197,7 +195,7 @@ rewriter.setInsertionPointToEnd(header); auto *mergeBlock = loopOp.getMergeBlock(); auto cmpOp = rewriter.create( - loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound()); + loc, rewriter.getI1Type(), newIndVar, adaptor.upperBound()); rewriter.create( loc, cmpOp, body, ArrayRef(), mergeBlock, ArrayRef()); @@ -209,7 +207,7 @@ // Add the step to the induction variable and branch to the header. Value updatedIndVar = rewriter.create( - loc, newIndVar.getType(), newIndVar, forOperands.step()); + loc, newIndVar.getType(), newIndVar, adaptor.step()); rewriter.create(loc, header, updatedIndVar); // Infer the return types from the init operands. Vector type may get @@ -217,7 +215,7 @@ // extra logic to figure out the right type we just infer it from the Init // operands. SmallVector initTypes; - for (auto arg : forOperands.initArgs()) + for (auto arg : adaptor.initArgs()) initTypes.push_back(arg.getType()); replaceSCFOutputValue(forOp, loopOp, rewriter, scfToSPIRVContext, initTypes); return success(); @@ -228,12 +226,11 @@ //===----------------------------------------------------------------------===// LogicalResult -IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef operands, +IfOpConversion::matchAndRewrite(scf::IfOp ifOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // When lowering `scf::IfOp` we explicitly create a selection header block // before the control flow diverges and a merge block where control flow // subsequently converges. - scf::IfOpAdaptor ifOperands(operands); auto loc = ifOp.getLoc(); // Create `spv.selection` operation, selection header block and merge block. @@ -267,7 +264,7 @@ // Create a `spv.BranchConditional` operation for selection header block. rewriter.setInsertionPointToEnd(selectionHeaderBlock); - rewriter.create(loc, ifOperands.condition(), + rewriter.create(loc, adaptor.condition(), thenBlock, ArrayRef(), elseBlock, ArrayRef()); @@ -289,8 +286,10 @@ /// parent region. For loops we also need to update the branch looping back to /// the header with the loop carried values. LogicalResult TerminatorOpConversion::matchAndRewrite( - scf::YieldOp terminatorOp, ArrayRef operands, + scf::YieldOp terminatorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + ValueRange operands = adaptor.getOperands(); + // If the region is return values, store each value into the associated // VariableOp created during lowering of the parent region. if (!operands.empty()) { diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -302,7 +302,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::AccessChainOp op, ArrayRef operands, + matchAndRewrite(spirv::AccessChainOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = typeConverter.convertType(op.component_ptr().getType()); if (!dstType) @@ -327,7 +327,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::AddressOfOp op, ArrayRef operands, + matchAndRewrite(spirv::AddressOfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = typeConverter.convertType(op.pointer().getType()); if (!dstType) @@ -343,7 +343,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::BitFieldInsertOp op, ArrayRef operands, + matchAndRewrite(spirv::BitFieldInsertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); @@ -387,7 +387,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::ConstantOp constOp, ArrayRef operands, + matchAndRewrite(spirv::ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = constOp.getType(); if (!srcType.isa() && !srcType.isIntOrFloat()) @@ -419,8 +419,8 @@ rewriter.replaceOpWithNewOp(constOp, dstType, dstAttr); return success(); } - rewriter.replaceOpWithNewOp(constOp, dstType, operands, - constOp->getAttrs()); + rewriter.replaceOpWithNewOp( + constOp, dstType, adaptor.getOperands(), constOp->getAttrs()); return success(); } }; @@ -431,7 +431,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::BitFieldSExtractOp op, ArrayRef operands, + matchAndRewrite(spirv::BitFieldSExtractOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); @@ -484,7 +484,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::BitFieldUExtractOp op, ArrayRef operands, + matchAndRewrite(spirv::BitFieldUExtractOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); @@ -518,9 +518,9 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::BranchOp branchOp, ArrayRef operands, + matchAndRewrite(spirv::BranchOp branchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(branchOp, operands, + rewriter.replaceOpWithNewOp(branchOp, adaptor.getOperands(), branchOp.getTarget()); return success(); } @@ -533,7 +533,7 @@ spirv::BranchConditionalOp>::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::BranchConditionalOp op, ArrayRef operands, + matchAndRewrite(spirv::BranchConditionalOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // If branch weights exist, map them to 32-bit integer vector. ElementsAttr branchWeights = nullptr; @@ -560,7 +560,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::CompositeExtractOp op, ArrayRef operands, + matchAndRewrite(spirv::CompositeExtractOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(op.getType()); if (!dstType) @@ -590,7 +590,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::CompositeInsertOp op, ArrayRef operands, + matchAndRewrite(spirv::CompositeInsertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(op.getType()); if (!dstType) @@ -619,13 +619,13 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, ArrayRef operands, + matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); if (!dstType) return failure(); - rewriter.template replaceOpWithNewOp(operation, dstType, operands, - operation->getAttrs()); + rewriter.template replaceOpWithNewOp( + operation, dstType, adaptor.getOperands(), operation->getAttrs()); return success(); } }; @@ -638,7 +638,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::ExecutionModeOp op, ArrayRef operands, + matchAndRewrite(spirv::ExecutionModeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // First, create the global struct's name that would be associated with // this entry point's execution mode. We set it to be: @@ -717,7 +717,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::GlobalVariableOp op, ArrayRef operands, + matchAndRewrite(spirv::GlobalVariableOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Currently, there is no support of initialization with a constant value in // SPIR-V dialect. Specialization constants are not considered as well. @@ -767,7 +767,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, ArrayRef operands, + matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type fromType = operation.operand().getType(); @@ -779,12 +779,12 @@ if (getBitWidth(fromType) < getBitWidth(toType)) { rewriter.template replaceOpWithNewOp(operation, dstType, - operands); + adaptor.getOperands()); return success(); } if (getBitWidth(fromType) > getBitWidth(toType)) { rewriter.template replaceOpWithNewOp(operation, dstType, - operands); + adaptor.getOperands()); return success(); } return failure(); @@ -797,18 +797,18 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::FunctionCallOp callOp, ArrayRef operands, + matchAndRewrite(spirv::FunctionCallOp callOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (callOp.getNumResults() == 0) { - rewriter.replaceOpWithNewOp(callOp, llvm::None, operands, - callOp->getAttrs()); + rewriter.replaceOpWithNewOp( + callOp, llvm::None, adaptor.getOperands(), callOp->getAttrs()); return success(); } // Function returns a single result. auto dstType = typeConverter.convertType(callOp.getType(0)); - rewriter.replaceOpWithNewOp(callOp, dstType, operands, - callOp->getAttrs()); + rewriter.replaceOpWithNewOp( + callOp, dstType, adaptor.getOperands(), callOp->getAttrs()); return success(); } }; @@ -820,7 +820,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, ArrayRef operands, + matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); @@ -841,7 +841,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, ArrayRef operands, + matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); @@ -861,7 +861,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::GLSLInverseSqrtOp op, ArrayRef operands, + matchAndRewrite(spirv::GLSLInverseSqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = op.getType(); auto dstType = typeConverter.convertType(srcType); @@ -877,15 +877,14 @@ }; /// Converts `spv.Load` and `spv.Store` to LLVM dialect. -template -class LoadStorePattern : public SPIRVToLLVMConversion { +template +class LoadStorePattern : public SPIRVToLLVMConversion { public: - using SPIRVToLLVMConversion::SPIRVToLLVMConversion; + using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVop op, ArrayRef operands, + matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (!op.memory_access().hasValue()) { return replaceWithLoadOrStore( op, rewriter, this->typeConverter, /*alignment=*/0, @@ -918,9 +917,8 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp notOp, ArrayRef operands, + matchAndRewrite(SPIRVOp notOp, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcType = notOp.getType(); auto dstType = this->typeConverter.convertType(srcType); if (!dstType) @@ -947,7 +945,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp op, ArrayRef operands, + matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.eraseOp(op); return success(); @@ -959,7 +957,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::ReturnOp returnOp, ArrayRef operands, + matchAndRewrite(spirv::ReturnOp returnOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(returnOp, ArrayRef(), ArrayRef()); @@ -972,10 +970,10 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::ReturnValueOp returnValueOp, ArrayRef operands, + matchAndRewrite(spirv::ReturnValueOp returnValueOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(returnValueOp, ArrayRef(), - operands); + adaptor.getOperands()); return success(); } }; @@ -1033,7 +1031,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::LoopOp loopOp, ArrayRef operands, + matchAndRewrite(spirv::LoopOp loopOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // There is no support of loop control at the moment. if (loopOp.loop_control() != spirv::LoopControl::None) @@ -1080,7 +1078,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::SelectionOp op, ArrayRef operands, + matchAndRewrite(spirv::SelectionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // There is no support for `Flatten` or `DontFlatten` selection control at // the moment. This are just compiler hints and can be performed during the @@ -1149,7 +1147,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(SPIRVOp operation, ArrayRef operands, + matchAndRewrite(SPIRVOp operation, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->typeConverter.convertType(operation.getType()); @@ -1161,7 +1159,7 @@ if (op1Type == op2Type) { rewriter.template replaceOpWithNewOp(operation, dstType, - operands); + adaptor.getOperands()); return success(); } @@ -1186,7 +1184,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::GLSLTanOp tanOp, ArrayRef operands, + matchAndRewrite(spirv::GLSLTanOp tanOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = typeConverter.convertType(tanOp.getType()); if (!dstType) @@ -1211,7 +1209,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::GLSLTanhOp tanhOp, ArrayRef operands, + matchAndRewrite(spirv::GLSLTanhOp tanhOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = tanhOp.getType(); auto dstType = typeConverter.convertType(srcType); @@ -1239,7 +1237,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::VariableOp varOp, ArrayRef operands, + matchAndRewrite(spirv::VariableOp varOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcType = varOp.getType(); // Initialization is supported for scalars and vectors only. @@ -1274,7 +1272,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::FuncOp funcOp, ArrayRef operands, + matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Convert function signature. At the moment LLVMType converter is enough @@ -1337,7 +1335,7 @@ using SPIRVToLLVMConversion::SPIRVToLLVMConversion; LogicalResult - matchAndRewrite(spirv::ModuleOp spvModuleOp, ArrayRef operands, + matchAndRewrite(spirv::ModuleOp spvModuleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto newModuleOp = diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -29,19 +29,17 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AnyOp op, ArrayRef operands, + matchAndRewrite(AnyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult -AnyOpConversion::matchAndRewrite(AnyOp op, ArrayRef operands, +AnyOpConversion::matchAndRewrite(AnyOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - AnyOp::Adaptor transformed(operands); - // Replace `any` with its first operand. // Any operand would be a valid substitution. - rewriter.replaceOp(op, {transformed.inputs().front()}); + rewriter.replaceOp(op, {adaptor.inputs().front()}); return success(); } @@ -52,16 +50,13 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SrcOpTy op, ArrayRef operands, + matchAndRewrite(SrcOpTy op, typename SrcOpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - typename SrcOpTy::Adaptor transformed(operands); - // For now, only error-free types are supported by this lowering. if (op.getType().template isa()) return failure(); - rewriter.replaceOpWithNewOp(op, transformed.lhs(), - transformed.rhs()); + rewriter.replaceOpWithNewOp(op, adaptor.lhs(), adaptor.rhs()); return success(); } }; @@ -72,7 +67,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(BroadcastOp op, ArrayRef operands, + matchAndRewrite(BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -120,7 +115,7 @@ } // namespace LogicalResult BroadcastOpConverter::matchAndRewrite( - BroadcastOp op, ArrayRef operands, + BroadcastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands, not // on shapes. @@ -129,7 +124,6 @@ auto loc = op.getLoc(); ImplicitLocOpBuilder lb(loc, rewriter); - BroadcastOp::Adaptor transformed(operands); Value zero = lb.create(0); Type indexTy = lb.getIndexType(); @@ -138,7 +132,7 @@ // representing the shape extents, the rank is the extent of the only // dimension in the tensor. SmallVector ranks, rankDiffs; - llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) { + llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) { return lb.create(v, zero); })); @@ -157,9 +151,8 @@ Value replacement = lb.create( getExtentTensorType(lb.getContext()), ValueRange{maxRank}, [&](OpBuilder &b, Location loc, ValueRange args) { - Value broadcastedDim = - getBroadcastedDim(ImplicitLocOpBuilder(loc, b), - transformed.shapes(), rankDiffs, args[0]); + Value broadcastedDim = getBroadcastedDim( + ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, args[0]); b.create(loc, broadcastedDim); }); @@ -175,13 +168,13 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConstShapeOp op, ArrayRef operands, + matchAndRewrite(ConstShapeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult ConstShapeOpConverter::matchAndRewrite( - ConstShapeOp op, ArrayRef operands, + ConstShapeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering supports only extent tensors, not `shape.shape` @@ -209,13 +202,13 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConstSizeOp op, ArrayRef operands, + matchAndRewrite(ConstSizeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult ConstSizeOpConversion::matchAndRewrite( - ConstSizeOp op, ArrayRef operands, + ConstSizeOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { rewriter.replaceOpWithNewOp(op, op.value().getSExtValue()); return success(); @@ -227,17 +220,16 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(IsBroadcastableOp op, ArrayRef operands, + matchAndRewrite(IsBroadcastableOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult IsBroadcastableOpConverter::matchAndRewrite( - IsBroadcastableOp op, ArrayRef operands, + IsBroadcastableOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands, not // on shapes. - IsBroadcastableOp::Adaptor transformed(operands); if (!llvm::all_of(op.shapes(), [](Value v) { return !v.getType().isa(); })) return failure(); @@ -252,7 +244,7 @@ // representing the shape extents, the rank is the extent of the only // dimension in the tensor. SmallVector ranks, rankDiffs; - llvm::append_range(ranks, llvm::map_range(transformed.shapes(), [&](Value v) { + llvm::append_range(ranks, llvm::map_range(adaptor.shapes(), [&](Value v) { return lb.create(v, zero); })); @@ -279,10 +271,10 @@ // could reuse the Broadcast lowering entirely, but we redo the work // here to make optimizations easier between the two loops. Value broadcastedDim = getBroadcastedDim( - ImplicitLocOpBuilder(loc, b), transformed.shapes(), rankDiffs, iv); + ImplicitLocOpBuilder(loc, b), adaptor.shapes(), rankDiffs, iv); Value broadcastable = iterArgs[0]; - for (auto tup : llvm::zip(transformed.shapes(), rankDiffs)) { + for (auto tup : llvm::zip(adaptor.shapes(), rankDiffs)) { Value shape, rankDiff; std::tie(shape, rankDiff) = tup; Value outOfBounds = @@ -327,16 +319,14 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(GetExtentOp op, ArrayRef operands, + matchAndRewrite(GetExtentOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult GetExtentOpConverter::matchAndRewrite( - GetExtentOp op, ArrayRef operands, + GetExtentOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - GetExtentOp::Adaptor transformed(operands); - // For now, only error-free types are supported by this lowering. if (op.getType().isa()) return failure(); @@ -346,14 +336,13 @@ if (auto shapeOfOp = op.shape().getDefiningOp()) { if (shapeOfOp.arg().getType().isa()) { rewriter.replaceOpWithNewOp(op, shapeOfOp.arg(), - transformed.dim()); + adaptor.dim()); return success(); } } - rewriter.replaceOpWithNewOp(op, rewriter.getIndexType(), - transformed.shape(), - ValueRange{transformed.dim()}); + rewriter.replaceOpWithNewOp( + op, rewriter.getIndexType(), adaptor.shape(), ValueRange{adaptor.dim()}); return success(); } @@ -363,20 +352,19 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(shape::RankOp op, ArrayRef operands, + matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult -RankOpConverter::matchAndRewrite(shape::RankOp op, ArrayRef operands, +RankOpConverter::matchAndRewrite(shape::RankOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering supports only error-free types. if (op.getType().isa()) return failure(); - shape::RankOp::Adaptor transformed(operands); - rewriter.replaceOpWithNewOp(op, transformed.shape(), 0); + rewriter.replaceOpWithNewOp(op, adaptor.shape(), 0); return success(); } @@ -387,32 +375,30 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(shape::ReduceOp op, ArrayRef operands, + matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final; }; } // namespace LogicalResult -ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef operands, +ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, this lowering is only defined on `tensor` operands. if (op.shape().getType().isa()) return failure(); auto loc = op.getLoc(); - shape::ReduceOp::Adaptor transformed(operands); Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); Type indexTy = rewriter.getIndexType(); Value rank = - rewriter.create(loc, indexTy, transformed.shape(), zero); + rewriter.create(loc, indexTy, adaptor.shape(), zero); auto loop = rewriter.create( loc, zero, rank, one, op.initVals(), [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { - Value extent = - b.create(loc, transformed.shape(), iv); + Value extent = b.create(loc, adaptor.shape(), iv); SmallVector mappedValues{iv, extent}; mappedValues.append(args.begin(), args.end()); @@ -468,13 +454,13 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ShapeEqOp op, ArrayRef operands, + matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult -ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, ArrayRef operands, +ShapeEqOpConverter::matchAndRewrite(ShapeEqOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (!llvm::all_of(op.shapes(), [](Value v) { return !v.getType().isa(); })) @@ -487,16 +473,15 @@ return success(); } - ShapeEqOp::Adaptor transformed(operands); auto loc = op.getLoc(); Type indexTy = rewriter.getIndexType(); Value zero = rewriter.create(loc, 0); - Value firstShape = transformed.shapes().front(); + Value firstShape = adaptor.shapes().front(); Value firstRank = rewriter.create(loc, indexTy, firstShape, zero); Value result = nullptr; // Generate a linear sequence of compares, all with firstShape as lhs. - for (Value shape : transformed.shapes().drop_front(1)) { + for (Value shape : adaptor.shapes().drop_front(1)) { Value rank = rewriter.create(loc, indexTy, shape, zero); Value eqRank = rewriter.create(loc, CmpIPredicate::eq, firstRank, rank); @@ -536,13 +521,13 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ShapeOfOp op, ArrayRef operands, + matchAndRewrite(ShapeOfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult ShapeOfOpConversion::matchAndRewrite( - ShapeOfOp op, ArrayRef operands, + ShapeOfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // For now, only error-free types are supported by this lowering. @@ -551,8 +536,7 @@ // For ranked tensor arguments, lower to `tensor.from_elements`. auto loc = op.getLoc(); - ShapeOfOp::Adaptor transformed(operands); - Value tensor = transformed.arg(); + Value tensor = adaptor.arg(); Type tensorTy = tensor.getType(); if (tensorTy.isa()) { @@ -599,13 +583,13 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SplitAtOp op, ArrayRef operands, + matchAndRewrite(SplitAtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult SplitAtOpConversion::matchAndRewrite( - SplitAtOp op, ArrayRef operands, + SplitAtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { // Error conditions are not implemented, only lower if all operands and // results are extent tensors. @@ -613,13 +597,12 @@ [](Value v) { return v.getType().isa(); })) return failure(); - SplitAtOp::Adaptor transformed(op); ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value zero = b.create(0); - Value rank = b.create(transformed.operand(), zero); + Value rank = b.create(adaptor.operand(), zero); // index < 0 ? index + rank : index - Value originalIndex = transformed.index(); + Value originalIndex = adaptor.index(); Value add = b.create(originalIndex, rank); Value indexIsNegative = b.create(CmpIPredicate::slt, originalIndex, zero); @@ -627,10 +610,10 @@ Value one = b.create(1); Value head = - b.create(transformed.operand(), zero, index, one); + b.create(adaptor.operand(), zero, index, one); Value tailSize = b.create(rank, index); - Value tail = b.create(transformed.operand(), index, - tailSize, one); + Value tail = + b.create(adaptor.operand(), index, tailSize, one); rewriter.replaceOp(op, {head, tail}); return success(); } @@ -642,10 +625,8 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToExtentTensorOp op, ArrayRef operands, + matchAndRewrite(ToExtentTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - ToExtentTensorOpAdaptor adaptor(operands); - if (!adaptor.input().getType().isa()) return rewriter.notifyMatchFailure(op, "input needs to be a tensor"); diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -292,7 +292,7 @@ : FuncOpConversionBase(converter) {} LogicalResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); if (!newFuncOp) @@ -319,7 +319,7 @@ using FuncOpConversionBase::FuncOpConversionBase; LogicalResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // TODO: bare ptr conversion could be handled by argument materialization @@ -442,10 +442,9 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(AssertOp op, ArrayRef operands, + matchAndRewrite(AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); - AssertOp::Adaptor transformed(operands); // Insert the `abort` declaration if necessary. auto module = op->getParentOfType(); @@ -471,7 +470,7 @@ // Generate assertion test. rewriter.setInsertionPointToEnd(opBlock); rewriter.replaceOpWithNewOp( - op, transformed.arg(), continuationBlock, failureBlock); + op, adaptor.arg(), continuationBlock, failureBlock); return success(); } @@ -481,7 +480,7 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(ConstantOp op, ArrayRef operands, + matchAndRewrite(ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // If constant refers to a function, convert it to "addressof". if (auto symbolRef = op.getValue().dyn_cast()) { @@ -506,8 +505,8 @@ op, "referring to a symbol outside of the current module"); return LLVM::detail::oneToOneRewrite( - op, LLVM::ConstantOp::getOperationName(), operands, *getTypeConverter(), - rewriter); + op, LLVM::ConstantOp::getOperationName(), adaptor.getOperands(), + *getTypeConverter(), rewriter); } }; @@ -520,10 +519,8 @@ using Base = ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(CallOpType callOp, ArrayRef operands, + matchAndRewrite(CallOpType callOp, typename CallOpType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - typename CallOpType::Adaptor transformed(operands); - // Pack the result types into a struct. Type packedResult = nullptr; unsigned numResults = callOp.getNumResults(); @@ -536,8 +533,8 @@ } auto promoted = this->getTypeConverter()->promoteOperands( - callOp.getLoc(), /*opOperands=*/callOp->getOperands(), operands, - rewriter); + callOp.getLoc(), /*opOperands=*/callOp->getOperands(), + adaptor.getOperands(), rewriter); auto newOp = rewriter.create( callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), promoted, callOp->getAttrs()); @@ -591,22 +588,21 @@ UnrealizedConversionCastOp>::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(UnrealizedConversionCastOp op, ArrayRef operands, + matchAndRewrite(UnrealizedConversionCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - UnrealizedConversionCastOp::Adaptor transformed(operands); SmallVector convertedTypes; if (succeeded(typeConverter->convertTypes(op.outputs().getTypes(), convertedTypes)) && - convertedTypes == transformed.inputs().getTypes()) { - rewriter.replaceOp(op, transformed.inputs()); + convertedTypes == adaptor.inputs().getTypes()) { + rewriter.replaceOp(op, adaptor.inputs()); return success(); } convertedTypes.clear(); - if (succeeded(typeConverter->convertTypes(transformed.inputs().getTypes(), + if (succeeded(typeConverter->convertTypes(adaptor.inputs().getTypes(), convertedTypes)) && convertedTypes == op.outputs().getType()) { - rewriter.replaceOp(op, transformed.inputs()); + rewriter.replaceOp(op, adaptor.inputs()); return success(); } return failure(); @@ -617,12 +613,12 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(RankOp op, ArrayRef operands, + matchAndRewrite(RankOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); Type operandType = op.memrefOrTensor().getType(); if (auto unrankedMemRefType = operandType.dyn_cast()) { - UnrankedMemRefDescriptor desc(RankOp::Adaptor(operands).memrefOrTensor()); + UnrankedMemRefDescriptor desc(adaptor.memrefOrTensor()); rewriter.replaceOp(op, {desc.rank(rewriter, loc)}); return success(); } @@ -658,10 +654,8 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(IndexCastOp indexCastOp, ArrayRef operands, + matchAndRewrite(IndexCastOp indexCastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - IndexCastOpAdaptor transformed(operands); - auto targetType = typeConverter->convertType(indexCastOp.getResult().getType()); auto targetElementType = @@ -669,18 +663,18 @@ ->convertType(getElementTypeOrSelf(indexCastOp.getResult())) .cast(); auto sourceElementType = - getElementTypeOrSelf(transformed.in()).cast(); + getElementTypeOrSelf(adaptor.in()).cast(); unsigned targetBits = targetElementType.getWidth(); unsigned sourceBits = sourceElementType.getWidth(); if (targetBits == sourceBits) - rewriter.replaceOp(indexCastOp, transformed.in()); + rewriter.replaceOp(indexCastOp, adaptor.in()); else if (targetBits < sourceBits) rewriter.replaceOpWithNewOp(indexCastOp, targetType, - transformed.in()); + adaptor.in()); else rewriter.replaceOpWithNewOp(indexCastOp, targetType, - transformed.in()); + adaptor.in()); return success(); } }; @@ -696,10 +690,9 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(CmpIOp cmpiOp, ArrayRef operands, + matchAndRewrite(CmpIOp cmpiOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - CmpIOpAdaptor transformed(operands); - auto operandType = transformed.lhs().getType(); + auto operandType = adaptor.lhs().getType(); auto resultType = cmpiOp.getResult().getType(); // Handle the scalar and 1D vector cases. @@ -707,7 +700,7 @@ rewriter.replaceOpWithNewOp( cmpiOp, typeConverter->convertType(resultType), convertCmpPredicate(cmpiOp.getPredicate()), - transformed.lhs(), transformed.rhs()); + adaptor.lhs(), adaptor.rhs()); return success(); } @@ -716,13 +709,13 @@ return rewriter.notifyMatchFailure(cmpiOp, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( - cmpiOp.getOperation(), operands, *getTypeConverter(), + cmpiOp.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { - CmpIOpAdaptor transformed(operands); + CmpIOpAdaptor adaptor(operands); return rewriter.create( cmpiOp.getLoc(), llvm1DVectorTy, convertCmpPredicate(cmpiOp.getPredicate()), - transformed.lhs(), transformed.rhs()); + adaptor.lhs(), adaptor.rhs()); }, rewriter); @@ -734,10 +727,9 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(CmpFOp cmpfOp, ArrayRef operands, + matchAndRewrite(CmpFOp cmpfOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - CmpFOpAdaptor transformed(operands); - auto operandType = transformed.lhs().getType(); + auto operandType = adaptor.lhs().getType(); auto resultType = cmpfOp.getResult().getType(); // Handle the scalar and 1D vector cases. @@ -745,7 +737,7 @@ rewriter.replaceOpWithNewOp( cmpfOp, typeConverter->convertType(resultType), convertCmpPredicate(cmpfOp.getPredicate()), - transformed.lhs(), transformed.rhs()); + adaptor.lhs(), adaptor.rhs()); return success(); } @@ -754,13 +746,13 @@ return rewriter.notifyMatchFailure(cmpfOp, "expected vector result type"); return LLVM::detail::handleMultidimensionalVectors( - cmpfOp.getOperation(), operands, *getTypeConverter(), + cmpfOp.getOperation(), adaptor.getOperands(), *getTypeConverter(), [&](Type llvm1DVectorTy, ValueRange operands) { - CmpFOpAdaptor transformed(operands); + CmpFOpAdaptor adaptor(operands); return rewriter.create( cmpfOp.getLoc(), llvm1DVectorTy, convertCmpPredicate(cmpfOp.getPredicate()), - transformed.lhs(), transformed.rhs()); + adaptor.lhs(), adaptor.rhs()); }, rewriter); } @@ -774,10 +766,10 @@ using Super = OneToOneLLVMTerminatorLowering; LogicalResult - matchAndRewrite(SourceOp op, ArrayRef operands, + matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands, op->getSuccessors(), - op->getAttrs()); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands(), + op->getSuccessors(), op->getAttrs()); return success(); } }; @@ -792,7 +784,7 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(ReturnOp op, ArrayRef operands, + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); unsigned numArguments = op.getNumOperands(); @@ -801,7 +793,7 @@ if (getTypeConverter()->getOptions().useBarePtrCallConv) { // For the bare-ptr calling convention, extract the aligned pointer to // be returned from the memref descriptor. - for (auto it : llvm::zip(op->getOperands(), operands)) { + for (auto it : llvm::zip(op->getOperands(), adaptor.getOperands())) { Type oldTy = std::get<0>(it).getType(); Value newOperand = std::get<1>(it); if (oldTy.isa()) { @@ -815,7 +807,7 @@ updatedOperands.push_back(newOperand); } } else { - updatedOperands = llvm::to_vector<4>(operands); + updatedOperands = llvm::to_vector<4>(adaptor.getOperands()); (void)copyUnrankedDescriptors(rewriter, loc, op.getOperands().getTypes(), updatedOperands, /*toDynamic=*/true); @@ -870,14 +862,12 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(SplatOp splatOp, ArrayRef operands, + matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() != 1) return failure(); - SplatOp::Adaptor adaptor(operands); - // First insert it into an undef vector so we can shuffle it. auto vectorType = typeConverter->convertType(splatOp.getType()); Value undef = rewriter.create(splatOp.getLoc(), vectorType); @@ -907,9 +897,8 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(SplatOp splatOp, ArrayRef operands, + matchAndRewrite(SplatOp splatOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - SplatOp::Adaptor adaptor(operands); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() == 1) return failure(); @@ -984,14 +973,13 @@ using Base::Base; LogicalResult - matchAndRewrite(AtomicRMWOp atomicOp, ArrayRef operands, + matchAndRewrite(AtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(match(atomicOp))) return failure(); auto maybeKind = matchSimpleAtomicOp(atomicOp); if (!maybeKind) return failure(); - AtomicRMWOp::Adaptor adaptor(operands); auto resultType = adaptor.value().getType(); auto memRefType = atomicOp.getMemRefType(); auto dataPtr = @@ -1036,11 +1024,10 @@ using Base::Base; LogicalResult - matchAndRewrite(GenericAtomicRMWOp atomicOp, ArrayRef operands, + matchAndRewrite(GenericAtomicRMWOp atomicOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = atomicOp.getLoc(); - GenericAtomicRMWOp::Adaptor adaptor(operands); Type valueType = typeConverter->convertType(atomicOp.getResult().getType()); // Split the block into initial, loop, and ending parts. diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -144,9 +144,9 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(StdOp operation, ArrayRef operands, + matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(operands.size() <= 2); + assert(adaptor.getOperands().size() <= 2); auto dstType = this->getTypeConverter()->convertType(operation.getType()); if (!dstType) return failure(); @@ -155,7 +155,8 @@ return operation.emitError( "bitwidth emulation is not implemented yet on unsigned op"); } - rewriter.template replaceOpWithNewOp(operation, dstType, operands); + rewriter.template replaceOpWithNewOp(operation, dstType, + adaptor.getOperands()); return success(); } }; @@ -169,7 +170,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SignedRemIOp remOp, ArrayRef operands, + matchAndRewrite(SignedRemIOp remOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -183,19 +184,19 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(StdOp operation, ArrayRef operands, + matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(operands.size() == 2); + assert(adaptor.getOperands().size() == 2); auto dstType = this->getTypeConverter()->convertType(operation.getResult().getType()); if (!dstType) return failure(); - if (isBoolScalarOrVector(operands.front().getType())) { - rewriter.template replaceOpWithNewOp(operation, dstType, - operands); + if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) { + rewriter.template replaceOpWithNewOp( + operation, dstType, adaptor.getOperands()); } else { - rewriter.template replaceOpWithNewOp(operation, dstType, - operands); + rewriter.template replaceOpWithNewOp( + operation, dstType, adaptor.getOperands()); } return success(); } @@ -208,7 +209,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConstantOp constOp, ArrayRef operands, + matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -218,7 +219,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConstantOp constOp, ArrayRef operands, + matchAndRewrite(ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -228,7 +229,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, + matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -239,7 +240,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, + matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -250,7 +251,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, + matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -260,7 +261,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, + matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -270,7 +271,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, + matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -280,7 +281,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp returnOp, ArrayRef operands, + matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -289,7 +290,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SelectOp op, ArrayRef operands, + matchAndRewrite(SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -299,7 +300,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SplatOp op, ArrayRef operands, + matchAndRewrite(SplatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -310,9 +311,9 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ZeroExtendIOp op, ArrayRef operands, + matchAndRewrite(ZeroExtendIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcType = operands.front().getType(); + auto srcType = adaptor.getOperands().front().getType(); if (!isBoolScalarOrVector(srcType)) return failure(); @@ -322,7 +323,7 @@ Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.template replaceOpWithNewOp( - op, dstType, operands.front(), one, zero); + op, dstType, adaptor.getOperands().front(), one, zero); return success(); } }; @@ -338,7 +339,7 @@ byteCountThreshold(threshold) {} LogicalResult - matchAndRewrite(tensor::ExtractOp extractOp, ArrayRef operands, + matchAndRewrite(tensor::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { TensorType tensorType = extractOp.tensor().getType().cast(); @@ -351,7 +352,6 @@ "exceeding byte count threshold"); Location loc = extractOp.getLoc(); - tensor::ExtractOp::Adaptor adaptor(operands); int64_t rank = tensorType.getRank(); SmallVector strides(rank, 1); @@ -396,7 +396,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(TruncateIOp op, ArrayRef operands, + matchAndRewrite(TruncateIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = this->getTypeConverter()->convertType(op.getResult().getType()); @@ -404,11 +404,11 @@ return failure(); Location loc = op.getLoc(); - auto srcType = operands.front().getType(); + auto srcType = adaptor.getOperands().front().getType(); // Check if (x & 1) == 1. Value mask = spirv::ConstantOp::getOne(srcType, loc, rewriter); - Value maskedSrc = - rewriter.create(loc, srcType, operands[0], mask); + Value maskedSrc = rewriter.create( + loc, srcType, adaptor.getOperands()[0], mask); Value isOne = rewriter.create(loc, maskedSrc, mask); Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); @@ -425,9 +425,9 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(UIToFPOp op, ArrayRef operands, + matchAndRewrite(UIToFPOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto srcType = operands.front().getType(); + auto srcType = adaptor.getOperands().front().getType(); if (!isBoolScalarOrVector(srcType)) return failure(); @@ -437,7 +437,7 @@ Value zero = spirv::ConstantOp::getZero(dstType, loc, rewriter); Value one = spirv::ConstantOp::getOne(dstType, loc, rewriter); rewriter.template replaceOpWithNewOp( - op, dstType, operands.front(), one, zero); + op, dstType, adaptor.getOperands().front(), one, zero); return success(); } }; @@ -449,10 +449,10 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(StdOp operation, ArrayRef operands, + matchAndRewrite(StdOp operation, typename StdOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - assert(operands.size() == 1); - auto srcType = operands.front().getType(); + assert(adaptor.getOperands().size() == 1); + auto srcType = adaptor.getOperands().front().getType(); auto dstType = this->getTypeConverter()->convertType(operation.getResult().getType()); if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType)) @@ -460,10 +460,10 @@ if (dstType == srcType) { // Due to type conversion, we are seeing the same source and target type. // Then we can just erase this operation by forwarding its operand. - rewriter.replaceOp(operation, operands.front()); + rewriter.replaceOp(operation, adaptor.getOperands().front()); } else { rewriter.template replaceOpWithNewOp(operation, dstType, - operands); + adaptor.getOperands()); } return success(); } @@ -475,7 +475,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(XOrOp xorOp, ArrayRef operands, + matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -486,7 +486,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(XOrOp xorOp, ArrayRef operands, + matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -497,10 +497,11 @@ //===----------------------------------------------------------------------===// LogicalResult SignedRemIOpPattern::matchAndRewrite( - SignedRemIOp remOp, ArrayRef operands, + SignedRemIOp remOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - Value result = emulateSignedRemainder(remOp.getLoc(), operands[0], - operands[1], operands[0], rewriter); + Value result = emulateSignedRemainder( + remOp.getLoc(), adaptor.getOperands()[0], adaptor.getOperands()[1], + adaptor.getOperands()[0], rewriter); rewriter.replaceOp(remOp, result); return success(); @@ -514,7 +515,7 @@ // so that the tensor case can be moved to TensorToSPIRV conversion. But, // std.constant is for the standard dialect though. LogicalResult ConstantCompositeOpPattern::matchAndRewrite( - ConstantOp constOp, ArrayRef operands, + ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto srcType = constOp.getType().dyn_cast(); if (!srcType) @@ -599,7 +600,7 @@ //===----------------------------------------------------------------------===// LogicalResult ConstantScalarOpPattern::matchAndRewrite( - ConstantOp constOp, ArrayRef operands, + ConstantOp constOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { Type srcType = constOp.getType(); if (!srcType.isIntOrIndexOrFloat()) @@ -653,16 +654,13 @@ //===----------------------------------------------------------------------===// LogicalResult -CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, +CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - CmpFOpAdaptor cmpFOpOperands(operands); - switch (cmpFOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp(cmpFOp, cmpFOp.getResult().getType(), \ - cmpFOpOperands.lhs(), \ - cmpFOpOperands.rhs()); \ + adaptor.lhs(), adaptor.rhs()); \ return success(); // Ordered. @@ -689,19 +687,17 @@ } LogicalResult CmpFOpNanKernelPattern::matchAndRewrite( - CmpFOp cmpFOp, ArrayRef operands, + CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - CmpFOpAdaptor cmpFOpOperands(operands); - if (cmpFOp.getPredicate() == CmpFPredicate::ORD) { - rewriter.replaceOpWithNewOp(cmpFOp, cmpFOpOperands.lhs(), - cmpFOpOperands.rhs()); + rewriter.replaceOpWithNewOp(cmpFOp, adaptor.lhs(), + adaptor.rhs()); return success(); } if (cmpFOp.getPredicate() == CmpFPredicate::UNO) { - rewriter.replaceOpWithNewOp( - cmpFOp, cmpFOpOperands.lhs(), cmpFOpOperands.rhs()); + rewriter.replaceOpWithNewOp(cmpFOp, adaptor.lhs(), + adaptor.rhs()); return success(); } @@ -709,17 +705,16 @@ } LogicalResult CmpFOpNanNonePattern::matchAndRewrite( - CmpFOp cmpFOp, ArrayRef operands, + CmpFOp cmpFOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (cmpFOp.getPredicate() != CmpFPredicate::ORD && cmpFOp.getPredicate() != CmpFPredicate::UNO) return failure(); - CmpFOpAdaptor cmpFOpOperands(operands); Location loc = cmpFOp.getLoc(); - Value lhsIsNan = rewriter.create(loc, cmpFOpOperands.lhs()); - Value rhsIsNan = rewriter.create(loc, cmpFOpOperands.rhs()); + Value lhsIsNan = rewriter.create(loc, adaptor.lhs()); + Value rhsIsNan = rewriter.create(loc, adaptor.rhs()); Value replace = rewriter.create(loc, lhsIsNan, rhsIsNan); if (cmpFOp.getPredicate() == CmpFPredicate::ORD) @@ -734,10 +729,8 @@ //===----------------------------------------------------------------------===// LogicalResult -BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, +BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - CmpIOpAdaptor cmpIOpOperands(operands); - Type operandType = cmpIOp.lhs().getType(); if (!isBoolScalarOrVector(operandType)) return failure(); @@ -746,8 +739,7 @@ #define DISPATCH(cmpPredicate, spirvOp) \ case cmpPredicate: \ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ - cmpIOpOperands.lhs(), \ - cmpIOpOperands.rhs()); \ + adaptor.lhs(), adaptor.rhs()); \ return success(); DISPATCH(CmpIPredicate::eq, spirv::LogicalEqualOp); @@ -760,10 +752,8 @@ } LogicalResult -CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, +CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - CmpIOpAdaptor cmpIOpOperands(operands); - Type operandType = cmpIOp.lhs().getType(); if (isBoolScalarOrVector(operandType)) return failure(); @@ -777,8 +767,7 @@ "bitwidth emulation is not implemented yet on unsigned op"); \ } \ rewriter.replaceOpWithNewOp(cmpIOp, cmpIOp.getResult().getType(), \ - cmpIOpOperands.lhs(), \ - cmpIOpOperands.rhs()); \ + adaptor.lhs(), adaptor.rhs()); \ return success(); DISPATCH(CmpIPredicate::eq, spirv::IEqualOp); @@ -802,13 +791,14 @@ //===----------------------------------------------------------------------===// LogicalResult -ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, ArrayRef operands, +ReturnOpPattern::matchAndRewrite(ReturnOp returnOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (returnOp.getNumOperands() > 1) return failure(); if (returnOp.getNumOperands() == 1) { - rewriter.replaceOpWithNewOp(returnOp, operands[0]); + rewriter.replaceOpWithNewOp(returnOp, + adaptor.getOperands()[0]); } else { rewriter.replaceOpWithNewOp(returnOp); } @@ -820,12 +810,10 @@ //===----------------------------------------------------------------------===// LogicalResult -SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef operands, +SelectOpPattern::matchAndRewrite(SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - SelectOpAdaptor selectOperands(operands); - rewriter.replaceOpWithNewOp(op, selectOperands.condition(), - selectOperands.true_value(), - selectOperands.false_value()); + rewriter.replaceOpWithNewOp( + op, adaptor.condition(), adaptor.true_value(), adaptor.false_value()); return success(); } @@ -834,12 +822,11 @@ //===----------------------------------------------------------------------===// LogicalResult -SplatPattern::matchAndRewrite(SplatOp op, ArrayRef operands, +SplatPattern::matchAndRewrite(SplatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto dstVecType = op.getType().dyn_cast(); if (!dstVecType || !spirv::CompositeType::isValid(dstVecType)) return failure(); - SplatOp::Adaptor adaptor(operands); SmallVector source(dstVecType.getNumElements(), adaptor.input()); rewriter.replaceOpWithNewOp(op, dstVecType, source); @@ -851,34 +838,35 @@ //===----------------------------------------------------------------------===// LogicalResult -XOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef operands, +XOrOpPattern::matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - assert(operands.size() == 2); + assert(adaptor.getOperands().size() == 2); - if (isBoolScalarOrVector(operands.front().getType())) + if (isBoolScalarOrVector(adaptor.getOperands().front().getType())) return failure(); auto dstType = getTypeConverter()->convertType(xorOp.getType()); if (!dstType) return failure(); - rewriter.replaceOpWithNewOp(xorOp, dstType, operands); + rewriter.replaceOpWithNewOp(xorOp, dstType, + adaptor.getOperands()); return success(); } LogicalResult -BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, ArrayRef operands, +BoolXOrOpPattern::matchAndRewrite(XOrOp xorOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - assert(operands.size() == 2); + assert(adaptor.getOperands().size() == 2); - if (!isBoolScalarOrVector(operands.front().getType())) + if (!isBoolScalarOrVector(adaptor.getOperands().front().getType())) return failure(); auto dstType = getTypeConverter()->convertType(xorOp.getType()); if (!dstType) return failure(); rewriter.replaceOpWithNewOp(xorOp, dstType, - operands); + adaptor.getOperands()); return success(); } diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -947,7 +947,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::Conv2DOp op, ArrayRef args, + matchAndRewrite(tosa::Conv2DOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = op->getLoc(); Value input = op->getOperand(0); @@ -1111,7 +1111,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::DepthwiseConv2DOp op, ArrayRef args, + matchAndRewrite(tosa::DepthwiseConv2DOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = op->getLoc(); Value input = op->getOperand(0); @@ -1266,7 +1266,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::TransposeConv2DOp op, ArrayRef args, + matchAndRewrite(tosa::TransposeConv2DOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = op->getLoc(); Value input = op->getOperand(0); @@ -1336,10 +1336,8 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::MatMulOp op, ArrayRef args, + matchAndRewrite(tosa::MatMulOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - tosa::MatMulOp::Adaptor adaptor(args); - Location loc = op.getLoc(); auto outputTy = op.getType().cast(); @@ -1377,7 +1375,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::FullyConnectedOp op, ArrayRef args, + matchAndRewrite(tosa::FullyConnectedOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); auto outputTy = op.getType().cast(); @@ -1486,15 +1484,13 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::ReshapeOp reshape, ArrayRef args, + matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - typename tosa::ReshapeOp::Adaptor operands(args); - - ShapedType operandTy = operands.input1().getType().cast(); + ShapedType operandTy = adaptor.input1().getType().cast(); ShapedType resultTy = reshape.getType().template cast(); if (operandTy == resultTy) { - rewriter.replaceOp(reshape, args[0]); + rewriter.replaceOp(reshape, adaptor.getOperands()[0]); return success(); } @@ -1575,19 +1571,20 @@ auto collapsedTy = RankedTensorType::get({totalElems}, elemTy); Value collapsedOp = rewriter.create( - loc, collapsedTy, args[0], collapsingMap); + loc, collapsedTy, adaptor.getOperands()[0], collapsingMap); rewriter.replaceOpWithNewOp( reshape, resultTy, collapsedOp, expandingMap); return success(); } - if (resultTy.getRank() < args[0].getType().cast().getRank()) + if (resultTy.getRank() < + adaptor.getOperands()[0].getType().cast().getRank()) rewriter.replaceOpWithNewOp( - reshape, resultTy, args[0], reassociationMap); + reshape, resultTy, adaptor.getOperands()[0], reassociationMap); else rewriter.replaceOpWithNewOp( - reshape, resultTy, args[0], reassociationMap); + reshape, resultTy, adaptor.getOperands()[0], reassociationMap); return success(); } @@ -2117,7 +2114,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::ConcatOp op, ArrayRef args, + matchAndRewrite(tosa::ConcatOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto resultType = op.getType().dyn_cast(); if (!resultType || !resultType.hasStaticShape()) { @@ -2136,11 +2133,12 @@ offsets.resize(rank, rewriter.create(loc, 0)); for (int i = 0; i < rank; ++i) { - sizes.push_back(rewriter.create(loc, args[0], i)); + sizes.push_back( + rewriter.create(loc, adaptor.getOperands()[0], i)); } Value resultDimSize = sizes[axis]; - for (auto arg : args.drop_front()) { + for (auto arg : adaptor.getOperands().drop_front()) { auto size = rewriter.create(loc, arg, axisValue); resultDimSize = rewriter.create(loc, resultDimSize, size); } @@ -2154,7 +2152,7 @@ Value result = rewriter.create(loc, zeroVal, init).getResult(0); - for (auto arg : args) { + for (auto arg : adaptor.getOperands()) { sizes[axis] = rewriter.create(loc, arg, axisValue); result = rewriter.create(loc, arg, result, offsets, sizes, strides); @@ -2230,7 +2228,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::TileOp op, ArrayRef args, + matchAndRewrite(tosa::TileOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); auto input = op.input1(); @@ -2488,10 +2486,10 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tosa::GatherOp op, ArrayRef args, + matchAndRewrite(tosa::GatherOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - auto input = args[0]; - auto indices = args[1]; + auto input = adaptor.getOperands()[0]; + auto indices = adaptor.getOperands()[1]; auto inputTy = input.getType().cast(); auto indicesTy = indices.getType().cast(); diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -36,13 +36,12 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::BitCastOp bitcastOp, ArrayRef operands, + matchAndRewrite(vector::BitCastOp bitcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = getTypeConverter()->convertType(bitcastOp.getType()); if (!dstType) return failure(); - vector::BitCastOp::Adaptor adaptor(operands); if (dstType == adaptor.source().getType()) rewriter.replaceOp(bitcastOp, adaptor.source()); else @@ -58,12 +57,11 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::BroadcastOp broadcastOp, ArrayRef operands, + matchAndRewrite(vector::BroadcastOp broadcastOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (broadcastOp.source().getType().isa() || !spirv::CompositeType::isValid(broadcastOp.getVectorType())) return failure(); - vector::BroadcastOp::Adaptor adaptor(operands); SmallVector source(broadcastOp.getVectorType().getNumElements(), adaptor.source()); rewriter.replaceOpWithNewOp( @@ -77,7 +75,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::ExtractOp extractOp, ArrayRef operands, + matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Only support extracting a scalar value now. VectorType resultVectorType = extractOp.getType().dyn_cast(); @@ -88,7 +86,6 @@ if (!dstType) return failure(); - vector::ExtractOp::Adaptor adaptor(operands); if (adaptor.vector().getType().isa()) { rewriter.replaceOp(extractOp, adaptor.vector()); return success(); @@ -106,8 +103,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::ExtractStridedSliceOp extractOp, - ArrayRef operands, + matchAndRewrite(vector::ExtractStridedSliceOp extractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto dstType = getTypeConverter()->convertType(extractOp.getType()); if (!dstType) @@ -120,7 +116,7 @@ if (stride != 1) return failure(); - Value srcVector = operands.front(); + Value srcVector = adaptor.getOperands().front(); // Extract vector<1xT> case. if (dstType.isa()) { @@ -144,11 +140,10 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::FMAOp fmaOp, ArrayRef operands, + matchAndRewrite(vector::FMAOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!spirv::CompositeType::isValid(fmaOp.getVectorType())) return failure(); - vector::FMAOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp( fmaOp, fmaOp.getType(), adaptor.lhs(), adaptor.rhs(), adaptor.acc()); return success(); @@ -160,12 +155,11 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::InsertOp insertOp, ArrayRef operands, + matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (insertOp.getSourceType().isa() || !spirv::CompositeType::isValid(insertOp.getDestVectorType())) return failure(); - vector::InsertOp::Adaptor adaptor(operands); int32_t id = getFirstIntValue(insertOp.position()); rewriter.replaceOpWithNewOp( insertOp, adaptor.source(), adaptor.dest(), id); @@ -178,12 +172,10 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::ExtractElementOp extractElementOp, - ArrayRef operands, + matchAndRewrite(vector::ExtractElementOp extractElementOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!spirv::CompositeType::isValid(extractElementOp.getVectorType())) return failure(); - vector::ExtractElementOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp( extractElementOp, extractElementOp.getType(), adaptor.vector(), extractElementOp.position()); @@ -196,12 +188,10 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::InsertElementOp insertElementOp, - ArrayRef operands, + matchAndRewrite(vector::InsertElementOp insertElementOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!spirv::CompositeType::isValid(insertElementOp.getDestVectorType())) return failure(); - vector::InsertElementOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp( insertElementOp, insertElementOp.getType(), insertElementOp.dest(), adaptor.source(), insertElementOp.position()); @@ -214,11 +204,10 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::InsertStridedSliceOp insertOp, - ArrayRef operands, + matchAndRewrite(vector::InsertStridedSliceOp insertOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Value srcVector = operands.front(); - Value dstVector = operands.back(); + Value srcVector = adaptor.getOperands().front(); + Value dstVector = adaptor.getOperands().back(); // Insert scalar values not supported yet. if (srcVector.getType().isa() || diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp @@ -84,7 +84,7 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TileZeroOp op, ArrayRef operands, + matchAndRewrite(TileZeroOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { VectorType vType = op.getVectorType(); // Determine m x n tile sizes. @@ -102,9 +102,8 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TileLoadOp op, ArrayRef operands, + matchAndRewrite(TileLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TileLoadOp::Adaptor adaptor(operands); MemRefType mType = op.getMemRefType(); VectorType vType = op.getVectorType(); // Determine m x n tile sizes. @@ -130,9 +129,8 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TileStoreOp op, ArrayRef operands, + matchAndRewrite(TileStoreOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TileStoreOp::Adaptor adaptor(operands); MemRefType mType = op.getMemRefType(); VectorType vType = op.getVectorType(); // Determine m x n tile sizes. @@ -156,9 +154,8 @@ struct TileMulFConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TileMulFOp op, ArrayRef operands, + matchAndRewrite(TileMulFOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TileMulFOp::Adaptor adaptor(operands); VectorType aType = op.getLhsVectorType(); VectorType bType = op.getRhsVectorType(); VectorType cType = op.getVectorType(); @@ -179,9 +176,8 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(TileMulIOp op, ArrayRef operands, + matchAndRewrite(TileMulIOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - TileMulIOp::Adaptor adaptor(operands); VectorType aType = op.getLhsVectorType(); VectorType bType = op.getRhsVectorType(); VectorType cType = op.getVectorType(); diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp @@ -46,12 +46,13 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(OpTy op, ArrayRef operands, + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const final { - if (ValueRange(operands).getTypes() == op->getOperands().getTypes()) + if (adaptor.getOperands().getTypes() == op->getOperands().getTypes()) return rewriter.notifyMatchFailure(op, "operand types already match"); - rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); + rewriter.updateRootInPlace( + op, [&]() { op->setOperands(adaptor.getOperands()); }); return success(); } }; @@ -61,9 +62,10 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, ArrayRef operands, + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); + rewriter.updateRootInPlace( + op, [&]() { op->setOperands(adaptor.getOperands()); }); return success(); } }; @@ -118,13 +120,12 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(ScalableLoadOp loadOp, ArrayRef operands, + matchAndRewrite(ScalableLoadOp loadOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = loadOp.getMemRefType(); if (!isConvertibleAndHasIdentityMaps(type)) return failure(); - ScalableLoadOp::Adaptor transformed(operands); LLVMTypeConverter converter(loadOp.getContext()); auto resultType = loadOp.result().getType(); @@ -138,9 +139,8 @@ converter) .getValue()); } - Value dataPtr = - getStridedElementPtr(loadOp.getLoc(), type, transformed.base(), - transformed.index(), rewriter); + Value dataPtr = getStridedElementPtr(loadOp.getLoc(), type, adaptor.base(), + adaptor.index(), rewriter); Value bitCastedPtr = rewriter.create( loadOp.getLoc(), llvmDataTypePtr, dataPtr); rewriter.replaceOpWithNewOp(loadOp, bitCastedPtr); @@ -155,13 +155,12 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(ScalableStoreOp storeOp, ArrayRef operands, + matchAndRewrite(ScalableStoreOp storeOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = storeOp.getMemRefType(); if (!isConvertibleAndHasIdentityMaps(type)) return failure(); - ScalableStoreOp::Adaptor transformed(operands); LLVMTypeConverter converter(storeOp.getContext()); auto resultType = storeOp.value().getType(); @@ -175,12 +174,11 @@ converter) .getValue()); } - Value dataPtr = - getStridedElementPtr(storeOp.getLoc(), type, transformed.base(), - transformed.index(), rewriter); + Value dataPtr = getStridedElementPtr(storeOp.getLoc(), type, adaptor.base(), + adaptor.index(), rewriter); Value bitCastedPtr = rewriter.create( storeOp.getLoc(), llvmDataTypePtr, dataPtr); - rewriter.replaceOpWithNewOp(storeOp, transformed.value(), + rewriter.replaceOpWithNewOp(storeOp, adaptor.value(), bitCastedPtr); return success(); } diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp --- a/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp +++ b/mlir/lib/Dialect/Async/Transforms/AsyncToAsyncRuntime.cpp @@ -337,10 +337,10 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(CreateGroupOp op, ArrayRef operands, + matchAndRewrite(CreateGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( - op, GroupType::get(op->getContext()), operands); + op, GroupType::get(op->getContext()), adaptor.getOperands()); return success(); } }; @@ -356,10 +356,10 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AddToGroupOp op, ArrayRef operands, + matchAndRewrite(AddToGroupOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp( - op, rewriter.getIndexType(), operands); + op, rewriter.getIndexType(), adaptor.getOperands()); return success(); } }; @@ -382,7 +382,7 @@ outlinedFunctions(outlinedFunctions) {} LogicalResult - matchAndRewrite(AwaitType op, ArrayRef operands, + matchAndRewrite(AwaitType op, typename AwaitType::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { // We can only await on one the `AwaitableType` (for `await` it can be // a `token` or a `value`, for `await_all` it must be a `group`). @@ -395,7 +395,7 @@ const bool isInCoroutine = outlined != outlinedFunctions.end(); Location loc = op->getLoc(); - Value operand = AwaitAdaptor(operands).operand(); + Value operand = adaptor.operand(); Type i1 = rewriter.getI1Type(); @@ -520,7 +520,7 @@ outlinedFunctions(outlinedFunctions) {} LogicalResult - matchAndRewrite(async::YieldOp op, ArrayRef operands, + matchAndRewrite(async::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check if yield operation is inside the async coroutine function. auto func = op->template getParentOfType(); @@ -534,7 +534,7 @@ // Store yielded values into the async values storage and switch async // values state to available. - for (auto tuple : llvm::zip(operands, coro.returnValues)) { + for (auto tuple : llvm::zip(adaptor.getOperands(), coro.returnValues)) { Value yieldValue = std::get<0>(tuple); Value asyncValue = std::get<1>(tuple); rewriter.create(loc, yieldValue, asyncValue); @@ -563,7 +563,7 @@ outlinedFunctions(outlinedFunctions) {} LogicalResult - matchAndRewrite(AssertOp op, ArrayRef operands, + matchAndRewrite(AssertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check if assert operation is inside the async coroutine function. auto func = op->template getParentOfType(); @@ -577,7 +577,7 @@ Block *cont = rewriter.splitBlock(op->getBlock(), Block::iterator(op)); rewriter.setInsertionPointToEnd(cont->getPrevNode()); - rewriter.create(loc, AssertOpAdaptor(operands).arg(), + rewriter.create(loc, adaptor.arg(), /*trueDest=*/cont, /*trueArgs=*/ArrayRef(), /*falseDest=*/setupSetErrorBlock(coro), diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -104,9 +104,8 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(InitTensorOp op, ArrayRef operands, + matchAndRewrite(InitTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - linalg::InitTensorOpAdaptor adaptor(operands, op->getAttrDictionary()); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()).cast(), adaptor.sizes()); @@ -126,9 +125,8 @@ memref::ExpandShapeOp, memref::CollapseShapeOp>; LogicalResult - matchAndRewrite(TensorReshapeOp op, ArrayRef operands, + matchAndRewrite(TensorReshapeOp op, Adaptor adaptor, ConversionPatternRewriter &rewriter) const final { - Adaptor adaptor(operands, op->getAttrDictionary()); rewriter.replaceOpWithNewOp(op, this->getTypeConverter() ->convertType(op.getType()) @@ -145,9 +143,8 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(FillOp op, ArrayRef operands, + matchAndRewrite(FillOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - linalg::FillOpAdaptor adaptor(operands, op->getAttrDictionary()); if (!op.output().getType().isa()) return rewriter.notifyMatchFailure(op, "operand must be of a tensor type"); @@ -208,9 +205,8 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::ExtractSliceOp op, ArrayRef operands, + matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - tensor::ExtractSliceOpAdaptor adaptor(operands, op->getAttrDictionary()); Value sourceMemref = adaptor.source(); assert(sourceMemref.getType().isa()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -60,7 +60,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(GenericOp op, ArrayRef operands, + matchAndRewrite(GenericOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Block *originalBlock = op->getBlock(); @@ -78,7 +78,7 @@ rewriter.replaceOp(op, yieldOp->getOperands()); // No need for these intermediate blocks, merge them into 1. - rewriter.mergeBlocks(opEntryBlock, originalBlock, operands); + rewriter.mergeBlocks(opEntryBlock, originalBlock, adaptor.getOperands()); rewriter.mergeBlocks(newBlock, originalBlock, {}); rewriter.eraseOp(&*Block::iterator(yieldOp)); diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp --- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp @@ -21,7 +21,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ForOp op, ArrayRef operands, + matchAndRewrite(ForOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { SmallVector newResultTypes; for (auto type : op.getResultTypes()) { @@ -63,7 +63,7 @@ } // Change the clone to use the updated operands. We could have cloned with // a BlockAndValueMapping, but this seems a bit more direct. - newOp->setOperands(operands); + newOp->setOperands(adaptor.getOperands()); // Update the result types to the new converted types. for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) std::get<0>(t).setType(std::get<1>(t)); @@ -79,7 +79,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(IfOp op, ArrayRef operands, + matchAndRewrite(IfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // TODO: Generalize this to any type conversion, not just 1:1. // @@ -108,7 +108,7 @@ newOp.elseRegion().end()); // Update the operands and types. - newOp->setOperands(operands); + newOp->setOperands(adaptor.getOperands()); for (auto t : llvm::zip(newOp.getResults(), newResultTypes)) std::get<0>(t).setType(std::get<1>(t)); rewriter.replaceOp(op, newOp.getResults()); @@ -125,9 +125,9 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(scf::YieldOp op, ArrayRef operands, + matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; @@ -139,7 +139,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(WhileOp op, ArrayRef operands, + matchAndRewrite(WhileOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto *converter = getTypeConverter(); assert(converter); @@ -147,7 +147,6 @@ if (failed(converter->convertTypes(op.getResultTypes(), newResultTypes))) return failure(); - WhileOp::Adaptor adaptor(operands); auto newOp = rewriter.create(op.getLoc(), newResultTypes, adaptor.getOperands()); for (auto i : {0u, 1u}) { @@ -167,9 +166,10 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConditionOp op, ArrayRef operands, + matchAndRewrite(ConditionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.updateRootInPlace(op, [&]() { op->setOperands(operands); }); + rewriter.updateRootInPlace( + op, [&]() { op->setOperands(adaptor.getOperands()); }); return success(); } }; diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -156,7 +156,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(spirv::FuncOp funcOp, ArrayRef operands, + matchAndRewrite(spirv::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; @@ -168,7 +168,7 @@ } // namespace LogicalResult ProcessInterfaceVarABI::matchAndRewrite( - spirv::FuncOp funcOp, ArrayRef operands, + spirv::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { if (!funcOp->getAttrOfType( spirv::getEntryPointABIAttrName())) { diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -541,13 +541,13 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(FuncOp funcOp, ArrayRef operands, + matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override; }; } // namespace LogicalResult -FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef operands, +FuncOpConversion::matchAndRewrite(FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { auto fnType = funcOp.getType(); if (fnType.getNumResults() > 1) diff --git a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp --- a/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/StructuralTypeConversions.cpp @@ -20,7 +20,7 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AssumingOp op, ArrayRef operands, + matchAndRewrite(AssumingOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { SmallVector newResultTypes; newResultTypes.reserve(op.getNumResults()); @@ -48,9 +48,9 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AssumingYieldOp op, ArrayRef operands, + matchAndRewrite(AssumingYieldOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { - rewriter.replaceOpWithNewOp(op, operands); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -249,9 +249,9 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, ArrayRef operands, + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, operands); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; @@ -262,7 +262,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::DimOp op, ArrayRef operands, + matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); auto enc = getSparseTensorEncoding(op.source().getType()); @@ -278,7 +278,7 @@ // Generate the call. StringRef name = "sparseDimSize"; SmallVector params; - params.push_back(operands[0]); + params.push_back(adaptor.getOperands()[0]); params.push_back( rewriter.create(op.getLoc(), rewriter.getIndexAttr(idx))); rewriter.replaceOpWithNewOp( @@ -291,14 +291,15 @@ class SparseTensorNewConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(NewOp op, ArrayRef operands, + matchAndRewrite(NewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); auto enc = getSparseTensorEncoding(resType); if (!enc) return failure(); Value perm; - rewriter.replaceOp(op, genNewCall(rewriter, op, enc, 0, perm, operands[0])); + rewriter.replaceOp( + op, genNewCall(rewriter, op, enc, 0, perm, adaptor.getOperands()[0])); return success(); } }; @@ -307,7 +308,7 @@ class SparseTensorConvertConverter : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ConvertOp op, ArrayRef operands, + matchAndRewrite(ConvertOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); auto encDst = getSparseTensorEncoding(resType); @@ -320,7 +321,8 @@ // yield the fastest conversion but avoids the need for a full // O(N^2) conversion matrix. Value perm; - Value coo = genNewCall(rewriter, op, encDst, 3, perm, operands[0]); + Value coo = + genNewCall(rewriter, op, encDst, 3, perm, adaptor.getOperands()[0]); rewriter.replaceOp(op, genNewCall(rewriter, op, encDst, 1, perm, coo)); return success(); } @@ -349,7 +351,7 @@ MemRefType::get({ShapedType::kDynamicSize}, rewriter.getIndexType()); Value perm; Value ptr = genNewCall(rewriter, op, encDst, 2, perm); - Value tensor = operands[0]; + Value tensor = adaptor.getOperands()[0]; Value arg = rewriter.create( loc, rewriter.getIndexAttr(shape.getRank())); Value ind = rewriter.create(loc, memTp, ValueRange{arg}); @@ -381,7 +383,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToPointersOp op, ArrayRef operands, + matchAndRewrite(ToPointersOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); Type eltType = resType.cast().getElementType(); @@ -398,10 +400,11 @@ name = "sparsePointers8"; else return failure(); - rewriter.replaceOpWithNewOp( - op, resType, - getFunc(op, name, resType, operands, /*emitCInterface=*/true), - operands); + rewriter.replaceOpWithNewOp(op, resType, + getFunc(op, name, resType, + adaptor.getOperands(), + /*emitCInterface=*/true), + adaptor.getOperands()); return success(); } }; @@ -411,7 +414,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToIndicesOp op, ArrayRef operands, + matchAndRewrite(ToIndicesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); Type eltType = resType.cast().getElementType(); @@ -428,10 +431,11 @@ name = "sparseIndices8"; else return failure(); - rewriter.replaceOpWithNewOp( - op, resType, - getFunc(op, name, resType, operands, /*emitCInterface=*/true), - operands); + rewriter.replaceOpWithNewOp(op, resType, + getFunc(op, name, resType, + adaptor.getOperands(), + /*emitCInterface=*/true), + adaptor.getOperands()); return success(); } }; @@ -441,7 +445,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ToValuesOp op, ArrayRef operands, + matchAndRewrite(ToValuesOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type resType = op.getType(); Type eltType = resType.cast().getElementType(); @@ -460,10 +464,11 @@ name = "sparseValuesI8"; else return failure(); - rewriter.replaceOpWithNewOp( - op, resType, - getFunc(op, name, resType, operands, /*emitCInterface=*/true), - operands); + rewriter.replaceOpWithNewOp(op, resType, + getFunc(op, name, resType, + adaptor.getOperands(), + /*emitCInterface=*/true), + adaptor.getOperands()); return success(); } }; @@ -474,12 +479,12 @@ using OpConversionPattern::OpConversionPattern; LogicalResult // Simply fold the operator into the pointer to the sparse storage scheme. - matchAndRewrite(ToTensorOp op, ArrayRef operands, + matchAndRewrite(ToTensorOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Check that all arguments of the tensor reconstruction operators are calls // into the support library that query exactly the same opaque pointer. Value ptr; - for (Value op : operands) { + for (Value op : adaptor.getOperands()) { if (auto call = op.getDefiningOp()) { Value arg = call.getOperand(0); if (!arg.getType().isa()) diff --git a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/Bufferize.cpp @@ -27,9 +27,8 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(IndexCastOp op, ArrayRef operands, + matchAndRewrite(IndexCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - IndexCastOp::Adaptor adaptor(operands); auto tensorType = op.getType().cast(); rewriter.replaceOpWithNewOp( op, adaptor.in(), @@ -42,12 +41,11 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(SelectOp op, ArrayRef operands, + matchAndRewrite(SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!op.condition().getType().isa()) return rewriter.notifyMatchFailure(op, "requires scalar condition"); - SelectOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp( op, adaptor.condition(), adaptor.true_value(), adaptor.false_value()); return success(); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp b/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/DecomposeCallGraphTypes.cpp @@ -61,7 +61,7 @@ DecomposeCallGraphTypesOpConversionPattern; LogicalResult - matchAndRewrite(FuncOp op, ArrayRef operands, + matchAndRewrite(FuncOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { auto functionType = op.getType(); @@ -106,10 +106,10 @@ using DecomposeCallGraphTypesOpConversionPattern:: DecomposeCallGraphTypesOpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, ArrayRef operands, + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { SmallVector newOperands; - for (Value operand : operands) + for (Value operand : adaptor.getOperands()) decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(), operand, newOperands); rewriter.replaceOpWithNewOp(op, newOperands); @@ -131,12 +131,12 @@ DecomposeCallGraphTypesOpConversionPattern; LogicalResult - matchAndRewrite(CallOp op, ArrayRef operands, + matchAndRewrite(CallOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { // Create the operands list of the new `CallOp`. SmallVector newOperands; - for (Value operand : operands) + for (Value operand : adaptor.getOperands()) decomposer.decomposeValue(rewriter, op.getLoc(), operand.getType(), operand, newOperands); diff --git a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/FuncConversions.cpp @@ -20,7 +20,7 @@ /// Hook for derived classes to implement combined matching and rewriting. LogicalResult - matchAndRewrite(CallOp callOp, ArrayRef operands, + matchAndRewrite(CallOp callOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { // Convert the original function results. SmallVector convertedResults; @@ -30,8 +30,8 @@ // Substitute with the new result types from the corresponding FuncType // conversion. - rewriter.replaceOpWithNewOp(callOp, callOp.callee(), - convertedResults, operands); + rewriter.replaceOpWithNewOp( + callOp, callOp.callee(), convertedResults, adaptor.getOperands()); return success(); } }; @@ -96,13 +96,12 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(ReturnOp op, ArrayRef operands, + matchAndRewrite(ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { // For a return, all operands go to the results of the parent, so // rewrite them all. - Operation *operation = op.getOperation(); - rewriter.updateRootInPlace( - op, [operands, operation]() { operation->setOperands(operands); }); + rewriter.updateRootInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); return success(); } }; diff --git a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp --- a/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp +++ b/mlir/lib/Dialect/StandardOps/Transforms/TensorConstantBufferize.cpp @@ -66,7 +66,7 @@ globals(globals) {} LogicalResult - matchAndRewrite(ConstantOp op, ArrayRef operands, + matchAndRewrite(ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto type = op.getType().dyn_cast(); if (!type) diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -26,10 +26,11 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::CastOp op, ArrayRef operands, + matchAndRewrite(tensor::CastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto resultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, resultType, operands[0]); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getOperands()[0]); return success(); } }; @@ -40,9 +41,8 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::DimOp op, ArrayRef operands, + matchAndRewrite(tensor::DimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - tensor::DimOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp(op, adaptor.source(), adaptor.index()); return success(); @@ -55,9 +55,8 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::ExtractOp op, ArrayRef operands, + matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - tensor::ExtractOp::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp(op, adaptor.tensor(), adaptor.indices()); return success(); @@ -71,7 +70,7 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::FromElementsOp op, ArrayRef operands, + matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { int numberOfElements = op.elements().size(); auto resultType = MemRefType::get( @@ -95,16 +94,15 @@ using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(tensor::GenerateOp op, ArrayRef operands, + matchAndRewrite(tensor::GenerateOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { // Allocate memory. Location loc = op.getLoc(); - tensor::GenerateOp::Adaptor transformed(operands); RankedTensorType tensorType = op.getType().cast(); MemRefType memrefType = MemRefType::get(tensorType.getShape(), tensorType.getElementType()); - Value result = rewriter.create( - loc, memrefType, transformed.dynamicExtents()); + Value result = rewriter.create(loc, memrefType, + adaptor.dynamicExtents()); // Collect loop bounds. int64_t rank = tensorType.getRank(); @@ -117,7 +115,7 @@ for (int i = 0; i < rank; i++) { Value upperBound = tensorType.isDynamicDim(i) - ? transformed.dynamicExtents()[nextDynamicIndex++] + ? adaptor.dynamicExtents()[nextDynamicIndex++] : rewriter.create(loc, memrefType.getDimSize(i)); upperBounds.push_back(upperBound); } diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp --- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp +++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp @@ -46,18 +46,18 @@ } LogicalResult - matchAndRewrite(OpTy op, ArrayRef operands, + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type elementType = getSrcVectorElementType(op); unsigned bitwidth = elementType.getIntOrFloatBitWidth(); if (bitwidth == 32) return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(), - operands, getTypeConverter(), - rewriter); + adaptor.getOperands(), + getTypeConverter(), rewriter); if (bitwidth == 64) return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(), - operands, getTypeConverter(), - rewriter); + adaptor.getOperands(), + getTypeConverter(), rewriter); return rewriter.notifyMatchFailure( op, "expected 'src' to be either f32 or f64"); } @@ -68,9 +68,8 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(MaskCompressOp op, ArrayRef operands, + matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - MaskCompressOp::Adaptor adaptor(operands); auto opType = adaptor.a().getType(); Value src; @@ -95,10 +94,8 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(RsqrtOp op, ArrayRef operands, + matchAndRewrite(RsqrtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RsqrtOp::Adaptor adaptor(operands); - auto opType = adaptor.a().getType(); rewriter.replaceOpWithNewOp(op, opType, adaptor.a()); return success(); @@ -109,9 +106,8 @@ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; LogicalResult - matchAndRewrite(DotOp op, ArrayRef operands, + matchAndRewrite(DotOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - DotOp::Adaptor adaptor(operands); auto opType = adaptor.a().getType(); Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8); // Dot product of all elements, broadcasted to all elements. diff --git a/mlir/lib/Transforms/Bufferize.cpp b/mlir/lib/Transforms/Bufferize.cpp --- a/mlir/lib/Transforms/Bufferize.cpp +++ b/mlir/lib/Transforms/Bufferize.cpp @@ -58,9 +58,8 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::TensorLoadOp op, ArrayRef operands, + matchAndRewrite(memref::TensorLoadOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - memref::TensorLoadOp::Adaptor adaptor(operands); rewriter.replaceOp(op, adaptor.memref()); return success(); } @@ -74,9 +73,8 @@ public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(memref::BufferCastOp op, ArrayRef operands, + matchAndRewrite(memref::BufferCastOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - memref::BufferCastOp::Adaptor adaptor(operands); rewriter.replaceOp(op, adaptor.tensor()); return success(); }