diff --git a/mlir/include/mlir/Dialect/Vector/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Vector/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Vector/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) +add_subdirectory(Interfaces) add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -13,6 +13,7 @@ #ifndef MLIR_DIALECT_VECTOR_IR_VECTOROPS_H #define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H +#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -49,6 +50,10 @@ struct BitmaskEnumStorage; } // namespace detail +/// Default callback to build a region with a 'vector.yield' terminator with no +/// arguments. +void buildTerminatedBody(OpBuilder &builder, Location loc); + /// Return whether `srcType` can be broadcast to `dstVectorType` under the /// semantics of the `vector.broadcast` op. enum class BroadcastableToResult { diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -13,6 +13,7 @@ #ifndef VECTOR_OPS #define VECTOR_OPS +include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td" include "mlir/IR/EnumAttr.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" @@ -283,6 +284,7 @@ Vector_Op<"reduction", [NoSideEffect, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]>, Arguments<(ins Vector_CombiningKindAttr:$kind, @@ -360,7 +362,7 @@ }]; let builders = [ OpBuilder<(ins "Value":$source, "Value":$acc, - "ArrayRef":$reductionMask, "CombiningKind":$kind)> + "ArrayRef":$reductionMask, "CombiningKind":$kind)> ]; let extraClassDeclaration = [{ static StringRef getKindAttrStrName() { return "kind"; } @@ -1050,6 +1052,7 @@ Vector_Op<"transfer_read", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AttrSizedOperandSegments ]>, @@ -1246,6 +1249,12 @@ "ValueRange":$indices, CArg<"Optional>", "::llvm::None">:$inBounds)>, ]; + + let extraClassDeclaration = [{ + // MaskableOpInterface methods. + bool supportsPassthru() { return true; } + }]; + let hasCanonicalizer = 1; let hasCustomAssemblyFormat = 1; let hasFolder = 1; @@ -1256,6 +1265,7 @@ Vector_Op<"transfer_write", [ DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, AttrSizedOperandSegments ]>, @@ -2120,6 +2130,78 @@ let assemblyFormat = "$operands attr-dict `:` type(results)"; } +def Vector_MaskOp : Vector_Op<"mask", [ + SingleBlockImplicitTerminator<"vector::YieldOp">, RecursiveSideEffects, + NoRegionArguments +]> { + let summary = "Predicates a maskable vector operation"; + let description = [{ + The `vector.mask` operation predicates the execution of another operation. + It takes an `i1` vector mask and an optional pass-thru vector as arguments. + A `vector.yield`-terminated region encloses the operation to be masked. + Values used within the region are captured from above. Only one *maskable* + operation can be masked with a `vector.mask` operation at a time. An + operation is *maskable* if it implements the `MaskableOpInterface`. + + The vector mask argument holds a bit for each vector lane and determines + which vector lanes should execute the maskable operation and which ones + should not. The `vector.mask` operation returns the value produced by the + masked execution of the nested operation, if any. The masked-off lanes in + the result vector are taken from the corresponding lanes of the pass-thru + argument, if provided, or left unmodified, otherwise. + + The `vector.mask` operation does not prescribe how a maskable operation + should be masked or how a masked operation should be lowered. Masking + constraints and some semantic details are provided by each maskable + operation through the `MaskableOpInterface`. Lowering of masked operations + is implementation defined. For instance, scalarizing the masked operation + or executing the operation for the masked-off lanes are valid lowerings as + long as the execution of masked-off lanes does not change the observable + behavior of the program. + + Examples: + + ``` + %0 = vector.mask %mask { vector.reduction , %a : vector<8xi32> into i32 } : vector<8xi1> -> i32 + ``` + + ``` + %0 = vector.mask %mask, %passthru { arith.divsi %a, %b : vector<8xi32> } : vector<8xi1> -> vector<8xi32> + ``` + + ``` + vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref } : vector<16xi1> + ``` + }]; + + // TODO: Support multiple results and passthru values. + let arguments = (ins VectorOf<[I1]>:$mask, + Optional:$passthru); + let results = (outs Optional:$results); + let regions = (region SizedRegion<1>:$maskRegion); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "Value":$mask, + CArg<"function_ref", + "buildTerminatedBody">:$maskRegion)>, + OpBuilder<(ins "Type":$resultType, "Value":$mask, + CArg<"function_ref", + "buildTerminatedBody">:$maskRegion)>, + OpBuilder<(ins "Type":$resultType, "Value":$mask, + "Value":$passthru, + CArg<"function_ref", + "buildTerminatedBody">:$maskRegion)> + ]; + + let extraClassDeclaration = [{ + static void ensureTerminator(Region ®ion, Builder &builder, Location loc); + }]; + + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; +} + def Vector_TransposeOp : Vector_Op<"transpose", [NoSideEffect, DeclareOpInterfaceMethods, diff --git a/mlir/include/mlir/Dialect/Vector/Interfaces/CMakeLists.txt b/mlir/include/mlir/Dialect/Vector/Interfaces/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/Interfaces/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_interface(MaskingInterfaces) diff --git a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h @@ -0,0 +1,22 @@ +//===- MaskingInterfaces.h - Masking interfaces ---------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the interfaces for masking operations. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_ +#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/OpDefinition.h" + +/// Include the generated interface declarations. +#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h.inc" + +#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES_H_ diff --git a/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td @@ -0,0 +1,52 @@ +//===- MaskingInterfaces.td - Masking Interfaces Decls === -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the definition file for vector masking related interfaces. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES +#define MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES + +include "mlir/IR/OpBase.td" + +def MaskableOpInterface : OpInterface<"MaskableOpInterface"> { + let description = [{ + The 'MaskableOpInterface' define an operation that can be masked using the + `vector.mask` operation and provides information about its masking + constraints and semantics. + }]; + let cppNamespace = "::mlir::vector"; + let methods = [ + InterfaceMethod< + /*desc=*/"Returns true if the operation may have a passthru argument when" + " masked.", + /*retTy=*/"bool", + /*methodName=*/"supportsPassthru", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return false; + }]>, + InterfaceMethod< + /*desc=*/"Returns the mask type expected by this operation. It requires the" + " operation to be vectorized.", + /*retTy=*/"mlir::VectorType", + /*methodName=*/"getExpectedMaskType", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + // Default implementation is only aimed for operations that implement the + // `getVectorType()` method. + return $_op.getVectorType().cloneWith( + /*shape=*/llvm::None, IntegerType::get($_op.getContext(), /*width=*/1)); + }]>, + ]; +} + +#endif // MLIR_DIALECT_VECTOR_INTERFACES_MASKINGINTERFACES diff --git a/mlir/lib/Dialect/Vector/CMakeLists.txt b/mlir/lib/Dialect/Vector/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) +add_subdirectory(Interfaces) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt --- a/mlir/lib/Dialect/Vector/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Vector/IR/CMakeLists.txt @@ -14,6 +14,7 @@ MLIRDataLayoutInterfaces MLIRDialectUtils MLIRIR + MLIRMaskingInterfaces MLIRMemRefDialect MLIRSideEffectInterfaces MLIRTensorDialect diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -96,6 +96,12 @@ return MaskFormat::Unknown; } +/// Default callback to build a region with a 'vector.yield' terminator with no +/// arguments. +void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) { + builder.create(loc); +} + // Helper for verifying combining kinds in contractions and reductions. static bool isSupportedCombiningKind(CombiningKind combiningKind, Type elementType) { @@ -4860,6 +4866,172 @@ results.add(context); } +//===----------------------------------------------------------------------===// +// MaskOp +//===----------------------------------------------------------------------===// + +void MaskOp::build( + OpBuilder &builder, OperationState &result, Value mask, + function_ref maskRegionBuilder) { + assert(maskRegionBuilder && + "builder callback for 'maskRegion' must be present"); + + result.addOperands(mask); + OpBuilder::InsertionGuard guard(builder); + Region *maskRegion = result.addRegion(); + builder.createBlock(maskRegion); + maskRegionBuilder(builder, result.location); +} + +void MaskOp::build( + OpBuilder &builder, OperationState &result, Type resultType, Value mask, + function_ref maskRegionBuilder) { + build(builder, result, resultType, mask, /*passthru=*/Value(), + maskRegionBuilder); +} + +void MaskOp::build( + OpBuilder &builder, OperationState &result, Type resultType, Value mask, + Value passthru, + function_ref maskRegionBuilder) { + build(builder, result, mask, maskRegionBuilder); + if (passthru) + result.addOperands(passthru); + result.addTypes(resultType); +} + +ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) { + // Create the op region. + result.regions.reserve(1); + Region &maskRegion = *result.addRegion(); + + auto &builder = parser.getBuilder(); + + // Parse all the operands. + OpAsmParser::UnresolvedOperand mask; + if (parser.parseOperand(mask)) + return failure(); + + // Optional passthru operand. + OpAsmParser::UnresolvedOperand passthru; + ParseResult parsePassthru = parser.parseOptionalComma(); + if (parsePassthru.succeeded() && parser.parseOperand(passthru)) + return failure(); + + // Parse op region. + if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{})) + return failure(); + + MaskOp::ensureTerminator(maskRegion, builder, result.location); + + // Parse the optional attribute list. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + // Parse all the types. + Type maskType; + if (parser.parseColonType(maskType)) + return failure(); + + SmallVector resultTypes; + if (parser.parseOptionalArrowTypeList(resultTypes)) + return failure(); + result.types.append(resultTypes); + + // Resolve operands. + if (parser.resolveOperand(mask, maskType, result.operands)) + return failure(); + + if (parsePassthru.succeeded()) + if (parser.resolveOperand(passthru, resultTypes[0], result.operands)) + return failure(); + + return success(); +} + +void mlir::vector::MaskOp::print(OpAsmPrinter &p) { + p << " " << getMask(); + if (getPassthru()) + p << ", " << getPassthru(); + + // Print single masked operation and skip terminator. + p << " { "; + Block *singleBlock = &getMaskRegion().getBlocks().front(); + if (singleBlock && singleBlock->getOperations().size() > 1) + p.printCustomOrGenericOp(&singleBlock->front()); + p << " }"; + + p.printOptionalAttrDict(getOperation()->getAttrs()); + + p << " : " << getMask().getType(); + if (getNumResults() > 0) + p << " -> " << getResultTypes(); +} + +void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) { + OpTrait::SingleBlockImplicitTerminator::Impl< + MaskOp>::ensureTerminator(region, builder, loc); + // Keep the default yield terminator if the number of masked operations is not + // the expected. This case will trigger a verification failure. + if (region.front().getOperations().size() != 2) + return; + + // Replace default yield terminator with a new one that returns the results + // from the masked operation. + OpBuilder opBuilder(builder.getContext()); + Operation *maskedOp = ®ion.front().front(); + Operation *oldYieldOp = ®ion.front().back(); + assert(isa(oldYieldOp) && "Expected vector::YieldOp"); + + opBuilder.setInsertionPoint(oldYieldOp); + opBuilder.create(maskedOp->getLoc(), maskedOp->getResults()); + oldYieldOp->dropAllReferences(); + oldYieldOp->erase(); +} + +LogicalResult MaskOp::verify() { + // Structural checks. + Block &block = getMaskRegion().getBlocks().front(); + if (block.getOperations().size() < 2) + return emitOpError("expects an operation to mask"); + if (block.getOperations().size() > 2) + return emitOpError("expects only one operation to mask"); + + auto maskableOp = dyn_cast(block.front()); + if (!maskableOp) + return emitOpError("expects a maskable operation"); + + // Result checks. + if (maskableOp->getNumResults() != getNumResults()) + return emitOpError("expects number of results to match maskable operation " + "number of results"); + + if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes())) + return emitOpError( + "expects result type to match maskable operation result type"); + + // Mask checks. + if (getMask().getType() != maskableOp.getExpectedMaskType()) + return emitOpError("expects a ") << maskableOp.getExpectedMaskType() + << " mask for the maskable operation"; + + // Passthru checks. + Value passthru = getPassthru(); + if (passthru) { + if (!maskableOp.supportsPassthru()) + return emitOpError( + "doesn't expect a passthru argument for this maskable operation"); + + if (maskableOp->getNumResults() != 1) + return emitOpError("expects result when passthru argument is provided"); + + if (passthru.getType() != maskableOp->getResultTypes()[0]) + return emitOpError("expects passthru type to match result type"); + } + + return success(); +} + //===----------------------------------------------------------------------===// // ScanOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Interfaces/CMakeLists.txt b/mlir/lib/Dialect/Vector/Interfaces/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Interfaces/CMakeLists.txt @@ -0,0 +1,21 @@ +set(LLVM_OPTIONAL_SOURCES + MaskingInterfaces.cpp + ) + +function(add_mlir_interface_library name) + add_mlir_library(MLIR${name} + ${name}.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/Interfaces + + DEPENDS + MLIR${name}IncGen + + LINK_LIBS PUBLIC + MLIRIR + ) +endfunction(add_mlir_interface_library) + +add_mlir_interface_library(MaskingInterfaces) + diff --git a/mlir/lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp b/mlir/lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp @@ -0,0 +1,16 @@ +//===- MaskingInterfaces.cpp - Masking interfaces ----------====-*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h" + +//===----------------------------------------------------------------------===// +// Masking Interfaces +//===----------------------------------------------------------------------===// + +/// Include the definitions of the masking interfaces. +#include "mlir/Dialect/Vector/Interfaces/MaskingInterfaces.cpp.inc" diff --git a/mlir/lib/Interfaces/CMakeLists.txt b/mlir/lib/Interfaces/CMakeLists.txt --- a/mlir/lib/Interfaces/CMakeLists.txt +++ b/mlir/lib/Interfaces/CMakeLists.txt @@ -40,10 +40,10 @@ add_mlir_interface_library(DerivedAttributeOpInterface) add_mlir_interface_library(InferIntRangeInterface) add_mlir_interface_library(InferTypeOpInterface) +add_mlir_interface_library(LoopLikeInterface) add_mlir_interface_library(ParallelCombiningOpInterface) add_mlir_interface_library(ShapedOpInterfaces) add_mlir_interface_library(SideEffectInterfaces) add_mlir_interface_library(TilingInterface) add_mlir_interface_library(VectorInterfaces) add_mlir_interface_library(ViewLikeInterface) -add_mlir_interface_library(LoopLikeInterface) diff --git a/mlir/test/Dialect/Vector/invalid.mlir b/mlir/test/Dialect/Vector/invalid.mlir --- a/mlir/test/Dialect/Vector/invalid.mlir +++ b/mlir/test/Dialect/Vector/invalid.mlir @@ -57,7 +57,7 @@ } // ----- - + func.func @shuffle_rank_mismatch_0d(%arg0: vector, %arg1: vector<1xf32>) { // expected-error@+1 {{'vector.shuffle' op rank mismatch}} %1 = vector.shuffle %arg0, %arg1 [0, 1] : vector, vector<1xf32> @@ -1166,7 +1166,7 @@ } // ----- - + func.func @transpose_length_mismatch_0d(%arg0: vector) { // expected-error@+1 {{'vector.transpose' op transposition length mismatch: 1}} %0 = vector.transpose %arg0, [1] : vector to vector @@ -1586,3 +1586,49 @@ } return } + +// ----- + +func.func @vector_mask_empty(%m0: vector<16xi1>) -> i32 { + // expected-error@+1 {{'vector.mask' op expects an operation to mask}} + vector.mask %m0 { } : vector<16xi1> +} + +// ----- + +func.func @vector_mask_multiple_ops(%t0: tensor, %t1: tensor, %idx: index, %val: vector<16xf32>, %m0: vector<16xi1>) { + %ft0 = arith.constant 0.0 : f32 + // expected-error@+1 {{'vector.mask' op expects only one operation to mask}} + vector.mask %m0 { + vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor + vector.transfer_write %val, %t1[%idx] : vector<16xf32>, tensor + } : vector<16xi1> + return +} + +// ----- + +func.func @vector_mask_shape_mismatch(%a: vector<8xi32>, %m0: vector<16xi1>) -> i32 { + // expected-error@+1 {{'vector.mask' op expects a 'vector<8xi1>' mask for the maskable operation}} + %0 = vector.mask %m0 { vector.reduction , %a : vector<8xi32> into i32 } : vector<16xi1> -> i32 + return %0 : i32 +} + +// ----- + +// expected-note@+1 {{prior use here}} +func.func @vector_mask_passthru_type_mismatch(%t0: tensor, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xi32>) -> vector<16xf32> { + %ft0 = arith.constant 0.0 : f32 + // expected-error@+1 {{use of value '%pt0' expects different type than prior uses: 'vector<16xf32>' vs 'vector<16xi32>'}} + %0 = vector.mask %m0, %pt0 { vector.transfer_read %t0[%idx], %ft0 : tensor, vector<16xf32> } : vector<16xi1> -> vector<16xf32> + return %0 : vector<16xf32> +} + +// ----- + +func.func @vector_mask_passthru_no_return(%val: vector<16xf32>, %t0: tensor, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xf32>) { + // expected-error@+1 {{'vector.mask' op expects result type to match maskable operation result type}} + vector.mask %m0, %pt0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor } : vector<16xi1> -> vector<16xf32> + return +} + diff --git a/mlir/test/Dialect/Vector/ops.mlir b/mlir/test/Dialect/Vector/ops.mlir --- a/mlir/test/Dialect/Vector/ops.mlir +++ b/mlir/test/Dialect/Vector/ops.mlir @@ -786,7 +786,7 @@ func.func @test_splat_op(%s : f32) { // CHECK: vector.splat [[S]] : vector<8xf32> %v = vector.splat %s : vector<8xf32> - + // CHECK: vector.splat [[S]] : vector<4xf32> %u = "vector.splat"(%s) : (f32) -> vector<4xf32> return @@ -824,4 +824,32 @@ return %2 : vector<4xi32> } +// CHECK-LABEL: func @vector_mask +func.func @vector_mask(%a: vector<8xi32>, %m0: vector<8xi1>) -> i32 { +// CHECK-NEXT: %{{.*}} = vector.mask %{{.*}} { vector.reduction , %{{.*}} : vector<8xi32> into i32 } : vector<8xi1> -> i32 + %0 = vector.mask %m0 { vector.reduction , %a : vector<8xi32> into i32 } : vector<8xi1> -> i32 + return %0 : i32 +} + +// CHECK-LABEL: func @vector_mask_passthru +func.func @vector_mask_passthru(%t0: tensor, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xf32>) -> vector<16xf32> { + %ft0 = arith.constant 0.0 : f32 +// CHECK: %{{.*}} = vector.mask %{{.*}}, %{{.*}} { vector.transfer_read %{{.*}}[%{{.*}}], %{{.*}} : tensor, vector<16xf32> } : vector<16xi1> -> vector<16xf32> + %0 = vector.mask %m0, %pt0 { vector.transfer_read %t0[%idx], %ft0 : tensor, vector<16xf32> } : vector<16xi1> -> vector<16xf32> + return %0 : vector<16xf32> +} + +// CHECK-LABEL: func @vector_mask_no_return +func.func @vector_mask_no_return(%val: vector<16xf32>, %t0: memref, %idx: index, %m0: vector<16xi1>) { +// CHECK-NEXT: vector.mask %{{.*}} { vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}] : vector<16xf32>, memref } : vector<16xi1> + vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref } : vector<16xi1> + return +} + +// CHECK-LABEL: func @vector_mask_tensor_return +func.func @vector_mask_tensor_return(%val: vector<16xf32>, %t0: tensor, %idx: index, %m0: vector<16xi1>) { +// CHECK-NEXT: vector.mask %{{.*}} { vector.transfer_write %{{.*}}, %{{.*}}[%{{.*}}] : vector<16xf32>, tensor } : vector<16xi1> -> tensor + vector.mask %m0 { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor } : vector<16xi1> -> tensor + return +} diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3213,6 +3213,7 @@ ":DialectUtils", ":IR", ":InferTypeOpInterface", + ":MaskingInterfaces", ":MemRefDialect", ":SideEffectInterfaces", ":Support", @@ -7954,6 +7955,17 @@ ], ) +##---------------------------------------------------------------------------## +# Vector dialect. +##---------------------------------------------------------------------------## + +td_library( + name = "MaskingInterfacesTdFiles", + srcs = ["include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td"], + includes = ["include"], + deps = [":OpBaseTdFiles"], +) + td_library( name = "VectorOpsTdFiles", srcs = ["include/mlir/Dialect/Vector/IR/VectorOps.td"], @@ -7961,6 +7973,7 @@ deps = [ ":ControlFlowInterfacesTdFiles", ":InferTypeOpInterfaceTdFiles", + ":MaskingInterfacesTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", ":VectorInterfacesTdFiles", @@ -7968,6 +7981,24 @@ ], ) +gentbl_cc_library( + name = "MaskingInterfacesIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.td", + deps = [":MaskingInterfacesTdFiles"], +) + gentbl_cc_library( name = "VectorOpsIncGen", strip_include_prefix = "include", @@ -8020,6 +8051,19 @@ deps = [":VectorOpsTdFiles"], ) +cc_library( + name = "MaskingInterfaces", + srcs = ["lib/Dialect/Vector/Interfaces/MaskingInterfaces.cpp"], + hdrs = ["include/mlir/Dialect/Vector/Interfaces/MaskingInterfaces.h"], + includes = ["include"], + deps = [ + ":IR", + ":MaskingInterfacesIncGen", + ":Support", + "//llvm:Support", + ], +) + cc_library( name = "VectorToLLVM", srcs = glob([