diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -21,6 +21,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Interfaces/MaskingInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" @@ -49,6 +50,10 @@ struct BitmaskEnumStorage; } // namespace detail +/// Default callback to build a region with a 'vector.yield' terminator with no +/// arguments. +void buildTerminatedBody(OpBuilder &builder, Location loc); + /// Return whether `srcType` can be broadcast to `dstVectorType` under the /// semantics of the `vector.broadcast` op. enum class BroadcastableToResult { diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -16,6 +16,7 @@ include "mlir/IR/EnumAttr.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" +include "mlir/Interfaces/MaskingInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" @@ -283,6 +284,7 @@ Vector_Op<"reduction", [NoSideEffect, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]>, Arguments<(ins Vector_CombiningKindAttr:$kind, @@ -360,7 +362,7 @@ }]; let builders = [ OpBuilder<(ins "Value":$source, "Value":$acc, - "ArrayRef":$reductionMask, "CombiningKind":$kind)> + "ArrayRef":$reductionMask, "CombiningKind":$kind)> ]; let extraClassDeclaration = [{ static StringRef getKindAttrStrName() { return "kind"; } @@ -1049,6 +1051,7 @@ Vector_Op<"transfer_read", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AttrSizedOperandSegments ]>, @@ -1245,6 +1248,12 @@ "ValueRange":$indices, CArg<"Optional>", "::llvm::None">:$inBounds)>, ]; + + let extraClassDeclaration = [{ + // MaskableOpInterface methods. + bool supportsPassthru() { return true; } + }]; + let hasCanonicalizer = 1; let hasCustomAssemblyFormat = 1; let hasFolder = 1; @@ -1255,6 +1264,7 @@ Vector_Op<"transfer_write", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AttrSizedOperandSegments ]>, @@ -2119,6 +2129,78 @@ let assemblyFormat = "$operands attr-dict `:` type(results)"; } +def Vector_MaskOp : Vector_Op<"mask", [ + SingleBlockImplicitTerminator<"vector::YieldOp">, RecursiveSideEffects, + NoRegionArguments +]> { + let summary = "Predicates a maskable vector operation"; + let description = [{ + The `vector.mask` operation predicates the execution of another operation. + It takes an `i1` vector mask and an optional pass-thru vector as arguments. + A `vector.yield`-terminated region encloses the operation to be masked. + Values used within the region are captured from above. Only one *maskable* + operation can be masked with a `vector.mask` operation at a time. An + operation is *maskable* if it implements the `MaskableOpInterface`. + + The vector mask argument holds a bit for each vector lane and determines + which vector lanes should execute the maskable operation and which ones + should not. The `vector.mask` operation returns the value produced by the + masked execution of the nested operation, if any. The masked-off lanes in + the result vector are taken from the corresponding lanes of the pass-thru + argument, if provided, or left unmodified, otherwise. + + The `vector.mask` operation does not prescribe how a maskable operation + should be masked or how a masked operation should be lowered. Masking + constraints and some semantic details are provided by each maskable + operation through the `MaskableOpInterface`. Lowering of masked operations + is implementation defined. For instance, scalarizing the masked operation + or executing the operation for the masked-off lanes are valid lowerings as + long as the execution of masked-off lanes does not change the observable + behavior of the program. + + Examples: + + ``` + %0 = vector.mask %mask { vector.reduction , %a : vector<8xi32> into i32 } : vector<8xi1>, i32 + ``` + + ``` + %0 = vector.mask %mask, %passthru { arith.divsi %a, %b : vector<8xi32> } : vector<8xi1>, vector<8xi32> + ``` + + ``` + vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor } : vector<16xi1>, vector<16xf32> + ``` + }]; + + // TODO: Support multiple results and passthru values. + let arguments = (ins VectorOf<[I1]>:$mask, + Optional:$passthru); + let results = (outs Optional:$results); + let regions = (region SizedRegion<1>:$maskRegion); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "Value":$mask, + CArg<"function_ref", + "buildTerminatedBody">:$maskRegion)>, + OpBuilder<(ins "Type":$resultType, "Value":$mask, + CArg<"function_ref", + "buildTerminatedBody">:$maskRegion)>, + OpBuilder<(ins "Type":$resultType, "Value":$mask, + "Value":$passthru, + CArg<"function_ref", + "buildTerminatedBody">:$maskRegion)> + ]; + + let extraClassDeclaration = [{ + static void ensureTerminator(Region ®ion, Builder &builder, Location loc); + }]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + def Vector_TransposeOp : Vector_Op<"transpose", [NoSideEffect, DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -340,6 +340,9 @@ ArrayRef argAttrs = {}, bool omitType = false) = 0; + /// Print an operation omitting its results, including the sign '='. + void printOperationWithoutResults(Operation *); + /// Print implementations for various things an operation contains. virtual void printOperand(Value value) = 0; virtual void printOperand(Value value, raw_ostream &os) = 0; diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -788,6 +788,9 @@ /// Print users of values as comments. OpPrintingFlags &printValueUsers(); + /// Print op results. + OpPrintingFlags &printResults(bool printResults = true); + /// Return if the given ElementsAttr should be elided. bool shouldElideElementsAttr(ElementsAttr attr) const; @@ -812,6 +815,9 @@ /// Return if the printer should print users of values. bool shouldPrintValueUsers() const; + /// Return if the printer should print the op results. + bool shouldPrintResults() const; + private: /// Elide large elements attributes if the number of elements is larger than /// the upper limit. @@ -832,6 +838,9 @@ /// Print users of values. bool printValueUsersFlag : 1; + + /// Print op results. + bool printResultsFlag : 1; }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Interfaces/CMakeLists.txt b/mlir/include/mlir/Interfaces/CMakeLists.txt --- a/mlir/include/mlir/Interfaces/CMakeLists.txt +++ b/mlir/include/mlir/Interfaces/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_interface(DerivedAttributeOpInterface) add_mlir_interface(InferIntRangeInterface) add_mlir_interface(InferTypeOpInterface) +add_mlir_interface(MaskingInterfaces) add_mlir_interface(LoopLikeInterface) add_mlir_interface(ParallelCombiningOpInterface) add_mlir_interface(SideEffectInterfaces) diff --git a/mlir/include/mlir/Interfaces/MaskingInterfaces.h b/mlir/include/mlir/Interfaces/MaskingInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/MaskingInterfaces.h @@ -0,0 +1,22 @@ +//===- MaskingInterfaces.h - Masking interfaces ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the interfaces for masking operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_MASKINGINTERFACES_H_ +#define MLIR_INTERFACES_MASKINGINTERFACES_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" + +/// Include the generated interface declarations. +#include "mlir/Interfaces/MaskingInterfaces.h.inc" + +#endif // MLIR_INTERFACES_MASKINGINTERFACES_H_ diff --git a/mlir/include/mlir/Interfaces/MaskingInterfaces.td b/mlir/include/mlir/Interfaces/MaskingInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Interfaces/MaskingInterfaces.td @@ -0,0 +1,52 @@ +//===- MaskingInterfaces.td - Masking Interfaces Decls === -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the definition file for vector masking related interfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INTERFACES_MASKINGINTERFACES +#define MLIR_INTERFACES_MASKINGINTERFACES + +include "mlir/IR/OpBase.td" + +def MaskableOpInterface : OpInterface<"MaskableOpInterface"> { + let description = [{ + The 'MaskableOpInterface' define an operation that can be masked using the + `vector.mask` operation and provides information about its masking + constraints and semantics. + }]; + let cppNamespace = "::mlir::vector"; + let methods = [ + InterfaceMethod< + /*desc=*/"Returns true if the operation may have a passthru argument when" + " masked.", + /*retTy=*/"bool", + /*methodName=*/"supportsPassthru", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }]>, + InterfaceMethod< + /*desc=*/"Returns the mask type expected by this operation. It requires the" + " operation to be vectorized.", + /*retTy=*/"mlir::VectorType", + /*methodName=*/"getExpectedMaskType", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Default implementation is only aimed for operations that implement the + // `getVectorType()` method. + return $_op.getVectorType().cloneWith( + /*shape=*/llvm::None, IntegerType::get($_op.getContext(), /*width=*/1)); + }]>, + ]; +} + +#endif // MLIR_INTERFACES_MASKINGINTERFACES diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt @@ -14,6 +14,7 @@ MLIRDataLayoutInterfaces MLIRDialectUtils MLIRIR + MLIRMaskingInterfaces MLIRMemRefDialect MLIRSideEffectInterfaces MLIRTensorDialect diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -96,6 +96,12 @@ return MaskFormat::Unknown; } +/// Default callback to build a region with a 'vector.yield' terminator with no +/// arguments. +void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) { + builder.create(loc); +} + // Helper for verifying combining kinds in contractions and reductions. static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType) { @@ -4808,6 +4814,169 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// MaskOp +//===----------------------------------------------------------------------===// + +void MaskOp::build( + OpBuilder &builder, OperationState &result, Value mask, + function_ref maskRegionBuilder) { + assert(maskRegionBuilder && + "builder callback for 'maskRegion' must be present"); + + result.addOperands(mask); + OpBuilder::InsertionGuard guard(builder); + Region *maskRegion = result.addRegion(); + builder.createBlock(maskRegion); + maskRegionBuilder(builder, result.location); +} + +void MaskOp::build( + OpBuilder &builder, OperationState &result, Type resultType, Value mask, + function_ref maskRegionBuilder) { + build(builder, result, resultType, mask, /*passthru=*/Value(), + maskRegionBuilder); +} + +void MaskOp::build( + OpBuilder &builder, OperationState &result, Type resultType, Value mask, + Value passthru, + function_ref maskRegionBuilder) { + build(builder, result, mask, maskRegionBuilder); + if (passthru) + result.addOperands(passthru); + result.addTypes(resultType); +} + +ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) { + // Create the op region. + result.regions.reserve(1); + Region &maskRegion = *result.addRegion(); + + auto &builder = parser.getBuilder(); + + // Parse all the operands. + OpAsmParser::UnresolvedOperand mask; + if (parser.parseOperand(mask)) + return failure(); + + // Optional passthru operand. + OpAsmParser::UnresolvedOperand passthru; + ParseResult parsePassthru = parser.parseOptionalComma(); + if (parsePassthru.succeeded()) + parser.parseOperand(passthru); + + // Parse op region. + if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + + MaskOp::ensureTerminator(maskRegion, builder, result.location); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Parse all the types and classify them. + SmallVector types; + if (parser.parseColonTypeList(types)) + return failure(); + + result.types.append(std::next(types.begin()), types.end()); + + // Resolve operands. + if (parser.resolveOperand(mask, types[0], result.operands)) + return failure(); + + if (parsePassthru.succeeded()) + if (parser.resolveOperand(passthru, types[1], result.operands)) + return failure(); + + return success(); +} + +void mlir::vector::MaskOp::print(OpAsmPrinter &p) { + p << " " << getMask(); + if (getPassthru()) + p << ", " << getPassthru(); + + // Print single masked operation and skip terminator. + p << " { "; + Block *singleBlock = &getMaskRegion().getBlocks().front(); + if (singleBlock && singleBlock->getOperations().size() > 1) + p.printOperationWithoutResults(&singleBlock->front()); + p << " }"; + + p.printOptionalAttrDict(getOperation()->getAttrs()); + + p << " : " << getMask().getType(); + if (getNumResults() > 0) + p << ", " << getResultTypes(); +} + +void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) { + OpTrait::SingleBlockImplicitTerminator::Impl< + MaskOp>::ensureTerminator(region, builder, loc); + // Keep the default yield terminator if the number of masked operations is not + // the expected. + if (region.front().getOperations().size() != 2) + return; + + // Replace default yield terminator with a new one that returns the results + // from the masked operation. + OpBuilder opBuilder(builder.getContext()); + Operation *maskedOp = ®ion.front().front(); + Operation *oldYieldOp = ®ion.front().back(); + assert(isa(oldYieldOp) && "Expected vector::YieldOp"); + + opBuilder.setInsertionPoint(oldYieldOp); + opBuilder.create(maskedOp->getLoc(), maskedOp->getResults()); + oldYieldOp->dropAllReferences(); + oldYieldOp->erase(); +} + +LogicalResult MaskOp::verify() { + // Structural checks. + Block &block = getMaskRegion().getBlocks().front(); + if (block.getOperations().size() < 2) + return emitOpError("expects an operation to mask"); + if (block.getOperations().size() > 2) + return emitOpError("expects only one operation to mask"); + + auto maskableOp = dyn_cast(block.front()); + if (!maskableOp) + return emitOpError("expects a maskable operation"); + + // Result checks. + if (maskableOp->getNumResults() != getNumResults()) + return emitOpError("expects number of results to match maskable operation " + "number of results"); + + if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes())) + return emitOpError( + "expects result type to match maskable operation result type"); + + // Mask checks. + if (getMask().getType() != maskableOp.getExpectedMaskType()) + return emitOpError("expects a ") << maskableOp.getExpectedMaskType() + << " mask for the maskable operation"; + + // Passthru checks. + Value passthru = getPassthru(); + if (passthru) { + if (!maskableOp.supportsPassthru()) + return emitOpError( + "doesn't expect a passthru argument for this maskable operation"); + + if (maskableOp->getNumResults() != 1) + return emitOpError("expects result when passthru argument is provided"); + + if (passthru.getType() != maskableOp->getResultTypes()[0]) + return emitOpError("expects passthru type to match result type"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // ScanOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -78,6 +78,11 @@ OpAsmPrinter::~OpAsmPrinter() = default; +/// Print an operation omitting its results, including the '=' sign. +void OpAsmPrinter::printOperationWithoutResults(Operation *op) { + op->print(getStream(), OpPrintingFlags().useLocalScope().printResults(false)); +} + void OpAsmPrinter::printFunctionalType(Operation *op) { auto &os = getStream(); os << '('; @@ -184,7 +189,8 @@ OpPrintingFlags::OpPrintingFlags() : printDebugInfoFlag(false), printDebugInfoPrettyFormFlag(false), printGenericOpFormFlag(false), assumeVerifiedFlag(false), - printLocalScope(false), printValueUsersFlag(false) { + printLocalScope(false), printValueUsersFlag(false), + printResultsFlag(true) { // Initialize based upon command line options, if they are available. if (!clOptions.isConstructed()) return; @@ -242,6 +248,12 @@ return *this; } +/// Print op results. +OpPrintingFlags &OpPrintingFlags::printResults(bool printResults) { + printResultsFlag = printResults; + return *this; +} + /// Return if the given ElementsAttr should be elided. bool OpPrintingFlags::shouldElideElementsAttr(ElementsAttr attr) const { return elementsAttrElementLimit && @@ -282,6 +294,9 @@ return printValueUsersFlag; } +/// Return if the printer should print the op results. +bool OpPrintingFlags::shouldPrintResults() const { return printResultsFlag; } + /// Returns true if an ElementsAttr with the given number of elements should be /// printed with hex. static bool shouldPrintElementsAttrWithHex(int64_t numElements) { @@ -2934,7 +2949,8 @@ } void OperationPrinter::printOperation(Operation *op) { - if (size_t numResults = op->getNumResults()) { + size_t numResults = op->getNumResults(); + if (printerFlags.shouldPrintResults() && numResults > 0) { auto printResultGroup = [&](size_t resultNo, size_t resultCount) { printValueID(op->getResult(resultNo), /*printResultNo=*/false); if (resultCount > 1) diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -7,6 +7,7 @@ DerivedAttributeOpInterface.cpp InferIntRangeInterface.cpp InferTypeOpInterface.cpp + MaskingInterfaces.cpp LoopLikeInterface.cpp ParallelCombiningOpInterface.cpp SideEffectInterfaces.cpp @@ -39,9 +40,10 @@ add_mlir_interface_library(DerivedAttributeOpInterface) add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) +add_mlir_interface_library(LoopLikeInterface) +add_mlir_interface_library(MaskingInterfaces) add_mlir_interface_library(ParallelCombiningOpInterface) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) add_mlir_interface_library(VectorInterfaces) add_mlir_interface_library(ViewLikeInterface) -add_mlir_interface_library(LoopLikeInterface) diff --git a/mlir/lib/Interfaces/MaskingInterfaces.cpp b/mlir/lib/Interfaces/MaskingInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Interfaces/MaskingInterfaces.cpp @@ -0,0 +1,16 @@ +//===- MaskingInterfaces.cpp - Masking interfaces ----------====-*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Interfaces/MaskingInterfaces.h" + +//===----------------------------------------------------------------------===// +// Masking Interfaces +//===----------------------------------------------------------------------===// + +/// Include the definitions of the masking interfaces. +#include "mlir/Interfaces/MaskingInterfaces.cpp.inc" diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -57,7 +57,7 @@ } // ----- - + func.func @shuffle_rank_mismatch_0d(%arg0: vector, %arg1: vector<1xf32>) { // expected-error@+1 {{'vector.shuffle' op rank mismatch}} %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector, vector<1xf32> @@ -1166,7 +1166,7 @@ } // ----- - + func.func @transpose_length_mismatch_0d(%arg0: vector) { // expected-error@+1 {{'vector.transpose' op transposition length mismatch: 1}} %0 = vector.transpose %arg0, [1] : vector to vector @@ -1586,3 +1586,49 @@ } return } + +// ----- + +func.func @vector_mask_empty(%m0: vector<16xi1>) -> i32 { + // expected-error@+1 {{'vector.mask' op expects an operation to mask}} + vector.mask %m0 { } : vector<16xi1> +} + +// ----- + +func.func @vector_mask_multiple_ops(%t0: tensor, %t1: tensor, %idx: index, %val: vector<16xf32>, %m0: vector<16xi1>) { + %ft0 = arith.constant 0.0 : f32 + // expected-error@+1 {{'vector.mask' op expects only one operation to mask}} + vector.mask %m0 { + vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor + vector.transfer_write %val, %t1[%idx] : vector<16xf32>, tensor + } : vector<16xi1> + return +} + +// ----- + +func.func @vector_mask_shape_mismatch(%a: vector<8xi32>, %m0: vector<16xi1>) -> i32 { + // expected-error@+1 {{'vector.mask' op expects a 'vector<8xi1>' mask for the maskable operation}} + %0 = vector.mask %m0 { vector.reduction , %a : vector<8xi32> into i32 } : vector<16xi1>, i32 + return %0 : i32 +} + +// ----- + +// expected-note@+1 {{prior use here}} +func.func @vector_mask_passthru_type_mismatch(%t0: tensor, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xi32>) -> vector<16xf32> { + %ft0 = arith.constant 0.0 : f32 + // expected-error@+1 {{use of value '%pt0' expects different type than prior uses: 'vector<16xf32>' vs 'vector<16xi32>'}} + %0 = vector.mask %m0, %pt0 { vector.transfer_read %t0[%idx], %ft0 : tensor, vector<16xf32> } : vector<16xi1>, vector<16xf32> + return %0 : vector<16xf32> +} + +// ----- + +func.func @vector_mask_passthru_no_return(%val: vector<16xf32>, %t0: tensor, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xf32>) { + // expected-error@+1 {{'vector.mask' op expects result type to match maskable operation result type}} + vector.mask %m0, %pt0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor } : vector<16xi1>, vector<16xf32> + return +} + diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -786,7 +786,7 @@ func.func @test_splat_op(%s : f32) { // CHECK: vector.splat [[S]] : vector<8xf32> %v = vector.splat %s : vector<8xf32> - + // CHECK: vector.splat [[S]] : vector<4xf32> %u = "vector.splat"(%s) : (f32) -> vector<4xf32> return @@ -824,4 +824,25 @@ return %2 : vector<4xi32> } +// CHECK-LABEL: func @vector_mask +func.func @vector_mask(%a: vector<8xi32>, %m0: vector<8xi1>) -> i32 { +// CHECK-NEXT: %{{.*}} = vector.mask %{{.*}} { vector.reduction , %{{.*}} : vector<8xi32> into i32 } : vector<8xi1>, i32 + %0 = vector.mask %m0 { vector.reduction , %a : vector<8xi32> into i32 } : vector<8xi1>, i32 + return %0 : i32 +} + +// CHECK-LABEL: func @vector_mask_passthru +func.func @vector_mask_passthru(%t0: tensor, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xf32>) -> vector<16xf32> { + %ft0 = arith.constant 0.0 : f32 +// CHECK: %{{.*}} = vector.mask %{{.*}}, %{{.*}} { vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : tensor, vector<16xf32> } : vector<16xi1>, vector<16xf32> + %0 = vector.mask %m0, %pt0 { vector.transfer_read %t0[%idx], %ft0 : tensor, vector<16xf32> } : vector<16xi1>, vector<16xf32> + return %0 : vector<16xf32> +} + +// CHECK-LABEL: func @vector_mask_no_return +func.func @vector_mask_no_return(%val: vector<16xf32>, %t0: memref, %idx: index, %m0: vector<16xi1>) { +// CHECK-NEXT: vector.mask %{{.*}} { vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}] : vector<16xf32>, memref } : vector<16xi1> + vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref } : vector<16xi1> + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -971,6 +971,13 @@ deps = [":OpBaseTdFiles"], ) +td_library( + name = "MaskingInterfacesTdFiles", + srcs = ["include/mlir/Interfaces/MaskingInterfaces.td"], + includes = ["include"], + deps = [":OpBaseTdFiles"], +) + td_library( name = "LoopLikeInterfaceTdFiles", srcs = ["include/mlir/Interfaces/LoopLikeInterface.td"], @@ -3169,6 +3176,7 @@ ":DialectUtils", ":IR", ":InferTypeOpInterface", + ":MaskingInterfaces", ":MemRefDialect", ":SideEffectInterfaces", ":Support", @@ -5889,6 +5897,37 @@ ], ) +gentbl_cc_library( + name = "MaskingInterfacesIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Interfaces/MaskingInterfaces.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Interfaces/MaskingInterfaces.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Interfaces/MaskingInterfaces.td", + deps = [":MaskingInterfacesTdFiles"], +) + +cc_library( + name = "MaskingInterfaces", + srcs = ["lib/Interfaces/MaskingInterfaces.cpp"], + hdrs = ["include/mlir/Interfaces/MaskingInterfaces.h"], + includes = ["include"], + deps = [ + ":IR", + ":MaskingInterfacesIncGen", + ":Support", + "//llvm:Support", + ], +) + gentbl_cc_library( name = "SideEffectInterfacesIncGen", strip_include_prefix = "include", @@ -7774,6 +7813,7 @@ deps = [ ":ControlFlowInterfacesTdFiles", ":InferTypeOpInterfaceTdFiles", + ":MaskingInterfacesTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ":VectorInterfacesTdFiles",