diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h --- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h @@ -22,6 +22,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td --- a/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td @@ -20,6 +20,7 @@ include "mlir/IR/FunctionInterfaces.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/DataLayoutInterfaces.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -465,8 +466,9 @@ let hasVerifier = 1; } -def GPU_LaunchOp : GPU_Op<"launch", - [AutomaticAllocationScope, AttrSizedOperandSegments, GPU_AsyncOpInterface]>, +def GPU_LaunchOp : GPU_Op<"launch", [ + AutomaticAllocationScope, AttrSizedOperandSegments, GPU_AsyncOpInterface, + DeclareOpInterfaceMethods]>, Arguments<(ins Variadic:$asyncDependencies, Index:$gridSizeX, Index:$gridSizeY, Index:$gridSizeZ, Index:$blockSizeX, Index:$blockSizeY, Index:$blockSizeZ, diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -680,6 +680,25 @@ rewrites.add(context); } +void LaunchOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRange) { + unsigned width = IndexType::kInternalStorageBitWidth; + // Number of blocks/threads is at least 1, at most 2^31-1. + ConstantIntRanges dimLimits = ConstantIntRanges::fromSigned( + APInt(width, 1), APInt(width, std::numeric_limits::max())); + argRanges = argRanges.drop_front(asyncDependencies().size()); + for (int i = 0; i < 6; ++i) { + if (argRanges[i].umin().getBitWidth() != width) + continue; + ConstantIntRanges dimRange = argRanges[i].intersection(dimLimits); + APInt umax = dimRange.umax() - APInt(width, 1); + APInt smax = dimRange.smax() - APInt(width, 1); + ConstantIntRanges idxRange(APInt(width, 0), umax, APInt(width, 0), smax); + setResultRange(body().getArgument(i), idxRange); + setResultRange(body().getArgument(i + 6), dimRange); + } +} + //===----------------------------------------------------------------------===// // LaunchFuncOp //===----------------------------------------------------------------------===// diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3490,6 +3490,7 @@ ":DLTIDialectTdFiles", ":DataLayoutInterfacesTdFiles", ":FunctionInterfacesTdFiles", + ":InferIntRangeInterfaceTdFiles", ":LLVMOpsTdFiles", ":OpBaseTdFiles", ":SideEffectInterfacesTdFiles", @@ -3581,6 +3582,7 @@ ":GPUBaseIncGen", ":GPUOpsIncGen", ":IR", + ":InferIntRangeInterface", ":InferTypeOpInterface", ":LLVMDialect", ":MemRefDialect",