diff --git a/mlir/include/mlir/Dialect/SparseTensor/CMakeLists.txt b/mlir/include/mlir/Dialect/SparseTensor/CMakeLists.txt --- a/mlir/include/mlir/Dialect/SparseTensor/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/SparseTensor/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) diff --git a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/CMakeLists.txt b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/CMakeLists.txt @@ -0,0 +1,4 @@ +set(LLVM_TARGET_DEFINITIONS SparseTensorTransformOps.td) +mlir_tablegen(SparseTensorTransformOps.h.inc -gen-op-decls) +mlir_tablegen(SparseTensorTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRSparseTensorTransformOpsIncGen) diff --git a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h @@ -0,0 +1,40 @@ +//===- SparseTensorTransformOps.h - sparse tensor transform ops -*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMOPS_SPARSETENSORTRANSFORMOPS_H +#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMOPS_SPARSETENSORTRANSFORMOPS_H + +#include "mlir/Dialect/Transform/IR/MatchInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/RegionKindInterface.h" + +namespace mlir { +namespace transform { +class TransformHandleTypeInterface; +} // namespace transform +} // namespace mlir + +namespace mlir { +class DialectRegistry; + +namespace sparse_tensor { +void registerTransformDialectExtension(DialectRegistry ®istry); +} // namespace sparse_tensor +} // namespace mlir + +//===----------------------------------------------------------------------===// +// SparseTensor Transform Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h.inc" + +#endif // MLIR_DIALECT_SPARSETENSOR_TRANSFORMOPS_SPARSETENSORTRANSFORMOPS_H diff --git a/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.td @@ -0,0 +1,41 @@ +//=== SparseTensorTransformOps.td ----------------------------*-tablegen *-===// +// +// Tablegen rules for Sparse Tensor transformation operations. +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef SPARSETENSOR_TRANSFORM_OPS +#define SPARSETENSOR_TRANSFORM_OPS + +include "mlir/Dialect/Transform/IR/MatchInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformAttrs.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +//===----------------------------------------------------------------------===// +// SparseTensorMatchOps +//===----------------------------------------------------------------------===// + +def MatchSparseInOut : Op { + let description = [{ + Checks if the payload op has any sparse inputs and/or outputs. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs TransformHandleTypeInterface:$result); + + let assemblyFormat = "$target attr-dict `:` custom(type($target), type($result))"; + let extraClassDeclaration = SingleOpMatcher.extraDeclaration # [{ + ::mlir::Value getOperandHandle() { return getTarget(); } + }]; +} + +#endif // SPARSETENSOR_TRANSFORM_OPS diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -69,6 +69,7 @@ #include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/Shape/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" +#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h" #include "mlir/Dialect/SparseTensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.h" @@ -142,6 +143,7 @@ memref::registerTransformDialectExtension(registry); nvgpu::registerTransformDialectExtension(registry); scf::registerTransformDialectExtension(registry); + sparse_tensor::registerTransformDialectExtension(registry); tensor::registerTransformDialectExtension(registry); transform::registerPDLExtension(registry); vector::registerTransformDialectExtension(registry); diff --git a/mlir/lib/Dialect/SparseTensor/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/CMakeLists.txt --- a/mlir/lib/Dialect/SparseTensor/CMakeLists.txt +++ b/mlir/lib/Dialect/SparseTensor/CMakeLists.txt @@ -1,4 +1,5 @@ add_subdirectory(IR) add_subdirectory(Transforms) +add_subdirectory(TransformOps) add_subdirectory(Pipelines) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/SparseTensor/TransformOps/CMakeLists.txt b/mlir/lib/Dialect/SparseTensor/TransformOps/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/TransformOps/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRSparseTensorTransformOps + SparseTensorTransformOps.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/SparseTensor/TransformOps + + DEPENDS + MLIRSparseTensorTransformOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRSparseTensorDialect + MLIRParser + MLIRSideEffectInterfaces + MLIRTransformDialect + MLIRTransformDialectUtils + ) diff --git a/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp @@ -0,0 +1,57 @@ +//===- SparseTensorTransformOps.cpp - sparse tensor transform ops impl ----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.h" +#include "mlir/Dialect/Linalg/TransformOps/Syntax.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" + +using namespace mlir; +using namespace mlir::sparse_tensor; + +//===----------------------------------------------------------------------===// +// Transform op implementation +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::MatchSparseInOut::matchOperation( + mlir::Operation *current, mlir::transform::TransformResults &results, + mlir::transform::TransformState &state) { + bool hasSparseInOut = hasAnySparseOperandOrResult(current); + if (!hasSparseInOut) { + return emitSilenceableFailure(current->getLoc(), + "operation has no sparse input or output"); + } + results.set(getResult().cast(), state.getPayloadOps(getTarget())); + return DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// Transform op registration +//===----------------------------------------------------------------------===// + +namespace { +class SparseTensorTransformDialectExtension + : public transform::TransformDialectExtension< + SparseTensorTransformDialectExtension> { +public: + SparseTensorTransformDialectExtension() { + declareGeneratedDialect(); + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +#define GET_OP_CLASSES +#include "mlir/Dialect/SparseTensor/TransformOps/SparseTensorTransformOps.cpp.inc" + +void mlir::sparse_tensor::registerTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} diff --git a/mlir/test/Integration/Dialect/SparseTensor/TransformOps/match_sparse_ops.mlir b/mlir/test/Integration/Dialect/SparseTensor/TransformOps/match_sparse_ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Integration/Dialect/SparseTensor/TransformOps/match_sparse_ops.mlir @@ -0,0 +1,48 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter --verify-diagnostics --split-input-file + +module attributes { transform.with_named_sequence } { + transform.named_sequence @match_sparse_structured(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op { + %0 = transform.match.structured %arg0 : (!transform.any_op) -> !transform.any_op { + ^bb0(%struct: !transform.any_op): + %sp_kernel = transform.sparse_tensor.match.sparse_inout %struct + : (!transform.any_op) -> !transform.any_op + transform.match.structured.yield %sp_kernel : !transform.any_op + } + transform.yield %0 : !transform.any_op + } + + transform.named_sequence @print_sparse_structured(%arg0: !transform.any_op {transform.readonly}) { + transform.test_print_remark_at_operand %arg0, "sparse_kernel" : !transform.any_op + transform.yield + } + + // Entry point. Match any structured sparse operation and emit at remark. + transform.sequence failures(propagate) attributes { transform.target_tag = "transform" } { + ^bb0(%arg0: !transform.any_op): + transform.foreach_match in %arg0 + @match_sparse_structured -> @print_sparse_structured + : (!transform.any_op) -> !transform.any_op + } +} + +#CSR = #sparse_tensor.encoding<{lvlTypes = ["dense", "compressed"]}> + +func.func @payload(%lhs: tensor<10x20xf16>, + %sp_lhs: tensor<10x20xf16, #CSR>, + %rhs: tensor<20x15xf32>) -> tensor<10x15xf64>{ + %cst = arith.constant 0.0 : f64 + %empty = tensor.empty() : tensor<10x15xf64> + %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<10x15xf64>) -> tensor<10x15xf64> + + %result = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>) + outs(%fill: tensor<10x15xf64>) -> tensor<10x15xf64> + // expected-remark @below {{sparse_kernel}} + %sp_in = linalg.matmul ins(%sp_lhs, %rhs: tensor<10x20xf16, #CSR>, tensor<20x15xf32>) + outs(%fill: tensor<10x15xf64>) -> tensor<10x15xf64> + + %sp_empty = tensor.empty() : tensor<10x15xf64, #CSR> + // expected-remark @below {{sparse_kernel}} + %sp_out = linalg.matmul ins(%lhs, %rhs: tensor<10x20xf16>, tensor<20x15xf32>) + outs(%sp_empty: tensor<10x15xf64, #CSR>) -> tensor<10x15xf64, #CSR> + return %result : tensor<10x15xf64> +}