diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -604,11 +604,72 @@ let hasVerifier = 1; } +def SparseTensor_SelectOp : SparseTensor_Op<"select", [NoSideEffect, SameOperandsAndResultType]>, + Arguments<(ins AnyType:$x)>, + Results<(outs AnyType:$output)> { + let summary = "Select operation utilized within linalg.generic"; + let description = [{ + Defines an evaluation within a `linalg.generic` operation that takes a single + operand and decides whether or not to keep that operand in the output. + + A single region must contain exactly one block taking one argument. The block + must end with a sparse_tensor.yield and the output type must be boolean. + + Value threshold is an obvious usage of the select operation. However, by using + `linalg.index`, other useful selection can be achieved, such as selecting the + upper triangle of a matrix. + + Example of selecting A >= 4.0: + + ```mlir + %C = bufferization.alloc_tensor... + %0 = linalg.generic #trait + ins(%A: tensor) + outs(%C: tensor) { + ^bb0(%a: f64, %c: f64) : + %result = sparse_tensor.select %a : f64 { + ^bb0(%arg0: f64): + %cf4 = arith.constant 4.0 : f64 + %keep = arith.cmpf "uge", %arg0, %cf4 : f64 + sparse_tensor.yield %keep : i1 + } + linalg.yield %result : f64 + } -> tensor + ``` + + Example of selecting lower triangle of a matrix: + + ```mlir + %C = bufferization.alloc_tensor... + %0 = linalg.generic #trait + ins(%A: tensor) + outs(%C: tensor) { + ^bb0(%a: f64, %c: f64) : + %row = linalg.index 0 : index + %col = linalg.index 1 : index + %result = sparse_tensor.select %a : f64 { + ^bb0(%arg0: f64): + %keep = arith.cmpf "olt", %col, %row : f64 + sparse_tensor.yield %keep : i1 + } + linalg.yield %result : f64 + } -> tensor + ``` + }]; + + let regions = (region SizedRegion<1>:$region); + let assemblyFormat = [{ + $x attr-dict `:` type($x) $region + }]; + let hasVerifier = 1; +} + def SparseTensor_YieldOp : SparseTensor_Op<"yield", [NoSideEffect, Terminator]>, Arguments<(ins AnyType:$result)> { let summary = "Yield from sparse_tensor set-like operations"; let description = [{ - Yields a value from within a `binary` or `unary` block. + Yields a value from within a `binary`, `unary`, `reduce`, + or `select` block. Example: diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -458,12 +458,27 @@ // Check correct number of block arguments and return type. Region &formula = getRegion(); - if (!formula.empty()) { - regionResult = verifyNumBlockArgs( - this, formula, "reduce", TypeRange{inputType, inputType}, inputType); - if (failed(regionResult)) - return regionResult; - } + regionResult = verifyNumBlockArgs(this, formula, "reduce", + TypeRange{inputType, inputType}, inputType); + if (failed(regionResult)) + return regionResult; + + return success(); +} + +LogicalResult SelectOp::verify() { + Builder b(getContext()); + + Type inputType = getX().getType(); + Type boolType = b.getI1Type(); + LogicalResult regionResult = success(); + + // Check correct number of block arguments and return type. + Region &formula = getRegion(); + regionResult = verifyNumBlockArgs(this, formula, "select", + TypeRange{inputType}, boolType); + if (failed(regionResult)) + return regionResult; return success(); } @@ -472,11 +487,11 @@ // Check for compatible parent. auto *parentOp = (*this)->getParentOp(); if (isa(parentOp) || isa(parentOp) || - isa(parentOp)) + isa(parentOp) || isa(parentOp)) return success(); - return emitOpError( - "expected parent op to be sparse_tensor unary, binary, or reduce"); + return emitOpError("expected parent op to be sparse_tensor unary, binary, " + "reduce, or select"); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SparseTensor/invalid.mlir b/mlir/test/Dialect/SparseTensor/invalid.mlir --- a/mlir/test/Dialect/SparseTensor/invalid.mlir +++ b/mlir/test/Dialect/SparseTensor/invalid.mlir @@ -355,6 +355,40 @@ // ----- +func.func @invalid_select_num_args_mismatch(%arg0: f64) -> f64 { + // expected-error@+1 {{select region must have exactly 1 arguments}} + %r = sparse_tensor.select %arg0 : f64 { + ^bb0(%x: f64, %y: f64): + %ret = arith.constant 1 : i1 + sparse_tensor.yield %ret : i1 + } + return %r : f64 +} + +// ----- + +func.func @invalid_select_return_type_mismatch(%arg0: f64) -> f64 { + // expected-error@+1 {{select region yield type mismatch}} + %r = sparse_tensor.select %arg0 : f64 { + ^bb0(%x: f64): + sparse_tensor.yield %x : f64 + } + return %r : f64 +} + +// ----- + +func.func @invalid_select_wrong_yield(%arg0: f64) -> f64 { + // expected-error@+1 {{select region must end with sparse_tensor.yield}} + %r = sparse_tensor.select %arg0 : f64 { + ^bb0(%x: f64): + tensor.yield %x : f64 + } + return %r : f64 +} + +// ----- + #DC = #sparse_tensor.encoding<{dimLevelType = ["dense", "compressed"]}> func.func @invalid_concat_less_inputs(%arg: tensor<9x4xf64, #DC>) -> tensor<9x4xf64, #DC> { // expected-error@+1 {{Need at least two tensors to concatenate.}} diff --git a/mlir/test/Dialect/SparseTensor/roundtrip.mlir b/mlir/test/Dialect/SparseTensor/roundtrip.mlir --- a/mlir/test/Dialect/SparseTensor/roundtrip.mlir +++ b/mlir/test/Dialect/SparseTensor/roundtrip.mlir @@ -291,6 +291,30 @@ #SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> +// CHECK-LABEL: func @sparse_select( +// CHECK-SAME: %[[A:.*]]: f64) -> f64 { +// CHECK: %[[Z:.*]] = arith.constant 0.000000e+00 : f64 +// CHECK: %[[C1:.*]] = sparse_tensor.select %[[A]] : f64 { +// CHECK: ^bb0(%[[A1:.*]]: f64): +// CHECK: %[[B1:.*]] = arith.cmpf ogt, %[[A1]], %[[Z]] : f64 +// CHECK: sparse_tensor.yield %[[B1]] : i1 +// CHECK: } +// CHECK: return %[[C1]] : f64 +// CHECK: } +func.func @sparse_select(%arg0: f64) -> f64 { + %cf0 = arith.constant 0.0 : f64 + %r = sparse_tensor.select %arg0 : f64 { + ^bb0(%x: f64): + %cmp = arith.cmpf "ogt", %x, %cf0 : f64 + sparse_tensor.yield %cmp : i1 + } + return %r : f64 +} + +// ----- + +#SparseMatrix = #sparse_tensor.encoding<{dimLevelType = ["compressed", "compressed"]}> + // CHECK-LABEL: func @concat_sparse_sparse( // CHECK-SAME: %[[A0:.*]]: tensor<2x4xf64 // CHECK-SAME: %[[A1:.*]]: tensor<3x4xf64