Skip to content

Commit fbefc66

Browse files
[mlir][Transforms] Context-aware Type Converter
Co-authored-by: Markus Böck <markus.boeck02@gmail.com>
1 parent bfc331e commit fbefc66

File tree

7 files changed

+304
-56
lines changed

7 files changed

+304
-56
lines changed

mlir/docs/DialectConversion.md

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,15 @@ target types. If the source type is converted to itself, we say it is a "legal"
280280
type. Type conversions are specified via the `addConversion` method described
281281
below.
282282

283+
There are two kind of conversion functions: context-aware and context-unaware
284+
conversions. A context-unaware conversion function converts a `Type` into a
285+
`Type`. A context-aware conversion function converts a `Value` into a type. The
286+
latter allows users to customize type conversion rules based on the IR.
287+
288+
Note: When there is at least one context-aware type conversion function, the
289+
result of type conversions can no longer be cached, which can increase
290+
compilation time. Use this feature with caution!
291+
283292
A `materialization` describes how a list of values should be converted to a
284293
list of values with specific types. An important distinction from a
285294
`conversion` is that a `materialization` can produce IR, whereas a `conversion`
@@ -332,29 +341,31 @@ Several of the available hooks are detailed below:
332341
```c++
333342
class TypeConverter {
334343
public:
335-
/// Register a conversion function. A conversion function defines how a given
336-
/// source type should be converted. A conversion function must be convertible
337-
/// to any of the following forms(where `T` is a class derived from `Type`:
338-
/// * Optional<Type>(T)
344+
/// Register a conversion function. A conversion function must be convertible
345+
/// to any of the following forms (where `T` is `Value` or a class derived
346+
/// from `Type`, including `Type` itself):
347+
///
348+
/// * std::optional<Type>(T)
339349
/// - This form represents a 1-1 type conversion. It should return nullptr
340-
/// or `std::nullopt` to signify failure. If `std::nullopt` is returned, the
341-
/// converter is allowed to try another conversion function to perform
342-
/// the conversion.
343-
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &)
350+
/// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
351+
/// the converter is allowed to try another conversion function to
352+
/// perform the conversion.
353+
/// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &)
344354
/// - This form represents a 1-N type conversion. It should return
345-
/// `failure` or `std::nullopt` to signify a failed conversion. If the new
346-
/// set of types is empty, the type is removed and any usages of the
355+
/// `failure` or `std::nullopt` to signify a failed conversion. If the
356+
/// new set of types is empty, the type is removed and any usages of the
347357
/// existing value are expected to be removed during conversion. If
348358
/// `std::nullopt` is returned, the converter is allowed to try another
349359
/// conversion function to perform the conversion.
350-
/// * Optional<LogicalResult>(T, SmallVectorImpl<Type> &, ArrayRef<Type>)
351-
/// - This form represents a 1-N type conversion supporting recursive
352-
/// types. The first two arguments and the return value are the same as
353-
/// for the regular 1-N form. The third argument is contains is the
354-
/// "call stack" of the recursive conversion: it contains the list of
355-
/// types currently being converted, with the current type being the
356-
/// last one. If it is present more than once in the list, the
357-
/// conversion concerns a recursive type.
360+
///
361+
/// Conversion functions that accept `Value` as the first argument are
362+
/// context-aware. I.e., they can take into account IR when converting the
363+
/// type of the given value. Context-unaware conversion functions accept
364+
/// `Type` or a derived class as the first argument.
365+
///
366+
/// Note: Context-unaware conversions are cached, but context-aware
367+
/// conversions are not.
368+
///
358369
/// Note: When attempting to convert a type, e.g. via 'convertType', the
359370
/// mostly recently added conversions will be invoked first.
360371
template <typename FnT,

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 92 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ class TypeConverter {
139139
};
140140

141141
/// Register a conversion function. A conversion function must be convertible
142-
/// to any of the following forms (where `T` is a class derived from `Type`):
142+
/// to any of the following forms (where `T` is `Value` or a class derived
143+
/// from `Type`, including `Type` itself):
143144
///
144145
/// * std::optional<Type>(T)
145146
/// - This form represents a 1-1 type conversion. It should return nullptr
@@ -154,6 +155,14 @@ class TypeConverter {
154155
/// `std::nullopt` is returned, the converter is allowed to try another
155156
/// conversion function to perform the conversion.
156157
///
158+
/// Conversion functions that accept `Value` as the first argument are
159+
/// context-aware. I.e., they can take into account IR when converting the
160+
/// type of the given value. Context-unaware conversion functions accept
161+
/// `Type` or a derived class as the first argument.
162+
///
163+
/// Note: Context-unaware conversions are cached, but context-aware
164+
/// conversions are not.
165+
///
157166
/// Note: When attempting to convert a type, e.g. via 'convertType', the
158167
/// mostly recently added conversions will be invoked first.
159168
template <typename FnT, typename T = typename llvm::function_traits<
@@ -242,15 +251,28 @@ class TypeConverter {
242251
wrapTypeAttributeConversion<T, A>(std::forward<FnT>(callback)));
243252
}
244253

245-
/// Convert the given type. This function should return failure if no valid
254+
/// Convert the given type. This function returns failure if no valid
246255
/// conversion exists, success otherwise. If the new set of types is empty,
247256
/// the type is removed and any usages of the existing value are expected to
248257
/// be removed during conversion.
258+
///
259+
/// Note: This overload invokes only context-unaware type conversion
260+
/// functions. Users should call the other overload if possible.
249261
LogicalResult convertType(Type t, SmallVectorImpl<Type> &results) const;
250262

263+
/// Convert the type of the given value. This function returns failure if no
264+
/// valid conversion exists, success otherwise. If the new set of types is
265+
/// empty, the type is removed and any usages of the existing value are
266+
/// expected to be removed during conversion.
267+
///
268+
/// Note: This overload invokes both context-aware and context-unaware type
269+
/// conversion functions.
270+
LogicalResult convertType(Value v, SmallVectorImpl<Type> &results) const;
271+
251272
/// This hook simplifies defining 1-1 type conversions. This function returns
252273
/// the type to convert to on success, and a null type on failure.
253274
Type convertType(Type t) const;
275+
Type convertType(Value v) const;
254276

255277
/// Attempts a 1-1 type conversion, expecting the result type to be
256278
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
@@ -259,25 +281,36 @@ class TypeConverter {
259281
TargetType convertType(Type t) const {
260282
return dyn_cast_or_null<TargetType>(convertType(t));
261283
}
284+
template <typename TargetType>
285+
TargetType convertType(Value v) const {
286+
return dyn_cast_or_null<TargetType>(convertType(v));
287+
}
262288

263-
/// Convert the given set of types, filling 'results' as necessary. This
264-
/// returns failure if the conversion of any of the types fails, success
289+
/// Convert the given types, filling 'results' as necessary. This returns
290+
/// "failure" if the conversion of any of the types fails, "success"
265291
/// otherwise.
266292
LogicalResult convertTypes(TypeRange types,
267293
SmallVectorImpl<Type> &results) const;
268294

295+
/// Convert the types of the given values, filling 'results' as necessary.
296+
/// This returns "failure" if the conversion of any of the types fails,
297+
/// "success" otherwise.
298+
LogicalResult convertTypes(ValueRange values,
299+
SmallVectorImpl<Type> &results) const;
300+
269301
/// Return true if the given type is legal for this type converter, i.e. the
270302
/// type converts to itself.
271303
bool isLegal(Type type) const;
304+
bool isLegal(Value value) const;
272305

273306
/// Return true if all of the given types are legal for this type converter.
274-
template <typename RangeT>
275-
std::enable_if_t<!std::is_convertible<RangeT, Type>::value &&
276-
!std::is_convertible<RangeT, Operation *>::value,
277-
bool>
278-
isLegal(RangeT &&range) const {
307+
bool isLegal(TypeRange range) const {
279308
return llvm::all_of(range, [this](Type type) { return isLegal(type); });
280309
}
310+
bool isLegal(ValueRange range) const {
311+
return llvm::all_of(range, [this](Value value) { return isLegal(value); });
312+
}
313+
281314
/// Return true if the given operation has legal operand and result types.
282315
bool isLegal(Operation *op) const;
283316

@@ -296,6 +329,11 @@ class TypeConverter {
296329
LogicalResult convertSignatureArgs(TypeRange types,
297330
SignatureConversion &result,
298331
unsigned origInputOffset = 0) const;
332+
LogicalResult convertSignatureArg(unsigned inputNo, Value value,
333+
SignatureConversion &result) const;
334+
LogicalResult convertSignatureArgs(ValueRange values,
335+
SignatureConversion &result,
336+
unsigned origInputOffset = 0) const;
299337

300338
/// This function converts the type signature of the given block, by invoking
301339
/// 'convertSignatureArg' for each argument. This function should return a
@@ -329,7 +367,7 @@ class TypeConverter {
329367
/// types is empty, the type is removed and any usages of the existing value
330368
/// are expected to be removed during conversion.
331369
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
332-
Type, SmallVectorImpl<Type> &)>;
370+
PointerUnion<Type, Value>, SmallVectorImpl<Type> &)>;
333371

334372
/// The signature of the callback used to materialize a source conversion.
335373
///
@@ -349,13 +387,14 @@ class TypeConverter {
349387

350388
/// Generate a wrapper for the given callback. This allows for accepting
351389
/// different callback forms, that all compose into a single version.
352-
/// With callback of form: `std::optional<Type>(T)`
390+
/// With callback of form: `std::optional<Type>(T)`, where `T` can be a
391+
/// `Value` or a `Type` (or a class derived from `Type`).
353392
template <typename T, typename FnT>
354393
std::enable_if_t<std::is_invocable_v<FnT, T>, ConversionCallbackFn>
355-
wrapCallback(FnT &&callback) const {
394+
wrapCallback(FnT &&callback) {
356395
return wrapCallback<T>([callback = std::forward<FnT>(callback)](
357-
T type, SmallVectorImpl<Type> &results) {
358-
if (std::optional<Type> resultOpt = callback(type)) {
396+
T typeOrValue, SmallVectorImpl<Type> &results) {
397+
if (std::optional<Type> resultOpt = callback(typeOrValue)) {
359398
bool wasSuccess = static_cast<bool>(*resultOpt);
360399
if (wasSuccess)
361400
results.push_back(*resultOpt);
@@ -365,20 +404,49 @@ class TypeConverter {
365404
});
366405
}
367406
/// With callback of form: `std::optional<LogicalResult>(
368-
/// T, SmallVectorImpl<Type> &, ArrayRef<Type>)`.
407+
/// T, SmallVectorImpl<Type> &)`, where `T` is a type.
369408
template <typename T, typename FnT>
370-
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &>,
409+
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
410+
std::is_base_of_v<Type, T>,
371411
ConversionCallbackFn>
372412
wrapCallback(FnT &&callback) const {
373413
return [callback = std::forward<FnT>(callback)](
374-
Type type,
414+
PointerUnion<Type, Value> typeOrValue,
375415
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
376-
T derivedType = dyn_cast<T>(type);
416+
T derivedType;
417+
if (Type t = dyn_cast<Type>(typeOrValue)) {
418+
derivedType = dyn_cast<T>(t);
419+
} else if (Value v = dyn_cast<Value>(typeOrValue)) {
420+
derivedType = dyn_cast<T>(v.getType());
421+
} else {
422+
llvm_unreachable("unexpected variant");
423+
}
377424
if (!derivedType)
378425
return std::nullopt;
379426
return callback(derivedType, results);
380427
};
381428
}
429+
/// With callback of form: `std::optional<LogicalResult>(
430+
/// T, SmallVectorImpl<Type>)`, where `T` is a `Value`.
431+
template <typename T, typename FnT>
432+
std::enable_if_t<std::is_invocable_v<FnT, T, SmallVectorImpl<Type> &> &&
433+
std::is_same_v<T, Value>,
434+
ConversionCallbackFn>
435+
wrapCallback(FnT &&callback) {
436+
hasContextAwareTypeConversions = true;
437+
return [callback = std::forward<FnT>(callback)](
438+
PointerUnion<Type, Value> typeOrValue,
439+
SmallVectorImpl<Type> &results) -> std::optional<LogicalResult> {
440+
if (Type t = dyn_cast<Type>(typeOrValue)) {
441+
// Context-aware type conversion was called with a type.
442+
return std::nullopt;
443+
} else if (Value v = dyn_cast<Value>(typeOrValue)) {
444+
return callback(v, results);
445+
}
446+
llvm_unreachable("unexpected variant");
447+
return std::nullopt;
448+
};
449+
}
382450

383451
/// Register a type conversion.
384452
void registerConversion(ConversionCallbackFn callback) {
@@ -505,6 +573,12 @@ class TypeConverter {
505573
mutable DenseMap<Type, SmallVector<Type, 2>> cachedMultiConversions;
506574
/// A mutex used for cache access
507575
mutable llvm::sys::SmartRWMutex<true> cacheMutex;
576+
/// Whether the type converter has context-aware type conversions. I.e.,
577+
/// conversion rules that depend on the SSA value instead of just the type.
578+
/// Type conversion caching is deactivated when there are context-aware
579+
/// conversions because the type converter may return different results for
580+
/// the same input type.
581+
bool hasContextAwareTypeConversions = false;
508582
};
509583

510584
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ class Structural1ToNConversionPattern : public OpConversionPattern<SourceOp> {
5252
SmallVector<unsigned> offsets;
5353
offsets.push_back(0);
5454
// Do the type conversion and record the offsets.
55-
for (Type type : op.getResultTypes()) {
56-
if (failed(typeConverter->convertTypes(type, dstTypes)))
55+
for (Value v : op.getResults()) {
56+
if (failed(typeConverter->convertType(v, dstTypes)))
5757
return rewriter.notifyMatchFailure(op, "could not convert result type");
5858
offsets.push_back(dstTypes.size());
5959
}
@@ -126,7 +126,6 @@ class ConvertForOpTypes
126126
// Inline the type converted region from the original operation.
127127
rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
128128
newOp.getRegion().end());
129-
130129
return newOp;
131130
}
132131
};
@@ -225,15 +224,14 @@ void mlir::scf::populateSCFStructuralTypeConversions(
225224

226225
void mlir::scf::populateSCFStructuralTypeConversionTarget(
227226
const TypeConverter &typeConverter, ConversionTarget &target) {
228-
target.addDynamicallyLegalOp<ForOp, IfOp>([&](Operation *op) {
229-
return typeConverter.isLegal(op->getResultTypes());
230-
});
227+
target.addDynamicallyLegalOp<ForOp, IfOp>(
228+
[&](Operation *op) { return typeConverter.isLegal(op->getResults()); });
231229
target.addDynamicallyLegalOp<scf::YieldOp>([&](scf::YieldOp op) {
232230
// We only have conversions for a subset of ops that use scf.yield
233231
// terminators.
234232
if (!isa<ForOp, IfOp, WhileOp>(op->getParentOp()))
235233
return true;
236-
return typeConverter.isLegal(op.getOperandTypes());
234+
return typeConverter.isLegal(op.getOperands());
237235
});
238236
target.addDynamicallyLegalOp<WhileOp, ConditionOp>(
239237
[&](Operation *op) { return typeConverter.isLegal(op); });

0 commit comments

Comments
 (0)