diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -13,7 +13,8 @@ #ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H -#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h" +#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" +#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -13,7 +13,8 @@ #ifndef VECTOR_OPS #define VECTOR_OPS -include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td" +include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td" +include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td" include "mlir/IR/EnumAttr.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" @@ -2140,13 +2141,15 @@ } def Vector_MaskOp : Vector_Op<"mask", [ - SingleBlockImplicitTerminator<"vector::YieldOp">, RecursiveMemoryEffects, - NoRegionArguments + SingleBlockImplicitTerminator<"vector::YieldOp">, + DeclareOpInterfaceMethods, + RecursiveMemoryEffects, NoRegionArguments ]> { let summary = "Predicates a maskable vector operation"; let description = [{ - The `vector.mask` operation predicates the execution of another operation. - It takes an `i1` vector mask and an optional pass-thru vector as arguments. + The `vector.mask` is a `MaskingOpInterface` operation that predicates the + execution of another operation. It takes an `i1` vector mask and an + optional passthru vector as arguments. A `vector.yield`-terminated region encloses the operation to be masked. Values used within the region are captured from above. Only one *maskable* operation can be masked with a `vector.mask` operation at a time. An diff --git a/mlir/include/mlir/Dialect/Vector/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Vector/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Vector/Interfaces/CMakeLists.txt @@ -1 +1,2 @@ -add_mlir_interface(MaskingInterfaces) +add_mlir_interface(MaskableOpInterface) +add_mlir_interface(MaskingOpInterface) diff --git a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h @@ -0,0 +1,23 @@ +//===- MaskableOpInterface.h ----------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the MaskableOpInterface. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_H_ +#define MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_H_ + +#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" + +// Include the generated interface declarations. +#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h.inc" + +#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_H_ diff --git a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td @@ -0,0 +1,72 @@ +//===- MaskableOpInterfaces.td - Masking Interfaces Decls -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the definition file for the MaskableOpInterface. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_TD +#define MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_TD + +include "mlir/IR/OpBase.td" + +def MaskableOpInterface : OpInterface<"MaskableOpInterface"> { + let description = [{ + The 'MaskableOpInterface' defines an operation that can be masked using a + MaskingOpInterface (e.g., `vector.mask`) and provides information about its + masking constraints and semantics. + }]; + let cppNamespace = "::mlir::vector"; + let methods = [ + InterfaceMethod< + /*desc=*/"Returns true if the operation is masked by a " + "MaskingOpInterface.", + /*retTy=*/"bool", + /*methodName=*/"isMasked", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return mlir::isa($_op->getParentOp()); + }]>, + InterfaceMethod< + /*desc=*/"Returns the MaskingOpInterface masking this operation.", + /*retTy=*/"mlir::vector::MaskingOpInterface", + /*methodName=*/"getMaskingOp", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return mlir::cast( + $_op->getParentOp()); + }]>, + InterfaceMethod< + /*desc=*/"Returns true if the operation can have a passthru argument when" + " masked.", + /*retTy=*/"bool", + /*methodName=*/"supportsPassthru", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }]>, + InterfaceMethod< + /*desc=*/"Returns the mask type expected by this operation. It requires " + "the operation to be vectorized.", + /*retTy=*/"mlir::VectorType", + /*methodName=*/"getExpectedMaskType", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Default implementation is only aimed for operations that implement the + // `getVectorType()` method. + return $_op.getVectorType().cloneWith(/*shape=*/llvm::None, + IntegerType::get($_op.getContext(), /*width=*/1)); + }]>, + ]; +} + +#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKABLEOPINTERFACE_TD diff --git a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td deleted file mode 100644 --- a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td +++ /dev/null @@ -1,52 +0,0 @@ -//===- MaskingInterfaces.td - Masking Interfaces Decls === -*- tablegen -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This is the definition file for vector masking related interfaces. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES -#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES - -include "mlir/IR/OpBase.td" - -def MaskableOpInterface : OpInterface<"MaskableOpInterface"> { - let description = [{ - The 'MaskableOpInterface' define an operation that can be masked using the - `vector.mask` operation and provides information about its masking - constraints and semantics. - }]; - let cppNamespace = "::mlir::vector"; - let methods = [ - InterfaceMethod< - /*desc=*/"Returns true if the operation may have a passthru argument when" - " masked.", - /*retTy=*/"bool", - /*methodName=*/"supportsPassthru", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - return false; - }]>, - InterfaceMethod< - /*desc=*/"Returns the mask type expected by this operation. It requires the" - " operation to be vectorized.", - /*retTy=*/"mlir::VectorType", - /*methodName=*/"getExpectedMaskType", - /*args=*/(ins), - /*methodBody=*/"", - /*defaultImplementation=*/[{ - // Default implementation is only aimed for operations that implement the - // `getVectorType()` method. - return $_op.getVectorType().cloneWith( - /*shape=*/llvm::None, IntegerType::get($_op.getContext(), /*width=*/1)); - }]>, - ]; -} - -#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES diff --git a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h rename from mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h rename to mlir/include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h --- a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h +++ b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h @@ -1,4 +1,4 @@ -//===- MaskingInterfaces.h - Masking interfaces ---------------------------===// +//===- MaskingOpInterface.h -----------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,17 +6,17 @@ // //===----------------------------------------------------------------------===// // -// This file implements the interfaces for masking operations. +// This file implements the MaskingOpInterface. // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_ -#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_ +#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_H_ +#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_H_ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/OpDefinition.h" -/// Include the generated interface declarations. -#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h.inc" +// Include the generated interface declarations. +#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h.inc" -#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_ +#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_H_ diff --git a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td @@ -0,0 +1,58 @@ +//===- MaskingOpInterfaces.td - MaskingOpInterface Decls = -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the definition file for the MaskingOpInterface. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_TD +#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_TD + +include "mlir/IR/OpBase.td" + +def MaskingOpInterface : OpInterface<"MaskingOpInterface"> { + let description = [{ + The 'MaskingOpInterface' defines an vector operation that can apply masking + to its own or other vector operations. + }]; + let cppNamespace = "::mlir::vector"; + let methods = [ + InterfaceMethod< + /*desc=*/"Returns the mask value of this masking operation.", + /*retTy=*/"mlir::Value", + /*methodName=*/"getMask", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"">, + InterfaceMethod< + /*desc=*/"Returns the operation masked by this masking operation.", + // TODO: Return a MaskableOpInterface when interface infra can handle + // dependences between interfaces. + /*retTy=*/"Operation *", + /*methodName=*/"getMaskableOp", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"">, + InterfaceMethod< + /*desc=*/"Returns true if the masking operation has a passthru value.", + /*retTy=*/"bool", + /*methodName=*/"hasPassthru", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"">, + InterfaceMethod< + /*desc=*/"Returns the passthru value of this masking operation.", + /*retTy=*/"mlir::Value", + /*methodName=*/"getPassthru", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/"">, + ]; +} + +#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGOPINTERFACE_TD diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt @@ -15,7 +15,8 @@ MLIRDestinationStyleOpInterface MLIRDialectUtils MLIRIR - MLIRMaskingInterfaces + MLIRMaskableOpInterface + MLIRMaskingOpInterface MLIRMemRefDialect MLIRSideEffectInterfaces MLIRTensorDialect diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5003,6 +5003,14 @@ return success(); } +// MaskingOpInterface definitions. + +/// Returns the operation masked by this 'vector.mask'. +Operation *MaskOp::getMaskableOp() { return &getMaskRegion().front().front(); } + +/// Returns true if 'vector.mask' has a passthru value. +bool MaskOp::hasPassthru() { return getPassthru() != Value(); } + //===----------------------------------------------------------------------===// // ScanOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Vector/Interfaces/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/Interfaces/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/Interfaces/CMakeLists.txt @@ -1,5 +1,6 @@ set(LLVM_OPTIONAL_SOURCES - MaskingInterfaces.cpp + MaskableOpInterface.cpp + MaskingOpInterface.cpp ) function(add_mlir_interface_library name) @@ -17,5 +18,5 @@ ) endfunction(add_mlir_interface_library) -add_mlir_interface_library(MaskingInterfaces) - +add_mlir_interface_library(MaskableOpInterface) +add_mlir_interface_library(MaskingOpInterface) diff --git a/mlir/lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp b/mlir/lib/Dialect/Vector/Interfaces/MaskableOpInterface.cpp rename from mlir/lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp rename to mlir/lib/Dialect/Vector/Interfaces/MaskableOpInterface.cpp --- a/mlir/lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp +++ b/mlir/lib/Dialect/Vector/Interfaces/MaskableOpInterface.cpp @@ -1,4 +1,4 @@ -//===- MaskingInterfaces.cpp - Masking interfaces ----------====-*- C++ -*-===// +//===- MaskableOpInterfaces.cpp - MaskableOpInterface Defs -====-*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,11 +6,13 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h" +#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" + +using namespace mlir; +using namespace mlir::vector; //===----------------------------------------------------------------------===// -// Masking Interfaces +// MaskableOpInterface Defs //===----------------------------------------------------------------------===// -/// Include the definitions of the masking interfaces. -#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.cpp.inc" +#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.cpp.inc" diff --git a/mlir/lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp b/mlir/lib/Dialect/Vector/Interfaces/MaskingOpInterface.cpp rename from mlir/lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp rename to mlir/lib/Dialect/Vector/Interfaces/MaskingOpInterface.cpp --- a/mlir/lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp +++ b/mlir/lib/Dialect/Vector/Interfaces/MaskingOpInterface.cpp @@ -1,4 +1,4 @@ -//===- MaskingInterfaces.cpp - Masking interfaces ----------====-*- C++ -*-===// +//===- MaskingOpInterface.cpp - MaskingOpInterface Defs -----====-*- C++-*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,11 +6,13 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h" +#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h" + +using namespace mlir; +using namespace mlir::vector; //===----------------------------------------------------------------------===// -// Masking Interfaces +// MaskingOpInterface Defs //===----------------------------------------------------------------------===// -/// Include the definitions of the masking interfaces. -#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.cpp.inc" +#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.cpp.inc" diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3250,7 +3250,8 @@ ":DialectUtils", ":IR", ":InferTypeOpInterface", - ":MaskingInterfaces", + ":MaskableOpInterface", + ":MaskingOpInterface", ":MemRefDialect", ":SideEffectInterfaces", ":Support", @@ -8201,8 +8202,15 @@ ##---------------------------------------------------------------------------## td_library( - name = "MaskingInterfacesTdFiles", - srcs = ["include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td"], + name = "MaskingOpInterfaceTdFiles", + srcs = ["include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"], + includes = ["include"], + deps = [":OpBaseTdFiles"], +) + +td_library( + name = "MaskableOpInterfaceTdFiles", + srcs = ["include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"], includes = ["include"], deps = [":OpBaseTdFiles"], ) @@ -8215,7 +8223,8 @@ ":ControlFlowInterfacesTdFiles", ":DestinationStyleOpInterfaceTdFiles", ":InferTypeOpInterfaceTdFiles", - ":MaskingInterfacesTdFiles", + ":MaskableOpInterfaceTdFiles", + ":MaskingOpInterfaceTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ":VectorInterfacesTdFiles", @@ -8224,21 +8233,39 @@ ) gentbl_cc_library( - name = "MaskingInterfacesIncGen", + name = "MaskableOpInterfaceIncGen", strip_include_prefix = "include", tbl_outs = [ ( ["-gen-op-interface-decls"], - "include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h.inc", + "include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h.inc", ), ( ["-gen-op-interface-defs"], - "include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.cpp.inc", + "include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.cpp.inc", ), ], tblgen = ":mlir-tblgen", - td_file = "include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td", - deps = [":MaskingInterfacesTdFiles"], + td_file = "include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td", + deps = [":MaskableOpInterfaceTdFiles"], +) + +gentbl_cc_library( + name = "MaskingOpInterfaceIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td", + deps = [":MaskingOpInterfaceTdFiles"], ) gentbl_cc_library( @@ -8294,13 +8321,27 @@ ) cc_library( - name = "MaskingInterfaces", - srcs = ["lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp"], - hdrs = ["include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h"], + name = "MaskableOpInterface", + srcs = ["lib/Dialect/Vector/Interfaces/MaskableOpInterface.cpp"], + hdrs = ["include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"], + includes = ["include"], + deps = [ + ":IR", + ":MaskableOpInterfaceIncGen", + ":MaskingOpInterface", + ":Support", + "//llvm:Support", + ], +) + +cc_library( + name = "MaskingOpInterface", + srcs = ["lib/Dialect/Vector/Interfaces/MaskingOpInterface.cpp"], + hdrs = ["include/mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"], includes = ["include"], deps = [ ":IR", - ":MaskingInterfacesIncGen", + ":MaskingOpInterfaceIncGen", ":Support", "//llvm:Support", ],