diff --git a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h --- a/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h +++ b/mlir/include/mlir/Conversion/StandardToSPIRV/StandardToSPIRV.h @@ -25,6 +25,23 @@ SPIRVTypeConverter &typeConverter, OwningRewritePatternList &patterns); +/// Appends to a pattern list additional patterns for translating tensor ops +/// to SPIR-V ops. +/// +/// Note: Normally tensors will be stored in buffers before converting to +/// SPIR-V, given that is how a large amount of data is sent to the GPU. +/// However, SPIR-V supports converting from tensors directly too. This is +/// for the cases where the tensor just contains a small amount of elements +/// and it makes sense to directly inline them as a small data array in the +/// shader. To handle this, internally the conversion might create new local +/// variables. SPIR-V consumers in GPU drivers may or may not optimize that +/// away. So this has implications over register pressure. Therefore, a +/// threshold is used to control when the patterns should kick in. +void populateTensorToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + int64_t byteCountThreshold, + OwningRewritePatternList &patterns); + /// Appends to a pattern list patterns to legalize ops that are not directly /// lowered to SPIR-V. void populateStdLegalizationPatternsForSPIRVLowering( diff --git a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h --- a/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h +++ b/mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h @@ -104,6 +104,11 @@ Value getBuiltinVariableValue(Operation *op, BuiltIn builtin, OpBuilder &builder); +/// Generates IR to perform index linearization with the given `indices` and +/// their corresponding `strides`, adding an initial `offset`. +Value linearizeIndex(ValueRange indices, ArrayRef strides, + int64_t offset, Location loc, OpBuilder &builder); + /// Performs the index computation to get to the element at `indices` of the /// memory pointed to by `basePtr`, using the layout map of `baseType`. diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/SPIRV/Utils/LayoutUtils.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/AffineMap.h" #include "mlir/Support/LogicalResult.h" #include "llvm/ADT/SetVector.h" @@ -512,6 +513,65 @@ } }; +/// Converts tensor.extract into loading using access chains from SPIR-V local +/// variables. +class TensorExtractPattern final + : public OpConversionPattern { +public: + TensorExtractPattern(TypeConverter &typeConverter, MLIRContext *context, + int64_t threshold, PatternBenefit benefit = 1) + : OpConversionPattern(typeConverter, context, benefit), + byteCountThreshold(threshold) {} + + LogicalResult + matchAndRewrite(tensor::ExtractOp extractOp, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + TensorType tensorType = extractOp.tensor().getType().cast(); + + if (!tensorType.hasStaticShape()) + return rewriter.notifyMatchFailure(extractOp, "non-static tensor"); + + if (tensorType.getNumElements() * tensorType.getElementTypeBitWidth() > + byteCountThreshold * 8) + return rewriter.notifyMatchFailure(extractOp, + "exceeding byte count threshold"); + + Location loc = extractOp.getLoc(); + tensor::ExtractOp::Adaptor adaptor(operands); + + int64_t rank = tensorType.getRank(); + SmallVector strides(rank, 1); + for (int i = rank - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * tensorType.getDimSize(i + 1); + } + + Type varType = spirv::PointerType::get(adaptor.tensor().getType(), + spirv::StorageClass::Function); + + spirv::VariableOp varOp; + if (adaptor.tensor().getDefiningOp()) { + varOp = rewriter.create( + loc, varType, spirv::StorageClass::Function, + /*initializer=*/adaptor.tensor()); + } else { + // Need to store the value to the local variable. It's questionable + // whether we want to support such case though. + return failure(); + } + + Value index = spirv::linearizeIndex(adaptor.indices(), strides, + /*offset=*/0, loc, rewriter); + auto acOp = rewriter.create(loc, varOp, index); + + rewriter.replaceOpWithNewOp(extractOp, acOp); + + return success(); + } + +private: + int64_t byteCountThreshold; +}; + /// Converts std.trunci to spv.Select if the type of result is i1 or vector of /// i1. class TruncI1Pattern final : public OpConversionPattern { @@ -622,6 +682,9 @@ // ConstantOp with composite type. //===----------------------------------------------------------------------===// +// TODO: This probably should be split into the vector case and tensor case, +// so that the tensor case can be moved to TensorToSPIRV conversion. But, +// std.constant is for the standard dialect though. LogicalResult ConstantCompositeOpPattern::matchAndRewrite( ConstantOp constOp, ArrayRef operands, ConversionPatternRewriter &rewriter) const { @@ -1170,6 +1233,7 @@ UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, UnaryAndBinaryOpPattern, + // Unary and binary patterns BitwiseOpPattern, BitwiseOpPattern, @@ -1224,4 +1288,13 @@ patterns.insert(typeConverter, context, /*benefit=*/2); } + +void populateTensorToSPIRVPatterns(MLIRContext *context, + SPIRVTypeConverter &typeConverter, + int64_t byteCountThreshold, + OwningRewritePatternList &patterns) { + patterns.insert(typeConverter, context, + byteCountThreshold); +} + } // namespace mlir diff --git a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp @@ -37,6 +37,8 @@ SPIRVTypeConverter typeConverter(targetAttr); OwningRewritePatternList patterns; populateStandardToSPIRVPatterns(context, typeConverter, patterns); + populateTensorToSPIRVPatterns(context, typeConverter, + /*byteCountThreshold=*/64, patterns); populateBuiltinFuncToSPIRVPatterns(context, typeConverter, patterns); if (failed(applyPartialConversion(module, *target, std::move(patterns)))) diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp @@ -607,6 +607,31 @@ // Index calculation //===----------------------------------------------------------------------===// +Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef strides, + int64_t offset, Location loc, + OpBuilder &builder) { + assert(indices.size() == strides.size() && + "must provide indices for all dimensions"); + + auto indexType = SPIRVTypeConverter::getIndexType(builder.getContext()); + + // TODO: Consider moving to use affine.apply and patterns converting + // affine.apply to standard ops. This needs converting to SPIR-V passes to be + // broken down into progressive small steps so we can have intermediate steps + // using other dialects. At the moment SPIR-V is the final sink. + + Value linearizedIndex = builder.create( + loc, indexType, IntegerAttr::get(indexType, offset)); + for (auto index : llvm::enumerate(indices)) { + Value strideVal = builder.create( + loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); + Value update = builder.create(loc, strideVal, index.value()); + linearizedIndex = + builder.create(loc, linearizedIndex, update); + } + return linearizedIndex; +} + spirv::AccessChainOp mlir::spirv::getElementPtr( SPIRVTypeConverter &typeConverter, MemRefType baseType, Value basePtr, ValueRange indices, Location loc, OpBuilder &builder) { @@ -623,28 +648,16 @@ auto indexType = typeConverter.getIndexType(builder.getContext()); SmallVector linearizedIndices; - // Add a '0' at the start to index into the struct. auto zero = spirv::ConstantOp::getZero(indexType, loc, builder); + + // Add a '0' at the start to index into the struct. linearizedIndices.push_back(zero); if (baseType.getRank() == 0) { linearizedIndices.push_back(zero); } else { - // TODO: Instead of this logic, use affine.apply and add patterns for - // lowering affine.apply to standard ops. These will get lowered to SPIR-V - // ops by the DialectConversion framework. - Value ptrLoc = builder.create( - loc, indexType, IntegerAttr::get(indexType, offset)); - assert(indices.size() == strides.size() && - "must provide indices for all dimensions"); - for (auto index : llvm::enumerate(indices)) { - Value strideVal = builder.create( - loc, indexType, IntegerAttr::get(indexType, strides[index.index()])); - Value update = - builder.create(loc, strideVal, index.value()); - ptrLoc = builder.create(loc, ptrLoc, update); - } - linearizedIndices.push_back(ptrLoc); + linearizedIndices.push_back( + linearizeIndex(indices, strides, offset, loc, builder)); } return builder.create(loc, basePtr, linearizedIndices); } diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -1148,3 +1148,32 @@ } } + +// ----- + +//===----------------------------------------------------------------------===// +// tensor.extract +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @tensor_extract_constant +// CHECK-SAME: (%[[A:.+]]: i32, %[[B:.+]]: i32, %[[C:.+]]: i32) +func @tensor_extract_constant(%a : index, %b: index, %c: index) -> i32 { + // CHECK: %[[CST:.+]] = spv.Constant dense<[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]> + %cst = constant dense<[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]> : tensor<2x2x3xi32> + // CHECK: %[[VAR:.+]] = spv.Variable init(%[[CST]]) : !spv.ptr, Function> + // CHECK: %[[C0:.+]] = spv.Constant 0 : i32 + // CHECK: %[[C6:.+]] = spv.Constant 6 : i32 + // CHECK: %[[MUL0:.+]] = spv.IMul %[[C6]], %[[A]] : i32 + // CHECK: %[[ADD0:.+]] = spv.IAdd %[[C0]], %[[MUL0]] : i32 + // CHECK: %[[C3:.+]] = spv.Constant 3 : i32 + // CHECK: %[[MUL1:.+]] = spv.IMul %[[C3]], %[[B]] : i32 + // CHECK: %[[ADD1:.+]] = spv.IAdd %[[ADD0]], %[[MUL1]] : i32 + // CHECK: %[[C1:.+]] = spv.Constant 1 : i32 + // CHECK: %[[MUL2:.+]] = spv.IMul %[[C1]], %[[C]] : i32 + // CHECK: %[[ADD2:.+]] = spv.IAdd %[[ADD1]], %[[MUL2]] : i32 + // CHECK: %[[AC:.+]] = spv.AccessChain %[[VAR]][%[[ADD2]]] + // CHECK: %[[VAL:.+]] = spv.Load "Function" %[[AC]] : i32 + %extract = tensor.extract %cst[%a, %b, %c] : tensor<2x2x3xi32> + // CHECK: spv.ReturnValue %[[VAL]] + return %extract : i32 +}