diff --git a/mlir/include/mlir/Dialect/GPU/InferIntRangeInterfaceImpl.h b/mlir/include/mlir/Dialect/GPU/InferIntRangeInterfaceImpl.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/GPU/InferIntRangeInterfaceImpl.h @@ -0,0 +1,20 @@ +//===- InferIntRangeInterfaceImpl.h - Impl. of InferIntRangeInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_GPU_INFERINTRANGEINTERFACEIMPL_H +#define MLIR_DIALECT_GPU_INFERINTRANGEINTERFACEIMPL_H + +namespace mlir { +class DialectRegistry; + +namespace gpu { +void registerInferIntRangeInterfaceExternalModels(DialectRegistry ®istry); +} // namespace gpu +} // namespace mlir + +#endif // MLIR_DIALECT_GPU_INFERINTRANGEINTERFACE_H diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -28,6 +28,7 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/Dialect/GPU/InferIntRangeInterfaceImpl.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -106,6 +107,8 @@ tensor::registerInferTypeOpInterfaceExternalModels(registry); tensor::registerTilingOpInterfaceExternalModels(registry); vector::registerBufferizableOpInterfaceExternalModels(registry); + + gpu::registerInferIntRangeInterfaceExternalModels(registry); } /// Append all the MLIR dialects to the registry contained in the given context. diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -43,6 +43,7 @@ add_mlir_dialect_library(MLIRGPUTransforms Transforms/AllReduceLowering.cpp Transforms/AsyncRegionRewriter.cpp + Transforms/InferIntRangeInterfaceImpl.cpp Transforms/KernelOutlining.cpp Transforms/MemoryPromotion.cpp Transforms/ParallelLoopMapper.cpp diff --git a/mlir/lib/Dialect/GPU/Transforms/InferIntRangeInterfaceImpl.cpp b/mlir/lib/Dialect/GPU/Transforms/InferIntRangeInterfaceImpl.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/GPU/Transforms/InferIntRangeInterfaceImpl.cpp @@ -0,0 +1,51 @@ +//===- InferIntRangeInterfaceImpl.cpp - Impl. of InferIntRangeInterface ---===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/GPU/InferIntRangeInterfaceImpl.h" + +#include "mlir/Dialect/Arithmetic/IR/InferIntRangeInterface.h" +#include "mlir/Dialect/GPU/GPUDialect.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Operation.h" + +using namespace mlir; +using namespace mlir::arith; +using namespace mlir::gpu; + +namespace { +/// Declare that index parameters such as block and thread IDs should be +/// interpereted as unsigned, but do not impose a bound, such as them being +/// below size_max on them. +template +struct IndexOpUnsigned + : InferIntRangeInterface::ExternalModel, Op> { + using InferIntRangeInterface::ExternalModel, + Op>::ExternalModel; + void inferResultRanges(Operation *rawOp, ArrayRef argRanges, + SmallVectorImpl &resultRanges) const { + Op theOp = cast(rawOp); + resultRanges.push_back( + {IntegerAttr::get(theOp->getResultTypes()[0], 0), {}}); + } +}; +} // end namespace + +void mlir::gpu::registerInferIntRangeInterfaceExternalModels( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *ctx, GPUDialect *dialect) { + BlockDimOp::attachInterface>(*ctx); + BlockIdOp::attachInterface>(*ctx); + ThreadIdOp::attachInterface>(*ctx); + LaneIdOp::attachInterface>(*ctx); + GlobalIdOp::attachInterface>(*ctx); + SubgroupIdOp::attachInterface>(*ctx); + NumSubgroupsOp::attachInterface>(*ctx); + SubgroupSizeOp::attachInterface>(*ctx); + }); +} diff --git a/mlir/test/Dialect/GPU/infer-int-range.mlir b/mlir/test/Dialect/GPU/infer-int-range.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/GPU/infer-int-range.mlir @@ -0,0 +1,50 @@ +// RUN: mlir-opt -arith-fold-inferred-constants -canonicalize -arith-unsigned-when-equivalent %s | FileCheck %s + +// Generated from the function body +// %1 = affine.apply affine_map<(d0) -> (d0 floordiv 9 + ((d0 mod 9) floordiv 3) * 3 + (d0 mod 3) * 9)>(%0) +// return %1 : index +// via -lower-affine -canonicalize +// CHECK-LABEL: func @like_affine_map +// CHECK-DAG: %[[c3:.*]] = arith.constant 3 +// CHECK-DAG: %[[c9:.*]] = arith.constant 9 +// CHECK-DAG: %[[blk:.*]] = gpu.block_id +// CHECK: %[[i0:.*]] = arith.divsi %[[blk]], %[[c9]] +// CHECK: %[[i1a:.*]] = arith.remsi %[[blk]], %[[c9]] +// CHECK: %[[i1b:.*]] = arith.divui %[[i1a]], %[[c3]] +// CHECK: %[[i1:.*]] = arith.muli %[[i1b]], %[[c3]] +// CHECK: %[[part1:.*]] = arith.addi %[[i0]], %[[i1]] +// CHECK: %[[i2a:.*]] = arith.remsi %[[blk]], %[[c3]] +// CHECK: %[[i2:.*]] = arith.muli %[[i2a]], %[[c9]] +// CHECK: %[[ret:.*]] = arith.addi %[[part1]], %[[i2]] +// CHECK: return %[[ret]] +func @like_affine_map() -> index { + %c9 = arith.constant 9 : index + %c0 = arith.constant 0 : index + %c3 = arith.constant 3 : index + %c-1 = arith.constant -1 : index + %0 = gpu.block_id x + %1 = arith.cmpi slt, %0, %c0 : index + %2 = arith.subi %c-1, %0 : index + %3 = arith.select %1, %2, %0 : index + %4 = arith.divsi %3, %c9 : index + %5 = arith.subi %c-1, %4 : index + %6 = arith.select %1, %5, %4 : index + %7 = arith.remsi %0, %c9 : index + %8 = arith.cmpi slt, %7, %c0 : index + %9 = arith.addi %7, %c9 : index + %10 = arith.select %8, %9, %7 : index + %11 = arith.cmpi slt, %10, %c0 : index + %12 = arith.subi %c-1, %10 : index %13 = arith.select %11, %12, %10 : index + %14 = arith.divsi %13, %c3 : index + %15 = arith.subi %c-1, %14 : index + %16 = arith.select %11, %15, %14 : index + %17 = arith.muli %16, %c3 : index + %18 = arith.addi %6, %17 : index + %19 = arith.remsi %0, %c3 : index + %20 = arith.cmpi slt, %19, %c0 : index + %21 = arith.addi %19, %c3 : index + %22 = arith.select %20, %21, %19 : index + %23 = arith.muli %22, %c9 : index + %24 = arith.addi %18, %23 : index + return %24 : index +}