Index: mlir/include/mlir/Dialect/CMakeLists.txt =================================================================== --- mlir/include/mlir/Dialect/CMakeLists.txt +++ mlir/include/mlir/Dialect/CMakeLists.txt @@ -15,6 +15,7 @@ add_subdirectory(Linalg) add_subdirectory(LLVMIR) add_subdirectory(MemRef) +add_subdirectory(NVPTX) add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(PDL) Index: mlir/include/mlir/Dialect/NVPTX/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/NVPTX/CMakeLists.txt @@ -0,0 +1,4 @@ +add_mlir_dialect(NVPTX nvptx) +add_mlir_doc(NVPTX -gen-dialect-doc NVPTX Dialects/) + +set(LLVM_TARGET_DEFINITIONS NVPTX.td) Index: mlir/include/mlir/Dialect/NVPTX/NVPTX.td =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/NVPTX/NVPTX.td @@ -0,0 +1,72 @@ +//===-- NVPTX.td - NVPTX dialect operation definitions *- tablegen -*------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the basic operations for the NVPTX dialect. +// +// This NVPTX provides a bridge between the target agnostic GPU and Vector +// dialects and lower level NVVM dialect. This allow representing PTX specific +// operations while using MLIR high level concepts like memref and 2-D vector. +// +// Ops semantic are going to be based on vendor specific PTX defintion: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html +// +//===----------------------------------------------------------------------===// + +#ifndef NVPTX +#define NVPTX + +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" + +def NVPTX_Dialect : Dialect { + let name = "nvptx"; + let cppNamespace = "::mlir::nvptx"; + let description = [{ + This `NVPTX` dialect provides a bridge between the target agnostic GPU and + Vector dialects and the lower level LLVM IR based NVVM dialect. This allow + representing PTX specific operations while using MLIR high level concepts + like memref and 2-D vector. + }]; +} + +//===----------------------------------------------------------------------===// +// NVPTX Op definitions +//===----------------------------------------------------------------------===// + +class NVPTX_Op traits = []> : + Op {} + +def NVPTX_LdMatrixOp : NVPTX_Op<"ldmatrix", + [MemoryEffects<[MemRead]>]> { + let description = [{ + The `nvptx.ldmatrix` op represents loading a matrix fragment from + memory. The load source and result type must be compatible with lowering + to the `nvvm.ldmatrix` instruction. This op is meant to represent + the distributed version of a `vector.transfer_read` as an intermediate + step between lowering from `vector.transfer_read` to `nvvm.ldmatrix`. + + This operation is meant to follow the semantic of described here: + https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix + + Example: + ```mlir + %0 = nvptx.ldmatrix %sm[%c0, %c0] {numTiles = 4 : i32, transpose = false} : + memref -> vector<4x2xf16> + ``` + }]; + + let arguments = (ins Arg:$srcMemref, + Variadic:$indices, BoolAttr:$transpose, + I32Attr:$numTiles); + let results = (outs AnyVector:$res); + let assemblyFormat = [{ + $srcMemref`[` $indices `]` attr-dict `:` type($srcMemref) `->` type($res) + }]; +} + +#endif // NVPTX Index: mlir/include/mlir/Dialect/NVPTX/NVPTXDialect.h =================================================================== --- /dev/null +++ mlir/include/mlir/Dialect/NVPTX/NVPTXDialect.h @@ -0,0 +1,26 @@ +//===- NVPTXDialect.h - MLIR Dialect for NVPTX ------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares the Target dialect for NVPTX in MLIR. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_NVPTX_NVPTXDIALECT_H_ +#define MLIR_DIALECT_NVPTX_NVPTXDIALECT_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#include "mlir/Dialect/NVPTX/NVPTXDialect.h.inc" + +#define GET_OP_CLASSES +#include "mlir/Dialect/NVPTX/NVPTX.h.inc" + +#endif // MLIR_DIALECT_NVPTX_NVPTXDIALECT_H_ Index: mlir/include/mlir/InitAllDialects.h =================================================================== --- mlir/include/mlir/InitAllDialects.h +++ mlir/include/mlir/InitAllDialects.h @@ -35,6 +35,7 @@ #include "mlir/Dialect/Linalg/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVPTX/NVPTXDialect.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" @@ -77,6 +78,7 @@ linalg::LinalgDialect, math::MathDialect, memref::MemRefDialect, + nvptx::NVPTXDialect, scf::SCFDialect, omp::OpenMPDialect, pdl::PDLDialect, Index: mlir/lib/Dialect/CMakeLists.txt =================================================================== --- mlir/lib/Dialect/CMakeLists.txt +++ mlir/lib/Dialect/CMakeLists.txt @@ -15,6 +15,7 @@ add_subdirectory(LLVMIR) add_subdirectory(Math) add_subdirectory(MemRef) +add_subdirectory(NVPTX) add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(PDL) Index: mlir/lib/Dialect/NVPTX/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/NVPTX/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) Index: mlir/lib/Dialect/NVPTX/IR/CMakeLists.txt =================================================================== --- /dev/null +++ mlir/lib/Dialect/NVPTX/IR/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRNVPTX + NVPTXDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/NVPTX + + DEPENDS + MLIRNVPTXIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRSideEffectInterfaces + ) Index: mlir/lib/Dialect/NVPTX/IR/NVPTXDialect.cpp =================================================================== --- /dev/null +++ mlir/lib/Dialect/NVPTX/IR/NVPTXDialect.cpp @@ -0,0 +1,30 @@ +//===- NVPTXDialect.cpp - MLIR NVPTX ops implementation -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements the NVPTX dialect and its operations. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/NVPTX/NVPTXDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/TypeUtilities.h" + +using namespace mlir; + +#include "mlir/Dialect/NVPTX/NVPTXDialect.cpp.inc" + +void nvptx::NVPTXDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/NVPTX/NVPTX.cpp.inc" + >(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/NVPTX/NVPTX.cpp.inc" Index: mlir/test/Dialect/NVPTX/roundtrip.mlir =================================================================== --- /dev/null +++ mlir/test/Dialect/NVPTX/roundtrip.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt %s | mlir-opt | FileCheck %s + +// CHECK-LABEL: func @ldmatrix( +func @ldmatrix(%arg0: memref, %x: index, %y: index) { +// CHECK: nvptx.ldmatrix %{{.*}}[%{{.*}}, %{{.*}}] +// CHECK-SAME: {numTiles = 4 : i32, transpose = false} : memref -> vector<4x2xf16> + %l = nvptx.ldmatrix %arg0[%x, %y] {numTiles = 4 : i32, transpose = false} : + memref -> vector<4x2xf16> + return +} Index: mlir/test/mlir-opt/commandline.mlir =================================================================== --- mlir/test/mlir-opt/commandline.mlir +++ mlir/test/mlir-opt/commandline.mlir @@ -19,6 +19,7 @@ // CHECK-NEXT: llvm // CHECK-NEXT: math // CHECK-NEXT: memref +// CHECK-NEXT: nvptx // CHECK-NEXT: nvvm // CHECK-NEXT: omp // CHECK-NEXT: pdl Index: utils/bazel/llvm-project-overlay/mlir/BUILD.bazel =================================================================== --- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -1995,6 +1995,69 @@ ], ) +##---------------------------------------------------------------------------## +# NVPTX dialect. +##---------------------------------------------------------------------------## + +td_library( + name = "NVPTXTdFiles", + srcs = ["include/mlir/Dialect/NVPTX/NVPTX.td"], + includes = ["include"], + deps = [ + ":SideEffectInterfacesTdFiles", + ], +) + +gentbl_cc_library( + name = "NVPTXIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "-gen-dialect-decls", + "-dialect=nvptx", + ], + "include/mlir/Dialect/NVPTX/NVPTXDialect.h.inc", + ), + ( + [ + "-gen-dialect-defs", + "-dialect=nvptx", + ], + "include/mlir/Dialect/NVPTX/NVPTXDialect.cpp.inc", + ), + ( + ["-gen-op-decls"], + "include/mlir/Dialect/NVPTX/NVPTX.h.inc", + ), + ( + ["-gen-op-defs"], + "include/mlir/Dialect/NVPTX/NVPTX.cpp.inc", + ), + ( + ["-gen-op-doc"], + "g3doc/Dialects/NVPTX/NVPTX.md", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/NVPTX/NVPTX.td", + deps = [":NVPTXTdFiles"], +) + +cc_library( + name = "NVPTX", + srcs = ["lib/Dialect/NVPTX/IR/NVPTXDialect.cpp"], + hdrs = ["include/mlir/Dialect/NVPTX/NVPTXDialect.h"], + includes = ["include"], + deps = [ + ":IR", + ":NVPTXIncGen", + ":SideEffectInterfaces", + "//llvm:Core", + "//llvm:Support", + ], +) + td_library( name = "FuncTdFiles", srcs = [ @@ -5946,6 +6009,7 @@ ":MemRefToLLVM", ":MemRefToSPIRV", ":MemRefTransforms", + ":NVPTX", ":NVVMDialect", ":OpenACCDialect", ":OpenMPDialect",