diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -1178,7 +1178,8 @@ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, - AttrSizedOperandSegments + AttrSizedOperandSegments, + DestinationStyleOpInterface ]>, Arguments<(ins AnyShaped:$source, Variadic:$indices, @@ -1400,6 +1401,10 @@ let extraClassDeclaration = [{ // MaskableOpInterface methods. bool supportsPassthru() { return true; } + + std::pair getDpsInitsPositionRange() { + return {0, 0}; // empty range (no init operands) + } }]; let hasCanonicalizer = 1; diff --git a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td --- a/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td +++ b/mlir/include/mlir/Interfaces/DestinationStyleOpInterface.td @@ -24,15 +24,15 @@ position [start, end). The positions are defined by getDpsInitsPositionRange method. - If the op has "tensor semantics", then the input operands are either scalars - or ranked tensors. The init operands are ranked tensors and every tensor - init is tied to a corresponding tensor OpResult in a 1-to-1 fashion. - The i-th init tensor is tied to the i-th OpResult. The op may not have any - additional OpResults. Init operands and their tied OpResults have the same - type. + If the op has "tensor semantics", then the input operands are either ranked + tensors or other non-tensor/memref types ("scalars"). The init operands are + ranked tensors and every tensor init is tied to a corresponding tensor + OpResult in a 1-to-1 fashion. The i-th init tensor is tied to the i-th + OpResult. The op may not have any additional OpResults. Init operands and + their tied OpResults have the same type. If the op has "buffer semantics", then the input operands are either ranked - memrefs or other non-tensor types, e.g. scalar types. Furthermore, the + memrefs or other non-tensor/memref types ("scalar" types). Furthermore, the init operands are ranked memrefs and the op has no results. Destination-passing style abstraction makes certain transformations easier. @@ -194,14 +194,17 @@ }] >, InterfaceMethod< - /*desc=*/"Return true if the `opOperand` is a scalar value.", + /*desc=*/[{ + Return true if the `opOperand` is a scalar value. A scalar is defined + as neither a memref nor a tensor value. + }], /*retTy=*/"bool", /*methodName=*/"isScalar", /*args=*/(ins "::mlir::OpOperand *":$opOperand), /*methodBody=*/"", /*defaultImplementation=*/[{ assert(opOperand->getOwner() == $_op.getOperation()); - return !::llvm::isa(opOperand->get().getType()); + return !::llvm::isa(opOperand->get().getType()); }] >, InterfaceMethod< @@ -235,32 +238,49 @@ // Other interface methods. //===------------------------------------------------------------------===// InterfaceMethod< - /*desc=*/"Return whether the op has only ranked MemRef input/inits.", + /*desc=*/[{ + Return whether the op has buffer semantics. That is the case if the op + has no tensor operands and at least one memref operand. + }], /*retTy=*/"bool", /*methodName=*/"hasBufferSemantics", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return $_op->getNumResults() == 0 && - ::llvm::all_of($_op->getOpOperands(), - [&](::mlir::OpOperand &opOperand) { - return isScalar(&opOperand) || - ::llvm::isa<::mlir::MemRefType>(opOperand.get().getType()); - }); + // No tensors. + auto isTensor = [](Value v){ + return ::llvm::isa<::mlir::RankedTensorType>(v.getType()); + }; + if (::llvm::any_of($_op->getOperands(), isTensor)) + return false; + // At least one memref. + auto isMemref = [](Value v){ + return ::llvm::isa<::mlir::MemRefType>(v.getType()); + }; + return llvm::any_of($_op->getOperands(), isMemref); }] >, InterfaceMethod< - /*desc=*/"Return whether the op has only ranked tensor inputs/inits.", + /*desc=*/[{ + Return whether the op has tensor semantics. That is the case if the op + has no memref operands and at least one tensor operand. + }], /*retTy=*/"bool", /*methodName=*/"hasTensorSemantics", /*args=*/(ins), /*methodBody=*/"", /*defaultImplementation=*/[{ - return ::llvm::all_of($_op->getOpOperands(), - [&](::mlir::OpOperand &opOperand) { - return isScalar(&opOperand) || - ::llvm::isa<::mlir::RankedTensorType>(opOperand.get().getType()); - }); + // No memrefs. + auto isMemref = [](Value v){ + return ::llvm::isa<::mlir::MemRefType>(v.getType()); + }; + if (::llvm::any_of($_op->getOperands(), isMemref)) + return false; + // At least one tensor. + auto isTensor = [](Value v){ + return ::llvm::isa<::mlir::RankedTensorType>(v.getType()); + }; + return llvm::any_of($_op->getOperands(), isTensor); }] > ];