diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -848,9 +848,8 @@ `.lhs()` to access the first operand and `.rhs()` to access the second operand. The operand adaptor class lives in the same namespace as the operation class, -and has the name of the operation followed by `OperandAdaptor`. A template -declaration `OperandAdaptor<>` is provided to look up the operand adaptor for -the given operation. +and has the name of the operation followed by `Adaptor` as well as an alias +`Adaptor` inside the op class. Operand adaptors can be used in function templates that also process operations: @@ -862,7 +861,7 @@ void process(AddOp op, ArrayRef newOperands) { zip(op); - zip(OperandAdaptor(newOperands)); + zip(Adaptor(newOperands)); /*...*/ } ``` diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md --- a/mlir/docs/Tutorials/Toy/Ch-5.md +++ b/mlir/docs/Tutorials/Toy/Ch-5.md @@ -124,7 +124,7 @@ // This allows for using the nice named accessors that are generated // by the ODS. This adaptor is automatically provided by the ODS // framework. - TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); + TransposeOpAdaptor transposeAdaptor(memRefOperands); mlir::Value input = transposeAdaptor.input(); // Transpose the elements by generating a load from the reverse diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -110,7 +110,7 @@ // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. - typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands); + typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); // Generate loads for the element of 'lhs' and 'rhs' at the inner // loop. @@ -234,7 +234,7 @@ // Generate an adaptor for the remapped operands of the TransposeOp. // This allows for using the nice named accessors that are generated // by the ODS. - toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); + toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); Value input = transposeAdaptor.input(); // Transpose the elements by generating a load from the reverse diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -110,7 +110,7 @@ // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. - typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands); + typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); // Generate loads for the element of 'lhs' and 'rhs' at the inner // loop. @@ -233,7 +233,7 @@ // Generate an adaptor for the remapped operands of the TransposeOp. // This allows for using the nice named accessors that are generated // by the ODS. - toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); + toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); Value input = transposeAdaptor.input(); // Transpose the elements by generating a load from the reverse diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -110,7 +110,7 @@ // Generate an adaptor for the remapped operands of the BinaryOp. This // allows for using the nice named accessors that are generated by the // ODS. - typename BinaryOp::OperandAdaptor binaryAdaptor(memRefOperands); + typename BinaryOp::Adaptor binaryAdaptor(memRefOperands); // Generate loads for the element of 'lhs' and 'rhs' at the inner // loop. @@ -234,7 +234,7 @@ // Generate an adaptor for the remapped operands of the TransposeOp. // This allows for using the nice named accessors that are generated // by the ODS. - toy::TransposeOpOperandAdaptor transposeAdaptor(memRefOperands); + toy::TransposeOpAdaptor transposeAdaptor(memRefOperands); Value input = transposeAdaptor.input(); // Transpose the elements by generating a load from the reverse diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -47,14 +47,6 @@ class ValueRange; template class ValueTypeRange; -/// This is an adaptor from a list of values to named operands of OpTy. In a -/// generic operation context, e.g., in dialect conversions, an ordered array of -/// `Value`s is treated as operands of `OpTy`. This adaptor takes a reference -/// to the array and provides accessors with the same names as `OpTy` for -/// operands. This makes possible to create function templates that operate on -/// either OpTy or OperandAdaptor seamlessly. -template using OperandAdaptor = typename OpTy::OperandAdaptor; - class OwningRewritePatternList; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -56,7 +56,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); - gpu::ShuffleOpOperandAdaptor adaptor(operands); + gpu::ShuffleOpAdaptor adaptor(operands); auto dialect = typeConverter.getDialect(); auto valueTy = adaptor.value().getType().cast(); diff --git a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp --- a/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp +++ b/mlir/lib/Conversion/GPUToSPIRV/ConvertGPUToSPIRV.cpp @@ -140,7 +140,7 @@ // latch and the merge block the exit block. The resulting spirv::LoopOp has a // single back edge from the continue to header block, and a single exit from // header to merge. - scf::ForOpOperandAdaptor forOperands(operands); + scf::ForOpAdaptor forOperands(operands); auto loc = forOp.getLoc(); auto loopControl = rewriter.getI32IntegerAttr( static_cast(spirv::LoopControl::None)); @@ -211,7 +211,7 @@ // When lowering `scf::IfOp` we explicitly create a selection header block // before the control flow diverges and a merge block where control flow // subsequently converges. - scf::IfOpOperandAdaptor ifOperands(operands); + scf::IfOpAdaptor ifOperands(operands); auto loc = ifOp.getLoc(); // Create `spv.selection` operation, selection header block and merge block. diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -140,7 +140,7 @@ edsc::ScopedContext context(rewriter, op->getLoc()); // Fill in an aggregate value of the descriptor. - RangeOpOperandAdaptor adaptor(operands); + RangeOpAdaptor adaptor(operands); Value desc = llvm_undef(rangeDescriptorTy); desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); @@ -178,7 +178,7 @@ return failure(); edsc::ScopedContext context(rewriter, op->getLoc()); - ReshapeOpOperandAdaptor adaptor(operands); + ReshapeOpAdaptor adaptor(operands); BaseViewConversionHelper baseDesc(adaptor.src()); BaseViewConversionHelper desc(typeConverter.convertType(dstType)); desc.setAllocatedPtr(baseDesc.allocatedPtr()); @@ -208,7 +208,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { edsc::ScopedContext context(rewriter, op->getLoc()); - SliceOpOperandAdaptor adaptor(operands); + SliceOpAdaptor adaptor(operands); BaseViewConversionHelper baseDesc(adaptor.view()); auto sliceOp = cast(op); @@ -302,7 +302,7 @@ ConversionPatternRewriter &rewriter) const override { // Initialize the common boilerplate and alloca at the top of the FuncOp. edsc::ScopedContext context(rewriter, op->getLoc()); - TransposeOpOperandAdaptor adaptor(operands); + TransposeOpAdaptor adaptor(operands); BaseViewConversionHelper baseDesc(adaptor.view()); auto transposeOp = cast(op); diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -28,7 +28,7 @@ LogicalResult matchAndRewrite(SrcOpTy op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - typename SrcOpTy::OperandAdaptor adaptor(operands); + typename SrcOpTy::Adaptor adaptor(operands); rewriter.replaceOpWithNewOp(op.getOperation(), adaptor.lhs(), adaptor.rhs()); return success(); @@ -43,7 +43,7 @@ LogicalResult matchAndRewrite(FromExtentTensorOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - FromExtentTensorOpOperandAdaptor transformed(operands); + FromExtentTensorOp::Adaptor transformed(operands); rewriter.replaceOp(op.getOperation(), transformed.input()); return success(); } @@ -56,7 +56,7 @@ LogicalResult matchAndRewrite(IndexToSizeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - IndexToSizeOpOperandAdaptor transformed(operands); + IndexToSizeOp::Adaptor transformed(operands); rewriter.replaceOp(op.getOperation(), transformed.arg()); return success(); } @@ -69,7 +69,7 @@ LogicalResult matchAndRewrite(SizeToIndexOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - SizeToIndexOpOperandAdaptor transformed(operands); + SizeToIndexOp::Adaptor transformed(operands); rewriter.replaceOp(op.getOperation(), transformed.arg()); return success(); } @@ -83,7 +83,7 @@ LogicalResult matchAndRewrite(ToExtentTensorOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - ToExtentTensorOpOperandAdaptor transformed(operands); + ToExtentTensorOp::Adaptor transformed(operands); rewriter.replaceOp(op.getOperation(), transformed.input()); return success(); } diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -1336,7 +1336,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto complexOp = cast(op); - OperandAdaptor transformed(operands); + CreateComplexOp::Adaptor transformed(operands); // Pack real and imaginary part in a complex number struct. auto loc = op->getLoc(); @@ -1356,7 +1356,7 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - OperandAdaptor transformed(operands); + ReOp::Adaptor transformed(operands); // Extract real part from the complex number struct. ComplexStructBuilder complexStruct(transformed.complex()); @@ -1373,7 +1373,7 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - OperandAdaptor transformed(operands); + ImOp::Adaptor transformed(operands); // Extract imaginary part from the complex number struct. ComplexStructBuilder complexStruct(transformed.complex()); @@ -1394,7 +1394,7 @@ ConversionPatternRewriter &rewriter) { auto bop = cast(op); auto loc = bop.getLoc(); - OperandAdaptor transformed(operands); + typename OpTy::Adaptor transformed(operands); // Extract real and imaginary values from operands. BinaryComplexOperands unpacked; @@ -1847,7 +1847,7 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - OperandAdaptor transformed(operands); + typename CallOpType::Adaptor transformed(operands); auto callOp = cast(op); // Pack the result types into a struct. @@ -1919,7 +1919,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1 && "dealloc takes one operand"); - OperandAdaptor transformed(operands); + DeallocOp::Adaptor transformed(operands); // Insert the `free` declaration if it is not already present. auto freeFunc = @@ -1949,7 +1949,7 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - OperandAdaptor transformed(operands); + RsqrtOp::Adaptor transformed(operands); auto operandType = transformed.operand().getType().dyn_cast(); @@ -2029,7 +2029,7 @@ void rewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto memRefCastOp = cast(op); - OperandAdaptor transformed(operands); + MemRefCastOp::Adaptor transformed(operands); auto srcType = memRefCastOp.getOperand().getType(); auto dstType = memRefCastOp.getType(); @@ -2098,7 +2098,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto castOp = cast(op); - OperandAdaptor transformed(operands); + LLVM::DialectCastOp::Adaptor transformed(operands); if (transformed.in().getType() != typeConverter.convertType(castOp.getType())) { return failure(); @@ -2117,7 +2117,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); - OperandAdaptor transformed(operands); + DimOp::Adaptor transformed(operands); MemRefType type = dimOp.memrefOrTensor().getType().cast(); Optional index = dimOp.getConstantIndex(); @@ -2163,7 +2163,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loadOp = cast(op); - OperandAdaptor transformed(operands); + LoadOp::Adaptor transformed(operands); auto type = loadOp.getMemRefType(); Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), @@ -2182,7 +2182,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto type = cast(op).getMemRefType(); - OperandAdaptor transformed(operands); + StoreOp::Adaptor transformed(operands); Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), transformed.indices(), rewriter, getModule()); @@ -2201,7 +2201,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto prefetchOp = cast(op); - OperandAdaptor transformed(operands); + PrefetchOp::Adaptor transformed(operands); auto type = prefetchOp.getMemRefType(); Value dataPtr = getDataPtr(op->getLoc(), type, transformed.memref(), @@ -2235,7 +2235,7 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - IndexCastOpOperandAdaptor transformed(operands); + IndexCastOpAdaptor transformed(operands); auto indexCastOp = cast(op); auto targetType = @@ -2271,7 +2271,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpiOp = cast(op); - CmpIOpOperandAdaptor transformed(operands); + CmpIOpAdaptor transformed(operands); rewriter.replaceOpWithNewOp( op, typeConverter.convertType(cmpiOp.getResult().getType()), @@ -2290,7 +2290,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto cmpfOp = cast(op); - CmpFOpOperandAdaptor transformed(operands); + CmpFOpAdaptor transformed(operands); rewriter.replaceOpWithNewOp( op, typeConverter.convertType(cmpfOp.getResult().getType()), @@ -2449,7 +2449,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto splatOp = cast(op); - OperandAdaptor adaptor(operands); + SplatOp::Adaptor adaptor(operands); VectorType resultType = splatOp.getType().dyn_cast(); if (!resultType || resultType.getRank() == 1) return failure(); @@ -2647,7 +2647,7 @@ ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); auto viewOp = cast(op); - ViewOpOperandAdaptor adaptor(operands); + ViewOpAdaptor adaptor(operands); auto viewMemRefType = viewOp.getType(); auto targetElementTy = @@ -2721,7 +2721,7 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - OperandAdaptor transformed(operands); + AssumeAlignmentOp::Adaptor transformed(operands); Value memref = transformed.memref(); unsigned alignment = cast(op).alignment().getZExtValue(); @@ -2791,7 +2791,7 @@ auto maybeKind = matchSimpleAtomicOp(atomicOp); if (!maybeKind) return failure(); - OperandAdaptor adaptor(operands); + AtomicRMWOp::Adaptor adaptor(operands); auto resultType = adaptor.value().getType(); auto memRefType = atomicOp.getMemRefType(); auto dataPtr = getDataPtr(op->getLoc(), memRefType, adaptor.memref(), @@ -2840,7 +2840,7 @@ auto atomicOp = cast(op); auto loc = op->getLoc(); - OperandAdaptor adaptor(operands); + GenericAtomicRMWOp::Adaptor adaptor(operands); LLVM::LLVMType valueType = typeConverter.convertType(atomicOp.getResult().getType()) .cast(); diff --git a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.cpp @@ -653,7 +653,7 @@ LogicalResult CmpFOpPattern::matchAndRewrite(CmpFOp cmpFOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - CmpFOpOperandAdaptor cmpFOpOperands(operands); + CmpFOpAdaptor cmpFOpOperands(operands); switch (cmpFOp.getPredicate()) { #define DISPATCH(cmpPredicate, spirvOp) \ @@ -693,7 +693,7 @@ LogicalResult BoolCmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - CmpIOpOperandAdaptor cmpIOpOperands(operands); + CmpIOpAdaptor cmpIOpOperands(operands); Type operandType = cmpIOp.lhs().getType(); if (!operandType.isa() || @@ -720,7 +720,7 @@ LogicalResult CmpIOpPattern::matchAndRewrite(CmpIOp cmpIOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - CmpIOpOperandAdaptor cmpIOpOperands(operands); + CmpIOpAdaptor cmpIOpOperands(operands); Type operandType = cmpIOp.lhs().getType(); if (operandType.isa() && @@ -763,7 +763,7 @@ LogicalResult IntLoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - LoadOpOperandAdaptor loadOperands(operands); + LoadOpAdaptor loadOperands(operands); auto loc = loadOp.getLoc(); auto memrefType = loadOp.memref().getType().cast(); if (!memrefType.getElementType().isSignlessInteger()) @@ -838,7 +838,7 @@ LogicalResult LoadOpPattern::matchAndRewrite(LoadOp loadOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - LoadOpOperandAdaptor loadOperands(operands); + LoadOpAdaptor loadOperands(operands); auto memrefType = loadOp.memref().getType().cast(); if (memrefType.getElementType().isSignlessInteger()) return failure(); @@ -870,7 +870,7 @@ LogicalResult SelectOpPattern::matchAndRewrite(SelectOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - SelectOpOperandAdaptor selectOperands(operands); + SelectOpAdaptor selectOperands(operands); rewriter.replaceOpWithNewOp(op, selectOperands.condition(), selectOperands.true_value(), selectOperands.false_value()); @@ -884,7 +884,7 @@ LogicalResult IntStoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - StoreOpOperandAdaptor storeOperands(operands); + StoreOpAdaptor storeOperands(operands); auto memrefType = storeOp.memref().getType().cast(); if (!memrefType.getElementType().isSignlessInteger()) return failure(); @@ -963,7 +963,7 @@ LogicalResult StoreOpPattern::matchAndRewrite(StoreOp storeOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { - StoreOpOperandAdaptor storeOperands(operands); + StoreOpAdaptor storeOperands(operands); auto memrefType = storeOp.memref().getType().cast(); if (memrefType.getElementType().isSignlessInteger()) return failure(); diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -176,7 +176,7 @@ LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp, ArrayRef operands, Value dataPtr) { - auto adaptor = TransferWriteOpOperandAdaptor(operands); + auto adaptor = TransferWriteOpAdaptor(operands); rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dataPtr); return success(); } @@ -190,21 +190,21 @@ if (failed(getVectorTransferAlignment(typeConverter, xferOp, align))) return failure(); - auto adaptor = TransferWriteOpOperandAdaptor(operands); + auto adaptor = TransferWriteOpAdaptor(operands); rewriter.replaceOpWithNewOp( xferOp, adaptor.vector(), dataPtr, mask, rewriter.getI32IntegerAttr(align)); return success(); } -static TransferReadOpOperandAdaptor -getTransferOpAdapter(TransferReadOp xferOp, ArrayRef operands) { - return TransferReadOpOperandAdaptor(operands); +static TransferReadOpAdaptor getTransferOpAdapter(TransferReadOp xferOp, + ArrayRef operands) { + return TransferReadOpAdaptor(operands); } -static TransferWriteOpOperandAdaptor -getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef operands) { - return TransferWriteOpOperandAdaptor(operands); +static TransferWriteOpAdaptor getTransferOpAdapter(TransferWriteOp xferOp, + ArrayRef operands) { + return TransferWriteOpAdaptor(operands); } namespace { @@ -222,7 +222,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto matmulOp = cast(op); - auto adaptor = vector::MatmulOpOperandAdaptor(operands); + auto adaptor = vector::MatmulOpAdaptor(operands); rewriter.replaceOpWithNewOp( op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(), adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(), @@ -244,7 +244,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto transOp = cast(op); - auto adaptor = vector::FlatTransposeOpOperandAdaptor(operands); + auto adaptor = vector::FlatTransposeOpAdaptor(operands); rewriter.replaceOpWithNewOp( transOp, typeConverter.convertType(transOp.res().getType()), adaptor.matrix(), transOp.rows(), transOp.columns()); @@ -337,7 +337,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - auto adaptor = vector::ShuffleOpOperandAdaptor(operands); + auto adaptor = vector::ShuffleOpAdaptor(operands); auto shuffleOp = cast(op); auto v1Type = shuffleOp.getV1VectorType(); auto v2Type = shuffleOp.getV2VectorType(); @@ -394,7 +394,7 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto adaptor = vector::ExtractElementOpOperandAdaptor(operands); + auto adaptor = vector::ExtractElementOpAdaptor(operands); auto extractEltOp = cast(op); auto vectorType = extractEltOp.getVectorType(); auto llvmType = typeConverter.convertType(vectorType.getElementType()); @@ -420,7 +420,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - auto adaptor = vector::ExtractOpOperandAdaptor(operands); + auto adaptor = vector::ExtractOpAdaptor(operands); auto extractOp = cast(op); auto vectorType = extractOp.getVectorType(); auto resultType = extractOp.getResult().getType(); @@ -488,7 +488,7 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto adaptor = vector::FMAOpOperandAdaptor(operands); + auto adaptor = vector::FMAOpAdaptor(operands); vector::FMAOp fmaOp = cast(op); VectorType vType = fmaOp.getVectorType(); if (vType.getRank() != 1) @@ -509,7 +509,7 @@ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - auto adaptor = vector::InsertElementOpOperandAdaptor(operands); + auto adaptor = vector::InsertElementOpAdaptor(operands); auto insertEltOp = cast(op); auto vectorType = insertEltOp.getDestVectorType(); auto llvmType = typeConverter.convertType(vectorType); @@ -535,7 +535,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - auto adaptor = vector::InsertOpOperandAdaptor(operands); + auto adaptor = vector::InsertOpAdaptor(operands); auto insertOp = cast(op); auto sourceType = insertOp.getSourceType(); auto destVectorType = insertOp.getDestVectorType(); @@ -967,7 +967,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto printOp = cast(op); - auto adaptor = vector::PrintOpOperandAdaptor(operands); + auto adaptor = vector::PrintOpAdaptor(operands); Type printType = printOp.getPrintType(); if (typeConverter.convertType(printType) == nullptr) diff --git a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp --- a/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp +++ b/mlir/lib/Conversion/VectorToROCDL/VectorToROCDL.cpp @@ -27,16 +27,6 @@ using namespace mlir; using namespace mlir::vector; -static TransferReadOpOperandAdaptor -getTransferOpAdapter(TransferReadOp xferOp, ArrayRef operands) { - return OperandAdaptor(operands); -} - -static TransferWriteOpOperandAdaptor -getTransferOpAdapter(TransferWriteOp xferOp, ArrayRef operands) { - return OperandAdaptor(operands); -} - static LogicalResult replaceTransferOpWithMubuf( ConversionPatternRewriter &rewriter, ArrayRef operands, LLVMTypeConverter &typeConverter, Location loc, TransferReadOp xferOp, @@ -52,7 +42,7 @@ LLVMTypeConverter &typeConverter, Location loc, TransferWriteOp xferOp, LLVM::LLVMType &vecTy, Value &dwordConfig, Value &vindex, Value &offsetSizeInBytes, Value &glc, Value &slc) { - auto adaptor = TransferWriteOpOperandAdaptor(operands); + auto adaptor = TransferWriteOpAdaptor(operands); rewriter.replaceOpWithNewOp(xferOp, adaptor.vector(), dwordConfig, vindex, offsetSizeInBytes, glc, slc); @@ -76,7 +66,7 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto xferOp = cast(op); - auto adaptor = getTransferOpAdapter(xferOp, operands); + typename ConcreteOp::Adaptor adaptor(operands); if (xferOp.getVectorType().getRank() > 1 || llvm::size(xferOp.indices()) == 0) diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -779,7 +779,7 @@ /*printBlockTerminators=*/false); } -// Namespace avoids ambiguous ReturnOpOperandAdaptor. +// Namespace avoids ambiguous ReturnOpAdaptor. namespace mlir { namespace gpu { #define GET_OP_CLASSES diff --git a/mlir/lib/TableGen/OpClass.cpp b/mlir/lib/TableGen/OpClass.cpp --- a/mlir/lib/TableGen/OpClass.cpp +++ b/mlir/lib/TableGen/OpClass.cpp @@ -227,8 +227,7 @@ os << ", " << trait; os << "> {\npublic:\n"; os << " using Op::Op;\n"; - os << " using OperandAdaptor = " << className << "OperandAdaptor;\n"; - os << " using Adaptor = " << className << "OperandAdaptor;\n"; + os << " using Adaptor = " << className << "Adaptor;\n"; bool hasPrivateMethod = false; for (const auto &method : methods) { diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -60,7 +60,7 @@ } std::string tblgen::Operator::getAdaptorName() const { - return std::string(llvm::formatv("{0}OperandAdaptor", getCppClassName())); + return std::string(llvm::formatv("{0}Adaptor", getCppClassName())); } StringRef tblgen::Operator::getDialectName() const { return dialect.getName(); } diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -33,7 +33,7 @@ // Test verify method // --- -// DEF: LogicalResult AOpOperandAdaptor::verify +// DEF: LogicalResult AOpAdaptor::verify // DEF: auto tblgen_aAttr = odsAttrs.get("aAttr"); // DEF-NEXT: if (!tblgen_aAttr) return emitError(loc, "'test.a_op' op ""requires attribute 'aAttr'"); // DEF: if (!((some-condition))) return emitError(loc, "'test.a_op' op ""attribute 'aAttr' failed to satisfy constraint: some attribute kind"); @@ -118,7 +118,7 @@ // Test common attribute kinds' constraints // --- -// DEF-LABEL: BOpOperandAdaptor::verify +// DEF-LABEL: BOpAdaptor::verify // DEF: if (!((true))) // DEF: if (!((tblgen_bool_attr.isa()))) // DEF: if (!(((tblgen_i32_attr.isa())) && ((tblgen_i32_attr.cast().getType().isSignlessInteger(32))))) diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -47,9 +47,9 @@ // CHECK-LABEL: NS::AOp declarations -// CHECK: class AOpOperandAdaptor { +// CHECK: class AOpAdaptor { // CHECK: public: -// CHECK: AOpOperandAdaptor(ValueRange values +// CHECK: AOpAdaptor(ValueRange values // CHECK: ValueRange getODSOperands(unsigned index); // CHECK: Value a(); // CHECK: ValueRange b(); @@ -63,7 +63,7 @@ // CHECK-NOT: OpTrait::IsIsolatedFromAbove // CHECK: public: // CHECK: using Op::Op; -// CHECK: using OperandAdaptor = AOpOperandAdaptor; +// CHECK: using Adaptor = AOpAdaptor; // CHECK: static StringRef getOperationName(); // CHECK: Operation::operand_range getODSOperands(unsigned index); // CHECK: Value a(); @@ -105,7 +105,7 @@ ); } -// CHECK-LABEL: AttrSizedOperandOpOperandAdaptor( +// CHECK-LABEL: AttrSizedOperandOpAdaptor( // CHECK-SAME: ValueRange values // CHECK-SAME: DictionaryAttr attrs // CHECK: ValueRange a(); diff --git a/mlir/test/mlir-tblgen/op-operand.td b/mlir/test/mlir-tblgen/op-operand.td --- a/mlir/test/mlir-tblgen/op-operand.td +++ b/mlir/test/mlir-tblgen/op-operand.td @@ -14,7 +14,7 @@ // CHECK-LABEL: OpA definitions -// CHECK: OpAOperandAdaptor::OpAOperandAdaptor +// CHECK: OpAAdaptor::OpAAdaptor // CHECK-SAME: odsOperands(values), odsAttrs(attrs) // CHECK: void OpA::build @@ -39,13 +39,13 @@ let arguments = (ins Variadic:$input1, AnyTensor:$input2, Variadic:$input3); } -// CHECK-LABEL: ValueRange OpDOperandAdaptor::input1 +// CHECK-LABEL: ValueRange OpDAdaptor::input1 // CHECK-NEXT: return getODSOperands(0); -// CHECK-LABEL: Value OpDOperandAdaptor::input2 +// CHECK-LABEL: Value OpDAdaptor::input2 // CHECK-NEXT: return *getODSOperands(1).begin(); -// CHECK-LABEL: ValueRange OpDOperandAdaptor::input3 +// CHECK-LABEL: ValueRange OpDAdaptor::input3 // CHECK-NEXT: return getODSOperands(2); // CHECK-LABEL: Operation::operand_range OpD::input1 diff --git a/mlir/test/mlir-tblgen/predicate.td b/mlir/test/mlir-tblgen/predicate.td --- a/mlir/test/mlir-tblgen/predicate.td +++ b/mlir/test/mlir-tblgen/predicate.td @@ -32,7 +32,7 @@ let arguments = (ins Confined]>:$attr); } -// CHECK-LABEL: OpFOperandAdaptor::verify +// CHECK-LABEL: OpFAdaptor::verify // CHECK: (tblgen_attr.cast().getInt() >= 10) // CHECK-SAME: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose minimum value is 10" @@ -40,7 +40,7 @@ let arguments = (ins Confined]>:$attr); } -// CHECK-LABEL: OpFXOperandAdaptor::verify +// CHECK-LABEL: OpFXAdaptor::verify // CHECK: (tblgen_attr.cast().getInt() <= 10) // CHECK-SAME: "attribute 'attr' failed to satisfy constraint: 32-bit signless integer attribute whose maximum value is 10" @@ -48,7 +48,7 @@ let arguments = (ins Confined]>:$attr); } -// CHECK-LABEL: OpGOperandAdaptor::verify +// CHECK-LABEL: OpGAdaptor::verify // CHECK: (tblgen_attr.cast().size() >= 8) // CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute with at least 8 elements" @@ -56,7 +56,7 @@ let arguments = (ins Confined]>:$attr); } -// CHECK-LABEL: OpHOperandAdaptor::verify +// CHECK-LABEL: OpHAdaptor::verify // CHECK: (((tblgen_attr.cast().size() > 0)) && ((tblgen_attr.cast()[0].cast().getInt() == 8))))) // CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be 8" @@ -64,7 +64,7 @@ let arguments = (ins Confined]>:$attr); } -// CHECK-LABEL: OpIOperandAdaptor::verify +// CHECK-LABEL: OpIAdaptor::verify // CHECK: (((tblgen_attr.cast().size() > 0)) && ((tblgen_attr.cast()[0].cast().getInt() >= 8))))) // CHECK-SAME: "attribute 'attr' failed to satisfy constraint: array attribute whose 0-th element must be at least 8" @@ -80,7 +80,7 @@ ); } -// CHECK-LABEL: OpJOperandAdaptor::verify +// CHECK-LABEL: OpJAdaptor::verify // CHECK: llvm::is_splat(llvm::map_range( // CHECK-SAME: llvm::ArrayRef({0, 2, 3}), // CHECK-SAME: [this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); }))